[🚀] mysql
This commit is contained in:
@@ -3,75 +3,262 @@ package database
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.mongodb.org/mongo-driver/mongo/options"
|
||||
"go.mongodb.org/mongo-driver/mongo/readpref"
|
||||
"time"
|
||||
)
|
||||
|
||||
var _ MongoDB = (*mongoDB)(nil)
|
||||
|
||||
// MongoDB MongoDB接口
|
||||
type MongoDB interface {
|
||||
i()
|
||||
// GetDB 获取数据库实例
|
||||
GetDB() *mongo.Database
|
||||
Close() error
|
||||
|
||||
// GetClient 获取客户端实例
|
||||
GetClient() *mongo.Client
|
||||
|
||||
// GetCollection 获取集合
|
||||
GetCollection(name string) *mongo.Collection
|
||||
|
||||
// Ping 检查连接
|
||||
Ping(ctx context.Context) error
|
||||
|
||||
// Close 关闭连接
|
||||
Close(ctx context.Context) error
|
||||
|
||||
// WithContext 创建带超时的上下文
|
||||
WithContext(ctx context.Context) (context.Context, context.CancelFunc)
|
||||
}
|
||||
|
||||
// MongoDBConfig MongoDB配置
|
||||
type MongoDBConfig struct {
|
||||
Addr string `yaml:"addr"`
|
||||
User string `yaml:"user"`
|
||||
Pass string `yaml:"pass"`
|
||||
Name string `yaml:"name"`
|
||||
Timeout time.Duration `yaml:"timeout"`
|
||||
// 地址,支持多个地址: "localhost:27017,localhost:27018"
|
||||
Addr string `yaml:"addr" json:"addr"`
|
||||
|
||||
// 用户名
|
||||
User string `yaml:"user" json:"user"`
|
||||
|
||||
// 密码
|
||||
Pass string `yaml:"pass" json:"pass"`
|
||||
|
||||
// 数据库名称
|
||||
Name string `yaml:"name" json:"name"`
|
||||
|
||||
// 连接超时(秒)
|
||||
Timeout time.Duration `yaml:"timeout" json:"timeout"`
|
||||
|
||||
// 最大连接池大小
|
||||
MaxPoolSize uint64 `yaml:"max_pool_size" json:"max_pool_size"`
|
||||
|
||||
// 最小连接池大小
|
||||
MinPoolSize uint64 `yaml:"min_pool_size" json:"min_pool_size"`
|
||||
|
||||
// 最大连接空闲时间(秒)
|
||||
MaxConnIdleTime time.Duration `yaml:"max_conn_idle_time" json:"max_conn_idle_time"`
|
||||
|
||||
// 是否使用副本集
|
||||
ReplicaSet string `yaml:"replica_set" json:"replica_set"`
|
||||
|
||||
// 是否使用TLS
|
||||
UseTLS bool `yaml:"use_tls" json:"use_tls"`
|
||||
|
||||
// 认证数据库
|
||||
AuthSource string `yaml:"auth_source" json:"auth_source"`
|
||||
}
|
||||
|
||||
// DefaultMongoDBConfig 默认配置
|
||||
func DefaultMongoDBConfig() *MongoDBConfig {
|
||||
return &MongoDBConfig{
|
||||
Timeout: 10, // 10秒
|
||||
MaxPoolSize: 100,
|
||||
MinPoolSize: 10,
|
||||
MaxConnIdleTime: 60, // 60秒
|
||||
AuthSource: "admin",
|
||||
}
|
||||
}
|
||||
|
||||
// mongoDB MongoDB实现
|
||||
type mongoDB struct {
|
||||
client *mongo.Client
|
||||
db *mongo.Database
|
||||
config *MongoDBConfig
|
||||
timeout time.Duration
|
||||
}
|
||||
|
||||
func (m *mongoDB) i() {}
|
||||
// NewMongoDB 创建MongoDB实例
|
||||
func NewMongoDB(cfg *MongoDBConfig) (MongoDB, error) {
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("配置不能为空")
|
||||
}
|
||||
|
||||
// 验证必填参数
|
||||
if cfg.Addr == "" {
|
||||
return nil, fmt.Errorf("地址不能为空")
|
||||
}
|
||||
if cfg.Name == "" {
|
||||
return nil, fmt.Errorf("数据库名称不能为空")
|
||||
}
|
||||
|
||||
func NewMongoDB(cfg MongoDBConfig) (MongoDB, error) {
|
||||
timeout := cfg.Timeout * time.Second
|
||||
connectCtx, connectCancelFunc := context.WithTimeout(context.Background(), timeout)
|
||||
defer connectCancelFunc()
|
||||
var auth string
|
||||
if len(cfg.User) > 0 && len(cfg.Pass) > 0 {
|
||||
auth = fmt.Sprintf("%s:%s@", cfg.User, cfg.Pass)
|
||||
}
|
||||
client, err := mongo.Connect(connectCtx, options.Client().ApplyURI(
|
||||
fmt.Sprintf("mongodb://%s%s", auth, cfg.Addr),
|
||||
))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
if timeout == 0 {
|
||||
timeout = 10 * time.Second
|
||||
}
|
||||
|
||||
pingCtx, pingCancelFunc := context.WithTimeout(context.Background(), timeout)
|
||||
defer pingCancelFunc()
|
||||
err = client.Ping(pingCtx, readpref.Primary())
|
||||
// 构建连接选项
|
||||
clientOpts := options.Client()
|
||||
|
||||
// 构建URI
|
||||
uri := buildMongoURI(cfg)
|
||||
clientOpts.ApplyURI(uri)
|
||||
|
||||
// 设置连接池
|
||||
if cfg.MaxPoolSize > 0 {
|
||||
clientOpts.SetMaxPoolSize(cfg.MaxPoolSize)
|
||||
}
|
||||
if cfg.MinPoolSize > 0 {
|
||||
clientOpts.SetMinPoolSize(cfg.MinPoolSize)
|
||||
}
|
||||
if cfg.MaxConnIdleTime > 0 {
|
||||
clientOpts.SetMaxConnIdleTime(cfg.MaxConnIdleTime * time.Second)
|
||||
}
|
||||
|
||||
// 设置超时
|
||||
clientOpts.SetConnectTimeout(timeout)
|
||||
clientOpts.SetServerSelectionTimeout(timeout)
|
||||
|
||||
// 连接MongoDB
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
|
||||
client, err := mongo.Connect(ctx, clientOpts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("连接MongoDB失败: %w", err)
|
||||
}
|
||||
|
||||
// Ping测试连接
|
||||
pingCtx, pingCancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer pingCancel()
|
||||
|
||||
if err = client.Ping(pingCtx, readpref.Primary()); err != nil {
|
||||
_ = client.Disconnect(context.Background())
|
||||
return nil, fmt.Errorf("Ping MongoDB失败: %w", err)
|
||||
}
|
||||
|
||||
return &mongoDB{
|
||||
client: client,
|
||||
db: client.Database(cfg.Name),
|
||||
config: cfg,
|
||||
timeout: timeout,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetDB 获取数据库实例
|
||||
func (m *mongoDB) GetDB() *mongo.Database {
|
||||
return m.db
|
||||
}
|
||||
|
||||
func (m *mongoDB) Close() error {
|
||||
disconnectCtx, disconnectCancelFunc := context.WithTimeout(context.Background(), m.timeout)
|
||||
defer disconnectCancelFunc()
|
||||
err := m.client.Disconnect(disconnectCtx)
|
||||
// GetClient 获取客户端实例
|
||||
func (m *mongoDB) GetClient() *mongo.Client {
|
||||
return m.client
|
||||
}
|
||||
|
||||
// GetCollection 获取集合
|
||||
func (m *mongoDB) GetCollection(name string) *mongo.Collection {
|
||||
return m.db.Collection(name)
|
||||
}
|
||||
|
||||
// Ping 检查连接
|
||||
func (m *mongoDB) Ping(ctx context.Context) error {
|
||||
pingCtx, cancel := m.WithContext(ctx)
|
||||
defer cancel()
|
||||
|
||||
return m.client.Ping(pingCtx, readpref.Primary())
|
||||
}
|
||||
|
||||
// Close 关闭连接
|
||||
func (m *mongoDB) Close(ctx context.Context) error {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
disconnectCtx, cancel := context.WithTimeout(ctx, m.timeout)
|
||||
defer cancel()
|
||||
|
||||
if err := m.client.Disconnect(disconnectCtx); err != nil {
|
||||
return fmt.Errorf("断开MongoDB连接失败: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// WithContext 创建带超时的上下文
|
||||
func (m *mongoDB) WithContext(ctx context.Context) (context.Context, context.CancelFunc) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
return context.WithTimeout(ctx, m.timeout)
|
||||
}
|
||||
|
||||
// buildMongoURI 构建MongoDB URI
|
||||
func buildMongoURI(cfg *MongoDBConfig) string {
|
||||
var auth string
|
||||
if cfg.User != "" && cfg.Pass != "" {
|
||||
auth = fmt.Sprintf("%s:%s@", cfg.User, cfg.Pass)
|
||||
}
|
||||
|
||||
uri := fmt.Sprintf("mongodb://%s%s/%s", auth, cfg.Addr, cfg.Name)
|
||||
|
||||
// 添加查询参数
|
||||
params := []string{}
|
||||
|
||||
if cfg.AuthSource != "" {
|
||||
params = append(params, fmt.Sprintf("authSource=%s", cfg.AuthSource))
|
||||
}
|
||||
|
||||
if cfg.ReplicaSet != "" {
|
||||
params = append(params, fmt.Sprintf("replicaSet=%s", cfg.ReplicaSet))
|
||||
}
|
||||
|
||||
if cfg.UseTLS {
|
||||
params = append(params, "tls=true")
|
||||
}
|
||||
|
||||
if len(params) > 0 {
|
||||
uri += "?"
|
||||
for i, param := range params {
|
||||
if i > 0 {
|
||||
uri += "&"
|
||||
}
|
||||
uri += param
|
||||
}
|
||||
}
|
||||
|
||||
return uri
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
|
||||
// EnsureIndexes 确保索引存在
|
||||
func EnsureIndexes(ctx context.Context, collection *mongo.Collection, indexes []mongo.IndexModel) error {
|
||||
if len(indexes) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err := collection.Indexes().CreateMany(ctx, indexes)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("创建索引失败: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DropIndexes 删除索引
|
||||
func DropIndexes(ctx context.Context, collection *mongo.Collection, indexNames []string) error {
|
||||
for _, name := range indexNames {
|
||||
if _, err := collection.Indexes().DropOne(ctx, name); err != nil {
|
||||
return fmt.Errorf("删除索引 %s 失败: %w", name, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,247 +1,499 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"gitea.bvbej.com/bvbej/base-golang/pkg/time_parse"
|
||||
"gitea.bvbej.com/bvbej/base-golang/pkg/trace"
|
||||
"gorm.io/driver/mysql"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
"gorm.io/gorm/schema"
|
||||
"gorm.io/gorm/utils"
|
||||
)
|
||||
|
||||
const (
|
||||
callBackBeforeName = "core:before"
|
||||
callBackAfterName = "core:after"
|
||||
startTime = "_start_time"
|
||||
traceCtxName = "_trace_ctx_name"
|
||||
)
|
||||
|
||||
var _ MysqlRepo = (*mysqlRepo)(nil)
|
||||
|
||||
// MysqlRepo MySQL接口
|
||||
type MysqlRepo interface {
|
||||
i()
|
||||
GetRead(options ...Option) *gorm.DB
|
||||
GetWrite(options ...Option) *gorm.DB
|
||||
// DB 获取数据库实例(自动选择读写)
|
||||
DB(ctx context.Context, forceWrite ...bool) *gorm.DB
|
||||
|
||||
// GetRead 获取读库
|
||||
GetRead() *gorm.DB
|
||||
|
||||
// GetWrite 获取写库
|
||||
GetWrite() *gorm.DB
|
||||
|
||||
// Transaction 执行事务
|
||||
Transaction(ctx context.Context, fc func(*gorm.DB) error) error
|
||||
|
||||
// Close 关闭连接
|
||||
Close() error
|
||||
|
||||
// Ping 检查连接
|
||||
Ping(ctx context.Context) error
|
||||
|
||||
// Stats 获取连接池统计信息
|
||||
Stats() (*DBStats, error)
|
||||
|
||||
// HealthCheck 健康检查
|
||||
HealthCheck(ctx context.Context) error
|
||||
}
|
||||
|
||||
// DBStats 数据库统计信息
|
||||
type DBStats struct {
|
||||
Read sql.DBStats `json:"read"`
|
||||
Write sql.DBStats `json:"write"`
|
||||
}
|
||||
|
||||
// MySQLConfig MySQL配置
|
||||
type MySQLConfig struct {
|
||||
Read struct {
|
||||
Addr string `yaml:"addr"`
|
||||
User string `yaml:"user"`
|
||||
Pass string `yaml:"pass"`
|
||||
Name string `yaml:"name"`
|
||||
} `yaml:"read"`
|
||||
Write struct {
|
||||
Addr string `yaml:"addr"`
|
||||
User string `yaml:"user"`
|
||||
Pass string `yaml:"pass"`
|
||||
Name string `yaml:"name"`
|
||||
} `yaml:"write"`
|
||||
Base struct {
|
||||
MaxOpenConn int `yaml:"maxOpenConn"` //最大连接数
|
||||
MaxIdleConn int `yaml:"maxIdleConn"` //最大空闲连接数
|
||||
ConnMaxLifeTime time.Duration `yaml:"connMaxLifeTime"` //最大连接超时(分钟)
|
||||
} `yaml:"base"`
|
||||
Read DBConnConfig `yaml:"read" json:"read"`
|
||||
Write DBConnConfig `yaml:"write" json:"write"`
|
||||
Base BaseConfig `yaml:"base" json:"base"`
|
||||
Logger LoggerConfig `yaml:"logger" json:"logger"`
|
||||
}
|
||||
|
||||
// DBConnConfig 数据库连接配置
|
||||
type DBConnConfig struct {
|
||||
Addr string `yaml:"addr" json:"addr"`
|
||||
User string `yaml:"user" json:"user"`
|
||||
Pass string `yaml:"pass" json:"pass"`
|
||||
Name string `yaml:"name" json:"name"`
|
||||
Charset string `yaml:"charset" json:"charset"` // 默认 utf8mb4
|
||||
Collation string `yaml:"collation" json:"collation"` // 默认 utf8mb4_unicode_ci
|
||||
Loc string `yaml:"loc" json:"loc"` // 默认 Local
|
||||
ParseTime bool `yaml:"parse_time" json:"parse_time"` // 默认 true
|
||||
Timeout time.Duration `yaml:"timeout" json:"timeout"` // 连接超时(秒)
|
||||
ReadTimeout time.Duration `yaml:"read_timeout" json:"read_timeout"` // 读超时(秒)
|
||||
WriteTimeout time.Duration `yaml:"write_timeout" json:"write_timeout"` // 写超时(秒)
|
||||
}
|
||||
|
||||
// BaseConfig 基础配置
|
||||
type BaseConfig struct {
|
||||
MaxOpenConn int `yaml:"max_open_conn" json:"max_open_conn"`
|
||||
MaxIdleConn int `yaml:"max_idle_conn" json:"max_idle_conn"`
|
||||
ConnMaxLifetime time.Duration `yaml:"conn_max_lifetime" json:"conn_max_lifetime"` // 秒
|
||||
ConnMaxIdleTime time.Duration `yaml:"conn_max_idle_time" json:"conn_max_idle_time"` // 秒
|
||||
}
|
||||
|
||||
// LoggerConfig 日志配置
|
||||
type LoggerConfig struct {
|
||||
SlowThreshold time.Duration `yaml:"slow_threshold" json:"slow_threshold"`
|
||||
Colorful bool `yaml:"colorful" json:"colorful"`
|
||||
IgnoreRecordNotFoundError bool `yaml:"ignore_record_not_found_error" json:"ignore_record_not_found_error"`
|
||||
LogLevel string `yaml:"log_level" json:"log_level"`
|
||||
LogOutput string `yaml:"log_output" json:"log_output"` // stdout, file
|
||||
LogFile string `yaml:"log_file" json:"log_file"` // 日志文件路径
|
||||
}
|
||||
|
||||
// DefaultMySQLConfig 默认配置
|
||||
func DefaultMySQLConfig() *MySQLConfig {
|
||||
return &MySQLConfig{
|
||||
Base: BaseConfig{
|
||||
MaxOpenConn: 100,
|
||||
MaxIdleConn: 10,
|
||||
ConnMaxLifetime: 3600, // 1小时
|
||||
ConnMaxIdleTime: 600, // 10分钟
|
||||
},
|
||||
Logger: LoggerConfig{
|
||||
SlowThreshold: 200 * time.Millisecond,
|
||||
Colorful: true,
|
||||
IgnoreRecordNotFoundError: true,
|
||||
LogLevel: "warn",
|
||||
LogOutput: "stdout",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// mysqlRepo MySQL实现
|
||||
type mysqlRepo struct {
|
||||
read *gorm.DB
|
||||
write *gorm.DB
|
||||
read *gorm.DB
|
||||
write *gorm.DB
|
||||
config *MySQLConfig
|
||||
mu sync.RWMutex
|
||||
closed bool
|
||||
}
|
||||
|
||||
func NewMysql(cfg MySQLConfig) (MysqlRepo, error) {
|
||||
dbr, err := dbConnect(cfg.Read.User, cfg.Read.Pass, cfg.Read.Addr, cfg.Read.Name,
|
||||
cfg.Base.MaxOpenConn, cfg.Base.MaxIdleConn, cfg.Base.ConnMaxLifeTime)
|
||||
if err != nil {
|
||||
// NewMysql 创建MySQL实例
|
||||
func NewMysql(cfg *MySQLConfig) (MysqlRepo, error) {
|
||||
if cfg == nil {
|
||||
cfg = DefaultMySQLConfig()
|
||||
}
|
||||
|
||||
// 合并默认配置
|
||||
mergeDefaultConfig(cfg)
|
||||
|
||||
// 验证配置
|
||||
if err := validateMySQLConfig(cfg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dbw, err := dbConnect(cfg.Write.User, cfg.Write.Pass, cfg.Write.Addr, cfg.Write.Name,
|
||||
cfg.Base.MaxOpenConn, cfg.Base.MaxIdleConn, cfg.Base.ConnMaxLifeTime)
|
||||
// 连接读库
|
||||
dbr, err := dbConnect("read", cfg.Read, cfg.Base, cfg.Logger)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("连接读库失败: %w", err)
|
||||
}
|
||||
|
||||
return &mysqlRepo{
|
||||
read: dbr,
|
||||
write: dbw,
|
||||
// 连接写库
|
||||
dbw, err := dbConnect("write", cfg.Write, cfg.Base, cfg.Logger)
|
||||
if err != nil {
|
||||
_ = closeDB(dbr)
|
||||
return nil, fmt.Errorf("连接写库失败: %w", err)
|
||||
}
|
||||
|
||||
repo := &mysqlRepo{
|
||||
read: dbr,
|
||||
write: dbw,
|
||||
config: cfg,
|
||||
closed: false,
|
||||
}
|
||||
|
||||
return repo, nil
|
||||
}
|
||||
|
||||
// DB 获取数据库实例(自动选择读写)
|
||||
func (m *mysqlRepo) DB(ctx context.Context, forceWrite ...bool) *gorm.DB {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
useWrite := len(forceWrite) > 0 && forceWrite[0]
|
||||
|
||||
var db *gorm.DB
|
||||
if useWrite {
|
||||
db = m.write
|
||||
} else {
|
||||
db = m.read
|
||||
}
|
||||
|
||||
if ctx != nil {
|
||||
db = db.WithContext(ctx)
|
||||
}
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
// GetRead 获取读库
|
||||
func (m *mysqlRepo) GetRead() *gorm.DB {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return m.read
|
||||
}
|
||||
|
||||
// GetWrite 获取写库
|
||||
func (m *mysqlRepo) GetWrite() *gorm.DB {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return m.write
|
||||
}
|
||||
|
||||
// Transaction 执行事务
|
||||
func (m *mysqlRepo) Transaction(ctx context.Context, fc func(*gorm.DB) error) error {
|
||||
m.mu.RLock()
|
||||
if m.closed {
|
||||
m.mu.RUnlock()
|
||||
return fmt.Errorf("数据库连接已关闭")
|
||||
}
|
||||
db := m.write
|
||||
m.mu.RUnlock()
|
||||
|
||||
if ctx != nil {
|
||||
db = db.WithContext(ctx)
|
||||
}
|
||||
|
||||
return db.Transaction(fc)
|
||||
}
|
||||
|
||||
// Close 关闭连接
|
||||
func (m *mysqlRepo) Close() error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if m.closed {
|
||||
return nil
|
||||
}
|
||||
|
||||
var errs []error
|
||||
if err := closeDB(m.read); err != nil {
|
||||
errs = append(errs, fmt.Errorf("关闭读库失败: %w", err))
|
||||
}
|
||||
if err := closeDB(m.write); err != nil {
|
||||
errs = append(errs, fmt.Errorf("关闭写库失败: %w", err))
|
||||
}
|
||||
|
||||
m.closed = true
|
||||
|
||||
if len(errs) > 0 {
|
||||
return errors.Join(errs...)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Ping 检查连接
|
||||
func (m *mysqlRepo) Ping(ctx context.Context) error {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
if m.closed {
|
||||
return fmt.Errorf("数据库连接已关闭")
|
||||
}
|
||||
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
// Ping读库
|
||||
readDB, err := m.read.DB()
|
||||
if err != nil {
|
||||
return fmt.Errorf("获取读库失败: %w", err)
|
||||
}
|
||||
if err := readDB.PingContext(ctx); err != nil {
|
||||
return fmt.Errorf("Ping读库失败: %w", err)
|
||||
}
|
||||
|
||||
// Ping写库
|
||||
writeDB, err := m.write.DB()
|
||||
if err != nil {
|
||||
return fmt.Errorf("获取写库失败: %w", err)
|
||||
}
|
||||
if err := writeDB.PingContext(ctx); err != nil {
|
||||
return fmt.Errorf("Ping写库失败: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stats 获取连接池统计信息
|
||||
func (m *mysqlRepo) Stats() (*DBStats, error) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
if m.closed {
|
||||
return nil, fmt.Errorf("数据库连接已关闭")
|
||||
}
|
||||
|
||||
readDB, err := m.read.DB()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取读库失败: %w", err)
|
||||
}
|
||||
|
||||
writeDB, err := m.write.DB()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取写库失败: %w", err)
|
||||
}
|
||||
|
||||
return &DBStats{
|
||||
Read: readDB.Stats(),
|
||||
Write: writeDB.Stats(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (d *mysqlRepo) i() {}
|
||||
|
||||
func (d *mysqlRepo) GetRead(options ...Option) *gorm.DB {
|
||||
opt := newOption()
|
||||
for _, f := range options {
|
||||
f(opt)
|
||||
// HealthCheck 健康检查
|
||||
func (m *mysqlRepo) HealthCheck(ctx context.Context) error {
|
||||
if err := m.Ping(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
db := d.read
|
||||
if opt.Trace != nil {
|
||||
db.InstanceSet(traceCtxName, opt.Trace)
|
||||
stats, err := m.Stats()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return db
|
||||
// 检查连接池状态
|
||||
if stats.Read.OpenConnections >= m.config.Base.MaxOpenConn {
|
||||
return fmt.Errorf("读库连接池已满")
|
||||
}
|
||||
if stats.Write.OpenConnections >= m.config.Base.MaxOpenConn {
|
||||
return fmt.Errorf("写库连接池已满")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *mysqlRepo) GetWrite(options ...Option) *gorm.DB {
|
||||
opt := newOption()
|
||||
for _, f := range options {
|
||||
f(opt)
|
||||
// mergeDefaultConfig 合并默认配置
|
||||
func mergeDefaultConfig(cfg *MySQLConfig) {
|
||||
// 读库默认配置
|
||||
if cfg.Read.Charset == "" {
|
||||
cfg.Read.Charset = "utf8mb4"
|
||||
}
|
||||
|
||||
db := d.write
|
||||
if opt.Trace != nil {
|
||||
db.InstanceSet(traceCtxName, opt.Trace)
|
||||
if cfg.Read.Collation == "" {
|
||||
cfg.Read.Collation = "utf8mb4_unicode_ci"
|
||||
}
|
||||
if cfg.Read.Loc == "" {
|
||||
cfg.Read.Loc = "Local"
|
||||
}
|
||||
if cfg.Read.Timeout == 0 {
|
||||
cfg.Read.Timeout = 10 // 10秒
|
||||
}
|
||||
if cfg.Read.ReadTimeout == 0 {
|
||||
cfg.Read.ReadTimeout = 30 // 30秒
|
||||
}
|
||||
if cfg.Read.WriteTimeout == 0 {
|
||||
cfg.Read.WriteTimeout = 30 // 30秒
|
||||
}
|
||||
cfg.Read.ParseTime = true
|
||||
|
||||
return db
|
||||
// 写库默认配置
|
||||
if cfg.Write.Charset == "" {
|
||||
cfg.Write.Charset = "utf8mb4"
|
||||
}
|
||||
if cfg.Write.Collation == "" {
|
||||
cfg.Write.Collation = "utf8mb4_unicode_ci"
|
||||
}
|
||||
if cfg.Write.Loc == "" {
|
||||
cfg.Write.Loc = "Local"
|
||||
}
|
||||
if cfg.Write.Timeout == 0 {
|
||||
cfg.Write.Timeout = 10
|
||||
}
|
||||
if cfg.Write.ReadTimeout == 0 {
|
||||
cfg.Write.ReadTimeout = 30
|
||||
}
|
||||
if cfg.Write.WriteTimeout == 0 {
|
||||
cfg.Write.WriteTimeout = 30
|
||||
}
|
||||
cfg.Write.ParseTime = true
|
||||
}
|
||||
|
||||
func (d *mysqlRepo) Close() (err error) {
|
||||
rdb, err1 := d.read.DB()
|
||||
if err1 != nil {
|
||||
err = errors.Join(err1)
|
||||
// validateMySQLConfig 验证配置
|
||||
func validateMySQLConfig(cfg *MySQLConfig) error {
|
||||
if cfg.Read.Addr == "" {
|
||||
return fmt.Errorf("读库地址不能为空")
|
||||
}
|
||||
err2 := rdb.Close()
|
||||
if err2 != nil {
|
||||
err = errors.Join(err2)
|
||||
if cfg.Read.User == "" {
|
||||
return fmt.Errorf("读库用户名不能为空")
|
||||
}
|
||||
if cfg.Read.Name == "" {
|
||||
return fmt.Errorf("读库数据库名不能为空")
|
||||
}
|
||||
|
||||
wdb, err3 := d.write.DB()
|
||||
if err3 != nil {
|
||||
err = errors.Join(err3)
|
||||
if cfg.Write.Addr == "" {
|
||||
return fmt.Errorf("写库地址不能为空")
|
||||
}
|
||||
err4 := wdb.Close()
|
||||
if err4 != nil {
|
||||
err = errors.Join(err4)
|
||||
if cfg.Write.User == "" {
|
||||
return fmt.Errorf("写库用户名不能为空")
|
||||
}
|
||||
if cfg.Write.Name == "" {
|
||||
return fmt.Errorf("写库数据库名不能为空")
|
||||
}
|
||||
|
||||
return err
|
||||
return nil
|
||||
}
|
||||
|
||||
func dbConnect(user, pass, addr, dbName string, maxOpenConn, maxIdleConn int, connMaxLifeTime time.Duration) (*gorm.DB, error) {
|
||||
dsn := fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=utf8mb4&parseTime=%t&loc=%s",
|
||||
user,
|
||||
pass,
|
||||
addr,
|
||||
dbName,
|
||||
true,
|
||||
"Local")
|
||||
// buildDSN 构建DSN
|
||||
func buildDSN(conn DBConnConfig) string {
|
||||
return fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=%s&collation=%s&parseTime=%t&loc=%s&timeout=%ds&readTimeout=%ds&writeTimeout=%ds",
|
||||
conn.User,
|
||||
conn.Pass,
|
||||
conn.Addr,
|
||||
conn.Name,
|
||||
conn.Charset,
|
||||
conn.Collation,
|
||||
conn.ParseTime,
|
||||
conn.Loc,
|
||||
int(conn.Timeout),
|
||||
int(conn.ReadTimeout),
|
||||
int(conn.WriteTimeout),
|
||||
)
|
||||
}
|
||||
|
||||
// dbConnect 连接数据库
|
||||
func dbConnect(name string, conn DBConnConfig, base BaseConfig, logCfg LoggerConfig) (*gorm.DB, error) {
|
||||
// 构建DSN
|
||||
dsn := buildDSN(conn)
|
||||
|
||||
// 配置日志
|
||||
logLevel := parseLogLevel(logCfg.LogLevel)
|
||||
logWriter := getLogWriter(logCfg)
|
||||
|
||||
// 日志配置
|
||||
newLogger := logger.New(
|
||||
log.New(os.Stdout, "\r\n", log.LstdFlags),
|
||||
log.New(logWriter, "\r\n", log.LstdFlags),
|
||||
logger.Config{
|
||||
SlowThreshold: time.Second, // 慢SQL阈值
|
||||
Colorful: true, // 彩色打印
|
||||
IgnoreRecordNotFoundError: true, // 忽略记录未找到错误
|
||||
LogLevel: logger.Error, // 日志级别
|
||||
SlowThreshold: logCfg.SlowThreshold,
|
||||
Colorful: logCfg.Colorful,
|
||||
IgnoreRecordNotFoundError: logCfg.IgnoreRecordNotFoundError,
|
||||
LogLevel: logLevel,
|
||||
},
|
||||
)
|
||||
|
||||
// 打开数据库连接
|
||||
db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{
|
||||
NamingStrategy: schema.NamingStrategy{
|
||||
SingularTable: true,
|
||||
},
|
||||
Logger: newLogger,
|
||||
Logger: newLogger,
|
||||
SkipDefaultTransaction: true,
|
||||
PrepareStmt: true,
|
||||
DisableForeignKeyConstraintWhenMigrating: true,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, errors.Join(err, fmt.Errorf("[db connection failed] Database name: %s", dbName))
|
||||
return nil, fmt.Errorf("打开数据库连接失败 [%s]: %w", name, err)
|
||||
}
|
||||
|
||||
db.Set("gorm:table_options", "CHARSET=utf8mb4")
|
||||
|
||||
// 获取底层sqlDB
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("获取sqlDB失败: %w", err)
|
||||
}
|
||||
|
||||
// 设置连接池 用于设置最大打开的连接数,默认值为0表示不限制.设置最大的连接数,可以避免并发太高导致连接mysql出现too many connections的错误。
|
||||
sqlDB.SetMaxOpenConns(maxOpenConn)
|
||||
|
||||
// 设置最大连接数 用于设置闲置的连接数.设置闲置的连接数则当开启的一个连接使用完成后可以放在池里等候下一次使用。
|
||||
sqlDB.SetMaxIdleConns(maxIdleConn)
|
||||
|
||||
// 设置最大连接超时
|
||||
sqlDB.SetConnMaxLifetime(time.Minute * connMaxLifeTime)
|
||||
|
||||
// 使用插件
|
||||
err = db.Use(&TracePlugin{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
// 设置连接池参数
|
||||
if base.MaxOpenConn > 0 {
|
||||
sqlDB.SetMaxOpenConns(base.MaxOpenConn)
|
||||
}
|
||||
if base.MaxIdleConn > 0 {
|
||||
sqlDB.SetMaxIdleConns(base.MaxIdleConn)
|
||||
}
|
||||
if base.ConnMaxLifetime > 0 {
|
||||
sqlDB.SetConnMaxLifetime(base.ConnMaxLifetime * time.Second)
|
||||
}
|
||||
if base.ConnMaxIdleTime > 0 {
|
||||
sqlDB.SetConnMaxIdleTime(base.ConnMaxIdleTime * time.Second)
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
/***************************************************************/
|
||||
|
||||
type TracePlugin struct{}
|
||||
|
||||
func (op *TracePlugin) Name() string {
|
||||
return "TracePlugin"
|
||||
// closeDB 关闭数据库
|
||||
func closeDB(db *gorm.DB) error {
|
||||
if db == nil {
|
||||
return nil
|
||||
}
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return sqlDB.Close()
|
||||
}
|
||||
|
||||
func (op *TracePlugin) Initialize(db *gorm.DB) (err error) {
|
||||
// 开始前
|
||||
_ = db.Callback().Create().Before("gorm:before_create").Register(callBackBeforeName, before)
|
||||
_ = db.Callback().Query().Before("gorm:query").Register(callBackBeforeName, before)
|
||||
_ = db.Callback().Delete().Before("gorm:before_delete").Register(callBackBeforeName, before)
|
||||
_ = db.Callback().Update().Before("gorm:setup_reflect_value").Register(callBackBeforeName, before)
|
||||
_ = db.Callback().Row().Before("gorm:row").Register(callBackBeforeName, before)
|
||||
_ = db.Callback().Raw().Before("gorm:raw").Register(callBackBeforeName, before)
|
||||
|
||||
// 结束后
|
||||
_ = db.Callback().Create().After("gorm:after_create").Register(callBackAfterName, after)
|
||||
_ = db.Callback().Query().After("gorm:after_query").Register(callBackAfterName, after)
|
||||
_ = db.Callback().Delete().After("gorm:after_delete").Register(callBackAfterName, after)
|
||||
_ = db.Callback().Update().After("gorm:after_update").Register(callBackAfterName, after)
|
||||
_ = db.Callback().Row().After("gorm:row").Register(callBackAfterName, after)
|
||||
_ = db.Callback().Raw().After("gorm:raw").Register(callBackAfterName, after)
|
||||
return
|
||||
// parseLogLevel 解析日志级别
|
||||
func parseLogLevel(level string) logger.LogLevel {
|
||||
switch level {
|
||||
case "silent":
|
||||
return logger.Silent
|
||||
case "error":
|
||||
return logger.Error
|
||||
case "warn":
|
||||
return logger.Warn
|
||||
case "info":
|
||||
return logger.Info
|
||||
default:
|
||||
return logger.Warn
|
||||
}
|
||||
}
|
||||
|
||||
func before(db *gorm.DB) {
|
||||
db.InstanceSet(startTime, time.Now())
|
||||
}
|
||||
|
||||
func after(db *gorm.DB) {
|
||||
_traceCtx, isExist := db.InstanceGet(traceCtxName)
|
||||
if !isExist {
|
||||
return
|
||||
}
|
||||
_trace, ok := _traceCtx.(trace.T)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
_ts, isExist := db.InstanceGet(startTime)
|
||||
if !isExist {
|
||||
return
|
||||
}
|
||||
|
||||
ts, ok := _ts.(time.Time)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
sql := db.Dialector.Explain(db.Statement.SQL.String(), db.Statement.Vars...)
|
||||
|
||||
sqlInfo := new(trace.SQL)
|
||||
sqlInfo.Timestamp = time_parse.CSTLayoutString()
|
||||
sqlInfo.SQL = sql
|
||||
sqlInfo.Stack = utils.FileWithLineNum()
|
||||
sqlInfo.Rows = db.Statement.RowsAffected
|
||||
sqlInfo.CostSeconds = time.Since(ts).Seconds()
|
||||
_trace.AppendSQL(sqlInfo)
|
||||
// getLogWriter 获取日志输出
|
||||
func getLogWriter(cfg LoggerConfig) *os.File {
|
||||
if cfg.LogOutput == "file" && cfg.LogFile != "" {
|
||||
file, err := os.OpenFile(cfg.LogFile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666)
|
||||
if err != nil {
|
||||
log.Printf("打开日志文件失败: %v,使用标准输出", err)
|
||||
return os.Stdout
|
||||
}
|
||||
return file
|
||||
}
|
||||
return os.Stdout
|
||||
}
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"gorm.io/gorm"
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type NullTime sql.NullTime
|
||||
@@ -50,6 +53,7 @@ type PaginateList struct {
|
||||
List any `json:"list"`
|
||||
}
|
||||
|
||||
// Paginate 分页查询
|
||||
func Paginate(db *gorm.DB, model any, page, size int64) (*PaginateList, error) {
|
||||
ptr := reflect.ValueOf(model)
|
||||
if ptr.Kind() != reflect.Ptr {
|
||||
@@ -86,3 +90,63 @@ func Paginate(db *gorm.DB, model any, page, size int64) (*PaginateList, error) {
|
||||
List: model,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// BatchInsert 批量插入
|
||||
func BatchInsert(ctx context.Context, db *gorm.DB, records interface{}, batchSize int) error {
|
||||
if batchSize <= 0 {
|
||||
batchSize = 100
|
||||
}
|
||||
|
||||
return db.WithContext(ctx).CreateInBatches(records, batchSize).Error
|
||||
}
|
||||
|
||||
// ExistsBy 检查记录是否存在
|
||||
func ExistsBy(ctx context.Context, db *gorm.DB, model interface{}, query interface{}, args ...interface{}) (bool, error) {
|
||||
var count int64
|
||||
err := db.WithContext(ctx).Model(model).Where(query, args...).Count(&count).Error
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
// FindByIDs 根据ID列表查询
|
||||
func FindByIDs(ctx context.Context, db *gorm.DB, dest interface{}, ids []uint) error {
|
||||
if len(ids) == 0 {
|
||||
return nil
|
||||
}
|
||||
return db.WithContext(ctx).Where("id IN ?", ids).Find(dest).Error
|
||||
}
|
||||
|
||||
// SoftDeleteByIDs 批量软删除
|
||||
func SoftDeleteByIDs(ctx context.Context, db *gorm.DB, model interface{}, ids []uint) error {
|
||||
if len(ids) == 0 {
|
||||
return nil
|
||||
}
|
||||
return db.WithContext(ctx).Where("id IN ?", ids).Delete(model).Error
|
||||
}
|
||||
|
||||
// UpdateFields 更新指定字段
|
||||
func UpdateFields(ctx context.Context, db *gorm.DB, model interface{}, id uint, fields map[string]interface{}) error {
|
||||
return db.WithContext(ctx).Model(model).Where("id = ?", id).Updates(fields).Error
|
||||
}
|
||||
|
||||
// Transaction 简化的事务助手
|
||||
func Transaction(ctx context.Context, db *gorm.DB, fn func(*gorm.DB) error) error {
|
||||
return db.WithContext(ctx).Transaction(fn)
|
||||
}
|
||||
|
||||
// BaseModel 基础模型
|
||||
type BaseModel struct {
|
||||
ID uint `gorm:"primarykey" json:"id"`
|
||||
CreatedAt time.Time `gorm:"autoCreateTime" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"autoUpdateTime" json:"updated_at"`
|
||||
DeletedAt gorm.DeletedAt `gorm:"index" json:"deleted_at,omitempty"`
|
||||
}
|
||||
|
||||
// SimpleModel 简单模型(无软删除)
|
||||
type SimpleModel struct {
|
||||
ID uint `gorm:"primarykey" json:"id"`
|
||||
CreatedAt time.Time `gorm:"autoCreateTime" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"autoUpdateTime" json:"updated_at"`
|
||||
}
|
||||
|
||||
@@ -1,23 +0,0 @@
|
||||
package database
|
||||
|
||||
import "gitea.bvbej.com/bvbej/base-golang/pkg/trace"
|
||||
|
||||
type Trace = trace.T
|
||||
|
||||
type Option func(*option)
|
||||
|
||||
func WithTrace(t Trace) Option {
|
||||
return func(opt *option) {
|
||||
if t != nil {
|
||||
opt.Trace = t.(*trace.Trace)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func newOption() *option {
|
||||
return &option{}
|
||||
}
|
||||
|
||||
type option struct {
|
||||
Trace *trace.Trace
|
||||
}
|
||||
Reference in New Issue
Block a user