[🚀] mysql

This commit is contained in:
2026-02-07 15:36:21 +08:00
parent 078967c601
commit 3005002379
6 changed files with 818 additions and 340 deletions

View File

@@ -3,75 +3,262 @@ package database
import (
"context"
"fmt"
"time"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
"go.mongodb.org/mongo-driver/mongo/readpref"
"time"
)
var _ MongoDB = (*mongoDB)(nil)
// MongoDB MongoDB接口
type MongoDB interface {
i()
// GetDB 获取数据库实例
GetDB() *mongo.Database
Close() error
// GetClient 获取客户端实例
GetClient() *mongo.Client
// GetCollection 获取集合
GetCollection(name string) *mongo.Collection
// Ping 检查连接
Ping(ctx context.Context) error
// Close 关闭连接
Close(ctx context.Context) error
// WithContext 创建带超时的上下文
WithContext(ctx context.Context) (context.Context, context.CancelFunc)
}
// MongoDBConfig MongoDB配置
type MongoDBConfig struct {
Addr string `yaml:"addr"`
User string `yaml:"user"`
Pass string `yaml:"pass"`
Name string `yaml:"name"`
Timeout time.Duration `yaml:"timeout"`
// 地址,支持多个地址: "localhost:27017,localhost:27018"
Addr string `yaml:"addr" json:"addr"`
// 用户名
User string `yaml:"user" json:"user"`
// 密码
Pass string `yaml:"pass" json:"pass"`
// 数据库名称
Name string `yaml:"name" json:"name"`
// 连接超时(秒)
Timeout time.Duration `yaml:"timeout" json:"timeout"`
// 最大连接池大小
MaxPoolSize uint64 `yaml:"max_pool_size" json:"max_pool_size"`
// 最小连接池大小
MinPoolSize uint64 `yaml:"min_pool_size" json:"min_pool_size"`
// 最大连接空闲时间(秒)
MaxConnIdleTime time.Duration `yaml:"max_conn_idle_time" json:"max_conn_idle_time"`
// 是否使用副本集
ReplicaSet string `yaml:"replica_set" json:"replica_set"`
// 是否使用TLS
UseTLS bool `yaml:"use_tls" json:"use_tls"`
// 认证数据库
AuthSource string `yaml:"auth_source" json:"auth_source"`
}
// DefaultMongoDBConfig 默认配置
func DefaultMongoDBConfig() *MongoDBConfig {
return &MongoDBConfig{
Timeout: 10, // 10秒
MaxPoolSize: 100,
MinPoolSize: 10,
MaxConnIdleTime: 60, // 60秒
AuthSource: "admin",
}
}
// mongoDB MongoDB实现
type mongoDB struct {
client *mongo.Client
db *mongo.Database
config *MongoDBConfig
timeout time.Duration
}
func (m *mongoDB) i() {}
// NewMongoDB 创建MongoDB实例
func NewMongoDB(cfg *MongoDBConfig) (MongoDB, error) {
if cfg == nil {
return nil, fmt.Errorf("配置不能为空")
}
// 验证必填参数
if cfg.Addr == "" {
return nil, fmt.Errorf("地址不能为空")
}
if cfg.Name == "" {
return nil, fmt.Errorf("数据库名称不能为空")
}
func NewMongoDB(cfg MongoDBConfig) (MongoDB, error) {
timeout := cfg.Timeout * time.Second
connectCtx, connectCancelFunc := context.WithTimeout(context.Background(), timeout)
defer connectCancelFunc()
var auth string
if len(cfg.User) > 0 && len(cfg.Pass) > 0 {
auth = fmt.Sprintf("%s:%s@", cfg.User, cfg.Pass)
}
client, err := mongo.Connect(connectCtx, options.Client().ApplyURI(
fmt.Sprintf("mongodb://%s%s", auth, cfg.Addr),
))
if err != nil {
return nil, err
if timeout == 0 {
timeout = 10 * time.Second
}
pingCtx, pingCancelFunc := context.WithTimeout(context.Background(), timeout)
defer pingCancelFunc()
err = client.Ping(pingCtx, readpref.Primary())
// 构建连接选项
clientOpts := options.Client()
// 构建URI
uri := buildMongoURI(cfg)
clientOpts.ApplyURI(uri)
// 设置连接池
if cfg.MaxPoolSize > 0 {
clientOpts.SetMaxPoolSize(cfg.MaxPoolSize)
}
if cfg.MinPoolSize > 0 {
clientOpts.SetMinPoolSize(cfg.MinPoolSize)
}
if cfg.MaxConnIdleTime > 0 {
clientOpts.SetMaxConnIdleTime(cfg.MaxConnIdleTime * time.Second)
}
// 设置超时
clientOpts.SetConnectTimeout(timeout)
clientOpts.SetServerSelectionTimeout(timeout)
// 连接MongoDB
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
client, err := mongo.Connect(ctx, clientOpts)
if err != nil {
return nil, err
return nil, fmt.Errorf("连接MongoDB失败: %w", err)
}
// Ping测试连接
pingCtx, pingCancel := context.WithTimeout(context.Background(), timeout)
defer pingCancel()
if err = client.Ping(pingCtx, readpref.Primary()); err != nil {
_ = client.Disconnect(context.Background())
return nil, fmt.Errorf("Ping MongoDB失败: %w", err)
}
return &mongoDB{
client: client,
db: client.Database(cfg.Name),
config: cfg,
timeout: timeout,
}, nil
}
// GetDB 获取数据库实例
func (m *mongoDB) GetDB() *mongo.Database {
return m.db
}
func (m *mongoDB) Close() error {
disconnectCtx, disconnectCancelFunc := context.WithTimeout(context.Background(), m.timeout)
defer disconnectCancelFunc()
err := m.client.Disconnect(disconnectCtx)
// GetClient 获取客户端实例
func (m *mongoDB) GetClient() *mongo.Client {
return m.client
}
// GetCollection 获取集合
func (m *mongoDB) GetCollection(name string) *mongo.Collection {
return m.db.Collection(name)
}
// Ping 检查连接
func (m *mongoDB) Ping(ctx context.Context) error {
pingCtx, cancel := m.WithContext(ctx)
defer cancel()
return m.client.Ping(pingCtx, readpref.Primary())
}
// Close 关闭连接
func (m *mongoDB) Close(ctx context.Context) error {
if ctx == nil {
ctx = context.Background()
}
disconnectCtx, cancel := context.WithTimeout(ctx, m.timeout)
defer cancel()
if err := m.client.Disconnect(disconnectCtx); err != nil {
return fmt.Errorf("断开MongoDB连接失败: %w", err)
}
return nil
}
// WithContext 创建带超时的上下文
func (m *mongoDB) WithContext(ctx context.Context) (context.Context, context.CancelFunc) {
if ctx == nil {
ctx = context.Background()
}
return context.WithTimeout(ctx, m.timeout)
}
// buildMongoURI 构建MongoDB URI
func buildMongoURI(cfg *MongoDBConfig) string {
var auth string
if cfg.User != "" && cfg.Pass != "" {
auth = fmt.Sprintf("%s:%s@", cfg.User, cfg.Pass)
}
uri := fmt.Sprintf("mongodb://%s%s/%s", auth, cfg.Addr, cfg.Name)
// 添加查询参数
params := []string{}
if cfg.AuthSource != "" {
params = append(params, fmt.Sprintf("authSource=%s", cfg.AuthSource))
}
if cfg.ReplicaSet != "" {
params = append(params, fmt.Sprintf("replicaSet=%s", cfg.ReplicaSet))
}
if cfg.UseTLS {
params = append(params, "tls=true")
}
if len(params) > 0 {
uri += "?"
for i, param := range params {
if i > 0 {
uri += "&"
}
uri += param
}
}
return uri
}
// Helper functions
// EnsureIndexes 确保索引存在
func EnsureIndexes(ctx context.Context, collection *mongo.Collection, indexes []mongo.IndexModel) error {
if len(indexes) == 0 {
return nil
}
_, err := collection.Indexes().CreateMany(ctx, indexes)
if err != nil {
return err
return fmt.Errorf("创建索引失败: %w", err)
}
return nil
}
// DropIndexes 删除索引
func DropIndexes(ctx context.Context, collection *mongo.Collection, indexNames []string) error {
for _, name := range indexNames {
if _, err := collection.Indexes().DropOne(ctx, name); err != nil {
return fmt.Errorf("删除索引 %s 失败: %w", name, err)
}
}
return nil
}

View File

@@ -1,247 +1,499 @@
package database
import (
"context"
"database/sql"
"errors"
"fmt"
"log"
"os"
"sync"
"time"
"gitea.bvbej.com/bvbej/base-golang/pkg/time_parse"
"gitea.bvbej.com/bvbej/base-golang/pkg/trace"
"gorm.io/driver/mysql"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils"
)
const (
callBackBeforeName = "core:before"
callBackAfterName = "core:after"
startTime = "_start_time"
traceCtxName = "_trace_ctx_name"
)
var _ MysqlRepo = (*mysqlRepo)(nil)
// MysqlRepo MySQL接口
type MysqlRepo interface {
i()
GetRead(options ...Option) *gorm.DB
GetWrite(options ...Option) *gorm.DB
// DB 获取数据库实例(自动选择读写)
DB(ctx context.Context, forceWrite ...bool) *gorm.DB
// GetRead 获取读库
GetRead() *gorm.DB
// GetWrite 获取写库
GetWrite() *gorm.DB
// Transaction 执行事务
Transaction(ctx context.Context, fc func(*gorm.DB) error) error
// Close 关闭连接
Close() error
// Ping 检查连接
Ping(ctx context.Context) error
// Stats 获取连接池统计信息
Stats() (*DBStats, error)
// HealthCheck 健康检查
HealthCheck(ctx context.Context) error
}
// DBStats 数据库统计信息
type DBStats struct {
Read sql.DBStats `json:"read"`
Write sql.DBStats `json:"write"`
}
// MySQLConfig MySQL配置
type MySQLConfig struct {
Read struct {
Addr string `yaml:"addr"`
User string `yaml:"user"`
Pass string `yaml:"pass"`
Name string `yaml:"name"`
} `yaml:"read"`
Write struct {
Addr string `yaml:"addr"`
User string `yaml:"user"`
Pass string `yaml:"pass"`
Name string `yaml:"name"`
} `yaml:"write"`
Base struct {
MaxOpenConn int `yaml:"maxOpenConn"` //最大连接数
MaxIdleConn int `yaml:"maxIdleConn"` //最大空闲连接数
ConnMaxLifeTime time.Duration `yaml:"connMaxLifeTime"` //最大连接超时(分钟)
} `yaml:"base"`
Read DBConnConfig `yaml:"read" json:"read"`
Write DBConnConfig `yaml:"write" json:"write"`
Base BaseConfig `yaml:"base" json:"base"`
Logger LoggerConfig `yaml:"logger" json:"logger"`
}
// DBConnConfig 数据库连接配置
type DBConnConfig struct {
Addr string `yaml:"addr" json:"addr"`
User string `yaml:"user" json:"user"`
Pass string `yaml:"pass" json:"pass"`
Name string `yaml:"name" json:"name"`
Charset string `yaml:"charset" json:"charset"` // 默认 utf8mb4
Collation string `yaml:"collation" json:"collation"` // 默认 utf8mb4_unicode_ci
Loc string `yaml:"loc" json:"loc"` // 默认 Local
ParseTime bool `yaml:"parse_time" json:"parse_time"` // 默认 true
Timeout time.Duration `yaml:"timeout" json:"timeout"` // 连接超时(秒)
ReadTimeout time.Duration `yaml:"read_timeout" json:"read_timeout"` // 读超时(秒)
WriteTimeout time.Duration `yaml:"write_timeout" json:"write_timeout"` // 写超时(秒)
}
// BaseConfig 基础配置
type BaseConfig struct {
MaxOpenConn int `yaml:"max_open_conn" json:"max_open_conn"`
MaxIdleConn int `yaml:"max_idle_conn" json:"max_idle_conn"`
ConnMaxLifetime time.Duration `yaml:"conn_max_lifetime" json:"conn_max_lifetime"` // 秒
ConnMaxIdleTime time.Duration `yaml:"conn_max_idle_time" json:"conn_max_idle_time"` // 秒
}
// LoggerConfig 日志配置
type LoggerConfig struct {
SlowThreshold time.Duration `yaml:"slow_threshold" json:"slow_threshold"`
Colorful bool `yaml:"colorful" json:"colorful"`
IgnoreRecordNotFoundError bool `yaml:"ignore_record_not_found_error" json:"ignore_record_not_found_error"`
LogLevel string `yaml:"log_level" json:"log_level"`
LogOutput string `yaml:"log_output" json:"log_output"` // stdout, file
LogFile string `yaml:"log_file" json:"log_file"` // 日志文件路径
}
// DefaultMySQLConfig 默认配置
func DefaultMySQLConfig() *MySQLConfig {
return &MySQLConfig{
Base: BaseConfig{
MaxOpenConn: 100,
MaxIdleConn: 10,
ConnMaxLifetime: 3600, // 1小时
ConnMaxIdleTime: 600, // 10分钟
},
Logger: LoggerConfig{
SlowThreshold: 200 * time.Millisecond,
Colorful: true,
IgnoreRecordNotFoundError: true,
LogLevel: "warn",
LogOutput: "stdout",
},
}
}
// mysqlRepo MySQL实现
type mysqlRepo struct {
read *gorm.DB
write *gorm.DB
read *gorm.DB
write *gorm.DB
config *MySQLConfig
mu sync.RWMutex
closed bool
}
func NewMysql(cfg MySQLConfig) (MysqlRepo, error) {
dbr, err := dbConnect(cfg.Read.User, cfg.Read.Pass, cfg.Read.Addr, cfg.Read.Name,
cfg.Base.MaxOpenConn, cfg.Base.MaxIdleConn, cfg.Base.ConnMaxLifeTime)
if err != nil {
// NewMysql 创建MySQL实例
func NewMysql(cfg *MySQLConfig) (MysqlRepo, error) {
if cfg == nil {
cfg = DefaultMySQLConfig()
}
// 合并默认配置
mergeDefaultConfig(cfg)
// 验证配置
if err := validateMySQLConfig(cfg); err != nil {
return nil, err
}
dbw, err := dbConnect(cfg.Write.User, cfg.Write.Pass, cfg.Write.Addr, cfg.Write.Name,
cfg.Base.MaxOpenConn, cfg.Base.MaxIdleConn, cfg.Base.ConnMaxLifeTime)
// 连接读库
dbr, err := dbConnect("read", cfg.Read, cfg.Base, cfg.Logger)
if err != nil {
return nil, err
return nil, fmt.Errorf("连接读库失败: %w", err)
}
return &mysqlRepo{
read: dbr,
write: dbw,
// 连接写库
dbw, err := dbConnect("write", cfg.Write, cfg.Base, cfg.Logger)
if err != nil {
_ = closeDB(dbr)
return nil, fmt.Errorf("连接写库失败: %w", err)
}
repo := &mysqlRepo{
read: dbr,
write: dbw,
config: cfg,
closed: false,
}
return repo, nil
}
// DB 获取数据库实例(自动选择读写)
func (m *mysqlRepo) DB(ctx context.Context, forceWrite ...bool) *gorm.DB {
m.mu.RLock()
defer m.mu.RUnlock()
useWrite := len(forceWrite) > 0 && forceWrite[0]
var db *gorm.DB
if useWrite {
db = m.write
} else {
db = m.read
}
if ctx != nil {
db = db.WithContext(ctx)
}
return db
}
// GetRead 获取读库
func (m *mysqlRepo) GetRead() *gorm.DB {
m.mu.RLock()
defer m.mu.RUnlock()
return m.read
}
// GetWrite 获取写库
func (m *mysqlRepo) GetWrite() *gorm.DB {
m.mu.RLock()
defer m.mu.RUnlock()
return m.write
}
// Transaction 执行事务
func (m *mysqlRepo) Transaction(ctx context.Context, fc func(*gorm.DB) error) error {
m.mu.RLock()
if m.closed {
m.mu.RUnlock()
return fmt.Errorf("数据库连接已关闭")
}
db := m.write
m.mu.RUnlock()
if ctx != nil {
db = db.WithContext(ctx)
}
return db.Transaction(fc)
}
// Close 关闭连接
func (m *mysqlRepo) Close() error {
m.mu.Lock()
defer m.mu.Unlock()
if m.closed {
return nil
}
var errs []error
if err := closeDB(m.read); err != nil {
errs = append(errs, fmt.Errorf("关闭读库失败: %w", err))
}
if err := closeDB(m.write); err != nil {
errs = append(errs, fmt.Errorf("关闭写库失败: %w", err))
}
m.closed = true
if len(errs) > 0 {
return errors.Join(errs...)
}
return nil
}
// Ping 检查连接
func (m *mysqlRepo) Ping(ctx context.Context) error {
m.mu.RLock()
defer m.mu.RUnlock()
if m.closed {
return fmt.Errorf("数据库连接已关闭")
}
if ctx == nil {
ctx = context.Background()
}
// Ping读库
readDB, err := m.read.DB()
if err != nil {
return fmt.Errorf("获取读库失败: %w", err)
}
if err := readDB.PingContext(ctx); err != nil {
return fmt.Errorf("Ping读库失败: %w", err)
}
// Ping写库
writeDB, err := m.write.DB()
if err != nil {
return fmt.Errorf("获取写库失败: %w", err)
}
if err := writeDB.PingContext(ctx); err != nil {
return fmt.Errorf("Ping写库失败: %w", err)
}
return nil
}
// Stats 获取连接池统计信息
func (m *mysqlRepo) Stats() (*DBStats, error) {
m.mu.RLock()
defer m.mu.RUnlock()
if m.closed {
return nil, fmt.Errorf("数据库连接已关闭")
}
readDB, err := m.read.DB()
if err != nil {
return nil, fmt.Errorf("获取读库失败: %w", err)
}
writeDB, err := m.write.DB()
if err != nil {
return nil, fmt.Errorf("获取写库失败: %w", err)
}
return &DBStats{
Read: readDB.Stats(),
Write: writeDB.Stats(),
}, nil
}
func (d *mysqlRepo) i() {}
func (d *mysqlRepo) GetRead(options ...Option) *gorm.DB {
opt := newOption()
for _, f := range options {
f(opt)
// HealthCheck 健康检查
func (m *mysqlRepo) HealthCheck(ctx context.Context) error {
if err := m.Ping(ctx); err != nil {
return err
}
db := d.read
if opt.Trace != nil {
db.InstanceSet(traceCtxName, opt.Trace)
stats, err := m.Stats()
if err != nil {
return err
}
return db
// 检查连接池状态
if stats.Read.OpenConnections >= m.config.Base.MaxOpenConn {
return fmt.Errorf("读库连接池已满")
}
if stats.Write.OpenConnections >= m.config.Base.MaxOpenConn {
return fmt.Errorf("写库连接池已满")
}
return nil
}
func (d *mysqlRepo) GetWrite(options ...Option) *gorm.DB {
opt := newOption()
for _, f := range options {
f(opt)
// mergeDefaultConfig 合并默认配置
func mergeDefaultConfig(cfg *MySQLConfig) {
// 读库默认配置
if cfg.Read.Charset == "" {
cfg.Read.Charset = "utf8mb4"
}
db := d.write
if opt.Trace != nil {
db.InstanceSet(traceCtxName, opt.Trace)
if cfg.Read.Collation == "" {
cfg.Read.Collation = "utf8mb4_unicode_ci"
}
if cfg.Read.Loc == "" {
cfg.Read.Loc = "Local"
}
if cfg.Read.Timeout == 0 {
cfg.Read.Timeout = 10 // 10秒
}
if cfg.Read.ReadTimeout == 0 {
cfg.Read.ReadTimeout = 30 // 30秒
}
if cfg.Read.WriteTimeout == 0 {
cfg.Read.WriteTimeout = 30 // 30秒
}
cfg.Read.ParseTime = true
return db
// 写库默认配置
if cfg.Write.Charset == "" {
cfg.Write.Charset = "utf8mb4"
}
if cfg.Write.Collation == "" {
cfg.Write.Collation = "utf8mb4_unicode_ci"
}
if cfg.Write.Loc == "" {
cfg.Write.Loc = "Local"
}
if cfg.Write.Timeout == 0 {
cfg.Write.Timeout = 10
}
if cfg.Write.ReadTimeout == 0 {
cfg.Write.ReadTimeout = 30
}
if cfg.Write.WriteTimeout == 0 {
cfg.Write.WriteTimeout = 30
}
cfg.Write.ParseTime = true
}
func (d *mysqlRepo) Close() (err error) {
rdb, err1 := d.read.DB()
if err1 != nil {
err = errors.Join(err1)
// validateMySQLConfig 验证配置
func validateMySQLConfig(cfg *MySQLConfig) error {
if cfg.Read.Addr == "" {
return fmt.Errorf("读库地址不能为空")
}
err2 := rdb.Close()
if err2 != nil {
err = errors.Join(err2)
if cfg.Read.User == "" {
return fmt.Errorf("读库用户名不能为空")
}
if cfg.Read.Name == "" {
return fmt.Errorf("读库数据库名不能为空")
}
wdb, err3 := d.write.DB()
if err3 != nil {
err = errors.Join(err3)
if cfg.Write.Addr == "" {
return fmt.Errorf("写库地址不能为空")
}
err4 := wdb.Close()
if err4 != nil {
err = errors.Join(err4)
if cfg.Write.User == "" {
return fmt.Errorf("写库用户名不能为空")
}
if cfg.Write.Name == "" {
return fmt.Errorf("写库数据库名不能为空")
}
return err
return nil
}
func dbConnect(user, pass, addr, dbName string, maxOpenConn, maxIdleConn int, connMaxLifeTime time.Duration) (*gorm.DB, error) {
dsn := fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=utf8mb4&parseTime=%t&loc=%s",
user,
pass,
addr,
dbName,
true,
"Local")
// buildDSN 构建DSN
func buildDSN(conn DBConnConfig) string {
return fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=%s&collation=%s&parseTime=%t&loc=%s&timeout=%ds&readTimeout=%ds&writeTimeout=%ds",
conn.User,
conn.Pass,
conn.Addr,
conn.Name,
conn.Charset,
conn.Collation,
conn.ParseTime,
conn.Loc,
int(conn.Timeout),
int(conn.ReadTimeout),
int(conn.WriteTimeout),
)
}
// dbConnect 连接数据库
func dbConnect(name string, conn DBConnConfig, base BaseConfig, logCfg LoggerConfig) (*gorm.DB, error) {
// 构建DSN
dsn := buildDSN(conn)
// 配置日志
logLevel := parseLogLevel(logCfg.LogLevel)
logWriter := getLogWriter(logCfg)
// 日志配置
newLogger := logger.New(
log.New(os.Stdout, "\r\n", log.LstdFlags),
log.New(logWriter, "\r\n", log.LstdFlags),
logger.Config{
SlowThreshold: time.Second, // 慢SQL阈值
Colorful: true, // 彩色打印
IgnoreRecordNotFoundError: true, // 忽略记录未找到错误
LogLevel: logger.Error, // 日志级别
SlowThreshold: logCfg.SlowThreshold,
Colorful: logCfg.Colorful,
IgnoreRecordNotFoundError: logCfg.IgnoreRecordNotFoundError,
LogLevel: logLevel,
},
)
// 打开数据库连接
db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{
NamingStrategy: schema.NamingStrategy{
SingularTable: true,
},
Logger: newLogger,
Logger: newLogger,
SkipDefaultTransaction: true,
PrepareStmt: true,
DisableForeignKeyConstraintWhenMigrating: true,
})
if err != nil {
return nil, errors.Join(err, fmt.Errorf("[db connection failed] Database name: %s", dbName))
return nil, fmt.Errorf("打开数据库连接失败 [%s]: %w", name, err)
}
db.Set("gorm:table_options", "CHARSET=utf8mb4")
// 获取底层sqlDB
sqlDB, err := db.DB()
if err != nil {
return nil, err
return nil, fmt.Errorf("获取sqlDB失败: %w", err)
}
// 设置连接池 用于设置最大打开的连接数默认值为0表示不限制.设置最大的连接数可以避免并发太高导致连接mysql出现too many connections的错误。
sqlDB.SetMaxOpenConns(maxOpenConn)
// 设置最大连接数 用于设置闲置的连接数.设置闲置的连接数则当开启的一个连接使用完成后可以放在池里等候下一次使用。
sqlDB.SetMaxIdleConns(maxIdleConn)
// 设置最大连接超时
sqlDB.SetConnMaxLifetime(time.Minute * connMaxLifeTime)
// 使用插件
err = db.Use(&TracePlugin{})
if err != nil {
return nil, err
// 设置连接池参数
if base.MaxOpenConn > 0 {
sqlDB.SetMaxOpenConns(base.MaxOpenConn)
}
if base.MaxIdleConn > 0 {
sqlDB.SetMaxIdleConns(base.MaxIdleConn)
}
if base.ConnMaxLifetime > 0 {
sqlDB.SetConnMaxLifetime(base.ConnMaxLifetime * time.Second)
}
if base.ConnMaxIdleTime > 0 {
sqlDB.SetConnMaxIdleTime(base.ConnMaxIdleTime * time.Second)
}
return db, nil
}
/***************************************************************/
type TracePlugin struct{}
func (op *TracePlugin) Name() string {
return "TracePlugin"
// closeDB 关闭数据库
func closeDB(db *gorm.DB) error {
if db == nil {
return nil
}
sqlDB, err := db.DB()
if err != nil {
return err
}
return sqlDB.Close()
}
func (op *TracePlugin) Initialize(db *gorm.DB) (err error) {
// 开始前
_ = db.Callback().Create().Before("gorm:before_create").Register(callBackBeforeName, before)
_ = db.Callback().Query().Before("gorm:query").Register(callBackBeforeName, before)
_ = db.Callback().Delete().Before("gorm:before_delete").Register(callBackBeforeName, before)
_ = db.Callback().Update().Before("gorm:setup_reflect_value").Register(callBackBeforeName, before)
_ = db.Callback().Row().Before("gorm:row").Register(callBackBeforeName, before)
_ = db.Callback().Raw().Before("gorm:raw").Register(callBackBeforeName, before)
// 结束后
_ = db.Callback().Create().After("gorm:after_create").Register(callBackAfterName, after)
_ = db.Callback().Query().After("gorm:after_query").Register(callBackAfterName, after)
_ = db.Callback().Delete().After("gorm:after_delete").Register(callBackAfterName, after)
_ = db.Callback().Update().After("gorm:after_update").Register(callBackAfterName, after)
_ = db.Callback().Row().After("gorm:row").Register(callBackAfterName, after)
_ = db.Callback().Raw().After("gorm:raw").Register(callBackAfterName, after)
return
// parseLogLevel 解析日志级别
func parseLogLevel(level string) logger.LogLevel {
switch level {
case "silent":
return logger.Silent
case "error":
return logger.Error
case "warn":
return logger.Warn
case "info":
return logger.Info
default:
return logger.Warn
}
}
func before(db *gorm.DB) {
db.InstanceSet(startTime, time.Now())
}
func after(db *gorm.DB) {
_traceCtx, isExist := db.InstanceGet(traceCtxName)
if !isExist {
return
}
_trace, ok := _traceCtx.(trace.T)
if !ok {
return
}
_ts, isExist := db.InstanceGet(startTime)
if !isExist {
return
}
ts, ok := _ts.(time.Time)
if !ok {
return
}
sql := db.Dialector.Explain(db.Statement.SQL.String(), db.Statement.Vars...)
sqlInfo := new(trace.SQL)
sqlInfo.Timestamp = time_parse.CSTLayoutString()
sqlInfo.SQL = sql
sqlInfo.Stack = utils.FileWithLineNum()
sqlInfo.Rows = db.Statement.RowsAffected
sqlInfo.CostSeconds = time.Since(ts).Seconds()
_trace.AppendSQL(sqlInfo)
// getLogWriter 获取日志输出
func getLogWriter(cfg LoggerConfig) *os.File {
if cfg.LogOutput == "file" && cfg.LogFile != "" {
file, err := os.OpenFile(cfg.LogFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666)
if err != nil {
log.Printf("打开日志文件失败: %v使用标准输出", err)
return os.Stdout
}
return file
}
return os.Stdout
}

View File

@@ -1,12 +1,15 @@
package database
import (
"context"
"database/sql"
"database/sql/driver"
"encoding/json"
"fmt"
"gorm.io/gorm"
"reflect"
"time"
"gorm.io/gorm"
)
type NullTime sql.NullTime
@@ -50,6 +53,7 @@ type PaginateList struct {
List any `json:"list"`
}
// Paginate 分页查询
func Paginate(db *gorm.DB, model any, page, size int64) (*PaginateList, error) {
ptr := reflect.ValueOf(model)
if ptr.Kind() != reflect.Ptr {
@@ -86,3 +90,63 @@ func Paginate(db *gorm.DB, model any, page, size int64) (*PaginateList, error) {
List: model,
}, nil
}
// BatchInsert 批量插入
func BatchInsert(ctx context.Context, db *gorm.DB, records interface{}, batchSize int) error {
if batchSize <= 0 {
batchSize = 100
}
return db.WithContext(ctx).CreateInBatches(records, batchSize).Error
}
// ExistsBy 检查记录是否存在
func ExistsBy(ctx context.Context, db *gorm.DB, model interface{}, query interface{}, args ...interface{}) (bool, error) {
var count int64
err := db.WithContext(ctx).Model(model).Where(query, args...).Count(&count).Error
if err != nil {
return false, err
}
return count > 0, nil
}
// FindByIDs 根据ID列表查询
func FindByIDs(ctx context.Context, db *gorm.DB, dest interface{}, ids []uint) error {
if len(ids) == 0 {
return nil
}
return db.WithContext(ctx).Where("id IN ?", ids).Find(dest).Error
}
// SoftDeleteByIDs 批量软删除
func SoftDeleteByIDs(ctx context.Context, db *gorm.DB, model interface{}, ids []uint) error {
if len(ids) == 0 {
return nil
}
return db.WithContext(ctx).Where("id IN ?", ids).Delete(model).Error
}
// UpdateFields 更新指定字段
func UpdateFields(ctx context.Context, db *gorm.DB, model interface{}, id uint, fields map[string]interface{}) error {
return db.WithContext(ctx).Model(model).Where("id = ?", id).Updates(fields).Error
}
// Transaction 简化的事务助手
func Transaction(ctx context.Context, db *gorm.DB, fn func(*gorm.DB) error) error {
return db.WithContext(ctx).Transaction(fn)
}
// BaseModel 基础模型
type BaseModel struct {
ID uint `gorm:"primarykey" json:"id"`
CreatedAt time.Time `gorm:"autoCreateTime" json:"created_at"`
UpdatedAt time.Time `gorm:"autoUpdateTime" json:"updated_at"`
DeletedAt gorm.DeletedAt `gorm:"index" json:"deleted_at,omitempty"`
}
// SimpleModel 简单模型(无软删除)
type SimpleModel struct {
ID uint `gorm:"primarykey" json:"id"`
CreatedAt time.Time `gorm:"autoCreateTime" json:"created_at"`
UpdatedAt time.Time `gorm:"autoUpdateTime" json:"updated_at"`
}

View File

@@ -1,23 +0,0 @@
package database
import "gitea.bvbej.com/bvbej/base-golang/pkg/trace"
type Trace = trace.T
type Option func(*option)
func WithTrace(t Trace) Option {
return func(opt *option) {
if t != nil {
opt.Trace = t.(*trace.Trace)
}
}
}
func newOption() *option {
return &option{}
}
type option struct {
Trace *trace.Trace
}