Files
base-golang/pkg/database/mysql.go
2026-02-07 15:36:21 +08:00

500 lines
12 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package database
import (
"context"
"database/sql"
"errors"
"fmt"
"log"
"os"
"sync"
"time"
"gorm.io/driver/mysql"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"gorm.io/gorm/schema"
)
// MysqlRepo MySQL接口
type MysqlRepo interface {
// DB 获取数据库实例(自动选择读写)
DB(ctx context.Context, forceWrite ...bool) *gorm.DB
// GetRead 获取读库
GetRead() *gorm.DB
// GetWrite 获取写库
GetWrite() *gorm.DB
// Transaction 执行事务
Transaction(ctx context.Context, fc func(*gorm.DB) error) error
// Close 关闭连接
Close() error
// Ping 检查连接
Ping(ctx context.Context) error
// Stats 获取连接池统计信息
Stats() (*DBStats, error)
// HealthCheck 健康检查
HealthCheck(ctx context.Context) error
}
// DBStats 数据库统计信息
type DBStats struct {
Read sql.DBStats `json:"read"`
Write sql.DBStats `json:"write"`
}
// MySQLConfig MySQL配置
type MySQLConfig struct {
Read DBConnConfig `yaml:"read" json:"read"`
Write DBConnConfig `yaml:"write" json:"write"`
Base BaseConfig `yaml:"base" json:"base"`
Logger LoggerConfig `yaml:"logger" json:"logger"`
}
// DBConnConfig 数据库连接配置
type DBConnConfig struct {
Addr string `yaml:"addr" json:"addr"`
User string `yaml:"user" json:"user"`
Pass string `yaml:"pass" json:"pass"`
Name string `yaml:"name" json:"name"`
Charset string `yaml:"charset" json:"charset"` // 默认 utf8mb4
Collation string `yaml:"collation" json:"collation"` // 默认 utf8mb4_unicode_ci
Loc string `yaml:"loc" json:"loc"` // 默认 Local
ParseTime bool `yaml:"parse_time" json:"parse_time"` // 默认 true
Timeout time.Duration `yaml:"timeout" json:"timeout"` // 连接超时(秒)
ReadTimeout time.Duration `yaml:"read_timeout" json:"read_timeout"` // 读超时(秒)
WriteTimeout time.Duration `yaml:"write_timeout" json:"write_timeout"` // 写超时(秒)
}
// BaseConfig 基础配置
type BaseConfig struct {
MaxOpenConn int `yaml:"max_open_conn" json:"max_open_conn"`
MaxIdleConn int `yaml:"max_idle_conn" json:"max_idle_conn"`
ConnMaxLifetime time.Duration `yaml:"conn_max_lifetime" json:"conn_max_lifetime"` // 秒
ConnMaxIdleTime time.Duration `yaml:"conn_max_idle_time" json:"conn_max_idle_time"` // 秒
}
// LoggerConfig 日志配置
type LoggerConfig struct {
SlowThreshold time.Duration `yaml:"slow_threshold" json:"slow_threshold"`
Colorful bool `yaml:"colorful" json:"colorful"`
IgnoreRecordNotFoundError bool `yaml:"ignore_record_not_found_error" json:"ignore_record_not_found_error"`
LogLevel string `yaml:"log_level" json:"log_level"`
LogOutput string `yaml:"log_output" json:"log_output"` // stdout, file
LogFile string `yaml:"log_file" json:"log_file"` // 日志文件路径
}
// DefaultMySQLConfig 默认配置
func DefaultMySQLConfig() *MySQLConfig {
return &MySQLConfig{
Base: BaseConfig{
MaxOpenConn: 100,
MaxIdleConn: 10,
ConnMaxLifetime: 3600, // 1小时
ConnMaxIdleTime: 600, // 10分钟
},
Logger: LoggerConfig{
SlowThreshold: 200 * time.Millisecond,
Colorful: true,
IgnoreRecordNotFoundError: true,
LogLevel: "warn",
LogOutput: "stdout",
},
}
}
// mysqlRepo MySQL实现
type mysqlRepo struct {
read *gorm.DB
write *gorm.DB
config *MySQLConfig
mu sync.RWMutex
closed bool
}
// NewMysql 创建MySQL实例
func NewMysql(cfg *MySQLConfig) (MysqlRepo, error) {
if cfg == nil {
cfg = DefaultMySQLConfig()
}
// 合并默认配置
mergeDefaultConfig(cfg)
// 验证配置
if err := validateMySQLConfig(cfg); err != nil {
return nil, err
}
// 连接读库
dbr, err := dbConnect("read", cfg.Read, cfg.Base, cfg.Logger)
if err != nil {
return nil, fmt.Errorf("连接读库失败: %w", err)
}
// 连接写库
dbw, err := dbConnect("write", cfg.Write, cfg.Base, cfg.Logger)
if err != nil {
_ = closeDB(dbr)
return nil, fmt.Errorf("连接写库失败: %w", err)
}
repo := &mysqlRepo{
read: dbr,
write: dbw,
config: cfg,
closed: false,
}
return repo, nil
}
// DB 获取数据库实例(自动选择读写)
func (m *mysqlRepo) DB(ctx context.Context, forceWrite ...bool) *gorm.DB {
m.mu.RLock()
defer m.mu.RUnlock()
useWrite := len(forceWrite) > 0 && forceWrite[0]
var db *gorm.DB
if useWrite {
db = m.write
} else {
db = m.read
}
if ctx != nil {
db = db.WithContext(ctx)
}
return db
}
// GetRead 获取读库
func (m *mysqlRepo) GetRead() *gorm.DB {
m.mu.RLock()
defer m.mu.RUnlock()
return m.read
}
// GetWrite 获取写库
func (m *mysqlRepo) GetWrite() *gorm.DB {
m.mu.RLock()
defer m.mu.RUnlock()
return m.write
}
// Transaction 执行事务
func (m *mysqlRepo) Transaction(ctx context.Context, fc func(*gorm.DB) error) error {
m.mu.RLock()
if m.closed {
m.mu.RUnlock()
return fmt.Errorf("数据库连接已关闭")
}
db := m.write
m.mu.RUnlock()
if ctx != nil {
db = db.WithContext(ctx)
}
return db.Transaction(fc)
}
// Close 关闭连接
func (m *mysqlRepo) Close() error {
m.mu.Lock()
defer m.mu.Unlock()
if m.closed {
return nil
}
var errs []error
if err := closeDB(m.read); err != nil {
errs = append(errs, fmt.Errorf("关闭读库失败: %w", err))
}
if err := closeDB(m.write); err != nil {
errs = append(errs, fmt.Errorf("关闭写库失败: %w", err))
}
m.closed = true
if len(errs) > 0 {
return errors.Join(errs...)
}
return nil
}
// Ping 检查连接
func (m *mysqlRepo) Ping(ctx context.Context) error {
m.mu.RLock()
defer m.mu.RUnlock()
if m.closed {
return fmt.Errorf("数据库连接已关闭")
}
if ctx == nil {
ctx = context.Background()
}
// Ping读库
readDB, err := m.read.DB()
if err != nil {
return fmt.Errorf("获取读库失败: %w", err)
}
if err := readDB.PingContext(ctx); err != nil {
return fmt.Errorf("Ping读库失败: %w", err)
}
// Ping写库
writeDB, err := m.write.DB()
if err != nil {
return fmt.Errorf("获取写库失败: %w", err)
}
if err := writeDB.PingContext(ctx); err != nil {
return fmt.Errorf("Ping写库失败: %w", err)
}
return nil
}
// Stats 获取连接池统计信息
func (m *mysqlRepo) Stats() (*DBStats, error) {
m.mu.RLock()
defer m.mu.RUnlock()
if m.closed {
return nil, fmt.Errorf("数据库连接已关闭")
}
readDB, err := m.read.DB()
if err != nil {
return nil, fmt.Errorf("获取读库失败: %w", err)
}
writeDB, err := m.write.DB()
if err != nil {
return nil, fmt.Errorf("获取写库失败: %w", err)
}
return &DBStats{
Read: readDB.Stats(),
Write: writeDB.Stats(),
}, nil
}
// HealthCheck 健康检查
func (m *mysqlRepo) HealthCheck(ctx context.Context) error {
if err := m.Ping(ctx); err != nil {
return err
}
stats, err := m.Stats()
if err != nil {
return err
}
// 检查连接池状态
if stats.Read.OpenConnections >= m.config.Base.MaxOpenConn {
return fmt.Errorf("读库连接池已满")
}
if stats.Write.OpenConnections >= m.config.Base.MaxOpenConn {
return fmt.Errorf("写库连接池已满")
}
return nil
}
// mergeDefaultConfig 合并默认配置
func mergeDefaultConfig(cfg *MySQLConfig) {
// 读库默认配置
if cfg.Read.Charset == "" {
cfg.Read.Charset = "utf8mb4"
}
if cfg.Read.Collation == "" {
cfg.Read.Collation = "utf8mb4_unicode_ci"
}
if cfg.Read.Loc == "" {
cfg.Read.Loc = "Local"
}
if cfg.Read.Timeout == 0 {
cfg.Read.Timeout = 10 // 10秒
}
if cfg.Read.ReadTimeout == 0 {
cfg.Read.ReadTimeout = 30 // 30秒
}
if cfg.Read.WriteTimeout == 0 {
cfg.Read.WriteTimeout = 30 // 30秒
}
cfg.Read.ParseTime = true
// 写库默认配置
if cfg.Write.Charset == "" {
cfg.Write.Charset = "utf8mb4"
}
if cfg.Write.Collation == "" {
cfg.Write.Collation = "utf8mb4_unicode_ci"
}
if cfg.Write.Loc == "" {
cfg.Write.Loc = "Local"
}
if cfg.Write.Timeout == 0 {
cfg.Write.Timeout = 10
}
if cfg.Write.ReadTimeout == 0 {
cfg.Write.ReadTimeout = 30
}
if cfg.Write.WriteTimeout == 0 {
cfg.Write.WriteTimeout = 30
}
cfg.Write.ParseTime = true
}
// validateMySQLConfig 验证配置
func validateMySQLConfig(cfg *MySQLConfig) error {
if cfg.Read.Addr == "" {
return fmt.Errorf("读库地址不能为空")
}
if cfg.Read.User == "" {
return fmt.Errorf("读库用户名不能为空")
}
if cfg.Read.Name == "" {
return fmt.Errorf("读库数据库名不能为空")
}
if cfg.Write.Addr == "" {
return fmt.Errorf("写库地址不能为空")
}
if cfg.Write.User == "" {
return fmt.Errorf("写库用户名不能为空")
}
if cfg.Write.Name == "" {
return fmt.Errorf("写库数据库名不能为空")
}
return nil
}
// buildDSN 构建DSN
func buildDSN(conn DBConnConfig) string {
return fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=%s&collation=%s&parseTime=%t&loc=%s&timeout=%ds&readTimeout=%ds&writeTimeout=%ds",
conn.User,
conn.Pass,
conn.Addr,
conn.Name,
conn.Charset,
conn.Collation,
conn.ParseTime,
conn.Loc,
int(conn.Timeout),
int(conn.ReadTimeout),
int(conn.WriteTimeout),
)
}
// dbConnect 连接数据库
func dbConnect(name string, conn DBConnConfig, base BaseConfig, logCfg LoggerConfig) (*gorm.DB, error) {
// 构建DSN
dsn := buildDSN(conn)
// 配置日志
logLevel := parseLogLevel(logCfg.LogLevel)
logWriter := getLogWriter(logCfg)
newLogger := logger.New(
log.New(logWriter, "\r\n", log.LstdFlags),
logger.Config{
SlowThreshold: logCfg.SlowThreshold,
Colorful: logCfg.Colorful,
IgnoreRecordNotFoundError: logCfg.IgnoreRecordNotFoundError,
LogLevel: logLevel,
},
)
// 打开数据库连接
db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{
NamingStrategy: schema.NamingStrategy{
SingularTable: true,
},
Logger: newLogger,
SkipDefaultTransaction: true,
PrepareStmt: true,
DisableForeignKeyConstraintWhenMigrating: true,
})
if err != nil {
return nil, fmt.Errorf("打开数据库连接失败 [%s]: %w", name, err)
}
// 获取底层sqlDB
sqlDB, err := db.DB()
if err != nil {
return nil, fmt.Errorf("获取sqlDB失败: %w", err)
}
// 设置连接池参数
if base.MaxOpenConn > 0 {
sqlDB.SetMaxOpenConns(base.MaxOpenConn)
}
if base.MaxIdleConn > 0 {
sqlDB.SetMaxIdleConns(base.MaxIdleConn)
}
if base.ConnMaxLifetime > 0 {
sqlDB.SetConnMaxLifetime(base.ConnMaxLifetime * time.Second)
}
if base.ConnMaxIdleTime > 0 {
sqlDB.SetConnMaxIdleTime(base.ConnMaxIdleTime * time.Second)
}
return db, nil
}
// closeDB 关闭数据库
func closeDB(db *gorm.DB) error {
if db == nil {
return nil
}
sqlDB, err := db.DB()
if err != nil {
return err
}
return sqlDB.Close()
}
// parseLogLevel 解析日志级别
func parseLogLevel(level string) logger.LogLevel {
switch level {
case "silent":
return logger.Silent
case "error":
return logger.Error
case "warn":
return logger.Warn
case "info":
return logger.Info
default:
return logger.Warn
}
}
// getLogWriter 获取日志输出
func getLogWriter(cfg LoggerConfig) *os.File {
if cfg.LogOutput == "file" && cfg.LogFile != "" {
file, err := os.OpenFile(cfg.LogFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666)
if err != nil {
log.Printf("打开日志文件失败: %v使用标准输出", err)
return os.Stdout
}
return file
}
return os.Stdout
}