[🚀] mysql

This commit is contained in:
2026-02-07 15:36:21 +08:00
parent 078967c601
commit 3005002379
6 changed files with 818 additions and 340 deletions

View File

@@ -1,247 +1,499 @@
package database
import (
"context"
"database/sql"
"errors"
"fmt"
"log"
"os"
"sync"
"time"
"gitea.bvbej.com/bvbej/base-golang/pkg/time_parse"
"gitea.bvbej.com/bvbej/base-golang/pkg/trace"
"gorm.io/driver/mysql"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils"
)
const (
callBackBeforeName = "core:before"
callBackAfterName = "core:after"
startTime = "_start_time"
traceCtxName = "_trace_ctx_name"
)
var _ MysqlRepo = (*mysqlRepo)(nil)
// MysqlRepo MySQL接口
type MysqlRepo interface {
i()
GetRead(options ...Option) *gorm.DB
GetWrite(options ...Option) *gorm.DB
// DB 获取数据库实例(自动选择读写)
DB(ctx context.Context, forceWrite ...bool) *gorm.DB
// GetRead 获取读库
GetRead() *gorm.DB
// GetWrite 获取写库
GetWrite() *gorm.DB
// Transaction 执行事务
Transaction(ctx context.Context, fc func(*gorm.DB) error) error
// Close 关闭连接
Close() error
// Ping 检查连接
Ping(ctx context.Context) error
// Stats 获取连接池统计信息
Stats() (*DBStats, error)
// HealthCheck 健康检查
HealthCheck(ctx context.Context) error
}
// DBStats 数据库统计信息
type DBStats struct {
Read sql.DBStats `json:"read"`
Write sql.DBStats `json:"write"`
}
// MySQLConfig MySQL配置
type MySQLConfig struct {
Read struct {
Addr string `yaml:"addr"`
User string `yaml:"user"`
Pass string `yaml:"pass"`
Name string `yaml:"name"`
} `yaml:"read"`
Write struct {
Addr string `yaml:"addr"`
User string `yaml:"user"`
Pass string `yaml:"pass"`
Name string `yaml:"name"`
} `yaml:"write"`
Base struct {
MaxOpenConn int `yaml:"maxOpenConn"` //最大连接数
MaxIdleConn int `yaml:"maxIdleConn"` //最大空闲连接数
ConnMaxLifeTime time.Duration `yaml:"connMaxLifeTime"` //最大连接超时(分钟)
} `yaml:"base"`
Read DBConnConfig `yaml:"read" json:"read"`
Write DBConnConfig `yaml:"write" json:"write"`
Base BaseConfig `yaml:"base" json:"base"`
Logger LoggerConfig `yaml:"logger" json:"logger"`
}
// DBConnConfig 数据库连接配置
type DBConnConfig struct {
Addr string `yaml:"addr" json:"addr"`
User string `yaml:"user" json:"user"`
Pass string `yaml:"pass" json:"pass"`
Name string `yaml:"name" json:"name"`
Charset string `yaml:"charset" json:"charset"` // 默认 utf8mb4
Collation string `yaml:"collation" json:"collation"` // 默认 utf8mb4_unicode_ci
Loc string `yaml:"loc" json:"loc"` // 默认 Local
ParseTime bool `yaml:"parse_time" json:"parse_time"` // 默认 true
Timeout time.Duration `yaml:"timeout" json:"timeout"` // 连接超时(秒)
ReadTimeout time.Duration `yaml:"read_timeout" json:"read_timeout"` // 读超时(秒)
WriteTimeout time.Duration `yaml:"write_timeout" json:"write_timeout"` // 写超时(秒)
}
// BaseConfig 基础配置
type BaseConfig struct {
MaxOpenConn int `yaml:"max_open_conn" json:"max_open_conn"`
MaxIdleConn int `yaml:"max_idle_conn" json:"max_idle_conn"`
ConnMaxLifetime time.Duration `yaml:"conn_max_lifetime" json:"conn_max_lifetime"` // 秒
ConnMaxIdleTime time.Duration `yaml:"conn_max_idle_time" json:"conn_max_idle_time"` // 秒
}
// LoggerConfig 日志配置
type LoggerConfig struct {
SlowThreshold time.Duration `yaml:"slow_threshold" json:"slow_threshold"`
Colorful bool `yaml:"colorful" json:"colorful"`
IgnoreRecordNotFoundError bool `yaml:"ignore_record_not_found_error" json:"ignore_record_not_found_error"`
LogLevel string `yaml:"log_level" json:"log_level"`
LogOutput string `yaml:"log_output" json:"log_output"` // stdout, file
LogFile string `yaml:"log_file" json:"log_file"` // 日志文件路径
}
// DefaultMySQLConfig 默认配置
func DefaultMySQLConfig() *MySQLConfig {
return &MySQLConfig{
Base: BaseConfig{
MaxOpenConn: 100,
MaxIdleConn: 10,
ConnMaxLifetime: 3600, // 1小时
ConnMaxIdleTime: 600, // 10分钟
},
Logger: LoggerConfig{
SlowThreshold: 200 * time.Millisecond,
Colorful: true,
IgnoreRecordNotFoundError: true,
LogLevel: "warn",
LogOutput: "stdout",
},
}
}
// mysqlRepo MySQL实现
type mysqlRepo struct {
read *gorm.DB
write *gorm.DB
read *gorm.DB
write *gorm.DB
config *MySQLConfig
mu sync.RWMutex
closed bool
}
func NewMysql(cfg MySQLConfig) (MysqlRepo, error) {
dbr, err := dbConnect(cfg.Read.User, cfg.Read.Pass, cfg.Read.Addr, cfg.Read.Name,
cfg.Base.MaxOpenConn, cfg.Base.MaxIdleConn, cfg.Base.ConnMaxLifeTime)
if err != nil {
// NewMysql 创建MySQL实例
func NewMysql(cfg *MySQLConfig) (MysqlRepo, error) {
if cfg == nil {
cfg = DefaultMySQLConfig()
}
// 合并默认配置
mergeDefaultConfig(cfg)
// 验证配置
if err := validateMySQLConfig(cfg); err != nil {
return nil, err
}
dbw, err := dbConnect(cfg.Write.User, cfg.Write.Pass, cfg.Write.Addr, cfg.Write.Name,
cfg.Base.MaxOpenConn, cfg.Base.MaxIdleConn, cfg.Base.ConnMaxLifeTime)
// 连接读库
dbr, err := dbConnect("read", cfg.Read, cfg.Base, cfg.Logger)
if err != nil {
return nil, err
return nil, fmt.Errorf("连接读库失败: %w", err)
}
return &mysqlRepo{
read: dbr,
write: dbw,
// 连接写库
dbw, err := dbConnect("write", cfg.Write, cfg.Base, cfg.Logger)
if err != nil {
_ = closeDB(dbr)
return nil, fmt.Errorf("连接写库失败: %w", err)
}
repo := &mysqlRepo{
read: dbr,
write: dbw,
config: cfg,
closed: false,
}
return repo, nil
}
// DB 获取数据库实例(自动选择读写)
func (m *mysqlRepo) DB(ctx context.Context, forceWrite ...bool) *gorm.DB {
m.mu.RLock()
defer m.mu.RUnlock()
useWrite := len(forceWrite) > 0 && forceWrite[0]
var db *gorm.DB
if useWrite {
db = m.write
} else {
db = m.read
}
if ctx != nil {
db = db.WithContext(ctx)
}
return db
}
// GetRead 获取读库
func (m *mysqlRepo) GetRead() *gorm.DB {
m.mu.RLock()
defer m.mu.RUnlock()
return m.read
}
// GetWrite 获取写库
func (m *mysqlRepo) GetWrite() *gorm.DB {
m.mu.RLock()
defer m.mu.RUnlock()
return m.write
}
// Transaction 执行事务
func (m *mysqlRepo) Transaction(ctx context.Context, fc func(*gorm.DB) error) error {
m.mu.RLock()
if m.closed {
m.mu.RUnlock()
return fmt.Errorf("数据库连接已关闭")
}
db := m.write
m.mu.RUnlock()
if ctx != nil {
db = db.WithContext(ctx)
}
return db.Transaction(fc)
}
// Close 关闭连接
func (m *mysqlRepo) Close() error {
m.mu.Lock()
defer m.mu.Unlock()
if m.closed {
return nil
}
var errs []error
if err := closeDB(m.read); err != nil {
errs = append(errs, fmt.Errorf("关闭读库失败: %w", err))
}
if err := closeDB(m.write); err != nil {
errs = append(errs, fmt.Errorf("关闭写库失败: %w", err))
}
m.closed = true
if len(errs) > 0 {
return errors.Join(errs...)
}
return nil
}
// Ping 检查连接
func (m *mysqlRepo) Ping(ctx context.Context) error {
m.mu.RLock()
defer m.mu.RUnlock()
if m.closed {
return fmt.Errorf("数据库连接已关闭")
}
if ctx == nil {
ctx = context.Background()
}
// Ping读库
readDB, err := m.read.DB()
if err != nil {
return fmt.Errorf("获取读库失败: %w", err)
}
if err := readDB.PingContext(ctx); err != nil {
return fmt.Errorf("Ping读库失败: %w", err)
}
// Ping写库
writeDB, err := m.write.DB()
if err != nil {
return fmt.Errorf("获取写库失败: %w", err)
}
if err := writeDB.PingContext(ctx); err != nil {
return fmt.Errorf("Ping写库失败: %w", err)
}
return nil
}
// Stats 获取连接池统计信息
func (m *mysqlRepo) Stats() (*DBStats, error) {
m.mu.RLock()
defer m.mu.RUnlock()
if m.closed {
return nil, fmt.Errorf("数据库连接已关闭")
}
readDB, err := m.read.DB()
if err != nil {
return nil, fmt.Errorf("获取读库失败: %w", err)
}
writeDB, err := m.write.DB()
if err != nil {
return nil, fmt.Errorf("获取写库失败: %w", err)
}
return &DBStats{
Read: readDB.Stats(),
Write: writeDB.Stats(),
}, nil
}
func (d *mysqlRepo) i() {}
func (d *mysqlRepo) GetRead(options ...Option) *gorm.DB {
opt := newOption()
for _, f := range options {
f(opt)
// HealthCheck 健康检查
func (m *mysqlRepo) HealthCheck(ctx context.Context) error {
if err := m.Ping(ctx); err != nil {
return err
}
db := d.read
if opt.Trace != nil {
db.InstanceSet(traceCtxName, opt.Trace)
stats, err := m.Stats()
if err != nil {
return err
}
return db
// 检查连接池状态
if stats.Read.OpenConnections >= m.config.Base.MaxOpenConn {
return fmt.Errorf("读库连接池已满")
}
if stats.Write.OpenConnections >= m.config.Base.MaxOpenConn {
return fmt.Errorf("写库连接池已满")
}
return nil
}
func (d *mysqlRepo) GetWrite(options ...Option) *gorm.DB {
opt := newOption()
for _, f := range options {
f(opt)
// mergeDefaultConfig 合并默认配置
func mergeDefaultConfig(cfg *MySQLConfig) {
// 读库默认配置
if cfg.Read.Charset == "" {
cfg.Read.Charset = "utf8mb4"
}
db := d.write
if opt.Trace != nil {
db.InstanceSet(traceCtxName, opt.Trace)
if cfg.Read.Collation == "" {
cfg.Read.Collation = "utf8mb4_unicode_ci"
}
if cfg.Read.Loc == "" {
cfg.Read.Loc = "Local"
}
if cfg.Read.Timeout == 0 {
cfg.Read.Timeout = 10 // 10秒
}
if cfg.Read.ReadTimeout == 0 {
cfg.Read.ReadTimeout = 30 // 30秒
}
if cfg.Read.WriteTimeout == 0 {
cfg.Read.WriteTimeout = 30 // 30秒
}
cfg.Read.ParseTime = true
return db
// 写库默认配置
if cfg.Write.Charset == "" {
cfg.Write.Charset = "utf8mb4"
}
if cfg.Write.Collation == "" {
cfg.Write.Collation = "utf8mb4_unicode_ci"
}
if cfg.Write.Loc == "" {
cfg.Write.Loc = "Local"
}
if cfg.Write.Timeout == 0 {
cfg.Write.Timeout = 10
}
if cfg.Write.ReadTimeout == 0 {
cfg.Write.ReadTimeout = 30
}
if cfg.Write.WriteTimeout == 0 {
cfg.Write.WriteTimeout = 30
}
cfg.Write.ParseTime = true
}
func (d *mysqlRepo) Close() (err error) {
rdb, err1 := d.read.DB()
if err1 != nil {
err = errors.Join(err1)
// validateMySQLConfig 验证配置
func validateMySQLConfig(cfg *MySQLConfig) error {
if cfg.Read.Addr == "" {
return fmt.Errorf("读库地址不能为空")
}
err2 := rdb.Close()
if err2 != nil {
err = errors.Join(err2)
if cfg.Read.User == "" {
return fmt.Errorf("读库用户名不能为空")
}
if cfg.Read.Name == "" {
return fmt.Errorf("读库数据库名不能为空")
}
wdb, err3 := d.write.DB()
if err3 != nil {
err = errors.Join(err3)
if cfg.Write.Addr == "" {
return fmt.Errorf("写库地址不能为空")
}
err4 := wdb.Close()
if err4 != nil {
err = errors.Join(err4)
if cfg.Write.User == "" {
return fmt.Errorf("写库用户名不能为空")
}
if cfg.Write.Name == "" {
return fmt.Errorf("写库数据库名不能为空")
}
return err
return nil
}
func dbConnect(user, pass, addr, dbName string, maxOpenConn, maxIdleConn int, connMaxLifeTime time.Duration) (*gorm.DB, error) {
dsn := fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=utf8mb4&parseTime=%t&loc=%s",
user,
pass,
addr,
dbName,
true,
"Local")
// buildDSN 构建DSN
func buildDSN(conn DBConnConfig) string {
return fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=%s&collation=%s&parseTime=%t&loc=%s&timeout=%ds&readTimeout=%ds&writeTimeout=%ds",
conn.User,
conn.Pass,
conn.Addr,
conn.Name,
conn.Charset,
conn.Collation,
conn.ParseTime,
conn.Loc,
int(conn.Timeout),
int(conn.ReadTimeout),
int(conn.WriteTimeout),
)
}
// dbConnect 连接数据库
func dbConnect(name string, conn DBConnConfig, base BaseConfig, logCfg LoggerConfig) (*gorm.DB, error) {
// 构建DSN
dsn := buildDSN(conn)
// 配置日志
logLevel := parseLogLevel(logCfg.LogLevel)
logWriter := getLogWriter(logCfg)
// 日志配置
newLogger := logger.New(
log.New(os.Stdout, "\r\n", log.LstdFlags),
log.New(logWriter, "\r\n", log.LstdFlags),
logger.Config{
SlowThreshold: time.Second, // 慢SQL阈值
Colorful: true, // 彩色打印
IgnoreRecordNotFoundError: true, // 忽略记录未找到错误
LogLevel: logger.Error, // 日志级别
SlowThreshold: logCfg.SlowThreshold,
Colorful: logCfg.Colorful,
IgnoreRecordNotFoundError: logCfg.IgnoreRecordNotFoundError,
LogLevel: logLevel,
},
)
// 打开数据库连接
db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{
NamingStrategy: schema.NamingStrategy{
SingularTable: true,
},
Logger: newLogger,
Logger: newLogger,
SkipDefaultTransaction: true,
PrepareStmt: true,
DisableForeignKeyConstraintWhenMigrating: true,
})
if err != nil {
return nil, errors.Join(err, fmt.Errorf("[db connection failed] Database name: %s", dbName))
return nil, fmt.Errorf("打开数据库连接失败 [%s]: %w", name, err)
}
db.Set("gorm:table_options", "CHARSET=utf8mb4")
// 获取底层sqlDB
sqlDB, err := db.DB()
if err != nil {
return nil, err
return nil, fmt.Errorf("获取sqlDB失败: %w", err)
}
// 设置连接池 用于设置最大打开的连接数默认值为0表示不限制.设置最大的连接数可以避免并发太高导致连接mysql出现too many connections的错误。
sqlDB.SetMaxOpenConns(maxOpenConn)
// 设置最大连接数 用于设置闲置的连接数.设置闲置的连接数则当开启的一个连接使用完成后可以放在池里等候下一次使用。
sqlDB.SetMaxIdleConns(maxIdleConn)
// 设置最大连接超时
sqlDB.SetConnMaxLifetime(time.Minute * connMaxLifeTime)
// 使用插件
err = db.Use(&TracePlugin{})
if err != nil {
return nil, err
// 设置连接池参数
if base.MaxOpenConn > 0 {
sqlDB.SetMaxOpenConns(base.MaxOpenConn)
}
if base.MaxIdleConn > 0 {
sqlDB.SetMaxIdleConns(base.MaxIdleConn)
}
if base.ConnMaxLifetime > 0 {
sqlDB.SetConnMaxLifetime(base.ConnMaxLifetime * time.Second)
}
if base.ConnMaxIdleTime > 0 {
sqlDB.SetConnMaxIdleTime(base.ConnMaxIdleTime * time.Second)
}
return db, nil
}
/***************************************************************/
type TracePlugin struct{}
func (op *TracePlugin) Name() string {
return "TracePlugin"
// closeDB 关闭数据库
func closeDB(db *gorm.DB) error {
if db == nil {
return nil
}
sqlDB, err := db.DB()
if err != nil {
return err
}
return sqlDB.Close()
}
func (op *TracePlugin) Initialize(db *gorm.DB) (err error) {
// 开始前
_ = db.Callback().Create().Before("gorm:before_create").Register(callBackBeforeName, before)
_ = db.Callback().Query().Before("gorm:query").Register(callBackBeforeName, before)
_ = db.Callback().Delete().Before("gorm:before_delete").Register(callBackBeforeName, before)
_ = db.Callback().Update().Before("gorm:setup_reflect_value").Register(callBackBeforeName, before)
_ = db.Callback().Row().Before("gorm:row").Register(callBackBeforeName, before)
_ = db.Callback().Raw().Before("gorm:raw").Register(callBackBeforeName, before)
// 结束后
_ = db.Callback().Create().After("gorm:after_create").Register(callBackAfterName, after)
_ = db.Callback().Query().After("gorm:after_query").Register(callBackAfterName, after)
_ = db.Callback().Delete().After("gorm:after_delete").Register(callBackAfterName, after)
_ = db.Callback().Update().After("gorm:after_update").Register(callBackAfterName, after)
_ = db.Callback().Row().After("gorm:row").Register(callBackAfterName, after)
_ = db.Callback().Raw().After("gorm:raw").Register(callBackAfterName, after)
return
// parseLogLevel 解析日志级别
func parseLogLevel(level string) logger.LogLevel {
switch level {
case "silent":
return logger.Silent
case "error":
return logger.Error
case "warn":
return logger.Warn
case "info":
return logger.Info
default:
return logger.Warn
}
}
func before(db *gorm.DB) {
db.InstanceSet(startTime, time.Now())
}
func after(db *gorm.DB) {
_traceCtx, isExist := db.InstanceGet(traceCtxName)
if !isExist {
return
}
_trace, ok := _traceCtx.(trace.T)
if !ok {
return
}
_ts, isExist := db.InstanceGet(startTime)
if !isExist {
return
}
ts, ok := _ts.(time.Time)
if !ok {
return
}
sql := db.Dialector.Explain(db.Statement.SQL.String(), db.Statement.Vars...)
sqlInfo := new(trace.SQL)
sqlInfo.Timestamp = time_parse.CSTLayoutString()
sqlInfo.SQL = sql
sqlInfo.Stack = utils.FileWithLineNum()
sqlInfo.Rows = db.Statement.RowsAffected
sqlInfo.CostSeconds = time.Since(ts).Seconds()
_trace.AppendSQL(sqlInfo)
// getLogWriter 获取日志输出
func getLogWriter(cfg LoggerConfig) *os.File {
if cfg.LogOutput == "file" && cfg.LogFile != "" {
file, err := os.OpenFile(cfg.LogFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666)
if err != nil {
log.Printf("打开日志文件失败: %v使用标准输出", err)
return os.Stdout
}
return file
}
return os.Stdout
}