This commit is contained in:
2025-12-04 16:17:27 +08:00
parent d9c2d9f103
commit 0f5c4f4303
9 changed files with 227 additions and 55 deletions

View File

@@ -2,7 +2,9 @@ package entities
import (
"crypto/rand"
"database/sql/driver"
"encoding/hex"
"encoding/json"
"errors"
"io"
"net"
@@ -18,14 +20,86 @@ const (
ApiUserStatusFrozen = "frozen"
)
// WhiteListItem 白名单项包含IP地址、添加时间和备注
type WhiteListItem struct {
IPAddress string `json:"ip_address"` // IP地址
AddedAt time.Time `json:"added_at"` // 添加时间
Remark string `json:"remark"` // 备注
}
// WhiteList 白名单类型,支持向后兼容(旧的字符串数组格式)
type WhiteList []WhiteListItem
// Value 实现 driver.Valuer 接口,用于数据库写入
func (w WhiteList) Value() (driver.Value, error) {
if w == nil {
return "[]", nil
}
data, err := json.Marshal(w)
if err != nil {
return nil, err
}
return string(data), nil
}
// Scan 实现 sql.Scanner 接口,用于数据库读取(支持向后兼容)
func (w *WhiteList) Scan(value interface{}) error {
if value == nil {
*w = WhiteList{}
return nil
}
var bytes []byte
switch v := value.(type) {
case []byte:
bytes = v
case string:
bytes = []byte(v)
default:
return errors.New("无法扫描 WhiteList 类型")
}
if len(bytes) == 0 || string(bytes) == "[]" || string(bytes) == "null" {
*w = WhiteList{}
return nil
}
// 首先尝试解析为新格式(结构体数组)
var items []WhiteListItem
if err := json.Unmarshal(bytes, &items); err == nil {
// 成功解析为新格式
*w = WhiteList(items)
return nil
}
// 如果失败,尝试解析为旧格式(字符串数组)
var oldFormat []string
if err := json.Unmarshal(bytes, &oldFormat); err != nil {
return err
}
// 将旧格式转换为新格式
now := time.Now()
items = make([]WhiteListItem, 0, len(oldFormat))
for _, ip := range oldFormat {
items = append(items, WhiteListItem{
IPAddress: ip,
AddedAt: now, // 使用当前时间作为添加时间(因为旧数据没有时间信息)
Remark: "", // 旧数据没有备注信息
})
}
*w = WhiteList(items)
return nil
}
// ApiUser API用户聚合根
type ApiUser struct {
ID string `gorm:"primaryKey;type:varchar(64)" json:"id"`
UserId string `gorm:"type:varchar(36);not null;uniqueIndex" json:"user_id"`
AccessId string `gorm:"type:varchar(64);not null;uniqueIndex" json:"access_id"`
SecretKey string `gorm:"type:varchar(128);not null" json:"secret_key"`
Status string `gorm:"type:varchar(20);not null;default:'normal'" json:"status"`
WhiteList []string `gorm:"type:json;serializer:json;default:'[]'" json:"white_list"` // 支持多个白名单
ID string `gorm:"primaryKey;type:varchar(64)" json:"id"`
UserId string `gorm:"type:varchar(36);not null;uniqueIndex" json:"user_id"`
AccessId string `gorm:"type:varchar(64);not null;uniqueIndex" json:"access_id"`
SecretKey string `gorm:"type:varchar(128);not null" json:"secret_key"`
Status string `gorm:"type:varchar(20);not null;default:'normal'" json:"status"`
WhiteList WhiteList `gorm:"type:json;default:'[]'" json:"white_list"` // 支持多个白名单包含IP和添加时间支持向后兼容
// 余额预警配置
BalanceAlertEnabled bool `gorm:"default:true" json:"balance_alert_enabled" comment:"是否启用余额预警"`
@@ -41,7 +115,7 @@ type ApiUser struct {
// IsWhiteListed 校验IP/域名是否在白名单
func (u *ApiUser) IsWhiteListed(target string) bool {
for _, w := range u.WhiteList {
if w == target {
if w.IPAddress == target {
return true
}
}
@@ -77,7 +151,7 @@ func NewApiUser(userId string, defaultAlertEnabled bool, defaultAlertThreshold f
AccessId: accessId,
SecretKey: secretKey,
Status: ApiUserStatusNormal,
WhiteList: []string{},
WhiteList: WhiteList{},
BalanceAlertEnabled: defaultAlertEnabled,
BalanceAlertThreshold: defaultAlertThreshold,
}, nil
@@ -90,12 +164,12 @@ func (u *ApiUser) Freeze() {
func (u *ApiUser) Unfreeze() {
u.Status = ApiUserStatusNormal
}
func (u *ApiUser) UpdateWhiteList(list []string) {
u.WhiteList = list
func (u *ApiUser) UpdateWhiteList(list []WhiteListItem) {
u.WhiteList = WhiteList(list)
}
// AddToWhiteList 新增白名单项(防御性校验)
func (u *ApiUser) AddToWhiteList(entry string) error {
func (u *ApiUser) AddToWhiteList(entry string, remark string) error {
if len(u.WhiteList) >= 10 {
return errors.New("白名单最多只能有10个")
}
@@ -103,27 +177,31 @@ func (u *ApiUser) AddToWhiteList(entry string) error {
return errors.New("非法IP")
}
for _, w := range u.WhiteList {
if w == entry {
if w.IPAddress == entry {
return errors.New("白名单已存在")
}
}
u.WhiteList = append(u.WhiteList, entry)
u.WhiteList = append(u.WhiteList, WhiteListItem{
IPAddress: entry,
AddedAt: time.Now(),
Remark: remark,
})
return nil
}
// BeforeUpdate GORM钩子更新前确保WhiteList不为nil
func (u *ApiUser) BeforeUpdate(tx *gorm.DB) error {
if u.WhiteList == nil {
u.WhiteList = []string{}
u.WhiteList = WhiteList{}
}
return nil
}
// RemoveFromWhiteList 删除白名单项
func (u *ApiUser) RemoveFromWhiteList(entry string) error {
newList := make([]string, 0, len(u.WhiteList))
newList := make([]WhiteListItem, 0, len(u.WhiteList))
for _, w := range u.WhiteList {
if w != entry {
if w.IPAddress != entry {
newList = append(newList, w)
}
}
@@ -216,9 +294,9 @@ func (u *ApiUser) Validate() error {
if len(u.WhiteList) > 10 {
return errors.New("白名单最多只能有10个")
}
for _, ip := range u.WhiteList {
if net.ParseIP(ip) == nil {
return errors.New("白名单项必须为合法IP地址: " + ip)
for _, item := range u.WhiteList {
if net.ParseIP(item.IPAddress) == nil {
return errors.New("白名单项必须为合法IP地址: " + item.IPAddress)
}
}
return nil
@@ -259,7 +337,26 @@ func (c *ApiUser) BeforeCreate(tx *gorm.DB) error {
c.ID = uuid.New().String()
}
if c.WhiteList == nil {
c.WhiteList = []string{}
c.WhiteList = WhiteList{}
}
return nil
}
// AfterFind GORM钩子查询后处理数据确保AddedAt不为零值
func (u *ApiUser) AfterFind(tx *gorm.DB) error {
// 如果 WhiteList 为空,初始化为空数组
if u.WhiteList == nil {
u.WhiteList = WhiteList{}
return nil
}
// 确保所有项的AddedAt不为零值处理可能从旧数据迁移的情况
now := time.Now()
for i := range u.WhiteList {
if u.WhiteList[i].AddedAt.IsZero() {
u.WhiteList[i].AddedAt = now
}
}
return nil
}