Files
tyapi-server/internal/application/api/query_whitelist_public_service.go
2026-06-19 10:49:13 +08:00

214 lines
6.3 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package api
import (
"context"
"crypto/subtle"
"encoding/json"
"strings"
"tyapi-server/internal/application/api/dto"
"tyapi-server/internal/config"
"tyapi-server/internal/domains/api/entities"
api_services "tyapi-server/internal/domains/api/services"
"tyapi-server/internal/shared/crypto"
"fmt"
"go.uber.org/zap"
)
const queryWhitelistMgmtKeyHeader = "Whitelist-Mgmt-Key"
// QueryWhitelistMgmtKeyHeader 公开接口管理密钥请求头名
func QueryWhitelistMgmtKeyHeader() string {
return queryWhitelistMgmtKeyHeader
}
// CreateEntryPublic 公开接口:解密业务参数后为当前 access_id 对应用户添加规则(仅对该用户生效)
func (s *QueryWhitelistApplicationServiceImpl) CreateEntryPublic(
ctx context.Context,
headerAccessID, managementKey, clientIP, encryptedData string,
) (*dto.QueryWhitelistEntryResponse, string, error) {
if !s.config.QueryWhitelist.PublicAPI.Enabled {
return nil, "", ErrPublicAPIDisabled
}
if strings.TrimSpace(managementKey) == "" {
return nil, "", ErrMissingMgmtKey
}
if !constantTimeEqual(managementKey, s.config.QueryWhitelist.PublicAPI.ManagementKey) {
return nil, "", ErrInvalidMgmtKey
}
headerAccessID = strings.TrimSpace(headerAccessID)
if headerAccessID == "" {
return nil, "", ErrMissingAccessId
}
apiUser, err := s.apiUserService.LoadApiUserByAccessId(ctx, headerAccessID)
if err != nil {
s.logger.Warn("公开白名单接口 AccessId 无效", zap.String("access_id", headerAccessID), zap.Error(err))
return nil, "", ErrInvalidAccessId
}
if apiUser.IsFrozen() {
return nil, "", ErrFrozenAccount
}
decrypted, err := crypto.AesDecrypt(encryptedData, apiUser.SecretKey)
if err != nil {
s.logger.Warn("公开白名单接口解密失败", zap.String("access_id", headerAccessID), zap.Error(err))
return nil, "", ErrDecryptFail
}
var raw map[string]json.RawMessage
if err := json.Unmarshal(decrypted, &raw); err != nil {
return nil, "", ErrRequestParam
}
if err := validatePublicAPICodesField(raw["api_codes"]); err != nil {
return nil, "", MapWhitelistAppError(err)
}
var payload dto.QueryWhitelistPublicPayload
if err := json.Unmarshal(decrypted, &payload); err != nil {
return nil, "", ErrRequestParam
}
if !s.config.App.IsDevelopment() && !apiUser.IsWhiteListed(clientIP) {
s.logger.Warn("公开白名单接口 IP 未授权",
zap.String("access_id", headerAccessID),
zap.String("client_ip", clientIP))
return nil, "", ErrInvalidIP
}
createdBy := "public_api:" + headerAccessID
resp, err := s.createEntry(ctx, createdBy, &dto.QueryWhitelistEntryRequest{
UserID: apiUser.UserId,
Name: payload.Name,
IDCard: payload.IDCard,
APICodes: payload.APICodes,
Remark: payload.Remark,
}, clientIP)
return resp, apiUser.SecretKey, MapWhitelistAppError(err)
}
func (s *QueryWhitelistApplicationServiceImpl) createEntry(
ctx context.Context,
createdBy string,
req *dto.QueryWhitelistEntryRequest,
operationIP string,
) (*dto.QueryWhitelistEntryResponse, error) {
if err := validateQueryWhitelistRequest(req.UserID, req.Name, req.IDCard, req.APICodes); err != nil {
return nil, err
}
idCardHash := api_services.HashIDCard(req.IDCard)
exists, err := s.repo.ExistsByUserIDCardHashAndName(ctx, req.UserID, idCardHash, strings.TrimSpace(req.Name), "")
if err != nil {
return nil, err
}
if exists {
return nil, fmt.Errorf("该用户下已存在相同的身份证与姓名规则")
}
entry := &entities.QueryWhitelistEntry{
UserID: strings.TrimSpace(req.UserID),
Name: normalizeWhitelistName(req.Name),
IDCardHash: idCardHash,
IDCardMasked: api_services.MaskIDCard(req.IDCard),
APICodes: entities.APICodeList(req.APICodes),
Status: entities.QueryWhitelistStatusEnabled,
Remark: strings.TrimSpace(req.Remark),
OperationIP: strings.TrimSpace(operationIP),
CreatedBy: &createdBy,
}
if err := s.repo.Create(ctx, entry); err != nil {
return nil, err
}
s.queryWhitelistSvc.InvalidateCache(entry.UserID, idCardHash)
resp := dto.NewQueryWhitelistEntryResponse(entry)
return &resp, nil
}
func constantTimeEqual(a, b string) bool {
return subtle.ConstantTimeCompare([]byte(a), []byte(b)) == 1
}
// validatePublicAPICodesField 公开接口api_codes 必须为 string 数组,且禁止通配符 *
func validatePublicAPICodesField(raw json.RawMessage) error {
if len(raw) == 0 {
return fmt.Errorf("api_codes 必填且必须为 string 数组")
}
var probe interface{}
if err := json.Unmarshal(raw, &probe); err != nil {
return fmt.Errorf("api_codes 必须为 string 数组")
}
arr, ok := probe.([]interface{})
if !ok {
return fmt.Errorf("api_codes 必须为 string 数组,不可传字符串或其他类型")
}
if len(arr) == 0 {
return fmt.Errorf("api_codes 不能为空")
}
codes := make([]string, 0, len(arr))
for _, item := range arr {
s, ok := item.(string)
if !ok {
return fmt.Errorf("api_codes 数组元素必须为 string")
}
code := strings.TrimSpace(s)
if code == "" {
return fmt.Errorf("api_codes 不能包含空值")
}
if code == "*" {
return fmt.Errorf("公开接口不允许 api_codes 使用通配符 *")
}
codes = append(codes, code)
}
return validateAPICodes(codes)
}
// ValidatePublicAPIManagementKey 校验平台管理密钥
func ValidatePublicAPIManagementKey(cfg *config.Config, key string) error {
if !cfg.QueryWhitelist.PublicAPI.Enabled {
return ErrPublicAPIDisabled
}
if strings.TrimSpace(key) == "" {
return ErrMissingMgmtKey
}
if !constantTimeEqual(key, cfg.QueryWhitelist.PublicAPI.ManagementKey) {
return ErrInvalidMgmtKey
}
return nil
}
// MapWhitelistAppError 将校验错误映射为公开 API 错误码
func MapWhitelistAppError(err error) error {
if err == nil {
return nil
}
msg := err.Error()
switch {
case strings.Contains(msg, "身份证号格式不正确"),
strings.Contains(msg, "api_codes"),
strings.Contains(msg, "name 不能为空"),
strings.Contains(msg, "user_id 不能为空"):
return ErrRequestParam
case strings.Contains(msg, "已存在"):
return ErrWhitelistExists
default:
return err
}
}
// PublicAPIErrorMessage 公开接口错误文案
func PublicAPIErrorMessage(err error) string {
if err == nil {
return ""
}
switch err {
case ErrWhitelistExists:
return "规则已存在"
case ErrRequestParam:
return "请求参数结构不正确"
default:
return GetErrorMessage(err)
}
}