This commit is contained in:
2025-07-11 21:05:58 +08:00
parent 5b4392894f
commit e3d64e7485
74 changed files with 14379 additions and 697 deletions

View File

@@ -34,6 +34,12 @@ type ChangePasswordRequest struct {
Code string `json:"code" binding:"required,len=6" example:"123456"`
}
// UpdateProfileRequest 更新用户信息请求
type UpdateProfileRequest struct {
Phone string `json:"phone" binding:"omitempty,len=11" example:"13800138000"`
// 可以在这里添加更多用户信息字段,如昵称、头像等
}
// UserResponse 用户响应
type UserResponse struct {
ID string `json:"id" example:"123e4567-e89b-12d3-a456-426614174000"`

View File

@@ -6,50 +6,58 @@ import (
"gorm.io/gorm"
)
// SMSCode 短信验证码记录
// SMSCode 短信验证码记录实体
// 记录用户发送的所有短信验证码,支持多种使用场景
// 包含验证码的有效期管理、使用状态跟踪、安全审计等功能
type SMSCode struct {
ID string `gorm:"primaryKey;type:varchar(36)" json:"id"`
Phone string `gorm:"type:varchar(20);not null;index" json:"phone"`
Code string `gorm:"type:varchar(10);not null" json:"-"` // 不返回给前端
Scene SMSScene `gorm:"type:varchar(20);not null" json:"scene"`
Used bool `gorm:"default:false" json:"used"`
ExpiresAt time.Time `gorm:"not null" json:"expires_at"`
UsedAt *time.Time `json:"used_at,omitempty"`
CreatedAt time.Time `gorm:"autoCreateTime" json:"created_at"`
UpdatedAt time.Time `gorm:"autoUpdateTime" json:"updated_at"`
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
// 基础标识
ID string `gorm:"primaryKey;type:varchar(36)" json:"id" comment:"短信验证码记录唯一标识"`
Phone string `gorm:"type:varchar(20);not null;index" json:"phone" comment:"接收手机号"`
Code string `gorm:"type:varchar(10);not null" json:"-" comment:"验证码内容(不返回给前端)"`
Scene SMSScene `gorm:"type:varchar(20);not null" json:"scene" comment:"使用场景"`
Used bool `gorm:"default:false" json:"used" comment:"是否已使用"`
ExpiresAt time.Time `gorm:"not null" json:"expires_at" comment:"过期时间"`
UsedAt *time.Time `json:"used_at,omitempty" comment:"使用时间"`
CreatedAt time.Time `gorm:"autoCreateTime" json:"created_at" comment:"创建时间"`
UpdatedAt time.Time `gorm:"autoUpdateTime" json:"updated_at" comment:"更新时间"`
DeletedAt gorm.DeletedAt `gorm:"index" json:"-" comment:"软删除时间"`
// 额外信息
IP string `gorm:"type:varchar(45)" json:"ip"`
UserAgent string `gorm:"type:varchar(500)" json:"user_agent"`
// 额外信息 - 安全审计相关数据
IP string `gorm:"type:varchar(45)" json:"ip" comment:"发送IP地址"`
UserAgent string `gorm:"type:varchar(500)" json:"user_agent" comment:"客户端信息"`
}
// SMSScene 短信验证码使用场景
// SMSScene 短信验证码使用场景枚举
// 定义系统中所有需要使用短信验证码的业务场景
type SMSScene string
const (
SMSSceneRegister SMSScene = "register" // 注册
SMSSceneLogin SMSScene = "login" // 登录
SMSSceneChangePassword SMSScene = "change_password" // 修改密码
SMSSceneResetPassword SMSScene = "reset_password" // 重置密码
SMSSceneBind SMSScene = "bind" // 绑定手机号
SMSSceneUnbind SMSScene = "unbind" // 解绑手机号
SMSSceneRegister SMSScene = "register" // 注册 - 新用户注册验证
SMSSceneLogin SMSScene = "login" // 登录 - 手机号登录验证
SMSSceneChangePassword SMSScene = "change_password" // 修改密码 - 修改密码验证
SMSSceneResetPassword SMSScene = "reset_password" // 重置密码 - 忘记密码重置
SMSSceneBind SMSScene = "bind" // 绑定手机号 - 绑定新手机号
SMSSceneUnbind SMSScene = "unbind" // 解绑手机号 - 解绑当前手机号
)
// 实现 Entity 接口
// 实现 Entity 接口 - 提供统一的实体管理接口
// GetID 获取实体唯一标识
func (s *SMSCode) GetID() string {
return s.ID
}
// GetCreatedAt 获取创建时间
func (s *SMSCode) GetCreatedAt() time.Time {
return s.CreatedAt
}
// GetUpdatedAt 获取更新时间
func (s *SMSCode) GetUpdatedAt() time.Time {
return s.UpdatedAt
}
// Validate 验证短信验证码
// 检查短信验证码记录的必填字段是否完整,确保数据的有效性
func (s *SMSCode) Validate() error {
if s.Phone == "" {
return &ValidationError{Message: "手机号不能为空"}
@@ -64,24 +72,253 @@ func (s *SMSCode) Validate() error {
return &ValidationError{Message: "过期时间不能为空"}
}
// 验证手机号格式
if !IsValidPhoneFormat(s.Phone) {
return &ValidationError{Message: "手机号格式无效"}
}
// 验证验证码格式
if err := s.validateCodeFormat(); err != nil {
return err
}
return nil
}
// 业务方法
func (s *SMSCode) IsExpired() bool {
return time.Now().After(s.ExpiresAt)
// ================ 业务方法 ================
// VerifyCode 验证验证码
// 检查输入的验证码是否匹配且有效
func (s *SMSCode) VerifyCode(inputCode string) error {
// 1. 检查验证码是否已使用
if s.Used {
return &ValidationError{Message: "验证码已被使用"}
}
// 2. 检查验证码是否已过期
if s.IsExpired() {
return &ValidationError{Message: "验证码已过期"}
}
// 3. 检查验证码是否匹配
if s.Code != inputCode {
return &ValidationError{Message: "验证码错误"}
}
// 4. 标记为已使用
s.MarkAsUsed()
return nil
}
// IsExpired 检查验证码是否已过期
// 判断当前时间是否超过验证码的有效期
func (s *SMSCode) IsExpired() bool {
return time.Now().After(s.ExpiresAt) || time.Now().Equal(s.ExpiresAt)
}
// IsValid 检查验证码是否有效
// 综合判断验证码是否可用,包括未使用和未过期两个条件
func (s *SMSCode) IsValid() bool {
return !s.Used && !s.IsExpired()
}
// MarkAsUsed 标记验证码为已使用
// 在验证码被成功使用后调用,记录使用时间并标记状态
func (s *SMSCode) MarkAsUsed() {
s.Used = true
now := time.Now()
s.UsedAt = &now
}
// CanResend 检查是否可以重新发送验证码
// 基于时间间隔和场景判断是否允许重新发送
func (s *SMSCode) CanResend(minInterval time.Duration) bool {
// 如果验证码已使用或已过期,可以重新发送
if s.Used || s.IsExpired() {
return true
}
// 检查距离上次发送的时间间隔
timeSinceCreated := time.Since(s.CreatedAt)
return timeSinceCreated >= minInterval
}
// GetRemainingTime 获取验证码剩余有效时间
func (s *SMSCode) GetRemainingTime() time.Duration {
if s.IsExpired() {
return 0
}
return s.ExpiresAt.Sub(time.Now())
}
// IsRecentlySent 检查是否最近发送过验证码
func (s *SMSCode) IsRecentlySent(within time.Duration) bool {
return time.Since(s.CreatedAt) < within
}
// GetMaskedCode 获取脱敏的验证码(用于日志记录)
func (s *SMSCode) GetMaskedCode() string {
if len(s.Code) < 3 {
return "***"
}
return s.Code[:1] + "***" + s.Code[len(s.Code)-1:]
}
// GetMaskedPhone 获取脱敏的手机号
func (s *SMSCode) GetMaskedPhone() string {
if len(s.Phone) < 7 {
return s.Phone
}
return s.Phone[:3] + "****" + s.Phone[len(s.Phone)-4:]
}
// ================ 场景相关方法 ================
// IsSceneValid 检查场景是否有效
func (s *SMSCode) IsSceneValid() bool {
validScenes := []SMSScene{
SMSSceneRegister,
SMSSceneLogin,
SMSSceneChangePassword,
SMSSceneResetPassword,
SMSSceneBind,
SMSSceneUnbind,
}
for _, scene := range validScenes {
if s.Scene == scene {
return true
}
}
return false
}
// GetSceneName 获取场景的中文名称
func (s *SMSCode) GetSceneName() string {
sceneNames := map[SMSScene]string{
SMSSceneRegister: "用户注册",
SMSSceneLogin: "用户登录",
SMSSceneChangePassword: "修改密码",
SMSSceneResetPassword: "重置密码",
SMSSceneBind: "绑定手机号",
SMSSceneUnbind: "解绑手机号",
}
if name, exists := sceneNames[s.Scene]; exists {
return name
}
return string(s.Scene)
}
// ================ 安全相关方法 ================
// IsSuspicious 检查是否存在可疑行为
func (s *SMSCode) IsSuspicious() bool {
// 检查IP地址是否为空可能表示异常
if s.IP == "" {
return true
}
// 检查UserAgent是否为空可能表示异常
if s.UserAgent == "" {
return true
}
// 可以添加更多安全检查逻辑
// 例如检查IP是否来自异常地区、UserAgent是否异常等
return false
}
// GetSecurityInfo 获取安全信息摘要
func (s *SMSCode) GetSecurityInfo() map[string]interface{} {
return map[string]interface{}{
"ip": s.IP,
"user_agent": s.UserAgent,
"suspicious": s.IsSuspicious(),
"scene": s.GetSceneName(),
"created_at": s.CreatedAt,
}
}
// ================ 私有辅助方法 ================
// validateCodeFormat 验证验证码格式
func (s *SMSCode) validateCodeFormat() error {
// 检查验证码长度
if len(s.Code) < 4 || len(s.Code) > 10 {
return &ValidationError{Message: "验证码长度必须在4-10位之间"}
}
// 检查验证码是否只包含数字
for _, char := range s.Code {
if char < '0' || char > '9' {
return &ValidationError{Message: "验证码只能包含数字"}
}
}
return nil
}
// ================ 静态工具方法 ================
// IsValidScene 检查场景是否有效(静态方法)
func IsValidScene(scene SMSScene) bool {
validScenes := []SMSScene{
SMSSceneRegister,
SMSSceneLogin,
SMSSceneChangePassword,
SMSSceneResetPassword,
SMSSceneBind,
SMSSceneUnbind,
}
for _, validScene := range validScenes {
if scene == validScene {
return true
}
}
return false
}
// GetSceneName 获取场景的中文名称(静态方法)
func GetSceneName(scene SMSScene) string {
sceneNames := map[SMSScene]string{
SMSSceneRegister: "用户注册",
SMSSceneLogin: "用户登录",
SMSSceneChangePassword: "修改密码",
SMSSceneResetPassword: "重置密码",
SMSSceneBind: "绑定手机号",
SMSSceneUnbind: "解绑手机号",
}
if name, exists := sceneNames[scene]; exists {
return name
}
return string(scene)
}
// NewSMSCode 创建新的短信验证码(工厂方法)
func NewSMSCode(phone, code string, scene SMSScene, expireTime time.Duration, clientIP, userAgent string) (*SMSCode, error) {
smsCode := &SMSCode{
Phone: phone,
Code: code,
Scene: scene,
Used: false,
ExpiresAt: time.Now().Add(expireTime),
IP: clientIP,
UserAgent: userAgent,
}
// 验证实体
if err := smsCode.Validate(); err != nil {
return nil, err
}
return smsCode, nil
}
// TableName 指定表名
func (SMSCode) TableName() string {
return "sms_codes"

View File

@@ -0,0 +1,681 @@
package entities
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestSMSCode_Validate(t *testing.T) {
tests := []struct {
name string
smsCode *SMSCode
wantErr bool
}{
{
name: "有效验证码",
smsCode: &SMSCode{
Phone: "13800138000",
Code: "123456",
Scene: SMSSceneRegister,
ExpiresAt: time.Now().Add(time.Hour),
},
wantErr: false,
},
{
name: "手机号为空",
smsCode: &SMSCode{
Phone: "",
Code: "123456",
Scene: SMSSceneRegister,
ExpiresAt: time.Now().Add(time.Hour),
},
wantErr: true,
},
{
name: "验证码为空",
smsCode: &SMSCode{
Phone: "13800138000",
Code: "",
Scene: SMSSceneRegister,
ExpiresAt: time.Now().Add(time.Hour),
},
wantErr: true,
},
{
name: "场景为空",
smsCode: &SMSCode{
Phone: "13800138000",
Code: "123456",
Scene: "",
ExpiresAt: time.Now().Add(time.Hour),
},
wantErr: true,
},
{
name: "过期时间为零",
smsCode: &SMSCode{
Phone: "13800138000",
Code: "123456",
Scene: SMSSceneRegister,
ExpiresAt: time.Time{},
},
wantErr: true,
},
{
name: "手机号格式无效",
smsCode: &SMSCode{
Phone: "123",
Code: "123456",
Scene: SMSSceneRegister,
ExpiresAt: time.Now().Add(time.Hour),
},
wantErr: true,
},
{
name: "验证码格式无效-包含字母",
smsCode: &SMSCode{
Phone: "13800138000",
Code: "12345a",
Scene: SMSSceneRegister,
ExpiresAt: time.Now().Add(time.Hour),
},
wantErr: true,
},
{
name: "验证码长度过短",
smsCode: &SMSCode{
Phone: "13800138000",
Code: "123",
Scene: SMSSceneRegister,
ExpiresAt: time.Now().Add(time.Hour),
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.smsCode.Validate()
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}
func TestSMSCode_VerifyCode(t *testing.T) {
now := time.Now()
expiresAt := now.Add(time.Hour)
tests := []struct {
name string
smsCode *SMSCode
inputCode string
wantErr bool
}{
{
name: "验证码正确",
smsCode: &SMSCode{
Code: "123456",
Used: false,
ExpiresAt: expiresAt,
},
inputCode: "123456",
wantErr: false,
},
{
name: "验证码错误",
smsCode: &SMSCode{
Code: "123456",
Used: false,
ExpiresAt: expiresAt,
},
inputCode: "654321",
wantErr: true,
},
{
name: "验证码已使用",
smsCode: &SMSCode{
Code: "123456",
Used: true,
ExpiresAt: expiresAt,
},
inputCode: "123456",
wantErr: true,
},
{
name: "验证码已过期",
smsCode: &SMSCode{
Code: "123456",
Used: false,
ExpiresAt: now.Add(-time.Hour),
},
inputCode: "123456",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.smsCode.VerifyCode(tt.inputCode)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
// 验证码正确时应该被标记为已使用
assert.True(t, tt.smsCode.Used)
assert.NotNil(t, tt.smsCode.UsedAt)
}
})
}
}
func TestSMSCode_IsExpired(t *testing.T) {
now := time.Now()
tests := []struct {
name string
smsCode *SMSCode
expected bool
}{
{
name: "未过期",
smsCode: &SMSCode{
ExpiresAt: now.Add(time.Hour),
},
expected: false,
},
{
name: "已过期",
smsCode: &SMSCode{
ExpiresAt: now.Add(-time.Hour),
},
expected: true,
},
{
name: "刚好过期",
smsCode: &SMSCode{
ExpiresAt: now,
},
expected: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.smsCode.IsExpired()
assert.Equal(t, tt.expected, result)
})
}
}
func TestSMSCode_IsValid(t *testing.T) {
now := time.Now()
tests := []struct {
name string
smsCode *SMSCode
expected bool
}{
{
name: "有效验证码",
smsCode: &SMSCode{
Used: false,
ExpiresAt: now.Add(time.Hour),
},
expected: true,
},
{
name: "已使用",
smsCode: &SMSCode{
Used: true,
ExpiresAt: now.Add(time.Hour),
},
expected: false,
},
{
name: "已过期",
smsCode: &SMSCode{
Used: false,
ExpiresAt: now.Add(-time.Hour),
},
expected: false,
},
{
name: "已使用且已过期",
smsCode: &SMSCode{
Used: true,
ExpiresAt: now.Add(-time.Hour),
},
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.smsCode.IsValid()
assert.Equal(t, tt.expected, result)
})
}
}
func TestSMSCode_CanResend(t *testing.T) {
now := time.Now()
minInterval := 60 * time.Second
tests := []struct {
name string
smsCode *SMSCode
expected bool
}{
{
name: "已使用-可以重发",
smsCode: &SMSCode{
Used: true,
CreatedAt: now.Add(-30 * time.Second),
},
expected: true,
},
{
name: "已过期-可以重发",
smsCode: &SMSCode{
Used: false,
ExpiresAt: now.Add(-time.Hour),
CreatedAt: now.Add(-30 * time.Second),
},
expected: true,
},
{
name: "未过期且未使用-间隔足够-可以重发",
smsCode: &SMSCode{
Used: false,
ExpiresAt: now.Add(time.Hour),
CreatedAt: now.Add(-2 * time.Minute),
},
expected: true,
},
{
name: "未过期且未使用-间隔不足-不能重发",
smsCode: &SMSCode{
Used: false,
ExpiresAt: now.Add(time.Hour),
CreatedAt: now.Add(-30 * time.Second),
},
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.smsCode.CanResend(minInterval)
assert.Equal(t, tt.expected, result)
})
}
}
func TestSMSCode_GetRemainingTime(t *testing.T) {
now := time.Now()
tests := []struct {
name string
smsCode *SMSCode
expected time.Duration
}{
{
name: "未过期",
smsCode: &SMSCode{
ExpiresAt: now.Add(time.Hour),
},
expected: time.Hour,
},
{
name: "已过期",
smsCode: &SMSCode{
ExpiresAt: now.Add(-time.Hour),
},
expected: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.smsCode.GetRemainingTime()
// 由于时间计算可能有微小差异,我们检查是否在合理范围内
if tt.expected > 0 {
assert.True(t, result > 0)
assert.True(t, result <= tt.expected)
} else {
assert.Equal(t, tt.expected, result)
}
})
}
}
func TestSMSCode_GetMaskedCode(t *testing.T) {
tests := []struct {
name string
code string
expected string
}{
{
name: "6位验证码",
code: "123456",
expected: "1***6",
},
{
name: "4位验证码",
code: "1234",
expected: "1***4",
},
{
name: "短验证码",
code: "12",
expected: "***",
},
{
name: "单字符",
code: "1",
expected: "***",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
smsCode := &SMSCode{Code: tt.code}
result := smsCode.GetMaskedCode()
assert.Equal(t, tt.expected, result)
})
}
}
func TestSMSCode_GetMaskedPhone(t *testing.T) {
tests := []struct {
name string
phone string
expected string
}{
{
name: "标准手机号",
phone: "13800138000",
expected: "138****8000",
},
{
name: "短手机号",
phone: "138001",
expected: "138001",
},
{
name: "空手机号",
phone: "",
expected: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
smsCode := &SMSCode{Phone: tt.phone}
result := smsCode.GetMaskedPhone()
assert.Equal(t, tt.expected, result)
})
}
}
func TestSMSCode_IsSceneValid(t *testing.T) {
tests := []struct {
name string
scene SMSScene
expected bool
}{
{
name: "注册场景",
scene: SMSSceneRegister,
expected: true,
},
{
name: "登录场景",
scene: SMSSceneLogin,
expected: true,
},
{
name: "无效场景",
scene: "invalid",
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
smsCode := &SMSCode{Scene: tt.scene}
result := smsCode.IsSceneValid()
assert.Equal(t, tt.expected, result)
})
}
}
func TestSMSCode_GetSceneName(t *testing.T) {
tests := []struct {
name string
scene SMSScene
expected string
}{
{
name: "注册场景",
scene: SMSSceneRegister,
expected: "用户注册",
},
{
name: "登录场景",
scene: SMSSceneLogin,
expected: "用户登录",
},
{
name: "无效场景",
scene: "invalid",
expected: "invalid",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
smsCode := &SMSCode{Scene: tt.scene}
result := smsCode.GetSceneName()
assert.Equal(t, tt.expected, result)
})
}
}
func TestSMSCode_IsSuspicious(t *testing.T) {
tests := []struct {
name string
smsCode *SMSCode
expected bool
}{
{
name: "正常记录",
smsCode: &SMSCode{
IP: "192.168.1.1",
UserAgent: "Mozilla/5.0",
},
expected: false,
},
{
name: "IP为空-可疑",
smsCode: &SMSCode{
IP: "",
UserAgent: "Mozilla/5.0",
},
expected: true,
},
{
name: "UserAgent为空-可疑",
smsCode: &SMSCode{
IP: "192.168.1.1",
UserAgent: "",
},
expected: true,
},
{
name: "IP和UserAgent都为空-可疑",
smsCode: &SMSCode{
IP: "",
UserAgent: "",
},
expected: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.smsCode.IsSuspicious()
assert.Equal(t, tt.expected, result)
})
}
}
func TestSMSCode_GetSecurityInfo(t *testing.T) {
now := time.Now()
smsCode := &SMSCode{
IP: "192.168.1.1",
UserAgent: "Mozilla/5.0",
Scene: SMSSceneRegister,
CreatedAt: now,
}
securityInfo := smsCode.GetSecurityInfo()
assert.Equal(t, "192.168.1.1", securityInfo["ip"])
assert.Equal(t, "Mozilla/5.0", securityInfo["user_agent"])
assert.Equal(t, false, securityInfo["suspicious"])
assert.Equal(t, "用户注册", securityInfo["scene"])
assert.Equal(t, now, securityInfo["created_at"])
}
func TestIsValidScene(t *testing.T) {
tests := []struct {
name string
scene SMSScene
expected bool
}{
{
name: "注册场景",
scene: SMSSceneRegister,
expected: true,
},
{
name: "登录场景",
scene: SMSSceneLogin,
expected: true,
},
{
name: "无效场景",
scene: "invalid",
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := IsValidScene(tt.scene)
assert.Equal(t, tt.expected, result)
})
}
}
func TestGetSceneName(t *testing.T) {
tests := []struct {
name string
scene SMSScene
expected string
}{
{
name: "注册场景",
scene: SMSSceneRegister,
expected: "用户注册",
},
{
name: "登录场景",
scene: SMSSceneLogin,
expected: "用户登录",
},
{
name: "无效场景",
scene: "invalid",
expected: "invalid",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GetSceneName(tt.scene)
assert.Equal(t, tt.expected, result)
})
}
}
func TestNewSMSCode(t *testing.T) {
tests := []struct {
name string
phone string
code string
scene SMSScene
expireTime time.Duration
clientIP string
userAgent string
expectError bool
}{
{
name: "有效参数",
phone: "13800138000",
code: "123456",
scene: SMSSceneRegister,
expireTime: time.Hour,
clientIP: "192.168.1.1",
userAgent: "Mozilla/5.0",
expectError: false,
},
{
name: "无效手机号",
phone: "123",
code: "123456",
scene: SMSSceneRegister,
expireTime: time.Hour,
clientIP: "192.168.1.1",
userAgent: "Mozilla/5.0",
expectError: true,
},
{
name: "无效验证码",
phone: "13800138000",
code: "123",
scene: SMSSceneRegister,
expireTime: time.Hour,
clientIP: "192.168.1.1",
userAgent: "Mozilla/5.0",
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
smsCode, err := NewSMSCode(tt.phone, tt.code, tt.scene, tt.expireTime, tt.clientIP, tt.userAgent)
if tt.expectError {
assert.Error(t, err)
assert.Nil(t, smsCode)
} else {
assert.NoError(t, err)
assert.NotNil(t, smsCode)
assert.Equal(t, tt.phone, smsCode.Phone)
assert.Equal(t, tt.code, smsCode.Code)
assert.Equal(t, tt.scene, smsCode.Scene)
assert.Equal(t, tt.clientIP, smsCode.IP)
assert.Equal(t, tt.userAgent, smsCode.UserAgent)
assert.False(t, smsCode.Used)
assert.True(t, smsCode.ExpiresAt.After(time.Now()))
}
})
}
}

View File

@@ -1,35 +1,48 @@
package entities
import (
"errors"
"fmt"
"regexp"
"time"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
)
// User 用户实体
// 系统用户的核心信息,提供基础的账户管理功能
// 支持手机号登录密码加密存储实现Entity接口便于统一管理
type User struct {
ID string `gorm:"primaryKey;type:varchar(36)" json:"id"`
Phone string `gorm:"uniqueIndex;type:varchar(20);not null" json:"phone"`
Password string `gorm:"type:varchar(255);not null" json:"-"`
CreatedAt time.Time `gorm:"autoCreateTime" json:"created_at"`
UpdatedAt time.Time `gorm:"autoUpdateTime" json:"updated_at"`
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
// 基础标识
ID string `gorm:"primaryKey;type:varchar(36)" json:"id" comment:"用户唯一标识"`
Phone string `gorm:"uniqueIndex;type:varchar(20);not null" json:"phone" comment:"手机号码(登录账号)"`
Password string `gorm:"type:varchar(255);not null" json:"-" comment:"登录密码(加密存储,不返回前端)"`
// 时间戳字段
CreatedAt time.Time `gorm:"autoCreateTime" json:"created_at" comment:"创建时间"`
UpdatedAt time.Time `gorm:"autoUpdateTime" json:"updated_at" comment:"更新时间"`
DeletedAt gorm.DeletedAt `gorm:"index" json:"-" comment:"软删除时间"`
}
// 实现 Entity 接口
// 实现 Entity 接口 - 提供统一的实体管理接口
// GetID 获取实体唯一标识
func (u *User) GetID() string {
return u.ID
}
// GetCreatedAt 获取创建时间
func (u *User) GetCreatedAt() time.Time {
return u.CreatedAt
}
// GetUpdatedAt 获取更新时间
func (u *User) GetUpdatedAt() time.Time {
return u.UpdatedAt
}
// 验证方法
// Validate 验证用户信息
// 检查用户必填字段是否完整,确保数据的有效性
func (u *User) Validate() error {
if u.Phone == "" {
return NewValidationError("手机号不能为空")
@@ -37,23 +50,226 @@ func (u *User) Validate() error {
if u.Password == "" {
return NewValidationError("密码不能为空")
}
// 验证手机号格式
if !u.IsValidPhone() {
return NewValidationError("手机号格式无效")
}
// 验证密码强度
if err := u.validatePasswordStrength(u.Password); err != nil {
return err
}
return nil
}
// ================ 业务方法 ================
// ChangePassword 修改密码
// 验证旧密码,检查新密码强度,更新密码
func (u *User) ChangePassword(oldPassword, newPassword, confirmPassword string) error {
// 1. 验证确认密码
if newPassword != confirmPassword {
return NewValidationError("新密码和确认新密码不匹配")
}
// 2. 验证旧密码
if !u.CheckPassword(oldPassword) {
return NewValidationError("当前密码错误")
}
// 3. 验证新密码强度
if err := u.validatePasswordStrength(newPassword); err != nil {
return err
}
// 4. 检查新密码不能与旧密码相同
if u.CheckPassword(newPassword) {
return NewValidationError("新密码不能与当前密码相同")
}
// 5. 更新密码
hashedPassword, err := u.hashPassword(newPassword)
if err != nil {
return fmt.Errorf("密码加密失败: %w", err)
}
u.Password = hashedPassword
return nil
}
// CheckPassword 验证密码是否正确
func (u *User) CheckPassword(password string) bool {
err := bcrypt.CompareHashAndPassword([]byte(u.Password), []byte(password))
return err == nil
}
// SetPassword 设置密码(用于注册或重置密码)
func (u *User) SetPassword(password string) error {
// 验证密码强度
if err := u.validatePasswordStrength(password); err != nil {
return err
}
// 加密密码
hashedPassword, err := u.hashPassword(password)
if err != nil {
return fmt.Errorf("密码加密失败: %w", err)
}
u.Password = hashedPassword
return nil
}
// CanLogin 检查用户是否可以登录
func (u *User) CanLogin() bool {
// 检查用户是否被删除
if !u.DeletedAt.Time.IsZero() {
return false
}
// 检查必要字段是否存在
if u.Phone == "" || u.Password == "" {
return false
}
return true
}
// IsActive 检查用户是否处于活跃状态
func (u *User) IsActive() bool {
return u.DeletedAt.Time.IsZero()
}
// IsDeleted 检查用户是否已被删除
func (u *User) IsDeleted() bool {
return !u.DeletedAt.Time.IsZero()
}
// ================ 手机号相关方法 ================
// IsValidPhone 验证手机号格式
func (u *User) IsValidPhone() bool {
return IsValidPhoneFormat(u.Phone)
}
// SetPhone 设置手机号
func (u *User) SetPhone(phone string) error {
if !IsValidPhoneFormat(phone) {
return NewValidationError("手机号格式无效")
}
u.Phone = phone
return nil
}
// GetMaskedPhone 获取脱敏的手机号
func (u *User) GetMaskedPhone() string {
if len(u.Phone) < 7 {
return u.Phone
}
return u.Phone[:3] + "****" + u.Phone[len(u.Phone)-4:]
}
// ================ 私有辅助方法 ================
// hashPassword 加密密码
func (u *User) hashPassword(password string) (string, error) {
hashedBytes, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return "", err
}
return string(hashedBytes), nil
}
// validatePasswordStrength 验证密码强度
func (u *User) validatePasswordStrength(password string) error {
if len(password) < 8 {
return NewValidationError("密码长度至少8位")
}
if len(password) > 128 {
return NewValidationError("密码长度不能超过128位")
}
// 检查是否包含数字
hasDigit := regexp.MustCompile(`[0-9]`).MatchString(password)
if !hasDigit {
return NewValidationError("密码必须包含数字")
}
// 检查是否包含字母
hasLetter := regexp.MustCompile(`[a-zA-Z]`).MatchString(password)
if !hasLetter {
return NewValidationError("密码必须包含字母")
}
// 检查是否包含特殊字符(可选,可以根据需求调整)
hasSpecial := regexp.MustCompile(`[!@#$%^&*()_+\-=\[\]{};':"\\|,.<>\/?]`).MatchString(password)
if !hasSpecial {
return NewValidationError("密码必须包含特殊字符")
}
return nil
}
// ================ 静态工具方法 ================
// IsValidPhoneFormat 验证手机号格式(静态方法)
func IsValidPhoneFormat(phone string) bool {
if phone == "" {
return false
}
// 中国手机号验证11位数字以1开头
pattern := `^1[3-9]\d{9}$`
matched, _ := regexp.MatchString(pattern, phone)
return matched
}
// NewUser 创建新用户(工厂方法)
func NewUser(phone, password string) (*User, error) {
user := &User{
ID: "", // 由数据库或调用方设置
Phone: phone,
}
// 验证手机号
if err := user.SetPhone(phone); err != nil {
return nil, err
}
// 设置密码
if err := user.SetPassword(password); err != nil {
return nil, err
}
return user, nil
}
// TableName 指定表名
func (User) TableName() string {
return "users"
}
// ValidationError 验证错误
// 自定义验证错误类型,提供结构化的错误信息
type ValidationError struct {
Message string
}
// Error 实现error接口
func (e *ValidationError) Error() string {
return e.Message
}
// NewValidationError 创建新的验证错误
// 工厂方法,用于创建验证错误实例
func NewValidationError(message string) *ValidationError {
return &ValidationError{Message: message}
}
// IsValidationError 检查是否为验证错误
func IsValidationError(err error) bool {
var validationErr *ValidationError
return errors.As(err, &validationErr)
}

View File

@@ -0,0 +1,338 @@
package entities
import (
"testing"
)
func TestUser_ChangePassword(t *testing.T) {
// 创建测试用户
user, err := NewUser("13800138000", "OldPassword123!")
if err != nil {
t.Fatalf("创建用户失败: %v", err)
}
tests := []struct {
name string
oldPassword string
newPassword string
confirmPassword string
wantErr bool
errorContains string
}{
{
name: "正常修改密码",
oldPassword: "OldPassword123!",
newPassword: "NewPassword123!",
confirmPassword: "NewPassword123!",
wantErr: false,
},
{
name: "旧密码错误",
oldPassword: "WrongPassword123!",
newPassword: "NewPassword123!",
confirmPassword: "NewPassword123!",
wantErr: true,
errorContains: "当前密码错误",
},
{
name: "确认密码不匹配",
oldPassword: "OldPassword123!",
newPassword: "NewPassword123!",
confirmPassword: "DifferentPassword123!",
wantErr: true,
errorContains: "新密码和确认新密码不匹配",
},
{
name: "新密码与旧密码相同",
oldPassword: "OldPassword123!",
newPassword: "OldPassword123!",
confirmPassword: "OldPassword123!",
wantErr: true,
errorContains: "新密码不能与当前密码相同",
},
{
name: "新密码强度不足",
oldPassword: "OldPassword123!",
newPassword: "weak",
confirmPassword: "weak",
wantErr: true,
errorContains: "密码长度至少8位",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 重置用户密码为初始状态
user.SetPassword("OldPassword123!")
err := user.ChangePassword(tt.oldPassword, tt.newPassword, tt.confirmPassword)
if tt.wantErr {
if err == nil {
t.Errorf("期望错误但没有得到错误")
return
}
if tt.errorContains != "" && !contains(err.Error(), tt.errorContains) {
t.Errorf("错误信息不包含期望的内容,期望包含: %s, 实际: %s", tt.errorContains, err.Error())
}
} else {
if err != nil {
t.Errorf("不期望错误但得到了错误: %v", err)
}
// 验证密码确实被修改了
if !user.CheckPassword(tt.newPassword) {
t.Errorf("密码修改后验证失败")
}
}
})
}
}
func TestUser_CheckPassword(t *testing.T) {
user, err := NewUser("13800138000", "TestPassword123!")
if err != nil {
t.Fatalf("创建用户失败: %v", err)
}
tests := []struct {
name string
password string
want bool
}{
{
name: "正确密码",
password: "TestPassword123!",
want: true,
},
{
name: "错误密码",
password: "WrongPassword123!",
want: false,
},
{
name: "空密码",
password: "",
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := user.CheckPassword(tt.password)
if got != tt.want {
t.Errorf("CheckPassword() = %v, want %v", got, tt.want)
}
})
}
}
func TestUser_SetPhone(t *testing.T) {
user := &User{}
tests := []struct {
name string
phone string
wantErr bool
}{
{
name: "有效手机号",
phone: "13800138000",
wantErr: false,
},
{
name: "无效手机号-太短",
phone: "1380013800",
wantErr: true,
},
{
name: "无效手机号-太长",
phone: "138001380000",
wantErr: true,
},
{
name: "无效手机号-格式错误",
phone: "1380013800a",
wantErr: true,
},
{
name: "无效手机号-不以1开头",
phone: "23800138000",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := user.SetPhone(tt.phone)
if tt.wantErr {
if err == nil {
t.Errorf("期望错误但没有得到错误")
}
} else {
if err != nil {
t.Errorf("不期望错误但得到了错误: %v", err)
}
if user.Phone != tt.phone {
t.Errorf("手机号设置失败,期望: %s, 实际: %s", tt.phone, user.Phone)
}
}
})
}
}
func TestUser_GetMaskedPhone(t *testing.T) {
tests := []struct {
name string
phone string
expected string
}{
{
name: "正常手机号",
phone: "13800138000",
expected: "138****8000",
},
{
name: "短手机号",
phone: "138001",
expected: "138001",
},
{
name: "空手机号",
phone: "",
expected: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
user := &User{Phone: tt.phone}
got := user.GetMaskedPhone()
if got != tt.expected {
t.Errorf("GetMaskedPhone() = %v, want %v", got, tt.expected)
}
})
}
}
func TestIsValidPhoneFormat(t *testing.T) {
tests := []struct {
name string
phone string
want bool
}{
{
name: "有效手机号-13开头",
phone: "13800138000",
want: true,
},
{
name: "有效手机号-15开头",
phone: "15800138000",
want: true,
},
{
name: "有效手机号-18开头",
phone: "18800138000",
want: true,
},
{
name: "无效手机号-12开头",
phone: "12800138000",
want: false,
},
{
name: "无效手机号-20开头",
phone: "20800138000",
want: false,
},
{
name: "无效手机号-太短",
phone: "1380013800",
want: false,
},
{
name: "无效手机号-太长",
phone: "138001380000",
want: false,
},
{
name: "无效手机号-包含字母",
phone: "1380013800a",
want: false,
},
{
name: "空手机号",
phone: "",
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := IsValidPhoneFormat(tt.phone)
if got != tt.want {
t.Errorf("IsValidPhoneFormat() = %v, want %v", got, tt.want)
}
})
}
}
func TestNewUser(t *testing.T) {
tests := []struct {
name string
phone string
password string
wantErr bool
}{
{
name: "有效用户信息",
phone: "13800138000",
password: "TestPassword123!",
wantErr: false,
},
{
name: "无效手机号",
phone: "1380013800",
password: "TestPassword123!",
wantErr: true,
},
{
name: "无效密码",
phone: "13800138000",
password: "weak",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
user, err := NewUser(tt.phone, tt.password)
if tt.wantErr {
if err == nil {
t.Errorf("期望错误但没有得到错误")
}
} else {
if err != nil {
t.Errorf("不期望错误但得到了错误: %v", err)
}
if user.Phone != tt.phone {
t.Errorf("手机号设置失败,期望: %s, 实际: %s", tt.phone, user.Phone)
}
if !user.CheckPassword(tt.password) {
t.Errorf("密码设置失败")
}
}
})
}
}
// 辅助函数
func contains(s, substr string) bool {
return len(s) >= len(substr) && (s == substr || (len(s) > len(substr) && (s[:len(substr)] == substr || s[len(s)-len(substr):] == substr || func() bool {
for i := 1; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}())))
}

View File

@@ -81,6 +81,34 @@ func (r *SMSCodeRepository) MarkAsUsed(ctx context.Context, id string) error {
return nil
}
// Update 更新验证码记录
func (r *SMSCodeRepository) Update(ctx context.Context, smsCode *entities.SMSCode) error {
if err := r.db.WithContext(ctx).Save(smsCode).Error; err != nil {
r.logger.Error("更新验证码记录失败", zap.Error(err))
return err
}
// 更新缓存
cacheKey := r.buildCacheKey(smsCode.Phone, smsCode.Scene)
r.cache.Set(ctx, cacheKey, smsCode, 5*time.Minute)
r.logger.Info("验证码记录更新成功", zap.String("code_id", smsCode.ID))
return nil
}
// GetRecentCode 获取最近的验证码记录(不限制有效性)
func (r *SMSCodeRepository) GetRecentCode(ctx context.Context, phone string, scene entities.SMSScene) (*entities.SMSCode, error) {
var smsCode entities.SMSCode
if err := r.db.WithContext(ctx).
Where("phone = ? AND scene = ?", phone, scene).
Order("created_at DESC").
First(&smsCode).Error; err != nil {
return nil, err
}
return &smsCode, nil
}
// CleanupExpired 清理过期的验证码
func (r *SMSCodeRepository) CleanupExpired(ctx context.Context) error {
result := r.db.WithContext(ctx).

View File

@@ -133,6 +133,41 @@ func (r *UserRepository) Delete(ctx context.Context, id string) error {
return nil
}
// SoftDelete 软删除用户
func (r *UserRepository) SoftDelete(ctx context.Context, id string) error {
// 先获取用户信息用于清除缓存
user, err := r.GetByID(ctx, id)
if err != nil {
return err
}
if err := r.db.WithContext(ctx).Delete(&entities.User{}, "id = ?", id).Error; err != nil {
r.logger.Error("软删除用户失败", zap.Error(err))
return err
}
// 清除相关缓存
r.deleteCacheByID(ctx, id)
r.deleteCacheByPhone(ctx, user.Phone)
r.logger.Info("用户软删除成功", zap.String("user_id", id))
return nil
}
// Restore 恢复软删除的用户
func (r *UserRepository) Restore(ctx context.Context, id string) error {
if err := r.db.WithContext(ctx).Unscoped().Model(&entities.User{}).Where("id = ?", id).Update("deleted_at", nil).Error; err != nil {
r.logger.Error("恢复用户失败", zap.Error(err))
return err
}
// 清除相关缓存
r.deleteCacheByID(ctx, id)
r.logger.Info("用户恢复成功", zap.String("user_id", id))
return nil
}
// List 分页获取用户列表
func (r *UserRepository) List(ctx context.Context, offset, limit int) ([]*entities.User, error) {
var users []*entities.User

View File

@@ -43,83 +43,132 @@ func NewSMSCodeService(
// SendCode 发送验证码
func (s *SMSCodeService) SendCode(ctx context.Context, phone string, scene entities.SMSScene, clientIP, userAgent string) error {
// 检查频率限制
// 1. 检查频率限制
if err := s.checkRateLimit(ctx, phone); err != nil {
return err
}
// 生成验证码
// 2. 生成验证码
code := s.smsClient.GenerateCode(s.config.CodeLength)
// 创建SMS验证码记录
smsCode := &entities.SMSCode{
ID: uuid.New().String(),
Phone: phone,
Code: code,
Scene: scene,
IP: clientIP,
UserAgent: userAgent,
Used: false,
ExpiresAt: time.Now().Add(s.config.ExpireTime),
// 3. 使用工厂方法创建SMS验证码记录
smsCode, err := entities.NewSMSCode(phone, code, scene, s.config.ExpireTime, clientIP, userAgent)
if err != nil {
return fmt.Errorf("创建验证码记录失败: %w", err)
}
// 保存验证码
// 4. 设置ID
smsCode.ID = uuid.New().String()
// 5. 保存验证码
if err := s.repo.Create(ctx, smsCode); err != nil {
s.logger.Error("保存短信验证码失败",
zap.String("phone", phone),
zap.String("scene", string(scene)),
zap.String("phone", smsCode.GetMaskedPhone()),
zap.String("scene", smsCode.GetSceneName()),
zap.Error(err))
return fmt.Errorf("保存验证码失败: %w", err)
}
// 发送短信
// 6. 发送短信
if err := s.smsClient.SendVerificationCode(ctx, phone, code); err != nil {
// 记录发送失败但不删除验证码记录,让其自然过期
s.logger.Error("发送短信验证码失败",
zap.String("phone", phone),
zap.String("code", code),
zap.String("phone", smsCode.GetMaskedPhone()),
zap.String("code", smsCode.GetMaskedCode()),
zap.Error(err))
return fmt.Errorf("短信发送失败: %w", err)
}
// 更新发送记录缓存
// 7. 更新发送记录缓存
s.updateSendRecord(ctx, phone)
s.logger.Info("短信验证码发送成功",
zap.String("phone", phone),
zap.String("scene", string(scene)))
zap.String("phone", smsCode.GetMaskedPhone()),
zap.String("scene", smsCode.GetSceneName()),
zap.String("remaining_time", smsCode.GetRemainingTime().String()))
return nil
}
// VerifyCode 验证验证码
func (s *SMSCodeService) VerifyCode(ctx context.Context, phone, code string, scene entities.SMSScene) error {
// 根据手机号和场景获取有效的验证码记录
// 1. 根据手机号和场景获取有效的验证码记录
smsCode, err := s.repo.GetValidCode(ctx, phone, scene)
if err != nil {
return fmt.Errorf("验证码无效或已过期")
}
// 验证验证码是否匹配
if smsCode.Code != code {
return fmt.Errorf("验证码无效或已过期")
// 2. 使用实体的验证方法
if err := smsCode.VerifyCode(code); err != nil {
return err
}
// 标记验证码为已使用
if err := s.repo.MarkAsUsed(ctx, smsCode.ID); err != nil {
s.logger.Error("标记验证码为已使用失败",
// 3. 保存更新后的验证码状态
if err := s.repo.Update(ctx, smsCode); err != nil {
s.logger.Error("更新验证码状态失败",
zap.String("code_id", smsCode.ID),
zap.Error(err))
return fmt.Errorf("验证码状态更新失败")
}
s.logger.Info("短信验证码验证成功",
zap.String("phone", phone),
zap.String("scene", string(scene)))
zap.String("phone", smsCode.GetMaskedPhone()),
zap.String("scene", smsCode.GetSceneName()))
return nil
}
// CanResendCode 检查是否可以重新发送验证码
func (s *SMSCodeService) CanResendCode(ctx context.Context, phone string, scene entities.SMSScene) (bool, error) {
// 1. 获取最近的验证码记录
recentCode, err := s.repo.GetRecentCode(ctx, phone, scene)
if err != nil {
// 如果没有记录,可以发送
return true, nil
}
// 2. 使用实体的方法检查是否可以重新发送
canResend := recentCode.CanResend(s.config.RateLimit.MinInterval)
// 3. 记录检查结果
if !canResend {
remainingTime := s.config.RateLimit.MinInterval - time.Since(recentCode.CreatedAt)
s.logger.Info("验证码发送频率限制",
zap.String("phone", recentCode.GetMaskedPhone()),
zap.String("scene", recentCode.GetSceneName()),
zap.Duration("remaining_wait_time", remainingTime))
}
return canResend, nil
}
// GetCodeStatus 获取验证码状态信息
func (s *SMSCodeService) GetCodeStatus(ctx context.Context, phone string, scene entities.SMSScene) (map[string]interface{}, error) {
// 1. 获取最近的验证码记录
recentCode, err := s.repo.GetRecentCode(ctx, phone, scene)
if err != nil {
return map[string]interface{}{
"has_code": false,
"message": "没有找到验证码记录",
}, nil
}
// 2. 构建状态信息
status := map[string]interface{}{
"has_code": true,
"is_valid": recentCode.IsValid(),
"is_expired": recentCode.IsExpired(),
"is_used": recentCode.Used,
"remaining_time": recentCode.GetRemainingTime().String(),
"scene": recentCode.GetSceneName(),
"can_resend": recentCode.CanResend(s.config.RateLimit.MinInterval),
"created_at": recentCode.CreatedAt,
"security_info": recentCode.GetSecurityInfo(),
}
return status, nil
}
// checkRateLimit 检查发送频率限制
func (s *SMSCodeService) checkRateLimit(ctx context.Context, phone string) error {
now := time.Now()

View File

@@ -3,11 +3,9 @@ package services
import (
"context"
"fmt"
"regexp"
"github.com/google/uuid"
"go.uber.org/zap"
"golang.org/x/crypto/bcrypt"
"tyapi-server/internal/domains/user/dto"
"tyapi-server/internal/domains/user/entities"
@@ -64,44 +62,32 @@ func (s *UserService) Shutdown(ctx context.Context) error {
// Register 用户注册
func (s *UserService) Register(ctx context.Context, registerReq *dto.RegisterRequest) (*entities.User, error) {
// 验证手机号格式
if !s.isValidPhone(registerReq.Phone) {
return nil, fmt.Errorf("手机号格式无效")
}
// 验证密码确认
if registerReq.Password != registerReq.ConfirmPassword {
return nil, fmt.Errorf("密码和确认密码不匹配")
}
// 验证短信验证码
// 1. 验证短信验证码
if err := s.smsCodeService.VerifyCode(ctx, registerReq.Phone, registerReq.Code, entities.SMSSceneRegister); err != nil {
return nil, fmt.Errorf("验证码验证失败: %w", err)
}
// 检查手机号是否已存在
// 2. 检查手机号是否已存在
if err := s.checkPhoneDuplicate(ctx, registerReq.Phone); err != nil {
return nil, err
}
// 创建用户实体
user := registerReq.ToEntity()
// 3. 使用工厂方法创建用户实体(业务规则验证在实体中完成)
user, err := entities.NewUser(registerReq.Phone, registerReq.Password)
if err != nil {
return nil, fmt.Errorf("创建用户失败: %w", err)
}
// 4. 设置用户ID
user.ID = uuid.New().String()
// 哈希密码
hashedPassword, err := s.hashPassword(registerReq.Password)
if err != nil {
return nil, fmt.Errorf("密码加密失败: %w", err)
}
user.Password = hashedPassword
// 保存用户
// 5. 保存用户
if err := s.repo.Create(ctx, user); err != nil {
s.logger.Error("创建用户失败", zap.Error(err))
return nil, fmt.Errorf("创建用户失败: %w", err)
}
// 发布用户注册事件
// 6. 发布用户注册事件
event := events.NewUserRegisteredEvent(user, s.getCorrelationID(ctx))
if err := s.eventBus.Publish(ctx, event); err != nil {
s.logger.Warn("发布用户注册事件失败", zap.Error(err))
@@ -116,18 +102,23 @@ func (s *UserService) Register(ctx context.Context, registerReq *dto.RegisterReq
// LoginWithPassword 密码登录
func (s *UserService) LoginWithPassword(ctx context.Context, loginReq *dto.LoginWithPasswordRequest) (*entities.User, error) {
// 根据手机号查找用户
// 1. 根据手机号查找用户
user, err := s.repo.FindByPhone(ctx, loginReq.Phone)
if err != nil {
return nil, fmt.Errorf("用户名或密码错误")
}
// 验证密码
if !s.checkPassword(loginReq.Password, user.Password) {
// 2. 检查用户是否可以登录(委托给实体)
if !user.CanLogin() {
return nil, fmt.Errorf("用户状态异常,无法登录")
}
// 3. 验证密码(委托给实体)
if !user.CheckPassword(loginReq.Password) {
return nil, fmt.Errorf("用户名或密码错误")
}
// 发布用户登录事件
// 4. 发布用户登录事件
event := events.NewUserLoggedInEvent(
user.ID, user.Phone,
s.getClientIP(ctx), s.getUserAgent(ctx),
@@ -145,18 +136,23 @@ func (s *UserService) LoginWithPassword(ctx context.Context, loginReq *dto.Login
// LoginWithSMS 短信验证码登录
func (s *UserService) LoginWithSMS(ctx context.Context, loginReq *dto.LoginWithSMSRequest) (*entities.User, error) {
// 验证短信验证码
// 1. 验证短信验证码
if err := s.smsCodeService.VerifyCode(ctx, loginReq.Phone, loginReq.Code, entities.SMSSceneLogin); err != nil {
return nil, fmt.Errorf("验证码验证失败: %w", err)
}
// 根据手机号查找用户
// 2. 根据手机号查找用户
user, err := s.repo.FindByPhone(ctx, loginReq.Phone)
if err != nil {
return nil, fmt.Errorf("用户不存在")
}
// 发布用户登录事件
// 3. 检查用户是否可以登录(委托给实体)
if !user.CanLogin() {
return nil, fmt.Errorf("用户状态异常,无法登录")
}
// 4. 发布用户登录事件
event := events.NewUserLoggedInEvent(
user.ID, user.Phone,
s.getClientIP(ctx), s.getUserAgent(ctx),
@@ -174,40 +170,28 @@ func (s *UserService) LoginWithSMS(ctx context.Context, loginReq *dto.LoginWithS
// ChangePassword 修改密码
func (s *UserService) ChangePassword(ctx context.Context, userID string, req *dto.ChangePasswordRequest) error {
// 验证新密码确认
if req.NewPassword != req.ConfirmNewPassword {
return fmt.Errorf("新密码和确认新密码不匹配")
}
// 获取用户信息
// 1. 获取用户信息
user, err := s.repo.GetByID(ctx, userID)
if err != nil {
return fmt.Errorf("用户不存在: %w", err)
}
// 验证短信验证码
// 2. 验证短信验证码
if err := s.smsCodeService.VerifyCode(ctx, user.Phone, req.Code, entities.SMSSceneChangePassword); err != nil {
return fmt.Errorf("验证码验证失败: %w", err)
}
// 验证当前密码
if !s.checkPassword(req.OldPassword, user.Password) {
return fmt.Errorf("当前密码错误")
// 3. 执行业务逻辑(委托给实体)
if err := user.ChangePassword(req.OldPassword, req.NewPassword, req.ConfirmNewPassword); err != nil {
return err
}
// 哈希新密码
hashedPassword, err := s.hashPassword(req.NewPassword)
if err != nil {
return fmt.Errorf("新密码加密失败: %w", err)
}
// 更新密码
user.Password = hashedPassword
// 4. 保存用户
if err := s.repo.Update(ctx, user); err != nil {
return fmt.Errorf("密码更新失败: %w", err)
}
// 发布密码修改事件
// 5. 发布密码修改事件
event := events.NewUserPasswordChangedEvent(user.ID, user.Phone, s.getCorrelationID(ctx))
if err := s.eventBus.Publish(ctx, event); err != nil {
s.logger.Warn("发布密码修改事件失败", zap.Error(err))
@@ -232,7 +216,61 @@ func (s *UserService) GetByID(ctx context.Context, id string) (*entities.User, e
return user, nil
}
// 工具方法
// UpdateUserProfile 更新用户信息
func (s *UserService) UpdateUserProfile(ctx context.Context, userID string, req *dto.UpdateProfileRequest) (*entities.User, error) {
// 1. 获取用户信息
user, err := s.repo.GetByID(ctx, userID)
if err != nil {
return nil, fmt.Errorf("用户不存在: %w", err)
}
// 2. 更新手机号(如果需要)
if req.Phone != "" && req.Phone != user.Phone {
// 检查新手机号是否已存在
if err := s.checkPhoneDuplicate(ctx, req.Phone); err != nil {
return nil, err
}
// 使用实体的方法设置手机号
if err := user.SetPhone(req.Phone); err != nil {
return nil, err
}
}
// 3. 保存用户
if err := s.repo.Update(ctx, user); err != nil {
return nil, fmt.Errorf("更新用户信息失败: %w", err)
}
s.logger.Info("用户信息更新成功", zap.String("user_id", userID))
return user, nil
}
// DeactivateUser 停用用户
func (s *UserService) DeactivateUser(ctx context.Context, userID string) error {
// 1. 获取用户信息
user, err := s.repo.GetByID(ctx, userID)
if err != nil {
return fmt.Errorf("用户不存在: %w", err)
}
// 2. 检查用户状态
if user.IsDeleted() {
return fmt.Errorf("用户已被停用")
}
// 3. 软删除用户(这里需要调用仓储的软删除方法)
if err := s.repo.SoftDelete(ctx, userID); err != nil {
return fmt.Errorf("停用用户失败: %w", err)
}
s.logger.Info("用户停用成功", zap.String("user_id", userID))
return nil
}
// ================ 工具方法 ================
// checkPhoneDuplicate 检查手机号重复
func (s *UserService) checkPhoneDuplicate(ctx context.Context, phone string) error {
@@ -242,34 +280,6 @@ func (s *UserService) checkPhoneDuplicate(ctx context.Context, phone string) err
return nil
}
// hashPassword 加密密码
func (s *UserService) hashPassword(password string) (string, error) {
hashedBytes, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return "", err
}
return string(hashedBytes), nil
}
// checkPassword 验证密码
func (s *UserService) checkPassword(password, hash string) bool {
err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(password))
return err == nil
}
// isValidPhone 验证手机号格式
func (s *UserService) isValidPhone(phone string) bool {
// 简单的中国手机号验证11位数字以1开头
pattern := `^1[3-9]\d{9}$`
matched, _ := regexp.MatchString(pattern, phone)
return matched
}
// generateUserID 生成用户ID
func (s *UserService) generateUserID() string {
return uuid.New().String()
}
// getCorrelationID 获取关联ID
func (s *UserService) getCorrelationID(ctx context.Context) string {
if id := ctx.Value("correlation_id"); id != nil {