248 lines
6.0 KiB
Go
248 lines
6.0 KiB
Go
package database
|
||
|
||
import (
|
||
"errors"
|
||
"fmt"
|
||
"log"
|
||
"os"
|
||
"time"
|
||
|
||
"gitea.bvbej.com/bvbej/base-golang/pkg/time_parse"
|
||
"gitea.bvbej.com/bvbej/base-golang/pkg/trace"
|
||
"gorm.io/driver/mysql"
|
||
"gorm.io/gorm"
|
||
"gorm.io/gorm/logger"
|
||
"gorm.io/gorm/schema"
|
||
"gorm.io/gorm/utils"
|
||
)
|
||
|
||
const (
|
||
callBackBeforeName = "core:before"
|
||
callBackAfterName = "core:after"
|
||
startTime = "_start_time"
|
||
traceCtxName = "_trace_ctx_name"
|
||
)
|
||
|
||
var _ MysqlRepo = (*mysqlRepo)(nil)
|
||
|
||
type MysqlRepo interface {
|
||
i()
|
||
GetRead(options ...Option) *gorm.DB
|
||
GetWrite(options ...Option) *gorm.DB
|
||
Close() error
|
||
}
|
||
|
||
type MySQLConfig struct {
|
||
Read struct {
|
||
Addr string `yaml:"addr"`
|
||
User string `yaml:"user"`
|
||
Pass string `yaml:"pass"`
|
||
Name string `yaml:"name"`
|
||
} `yaml:"read"`
|
||
Write struct {
|
||
Addr string `yaml:"addr"`
|
||
User string `yaml:"user"`
|
||
Pass string `yaml:"pass"`
|
||
Name string `yaml:"name"`
|
||
} `yaml:"write"`
|
||
Base struct {
|
||
MaxOpenConn int `yaml:"maxOpenConn"` //最大连接数
|
||
MaxIdleConn int `yaml:"maxIdleConn"` //最大空闲连接数
|
||
ConnMaxLifeTime time.Duration `yaml:"connMaxLifeTime"` //最大连接超时(分钟)
|
||
} `yaml:"base"`
|
||
}
|
||
|
||
type mysqlRepo struct {
|
||
read *gorm.DB
|
||
write *gorm.DB
|
||
}
|
||
|
||
func NewMysql(cfg MySQLConfig) (MysqlRepo, error) {
|
||
dbr, err := dbConnect(cfg.Read.User, cfg.Read.Pass, cfg.Read.Addr, cfg.Read.Name,
|
||
cfg.Base.MaxOpenConn, cfg.Base.MaxIdleConn, cfg.Base.ConnMaxLifeTime)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
dbw, err := dbConnect(cfg.Write.User, cfg.Write.Pass, cfg.Write.Addr, cfg.Write.Name,
|
||
cfg.Base.MaxOpenConn, cfg.Base.MaxIdleConn, cfg.Base.ConnMaxLifeTime)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
return &mysqlRepo{
|
||
read: dbr,
|
||
write: dbw,
|
||
}, nil
|
||
}
|
||
|
||
func (d *mysqlRepo) i() {}
|
||
|
||
func (d *mysqlRepo) GetRead(options ...Option) *gorm.DB {
|
||
opt := newOption()
|
||
for _, f := range options {
|
||
f(opt)
|
||
}
|
||
|
||
db := d.read
|
||
if opt.Trace != nil {
|
||
db.InstanceSet(traceCtxName, opt.Trace)
|
||
}
|
||
|
||
return db
|
||
}
|
||
|
||
func (d *mysqlRepo) GetWrite(options ...Option) *gorm.DB {
|
||
opt := newOption()
|
||
for _, f := range options {
|
||
f(opt)
|
||
}
|
||
|
||
db := d.write
|
||
if opt.Trace != nil {
|
||
db.InstanceSet(traceCtxName, opt.Trace)
|
||
}
|
||
|
||
return db
|
||
}
|
||
|
||
func (d *mysqlRepo) Close() (err error) {
|
||
rdb, err1 := d.read.DB()
|
||
if err1 != nil {
|
||
err = errors.Join(err1)
|
||
}
|
||
err2 := rdb.Close()
|
||
if err2 != nil {
|
||
err = errors.Join(err2)
|
||
}
|
||
|
||
wdb, err3 := d.write.DB()
|
||
if err3 != nil {
|
||
err = errors.Join(err3)
|
||
}
|
||
err4 := wdb.Close()
|
||
if err4 != nil {
|
||
err = errors.Join(err4)
|
||
}
|
||
|
||
return err
|
||
}
|
||
|
||
func dbConnect(user, pass, addr, dbName string, maxOpenConn, maxIdleConn int, connMaxLifeTime time.Duration) (*gorm.DB, error) {
|
||
dsn := fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=utf8mb4&parseTime=%t&loc=%s",
|
||
user,
|
||
pass,
|
||
addr,
|
||
dbName,
|
||
true,
|
||
"Local")
|
||
|
||
// 日志配置
|
||
newLogger := logger.New(
|
||
log.New(os.Stdout, "\r\n", log.LstdFlags),
|
||
logger.Config{
|
||
SlowThreshold: time.Second, // 慢SQL阈值
|
||
Colorful: true, // 彩色打印
|
||
IgnoreRecordNotFoundError: true, // 忽略记录未找到错误
|
||
LogLevel: logger.Error, // 日志级别
|
||
},
|
||
)
|
||
|
||
db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{
|
||
NamingStrategy: schema.NamingStrategy{
|
||
SingularTable: true,
|
||
},
|
||
Logger: newLogger,
|
||
})
|
||
|
||
if err != nil {
|
||
return nil, errors.Join(err, fmt.Errorf("[db connection failed] Database name: %s", dbName))
|
||
}
|
||
|
||
db.Set("gorm:table_options", "CHARSET=utf8mb4")
|
||
|
||
sqlDB, err := db.DB()
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 设置连接池 用于设置最大打开的连接数,默认值为0表示不限制.设置最大的连接数,可以避免并发太高导致连接mysql出现too many connections的错误。
|
||
sqlDB.SetMaxOpenConns(maxOpenConn)
|
||
|
||
// 设置最大连接数 用于设置闲置的连接数.设置闲置的连接数则当开启的一个连接使用完成后可以放在池里等候下一次使用。
|
||
sqlDB.SetMaxIdleConns(maxIdleConn)
|
||
|
||
// 设置最大连接超时
|
||
sqlDB.SetConnMaxLifetime(time.Minute * connMaxLifeTime)
|
||
|
||
// 使用插件
|
||
err = db.Use(&TracePlugin{})
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
return db, nil
|
||
}
|
||
|
||
/***************************************************************/
|
||
|
||
type TracePlugin struct{}
|
||
|
||
func (op *TracePlugin) Name() string {
|
||
return "TracePlugin"
|
||
}
|
||
|
||
func (op *TracePlugin) Initialize(db *gorm.DB) (err error) {
|
||
// 开始前
|
||
_ = db.Callback().Create().Before("gorm:before_create").Register(callBackBeforeName, before)
|
||
_ = db.Callback().Query().Before("gorm:query").Register(callBackBeforeName, before)
|
||
_ = db.Callback().Delete().Before("gorm:before_delete").Register(callBackBeforeName, before)
|
||
_ = db.Callback().Update().Before("gorm:setup_reflect_value").Register(callBackBeforeName, before)
|
||
_ = db.Callback().Row().Before("gorm:row").Register(callBackBeforeName, before)
|
||
_ = db.Callback().Raw().Before("gorm:raw").Register(callBackBeforeName, before)
|
||
|
||
// 结束后
|
||
_ = db.Callback().Create().After("gorm:after_create").Register(callBackAfterName, after)
|
||
_ = db.Callback().Query().After("gorm:after_query").Register(callBackAfterName, after)
|
||
_ = db.Callback().Delete().After("gorm:after_delete").Register(callBackAfterName, after)
|
||
_ = db.Callback().Update().After("gorm:after_update").Register(callBackAfterName, after)
|
||
_ = db.Callback().Row().After("gorm:row").Register(callBackAfterName, after)
|
||
_ = db.Callback().Raw().After("gorm:raw").Register(callBackAfterName, after)
|
||
return
|
||
}
|
||
|
||
func before(db *gorm.DB) {
|
||
db.InstanceSet(startTime, time.Now())
|
||
}
|
||
|
||
func after(db *gorm.DB) {
|
||
_traceCtx, isExist := db.InstanceGet(traceCtxName)
|
||
if !isExist {
|
||
return
|
||
}
|
||
_trace, ok := _traceCtx.(trace.T)
|
||
if !ok {
|
||
return
|
||
}
|
||
|
||
_ts, isExist := db.InstanceGet(startTime)
|
||
if !isExist {
|
||
return
|
||
}
|
||
|
||
ts, ok := _ts.(time.Time)
|
||
if !ok {
|
||
return
|
||
}
|
||
|
||
sql := db.Dialector.Explain(db.Statement.SQL.String(), db.Statement.Vars...)
|
||
|
||
sqlInfo := new(trace.SQL)
|
||
sqlInfo.Timestamp = time_parse.CSTLayoutString()
|
||
sqlInfo.SQL = sql
|
||
sqlInfo.Stack = utils.FileWithLineNum()
|
||
sqlInfo.Rows = db.Statement.RowsAffected
|
||
sqlInfo.CostSeconds = time.Since(ts).Seconds()
|
||
_trace.AppendSQL(sqlInfo)
|
||
}
|