diff --git a/internal/application/api/api_application_service.go b/internal/application/api/api_application_service.go index 23e9d50..75023e2 100644 --- a/internal/application/api/api_application_service.go +++ b/internal/application/api/api_application_service.go @@ -370,10 +370,9 @@ func extractParentAccessID(params map[string]interface{}) (string, bool) { // callExternalApi 同步调用外部API func (s *ApiApplicationServiceImpl) callExternalApi(ctx context.Context, cmd *commands.ApiCallCommand, validation *dto.ApiCallValidationResult) (string, error) { - // 查询白名单拦截:命中则返回「查询为空」,不调用上游、不扣费 - if s.queryWhitelistSvc != nil && - s.queryWhitelistSvc.ShouldReturnEmpty(ctx, validation.GetUserID(), cmd.ApiName, validation.RequestParams) { - return "", ErrQueryEmpty + // 查询白名单:应用层判断入参是否命中,并将命中的 api_codes 写入 context + if s.queryWhitelistSvc != nil { + ctx = s.queryWhitelistSvc.EnrichContext(ctx, validation.GetUserID(), validation.RequestParams) } // 创建CallContext diff --git a/internal/domains/api/services/api_request_service.go b/internal/domains/api/services/api_request_service.go index 36b6d29..11929e5 100644 --- a/internal/domains/api/services/api_request_service.go +++ b/internal/domains/api/services/api_request_service.go @@ -427,13 +427,26 @@ func registerAllProcessors(combService *comb.CombService) { "PDFG01GZ": pdfg.ProcessPDFG01GZRequest, } - // 批量注册到组合包服务 + // 批量注册到组合包服务(包装白名单:读 ctx 中命中的 api_code,嵌套子调用按子 api_code 判断) for apiCode, processor := range processorMap { - combService.RegisterProcessor(apiCode, processor) + wrapped := wrapProcessorWithWhitelist(apiCode, processor) + combService.RegisterProcessor(apiCode, wrapped) } // 同时设置全局处理器映射 - RequestProcessors = processorMap + RequestProcessors = make(map[string]processors.ProcessorFunc, len(processorMap)) + for apiCode, processor := range processorMap { + RequestProcessors[apiCode] = wrapProcessorWithWhitelist(apiCode, processor) + } +} + +func wrapProcessorWithWhitelist(apiCode string, processor processors.ProcessorFunc) processors.ProcessorFunc { + return func(ctx context.Context, params []byte, deps *processors.ProcessorDependencies) ([]byte, error) { + if processors.WhitelistShouldReturnEmpty(ctx, apiCode) { + return nil, processors.ErrNotFound + } + return processor(ctx, params, deps) + } } // 注册API处理器 - 现在通过registerAllProcessors统一管理 diff --git a/internal/domains/api/services/processors/dwbg/dwbg8b4d_processor.go b/internal/domains/api/services/processors/dwbg/dwbg8b4d_processor.go index 77d4421..ba9535b 100644 --- a/internal/domains/api/services/processors/dwbg/dwbg8b4d_processor.go +++ b/internal/domains/api/services/processors/dwbg/dwbg8b4d_processor.go @@ -255,28 +255,17 @@ func collectAPIData(ctx context.Context, params dto.DWBG8B4DReq, deps *processor return apiData } -// callProcessor 调用指定的处理器 +// callProcessor 调用指定的处理器(走注册表,含白名单包装) func callProcessor(ctx context.Context, apiCode string, params []byte, deps *processors.ProcessorDependencies) (interface{}, error) { - // 通过CombService获取处理器 - if combSvc, ok := deps.CombService.(interface { - GetProcessor(apiCode string) (processors.ProcessorFunc, bool) - }); ok { - processor, exists := combSvc.GetProcessor(apiCode) - if !exists { - return nil, fmt.Errorf("未找到处理器: %s", apiCode) - } - respBytes, err := processor(ctx, params, deps) - if err != nil { - return nil, err - } - var data interface{} - if err := json.Unmarshal(respBytes, &data); err != nil { - return nil, fmt.Errorf("解析响应失败: %w", err) - } - return data, nil + respBytes, err := processors.InvokeRegisteredProcessor(ctx, apiCode, params, deps) + if err != nil { + return nil, err } - - return nil, fmt.Errorf("无法获取处理器: %s,CombService不支持GetProcessor方法", apiCode) + var data interface{} + if err := json.Unmarshal(respBytes, &data); err != nil { + return nil, fmt.Errorf("解析响应失败: %w", err) + } + return data, nil } // exportAPIDataToJSON 将API数据导出为JSON文件,方便调试 diff --git a/internal/domains/api/services/processors/invoke.go b/internal/domains/api/services/processors/invoke.go new file mode 100644 index 0000000..964b0d1 --- /dev/null +++ b/internal/domains/api/services/processors/invoke.go @@ -0,0 +1,32 @@ +package processors + +import ( + "context" + "fmt" +) + +// ProcessorRegistry 已注册处理器查询(由 CombService 实现) +type ProcessorRegistry interface { + GetProcessor(apiCode string) (ProcessorFunc, bool) +} + +// InvokeRegisteredProcessor 通过注册表调用处理器(含白名单包装),聚合处理器内部转接应使用此方法。 +func InvokeRegisteredProcessor( + ctx context.Context, + apiCode string, + params []byte, + deps *ProcessorDependencies, +) ([]byte, error) { + if deps == nil || deps.CombService == nil { + return nil, fmt.Errorf("CombService 未配置,无法调用处理器: %s", apiCode) + } + registry, ok := deps.CombService.(ProcessorRegistry) + if !ok { + return nil, fmt.Errorf("CombService 不支持 GetProcessor,无法调用处理器: %s", apiCode) + } + processor, exists := registry.GetProcessor(apiCode) + if !exists { + return nil, fmt.Errorf("未找到处理器: %s", apiCode) + } + return processor(ctx, params, deps) +} diff --git a/internal/domains/api/services/processors/pdfg/pdfg01gz_processor.go b/internal/domains/api/services/processors/pdfg/pdfg01gz_processor.go index 7601082..918c871 100644 --- a/internal/domains/api/services/processors/pdfg/pdfg01gz_processor.go +++ b/internal/domains/api/services/processors/pdfg/pdfg01gz_processor.go @@ -349,29 +349,17 @@ func collectAPIData(ctx context.Context, params dto.PDFG01GZReq, deps *processor return apiData } -// callProcessor 调用指定的处理器 +// callProcessor 调用指定的处理器(走注册表,含白名单包装) func callProcessor(ctx context.Context, apiCode string, params []byte, deps *processors.ProcessorDependencies) (interface{}, error) { - // 通过CombService获取处理器 - if combSvc, ok := deps.CombService.(interface { - GetProcessor(apiCode string) (processors.ProcessorFunc, bool) - }); ok { - processor, exists := combSvc.GetProcessor(apiCode) - if !exists { - return nil, fmt.Errorf("未找到处理器: %s", apiCode) - } - respBytes, err := processor(ctx, params, deps) - if err != nil { - return nil, err - } - var data interface{} - if err := json.Unmarshal(respBytes, &data); err != nil { - return nil, fmt.Errorf("解析响应失败: %w", err) - } - return data, nil + respBytes, err := processors.InvokeRegisteredProcessor(ctx, apiCode, params, deps) + if err != nil { + return nil, err } - - // 如果无法通过CombService获取,返回错误 - return nil, fmt.Errorf("无法获取处理器: %s,CombService不支持GetProcessor方法", apiCode) + var data interface{} + if err := json.Unmarshal(respBytes, &data); err != nil { + return nil, fmt.Errorf("解析响应失败: %w", err) + } + return data, nil } // formatDataForPDF 格式化数据为PDF生成服务需要的格式 diff --git a/internal/domains/api/services/processors/qcxg/qcxg4d2e_processor.go b/internal/domains/api/services/processors/qcxg/qcxg4d2e_processor.go index 02c789a..8295a48 100644 --- a/internal/domains/api/services/processors/qcxg/qcxg4d2e_processor.go +++ b/internal/domains/api/services/processors/qcxg/qcxg4d2e_processor.go @@ -25,7 +25,7 @@ func ProcessQCXG4D2ERequest(ctx context.Context, params []byte, deps *processors return nil, errors.Join(processors.ErrSystem, err) } - raw, err := ProcessQCXGM4CLRequest(ctx, m4clParams, deps) + raw, err := processors.InvokeRegisteredProcessor(ctx, "QCXGM4CL", m4clParams, deps) if err != nil { return nil, err } diff --git a/internal/domains/api/services/processors/qcxg/qcxg5f3a_processor.go b/internal/domains/api/services/processors/qcxg/qcxg5f3a_processor.go index f7c1b9c..6e63093 100644 --- a/internal/domains/api/services/processors/qcxg/qcxg5f3a_processor.go +++ b/internal/domains/api/services/processors/qcxg/qcxg5f3a_processor.go @@ -28,7 +28,7 @@ func ProcessQCXG5F3ARequest(ctx context.Context, params []byte, deps *processors return nil, errors.Join(processors.ErrSystem, err) } - raw, err := ProcessQCXGM4CLRequest(ctx, m4clParams, deps) + raw, err := processors.InvokeRegisteredProcessor(ctx, "QCXGM4CL", m4clParams, deps) if err != nil { return nil, err } diff --git a/internal/domains/api/services/processors/qcxg/qcxg9p1c_processor.go b/internal/domains/api/services/processors/qcxg/qcxg9p1c_processor.go index d3ab4bf..a5d52fc 100644 --- a/internal/domains/api/services/processors/qcxg/qcxg9p1c_processor.go +++ b/internal/domains/api/services/processors/qcxg/qcxg9p1c_processor.go @@ -28,7 +28,7 @@ func ProcessQCXG9P1CRequest(ctx context.Context, params []byte, deps *processors return nil, errors.Join(processors.ErrSystem, err) } - raw, err := ProcessQCXGM4CLRequest(ctx, m4clParams, deps) + raw, err := processors.InvokeRegisteredProcessor(ctx, "QCXGM4CL", m4clParams, deps) if err != nil { return nil, err } diff --git a/internal/domains/api/services/processors/qygl/qygl3f8e_processor.go b/internal/domains/api/services/processors/qygl/qygl3f8e_processor.go index 036a4f0..563fc05 100644 --- a/internal/domains/api/services/processors/qygl/qygl3f8e_processor.go +++ b/internal/domains/api/services/processors/qygl/qygl3f8e_processor.go @@ -50,7 +50,7 @@ func ProcessQYGL3F8ERequest(ctx context.Context, params []byte, deps *processors return nil, errors.Join(processors.ErrSystem, err) } - b4c0Response, err := ProcessQYGL6S1BRequest(ctx, b4c0ParamsBytes, deps) + b4c0Response, err := processors.InvokeRegisteredProcessor(ctx, "QYGL6S1B", b4c0ParamsBytes, deps) if err != nil { log.Error("QYGL3F8E调用QYGL6S1B失败", zap.Error(err)) return nil, err // 错误已经是处理器标准错误,直接返回 @@ -620,25 +620,7 @@ func callProcessorSafely(ctx context.Context, processorType, entCode string, dep } var response []byte - switch processorType { - case "QYGL5A3C": - response, err = ProcessQYGL5A3CRequest(ctx, paramsBytes, deps) - case "QYGL8B4D": - response, err = ProcessQYGL8B4DRequest(ctx, paramsBytes, deps) - case "QYGL9E2F": - response, err = ProcessQYGL9E2FRequest(ctx, paramsBytes, deps) - case "QYGL7C1A": - response, err = ProcessQYGL7C1ARequest(ctx, paramsBytes, deps) - case "QYGL7D9A": - response, err = ProcessQYGL7D9ARequest(ctx, paramsBytes, deps) - case "QYGL4B2E": - response, err = ProcessQYGL4B2ERequest(ctx, paramsBytes, deps) - default: - log.Warn("QYGL3F8E未知的处理器类型", - zap.String("processor_type", processorType), - ) - return map[string]interface{}{} - } + response, err = processors.InvokeRegisteredProcessor(ctx, processorType, paramsBytes, deps) if err != nil { // 如果是查询为空错误,返回空对象 @@ -681,7 +663,7 @@ func callQYGL5S1IProcessorSafely(ctx context.Context, entCode string, entName st if err != nil { return map[string]interface{}{} } - response, err := ProcessQYGL5S1IRequest(ctx, paramsBytes, deps) + response, err := processors.InvokeRegisteredProcessor(ctx, "QYGL5S1I", paramsBytes, deps) if err != nil { return map[string]interface{}{} } diff --git a/internal/domains/api/services/processors/qygl/qyglj1u9_processor.go b/internal/domains/api/services/processors/qygl/qyglj1u9_processor.go index 0c5d3f6..e4328d4 100644 --- a/internal/domains/api/services/processors/qygl/qyglj1u9_processor.go +++ b/internal/domains/api/services/processors/qygl/qyglj1u9_processor.go @@ -39,7 +39,7 @@ func ProcessQYGLJ1U9Request(ctx context.Context, params []byte, deps *processors resultsCh := make(chan apiResult, 7) var wg sync.WaitGroup - call := func(key string, req interface{}, fn func(context.Context, []byte, *processors.ProcessorDependencies) ([]byte, error)) { + call := func(key, apiCode string, req interface{}) { wg.Add(1) go func() { defer wg.Done() @@ -48,7 +48,7 @@ func ProcessQYGLJ1U9Request(ctx context.Context, params []byte, deps *processors resultsCh <- apiResult{key: key, err: err} return } - resp, err := fn(ctx, b, deps) + resp, err := processors.InvokeRegisteredProcessor(ctx, apiCode, b, deps) if err != nil { resultsCh <- apiResult{key: key, err: err} return @@ -70,46 +70,46 @@ func ProcessQYGLJ1U9Request(ctx context.Context, params []byte, deps *processors } // 企业全量信息核验V2(QYGLUY3S) - call("jiguangFull", map[string]interface{}{ + call("jiguangFull", "QYGLUY3S", map[string]interface{}{ "ent_name": p.EntName, "ent_code": p.EntCode, - }, ProcessQYGLUY3SRequest) + }) // 企业股权结构全景(QYGLJ0Q1) - call("equityPanorama", map[string]interface{}{ + call("equityPanorama", "QYGLJ0Q1", map[string]interface{}{ "ent_name": p.EntName, - }, ProcessQYGLJ0Q1Request) + }) // 企业司法涉诉V2(QYGL5S1I) - call("judicialCertFull", map[string]interface{}{ + call("judicialCertFull", "QYGL5S1I", map[string]interface{}{ "ent_name": p.EntName, "ent_code": p.EntCode, - }, ProcessQYGL5S1IRequest) + }) // 企业年报信息核验(QYGLDJ12) - call("annualReport", map[string]interface{}{ + call("annualReport", "QYGLDJ12", map[string]interface{}{ "ent_name": p.EntName, "ent_code": p.EntCode, - }, ProcessQYGLDJ12Request) + }) // 企业税收违法核查(QYGL8848) - call("taxViolation", map[string]interface{}{ + call("taxViolation", "QYGL8848", map[string]interface{}{ "ent_name": p.EntName, "ent_code": p.EntCode, - }, ProcessQYGL8848Request) + }) // 欠税公告(QYGL7D9A,天眼查 OwnTax,keyword 为统一社会信用代码) - call("taxArrears", map[string]interface{}{ + call("taxArrears", "QYGL7D9A", map[string]interface{}{ "ent_code": p.EntCode, "page_size": 20, "page_num": 1, - }, ProcessQYGL7D9ARequest) + }) // 企业进出口信用核查(QYGLDJ33) - call("customsCredit", map[string]interface{}{ + call("customsCredit", "QYGLDJ33", map[string]interface{}{ "ent_name": p.EntName, "ent_code": p.EntCode, - }, ProcessQYGLDJ33Request) + }) wg.Wait() close(resultsCh) diff --git a/internal/domains/api/services/processors/whitelist_context.go b/internal/domains/api/services/processors/whitelist_context.go new file mode 100644 index 0000000..aca0510 --- /dev/null +++ b/internal/domains/api/services/processors/whitelist_context.go @@ -0,0 +1,58 @@ +package processors + +import "context" + +type whitelistContextKey struct{} + +// WhitelistMatch 请求级白名单命中项(身份匹配,不含当前顶层 api_code 过滤) +type WhitelistMatch struct { + ID string + APICodes []string + IsGlobal bool +} + +// WhitelistContext 写入 context 的白名单状态,供各处理器按 api_code 判断是否返回查询为空 +type WhitelistContext struct { + Matches []WhitelistMatch +} + +// WithWhitelistContext 将白名单命中结果写入 context +func WithWhitelistContext(ctx context.Context, matches []WhitelistMatch) context.Context { + if len(matches) == 0 { + return ctx + } + return context.WithValue(ctx, whitelistContextKey{}, &WhitelistContext{Matches: matches}) +} + +// WhitelistFromContext 从 context 读取白名单状态 +func WhitelistFromContext(ctx context.Context) *WhitelistContext { + wc, ok := ctx.Value(whitelistContextKey{}).(*WhitelistContext) + if !ok || wc == nil { + return nil + } + return wc +} + +// WhitelistShouldReturnEmpty 根据 context 中的白名单与当前 api_code 判断是否应返回查询为空。 +// 入参是否命中、命中哪些 api_code 由应用层 EnrichContext 写入;此处仅读 ctx。 +func WhitelistShouldReturnEmpty(ctx context.Context, apiCode string) bool { + wc := WhitelistFromContext(ctx) + if wc == nil { + return false + } + for _, m := range wc.Matches { + if matchesWhitelistAPICode(m.APICodes, apiCode) { + return true + } + } + return false +} + +func matchesWhitelistAPICode(apiCodes []string, apiCode string) bool { + for _, code := range apiCodes { + if code == "*" || code == apiCode { + return true + } + } + return false +} diff --git a/internal/domains/api/services/processors/whitelist_context_test.go b/internal/domains/api/services/processors/whitelist_context_test.go new file mode 100644 index 0000000..1272536 --- /dev/null +++ b/internal/domains/api/services/processors/whitelist_context_test.go @@ -0,0 +1,32 @@ +package processors + +import ( + "context" + "testing" +) + +func TestWhitelistShouldReturnEmpty_PerAPICode(t *testing.T) { + ctx := WithWhitelistContext(context.Background(), []WhitelistMatch{ + {ID: "1", APICodes: []string{"FLXG0V4B"}}, + }) + + if !WhitelistShouldReturnEmpty(ctx, "FLXG0V4B") { + t.Fatal("FLXG0V4B should hit") + } + if WhitelistShouldReturnEmpty(ctx, "JRZQ8A2D") { + t.Fatal("JRZQ8A2D should not hit entry scoped to FLXG0V4B") + } +} + +func TestWhitelistShouldReturnEmpty_WildcardMatchesAnyAPICode(t *testing.T) { + ctx := WithWhitelistContext(context.Background(), []WhitelistMatch{ + {ID: "1", APICodes: []string{"*"}}, + }) + + if !WhitelistShouldReturnEmpty(ctx, "QYGL8261") { + t.Fatal("wildcard should match any api_code in ctx") + } + if !WhitelistShouldReturnEmpty(ctx, "YYSY8B1C") { + t.Fatal("wildcard should match mobile-only api_code when in ctx") + } +} diff --git a/internal/domains/api/services/processors/yysy/yysy8c2d_processor.go b/internal/domains/api/services/processors/yysy/yysy8c2d_processor.go index ab6db2f..d8a820a 100644 --- a/internal/domains/api/services/processors/yysy/yysy8c2d_processor.go +++ b/internal/domains/api/services/processors/yysy/yysy8c2d_processor.go @@ -9,6 +9,6 @@ import ( // ProcessYYSY8C2DRequest YYSY8C2D API处理方法 - 运营商三要素查询 func ProcessYYSY8C2DRequest(ctx context.Context, params []byte, deps *processors.ProcessorDependencies) ([]byte, error) { - return ProcessYYSY9A1BRequest(ctx, params, deps) + return processors.InvokeRegisteredProcessor(ctx, "YYSY9A1B", params, deps) } diff --git a/internal/domains/api/services/query_whitelist_identity_api_test.go b/internal/domains/api/services/query_whitelist_identity_api_test.go index 7a0c07c..4794392 100644 --- a/internal/domains/api/services/query_whitelist_identity_api_test.go +++ b/internal/domains/api/services/query_whitelist_identity_api_test.go @@ -5,6 +5,7 @@ import ( "testing" "tyapi-server/internal/domains/api/entities" + "tyapi-server/internal/domains/api/services/processors" ) func TestRequiresIdentityInput_FLXG0V4B(t *testing.T) { @@ -29,7 +30,7 @@ func TestRequiresIdentityInput_COMB(t *testing.T) { } } -func TestShouldReturnEmpty_SkipsNonIdentityAPIEvenWithWildcard(t *testing.T) { +func TestEnrichContext_WildcardAppliesToConfiguredAPICodesInProcessor(t *testing.T) { idCard := "350681198611130611" hash := HashIDCard(idCard) svc := newTestQueryWhitelistService(&mockQueryWhitelistRepo{ @@ -43,10 +44,11 @@ func TestShouldReturnEmpty_SkipsNonIdentityAPIEvenWithWildcard(t *testing.T) { Status: entities.QueryWhitelistStatusEnabled, }, }, - }, false) + }) params := map[string]interface{}{"id_card": idCard, "name": "张三"} - if svc.ShouldReturnEmpty(context.Background(), "user-a", "QYGL8261", params) { - t.Fatal("non-identity API should not be intercepted even with api_codes=*") + ctx := svc.EnrichContext(context.Background(), "user-a", params) + if !processors.WhitelistShouldReturnEmpty(ctx, "YYSY8B1C") { + t.Fatal("processor layer should honor wildcard api_codes from ctx") } } diff --git a/internal/domains/api/services/query_whitelist_service.go b/internal/domains/api/services/query_whitelist_service.go index b968855..1a0c5af 100644 --- a/internal/domains/api/services/query_whitelist_service.go +++ b/internal/domains/api/services/query_whitelist_service.go @@ -8,12 +8,14 @@ import ( "tyapi-server/internal/domains/api/entities" "tyapi-server/internal/domains/api/repositories" + "tyapi-server/internal/domains/api/services/processors" "go.uber.org/zap" ) type QueryWhitelistService interface { - ShouldReturnEmpty(ctx context.Context, userID, apiCode string, params map[string]interface{}) bool + // EnrichContext 入参(姓名+身份证)命中白名单时,将命中的 api_codes 写入 context + EnrichContext(ctx context.Context, userID string, params map[string]interface{}) context.Context InvalidateCache(userID, idCardHash string) InvalidateAllCache() } @@ -24,84 +26,66 @@ type queryWhitelistSnapshot struct { } type QueryWhitelistServiceImpl struct { - repo repositories.QueryWhitelistRepository - formConfigService FormConfigService - logger *zap.Logger + repo repositories.QueryWhitelistRepository + logger *zap.Logger snapshot atomic.Pointer[queryWhitelistSnapshot] snapshotMu sync.Mutex - - // apiCode -> 是否要求身份证入参(FormConfig 反射结果,进程内永久缓存) - identityAPICache sync.Map } func NewQueryWhitelistService( repo repositories.QueryWhitelistRepository, - formConfigService FormConfigService, + _ FormConfigService, logger *zap.Logger, ) QueryWhitelistService { s := &QueryWhitelistServiceImpl{ - repo: repo, - formConfigService: formConfigService, - logger: logger, + repo: repo, + logger: logger, } return s } -// ShouldReturnEmpty 检查是否应返回「查询为空」。 -// 热路径:入参提取 → API 类型缓存 → 内存快照匹配,不逐请求查库。 -func (s *QueryWhitelistServiceImpl) ShouldReturnEmpty( +// EnrichContext 判断入参是否命中白名单,并将命中的 api_codes 写入 context,不拦截请求。 +// 热路径:姓名+身份证提取 → 内存快照匹配 → 写入 ctx,由各处理器按 api_code 返回查询为空。 +func (s *QueryWhitelistServiceImpl) EnrichContext( ctx context.Context, - userID, apiCode string, + userID string, params map[string]interface{}, -) bool { +) context.Context { identity := ExtractIdentityParams(params) if !identity.OK { - return false - } - if !s.requiresIdentityInput(ctx, apiCode) { - return false + return ctx } idCardHash := HashIDCard(identity.IDCard) entries, err := s.lookupEntries(ctx, userID, idCardHash) if err != nil { s.logger.Error("查询白名单快照失败", zap.Error(err), zap.String("user_id", userID)) - return false + return ctx } + matches := make([]processors.WhitelistMatch, 0, len(entries)) for _, entry := range entries { if !entry.IsEnabled() { continue } - if !entry.MatchesAPICode(apiCode) { - continue - } if !entry.MatchesName(identity.Name) { continue } s.logger.Info("命中查询白名单", zap.String("user_id", userID), - zap.String("api_code", apiCode), zap.String("whitelist_id", entry.ID), zap.Bool("is_global", entry.IsGlobal()), + zap.Strings("api_codes", entry.APICodes), ) - return true + matches = append(matches, processors.WhitelistMatch{ + ID: entry.ID, + APICodes: entry.APICodes, + IsGlobal: entry.IsGlobal(), + }) } - return false -} - -func (s *QueryWhitelistServiceImpl) requiresIdentityInput(ctx context.Context, apiCode string) bool { - if s.formConfigService == nil { - return false - } - if cached, ok := s.identityAPICache.Load(apiCode); ok { - return cached.(bool) - } - result := s.formConfigService.RequiresIdentityInput(ctx, apiCode) - s.identityAPICache.Store(apiCode, result) - return result + return processors.WithWhitelistContext(ctx, matches) } func (s *QueryWhitelistServiceImpl) lookupEntries(ctx context.Context, userID, idCardHash string) ([]*entities.QueryWhitelistEntry, error) { diff --git a/internal/domains/api/services/query_whitelist_service_test.go b/internal/domains/api/services/query_whitelist_service_test.go index d530c1c..214e181 100644 --- a/internal/domains/api/services/query_whitelist_service_test.go +++ b/internal/domains/api/services/query_whitelist_service_test.go @@ -6,6 +6,7 @@ import ( "tyapi-server/internal/domains/api/entities" "tyapi-server/internal/domains/api/repositories" + "tyapi-server/internal/domains/api/services/processors" "tyapi-server/internal/shared/interfaces" "go.uber.org/zap" @@ -61,20 +62,13 @@ func (m *mockQueryWhitelistRepo) FindAllEnabled(ctx context.Context) ([]*entitie return result, nil } -type mockFormConfigService struct { - requiresIdentity bool +func newTestQueryWhitelistService(repo repositories.QueryWhitelistRepository) QueryWhitelistService { + return NewQueryWhitelistService(repo, nil, zap.NewNop()) } -func (m *mockFormConfigService) GetFormConfig(ctx context.Context, apiCode string) (*FormConfig, error) { - return nil, nil -} - -func (m *mockFormConfigService) RequiresIdentityInput(ctx context.Context, apiCode string) bool { - return m.requiresIdentity -} - -func newTestQueryWhitelistService(repo repositories.QueryWhitelistRepository, requiresIdentity bool) QueryWhitelistService { - return NewQueryWhitelistService(repo, &mockFormConfigService{requiresIdentity: requiresIdentity}, zap.NewNop()) +func whitelistShouldReturnEmpty(svc QueryWhitelistService, ctx context.Context, userID, apiCode string, params map[string]interface{}) bool { + ctx = svc.EnrichContext(ctx, userID, params) + return processors.WhitelistShouldReturnEmpty(ctx, apiCode) } func TestShouldReturnEmpty_GlobalRule(t *testing.T) { @@ -91,10 +85,10 @@ func TestShouldReturnEmpty_GlobalRule(t *testing.T) { Status: entities.QueryWhitelistStatusEnabled, }, }, - }, true) + }) params := map[string]interface{}{"id_card": idCard, "name": "任意姓名"} - if !svc.ShouldReturnEmpty(context.Background(), "user-a", "FLXG0V4B", params) { + if !whitelistShouldReturnEmpty(svc, context.Background(), "user-a", "FLXG0V4B", params) { t.Fatal("global rule should hit for any user") } } @@ -113,16 +107,16 @@ func TestShouldReturnEmpty_UserSpecificRule(t *testing.T) { Status: entities.QueryWhitelistStatusEnabled, }, }, - }, true) + }) params := map[string]interface{}{"id_card": idCard, "name": "张三"} - if !svc.ShouldReturnEmpty(context.Background(), "user-a", "FLXG0V4B", params) { + if !whitelistShouldReturnEmpty(svc, context.Background(), "user-a", "FLXG0V4B", params) { t.Fatal("user-a should hit") } - if svc.ShouldReturnEmpty(context.Background(), "user-b", "FLXG0V4B", params) { + if whitelistShouldReturnEmpty(svc, context.Background(), "user-b", "FLXG0V4B", params) { t.Fatal("user-b should not hit user-a rule") } - if svc.ShouldReturnEmpty(context.Background(), "user-a", "JRZQ8A2D", params) { + if whitelistShouldReturnEmpty(svc, context.Background(), "user-a", "JRZQ8A2D", params) { t.Fatal("wrong api code should not hit") } } @@ -141,10 +135,38 @@ func TestShouldReturnEmpty_NameMismatch(t *testing.T) { Status: entities.QueryWhitelistStatusEnabled, }, }, - }, true) + }) params := map[string]interface{}{"id_card": idCard, "name": "李四"} - if svc.ShouldReturnEmpty(context.Background(), "user-a", "FLXG0V4B", params) { - t.Fatal("name mismatch should not hit") + ctx := svc.EnrichContext(context.Background(), "user-a", params) + if processors.WhitelistFromContext(ctx) != nil { + t.Fatal("name mismatch should not enrich context") + } +} + +func TestEnrichContext_PartialAPICodeForComb(t *testing.T) { + idCard := "350681198611130611" + hash := HashIDCard(idCard) + svc := newTestQueryWhitelistService(&mockQueryWhitelistRepo{ + entries: []*entities.QueryWhitelistEntry{ + { + ID: "4", + UserID: "user-a", + Name: "张三", + IDCardHash: hash, + APICodes: entities.APICodeList{"FLXG0V4B"}, + Status: entities.QueryWhitelistStatusEnabled, + }, + }, + }) + + params := map[string]interface{}{"id_card": idCard, "name": "张三"} + ctx := svc.EnrichContext(context.Background(), "user-a", params) + + if !processors.WhitelistShouldReturnEmpty(ctx, "FLXG0V4B") { + t.Fatal("FLXG0V4B should return empty in comb sub-call") + } + if processors.WhitelistShouldReturnEmpty(ctx, "JRZQ8A2D") { + t.Fatal("JRZQ8A2D should still call upstream in comb sub-call") } }