This commit is contained in:
2026-06-19 10:49:13 +08:00
parent 82f759586a
commit d71a23fa57
13 changed files with 696 additions and 39 deletions

View File

@@ -34,6 +34,7 @@ type QueryWhitelistEntryResponse struct {
APICodes []string `json:"api_codes"`
Status string `json:"status"`
Remark string `json:"remark"`
OperationIP string `json:"operation_ip,omitempty"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
@@ -51,6 +52,19 @@ type QueryWhitelistImportLegacyResponse struct {
Total int `json:"total"`
}
// QueryWhitelistPublicEncryptedRequest 公开接口外层请求data 为 AES 密文)
type QueryWhitelistPublicEncryptedRequest struct {
Data string `json:"data" binding:"required"`
}
// QueryWhitelistPublicPayload 公开接口解密后的业务参数(不含 key/access_id身份由请求头 + 解密成功证明)
type QueryWhitelistPublicPayload struct {
Name string `json:"name"`
IDCard string `json:"id_card"`
APICodes []string `json:"api_codes"`
Remark string `json:"remark"`
}
func NewQueryWhitelistEntryResponse(entry *entities.QueryWhitelistEntry) QueryWhitelistEntryResponse {
apiCodes := []string(entry.APICodes)
if apiCodes == nil {
@@ -65,6 +79,7 @@ func NewQueryWhitelistEntryResponse(entry *entities.QueryWhitelistEntry) QueryWh
APICodes: apiCodes,
Status: entry.Status,
Remark: entry.Remark,
OperationIP: entry.OperationIP,
CreatedAt: entry.CreatedAt,
UpdatedAt: entry.UpdatedAt,
}

View File

@@ -25,6 +25,10 @@ var (
ErrBusiness = errors.New("业务失败")
ErrSubordinateLinkNotFound = errors.New("非子账号无法使用master_accessid")
ErrSubordinateParentMismatch = errors.New("master_accessid与主账号不匹配")
ErrMissingMgmtKey = errors.New("缺少管理密钥")
ErrInvalidMgmtKey = errors.New("管理密钥无效")
ErrPublicAPIDisabled = errors.New("公开接口未启用")
ErrWhitelistExists = errors.New("规则已存在")
)
// 错误码映射 - 严格按照用户要求
@@ -50,6 +54,10 @@ var ErrorCodeMap = map[error]int{
ErrBusiness: 2001,
ErrSubordinateLinkNotFound: 1301,
ErrSubordinateParentMismatch: 1302,
ErrMissingMgmtKey: 1010,
ErrInvalidMgmtKey: 1011,
ErrPublicAPIDisabled: 1012,
ErrWhitelistExists: 1013,
}
// GetErrorCode 获取错误对应的错误码

View File

@@ -7,6 +7,7 @@ import (
"strings"
"tyapi-server/internal/application/api/dto"
"tyapi-server/internal/config"
"tyapi-server/internal/domains/api/entities"
"tyapi-server/internal/domains/api/repositories"
api_services "tyapi-server/internal/domains/api/services"
@@ -18,6 +19,7 @@ import (
type QueryWhitelistApplicationService interface {
CreateEntry(ctx context.Context, adminUserID string, req *dto.QueryWhitelistEntryRequest) (*dto.QueryWhitelistEntryResponse, error)
CreateEntryPublic(ctx context.Context, headerAccessID, managementKey, clientIP, encryptedData string) (*dto.QueryWhitelistEntryResponse, string, error)
UpdateEntry(ctx context.Context, adminUserID, id string, req *dto.QueryWhitelistEntryUpdateRequest) (*dto.QueryWhitelistEntryResponse, error)
UpdateEntryStatus(ctx context.Context, adminUserID, id, status string) (*dto.QueryWhitelistEntryResponse, error)
DeleteEntry(ctx context.Context, id string) error
@@ -29,17 +31,23 @@ type QueryWhitelistApplicationService interface {
type QueryWhitelistApplicationServiceImpl struct {
repo repositories.QueryWhitelistRepository
queryWhitelistSvc api_services.QueryWhitelistService
apiUserService api_services.ApiUserAggregateService
config *config.Config
logger *zap.Logger
}
func NewQueryWhitelistApplicationService(
repo repositories.QueryWhitelistRepository,
queryWhitelistSvc api_services.QueryWhitelistService,
apiUserService api_services.ApiUserAggregateService,
config *config.Config,
logger *zap.Logger,
) QueryWhitelistApplicationService {
return &QueryWhitelistApplicationServiceImpl{
repo: repo,
queryWhitelistSvc: queryWhitelistSvc,
apiUserService: apiUserService,
config: config,
logger: logger,
}
}
@@ -49,35 +57,7 @@ func (s *QueryWhitelistApplicationServiceImpl) CreateEntry(
adminUserID string,
req *dto.QueryWhitelistEntryRequest,
) (*dto.QueryWhitelistEntryResponse, error) {
if err := validateQueryWhitelistRequest(req.UserID, req.Name, req.IDCard, req.APICodes); err != nil {
return nil, err
}
idCardHash := api_services.HashIDCard(req.IDCard)
exists, err := s.repo.ExistsByUserIDCardHashAndName(ctx, req.UserID, idCardHash, strings.TrimSpace(req.Name), "")
if err != nil {
return nil, err
}
if exists {
return nil, fmt.Errorf("该用户下已存在相同的身份证与姓名规则")
}
entry := &entities.QueryWhitelistEntry{
UserID: strings.TrimSpace(req.UserID),
Name: normalizeWhitelistName(req.Name),
IDCardHash: idCardHash,
IDCardMasked: api_services.MaskIDCard(req.IDCard),
APICodes: entities.APICodeList(req.APICodes),
Status: entities.QueryWhitelistStatusEnabled,
Remark: strings.TrimSpace(req.Remark),
CreatedBy: &adminUserID,
}
if err := s.repo.Create(ctx, entry); err != nil {
return nil, err
}
s.queryWhitelistSvc.InvalidateCache(entry.UserID, idCardHash)
resp := dto.NewQueryWhitelistEntryResponse(entry)
return &resp, nil
return s.createEntry(ctx, adminUserID, req, "")
}
func (s *QueryWhitelistApplicationServiceImpl) UpdateEntry(

View File

@@ -0,0 +1,213 @@
package api
import (
"context"
"crypto/subtle"
"encoding/json"
"strings"
"tyapi-server/internal/application/api/dto"
"tyapi-server/internal/config"
"tyapi-server/internal/domains/api/entities"
api_services "tyapi-server/internal/domains/api/services"
"tyapi-server/internal/shared/crypto"
"fmt"
"go.uber.org/zap"
)
const queryWhitelistMgmtKeyHeader = "Whitelist-Mgmt-Key"
// QueryWhitelistMgmtKeyHeader 公开接口管理密钥请求头名
func QueryWhitelistMgmtKeyHeader() string {
return queryWhitelistMgmtKeyHeader
}
// CreateEntryPublic 公开接口:解密业务参数后为当前 access_id 对应用户添加规则(仅对该用户生效)
func (s *QueryWhitelistApplicationServiceImpl) CreateEntryPublic(
ctx context.Context,
headerAccessID, managementKey, clientIP, encryptedData string,
) (*dto.QueryWhitelistEntryResponse, string, error) {
if !s.config.QueryWhitelist.PublicAPI.Enabled {
return nil, "", ErrPublicAPIDisabled
}
if strings.TrimSpace(managementKey) == "" {
return nil, "", ErrMissingMgmtKey
}
if !constantTimeEqual(managementKey, s.config.QueryWhitelist.PublicAPI.ManagementKey) {
return nil, "", ErrInvalidMgmtKey
}
headerAccessID = strings.TrimSpace(headerAccessID)
if headerAccessID == "" {
return nil, "", ErrMissingAccessId
}
apiUser, err := s.apiUserService.LoadApiUserByAccessId(ctx, headerAccessID)
if err != nil {
s.logger.Warn("公开白名单接口 AccessId 无效", zap.String("access_id", headerAccessID), zap.Error(err))
return nil, "", ErrInvalidAccessId
}
if apiUser.IsFrozen() {
return nil, "", ErrFrozenAccount
}
decrypted, err := crypto.AesDecrypt(encryptedData, apiUser.SecretKey)
if err != nil {
s.logger.Warn("公开白名单接口解密失败", zap.String("access_id", headerAccessID), zap.Error(err))
return nil, "", ErrDecryptFail
}
var raw map[string]json.RawMessage
if err := json.Unmarshal(decrypted, &raw); err != nil {
return nil, "", ErrRequestParam
}
if err := validatePublicAPICodesField(raw["api_codes"]); err != nil {
return nil, "", MapWhitelistAppError(err)
}
var payload dto.QueryWhitelistPublicPayload
if err := json.Unmarshal(decrypted, &payload); err != nil {
return nil, "", ErrRequestParam
}
if !s.config.App.IsDevelopment() && !apiUser.IsWhiteListed(clientIP) {
s.logger.Warn("公开白名单接口 IP 未授权",
zap.String("access_id", headerAccessID),
zap.String("client_ip", clientIP))
return nil, "", ErrInvalidIP
}
createdBy := "public_api:" + headerAccessID
resp, err := s.createEntry(ctx, createdBy, &dto.QueryWhitelistEntryRequest{
UserID: apiUser.UserId,
Name: payload.Name,
IDCard: payload.IDCard,
APICodes: payload.APICodes,
Remark: payload.Remark,
}, clientIP)
return resp, apiUser.SecretKey, MapWhitelistAppError(err)
}
func (s *QueryWhitelistApplicationServiceImpl) createEntry(
ctx context.Context,
createdBy string,
req *dto.QueryWhitelistEntryRequest,
operationIP string,
) (*dto.QueryWhitelistEntryResponse, error) {
if err := validateQueryWhitelistRequest(req.UserID, req.Name, req.IDCard, req.APICodes); err != nil {
return nil, err
}
idCardHash := api_services.HashIDCard(req.IDCard)
exists, err := s.repo.ExistsByUserIDCardHashAndName(ctx, req.UserID, idCardHash, strings.TrimSpace(req.Name), "")
if err != nil {
return nil, err
}
if exists {
return nil, fmt.Errorf("该用户下已存在相同的身份证与姓名规则")
}
entry := &entities.QueryWhitelistEntry{
UserID: strings.TrimSpace(req.UserID),
Name: normalizeWhitelistName(req.Name),
IDCardHash: idCardHash,
IDCardMasked: api_services.MaskIDCard(req.IDCard),
APICodes: entities.APICodeList(req.APICodes),
Status: entities.QueryWhitelistStatusEnabled,
Remark: strings.TrimSpace(req.Remark),
OperationIP: strings.TrimSpace(operationIP),
CreatedBy: &createdBy,
}
if err := s.repo.Create(ctx, entry); err != nil {
return nil, err
}
s.queryWhitelistSvc.InvalidateCache(entry.UserID, idCardHash)
resp := dto.NewQueryWhitelistEntryResponse(entry)
return &resp, nil
}
func constantTimeEqual(a, b string) bool {
return subtle.ConstantTimeCompare([]byte(a), []byte(b)) == 1
}
// validatePublicAPICodesField 公开接口api_codes 必须为 string 数组,且禁止通配符 *
func validatePublicAPICodesField(raw json.RawMessage) error {
if len(raw) == 0 {
return fmt.Errorf("api_codes 必填且必须为 string 数组")
}
var probe interface{}
if err := json.Unmarshal(raw, &probe); err != nil {
return fmt.Errorf("api_codes 必须为 string 数组")
}
arr, ok := probe.([]interface{})
if !ok {
return fmt.Errorf("api_codes 必须为 string 数组,不可传字符串或其他类型")
}
if len(arr) == 0 {
return fmt.Errorf("api_codes 不能为空")
}
codes := make([]string, 0, len(arr))
for _, item := range arr {
s, ok := item.(string)
if !ok {
return fmt.Errorf("api_codes 数组元素必须为 string")
}
code := strings.TrimSpace(s)
if code == "" {
return fmt.Errorf("api_codes 不能包含空值")
}
if code == "*" {
return fmt.Errorf("公开接口不允许 api_codes 使用通配符 *")
}
codes = append(codes, code)
}
return validateAPICodes(codes)
}
// ValidatePublicAPIManagementKey 校验平台管理密钥
func ValidatePublicAPIManagementKey(cfg *config.Config, key string) error {
if !cfg.QueryWhitelist.PublicAPI.Enabled {
return ErrPublicAPIDisabled
}
if strings.TrimSpace(key) == "" {
return ErrMissingMgmtKey
}
if !constantTimeEqual(key, cfg.QueryWhitelist.PublicAPI.ManagementKey) {
return ErrInvalidMgmtKey
}
return nil
}
// MapWhitelistAppError 将校验错误映射为公开 API 错误码
func MapWhitelistAppError(err error) error {
if err == nil {
return nil
}
msg := err.Error()
switch {
case strings.Contains(msg, "身份证号格式不正确"),
strings.Contains(msg, "api_codes"),
strings.Contains(msg, "name 不能为空"),
strings.Contains(msg, "user_id 不能为空"):
return ErrRequestParam
case strings.Contains(msg, "已存在"):
return ErrWhitelistExists
default:
return err
}
}
// PublicAPIErrorMessage 公开接口错误文案
func PublicAPIErrorMessage(err error) string {
if err == nil {
return ""
}
switch err {
case ErrWhitelistExists:
return "规则已存在"
case ErrRequestParam:
return "请求参数结构不正确"
default:
return GetErrorMessage(err)
}
}

View File

@@ -0,0 +1,34 @@
package api
import (
"encoding/json"
"testing"
)
func TestValidatePublicAPICodesField(t *testing.T) {
tests := []struct {
name string
raw string
wantErr bool
}{
{name: "valid array", raw: `["FLXG0V4B","JRZQ8A2D"]`},
{name: "string not array", raw: `"FLXG0V4B"`, wantErr: true},
{name: "wildcard star", raw: `["*"]`, wantErr: true},
{name: "star mixed", raw: `["FLXG0V4B","*"]`, wantErr: true},
{name: "empty array", raw: `[]`, wantErr: true},
{name: "missing field", raw: "", wantErr: true},
{name: "number element", raw: `[123]`, wantErr: true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var raw json.RawMessage
if tt.raw != "" {
raw = json.RawMessage(tt.raw)
}
err := validatePublicAPICodesField(raw)
if (err != nil) != tt.wantErr {
t.Fatalf("validatePublicAPICodesField() err=%v wantErr=%v", err, tt.wantErr)
}
})
}
}