500 lines
12 KiB
Go
500 lines
12 KiB
Go
package database
|
||
|
||
import (
|
||
"context"
|
||
"database/sql"
|
||
"errors"
|
||
"fmt"
|
||
"log"
|
||
"os"
|
||
"sync"
|
||
"time"
|
||
|
||
"gorm.io/driver/mysql"
|
||
"gorm.io/gorm"
|
||
"gorm.io/gorm/logger"
|
||
"gorm.io/gorm/schema"
|
||
)
|
||
|
||
// MysqlRepo MySQL接口
|
||
type MysqlRepo interface {
|
||
// DB 获取数据库实例(自动选择读写)
|
||
DB(ctx context.Context, forceWrite ...bool) *gorm.DB
|
||
|
||
// GetRead 获取读库
|
||
GetRead() *gorm.DB
|
||
|
||
// GetWrite 获取写库
|
||
GetWrite() *gorm.DB
|
||
|
||
// Transaction 执行事务
|
||
Transaction(ctx context.Context, fc func(*gorm.DB) error) error
|
||
|
||
// Close 关闭连接
|
||
Close() error
|
||
|
||
// Ping 检查连接
|
||
Ping(ctx context.Context) error
|
||
|
||
// Stats 获取连接池统计信息
|
||
Stats() (*DBStats, error)
|
||
|
||
// HealthCheck 健康检查
|
||
HealthCheck(ctx context.Context) error
|
||
}
|
||
|
||
// DBStats 数据库统计信息
|
||
type DBStats struct {
|
||
Read sql.DBStats `json:"read"`
|
||
Write sql.DBStats `json:"write"`
|
||
}
|
||
|
||
// MySQLConfig MySQL配置
|
||
type MySQLConfig struct {
|
||
Read DBConnConfig `yaml:"read" json:"read"`
|
||
Write DBConnConfig `yaml:"write" json:"write"`
|
||
Base BaseConfig `yaml:"base" json:"base"`
|
||
Logger LoggerConfig `yaml:"logger" json:"logger"`
|
||
}
|
||
|
||
// DBConnConfig 数据库连接配置
|
||
type DBConnConfig struct {
|
||
Addr string `yaml:"addr" json:"addr"`
|
||
User string `yaml:"user" json:"user"`
|
||
Pass string `yaml:"pass" json:"pass"`
|
||
Name string `yaml:"name" json:"name"`
|
||
Charset string `yaml:"charset" json:"charset"` // 默认 utf8mb4
|
||
Collation string `yaml:"collation" json:"collation"` // 默认 utf8mb4_unicode_ci
|
||
Loc string `yaml:"loc" json:"loc"` // 默认 Local
|
||
ParseTime bool `yaml:"parse_time" json:"parse_time"` // 默认 true
|
||
Timeout time.Duration `yaml:"timeout" json:"timeout"` // 连接超时(秒)
|
||
ReadTimeout time.Duration `yaml:"read_timeout" json:"read_timeout"` // 读超时(秒)
|
||
WriteTimeout time.Duration `yaml:"write_timeout" json:"write_timeout"` // 写超时(秒)
|
||
}
|
||
|
||
// BaseConfig 基础配置
|
||
type BaseConfig struct {
|
||
MaxOpenConn int `yaml:"max_open_conn" json:"max_open_conn"`
|
||
MaxIdleConn int `yaml:"max_idle_conn" json:"max_idle_conn"`
|
||
ConnMaxLifetime time.Duration `yaml:"conn_max_lifetime" json:"conn_max_lifetime"` // 秒
|
||
ConnMaxIdleTime time.Duration `yaml:"conn_max_idle_time" json:"conn_max_idle_time"` // 秒
|
||
}
|
||
|
||
// LoggerConfig 日志配置
|
||
type LoggerConfig struct {
|
||
SlowThreshold time.Duration `yaml:"slow_threshold" json:"slow_threshold"`
|
||
Colorful bool `yaml:"colorful" json:"colorful"`
|
||
IgnoreRecordNotFoundError bool `yaml:"ignore_record_not_found_error" json:"ignore_record_not_found_error"`
|
||
LogLevel string `yaml:"log_level" json:"log_level"`
|
||
LogOutput string `yaml:"log_output" json:"log_output"` // stdout, file
|
||
LogFile string `yaml:"log_file" json:"log_file"` // 日志文件路径
|
||
}
|
||
|
||
// DefaultMySQLConfig 默认配置
|
||
func DefaultMySQLConfig() *MySQLConfig {
|
||
return &MySQLConfig{
|
||
Base: BaseConfig{
|
||
MaxOpenConn: 100,
|
||
MaxIdleConn: 10,
|
||
ConnMaxLifetime: 3600, // 1小时
|
||
ConnMaxIdleTime: 600, // 10分钟
|
||
},
|
||
Logger: LoggerConfig{
|
||
SlowThreshold: 200 * time.Millisecond,
|
||
Colorful: true,
|
||
IgnoreRecordNotFoundError: true,
|
||
LogLevel: "warn",
|
||
LogOutput: "stdout",
|
||
},
|
||
}
|
||
}
|
||
|
||
// mysqlRepo MySQL实现
|
||
type mysqlRepo struct {
|
||
read *gorm.DB
|
||
write *gorm.DB
|
||
config *MySQLConfig
|
||
mu sync.RWMutex
|
||
closed bool
|
||
}
|
||
|
||
// NewMysql 创建MySQL实例
|
||
func NewMysql(cfg *MySQLConfig) (MysqlRepo, error) {
|
||
if cfg == nil {
|
||
cfg = DefaultMySQLConfig()
|
||
}
|
||
|
||
// 合并默认配置
|
||
mergeDefaultConfig(cfg)
|
||
|
||
// 验证配置
|
||
if err := validateMySQLConfig(cfg); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 连接读库
|
||
dbr, err := dbConnect("read", cfg.Read, cfg.Base, cfg.Logger)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("连接读库失败: %w", err)
|
||
}
|
||
|
||
// 连接写库
|
||
dbw, err := dbConnect("write", cfg.Write, cfg.Base, cfg.Logger)
|
||
if err != nil {
|
||
_ = closeDB(dbr)
|
||
return nil, fmt.Errorf("连接写库失败: %w", err)
|
||
}
|
||
|
||
repo := &mysqlRepo{
|
||
read: dbr,
|
||
write: dbw,
|
||
config: cfg,
|
||
closed: false,
|
||
}
|
||
|
||
return repo, nil
|
||
}
|
||
|
||
// DB 获取数据库实例(自动选择读写)
|
||
func (m *mysqlRepo) DB(ctx context.Context, forceWrite ...bool) *gorm.DB {
|
||
m.mu.RLock()
|
||
defer m.mu.RUnlock()
|
||
|
||
useWrite := len(forceWrite) > 0 && forceWrite[0]
|
||
|
||
var db *gorm.DB
|
||
if useWrite {
|
||
db = m.write
|
||
} else {
|
||
db = m.read
|
||
}
|
||
|
||
if ctx != nil {
|
||
db = db.WithContext(ctx)
|
||
}
|
||
|
||
return db
|
||
}
|
||
|
||
// GetRead 获取读库
|
||
func (m *mysqlRepo) GetRead() *gorm.DB {
|
||
m.mu.RLock()
|
||
defer m.mu.RUnlock()
|
||
return m.read
|
||
}
|
||
|
||
// GetWrite 获取写库
|
||
func (m *mysqlRepo) GetWrite() *gorm.DB {
|
||
m.mu.RLock()
|
||
defer m.mu.RUnlock()
|
||
return m.write
|
||
}
|
||
|
||
// Transaction 执行事务
|
||
func (m *mysqlRepo) Transaction(ctx context.Context, fc func(*gorm.DB) error) error {
|
||
m.mu.RLock()
|
||
if m.closed {
|
||
m.mu.RUnlock()
|
||
return fmt.Errorf("数据库连接已关闭")
|
||
}
|
||
db := m.write
|
||
m.mu.RUnlock()
|
||
|
||
if ctx != nil {
|
||
db = db.WithContext(ctx)
|
||
}
|
||
|
||
return db.Transaction(fc)
|
||
}
|
||
|
||
// Close 关闭连接
|
||
func (m *mysqlRepo) Close() error {
|
||
m.mu.Lock()
|
||
defer m.mu.Unlock()
|
||
|
||
if m.closed {
|
||
return nil
|
||
}
|
||
|
||
var errs []error
|
||
if err := closeDB(m.read); err != nil {
|
||
errs = append(errs, fmt.Errorf("关闭读库失败: %w", err))
|
||
}
|
||
if err := closeDB(m.write); err != nil {
|
||
errs = append(errs, fmt.Errorf("关闭写库失败: %w", err))
|
||
}
|
||
|
||
m.closed = true
|
||
|
||
if len(errs) > 0 {
|
||
return errors.Join(errs...)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// Ping 检查连接
|
||
func (m *mysqlRepo) Ping(ctx context.Context) error {
|
||
m.mu.RLock()
|
||
defer m.mu.RUnlock()
|
||
|
||
if m.closed {
|
||
return fmt.Errorf("数据库连接已关闭")
|
||
}
|
||
|
||
if ctx == nil {
|
||
ctx = context.Background()
|
||
}
|
||
|
||
// Ping读库
|
||
readDB, err := m.read.DB()
|
||
if err != nil {
|
||
return fmt.Errorf("获取读库失败: %w", err)
|
||
}
|
||
if err := readDB.PingContext(ctx); err != nil {
|
||
return fmt.Errorf("Ping读库失败: %w", err)
|
||
}
|
||
|
||
// Ping写库
|
||
writeDB, err := m.write.DB()
|
||
if err != nil {
|
||
return fmt.Errorf("获取写库失败: %w", err)
|
||
}
|
||
if err := writeDB.PingContext(ctx); err != nil {
|
||
return fmt.Errorf("Ping写库失败: %w", err)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// Stats 获取连接池统计信息
|
||
func (m *mysqlRepo) Stats() (*DBStats, error) {
|
||
m.mu.RLock()
|
||
defer m.mu.RUnlock()
|
||
|
||
if m.closed {
|
||
return nil, fmt.Errorf("数据库连接已关闭")
|
||
}
|
||
|
||
readDB, err := m.read.DB()
|
||
if err != nil {
|
||
return nil, fmt.Errorf("获取读库失败: %w", err)
|
||
}
|
||
|
||
writeDB, err := m.write.DB()
|
||
if err != nil {
|
||
return nil, fmt.Errorf("获取写库失败: %w", err)
|
||
}
|
||
|
||
return &DBStats{
|
||
Read: readDB.Stats(),
|
||
Write: writeDB.Stats(),
|
||
}, nil
|
||
}
|
||
|
||
// HealthCheck 健康检查
|
||
func (m *mysqlRepo) HealthCheck(ctx context.Context) error {
|
||
if err := m.Ping(ctx); err != nil {
|
||
return err
|
||
}
|
||
|
||
stats, err := m.Stats()
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
// 检查连接池状态
|
||
if stats.Read.OpenConnections >= m.config.Base.MaxOpenConn {
|
||
return fmt.Errorf("读库连接池已满")
|
||
}
|
||
if stats.Write.OpenConnections >= m.config.Base.MaxOpenConn {
|
||
return fmt.Errorf("写库连接池已满")
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// mergeDefaultConfig 合并默认配置
|
||
func mergeDefaultConfig(cfg *MySQLConfig) {
|
||
// 读库默认配置
|
||
if cfg.Read.Charset == "" {
|
||
cfg.Read.Charset = "utf8mb4"
|
||
}
|
||
if cfg.Read.Collation == "" {
|
||
cfg.Read.Collation = "utf8mb4_unicode_ci"
|
||
}
|
||
if cfg.Read.Loc == "" {
|
||
cfg.Read.Loc = "Local"
|
||
}
|
||
if cfg.Read.Timeout == 0 {
|
||
cfg.Read.Timeout = 10 // 10秒
|
||
}
|
||
if cfg.Read.ReadTimeout == 0 {
|
||
cfg.Read.ReadTimeout = 30 // 30秒
|
||
}
|
||
if cfg.Read.WriteTimeout == 0 {
|
||
cfg.Read.WriteTimeout = 30 // 30秒
|
||
}
|
||
cfg.Read.ParseTime = true
|
||
|
||
// 写库默认配置
|
||
if cfg.Write.Charset == "" {
|
||
cfg.Write.Charset = "utf8mb4"
|
||
}
|
||
if cfg.Write.Collation == "" {
|
||
cfg.Write.Collation = "utf8mb4_unicode_ci"
|
||
}
|
||
if cfg.Write.Loc == "" {
|
||
cfg.Write.Loc = "Local"
|
||
}
|
||
if cfg.Write.Timeout == 0 {
|
||
cfg.Write.Timeout = 10
|
||
}
|
||
if cfg.Write.ReadTimeout == 0 {
|
||
cfg.Write.ReadTimeout = 30
|
||
}
|
||
if cfg.Write.WriteTimeout == 0 {
|
||
cfg.Write.WriteTimeout = 30
|
||
}
|
||
cfg.Write.ParseTime = true
|
||
}
|
||
|
||
// validateMySQLConfig 验证配置
|
||
func validateMySQLConfig(cfg *MySQLConfig) error {
|
||
if cfg.Read.Addr == "" {
|
||
return fmt.Errorf("读库地址不能为空")
|
||
}
|
||
if cfg.Read.User == "" {
|
||
return fmt.Errorf("读库用户名不能为空")
|
||
}
|
||
if cfg.Read.Name == "" {
|
||
return fmt.Errorf("读库数据库名不能为空")
|
||
}
|
||
|
||
if cfg.Write.Addr == "" {
|
||
return fmt.Errorf("写库地址不能为空")
|
||
}
|
||
if cfg.Write.User == "" {
|
||
return fmt.Errorf("写库用户名不能为空")
|
||
}
|
||
if cfg.Write.Name == "" {
|
||
return fmt.Errorf("写库数据库名不能为空")
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// buildDSN 构建DSN
|
||
func buildDSN(conn DBConnConfig) string {
|
||
return fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=%s&collation=%s&parseTime=%t&loc=%s&timeout=%ds&readTimeout=%ds&writeTimeout=%ds",
|
||
conn.User,
|
||
conn.Pass,
|
||
conn.Addr,
|
||
conn.Name,
|
||
conn.Charset,
|
||
conn.Collation,
|
||
conn.ParseTime,
|
||
conn.Loc,
|
||
int(conn.Timeout),
|
||
int(conn.ReadTimeout),
|
||
int(conn.WriteTimeout),
|
||
)
|
||
}
|
||
|
||
// dbConnect 连接数据库
|
||
func dbConnect(name string, conn DBConnConfig, base BaseConfig, logCfg LoggerConfig) (*gorm.DB, error) {
|
||
// 构建DSN
|
||
dsn := buildDSN(conn)
|
||
|
||
// 配置日志
|
||
logLevel := parseLogLevel(logCfg.LogLevel)
|
||
logWriter := getLogWriter(logCfg)
|
||
|
||
newLogger := logger.New(
|
||
log.New(logWriter, "\r\n", log.LstdFlags),
|
||
logger.Config{
|
||
SlowThreshold: logCfg.SlowThreshold,
|
||
Colorful: logCfg.Colorful,
|
||
IgnoreRecordNotFoundError: logCfg.IgnoreRecordNotFoundError,
|
||
LogLevel: logLevel,
|
||
},
|
||
)
|
||
|
||
// 打开数据库连接
|
||
db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{
|
||
NamingStrategy: schema.NamingStrategy{
|
||
SingularTable: true,
|
||
},
|
||
Logger: newLogger,
|
||
SkipDefaultTransaction: true,
|
||
PrepareStmt: true,
|
||
DisableForeignKeyConstraintWhenMigrating: true,
|
||
})
|
||
|
||
if err != nil {
|
||
return nil, fmt.Errorf("打开数据库连接失败 [%s]: %w", name, err)
|
||
}
|
||
|
||
// 获取底层sqlDB
|
||
sqlDB, err := db.DB()
|
||
if err != nil {
|
||
return nil, fmt.Errorf("获取sqlDB失败: %w", err)
|
||
}
|
||
|
||
// 设置连接池参数
|
||
if base.MaxOpenConn > 0 {
|
||
sqlDB.SetMaxOpenConns(base.MaxOpenConn)
|
||
}
|
||
if base.MaxIdleConn > 0 {
|
||
sqlDB.SetMaxIdleConns(base.MaxIdleConn)
|
||
}
|
||
if base.ConnMaxLifetime > 0 {
|
||
sqlDB.SetConnMaxLifetime(base.ConnMaxLifetime * time.Second)
|
||
}
|
||
if base.ConnMaxIdleTime > 0 {
|
||
sqlDB.SetConnMaxIdleTime(base.ConnMaxIdleTime * time.Second)
|
||
}
|
||
|
||
return db, nil
|
||
}
|
||
|
||
// closeDB 关闭数据库
|
||
func closeDB(db *gorm.DB) error {
|
||
if db == nil {
|
||
return nil
|
||
}
|
||
sqlDB, err := db.DB()
|
||
if err != nil {
|
||
return err
|
||
}
|
||
return sqlDB.Close()
|
||
}
|
||
|
||
// parseLogLevel 解析日志级别
|
||
func parseLogLevel(level string) logger.LogLevel {
|
||
switch level {
|
||
case "silent":
|
||
return logger.Silent
|
||
case "error":
|
||
return logger.Error
|
||
case "warn":
|
||
return logger.Warn
|
||
case "info":
|
||
return logger.Info
|
||
default:
|
||
return logger.Warn
|
||
}
|
||
}
|
||
|
||
// getLogWriter 获取日志输出
|
||
func getLogWriter(cfg LoggerConfig) *os.File {
|
||
if cfg.LogOutput == "file" && cfg.LogFile != "" {
|
||
file, err := os.OpenFile(cfg.LogFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666)
|
||
if err != nil {
|
||
log.Printf("打开日志文件失败: %v,使用标准输出", err)
|
||
return os.Stdout
|
||
}
|
||
return file
|
||
}
|
||
return os.Stdout
|
||
}
|