package cache import ( "context" "encoding/json" "errors" "fmt" "time" "github.com/redis/go-redis/v9" ) // ========== 分布式锁 ========== // Lock 分布式锁 type Lock struct { repo RedisRepo key string value string expiry time.Duration } // NewLock 创建分布式锁 func NewLock(repo RedisRepo, key, value string, expiry time.Duration) *Lock { return &Lock{ repo: repo, key: key, value: value, expiry: expiry, } } // Acquire 获取锁 func (l *Lock) Acquire(ctx context.Context) (bool, error) { return l.repo.SetNX(ctx, l.key, l.value, l.expiry) } // Release 释放锁 func (l *Lock) Release(ctx context.Context) error { script := redis.NewScript(` if redis.call("get", KEYS[1]) == ARGV[1] then return redis.call("del", KEYS[1]) else return 0 end `) return script.Run(ctx, l.repo.Client(), []string{l.key}, l.value).Err() } // Refresh 刷新锁 func (l *Lock) Refresh(ctx context.Context) (bool, error) { script := redis.NewScript(` if redis.call("get", KEYS[1]) == ARGV[1] then return redis.call("pexpire", KEYS[1], ARGV[2]) else return 0 end `) result, err := script.Run(ctx, l.repo.Client(), []string{l.key}, l.value, l.expiry.Milliseconds()).Result() if err != nil { return false, err } return result.(int64) == 1, nil } // ========== JSON 序列化 ========== // SetJSON 设置JSON对象 func SetJSON(ctx context.Context, repo RedisRepo, key string, value interface{}, ttl time.Duration) error { data, err := json.Marshal(value) if err != nil { return fmt.Errorf("json marshal 失败: %w", err) } return repo.Set(ctx, key, string(data), ttl) } // GetJSON 获取JSON对象 func GetJSON(ctx context.Context, repo RedisRepo, key string, dest interface{}) error { val, err := repo.Get(ctx, key) if err != nil { return err } if val == "" { return errors.New("key 不存在") } if err := json.Unmarshal([]byte(val), dest); err != nil { return fmt.Errorf("json unmarshal 失败: %w", err) } return nil } // ========== 限流器 ========== // RateLimiter 限流器(令牌桶算法) type RateLimiter struct { repo RedisRepo key string limit int64 // 最大令牌数 interval time.Duration // 时间窗口 } // NewRateLimiter 创建限流器 func NewRateLimiter(repo RedisRepo, key string, limit int64, interval time.Duration) *RateLimiter { return &RateLimiter{ repo: repo, key: key, limit: limit, interval: interval, } } // Allow 检查是否允许(简单计数器实现) func (r *RateLimiter) Allow(ctx context.Context) (bool, error) { pipe := r.repo.Pipeline() incrCmd := pipe.Incr(ctx, r.key) pipe.Expire(ctx, r.key, r.interval) if _, err := pipe.Exec(ctx); err != nil { return false, err } count, err := incrCmd.Result() if err != nil { return false, err } return count <= r.limit, nil } // AllowN 检查是否允许N次 func (r *RateLimiter) AllowN(ctx context.Context, n int64) (bool, error) { pipe := r.repo.Pipeline() incrCmd := pipe.IncrBy(ctx, r.key, n) pipe.Expire(ctx, r.key, r.interval) if _, err := pipe.Exec(ctx); err != nil { return false, err } count, err := incrCmd.Result() if err != nil { return false, err } return count <= r.limit, nil } // Remaining 获取剩余次数 func (r *RateLimiter) Remaining(ctx context.Context) (int64, error) { val, err := r.repo.Get(ctx, r.key) if err != nil { return r.limit, nil } if val == "" { return r.limit, nil } var count int64 if _, err := fmt.Sscanf(val, "%d", &count); err != nil { return 0, err } remaining := r.limit - count if remaining < 0 { remaining = 0 } return remaining, nil } // Reset 重置限流器 func (r *RateLimiter) Reset(ctx context.Context) error { _, err := r.repo.Del(ctx, r.key) return err } // ========== 缓存装饰器 ========== // CacheDecorator 缓存装饰器 type CacheDecorator struct { repo RedisRepo ttl time.Duration } // NewCacheDecorator 创建缓存装饰器 func NewCacheDecorator(repo RedisRepo, ttl time.Duration) *CacheDecorator { return &CacheDecorator{ repo: repo, ttl: ttl, } } // GetOrSet 获取或设置缓存 func (c *CacheDecorator) GetOrSet(ctx context.Context, key string, dest interface{}, loader func() (interface{}, error)) error { // 尝试从缓存获取 err := GetJSON(ctx, c.repo, key, dest) if err == nil { return nil } // 缓存未命中,执行加载函数 data, err := loader() if err != nil { return fmt.Errorf("loader 执行失败: %w", err) } // 设置缓存 if err := SetJSON(ctx, c.repo, key, data, c.ttl); err != nil { // 缓存设置失败不影响主流程 // 可以记录日志 } // 将数据赋值给dest dataBytes, _ := json.Marshal(data) return json.Unmarshal(dataBytes, dest) } // ========== 布隆过滤器(简单实现) ========== // BloomFilter 布隆过滤器 type BloomFilter struct { repo RedisRepo key string size uint64 } // NewBloomFilter 创建布隆过滤器 func NewBloomFilter(repo RedisRepo, key string, size uint64) *BloomFilter { return &BloomFilter{ repo: repo, key: key, size: size, } } // Add 添加元素 func (b *BloomFilter) Add(ctx context.Context, value string) error { // 简单hash hash1 := hashString(value, 0) % b.size hash2 := hashString(value, 1) % b.size hash3 := hashString(value, 2) % b.size pipe := b.repo.Pipeline() pipe.SetBit(ctx, b.key, int64(hash1), 1) pipe.SetBit(ctx, b.key, int64(hash2), 1) pipe.SetBit(ctx, b.key, int64(hash3), 1) _, err := pipe.Exec(ctx) return err } // Exists 检查元素是否存在 func (b *BloomFilter) Exists(ctx context.Context, value string) (bool, error) { hash1 := hashString(value, 0) % b.size hash2 := hashString(value, 1) % b.size hash3 := hashString(value, 2) % b.size pipe := b.repo.Pipeline() bit1Cmd := pipe.GetBit(ctx, b.key, int64(hash1)) bit2Cmd := pipe.GetBit(ctx, b.key, int64(hash2)) bit3Cmd := pipe.GetBit(ctx, b.key, int64(hash3)) if _, err := pipe.Exec(ctx); err != nil { return false, err } bit1, _ := bit1Cmd.Result() bit2, _ := bit2Cmd.Result() bit3, _ := bit3Cmd.Result() return bit1 == 1 && bit2 == 1 && bit3 == 1, nil } // hashString 简单字符串hash func hashString(s string, seed uint64) uint64 { hash := seed for i := 0; i < len(s); i++ { hash = hash*31 + uint64(s[i]) } return hash }