Files
tyapi-server/internal/domains/api/services/query_whitelist_service.go
2026-06-20 17:28:24 +08:00

170 lines
4.8 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package services
import (
"context"
"sync"
"sync/atomic"
"time"
"tyapi-server/internal/domains/api/entities"
"tyapi-server/internal/domains/api/repositories"
"tyapi-server/internal/domains/api/services/processors"
"go.uber.org/zap"
)
type QueryWhitelistService interface {
// EnrichContext 入参(姓名+身份证)命中白名单时,将命中的 api_codes 写入 context
EnrichContext(ctx context.Context, userID string, params map[string]interface{}) context.Context
InvalidateCache(userID, idCardHash string)
InvalidateAllCache()
}
// queryWhitelistSnapshot 全量 enabled 规则快照,按 id_card_hash 索引,热路径只读内存。
type queryWhitelistSnapshot struct {
byHash map[string][]*entities.QueryWhitelistEntry
}
type QueryWhitelistServiceImpl struct {
repo repositories.QueryWhitelistRepository
logger *zap.Logger
snapshot atomic.Pointer[queryWhitelistSnapshot]
snapshotMu sync.Mutex
}
func NewQueryWhitelistService(
repo repositories.QueryWhitelistRepository,
_ FormConfigService,
logger *zap.Logger,
) QueryWhitelistService {
s := &QueryWhitelistServiceImpl{
repo: repo,
logger: logger,
}
return s
}
// EnrichContext 判断入参是否命中白名单,并将命中的 api_codes 写入 context不拦截请求。
// 热路径:姓名+身份证提取 → 内存快照匹配 → 写入 ctx由各处理器按 api_code 返回查询为空。
func (s *QueryWhitelistServiceImpl) EnrichContext(
ctx context.Context,
userID string,
params map[string]interface{},
) context.Context {
identity := ExtractIdentityParams(params)
if !identity.OK {
return ctx
}
idCardHash := HashIDCard(identity.IDCard)
entries, err := s.lookupEntries(ctx, userID, idCardHash)
if err != nil {
s.logger.Error("查询白名单快照失败", zap.Error(err), zap.String("user_id", userID))
return ctx
}
matches := make([]processors.WhitelistMatch, 0, len(entries))
for _, entry := range entries {
if !entry.IsEnabled() {
continue
}
if !entry.MatchesName(identity.Name) {
continue
}
s.logger.Info("命中查询白名单",
zap.String("user_id", userID),
zap.String("whitelist_id", entry.ID),
zap.Bool("is_global", entry.IsGlobal()),
zap.Strings("api_codes", entry.APICodes),
)
matches = append(matches, processors.WhitelistMatch{
ID: entry.ID,
APICodes: entry.APICodes,
IsGlobal: entry.IsGlobal(),
})
}
return processors.WithWhitelistContext(ctx, matches)
}
func (s *QueryWhitelistServiceImpl) lookupEntries(ctx context.Context, userID, idCardHash string) ([]*entities.QueryWhitelistEntry, error) {
snap, err := s.getSnapshot(ctx)
if err != nil {
return nil, err
}
candidates := snap.byHash[idCardHash]
if len(candidates) == 0 {
return nil, nil
}
result := make([]*entities.QueryWhitelistEntry, 0, len(candidates))
for _, entry := range candidates {
if entry.UserID == userID || entry.UserID == entities.QueryWhitelistGlobalUserID {
result = append(result, entry)
}
}
return result, nil
}
func (s *QueryWhitelistServiceImpl) getSnapshot(ctx context.Context) (*queryWhitelistSnapshot, error) {
if snap := s.snapshot.Load(); snap != nil {
return snap, nil
}
return s.reloadSnapshot(ctx)
}
func (s *QueryWhitelistServiceImpl) reloadSnapshot(ctx context.Context) (*queryWhitelistSnapshot, error) {
s.snapshotMu.Lock()
defer s.snapshotMu.Unlock()
if snap := s.snapshot.Load(); snap != nil {
return snap, nil
}
entries, err := s.repo.FindAllEnabled(ctx)
if err != nil {
return nil, err
}
byHash := make(map[string][]*entities.QueryWhitelistEntry, len(entries))
for _, entry := range entries {
byHash[entry.IDCardHash] = append(byHash[entry.IDCardHash], entry)
}
snap := &queryWhitelistSnapshot{byHash: byHash}
s.snapshot.Store(snap)
s.logger.Info("查询白名单快照已加载", zap.Int("entries", len(entries)), zap.Int("hash_buckets", len(byHash)))
return snap, nil
}
// refreshSnapshotAsync 管理端变更后异步刷新,避免阻塞写请求。
func (s *QueryWhitelistServiceImpl) refreshSnapshotAsync() {
go func() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
s.snapshotMu.Lock()
defer s.snapshotMu.Unlock()
s.snapshot.Store(nil)
entries, err := s.repo.FindAllEnabled(ctx)
if err != nil {
s.logger.Error("刷新查询白名单快照失败", zap.Error(err))
return
}
byHash := make(map[string][]*entities.QueryWhitelistEntry, len(entries))
for _, entry := range entries {
byHash[entry.IDCardHash] = append(byHash[entry.IDCardHash], entry)
}
s.snapshot.Store(&queryWhitelistSnapshot{byHash: byHash})
s.logger.Info("查询白名单快照已刷新", zap.Int("entries", len(entries)))
}()
}
func (s *QueryWhitelistServiceImpl) InvalidateCache(_ string, _ string) {
s.refreshSnapshotAsync()
}
func (s *QueryWhitelistServiceImpl) InvalidateAllCache() {
s.refreshSnapshotAsync()
}