package database import ( "context" "database/sql" "errors" "fmt" "log" "os" "sync" "time" "gorm.io/driver/mysql" "gorm.io/gorm" "gorm.io/gorm/logger" "gorm.io/gorm/schema" ) // MysqlRepo MySQL接口 type MysqlRepo interface { // DB 获取数据库实例(自动选择读写) DB(ctx context.Context, forceWrite ...bool) *gorm.DB // GetRead 获取读库 GetRead() *gorm.DB // GetWrite 获取写库 GetWrite() *gorm.DB // Transaction 执行事务 Transaction(ctx context.Context, fc func(*gorm.DB) error) error // Close 关闭连接 Close() error // Ping 检查连接 Ping(ctx context.Context) error // Stats 获取连接池统计信息 Stats() (*DBStats, error) // HealthCheck 健康检查 HealthCheck(ctx context.Context) error } // DBStats 数据库统计信息 type DBStats struct { Read sql.DBStats `json:"read"` Write sql.DBStats `json:"write"` } // MySQLConfig MySQL配置 type MySQLConfig struct { Read DBConnConfig `yaml:"read" json:"read"` Write DBConnConfig `yaml:"write" json:"write"` Base BaseConfig `yaml:"base" json:"base"` Logger LoggerConfig `yaml:"logger" json:"logger"` } // DBConnConfig 数据库连接配置 type DBConnConfig struct { Addr string `yaml:"addr" json:"addr"` User string `yaml:"user" json:"user"` Pass string `yaml:"pass" json:"pass"` Name string `yaml:"name" json:"name"` Charset string `yaml:"charset" json:"charset"` // 默认 utf8mb4 Collation string `yaml:"collation" json:"collation"` // 默认 utf8mb4_unicode_ci Loc string `yaml:"loc" json:"loc"` // 默认 Local ParseTime bool `yaml:"parse_time" json:"parse_time"` // 默认 true Timeout time.Duration `yaml:"timeout" json:"timeout"` // 连接超时(秒) ReadTimeout time.Duration `yaml:"read_timeout" json:"read_timeout"` // 读超时(秒) WriteTimeout time.Duration `yaml:"write_timeout" json:"write_timeout"` // 写超时(秒) } // BaseConfig 基础配置 type BaseConfig struct { MaxOpenConn int `yaml:"max_open_conn" json:"max_open_conn"` MaxIdleConn int `yaml:"max_idle_conn" json:"max_idle_conn"` ConnMaxLifetime time.Duration `yaml:"conn_max_lifetime" json:"conn_max_lifetime"` // 秒 ConnMaxIdleTime time.Duration `yaml:"conn_max_idle_time" json:"conn_max_idle_time"` // 秒 } // LoggerConfig 日志配置 type LoggerConfig struct { SlowThreshold time.Duration `yaml:"slow_threshold" json:"slow_threshold"` Colorful bool `yaml:"colorful" json:"colorful"` IgnoreRecordNotFoundError bool `yaml:"ignore_record_not_found_error" json:"ignore_record_not_found_error"` LogLevel string `yaml:"log_level" json:"log_level"` LogOutput string `yaml:"log_output" json:"log_output"` // stdout, file LogFile string `yaml:"log_file" json:"log_file"` // 日志文件路径 } // DefaultMySQLConfig 默认配置 func DefaultMySQLConfig() *MySQLConfig { return &MySQLConfig{ Base: BaseConfig{ MaxOpenConn: 100, MaxIdleConn: 10, ConnMaxLifetime: 3600, // 1小时 ConnMaxIdleTime: 600, // 10分钟 }, Logger: LoggerConfig{ SlowThreshold: 200 * time.Millisecond, Colorful: true, IgnoreRecordNotFoundError: true, LogLevel: "warn", LogOutput: "stdout", }, } } // mysqlRepo MySQL实现 type mysqlRepo struct { read *gorm.DB write *gorm.DB config *MySQLConfig mu sync.RWMutex closed bool } // NewMysql 创建MySQL实例 func NewMysql(cfg *MySQLConfig) (MysqlRepo, error) { if cfg == nil { cfg = DefaultMySQLConfig() } // 合并默认配置 mergeDefaultConfig(cfg) // 验证配置 if err := validateMySQLConfig(cfg); err != nil { return nil, err } // 连接读库 dbr, err := dbConnect("read", cfg.Read, cfg.Base, cfg.Logger) if err != nil { return nil, fmt.Errorf("连接读库失败: %w", err) } // 连接写库 dbw, err := dbConnect("write", cfg.Write, cfg.Base, cfg.Logger) if err != nil { _ = closeDB(dbr) return nil, fmt.Errorf("连接写库失败: %w", err) } repo := &mysqlRepo{ read: dbr, write: dbw, config: cfg, closed: false, } return repo, nil } // DB 获取数据库实例(自动选择读写) func (m *mysqlRepo) DB(ctx context.Context, forceWrite ...bool) *gorm.DB { m.mu.RLock() defer m.mu.RUnlock() useWrite := len(forceWrite) > 0 && forceWrite[0] var db *gorm.DB if useWrite { db = m.write } else { db = m.read } if ctx != nil { db = db.WithContext(ctx) } return db } // GetRead 获取读库 func (m *mysqlRepo) GetRead() *gorm.DB { m.mu.RLock() defer m.mu.RUnlock() return m.read } // GetWrite 获取写库 func (m *mysqlRepo) GetWrite() *gorm.DB { m.mu.RLock() defer m.mu.RUnlock() return m.write } // Transaction 执行事务 func (m *mysqlRepo) Transaction(ctx context.Context, fc func(*gorm.DB) error) error { m.mu.RLock() if m.closed { m.mu.RUnlock() return fmt.Errorf("数据库连接已关闭") } db := m.write m.mu.RUnlock() if ctx != nil { db = db.WithContext(ctx) } return db.Transaction(fc) } // Close 关闭连接 func (m *mysqlRepo) Close() error { m.mu.Lock() defer m.mu.Unlock() if m.closed { return nil } var errs []error if err := closeDB(m.read); err != nil { errs = append(errs, fmt.Errorf("关闭读库失败: %w", err)) } if err := closeDB(m.write); err != nil { errs = append(errs, fmt.Errorf("关闭写库失败: %w", err)) } m.closed = true if len(errs) > 0 { return errors.Join(errs...) } return nil } // Ping 检查连接 func (m *mysqlRepo) Ping(ctx context.Context) error { m.mu.RLock() defer m.mu.RUnlock() if m.closed { return fmt.Errorf("数据库连接已关闭") } if ctx == nil { ctx = context.Background() } // Ping读库 readDB, err := m.read.DB() if err != nil { return fmt.Errorf("获取读库失败: %w", err) } if err := readDB.PingContext(ctx); err != nil { return fmt.Errorf("Ping读库失败: %w", err) } // Ping写库 writeDB, err := m.write.DB() if err != nil { return fmt.Errorf("获取写库失败: %w", err) } if err := writeDB.PingContext(ctx); err != nil { return fmt.Errorf("Ping写库失败: %w", err) } return nil } // Stats 获取连接池统计信息 func (m *mysqlRepo) Stats() (*DBStats, error) { m.mu.RLock() defer m.mu.RUnlock() if m.closed { return nil, fmt.Errorf("数据库连接已关闭") } readDB, err := m.read.DB() if err != nil { return nil, fmt.Errorf("获取读库失败: %w", err) } writeDB, err := m.write.DB() if err != nil { return nil, fmt.Errorf("获取写库失败: %w", err) } return &DBStats{ Read: readDB.Stats(), Write: writeDB.Stats(), }, nil } // HealthCheck 健康检查 func (m *mysqlRepo) HealthCheck(ctx context.Context) error { if err := m.Ping(ctx); err != nil { return err } stats, err := m.Stats() if err != nil { return err } // 检查连接池状态 if stats.Read.OpenConnections >= m.config.Base.MaxOpenConn { return fmt.Errorf("读库连接池已满") } if stats.Write.OpenConnections >= m.config.Base.MaxOpenConn { return fmt.Errorf("写库连接池已满") } return nil } // mergeDefaultConfig 合并默认配置 func mergeDefaultConfig(cfg *MySQLConfig) { // 读库默认配置 if cfg.Read.Charset == "" { cfg.Read.Charset = "utf8mb4" } if cfg.Read.Collation == "" { cfg.Read.Collation = "utf8mb4_unicode_ci" } if cfg.Read.Loc == "" { cfg.Read.Loc = "Local" } if cfg.Read.Timeout == 0 { cfg.Read.Timeout = 10 // 10秒 } if cfg.Read.ReadTimeout == 0 { cfg.Read.ReadTimeout = 30 // 30秒 } if cfg.Read.WriteTimeout == 0 { cfg.Read.WriteTimeout = 30 // 30秒 } cfg.Read.ParseTime = true // 写库默认配置 if cfg.Write.Charset == "" { cfg.Write.Charset = "utf8mb4" } if cfg.Write.Collation == "" { cfg.Write.Collation = "utf8mb4_unicode_ci" } if cfg.Write.Loc == "" { cfg.Write.Loc = "Local" } if cfg.Write.Timeout == 0 { cfg.Write.Timeout = 10 } if cfg.Write.ReadTimeout == 0 { cfg.Write.ReadTimeout = 30 } if cfg.Write.WriteTimeout == 0 { cfg.Write.WriteTimeout = 30 } cfg.Write.ParseTime = true } // validateMySQLConfig 验证配置 func validateMySQLConfig(cfg *MySQLConfig) error { if cfg.Read.Addr == "" { return fmt.Errorf("读库地址不能为空") } if cfg.Read.User == "" { return fmt.Errorf("读库用户名不能为空") } if cfg.Read.Name == "" { return fmt.Errorf("读库数据库名不能为空") } if cfg.Write.Addr == "" { return fmt.Errorf("写库地址不能为空") } if cfg.Write.User == "" { return fmt.Errorf("写库用户名不能为空") } if cfg.Write.Name == "" { return fmt.Errorf("写库数据库名不能为空") } return nil } // buildDSN 构建DSN func buildDSN(conn DBConnConfig) string { return fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=%s&collation=%s&parseTime=%t&loc=%s&timeout=%ds&readTimeout=%ds&writeTimeout=%ds", conn.User, conn.Pass, conn.Addr, conn.Name, conn.Charset, conn.Collation, conn.ParseTime, conn.Loc, int(conn.Timeout), int(conn.ReadTimeout), int(conn.WriteTimeout), ) } // dbConnect 连接数据库 func dbConnect(name string, conn DBConnConfig, base BaseConfig, logCfg LoggerConfig) (*gorm.DB, error) { // 构建DSN dsn := buildDSN(conn) // 配置日志 logLevel := parseLogLevel(logCfg.LogLevel) logWriter := getLogWriter(logCfg) newLogger := logger.New( log.New(logWriter, "\r\n", log.LstdFlags), logger.Config{ SlowThreshold: logCfg.SlowThreshold, Colorful: logCfg.Colorful, IgnoreRecordNotFoundError: logCfg.IgnoreRecordNotFoundError, LogLevel: logLevel, }, ) // 打开数据库连接 db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{ NamingStrategy: schema.NamingStrategy{ SingularTable: true, }, Logger: newLogger, SkipDefaultTransaction: true, PrepareStmt: true, DisableForeignKeyConstraintWhenMigrating: true, }) if err != nil { return nil, fmt.Errorf("打开数据库连接失败 [%s]: %w", name, err) } // 获取底层sqlDB sqlDB, err := db.DB() if err != nil { return nil, fmt.Errorf("获取sqlDB失败: %w", err) } // 设置连接池参数 if base.MaxOpenConn > 0 { sqlDB.SetMaxOpenConns(base.MaxOpenConn) } if base.MaxIdleConn > 0 { sqlDB.SetMaxIdleConns(base.MaxIdleConn) } if base.ConnMaxLifetime > 0 { sqlDB.SetConnMaxLifetime(base.ConnMaxLifetime * time.Second) } if base.ConnMaxIdleTime > 0 { sqlDB.SetConnMaxIdleTime(base.ConnMaxIdleTime * time.Second) } return db, nil } // closeDB 关闭数据库 func closeDB(db *gorm.DB) error { if db == nil { return nil } sqlDB, err := db.DB() if err != nil { return err } return sqlDB.Close() } // parseLogLevel 解析日志级别 func parseLogLevel(level string) logger.LogLevel { switch level { case "silent": return logger.Silent case "error": return logger.Error case "warn": return logger.Warn case "info": return logger.Info default: return logger.Warn } } // getLogWriter 获取日志输出 func getLogWriter(cfg LoggerConfig) *os.File { if cfg.LogOutput == "file" && cfg.LogFile != "" { file, err := os.OpenFile(cfg.LogFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666) if err != nil { log.Printf("打开日志文件失败: %v,使用标准输出", err) return os.Stdout } return file } return os.Stdout }