This commit is contained in:
@@ -5,6 +5,8 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
"tyapi-server/internal/application/api/commands"
|
||||
"tyapi-server/internal/application/api/dto"
|
||||
@@ -37,8 +39,8 @@ type ApiApplicationService interface {
|
||||
GetUserApiKeys(ctx context.Context, userID string) (*dto.ApiKeysResponse, error)
|
||||
|
||||
// 用户白名单管理
|
||||
GetUserWhiteList(ctx context.Context, userID string) (*dto.WhiteListListResponse, error)
|
||||
AddWhiteListIP(ctx context.Context, userID string, ipAddress string) error
|
||||
GetUserWhiteList(ctx context.Context, userID string, remarkKeyword string) (*dto.WhiteListListResponse, error)
|
||||
AddWhiteListIP(ctx context.Context, userID string, ipAddress string, remark string) error
|
||||
DeleteWhiteListIP(ctx context.Context, userID string, ipAddress string) error
|
||||
|
||||
// 获取用户API调用记录
|
||||
@@ -466,7 +468,7 @@ func (s *ApiApplicationServiceImpl) GetUserApiKeys(ctx context.Context, userID s
|
||||
}
|
||||
|
||||
// GetUserWhiteList 获取用户白名单列表
|
||||
func (s *ApiApplicationServiceImpl) GetUserWhiteList(ctx context.Context, userID string) (*dto.WhiteListListResponse, error) {
|
||||
func (s *ApiApplicationServiceImpl) GetUserWhiteList(ctx context.Context, userID string, remarkKeyword string) (*dto.WhiteListListResponse, error) {
|
||||
apiUser, err := s.apiUserService.LoadApiUserByUserId(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -474,28 +476,49 @@ func (s *ApiApplicationServiceImpl) GetUserWhiteList(ctx context.Context, userID
|
||||
|
||||
// 确保WhiteList不为nil
|
||||
if apiUser.WhiteList == nil {
|
||||
apiUser.WhiteList = []string{}
|
||||
apiUser.WhiteList = entities.WhiteList{}
|
||||
}
|
||||
|
||||
// 将白名单字符串数组转换为响应格式
|
||||
// 将白名单转换为响应格式
|
||||
var items []dto.WhiteListResponse
|
||||
for _, ip := range apiUser.WhiteList {
|
||||
for _, item := range apiUser.WhiteList {
|
||||
// 如果提供了备注关键词,进行模糊匹配过滤
|
||||
if remarkKeyword != "" {
|
||||
if !contains(item.Remark, remarkKeyword) {
|
||||
continue // 不匹配则跳过
|
||||
}
|
||||
}
|
||||
|
||||
items = append(items, dto.WhiteListResponse{
|
||||
ID: apiUser.ID, // 使用API用户ID作为标识
|
||||
UserID: apiUser.UserId,
|
||||
IPAddress: ip,
|
||||
CreatedAt: apiUser.CreatedAt, // 使用API用户创建时间
|
||||
IPAddress: item.IPAddress,
|
||||
Remark: item.Remark, // 备注
|
||||
CreatedAt: item.AddedAt, // 使用每个IP的实际添加时间
|
||||
})
|
||||
}
|
||||
|
||||
// 按添加时间降序排序(新的排在前面)
|
||||
sort.Slice(items, func(i, j int) bool {
|
||||
return items[i].CreatedAt.After(items[j].CreatedAt)
|
||||
})
|
||||
|
||||
return &dto.WhiteListListResponse{
|
||||
Items: items,
|
||||
Total: len(items),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// contains 检查字符串是否包含子字符串(不区分大小写)
|
||||
func contains(s, substr string) bool {
|
||||
if substr == "" {
|
||||
return true
|
||||
}
|
||||
return strings.Contains(strings.ToLower(s), strings.ToLower(substr))
|
||||
}
|
||||
|
||||
// AddWhiteListIP 添加白名单IP
|
||||
func (s *ApiApplicationServiceImpl) AddWhiteListIP(ctx context.Context, userID string, ipAddress string) error {
|
||||
func (s *ApiApplicationServiceImpl) AddWhiteListIP(ctx context.Context, userID string, ipAddress string, remark string) error {
|
||||
apiUser, err := s.apiUserService.LoadApiUserByUserId(ctx, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -503,11 +526,11 @@ func (s *ApiApplicationServiceImpl) AddWhiteListIP(ctx context.Context, userID s
|
||||
|
||||
// 确保WhiteList不为nil
|
||||
if apiUser.WhiteList == nil {
|
||||
apiUser.WhiteList = []string{}
|
||||
apiUser.WhiteList = entities.WhiteList{}
|
||||
}
|
||||
|
||||
// 使用实体的领域方法添加IP到白名单
|
||||
err = apiUser.AddToWhiteList(ipAddress)
|
||||
// 使用实体的领域方法添加IP到白名单(会自动记录添加时间和备注)
|
||||
err = apiUser.AddToWhiteList(ipAddress, remark)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -530,7 +553,7 @@ func (s *ApiApplicationServiceImpl) DeleteWhiteListIP(ctx context.Context, userI
|
||||
|
||||
// 确保WhiteList不为nil
|
||||
if apiUser.WhiteList == nil {
|
||||
apiUser.WhiteList = []string{}
|
||||
apiUser.WhiteList = entities.WhiteList{}
|
||||
}
|
||||
|
||||
// 使用实体的领域方法删除IP
|
||||
|
||||
@@ -26,11 +26,13 @@ type WhiteListResponse struct {
|
||||
ID string `json:"id"`
|
||||
UserID string `json:"user_id"`
|
||||
IPAddress string `json:"ip_address"`
|
||||
Remark string `json:"remark"` // 备注
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
type WhiteListRequest struct {
|
||||
IPAddress string `json:"ip_address" binding:"required,ip"`
|
||||
Remark string `json:"remark"` // 备注(可选)
|
||||
}
|
||||
|
||||
type WhiteListListResponse struct {
|
||||
|
||||
@@ -2,7 +2,9 @@ package entities
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"database/sql/driver"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
@@ -18,6 +20,78 @@ const (
|
||||
ApiUserStatusFrozen = "frozen"
|
||||
)
|
||||
|
||||
// WhiteListItem 白名单项,包含IP地址、添加时间和备注
|
||||
type WhiteListItem struct {
|
||||
IPAddress string `json:"ip_address"` // IP地址
|
||||
AddedAt time.Time `json:"added_at"` // 添加时间
|
||||
Remark string `json:"remark"` // 备注
|
||||
}
|
||||
|
||||
// WhiteList 白名单类型,支持向后兼容(旧的字符串数组格式)
|
||||
type WhiteList []WhiteListItem
|
||||
|
||||
// Value 实现 driver.Valuer 接口,用于数据库写入
|
||||
func (w WhiteList) Value() (driver.Value, error) {
|
||||
if w == nil {
|
||||
return "[]", nil
|
||||
}
|
||||
data, err := json.Marshal(w)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return string(data), nil
|
||||
}
|
||||
|
||||
// Scan 实现 sql.Scanner 接口,用于数据库读取(支持向后兼容)
|
||||
func (w *WhiteList) Scan(value interface{}) error {
|
||||
if value == nil {
|
||||
*w = WhiteList{}
|
||||
return nil
|
||||
}
|
||||
|
||||
var bytes []byte
|
||||
switch v := value.(type) {
|
||||
case []byte:
|
||||
bytes = v
|
||||
case string:
|
||||
bytes = []byte(v)
|
||||
default:
|
||||
return errors.New("无法扫描 WhiteList 类型")
|
||||
}
|
||||
|
||||
if len(bytes) == 0 || string(bytes) == "[]" || string(bytes) == "null" {
|
||||
*w = WhiteList{}
|
||||
return nil
|
||||
}
|
||||
|
||||
// 首先尝试解析为新格式(结构体数组)
|
||||
var items []WhiteListItem
|
||||
if err := json.Unmarshal(bytes, &items); err == nil {
|
||||
// 成功解析为新格式
|
||||
*w = WhiteList(items)
|
||||
return nil
|
||||
}
|
||||
|
||||
// 如果失败,尝试解析为旧格式(字符串数组)
|
||||
var oldFormat []string
|
||||
if err := json.Unmarshal(bytes, &oldFormat); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 将旧格式转换为新格式
|
||||
now := time.Now()
|
||||
items = make([]WhiteListItem, 0, len(oldFormat))
|
||||
for _, ip := range oldFormat {
|
||||
items = append(items, WhiteListItem{
|
||||
IPAddress: ip,
|
||||
AddedAt: now, // 使用当前时间作为添加时间(因为旧数据没有时间信息)
|
||||
Remark: "", // 旧数据没有备注信息
|
||||
})
|
||||
}
|
||||
*w = WhiteList(items)
|
||||
return nil
|
||||
}
|
||||
|
||||
// ApiUser API用户(聚合根)
|
||||
type ApiUser struct {
|
||||
ID string `gorm:"primaryKey;type:varchar(64)" json:"id"`
|
||||
@@ -25,7 +99,7 @@ type ApiUser struct {
|
||||
AccessId string `gorm:"type:varchar(64);not null;uniqueIndex" json:"access_id"`
|
||||
SecretKey string `gorm:"type:varchar(128);not null" json:"secret_key"`
|
||||
Status string `gorm:"type:varchar(20);not null;default:'normal'" json:"status"`
|
||||
WhiteList []string `gorm:"type:json;serializer:json;default:'[]'" json:"white_list"` // 支持多个白名单
|
||||
WhiteList WhiteList `gorm:"type:json;default:'[]'" json:"white_list"` // 支持多个白名单,包含IP和添加时间,支持向后兼容
|
||||
|
||||
// 余额预警配置
|
||||
BalanceAlertEnabled bool `gorm:"default:true" json:"balance_alert_enabled" comment:"是否启用余额预警"`
|
||||
@@ -41,7 +115,7 @@ type ApiUser struct {
|
||||
// IsWhiteListed 校验IP/域名是否在白名单
|
||||
func (u *ApiUser) IsWhiteListed(target string) bool {
|
||||
for _, w := range u.WhiteList {
|
||||
if w == target {
|
||||
if w.IPAddress == target {
|
||||
return true
|
||||
}
|
||||
}
|
||||
@@ -77,7 +151,7 @@ func NewApiUser(userId string, defaultAlertEnabled bool, defaultAlertThreshold f
|
||||
AccessId: accessId,
|
||||
SecretKey: secretKey,
|
||||
Status: ApiUserStatusNormal,
|
||||
WhiteList: []string{},
|
||||
WhiteList: WhiteList{},
|
||||
BalanceAlertEnabled: defaultAlertEnabled,
|
||||
BalanceAlertThreshold: defaultAlertThreshold,
|
||||
}, nil
|
||||
@@ -90,12 +164,12 @@ func (u *ApiUser) Freeze() {
|
||||
func (u *ApiUser) Unfreeze() {
|
||||
u.Status = ApiUserStatusNormal
|
||||
}
|
||||
func (u *ApiUser) UpdateWhiteList(list []string) {
|
||||
u.WhiteList = list
|
||||
func (u *ApiUser) UpdateWhiteList(list []WhiteListItem) {
|
||||
u.WhiteList = WhiteList(list)
|
||||
}
|
||||
|
||||
// AddToWhiteList 新增白名单项(防御性校验)
|
||||
func (u *ApiUser) AddToWhiteList(entry string) error {
|
||||
func (u *ApiUser) AddToWhiteList(entry string, remark string) error {
|
||||
if len(u.WhiteList) >= 10 {
|
||||
return errors.New("白名单最多只能有10个")
|
||||
}
|
||||
@@ -103,27 +177,31 @@ func (u *ApiUser) AddToWhiteList(entry string) error {
|
||||
return errors.New("非法IP")
|
||||
}
|
||||
for _, w := range u.WhiteList {
|
||||
if w == entry {
|
||||
if w.IPAddress == entry {
|
||||
return errors.New("白名单已存在")
|
||||
}
|
||||
}
|
||||
u.WhiteList = append(u.WhiteList, entry)
|
||||
u.WhiteList = append(u.WhiteList, WhiteListItem{
|
||||
IPAddress: entry,
|
||||
AddedAt: time.Now(),
|
||||
Remark: remark,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// BeforeUpdate GORM钩子:更新前确保WhiteList不为nil
|
||||
func (u *ApiUser) BeforeUpdate(tx *gorm.DB) error {
|
||||
if u.WhiteList == nil {
|
||||
u.WhiteList = []string{}
|
||||
u.WhiteList = WhiteList{}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveFromWhiteList 删除白名单项
|
||||
func (u *ApiUser) RemoveFromWhiteList(entry string) error {
|
||||
newList := make([]string, 0, len(u.WhiteList))
|
||||
newList := make([]WhiteListItem, 0, len(u.WhiteList))
|
||||
for _, w := range u.WhiteList {
|
||||
if w != entry {
|
||||
if w.IPAddress != entry {
|
||||
newList = append(newList, w)
|
||||
}
|
||||
}
|
||||
@@ -216,9 +294,9 @@ func (u *ApiUser) Validate() error {
|
||||
if len(u.WhiteList) > 10 {
|
||||
return errors.New("白名单最多只能有10个")
|
||||
}
|
||||
for _, ip := range u.WhiteList {
|
||||
if net.ParseIP(ip) == nil {
|
||||
return errors.New("白名单项必须为合法IP地址: " + ip)
|
||||
for _, item := range u.WhiteList {
|
||||
if net.ParseIP(item.IPAddress) == nil {
|
||||
return errors.New("白名单项必须为合法IP地址: " + item.IPAddress)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
@@ -259,7 +337,26 @@ func (c *ApiUser) BeforeCreate(tx *gorm.DB) error {
|
||||
c.ID = uuid.New().String()
|
||||
}
|
||||
if c.WhiteList == nil {
|
||||
c.WhiteList = []string{}
|
||||
c.WhiteList = WhiteList{}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AfterFind GORM钩子:查询后处理数据,确保AddedAt不为零值
|
||||
func (u *ApiUser) AfterFind(tx *gorm.DB) error {
|
||||
// 如果 WhiteList 为空,初始化为空数组
|
||||
if u.WhiteList == nil {
|
||||
u.WhiteList = WhiteList{}
|
||||
return nil
|
||||
}
|
||||
|
||||
// 确保所有项的AddedAt不为零值(处理可能从旧数据迁移的情况)
|
||||
now := time.Now()
|
||||
for i := range u.WhiteList {
|
||||
if u.WhiteList[i].AddedAt.IsZero() {
|
||||
u.WhiteList[i].AddedAt = now
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
"tyapi-server/internal/config"
|
||||
"tyapi-server/internal/domains/api/entities"
|
||||
repo "tyapi-server/internal/domains/api/repositories"
|
||||
@@ -10,7 +11,7 @@ import (
|
||||
type ApiUserAggregateService interface {
|
||||
CreateApiUser(ctx context.Context, apiUserId string) error
|
||||
UpdateWhiteList(ctx context.Context, apiUserId string, whiteList []string) error
|
||||
AddToWhiteList(ctx context.Context, apiUserId string, entry string) error
|
||||
AddToWhiteList(ctx context.Context, apiUserId string, entry string, remark string) error
|
||||
RemoveFromWhiteList(ctx context.Context, apiUserId string, entry string) error
|
||||
FreezeApiUser(ctx context.Context, apiUserId string) error
|
||||
UnfreezeApiUser(ctx context.Context, apiUserId string) error
|
||||
@@ -44,16 +45,25 @@ func (s *ApiUserAggregateServiceImpl) UpdateWhiteList(ctx context.Context, apiUs
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
apiUser.UpdateWhiteList(whiteList)
|
||||
// 将字符串数组转换为WhiteListItem数组
|
||||
items := make([]entities.WhiteListItem, 0, len(whiteList))
|
||||
now := time.Now()
|
||||
for _, ip := range whiteList {
|
||||
items = append(items, entities.WhiteListItem{
|
||||
IPAddress: ip,
|
||||
AddedAt: now, // 批量更新时使用当前时间
|
||||
})
|
||||
}
|
||||
apiUser.UpdateWhiteList(items) // UpdateWhiteList 会转换为 WhiteList 类型
|
||||
return s.repo.Update(ctx, apiUser)
|
||||
}
|
||||
|
||||
func (s *ApiUserAggregateServiceImpl) AddToWhiteList(ctx context.Context, apiUserId string, entry string) error {
|
||||
func (s *ApiUserAggregateServiceImpl) AddToWhiteList(ctx context.Context, apiUserId string, entry string, remark string) error {
|
||||
apiUser, err := s.repo.FindByUserId(ctx, apiUserId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = apiUser.AddToWhiteList(entry)
|
||||
err = apiUser.AddToWhiteList(entry, remark)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -90,7 +100,6 @@ func (s *ApiUserAggregateServiceImpl) UnfreezeApiUser(ctx context.Context, apiUs
|
||||
return s.repo.Update(ctx, apiUser)
|
||||
}
|
||||
|
||||
|
||||
func (s *ApiUserAggregateServiceImpl) LoadApiUserByAccessId(ctx context.Context, accessId string) (*entities.ApiUser, error) {
|
||||
return s.repo.FindByAccessId(ctx, accessId)
|
||||
}
|
||||
@@ -103,7 +112,7 @@ func (s *ApiUserAggregateServiceImpl) LoadApiUserByUserId(ctx context.Context, a
|
||||
|
||||
// 确保WhiteList不为nil
|
||||
if apiUser.WhiteList == nil {
|
||||
apiUser.WhiteList = []string{}
|
||||
apiUser.WhiteList = entities.WhiteList{}
|
||||
}
|
||||
|
||||
return apiUser, nil
|
||||
@@ -117,7 +126,7 @@ func (s *ApiUserAggregateServiceImpl) SaveApiUser(ctx context.Context, apiUser *
|
||||
if exists != nil {
|
||||
// 确保WhiteList不为nil
|
||||
if apiUser.WhiteList == nil {
|
||||
apiUser.WhiteList = []string{}
|
||||
apiUser.WhiteList = []entities.WhiteListItem{}
|
||||
}
|
||||
return s.repo.Update(ctx, apiUser)
|
||||
} else {
|
||||
|
||||
@@ -110,7 +110,10 @@ func (h *ApiHandler) GetUserWhiteList(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.appService.GetUserWhiteList(c.Request.Context(), userID)
|
||||
// 获取查询参数
|
||||
remarkKeyword := c.Query("remark") // 备注模糊查询关键词
|
||||
|
||||
result, err := h.appService.GetUserWhiteList(c.Request.Context(), userID, remarkKeyword)
|
||||
if err != nil {
|
||||
h.logger.Error("获取用户白名单失败", zap.Error(err))
|
||||
h.responseBuilder.BadRequest(c, err.Error())
|
||||
@@ -134,7 +137,7 @@ func (h *ApiHandler) AddWhiteListIP(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
err := h.appService.AddWhiteListIP(c.Request.Context(), userID, req.IPAddress)
|
||||
err := h.appService.AddWhiteListIP(c.Request.Context(), userID, req.IPAddress, req.Remark)
|
||||
if err != nil {
|
||||
h.logger.Error("添加白名单IP失败", zap.Error(err))
|
||||
h.responseBuilder.BadRequest(c, err.Error())
|
||||
|
||||
@@ -223,7 +223,6 @@ func (fm *FontManager) getWatermarkFontPaths() []string {
|
||||
return fm.buildFontPaths(fontNames)
|
||||
}
|
||||
|
||||
|
||||
// buildFontPaths 构建字体文件路径列表(仅从resources/pdf/fonts加载)
|
||||
// 返回所有存在的字体文件的绝对路径
|
||||
func (fm *FontManager) buildFontPaths(fontNames []string) []string {
|
||||
|
||||
39
scripts/migrate_whitelist.sql
Normal file
39
scripts/migrate_whitelist.sql
Normal file
@@ -0,0 +1,39 @@
|
||||
-- 白名单数据结构迁移脚本
|
||||
-- 将旧的字符串数组格式转换为新的结构体数组格式(包含IP和添加时间)
|
||||
--
|
||||
-- 执行前请备份数据库!
|
||||
--
|
||||
-- 使用方法:
|
||||
-- psql -U your_user -d your_database -f migrate_whitelist.sql
|
||||
|
||||
-- 开始事务
|
||||
BEGIN;
|
||||
|
||||
-- 更新 api_users 表中的 white_list 字段
|
||||
-- 将旧的字符串数组格式: ["ip1", "ip2"]
|
||||
-- 转换为新格式: [{"ip_address": "ip1", "added_at": "2025-12-04T15:20:19Z"}, ...]
|
||||
|
||||
UPDATE api_users
|
||||
SET white_list = (
|
||||
SELECT json_agg(
|
||||
json_build_object(
|
||||
'ip_address', ip_value,
|
||||
'added_at', COALESCE(
|
||||
(SELECT updated_at FROM api_users WHERE id = api_users.id),
|
||||
NOW()
|
||||
)
|
||||
)
|
||||
)
|
||||
FROM json_array_elements_text(white_list::json) AS ip_value
|
||||
)
|
||||
WHERE white_list IS NOT NULL
|
||||
AND white_list != '[]'::json
|
||||
AND white_list::text NOT LIKE '[{%' -- 排除已经是新格式的数据
|
||||
AND json_array_length(white_list::json) > 0;
|
||||
|
||||
-- 提交事务
|
||||
COMMIT;
|
||||
|
||||
-- 验证迁移结果(可选)
|
||||
-- SELECT id, white_list FROM api_users WHERE white_list IS NOT NULL LIMIT 5;
|
||||
|
||||
Reference in New Issue
Block a user