@@ -4,6 +4,8 @@ import (
"context"
"crypto/subtle"
"encoding/json"
"errors"
"fmt"
"strings"
"tyapi-server/internal/application/api/dto"
@@ -11,9 +13,9 @@ import (
"tyapi-server/internal/domains/api/entities"
api_services "tyapi-server/internal/domains/api/services"
"tyapi-server/internal/shared/crypto"
"fmt"
"go.uber.org/zap"
"gorm.io/gorm"
)
const queryWhitelistMgmtKeyHeader = "Whitelist-Mgmt-Key"
@@ -23,61 +25,17 @@ func QueryWhitelistMgmtKeyHeader() string {
return queryWhitelistMgmtKeyHeader
}
// CreateEntryPublic 公开接口:解密业务参数后为当前 access_id 对应用户添加规则(仅对该用户生效 )
// CreateEntryPublic 公开接口:新建规则(同用户+身份证+姓名已存在则拒绝 )
func ( s * QueryWhitelistApplicationServiceImpl ) CreateEntryPublic (
ctx context . Context ,
headerAccessID , managementKey , clientIP , encryptedData string ,
) ( * dto . QueryWhitelistEntryResponse , string , error ) {
if ! s . config . QueryWhitelist . PublicAPI . Enabled {
return nil , "" , ErrPublicAPIDisabled
}
if strings . TrimSpace ( managementKey ) == "" {
return nil , "" , ErrMissingMgmtKey
}
if ! constantTimeEqual ( managementKey , s . config . QueryWhitelist . PublicAPI . ManagementKey ) {
return nil , "" , ErrInvalidMgmtKey
}
headerAccessID = strings . TrimSpace ( headerAccessID )
if headerAccessID == "" {
return nil , "" , ErrMissingAccessId
}
apiUser , err := s . apiUserService . LoadApiUserByAccessId ( ctx , headerAccessID )
apiUser , payload , err := s . preparePublicRequest ( ctx , headerAccessID , managementKey , clientIP , encryptedData )
if err != nil {
s . logger . Warn ( "公开白名单接口 AccessId 无效" , zap . String ( "access_id" , headerAccessID ) , zap . Error ( err ) )
return nil , "" , ErrInvalidAccessId
}
if apiUser . IsFrozen ( ) {
return nil , "" , ErrFrozenAccount
return nil , "" , err
}
decrypted , err := crypto . AesDecrypt ( encryptedData , apiUser . SecretKey )
if err != nil {
s . logger . Warn ( "公开白名单接口解密失败" , zap . String ( "access_id" , headerAccessID ) , zap . Error ( err ) )
return nil , "" , ErrDecryptFail
}
var raw map [ string ] json . RawMessage
if err := json . Unmarshal ( decrypted , & raw ) ; err != nil {
return nil , "" , ErrRequestParam
}
if err := validatePublicAPICodesField ( raw [ "api_codes" ] ) ; err != nil {
return nil , "" , MapWhitelistAppError ( err )
}
var payload dto . QueryWhitelistPublicPayload
if err := json . Unmarshal ( decrypted , & payload ) ; err != nil {
return nil , "" , ErrRequestParam
}
if ! s . config . App . IsDevelopment ( ) && ! apiUser . IsWhiteListed ( clientIP ) {
s . logger . Warn ( "公开白名单接口 IP 未授权" ,
zap . String ( "access_id" , headerAccessID ) ,
zap . String ( "client_ip" , clientIP ) )
return nil , "" , ErrInvalidIP
}
createdBy := "public_api:" + headerAccessID
createdBy := "public_api:" + strings . TrimSpace ( headerAccessID )
resp , err := s . createEntry ( ctx , createdBy , & dto . QueryWhitelistEntryRequest {
UserID : apiUser . UserId ,
Name : payload . Name ,
@@ -88,6 +46,115 @@ func (s *QueryWhitelistApplicationServiceImpl) CreateEntryPublic(
return resp , apiUser . SecretKey , MapWhitelistAppError ( err )
}
// AppendEntryPublic 公开接口:向已有规则追加 api_codes( 去重合并, 不新建记录)
func ( s * QueryWhitelistApplicationServiceImpl ) AppendEntryPublic (
ctx context . Context ,
headerAccessID , managementKey , clientIP , encryptedData string ,
) ( * dto . QueryWhitelistEntryResponse , string , error ) {
apiUser , payload , err := s . preparePublicRequest ( ctx , headerAccessID , managementKey , clientIP , encryptedData )
if err != nil {
return nil , "" , err
}
if err := validateIDCard ( payload . IDCard ) ; err != nil {
return nil , "" , MapWhitelistAppError ( err )
}
if strings . TrimSpace ( payload . Name ) == "" {
return nil , "" , ErrRequestParam
}
userID := apiUser . UserId
idCardHash := api_services . HashIDCard ( payload . IDCard )
name := normalizeWhitelistName ( payload . Name )
entry , err := s . repo . FindByUserIDCardHashAndName ( ctx , userID , idCardHash , name )
if err != nil {
if errors . Is ( err , gorm . ErrRecordNotFound ) {
return nil , "" , ErrWhitelistNotFound
}
return nil , "" , err
}
for _ , code := range entry . APICodes {
if code == "*" {
return nil , "" , ErrRequestParam
}
}
merged := mergeAPICodes ( [ ] string ( entry . APICodes ) , payload . APICodes )
entry . APICodes = entities . APICodeList ( merged )
if remark := strings . TrimSpace ( payload . Remark ) ; remark != "" {
entry . Remark = remark
}
updatedBy := "public_api_append:" + strings . TrimSpace ( headerAccessID )
entry . UpdatedBy = & updatedBy
entry . OperationIP = strings . TrimSpace ( clientIP )
if err := s . repo . Update ( ctx , entry ) ; err != nil {
return nil , "" , err
}
s . queryWhitelistSvc . InvalidateCache ( entry . UserID , idCardHash )
resp := dto . NewQueryWhitelistEntryResponse ( entry )
return & resp , apiUser . SecretKey , nil
}
func ( s * QueryWhitelistApplicationServiceImpl ) preparePublicRequest (
ctx context . Context ,
headerAccessID , managementKey , clientIP , encryptedData string ,
) ( * entities . ApiUser , dto . QueryWhitelistPublicPayload , error ) {
if ! s . config . QueryWhitelist . PublicAPI . Enabled {
return nil , dto . QueryWhitelistPublicPayload { } , ErrPublicAPIDisabled
}
if strings . TrimSpace ( managementKey ) == "" {
return nil , dto . QueryWhitelistPublicPayload { } , ErrMissingMgmtKey
}
if ! constantTimeEqual ( managementKey , s . config . QueryWhitelist . PublicAPI . ManagementKey ) {
return nil , dto . QueryWhitelistPublicPayload { } , ErrInvalidMgmtKey
}
headerAccessID = strings . TrimSpace ( headerAccessID )
if headerAccessID == "" {
return nil , dto . QueryWhitelistPublicPayload { } , ErrMissingAccessId
}
apiUser , err := s . apiUserService . LoadApiUserByAccessId ( ctx , headerAccessID )
if err != nil {
s . logger . Warn ( "公开白名单接口 AccessId 无效" , zap . String ( "access_id" , headerAccessID ) , zap . Error ( err ) )
return nil , dto . QueryWhitelistPublicPayload { } , ErrInvalidAccessId
}
if apiUser . IsFrozen ( ) {
return nil , dto . QueryWhitelistPublicPayload { } , ErrFrozenAccount
}
decrypted , err := crypto . AesDecrypt ( encryptedData , apiUser . SecretKey )
if err != nil {
s . logger . Warn ( "公开白名单接口解密失败" , zap . String ( "access_id" , headerAccessID ) , zap . Error ( err ) )
return nil , dto . QueryWhitelistPublicPayload { } , ErrDecryptFail
}
var raw map [ string ] json . RawMessage
if err := json . Unmarshal ( decrypted , & raw ) ; err != nil {
return nil , dto . QueryWhitelistPublicPayload { } , ErrRequestParam
}
if err := validatePublicAPICodesField ( raw [ "api_codes" ] ) ; err != nil {
return nil , dto . QueryWhitelistPublicPayload { } , MapWhitelistAppError ( err )
}
var payload dto . QueryWhitelistPublicPayload
if err := json . Unmarshal ( decrypted , & payload ) ; err != nil {
return nil , dto . QueryWhitelistPublicPayload { } , ErrRequestParam
}
if ! s . config . App . IsDevelopment ( ) && ! apiUser . IsWhiteListed ( clientIP ) {
s . logger . Warn ( "公开白名单接口 IP 未授权" ,
zap . String ( "access_id" , headerAccessID ) ,
zap . String ( "client_ip" , clientIP ) )
return nil , dto . QueryWhitelistPublicPayload { } , ErrInvalidIP
}
return apiUser , payload , nil
}
func ( s * QueryWhitelistApplicationServiceImpl ) createEntry (
ctx context . Context ,
createdBy string ,
@@ -126,6 +193,35 @@ func (s *QueryWhitelistApplicationServiceImpl) createEntry(
return & resp , nil
}
// mergeAPICodes 合并接口编码并去重,保持原有顺序,新增编码追加在后
func mergeAPICodes ( existing , incoming [ ] string ) [ ] string {
seen := make ( map [ string ] struct { } , len ( existing ) + len ( incoming ) )
result := make ( [ ] string , 0 , len ( existing ) + len ( incoming ) )
for _ , code := range existing {
code = strings . TrimSpace ( code )
if code == "" {
continue
}
if _ , ok := seen [ code ] ; ok {
continue
}
seen [ code ] = struct { } { }
result = append ( result , code )
}
for _ , code := range incoming {
code = strings . TrimSpace ( code )
if code == "" {
continue
}
if _ , ok := seen [ code ] ; ok {
continue
}
seen [ code ] = struct { } { }
result = append ( result , code )
}
return result
}
func constantTimeEqual ( a , b string ) bool {
return subtle . ConstantTimeCompare ( [ ] byte ( a ) , [ ] byte ( b ) ) == 1
}
@@ -205,6 +301,8 @@ func PublicAPIErrorMessage(err error) string {
switch err {
case ErrWhitelistExists :
return "规则已存在"
case ErrWhitelistNotFound :
return "规则不存在,请先调用创建接口"
case ErrRequestParam :
return "请求参数结构不正确"
default :