package database import ( "context" "database/sql" "database/sql/driver" "encoding/json" "fmt" "reflect" "time" "gorm.io/gorm" ) type NullTime sql.NullTime func (n *NullTime) Scan(value any) error { return (*sql.NullTime)(n).Scan(value) } func (n NullTime) Value() (driver.Value, error) { if !n.Valid { return nil, nil } return n.Time, nil } func (n NullTime) MarshalJSON() ([]byte, error) { if n.Valid { return json.Marshal(n.Time) } return json.Marshal(nil) } func (n *NullTime) UnmarshalJSON(b []byte) error { if string(b) == "null" { n.Valid = false return nil } err := json.Unmarshal(b, &n.Time) if err == nil { n.Valid = true } return err } /*-----------------------------------------------------------*/ type PaginateList struct { Page int64 `json:"page"` Size int64 `json:"size"` Total int64 `json:"total"` List any `json:"list"` } // Paginate 分页查询 func Paginate(db *gorm.DB, model any, page, size int64) (*PaginateList, error) { ptr := reflect.ValueOf(model) if ptr.Kind() != reflect.Ptr { return nil, fmt.Errorf("model must be pointer") } var total int64 err := db.Model(model).Count(&total).Error if err != nil { return &PaginateList{ Page: page, Size: size, Total: total, List: make([]any, 0), }, err } offset := size * (page - 1) err = db.Limit(int(size)).Offset(int(offset)).Find(model).Error if err != nil { return &PaginateList{ Page: page, Size: size, Total: total, List: make([]any, 0), }, err } return &PaginateList{ Page: page, Size: size, Total: total, List: model, }, nil } // BatchInsert 批量插入 func BatchInsert(ctx context.Context, db *gorm.DB, records interface{}, batchSize int) error { if batchSize <= 0 { batchSize = 100 } return db.WithContext(ctx).CreateInBatches(records, batchSize).Error } // ExistsBy 检查记录是否存在 func ExistsBy(ctx context.Context, db *gorm.DB, model interface{}, query interface{}, args ...interface{}) (bool, error) { var count int64 err := db.WithContext(ctx).Model(model).Where(query, args...).Count(&count).Error if err != nil { return false, err } return count > 0, nil } // FindByIDs 根据ID列表查询 func FindByIDs(ctx context.Context, db *gorm.DB, dest interface{}, ids []uint) error { if len(ids) == 0 { return nil } return db.WithContext(ctx).Where("id IN ?", ids).Find(dest).Error } // SoftDeleteByIDs 批量软删除 func SoftDeleteByIDs(ctx context.Context, db *gorm.DB, model interface{}, ids []uint) error { if len(ids) == 0 { return nil } return db.WithContext(ctx).Where("id IN ?", ids).Delete(model).Error } // UpdateFields 更新指定字段 func UpdateFields(ctx context.Context, db *gorm.DB, model interface{}, id uint, fields map[string]interface{}) error { return db.WithContext(ctx).Model(model).Where("id = ?", id).Updates(fields).Error } // Transaction 简化的事务助手 func Transaction(ctx context.Context, db *gorm.DB, fn func(*gorm.DB) error) error { return db.WithContext(ctx).Transaction(fn) } // BaseModel 基础模型 type BaseModel struct { ID uint `gorm:"primarykey" json:"id"` CreatedAt time.Time `gorm:"autoCreateTime" json:"created_at"` UpdatedAt time.Time `gorm:"autoUpdateTime" json:"updated_at"` DeletedAt gorm.DeletedAt `gorm:"index" json:"deleted_at,omitempty"` } // SimpleModel 简单模型(无软删除) type SimpleModel struct { ID uint `gorm:"primarykey" json:"id"` CreatedAt time.Time `gorm:"autoCreateTime" json:"created_at"` UpdatedAt time.Time `gorm:"autoUpdateTime" json:"updated_at"` }