288 lines
6.2 KiB
Go
288 lines
6.2 KiB
Go
package cache
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"time"
|
|
|
|
"github.com/redis/go-redis/v9"
|
|
)
|
|
|
|
// ========== 分布式锁 ==========
|
|
|
|
// Lock 分布式锁
|
|
type Lock struct {
|
|
repo RedisRepo
|
|
key string
|
|
value string
|
|
expiry time.Duration
|
|
}
|
|
|
|
// NewLock 创建分布式锁
|
|
func NewLock(repo RedisRepo, key, value string, expiry time.Duration) *Lock {
|
|
return &Lock{
|
|
repo: repo,
|
|
key: key,
|
|
value: value,
|
|
expiry: expiry,
|
|
}
|
|
}
|
|
|
|
// Acquire 获取锁
|
|
func (l *Lock) Acquire(ctx context.Context) (bool, error) {
|
|
return l.repo.SetNX(ctx, l.key, l.value, l.expiry)
|
|
}
|
|
|
|
// Release 释放锁
|
|
func (l *Lock) Release(ctx context.Context) error {
|
|
script := redis.NewScript(`
|
|
if redis.call("get", KEYS[1]) == ARGV[1] then
|
|
return redis.call("del", KEYS[1])
|
|
else
|
|
return 0
|
|
end
|
|
`)
|
|
|
|
return script.Run(ctx, l.repo.Client(), []string{l.key}, l.value).Err()
|
|
}
|
|
|
|
// Refresh 刷新锁
|
|
func (l *Lock) Refresh(ctx context.Context) (bool, error) {
|
|
script := redis.NewScript(`
|
|
if redis.call("get", KEYS[1]) == ARGV[1] then
|
|
return redis.call("pexpire", KEYS[1], ARGV[2])
|
|
else
|
|
return 0
|
|
end
|
|
`)
|
|
|
|
result, err := script.Run(ctx, l.repo.Client(), []string{l.key}, l.value, l.expiry.Milliseconds()).Result()
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
return result.(int64) == 1, nil
|
|
}
|
|
|
|
// ========== JSON 序列化 ==========
|
|
|
|
// SetJSON 设置JSON对象
|
|
func SetJSON(ctx context.Context, repo RedisRepo, key string, value interface{}, ttl time.Duration) error {
|
|
data, err := json.Marshal(value)
|
|
if err != nil {
|
|
return fmt.Errorf("json marshal 失败: %w", err)
|
|
}
|
|
return repo.Set(ctx, key, string(data), ttl)
|
|
}
|
|
|
|
// GetJSON 获取JSON对象
|
|
func GetJSON(ctx context.Context, repo RedisRepo, key string, dest interface{}) error {
|
|
val, err := repo.Get(ctx, key)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if val == "" {
|
|
return errors.New("key 不存在")
|
|
}
|
|
|
|
if err := json.Unmarshal([]byte(val), dest); err != nil {
|
|
return fmt.Errorf("json unmarshal 失败: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// ========== 限流器 ==========
|
|
|
|
// RateLimiter 限流器(令牌桶算法)
|
|
type RateLimiter struct {
|
|
repo RedisRepo
|
|
key string
|
|
limit int64 // 最大令牌数
|
|
interval time.Duration // 时间窗口
|
|
}
|
|
|
|
// NewRateLimiter 创建限流器
|
|
func NewRateLimiter(repo RedisRepo, key string, limit int64, interval time.Duration) *RateLimiter {
|
|
return &RateLimiter{
|
|
repo: repo,
|
|
key: key,
|
|
limit: limit,
|
|
interval: interval,
|
|
}
|
|
}
|
|
|
|
// Allow 检查是否允许(简单计数器实现)
|
|
func (r *RateLimiter) Allow(ctx context.Context) (bool, error) {
|
|
pipe := r.repo.Pipeline()
|
|
|
|
incrCmd := pipe.Incr(ctx, r.key)
|
|
pipe.Expire(ctx, r.key, r.interval)
|
|
|
|
if _, err := pipe.Exec(ctx); err != nil {
|
|
return false, err
|
|
}
|
|
|
|
count, err := incrCmd.Result()
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
return count <= r.limit, nil
|
|
}
|
|
|
|
// AllowN 检查是否允许N次
|
|
func (r *RateLimiter) AllowN(ctx context.Context, n int64) (bool, error) {
|
|
pipe := r.repo.Pipeline()
|
|
|
|
incrCmd := pipe.IncrBy(ctx, r.key, n)
|
|
pipe.Expire(ctx, r.key, r.interval)
|
|
|
|
if _, err := pipe.Exec(ctx); err != nil {
|
|
return false, err
|
|
}
|
|
|
|
count, err := incrCmd.Result()
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
return count <= r.limit, nil
|
|
}
|
|
|
|
// Remaining 获取剩余次数
|
|
func (r *RateLimiter) Remaining(ctx context.Context) (int64, error) {
|
|
val, err := r.repo.Get(ctx, r.key)
|
|
if err != nil {
|
|
return r.limit, nil
|
|
}
|
|
if val == "" {
|
|
return r.limit, nil
|
|
}
|
|
|
|
var count int64
|
|
if _, err := fmt.Sscanf(val, "%d", &count); err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
remaining := r.limit - count
|
|
if remaining < 0 {
|
|
remaining = 0
|
|
}
|
|
|
|
return remaining, nil
|
|
}
|
|
|
|
// Reset 重置限流器
|
|
func (r *RateLimiter) Reset(ctx context.Context) error {
|
|
_, err := r.repo.Del(ctx, r.key)
|
|
return err
|
|
}
|
|
|
|
// ========== 缓存装饰器 ==========
|
|
|
|
// CacheDecorator 缓存装饰器
|
|
type CacheDecorator struct {
|
|
repo RedisRepo
|
|
ttl time.Duration
|
|
}
|
|
|
|
// NewCacheDecorator 创建缓存装饰器
|
|
func NewCacheDecorator(repo RedisRepo, ttl time.Duration) *CacheDecorator {
|
|
return &CacheDecorator{
|
|
repo: repo,
|
|
ttl: ttl,
|
|
}
|
|
}
|
|
|
|
// GetOrSet 获取或设置缓存
|
|
func (c *CacheDecorator) GetOrSet(ctx context.Context, key string, dest interface{}, loader func() (interface{}, error)) error {
|
|
// 尝试从缓存获取
|
|
err := GetJSON(ctx, c.repo, key, dest)
|
|
if err == nil {
|
|
return nil
|
|
}
|
|
|
|
// 缓存未命中,执行加载函数
|
|
data, err := loader()
|
|
if err != nil {
|
|
return fmt.Errorf("loader 执行失败: %w", err)
|
|
}
|
|
|
|
// 设置缓存
|
|
if err := SetJSON(ctx, c.repo, key, data, c.ttl); err != nil {
|
|
// 缓存设置失败不影响主流程
|
|
// 可以记录日志
|
|
}
|
|
|
|
// 将数据赋值给dest
|
|
dataBytes, _ := json.Marshal(data)
|
|
return json.Unmarshal(dataBytes, dest)
|
|
}
|
|
|
|
// ========== 布隆过滤器(简单实现) ==========
|
|
|
|
// BloomFilter 布隆过滤器
|
|
type BloomFilter struct {
|
|
repo RedisRepo
|
|
key string
|
|
size uint64
|
|
}
|
|
|
|
// NewBloomFilter 创建布隆过滤器
|
|
func NewBloomFilter(repo RedisRepo, key string, size uint64) *BloomFilter {
|
|
return &BloomFilter{
|
|
repo: repo,
|
|
key: key,
|
|
size: size,
|
|
}
|
|
}
|
|
|
|
// Add 添加元素
|
|
func (b *BloomFilter) Add(ctx context.Context, value string) error {
|
|
// 简单hash
|
|
hash1 := hashString(value, 0) % b.size
|
|
hash2 := hashString(value, 1) % b.size
|
|
hash3 := hashString(value, 2) % b.size
|
|
|
|
pipe := b.repo.Pipeline()
|
|
pipe.SetBit(ctx, b.key, int64(hash1), 1)
|
|
pipe.SetBit(ctx, b.key, int64(hash2), 1)
|
|
pipe.SetBit(ctx, b.key, int64(hash3), 1)
|
|
|
|
_, err := pipe.Exec(ctx)
|
|
return err
|
|
}
|
|
|
|
// Exists 检查元素是否存在
|
|
func (b *BloomFilter) Exists(ctx context.Context, value string) (bool, error) {
|
|
hash1 := hashString(value, 0) % b.size
|
|
hash2 := hashString(value, 1) % b.size
|
|
hash3 := hashString(value, 2) % b.size
|
|
|
|
pipe := b.repo.Pipeline()
|
|
bit1Cmd := pipe.GetBit(ctx, b.key, int64(hash1))
|
|
bit2Cmd := pipe.GetBit(ctx, b.key, int64(hash2))
|
|
bit3Cmd := pipe.GetBit(ctx, b.key, int64(hash3))
|
|
|
|
if _, err := pipe.Exec(ctx); err != nil {
|
|
return false, err
|
|
}
|
|
|
|
bit1, _ := bit1Cmd.Result()
|
|
bit2, _ := bit2Cmd.Result()
|
|
bit3, _ := bit3Cmd.Result()
|
|
|
|
return bit1 == 1 && bit2 == 1 && bit3 == 1, nil
|
|
}
|
|
|
|
// hashString 简单字符串hash
|
|
func hashString(s string, seed uint64) uint64 {
|
|
hash := seed
|
|
for i := 0; i < len(s); i++ {
|
|
hash = hash*31 + uint64(s[i])
|
|
}
|
|
return hash
|
|
}
|