Files
tyapi-server/internal/infrastructure/cache/redis_cache.go

285 lines
6.1 KiB
Go
Raw Normal View History

package cache
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/redis/go-redis/v9"
"go.uber.org/zap"
"tyapi-server/internal/shared/interfaces"
)
// RedisCache Redis缓存实现
type RedisCache struct {
client *redis.Client
logger *zap.Logger
prefix string
// 统计信息
hits int64
misses int64
}
// NewRedisCache 创建Redis缓存实例
func NewRedisCache(client *redis.Client, logger *zap.Logger, prefix string) *RedisCache {
return &RedisCache{
client: client,
logger: logger,
prefix: prefix,
}
}
// Name 返回服务名称
func (r *RedisCache) Name() string {
return "redis-cache"
}
// Initialize 初始化服务
func (r *RedisCache) Initialize(ctx context.Context) error {
// 测试连接
_, err := r.client.Ping(ctx).Result()
if err != nil {
r.logger.Error("Failed to connect to Redis", zap.Error(err))
return fmt.Errorf("redis connection failed: %w", err)
}
r.logger.Info("Redis cache service initialized")
return nil
}
// HealthCheck 健康检查
func (r *RedisCache) HealthCheck(ctx context.Context) error {
_, err := r.client.Ping(ctx).Result()
return err
}
// Shutdown 关闭服务
func (r *RedisCache) Shutdown(ctx context.Context) error {
return r.client.Close()
}
// Get 获取缓存值
func (r *RedisCache) Get(ctx context.Context, key string, dest interface{}) error {
fullKey := r.getFullKey(key)
val, err := r.client.Get(ctx, fullKey).Result()
if err != nil {
if err == redis.Nil {
r.misses++
return fmt.Errorf("cache miss: key %s not found", key)
}
r.logger.Error("Failed to get cache", zap.String("key", key), zap.Error(err))
return err
}
r.hits++
return json.Unmarshal([]byte(val), dest)
}
// Set 设置缓存值
func (r *RedisCache) Set(ctx context.Context, key string, value interface{}, ttl ...interface{}) error {
fullKey := r.getFullKey(key)
data, err := json.Marshal(value)
if err != nil {
return fmt.Errorf("failed to marshal value: %w", err)
}
var expiration time.Duration
if len(ttl) > 0 {
switch v := ttl[0].(type) {
case time.Duration:
expiration = v
case int:
expiration = time.Duration(v) * time.Second
case string:
expiration, _ = time.ParseDuration(v)
default:
expiration = 24 * time.Hour // 默认24小时
}
} else {
expiration = 24 * time.Hour // 默认24小时
}
err = r.client.Set(ctx, fullKey, data, expiration).Err()
if err != nil {
r.logger.Error("Failed to set cache", zap.String("key", key), zap.Error(err))
return err
}
return nil
}
// Delete 删除缓存
func (r *RedisCache) Delete(ctx context.Context, keys ...string) error {
if len(keys) == 0 {
return nil
}
fullKeys := make([]string, len(keys))
for i, key := range keys {
fullKeys[i] = r.getFullKey(key)
}
err := r.client.Del(ctx, fullKeys...).Err()
if err != nil {
r.logger.Error("Failed to delete cache", zap.Strings("keys", keys), zap.Error(err))
return err
}
return nil
}
// Exists 检查键是否存在
func (r *RedisCache) Exists(ctx context.Context, key string) (bool, error) {
fullKey := r.getFullKey(key)
count, err := r.client.Exists(ctx, fullKey).Result()
if err != nil {
return false, err
}
return count > 0, nil
}
// GetMultiple 批量获取
func (r *RedisCache) GetMultiple(ctx context.Context, keys []string) (map[string]interface{}, error) {
if len(keys) == 0 {
return make(map[string]interface{}), nil
}
fullKeys := make([]string, len(keys))
for i, key := range keys {
fullKeys[i] = r.getFullKey(key)
}
values, err := r.client.MGet(ctx, fullKeys...).Result()
if err != nil {
return nil, err
}
result := make(map[string]interface{})
for i, val := range values {
if val != nil {
var data interface{}
if err := json.Unmarshal([]byte(val.(string)), &data); err == nil {
result[keys[i]] = data
}
}
}
return result, nil
}
// SetMultiple 批量设置
func (r *RedisCache) SetMultiple(ctx context.Context, data map[string]interface{}, ttl ...interface{}) error {
if len(data) == 0 {
return nil
}
var expiration time.Duration
if len(ttl) > 0 {
switch v := ttl[0].(type) {
case time.Duration:
expiration = v
case int:
expiration = time.Duration(v) * time.Second
default:
expiration = 24 * time.Hour
}
} else {
expiration = 24 * time.Hour
}
pipe := r.client.Pipeline()
for key, value := range data {
fullKey := r.getFullKey(key)
jsonData, err := json.Marshal(value)
if err != nil {
continue
}
pipe.Set(ctx, fullKey, jsonData, expiration)
}
_, err := pipe.Exec(ctx)
return err
}
// DeletePattern 按模式删除
func (r *RedisCache) DeletePattern(ctx context.Context, pattern string) error {
fullPattern := r.getFullKey(pattern)
keys, err := r.client.Keys(ctx, fullPattern).Result()
if err != nil {
return err
}
if len(keys) > 0 {
return r.client.Del(ctx, keys...).Err()
}
return nil
}
// Keys 获取匹配的键
func (r *RedisCache) Keys(ctx context.Context, pattern string) ([]string, error) {
fullPattern := r.getFullKey(pattern)
keys, err := r.client.Keys(ctx, fullPattern).Result()
if err != nil {
return nil, err
}
// 移除前缀
result := make([]string, len(keys))
prefixLen := len(r.prefix) + 1 // +1 for ":"
for i, key := range keys {
if len(key) > prefixLen {
result[i] = key[prefixLen:]
} else {
result[i] = key
}
}
return result, nil
}
// Stats 获取缓存统计
func (r *RedisCache) Stats(ctx context.Context) (interfaces.CacheStats, error) {
dbSize, _ := r.client.DBSize(ctx).Result()
return interfaces.CacheStats{
Hits: r.hits,
Misses: r.misses,
Keys: dbSize,
Memory: 0, // 暂时设为0后续可解析Redis info
Connections: 0, // 暂时设为0后续可解析Redis info
}, nil
}
// getFullKey 获取完整键名
func (r *RedisCache) getFullKey(key string) string {
if r.prefix == "" {
return key
}
return fmt.Sprintf("%s:%s", r.prefix, key)
}
// Flush 清空所有缓存
func (r *RedisCache) Flush(ctx context.Context) error {
if r.prefix == "" {
return r.client.FlushDB(ctx).Err()
}
// 只删除带前缀的键
return r.DeletePattern(ctx, "*")
}
// GetClient 获取原始Redis客户端
func (r *RedisCache) GetClient() *redis.Client {
return r.client
}