Files
tyapi-server/internal/infrastructure/cache/redis_cache.go
2025-09-01 18:29:59 +08:00

386 lines
8.7 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package cache
import (
"context"
"encoding/json"
"fmt"
"strings"
"time"
"github.com/redis/go-redis/v9"
"go.uber.org/zap"
"tyapi-server/internal/shared/interfaces"
)
// RedisCache Redis缓存实现
type RedisCache struct {
client *redis.Client
logger *zap.Logger
prefix string
// 统计信息
hits int64
misses int64
}
// NewRedisCache 创建Redis缓存实例
func NewRedisCache(client *redis.Client, logger *zap.Logger, prefix string) *RedisCache {
return &RedisCache{
client: client,
logger: logger,
prefix: prefix,
}
}
// Name 返回服务名称
func (r *RedisCache) Name() string {
return "redis-cache"
}
// Initialize 初始化服务
func (r *RedisCache) Initialize(ctx context.Context) error {
// 测试连接
_, err := r.client.Ping(ctx).Result()
if err != nil {
r.logger.Error("Failed to connect to Redis", zap.Error(err))
return fmt.Errorf("redis connection failed: %w", err)
}
r.logger.Info("Redis cache service initialized")
return nil
}
// HealthCheck 健康检查
func (r *RedisCache) HealthCheck(ctx context.Context) error {
_, err := r.client.Ping(ctx).Result()
return err
}
// Shutdown 关闭服务
func (r *RedisCache) Shutdown(ctx context.Context) error {
return r.client.Close()
}
// Get 获取缓存值
func (r *RedisCache) Get(ctx context.Context, key string, dest interface{}) error {
fullKey := r.getFullKey(key)
val, err := r.client.Get(ctx, fullKey).Result()
if err != nil {
if err == redis.Nil {
r.misses++
return fmt.Errorf("cache miss: key %s not found", key)
}
r.logger.Error("Failed to get cache", zap.String("key", key), zap.Error(err))
return err
}
r.hits++
return json.Unmarshal([]byte(val), dest)
}
// Set 设置缓存值
func (r *RedisCache) Set(ctx context.Context, key string, value interface{}, ttl ...interface{}) error {
fullKey := r.getFullKey(key)
data, err := json.Marshal(value)
if err != nil {
r.logger.Error("序列化缓存数据失败", zap.String("key", key), zap.Error(err))
return fmt.Errorf("failed to marshal value: %w", err)
}
var expiration time.Duration
if len(ttl) > 0 {
switch v := ttl[0].(type) {
case time.Duration:
expiration = v
case int:
expiration = time.Duration(v) * time.Second
case string:
expiration, _ = time.ParseDuration(v)
default:
expiration = 24 * time.Hour // 默认24小时
}
} else {
expiration = 24 * time.Hour // 默认24小时
}
err = r.client.Set(ctx, fullKey, data, expiration).Err()
if err != nil {
r.logger.Error("设置缓存失败", zap.String("key", key), zap.Error(err))
return err
}
r.logger.Debug("设置缓存成功", zap.String("key", key), zap.Duration("ttl", expiration))
return nil
}
// Delete 删除缓存
func (r *RedisCache) Delete(ctx context.Context, keys ...string) error {
if len(keys) == 0 {
return nil
}
fullKeys := make([]string, len(keys))
for i, key := range keys {
fullKeys[i] = r.getFullKey(key)
}
err := r.client.Del(ctx, fullKeys...).Err()
if err != nil {
r.logger.Error("Failed to delete cache", zap.Strings("keys", keys), zap.Error(err))
return err
}
return nil
}
// Exists 检查键是否存在
func (r *RedisCache) Exists(ctx context.Context, key string) (bool, error) {
fullKey := r.getFullKey(key)
count, err := r.client.Exists(ctx, fullKey).Result()
if err != nil {
return false, err
}
return count > 0, nil
}
// GetMultiple 批量获取
func (r *RedisCache) GetMultiple(ctx context.Context, keys []string) (map[string]interface{}, error) {
if len(keys) == 0 {
return make(map[string]interface{}), nil
}
fullKeys := make([]string, len(keys))
for i, key := range keys {
fullKeys[i] = r.getFullKey(key)
}
values, err := r.client.MGet(ctx, fullKeys...).Result()
if err != nil {
r.logger.Error("批量获取缓存失败", zap.Strings("keys", keys), zap.Error(err))
return nil, err
}
result := make(map[string]interface{})
for i, val := range values {
if val != nil {
var data interface{}
// 修复改进JSON反序列化错误处理
if err := json.Unmarshal([]byte(val.(string)), &data); err != nil {
r.logger.Warn("反序列化缓存数据失败",
zap.String("key", keys[i]),
zap.String("value", val.(string)),
zap.Error(err))
continue
}
result[keys[i]] = data
}
}
return result, nil
}
// SetMultiple 批量设置
func (r *RedisCache) SetMultiple(ctx context.Context, data map[string]interface{}, ttl ...interface{}) error {
if len(data) == 0 {
return nil
}
var expiration time.Duration
if len(ttl) > 0 {
switch v := ttl[0].(type) {
case time.Duration:
expiration = v
case int:
expiration = time.Duration(v) * time.Second
default:
expiration = 24 * time.Hour
}
} else {
expiration = 24 * time.Hour
}
pipe := r.client.Pipeline()
for key, value := range data {
fullKey := r.getFullKey(key)
jsonData, err := json.Marshal(value)
if err != nil {
continue
}
pipe.Set(ctx, fullKey, jsonData, expiration)
}
_, err := pipe.Exec(ctx)
return err
}
// DeletePattern 按模式删除
func (r *RedisCache) DeletePattern(ctx context.Context, pattern string) error {
// 修复:避免重复添加前缀
var fullPattern string
if strings.HasPrefix(pattern, r.prefix+":") {
fullPattern = pattern
} else {
fullPattern = r.getFullKey(pattern)
}
// 检查上下文是否已取消
if ctx.Err() != nil {
return ctx.Err()
}
var cursor uint64
var totalDeleted int64
maxIterations := 100 // 防止无限循环
iteration := 0
for {
// 检查迭代次数限制
iteration++
if iteration > maxIterations {
r.logger.Warn("缓存删除操作达到最大迭代次数限制",
zap.String("pattern", fullPattern),
zap.Int("max_iterations", maxIterations),
zap.Int64("total_deleted", totalDeleted),
)
break
}
// 检查上下文是否已取消
if ctx.Err() != nil {
r.logger.Warn("缓存删除操作被取消",
zap.String("pattern", fullPattern),
zap.Int64("total_deleted", totalDeleted),
zap.Error(ctx.Err()),
)
return ctx.Err()
}
// 执行SCAN操作
keys, next, err := r.client.Scan(ctx, cursor, fullPattern, 1000).Result()
if err != nil {
// 如果是上下文取消错误,直接返回
if err == context.Canceled || err == context.DeadlineExceeded {
r.logger.Warn("缓存删除操作被取消",
zap.String("pattern", fullPattern),
zap.Int64("total_deleted", totalDeleted),
zap.Error(err),
)
return err
}
r.logger.Error("扫描缓存键失败",
zap.String("pattern", fullPattern),
zap.Error(err))
return err
}
// 批量删除找到的键
if len(keys) > 0 {
// 使用pipeline批量删除提高性能
pipe := r.client.Pipeline()
pipe.Del(ctx, keys...)
cmds, err := pipe.Exec(ctx)
if err != nil {
r.logger.Error("批量删除缓存键失败",
zap.Strings("keys", keys),
zap.Error(err))
return err
}
// 统计删除的键数量
for _, cmd := range cmds {
if delCmd, ok := cmd.(*redis.IntCmd); ok {
if deleted, err := delCmd.Result(); err == nil {
totalDeleted += deleted
}
}
}
r.logger.Debug("批量删除缓存键",
zap.Strings("keys", keys),
zap.Int("batch_size", len(keys)),
zap.Int64("total_deleted", totalDeleted),
)
}
cursor = next
if cursor == 0 {
break
}
}
r.logger.Debug("缓存模式删除完成",
zap.String("pattern", fullPattern),
zap.Int64("total_deleted", totalDeleted),
zap.Int("iterations", iteration),
)
return nil
}
// Keys 获取匹配的键
func (r *RedisCache) Keys(ctx context.Context, pattern string) ([]string, error) {
fullPattern := r.getFullKey(pattern)
keys, err := r.client.Keys(ctx, fullPattern).Result()
if err != nil {
return nil, err
}
// 移除前缀
result := make([]string, len(keys))
prefixLen := len(r.prefix) + 1 // +1 for ":"
for i, key := range keys {
if len(key) > prefixLen {
result[i] = key[prefixLen:]
} else {
result[i] = key
}
}
return result, nil
}
// Stats 获取缓存统计
func (r *RedisCache) Stats(ctx context.Context) (interfaces.CacheStats, error) {
dbSize, _ := r.client.DBSize(ctx).Result()
return interfaces.CacheStats{
Hits: r.hits,
Misses: r.misses,
Keys: dbSize,
Memory: 0, // 暂时设为0后续可解析Redis info
Connections: 0, // 暂时设为0后续可解析Redis info
}, nil
}
// getFullKey 获取完整键名
func (r *RedisCache) getFullKey(key string) string {
if r.prefix == "" {
return key
}
return fmt.Sprintf("%s:%s", r.prefix, key)
}
// Flush 清空所有缓存
func (r *RedisCache) Flush(ctx context.Context) error {
if r.prefix == "" {
return r.client.FlushDB(ctx).Err()
}
// 只删除带前缀的键
return r.DeletePattern(ctx, "*")
}
// GetClient 获取原始Redis客户端
func (r *RedisCache) GetClient() *redis.Client {
return r.client
}