This commit is contained in:
2026-06-19 10:56:52 +08:00
parent d71a23fa57
commit 2dcba2e5d9
10 changed files with 310 additions and 80 deletions

View File

@@ -4,6 +4,8 @@ import (
"context"
"crypto/subtle"
"encoding/json"
"errors"
"fmt"
"strings"
"tyapi-server/internal/application/api/dto"
@@ -11,9 +13,9 @@ import (
"tyapi-server/internal/domains/api/entities"
api_services "tyapi-server/internal/domains/api/services"
"tyapi-server/internal/shared/crypto"
"fmt"
"go.uber.org/zap"
"gorm.io/gorm"
)
const queryWhitelistMgmtKeyHeader = "Whitelist-Mgmt-Key"
@@ -23,61 +25,17 @@ func QueryWhitelistMgmtKeyHeader() string {
return queryWhitelistMgmtKeyHeader
}
// CreateEntryPublic 公开接口:解密业务参数后为当前 access_id 对应用户添加规则(仅对该用户生效
// CreateEntryPublic 公开接口:新建规则(同用户+身份证+姓名已存在则拒绝
func (s *QueryWhitelistApplicationServiceImpl) CreateEntryPublic(
ctx context.Context,
headerAccessID, managementKey, clientIP, encryptedData string,
) (*dto.QueryWhitelistEntryResponse, string, error) {
if !s.config.QueryWhitelist.PublicAPI.Enabled {
return nil, "", ErrPublicAPIDisabled
}
if strings.TrimSpace(managementKey) == "" {
return nil, "", ErrMissingMgmtKey
}
if !constantTimeEqual(managementKey, s.config.QueryWhitelist.PublicAPI.ManagementKey) {
return nil, "", ErrInvalidMgmtKey
}
headerAccessID = strings.TrimSpace(headerAccessID)
if headerAccessID == "" {
return nil, "", ErrMissingAccessId
}
apiUser, err := s.apiUserService.LoadApiUserByAccessId(ctx, headerAccessID)
apiUser, payload, err := s.preparePublicRequest(ctx, headerAccessID, managementKey, clientIP, encryptedData)
if err != nil {
s.logger.Warn("公开白名单接口 AccessId 无效", zap.String("access_id", headerAccessID), zap.Error(err))
return nil, "", ErrInvalidAccessId
}
if apiUser.IsFrozen() {
return nil, "", ErrFrozenAccount
return nil, "", err
}
decrypted, err := crypto.AesDecrypt(encryptedData, apiUser.SecretKey)
if err != nil {
s.logger.Warn("公开白名单接口解密失败", zap.String("access_id", headerAccessID), zap.Error(err))
return nil, "", ErrDecryptFail
}
var raw map[string]json.RawMessage
if err := json.Unmarshal(decrypted, &raw); err != nil {
return nil, "", ErrRequestParam
}
if err := validatePublicAPICodesField(raw["api_codes"]); err != nil {
return nil, "", MapWhitelistAppError(err)
}
var payload dto.QueryWhitelistPublicPayload
if err := json.Unmarshal(decrypted, &payload); err != nil {
return nil, "", ErrRequestParam
}
if !s.config.App.IsDevelopment() && !apiUser.IsWhiteListed(clientIP) {
s.logger.Warn("公开白名单接口 IP 未授权",
zap.String("access_id", headerAccessID),
zap.String("client_ip", clientIP))
return nil, "", ErrInvalidIP
}
createdBy := "public_api:" + headerAccessID
createdBy := "public_api:" + strings.TrimSpace(headerAccessID)
resp, err := s.createEntry(ctx, createdBy, &dto.QueryWhitelistEntryRequest{
UserID: apiUser.UserId,
Name: payload.Name,
@@ -88,6 +46,115 @@ func (s *QueryWhitelistApplicationServiceImpl) CreateEntryPublic(
return resp, apiUser.SecretKey, MapWhitelistAppError(err)
}
// AppendEntryPublic 公开接口:向已有规则追加 api_codes去重合并不新建记录
func (s *QueryWhitelistApplicationServiceImpl) AppendEntryPublic(
ctx context.Context,
headerAccessID, managementKey, clientIP, encryptedData string,
) (*dto.QueryWhitelistEntryResponse, string, error) {
apiUser, payload, err := s.preparePublicRequest(ctx, headerAccessID, managementKey, clientIP, encryptedData)
if err != nil {
return nil, "", err
}
if err := validateIDCard(payload.IDCard); err != nil {
return nil, "", MapWhitelistAppError(err)
}
if strings.TrimSpace(payload.Name) == "" {
return nil, "", ErrRequestParam
}
userID := apiUser.UserId
idCardHash := api_services.HashIDCard(payload.IDCard)
name := normalizeWhitelistName(payload.Name)
entry, err := s.repo.FindByUserIDCardHashAndName(ctx, userID, idCardHash, name)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, "", ErrWhitelistNotFound
}
return nil, "", err
}
for _, code := range entry.APICodes {
if code == "*" {
return nil, "", ErrRequestParam
}
}
merged := mergeAPICodes([]string(entry.APICodes), payload.APICodes)
entry.APICodes = entities.APICodeList(merged)
if remark := strings.TrimSpace(payload.Remark); remark != "" {
entry.Remark = remark
}
updatedBy := "public_api_append:" + strings.TrimSpace(headerAccessID)
entry.UpdatedBy = &updatedBy
entry.OperationIP = strings.TrimSpace(clientIP)
if err := s.repo.Update(ctx, entry); err != nil {
return nil, "", err
}
s.queryWhitelistSvc.InvalidateCache(entry.UserID, idCardHash)
resp := dto.NewQueryWhitelistEntryResponse(entry)
return &resp, apiUser.SecretKey, nil
}
func (s *QueryWhitelistApplicationServiceImpl) preparePublicRequest(
ctx context.Context,
headerAccessID, managementKey, clientIP, encryptedData string,
) (*entities.ApiUser, dto.QueryWhitelistPublicPayload, error) {
if !s.config.QueryWhitelist.PublicAPI.Enabled {
return nil, dto.QueryWhitelistPublicPayload{}, ErrPublicAPIDisabled
}
if strings.TrimSpace(managementKey) == "" {
return nil, dto.QueryWhitelistPublicPayload{}, ErrMissingMgmtKey
}
if !constantTimeEqual(managementKey, s.config.QueryWhitelist.PublicAPI.ManagementKey) {
return nil, dto.QueryWhitelistPublicPayload{}, ErrInvalidMgmtKey
}
headerAccessID = strings.TrimSpace(headerAccessID)
if headerAccessID == "" {
return nil, dto.QueryWhitelistPublicPayload{}, ErrMissingAccessId
}
apiUser, err := s.apiUserService.LoadApiUserByAccessId(ctx, headerAccessID)
if err != nil {
s.logger.Warn("公开白名单接口 AccessId 无效", zap.String("access_id", headerAccessID), zap.Error(err))
return nil, dto.QueryWhitelistPublicPayload{}, ErrInvalidAccessId
}
if apiUser.IsFrozen() {
return nil, dto.QueryWhitelistPublicPayload{}, ErrFrozenAccount
}
decrypted, err := crypto.AesDecrypt(encryptedData, apiUser.SecretKey)
if err != nil {
s.logger.Warn("公开白名单接口解密失败", zap.String("access_id", headerAccessID), zap.Error(err))
return nil, dto.QueryWhitelistPublicPayload{}, ErrDecryptFail
}
var raw map[string]json.RawMessage
if err := json.Unmarshal(decrypted, &raw); err != nil {
return nil, dto.QueryWhitelistPublicPayload{}, ErrRequestParam
}
if err := validatePublicAPICodesField(raw["api_codes"]); err != nil {
return nil, dto.QueryWhitelistPublicPayload{}, MapWhitelistAppError(err)
}
var payload dto.QueryWhitelistPublicPayload
if err := json.Unmarshal(decrypted, &payload); err != nil {
return nil, dto.QueryWhitelistPublicPayload{}, ErrRequestParam
}
if !s.config.App.IsDevelopment() && !apiUser.IsWhiteListed(clientIP) {
s.logger.Warn("公开白名单接口 IP 未授权",
zap.String("access_id", headerAccessID),
zap.String("client_ip", clientIP))
return nil, dto.QueryWhitelistPublicPayload{}, ErrInvalidIP
}
return apiUser, payload, nil
}
func (s *QueryWhitelistApplicationServiceImpl) createEntry(
ctx context.Context,
createdBy string,
@@ -126,6 +193,35 @@ func (s *QueryWhitelistApplicationServiceImpl) createEntry(
return &resp, nil
}
// mergeAPICodes 合并接口编码并去重,保持原有顺序,新增编码追加在后
func mergeAPICodes(existing, incoming []string) []string {
seen := make(map[string]struct{}, len(existing)+len(incoming))
result := make([]string, 0, len(existing)+len(incoming))
for _, code := range existing {
code = strings.TrimSpace(code)
if code == "" {
continue
}
if _, ok := seen[code]; ok {
continue
}
seen[code] = struct{}{}
result = append(result, code)
}
for _, code := range incoming {
code = strings.TrimSpace(code)
if code == "" {
continue
}
if _, ok := seen[code]; ok {
continue
}
seen[code] = struct{}{}
result = append(result, code)
}
return result
}
func constantTimeEqual(a, b string) bool {
return subtle.ConstantTimeCompare([]byte(a), []byte(b)) == 1
}
@@ -205,6 +301,8 @@ func PublicAPIErrorMessage(err error) string {
switch err {
case ErrWhitelistExists:
return "规则已存在"
case ErrWhitelistNotFound:
return "规则不存在,请先调用创建接口"
case ErrRequestParam:
return "请求参数结构不正确"
default: