first commit

This commit is contained in:
2024-07-23 10:23:43 +08:00
commit 7b4c2521a3
126 changed files with 15931 additions and 0 deletions

77
pkg/database/mongo.go Normal file
View File

@ -0,0 +1,77 @@
package database
import (
"context"
"fmt"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
"go.mongodb.org/mongo-driver/mongo/readpref"
"time"
)
var _ MongoDB = (*mongoDB)(nil)
type MongoDB interface {
i()
GetDB() *mongo.Database
Close() error
}
type MongoDBConfig struct {
Addr string `yaml:"addr"`
User string `yaml:"user"`
Pass string `yaml:"pass"`
Name string `yaml:"name"`
Timeout time.Duration `yaml:"timeout"`
}
type mongoDB struct {
client *mongo.Client
db *mongo.Database
timeout time.Duration
}
func (m *mongoDB) i() {}
func NewMongoDB(cfg MongoDBConfig) (MongoDB, error) {
timeout := cfg.Timeout * time.Second
connectCtx, connectCancelFunc := context.WithTimeout(context.Background(), timeout)
defer connectCancelFunc()
var auth string
if len(cfg.User) > 0 && len(cfg.Pass) > 0 {
auth = fmt.Sprintf("%s:%s@", cfg.User, cfg.Pass)
}
client, err := mongo.Connect(connectCtx, options.Client().ApplyURI(
fmt.Sprintf("mongodb://%s%s", auth, cfg.Addr),
))
if err != nil {
return nil, err
}
pingCtx, pingCancelFunc := context.WithTimeout(context.Background(), timeout)
defer pingCancelFunc()
err = client.Ping(pingCtx, readpref.Primary())
if err != nil {
return nil, err
}
return &mongoDB{
client: client,
db: client.Database(cfg.Name),
timeout: timeout,
}, nil
}
func (m *mongoDB) GetDB() *mongo.Database {
return m.db
}
func (m *mongoDB) Close() error {
disconnectCtx, disconnectCancelFunc := context.WithTimeout(context.Background(), m.timeout)
defer disconnectCancelFunc()
err := m.client.Disconnect(disconnectCtx)
if err != nil {
return err
}
return nil
}

247
pkg/database/mysql.go Normal file
View File

@ -0,0 +1,247 @@
package database
import (
"errors"
"fmt"
"log"
"os"
"time"
"git.bvbej.com/bvbej/base-golang/pkg/time_parse"
"git.bvbej.com/bvbej/base-golang/pkg/trace"
"gorm.io/driver/mysql"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils"
)
const (
callBackBeforeName = "core:before"
callBackAfterName = "core:after"
startTime = "_start_time"
traceCtxName = "_trace_ctx_name"
)
var _ MysqlRepo = (*mysqlRepo)(nil)
type MysqlRepo interface {
i()
GetRead(options ...Option) *gorm.DB
GetWrite(options ...Option) *gorm.DB
Close() error
}
type MySQLConfig struct {
Read struct {
Addr string `yaml:"addr"`
User string `yaml:"user"`
Pass string `yaml:"pass"`
Name string `yaml:"name"`
} `yaml:"read"`
Write struct {
Addr string `yaml:"addr"`
User string `yaml:"user"`
Pass string `yaml:"pass"`
Name string `yaml:"name"`
} `yaml:"write"`
Base struct {
MaxOpenConn int `yaml:"maxOpenConn"` //最大连接数
MaxIdleConn int `yaml:"maxIdleConn"` //最大空闲连接数
ConnMaxLifeTime time.Duration `yaml:"connMaxLifeTime"` //最大连接超时(分钟)
} `yaml:"base"`
}
type mysqlRepo struct {
read *gorm.DB
write *gorm.DB
}
func NewMysql(cfg MySQLConfig) (MysqlRepo, error) {
dbr, err := dbConnect(cfg.Read.User, cfg.Read.Pass, cfg.Read.Addr, cfg.Read.Name,
cfg.Base.MaxOpenConn, cfg.Base.MaxIdleConn, cfg.Base.ConnMaxLifeTime)
if err != nil {
return nil, err
}
dbw, err := dbConnect(cfg.Write.User, cfg.Write.Pass, cfg.Write.Addr, cfg.Write.Name,
cfg.Base.MaxOpenConn, cfg.Base.MaxIdleConn, cfg.Base.ConnMaxLifeTime)
if err != nil {
return nil, err
}
return &mysqlRepo{
read: dbr,
write: dbw,
}, nil
}
func (d *mysqlRepo) i() {}
func (d *mysqlRepo) GetRead(options ...Option) *gorm.DB {
opt := newOption()
for _, f := range options {
f(opt)
}
db := d.read
if opt.Trace != nil {
db.InstanceSet(traceCtxName, opt.Trace)
}
return db
}
func (d *mysqlRepo) GetWrite(options ...Option) *gorm.DB {
opt := newOption()
for _, f := range options {
f(opt)
}
db := d.write
if opt.Trace != nil {
db.InstanceSet(traceCtxName, opt.Trace)
}
return db
}
func (d *mysqlRepo) Close() (err error) {
rdb, err1 := d.read.DB()
if err1 != nil {
err = errors.Join(err1)
}
err2 := rdb.Close()
if err2 != nil {
err = errors.Join(err2)
}
wdb, err3 := d.write.DB()
if err3 != nil {
err = errors.Join(err3)
}
err4 := wdb.Close()
if err4 != nil {
err = errors.Join(err4)
}
return err
}
func dbConnect(user, pass, addr, dbName string, maxOpenConn, maxIdleConn int, connMaxLifeTime time.Duration) (*gorm.DB, error) {
dsn := fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=utf8mb4&parseTime=%t&loc=%s",
user,
pass,
addr,
dbName,
true,
"Local")
// 日志配置
newLogger := logger.New(
log.New(os.Stdout, "\r\n", log.LstdFlags),
logger.Config{
SlowThreshold: time.Second, // 慢SQL阈值
Colorful: true, // 彩色打印
IgnoreRecordNotFoundError: true, // 忽略记录未找到错误
LogLevel: logger.Error, // 日志级别
},
)
db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{
NamingStrategy: schema.NamingStrategy{
SingularTable: true,
},
Logger: newLogger,
})
if err != nil {
return nil, errors.Join(err, fmt.Errorf("[db connection failed] Database name: %s", dbName))
}
db.Set("gorm:table_options", "CHARSET=utf8mb4")
sqlDB, err := db.DB()
if err != nil {
return nil, err
}
// 设置连接池 用于设置最大打开的连接数默认值为0表示不限制.设置最大的连接数可以避免并发太高导致连接mysql出现too many connections的错误。
sqlDB.SetMaxOpenConns(maxOpenConn)
// 设置最大连接数 用于设置闲置的连接数.设置闲置的连接数则当开启的一个连接使用完成后可以放在池里等候下一次使用。
sqlDB.SetMaxIdleConns(maxIdleConn)
// 设置最大连接超时
sqlDB.SetConnMaxLifetime(time.Minute * connMaxLifeTime)
// 使用插件
err = db.Use(&TracePlugin{})
if err != nil {
return nil, err
}
return db, nil
}
/***************************************************************/
type TracePlugin struct{}
func (op *TracePlugin) Name() string {
return "TracePlugin"
}
func (op *TracePlugin) Initialize(db *gorm.DB) (err error) {
// 开始前
_ = db.Callback().Create().Before("gorm:before_create").Register(callBackBeforeName, before)
_ = db.Callback().Query().Before("gorm:query").Register(callBackBeforeName, before)
_ = db.Callback().Delete().Before("gorm:before_delete").Register(callBackBeforeName, before)
_ = db.Callback().Update().Before("gorm:setup_reflect_value").Register(callBackBeforeName, before)
_ = db.Callback().Row().Before("gorm:row").Register(callBackBeforeName, before)
_ = db.Callback().Raw().Before("gorm:raw").Register(callBackBeforeName, before)
// 结束后
_ = db.Callback().Create().After("gorm:after_create").Register(callBackAfterName, after)
_ = db.Callback().Query().After("gorm:after_query").Register(callBackAfterName, after)
_ = db.Callback().Delete().After("gorm:after_delete").Register(callBackAfterName, after)
_ = db.Callback().Update().After("gorm:after_update").Register(callBackAfterName, after)
_ = db.Callback().Row().After("gorm:row").Register(callBackAfterName, after)
_ = db.Callback().Raw().After("gorm:raw").Register(callBackAfterName, after)
return
}
func before(db *gorm.DB) {
db.InstanceSet(startTime, time.Now())
}
func after(db *gorm.DB) {
_traceCtx, isExist := db.InstanceGet(traceCtxName)
if !isExist {
return
}
_trace, ok := _traceCtx.(trace.T)
if !ok {
return
}
_ts, isExist := db.InstanceGet(startTime)
if !isExist {
return
}
ts, ok := _ts.(time.Time)
if !ok {
return
}
sql := db.Dialector.Explain(db.Statement.SQL.String(), db.Statement.Vars...)
sqlInfo := new(trace.SQL)
sqlInfo.Timestamp = time_parse.CSTLayoutString()
sqlInfo.SQL = sql
sqlInfo.Stack = utils.FileWithLineNum()
sqlInfo.Rows = db.Statement.RowsAffected
sqlInfo.CostSeconds = time.Since(ts).Seconds()
_trace.AppendSQL(sqlInfo)
}

88
pkg/database/tool.go Normal file
View File

@ -0,0 +1,88 @@
package database
import (
"database/sql"
"database/sql/driver"
"encoding/json"
"fmt"
"gorm.io/gorm"
"reflect"
)
type NullTime sql.NullTime
func (n *NullTime) Scan(value any) error {
return (*sql.NullTime)(n).Scan(value)
}
func (n NullTime) Value() (driver.Value, error) {
if !n.Valid {
return nil, nil
}
return n.Time, nil
}
func (n NullTime) MarshalJSON() ([]byte, error) {
if n.Valid {
return json.Marshal(n.Time)
}
return json.Marshal(nil)
}
func (n *NullTime) UnmarshalJSON(b []byte) error {
if string(b) == "null" {
n.Valid = false
return nil
}
err := json.Unmarshal(b, &n.Time)
if err == nil {
n.Valid = true
}
return err
}
/*-----------------------------------------------------------*/
type PaginateList struct {
Page int64 `json:"page"`
Size int64 `json:"size"`
Total int64 `json:"total"`
List any `json:"list"`
}
func Paginate(db *gorm.DB, model any, page, size int64) (*PaginateList, error) {
ptr := reflect.ValueOf(model)
if ptr.Kind() != reflect.Ptr {
return nil, fmt.Errorf("model must be pointer")
}
var total int64
err := db.Model(model).Count(&total).Error
if err != nil {
return &PaginateList{
Page: page,
Size: size,
Total: total,
List: make([]any, 0),
}, err
}
offset := size * (page - 1)
err = db.Limit(int(size)).Offset(int(offset)).Find(model).Error
if err != nil {
return &PaginateList{
Page: page,
Size: size,
Total: total,
List: make([]any, 0),
}, err
}
return &PaginateList{
Page: page,
Size: size,
Total: total,
List: model,
}, nil
}

23
pkg/database/trace.go Normal file
View File

@ -0,0 +1,23 @@
package database
import "git.bvbej.com/bvbej/base-golang/pkg/trace"
type Trace = trace.T
type Option func(*option)
func WithTrace(t Trace) Option {
return func(opt *option) {
if t != nil {
opt.Trace = t.(*trace.Trace)
}
}
}
func newOption() *option {
return &option{}
}
type option struct {
Trace *trace.Trace
}