package cache import ( "context" "encoding/json" "fmt" "strings" "time" "github.com/redis/go-redis/v9" "go.uber.org/zap" "tyapi-server/internal/shared/interfaces" ) // RedisCache Redis缓存实现 type RedisCache struct { client *redis.Client logger *zap.Logger prefix string // 统计信息 hits int64 misses int64 } // NewRedisCache 创建Redis缓存实例 func NewRedisCache(client *redis.Client, logger *zap.Logger, prefix string) *RedisCache { return &RedisCache{ client: client, logger: logger, prefix: prefix, } } // Name 返回服务名称 func (r *RedisCache) Name() string { return "redis-cache" } // Initialize 初始化服务 func (r *RedisCache) Initialize(ctx context.Context) error { // 测试连接 _, err := r.client.Ping(ctx).Result() if err != nil { r.logger.Error("Failed to connect to Redis", zap.Error(err)) return fmt.Errorf("redis connection failed: %w", err) } r.logger.Info("Redis cache service initialized") return nil } // HealthCheck 健康检查 func (r *RedisCache) HealthCheck(ctx context.Context) error { _, err := r.client.Ping(ctx).Result() return err } // Shutdown 关闭服务 func (r *RedisCache) Shutdown(ctx context.Context) error { return r.client.Close() } // Get 获取缓存值 func (r *RedisCache) Get(ctx context.Context, key string, dest interface{}) error { fullKey := r.getFullKey(key) val, err := r.client.Get(ctx, fullKey).Result() if err != nil { if err == redis.Nil { r.misses++ return fmt.Errorf("cache miss: key %s not found", key) } r.logger.Error("Failed to get cache", zap.String("key", key), zap.Error(err)) return err } r.hits++ return json.Unmarshal([]byte(val), dest) } // Set 设置缓存值 func (r *RedisCache) Set(ctx context.Context, key string, value interface{}, ttl ...interface{}) error { fullKey := r.getFullKey(key) data, err := json.Marshal(value) if err != nil { r.logger.Error("序列化缓存数据失败", zap.String("key", key), zap.Error(err)) return fmt.Errorf("failed to marshal value: %w", err) } var expiration time.Duration if len(ttl) > 0 { switch v := ttl[0].(type) { case time.Duration: expiration = v case int: expiration = time.Duration(v) * time.Second case string: expiration, _ = time.ParseDuration(v) default: expiration = 24 * time.Hour // 默认24小时 } } else { expiration = 24 * time.Hour // 默认24小时 } err = r.client.Set(ctx, fullKey, data, expiration).Err() if err != nil { r.logger.Error("设置缓存失败", zap.String("key", key), zap.Error(err)) return err } r.logger.Debug("设置缓存成功", zap.String("key", key), zap.Duration("ttl", expiration)) return nil } // Delete 删除缓存 func (r *RedisCache) Delete(ctx context.Context, keys ...string) error { if len(keys) == 0 { return nil } fullKeys := make([]string, len(keys)) for i, key := range keys { fullKeys[i] = r.getFullKey(key) } err := r.client.Del(ctx, fullKeys...).Err() if err != nil { r.logger.Error("Failed to delete cache", zap.Strings("keys", keys), zap.Error(err)) return err } return nil } // Exists 检查键是否存在 func (r *RedisCache) Exists(ctx context.Context, key string) (bool, error) { fullKey := r.getFullKey(key) count, err := r.client.Exists(ctx, fullKey).Result() if err != nil { return false, err } return count > 0, nil } // GetMultiple 批量获取 func (r *RedisCache) GetMultiple(ctx context.Context, keys []string) (map[string]interface{}, error) { if len(keys) == 0 { return make(map[string]interface{}), nil } fullKeys := make([]string, len(keys)) for i, key := range keys { fullKeys[i] = r.getFullKey(key) } values, err := r.client.MGet(ctx, fullKeys...).Result() if err != nil { r.logger.Error("批量获取缓存失败", zap.Strings("keys", keys), zap.Error(err)) return nil, err } result := make(map[string]interface{}) for i, val := range values { if val != nil { var data interface{} // 修复:改进JSON反序列化错误处理 if err := json.Unmarshal([]byte(val.(string)), &data); err != nil { r.logger.Warn("反序列化缓存数据失败", zap.String("key", keys[i]), zap.String("value", val.(string)), zap.Error(err)) continue } result[keys[i]] = data } } return result, nil } // SetMultiple 批量设置 func (r *RedisCache) SetMultiple(ctx context.Context, data map[string]interface{}, ttl ...interface{}) error { if len(data) == 0 { return nil } var expiration time.Duration if len(ttl) > 0 { switch v := ttl[0].(type) { case time.Duration: expiration = v case int: expiration = time.Duration(v) * time.Second default: expiration = 24 * time.Hour } } else { expiration = 24 * time.Hour } pipe := r.client.Pipeline() for key, value := range data { fullKey := r.getFullKey(key) jsonData, err := json.Marshal(value) if err != nil { continue } pipe.Set(ctx, fullKey, jsonData, expiration) } _, err := pipe.Exec(ctx) return err } // DeletePattern 按模式删除 func (r *RedisCache) DeletePattern(ctx context.Context, pattern string) error { // 修复:避免重复添加前缀 var fullPattern string if strings.HasPrefix(pattern, r.prefix+":") { fullPattern = pattern } else { fullPattern = r.getFullKey(pattern) } // 检查上下文是否已取消 if ctx.Err() != nil { return ctx.Err() } var cursor uint64 var totalDeleted int64 maxIterations := 100 // 防止无限循环 iteration := 0 for { // 检查迭代次数限制 iteration++ if iteration > maxIterations { r.logger.Warn("缓存删除操作达到最大迭代次数限制", zap.String("pattern", fullPattern), zap.Int("max_iterations", maxIterations), zap.Int64("total_deleted", totalDeleted), ) break } // 检查上下文是否已取消 if ctx.Err() != nil { r.logger.Warn("缓存删除操作被取消", zap.String("pattern", fullPattern), zap.Int64("total_deleted", totalDeleted), zap.Error(ctx.Err()), ) return ctx.Err() } // 执行SCAN操作 keys, next, err := r.client.Scan(ctx, cursor, fullPattern, 1000).Result() if err != nil { // 如果是上下文取消错误,直接返回 if err == context.Canceled || err == context.DeadlineExceeded { r.logger.Warn("缓存删除操作被取消", zap.String("pattern", fullPattern), zap.Int64("total_deleted", totalDeleted), zap.Error(err), ) return err } r.logger.Error("扫描缓存键失败", zap.String("pattern", fullPattern), zap.Error(err)) return err } // 批量删除找到的键 if len(keys) > 0 { // 使用pipeline批量删除,提高性能 pipe := r.client.Pipeline() pipe.Del(ctx, keys...) cmds, err := pipe.Exec(ctx) if err != nil { r.logger.Error("批量删除缓存键失败", zap.Strings("keys", keys), zap.Error(err)) return err } // 统计删除的键数量 for _, cmd := range cmds { if delCmd, ok := cmd.(*redis.IntCmd); ok { if deleted, err := delCmd.Result(); err == nil { totalDeleted += deleted } } } r.logger.Debug("批量删除缓存键", zap.Strings("keys", keys), zap.Int("batch_size", len(keys)), zap.Int64("total_deleted", totalDeleted), ) } cursor = next if cursor == 0 { break } } r.logger.Debug("缓存模式删除完成", zap.String("pattern", fullPattern), zap.Int64("total_deleted", totalDeleted), zap.Int("iterations", iteration), ) return nil } // Keys 获取匹配的键 func (r *RedisCache) Keys(ctx context.Context, pattern string) ([]string, error) { fullPattern := r.getFullKey(pattern) keys, err := r.client.Keys(ctx, fullPattern).Result() if err != nil { return nil, err } // 移除前缀 result := make([]string, len(keys)) prefixLen := len(r.prefix) + 1 // +1 for ":" for i, key := range keys { if len(key) > prefixLen { result[i] = key[prefixLen:] } else { result[i] = key } } return result, nil } // Stats 获取缓存统计 func (r *RedisCache) Stats(ctx context.Context) (interfaces.CacheStats, error) { dbSize, _ := r.client.DBSize(ctx).Result() return interfaces.CacheStats{ Hits: r.hits, Misses: r.misses, Keys: dbSize, Memory: 0, // 暂时设为0,后续可解析Redis info Connections: 0, // 暂时设为0,后续可解析Redis info }, nil } // getFullKey 获取完整键名 func (r *RedisCache) getFullKey(key string) string { if r.prefix == "" { return key } return fmt.Sprintf("%s:%s", r.prefix, key) } // Flush 清空所有缓存 func (r *RedisCache) Flush(ctx context.Context) error { if r.prefix == "" { return r.client.FlushDB(ctx).Err() } // 只删除带前缀的键 return r.DeletePattern(ctx, "*") } // GetClient 获取原始Redis客户端 func (r *RedisCache) GetClient() *redis.Client { return r.client }