diff --git a/pkg/cache/helper.go b/pkg/cache/helper.go new file mode 100644 index 0000000..ebc41bc --- /dev/null +++ b/pkg/cache/helper.go @@ -0,0 +1,287 @@ +package cache + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "time" + + "github.com/redis/go-redis/v9" +) + +// ========== 分布式锁 ========== + +// Lock 分布式锁 +type Lock struct { + repo RedisRepo + key string + value string + expiry time.Duration +} + +// NewLock 创建分布式锁 +func NewLock(repo RedisRepo, key, value string, expiry time.Duration) *Lock { + return &Lock{ + repo: repo, + key: key, + value: value, + expiry: expiry, + } +} + +// Acquire 获取锁 +func (l *Lock) Acquire(ctx context.Context) (bool, error) { + return l.repo.SetNX(ctx, l.key, l.value, l.expiry) +} + +// Release 释放锁 +func (l *Lock) Release(ctx context.Context) error { + script := redis.NewScript(` + if redis.call("get", KEYS[1]) == ARGV[1] then + return redis.call("del", KEYS[1]) + else + return 0 + end + `) + + return script.Run(ctx, l.repo.Client(), []string{l.key}, l.value).Err() +} + +// Refresh 刷新锁 +func (l *Lock) Refresh(ctx context.Context) (bool, error) { + script := redis.NewScript(` + if redis.call("get", KEYS[1]) == ARGV[1] then + return redis.call("pexpire", KEYS[1], ARGV[2]) + else + return 0 + end + `) + + result, err := script.Run(ctx, l.repo.Client(), []string{l.key}, l.value, l.expiry.Milliseconds()).Result() + if err != nil { + return false, err + } + + return result.(int64) == 1, nil +} + +// ========== JSON 序列化 ========== + +// SetJSON 设置JSON对象 +func SetJSON(ctx context.Context, repo RedisRepo, key string, value interface{}, ttl time.Duration) error { + data, err := json.Marshal(value) + if err != nil { + return fmt.Errorf("json marshal 失败: %w", err) + } + return repo.Set(ctx, key, string(data), ttl) +} + +// GetJSON 获取JSON对象 +func GetJSON(ctx context.Context, repo RedisRepo, key string, dest interface{}) error { + val, err := repo.Get(ctx, key) + if err != nil { + return err + } + if val == "" { + return errors.New("key 不存在") + } + + if err := json.Unmarshal([]byte(val), dest); err != nil { + return fmt.Errorf("json unmarshal 失败: %w", err) + } + return nil +} + +// ========== 限流器 ========== + +// RateLimiter 限流器(令牌桶算法) +type RateLimiter struct { + repo RedisRepo + key string + limit int64 // 最大令牌数 + interval time.Duration // 时间窗口 +} + +// NewRateLimiter 创建限流器 +func NewRateLimiter(repo RedisRepo, key string, limit int64, interval time.Duration) *RateLimiter { + return &RateLimiter{ + repo: repo, + key: key, + limit: limit, + interval: interval, + } +} + +// Allow 检查是否允许(简单计数器实现) +func (r *RateLimiter) Allow(ctx context.Context) (bool, error) { + pipe := r.repo.Pipeline() + + incrCmd := pipe.Incr(ctx, r.key) + pipe.Expire(ctx, r.key, r.interval) + + if _, err := pipe.Exec(ctx); err != nil { + return false, err + } + + count, err := incrCmd.Result() + if err != nil { + return false, err + } + + return count <= r.limit, nil +} + +// AllowN 检查是否允许N次 +func (r *RateLimiter) AllowN(ctx context.Context, n int64) (bool, error) { + pipe := r.repo.Pipeline() + + incrCmd := pipe.IncrBy(ctx, r.key, n) + pipe.Expire(ctx, r.key, r.interval) + + if _, err := pipe.Exec(ctx); err != nil { + return false, err + } + + count, err := incrCmd.Result() + if err != nil { + return false, err + } + + return count <= r.limit, nil +} + +// Remaining 获取剩余次数 +func (r *RateLimiter) Remaining(ctx context.Context) (int64, error) { + val, err := r.repo.Get(ctx, r.key) + if err != nil { + return r.limit, nil + } + if val == "" { + return r.limit, nil + } + + var count int64 + if _, err := fmt.Sscanf(val, "%d", &count); err != nil { + return 0, err + } + + remaining := r.limit - count + if remaining < 0 { + remaining = 0 + } + + return remaining, nil +} + +// Reset 重置限流器 +func (r *RateLimiter) Reset(ctx context.Context) error { + _, err := r.repo.Del(ctx, r.key) + return err +} + +// ========== 缓存装饰器 ========== + +// CacheDecorator 缓存装饰器 +type CacheDecorator struct { + repo RedisRepo + ttl time.Duration +} + +// NewCacheDecorator 创建缓存装饰器 +func NewCacheDecorator(repo RedisRepo, ttl time.Duration) *CacheDecorator { + return &CacheDecorator{ + repo: repo, + ttl: ttl, + } +} + +// GetOrSet 获取或设置缓存 +func (c *CacheDecorator) GetOrSet(ctx context.Context, key string, dest interface{}, loader func() (interface{}, error)) error { + // 尝试从缓存获取 + err := GetJSON(ctx, c.repo, key, dest) + if err == nil { + return nil + } + + // 缓存未命中,执行加载函数 + data, err := loader() + if err != nil { + return fmt.Errorf("loader 执行失败: %w", err) + } + + // 设置缓存 + if err := SetJSON(ctx, c.repo, key, data, c.ttl); err != nil { + // 缓存设置失败不影响主流程 + // 可以记录日志 + } + + // 将数据赋值给dest + dataBytes, _ := json.Marshal(data) + return json.Unmarshal(dataBytes, dest) +} + +// ========== 布隆过滤器(简单实现) ========== + +// BloomFilter 布隆过滤器 +type BloomFilter struct { + repo RedisRepo + key string + size uint64 +} + +// NewBloomFilter 创建布隆过滤器 +func NewBloomFilter(repo RedisRepo, key string, size uint64) *BloomFilter { + return &BloomFilter{ + repo: repo, + key: key, + size: size, + } +} + +// Add 添加元素 +func (b *BloomFilter) Add(ctx context.Context, value string) error { + // 简单hash + hash1 := hashString(value, 0) % b.size + hash2 := hashString(value, 1) % b.size + hash3 := hashString(value, 2) % b.size + + pipe := b.repo.Pipeline() + pipe.SetBit(ctx, b.key, int64(hash1), 1) + pipe.SetBit(ctx, b.key, int64(hash2), 1) + pipe.SetBit(ctx, b.key, int64(hash3), 1) + + _, err := pipe.Exec(ctx) + return err +} + +// Exists 检查元素是否存在 +func (b *BloomFilter) Exists(ctx context.Context, value string) (bool, error) { + hash1 := hashString(value, 0) % b.size + hash2 := hashString(value, 1) % b.size + hash3 := hashString(value, 2) % b.size + + pipe := b.repo.Pipeline() + bit1Cmd := pipe.GetBit(ctx, b.key, int64(hash1)) + bit2Cmd := pipe.GetBit(ctx, b.key, int64(hash2)) + bit3Cmd := pipe.GetBit(ctx, b.key, int64(hash3)) + + if _, err := pipe.Exec(ctx); err != nil { + return false, err + } + + bit1, _ := bit1Cmd.Result() + bit2, _ := bit2Cmd.Result() + bit3, _ := bit3Cmd.Result() + + return bit1 == 1 && bit2 == 1 && bit3 == 1, nil +} + +// hashString 简单字符串hash +func hashString(s string, seed uint64) uint64 { + hash := seed + for i := 0; i < len(s); i++ { + hash = hash*31 + uint64(s[i]) + } + return hash +} diff --git a/pkg/cache/redis.go b/pkg/cache/redis.go index 62303ba..cfbb38c 100644 --- a/pkg/cache/redis.go +++ b/pkg/cache/redis.go @@ -4,479 +4,716 @@ import ( "context" "errors" "fmt" + "sync" "time" - "gitea.bvbej.com/bvbej/base-golang/pkg/time_parse" - "gitea.bvbej.com/bvbej/base-golang/pkg/trace" "github.com/redis/go-redis/v9" ) -type Option func(*option) - -type Trace = trace.T - -type option struct { - Trace *trace.Trace - Redis *trace.Redis -} - -type RedisConfig struct { - Addr string `yaml:"addr"` - Pass string `yaml:"pass"` - DB int `yaml:"db"` - MaxRetries int `yaml:"maxRetries"` // 最大重试次数 - PoolSize int `yaml:"poolSize"` // Redis连接池大小 - MinIdleConn int `yaml:"minIdleConn"` // 最小空闲连接数 -} - -func newOption() *option { - return &option{} -} - -var _ Repo = (*cacheRepo)(nil) - -type Repo interface { - i() +// RedisRepo Redis接口 +type RedisRepo interface { + // Client 获取原生客户端 Client() *redis.Client - Set(key, value string, ttl time.Duration, options ...Option) error - Get(key string, options ...Option) (string, error) - TTL(key string) (time.Duration, error) - Expire(key string, ttl time.Duration) bool - ExpireAt(key string, ttl time.Time) bool - Del(key string, options ...Option) bool - Exists(keys ...string) bool - Incr(key string, options ...Option) (int64, error) - Decr(key string, options ...Option) (int64, error) - HGet(key, field string, options ...Option) (string, error) - HSet(key, field, value string, options ...Option) error - HDel(key, field string, options ...Option) error - HGetAll(key string, options ...Option) (map[string]string, error) - HIncrBy(key, field string, incr int64, options ...Option) (int64, error) - HIncrByFloat(key, field string, incr float64, options ...Option) (float64, error) - LPush(key, value string, options ...Option) error - LLen(key string, options ...Option) (int64, error) - BRPop(key string, timeout time.Duration, options ...Option) (string, error) + + // String 操作 + Set(ctx context.Context, key, value string, ttl time.Duration) error + Get(ctx context.Context, key string) (string, error) + GetSet(ctx context.Context, key, value string) (string, error) + SetNX(ctx context.Context, key, value string, ttl time.Duration) (bool, error) + MGet(ctx context.Context, keys ...string) ([]interface{}, error) + MSet(ctx context.Context, pairs ...interface{}) error + + // Key 操作 + Del(ctx context.Context, keys ...string) (int64, error) + Exists(ctx context.Context, keys ...string) (int64, error) + Expire(ctx context.Context, key string, ttl time.Duration) (bool, error) + ExpireAt(ctx context.Context, key string, tm time.Time) (bool, error) + TTL(ctx context.Context, key string) (time.Duration, error) + Rename(ctx context.Context, key, newKey string) error + + // 计数器操作 + Incr(ctx context.Context, key string) (int64, error) + IncrBy(ctx context.Context, key string, value int64) (int64, error) + Decr(ctx context.Context, key string) (int64, error) + DecrBy(ctx context.Context, key string, value int64) (int64, error) + + // Hash 操作 + HSet(ctx context.Context, key string, values ...interface{}) error + HGet(ctx context.Context, key, field string) (string, error) + HGetAll(ctx context.Context, key string) (map[string]string, error) + HDel(ctx context.Context, key string, fields ...string) (int64, error) + HExists(ctx context.Context, key, field string) (bool, error) + HIncrBy(ctx context.Context, key, field string, incr int64) (int64, error) + HIncrByFloat(ctx context.Context, key, field string, incr float64) (float64, error) + HKeys(ctx context.Context, key string) ([]string, error) + HLen(ctx context.Context, key string) (int64, error) + + // List 操作 + LPush(ctx context.Context, key string, values ...interface{}) (int64, error) + RPush(ctx context.Context, key string, values ...interface{}) (int64, error) + LPop(ctx context.Context, key string) (string, error) + RPop(ctx context.Context, key string) (string, error) + BLPop(ctx context.Context, timeout time.Duration, keys ...string) ([]string, error) + BRPop(ctx context.Context, timeout time.Duration, keys ...string) ([]string, error) + LLen(ctx context.Context, key string) (int64, error) + LRange(ctx context.Context, key string, start, stop int64) ([]string, error) + + // Set 操作 + SAdd(ctx context.Context, key string, members ...interface{}) (int64, error) + SMembers(ctx context.Context, key string) ([]string, error) + SIsMember(ctx context.Context, key string, member interface{}) (bool, error) + SRem(ctx context.Context, key string, members ...interface{}) (int64, error) + SCard(ctx context.Context, key string) (int64, error) + + // ZSet 操作 + ZAdd(ctx context.Context, key string, members ...redis.Z) (int64, error) + ZRange(ctx context.Context, key string, start, stop int64) ([]string, error) + ZRangeWithScores(ctx context.Context, key string, start, stop int64) ([]redis.Z, error) + ZRevRange(ctx context.Context, key string, start, stop int64) ([]string, error) + ZRem(ctx context.Context, key string, members ...interface{}) (int64, error) + ZCard(ctx context.Context, key string) (int64, error) + ZScore(ctx context.Context, key, member string) (float64, error) + + // Pipeline 操作 + Pipeline() redis.Pipeliner + TxPipeline() redis.Pipeliner + + // Pub/Sub 操作 + Publish(ctx context.Context, channel string, message interface{}) error + Subscribe(ctx context.Context, channels ...string) *redis.PubSub + + // 健康检查 + Ping(ctx context.Context) error + Info(ctx context.Context, section ...string) (string, error) + + // 连接管理 Close() error + PoolStats() *redis.PoolStats } -type cacheRepo struct { - client *redis.Client - ctx context.Context +// RedisConfig Redis配置 +type RedisConfig struct { + // 基础配置 + Addr string `yaml:"addr" json:"addr"` // 地址,如: localhost:6379 + Password string `yaml:"password" json:"password"` // 密码 + DB int `yaml:"db" json:"db"` // 数据库编号 + + // 连接池配置 + PoolSize int `yaml:"pool_size" json:"pool_size"` // 最大连接数 + MinIdleConns int `yaml:"min_idle_conns" json:"min_idle_conns"` // 最小空闲连接数 + + // 超时配置 + DialTimeout time.Duration `yaml:"dial_timeout" json:"dial_timeout"` // 连接超时 + ReadTimeout time.Duration `yaml:"read_timeout" json:"read_timeout"` // 读超时 + WriteTimeout time.Duration `yaml:"write_timeout" json:"write_timeout"` // 写超时 + PoolTimeout time.Duration `yaml:"pool_timeout" json:"pool_timeout"` // 连接池超时 + + // 重试配置 + MaxRetries int `yaml:"max_retries" json:"max_retries"` // 最大重试次数 + MinRetryBackoff time.Duration `yaml:"min_retry_backoff" json:"min_retry_backoff"` // 最小重试间隔 + MaxRetryBackoff time.Duration `yaml:"max_retry_backoff" json:"max_retry_backoff"` // 最大重试间隔 + + // 连接存活 + ConnMaxIdleTime time.Duration `yaml:"conn_max_idle_time" json:"conn_max_idle_time"` // 最大空闲时间 + ConnMaxLifetime time.Duration `yaml:"conn_max_lifetime" json:"conn_max_lifetime"` // 最大生命周期 } -func New(cfg RedisConfig) (Repo, error) { - client := redis.NewClient(&redis.Options{ - Addr: cfg.Addr, - Password: cfg.Pass, - DB: cfg.DB, - MaxRetries: cfg.MaxRetries, - PoolSize: cfg.PoolSize, - MinIdleConns: cfg.MinIdleConn, - }) - ctx := context.TODO() - if err := client.Ping(ctx).Err(); err != nil { - return nil, errors.Join(err, errors.New("ping redis err")) +// DefaultRedisConfig 默认配置 +func DefaultRedisConfig() *RedisConfig { + return &RedisConfig{ + Addr: "localhost:6379", + Password: "", + DB: 0, + PoolSize: 100, + MinIdleConns: 10, + DialTimeout: 5 * time.Second, + ReadTimeout: 3 * time.Second, + WriteTimeout: 3 * time.Second, + PoolTimeout: 4 * time.Second, + MaxRetries: 3, + MinRetryBackoff: 8 * time.Millisecond, + MaxRetryBackoff: 512 * time.Millisecond, + ConnMaxIdleTime: 5 * time.Minute, + ConnMaxLifetime: 30 * time.Minute, } - return &cacheRepo{ +} + +// redisRepo Redis实现 +type redisRepo struct { + client *redis.Client + config *RedisConfig + mu sync.RWMutex + closed bool +} + +// NewRedis 创建Redis实例 +func NewRedis(cfg *RedisConfig) (RedisRepo, error) { + if cfg == nil { + cfg = DefaultRedisConfig() + } + + // 合并默认配置 + mergeDefaultConfig(cfg) + + // 创建客户端 + client := redis.NewClient(&redis.Options{ + Addr: cfg.Addr, + Password: cfg.Password, + DB: cfg.DB, + PoolSize: cfg.PoolSize, + MinIdleConns: cfg.MinIdleConns, + DialTimeout: cfg.DialTimeout, + ReadTimeout: cfg.ReadTimeout, + WriteTimeout: cfg.WriteTimeout, + PoolTimeout: cfg.PoolTimeout, + MaxRetries: cfg.MaxRetries, + MinRetryBackoff: cfg.MinRetryBackoff, + MaxRetryBackoff: cfg.MaxRetryBackoff, + ConnMaxIdleTime: cfg.ConnMaxIdleTime, + ConnMaxLifetime: cfg.ConnMaxLifetime, + }) + + // Ping 测试连接 + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + if err := client.Ping(ctx).Err(); err != nil { + return nil, fmt.Errorf("redis ping 失败: %w", err) + } + + return &redisRepo{ client: client, - ctx: ctx, + config: cfg, + closed: false, }, nil } -func WithTrace(t Trace) Option { - return func(opt *option) { - if t != nil { - opt.Trace = t.(*trace.Trace) - opt.Redis = new(trace.Redis) - } - } +// Client 获取原生客户端 +func (r *redisRepo) Client() *redis.Client { + return r.client } -func (c *cacheRepo) i() {} +// ========== String 操作 ========== -func (c *cacheRepo) Client() *redis.Client { - return c.client +// Set 设置值 +func (r *redisRepo) Set(ctx context.Context, key, value string, ttl time.Duration) error { + if err := r.checkClosed(); err != nil { + return err + } + return r.client.Set(ctx, key, value, ttl).Err() } -func (c *cacheRepo) Set(key, value string, ttl time.Duration, options ...Option) error { - ts := time.Now() - opt := newOption() - defer func() { - if opt.Trace != nil { - opt.Redis.Timestamp = time_parse.CSTLayoutString() - opt.Redis.Handle = "set" - opt.Redis.Key = key - opt.Redis.Value = value - opt.Redis.TTL = ttl.Minutes() - opt.Redis.CostSeconds = time.Since(ts).Seconds() - opt.Trace.AppendRedis(opt.Redis) - } - }() - - for _, f := range options { - f(opt) +// Get 获取值 +func (r *redisRepo) Get(ctx context.Context, key string) (string, error) { + if err := r.checkClosed(); err != nil { + return "", err } - if err := c.client.Set(c.ctx, key, value, ttl).Err(); err != nil { - return errors.Join(err, fmt.Errorf("redis set key: %s err", key)) + val, err := r.client.Get(ctx, key).Result() + if err == redis.Nil { + return "", nil + } + return val, err +} + +// GetSet 设置新值并返回旧值 +func (r *redisRepo) GetSet(ctx context.Context, key, value string) (string, error) { + if err := r.checkClosed(); err != nil { + return "", err } + val, err := r.client.GetSet(ctx, key, value).Result() + if err == redis.Nil { + return "", nil + } + return val, err +} + +// SetNX 仅当key不存在时设置 +func (r *redisRepo) SetNX(ctx context.Context, key, value string, ttl time.Duration) (bool, error) { + if err := r.checkClosed(); err != nil { + return false, err + } + return r.client.SetNX(ctx, key, value, ttl).Result() +} + +// MGet 批量获取 +func (r *redisRepo) MGet(ctx context.Context, keys ...string) ([]interface{}, error) { + if err := r.checkClosed(); err != nil { + return nil, err + } + return r.client.MGet(ctx, keys...).Result() +} + +// MSet 批量设置 +func (r *redisRepo) MSet(ctx context.Context, pairs ...interface{}) error { + if err := r.checkClosed(); err != nil { + return err + } + return r.client.MSet(ctx, pairs...).Err() +} + +// ========== Key 操作 ========== + +// Del 删除键 +func (r *redisRepo) Del(ctx context.Context, keys ...string) (int64, error) { + if err := r.checkClosed(); err != nil { + return 0, err + } + return r.client.Del(ctx, keys...).Result() +} + +// Exists 检查键是否存在 +func (r *redisRepo) Exists(ctx context.Context, keys ...string) (int64, error) { + if err := r.checkClosed(); err != nil { + return 0, err + } + return r.client.Exists(ctx, keys...).Result() +} + +// Expire 设置过期时间 +func (r *redisRepo) Expire(ctx context.Context, key string, ttl time.Duration) (bool, error) { + if err := r.checkClosed(); err != nil { + return false, err + } + return r.client.Expire(ctx, key, ttl).Result() +} + +// ExpireAt 设置过期时间点 +func (r *redisRepo) ExpireAt(ctx context.Context, key string, tm time.Time) (bool, error) { + if err := r.checkClosed(); err != nil { + return false, err + } + return r.client.ExpireAt(ctx, key, tm).Result() +} + +// TTL 获取剩余时间 +func (r *redisRepo) TTL(ctx context.Context, key string) (time.Duration, error) { + if err := r.checkClosed(); err != nil { + return 0, err + } + return r.client.TTL(ctx, key).Result() +} + +// Rename 重命名键 +func (r *redisRepo) Rename(ctx context.Context, key, newKey string) error { + if err := r.checkClosed(); err != nil { + return err + } + return r.client.Rename(ctx, key, newKey).Err() +} + +// ========== 计数器操作 ========== + +// Incr 自增1 +func (r *redisRepo) Incr(ctx context.Context, key string) (int64, error) { + if err := r.checkClosed(); err != nil { + return 0, err + } + return r.client.Incr(ctx, key).Result() +} + +// IncrBy 自增指定值 +func (r *redisRepo) IncrBy(ctx context.Context, key string, value int64) (int64, error) { + if err := r.checkClosed(); err != nil { + return 0, err + } + return r.client.IncrBy(ctx, key, value).Result() +} + +// Decr 自减1 +func (r *redisRepo) Decr(ctx context.Context, key string) (int64, error) { + if err := r.checkClosed(); err != nil { + return 0, err + } + return r.client.Decr(ctx, key).Result() +} + +// DecrBy 自减指定值 +func (r *redisRepo) DecrBy(ctx context.Context, key string, value int64) (int64, error) { + if err := r.checkClosed(); err != nil { + return 0, err + } + return r.client.DecrBy(ctx, key, value).Result() +} + +// ========== Hash 操作 ========== + +// HSet 设置hash字段 +func (r *redisRepo) HSet(ctx context.Context, key string, values ...interface{}) error { + if err := r.checkClosed(); err != nil { + return err + } + return r.client.HSet(ctx, key, values...).Err() +} + +// HGet 获取hash字段 +func (r *redisRepo) HGet(ctx context.Context, key, field string) (string, error) { + if err := r.checkClosed(); err != nil { + return "", err + } + + val, err := r.client.HGet(ctx, key, field).Result() + if err == redis.Nil { + return "", nil + } + return val, err +} + +// HGetAll 获取所有hash字段 +func (r *redisRepo) HGetAll(ctx context.Context, key string) (map[string]string, error) { + if err := r.checkClosed(); err != nil { + return nil, err + } + return r.client.HGetAll(ctx, key).Result() +} + +// HDel 删除hash字段 +func (r *redisRepo) HDel(ctx context.Context, key string, fields ...string) (int64, error) { + if err := r.checkClosed(); err != nil { + return 0, err + } + return r.client.HDel(ctx, key, fields...).Result() +} + +// HExists 检查hash字段是否存在 +func (r *redisRepo) HExists(ctx context.Context, key, field string) (bool, error) { + if err := r.checkClosed(); err != nil { + return false, err + } + return r.client.HExists(ctx, key, field).Result() +} + +// HIncrBy hash字段自增 +func (r *redisRepo) HIncrBy(ctx context.Context, key, field string, incr int64) (int64, error) { + if err := r.checkClosed(); err != nil { + return 0, err + } + return r.client.HIncrBy(ctx, key, field, incr).Result() +} + +// HIncrByFloat hash字段浮点自增 +func (r *redisRepo) HIncrByFloat(ctx context.Context, key, field string, incr float64) (float64, error) { + if err := r.checkClosed(); err != nil { + return 0, err + } + return r.client.HIncrByFloat(ctx, key, field, incr).Result() +} + +// HKeys 获取所有hash字段名 +func (r *redisRepo) HKeys(ctx context.Context, key string) ([]string, error) { + if err := r.checkClosed(); err != nil { + return nil, err + } + return r.client.HKeys(ctx, key).Result() +} + +// HLen 获取hash字段数量 +func (r *redisRepo) HLen(ctx context.Context, key string) (int64, error) { + if err := r.checkClosed(); err != nil { + return 0, err + } + return r.client.HLen(ctx, key).Result() +} + +// ========== List 操作 ========== + +// LPush 从左侧推入 +func (r *redisRepo) LPush(ctx context.Context, key string, values ...interface{}) (int64, error) { + if err := r.checkClosed(); err != nil { + return 0, err + } + return r.client.LPush(ctx, key, values...).Result() +} + +// RPush 从右侧推入 +func (r *redisRepo) RPush(ctx context.Context, key string, values ...interface{}) (int64, error) { + if err := r.checkClosed(); err != nil { + return 0, err + } + return r.client.RPush(ctx, key, values...).Result() +} + +// LPop 从左侧弹出 +func (r *redisRepo) LPop(ctx context.Context, key string) (string, error) { + if err := r.checkClosed(); err != nil { + return "", err + } + + val, err := r.client.LPop(ctx, key).Result() + if err == redis.Nil { + return "", nil + } + return val, err +} + +// RPop 从右侧弹出 +func (r *redisRepo) RPop(ctx context.Context, key string) (string, error) { + if err := r.checkClosed(); err != nil { + return "", err + } + + val, err := r.client.RPop(ctx, key).Result() + if err == redis.Nil { + return "", nil + } + return val, err +} + +// BLPop 阻塞式从左侧弹出 +func (r *redisRepo) BLPop(ctx context.Context, timeout time.Duration, keys ...string) ([]string, error) { + if err := r.checkClosed(); err != nil { + return nil, err + } + return r.client.BLPop(ctx, timeout, keys...).Result() +} + +// BRPop 阻塞式从右侧弹出 +func (r *redisRepo) BRPop(ctx context.Context, timeout time.Duration, keys ...string) ([]string, error) { + if err := r.checkClosed(); err != nil { + return nil, err + } + return r.client.BRPop(ctx, timeout, keys...).Result() +} + +// LLen 获取列表长度 +func (r *redisRepo) LLen(ctx context.Context, key string) (int64, error) { + if err := r.checkClosed(); err != nil { + return 0, err + } + return r.client.LLen(ctx, key).Result() +} + +// LRange 获取列表范围 +func (r *redisRepo) LRange(ctx context.Context, key string, start, stop int64) ([]string, error) { + if err := r.checkClosed(); err != nil { + return nil, err + } + return r.client.LRange(ctx, key, start, stop).Result() +} + +// ========== Set 操作 ========== + +// SAdd 添加集合成员 +func (r *redisRepo) SAdd(ctx context.Context, key string, members ...interface{}) (int64, error) { + if err := r.checkClosed(); err != nil { + return 0, err + } + return r.client.SAdd(ctx, key, members...).Result() +} + +// SMembers 获取所有集合成员 +func (r *redisRepo) SMembers(ctx context.Context, key string) ([]string, error) { + if err := r.checkClosed(); err != nil { + return nil, err + } + return r.client.SMembers(ctx, key).Result() +} + +// SIsMember 检查是否是集合成员 +func (r *redisRepo) SIsMember(ctx context.Context, key string, member interface{}) (bool, error) { + if err := r.checkClosed(); err != nil { + return false, err + } + return r.client.SIsMember(ctx, key, member).Result() +} + +// SRem 删除集合成员 +func (r *redisRepo) SRem(ctx context.Context, key string, members ...interface{}) (int64, error) { + if err := r.checkClosed(); err != nil { + return 0, err + } + return r.client.SRem(ctx, key, members...).Result() +} + +// SCard 获取集合成员数量 +func (r *redisRepo) SCard(ctx context.Context, key string) (int64, error) { + if err := r.checkClosed(); err != nil { + return 0, err + } + return r.client.SCard(ctx, key).Result() +} + +// ========== ZSet 操作 ========== + +// ZAdd 添加有序集合成员 +func (r *redisRepo) ZAdd(ctx context.Context, key string, members ...redis.Z) (int64, error) { + if err := r.checkClosed(); err != nil { + return 0, err + } + return r.client.ZAdd(ctx, key, members...).Result() +} + +// ZRange 获取有序集合范围 +func (r *redisRepo) ZRange(ctx context.Context, key string, start, stop int64) ([]string, error) { + if err := r.checkClosed(); err != nil { + return nil, err + } + return r.client.ZRange(ctx, key, start, stop).Result() +} + +// ZRangeWithScores 获取有序集合范围(带分数) +func (r *redisRepo) ZRangeWithScores(ctx context.Context, key string, start, stop int64) ([]redis.Z, error) { + if err := r.checkClosed(); err != nil { + return nil, err + } + return r.client.ZRangeWithScores(ctx, key, start, stop).Result() +} + +// ZRevRange 倒序获取有序集合范围 +func (r *redisRepo) ZRevRange(ctx context.Context, key string, start, stop int64) ([]string, error) { + if err := r.checkClosed(); err != nil { + return nil, err + } + return r.client.ZRevRange(ctx, key, start, stop).Result() +} + +// ZRem 删除有序集合成员 +func (r *redisRepo) ZRem(ctx context.Context, key string, members ...interface{}) (int64, error) { + if err := r.checkClosed(); err != nil { + return 0, err + } + return r.client.ZRem(ctx, key, members...).Result() +} + +// ZCard 获取有序集合成员数量 +func (r *redisRepo) ZCard(ctx context.Context, key string) (int64, error) { + if err := r.checkClosed(); err != nil { + return 0, err + } + return r.client.ZCard(ctx, key).Result() +} + +// ZScore 获取成员分数 +func (r *redisRepo) ZScore(ctx context.Context, key, member string) (float64, error) { + if err := r.checkClosed(); err != nil { + return 0, err + } + return r.client.ZScore(ctx, key, member).Result() +} + +// ========== Pipeline 操作 ========== + +// Pipeline 获取pipeline +func (r *redisRepo) Pipeline() redis.Pipeliner { + return r.client.Pipeline() +} + +// TxPipeline 获取事务pipeline +func (r *redisRepo) TxPipeline() redis.Pipeliner { + return r.client.TxPipeline() +} + +// ========== Pub/Sub 操作 ========== + +// Publish 发布消息 +func (r *redisRepo) Publish(ctx context.Context, channel string, message interface{}) error { + if err := r.checkClosed(); err != nil { + return err + } + return r.client.Publish(ctx, channel, message).Err() +} + +// Subscribe 订阅频道 +func (r *redisRepo) Subscribe(ctx context.Context, channels ...string) *redis.PubSub { + return r.client.Subscribe(ctx, channels...) +} + +// ========== 健康检查 ========== + +// Ping 检查连接 +func (r *redisRepo) Ping(ctx context.Context) error { + if err := r.checkClosed(); err != nil { + return err + } + return r.client.Ping(ctx).Err() +} + +// Info 获取服务器信息 +func (r *redisRepo) Info(ctx context.Context, section ...string) (string, error) { + if err := r.checkClosed(); err != nil { + return "", err + } + return r.client.Info(ctx, section...).Result() +} + +// ========== 连接管理 ========== + +// Close 关闭连接 +func (r *redisRepo) Close() error { + r.mu.Lock() + defer r.mu.Unlock() + + if r.closed { + return nil + } + + r.closed = true + return r.client.Close() +} + +// PoolStats 获取连接池统计 +func (r *redisRepo) PoolStats() *redis.PoolStats { + r.mu.RLock() + defer r.mu.RUnlock() + + if r.closed { + return nil + } + + stats := r.client.PoolStats() + return stats +} + +// ========== 内部方法 ========== + +// checkClosed 检查是否已关闭 +func (r *redisRepo) checkClosed() error { + r.mu.RLock() + defer r.mu.RUnlock() + + if r.closed { + return errors.New("redis 连接已关闭") + } return nil } -func (c *cacheRepo) Get(key string, options ...Option) (string, error) { - ts := time.Now() - opt := newOption() - defer func() { - if opt.Trace != nil { - opt.Redis.Timestamp = time_parse.CSTLayoutString() - opt.Redis.Handle = "get" - opt.Redis.Key = key - opt.Redis.CostSeconds = time.Since(ts).Seconds() - opt.Trace.AppendRedis(opt.Redis) - } - }() - - for _, f := range options { - f(opt) +// mergeDefaultConfig 合并默认配置 +func mergeDefaultConfig(cfg *RedisConfig) { + if cfg.Addr == "" { + cfg.Addr = "localhost:6379" } - - value, err := c.client.Get(c.ctx, key).Result() - if err != nil { - return "", errors.Join(err, fmt.Errorf("redis get key: %s err", key)) + if cfg.PoolSize == 0 { + cfg.PoolSize = 100 + } + if cfg.MinIdleConns == 0 { + cfg.MinIdleConns = 10 + } + if cfg.DialTimeout == 0 { + cfg.DialTimeout = 5 * time.Second + } + if cfg.ReadTimeout == 0 { + cfg.ReadTimeout = 3 * time.Second + } + if cfg.WriteTimeout == 0 { + cfg.WriteTimeout = 3 * time.Second + } + if cfg.PoolTimeout == 0 { + cfg.PoolTimeout = 4 * time.Second + } + if cfg.MaxRetries == 0 { + cfg.MaxRetries = 3 + } + if cfg.MinRetryBackoff == 0 { + cfg.MinRetryBackoff = 8 * time.Millisecond + } + if cfg.MaxRetryBackoff == 0 { + cfg.MaxRetryBackoff = 512 * time.Millisecond + } + if cfg.ConnMaxIdleTime == 0 { + cfg.ConnMaxIdleTime = 5 * time.Minute + } + if cfg.ConnMaxLifetime == 0 { + cfg.ConnMaxLifetime = 30 * time.Minute } - - return value, nil -} - -func (c *cacheRepo) TTL(key string) (time.Duration, error) { - ttl, err := c.client.TTL(c.ctx, key).Result() - if err != nil { - return -1, errors.Join(err, fmt.Errorf("redis get key: %s err", key)) - } - - return ttl, nil -} - -func (c *cacheRepo) Expire(key string, ttl time.Duration) bool { - ok, _ := c.client.Expire(c.ctx, key, ttl).Result() - return ok -} - -func (c *cacheRepo) ExpireAt(key string, ttl time.Time) bool { - ok, _ := c.client.ExpireAt(c.ctx, key, ttl).Result() - return ok -} - -func (c *cacheRepo) Exists(keys ...string) bool { - if len(keys) == 0 { - return true - } - value, _ := c.client.Exists(c.ctx, keys...).Result() - return value > 0 -} - -func (c *cacheRepo) Del(key string, options ...Option) bool { - ts := time.Now() - opt := newOption() - defer func() { - if opt.Trace != nil { - opt.Redis.Timestamp = time_parse.CSTLayoutString() - opt.Redis.Handle = "del" - opt.Redis.Key = key - opt.Redis.CostSeconds = time.Since(ts).Seconds() - opt.Trace.AppendRedis(opt.Redis) - } - }() - - for _, f := range options { - f(opt) - } - - if key == "" { - return true - } - - value, _ := c.client.Del(c.ctx, key).Result() - return value > 0 -} - -func (c *cacheRepo) Incr(key string, options ...Option) (int64, error) { - ts := time.Now() - opt := newOption() - defer func() { - if opt.Trace != nil { - opt.Redis.Timestamp = time_parse.CSTLayoutString() - opt.Redis.Handle = "incr" - opt.Redis.Key = key - opt.Redis.CostSeconds = time.Since(ts).Seconds() - opt.Trace.AppendRedis(opt.Redis) - } - }() - - for _, f := range options { - f(opt) - } - value, err := c.client.Incr(c.ctx, key).Result() - if err != nil { - return 0, errors.Join(err, fmt.Errorf("redis incr key: %s err", key)) - } - return value, nil -} - -func (c *cacheRepo) Decr(key string, options ...Option) (int64, error) { - ts := time.Now() - opt := newOption() - defer func() { - if opt.Trace != nil { - opt.Redis.Timestamp = time_parse.CSTLayoutString() - opt.Redis.Handle = "decr" - opt.Redis.Key = key - opt.Redis.CostSeconds = time.Since(ts).Seconds() - opt.Trace.AppendRedis(opt.Redis) - } - }() - - for _, f := range options { - f(opt) - } - value, err := c.client.Decr(c.ctx, key).Result() - if err != nil { - return 0, errors.Join(err, fmt.Errorf("redis decr key: %s err", key)) - } - return value, nil -} - -func (c *cacheRepo) HGet(key, field string, options ...Option) (string, error) { - ts := time.Now() - opt := newOption() - defer func() { - if opt.Trace != nil { - opt.Redis.Timestamp = time_parse.CSTLayoutString() - opt.Redis.Handle = "hash get" - opt.Redis.Key = key - opt.Redis.Value = field - opt.Redis.CostSeconds = time.Since(ts).Seconds() - opt.Trace.AppendRedis(opt.Redis) - } - }() - - for _, f := range options { - f(opt) - } - - value, err := c.client.HGet(c.ctx, key, field).Result() - if err != nil { - return "", errors.Join(err, fmt.Errorf("redis hget key: %s field: %s err", key, field)) - } - - return value, nil -} - -func (c *cacheRepo) HSet(key, field, value string, options ...Option) error { - ts := time.Now() - opt := newOption() - defer func() { - if opt.Trace != nil { - opt.Redis.Timestamp = time_parse.CSTLayoutString() - opt.Redis.Handle = "hash set" - opt.Redis.Key = key - opt.Redis.Value = field + "/" + value - opt.Redis.CostSeconds = time.Since(ts).Seconds() - opt.Trace.AppendRedis(opt.Redis) - } - }() - - for _, f := range options { - f(opt) - } - - if err := c.client.HSet(c.ctx, key, field, value).Err(); err != nil { - return errors.Join(err, fmt.Errorf("redis hset key: %s field: %s err", key, field)) - } - - return nil -} - -func (c *cacheRepo) HDel(key, field string, options ...Option) error { - ts := time.Now() - opt := newOption() - defer func() { - if opt.Trace != nil { - opt.Redis.Timestamp = time_parse.CSTLayoutString() - opt.Redis.Handle = "hash del" - opt.Redis.Key = key - opt.Redis.Value = field - opt.Redis.CostSeconds = time.Since(ts).Seconds() - opt.Trace.AppendRedis(opt.Redis) - } - }() - - for _, f := range options { - f(opt) - } - - if err := c.client.HDel(c.ctx, key, field).Err(); err != nil { - return errors.Join(err, fmt.Errorf("redis hdel key: %s field: %s err", key, field)) - } - - return nil -} - -func (c *cacheRepo) HGetAll(key string, options ...Option) (map[string]string, error) { - ts := time.Now() - opt := newOption() - defer func() { - if opt.Trace != nil { - opt.Redis.Timestamp = time_parse.CSTLayoutString() - opt.Redis.Handle = "hash get all" - opt.Redis.Key = key - opt.Redis.CostSeconds = time.Since(ts).Seconds() - opt.Trace.AppendRedis(opt.Redis) - } - }() - - for _, f := range options { - f(opt) - } - - value, err := c.client.HGetAll(c.ctx, key).Result() - if err != nil { - return nil, errors.Join(err, fmt.Errorf("redis hget all key: %s err", key)) - } - - return value, nil -} - -func (c *cacheRepo) HIncrBy(key, field string, incr int64, options ...Option) (int64, error) { - ts := time.Now() - opt := newOption() - defer func() { - if opt.Trace != nil { - opt.Redis.Timestamp = time_parse.CSTLayoutString() - opt.Redis.Handle = "hash incr int64" - opt.Redis.Key = key - opt.Redis.Value = fmt.Sprintf("field:%s incr:%d", field, incr) - opt.Redis.CostSeconds = time.Since(ts).Seconds() - opt.Trace.AppendRedis(opt.Redis) - } - }() - - for _, f := range options { - f(opt) - } - - value, err := c.client.HIncrBy(c.ctx, key, field, incr).Result() - if err != nil { - return 0, errors.Join(err, fmt.Errorf("redis hash incr int64 key: %s err", key)) - } - - return value, nil -} - -func (c *cacheRepo) HIncrByFloat(key, field string, incr float64, options ...Option) (float64, error) { - ts := time.Now() - opt := newOption() - defer func() { - if opt.Trace != nil { - opt.Redis.Timestamp = time_parse.CSTLayoutString() - opt.Redis.Handle = "hash incr float64" - opt.Redis.Key = key - opt.Redis.Value = fmt.Sprintf("field:%s incr:%d", field, incr) - opt.Redis.CostSeconds = time.Since(ts).Seconds() - opt.Trace.AppendRedis(opt.Redis) - } - }() - - for _, f := range options { - f(opt) - } - - value, err := c.client.HIncrByFloat(c.ctx, key, field, incr).Result() - if err != nil { - return 0, errors.Join(err, fmt.Errorf("redis hash incr float64 key: %s err", key)) - } - - return value, nil -} - -func (c *cacheRepo) LPush(key, value string, options ...Option) error { - ts := time.Now() - opt := newOption() - defer func() { - if opt.Trace != nil { - opt.Redis.Timestamp = time_parse.CSTLayoutString() - opt.Redis.Handle = "list push" - opt.Redis.Key = key - opt.Redis.Value = value - opt.Redis.CostSeconds = time.Since(ts).Seconds() - opt.Trace.AppendRedis(opt.Redis) - } - }() - - for _, f := range options { - f(opt) - } - - _, err := c.client.LPush(c.ctx, key, value).Result() - if err != nil { - return errors.Join(err, fmt.Errorf("redis list push key: %s value: %s err", key, value)) - } - - return nil -} - -func (c *cacheRepo) LLen(key string, options ...Option) (int64, error) { - ts := time.Now() - opt := newOption() - defer func() { - if opt.Trace != nil { - opt.Redis.Timestamp = time_parse.CSTLayoutString() - opt.Redis.Handle = "list len" - opt.Redis.Key = key - opt.Redis.CostSeconds = time.Since(ts).Seconds() - opt.Trace.AppendRedis(opt.Redis) - } - }() - - for _, f := range options { - f(opt) - } - - value, err := c.client.LLen(c.ctx, key).Result() - if err != nil { - return 0, errors.Join(err, fmt.Errorf("redis list len key: %s err", key)) - } - - return value, nil -} - -func (c *cacheRepo) BRPop(key string, timeout time.Duration, options ...Option) (string, error) { - ts := time.Now() - opt := newOption() - defer func() { - if opt.Trace != nil { - opt.Redis.Timestamp = time_parse.CSTLayoutString() - opt.Redis.Handle = "list brpop" - opt.Redis.Key = key - opt.Redis.TTL = timeout.Seconds() - opt.Redis.CostSeconds = time.Since(ts).Seconds() - opt.Trace.AppendRedis(opt.Redis) - } - }() - - for _, f := range options { - f(opt) - } - - value, err := c.client.BRPop(c.ctx, timeout, key).Result() - if err != nil { - return "", errors.Join(err, fmt.Errorf("redis list len key: %s err", key)) - } - - return value[1], nil -} - -func (c *cacheRepo) Close() error { - return c.client.Close() }