Files
base-golang/pkg/database/tool.go
2026-02-07 15:36:21 +08:00

153 lines
3.5 KiB
Go

package database
import (
"context"
"database/sql"
"database/sql/driver"
"encoding/json"
"fmt"
"reflect"
"time"
"gorm.io/gorm"
)
type NullTime sql.NullTime
func (n *NullTime) Scan(value any) error {
return (*sql.NullTime)(n).Scan(value)
}
func (n NullTime) Value() (driver.Value, error) {
if !n.Valid {
return nil, nil
}
return n.Time, nil
}
func (n NullTime) MarshalJSON() ([]byte, error) {
if n.Valid {
return json.Marshal(n.Time)
}
return json.Marshal(nil)
}
func (n *NullTime) UnmarshalJSON(b []byte) error {
if string(b) == "null" {
n.Valid = false
return nil
}
err := json.Unmarshal(b, &n.Time)
if err == nil {
n.Valid = true
}
return err
}
/*-----------------------------------------------------------*/
type PaginateList struct {
Page int64 `json:"page"`
Size int64 `json:"size"`
Total int64 `json:"total"`
List any `json:"list"`
}
// Paginate 分页查询
func Paginate(db *gorm.DB, model any, page, size int64) (*PaginateList, error) {
ptr := reflect.ValueOf(model)
if ptr.Kind() != reflect.Ptr {
return nil, fmt.Errorf("model must be pointer")
}
var total int64
err := db.Model(model).Count(&total).Error
if err != nil {
return &PaginateList{
Page: page,
Size: size,
Total: total,
List: make([]any, 0),
}, err
}
offset := size * (page - 1)
err = db.Limit(int(size)).Offset(int(offset)).Find(model).Error
if err != nil {
return &PaginateList{
Page: page,
Size: size,
Total: total,
List: make([]any, 0),
}, err
}
return &PaginateList{
Page: page,
Size: size,
Total: total,
List: model,
}, nil
}
// BatchInsert 批量插入
func BatchInsert(ctx context.Context, db *gorm.DB, records interface{}, batchSize int) error {
if batchSize <= 0 {
batchSize = 100
}
return db.WithContext(ctx).CreateInBatches(records, batchSize).Error
}
// ExistsBy 检查记录是否存在
func ExistsBy(ctx context.Context, db *gorm.DB, model interface{}, query interface{}, args ...interface{}) (bool, error) {
var count int64
err := db.WithContext(ctx).Model(model).Where(query, args...).Count(&count).Error
if err != nil {
return false, err
}
return count > 0, nil
}
// FindByIDs 根据ID列表查询
func FindByIDs(ctx context.Context, db *gorm.DB, dest interface{}, ids []uint) error {
if len(ids) == 0 {
return nil
}
return db.WithContext(ctx).Where("id IN ?", ids).Find(dest).Error
}
// SoftDeleteByIDs 批量软删除
func SoftDeleteByIDs(ctx context.Context, db *gorm.DB, model interface{}, ids []uint) error {
if len(ids) == 0 {
return nil
}
return db.WithContext(ctx).Where("id IN ?", ids).Delete(model).Error
}
// UpdateFields 更新指定字段
func UpdateFields(ctx context.Context, db *gorm.DB, model interface{}, id uint, fields map[string]interface{}) error {
return db.WithContext(ctx).Model(model).Where("id = ?", id).Updates(fields).Error
}
// Transaction 简化的事务助手
func Transaction(ctx context.Context, db *gorm.DB, fn func(*gorm.DB) error) error {
return db.WithContext(ctx).Transaction(fn)
}
// BaseModel 基础模型
type BaseModel struct {
ID uint `gorm:"primarykey" json:"id"`
CreatedAt time.Time `gorm:"autoCreateTime" json:"created_at"`
UpdatedAt time.Time `gorm:"autoUpdateTime" json:"updated_at"`
DeletedAt gorm.DeletedAt `gorm:"index" json:"deleted_at,omitempty"`
}
// SimpleModel 简单模型(无软删除)
type SimpleModel struct {
ID uint `gorm:"primarykey" json:"id"`
CreatedAt time.Time `gorm:"autoCreateTime" json:"created_at"`
UpdatedAt time.Time `gorm:"autoUpdateTime" json:"updated_at"`
}