first commit
This commit is contained in:
		
							
								
								
									
										77
									
								
								pkg/database/mongo.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										77
									
								
								pkg/database/mongo.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,77 @@
 | 
			
		||||
package database
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"go.mongodb.org/mongo-driver/mongo"
 | 
			
		||||
	"go.mongodb.org/mongo-driver/mongo/options"
 | 
			
		||||
	"go.mongodb.org/mongo-driver/mongo/readpref"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var _ MongoDB = (*mongoDB)(nil)
 | 
			
		||||
 | 
			
		||||
type MongoDB interface {
 | 
			
		||||
	i()
 | 
			
		||||
	GetDB() *mongo.Database
 | 
			
		||||
	Close() error
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type MongoDBConfig struct {
 | 
			
		||||
	Addr    string        `yaml:"addr"`
 | 
			
		||||
	User    string        `yaml:"user"`
 | 
			
		||||
	Pass    string        `yaml:"pass"`
 | 
			
		||||
	Name    string        `yaml:"name"`
 | 
			
		||||
	Timeout time.Duration `yaml:"timeout"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type mongoDB struct {
 | 
			
		||||
	client  *mongo.Client
 | 
			
		||||
	db      *mongo.Database
 | 
			
		||||
	timeout time.Duration
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m *mongoDB) i() {}
 | 
			
		||||
 | 
			
		||||
func NewMongoDB(cfg MongoDBConfig) (MongoDB, error) {
 | 
			
		||||
	timeout := cfg.Timeout * time.Second
 | 
			
		||||
	connectCtx, connectCancelFunc := context.WithTimeout(context.Background(), timeout)
 | 
			
		||||
	defer connectCancelFunc()
 | 
			
		||||
	var auth string
 | 
			
		||||
	if len(cfg.User) > 0 && len(cfg.Pass) > 0 {
 | 
			
		||||
		auth = fmt.Sprintf("%s:%s@", cfg.User, cfg.Pass)
 | 
			
		||||
	}
 | 
			
		||||
	client, err := mongo.Connect(connectCtx, options.Client().ApplyURI(
 | 
			
		||||
		fmt.Sprintf("mongodb://%s%s", auth, cfg.Addr),
 | 
			
		||||
	))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	pingCtx, pingCancelFunc := context.WithTimeout(context.Background(), timeout)
 | 
			
		||||
	defer pingCancelFunc()
 | 
			
		||||
	err = client.Ping(pingCtx, readpref.Primary())
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return &mongoDB{
 | 
			
		||||
		client:  client,
 | 
			
		||||
		db:      client.Database(cfg.Name),
 | 
			
		||||
		timeout: timeout,
 | 
			
		||||
	}, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m *mongoDB) GetDB() *mongo.Database {
 | 
			
		||||
	return m.db
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m *mongoDB) Close() error {
 | 
			
		||||
	disconnectCtx, disconnectCancelFunc := context.WithTimeout(context.Background(), m.timeout)
 | 
			
		||||
	defer disconnectCancelFunc()
 | 
			
		||||
	err := m.client.Disconnect(disconnectCtx)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										247
									
								
								pkg/database/mysql.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										247
									
								
								pkg/database/mysql.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,247 @@
 | 
			
		||||
package database
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"log"
 | 
			
		||||
	"os"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"git.bvbej.com/bvbej/base-golang/pkg/time_parse"
 | 
			
		||||
	"git.bvbej.com/bvbej/base-golang/pkg/trace"
 | 
			
		||||
	"gorm.io/driver/mysql"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
	"gorm.io/gorm/logger"
 | 
			
		||||
	"gorm.io/gorm/schema"
 | 
			
		||||
	"gorm.io/gorm/utils"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	callBackBeforeName = "core:before"
 | 
			
		||||
	callBackAfterName  = "core:after"
 | 
			
		||||
	startTime          = "_start_time"
 | 
			
		||||
	traceCtxName       = "_trace_ctx_name"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var _ MysqlRepo = (*mysqlRepo)(nil)
 | 
			
		||||
 | 
			
		||||
type MysqlRepo interface {
 | 
			
		||||
	i()
 | 
			
		||||
	GetRead(options ...Option) *gorm.DB
 | 
			
		||||
	GetWrite(options ...Option) *gorm.DB
 | 
			
		||||
	Close() error
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type MySQLConfig struct {
 | 
			
		||||
	Read struct {
 | 
			
		||||
		Addr string `yaml:"addr"`
 | 
			
		||||
		User string `yaml:"user"`
 | 
			
		||||
		Pass string `yaml:"pass"`
 | 
			
		||||
		Name string `yaml:"name"`
 | 
			
		||||
	} `yaml:"read"`
 | 
			
		||||
	Write struct {
 | 
			
		||||
		Addr string `yaml:"addr"`
 | 
			
		||||
		User string `yaml:"user"`
 | 
			
		||||
		Pass string `yaml:"pass"`
 | 
			
		||||
		Name string `yaml:"name"`
 | 
			
		||||
	} `yaml:"write"`
 | 
			
		||||
	Base struct {
 | 
			
		||||
		MaxOpenConn     int           `yaml:"maxOpenConn"`     //最大连接数
 | 
			
		||||
		MaxIdleConn     int           `yaml:"maxIdleConn"`     //最大空闲连接数
 | 
			
		||||
		ConnMaxLifeTime time.Duration `yaml:"connMaxLifeTime"` //最大连接超时(分钟)
 | 
			
		||||
	} `yaml:"base"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type mysqlRepo struct {
 | 
			
		||||
	read  *gorm.DB
 | 
			
		||||
	write *gorm.DB
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewMysql(cfg MySQLConfig) (MysqlRepo, error) {
 | 
			
		||||
	dbr, err := dbConnect(cfg.Read.User, cfg.Read.Pass, cfg.Read.Addr, cfg.Read.Name,
 | 
			
		||||
		cfg.Base.MaxOpenConn, cfg.Base.MaxIdleConn, cfg.Base.ConnMaxLifeTime)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	dbw, err := dbConnect(cfg.Write.User, cfg.Write.Pass, cfg.Write.Addr, cfg.Write.Name,
 | 
			
		||||
		cfg.Base.MaxOpenConn, cfg.Base.MaxIdleConn, cfg.Base.ConnMaxLifeTime)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return &mysqlRepo{
 | 
			
		||||
		read:  dbr,
 | 
			
		||||
		write: dbw,
 | 
			
		||||
	}, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (d *mysqlRepo) i() {}
 | 
			
		||||
 | 
			
		||||
func (d *mysqlRepo) GetRead(options ...Option) *gorm.DB {
 | 
			
		||||
	opt := newOption()
 | 
			
		||||
	for _, f := range options {
 | 
			
		||||
		f(opt)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	db := d.read
 | 
			
		||||
	if opt.Trace != nil {
 | 
			
		||||
		db.InstanceSet(traceCtxName, opt.Trace)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return db
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (d *mysqlRepo) GetWrite(options ...Option) *gorm.DB {
 | 
			
		||||
	opt := newOption()
 | 
			
		||||
	for _, f := range options {
 | 
			
		||||
		f(opt)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	db := d.write
 | 
			
		||||
	if opt.Trace != nil {
 | 
			
		||||
		db.InstanceSet(traceCtxName, opt.Trace)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return db
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (d *mysqlRepo) Close() (err error) {
 | 
			
		||||
	rdb, err1 := d.read.DB()
 | 
			
		||||
	if err1 != nil {
 | 
			
		||||
		err = errors.Join(err1)
 | 
			
		||||
	}
 | 
			
		||||
	err2 := rdb.Close()
 | 
			
		||||
	if err2 != nil {
 | 
			
		||||
		err = errors.Join(err2)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	wdb, err3 := d.write.DB()
 | 
			
		||||
	if err3 != nil {
 | 
			
		||||
		err = errors.Join(err3)
 | 
			
		||||
	}
 | 
			
		||||
	err4 := wdb.Close()
 | 
			
		||||
	if err4 != nil {
 | 
			
		||||
		err = errors.Join(err4)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func dbConnect(user, pass, addr, dbName string, maxOpenConn, maxIdleConn int, connMaxLifeTime time.Duration) (*gorm.DB, error) {
 | 
			
		||||
	dsn := fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=utf8mb4&parseTime=%t&loc=%s",
 | 
			
		||||
		user,
 | 
			
		||||
		pass,
 | 
			
		||||
		addr,
 | 
			
		||||
		dbName,
 | 
			
		||||
		true,
 | 
			
		||||
		"Local")
 | 
			
		||||
 | 
			
		||||
	// 日志配置
 | 
			
		||||
	newLogger := logger.New(
 | 
			
		||||
		log.New(os.Stdout, "\r\n", log.LstdFlags),
 | 
			
		||||
		logger.Config{
 | 
			
		||||
			SlowThreshold:             time.Second,  // 慢SQL阈值
 | 
			
		||||
			Colorful:                  true,         // 彩色打印
 | 
			
		||||
			IgnoreRecordNotFoundError: true,         // 忽略记录未找到错误
 | 
			
		||||
			LogLevel:                  logger.Error, // 日志级别
 | 
			
		||||
		},
 | 
			
		||||
	)
 | 
			
		||||
 | 
			
		||||
	db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{
 | 
			
		||||
		NamingStrategy: schema.NamingStrategy{
 | 
			
		||||
			SingularTable: true,
 | 
			
		||||
		},
 | 
			
		||||
		Logger: newLogger,
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, errors.Join(err, fmt.Errorf("[db connection failed] Database name: %s", dbName))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	db.Set("gorm:table_options", "CHARSET=utf8mb4")
 | 
			
		||||
 | 
			
		||||
	sqlDB, err := db.DB()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 设置连接池 用于设置最大打开的连接数,默认值为0表示不限制.设置最大的连接数,可以避免并发太高导致连接mysql出现too many connections的错误。
 | 
			
		||||
	sqlDB.SetMaxOpenConns(maxOpenConn)
 | 
			
		||||
 | 
			
		||||
	// 设置最大连接数 用于设置闲置的连接数.设置闲置的连接数则当开启的一个连接使用完成后可以放在池里等候下一次使用。
 | 
			
		||||
	sqlDB.SetMaxIdleConns(maxIdleConn)
 | 
			
		||||
 | 
			
		||||
	// 设置最大连接超时
 | 
			
		||||
	sqlDB.SetConnMaxLifetime(time.Minute * connMaxLifeTime)
 | 
			
		||||
 | 
			
		||||
	// 使用插件
 | 
			
		||||
	err = db.Use(&TracePlugin{})
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return db, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/***************************************************************/
 | 
			
		||||
 | 
			
		||||
type TracePlugin struct{}
 | 
			
		||||
 | 
			
		||||
func (op *TracePlugin) Name() string {
 | 
			
		||||
	return "TracePlugin"
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (op *TracePlugin) Initialize(db *gorm.DB) (err error) {
 | 
			
		||||
	// 开始前
 | 
			
		||||
	_ = db.Callback().Create().Before("gorm:before_create").Register(callBackBeforeName, before)
 | 
			
		||||
	_ = db.Callback().Query().Before("gorm:query").Register(callBackBeforeName, before)
 | 
			
		||||
	_ = db.Callback().Delete().Before("gorm:before_delete").Register(callBackBeforeName, before)
 | 
			
		||||
	_ = db.Callback().Update().Before("gorm:setup_reflect_value").Register(callBackBeforeName, before)
 | 
			
		||||
	_ = db.Callback().Row().Before("gorm:row").Register(callBackBeforeName, before)
 | 
			
		||||
	_ = db.Callback().Raw().Before("gorm:raw").Register(callBackBeforeName, before)
 | 
			
		||||
 | 
			
		||||
	// 结束后
 | 
			
		||||
	_ = db.Callback().Create().After("gorm:after_create").Register(callBackAfterName, after)
 | 
			
		||||
	_ = db.Callback().Query().After("gorm:after_query").Register(callBackAfterName, after)
 | 
			
		||||
	_ = db.Callback().Delete().After("gorm:after_delete").Register(callBackAfterName, after)
 | 
			
		||||
	_ = db.Callback().Update().After("gorm:after_update").Register(callBackAfterName, after)
 | 
			
		||||
	_ = db.Callback().Row().After("gorm:row").Register(callBackAfterName, after)
 | 
			
		||||
	_ = db.Callback().Raw().After("gorm:raw").Register(callBackAfterName, after)
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func before(db *gorm.DB) {
 | 
			
		||||
	db.InstanceSet(startTime, time.Now())
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func after(db *gorm.DB) {
 | 
			
		||||
	_traceCtx, isExist := db.InstanceGet(traceCtxName)
 | 
			
		||||
	if !isExist {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	_trace, ok := _traceCtx.(trace.T)
 | 
			
		||||
	if !ok {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	_ts, isExist := db.InstanceGet(startTime)
 | 
			
		||||
	if !isExist {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	ts, ok := _ts.(time.Time)
 | 
			
		||||
	if !ok {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	sql := db.Dialector.Explain(db.Statement.SQL.String(), db.Statement.Vars...)
 | 
			
		||||
 | 
			
		||||
	sqlInfo := new(trace.SQL)
 | 
			
		||||
	sqlInfo.Timestamp = time_parse.CSTLayoutString()
 | 
			
		||||
	sqlInfo.SQL = sql
 | 
			
		||||
	sqlInfo.Stack = utils.FileWithLineNum()
 | 
			
		||||
	sqlInfo.Rows = db.Statement.RowsAffected
 | 
			
		||||
	sqlInfo.CostSeconds = time.Since(ts).Seconds()
 | 
			
		||||
	_trace.AppendSQL(sqlInfo)
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										88
									
								
								pkg/database/tool.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										88
									
								
								pkg/database/tool.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,88 @@
 | 
			
		||||
package database
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"database/sql"
 | 
			
		||||
	"database/sql/driver"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
	"reflect"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type NullTime sql.NullTime
 | 
			
		||||
 | 
			
		||||
func (n *NullTime) Scan(value any) error {
 | 
			
		||||
	return (*sql.NullTime)(n).Scan(value)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (n NullTime) Value() (driver.Value, error) {
 | 
			
		||||
	if !n.Valid {
 | 
			
		||||
		return nil, nil
 | 
			
		||||
	}
 | 
			
		||||
	return n.Time, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (n NullTime) MarshalJSON() ([]byte, error) {
 | 
			
		||||
	if n.Valid {
 | 
			
		||||
		return json.Marshal(n.Time)
 | 
			
		||||
	}
 | 
			
		||||
	return json.Marshal(nil)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (n *NullTime) UnmarshalJSON(b []byte) error {
 | 
			
		||||
	if string(b) == "null" {
 | 
			
		||||
		n.Valid = false
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	err := json.Unmarshal(b, &n.Time)
 | 
			
		||||
	if err == nil {
 | 
			
		||||
		n.Valid = true
 | 
			
		||||
	}
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/*-----------------------------------------------------------*/
 | 
			
		||||
 | 
			
		||||
type PaginateList struct {
 | 
			
		||||
	Page  int64 `json:"page"`
 | 
			
		||||
	Size  int64 `json:"size"`
 | 
			
		||||
	Total int64 `json:"total"`
 | 
			
		||||
	List  any   `json:"list"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func Paginate(db *gorm.DB, model any, page, size int64) (*PaginateList, error) {
 | 
			
		||||
	ptr := reflect.ValueOf(model)
 | 
			
		||||
	if ptr.Kind() != reflect.Ptr {
 | 
			
		||||
		return nil, fmt.Errorf("model must be pointer")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var total int64
 | 
			
		||||
	err := db.Model(model).Count(&total).Error
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return &PaginateList{
 | 
			
		||||
			Page:  page,
 | 
			
		||||
			Size:  size,
 | 
			
		||||
			Total: total,
 | 
			
		||||
			List:  make([]any, 0),
 | 
			
		||||
		}, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	offset := size * (page - 1)
 | 
			
		||||
	err = db.Limit(int(size)).Offset(int(offset)).Find(model).Error
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return &PaginateList{
 | 
			
		||||
			Page:  page,
 | 
			
		||||
			Size:  size,
 | 
			
		||||
			Total: total,
 | 
			
		||||
			List:  make([]any, 0),
 | 
			
		||||
		}, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return &PaginateList{
 | 
			
		||||
		Page:  page,
 | 
			
		||||
		Size:  size,
 | 
			
		||||
		Total: total,
 | 
			
		||||
		List:  model,
 | 
			
		||||
	}, nil
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										23
									
								
								pkg/database/trace.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										23
									
								
								pkg/database/trace.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,23 @@
 | 
			
		||||
package database
 | 
			
		||||
 | 
			
		||||
import "git.bvbej.com/bvbej/base-golang/pkg/trace"
 | 
			
		||||
 | 
			
		||||
type Trace = trace.T
 | 
			
		||||
 | 
			
		||||
type Option func(*option)
 | 
			
		||||
 | 
			
		||||
func WithTrace(t Trace) Option {
 | 
			
		||||
	return func(opt *option) {
 | 
			
		||||
		if t != nil {
 | 
			
		||||
			opt.Trace = t.(*trace.Trace)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func newOption() *option {
 | 
			
		||||
	return &option{}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type option struct {
 | 
			
		||||
	Trace *trace.Trace
 | 
			
		||||
}
 | 
			
		||||
		Reference in New Issue
	
	Block a user