This commit is contained in:
2025-07-20 20:53:26 +08:00
parent 83bf9aea7d
commit 8ad1d7288e
158 changed files with 18156 additions and 13188 deletions

View File

@@ -0,0 +1,495 @@
package cache
import (
"context"
"crypto/md5"
"encoding/hex"
"encoding/json"
"fmt"
"reflect"
"strings"
"time"
"go.uber.org/zap"
"gorm.io/gorm"
"tyapi-server/internal/shared/interfaces"
)
// GormCachePlugin GORM缓存插件
type GormCachePlugin struct {
cache interfaces.CacheService
logger *zap.Logger
config CacheConfig
}
// CacheConfig 缓存配置
type CacheConfig struct {
// 基础配置
DefaultTTL time.Duration `json:"default_ttl"` // 默认TTL
TablePrefix string `json:"table_prefix"` // 表前缀
EnabledTables []string `json:"enabled_tables"` // 启用缓存的表
DisabledTables []string `json:"disabled_tables"` // 禁用缓存的表
// 查询配置
MaxCacheSize int `json:"max_cache_size"` // 单次查询最大缓存记录数
CacheComplexSQL bool `json:"cache_complex_sql"` // 是否缓存复杂SQL
// 高级特性
EnableStats bool `json:"enable_stats"` // 启用统计
EnableWarmup bool `json:"enable_warmup"` // 启用预热
PenetrationGuard bool `json:"penetration_guard"` // 缓存穿透保护
BloomFilter bool `json:"bloom_filter"` // 布隆过滤器
// 失效策略
AutoInvalidate bool `json:"auto_invalidate"` // 自动失效
InvalidateDelay time.Duration `json:"invalidate_delay"` // 延迟失效时间
}
// DefaultCacheConfig 默认缓存配置
func DefaultCacheConfig() CacheConfig {
return CacheConfig{
DefaultTTL: 30 * time.Minute,
TablePrefix: "gorm_cache",
MaxCacheSize: 1000,
CacheComplexSQL: false,
EnableStats: true,
EnableWarmup: false,
PenetrationGuard: true,
BloomFilter: false,
AutoInvalidate: true,
InvalidateDelay: 100 * time.Millisecond,
}
}
// NewGormCachePlugin 创建GORM缓存插件
func NewGormCachePlugin(cache interfaces.CacheService, logger *zap.Logger, config ...CacheConfig) *GormCachePlugin {
cfg := DefaultCacheConfig()
if len(config) > 0 {
cfg = config[0]
}
return &GormCachePlugin{
cache: cache,
logger: logger,
config: cfg,
}
}
// Name 插件名称
func (p *GormCachePlugin) Name() string {
return "gorm-cache-plugin"
}
// Initialize 初始化插件
func (p *GormCachePlugin) Initialize(db *gorm.DB) error {
p.logger.Info("初始化GORM缓存插件",
zap.Duration("default_ttl", p.config.DefaultTTL),
zap.Bool("auto_invalidate", p.config.AutoInvalidate),
zap.Bool("penetration_guard", p.config.PenetrationGuard),
)
// 注册回调函数
return p.registerCallbacks(db)
}
// registerCallbacks 注册GORM回调
func (p *GormCachePlugin) registerCallbacks(db *gorm.DB) error {
// Query回调 - 查询时检查缓存
db.Callback().Query().Before("gorm:query").Register("cache:before_query", p.beforeQuery)
db.Callback().Query().After("gorm:query").Register("cache:after_query", p.afterQuery)
// Create回调 - 创建时失效缓存
db.Callback().Create().After("gorm:create").Register("cache:after_create", p.afterCreate)
// Update回调 - 更新时失效缓存
db.Callback().Update().After("gorm:update").Register("cache:after_update", p.afterUpdate)
// Delete回调 - 删除时失效缓存
db.Callback().Delete().After("gorm:delete").Register("cache:after_delete", p.afterDelete)
return nil
}
// ================ 查询回调 ================
// beforeQuery 查询前回调
func (p *GormCachePlugin) beforeQuery(db *gorm.DB) {
// 检查是否启用缓存
if !p.shouldCache(db) {
return
}
ctx := db.Statement.Context
if ctx == nil {
ctx = context.Background()
}
// 生成缓存键
cacheKey := p.generateCacheKey(db)
// 从缓存获取结果
var cachedResult CachedResult
if err := p.cache.Get(ctx, cacheKey, &cachedResult); err == nil {
p.logger.Debug("缓存命中",
zap.String("cache_key", cacheKey),
zap.String("table", db.Statement.Table),
)
// 恢复查询结果
if err := p.restoreFromCache(db, &cachedResult); err == nil {
// 设置标记,跳过实际查询
db.Statement.Set("cache:hit", true)
db.Statement.Set("cache:key", cacheKey)
// 更新统计
if p.config.EnableStats {
p.updateStats("hit", db.Statement.Table)
}
return
}
}
// 缓存未命中,设置标记
db.Statement.Set("cache:miss", true)
db.Statement.Set("cache:key", cacheKey)
if p.config.EnableStats {
p.updateStats("miss", db.Statement.Table)
}
}
// afterQuery 查询后回调
func (p *GormCachePlugin) afterQuery(db *gorm.DB) {
// 检查是否缓存未命中
if _, ok := db.Statement.Get("cache:miss"); !ok {
return
}
// 检查查询是否成功
if db.Error != nil {
return
}
ctx := db.Statement.Context
if ctx == nil {
ctx = context.Background()
}
cacheKey, _ := db.Statement.Get("cache:key")
// 将查询结果保存到缓存
if err := p.saveToCache(ctx, cacheKey.(string), db); err != nil {
p.logger.Warn("保存查询结果到缓存失败",
zap.String("cache_key", cacheKey.(string)),
zap.Error(err),
)
}
}
// ================ CUD回调 ================
// afterCreate 创建后回调
func (p *GormCachePlugin) afterCreate(db *gorm.DB) {
if !p.config.AutoInvalidate || db.Error != nil {
return
}
p.invalidateTableCache(db.Statement.Context, db.Statement.Table)
}
// afterUpdate 更新后回调
func (p *GormCachePlugin) afterUpdate(db *gorm.DB) {
if !p.config.AutoInvalidate || db.Error != nil {
return
}
p.invalidateTableCache(db.Statement.Context, db.Statement.Table)
}
// afterDelete 删除后回调
func (p *GormCachePlugin) afterDelete(db *gorm.DB) {
if !p.config.AutoInvalidate || db.Error != nil {
return
}
p.invalidateTableCache(db.Statement.Context, db.Statement.Table)
}
// ================ 缓存管理方法 ================
// shouldCache 判断是否应该缓存
func (p *GormCachePlugin) shouldCache(db *gorm.DB) bool {
// 检查是否明确禁用缓存
if value, ok := db.Statement.Get("cache:disabled"); ok && value.(bool) {
return false
}
// 检查是否明确启用缓存
if value, ok := db.Statement.Get("cache:enabled"); ok && value.(bool) {
return true
}
// 检查表是否在禁用列表中
for _, table := range p.config.DisabledTables {
if table == db.Statement.Table {
return false
}
}
// 检查表是否在启用列表中(如果配置了启用列表)
if len(p.config.EnabledTables) > 0 {
for _, table := range p.config.EnabledTables {
if table == db.Statement.Table {
return true
}
}
return false
}
// 检查是否为复杂查询
if !p.config.CacheComplexSQL && p.isComplexQuery(db) {
return false
}
return true
}
// isComplexQuery 判断是否为复杂查询
func (p *GormCachePlugin) isComplexQuery(db *gorm.DB) bool {
sql := db.Statement.SQL.String()
// 检查是否包含复杂操作
complexKeywords := []string{
"JOIN", "UNION", "SUBQUERY", "GROUP BY",
"HAVING", "WINDOW", "RECURSIVE",
}
upperSQL := strings.ToUpper(sql)
for _, keyword := range complexKeywords {
if strings.Contains(upperSQL, keyword) {
return true
}
}
return false
}
// generateCacheKey 生成缓存键
func (p *GormCachePlugin) generateCacheKey(db *gorm.DB) string {
// 构建缓存键的组成部分
keyParts := []string{
p.config.TablePrefix,
db.Statement.Table,
}
// 添加SQL语句hash
sqlHash := p.hashSQL(db.Statement.SQL.String(), db.Statement.Vars)
keyParts = append(keyParts, sqlHash)
return strings.Join(keyParts, ":")
}
// hashSQL 对SQL语句和参数进行hash
func (p *GormCachePlugin) hashSQL(sql string, vars []interface{}) string {
// 将SQL和参数组合
combined := sql
for _, v := range vars {
combined += fmt.Sprintf(":%v", v)
}
// 计算MD5 hash
hasher := md5.New()
hasher.Write([]byte(combined))
return hex.EncodeToString(hasher.Sum(nil))
}
// CachedResult 缓存结果结构
type CachedResult struct {
Data interface{} `json:"data"`
RowCount int64 `json:"row_count"`
Timestamp time.Time `json:"timestamp"`
}
// saveToCache 保存结果到缓存
func (p *GormCachePlugin) saveToCache(ctx context.Context, cacheKey string, db *gorm.DB) error {
// 检查结果大小限制
if db.Statement.RowsAffected > int64(p.config.MaxCacheSize) {
p.logger.Debug("查询结果过大,跳过缓存",
zap.String("cache_key", cacheKey),
zap.Int64("rows", db.Statement.RowsAffected),
)
return nil
}
// 获取查询结果
dest := db.Statement.Dest
if dest == nil {
return fmt.Errorf("查询结果为空")
}
// 构建缓存结果
result := CachedResult{
Data: dest,
RowCount: db.Statement.RowsAffected,
Timestamp: time.Now(),
}
// 获取TTL
ttl := p.getTTL(db)
// 保存到缓存
if err := p.cache.Set(ctx, cacheKey, result, ttl); err != nil {
return fmt.Errorf("保存到缓存失败: %w", err)
}
p.logger.Debug("查询结果已缓存",
zap.String("cache_key", cacheKey),
zap.Int64("rows", db.Statement.RowsAffected),
zap.Duration("ttl", ttl),
)
return nil
}
// restoreFromCache 从缓存恢复结果
func (p *GormCachePlugin) restoreFromCache(db *gorm.DB, cachedResult *CachedResult) error {
if cachedResult.Data == nil {
return fmt.Errorf("缓存数据为空")
}
// 反序列化到目标对象
destValue := reflect.ValueOf(db.Statement.Dest)
if destValue.Kind() != reflect.Ptr || destValue.IsNil() {
return fmt.Errorf("目标对象必须是指针")
}
// 将缓存数据复制到目标
cachedValue := reflect.ValueOf(cachedResult.Data)
if !cachedValue.Type().AssignableTo(destValue.Elem().Type()) {
// 尝试JSON转换
jsonData, err := json.Marshal(cachedResult.Data)
if err != nil {
return fmt.Errorf("缓存数据类型不匹配")
}
if err := json.Unmarshal(jsonData, db.Statement.Dest); err != nil {
return fmt.Errorf("JSON反序列化失败: %w", err)
}
} else {
destValue.Elem().Set(cachedValue)
}
// 设置影响行数
db.Statement.RowsAffected = cachedResult.RowCount
return nil
}
// getTTL 获取TTL
func (p *GormCachePlugin) getTTL(db *gorm.DB) time.Duration {
// 检查是否设置了自定义TTL
if value, ok := db.Statement.Get("cache:ttl"); ok {
if ttl, ok := value.(time.Duration); ok {
return ttl
}
}
return p.config.DefaultTTL
}
// invalidateTableCache 失效表相关缓存
func (p *GormCachePlugin) invalidateTableCache(ctx context.Context, table string) {
if ctx == nil {
ctx = context.Background()
}
// 延迟失效(避免并发问题)
if p.config.InvalidateDelay > 0 {
time.AfterFunc(p.config.InvalidateDelay, func() {
p.doInvalidateTableCache(ctx, table)
})
} else {
p.doInvalidateTableCache(ctx, table)
}
}
// doInvalidateTableCache 执行缓存失效
func (p *GormCachePlugin) doInvalidateTableCache(ctx context.Context, table string) {
pattern := fmt.Sprintf("%s:%s:*", p.config.TablePrefix, table)
if err := p.cache.DeletePattern(ctx, pattern); err != nil {
p.logger.Warn("失效表缓存失败",
zap.String("table", table),
zap.String("pattern", pattern),
zap.Error(err),
)
return
}
p.logger.Debug("表缓存已失效",
zap.String("table", table),
zap.String("pattern", pattern),
)
}
// updateStats 更新统计信息
func (p *GormCachePlugin) updateStats(operation, table string) {
// 这里可以接入Prometheus等监控系统
p.logger.Debug("缓存统计",
zap.String("operation", operation),
zap.String("table", table),
)
}
// ================ 高级功能 ================
// WarmupCache 预热缓存
func (p *GormCachePlugin) WarmupCache(ctx context.Context, db *gorm.DB, queries []string) error {
if !p.config.EnableWarmup {
return fmt.Errorf("缓存预热未启用")
}
for _, query := range queries {
if err := db.Raw(query).Error; err != nil {
p.logger.Warn("缓存预热失败",
zap.String("query", query),
zap.Error(err),
)
}
}
return nil
}
// GetCacheStats 获取缓存统计
func (p *GormCachePlugin) GetCacheStats(ctx context.Context) (map[string]interface{}, error) {
stats, err := p.cache.Stats(ctx)
if err != nil {
return nil, err
}
return map[string]interface{}{
"hits": stats.Hits,
"misses": stats.Misses,
"keys": stats.Keys,
"memory": stats.Memory,
"connections": stats.Connections,
"config": p.config,
}, nil
}
// SetCacheEnabled 设置缓存启用状态
func (p *GormCachePlugin) SetCacheEnabled(db *gorm.DB, enabled bool) *gorm.DB {
return db.Set("cache:enabled", enabled)
}
// SetCacheDisabled 设置缓存禁用状态
func (p *GormCachePlugin) SetCacheDisabled(db *gorm.DB, disabled bool) *gorm.DB {
return db.Set("cache:disabled", disabled)
}
// SetCacheTTL 设置缓存TTL
func (p *GormCachePlugin) SetCacheTTL(db *gorm.DB, ttl time.Duration) *gorm.DB {
return db.Set("cache:ttl", ttl)
}