package api import ( "context" "crypto/subtle" "encoding/json" "strings" "tyapi-server/internal/application/api/dto" "tyapi-server/internal/config" "tyapi-server/internal/domains/api/entities" api_services "tyapi-server/internal/domains/api/services" "tyapi-server/internal/shared/crypto" "fmt" "go.uber.org/zap" ) const queryWhitelistMgmtKeyHeader = "Whitelist-Mgmt-Key" // QueryWhitelistMgmtKeyHeader 公开接口管理密钥请求头名 func QueryWhitelistMgmtKeyHeader() string { return queryWhitelistMgmtKeyHeader } // CreateEntryPublic 公开接口:解密业务参数后为当前 access_id 对应用户添加规则(仅对该用户生效) func (s *QueryWhitelistApplicationServiceImpl) CreateEntryPublic( ctx context.Context, headerAccessID, managementKey, clientIP, encryptedData string, ) (*dto.QueryWhitelistEntryResponse, string, error) { if !s.config.QueryWhitelist.PublicAPI.Enabled { return nil, "", ErrPublicAPIDisabled } if strings.TrimSpace(managementKey) == "" { return nil, "", ErrMissingMgmtKey } if !constantTimeEqual(managementKey, s.config.QueryWhitelist.PublicAPI.ManagementKey) { return nil, "", ErrInvalidMgmtKey } headerAccessID = strings.TrimSpace(headerAccessID) if headerAccessID == "" { return nil, "", ErrMissingAccessId } apiUser, err := s.apiUserService.LoadApiUserByAccessId(ctx, headerAccessID) if err != nil { s.logger.Warn("公开白名单接口 AccessId 无效", zap.String("access_id", headerAccessID), zap.Error(err)) return nil, "", ErrInvalidAccessId } if apiUser.IsFrozen() { return nil, "", ErrFrozenAccount } decrypted, err := crypto.AesDecrypt(encryptedData, apiUser.SecretKey) if err != nil { s.logger.Warn("公开白名单接口解密失败", zap.String("access_id", headerAccessID), zap.Error(err)) return nil, "", ErrDecryptFail } var raw map[string]json.RawMessage if err := json.Unmarshal(decrypted, &raw); err != nil { return nil, "", ErrRequestParam } if err := validatePublicAPICodesField(raw["api_codes"]); err != nil { return nil, "", MapWhitelistAppError(err) } var payload dto.QueryWhitelistPublicPayload if err := json.Unmarshal(decrypted, &payload); err != nil { return nil, "", ErrRequestParam } if !s.config.App.IsDevelopment() && !apiUser.IsWhiteListed(clientIP) { s.logger.Warn("公开白名单接口 IP 未授权", zap.String("access_id", headerAccessID), zap.String("client_ip", clientIP)) return nil, "", ErrInvalidIP } createdBy := "public_api:" + headerAccessID resp, err := s.createEntry(ctx, createdBy, &dto.QueryWhitelistEntryRequest{ UserID: apiUser.UserId, Name: payload.Name, IDCard: payload.IDCard, APICodes: payload.APICodes, Remark: payload.Remark, }, clientIP) return resp, apiUser.SecretKey, MapWhitelistAppError(err) } func (s *QueryWhitelistApplicationServiceImpl) createEntry( ctx context.Context, createdBy string, req *dto.QueryWhitelistEntryRequest, operationIP string, ) (*dto.QueryWhitelistEntryResponse, error) { if err := validateQueryWhitelistRequest(req.UserID, req.Name, req.IDCard, req.APICodes); err != nil { return nil, err } idCardHash := api_services.HashIDCard(req.IDCard) exists, err := s.repo.ExistsByUserIDCardHashAndName(ctx, req.UserID, idCardHash, strings.TrimSpace(req.Name), "") if err != nil { return nil, err } if exists { return nil, fmt.Errorf("该用户下已存在相同的身份证与姓名规则") } entry := &entities.QueryWhitelistEntry{ UserID: strings.TrimSpace(req.UserID), Name: normalizeWhitelistName(req.Name), IDCardHash: idCardHash, IDCardMasked: api_services.MaskIDCard(req.IDCard), APICodes: entities.APICodeList(req.APICodes), Status: entities.QueryWhitelistStatusEnabled, Remark: strings.TrimSpace(req.Remark), OperationIP: strings.TrimSpace(operationIP), CreatedBy: &createdBy, } if err := s.repo.Create(ctx, entry); err != nil { return nil, err } s.queryWhitelistSvc.InvalidateCache(entry.UserID, idCardHash) resp := dto.NewQueryWhitelistEntryResponse(entry) return &resp, nil } func constantTimeEqual(a, b string) bool { return subtle.ConstantTimeCompare([]byte(a), []byte(b)) == 1 } // validatePublicAPICodesField 公开接口:api_codes 必须为 string 数组,且禁止通配符 * func validatePublicAPICodesField(raw json.RawMessage) error { if len(raw) == 0 { return fmt.Errorf("api_codes 必填且必须为 string 数组") } var probe interface{} if err := json.Unmarshal(raw, &probe); err != nil { return fmt.Errorf("api_codes 必须为 string 数组") } arr, ok := probe.([]interface{}) if !ok { return fmt.Errorf("api_codes 必须为 string 数组,不可传字符串或其他类型") } if len(arr) == 0 { return fmt.Errorf("api_codes 不能为空") } codes := make([]string, 0, len(arr)) for _, item := range arr { s, ok := item.(string) if !ok { return fmt.Errorf("api_codes 数组元素必须为 string") } code := strings.TrimSpace(s) if code == "" { return fmt.Errorf("api_codes 不能包含空值") } if code == "*" { return fmt.Errorf("公开接口不允许 api_codes 使用通配符 *") } codes = append(codes, code) } return validateAPICodes(codes) } // ValidatePublicAPIManagementKey 校验平台管理密钥 func ValidatePublicAPIManagementKey(cfg *config.Config, key string) error { if !cfg.QueryWhitelist.PublicAPI.Enabled { return ErrPublicAPIDisabled } if strings.TrimSpace(key) == "" { return ErrMissingMgmtKey } if !constantTimeEqual(key, cfg.QueryWhitelist.PublicAPI.ManagementKey) { return ErrInvalidMgmtKey } return nil } // MapWhitelistAppError 将校验错误映射为公开 API 错误码 func MapWhitelistAppError(err error) error { if err == nil { return nil } msg := err.Error() switch { case strings.Contains(msg, "身份证号格式不正确"), strings.Contains(msg, "api_codes"), strings.Contains(msg, "name 不能为空"), strings.Contains(msg, "user_id 不能为空"): return ErrRequestParam case strings.Contains(msg, "已存在"): return ErrWhitelistExists default: return err } } // PublicAPIErrorMessage 公开接口错误文案 func PublicAPIErrorMessage(err error) string { if err == nil { return "" } switch err { case ErrWhitelistExists: return "规则已存在" case ErrRequestParam: return "请求参数结构不正确" default: return GetErrorMessage(err) } }