tianyuan-api-server/apps/sentinel/internal/model/whitelistmodel.go
2024-10-02 00:57:17 +08:00

88 lines
2.6 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package model
import (
"context"
"fmt"
"github.com/zeromicro/go-zero/core/stores/cache"
"github.com/zeromicro/go-zero/core/stores/redis"
"github.com/zeromicro/go-zero/core/stores/sqlc"
"github.com/zeromicro/go-zero/core/stores/sqlx"
)
var _ WhitelistModel = (*customWhitelistModel)(nil)
type (
// WhitelistModel is an interface to be customized, add more methods here,
// and implement the added methods in customWhitelistModel.
WhitelistModel interface {
whitelistModel
IsIpInWhitelist(ctx context.Context, ip string) (bool, error)
FindWhitelistList(ctx context.Context, userId, page, pageSize int64) ([]*Whitelist, int64, error)
}
customWhitelistModel struct {
*defaultWhitelistModel
rds *redis.Redis
}
)
// NewWhitelistModel returns a model for the database table.
func NewWhitelistModel(rds *redis.Redis, conn sqlx.SqlConn, c cache.CacheConf, opts ...cache.Option) WhitelistModel {
return &customWhitelistModel{
rds: rds,
defaultWhitelistModel: newWhitelistModel(conn, c, opts...),
}
}
func (m *customWhitelistModel) IsIpInWhitelist(ctx context.Context, ip string) (bool, error) {
// 定义 Redis 缓存 Set 键,存储所有白名单 IP
redisKey := "whitelist_ips"
// 1. 检查 Redis Set 中是否有这个 IP
isMember, err := m.rds.SismemberCtx(ctx, redisKey, ip)
if err == nil && isMember {
// 如果 Redis Set 中存在,表示 IP 已在白名单中
return true, nil
}
// 2. 如果 Redis 中没有匹配,查询数据库
query := `SELECT whitelist_ip FROM whitelist WHERE whitelist_ip = ? LIMIT 1`
var dbIp string
err = m.QueryRowNoCacheCtx(ctx, &dbIp, query, ip)
if err != nil {
// 如果数据库查询出错,返回错误
if err == sqlc.ErrNotFound {
return false, nil // 如果没有找到,返回 false
}
return false, err
}
// 3. 如果数据库查询成功,写入 Redis Set并返回 true
_, redisErr := m.rds.SaddCtx(ctx, redisKey, ip)
if redisErr != nil {
return false, redisErr // Redis 更新失败
}
return true, nil
}
func (m *customWhitelistModel) FindWhitelistList(ctx context.Context, userId, page, pageSize int64) ([]*Whitelist, int64, error) {
offset := (page - 1) * pageSize
var whitelist []*Whitelist
query := fmt.Sprintf("SELECT %s FROM %s WHERE user_id = ? ORDER BY created_at DESC LIMIT ?,?", whitelistRows, m.table)
err := m.QueryRowsNoCacheCtx(ctx, &whitelist, query, userId, offset, pageSize)
if err != nil {
return nil, 0, err
}
// 查询总数量
var total int64
countQuery := fmt.Sprintf("SELECT COUNT(*) FROM %s", m.table)
err = m.QueryRowNoCacheCtx(ctx, &total, countQuery)
if err != nil {
return nil, 0, err
}
return whitelist, total, nil
}