312 lines
9.5 KiB
Go
312 lines
9.5 KiB
Go
package api
|
||
|
||
import (
|
||
"context"
|
||
"crypto/subtle"
|
||
"encoding/json"
|
||
"errors"
|
||
"fmt"
|
||
"strings"
|
||
|
||
"tyapi-server/internal/application/api/dto"
|
||
"tyapi-server/internal/config"
|
||
"tyapi-server/internal/domains/api/entities"
|
||
api_services "tyapi-server/internal/domains/api/services"
|
||
"tyapi-server/internal/shared/crypto"
|
||
|
||
"go.uber.org/zap"
|
||
"gorm.io/gorm"
|
||
)
|
||
|
||
const queryWhitelistMgmtKeyHeader = "Whitelist-Mgmt-Key"
|
||
|
||
// QueryWhitelistMgmtKeyHeader 公开接口管理密钥请求头名
|
||
func QueryWhitelistMgmtKeyHeader() string {
|
||
return queryWhitelistMgmtKeyHeader
|
||
}
|
||
|
||
// CreateEntryPublic 公开接口:新建规则(同用户+身份证+姓名已存在则拒绝)
|
||
func (s *QueryWhitelistApplicationServiceImpl) CreateEntryPublic(
|
||
ctx context.Context,
|
||
headerAccessID, managementKey, clientIP, encryptedData string,
|
||
) (*dto.QueryWhitelistEntryResponse, string, error) {
|
||
apiUser, payload, err := s.preparePublicRequest(ctx, headerAccessID, managementKey, clientIP, encryptedData)
|
||
if err != nil {
|
||
return nil, "", err
|
||
}
|
||
|
||
createdBy := "public_api:" + strings.TrimSpace(headerAccessID)
|
||
resp, err := s.createEntry(ctx, createdBy, &dto.QueryWhitelistEntryRequest{
|
||
UserID: apiUser.UserId,
|
||
Name: payload.Name,
|
||
IDCard: payload.IDCard,
|
||
APICodes: payload.APICodes,
|
||
Remark: payload.Remark,
|
||
}, clientIP)
|
||
return resp, apiUser.SecretKey, MapWhitelistAppError(err)
|
||
}
|
||
|
||
// AppendEntryPublic 公开接口:向已有规则追加 api_codes(去重合并,不新建记录)
|
||
func (s *QueryWhitelistApplicationServiceImpl) AppendEntryPublic(
|
||
ctx context.Context,
|
||
headerAccessID, managementKey, clientIP, encryptedData string,
|
||
) (*dto.QueryWhitelistEntryResponse, string, error) {
|
||
apiUser, payload, err := s.preparePublicRequest(ctx, headerAccessID, managementKey, clientIP, encryptedData)
|
||
if err != nil {
|
||
return nil, "", err
|
||
}
|
||
|
||
if err := validateIDCard(payload.IDCard); err != nil {
|
||
return nil, "", MapWhitelistAppError(err)
|
||
}
|
||
if strings.TrimSpace(payload.Name) == "" {
|
||
return nil, "", ErrRequestParam
|
||
}
|
||
|
||
userID := apiUser.UserId
|
||
idCardHash := api_services.HashIDCard(payload.IDCard)
|
||
name := normalizeWhitelistName(payload.Name)
|
||
|
||
entry, err := s.repo.FindByUserIDCardHashAndName(ctx, userID, idCardHash, name)
|
||
if err != nil {
|
||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||
return nil, "", ErrWhitelistNotFound
|
||
}
|
||
return nil, "", err
|
||
}
|
||
|
||
for _, code := range entry.APICodes {
|
||
if code == "*" {
|
||
return nil, "", ErrRequestParam
|
||
}
|
||
}
|
||
|
||
merged := mergeAPICodes([]string(entry.APICodes), payload.APICodes)
|
||
entry.APICodes = entities.APICodeList(merged)
|
||
if remark := strings.TrimSpace(payload.Remark); remark != "" {
|
||
entry.Remark = remark
|
||
}
|
||
updatedBy := "public_api_append:" + strings.TrimSpace(headerAccessID)
|
||
entry.UpdatedBy = &updatedBy
|
||
entry.OperationIP = strings.TrimSpace(clientIP)
|
||
|
||
if err := s.repo.Update(ctx, entry); err != nil {
|
||
return nil, "", err
|
||
}
|
||
s.queryWhitelistSvc.InvalidateCache(entry.UserID, idCardHash)
|
||
|
||
resp := dto.NewQueryWhitelistEntryResponse(entry)
|
||
return &resp, apiUser.SecretKey, nil
|
||
}
|
||
|
||
func (s *QueryWhitelistApplicationServiceImpl) preparePublicRequest(
|
||
ctx context.Context,
|
||
headerAccessID, managementKey, clientIP, encryptedData string,
|
||
) (*entities.ApiUser, dto.QueryWhitelistPublicPayload, error) {
|
||
if !s.config.QueryWhitelist.PublicAPI.Enabled {
|
||
return nil, dto.QueryWhitelistPublicPayload{}, ErrPublicAPIDisabled
|
||
}
|
||
if strings.TrimSpace(managementKey) == "" {
|
||
return nil, dto.QueryWhitelistPublicPayload{}, ErrMissingMgmtKey
|
||
}
|
||
if !constantTimeEqual(managementKey, s.config.QueryWhitelist.PublicAPI.ManagementKey) {
|
||
return nil, dto.QueryWhitelistPublicPayload{}, ErrInvalidMgmtKey
|
||
}
|
||
headerAccessID = strings.TrimSpace(headerAccessID)
|
||
if headerAccessID == "" {
|
||
return nil, dto.QueryWhitelistPublicPayload{}, ErrMissingAccessId
|
||
}
|
||
|
||
apiUser, err := s.apiUserService.LoadApiUserByAccessId(ctx, headerAccessID)
|
||
if err != nil {
|
||
s.logger.Warn("公开白名单接口 AccessId 无效", zap.String("access_id", headerAccessID), zap.Error(err))
|
||
return nil, dto.QueryWhitelistPublicPayload{}, ErrInvalidAccessId
|
||
}
|
||
if apiUser.IsFrozen() {
|
||
return nil, dto.QueryWhitelistPublicPayload{}, ErrFrozenAccount
|
||
}
|
||
|
||
decrypted, err := crypto.AesDecrypt(encryptedData, apiUser.SecretKey)
|
||
if err != nil {
|
||
s.logger.Warn("公开白名单接口解密失败", zap.String("access_id", headerAccessID), zap.Error(err))
|
||
return nil, dto.QueryWhitelistPublicPayload{}, ErrDecryptFail
|
||
}
|
||
|
||
var raw map[string]json.RawMessage
|
||
if err := json.Unmarshal(decrypted, &raw); err != nil {
|
||
return nil, dto.QueryWhitelistPublicPayload{}, ErrRequestParam
|
||
}
|
||
if err := validatePublicAPICodesField(raw["api_codes"]); err != nil {
|
||
return nil, dto.QueryWhitelistPublicPayload{}, MapWhitelistAppError(err)
|
||
}
|
||
|
||
var payload dto.QueryWhitelistPublicPayload
|
||
if err := json.Unmarshal(decrypted, &payload); err != nil {
|
||
return nil, dto.QueryWhitelistPublicPayload{}, ErrRequestParam
|
||
}
|
||
|
||
if !s.config.App.IsDevelopment() && !apiUser.IsWhiteListed(clientIP) {
|
||
s.logger.Warn("公开白名单接口 IP 未授权",
|
||
zap.String("access_id", headerAccessID),
|
||
zap.String("client_ip", clientIP))
|
||
return nil, dto.QueryWhitelistPublicPayload{}, ErrInvalidIP
|
||
}
|
||
|
||
return apiUser, payload, nil
|
||
}
|
||
|
||
func (s *QueryWhitelistApplicationServiceImpl) createEntry(
|
||
ctx context.Context,
|
||
createdBy string,
|
||
req *dto.QueryWhitelistEntryRequest,
|
||
operationIP string,
|
||
) (*dto.QueryWhitelistEntryResponse, error) {
|
||
if err := validateQueryWhitelistRequest(req.UserID, req.Name, req.IDCard, req.APICodes); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
idCardHash := api_services.HashIDCard(req.IDCard)
|
||
exists, err := s.repo.ExistsByUserIDCardHashAndName(ctx, req.UserID, idCardHash, strings.TrimSpace(req.Name), "")
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
if exists {
|
||
return nil, fmt.Errorf("该用户下已存在相同的身份证与姓名规则")
|
||
}
|
||
|
||
entry := &entities.QueryWhitelistEntry{
|
||
UserID: strings.TrimSpace(req.UserID),
|
||
Name: normalizeWhitelistName(req.Name),
|
||
IDCardHash: idCardHash,
|
||
IDCardMasked: api_services.MaskIDCard(req.IDCard),
|
||
APICodes: entities.APICodeList(req.APICodes),
|
||
Status: entities.QueryWhitelistStatusEnabled,
|
||
Remark: strings.TrimSpace(req.Remark),
|
||
OperationIP: strings.TrimSpace(operationIP),
|
||
CreatedBy: &createdBy,
|
||
}
|
||
if err := s.repo.Create(ctx, entry); err != nil {
|
||
return nil, err
|
||
}
|
||
s.queryWhitelistSvc.InvalidateCache(entry.UserID, idCardHash)
|
||
resp := dto.NewQueryWhitelistEntryResponse(entry)
|
||
return &resp, nil
|
||
}
|
||
|
||
// mergeAPICodes 合并接口编码并去重,保持原有顺序,新增编码追加在后
|
||
func mergeAPICodes(existing, incoming []string) []string {
|
||
seen := make(map[string]struct{}, len(existing)+len(incoming))
|
||
result := make([]string, 0, len(existing)+len(incoming))
|
||
for _, code := range existing {
|
||
code = strings.TrimSpace(code)
|
||
if code == "" {
|
||
continue
|
||
}
|
||
if _, ok := seen[code]; ok {
|
||
continue
|
||
}
|
||
seen[code] = struct{}{}
|
||
result = append(result, code)
|
||
}
|
||
for _, code := range incoming {
|
||
code = strings.TrimSpace(code)
|
||
if code == "" {
|
||
continue
|
||
}
|
||
if _, ok := seen[code]; ok {
|
||
continue
|
||
}
|
||
seen[code] = struct{}{}
|
||
result = append(result, code)
|
||
}
|
||
return result
|
||
}
|
||
|
||
func constantTimeEqual(a, b string) bool {
|
||
return subtle.ConstantTimeCompare([]byte(a), []byte(b)) == 1
|
||
}
|
||
|
||
// validatePublicAPICodesField 公开接口:api_codes 必须为 string 数组,且禁止通配符 *
|
||
func validatePublicAPICodesField(raw json.RawMessage) error {
|
||
if len(raw) == 0 {
|
||
return fmt.Errorf("api_codes 必填且必须为 string 数组")
|
||
}
|
||
var probe interface{}
|
||
if err := json.Unmarshal(raw, &probe); err != nil {
|
||
return fmt.Errorf("api_codes 必须为 string 数组")
|
||
}
|
||
arr, ok := probe.([]interface{})
|
||
if !ok {
|
||
return fmt.Errorf("api_codes 必须为 string 数组,不可传字符串或其他类型")
|
||
}
|
||
if len(arr) == 0 {
|
||
return fmt.Errorf("api_codes 不能为空")
|
||
}
|
||
codes := make([]string, 0, len(arr))
|
||
for _, item := range arr {
|
||
s, ok := item.(string)
|
||
if !ok {
|
||
return fmt.Errorf("api_codes 数组元素必须为 string")
|
||
}
|
||
code := strings.TrimSpace(s)
|
||
if code == "" {
|
||
return fmt.Errorf("api_codes 不能包含空值")
|
||
}
|
||
if code == "*" {
|
||
return fmt.Errorf("公开接口不允许 api_codes 使用通配符 *")
|
||
}
|
||
codes = append(codes, code)
|
||
}
|
||
return validateAPICodes(codes)
|
||
}
|
||
|
||
// ValidatePublicAPIManagementKey 校验平台管理密钥
|
||
func ValidatePublicAPIManagementKey(cfg *config.Config, key string) error {
|
||
if !cfg.QueryWhitelist.PublicAPI.Enabled {
|
||
return ErrPublicAPIDisabled
|
||
}
|
||
if strings.TrimSpace(key) == "" {
|
||
return ErrMissingMgmtKey
|
||
}
|
||
if !constantTimeEqual(key, cfg.QueryWhitelist.PublicAPI.ManagementKey) {
|
||
return ErrInvalidMgmtKey
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// MapWhitelistAppError 将校验错误映射为公开 API 错误码
|
||
func MapWhitelistAppError(err error) error {
|
||
if err == nil {
|
||
return nil
|
||
}
|
||
msg := err.Error()
|
||
switch {
|
||
case strings.Contains(msg, "身份证号格式不正确"),
|
||
strings.Contains(msg, "api_codes"),
|
||
strings.Contains(msg, "name 不能为空"),
|
||
strings.Contains(msg, "user_id 不能为空"):
|
||
return ErrRequestParam
|
||
case strings.Contains(msg, "已存在"):
|
||
return ErrWhitelistExists
|
||
default:
|
||
return err
|
||
}
|
||
}
|
||
|
||
// PublicAPIErrorMessage 公开接口错误文案
|
||
func PublicAPIErrorMessage(err error) string {
|
||
if err == nil {
|
||
return ""
|
||
}
|
||
switch err {
|
||
case ErrWhitelistExists:
|
||
return "规则已存在"
|
||
case ErrWhitelistNotFound:
|
||
return "规则不存在,请先调用创建接口"
|
||
case ErrRequestParam:
|
||
return "请求参数结构不正确"
|
||
default:
|
||
return GetErrorMessage(err)
|
||
}
|
||
}
|