diff --git a/pkg/cache/helper.go b/pkg/cache/helper.go new file mode 100644 index 0000000..ebc41bc --- /dev/null +++ b/pkg/cache/helper.go @@ -0,0 +1,287 @@ +package cache + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "time" + + "github.com/redis/go-redis/v9" +) + +// ========== 分布式锁 ========== + +// Lock 分布式锁 +type Lock struct { + repo RedisRepo + key string + value string + expiry time.Duration +} + +// NewLock 创建分布式锁 +func NewLock(repo RedisRepo, key, value string, expiry time.Duration) *Lock { + return &Lock{ + repo: repo, + key: key, + value: value, + expiry: expiry, + } +} + +// Acquire 获取锁 +func (l *Lock) Acquire(ctx context.Context) (bool, error) { + return l.repo.SetNX(ctx, l.key, l.value, l.expiry) +} + +// Release 释放锁 +func (l *Lock) Release(ctx context.Context) error { + script := redis.NewScript(` + if redis.call("get", KEYS[1]) == ARGV[1] then + return redis.call("del", KEYS[1]) + else + return 0 + end + `) + + return script.Run(ctx, l.repo.Client(), []string{l.key}, l.value).Err() +} + +// Refresh 刷新锁 +func (l *Lock) Refresh(ctx context.Context) (bool, error) { + script := redis.NewScript(` + if redis.call("get", KEYS[1]) == ARGV[1] then + return redis.call("pexpire", KEYS[1], ARGV[2]) + else + return 0 + end + `) + + result, err := script.Run(ctx, l.repo.Client(), []string{l.key}, l.value, l.expiry.Milliseconds()).Result() + if err != nil { + return false, err + } + + return result.(int64) == 1, nil +} + +// ========== JSON 序列化 ========== + +// SetJSON 设置JSON对象 +func SetJSON(ctx context.Context, repo RedisRepo, key string, value interface{}, ttl time.Duration) error { + data, err := json.Marshal(value) + if err != nil { + return fmt.Errorf("json marshal 失败: %w", err) + } + return repo.Set(ctx, key, string(data), ttl) +} + +// GetJSON 获取JSON对象 +func GetJSON(ctx context.Context, repo RedisRepo, key string, dest interface{}) error { + val, err := repo.Get(ctx, key) + if err != nil { + return err + } + if val == "" { + return errors.New("key 不存在") + } + + if err := json.Unmarshal([]byte(val), dest); err != nil { + return fmt.Errorf("json unmarshal 失败: %w", err) + } + return nil +} + +// ========== 限流器 ========== + +// RateLimiter 限流器(令牌桶算法) +type RateLimiter struct { + repo RedisRepo + key string + limit int64 // 最大令牌数 + interval time.Duration // 时间窗口 +} + +// NewRateLimiter 创建限流器 +func NewRateLimiter(repo RedisRepo, key string, limit int64, interval time.Duration) *RateLimiter { + return &RateLimiter{ + repo: repo, + key: key, + limit: limit, + interval: interval, + } +} + +// Allow 检查是否允许(简单计数器实现) +func (r *RateLimiter) Allow(ctx context.Context) (bool, error) { + pipe := r.repo.Pipeline() + + incrCmd := pipe.Incr(ctx, r.key) + pipe.Expire(ctx, r.key, r.interval) + + if _, err := pipe.Exec(ctx); err != nil { + return false, err + } + + count, err := incrCmd.Result() + if err != nil { + return false, err + } + + return count <= r.limit, nil +} + +// AllowN 检查是否允许N次 +func (r *RateLimiter) AllowN(ctx context.Context, n int64) (bool, error) { + pipe := r.repo.Pipeline() + + incrCmd := pipe.IncrBy(ctx, r.key, n) + pipe.Expire(ctx, r.key, r.interval) + + if _, err := pipe.Exec(ctx); err != nil { + return false, err + } + + count, err := incrCmd.Result() + if err != nil { + return false, err + } + + return count <= r.limit, nil +} + +// Remaining 获取剩余次数 +func (r *RateLimiter) Remaining(ctx context.Context) (int64, error) { + val, err := r.repo.Get(ctx, r.key) + if err != nil { + return r.limit, nil + } + if val == "" { + return r.limit, nil + } + + var count int64 + if _, err := fmt.Sscanf(val, "%d", &count); err != nil { + return 0, err + } + + remaining := r.limit - count + if remaining < 0 { + remaining = 0 + } + + return remaining, nil +} + +// Reset 重置限流器 +func (r *RateLimiter) Reset(ctx context.Context) error { + _, err := r.repo.Del(ctx, r.key) + return err +} + +// ========== 缓存装饰器 ========== + +// CacheDecorator 缓存装饰器 +type CacheDecorator struct { + repo RedisRepo + ttl time.Duration +} + +// NewCacheDecorator 创建缓存装饰器 +func NewCacheDecorator(repo RedisRepo, ttl time.Duration) *CacheDecorator { + return &CacheDecorator{ + repo: repo, + ttl: ttl, + } +} + +// GetOrSet 获取或设置缓存 +func (c *CacheDecorator) GetOrSet(ctx context.Context, key string, dest interface{}, loader func() (interface{}, error)) error { + // 尝试从缓存获取 + err := GetJSON(ctx, c.repo, key, dest) + if err == nil { + return nil + } + + // 缓存未命中,执行加载函数 + data, err := loader() + if err != nil { + return fmt.Errorf("loader 执行失败: %w", err) + } + + // 设置缓存 + if err := SetJSON(ctx, c.repo, key, data, c.ttl); err != nil { + // 缓存设置失败不影响主流程 + // 可以记录日志 + } + + // 将数据赋值给dest + dataBytes, _ := json.Marshal(data) + return json.Unmarshal(dataBytes, dest) +} + +// ========== 布隆过滤器(简单实现) ========== + +// BloomFilter 布隆过滤器 +type BloomFilter struct { + repo RedisRepo + key string + size uint64 +} + +// NewBloomFilter 创建布隆过滤器 +func NewBloomFilter(repo RedisRepo, key string, size uint64) *BloomFilter { + return &BloomFilter{ + repo: repo, + key: key, + size: size, + } +} + +// Add 添加元素 +func (b *BloomFilter) Add(ctx context.Context, value string) error { + // 简单hash + hash1 := hashString(value, 0) % b.size + hash2 := hashString(value, 1) % b.size + hash3 := hashString(value, 2) % b.size + + pipe := b.repo.Pipeline() + pipe.SetBit(ctx, b.key, int64(hash1), 1) + pipe.SetBit(ctx, b.key, int64(hash2), 1) + pipe.SetBit(ctx, b.key, int64(hash3), 1) + + _, err := pipe.Exec(ctx) + return err +} + +// Exists 检查元素是否存在 +func (b *BloomFilter) Exists(ctx context.Context, value string) (bool, error) { + hash1 := hashString(value, 0) % b.size + hash2 := hashString(value, 1) % b.size + hash3 := hashString(value, 2) % b.size + + pipe := b.repo.Pipeline() + bit1Cmd := pipe.GetBit(ctx, b.key, int64(hash1)) + bit2Cmd := pipe.GetBit(ctx, b.key, int64(hash2)) + bit3Cmd := pipe.GetBit(ctx, b.key, int64(hash3)) + + if _, err := pipe.Exec(ctx); err != nil { + return false, err + } + + bit1, _ := bit1Cmd.Result() + bit2, _ := bit2Cmd.Result() + bit3, _ := bit3Cmd.Result() + + return bit1 == 1 && bit2 == 1 && bit3 == 1, nil +} + +// hashString 简单字符串hash +func hashString(s string, seed uint64) uint64 { + hash := seed + for i := 0; i < len(s); i++ { + hash = hash*31 + uint64(s[i]) + } + return hash +} diff --git a/pkg/cache/redis.go b/pkg/cache/redis.go index 62303ba..cfbb38c 100644 --- a/pkg/cache/redis.go +++ b/pkg/cache/redis.go @@ -4,479 +4,716 @@ import ( "context" "errors" "fmt" + "sync" "time" - "gitea.bvbej.com/bvbej/base-golang/pkg/time_parse" - "gitea.bvbej.com/bvbej/base-golang/pkg/trace" "github.com/redis/go-redis/v9" ) -type Option func(*option) - -type Trace = trace.T - -type option struct { - Trace *trace.Trace - Redis *trace.Redis -} - -type RedisConfig struct { - Addr string `yaml:"addr"` - Pass string `yaml:"pass"` - DB int `yaml:"db"` - MaxRetries int `yaml:"maxRetries"` // 最大重试次数 - PoolSize int `yaml:"poolSize"` // Redis连接池大小 - MinIdleConn int `yaml:"minIdleConn"` // 最小空闲连接数 -} - -func newOption() *option { - return &option{} -} - -var _ Repo = (*cacheRepo)(nil) - -type Repo interface { - i() +// RedisRepo Redis接口 +type RedisRepo interface { + // Client 获取原生客户端 Client() *redis.Client - Set(key, value string, ttl time.Duration, options ...Option) error - Get(key string, options ...Option) (string, error) - TTL(key string) (time.Duration, error) - Expire(key string, ttl time.Duration) bool - ExpireAt(key string, ttl time.Time) bool - Del(key string, options ...Option) bool - Exists(keys ...string) bool - Incr(key string, options ...Option) (int64, error) - Decr(key string, options ...Option) (int64, error) - HGet(key, field string, options ...Option) (string, error) - HSet(key, field, value string, options ...Option) error - HDel(key, field string, options ...Option) error - HGetAll(key string, options ...Option) (map[string]string, error) - HIncrBy(key, field string, incr int64, options ...Option) (int64, error) - HIncrByFloat(key, field string, incr float64, options ...Option) (float64, error) - LPush(key, value string, options ...Option) error - LLen(key string, options ...Option) (int64, error) - BRPop(key string, timeout time.Duration, options ...Option) (string, error) + + // String 操作 + Set(ctx context.Context, key, value string, ttl time.Duration) error + Get(ctx context.Context, key string) (string, error) + GetSet(ctx context.Context, key, value string) (string, error) + SetNX(ctx context.Context, key, value string, ttl time.Duration) (bool, error) + MGet(ctx context.Context, keys ...string) ([]interface{}, error) + MSet(ctx context.Context, pairs ...interface{}) error + + // Key 操作 + Del(ctx context.Context, keys ...string) (int64, error) + Exists(ctx context.Context, keys ...string) (int64, error) + Expire(ctx context.Context, key string, ttl time.Duration) (bool, error) + ExpireAt(ctx context.Context, key string, tm time.Time) (bool, error) + TTL(ctx context.Context, key string) (time.Duration, error) + Rename(ctx context.Context, key, newKey string) error + + // 计数器操作 + Incr(ctx context.Context, key string) (int64, error) + IncrBy(ctx context.Context, key string, value int64) (int64, error) + Decr(ctx context.Context, key string) (int64, error) + DecrBy(ctx context.Context, key string, value int64) (int64, error) + + // Hash 操作 + HSet(ctx context.Context, key string, values ...interface{}) error + HGet(ctx context.Context, key, field string) (string, error) + HGetAll(ctx context.Context, key string) (map[string]string, error) + HDel(ctx context.Context, key string, fields ...string) (int64, error) + HExists(ctx context.Context, key, field string) (bool, error) + HIncrBy(ctx context.Context, key, field string, incr int64) (int64, error) + HIncrByFloat(ctx context.Context, key, field string, incr float64) (float64, error) + HKeys(ctx context.Context, key string) ([]string, error) + HLen(ctx context.Context, key string) (int64, error) + + // List 操作 + LPush(ctx context.Context, key string, values ...interface{}) (int64, error) + RPush(ctx context.Context, key string, values ...interface{}) (int64, error) + LPop(ctx context.Context, key string) (string, error) + RPop(ctx context.Context, key string) (string, error) + BLPop(ctx context.Context, timeout time.Duration, keys ...string) ([]string, error) + BRPop(ctx context.Context, timeout time.Duration, keys ...string) ([]string, error) + LLen(ctx context.Context, key string) (int64, error) + LRange(ctx context.Context, key string, start, stop int64) ([]string, error) + + // Set 操作 + SAdd(ctx context.Context, key string, members ...interface{}) (int64, error) + SMembers(ctx context.Context, key string) ([]string, error) + SIsMember(ctx context.Context, key string, member interface{}) (bool, error) + SRem(ctx context.Context, key string, members ...interface{}) (int64, error) + SCard(ctx context.Context, key string) (int64, error) + + // ZSet 操作 + ZAdd(ctx context.Context, key string, members ...redis.Z) (int64, error) + ZRange(ctx context.Context, key string, start, stop int64) ([]string, error) + ZRangeWithScores(ctx context.Context, key string, start, stop int64) ([]redis.Z, error) + ZRevRange(ctx context.Context, key string, start, stop int64) ([]string, error) + ZRem(ctx context.Context, key string, members ...interface{}) (int64, error) + ZCard(ctx context.Context, key string) (int64, error) + ZScore(ctx context.Context, key, member string) (float64, error) + + // Pipeline 操作 + Pipeline() redis.Pipeliner + TxPipeline() redis.Pipeliner + + // Pub/Sub 操作 + Publish(ctx context.Context, channel string, message interface{}) error + Subscribe(ctx context.Context, channels ...string) *redis.PubSub + + // 健康检查 + Ping(ctx context.Context) error + Info(ctx context.Context, section ...string) (string, error) + + // 连接管理 Close() error + PoolStats() *redis.PoolStats } -type cacheRepo struct { - client *redis.Client - ctx context.Context +// RedisConfig Redis配置 +type RedisConfig struct { + // 基础配置 + Addr string `yaml:"addr" json:"addr"` // 地址,如: localhost:6379 + Password string `yaml:"password" json:"password"` // 密码 + DB int `yaml:"db" json:"db"` // 数据库编号 + + // 连接池配置 + PoolSize int `yaml:"pool_size" json:"pool_size"` // 最大连接数 + MinIdleConns int `yaml:"min_idle_conns" json:"min_idle_conns"` // 最小空闲连接数 + + // 超时配置 + DialTimeout time.Duration `yaml:"dial_timeout" json:"dial_timeout"` // 连接超时 + ReadTimeout time.Duration `yaml:"read_timeout" json:"read_timeout"` // 读超时 + WriteTimeout time.Duration `yaml:"write_timeout" json:"write_timeout"` // 写超时 + PoolTimeout time.Duration `yaml:"pool_timeout" json:"pool_timeout"` // 连接池超时 + + // 重试配置 + MaxRetries int `yaml:"max_retries" json:"max_retries"` // 最大重试次数 + MinRetryBackoff time.Duration `yaml:"min_retry_backoff" json:"min_retry_backoff"` // 最小重试间隔 + MaxRetryBackoff time.Duration `yaml:"max_retry_backoff" json:"max_retry_backoff"` // 最大重试间隔 + + // 连接存活 + ConnMaxIdleTime time.Duration `yaml:"conn_max_idle_time" json:"conn_max_idle_time"` // 最大空闲时间 + ConnMaxLifetime time.Duration `yaml:"conn_max_lifetime" json:"conn_max_lifetime"` // 最大生命周期 } -func New(cfg RedisConfig) (Repo, error) { - client := redis.NewClient(&redis.Options{ - Addr: cfg.Addr, - Password: cfg.Pass, - DB: cfg.DB, - MaxRetries: cfg.MaxRetries, - PoolSize: cfg.PoolSize, - MinIdleConns: cfg.MinIdleConn, - }) - ctx := context.TODO() - if err := client.Ping(ctx).Err(); err != nil { - return nil, errors.Join(err, errors.New("ping redis err")) +// DefaultRedisConfig 默认配置 +func DefaultRedisConfig() *RedisConfig { + return &RedisConfig{ + Addr: "localhost:6379", + Password: "", + DB: 0, + PoolSize: 100, + MinIdleConns: 10, + DialTimeout: 5 * time.Second, + ReadTimeout: 3 * time.Second, + WriteTimeout: 3 * time.Second, + PoolTimeout: 4 * time.Second, + MaxRetries: 3, + MinRetryBackoff: 8 * time.Millisecond, + MaxRetryBackoff: 512 * time.Millisecond, + ConnMaxIdleTime: 5 * time.Minute, + ConnMaxLifetime: 30 * time.Minute, } - return &cacheRepo{ +} + +// redisRepo Redis实现 +type redisRepo struct { + client *redis.Client + config *RedisConfig + mu sync.RWMutex + closed bool +} + +// NewRedis 创建Redis实例 +func NewRedis(cfg *RedisConfig) (RedisRepo, error) { + if cfg == nil { + cfg = DefaultRedisConfig() + } + + // 合并默认配置 + mergeDefaultConfig(cfg) + + // 创建客户端 + client := redis.NewClient(&redis.Options{ + Addr: cfg.Addr, + Password: cfg.Password, + DB: cfg.DB, + PoolSize: cfg.PoolSize, + MinIdleConns: cfg.MinIdleConns, + DialTimeout: cfg.DialTimeout, + ReadTimeout: cfg.ReadTimeout, + WriteTimeout: cfg.WriteTimeout, + PoolTimeout: cfg.PoolTimeout, + MaxRetries: cfg.MaxRetries, + MinRetryBackoff: cfg.MinRetryBackoff, + MaxRetryBackoff: cfg.MaxRetryBackoff, + ConnMaxIdleTime: cfg.ConnMaxIdleTime, + ConnMaxLifetime: cfg.ConnMaxLifetime, + }) + + // Ping 测试连接 + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + if err := client.Ping(ctx).Err(); err != nil { + return nil, fmt.Errorf("redis ping 失败: %w", err) + } + + return &redisRepo{ client: client, - ctx: ctx, + config: cfg, + closed: false, }, nil } -func WithTrace(t Trace) Option { - return func(opt *option) { - if t != nil { - opt.Trace = t.(*trace.Trace) - opt.Redis = new(trace.Redis) - } - } +// Client 获取原生客户端 +func (r *redisRepo) Client() *redis.Client { + return r.client } -func (c *cacheRepo) i() {} +// ========== String 操作 ========== -func (c *cacheRepo) Client() *redis.Client { - return c.client +// Set 设置值 +func (r *redisRepo) Set(ctx context.Context, key, value string, ttl time.Duration) error { + if err := r.checkClosed(); err != nil { + return err + } + return r.client.Set(ctx, key, value, ttl).Err() } -func (c *cacheRepo) Set(key, value string, ttl time.Duration, options ...Option) error { - ts := time.Now() - opt := newOption() - defer func() { - if opt.Trace != nil { - opt.Redis.Timestamp = time_parse.CSTLayoutString() - opt.Redis.Handle = "set" - opt.Redis.Key = key - opt.Redis.Value = value - opt.Redis.TTL = ttl.Minutes() - opt.Redis.CostSeconds = time.Since(ts).Seconds() - opt.Trace.AppendRedis(opt.Redis) - } - }() - - for _, f := range options { - f(opt) +// Get 获取值 +func (r *redisRepo) Get(ctx context.Context, key string) (string, error) { + if err := r.checkClosed(); err != nil { + return "", err } - if err := c.client.Set(c.ctx, key, value, ttl).Err(); err != nil { - return errors.Join(err, fmt.Errorf("redis set key: %s err", key)) + val, err := r.client.Get(ctx, key).Result() + if err == redis.Nil { + return "", nil + } + return val, err +} + +// GetSet 设置新值并返回旧值 +func (r *redisRepo) GetSet(ctx context.Context, key, value string) (string, error) { + if err := r.checkClosed(); err != nil { + return "", err } + val, err := r.client.GetSet(ctx, key, value).Result() + if err == redis.Nil { + return "", nil + } + return val, err +} + +// SetNX 仅当key不存在时设置 +func (r *redisRepo) SetNX(ctx context.Context, key, value string, ttl time.Duration) (bool, error) { + if err := r.checkClosed(); err != nil { + return false, err + } + return r.client.SetNX(ctx, key, value, ttl).Result() +} + +// MGet 批量获取 +func (r *redisRepo) MGet(ctx context.Context, keys ...string) ([]interface{}, error) { + if err := r.checkClosed(); err != nil { + return nil, err + } + return r.client.MGet(ctx, keys...).Result() +} + +// MSet 批量设置 +func (r *redisRepo) MSet(ctx context.Context, pairs ...interface{}) error { + if err := r.checkClosed(); err != nil { + return err + } + return r.client.MSet(ctx, pairs...).Err() +} + +// ========== Key 操作 ========== + +// Del 删除键 +func (r *redisRepo) Del(ctx context.Context, keys ...string) (int64, error) { + if err := r.checkClosed(); err != nil { + return 0, err + } + return r.client.Del(ctx, keys...).Result() +} + +// Exists 检查键是否存在 +func (r *redisRepo) Exists(ctx context.Context, keys ...string) (int64, error) { + if err := r.checkClosed(); err != nil { + return 0, err + } + return r.client.Exists(ctx, keys...).Result() +} + +// Expire 设置过期时间 +func (r *redisRepo) Expire(ctx context.Context, key string, ttl time.Duration) (bool, error) { + if err := r.checkClosed(); err != nil { + return false, err + } + return r.client.Expire(ctx, key, ttl).Result() +} + +// ExpireAt 设置过期时间点 +func (r *redisRepo) ExpireAt(ctx context.Context, key string, tm time.Time) (bool, error) { + if err := r.checkClosed(); err != nil { + return false, err + } + return r.client.ExpireAt(ctx, key, tm).Result() +} + +// TTL 获取剩余时间 +func (r *redisRepo) TTL(ctx context.Context, key string) (time.Duration, error) { + if err := r.checkClosed(); err != nil { + return 0, err + } + return r.client.TTL(ctx, key).Result() +} + +// Rename 重命名键 +func (r *redisRepo) Rename(ctx context.Context, key, newKey string) error { + if err := r.checkClosed(); err != nil { + return err + } + return r.client.Rename(ctx, key, newKey).Err() +} + +// ========== 计数器操作 ========== + +// Incr 自增1 +func (r *redisRepo) Incr(ctx context.Context, key string) (int64, error) { + if err := r.checkClosed(); err != nil { + return 0, err + } + return r.client.Incr(ctx, key).Result() +} + +// IncrBy 自增指定值 +func (r *redisRepo) IncrBy(ctx context.Context, key string, value int64) (int64, error) { + if err := r.checkClosed(); err != nil { + return 0, err + } + return r.client.IncrBy(ctx, key, value).Result() +} + +// Decr 自减1 +func (r *redisRepo) Decr(ctx context.Context, key string) (int64, error) { + if err := r.checkClosed(); err != nil { + return 0, err + } + return r.client.Decr(ctx, key).Result() +} + +// DecrBy 自减指定值 +func (r *redisRepo) DecrBy(ctx context.Context, key string, value int64) (int64, error) { + if err := r.checkClosed(); err != nil { + return 0, err + } + return r.client.DecrBy(ctx, key, value).Result() +} + +// ========== Hash 操作 ========== + +// HSet 设置hash字段 +func (r *redisRepo) HSet(ctx context.Context, key string, values ...interface{}) error { + if err := r.checkClosed(); err != nil { + return err + } + return r.client.HSet(ctx, key, values...).Err() +} + +// HGet 获取hash字段 +func (r *redisRepo) HGet(ctx context.Context, key, field string) (string, error) { + if err := r.checkClosed(); err != nil { + return "", err + } + + val, err := r.client.HGet(ctx, key, field).Result() + if err == redis.Nil { + return "", nil + } + return val, err +} + +// HGetAll 获取所有hash字段 +func (r *redisRepo) HGetAll(ctx context.Context, key string) (map[string]string, error) { + if err := r.checkClosed(); err != nil { + return nil, err + } + return r.client.HGetAll(ctx, key).Result() +} + +// HDel 删除hash字段 +func (r *redisRepo) HDel(ctx context.Context, key string, fields ...string) (int64, error) { + if err := r.checkClosed(); err != nil { + return 0, err + } + return r.client.HDel(ctx, key, fields...).Result() +} + +// HExists 检查hash字段是否存在 +func (r *redisRepo) HExists(ctx context.Context, key, field string) (bool, error) { + if err := r.checkClosed(); err != nil { + return false, err + } + return r.client.HExists(ctx, key, field).Result() +} + +// HIncrBy hash字段自增 +func (r *redisRepo) HIncrBy(ctx context.Context, key, field string, incr int64) (int64, error) { + if err := r.checkClosed(); err != nil { + return 0, err + } + return r.client.HIncrBy(ctx, key, field, incr).Result() +} + +// HIncrByFloat hash字段浮点自增 +func (r *redisRepo) HIncrByFloat(ctx context.Context, key, field string, incr float64) (float64, error) { + if err := r.checkClosed(); err != nil { + return 0, err + } + return r.client.HIncrByFloat(ctx, key, field, incr).Result() +} + +// HKeys 获取所有hash字段名 +func (r *redisRepo) HKeys(ctx context.Context, key string) ([]string, error) { + if err := r.checkClosed(); err != nil { + return nil, err + } + return r.client.HKeys(ctx, key).Result() +} + +// HLen 获取hash字段数量 +func (r *redisRepo) HLen(ctx context.Context, key string) (int64, error) { + if err := r.checkClosed(); err != nil { + return 0, err + } + return r.client.HLen(ctx, key).Result() +} + +// ========== List 操作 ========== + +// LPush 从左侧推入 +func (r *redisRepo) LPush(ctx context.Context, key string, values ...interface{}) (int64, error) { + if err := r.checkClosed(); err != nil { + return 0, err + } + return r.client.LPush(ctx, key, values...).Result() +} + +// RPush 从右侧推入 +func (r *redisRepo) RPush(ctx context.Context, key string, values ...interface{}) (int64, error) { + if err := r.checkClosed(); err != nil { + return 0, err + } + return r.client.RPush(ctx, key, values...).Result() +} + +// LPop 从左侧弹出 +func (r *redisRepo) LPop(ctx context.Context, key string) (string, error) { + if err := r.checkClosed(); err != nil { + return "", err + } + + val, err := r.client.LPop(ctx, key).Result() + if err == redis.Nil { + return "", nil + } + return val, err +} + +// RPop 从右侧弹出 +func (r *redisRepo) RPop(ctx context.Context, key string) (string, error) { + if err := r.checkClosed(); err != nil { + return "", err + } + + val, err := r.client.RPop(ctx, key).Result() + if err == redis.Nil { + return "", nil + } + return val, err +} + +// BLPop 阻塞式从左侧弹出 +func (r *redisRepo) BLPop(ctx context.Context, timeout time.Duration, keys ...string) ([]string, error) { + if err := r.checkClosed(); err != nil { + return nil, err + } + return r.client.BLPop(ctx, timeout, keys...).Result() +} + +// BRPop 阻塞式从右侧弹出 +func (r *redisRepo) BRPop(ctx context.Context, timeout time.Duration, keys ...string) ([]string, error) { + if err := r.checkClosed(); err != nil { + return nil, err + } + return r.client.BRPop(ctx, timeout, keys...).Result() +} + +// LLen 获取列表长度 +func (r *redisRepo) LLen(ctx context.Context, key string) (int64, error) { + if err := r.checkClosed(); err != nil { + return 0, err + } + return r.client.LLen(ctx, key).Result() +} + +// LRange 获取列表范围 +func (r *redisRepo) LRange(ctx context.Context, key string, start, stop int64) ([]string, error) { + if err := r.checkClosed(); err != nil { + return nil, err + } + return r.client.LRange(ctx, key, start, stop).Result() +} + +// ========== Set 操作 ========== + +// SAdd 添加集合成员 +func (r *redisRepo) SAdd(ctx context.Context, key string, members ...interface{}) (int64, error) { + if err := r.checkClosed(); err != nil { + return 0, err + } + return r.client.SAdd(ctx, key, members...).Result() +} + +// SMembers 获取所有集合成员 +func (r *redisRepo) SMembers(ctx context.Context, key string) ([]string, error) { + if err := r.checkClosed(); err != nil { + return nil, err + } + return r.client.SMembers(ctx, key).Result() +} + +// SIsMember 检查是否是集合成员 +func (r *redisRepo) SIsMember(ctx context.Context, key string, member interface{}) (bool, error) { + if err := r.checkClosed(); err != nil { + return false, err + } + return r.client.SIsMember(ctx, key, member).Result() +} + +// SRem 删除集合成员 +func (r *redisRepo) SRem(ctx context.Context, key string, members ...interface{}) (int64, error) { + if err := r.checkClosed(); err != nil { + return 0, err + } + return r.client.SRem(ctx, key, members...).Result() +} + +// SCard 获取集合成员数量 +func (r *redisRepo) SCard(ctx context.Context, key string) (int64, error) { + if err := r.checkClosed(); err != nil { + return 0, err + } + return r.client.SCard(ctx, key).Result() +} + +// ========== ZSet 操作 ========== + +// ZAdd 添加有序集合成员 +func (r *redisRepo) ZAdd(ctx context.Context, key string, members ...redis.Z) (int64, error) { + if err := r.checkClosed(); err != nil { + return 0, err + } + return r.client.ZAdd(ctx, key, members...).Result() +} + +// ZRange 获取有序集合范围 +func (r *redisRepo) ZRange(ctx context.Context, key string, start, stop int64) ([]string, error) { + if err := r.checkClosed(); err != nil { + return nil, err + } + return r.client.ZRange(ctx, key, start, stop).Result() +} + +// ZRangeWithScores 获取有序集合范围(带分数) +func (r *redisRepo) ZRangeWithScores(ctx context.Context, key string, start, stop int64) ([]redis.Z, error) { + if err := r.checkClosed(); err != nil { + return nil, err + } + return r.client.ZRangeWithScores(ctx, key, start, stop).Result() +} + +// ZRevRange 倒序获取有序集合范围 +func (r *redisRepo) ZRevRange(ctx context.Context, key string, start, stop int64) ([]string, error) { + if err := r.checkClosed(); err != nil { + return nil, err + } + return r.client.ZRevRange(ctx, key, start, stop).Result() +} + +// ZRem 删除有序集合成员 +func (r *redisRepo) ZRem(ctx context.Context, key string, members ...interface{}) (int64, error) { + if err := r.checkClosed(); err != nil { + return 0, err + } + return r.client.ZRem(ctx, key, members...).Result() +} + +// ZCard 获取有序集合成员数量 +func (r *redisRepo) ZCard(ctx context.Context, key string) (int64, error) { + if err := r.checkClosed(); err != nil { + return 0, err + } + return r.client.ZCard(ctx, key).Result() +} + +// ZScore 获取成员分数 +func (r *redisRepo) ZScore(ctx context.Context, key, member string) (float64, error) { + if err := r.checkClosed(); err != nil { + return 0, err + } + return r.client.ZScore(ctx, key, member).Result() +} + +// ========== Pipeline 操作 ========== + +// Pipeline 获取pipeline +func (r *redisRepo) Pipeline() redis.Pipeliner { + return r.client.Pipeline() +} + +// TxPipeline 获取事务pipeline +func (r *redisRepo) TxPipeline() redis.Pipeliner { + return r.client.TxPipeline() +} + +// ========== Pub/Sub 操作 ========== + +// Publish 发布消息 +func (r *redisRepo) Publish(ctx context.Context, channel string, message interface{}) error { + if err := r.checkClosed(); err != nil { + return err + } + return r.client.Publish(ctx, channel, message).Err() +} + +// Subscribe 订阅频道 +func (r *redisRepo) Subscribe(ctx context.Context, channels ...string) *redis.PubSub { + return r.client.Subscribe(ctx, channels...) +} + +// ========== 健康检查 ========== + +// Ping 检查连接 +func (r *redisRepo) Ping(ctx context.Context) error { + if err := r.checkClosed(); err != nil { + return err + } + return r.client.Ping(ctx).Err() +} + +// Info 获取服务器信息 +func (r *redisRepo) Info(ctx context.Context, section ...string) (string, error) { + if err := r.checkClosed(); err != nil { + return "", err + } + return r.client.Info(ctx, section...).Result() +} + +// ========== 连接管理 ========== + +// Close 关闭连接 +func (r *redisRepo) Close() error { + r.mu.Lock() + defer r.mu.Unlock() + + if r.closed { + return nil + } + + r.closed = true + return r.client.Close() +} + +// PoolStats 获取连接池统计 +func (r *redisRepo) PoolStats() *redis.PoolStats { + r.mu.RLock() + defer r.mu.RUnlock() + + if r.closed { + return nil + } + + stats := r.client.PoolStats() + return stats +} + +// ========== 内部方法 ========== + +// checkClosed 检查是否已关闭 +func (r *redisRepo) checkClosed() error { + r.mu.RLock() + defer r.mu.RUnlock() + + if r.closed { + return errors.New("redis 连接已关闭") + } return nil } -func (c *cacheRepo) Get(key string, options ...Option) (string, error) { - ts := time.Now() - opt := newOption() - defer func() { - if opt.Trace != nil { - opt.Redis.Timestamp = time_parse.CSTLayoutString() - opt.Redis.Handle = "get" - opt.Redis.Key = key - opt.Redis.CostSeconds = time.Since(ts).Seconds() - opt.Trace.AppendRedis(opt.Redis) - } - }() - - for _, f := range options { - f(opt) +// mergeDefaultConfig 合并默认配置 +func mergeDefaultConfig(cfg *RedisConfig) { + if cfg.Addr == "" { + cfg.Addr = "localhost:6379" } - - value, err := c.client.Get(c.ctx, key).Result() - if err != nil { - return "", errors.Join(err, fmt.Errorf("redis get key: %s err", key)) + if cfg.PoolSize == 0 { + cfg.PoolSize = 100 + } + if cfg.MinIdleConns == 0 { + cfg.MinIdleConns = 10 + } + if cfg.DialTimeout == 0 { + cfg.DialTimeout = 5 * time.Second + } + if cfg.ReadTimeout == 0 { + cfg.ReadTimeout = 3 * time.Second + } + if cfg.WriteTimeout == 0 { + cfg.WriteTimeout = 3 * time.Second + } + if cfg.PoolTimeout == 0 { + cfg.PoolTimeout = 4 * time.Second + } + if cfg.MaxRetries == 0 { + cfg.MaxRetries = 3 + } + if cfg.MinRetryBackoff == 0 { + cfg.MinRetryBackoff = 8 * time.Millisecond + } + if cfg.MaxRetryBackoff == 0 { + cfg.MaxRetryBackoff = 512 * time.Millisecond + } + if cfg.ConnMaxIdleTime == 0 { + cfg.ConnMaxIdleTime = 5 * time.Minute + } + if cfg.ConnMaxLifetime == 0 { + cfg.ConnMaxLifetime = 30 * time.Minute } - - return value, nil -} - -func (c *cacheRepo) TTL(key string) (time.Duration, error) { - ttl, err := c.client.TTL(c.ctx, key).Result() - if err != nil { - return -1, errors.Join(err, fmt.Errorf("redis get key: %s err", key)) - } - - return ttl, nil -} - -func (c *cacheRepo) Expire(key string, ttl time.Duration) bool { - ok, _ := c.client.Expire(c.ctx, key, ttl).Result() - return ok -} - -func (c *cacheRepo) ExpireAt(key string, ttl time.Time) bool { - ok, _ := c.client.ExpireAt(c.ctx, key, ttl).Result() - return ok -} - -func (c *cacheRepo) Exists(keys ...string) bool { - if len(keys) == 0 { - return true - } - value, _ := c.client.Exists(c.ctx, keys...).Result() - return value > 0 -} - -func (c *cacheRepo) Del(key string, options ...Option) bool { - ts := time.Now() - opt := newOption() - defer func() { - if opt.Trace != nil { - opt.Redis.Timestamp = time_parse.CSTLayoutString() - opt.Redis.Handle = "del" - opt.Redis.Key = key - opt.Redis.CostSeconds = time.Since(ts).Seconds() - opt.Trace.AppendRedis(opt.Redis) - } - }() - - for _, f := range options { - f(opt) - } - - if key == "" { - return true - } - - value, _ := c.client.Del(c.ctx, key).Result() - return value > 0 -} - -func (c *cacheRepo) Incr(key string, options ...Option) (int64, error) { - ts := time.Now() - opt := newOption() - defer func() { - if opt.Trace != nil { - opt.Redis.Timestamp = time_parse.CSTLayoutString() - opt.Redis.Handle = "incr" - opt.Redis.Key = key - opt.Redis.CostSeconds = time.Since(ts).Seconds() - opt.Trace.AppendRedis(opt.Redis) - } - }() - - for _, f := range options { - f(opt) - } - value, err := c.client.Incr(c.ctx, key).Result() - if err != nil { - return 0, errors.Join(err, fmt.Errorf("redis incr key: %s err", key)) - } - return value, nil -} - -func (c *cacheRepo) Decr(key string, options ...Option) (int64, error) { - ts := time.Now() - opt := newOption() - defer func() { - if opt.Trace != nil { - opt.Redis.Timestamp = time_parse.CSTLayoutString() - opt.Redis.Handle = "decr" - opt.Redis.Key = key - opt.Redis.CostSeconds = time.Since(ts).Seconds() - opt.Trace.AppendRedis(opt.Redis) - } - }() - - for _, f := range options { - f(opt) - } - value, err := c.client.Decr(c.ctx, key).Result() - if err != nil { - return 0, errors.Join(err, fmt.Errorf("redis decr key: %s err", key)) - } - return value, nil -} - -func (c *cacheRepo) HGet(key, field string, options ...Option) (string, error) { - ts := time.Now() - opt := newOption() - defer func() { - if opt.Trace != nil { - opt.Redis.Timestamp = time_parse.CSTLayoutString() - opt.Redis.Handle = "hash get" - opt.Redis.Key = key - opt.Redis.Value = field - opt.Redis.CostSeconds = time.Since(ts).Seconds() - opt.Trace.AppendRedis(opt.Redis) - } - }() - - for _, f := range options { - f(opt) - } - - value, err := c.client.HGet(c.ctx, key, field).Result() - if err != nil { - return "", errors.Join(err, fmt.Errorf("redis hget key: %s field: %s err", key, field)) - } - - return value, nil -} - -func (c *cacheRepo) HSet(key, field, value string, options ...Option) error { - ts := time.Now() - opt := newOption() - defer func() { - if opt.Trace != nil { - opt.Redis.Timestamp = time_parse.CSTLayoutString() - opt.Redis.Handle = "hash set" - opt.Redis.Key = key - opt.Redis.Value = field + "/" + value - opt.Redis.CostSeconds = time.Since(ts).Seconds() - opt.Trace.AppendRedis(opt.Redis) - } - }() - - for _, f := range options { - f(opt) - } - - if err := c.client.HSet(c.ctx, key, field, value).Err(); err != nil { - return errors.Join(err, fmt.Errorf("redis hset key: %s field: %s err", key, field)) - } - - return nil -} - -func (c *cacheRepo) HDel(key, field string, options ...Option) error { - ts := time.Now() - opt := newOption() - defer func() { - if opt.Trace != nil { - opt.Redis.Timestamp = time_parse.CSTLayoutString() - opt.Redis.Handle = "hash del" - opt.Redis.Key = key - opt.Redis.Value = field - opt.Redis.CostSeconds = time.Since(ts).Seconds() - opt.Trace.AppendRedis(opt.Redis) - } - }() - - for _, f := range options { - f(opt) - } - - if err := c.client.HDel(c.ctx, key, field).Err(); err != nil { - return errors.Join(err, fmt.Errorf("redis hdel key: %s field: %s err", key, field)) - } - - return nil -} - -func (c *cacheRepo) HGetAll(key string, options ...Option) (map[string]string, error) { - ts := time.Now() - opt := newOption() - defer func() { - if opt.Trace != nil { - opt.Redis.Timestamp = time_parse.CSTLayoutString() - opt.Redis.Handle = "hash get all" - opt.Redis.Key = key - opt.Redis.CostSeconds = time.Since(ts).Seconds() - opt.Trace.AppendRedis(opt.Redis) - } - }() - - for _, f := range options { - f(opt) - } - - value, err := c.client.HGetAll(c.ctx, key).Result() - if err != nil { - return nil, errors.Join(err, fmt.Errorf("redis hget all key: %s err", key)) - } - - return value, nil -} - -func (c *cacheRepo) HIncrBy(key, field string, incr int64, options ...Option) (int64, error) { - ts := time.Now() - opt := newOption() - defer func() { - if opt.Trace != nil { - opt.Redis.Timestamp = time_parse.CSTLayoutString() - opt.Redis.Handle = "hash incr int64" - opt.Redis.Key = key - opt.Redis.Value = fmt.Sprintf("field:%s incr:%d", field, incr) - opt.Redis.CostSeconds = time.Since(ts).Seconds() - opt.Trace.AppendRedis(opt.Redis) - } - }() - - for _, f := range options { - f(opt) - } - - value, err := c.client.HIncrBy(c.ctx, key, field, incr).Result() - if err != nil { - return 0, errors.Join(err, fmt.Errorf("redis hash incr int64 key: %s err", key)) - } - - return value, nil -} - -func (c *cacheRepo) HIncrByFloat(key, field string, incr float64, options ...Option) (float64, error) { - ts := time.Now() - opt := newOption() - defer func() { - if opt.Trace != nil { - opt.Redis.Timestamp = time_parse.CSTLayoutString() - opt.Redis.Handle = "hash incr float64" - opt.Redis.Key = key - opt.Redis.Value = fmt.Sprintf("field:%s incr:%d", field, incr) - opt.Redis.CostSeconds = time.Since(ts).Seconds() - opt.Trace.AppendRedis(opt.Redis) - } - }() - - for _, f := range options { - f(opt) - } - - value, err := c.client.HIncrByFloat(c.ctx, key, field, incr).Result() - if err != nil { - return 0, errors.Join(err, fmt.Errorf("redis hash incr float64 key: %s err", key)) - } - - return value, nil -} - -func (c *cacheRepo) LPush(key, value string, options ...Option) error { - ts := time.Now() - opt := newOption() - defer func() { - if opt.Trace != nil { - opt.Redis.Timestamp = time_parse.CSTLayoutString() - opt.Redis.Handle = "list push" - opt.Redis.Key = key - opt.Redis.Value = value - opt.Redis.CostSeconds = time.Since(ts).Seconds() - opt.Trace.AppendRedis(opt.Redis) - } - }() - - for _, f := range options { - f(opt) - } - - _, err := c.client.LPush(c.ctx, key, value).Result() - if err != nil { - return errors.Join(err, fmt.Errorf("redis list push key: %s value: %s err", key, value)) - } - - return nil -} - -func (c *cacheRepo) LLen(key string, options ...Option) (int64, error) { - ts := time.Now() - opt := newOption() - defer func() { - if opt.Trace != nil { - opt.Redis.Timestamp = time_parse.CSTLayoutString() - opt.Redis.Handle = "list len" - opt.Redis.Key = key - opt.Redis.CostSeconds = time.Since(ts).Seconds() - opt.Trace.AppendRedis(opt.Redis) - } - }() - - for _, f := range options { - f(opt) - } - - value, err := c.client.LLen(c.ctx, key).Result() - if err != nil { - return 0, errors.Join(err, fmt.Errorf("redis list len key: %s err", key)) - } - - return value, nil -} - -func (c *cacheRepo) BRPop(key string, timeout time.Duration, options ...Option) (string, error) { - ts := time.Now() - opt := newOption() - defer func() { - if opt.Trace != nil { - opt.Redis.Timestamp = time_parse.CSTLayoutString() - opt.Redis.Handle = "list brpop" - opt.Redis.Key = key - opt.Redis.TTL = timeout.Seconds() - opt.Redis.CostSeconds = time.Since(ts).Seconds() - opt.Trace.AppendRedis(opt.Redis) - } - }() - - for _, f := range options { - f(opt) - } - - value, err := c.client.BRPop(c.ctx, timeout, key).Result() - if err != nil { - return "", errors.Join(err, fmt.Errorf("redis list len key: %s err", key)) - } - - return value[1], nil -} - -func (c *cacheRepo) Close() error { - return c.client.Close() } diff --git a/pkg/crypto/client/axios/encryptedAxios.ts b/pkg/crypto/client/axios/encryptedAxios.ts new file mode 100644 index 0000000..ab0c922 --- /dev/null +++ b/pkg/crypto/client/axios/encryptedAxios.ts @@ -0,0 +1,202 @@ +import axios, { AxiosInstance, AxiosRequestConfig, AxiosResponse, InternalAxiosRequestConfig } from 'axios'; +import { IEncryptor, ISigner, CryptoConfig, EncryptedRequest, EncryptedResponse } from '../crypto/interface'; +import { uint8ArrayToBase64, base64ToUint8Array, stringToUint8Array, uint8ArrayToString } from '../utils/base64'; +import { generateUUID } from '../utils/uuid'; + +/** + * 加密Axios实例 + */ +export class EncryptedAxios { + private axiosInstance: AxiosInstance; + private encryptor: IEncryptor | null = null; + private signer: ISigner | null = null; + private config: CryptoConfig; + + constructor( + encryptor?: IEncryptor, + signer?: ISigner, + config: CryptoConfig = {}, + axiosConfig?: AxiosRequestConfig + ) { + this.encryptor = encryptor || null; + this.signer = signer || null; + this.config = { + timestampWindow: 5 * 60 * 1000, // 默认5分钟 + enableTimestamp: true, + enableSignature: true, + ...config, + }; + + // 创建axios实例 + this.axiosInstance = axios.create(axiosConfig); + + // 添加请求拦截器 + this.axiosInstance.interceptors.request.use( + this.encryptRequestInterceptor.bind(this), + (error) => Promise.reject(error) + ); + + // 添加响应拦截器 + this.axiosInstance.interceptors.response.use( + this.decryptResponseInterceptor.bind(this), + (error) => Promise.reject(error) + ); + } + + /** + * 请求拦截器 - 加密请求数据 + */ + private async encryptRequestInterceptor( + config: InternalAxiosRequestConfig + ): Promise { + // 放行GET和OPTIONS请求 + if (config.method?.toUpperCase() === 'GET' || config.method?.toUpperCase() === 'OPTIONS') { + return config; + } + + // 如果没有配置加密器,直接返回 + if (!this.encryptor) { + return config; + } + + try { + // 将请求数据转换为JSON字符串 + const plaintext = JSON.stringify(config.data || {}); + const plaintextBytes = stringToUint8Array(plaintext); + + // 加密数据 + const ciphertext = await this.encryptor.encrypt(plaintextBytes); + const encryptedData = uint8ArrayToBase64(ciphertext); + + // 构建加密请求体 + const encryptedRequest: EncryptedRequest = { + data: encryptedData, + timestamp: Date.now(), + request_id: generateUUID(), + algorithm: this.encryptor.name(), + }; + + // 生成签名 + if (this.config.enableSignature && this.signer) { + const signature = await this.signer.sign(plaintextBytes); + encryptedRequest.signature = uint8ArrayToBase64(signature); + } + + // 替换请求数据 + config.data = encryptedRequest; + config.headers['Content-Type'] = 'application/json'; + + // 保存request_id供响应使用 + config.headers['X-Request-ID'] = encryptedRequest.request_id; + + } catch (error) { + console.error('加密请求失败:', error); + throw error; + } + + return config; + } + + /** + * 响应拦截器 - 解密响应数据 + */ + private async decryptResponseInterceptor( + response: AxiosResponse + ): Promise { + // 如果没有配置加密器或响应不是加密格式,直接返回 + if (!this.encryptor || !response.data || typeof response.data !== 'object') { + return response; + } + + // 检查是否是加密响应 + const encryptedResponse = response.data as EncryptedResponse; + if (!encryptedResponse.data || !encryptedResponse.request_id) { + // 不是加密响应,直接返回 + return response; + } + + try { + // 验证时间戳 + if (this.config.enableTimestamp) { + this.verifyTimestamp(encryptedResponse.timestamp); + } + + // 解密数据 + const ciphertext = base64ToUint8Array(encryptedResponse.data); + const plaintext = await this.encryptor.decrypt(ciphertext); + + // 验证签名 + if (this.config.enableSignature && this.signer && encryptedResponse.signature) { + const signature = base64ToUint8Array(encryptedResponse.signature); + const isValid = await this.signer.verify(plaintext, signature); + if (!isValid) { + throw new Error('签名验证失败'); + } + } + + // 将解密后的数据转换为JSON对象 + const decryptedData = uint8ArrayToString(plaintext); + response.data = JSON.parse(decryptedData); + + } catch (error) { + console.error('解密响应失败:', error); + throw error; + } + + return response; + } + + /** + * 验证时间戳 + */ + private verifyTimestamp(timestamp: number): void { + const now = Date.now(); + const diff = Math.abs(now - timestamp); + + if (diff > (this.config.timestampWindow || 5 * 60 * 1000)) { + throw new Error('请求超时'); + } + } + + /** + * 获取axios实例 + */ + getInstance(): AxiosInstance { + return this.axiosInstance; + } + + /** + * GET 请求 + */ + get(url: string, config?: AxiosRequestConfig): Promise> { + return this.axiosInstance.get(url, config); + } + + /** + * POST 请求 + */ + post(url: string, data?: any, config?: AxiosRequestConfig): Promise> { + return this.axiosInstance.post(url, data, config); + } + + /** + * PUT 请求 + */ + put(url: string, data?: any, config?: AxiosRequestConfig): Promise> { + return this.axiosInstance.put(url, data, config); + } + + /** + * DELETE 请求 + */ + delete(url: string, config?: AxiosRequestConfig): Promise> { + return this.axiosInstance.delete(url, config); + } + + /** + * PATCH 请求 + */ + patch(url: string, data?: any, config?: AxiosRequestConfig): Promise> { + return this.axiosInstance.patch(url, data, config); + } +} diff --git a/pkg/crypto/client/crypto/aes.ts b/pkg/crypto/client/crypto/aes.ts new file mode 100644 index 0000000..5fad134 --- /dev/null +++ b/pkg/crypto/client/crypto/aes.ts @@ -0,0 +1,90 @@ +import { IEncryptor } from './interface'; + +/** + * AES-GCM加密器 + */ +export class AESEncryptor implements IEncryptor { + private key: CryptoKey | null = null; + + constructor(keyString: string) { + this.importKey(keyString); + } + + /** + * 导入密钥 + */ + private async importKey(keyString: string): Promise { + const encoder = new TextEncoder(); + const keyData = encoder.encode(keyString.padEnd(32, '0').substring(0, 32)); + + this.key = await crypto.subtle.importKey( + 'raw', + keyData, + { + name: 'AES-GCM', + length: 256, + }, + false, + ['encrypt', 'decrypt'] + ); + } + + /** + * 加密数据 + */ + async encrypt(plaintext: Uint8Array): Promise { + if (!this.key) { + throw new Error('密钥未设置'); + } + + // 生成随机IV + const iv = crypto.getRandomValues(new Uint8Array(12)); + + const encrypted = await crypto.subtle.encrypt( + { + name: 'AES-GCM', + iv: iv, + }, + this.key, + plaintext + ); + + // 将IV和密文拼接在一起 + const result = new Uint8Array(iv.length + encrypted.byteLength); + result.set(iv, 0); + result.set(new Uint8Array(encrypted), iv.length); + + return result; + } + + /** + * 解密数据 + */ + async decrypt(ciphertext: Uint8Array): Promise { + if (!this.key) { + throw new Error('密钥未设置'); + } + + // 提取IV + const iv = ciphertext.slice(0, 12); + const data = ciphertext.slice(12); + + const decrypted = await crypto.subtle.decrypt( + { + name: 'AES-GCM', + iv: iv, + }, + this.key, + data + ); + + return new Uint8Array(decrypted); + } + + /** + * 返回算法名称 + */ + name(): string { + return 'AES-GCM-256'; + } +} diff --git a/pkg/crypto/client/crypto/hmac.ts b/pkg/crypto/client/crypto/hmac.ts new file mode 100644 index 0000000..ab1b8aa --- /dev/null +++ b/pkg/crypto/client/crypto/hmac.ts @@ -0,0 +1,54 @@ +import { ISigner } from './interface'; + +/** + * HMAC签名器 + */ +export class HMACSigner implements ISigner { + private key: CryptoKey | null = null; + + constructor(keyString: string) { + this.importKey(keyString); + } + + /** + * 导入密钥 + */ + private async importKey(keyString: string): Promise { + const encoder = new TextEncoder(); + const keyData = encoder.encode(keyString); + + this.key = await crypto.subtle.importKey( + 'raw', + keyData, + { + name: 'HMAC', + hash: 'SHA-256', + }, + false, + ['sign', 'verify'] + ); + } + + /** + * 生成签名 + */ + async sign(data: Uint8Array): Promise { + if (!this.key) { + throw new Error('密钥未设置'); + } + + const signature = await crypto.subtle.sign('HMAC', this.key, data); + return new Uint8Array(signature); + } + + /** + * 验证签名 + */ + async verify(data: Uint8Array, signature: Uint8Array): Promise { + if (!this.key) { + throw new Error('密钥未设置'); + } + + return await crypto.subtle.verify('HMAC', this.key, signature, data); + } +} diff --git a/pkg/crypto/client/crypto/interface.ts b/pkg/crypto/client/crypto/interface.ts new file mode 100644 index 0000000..5ae1cff --- /dev/null +++ b/pkg/crypto/client/crypto/interface.ts @@ -0,0 +1,50 @@ +/** + * 加密器接口 + */ +export interface IEncryptor { + encrypt(plaintext: Uint8Array): Promise; + decrypt(ciphertext: Uint8Array): Promise; + name(): string; +} + +/** + * 签名器接口 + */ +export interface ISigner { + sign(data: Uint8Array): Promise; + verify(data: Uint8Array, signature: Uint8Array): Promise; +} + +/** + * 配置选项 + */ +export interface CryptoConfig { + secretKey?: string; // 对称加密密钥 + signKey?: string; // 签名密钥 + publicKey?: string; // RSA公钥(PEM格式) + privateKey?: string; // RSA私钥(PEM格式) + timestampWindow?: number; // 时间戳窗口(毫秒) + enableTimestamp?: boolean; // 是否启用时间戳验证 + enableSignature?: boolean; // 是否启用签名 +} + +/** + * 加密请求体 + */ +export interface EncryptedRequest { + data: string; // Base64编码的加密数据 + signature?: string; // Base64编码的签名 + timestamp: number; // 时间戳 + request_id: string; // 请求ID + algorithm: string; // 加密算法名称 +} + +/** + * 加密响应体 + */ +export interface EncryptedResponse { + data: string; // Base64编码的加密数据 + signature?: string; // Base64编码的签名 + timestamp: number; // 时间戳 + request_id: string; // 请求ID +} diff --git a/pkg/crypto/client/crypto/rsa.ts b/pkg/crypto/client/crypto/rsa.ts new file mode 100644 index 0000000..80c6d42 --- /dev/null +++ b/pkg/crypto/client/crypto/rsa.ts @@ -0,0 +1,125 @@ +import { IEncryptor } from './interface'; + +/** + * RSA加密器(使用Web Crypto API) + */ +export class RSAEncryptor implements IEncryptor { + private publicKey: CryptoKey | null = null; + private privateKey: CryptoKey | null = null; + + constructor(publicKeyPEM?: string, privateKeyPEM?: string) { + if (publicKeyPEM) { + this.importPublicKey(publicKeyPEM); + } + if (privateKeyPEM) { + this.importPrivateKey(privateKeyPEM); + } + } + + /** + * 导入公钥(PEM格式) + */ + private async importPublicKey(pem: string): Promise { + const pemHeader = '-----BEGIN PUBLIC KEY-----'; + const pemFooter = '-----END PUBLIC KEY-----'; + const pemContents = pem + .replace(pemHeader, '') + .replace(pemFooter, '') + .replace(/\s/g, ''); + + const binaryDer = this.base64ToArrayBuffer(pemContents); + + this.publicKey = await crypto.subtle.importKey( + 'spki', + binaryDer, + { + name: 'RSA-OAEP', + hash: 'SHA-256', + }, + true, + ['encrypt'] + ); + } + + /** + * 导入私钥(PEM格式) + */ + private async importPrivateKey(pem: string): Promise { + const pemHeader = '-----BEGIN PRIVATE KEY-----'; + const pemFooter = '-----END PRIVATE KEY-----'; + const pemContents = pem + .replace(pemHeader, '') + .replace(pemFooter, '') + .replace(/\s/g, ''); + + const binaryDer = this.base64ToArrayBuffer(pemContents); + + this.privateKey = await crypto.subtle.importKey( + 'pkcs8', + binaryDer, + { + name: 'RSA-OAEP', + hash: 'SHA-256', + }, + true, + ['decrypt'] + ); + } + + /** + * Base64转ArrayBuffer + */ + private base64ToArrayBuffer(base64: string): ArrayBuffer { + const binaryString = atob(base64); + const bytes = new Uint8Array(binaryString.length); + for (let i = 0; i < binaryString.length; i++) { + bytes[i] = binaryString.charCodeAt(i); + } + return bytes.buffer; + } + + /** + * 加密数据 + */ + async encrypt(plaintext: Uint8Array): Promise { + if (!this.publicKey) { + throw new Error('公钥未设置'); + } + + const encrypted = await crypto.subtle.encrypt( + { + name: 'RSA-OAEP', + }, + this.publicKey, + plaintext + ); + + return new Uint8Array(encrypted); + } + + /** + * 解密数据 + */ + async decrypt(ciphertext: Uint8Array): Promise { + if (!this.privateKey) { + throw new Error('私钥未设置'); + } + + const decrypted = await crypto.subtle.decrypt( + { + name: 'RSA-OAEP', + }, + this.privateKey, + ciphertext + ); + + return new Uint8Array(decrypted); + } + + /** + * 返回算法名称 + */ + name(): string { + return 'RSA-OAEP-SHA256'; + } +} diff --git a/pkg/crypto/client/index.ts b/pkg/crypto/client/index.ts new file mode 100644 index 0000000..d70ce1f --- /dev/null +++ b/pkg/crypto/client/index.ts @@ -0,0 +1,7 @@ +export * from './crypto/interface'; +export * from './crypto/rsa'; +export * from './crypto/hmac'; +export * from './crypto/aes'; +export * from './axios/encryptedAxios'; +export * from './utils/base64'; +export * from './utils/uuid'; diff --git a/pkg/crypto/client/utils/base64.ts b/pkg/crypto/client/utils/base64.ts new file mode 100644 index 0000000..2e9108d --- /dev/null +++ b/pkg/crypto/client/utils/base64.ts @@ -0,0 +1,38 @@ +/** + * Uint8Array 转 Base64 + */ +export function uint8ArrayToBase64(bytes: Uint8Array): string { + let binary = ''; + for (let i = 0; i < bytes.length; i++) { + binary += String.fromCharCode(bytes[i]); + } + return btoa(binary); +} + +/** + * Base64 转 Uint8Array + */ +export function base64ToUint8Array(base64: string): Uint8Array { + const binary = atob(base64); + const bytes = new Uint8Array(binary.length); + for (let i = 0; i < binary.length; i++) { + bytes[i] = binary.charCodeAt(i); + } + return bytes; +} + +/** + * 字符串转 Uint8Array + */ +export function stringToUint8Array(str: string): Uint8Array { + const encoder = new TextEncoder(); + return encoder.encode(str); +} + +/** + * Uint8Array 转字符串 + */ +export function uint8ArrayToString(bytes: Uint8Array): string { + const decoder = new TextDecoder(); + return decoder.decode(bytes); +} diff --git a/pkg/crypto/client/utils/uuid.ts b/pkg/crypto/client/utils/uuid.ts new file mode 100644 index 0000000..fff1fa6 --- /dev/null +++ b/pkg/crypto/client/utils/uuid.ts @@ -0,0 +1,10 @@ +/** + * 生成UUID v4 + */ +export function generateUUID(): string { + return 'xxxxxxxx-xxxx-4xxx-yxxx-xxxxxxxxxxxx'.replace(/[xy]/g, (c) => { + const r = (Math.random() * 16) | 0; + const v = c === 'x' ? r : (r & 0x3) | 0x8; + return v.toString(16); + }); +}