This commit is contained in:
liangzai 2025-04-08 12:49:19 +08:00
parent 39af208ea3
commit 14b5d10992
20 changed files with 1437 additions and 67 deletions

View File

@ -3,11 +3,13 @@ package agent
import ( import (
"context" "context"
"database/sql" "database/sql"
"github.com/pkg/errors"
"github.com/zeromicro/go-zero/core/stores/sqlx"
"time" "time"
"tydata-server/app/user/model" "tydata-server/app/user/model"
"tydata-server/common/xerr" "tydata-server/common/xerr"
"tydata-server/pkg/lzkit/crypto"
"github.com/pkg/errors"
"github.com/zeromicro/go-zero/core/stores/sqlx"
"tydata-server/app/user/cmd/api/internal/svc" "tydata-server/app/user/cmd/api/internal/svc"
"tydata-server/app/user/cmd/api/internal/types" "tydata-server/app/user/cmd/api/internal/types"
@ -34,7 +36,12 @@ func (l *ActivateAgentMembershipLogic) ActivateAgentMembership(req *types.AgentA
//if err != nil { //if err != nil {
// return nil, errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "获取用户ID失败: %v", err) // return nil, errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "获取用户ID失败: %v", err)
//} //}
userModel, err := l.svcCtx.UserModel.FindOneByMobile(l.ctx, req.Mobile) secretKey := l.svcCtx.Config.Encrypt.SecretKey
encryptedMobile, err := crypto.EncryptMobile(req.Mobile, secretKey)
if err != nil {
return nil, errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "加密手机号失败: %v", err)
}
userModel, err := l.svcCtx.UserModel.FindOneByMobile(l.ctx, encryptedMobile)
if err != nil { if err != nil {
return nil, errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "查询代理信息失败: %v", err) return nil, errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "查询代理信息失败: %v", err)
} }

View File

@ -7,6 +7,7 @@ import (
"tydata-server/app/user/model" "tydata-server/app/user/model"
jwtx "tydata-server/common/jwt" jwtx "tydata-server/common/jwt"
"tydata-server/common/xerr" "tydata-server/common/xerr"
"tydata-server/pkg/lzkit/crypto"
"tydata-server/pkg/lzkit/lzUtils" "tydata-server/pkg/lzkit/lzUtils"
"github.com/pkg/errors" "github.com/pkg/errors"
@ -34,17 +35,22 @@ func NewApplyForAgentLogic(ctx context.Context, svcCtx *svc.ServiceContext) *App
} }
func (l *ApplyForAgentLogic) ApplyForAgent(req *types.AgentApplyReq) (resp *types.AgentApplyResp, err error) { func (l *ApplyForAgentLogic) ApplyForAgent(req *types.AgentApplyReq) (resp *types.AgentApplyResp, err error) {
secretKey := l.svcCtx.Config.Encrypt.SecretKey
encryptedMobile, err := crypto.EncryptMobile(req.Mobile, secretKey)
if err != nil {
return nil, errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "加密手机号失败: %v", err)
}
// 校验验证码 // 校验验证码
redisKey := fmt.Sprintf("%s:%s", "agentApply", req.Mobile) redisKey := fmt.Sprintf("%s:%s", "agentApply", encryptedMobile)
cacheCode, err := l.svcCtx.Redis.Get(redisKey) cacheCode, err := l.svcCtx.Redis.Get(redisKey)
if err != nil { if err != nil {
if errors.Is(err, redis.Nil) { if errors.Is(err, redis.Nil) {
return nil, errors.Wrapf(xerr.NewErrMsg("验证码已过期"), "代理申请, 验证码过期: %s", req.Mobile) return nil, errors.Wrapf(xerr.NewErrMsg("验证码已过期"), "代理申请, 验证码过期: %s", encryptedMobile)
} }
return nil, errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "代理申请, 读取验证码redis缓存失败, mobile: %s, err: %+v", req.Mobile, err) return nil, errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "代理申请, 读取验证码redis缓存失败, mobile: %s, err: %+v", encryptedMobile, err)
} }
if cacheCode != req.Code { if cacheCode != req.Code {
return nil, errors.Wrapf(xerr.NewErrMsg("验证码不正确"), "代理申请, 验证码不正确: %s", req.Mobile) return nil, errors.Wrapf(xerr.NewErrMsg("验证码不正确"), "代理申请, 验证码不正确: %s", encryptedMobile)
} }
if req.Ancestor == req.Mobile { if req.Ancestor == req.Mobile {
return nil, errors.Wrapf(xerr.NewErrMsg("不能成为自己的代理"), "") return nil, errors.Wrapf(xerr.NewErrMsg("不能成为自己的代理"), "")
@ -52,18 +58,18 @@ func (l *ApplyForAgentLogic) ApplyForAgent(req *types.AgentApplyReq) (resp *type
var userID int64 var userID int64
transErr := l.svcCtx.AgentAuditModel.Trans(l.ctx, func(transCtx context.Context, session sqlx.Session) error { transErr := l.svcCtx.AgentAuditModel.Trans(l.ctx, func(transCtx context.Context, session sqlx.Session) error {
// 两种情况1. 已注册账号然后申请代理 2. 未注册账号申请代理 // 两种情况1. 已注册账号然后申请代理 2. 未注册账号申请代理
user, findUserErr := l.svcCtx.UserModel.FindOneByMobile(l.ctx, req.Mobile) user, findUserErr := l.svcCtx.UserModel.FindOneByMobile(l.ctx, encryptedMobile)
if findUserErr != nil && !errors.Is(findUserErr, model.ErrNotFound) { if findUserErr != nil && !errors.Is(findUserErr, model.ErrNotFound) {
return errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "手机登录, 读取数据库获取用户失败, mobile: %s, err: %+v", req.Mobile, err) return errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "代理申请, 读取数据库获取用户失败, mobile: %s, err: %+v", encryptedMobile, err)
} }
if user == nil { if user == nil {
user = &model.User{Mobile: req.Mobile} user = &model.User{Mobile: encryptedMobile}
if len(user.Nickname) == 0 { if len(user.Nickname) == 0 {
user.Nickname = req.Mobile user.Nickname = encryptedMobile
} }
insertResult, userInsertErr := l.svcCtx.UserModel.Insert(transCtx, session, user) insertResult, userInsertErr := l.svcCtx.UserModel.Insert(transCtx, session, user)
if userInsertErr != nil { if userInsertErr != nil {
return errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "代理申请, 数据库插入新用户失败, mobile%s, err: %+v", req.Mobile, err) return errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "代理申请, 数据库插入新用户失败, mobile%s, err: %+v", encryptedMobile, err)
} }
lastId, lastInsertIdErr := insertResult.LastInsertId() lastId, lastInsertIdErr := insertResult.LastInsertId()
if lastInsertIdErr != nil { if lastInsertIdErr != nil {
@ -73,7 +79,7 @@ func (l *ApplyForAgentLogic) ApplyForAgent(req *types.AgentApplyReq) (resp *type
userID = lastId userID = lastId
userAuth := new(model.UserAuth) userAuth := new(model.UserAuth)
userAuth.UserId = lastId userAuth.UserId = lastId
userAuth.AuthKey = req.Mobile userAuth.AuthKey = encryptedMobile
userAuth.AuthType = model.UserAuthTypeAgentDirect userAuth.AuthType = model.UserAuthTypeAgentDirect
if _, userAuthInsertErr := l.svcCtx.UserAuthModel.Insert(transCtx, session, userAuth); userAuthInsertErr != nil { if _, userAuthInsertErr := l.svcCtx.UserAuthModel.Insert(transCtx, session, userAuth); userAuthInsertErr != nil {
return errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "代理申请, 数据库插入用户认证失败, err:%+v", userAuthInsertErr) return errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "代理申请, 数据库插入用户认证失败, err:%+v", userAuthInsertErr)
@ -99,7 +105,7 @@ func (l *ApplyForAgentLogic) ApplyForAgent(req *types.AgentApplyReq) (resp *type
var agentAudit model.AgentAudit var agentAudit model.AgentAudit
agentAudit.UserId = user.Id agentAudit.UserId = user.Id
agentAudit.Mobile = req.Mobile agentAudit.Mobile = encryptedMobile
agentAudit.Region = req.Region agentAudit.Region = req.Region
agentAudit.WechatId = lzUtils.StringToNullString(req.WechatID) agentAudit.WechatId = lzUtils.StringToNullString(req.WechatID)
agentAudit.Status = 1 agentAudit.Status = 1
@ -133,7 +139,7 @@ func (l *ApplyForAgentLogic) ApplyForAgent(req *types.AgentApplyReq) (resp *type
// 关联上级 // 关联上级
if req.Ancestor != "" { if req.Ancestor != "" {
ancestorAgentModel, findAgentModelErr := l.svcCtx.AgentModel.FindOneByMobile(transCtx, req.Ancestor) ancestorAgentModel, findAgentModelErr := l.svcCtx.AgentModel.FindOneByMobile(transCtx, encryptedMobile)
if findAgentModelErr != nil { if findAgentModelErr != nil {
return errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "代理申请, 查找上级代理失败: %+v", findAgentModelErr) return errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "代理申请, 查找上级代理失败: %+v", findAgentModelErr)
} }

View File

@ -5,6 +5,7 @@ import (
"database/sql" "database/sql"
"tydata-server/common/ctxdata" "tydata-server/common/ctxdata"
"tydata-server/common/xerr" "tydata-server/common/xerr"
"tydata-server/pkg/lzkit/crypto"
"tydata-server/pkg/lzkit/lzUtils" "tydata-server/pkg/lzkit/lzUtils"
"github.com/pkg/errors" "github.com/pkg/errors"
@ -59,6 +60,10 @@ func (l *GetAgentInfoLogic) GetAgentInfo() (resp *types.AgentInfoResp, err error
} }
return nil, errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "获取代理信息, %v", err) return nil, errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "获取代理信息, %v", err)
} }
agent.Mobile, err = crypto.DecryptMobile(agent.Mobile, l.svcCtx.Config.Encrypt.SecretKey)
if err != nil {
return nil, errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "获取代理信息, 解密手机号失败: %v", err)
}
return &types.AgentInfoResp{ return &types.AgentInfoResp{
AgentID: agent.Id, AgentID: agent.Id,
Level: agent.LevelName, Level: agent.LevelName,

View File

@ -3,10 +3,12 @@ package auth
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/pkg/errors"
"math/rand" "math/rand"
"time" "time"
"tydata-server/common/xerr" "tydata-server/common/xerr"
"tydata-server/pkg/lzkit/crypto"
"github.com/pkg/errors"
"tydata-server/app/user/cmd/api/internal/svc" "tydata-server/app/user/cmd/api/internal/svc"
"tydata-server/app/user/cmd/api/internal/types" "tydata-server/app/user/cmd/api/internal/types"
@ -33,16 +35,21 @@ func NewSendSmsLogic(ctx context.Context, svcCtx *svc.ServiceContext) *SendSmsLo
} }
func (l *SendSmsLogic) SendSms(req *types.SendSmsReq) error { func (l *SendSmsLogic) SendSms(req *types.SendSmsReq) error {
secretKey := l.svcCtx.Config.Encrypt.SecretKey
encryptedMobile, err := crypto.EncryptMobile(req.Mobile, secretKey)
if err != nil {
return errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "短信发送, 加密手机号失败: %v", err)
}
// 检查手机号是否在一分钟内已发送过验证码 // 检查手机号是否在一分钟内已发送过验证码
limitCodeKey := fmt.Sprintf("limit:%s:%s", req.ActionType, req.Mobile) limitCodeKey := fmt.Sprintf("limit:%s:%s", req.ActionType, encryptedMobile)
exists, err := l.svcCtx.Redis.Exists(limitCodeKey) exists, err := l.svcCtx.Redis.Exists(limitCodeKey)
if err != nil { if err != nil {
return errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "短信发送, 读取redis缓存失败: %s", req.Mobile) return errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "短信发送, 读取redis缓存失败: %s", encryptedMobile)
} }
if exists { if exists {
// 如果 Redis 中已经存在标记,说明在 1 分钟内请求过,返回错误 // 如果 Redis 中已经存在标记,说明在 1 分钟内请求过,返回错误
return errors.Wrapf(xerr.NewErrMsg("一分钟内不能重复发送验证码"), "短信发送, 手机号1分钟内重复请求发送二维码: %s", req.Mobile) return errors.Wrapf(xerr.NewErrMsg("一分钟内不能重复发送验证码"), "短信发送, 手机号1分钟内重复请求发送验证码: %s", encryptedMobile)
} }
code := fmt.Sprintf("%06d", rand.New(rand.NewSource(time.Now().UnixNano())).Intn(1000000)) code := fmt.Sprintf("%06d", rand.New(rand.NewSource(time.Now().UnixNano())).Intn(1000000))
@ -55,7 +62,7 @@ func (l *SendSmsLogic) SendSms(req *types.SendSmsReq) error {
if *smsResp.Body.Code != "OK" { if *smsResp.Body.Code != "OK" {
return errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "短信发送, 阿里客户端响应失败: %s", *smsResp.Body.Message) return errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "短信发送, 阿里客户端响应失败: %s", *smsResp.Body.Message)
} }
codeKey := fmt.Sprintf("%s:%s", req.ActionType, req.Mobile) codeKey := fmt.Sprintf("%s:%s", req.ActionType, encryptedMobile)
// 将验证码保存到 Redis设置过期时间 // 将验证码保存到 Redis设置过期时间
err = l.svcCtx.Redis.Setex(codeKey, code, l.svcCtx.Config.VerifyCode.ValidTime) // 验证码有效期5分钟 err = l.svcCtx.Redis.Setex(codeKey, code, l.svcCtx.Config.VerifyCode.ValidTime) // 验证码有效期5分钟
if err != nil { if err != nil {

View File

@ -7,6 +7,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"time" "time"
"tydata-server/common/ctxdata"
"tydata-server/common/xerr" "tydata-server/common/xerr"
"tydata-server/pkg/lzkit/crypto" "tydata-server/pkg/lzkit/crypto"
"tydata-server/pkg/lzkit/delay" "tydata-server/pkg/lzkit/delay"
@ -17,6 +18,7 @@ import (
"tydata-server/app/user/cmd/api/internal/svc" "tydata-server/app/user/cmd/api/internal/svc"
"tydata-server/app/user/cmd/api/internal/types" "tydata-server/app/user/cmd/api/internal/types"
"tydata-server/app/user/model"
"github.com/zeromicro/go-zero/core/logx" "github.com/zeromicro/go-zero/core/logx"
) )
@ -36,12 +38,26 @@ func NewQueryDetailByOrderIdLogic(ctx context.Context, svcCtx *svc.ServiceContex
} }
func (l *QueryDetailByOrderIdLogic) QueryDetailByOrderId(req *types.QueryDetailByOrderIdReq) (resp *types.QueryDetailByOrderIdResp, err error) { func (l *QueryDetailByOrderIdLogic) QueryDetailByOrderId(req *types.QueryDetailByOrderIdReq) (resp *types.QueryDetailByOrderIdResp, err error) {
// 获取当前用户ID
userId, err := ctxdata.GetUidFromCtx(l.ctx)
if err != nil {
return nil, errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "获取用户ID失败: %v", err)
}
// 获取订单信息 // 获取订单信息
order, err := l.svcCtx.OrderModel.FindOne(l.ctx, req.OrderId) order, err := l.svcCtx.OrderModel.FindOne(l.ctx, req.OrderId)
if err != nil { if err != nil {
if errors.Is(err, model.ErrNotFound) {
return nil, errors.Wrapf(xerr.NewErrCode(xerr.LOGIC_QUERY_NOT_FOUND), "报告查询, 订单不存在: %v", err)
}
return nil, errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "报告查询, 查找报告错误: %v", err) return nil, errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "报告查询, 查找报告错误: %v", err)
} }
// 安全验证:确保订单属于当前用户
if order.UserId != userId {
return nil, errors.Wrapf(xerr.NewErrCode(xerr.LOGIC_QUERY_NOT_FOUND), "无权查看此订单报告")
}
// 创建渐进式延迟策略实例 // 创建渐进式延迟策略实例
progressiveDelayOrder, err := delay.New(200*time.Millisecond, 3*time.Second, 10*time.Second, 1.5) progressiveDelayOrder, err := delay.New(200*time.Millisecond, 3*time.Second, 10*time.Second, 1.5)
if err != nil { if err != nil {

View File

@ -5,6 +5,7 @@ import (
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"time" "time"
"tydata-server/common/ctxdata"
"tydata-server/common/xerr" "tydata-server/common/xerr"
"tydata-server/pkg/lzkit/delay" "tydata-server/pkg/lzkit/delay"
@ -13,6 +14,7 @@ import (
"tydata-server/app/user/cmd/api/internal/svc" "tydata-server/app/user/cmd/api/internal/svc"
"tydata-server/app/user/cmd/api/internal/types" "tydata-server/app/user/cmd/api/internal/types"
"tydata-server/app/user/model"
"github.com/zeromicro/go-zero/core/logx" "github.com/zeromicro/go-zero/core/logx"
) )
@ -32,12 +34,26 @@ func NewQueryDetailByOrderNoLogic(ctx context.Context, svcCtx *svc.ServiceContex
} }
func (l *QueryDetailByOrderNoLogic) QueryDetailByOrderNo(req *types.QueryDetailByOrderNoReq) (resp *types.QueryDetailByOrderNoResp, err error) { func (l *QueryDetailByOrderNoLogic) QueryDetailByOrderNo(req *types.QueryDetailByOrderNoReq) (resp *types.QueryDetailByOrderNoResp, err error) {
// 获取当前用户ID
userId, err := ctxdata.GetUidFromCtx(l.ctx)
if err != nil {
return nil, errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "获取用户ID失败: %v", err)
}
// 获取订单信息 // 获取订单信息
order, err := l.svcCtx.OrderModel.FindOneByOrderNo(l.ctx, req.OrderNo) order, err := l.svcCtx.OrderModel.FindOneByOrderNo(l.ctx, req.OrderNo)
if err != nil { if err != nil {
if errors.Is(err, model.ErrNotFound) {
return nil, errors.Wrapf(xerr.NewErrCode(xerr.LOGIC_QUERY_NOT_FOUND), "报告查询, 订单不存在: %v", err)
}
return nil, errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "报告查询, 查找报告错误: %v", err) return nil, errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "报告查询, 查找报告错误: %v", err)
} }
// 安全验证:确保订单属于当前用户
if order.UserId != userId {
return nil, errors.Wrapf(xerr.NewErrCode(xerr.LOGIC_QUERY_NOT_FOUND), "无权查看此订单报告")
}
// 创建渐进式延迟策略实例 // 创建渐进式延迟策略实例
progressiveDelayOrder, err := delay.New(200*time.Millisecond, 3*time.Second, 10*time.Second, 1.5) progressiveDelayOrder, err := delay.New(200*time.Millisecond, 3*time.Second, 10*time.Second, 1.5)
if err != nil { if err != nil {

View File

@ -3,15 +3,17 @@ package user
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/pkg/errors"
"github.com/zeromicro/go-zero/core/stores/redis"
"github.com/zeromicro/go-zero/core/stores/sqlx"
"time" "time"
"tydata-server/app/user/cmd/api/internal/svc" "tydata-server/app/user/cmd/api/internal/svc"
"tydata-server/app/user/cmd/api/internal/types" "tydata-server/app/user/cmd/api/internal/types"
"tydata-server/app/user/model" "tydata-server/app/user/model"
jwtx "tydata-server/common/jwt" jwtx "tydata-server/common/jwt"
"tydata-server/common/xerr" "tydata-server/common/xerr"
"tydata-server/pkg/lzkit/crypto"
"github.com/pkg/errors"
"github.com/zeromicro/go-zero/core/stores/redis"
"github.com/zeromicro/go-zero/core/stores/sqlx"
"github.com/zeromicro/go-zero/core/logx" "github.com/zeromicro/go-zero/core/logx"
) )
@ -31,32 +33,37 @@ func NewAgentMobileCodeLoginLogic(ctx context.Context, svcCtx *svc.ServiceContex
} }
func (l *AgentMobileCodeLoginLogic) AgentMobileCodeLogin(req *types.MobileCodeLoginReq) (resp *types.MobileCodeLoginResp, err error) { func (l *AgentMobileCodeLoginLogic) AgentMobileCodeLogin(req *types.MobileCodeLoginReq) (resp *types.MobileCodeLoginResp, err error) {
secretKey := l.svcCtx.Config.Encrypt.SecretKey
encryptedMobile, err := crypto.EncryptMobile(req.Mobile, secretKey)
if err != nil {
return nil, errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "手机登录, 加密手机号失败: %+v", err)
}
// 检查手机号是否在一分钟内已发送过验证码 // 检查手机号是否在一分钟内已发送过验证码
redisKey := fmt.Sprintf("%s:%s", "query", req.Mobile) redisKey := fmt.Sprintf("%s:%s", "query", encryptedMobile)
cacheCode, err := l.svcCtx.Redis.Get(redisKey) cacheCode, err := l.svcCtx.Redis.Get(redisKey)
if err != nil { if err != nil {
if errors.Is(err, redis.Nil) { if errors.Is(err, redis.Nil) {
return nil, errors.Wrapf(xerr.NewErrMsg("验证码已过期"), "手机登录, 验证码过期: %s", req.Mobile) return nil, errors.Wrapf(xerr.NewErrMsg("验证码已过期"), "手机登录, 验证码过期")
} }
return nil, errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "手机登录, 读取验证码redis缓存失败, mobile: %s, err: %+v", req.Mobile, err) return nil, errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "手机登录, 读取验证码redis缓存失败, err: %+v", err)
} }
if cacheCode != req.Code { if cacheCode != req.Code {
return nil, errors.Wrapf(xerr.NewErrMsg("验证码不正确"), "手机登录, 验证码不正确: %s", req.Mobile) return nil, errors.Wrapf(xerr.NewErrMsg("验证码不正确"), "手机登录, 验证码不正确")
} }
user, findUserErr := l.svcCtx.UserModel.FindOneByMobile(l.ctx, req.Mobile) user, findUserErr := l.svcCtx.UserModel.FindOneByMobile(l.ctx, encryptedMobile)
if findUserErr != nil && findUserErr != model.ErrNotFound { if findUserErr != nil && findUserErr != model.ErrNotFound {
return nil, errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "手机登录, 读取数据库获取用户失败, mobile: %s, err: %+v", req.Mobile, err) return nil, errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "手机登录, 读取数据库获取用户失败, mobile: %s, err: %+v", encryptedMobile, err)
} }
if user == nil { if user == nil {
user = &model.User{Mobile: req.Mobile} user = &model.User{Mobile: encryptedMobile}
if len(user.Nickname) == 0 { if len(user.Nickname) == 0 {
user.Nickname = req.Mobile user.Nickname = ""
} }
if transErr := l.svcCtx.UserModel.Trans(l.ctx, func(ctx context.Context, session sqlx.Session) error { if transErr := l.svcCtx.UserModel.Trans(l.ctx, func(ctx context.Context, session sqlx.Session) error {
insertResult, userInsertErr := l.svcCtx.UserModel.Insert(ctx, session, user) insertResult, userInsertErr := l.svcCtx.UserModel.Insert(ctx, session, user)
if userInsertErr != nil { if userInsertErr != nil {
return errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "手机注册, 数据库插入新用户失败, mobile%s, err: %+v", req.Mobile, err) return errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "手机注册, 数据库插入新用户失败, mobile%s, err: %+v", encryptedMobile, err)
} }
lastId, lastInsertIdErr := insertResult.LastInsertId() lastId, lastInsertIdErr := insertResult.LastInsertId()
if lastInsertIdErr != nil { if lastInsertIdErr != nil {
@ -66,7 +73,7 @@ func (l *AgentMobileCodeLoginLogic) AgentMobileCodeLogin(req *types.MobileCodeLo
userAuth := new(model.UserAuth) userAuth := new(model.UserAuth)
userAuth.UserId = lastId userAuth.UserId = lastId
userAuth.AuthKey = req.Mobile userAuth.AuthKey = encryptedMobile
userAuth.AuthType = model.UserAuthTypeH5Mobile userAuth.AuthType = model.UserAuthTypeH5Mobile
if _, userAuthInsertErr := l.svcCtx.UserAuthModel.Insert(ctx, session, userAuth); userAuthInsertErr != nil { if _, userAuthInsertErr := l.svcCtx.UserAuthModel.Insert(ctx, session, userAuth); userAuthInsertErr != nil {
return errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "手机注册, 数据库插入用户认证失败, err:%+v", userAuthInsertErr) return errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "手机注册, 数据库插入用户认证失败, err:%+v", userAuthInsertErr)

View File

@ -2,12 +2,14 @@ package user
import ( import (
"context" "context"
"github.com/jinzhu/copier"
"github.com/pkg/errors"
"tydata-server/app/user/cmd/api/internal/svc" "tydata-server/app/user/cmd/api/internal/svc"
"tydata-server/app/user/cmd/api/internal/types" "tydata-server/app/user/cmd/api/internal/types"
"tydata-server/common/ctxdata" "tydata-server/common/ctxdata"
"tydata-server/common/xerr" "tydata-server/common/xerr"
"tydata-server/pkg/lzkit/crypto"
"github.com/jinzhu/copier"
"github.com/pkg/errors"
"github.com/zeromicro/go-zero/core/logx" "github.com/zeromicro/go-zero/core/logx"
) )
@ -40,6 +42,10 @@ func (l *DetailLogic) Detail() (resp *types.UserInfoResp, err error) {
if err != nil { if err != nil {
return nil, errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "用户信息, 用户信息结构体复制失败, %v", err) return nil, errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "用户信息, 用户信息结构体复制失败, %v", err)
} }
userInfo.Mobile, err = crypto.DecryptMobile(userInfo.Mobile, l.svcCtx.Config.Encrypt.SecretKey)
if err != nil {
return nil, errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "用户信息, 解密手机号失败, %v", err)
}
return &types.UserInfoResp{ return &types.UserInfoResp{
UserInfo: userInfo, UserInfo: userInfo,
}, nil }, nil

View File

@ -3,15 +3,17 @@ package user
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/pkg/errors"
"github.com/zeromicro/go-zero/core/stores/redis"
"github.com/zeromicro/go-zero/core/stores/sqlx"
"time" "time"
"tydata-server/app/user/cmd/api/internal/svc" "tydata-server/app/user/cmd/api/internal/svc"
"tydata-server/app/user/cmd/api/internal/types" "tydata-server/app/user/cmd/api/internal/types"
"tydata-server/app/user/model" "tydata-server/app/user/model"
jwtx "tydata-server/common/jwt" jwtx "tydata-server/common/jwt"
"tydata-server/common/xerr" "tydata-server/common/xerr"
"tydata-server/pkg/lzkit/crypto"
"github.com/pkg/errors"
"github.com/zeromicro/go-zero/core/stores/redis"
"github.com/zeromicro/go-zero/core/stores/sqlx"
"github.com/zeromicro/go-zero/core/logx" "github.com/zeromicro/go-zero/core/logx"
) )
@ -31,34 +33,39 @@ func NewMobileCodeLoginLogic(ctx context.Context, svcCtx *svc.ServiceContext) *M
} }
func (l *MobileCodeLoginLogic) MobileCodeLogin(req *types.MobileCodeLoginReq) (resp *types.MobileCodeLoginResp, err error) { func (l *MobileCodeLoginLogic) MobileCodeLogin(req *types.MobileCodeLoginReq) (resp *types.MobileCodeLoginResp, err error) {
secretKey := l.svcCtx.Config.Encrypt.SecretKey
encryptedMobile, err := crypto.EncryptMobile(req.Mobile, secretKey)
if err != nil {
return nil, errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "手机登录, 加密手机号失败: %+v", err)
}
if !l.MobileCodeLoginInside(req) { if !l.MobileCodeLoginInside(req) {
// 检查手机号是否在一分钟内已发送过验证码 // 检查手机号是否在一分钟内已发送过验证码
redisKey := fmt.Sprintf("%s:%s", "login", req.Mobile) redisKey := fmt.Sprintf("%s:%s", "login", encryptedMobile)
cacheCode, err := l.svcCtx.Redis.Get(redisKey) cacheCode, err := l.svcCtx.Redis.Get(redisKey)
if err != nil { if err != nil {
if errors.Is(err, redis.Nil) { if errors.Is(err, redis.Nil) {
return nil, errors.Wrapf(xerr.NewErrMsg("验证码已过期"), "手机登录, 验证码过期: %s", req.Mobile) return nil, errors.Wrapf(xerr.NewErrMsg("验证码已过期"), "手机登录, 验证码过期: %s", encryptedMobile)
} }
return nil, errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "手机登录, 读取验证码redis缓存失败, mobile: %s, err: %+v", req.Mobile, err) return nil, errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "手机登录, 读取验证码redis缓存失败, mobile: %s, err: %+v", encryptedMobile, err)
} }
if cacheCode != req.Code { if cacheCode != req.Code {
return nil, errors.Wrapf(xerr.NewErrMsg("验证码不正确"), "手机登录, 验证码不正确: %s", req.Mobile) return nil, errors.Wrapf(xerr.NewErrMsg("验证码不正确"), "手机登录, 验证码不正确: %s", encryptedMobile)
} }
} }
user, findUserErr := l.svcCtx.UserModel.FindOneByMobile(l.ctx, req.Mobile) user, findUserErr := l.svcCtx.UserModel.FindOneByMobile(l.ctx, encryptedMobile)
if findUserErr != nil && findUserErr != model.ErrNotFound { if findUserErr != nil && findUserErr != model.ErrNotFound {
return nil, errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "手机登录, 读取数据库获取用户失败, mobile: %s, err: %+v", req.Mobile, err) return nil, errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "手机登录, 读取数据库获取用户失败, mobile: %s, err: %+v", encryptedMobile, err)
} }
if user == nil { if user == nil {
user = &model.User{Mobile: req.Mobile} user = &model.User{Mobile: encryptedMobile}
if len(user.Nickname) == 0 { if len(user.Nickname) == 0 {
user.Nickname = req.Mobile user.Nickname = encryptedMobile
} }
if transErr := l.svcCtx.UserModel.Trans(l.ctx, func(ctx context.Context, session sqlx.Session) error { if transErr := l.svcCtx.UserModel.Trans(l.ctx, func(ctx context.Context, session sqlx.Session) error {
insertResult, userInsertErr := l.svcCtx.UserModel.Insert(ctx, session, user) insertResult, userInsertErr := l.svcCtx.UserModel.Insert(ctx, session, user)
if userInsertErr != nil { if userInsertErr != nil {
return errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "手机注册, 数据库插入新用户失败, mobile%s, err: %+v", req.Mobile, err) return errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "手机注册, 数据库插入新用户失败, mobile%s, err: %+v", encryptedMobile, err)
} }
lastId, lastInsertIdErr := insertResult.LastInsertId() lastId, lastInsertIdErr := insertResult.LastInsertId()
if lastInsertIdErr != nil { if lastInsertIdErr != nil {
@ -68,7 +75,7 @@ func (l *MobileCodeLoginLogic) MobileCodeLogin(req *types.MobileCodeLoginReq) (r
userAuth := new(model.UserAuth) userAuth := new(model.UserAuth)
userAuth.UserId = lastId userAuth.UserId = lastId
userAuth.AuthKey = req.Mobile userAuth.AuthKey = encryptedMobile
userAuth.AuthType = model.UserAuthTypeAppMobile userAuth.AuthType = model.UserAuthTypeAppMobile
if _, userAuthInsertErr := l.svcCtx.UserAuthModel.Insert(ctx, session, userAuth); userAuthInsertErr != nil { if _, userAuthInsertErr := l.svcCtx.UserAuthModel.Insert(ctx, session, userAuth); userAuthInsertErr != nil {
return errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "手机注册, 数据库插入用户认证失败, err:%+v", userAuthInsertErr) return errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "手机注册, 数据库插入用户认证失败, err:%+v", userAuthInsertErr)

View File

@ -2,14 +2,16 @@ package user
import ( import (
"context" "context"
"github.com/pkg/errors"
"time" "time"
"tydata-server/app/user/model" "tydata-server/app/user/model"
jwtx "tydata-server/common/jwt" jwtx "tydata-server/common/jwt"
"tydata-server/common/tool" "tydata-server/common/tool"
"tydata-server/common/xerr" "tydata-server/common/xerr"
"tydata-server/pkg/lzkit/crypto"
"tydata-server/pkg/lzkit/lzUtils" "tydata-server/pkg/lzkit/lzUtils"
"github.com/pkg/errors"
"tydata-server/app/user/cmd/api/internal/svc" "tydata-server/app/user/cmd/api/internal/svc"
"tydata-server/app/user/cmd/api/internal/types" "tydata-server/app/user/cmd/api/internal/types"
@ -31,15 +33,20 @@ func NewMobileLoginLogic(ctx context.Context, svcCtx *svc.ServiceContext) *Mobil
} }
func (l *MobileLoginLogic) MobileLogin(req *types.MobileLoginReq) (resp *types.MobileCodeLoginResp, err error) { func (l *MobileLoginLogic) MobileLogin(req *types.MobileLoginReq) (resp *types.MobileCodeLoginResp, err error) {
user, findUserErr := l.svcCtx.UserModel.FindOneByMobile(l.ctx, req.Mobile) secretKey := l.svcCtx.Config.Encrypt.SecretKey
encryptedMobile, err := crypto.EncryptMobile(req.Mobile, secretKey)
if err != nil {
return nil, errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "手机登录, 加密手机号失败: %+v", err)
}
user, findUserErr := l.svcCtx.UserModel.FindOneByMobile(l.ctx, encryptedMobile)
if findUserErr != nil && findUserErr != model.ErrNotFound { if findUserErr != nil && findUserErr != model.ErrNotFound {
return nil, errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "手机登录, 读取数据库获取用户失败, mobile%s, err: %+v", req.Mobile, err) return nil, errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "手机登录, 读取数据库获取用户失败, mobile%s, err: %+v", encryptedMobile, err)
} }
if user == nil { if user == nil {
return nil, errors.Wrapf(xerr.NewErrMsg("手机号码未注册"), "手机登录, 手机号未注册:%s", req.Mobile) return nil, errors.Wrapf(xerr.NewErrMsg("手机号码未注册"), "手机登录, 手机号未注册:%s", encryptedMobile)
} }
if !(tool.Md5ByString(req.Password) == lzUtils.NullStringToString(user.Password)) { if !(tool.Md5ByString(req.Password) == lzUtils.NullStringToString(user.Password)) {
return nil, errors.Wrapf(xerr.NewErrMsg("密码不正确"), "手机登录, 密码匹配不正确%s", req.Mobile) return nil, errors.Wrapf(xerr.NewErrMsg("密码不正确"), "手机登录, 密码匹配不正确%s", encryptedMobile)
} }
token, generaErr := jwtx.GenerateJwtToken(user.Id, l.svcCtx.Config.JwtAuth.AccessSecret, l.svcCtx.Config.JwtAuth.AccessExpire) token, generaErr := jwtx.GenerateJwtToken(user.Id, l.svcCtx.Config.JwtAuth.AccessSecret, l.svcCtx.Config.JwtAuth.AccessExpire)

View File

@ -3,9 +3,6 @@ package user
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/pkg/errors"
"github.com/zeromicro/go-zero/core/stores/redis"
"github.com/zeromicro/go-zero/core/stores/sqlx"
"time" "time"
"tydata-server/app/user/cmd/api/internal/svc" "tydata-server/app/user/cmd/api/internal/svc"
"tydata-server/app/user/cmd/api/internal/types" "tydata-server/app/user/cmd/api/internal/types"
@ -13,8 +10,13 @@ import (
jwtx "tydata-server/common/jwt" jwtx "tydata-server/common/jwt"
"tydata-server/common/tool" "tydata-server/common/tool"
"tydata-server/common/xerr" "tydata-server/common/xerr"
"tydata-server/pkg/lzkit/crypto"
"tydata-server/pkg/lzkit/lzUtils" "tydata-server/pkg/lzkit/lzUtils"
"github.com/pkg/errors"
"github.com/zeromicro/go-zero/core/stores/redis"
"github.com/zeromicro/go-zero/core/stores/sqlx"
"github.com/zeromicro/go-zero/core/logx" "github.com/zeromicro/go-zero/core/logx"
) )
@ -33,38 +35,43 @@ func NewRegisterLogic(ctx context.Context, svcCtx *svc.ServiceContext) *Register
} }
func (l *RegisterLogic) Register(req *types.RegisterReq) (resp *types.RegisterResp, err error) { func (l *RegisterLogic) Register(req *types.RegisterReq) (resp *types.RegisterResp, err error) {
secretKey := l.svcCtx.Config.Encrypt.SecretKey
encryptedMobile, err := crypto.EncryptMobile(req.Mobile, secretKey)
if err != nil {
return nil, errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "手机注册, 加密手机号失败: %+v", err)
}
// 检查手机号是否在一分钟内已发送过验证码 // 检查手机号是否在一分钟内已发送过验证码
redisKey := fmt.Sprintf("%s:%s", "register", req.Mobile) redisKey := fmt.Sprintf("%s:%s", "register", encryptedMobile)
cacheCode, err := l.svcCtx.Redis.Get(redisKey) cacheCode, err := l.svcCtx.Redis.Get(redisKey)
if err != nil { if err != nil {
if errors.Is(err, redis.Nil) { if errors.Is(err, redis.Nil) {
return nil, errors.Wrapf(xerr.NewErrMsg("验证码已过期"), "手机注册, 验证码过期: %s", req.Mobile) return nil, errors.Wrapf(xerr.NewErrMsg("验证码已过期"), "手机注册, 验证码过期: %s", encryptedMobile)
} }
return nil, errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "手机注册, 读取验证码redis缓存失败, mobile: %s, err: %+v", req.Mobile, err) return nil, errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "手机注册, 读取验证码redis缓存失败, mobile: %s, err: %+v", encryptedMobile, err)
} }
if cacheCode != req.Code { if cacheCode != req.Code {
return nil, errors.Wrapf(xerr.NewErrMsg("验证码不正确"), "手机注册, 验证码不正确: %s", req.Mobile) return nil, errors.Wrapf(xerr.NewErrMsg("验证码不正确"), "手机注册, 验证码不正确: %s", encryptedMobile)
} }
hasUser, findUserErr := l.svcCtx.UserModel.FindOneByMobile(l.ctx, req.Mobile) hasUser, findUserErr := l.svcCtx.UserModel.FindOneByMobile(l.ctx, encryptedMobile)
if findUserErr != nil && findUserErr != model.ErrNotFound { if findUserErr != nil && findUserErr != model.ErrNotFound {
return nil, errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "手机注册, 读取数据库获取用户失败, mobile%s, err: %+v", req.Mobile, err) return nil, errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "手机注册, 读取数据库获取用户失败, mobile%s, err: %+v", encryptedMobile, err)
} }
if hasUser != nil { if hasUser != nil {
return nil, errors.Wrapf(xerr.NewErrMsg("该手机号码已注册"), "手机注册, 手机号码已注册, mobile:%s", req.Mobile) return nil, errors.Wrapf(xerr.NewErrMsg("该手机号码已注册"), "手机注册, 手机号码已注册, mobile:%s", encryptedMobile)
} }
var userId int64 var userId int64
if transErr := l.svcCtx.UserModel.Trans(l.ctx, func(ctx context.Context, session sqlx.Session) error { if transErr := l.svcCtx.UserModel.Trans(l.ctx, func(ctx context.Context, session sqlx.Session) error {
user := new(model.User) user := new(model.User)
user.Mobile = req.Mobile user.Mobile = encryptedMobile
if len(user.Nickname) == 0 { if len(user.Nickname) == 0 {
user.Nickname = req.Mobile user.Nickname = encryptedMobile
} }
if len(req.Password) > 0 { if len(req.Password) > 0 {
user.Password = lzUtils.StringToNullString(tool.Md5ByString(req.Password)) user.Password = lzUtils.StringToNullString(tool.Md5ByString(req.Password))
} }
insertResult, userInsertErr := l.svcCtx.UserModel.Insert(ctx, session, user) insertResult, userInsertErr := l.svcCtx.UserModel.Insert(ctx, session, user)
if userInsertErr != nil { if userInsertErr != nil {
return errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "手机注册, 数据库插入新用户失败, mobile%s, err: %+v", req.Mobile, err) return errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "手机注册, 数据库插入新用户失败, mobile%s, err: %+v", encryptedMobile, err)
} }
lastId, lastInsertIdErr := insertResult.LastInsertId() lastId, lastInsertIdErr := insertResult.LastInsertId()
if lastInsertIdErr != nil { if lastInsertIdErr != nil {
@ -74,7 +81,7 @@ func (l *RegisterLogic) Register(req *types.RegisterReq) (resp *types.RegisterRe
userAuth := new(model.UserAuth) userAuth := new(model.UserAuth)
userAuth.UserId = lastId userAuth.UserId = lastId
userAuth.AuthKey = req.Mobile userAuth.AuthKey = encryptedMobile
userAuth.AuthType = model.UserAuthTypeAppMobile userAuth.AuthType = model.UserAuthTypeAppMobile
if _, userAuthInsertErr := l.svcCtx.UserAuthModel.Insert(ctx, session, userAuth); userAuthInsertErr != nil { if _, userAuthInsertErr := l.svcCtx.UserAuthModel.Insert(ctx, session, userAuth); userAuthInsertErr != nil {
return errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "手机注册, 数据库插入用户认证失败, err:%+v", userAuthInsertErr) return errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "手机注册, 数据库插入用户认证失败, err:%+v", userAuthInsertErr)

View File

@ -14,7 +14,9 @@ const DB_ERROR uint32 = 100005
const DB_UPDATE_AFFECTED_ZERO_ERROR uint32 = 100006 const DB_UPDATE_AFFECTED_ZERO_ERROR uint32 = 100006
const PARAM_VERIFICATION_ERROR uint32 = 100007 const PARAM_VERIFICATION_ERROR uint32 = 100007
const CUSTOM_ERROR uint32 = 100008 const CUSTOM_ERROR uint32 = 100008
const AUTH_ERROR uint32 = 100009
const LOGIN_FAILED uint32 = 200001 const LOGIN_FAILED uint32 = 200001
const LOGIC_QUERY_WAIT uint32 = 200002 const LOGIC_QUERY_WAIT uint32 = 200002
const LOGIC_QUERY_ERROR uint32 = 200003 const LOGIC_QUERY_ERROR uint32 = 200003
const LOGIC_QUERY_NOT_FOUND uint32 = 200004

View File

@ -11,6 +11,7 @@ func init() {
message[TOKEN_GENERATE_ERROR] = "生成token失败" message[TOKEN_GENERATE_ERROR] = "生成token失败"
message[DB_ERROR] = "数据库繁忙,请稍后再试" message[DB_ERROR] = "数据库繁忙,请稍后再试"
message[DB_UPDATE_AFFECTED_ZERO_ERROR] = "更新数据影响行数为0" message[DB_UPDATE_AFFECTED_ZERO_ERROR] = "更新数据影响行数为0"
message[AUTH_ERROR] = "权限错误,无权访问"
} }
func MapErrMsg(errcode uint32) string { func MapErrMsg(errcode uint32) string {

235
pkg/lzkit/crypto/README.md Normal file
View File

@ -0,0 +1,235 @@
# AES 加密工具包
本包提供了多种加密方式,特别是用于处理敏感个人信息(如手机号、身份证号等)的加密和解密功能。
## 主要功能
- **AES-CBC 模式加密/解密** - 标准加密模式,适用于一般数据加密
- **AES-ECB 模式加密/解密** - 确定性加密模式,适用于数据库字段加密和查询
- **专门针对个人敏感信息的加密/解密方法**
- **密钥生成和管理工具**
## 安全性说明
- **AES-CBC 模式**:使用随机 IV相同明文每次加密结果不同安全性较高
- **AES-ECB 模式**:确定性加密,相同明文每次加密结果相同,便于数据库查询,但安全性较低
> **⚠️ 警告**ECB 模式仅适用于短文本(如手机号、身份证号)的确定性加密,不建议用于加密大段文本或高安全需求场景。
## 使用示例
### 1. 加密手机号
使用 AES-ECB 模式加密手机号,保证确定性(相同手机号总是产生相同密文)
```go
import (
"fmt"
"tydata-server/pkg/lzkit/crypto"
)
func encryptMobileExample() {
// 您的密钥(需安全保存,建议存储在配置中)
key := []byte("1234567890abcdef") // 16字节AES-128密钥
// 加密手机号
mobile := "13800138000"
encryptedMobile, err := crypto.EncryptMobile(mobile, key)
if err != nil {
panic(err)
}
fmt.Println("加密后的手机号:", encryptedMobile)
// 解密手机号
decryptedMobile, err := crypto.DecryptMobile(encryptedMobile, key)
if err != nil {
panic(err)
}
fmt.Println("解密后的手机号:", decryptedMobile)
}
```
### 2. 在数据库中存储和查询加密手机号
```go
// 加密并存储手机号
func saveUser(db *sqlx.DB, mobile string, key []byte) (int64, error) {
encryptedMobile, err := crypto.EncryptMobile(mobile, key)
if err != nil {
return 0, err
}
var id int64
err = db.QueryRow(
"INSERT INTO users (mobile, create_time) VALUES (?, NOW()) RETURNING id",
encryptedMobile,
).Scan(&id)
return id, err
}
// 根据手机号查询用户
func findUserByMobile(db *sqlx.DB, mobile string, key []byte) (*User, error) {
encryptedMobile, err := crypto.EncryptMobile(mobile, key)
if err != nil {
return nil, err
}
var user User
err = db.QueryRow(
"SELECT id, mobile, create_time FROM users WHERE mobile = ?",
encryptedMobile,
).Scan(&user.ID, &user.EncryptedMobile, &user.CreateTime)
if err != nil {
return nil, err
}
// 解密手机号用于显示
user.Mobile, _ = crypto.DecryptMobile(user.EncryptedMobile, key)
return &user, nil
}
```
### 3. 加密身份证号
```go
func encryptIDCardExample() {
key := []byte("1234567890abcdef")
idCard := "440101199001011234"
encryptedIDCard, err := crypto.EncryptIDCard(idCard, key)
if err != nil {
panic(err)
}
fmt.Println("加密后的身份证号:", encryptedIDCard)
// 解密身份证号
decryptedIDCard, err := crypto.DecryptIDCard(encryptedIDCard, key)
if err != nil {
panic(err)
}
fmt.Println("解密后的身份证号:", decryptedIDCard)
}
```
### 4. 密钥管理
```go
func keyManagementExample() {
// 生成随机密钥
key, err := crypto.GenerateAESKey(16) // AES-128
if err != nil {
panic(err)
}
fmt.Printf("生成的密钥(十六进制): %x\n", key)
// 从密码派生密钥(便于记忆)
password := "my-secure-password"
derivedKey, err := crypto.DeriveKeyFromPassword(password, 16)
if err != nil {
panic(err)
}
fmt.Printf("从密码派生的密钥: %x\n", derivedKey)
}
```
### 5. 使用十六进制输出(适用于 URL 参数)
```go
func hexEncodingExample() {
key := []byte("1234567890abcdef")
mobile := "13800138000"
// 使用十六进制编码(适合URL参数)
encryptedHex, err := crypto.EncryptMobileHex(mobile, key)
if err != nil {
panic(err)
}
fmt.Println("十六进制编码的加密手机号:", encryptedHex)
// 解密十六进制编码的手机号
decryptedMobile, err := crypto.DecryptMobileHex(encryptedHex, key)
if err != nil {
panic(err)
}
fmt.Println("解密后的手机号:", decryptedMobile)
}
```
## 在 Go-Zero 项目中使用
在 Go-Zero 项目中,建议将加密密钥放在配置文件中:
1. 在配置文件中添加密钥配置:
```yaml
# etc/main.yaml
Name: user-api
Host: 0.0.0.0
Port: 8888
Encrypt:
MobileKey: "1234567890abcdef" # 16字节AES-128密钥
IDCardKey: "1234567890abcdef1234567890abcdef" # 32字节AES-256密钥
```
2. 在配置结构中定义:
```go
type Config struct {
rest.RestConf
Encrypt struct {
MobileKey string
IDCardKey string
}
}
```
3. 在服务上下文中使用:
```go
type ServiceContext struct {
Config config.Config
UserModel model.UserModel
MobileKey []byte
IDCardKey []byte
}
func NewServiceContext(c config.Config) *ServiceContext {
return &ServiceContext{
Config: c,
UserModel: model.NewUserModel(sqlx.NewMysql(c.DB.DataSource), c.Cache),
MobileKey: []byte(c.Encrypt.MobileKey),
IDCardKey: []byte(c.Encrypt.IDCardKey),
}
}
```
4. 在 Logic 中使用:
```go
func (l *RegisterLogic) Register(req *types.RegisterReq) (*types.RegisterResp, error) {
// 加密手机号用于存储
encryptedMobile, err := crypto.EncryptMobile(req.Mobile, l.svcCtx.MobileKey)
if err != nil {
return nil, errors.New("手机号加密失败")
}
// 保存到数据库
user := &model.User{
Mobile: encryptedMobile,
// 其他字段...
}
result, err := l.svcCtx.UserModel.Insert(l.ctx, nil, user)
// 其余逻辑...
}
```

274
pkg/lzkit/crypto/ecb.go Normal file
View File

@ -0,0 +1,274 @@
package crypto
import (
"crypto/aes"
"crypto/md5"
"crypto/rand"
"encoding/base64"
"encoding/hex"
"errors"
"fmt"
)
// ECB模式是一种基本的加密模式每个明文块独立加密
// 警告ECB模式存在安全问题仅用于需要确定性加密的场景如数据库字段查询
// 不要用于加密大段文本或安全要求高的场景
// 验证密钥长度是否有效 (AES-128, AES-192, AES-256)
func validateAESKey(key []byte) error {
switch len(key) {
case 16, 24, 32:
return nil
default:
return errors.New("AES密钥长度必须是16、24或32字节(对应AES-128、AES-192、AES-256)")
}
}
// AesEcbEncrypt AES-ECB模式加密返回Base64编码的密文
// 使用已有的ECB实现但提供更易用的接口
func AesEcbEncrypt(plainText, key []byte) (string, error) {
if err := validateAESKey(key); err != nil {
return "", err
}
block, err := aes.NewCipher(key)
if err != nil {
return "", err
}
// 使用PKCS7填充
plainText = PKCS7Padding(plainText, block.BlockSize())
// 创建密文数组
cipherText := make([]byte, len(plainText))
// ECB模式加密使用west_crypto.go中已有的实现
mode := newECBEncrypter(block)
mode.CryptBlocks(cipherText, plainText)
// 返回Base64编码的密文
return base64.StdEncoding.EncodeToString(cipherText), nil
}
// AesEcbDecrypt AES-ECB模式解密输入Base64编码的密文
func AesEcbDecrypt(cipherTextBase64 string, key []byte) ([]byte, error) {
if err := validateAESKey(key); err != nil {
return nil, err
}
// Base64解码
cipherText, err := base64.StdEncoding.DecodeString(cipherTextBase64)
if err != nil {
return nil, err
}
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
// 检查密文长度
if len(cipherText)%block.BlockSize() != 0 {
return nil, errors.New("密文长度必须是块大小的整数倍")
}
// 创建明文数组
plainText := make([]byte, len(cipherText))
// ECB模式解密使用west_crypto.go中已有的实现
mode := newECBDecrypter(block)
mode.CryptBlocks(plainText, cipherText)
// 去除PKCS7填充
plainText, err = PKCS7UnPadding(plainText)
if err != nil {
return nil, err
}
return plainText, nil
}
// AesEcbEncryptHex AES-ECB模式加密返回十六进制编码的密文
func AesEcbEncryptHex(plainText, key []byte) (string, error) {
if err := validateAESKey(key); err != nil {
return "", err
}
block, err := aes.NewCipher(key)
if err != nil {
return "", err
}
// 使用PKCS7填充
plainText = PKCS7Padding(plainText, block.BlockSize())
// 创建密文数组
cipherText := make([]byte, len(plainText))
// ECB模式加密
mode := newECBEncrypter(block)
mode.CryptBlocks(cipherText, plainText)
// 返回十六进制编码的密文
return hex.EncodeToString(cipherText), nil
}
// AesEcbDecryptHex AES-ECB模式解密输入十六进制编码的密文
func AesEcbDecryptHex(cipherTextHex string, key []byte) ([]byte, error) {
if err := validateAESKey(key); err != nil {
return nil, err
}
// 十六进制解码
cipherText, err := hex.DecodeString(cipherTextHex)
if err != nil {
return nil, err
}
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
// 检查密文长度
if len(cipherText)%block.BlockSize() != 0 {
return nil, errors.New("密文长度必须是块大小的整数倍")
}
// 创建明文数组
plainText := make([]byte, len(cipherText))
// ECB模式解密
mode := newECBDecrypter(block)
mode.CryptBlocks(plainText, cipherText)
// 去除PKCS7填充
plainText, err = PKCS7UnPadding(plainText)
if err != nil {
return nil, err
}
return plainText, nil
}
// 以下是专门用于处理手机号等敏感数据的实用函数
// EncryptMobile 使用AES-ECB加密手机号返回Base64编码
// 该方法保证对相同手机号总是产生相同密文,便于数据库查询
func EncryptMobile(mobile string, secretKey string) (string, error) {
key, decodeErr := hex.DecodeString(secretKey)
if decodeErr != nil {
return "", decodeErr
}
if mobile == "" {
return "", errors.New("手机号不能为空")
}
return AesEcbEncrypt([]byte(mobile), key)
}
// DecryptMobile 解密手机号
func DecryptMobile(encryptedMobile string, secretKey string) (string, error) {
key, decodeErr := hex.DecodeString(secretKey)
if decodeErr != nil {
return "", decodeErr
}
if encryptedMobile == "" {
return "", errors.New("加密手机号不能为空")
}
bytes, err := AesEcbDecrypt(encryptedMobile, key)
if err != nil {
return "", fmt.Errorf("解密手机号失败: %v", err)
}
return string(bytes), nil
}
// EncryptMobileHex 使用AES-ECB加密手机号返回十六进制编码(适用于URL参数)
func EncryptMobileHex(mobile string, key []byte) (string, error) {
if mobile == "" {
return "", errors.New("手机号不能为空")
}
return AesEcbEncryptHex([]byte(mobile), key)
}
// DecryptMobileHex 解密十六进制编码的手机号
func DecryptMobileHex(encryptedMobileHex string, key []byte) (string, error) {
if encryptedMobileHex == "" {
return "", errors.New("加密手机号不能为空")
}
bytes, err := AesEcbDecryptHex(encryptedMobileHex, key)
if err != nil {
return "", fmt.Errorf("解密手机号失败: %v", err)
}
return string(bytes), nil
}
// EncryptIDCard 使用AES-ECB加密身份证号
func EncryptIDCard(idCard string, key []byte) (string, error) {
if idCard == "" {
return "", errors.New("身份证号不能为空")
}
return AesEcbEncrypt([]byte(idCard), key)
}
// DecryptIDCard 解密身份证号
func DecryptIDCard(encryptedIDCard string, key []byte) (string, error) {
if encryptedIDCard == "" {
return "", errors.New("加密身份证号不能为空")
}
bytes, err := AesEcbDecrypt(encryptedIDCard, key)
if err != nil {
return "", fmt.Errorf("解密身份证号失败: %v", err)
}
return string(bytes), nil
}
// IsEncrypted 检查字符串是否为Base64编码的加密数据
func IsEncrypted(data string) bool {
// 检查是否是有效的Base64编码
_, err := base64.StdEncoding.DecodeString(data)
return err == nil && len(data) >= 20 // 至少20个字符的Base64字符串
}
// GenerateAESKey 生成AES密钥
// keySize: 可选16, 24, 32字节(对应AES-128, AES-192, AES-256)
func GenerateAESKey(keySize int) ([]byte, error) {
if keySize != 16 && keySize != 24 && keySize != 32 {
return nil, errors.New("密钥长度必须是16、24或32字节")
}
key := make([]byte, keySize)
_, err := rand.Read(key)
if err != nil {
return nil, err
}
return key, nil
}
// DeriveKeyFromPassword 基于密码派生固定长度的AES密钥
func DeriveKeyFromPassword(password string, keySize int) ([]byte, error) {
if keySize != 16 && keySize != 24 && keySize != 32 {
return nil, errors.New("密钥长度必须是16、24或32字节")
}
// 使用PBKDF2或简单的方法从密码派生密钥
// 这里使用简单的MD5方法实际生产环境应使用更安全的PBKDF2
hash := md5.New()
hash.Write([]byte(password))
key := hash.Sum(nil) // 16字节
// 如果需要24或32字节继续哈希
if keySize > 16 {
hash.Reset()
hash.Write(key)
key = append(key, hash.Sum(nil)[:keySize-16]...)
}
return key, nil
}

View File

@ -0,0 +1,184 @@
package crypto
import (
"encoding/base64"
"encoding/hex"
"fmt"
"testing"
)
func TestAesEcbMobileEncryption(t *testing.T) {
// 测试手机号加密
mobile := "13800138000"
key := []byte("1234567890abcdef") // 16字节AES-128密钥
// 测试加密
encrypted, err := EncryptMobile(mobile, key)
if err != nil {
t.Fatalf("手机号加密失败: %v", err)
}
fmt.Println(encrypted)
// 测试解密
decrypted, err := DecryptMobile(encrypted, key)
if err != nil {
t.Fatalf("手机号解密失败: %v", err)
}
fmt.Println(decrypted)
// 验证结果
if decrypted != mobile {
t.Errorf("解密结果不匹配,期望: %s, 实际: %s", mobile, decrypted)
}
// 测试相同输入产生相同输出(确定性)
encrypted2, _ := EncryptMobile(mobile, key)
if encrypted != encrypted2 {
t.Errorf("AES-ECB不是确定性的两次加密结果不同: %s vs %s", encrypted, encrypted2)
}
}
func TestAesEcbHexEncryption(t *testing.T) {
// 测试十六进制编码加密
idCard := "440101199001011234"
key := []byte("1234567890abcdef") // 16字节AES-128密钥
// 测试HEX加密
encryptedHex, err := EncryptIDCard(idCard, key)
if err != nil {
t.Fatalf("身份证加密失败: %v", err)
}
// 测试HEX解密
decrypted, err := DecryptIDCard(encryptedHex, key)
if err != nil {
t.Fatalf("身份证解密失败: %v", err)
}
// 验证结果
if decrypted != idCard {
t.Errorf("解密结果不匹配,期望: %s, 实际: %s", idCard, decrypted)
}
}
func TestAesEcbKeyValidation(t *testing.T) {
// 测试不同长度的密钥
validKeys := [][]byte{
make([]byte, 16), // AES-128
make([]byte, 24), // AES-192
make([]byte, 32), // AES-256
}
invalidKeys := [][]byte{
make([]byte, 15),
make([]byte, 20),
make([]byte, 33),
}
text := []byte("test text")
// 测试有效密钥
for _, key := range validKeys {
_, err := AesEcbEncrypt(text, key)
if err != nil {
t.Errorf("有效密钥(%d字节)校验失败: %v", len(key), err)
}
}
// 测试无效密钥
for _, key := range invalidKeys {
_, err := AesEcbEncrypt(text, key)
if err == nil {
t.Errorf("无效密钥(%d字节)未被检测出", len(key))
}
}
}
func TestIsEncrypted(t *testing.T) {
// 有效的Base64编码字符串
validBase64 := base64.StdEncoding.EncodeToString([]byte("这是一个足够长的字符串以通过IsEncrypted检查"))
// 无效的字符串
invalidStrings := []string{
"",
"abc",
"not-base64!@#",
hex.EncodeToString([]byte("hexstring")),
}
// 测试有效的加密数据
if !IsEncrypted(validBase64) {
t.Errorf("有效的Base64未被识别为加密数据: %s", validBase64)
}
// 测试无效的数据
for _, s := range invalidStrings {
if IsEncrypted(s) {
t.Errorf("无效字符串被错误识别为加密数据: %s", s)
}
}
}
func TestDeriveKeyFromPassword(t *testing.T) {
password := "my-secure-password"
// 测试不同长度的派生密钥
keySizes := []int{16, 24, 32}
for _, size := range keySizes {
key, err := DeriveKeyFromPassword(password, size)
if err != nil {
t.Errorf("从密码派生%d字节密钥失败: %v", size, err)
continue
}
if len(key) != size {
t.Errorf("派生的密钥长度错误,期望: %d, 实际: %d", size, len(key))
}
// 测试相同密码总是产生相同密钥
key2, _ := DeriveKeyFromPassword(password, size)
if string(key) != string(key2) {
t.Errorf("从相同密码派生的密钥不一致")
}
// 使用派生的密钥加密测试
_, err = AesEcbEncrypt([]byte("test"), key)
if err != nil {
t.Errorf("使用派生的密钥加密失败: %v", err)
}
}
// 测试无效的密钥大小
_, err := DeriveKeyFromPassword(password, 18)
if err == nil {
t.Error("无效的密钥大小未被检测出")
}
}
func TestGenerateAESKey(t *testing.T) {
// 测试生成不同长度的密钥
keySizes := []int{16, 24, 32}
for _, size := range keySizes {
key, err := GenerateAESKey(size)
if err != nil {
t.Errorf("生成%d字节密钥失败: %v", size, err)
continue
}
if len(key) != size {
t.Errorf("生成的密钥长度错误,期望: %d, 实际: %d", size, len(key))
}
// 使用生成的密钥加密测试
_, err = AesEcbEncrypt([]byte("test"), key)
if err != nil {
t.Errorf("使用生成的密钥加密失败: %v", err)
}
}
// 测试无效的密钥大小
_, err := GenerateAESKey(18)
if err == nil {
t.Error("无效的密钥大小未被检测出")
}
}

106
pkg/lzkit/md5/README.md Normal file
View File

@ -0,0 +1,106 @@
# MD5 工具包
这个包提供了全面的 MD5 哈希功能,包括字符串加密、文件加密、链式操作、加盐哈希等。
## 主要功能
- 字符串和字节切片的 MD5 哈希计算
- 文件 MD5 哈希计算(支持大文件分块处理)
- 链式 API支持构建复杂的哈希内容
- 带盐值的 MD5 哈希,提高安全性
- 哈希验证功能
- 16 位和 8 位 MD5 哈希(短版本)
- HMAC-MD5 实现,增强安全性
## 使用示例
### 基本使用
```go
// 计算字符串的MD5哈希
hash := md5.EncryptString("hello world")
fmt.Println(hash) // 5eb63bbbe01eeed093cb22bb8f5acdc3
// 计算字节切片的MD5哈希
bytes := []byte("hello world")
hash = md5.EncryptBytes(bytes)
// 计算文件的MD5哈希
fileHash, err := md5.EncryptFile("path/to/file.txt")
if err != nil {
log.Fatal(err)
}
fmt.Println(fileHash)
```
### 链式 API
```go
// 创建一个新的MD5实例并添加内容
hash := md5.New().
Add("hello").
Add(" ").
Add("world").
Sum()
fmt.Println(hash) // 5eb63bbbe01eeed093cb22bb8f5acdc3
// 或者从字符串初始化
hash = md5.FromString("hello").
Add(" world").
Sum()
// 从字节切片初始化
hash = md5.FromBytes([]byte("hello")).
AddBytes([]byte(" world")).
Sum()
```
### 安全性增强
```go
// 使用盐值加密(提高安全性)
hashedPassword := md5.EncryptStringWithSalt("password123", "user@example.com")
// 使用前缀加密
hashedValue := md5.EncryptStringWithPrefix("secret-data", "prefix-")
// 验证带盐值的哈希
isValid := md5.VerifyMD5WithSalt("password123", "user@example.com", hashedPassword)
// 使用HMAC-MD5提高安全性
hmacHash := md5.MD5HMAC("message", "secret-key")
```
### 短哈希值
```go
// 获取16位MD532位MD5的中间部分
hash16 := md5.Get16("hello world")
fmt.Println(hash16) // 中间16个字符
// 获取8位MD5
hash8 := md5.Get8("hello world")
fmt.Println(hash8) // 中间8个字符
```
### 文件验证
```go
// 验证文件MD5是否匹配
match, err := md5.VerifyFileMD5("path/to/file.txt", "expected-hash")
if err != nil {
log.Fatal(err)
}
if match {
fmt.Println("文件MD5校验通过")
} else {
fmt.Println("文件MD5校验失败")
}
```
## 注意事项
1. MD5 主要用于校验,不适合用于安全存储密码等敏感信息
2. 如果用于密码存储,请务必使用加盐处理并考虑使用更安全的算法
3. 处理大文件时请使用`EncryptFileChunk`以优化性能
4. 返回的 MD5 哈希值都是 32 位的小写十六进制字符串(除非使用 Get16/Get8 函数)

View File

@ -0,0 +1,79 @@
package md5_test
import (
"fmt"
"log"
"tydata-server/pkg/lzkit/md5"
)
func Example() {
// 简单的字符串MD5
hashValue := md5.EncryptString("hello world")
fmt.Println("MD5(hello world):", hashValue)
// 使用链式API
chainHash := md5.New().
Add("hello").
Add(" ").
Add("world").
Sum()
fmt.Println("链式MD5:", chainHash)
// 使用盐值
saltedHash := md5.EncryptStringWithSalt("password123", "user@example.com")
fmt.Println("加盐MD5:", saltedHash)
// 验证哈希
isValid := md5.VerifyMD5("hello world", hashValue)
fmt.Println("验证结果:", isValid)
// 生成短版本的MD5
fmt.Println("16位MD5:", md5.Get16("hello world"))
fmt.Println("8位MD5:", md5.Get8("hello world"))
// 文件MD5计算
filePath := "example.txt" // 这只是示例,实际上这个文件可能不存在
fileHash, err := md5.EncryptFile(filePath)
if err != nil {
// 在实际代码中执行正确的错误处理
log.Printf("计算文件MD5出错: %v", err)
} else {
fmt.Println("文件MD5:", fileHash)
}
// HMAC-MD5
hmacHash := md5.MD5HMAC("重要消息", "secret-key")
fmt.Println("HMAC-MD5:", hmacHash)
}
func ExampleEncryptString() {
hash := md5.EncryptString("HelloWorld")
fmt.Println(hash)
// Output: 68e109f0f40ca72a15e05cc22786f8e6
}
func ExampleMD5_Sum() {
hash := md5.New().
Add("Hello").
Add("World").
Sum()
fmt.Println(hash)
// Output: 68e109f0f40ca72a15e05cc22786f8e6
}
func ExampleEncryptStringWithSalt() {
// 为用户密码加盐,通常使用用户唯一标识(如邮箱)作为盐值
hash := md5.EncryptStringWithSalt("password123", "user@example.com")
fmt.Println("盐值哈希长度:", len(hash))
fmt.Println("是否为有效哈希:", md5.VerifyMD5WithSalt("password123", "user@example.com", hash))
// Output:
// 盐值哈希长度: 32
// 是否为有效哈希: true
}
func ExampleGet16() {
// 获取16位MD5适合不需要完全防碰撞场景
hash := md5.Get16("HelloWorld")
fmt.Println(hash)
// Output: f0f40ca72a15e05c
}

206
pkg/lzkit/md5/md5.go Normal file
View File

@ -0,0 +1,206 @@
package md5
import (
"bufio"
"crypto/md5"
"encoding/hex"
"io"
"os"
"strings"
)
// MD5结构体可用于链式调用
type MD5 struct {
data []byte
}
// New 创建一个新的MD5实例
func New() *MD5 {
return &MD5{
data: []byte{},
}
}
// FromString 从字符串创建MD5
func FromString(s string) *MD5 {
return &MD5{
data: []byte(s),
}
}
// FromBytes 从字节切片创建MD5
func FromBytes(b []byte) *MD5 {
return &MD5{
data: b,
}
}
// Add 向MD5中添加字符串
func (m *MD5) Add(s string) *MD5 {
m.data = append(m.data, []byte(s)...)
return m
}
// AddBytes 向MD5中添加字节切片
func (m *MD5) AddBytes(b []byte) *MD5 {
m.data = append(m.data, b...)
return m
}
// Sum 计算并返回MD5哈希值(16进制字符串)
func (m *MD5) Sum() string {
hash := md5.New()
hash.Write(m.data)
return hex.EncodeToString(hash.Sum(nil))
}
// SumBytes 计算并返回MD5哈希值(字节切片)
func (m *MD5) SumBytes() []byte {
hash := md5.New()
hash.Write(m.data)
return hash.Sum(nil)
}
// 直接调用的工具函数
// EncryptString 加密字符串
func EncryptString(s string) string {
hash := md5.New()
hash.Write([]byte(s))
return hex.EncodeToString(hash.Sum(nil))
}
// EncryptBytes 加密字节切片
func EncryptBytes(b []byte) string {
hash := md5.New()
hash.Write(b)
return hex.EncodeToString(hash.Sum(nil))
}
// EncryptFile 加密文件内容
func EncryptFile(filePath string) (string, error) {
file, err := os.Open(filePath)
if err != nil {
return "", err
}
defer file.Close()
hash := md5.New()
if _, err := io.Copy(hash, file); err != nil {
return "", err
}
return hex.EncodeToString(hash.Sum(nil)), nil
}
// EncryptFileChunk 对大文件分块计算MD5提高效率
func EncryptFileChunk(filePath string, chunkSize int) (string, error) {
if chunkSize <= 0 {
chunkSize = 1024 * 1024 // 默认1MB
}
file, err := os.Open(filePath)
if err != nil {
return "", err
}
defer file.Close()
hash := md5.New()
buf := make([]byte, chunkSize)
reader := bufio.NewReader(file)
for {
n, err := reader.Read(buf)
if err != nil && err != io.EOF {
return "", err
}
if n == 0 {
break
}
hash.Write(buf[:n])
}
return hex.EncodeToString(hash.Sum(nil)), nil
}
// EncryptStringWithSalt 使用盐值加密字符串
func EncryptStringWithSalt(s, salt string) string {
return EncryptString(s + salt)
}
// EncryptStringWithPrefix 使用前缀加密字符串
func EncryptStringWithPrefix(s, prefix string) string {
return EncryptString(prefix + s)
}
// VerifyMD5 验证字符串的MD5哈希是否匹配
func VerifyMD5(s, hash string) bool {
return EncryptString(s) == strings.ToLower(hash)
}
// VerifyMD5WithSalt 验证带盐值的字符串MD5哈希是否匹配
func VerifyMD5WithSalt(s, salt, hash string) bool {
return EncryptStringWithSalt(s, salt) == strings.ToLower(hash)
}
// VerifyFileMD5 验证文件的MD5哈希是否匹配
func VerifyFileMD5(filePath, hash string) (bool, error) {
fileHash, err := EncryptFile(filePath)
if err != nil {
return false, err
}
return fileHash == strings.ToLower(hash), nil
}
// MD5格式化为指定位数
// Get16 获取16位MD5值(取32位结果的中间16位)
func Get16(s string) string {
result := EncryptString(s)
return result[8:24]
}
// Get8 获取8位MD5值
func Get8(s string) string {
result := EncryptString(s)
return result[12:20]
}
// MD5主要用于校验而非安全存储对于需要高安全性的场景应考虑:
// 1. bcrypt, scrypt或Argon2等专门为密码设计的算法
// 2. HMAC-MD5等方式以防御彩虹表攻击
// 3. 加盐并使用多次哈希迭代提高安全性
// MD5HMAC 使用HMAC-MD5算法
func MD5HMAC(message, key string) string {
hash := md5.New()
// 如果key长度超出block size先进行哈希
if len(key) > 64 {
hash.Write([]byte(key))
key = hex.EncodeToString(hash.Sum(nil))
hash.Reset()
}
// 内部填充
k_ipad := make([]byte, 64)
k_opad := make([]byte, 64)
copy(k_ipad, []byte(key))
copy(k_opad, []byte(key))
for i := 0; i < 64; i++ {
k_ipad[i] ^= 0x36
k_opad[i] ^= 0x5c
}
// 内部哈希
hash.Write(k_ipad)
hash.Write([]byte(message))
innerHash := hash.Sum(nil)
hash.Reset()
// 外部哈希
hash.Write(k_opad)
hash.Write(innerHash)
return hex.EncodeToString(hash.Sum(nil))
}

192
pkg/lzkit/md5/md5_test.go Normal file
View File

@ -0,0 +1,192 @@
package md5
import (
"fmt"
"os"
"testing"
)
func TestEncryptString(t *testing.T) {
tests := []struct {
input string
expected string
}{
{"", "d41d8cd98f00b204e9800998ecf8427e"},
{"hello", "5d41402abc4b2a76b9719d911017c592"},
{"123456", "e10adc3949ba59abbe56e057f20f883e"},
{"Hello World!", "ed076287532e86365e841e92bfc50d8c"},
}
for _, test := range tests {
result := EncryptString(test.input)
fmt.Println(result)
if result != test.expected {
t.Errorf("EncryptString(%s) = %s; want %s", test.input, result, test.expected)
}
}
}
func TestEncryptBytes(t *testing.T) {
tests := []struct {
input []byte
expected string
}{
{[]byte(""), "d41d8cd98f00b204e9800998ecf8427e"},
{[]byte("hello"), "5d41402abc4b2a76b9719d911017c592"},
{[]byte{0, 1, 2, 3, 4}, "5267768822ee624d48fce15ec5ca79b6"},
}
for _, test := range tests {
result := EncryptBytes(test.input)
if result != test.expected {
t.Errorf("EncryptBytes(%v) = %s; want %s", test.input, result, test.expected)
}
}
}
func TestMD5Chain(t *testing.T) {
// 测试链式调用
result := New().
Add("hello").
Add(" ").
Add("world").
Sum()
expected := "fc5e038d38a57032085441e7fe7010b0" // MD5("hello world")
if result != expected {
t.Errorf("Chain MD5 = %s; want %s", result, expected)
}
// 测试从字符串初始化
result = FromString("hello").Add(" world").Sum()
if result != expected {
t.Errorf("FromString MD5 = %s; want %s", result, expected)
}
// 测试从字节切片初始化
result = FromBytes([]byte("hello")).AddBytes([]byte(" world")).Sum()
if result != expected {
t.Errorf("FromBytes MD5 = %s; want %s", result, expected)
}
}
func TestVerifyMD5(t *testing.T) {
if !VerifyMD5("hello", "5d41402abc4b2a76b9719d911017c592") {
t.Error("VerifyMD5 failed for correct match")
}
if VerifyMD5("hello", "wrong-hash") {
t.Error("VerifyMD5 succeeded for incorrect match")
}
// 测试大小写不敏感
if !VerifyMD5("hello", "5D41402ABC4B2A76B9719D911017C592") {
t.Error("VerifyMD5 failed for uppercase hash")
}
}
func TestSaltAndPrefix(t *testing.T) {
// 测试加盐
saltResult := EncryptStringWithSalt("password", "salt123")
expectedSalt := EncryptString("passwordsalt123")
if saltResult != expectedSalt {
t.Errorf("EncryptStringWithSalt = %s; want %s", saltResult, expectedSalt)
}
// 测试前缀
prefixResult := EncryptStringWithPrefix("password", "prefix123")
expectedPrefix := EncryptString("prefix123password")
if prefixResult != expectedPrefix {
t.Errorf("EncryptStringWithPrefix = %s; want %s", prefixResult, expectedPrefix)
}
// 验证带盐值的MD5
if !VerifyMD5WithSalt("password", "salt123", saltResult) {
t.Error("VerifyMD5WithSalt failed for correct match")
}
}
func TestGet16And8(t *testing.T) {
full := EncryptString("test-string")
// 测试16位MD5
result16 := Get16("test-string")
expected16 := full[8:24]
if result16 != expected16 {
t.Errorf("Get16 = %s; want %s", result16, expected16)
}
// 测试8位MD5
result8 := Get8("test-string")
expected8 := full[12:20]
if result8 != expected8 {
t.Errorf("Get8 = %s; want %s", result8, expected8)
}
}
func TestMD5HMAC(t *testing.T) {
// 已知的HMAC-MD5结果
tests := []struct {
message string
key string
expected string
}{
{"message", "key", "4e4748e62b463521f6775fbf921234b5"},
{"test", "secret", "8b11d99898918564dda1a9fe205b5310"},
}
for _, test := range tests {
result := MD5HMAC(test.message, test.key)
if result != test.expected {
t.Errorf("MD5HMAC(%s, %s) = %s; want %s",
test.message, test.key, result, test.expected)
}
}
}
func TestEncryptFile(t *testing.T) {
// 创建临时测试文件
content := []byte("test file content for MD5")
tmpFile, err := os.CreateTemp("", "md5test-*.txt")
if err != nil {
t.Fatalf("无法创建临时文件: %v", err)
}
defer os.Remove(tmpFile.Name())
if _, err := tmpFile.Write(content); err != nil {
t.Fatalf("无法写入临时文件: %v", err)
}
if err := tmpFile.Close(); err != nil {
t.Fatalf("无法关闭临时文件: %v", err)
}
// 计算文件MD5
fileHash, err := EncryptFile(tmpFile.Name())
if err != nil {
t.Fatalf("计算文件MD5失败: %v", err)
}
// 验证文件MD5
expectedHash := EncryptBytes(content)
if fileHash != expectedHash {
t.Errorf("文件MD5 = %s; 应为 %s", fileHash, expectedHash)
}
// 测试VerifyFileMD5
match, err := VerifyFileMD5(tmpFile.Name(), expectedHash)
if err != nil {
t.Fatalf("验证文件MD5失败: %v", err)
}
if !match {
t.Error("VerifyFileMD5返回false应返回true")
}
// 测试不匹配的情况
match, err = VerifyFileMD5(tmpFile.Name(), "wronghash")
if err != nil {
t.Fatalf("验证文件MD5失败: %v", err)
}
if match {
t.Error("VerifyFileMD5对错误的哈希返回true应返回false")
}
}