diff --git a/app/user/cmd/api/internal/logic/agent/activateagentmembershiplogic.go b/app/user/cmd/api/internal/logic/agent/activateagentmembershiplogic.go index b33d4dc..8fa215e 100644 --- a/app/user/cmd/api/internal/logic/agent/activateagentmembershiplogic.go +++ b/app/user/cmd/api/internal/logic/agent/activateagentmembershiplogic.go @@ -3,11 +3,13 @@ package agent import ( "context" "database/sql" - "github.com/pkg/errors" - "github.com/zeromicro/go-zero/core/stores/sqlx" "time" "tydata-server/app/user/model" "tydata-server/common/xerr" + "tydata-server/pkg/lzkit/crypto" + + "github.com/pkg/errors" + "github.com/zeromicro/go-zero/core/stores/sqlx" "tydata-server/app/user/cmd/api/internal/svc" "tydata-server/app/user/cmd/api/internal/types" @@ -34,7 +36,12 @@ func (l *ActivateAgentMembershipLogic) ActivateAgentMembership(req *types.AgentA //if err != nil { // return nil, errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "获取用户ID失败: %v", err) //} - userModel, err := l.svcCtx.UserModel.FindOneByMobile(l.ctx, req.Mobile) + secretKey := l.svcCtx.Config.Encrypt.SecretKey + encryptedMobile, err := crypto.EncryptMobile(req.Mobile, secretKey) + if err != nil { + return nil, errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "加密手机号失败: %v", err) + } + userModel, err := l.svcCtx.UserModel.FindOneByMobile(l.ctx, encryptedMobile) if err != nil { return nil, errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "查询代理信息失败: %v", err) } diff --git a/app/user/cmd/api/internal/logic/agent/applyforagentlogic.go b/app/user/cmd/api/internal/logic/agent/applyforagentlogic.go index a367d8c..03150df 100644 --- a/app/user/cmd/api/internal/logic/agent/applyforagentlogic.go +++ b/app/user/cmd/api/internal/logic/agent/applyforagentlogic.go @@ -7,6 +7,7 @@ import ( "tydata-server/app/user/model" jwtx "tydata-server/common/jwt" "tydata-server/common/xerr" + "tydata-server/pkg/lzkit/crypto" "tydata-server/pkg/lzkit/lzUtils" "github.com/pkg/errors" @@ -34,17 +35,22 @@ func NewApplyForAgentLogic(ctx context.Context, svcCtx *svc.ServiceContext) *App } func (l *ApplyForAgentLogic) ApplyForAgent(req *types.AgentApplyReq) (resp *types.AgentApplyResp, err error) { + secretKey := l.svcCtx.Config.Encrypt.SecretKey + encryptedMobile, err := crypto.EncryptMobile(req.Mobile, secretKey) + if err != nil { + return nil, errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "加密手机号失败: %v", err) + } // 校验验证码 - redisKey := fmt.Sprintf("%s:%s", "agentApply", req.Mobile) + redisKey := fmt.Sprintf("%s:%s", "agentApply", encryptedMobile) cacheCode, err := l.svcCtx.Redis.Get(redisKey) if err != nil { if errors.Is(err, redis.Nil) { - return nil, errors.Wrapf(xerr.NewErrMsg("验证码已过期"), "代理申请, 验证码过期: %s", req.Mobile) + return nil, errors.Wrapf(xerr.NewErrMsg("验证码已过期"), "代理申请, 验证码过期: %s", encryptedMobile) } - return nil, errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "代理申请, 读取验证码redis缓存失败, mobile: %s, err: %+v", req.Mobile, err) + return nil, errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "代理申请, 读取验证码redis缓存失败, mobile: %s, err: %+v", encryptedMobile, err) } if cacheCode != req.Code { - return nil, errors.Wrapf(xerr.NewErrMsg("验证码不正确"), "代理申请, 验证码不正确: %s", req.Mobile) + return nil, errors.Wrapf(xerr.NewErrMsg("验证码不正确"), "代理申请, 验证码不正确: %s", encryptedMobile) } if req.Ancestor == req.Mobile { return nil, errors.Wrapf(xerr.NewErrMsg("不能成为自己的代理"), "") @@ -52,18 +58,18 @@ func (l *ApplyForAgentLogic) ApplyForAgent(req *types.AgentApplyReq) (resp *type var userID int64 transErr := l.svcCtx.AgentAuditModel.Trans(l.ctx, func(transCtx context.Context, session sqlx.Session) error { // 两种情况,1. 已注册账号然后申请代理 2. 未注册账号申请代理 - user, findUserErr := l.svcCtx.UserModel.FindOneByMobile(l.ctx, req.Mobile) + user, findUserErr := l.svcCtx.UserModel.FindOneByMobile(l.ctx, encryptedMobile) if findUserErr != nil && !errors.Is(findUserErr, model.ErrNotFound) { - return errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "手机登录, 读取数据库获取用户失败, mobile: %s, err: %+v", req.Mobile, err) + return errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "代理申请, 读取数据库获取用户失败, mobile: %s, err: %+v", encryptedMobile, err) } if user == nil { - user = &model.User{Mobile: req.Mobile} + user = &model.User{Mobile: encryptedMobile} if len(user.Nickname) == 0 { - user.Nickname = req.Mobile + user.Nickname = encryptedMobile } insertResult, userInsertErr := l.svcCtx.UserModel.Insert(transCtx, session, user) if userInsertErr != nil { - return errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "代理申请, 数据库插入新用户失败, mobile%s, err: %+v", req.Mobile, err) + return errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "代理申请, 数据库插入新用户失败, mobile%s, err: %+v", encryptedMobile, err) } lastId, lastInsertIdErr := insertResult.LastInsertId() if lastInsertIdErr != nil { @@ -73,7 +79,7 @@ func (l *ApplyForAgentLogic) ApplyForAgent(req *types.AgentApplyReq) (resp *type userID = lastId userAuth := new(model.UserAuth) userAuth.UserId = lastId - userAuth.AuthKey = req.Mobile + userAuth.AuthKey = encryptedMobile userAuth.AuthType = model.UserAuthTypeAgentDirect if _, userAuthInsertErr := l.svcCtx.UserAuthModel.Insert(transCtx, session, userAuth); userAuthInsertErr != nil { return errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "代理申请, 数据库插入用户认证失败, err:%+v", userAuthInsertErr) @@ -99,7 +105,7 @@ func (l *ApplyForAgentLogic) ApplyForAgent(req *types.AgentApplyReq) (resp *type var agentAudit model.AgentAudit agentAudit.UserId = user.Id - agentAudit.Mobile = req.Mobile + agentAudit.Mobile = encryptedMobile agentAudit.Region = req.Region agentAudit.WechatId = lzUtils.StringToNullString(req.WechatID) agentAudit.Status = 1 @@ -133,7 +139,7 @@ func (l *ApplyForAgentLogic) ApplyForAgent(req *types.AgentApplyReq) (resp *type // 关联上级 if req.Ancestor != "" { - ancestorAgentModel, findAgentModelErr := l.svcCtx.AgentModel.FindOneByMobile(transCtx, req.Ancestor) + ancestorAgentModel, findAgentModelErr := l.svcCtx.AgentModel.FindOneByMobile(transCtx, encryptedMobile) if findAgentModelErr != nil { return errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "代理申请, 查找上级代理失败: %+v", findAgentModelErr) } diff --git a/app/user/cmd/api/internal/logic/agent/getagentinfologic.go b/app/user/cmd/api/internal/logic/agent/getagentinfologic.go index 81a91c5..d6ae714 100644 --- a/app/user/cmd/api/internal/logic/agent/getagentinfologic.go +++ b/app/user/cmd/api/internal/logic/agent/getagentinfologic.go @@ -5,6 +5,7 @@ import ( "database/sql" "tydata-server/common/ctxdata" "tydata-server/common/xerr" + "tydata-server/pkg/lzkit/crypto" "tydata-server/pkg/lzkit/lzUtils" "github.com/pkg/errors" @@ -59,6 +60,10 @@ func (l *GetAgentInfoLogic) GetAgentInfo() (resp *types.AgentInfoResp, err error } return nil, errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "获取代理信息, %v", err) } + agent.Mobile, err = crypto.DecryptMobile(agent.Mobile, l.svcCtx.Config.Encrypt.SecretKey) + if err != nil { + return nil, errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "获取代理信息, 解密手机号失败: %v", err) + } return &types.AgentInfoResp{ AgentID: agent.Id, Level: agent.LevelName, diff --git a/app/user/cmd/api/internal/logic/auth/sendsmslogic.go b/app/user/cmd/api/internal/logic/auth/sendsmslogic.go index 024a56c..5cd1974 100644 --- a/app/user/cmd/api/internal/logic/auth/sendsmslogic.go +++ b/app/user/cmd/api/internal/logic/auth/sendsmslogic.go @@ -3,10 +3,12 @@ package auth import ( "context" "fmt" - "github.com/pkg/errors" "math/rand" "time" "tydata-server/common/xerr" + "tydata-server/pkg/lzkit/crypto" + + "github.com/pkg/errors" "tydata-server/app/user/cmd/api/internal/svc" "tydata-server/app/user/cmd/api/internal/types" @@ -33,16 +35,21 @@ func NewSendSmsLogic(ctx context.Context, svcCtx *svc.ServiceContext) *SendSmsLo } func (l *SendSmsLogic) SendSms(req *types.SendSmsReq) error { + secretKey := l.svcCtx.Config.Encrypt.SecretKey + encryptedMobile, err := crypto.EncryptMobile(req.Mobile, secretKey) + if err != nil { + return errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "短信发送, 加密手机号失败: %v", err) + } // 检查手机号是否在一分钟内已发送过验证码 - limitCodeKey := fmt.Sprintf("limit:%s:%s", req.ActionType, req.Mobile) + limitCodeKey := fmt.Sprintf("limit:%s:%s", req.ActionType, encryptedMobile) exists, err := l.svcCtx.Redis.Exists(limitCodeKey) if err != nil { - return errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "短信发送, 读取redis缓存失败: %s", req.Mobile) + return errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "短信发送, 读取redis缓存失败: %s", encryptedMobile) } if exists { // 如果 Redis 中已经存在标记,说明在 1 分钟内请求过,返回错误 - return errors.Wrapf(xerr.NewErrMsg("一分钟内不能重复发送验证码"), "短信发送, 手机号1分钟内重复请求发送二维码: %s", req.Mobile) + return errors.Wrapf(xerr.NewErrMsg("一分钟内不能重复发送验证码"), "短信发送, 手机号1分钟内重复请求发送验证码: %s", encryptedMobile) } code := fmt.Sprintf("%06d", rand.New(rand.NewSource(time.Now().UnixNano())).Intn(1000000)) @@ -55,7 +62,7 @@ func (l *SendSmsLogic) SendSms(req *types.SendSmsReq) error { if *smsResp.Body.Code != "OK" { return errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "短信发送, 阿里客户端响应失败: %s", *smsResp.Body.Message) } - codeKey := fmt.Sprintf("%s:%s", req.ActionType, req.Mobile) + codeKey := fmt.Sprintf("%s:%s", req.ActionType, encryptedMobile) // 将验证码保存到 Redis,设置过期时间 err = l.svcCtx.Redis.Setex(codeKey, code, l.svcCtx.Config.VerifyCode.ValidTime) // 验证码有效期5分钟 if err != nil { diff --git a/app/user/cmd/api/internal/logic/query/querydetailbyorderidlogic.go b/app/user/cmd/api/internal/logic/query/querydetailbyorderidlogic.go index 2a1d6b3..5a7ec7c 100644 --- a/app/user/cmd/api/internal/logic/query/querydetailbyorderidlogic.go +++ b/app/user/cmd/api/internal/logic/query/querydetailbyorderidlogic.go @@ -7,6 +7,7 @@ import ( "encoding/json" "fmt" "time" + "tydata-server/common/ctxdata" "tydata-server/common/xerr" "tydata-server/pkg/lzkit/crypto" "tydata-server/pkg/lzkit/delay" @@ -17,6 +18,7 @@ import ( "tydata-server/app/user/cmd/api/internal/svc" "tydata-server/app/user/cmd/api/internal/types" + "tydata-server/app/user/model" "github.com/zeromicro/go-zero/core/logx" ) @@ -36,12 +38,26 @@ func NewQueryDetailByOrderIdLogic(ctx context.Context, svcCtx *svc.ServiceContex } func (l *QueryDetailByOrderIdLogic) QueryDetailByOrderId(req *types.QueryDetailByOrderIdReq) (resp *types.QueryDetailByOrderIdResp, err error) { + // 获取当前用户ID + userId, err := ctxdata.GetUidFromCtx(l.ctx) + if err != nil { + return nil, errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "获取用户ID失败: %v", err) + } + // 获取订单信息 order, err := l.svcCtx.OrderModel.FindOne(l.ctx, req.OrderId) if err != nil { + if errors.Is(err, model.ErrNotFound) { + return nil, errors.Wrapf(xerr.NewErrCode(xerr.LOGIC_QUERY_NOT_FOUND), "报告查询, 订单不存在: %v", err) + } return nil, errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "报告查询, 查找报告错误: %v", err) } + // 安全验证:确保订单属于当前用户 + if order.UserId != userId { + return nil, errors.Wrapf(xerr.NewErrCode(xerr.LOGIC_QUERY_NOT_FOUND), "无权查看此订单报告") + } + // 创建渐进式延迟策略实例 progressiveDelayOrder, err := delay.New(200*time.Millisecond, 3*time.Second, 10*time.Second, 1.5) if err != nil { diff --git a/app/user/cmd/api/internal/logic/query/querydetailbyordernologic.go b/app/user/cmd/api/internal/logic/query/querydetailbyordernologic.go index 7adad8f..42fcaa6 100644 --- a/app/user/cmd/api/internal/logic/query/querydetailbyordernologic.go +++ b/app/user/cmd/api/internal/logic/query/querydetailbyordernologic.go @@ -5,6 +5,7 @@ import ( "encoding/hex" "fmt" "time" + "tydata-server/common/ctxdata" "tydata-server/common/xerr" "tydata-server/pkg/lzkit/delay" @@ -13,6 +14,7 @@ import ( "tydata-server/app/user/cmd/api/internal/svc" "tydata-server/app/user/cmd/api/internal/types" + "tydata-server/app/user/model" "github.com/zeromicro/go-zero/core/logx" ) @@ -32,12 +34,26 @@ func NewQueryDetailByOrderNoLogic(ctx context.Context, svcCtx *svc.ServiceContex } func (l *QueryDetailByOrderNoLogic) QueryDetailByOrderNo(req *types.QueryDetailByOrderNoReq) (resp *types.QueryDetailByOrderNoResp, err error) { + // 获取当前用户ID + userId, err := ctxdata.GetUidFromCtx(l.ctx) + if err != nil { + return nil, errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "获取用户ID失败: %v", err) + } + // 获取订单信息 order, err := l.svcCtx.OrderModel.FindOneByOrderNo(l.ctx, req.OrderNo) if err != nil { + if errors.Is(err, model.ErrNotFound) { + return nil, errors.Wrapf(xerr.NewErrCode(xerr.LOGIC_QUERY_NOT_FOUND), "报告查询, 订单不存在: %v", err) + } return nil, errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "报告查询, 查找报告错误: %v", err) } + // 安全验证:确保订单属于当前用户 + if order.UserId != userId { + return nil, errors.Wrapf(xerr.NewErrCode(xerr.LOGIC_QUERY_NOT_FOUND), "无权查看此订单报告") + } + // 创建渐进式延迟策略实例 progressiveDelayOrder, err := delay.New(200*time.Millisecond, 3*time.Second, 10*time.Second, 1.5) if err != nil { diff --git a/app/user/cmd/api/internal/logic/user/agentmobilecodeloginlogic.go b/app/user/cmd/api/internal/logic/user/agentmobilecodeloginlogic.go index 862cb92..956a122 100644 --- a/app/user/cmd/api/internal/logic/user/agentmobilecodeloginlogic.go +++ b/app/user/cmd/api/internal/logic/user/agentmobilecodeloginlogic.go @@ -3,15 +3,17 @@ package user import ( "context" "fmt" - "github.com/pkg/errors" - "github.com/zeromicro/go-zero/core/stores/redis" - "github.com/zeromicro/go-zero/core/stores/sqlx" "time" "tydata-server/app/user/cmd/api/internal/svc" "tydata-server/app/user/cmd/api/internal/types" "tydata-server/app/user/model" jwtx "tydata-server/common/jwt" "tydata-server/common/xerr" + "tydata-server/pkg/lzkit/crypto" + + "github.com/pkg/errors" + "github.com/zeromicro/go-zero/core/stores/redis" + "github.com/zeromicro/go-zero/core/stores/sqlx" "github.com/zeromicro/go-zero/core/logx" ) @@ -31,32 +33,37 @@ func NewAgentMobileCodeLoginLogic(ctx context.Context, svcCtx *svc.ServiceContex } func (l *AgentMobileCodeLoginLogic) AgentMobileCodeLogin(req *types.MobileCodeLoginReq) (resp *types.MobileCodeLoginResp, err error) { + secretKey := l.svcCtx.Config.Encrypt.SecretKey + encryptedMobile, err := crypto.EncryptMobile(req.Mobile, secretKey) + if err != nil { + return nil, errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "手机登录, 加密手机号失败: %+v", err) + } // 检查手机号是否在一分钟内已发送过验证码 - redisKey := fmt.Sprintf("%s:%s", "query", req.Mobile) + redisKey := fmt.Sprintf("%s:%s", "query", encryptedMobile) cacheCode, err := l.svcCtx.Redis.Get(redisKey) if err != nil { if errors.Is(err, redis.Nil) { - return nil, errors.Wrapf(xerr.NewErrMsg("验证码已过期"), "手机登录, 验证码过期: %s", req.Mobile) + return nil, errors.Wrapf(xerr.NewErrMsg("验证码已过期"), "手机登录, 验证码过期") } - return nil, errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "手机登录, 读取验证码redis缓存失败, mobile: %s, err: %+v", req.Mobile, err) + return nil, errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "手机登录, 读取验证码redis缓存失败, err: %+v", err) } if cacheCode != req.Code { - return nil, errors.Wrapf(xerr.NewErrMsg("验证码不正确"), "手机登录, 验证码不正确: %s", req.Mobile) + return nil, errors.Wrapf(xerr.NewErrMsg("验证码不正确"), "手机登录, 验证码不正确") } - user, findUserErr := l.svcCtx.UserModel.FindOneByMobile(l.ctx, req.Mobile) + user, findUserErr := l.svcCtx.UserModel.FindOneByMobile(l.ctx, encryptedMobile) if findUserErr != nil && findUserErr != model.ErrNotFound { - return nil, errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "手机登录, 读取数据库获取用户失败, mobile: %s, err: %+v", req.Mobile, err) + return nil, errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "手机登录, 读取数据库获取用户失败, mobile: %s, err: %+v", encryptedMobile, err) } if user == nil { - user = &model.User{Mobile: req.Mobile} + user = &model.User{Mobile: encryptedMobile} if len(user.Nickname) == 0 { - user.Nickname = req.Mobile + user.Nickname = "" } if transErr := l.svcCtx.UserModel.Trans(l.ctx, func(ctx context.Context, session sqlx.Session) error { insertResult, userInsertErr := l.svcCtx.UserModel.Insert(ctx, session, user) if userInsertErr != nil { - return errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "手机注册, 数据库插入新用户失败, mobile%s, err: %+v", req.Mobile, err) + return errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "手机注册, 数据库插入新用户失败, mobile%s, err: %+v", encryptedMobile, err) } lastId, lastInsertIdErr := insertResult.LastInsertId() if lastInsertIdErr != nil { @@ -66,7 +73,7 @@ func (l *AgentMobileCodeLoginLogic) AgentMobileCodeLogin(req *types.MobileCodeLo userAuth := new(model.UserAuth) userAuth.UserId = lastId - userAuth.AuthKey = req.Mobile + userAuth.AuthKey = encryptedMobile userAuth.AuthType = model.UserAuthTypeH5Mobile if _, userAuthInsertErr := l.svcCtx.UserAuthModel.Insert(ctx, session, userAuth); userAuthInsertErr != nil { return errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "手机注册, 数据库插入用户认证失败, err:%+v", userAuthInsertErr) diff --git a/app/user/cmd/api/internal/logic/user/detaillogic.go b/app/user/cmd/api/internal/logic/user/detaillogic.go index 5ce405d..1fb8756 100644 --- a/app/user/cmd/api/internal/logic/user/detaillogic.go +++ b/app/user/cmd/api/internal/logic/user/detaillogic.go @@ -2,12 +2,14 @@ package user import ( "context" - "github.com/jinzhu/copier" - "github.com/pkg/errors" "tydata-server/app/user/cmd/api/internal/svc" "tydata-server/app/user/cmd/api/internal/types" "tydata-server/common/ctxdata" "tydata-server/common/xerr" + "tydata-server/pkg/lzkit/crypto" + + "github.com/jinzhu/copier" + "github.com/pkg/errors" "github.com/zeromicro/go-zero/core/logx" ) @@ -40,6 +42,10 @@ func (l *DetailLogic) Detail() (resp *types.UserInfoResp, err error) { if err != nil { return nil, errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "用户信息, 用户信息结构体复制失败, %v", err) } + userInfo.Mobile, err = crypto.DecryptMobile(userInfo.Mobile, l.svcCtx.Config.Encrypt.SecretKey) + if err != nil { + return nil, errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "用户信息, 解密手机号失败, %v", err) + } return &types.UserInfoResp{ UserInfo: userInfo, }, nil diff --git a/app/user/cmd/api/internal/logic/user/mobilecodeloginlogic.go b/app/user/cmd/api/internal/logic/user/mobilecodeloginlogic.go index 3a1f72a..5911a2c 100644 --- a/app/user/cmd/api/internal/logic/user/mobilecodeloginlogic.go +++ b/app/user/cmd/api/internal/logic/user/mobilecodeloginlogic.go @@ -3,15 +3,17 @@ package user import ( "context" "fmt" - "github.com/pkg/errors" - "github.com/zeromicro/go-zero/core/stores/redis" - "github.com/zeromicro/go-zero/core/stores/sqlx" "time" "tydata-server/app/user/cmd/api/internal/svc" "tydata-server/app/user/cmd/api/internal/types" "tydata-server/app/user/model" jwtx "tydata-server/common/jwt" "tydata-server/common/xerr" + "tydata-server/pkg/lzkit/crypto" + + "github.com/pkg/errors" + "github.com/zeromicro/go-zero/core/stores/redis" + "github.com/zeromicro/go-zero/core/stores/sqlx" "github.com/zeromicro/go-zero/core/logx" ) @@ -31,34 +33,39 @@ func NewMobileCodeLoginLogic(ctx context.Context, svcCtx *svc.ServiceContext) *M } func (l *MobileCodeLoginLogic) MobileCodeLogin(req *types.MobileCodeLoginReq) (resp *types.MobileCodeLoginResp, err error) { + secretKey := l.svcCtx.Config.Encrypt.SecretKey + encryptedMobile, err := crypto.EncryptMobile(req.Mobile, secretKey) + if err != nil { + return nil, errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "手机登录, 加密手机号失败: %+v", err) + } if !l.MobileCodeLoginInside(req) { // 检查手机号是否在一分钟内已发送过验证码 - redisKey := fmt.Sprintf("%s:%s", "login", req.Mobile) + redisKey := fmt.Sprintf("%s:%s", "login", encryptedMobile) cacheCode, err := l.svcCtx.Redis.Get(redisKey) if err != nil { if errors.Is(err, redis.Nil) { - return nil, errors.Wrapf(xerr.NewErrMsg("验证码已过期"), "手机登录, 验证码过期: %s", req.Mobile) + return nil, errors.Wrapf(xerr.NewErrMsg("验证码已过期"), "手机登录, 验证码过期: %s", encryptedMobile) } - return nil, errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "手机登录, 读取验证码redis缓存失败, mobile: %s, err: %+v", req.Mobile, err) + return nil, errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "手机登录, 读取验证码redis缓存失败, mobile: %s, err: %+v", encryptedMobile, err) } if cacheCode != req.Code { - return nil, errors.Wrapf(xerr.NewErrMsg("验证码不正确"), "手机登录, 验证码不正确: %s", req.Mobile) + return nil, errors.Wrapf(xerr.NewErrMsg("验证码不正确"), "手机登录, 验证码不正确: %s", encryptedMobile) } } - user, findUserErr := l.svcCtx.UserModel.FindOneByMobile(l.ctx, req.Mobile) + user, findUserErr := l.svcCtx.UserModel.FindOneByMobile(l.ctx, encryptedMobile) if findUserErr != nil && findUserErr != model.ErrNotFound { - return nil, errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "手机登录, 读取数据库获取用户失败, mobile: %s, err: %+v", req.Mobile, err) + return nil, errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "手机登录, 读取数据库获取用户失败, mobile: %s, err: %+v", encryptedMobile, err) } if user == nil { - user = &model.User{Mobile: req.Mobile} + user = &model.User{Mobile: encryptedMobile} if len(user.Nickname) == 0 { - user.Nickname = req.Mobile + user.Nickname = encryptedMobile } if transErr := l.svcCtx.UserModel.Trans(l.ctx, func(ctx context.Context, session sqlx.Session) error { insertResult, userInsertErr := l.svcCtx.UserModel.Insert(ctx, session, user) if userInsertErr != nil { - return errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "手机注册, 数据库插入新用户失败, mobile%s, err: %+v", req.Mobile, err) + return errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "手机注册, 数据库插入新用户失败, mobile%s, err: %+v", encryptedMobile, err) } lastId, lastInsertIdErr := insertResult.LastInsertId() if lastInsertIdErr != nil { @@ -68,7 +75,7 @@ func (l *MobileCodeLoginLogic) MobileCodeLogin(req *types.MobileCodeLoginReq) (r userAuth := new(model.UserAuth) userAuth.UserId = lastId - userAuth.AuthKey = req.Mobile + userAuth.AuthKey = encryptedMobile userAuth.AuthType = model.UserAuthTypeAppMobile if _, userAuthInsertErr := l.svcCtx.UserAuthModel.Insert(ctx, session, userAuth); userAuthInsertErr != nil { return errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "手机注册, 数据库插入用户认证失败, err:%+v", userAuthInsertErr) diff --git a/app/user/cmd/api/internal/logic/user/mobileloginlogic.go b/app/user/cmd/api/internal/logic/user/mobileloginlogic.go index 4ff1144..8498ee8 100644 --- a/app/user/cmd/api/internal/logic/user/mobileloginlogic.go +++ b/app/user/cmd/api/internal/logic/user/mobileloginlogic.go @@ -2,14 +2,16 @@ package user import ( "context" - "github.com/pkg/errors" "time" "tydata-server/app/user/model" jwtx "tydata-server/common/jwt" "tydata-server/common/tool" "tydata-server/common/xerr" + "tydata-server/pkg/lzkit/crypto" "tydata-server/pkg/lzkit/lzUtils" + "github.com/pkg/errors" + "tydata-server/app/user/cmd/api/internal/svc" "tydata-server/app/user/cmd/api/internal/types" @@ -31,15 +33,20 @@ func NewMobileLoginLogic(ctx context.Context, svcCtx *svc.ServiceContext) *Mobil } func (l *MobileLoginLogic) MobileLogin(req *types.MobileLoginReq) (resp *types.MobileCodeLoginResp, err error) { - user, findUserErr := l.svcCtx.UserModel.FindOneByMobile(l.ctx, req.Mobile) + secretKey := l.svcCtx.Config.Encrypt.SecretKey + encryptedMobile, err := crypto.EncryptMobile(req.Mobile, secretKey) + if err != nil { + return nil, errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "手机登录, 加密手机号失败: %+v", err) + } + user, findUserErr := l.svcCtx.UserModel.FindOneByMobile(l.ctx, encryptedMobile) if findUserErr != nil && findUserErr != model.ErrNotFound { - return nil, errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "手机登录, 读取数据库获取用户失败, mobile%s, err: %+v", req.Mobile, err) + return nil, errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "手机登录, 读取数据库获取用户失败, mobile%s, err: %+v", encryptedMobile, err) } if user == nil { - return nil, errors.Wrapf(xerr.NewErrMsg("手机号码未注册"), "手机登录, 手机号未注册:%s", req.Mobile) + return nil, errors.Wrapf(xerr.NewErrMsg("手机号码未注册"), "手机登录, 手机号未注册:%s", encryptedMobile) } if !(tool.Md5ByString(req.Password) == lzUtils.NullStringToString(user.Password)) { - return nil, errors.Wrapf(xerr.NewErrMsg("密码不正确"), "手机登录, 密码匹配不正确%s", req.Mobile) + return nil, errors.Wrapf(xerr.NewErrMsg("密码不正确"), "手机登录, 密码匹配不正确%s", encryptedMobile) } token, generaErr := jwtx.GenerateJwtToken(user.Id, l.svcCtx.Config.JwtAuth.AccessSecret, l.svcCtx.Config.JwtAuth.AccessExpire) diff --git a/app/user/cmd/api/internal/logic/user/registerlogic.go b/app/user/cmd/api/internal/logic/user/registerlogic.go index 7d8ef47..f4bdb72 100644 --- a/app/user/cmd/api/internal/logic/user/registerlogic.go +++ b/app/user/cmd/api/internal/logic/user/registerlogic.go @@ -3,9 +3,6 @@ package user import ( "context" "fmt" - "github.com/pkg/errors" - "github.com/zeromicro/go-zero/core/stores/redis" - "github.com/zeromicro/go-zero/core/stores/sqlx" "time" "tydata-server/app/user/cmd/api/internal/svc" "tydata-server/app/user/cmd/api/internal/types" @@ -13,8 +10,13 @@ import ( jwtx "tydata-server/common/jwt" "tydata-server/common/tool" "tydata-server/common/xerr" + "tydata-server/pkg/lzkit/crypto" "tydata-server/pkg/lzkit/lzUtils" + "github.com/pkg/errors" + "github.com/zeromicro/go-zero/core/stores/redis" + "github.com/zeromicro/go-zero/core/stores/sqlx" + "github.com/zeromicro/go-zero/core/logx" ) @@ -33,38 +35,43 @@ func NewRegisterLogic(ctx context.Context, svcCtx *svc.ServiceContext) *Register } func (l *RegisterLogic) Register(req *types.RegisterReq) (resp *types.RegisterResp, err error) { + secretKey := l.svcCtx.Config.Encrypt.SecretKey + encryptedMobile, err := crypto.EncryptMobile(req.Mobile, secretKey) + if err != nil { + return nil, errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "手机注册, 加密手机号失败: %+v", err) + } // 检查手机号是否在一分钟内已发送过验证码 - redisKey := fmt.Sprintf("%s:%s", "register", req.Mobile) + redisKey := fmt.Sprintf("%s:%s", "register", encryptedMobile) cacheCode, err := l.svcCtx.Redis.Get(redisKey) if err != nil { if errors.Is(err, redis.Nil) { - return nil, errors.Wrapf(xerr.NewErrMsg("验证码已过期"), "手机注册, 验证码过期: %s", req.Mobile) + return nil, errors.Wrapf(xerr.NewErrMsg("验证码已过期"), "手机注册, 验证码过期: %s", encryptedMobile) } - return nil, errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "手机注册, 读取验证码redis缓存失败, mobile: %s, err: %+v", req.Mobile, err) + return nil, errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "手机注册, 读取验证码redis缓存失败, mobile: %s, err: %+v", encryptedMobile, err) } if cacheCode != req.Code { - return nil, errors.Wrapf(xerr.NewErrMsg("验证码不正确"), "手机注册, 验证码不正确: %s", req.Mobile) + return nil, errors.Wrapf(xerr.NewErrMsg("验证码不正确"), "手机注册, 验证码不正确: %s", encryptedMobile) } - hasUser, findUserErr := l.svcCtx.UserModel.FindOneByMobile(l.ctx, req.Mobile) + hasUser, findUserErr := l.svcCtx.UserModel.FindOneByMobile(l.ctx, encryptedMobile) if findUserErr != nil && findUserErr != model.ErrNotFound { - return nil, errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "手机注册, 读取数据库获取用户失败, mobile%s, err: %+v", req.Mobile, err) + return nil, errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "手机注册, 读取数据库获取用户失败, mobile%s, err: %+v", encryptedMobile, err) } if hasUser != nil { - return nil, errors.Wrapf(xerr.NewErrMsg("该手机号码已注册"), "手机注册, 手机号码已注册, mobile:%s", req.Mobile) + return nil, errors.Wrapf(xerr.NewErrMsg("该手机号码已注册"), "手机注册, 手机号码已注册, mobile:%s", encryptedMobile) } var userId int64 if transErr := l.svcCtx.UserModel.Trans(l.ctx, func(ctx context.Context, session sqlx.Session) error { user := new(model.User) - user.Mobile = req.Mobile + user.Mobile = encryptedMobile if len(user.Nickname) == 0 { - user.Nickname = req.Mobile + user.Nickname = encryptedMobile } if len(req.Password) > 0 { user.Password = lzUtils.StringToNullString(tool.Md5ByString(req.Password)) } insertResult, userInsertErr := l.svcCtx.UserModel.Insert(ctx, session, user) if userInsertErr != nil { - return errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "手机注册, 数据库插入新用户失败, mobile%s, err: %+v", req.Mobile, err) + return errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "手机注册, 数据库插入新用户失败, mobile%s, err: %+v", encryptedMobile, err) } lastId, lastInsertIdErr := insertResult.LastInsertId() if lastInsertIdErr != nil { @@ -74,7 +81,7 @@ func (l *RegisterLogic) Register(req *types.RegisterReq) (resp *types.RegisterRe userAuth := new(model.UserAuth) userAuth.UserId = lastId - userAuth.AuthKey = req.Mobile + userAuth.AuthKey = encryptedMobile userAuth.AuthType = model.UserAuthTypeAppMobile if _, userAuthInsertErr := l.svcCtx.UserAuthModel.Insert(ctx, session, userAuth); userAuthInsertErr != nil { return errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "手机注册, 数据库插入用户认证失败, err:%+v", userAuthInsertErr) diff --git a/common/xerr/errCode.go b/common/xerr/errCode.go index ade3e6f..0d3f846 100644 --- a/common/xerr/errCode.go +++ b/common/xerr/errCode.go @@ -14,7 +14,9 @@ const DB_ERROR uint32 = 100005 const DB_UPDATE_AFFECTED_ZERO_ERROR uint32 = 100006 const PARAM_VERIFICATION_ERROR uint32 = 100007 const CUSTOM_ERROR uint32 = 100008 +const AUTH_ERROR uint32 = 100009 const LOGIN_FAILED uint32 = 200001 const LOGIC_QUERY_WAIT uint32 = 200002 const LOGIC_QUERY_ERROR uint32 = 200003 +const LOGIC_QUERY_NOT_FOUND uint32 = 200004 diff --git a/common/xerr/errMsg.go b/common/xerr/errMsg.go index ece9110..7e03997 100644 --- a/common/xerr/errMsg.go +++ b/common/xerr/errMsg.go @@ -11,6 +11,7 @@ func init() { message[TOKEN_GENERATE_ERROR] = "生成token失败" message[DB_ERROR] = "数据库繁忙,请稍后再试" message[DB_UPDATE_AFFECTED_ZERO_ERROR] = "更新数据影响行数为0" + message[AUTH_ERROR] = "权限错误,无权访问" } func MapErrMsg(errcode uint32) string { diff --git a/pkg/lzkit/crypto/README.md b/pkg/lzkit/crypto/README.md new file mode 100644 index 0000000..bd8d77e --- /dev/null +++ b/pkg/lzkit/crypto/README.md @@ -0,0 +1,235 @@ +# AES 加密工具包 + +本包提供了多种加密方式,特别是用于处理敏感个人信息(如手机号、身份证号等)的加密和解密功能。 + +## 主要功能 + +- **AES-CBC 模式加密/解密** - 标准加密模式,适用于一般数据加密 +- **AES-ECB 模式加密/解密** - 确定性加密模式,适用于数据库字段加密和查询 +- **专门针对个人敏感信息的加密/解密方法** +- **密钥生成和管理工具** + +## 安全性说明 + +- **AES-CBC 模式**:使用随机 IV,相同明文每次加密结果不同,安全性较高 +- **AES-ECB 模式**:确定性加密,相同明文每次加密结果相同,便于数据库查询,但安全性较低 + +> **⚠️ 警告**:ECB 模式仅适用于短文本(如手机号、身份证号)的确定性加密,不建议用于加密大段文本或高安全需求场景。 + +## 使用示例 + +### 1. 加密手机号 + +使用 AES-ECB 模式加密手机号,保证确定性(相同手机号总是产生相同密文): + +```go +import ( + "fmt" + "tydata-server/pkg/lzkit/crypto" +) + +func encryptMobileExample() { + // 您的密钥(需安全保存,建议存储在配置中) + key := []byte("1234567890abcdef") // 16字节AES-128密钥 + + // 加密手机号 + mobile := "13800138000" + encryptedMobile, err := crypto.EncryptMobile(mobile, key) + if err != nil { + panic(err) + } + + fmt.Println("加密后的手机号:", encryptedMobile) + + // 解密手机号 + decryptedMobile, err := crypto.DecryptMobile(encryptedMobile, key) + if err != nil { + panic(err) + } + + fmt.Println("解密后的手机号:", decryptedMobile) +} +``` + +### 2. 在数据库中存储和查询加密手机号 + +```go +// 加密并存储手机号 +func saveUser(db *sqlx.DB, mobile string, key []byte) (int64, error) { + encryptedMobile, err := crypto.EncryptMobile(mobile, key) + if err != nil { + return 0, err + } + + var id int64 + err = db.QueryRow( + "INSERT INTO users (mobile, create_time) VALUES (?, NOW()) RETURNING id", + encryptedMobile, + ).Scan(&id) + + return id, err +} + +// 根据手机号查询用户 +func findUserByMobile(db *sqlx.DB, mobile string, key []byte) (*User, error) { + encryptedMobile, err := crypto.EncryptMobile(mobile, key) + if err != nil { + return nil, err + } + + var user User + err = db.QueryRow( + "SELECT id, mobile, create_time FROM users WHERE mobile = ?", + encryptedMobile, + ).Scan(&user.ID, &user.EncryptedMobile, &user.CreateTime) + + if err != nil { + return nil, err + } + + // 解密手机号用于显示 + user.Mobile, _ = crypto.DecryptMobile(user.EncryptedMobile, key) + + return &user, nil +} +``` + +### 3. 加密身份证号 + +```go +func encryptIDCardExample() { + key := []byte("1234567890abcdef") + + idCard := "440101199001011234" + encryptedIDCard, err := crypto.EncryptIDCard(idCard, key) + if err != nil { + panic(err) + } + + fmt.Println("加密后的身份证号:", encryptedIDCard) + + // 解密身份证号 + decryptedIDCard, err := crypto.DecryptIDCard(encryptedIDCard, key) + if err != nil { + panic(err) + } + + fmt.Println("解密后的身份证号:", decryptedIDCard) +} +``` + +### 4. 密钥管理 + +```go +func keyManagementExample() { + // 生成随机密钥 + key, err := crypto.GenerateAESKey(16) // AES-128 + if err != nil { + panic(err) + } + fmt.Printf("生成的密钥(十六进制): %x\n", key) + + // 从密码派生密钥(便于记忆) + password := "my-secure-password" + derivedKey, err := crypto.DeriveKeyFromPassword(password, 16) + if err != nil { + panic(err) + } + fmt.Printf("从密码派生的密钥: %x\n", derivedKey) +} +``` + +### 5. 使用十六进制输出(适用于 URL 参数) + +```go +func hexEncodingExample() { + key := []byte("1234567890abcdef") + mobile := "13800138000" + + // 使用十六进制编码(适合URL参数) + encryptedHex, err := crypto.EncryptMobileHex(mobile, key) + if err != nil { + panic(err) + } + + fmt.Println("十六进制编码的加密手机号:", encryptedHex) + + // 解密十六进制编码的手机号 + decryptedMobile, err := crypto.DecryptMobileHex(encryptedHex, key) + if err != nil { + panic(err) + } + + fmt.Println("解密后的手机号:", decryptedMobile) +} +``` + +## 在 Go-Zero 项目中使用 + +在 Go-Zero 项目中,建议将加密密钥放在配置文件中: + +1. 在配置文件中添加密钥配置: + +```yaml +# etc/main.yaml +Name: user-api +Host: 0.0.0.0 +Port: 8888 + +Encrypt: + MobileKey: "1234567890abcdef" # 16字节AES-128密钥 + IDCardKey: "1234567890abcdef1234567890abcdef" # 32字节AES-256密钥 +``` + +2. 在配置结构中定义: + +```go +type Config struct { + rest.RestConf + Encrypt struct { + MobileKey string + IDCardKey string + } +} +``` + +3. 在服务上下文中使用: + +```go +type ServiceContext struct { + Config config.Config + UserModel model.UserModel + MobileKey []byte + IDCardKey []byte +} + +func NewServiceContext(c config.Config) *ServiceContext { + return &ServiceContext{ + Config: c, + UserModel: model.NewUserModel(sqlx.NewMysql(c.DB.DataSource), c.Cache), + MobileKey: []byte(c.Encrypt.MobileKey), + IDCardKey: []byte(c.Encrypt.IDCardKey), + } +} +``` + +4. 在 Logic 中使用: + +```go +func (l *RegisterLogic) Register(req *types.RegisterReq) (*types.RegisterResp, error) { + // 加密手机号用于存储 + encryptedMobile, err := crypto.EncryptMobile(req.Mobile, l.svcCtx.MobileKey) + if err != nil { + return nil, errors.New("手机号加密失败") + } + + // 保存到数据库 + user := &model.User{ + Mobile: encryptedMobile, + // 其他字段... + } + + result, err := l.svcCtx.UserModel.Insert(l.ctx, nil, user) + // 其余逻辑... +} +``` diff --git a/pkg/lzkit/crypto/ecb.go b/pkg/lzkit/crypto/ecb.go new file mode 100644 index 0000000..3fed15f --- /dev/null +++ b/pkg/lzkit/crypto/ecb.go @@ -0,0 +1,274 @@ +package crypto + +import ( + "crypto/aes" + "crypto/md5" + "crypto/rand" + "encoding/base64" + "encoding/hex" + "errors" + "fmt" +) + +// ECB模式是一种基本的加密模式,每个明文块独立加密 +// 警告:ECB模式存在安全问题,仅用于需要确定性加密的场景,如数据库字段查询 +// 不要用于加密大段文本或安全要求高的场景 + +// 验证密钥长度是否有效 (AES-128, AES-192, AES-256) +func validateAESKey(key []byte) error { + switch len(key) { + case 16, 24, 32: + return nil + default: + return errors.New("AES密钥长度必须是16、24或32字节(对应AES-128、AES-192、AES-256)") + } +} + +// AesEcbEncrypt AES-ECB模式加密,返回Base64编码的密文 +// 使用已有的ECB实现,但提供更易用的接口 +func AesEcbEncrypt(plainText, key []byte) (string, error) { + if err := validateAESKey(key); err != nil { + return "", err + } + + block, err := aes.NewCipher(key) + if err != nil { + return "", err + } + + // 使用PKCS7填充 + plainText = PKCS7Padding(plainText, block.BlockSize()) + + // 创建密文数组 + cipherText := make([]byte, len(plainText)) + + // ECB模式加密,使用west_crypto.go中已有的实现 + mode := newECBEncrypter(block) + mode.CryptBlocks(cipherText, plainText) + + // 返回Base64编码的密文 + return base64.StdEncoding.EncodeToString(cipherText), nil +} + +// AesEcbDecrypt AES-ECB模式解密,输入Base64编码的密文 +func AesEcbDecrypt(cipherTextBase64 string, key []byte) ([]byte, error) { + if err := validateAESKey(key); err != nil { + return nil, err + } + + // Base64解码 + cipherText, err := base64.StdEncoding.DecodeString(cipherTextBase64) + if err != nil { + return nil, err + } + + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + + // 检查密文长度 + if len(cipherText)%block.BlockSize() != 0 { + return nil, errors.New("密文长度必须是块大小的整数倍") + } + + // 创建明文数组 + plainText := make([]byte, len(cipherText)) + + // ECB模式解密,使用west_crypto.go中已有的实现 + mode := newECBDecrypter(block) + mode.CryptBlocks(plainText, cipherText) + + // 去除PKCS7填充 + plainText, err = PKCS7UnPadding(plainText) + if err != nil { + return nil, err + } + + return plainText, nil +} + +// AesEcbEncryptHex AES-ECB模式加密,返回十六进制编码的密文 +func AesEcbEncryptHex(plainText, key []byte) (string, error) { + if err := validateAESKey(key); err != nil { + return "", err + } + + block, err := aes.NewCipher(key) + if err != nil { + return "", err + } + + // 使用PKCS7填充 + plainText = PKCS7Padding(plainText, block.BlockSize()) + + // 创建密文数组 + cipherText := make([]byte, len(plainText)) + + // ECB模式加密 + mode := newECBEncrypter(block) + mode.CryptBlocks(cipherText, plainText) + + // 返回十六进制编码的密文 + return hex.EncodeToString(cipherText), nil +} + +// AesEcbDecryptHex AES-ECB模式解密,输入十六进制编码的密文 +func AesEcbDecryptHex(cipherTextHex string, key []byte) ([]byte, error) { + if err := validateAESKey(key); err != nil { + return nil, err + } + + // 十六进制解码 + cipherText, err := hex.DecodeString(cipherTextHex) + if err != nil { + return nil, err + } + + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + + // 检查密文长度 + if len(cipherText)%block.BlockSize() != 0 { + return nil, errors.New("密文长度必须是块大小的整数倍") + } + + // 创建明文数组 + plainText := make([]byte, len(cipherText)) + + // ECB模式解密 + mode := newECBDecrypter(block) + mode.CryptBlocks(plainText, cipherText) + + // 去除PKCS7填充 + plainText, err = PKCS7UnPadding(plainText) + if err != nil { + return nil, err + } + + return plainText, nil +} + +// 以下是专门用于处理手机号等敏感数据的实用函数 + +// EncryptMobile 使用AES-ECB加密手机号,返回Base64编码 +// 该方法保证对相同手机号总是产生相同密文,便于数据库查询 +func EncryptMobile(mobile string, secretKey string) (string, error) { + key, decodeErr := hex.DecodeString(secretKey) + if decodeErr != nil { + return "", decodeErr + } + if mobile == "" { + return "", errors.New("手机号不能为空") + } + return AesEcbEncrypt([]byte(mobile), key) +} + +// DecryptMobile 解密手机号 +func DecryptMobile(encryptedMobile string, secretKey string) (string, error) { + key, decodeErr := hex.DecodeString(secretKey) + if decodeErr != nil { + return "", decodeErr + } + if encryptedMobile == "" { + return "", errors.New("加密手机号不能为空") + } + + bytes, err := AesEcbDecrypt(encryptedMobile, key) + if err != nil { + return "", fmt.Errorf("解密手机号失败: %v", err) + } + + return string(bytes), nil +} + +// EncryptMobileHex 使用AES-ECB加密手机号,返回十六进制编码(适用于URL参数) +func EncryptMobileHex(mobile string, key []byte) (string, error) { + if mobile == "" { + return "", errors.New("手机号不能为空") + } + return AesEcbEncryptHex([]byte(mobile), key) +} + +// DecryptMobileHex 解密十六进制编码的手机号 +func DecryptMobileHex(encryptedMobileHex string, key []byte) (string, error) { + if encryptedMobileHex == "" { + return "", errors.New("加密手机号不能为空") + } + + bytes, err := AesEcbDecryptHex(encryptedMobileHex, key) + if err != nil { + return "", fmt.Errorf("解密手机号失败: %v", err) + } + + return string(bytes), nil +} + +// EncryptIDCard 使用AES-ECB加密身份证号 +func EncryptIDCard(idCard string, key []byte) (string, error) { + if idCard == "" { + return "", errors.New("身份证号不能为空") + } + return AesEcbEncrypt([]byte(idCard), key) +} + +// DecryptIDCard 解密身份证号 +func DecryptIDCard(encryptedIDCard string, key []byte) (string, error) { + if encryptedIDCard == "" { + return "", errors.New("加密身份证号不能为空") + } + + bytes, err := AesEcbDecrypt(encryptedIDCard, key) + if err != nil { + return "", fmt.Errorf("解密身份证号失败: %v", err) + } + + return string(bytes), nil +} + +// IsEncrypted 检查字符串是否为Base64编码的加密数据 +func IsEncrypted(data string) bool { + // 检查是否是有效的Base64编码 + _, err := base64.StdEncoding.DecodeString(data) + return err == nil && len(data) >= 20 // 至少20个字符的Base64字符串 +} + +// GenerateAESKey 生成AES密钥 +// keySize: 可选16, 24, 32字节(对应AES-128, AES-192, AES-256) +func GenerateAESKey(keySize int) ([]byte, error) { + if keySize != 16 && keySize != 24 && keySize != 32 { + return nil, errors.New("密钥长度必须是16、24或32字节") + } + + key := make([]byte, keySize) + _, err := rand.Read(key) + if err != nil { + return nil, err + } + + return key, nil +} + +// DeriveKeyFromPassword 基于密码派生固定长度的AES密钥 +func DeriveKeyFromPassword(password string, keySize int) ([]byte, error) { + if keySize != 16 && keySize != 24 && keySize != 32 { + return nil, errors.New("密钥长度必须是16、24或32字节") + } + + // 使用PBKDF2或简单的方法从密码派生密钥 + // 这里使用简单的MD5方法,实际生产环境应使用更安全的PBKDF2 + hash := md5.New() + hash.Write([]byte(password)) + key := hash.Sum(nil) // 16字节 + + // 如果需要24或32字节,继续哈希 + if keySize > 16 { + hash.Reset() + hash.Write(key) + key = append(key, hash.Sum(nil)[:keySize-16]...) + } + + return key, nil +} diff --git a/pkg/lzkit/crypto/ecb_test.go b/pkg/lzkit/crypto/ecb_test.go new file mode 100644 index 0000000..b0cf696 --- /dev/null +++ b/pkg/lzkit/crypto/ecb_test.go @@ -0,0 +1,184 @@ +package crypto + +import ( + "encoding/base64" + "encoding/hex" + "fmt" + "testing" +) + +func TestAesEcbMobileEncryption(t *testing.T) { + // 测试手机号加密 + mobile := "13800138000" + key := []byte("1234567890abcdef") // 16字节AES-128密钥 + + // 测试加密 + encrypted, err := EncryptMobile(mobile, key) + if err != nil { + t.Fatalf("手机号加密失败: %v", err) + } + fmt.Println(encrypted) + // 测试解密 + decrypted, err := DecryptMobile(encrypted, key) + if err != nil { + t.Fatalf("手机号解密失败: %v", err) + } + fmt.Println(decrypted) + // 验证结果 + if decrypted != mobile { + t.Errorf("解密结果不匹配,期望: %s, 实际: %s", mobile, decrypted) + } + + // 测试相同输入产生相同输出(确定性) + encrypted2, _ := EncryptMobile(mobile, key) + if encrypted != encrypted2 { + t.Errorf("AES-ECB不是确定性的,两次加密结果不同: %s vs %s", encrypted, encrypted2) + } +} + +func TestAesEcbHexEncryption(t *testing.T) { + // 测试十六进制编码加密 + idCard := "440101199001011234" + key := []byte("1234567890abcdef") // 16字节AES-128密钥 + + // 测试HEX加密 + encryptedHex, err := EncryptIDCard(idCard, key) + if err != nil { + t.Fatalf("身份证加密失败: %v", err) + } + + // 测试HEX解密 + decrypted, err := DecryptIDCard(encryptedHex, key) + if err != nil { + t.Fatalf("身份证解密失败: %v", err) + } + + // 验证结果 + if decrypted != idCard { + t.Errorf("解密结果不匹配,期望: %s, 实际: %s", idCard, decrypted) + } +} + +func TestAesEcbKeyValidation(t *testing.T) { + // 测试不同长度的密钥 + validKeys := [][]byte{ + make([]byte, 16), // AES-128 + make([]byte, 24), // AES-192 + make([]byte, 32), // AES-256 + } + + invalidKeys := [][]byte{ + make([]byte, 15), + make([]byte, 20), + make([]byte, 33), + } + + text := []byte("test text") + + // 测试有效密钥 + for _, key := range validKeys { + _, err := AesEcbEncrypt(text, key) + if err != nil { + t.Errorf("有效密钥(%d字节)校验失败: %v", len(key), err) + } + } + + // 测试无效密钥 + for _, key := range invalidKeys { + _, err := AesEcbEncrypt(text, key) + if err == nil { + t.Errorf("无效密钥(%d字节)未被检测出", len(key)) + } + } +} + +func TestIsEncrypted(t *testing.T) { + // 有效的Base64编码字符串 + validBase64 := base64.StdEncoding.EncodeToString([]byte("这是一个足够长的字符串,以通过IsEncrypted检查")) + + // 无效的字符串 + invalidStrings := []string{ + "", + "abc", + "not-base64!@#", + hex.EncodeToString([]byte("hexstring")), + } + + // 测试有效的加密数据 + if !IsEncrypted(validBase64) { + t.Errorf("有效的Base64未被识别为加密数据: %s", validBase64) + } + + // 测试无效的数据 + for _, s := range invalidStrings { + if IsEncrypted(s) { + t.Errorf("无效字符串被错误识别为加密数据: %s", s) + } + } +} + +func TestDeriveKeyFromPassword(t *testing.T) { + password := "my-secure-password" + + // 测试不同长度的派生密钥 + keySizes := []int{16, 24, 32} + + for _, size := range keySizes { + key, err := DeriveKeyFromPassword(password, size) + if err != nil { + t.Errorf("从密码派生%d字节密钥失败: %v", size, err) + continue + } + + if len(key) != size { + t.Errorf("派生的密钥长度错误,期望: %d, 实际: %d", size, len(key)) + } + + // 测试相同密码总是产生相同密钥 + key2, _ := DeriveKeyFromPassword(password, size) + if string(key) != string(key2) { + t.Errorf("从相同密码派生的密钥不一致") + } + + // 使用派生的密钥加密测试 + _, err = AesEcbEncrypt([]byte("test"), key) + if err != nil { + t.Errorf("使用派生的密钥加密失败: %v", err) + } + } + + // 测试无效的密钥大小 + _, err := DeriveKeyFromPassword(password, 18) + if err == nil { + t.Error("无效的密钥大小未被检测出") + } +} + +func TestGenerateAESKey(t *testing.T) { + // 测试生成不同长度的密钥 + keySizes := []int{16, 24, 32} + + for _, size := range keySizes { + key, err := GenerateAESKey(size) + if err != nil { + t.Errorf("生成%d字节密钥失败: %v", size, err) + continue + } + + if len(key) != size { + t.Errorf("生成的密钥长度错误,期望: %d, 实际: %d", size, len(key)) + } + + // 使用生成的密钥加密测试 + _, err = AesEcbEncrypt([]byte("test"), key) + if err != nil { + t.Errorf("使用生成的密钥加密失败: %v", err) + } + } + + // 测试无效的密钥大小 + _, err := GenerateAESKey(18) + if err == nil { + t.Error("无效的密钥大小未被检测出") + } +} diff --git a/pkg/lzkit/md5/README.md b/pkg/lzkit/md5/README.md new file mode 100644 index 0000000..5398ec1 --- /dev/null +++ b/pkg/lzkit/md5/README.md @@ -0,0 +1,106 @@ +# MD5 工具包 + +这个包提供了全面的 MD5 哈希功能,包括字符串加密、文件加密、链式操作、加盐哈希等。 + +## 主要功能 + +- 字符串和字节切片的 MD5 哈希计算 +- 文件 MD5 哈希计算(支持大文件分块处理) +- 链式 API,支持构建复杂的哈希内容 +- 带盐值的 MD5 哈希,提高安全性 +- 哈希验证功能 +- 16 位和 8 位 MD5 哈希(短版本) +- HMAC-MD5 实现,增强安全性 + +## 使用示例 + +### 基本使用 + +```go +// 计算字符串的MD5哈希 +hash := md5.EncryptString("hello world") +fmt.Println(hash) // 5eb63bbbe01eeed093cb22bb8f5acdc3 + +// 计算字节切片的MD5哈希 +bytes := []byte("hello world") +hash = md5.EncryptBytes(bytes) + +// 计算文件的MD5哈希 +fileHash, err := md5.EncryptFile("path/to/file.txt") +if err != nil { + log.Fatal(err) +} +fmt.Println(fileHash) +``` + +### 链式 API + +```go +// 创建一个新的MD5实例并添加内容 +hash := md5.New(). + Add("hello"). + Add(" "). + Add("world"). + Sum() +fmt.Println(hash) // 5eb63bbbe01eeed093cb22bb8f5acdc3 + +// 或者从字符串初始化 +hash = md5.FromString("hello"). + Add(" world"). + Sum() + +// 从字节切片初始化 +hash = md5.FromBytes([]byte("hello")). + AddBytes([]byte(" world")). + Sum() +``` + +### 安全性增强 + +```go +// 使用盐值加密(提高安全性) +hashedPassword := md5.EncryptStringWithSalt("password123", "user@example.com") + +// 使用前缀加密 +hashedValue := md5.EncryptStringWithPrefix("secret-data", "prefix-") + +// 验证带盐值的哈希 +isValid := md5.VerifyMD5WithSalt("password123", "user@example.com", hashedPassword) + +// 使用HMAC-MD5提高安全性 +hmacHash := md5.MD5HMAC("message", "secret-key") +``` + +### 短哈希值 + +```go +// 获取16位MD5(32位MD5的中间部分) +hash16 := md5.Get16("hello world") +fmt.Println(hash16) // 中间16个字符 + +// 获取8位MD5 +hash8 := md5.Get8("hello world") +fmt.Println(hash8) // 中间8个字符 +``` + +### 文件验证 + +```go +// 验证文件MD5是否匹配 +match, err := md5.VerifyFileMD5("path/to/file.txt", "expected-hash") +if err != nil { + log.Fatal(err) +} +if match { + fmt.Println("文件MD5校验通过") +} else { + fmt.Println("文件MD5校验失败") +} +``` + +## 注意事项 + +1. MD5 主要用于校验,不适合用于安全存储密码等敏感信息 +2. 如果用于密码存储,请务必使用加盐处理并考虑使用更安全的算法 +3. 处理大文件时请使用`EncryptFileChunk`以优化性能 +4. 返回的 MD5 哈希值都是 32 位的小写十六进制字符串(除非使用 Get16/Get8 函数) diff --git a/pkg/lzkit/md5/example_test.go b/pkg/lzkit/md5/example_test.go new file mode 100644 index 0000000..26a5e96 --- /dev/null +++ b/pkg/lzkit/md5/example_test.go @@ -0,0 +1,79 @@ +package md5_test + +import ( + "fmt" + "log" + "tydata-server/pkg/lzkit/md5" +) + +func Example() { + // 简单的字符串MD5 + hashValue := md5.EncryptString("hello world") + fmt.Println("MD5(hello world):", hashValue) + + // 使用链式API + chainHash := md5.New(). + Add("hello"). + Add(" "). + Add("world"). + Sum() + fmt.Println("链式MD5:", chainHash) + + // 使用盐值 + saltedHash := md5.EncryptStringWithSalt("password123", "user@example.com") + fmt.Println("加盐MD5:", saltedHash) + + // 验证哈希 + isValid := md5.VerifyMD5("hello world", hashValue) + fmt.Println("验证结果:", isValid) + + // 生成短版本的MD5 + fmt.Println("16位MD5:", md5.Get16("hello world")) + fmt.Println("8位MD5:", md5.Get8("hello world")) + + // 文件MD5计算 + filePath := "example.txt" // 这只是示例,实际上这个文件可能不存在 + fileHash, err := md5.EncryptFile(filePath) + if err != nil { + // 在实际代码中执行正确的错误处理 + log.Printf("计算文件MD5出错: %v", err) + } else { + fmt.Println("文件MD5:", fileHash) + } + + // HMAC-MD5 + hmacHash := md5.MD5HMAC("重要消息", "secret-key") + fmt.Println("HMAC-MD5:", hmacHash) +} + +func ExampleEncryptString() { + hash := md5.EncryptString("HelloWorld") + fmt.Println(hash) + // Output: 68e109f0f40ca72a15e05cc22786f8e6 +} + +func ExampleMD5_Sum() { + hash := md5.New(). + Add("Hello"). + Add("World"). + Sum() + fmt.Println(hash) + // Output: 68e109f0f40ca72a15e05cc22786f8e6 +} + +func ExampleEncryptStringWithSalt() { + // 为用户密码加盐,通常使用用户唯一标识(如邮箱)作为盐值 + hash := md5.EncryptStringWithSalt("password123", "user@example.com") + fmt.Println("盐值哈希长度:", len(hash)) + fmt.Println("是否为有效哈希:", md5.VerifyMD5WithSalt("password123", "user@example.com", hash)) + // Output: + // 盐值哈希长度: 32 + // 是否为有效哈希: true +} + +func ExampleGet16() { + // 获取16位MD5,适合不需要完全防碰撞场景 + hash := md5.Get16("HelloWorld") + fmt.Println(hash) + // Output: f0f40ca72a15e05c +} diff --git a/pkg/lzkit/md5/md5.go b/pkg/lzkit/md5/md5.go new file mode 100644 index 0000000..bfbfa89 --- /dev/null +++ b/pkg/lzkit/md5/md5.go @@ -0,0 +1,206 @@ +package md5 + +import ( + "bufio" + "crypto/md5" + "encoding/hex" + "io" + "os" + "strings" +) + +// MD5结构体,可用于链式调用 +type MD5 struct { + data []byte +} + +// New 创建一个新的MD5实例 +func New() *MD5 { + return &MD5{ + data: []byte{}, + } +} + +// FromString 从字符串创建MD5 +func FromString(s string) *MD5 { + return &MD5{ + data: []byte(s), + } +} + +// FromBytes 从字节切片创建MD5 +func FromBytes(b []byte) *MD5 { + return &MD5{ + data: b, + } +} + +// Add 向MD5中添加字符串 +func (m *MD5) Add(s string) *MD5 { + m.data = append(m.data, []byte(s)...) + return m +} + +// AddBytes 向MD5中添加字节切片 +func (m *MD5) AddBytes(b []byte) *MD5 { + m.data = append(m.data, b...) + return m +} + +// Sum 计算并返回MD5哈希值(16进制字符串) +func (m *MD5) Sum() string { + hash := md5.New() + hash.Write(m.data) + return hex.EncodeToString(hash.Sum(nil)) +} + +// SumBytes 计算并返回MD5哈希值(字节切片) +func (m *MD5) SumBytes() []byte { + hash := md5.New() + hash.Write(m.data) + return hash.Sum(nil) +} + +// 直接调用的工具函数 + +// EncryptString 加密字符串 +func EncryptString(s string) string { + hash := md5.New() + hash.Write([]byte(s)) + return hex.EncodeToString(hash.Sum(nil)) +} + +// EncryptBytes 加密字节切片 +func EncryptBytes(b []byte) string { + hash := md5.New() + hash.Write(b) + return hex.EncodeToString(hash.Sum(nil)) +} + +// EncryptFile 加密文件内容 +func EncryptFile(filePath string) (string, error) { + file, err := os.Open(filePath) + if err != nil { + return "", err + } + defer file.Close() + + hash := md5.New() + if _, err := io.Copy(hash, file); err != nil { + return "", err + } + return hex.EncodeToString(hash.Sum(nil)), nil +} + +// EncryptFileChunk 对大文件分块计算MD5,提高效率 +func EncryptFileChunk(filePath string, chunkSize int) (string, error) { + if chunkSize <= 0 { + chunkSize = 1024 * 1024 // 默认1MB + } + + file, err := os.Open(filePath) + if err != nil { + return "", err + } + defer file.Close() + + hash := md5.New() + buf := make([]byte, chunkSize) + reader := bufio.NewReader(file) + + for { + n, err := reader.Read(buf) + if err != nil && err != io.EOF { + return "", err + } + if n == 0 { + break + } + hash.Write(buf[:n]) + } + + return hex.EncodeToString(hash.Sum(nil)), nil +} + +// EncryptStringWithSalt 使用盐值加密字符串 +func EncryptStringWithSalt(s, salt string) string { + return EncryptString(s + salt) +} + +// EncryptStringWithPrefix 使用前缀加密字符串 +func EncryptStringWithPrefix(s, prefix string) string { + return EncryptString(prefix + s) +} + +// VerifyMD5 验证字符串的MD5哈希是否匹配 +func VerifyMD5(s, hash string) bool { + return EncryptString(s) == strings.ToLower(hash) +} + +// VerifyMD5WithSalt 验证带盐值的字符串MD5哈希是否匹配 +func VerifyMD5WithSalt(s, salt, hash string) bool { + return EncryptStringWithSalt(s, salt) == strings.ToLower(hash) +} + +// VerifyFileMD5 验证文件的MD5哈希是否匹配 +func VerifyFileMD5(filePath, hash string) (bool, error) { + fileHash, err := EncryptFile(filePath) + if err != nil { + return false, err + } + return fileHash == strings.ToLower(hash), nil +} + +// MD5格式化为指定位数 + +// Get16 获取16位MD5值(取32位结果的中间16位) +func Get16(s string) string { + result := EncryptString(s) + return result[8:24] +} + +// Get8 获取8位MD5值 +func Get8(s string) string { + result := EncryptString(s) + return result[12:20] +} + +// MD5主要用于校验而非安全存储,对于需要高安全性的场景,应考虑: +// 1. bcrypt, scrypt或Argon2等专门为密码设计的算法 +// 2. HMAC-MD5等方式以防御彩虹表攻击 +// 3. 加盐并使用多次哈希迭代提高安全性 + +// MD5HMAC 使用HMAC-MD5算法 +func MD5HMAC(message, key string) string { + hash := md5.New() + + // 如果key长度超出block size,先进行哈希 + if len(key) > 64 { + hash.Write([]byte(key)) + key = hex.EncodeToString(hash.Sum(nil)) + hash.Reset() + } + + // 内部填充 + k_ipad := make([]byte, 64) + k_opad := make([]byte, 64) + copy(k_ipad, []byte(key)) + copy(k_opad, []byte(key)) + + for i := 0; i < 64; i++ { + k_ipad[i] ^= 0x36 + k_opad[i] ^= 0x5c + } + + // 内部哈希 + hash.Write(k_ipad) + hash.Write([]byte(message)) + innerHash := hash.Sum(nil) + hash.Reset() + + // 外部哈希 + hash.Write(k_opad) + hash.Write(innerHash) + + return hex.EncodeToString(hash.Sum(nil)) +} diff --git a/pkg/lzkit/md5/md5_test.go b/pkg/lzkit/md5/md5_test.go new file mode 100644 index 0000000..198c24a --- /dev/null +++ b/pkg/lzkit/md5/md5_test.go @@ -0,0 +1,192 @@ +package md5 + +import ( + "fmt" + "os" + "testing" +) + +func TestEncryptString(t *testing.T) { + tests := []struct { + input string + expected string + }{ + {"", "d41d8cd98f00b204e9800998ecf8427e"}, + {"hello", "5d41402abc4b2a76b9719d911017c592"}, + {"123456", "e10adc3949ba59abbe56e057f20f883e"}, + {"Hello World!", "ed076287532e86365e841e92bfc50d8c"}, + } + + for _, test := range tests { + result := EncryptString(test.input) + fmt.Println(result) + if result != test.expected { + t.Errorf("EncryptString(%s) = %s; want %s", test.input, result, test.expected) + } + } +} + +func TestEncryptBytes(t *testing.T) { + tests := []struct { + input []byte + expected string + }{ + {[]byte(""), "d41d8cd98f00b204e9800998ecf8427e"}, + {[]byte("hello"), "5d41402abc4b2a76b9719d911017c592"}, + {[]byte{0, 1, 2, 3, 4}, "5267768822ee624d48fce15ec5ca79b6"}, + } + + for _, test := range tests { + result := EncryptBytes(test.input) + if result != test.expected { + t.Errorf("EncryptBytes(%v) = %s; want %s", test.input, result, test.expected) + } + } +} + +func TestMD5Chain(t *testing.T) { + // 测试链式调用 + result := New(). + Add("hello"). + Add(" "). + Add("world"). + Sum() + + expected := "fc5e038d38a57032085441e7fe7010b0" // MD5("hello world") + if result != expected { + t.Errorf("Chain MD5 = %s; want %s", result, expected) + } + + // 测试从字符串初始化 + result = FromString("hello").Add(" world").Sum() + if result != expected { + t.Errorf("FromString MD5 = %s; want %s", result, expected) + } + + // 测试从字节切片初始化 + result = FromBytes([]byte("hello")).AddBytes([]byte(" world")).Sum() + if result != expected { + t.Errorf("FromBytes MD5 = %s; want %s", result, expected) + } +} + +func TestVerifyMD5(t *testing.T) { + if !VerifyMD5("hello", "5d41402abc4b2a76b9719d911017c592") { + t.Error("VerifyMD5 failed for correct match") + } + + if VerifyMD5("hello", "wrong-hash") { + t.Error("VerifyMD5 succeeded for incorrect match") + } + + // 测试大小写不敏感 + if !VerifyMD5("hello", "5D41402ABC4B2A76B9719D911017C592") { + t.Error("VerifyMD5 failed for uppercase hash") + } +} + +func TestSaltAndPrefix(t *testing.T) { + // 测试加盐 + saltResult := EncryptStringWithSalt("password", "salt123") + expectedSalt := EncryptString("passwordsalt123") + if saltResult != expectedSalt { + t.Errorf("EncryptStringWithSalt = %s; want %s", saltResult, expectedSalt) + } + + // 测试前缀 + prefixResult := EncryptStringWithPrefix("password", "prefix123") + expectedPrefix := EncryptString("prefix123password") + if prefixResult != expectedPrefix { + t.Errorf("EncryptStringWithPrefix = %s; want %s", prefixResult, expectedPrefix) + } + + // 验证带盐值的MD5 + if !VerifyMD5WithSalt("password", "salt123", saltResult) { + t.Error("VerifyMD5WithSalt failed for correct match") + } +} + +func TestGet16And8(t *testing.T) { + full := EncryptString("test-string") + + // 测试16位MD5 + result16 := Get16("test-string") + expected16 := full[8:24] + if result16 != expected16 { + t.Errorf("Get16 = %s; want %s", result16, expected16) + } + + // 测试8位MD5 + result8 := Get8("test-string") + expected8 := full[12:20] + if result8 != expected8 { + t.Errorf("Get8 = %s; want %s", result8, expected8) + } +} + +func TestMD5HMAC(t *testing.T) { + // 已知的HMAC-MD5结果 + tests := []struct { + message string + key string + expected string + }{ + {"message", "key", "4e4748e62b463521f6775fbf921234b5"}, + {"test", "secret", "8b11d99898918564dda1a9fe205b5310"}, + } + + for _, test := range tests { + result := MD5HMAC(test.message, test.key) + if result != test.expected { + t.Errorf("MD5HMAC(%s, %s) = %s; want %s", + test.message, test.key, result, test.expected) + } + } +} + +func TestEncryptFile(t *testing.T) { + // 创建临时测试文件 + content := []byte("test file content for MD5") + tmpFile, err := os.CreateTemp("", "md5test-*.txt") + if err != nil { + t.Fatalf("无法创建临时文件: %v", err) + } + defer os.Remove(tmpFile.Name()) + + if _, err := tmpFile.Write(content); err != nil { + t.Fatalf("无法写入临时文件: %v", err) + } + if err := tmpFile.Close(); err != nil { + t.Fatalf("无法关闭临时文件: %v", err) + } + + // 计算文件MD5 + fileHash, err := EncryptFile(tmpFile.Name()) + if err != nil { + t.Fatalf("计算文件MD5失败: %v", err) + } + + // 验证文件MD5 + expectedHash := EncryptBytes(content) + if fileHash != expectedHash { + t.Errorf("文件MD5 = %s; 应为 %s", fileHash, expectedHash) + } + + // 测试VerifyFileMD5 + match, err := VerifyFileMD5(tmpFile.Name(), expectedHash) + if err != nil { + t.Fatalf("验证文件MD5失败: %v", err) + } + if !match { + t.Error("VerifyFileMD5返回false,应返回true") + } + + // 测试不匹配的情况 + match, err = VerifyFileMD5(tmpFile.Name(), "wronghash") + if err != nil { + t.Fatalf("验证文件MD5失败: %v", err) + } + if match { + t.Error("VerifyFileMD5对错误的哈希返回true,应返回false") + } +}