package model import ( "context" "fmt" "github.com/zeromicro/go-zero/core/stores/cache" "github.com/zeromicro/go-zero/core/stores/redis" "github.com/zeromicro/go-zero/core/stores/sqlc" "github.com/zeromicro/go-zero/core/stores/sqlx" ) var _ WhitelistModel = (*customWhitelistModel)(nil) type ( // WhitelistModel is an interface to be customized, add more methods here, // and implement the added methods in customWhitelistModel. WhitelistModel interface { whitelistModel IsIpInWhitelist(ctx context.Context, ip string) (bool, error) FindWhitelistList(ctx context.Context, userId, page, pageSize int64) ([]*Whitelist, int64, error) } customWhitelistModel struct { *defaultWhitelistModel rds *redis.Redis } ) // NewWhitelistModel returns a model for the database table. func NewWhitelistModel(rds *redis.Redis, conn sqlx.SqlConn, c cache.CacheConf, opts ...cache.Option) WhitelistModel { return &customWhitelistModel{ rds: rds, defaultWhitelistModel: newWhitelistModel(conn, c, opts...), } } func (m *customWhitelistModel) IsIpInWhitelist(ctx context.Context, ip string) (bool, error) { // 定义 Redis 缓存 Set 键,存储所有白名单 IP redisKey := "whitelist_ips" // 1. 检查 Redis Set 中是否有这个 IP isMember, err := m.rds.SismemberCtx(ctx, redisKey, ip) if err == nil && isMember { // 如果 Redis Set 中存在,表示 IP 已在白名单中 return true, nil } // 2. 如果 Redis 中没有匹配,查询数据库 query := `SELECT whitelist_ip FROM whitelist WHERE whitelist_ip = ? LIMIT 1` var dbIp string err = m.QueryRowNoCacheCtx(ctx, &dbIp, query, ip) if err != nil { // 如果数据库查询出错,返回错误 if err == sqlc.ErrNotFound { return false, nil // 如果没有找到,返回 false } return false, err } // 3. 如果数据库查询成功,写入 Redis Set,并返回 true _, redisErr := m.rds.SaddCtx(ctx, redisKey, ip) if redisErr != nil { return false, redisErr // Redis 更新失败 } return true, nil } func (m *customWhitelistModel) FindWhitelistList(ctx context.Context, userId, page, pageSize int64) ([]*Whitelist, int64, error) { offset := (page - 1) * pageSize var whitelist []*Whitelist query := fmt.Sprintf("SELECT %s FROM %s WHERE user_id = ? ORDER BY created_at DESC LIMIT ?,?", whitelistRows, m.table) err := m.QueryRowsNoCacheCtx(ctx, &whitelist, query, userId, offset, pageSize) if err != nil { return nil, 0, err } // 查询总数量 var total int64 countQuery := fmt.Sprintf("SELECT COUNT(*) FROM %s", m.table) err = m.QueryRowNoCacheCtx(ctx, &total, countQuery) if err != nil { return nil, 0, err } return whitelist, total, nil }