385 lines
		
	
	
		
			8.7 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			385 lines
		
	
	
		
			8.7 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package cache
 | ||
| 
 | ||
| import (
 | ||
| 	"context"
 | ||
| 	"encoding/json"
 | ||
| 	"fmt"
 | ||
| 	"strings"
 | ||
| 	"time"
 | ||
| 
 | ||
| 	"github.com/redis/go-redis/v9"
 | ||
| 	"go.uber.org/zap"
 | ||
| 
 | ||
| 	"tyapi-server/internal/shared/interfaces"
 | ||
| )
 | ||
| 
 | ||
| // RedisCache Redis缓存实现
 | ||
| type RedisCache struct {
 | ||
| 	client *redis.Client
 | ||
| 	logger *zap.Logger
 | ||
| 	prefix string
 | ||
| 
 | ||
| 	// 统计信息
 | ||
| 	hits   int64
 | ||
| 	misses int64
 | ||
| }
 | ||
| 
 | ||
| // NewRedisCache 创建Redis缓存实例
 | ||
| func NewRedisCache(client *redis.Client, logger *zap.Logger, prefix string) *RedisCache {
 | ||
| 	return &RedisCache{
 | ||
| 		client: client,
 | ||
| 		logger: logger,
 | ||
| 		prefix: prefix,
 | ||
| 	}
 | ||
| }
 | ||
| 
 | ||
| // Name 返回服务名称
 | ||
| func (r *RedisCache) Name() string {
 | ||
| 	return "redis-cache"
 | ||
| }
 | ||
| 
 | ||
| // Initialize 初始化服务
 | ||
| func (r *RedisCache) Initialize(ctx context.Context) error {
 | ||
| 	// 测试连接
 | ||
| 	_, err := r.client.Ping(ctx).Result()
 | ||
| 	if err != nil {
 | ||
| 		r.logger.Error("Failed to connect to Redis", zap.Error(err))
 | ||
| 		return fmt.Errorf("redis connection failed: %w", err)
 | ||
| 	}
 | ||
| 
 | ||
| 	r.logger.Info("Redis cache service initialized")
 | ||
| 	return nil
 | ||
| }
 | ||
| 
 | ||
| // HealthCheck 健康检查
 | ||
| func (r *RedisCache) HealthCheck(ctx context.Context) error {
 | ||
| 	_, err := r.client.Ping(ctx).Result()
 | ||
| 	return err
 | ||
| }
 | ||
| 
 | ||
| // Shutdown 关闭服务
 | ||
| func (r *RedisCache) Shutdown(ctx context.Context) error {
 | ||
| 	return r.client.Close()
 | ||
| }
 | ||
| 
 | ||
| // Get 获取缓存值
 | ||
| func (r *RedisCache) Get(ctx context.Context, key string, dest interface{}) error {
 | ||
| 	fullKey := r.getFullKey(key)
 | ||
| 
 | ||
| 	val, err := r.client.Get(ctx, fullKey).Result()
 | ||
| 	if err != nil {
 | ||
| 		if err == redis.Nil {
 | ||
| 			r.misses++
 | ||
| 			return fmt.Errorf("cache miss: key %s not found", key)
 | ||
| 		}
 | ||
| 		r.logger.Error("Failed to get cache", zap.String("key", key), zap.Error(err))
 | ||
| 		return err
 | ||
| 	}
 | ||
| 
 | ||
| 	r.hits++
 | ||
| 	return json.Unmarshal([]byte(val), dest)
 | ||
| }
 | ||
| 
 | ||
| // Set 设置缓存值
 | ||
| func (r *RedisCache) Set(ctx context.Context, key string, value interface{}, ttl ...interface{}) error {
 | ||
| 	fullKey := r.getFullKey(key)
 | ||
| 
 | ||
| 	data, err := json.Marshal(value)
 | ||
| 	if err != nil {
 | ||
| 		r.logger.Error("序列化缓存数据失败", zap.String("key", key), zap.Error(err))
 | ||
| 		return fmt.Errorf("failed to marshal value: %w", err)
 | ||
| 	}
 | ||
| 
 | ||
| 	var expiration time.Duration
 | ||
| 	if len(ttl) > 0 {
 | ||
| 		switch v := ttl[0].(type) {
 | ||
| 		case time.Duration:
 | ||
| 			expiration = v
 | ||
| 		case int:
 | ||
| 			expiration = time.Duration(v) * time.Second
 | ||
| 		case string:
 | ||
| 			expiration, _ = time.ParseDuration(v)
 | ||
| 		default:
 | ||
| 			expiration = 24 * time.Hour // 默认24小时
 | ||
| 		}
 | ||
| 	} else {
 | ||
| 		expiration = 24 * time.Hour // 默认24小时
 | ||
| 	}
 | ||
| 
 | ||
| 	err = r.client.Set(ctx, fullKey, data, expiration).Err()
 | ||
| 	if err != nil {
 | ||
| 		r.logger.Error("设置缓存失败", zap.String("key", key), zap.Error(err))
 | ||
| 		return err
 | ||
| 	}
 | ||
| 
 | ||
| 	r.logger.Debug("设置缓存成功", zap.String("key", key), zap.Duration("ttl", expiration))
 | ||
| 	return nil
 | ||
| }
 | ||
| 
 | ||
| // Delete 删除缓存
 | ||
| func (r *RedisCache) Delete(ctx context.Context, keys ...string) error {
 | ||
| 	if len(keys) == 0 {
 | ||
| 		return nil
 | ||
| 	}
 | ||
| 
 | ||
| 	fullKeys := make([]string, len(keys))
 | ||
| 	for i, key := range keys {
 | ||
| 		fullKeys[i] = r.getFullKey(key)
 | ||
| 	}
 | ||
| 
 | ||
| 	err := r.client.Del(ctx, fullKeys...).Err()
 | ||
| 	if err != nil {
 | ||
| 		r.logger.Error("Failed to delete cache", zap.Strings("keys", keys), zap.Error(err))
 | ||
| 		return err
 | ||
| 	}
 | ||
| 
 | ||
| 	return nil
 | ||
| }
 | ||
| 
 | ||
| // Exists 检查键是否存在
 | ||
| func (r *RedisCache) Exists(ctx context.Context, key string) (bool, error) {
 | ||
| 	fullKey := r.getFullKey(key)
 | ||
| 
 | ||
| 	count, err := r.client.Exists(ctx, fullKey).Result()
 | ||
| 	if err != nil {
 | ||
| 		return false, err
 | ||
| 	}
 | ||
| 
 | ||
| 	return count > 0, nil
 | ||
| }
 | ||
| 
 | ||
| // GetMultiple 批量获取
 | ||
| func (r *RedisCache) GetMultiple(ctx context.Context, keys []string) (map[string]interface{}, error) {
 | ||
| 	if len(keys) == 0 {
 | ||
| 		return make(map[string]interface{}), nil
 | ||
| 	}
 | ||
| 
 | ||
| 	fullKeys := make([]string, len(keys))
 | ||
| 	for i, key := range keys {
 | ||
| 		fullKeys[i] = r.getFullKey(key)
 | ||
| 	}
 | ||
| 
 | ||
| 	values, err := r.client.MGet(ctx, fullKeys...).Result()
 | ||
| 	if err != nil {
 | ||
| 		r.logger.Error("批量获取缓存失败", zap.Strings("keys", keys), zap.Error(err))
 | ||
| 		return nil, err
 | ||
| 	}
 | ||
| 
 | ||
| 	result := make(map[string]interface{})
 | ||
| 	for i, val := range values {
 | ||
| 		if val != nil {
 | ||
| 			var data interface{}
 | ||
| 			// 修复:改进JSON反序列化错误处理
 | ||
| 			if err := json.Unmarshal([]byte(val.(string)), &data); err != nil {
 | ||
| 				r.logger.Warn("反序列化缓存数据失败", 
 | ||
| 					zap.String("key", keys[i]), 
 | ||
| 					zap.String("value", val.(string)), 
 | ||
| 					zap.Error(err))
 | ||
| 				continue
 | ||
| 			}
 | ||
| 			result[keys[i]] = data
 | ||
| 		}
 | ||
| 	}
 | ||
| 
 | ||
| 	return result, nil
 | ||
| }
 | ||
| 
 | ||
| // SetMultiple 批量设置
 | ||
| func (r *RedisCache) SetMultiple(ctx context.Context, data map[string]interface{}, ttl ...interface{}) error {
 | ||
| 	if len(data) == 0 {
 | ||
| 		return nil
 | ||
| 	}
 | ||
| 
 | ||
| 	var expiration time.Duration
 | ||
| 	if len(ttl) > 0 {
 | ||
| 		switch v := ttl[0].(type) {
 | ||
| 		case time.Duration:
 | ||
| 			expiration = v
 | ||
| 		case int:
 | ||
| 			expiration = time.Duration(v) * time.Second
 | ||
| 		default:
 | ||
| 			expiration = 24 * time.Hour
 | ||
| 		}
 | ||
| 	} else {
 | ||
| 		expiration = 24 * time.Hour
 | ||
| 	}
 | ||
| 
 | ||
| 	pipe := r.client.Pipeline()
 | ||
| 	for key, value := range data {
 | ||
| 		fullKey := r.getFullKey(key)
 | ||
| 		jsonData, err := json.Marshal(value)
 | ||
| 		if err != nil {
 | ||
| 			continue
 | ||
| 		}
 | ||
| 		pipe.Set(ctx, fullKey, jsonData, expiration)
 | ||
| 	}
 | ||
| 
 | ||
| 	_, err := pipe.Exec(ctx)
 | ||
| 	return err
 | ||
| }
 | ||
| 
 | ||
| // DeletePattern 按模式删除
 | ||
| func (r *RedisCache) DeletePattern(ctx context.Context, pattern string) error {
 | ||
| 	// 修复:避免重复添加前缀
 | ||
| 	var fullPattern string
 | ||
| 	if strings.HasPrefix(pattern, r.prefix+":") {
 | ||
| 		fullPattern = pattern
 | ||
| 	} else {
 | ||
| 		fullPattern = r.getFullKey(pattern)
 | ||
| 	}
 | ||
| 	
 | ||
| 	// 检查上下文是否已取消
 | ||
| 	if ctx.Err() != nil {
 | ||
| 		return ctx.Err()
 | ||
| 	}
 | ||
| 	
 | ||
| 	var cursor uint64
 | ||
| 	var totalDeleted int64
 | ||
| 	maxIterations := 100 // 防止无限循环
 | ||
| 	iteration := 0
 | ||
| 	
 | ||
| 	for {
 | ||
| 		// 检查迭代次数限制
 | ||
| 		iteration++
 | ||
| 		if iteration > maxIterations {
 | ||
| 			r.logger.Warn("缓存删除操作达到最大迭代次数限制",
 | ||
| 				zap.String("pattern", fullPattern),
 | ||
| 				zap.Int("max_iterations", maxIterations),
 | ||
| 				zap.Int64("total_deleted", totalDeleted),
 | ||
| 			)
 | ||
| 			break
 | ||
| 		}
 | ||
| 		
 | ||
| 		// 检查上下文是否已取消
 | ||
| 		if ctx.Err() != nil {
 | ||
| 			r.logger.Warn("缓存删除操作被取消",
 | ||
| 				zap.String("pattern", fullPattern),
 | ||
| 				zap.Int64("total_deleted", totalDeleted),
 | ||
| 				zap.Error(ctx.Err()),
 | ||
| 			)
 | ||
| 			return ctx.Err()
 | ||
| 		}
 | ||
| 		
 | ||
| 		// 执行SCAN操作
 | ||
| 		keys, next, err := r.client.Scan(ctx, cursor, fullPattern, 1000).Result()
 | ||
| 		if err != nil {
 | ||
| 			// 如果是上下文取消错误,直接返回
 | ||
| 			if err == context.Canceled || err == context.DeadlineExceeded {
 | ||
| 				r.logger.Warn("缓存删除操作被取消",
 | ||
| 					zap.String("pattern", fullPattern),
 | ||
| 					zap.Int64("total_deleted", totalDeleted),
 | ||
| 					zap.Error(err),
 | ||
| 				)
 | ||
| 				return err
 | ||
| 			}
 | ||
| 			
 | ||
| 			r.logger.Error("扫描缓存键失败", 
 | ||
| 				zap.String("pattern", fullPattern), 
 | ||
| 				zap.Error(err))
 | ||
| 			return err
 | ||
| 		}
 | ||
| 		
 | ||
| 		// 批量删除找到的键
 | ||
| 		if len(keys) > 0 {
 | ||
| 			// 使用pipeline批量删除,提高性能
 | ||
| 			pipe := r.client.Pipeline()
 | ||
| 			pipe.Del(ctx, keys...)
 | ||
| 			
 | ||
| 			cmds, err := pipe.Exec(ctx)
 | ||
| 			if err != nil {
 | ||
| 				r.logger.Error("批量删除缓存键失败", 
 | ||
| 					zap.Strings("keys", keys), 
 | ||
| 					zap.Error(err))
 | ||
| 				return err
 | ||
| 			}
 | ||
| 			
 | ||
| 			// 统计删除的键数量
 | ||
| 			for _, cmd := range cmds {
 | ||
| 				if delCmd, ok := cmd.(*redis.IntCmd); ok {
 | ||
| 					if deleted, err := delCmd.Result(); err == nil {
 | ||
| 						totalDeleted += deleted
 | ||
| 					}
 | ||
| 				}
 | ||
| 			}
 | ||
| 			
 | ||
| 			r.logger.Debug("批量删除缓存键", 
 | ||
| 				zap.Strings("keys", keys),
 | ||
| 				zap.Int("batch_size", len(keys)),
 | ||
| 				zap.Int64("total_deleted", totalDeleted),
 | ||
| 			)
 | ||
| 		}
 | ||
| 		
 | ||
| 		cursor = next
 | ||
| 		if cursor == 0 {
 | ||
| 			break
 | ||
| 		}
 | ||
| 	}
 | ||
| 	
 | ||
| 	r.logger.Debug("缓存模式删除完成",
 | ||
| 		zap.String("pattern", fullPattern),
 | ||
| 		zap.Int64("total_deleted", totalDeleted),
 | ||
| 		zap.Int("iterations", iteration),
 | ||
| 	)
 | ||
| 	
 | ||
| 	return nil
 | ||
| }
 | ||
| 
 | ||
| // Keys 获取匹配的键
 | ||
| func (r *RedisCache) Keys(ctx context.Context, pattern string) ([]string, error) {
 | ||
| 	fullPattern := r.getFullKey(pattern)
 | ||
| 
 | ||
| 	keys, err := r.client.Keys(ctx, fullPattern).Result()
 | ||
| 	if err != nil {
 | ||
| 		return nil, err
 | ||
| 	}
 | ||
| 
 | ||
| 	// 移除前缀
 | ||
| 	result := make([]string, len(keys))
 | ||
| 	prefixLen := len(r.prefix) + 1 // +1 for ":"
 | ||
| 	for i, key := range keys {
 | ||
| 		if len(key) > prefixLen {
 | ||
| 			result[i] = key[prefixLen:]
 | ||
| 		} else {
 | ||
| 			result[i] = key
 | ||
| 		}
 | ||
| 	}
 | ||
| 
 | ||
| 	return result, nil
 | ||
| }
 | ||
| 
 | ||
| // Stats 获取缓存统计
 | ||
| func (r *RedisCache) Stats(ctx context.Context) (interfaces.CacheStats, error) {
 | ||
| 	dbSize, _ := r.client.DBSize(ctx).Result()
 | ||
| 
 | ||
| 	return interfaces.CacheStats{
 | ||
| 		Hits:        r.hits,
 | ||
| 		Misses:      r.misses,
 | ||
| 		Keys:        dbSize,
 | ||
| 		Memory:      0, // 暂时设为0,后续可解析Redis info
 | ||
| 		Connections: 0, // 暂时设为0,后续可解析Redis info
 | ||
| 	}, nil
 | ||
| }
 | ||
| 
 | ||
| // getFullKey 获取完整键名
 | ||
| func (r *RedisCache) getFullKey(key string) string {
 | ||
| 	if r.prefix == "" {
 | ||
| 		return key
 | ||
| 	}
 | ||
| 	return fmt.Sprintf("%s:%s", r.prefix, key)
 | ||
| }
 | ||
| 
 | ||
| // Flush 清空所有缓存
 | ||
| func (r *RedisCache) Flush(ctx context.Context) error {
 | ||
| 	if r.prefix == "" {
 | ||
| 		return r.client.FlushDB(ctx).Err()
 | ||
| 	}
 | ||
| 
 | ||
| 	// 只删除带前缀的键
 | ||
| 	return r.DeletePattern(ctx, "*")
 | ||
| }
 | ||
| 
 | ||
| // GetClient 获取原始Redis客户端
 | ||
| func (r *RedisCache) GetClient() *redis.Client {
 | ||
| 	return r.client
 | ||
| }
 |