diff --git a/app/user/cmd/api/internal/logic/auth/sendsmslogic.go b/app/user/cmd/api/internal/logic/auth/sendsmslogic.go index ce389fe..de63946 100644 --- a/app/user/cmd/api/internal/logic/auth/sendsmslogic.go +++ b/app/user/cmd/api/internal/logic/auth/sendsmslogic.go @@ -6,6 +6,7 @@ import ( "math/rand" "time" "tyc-server/common/xerr" + "tyc-server/pkg/lzkit/crypto" "github.com/pkg/errors" @@ -34,16 +35,21 @@ func NewSendSmsLogic(ctx context.Context, svcCtx *svc.ServiceContext) *SendSmsLo } func (l *SendSmsLogic) SendSms(req *types.SendSmsReq) error { + secretKey := l.svcCtx.Config.Encrypt.SecretKey + encryptedMobile, err := crypto.EncryptMobile(req.Mobile, secretKey) + if err != nil { + return errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "短信发送, 加密手机号失败: %+v", err) + } // 检查手机号是否在一分钟内已发送过验证码 - limitCodeKey := fmt.Sprintf("limit:%s:%s", req.ActionType, req.Mobile) + limitCodeKey := fmt.Sprintf("limit:%s:%s", req.ActionType, encryptedMobile) exists, err := l.svcCtx.Redis.Exists(limitCodeKey) if err != nil { - return errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "短信发送, 读取redis缓存失败: %s", req.Mobile) + return errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "短信发送, 读取redis缓存失败: %s", encryptedMobile) } if exists { // 如果 Redis 中已经存在标记,说明在 1 分钟内请求过,返回错误 - return errors.Wrapf(xerr.NewErrMsg("一分钟内不能重复发送验证码"), "短信发送, 手机号1分钟内重复请求发送二维码: %s", req.Mobile) + return errors.Wrapf(xerr.NewErrMsg("一分钟内不能重复发送验证码"), "短信发送, 手机号1分钟内重复请求发送验证码: %s", encryptedMobile) } code := fmt.Sprintf("%06d", rand.New(rand.NewSource(time.Now().UnixNano())).Intn(1000000)) @@ -56,7 +62,7 @@ func (l *SendSmsLogic) SendSms(req *types.SendSmsReq) error { if *smsResp.Body.Code != "OK" { return errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "短信发送, 阿里客户端响应失败: %s", *smsResp.Body.Message) } - codeKey := fmt.Sprintf("%s:%s", req.ActionType, req.Mobile) + codeKey := fmt.Sprintf("%s:%s", req.ActionType, encryptedMobile) // 将验证码保存到 Redis,设置过期时间 err = l.svcCtx.Redis.Setex(codeKey, code, l.svcCtx.Config.VerifyCode.ValidTime) // 验证码有效期5分钟 if err != nil { diff --git a/app/user/cmd/api/internal/logic/pay/alipaycallbacklogic.go b/app/user/cmd/api/internal/logic/pay/alipaycallbacklogic.go index 19dbc9a..e137fbf 100644 --- a/app/user/cmd/api/internal/logic/pay/alipaycallbacklogic.go +++ b/app/user/cmd/api/internal/logic/pay/alipaycallbacklogic.go @@ -60,7 +60,7 @@ func (l *AlipayCallbackLogic) AlipayCallback(w http.ResponseWriter, r *http.Requ default: return nil } - order.PlatformOrderId = lzUtils.StringToNullString(notification.OutTradeNo) + order.PlatformOrderId = lzUtils.StringToNullString(notification.TradeNo) if updateErr := l.svcCtx.OrderModel.UpdateWithVersion(l.ctx, nil, order); updateErr != nil { logx.Errorf("支付宝支付回调,修改订单信息失败: %+v", updateErr) return nil diff --git a/app/user/cmd/api/internal/logic/pay/paymentlogic.go b/app/user/cmd/api/internal/logic/pay/paymentlogic.go index 4325506..e12a0e7 100644 --- a/app/user/cmd/api/internal/logic/pay/paymentlogic.go +++ b/app/user/cmd/api/internal/logic/pay/paymentlogic.go @@ -2,7 +2,6 @@ package pay import ( "context" - "encoding/hex" "encoding/json" "fmt" "tyc-server/app/user/cmd/api/internal/svc" @@ -10,7 +9,6 @@ import ( "tyc-server/app/user/model" "tyc-server/common/ctxdata" "tyc-server/common/xerr" - "tyc-server/pkg/lzkit/crypto" "github.com/pkg/errors" "github.com/zeromicro/go-zero/core/logx" @@ -41,7 +39,8 @@ func (l *PaymentLogic) Payment(req *types.PaymentReq) (resp *types.PaymentResp, if !ok { return nil, errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "生成订单, 获取平台失败, %+v", getUidErr) } - redisKey := fmt.Sprintf("%d:%s", userID, req.Id) + outTradeNo := req.Id + redisKey := fmt.Sprintf("%d:%s", userID, outTradeNo) cache, cacheErr := l.svcCtx.Redis.GetCtx(l.ctx, redisKey) if cacheErr != nil { return nil, cacheErr @@ -56,21 +55,7 @@ func (l *PaymentLogic) Payment(req *types.PaymentReq) (resp *types.PaymentResp, if err != nil { return nil, errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "生成订单, 查找产品错误: %+v", err) } - secretKey := l.svcCtx.Config.Encrypt.SecretKey - key, decodeErr := hex.DecodeString(secretKey) - if decodeErr != nil { - return nil, errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "生成订单, 获取AES密钥失败: %+v", decodeErr) - } - params, marshalErr := json.Marshal(data.Params) - if marshalErr != nil { - return nil, errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "生成订单, 序列化参数失败: %+v", marshalErr) - } - encryptParams, aesEncryptErr := crypto.AesEncrypt(params, key) - if aesEncryptErr != nil { - return nil, errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "生成订单, 加密参数失败: %+v", aesEncryptErr) - } var prepayData interface{} - var outTradeNo string var amount float64 user, err := l.svcCtx.UserModel.FindOne(l.ctx, userID) if err != nil { @@ -85,13 +70,10 @@ func (l *PaymentLogic) Payment(req *types.PaymentReq) (resp *types.PaymentResp, var createOrderErr error if req.PayMethod == "wechat" { - outTradeNo = l.svcCtx.WechatPayService.GenerateOutTradeNo() prepayData, createOrderErr = l.svcCtx.WechatPayService.CreateWechatOrder(l.ctx, amount, product.ProductName, outTradeNo) } else if req.PayMethod == "alipay" { - outTradeNo = l.svcCtx.AlipayService.GenerateOutTradeNo() prepayData, createOrderErr = l.svcCtx.AlipayService.CreateAlipayOrder(l.ctx, amount, product.ProductName, outTradeNo, brand) } else if req.PayMethod == "appleiap" { - outTradeNo = l.svcCtx.ApplePayService.GenerateOutTradeNo() prepayData = l.svcCtx.ApplePayService.GetIappayAppID(product.ProductEn) } if createOrderErr != nil { @@ -99,17 +81,12 @@ func (l *PaymentLogic) Payment(req *types.PaymentReq) (resp *types.PaymentResp, } var orderID int64 transErr := l.svcCtx.OrderModel.Trans(l.ctx, func(ctx context.Context, session sqlx.Session) error { - - paymentScene := "app" - if brand == "tyc" { - paymentScene = "tyc" - } order := model.Order{ OrderNo: outTradeNo, UserId: userID, ProductId: product.Id, PaymentPlatform: req.PayMethod, - PaymentScene: paymentScene, + PaymentScene: "tyc", Amount: amount, Status: "pending", } @@ -122,17 +99,6 @@ func (l *PaymentLogic) Payment(req *types.PaymentReq) (resp *types.PaymentResp, return errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "生成订单, 获取保存订单ID失败: %+v", lastInsertIdErr) } orderID = insertedOrderID - query := model.Query{ - OrderId: orderID, - UserId: userID, - ProductId: product.Id, - QueryParams: encryptParams, - QueryState: "pending", - } - _, insertQueryErr := l.svcCtx.QueryModel.Insert(l.ctx, session, &query) - if insertQueryErr != nil { - return insertQueryErr - } return nil }) if transErr != nil { diff --git a/app/user/cmd/api/internal/logic/query/querydetailbyorderidlogic.go b/app/user/cmd/api/internal/logic/query/querydetailbyorderidlogic.go index c023bb0..176a172 100644 --- a/app/user/cmd/api/internal/logic/query/querydetailbyorderidlogic.go +++ b/app/user/cmd/api/internal/logic/query/querydetailbyorderidlogic.go @@ -6,6 +6,7 @@ import ( "encoding/hex" "encoding/json" "time" + "tyc-server/common/ctxdata" "tyc-server/common/xerr" "tyc-server/pkg/lzkit/crypto" "tyc-server/pkg/lzkit/delay" @@ -16,6 +17,7 @@ import ( "tyc-server/app/user/cmd/api/internal/svc" "tyc-server/app/user/cmd/api/internal/types" + "tyc-server/app/user/model" "github.com/zeromicro/go-zero/core/logx" ) @@ -35,12 +37,24 @@ func NewQueryDetailByOrderIdLogic(ctx context.Context, svcCtx *svc.ServiceContex } func (l *QueryDetailByOrderIdLogic) QueryDetailByOrderId(req *types.QueryDetailByOrderIdReq) (resp *types.QueryDetailByOrderIdResp, err error) { + // 获取当前用户ID + userId, err := ctxdata.GetUidFromCtx(l.ctx) + if err != nil { + return nil, errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "获取用户ID失败: %v", err) + } + // 获取订单信息 order, err := l.svcCtx.OrderModel.FindOne(l.ctx, req.OrderId) if err != nil { + if errors.Is(err, model.ErrNotFound) { + return nil, errors.Wrapf(xerr.NewErrCode(xerr.LOGIC_QUERY_NOT_FOUND), "报告查询, 订单不存在: %v", err) + } return nil, errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "报告查询, 查找报告错误: %+v", err) } - + // 安全验证:确保订单属于当前用户 + if order.UserId != userId { + return nil, errors.Wrapf(xerr.NewErrCode(xerr.LOGIC_QUERY_NOT_FOUND), "无权查看此订单报告") + } // 创建渐进式延迟策略实例 progressiveDelayOrder, err := delay.New(200*time.Millisecond, 3*time.Second, 10*time.Second, 1.5) if err != nil { diff --git a/app/user/cmd/api/internal/logic/query/querydetailbyordernologic.go b/app/user/cmd/api/internal/logic/query/querydetailbyordernologic.go index 1121aa3..48785a6 100644 --- a/app/user/cmd/api/internal/logic/query/querydetailbyordernologic.go +++ b/app/user/cmd/api/internal/logic/query/querydetailbyordernologic.go @@ -4,6 +4,7 @@ import ( "context" "encoding/hex" "time" + "tyc-server/common/ctxdata" "tyc-server/common/xerr" "tyc-server/pkg/lzkit/delay" @@ -12,6 +13,7 @@ import ( "tyc-server/app/user/cmd/api/internal/svc" "tyc-server/app/user/cmd/api/internal/types" + "tyc-server/app/user/model" "github.com/zeromicro/go-zero/core/logx" ) @@ -31,12 +33,25 @@ func NewQueryDetailByOrderNoLogic(ctx context.Context, svcCtx *svc.ServiceContex } func (l *QueryDetailByOrderNoLogic) QueryDetailByOrderNo(req *types.QueryDetailByOrderNoReq) (resp *types.QueryDetailByOrderNoResp, err error) { + // 获取当前用户ID + userId, err := ctxdata.GetUidFromCtx(l.ctx) + if err != nil { + return nil, errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "获取用户ID失败: %v", err) + } + // 获取订单信息 order, err := l.svcCtx.OrderModel.FindOneByOrderNo(l.ctx, req.OrderNo) if err != nil { + if errors.Is(err, model.ErrNotFound) { + return nil, errors.Wrapf(xerr.NewErrCode(xerr.LOGIC_QUERY_NOT_FOUND), "报告查询, 订单不存在: %v", err) + } return nil, errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "报告查询, 查找报告错误: %+v", err) } + // 安全验证:确保订单属于当前用户 + if order.UserId != userId { + return nil, errors.Wrapf(xerr.NewErrCode(xerr.LOGIC_QUERY_NOT_FOUND), "无权查看此订单报告") + } // 创建渐进式延迟策略实例 progressiveDelayOrder, err := delay.New(200*time.Millisecond, 3*time.Second, 10*time.Second, 1.5) if err != nil { diff --git a/app/user/cmd/api/internal/logic/query/queryprovisionalorderlogic.go b/app/user/cmd/api/internal/logic/query/queryprovisionalorderlogic.go index e25cb03..2bd246d 100644 --- a/app/user/cmd/api/internal/logic/query/queryprovisionalorderlogic.go +++ b/app/user/cmd/api/internal/logic/query/queryprovisionalorderlogic.go @@ -2,12 +2,14 @@ package query import ( "context" + "encoding/hex" "encoding/json" "fmt" "tyc-server/app/user/cmd/api/internal/svc" "tyc-server/app/user/cmd/api/internal/types" "tyc-server/common/ctxdata" "tyc-server/common/xerr" + "tyc-server/pkg/lzkit/crypto" "github.com/jinzhu/copier" "github.com/pkg/errors" @@ -54,8 +56,22 @@ func (l *QueryProvisionalOrderLogic) QueryProvisionalOrder(req *types.QueryProvi if err != nil { return nil, errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "获取临时订单, 用户信息结构体复制失败: %+v", err) } + secretKey := l.svcCtx.Config.Encrypt.SecretKey + key, decodeErr := hex.DecodeString(secretKey) + if decodeErr != nil { + return nil, fmt.Errorf("获取AES密钥失败: %+v", decodeErr) + } + decryptData, aesdecryptErr := crypto.AesDecrypt(data.Params, key) + if aesdecryptErr != nil { + return nil, fmt.Errorf("解密参数失败: %+v", aesdecryptErr) + } + queryParams := make(map[string]interface{}) + err = json.Unmarshal(decryptData, &queryParams) + if err != nil { + return nil, fmt.Errorf("解析解密数据失败: %+v", err) + } return &types.QueryProvisionalOrderResp{ - QueryParams: data.Params, + QueryParams: queryParams, Product: product, }, nil } diff --git a/app/user/cmd/api/internal/logic/query/queryservicelogic.go b/app/user/cmd/api/internal/logic/query/queryservicelogic.go index 5847316..4609ca4 100644 --- a/app/user/cmd/api/internal/logic/query/queryservicelogic.go +++ b/app/user/cmd/api/internal/logic/query/queryservicelogic.go @@ -53,16 +53,21 @@ func (l *QueryServiceLogic) DecryptData(data string) ([]byte, error) { // 校验验证码 func (l *QueryServiceLogic) VerifyCode(mobile string, code string) error { - codeRedisKey := fmt.Sprintf("%s:%s", "query", mobile) + secretKey := l.svcCtx.Config.Encrypt.SecretKey + encryptedMobile, err := crypto.EncryptMobile(mobile, secretKey) + if err != nil { + return errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "加密手机号失败: %+v", err) + } + codeRedisKey := fmt.Sprintf("%s:%s", "query", encryptedMobile) cacheCode, err := l.svcCtx.Redis.Get(codeRedisKey) if err != nil { if errors.Is(err, redis.Nil) { - return errors.Wrapf(xerr.NewErrMsg("验证码已过期"), "验证码过期: %s", mobile) + return errors.Wrapf(xerr.NewErrMsg("验证码已过期"), "验证码过期: %s", encryptedMobile) } - return errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "读取验证码redis缓存失败, mobile: %s, err: %+v", mobile, err) + return errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "读取验证码redis缓存失败, mobile: %s, err: %+v", encryptedMobile, err) } if cacheCode != code { - return errors.Wrapf(xerr.NewErrMsg("验证码不正确"), "验证码不正确: %s", mobile) + return errors.Wrapf(xerr.NewErrMsg("验证码不正确"), "验证码不正确: %s", encryptedMobile) } return nil } @@ -101,15 +106,29 @@ func (l *QueryServiceLogic) Verify(Name string, IDCard string, Mobile string) er // 缓存 func (l *QueryServiceLogic) CacheData(params map[string]interface{}, Product string, userID int64) (string, error) { + + secretKey := l.svcCtx.Config.Encrypt.SecretKey + key, decodeErr := hex.DecodeString(secretKey) + if decodeErr != nil { + return "", errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "查询服务, 获取AES密钥失败: %+v", decodeErr) + } + paramsMarshal, marshalErr := json.Marshal(params) + if marshalErr != nil { + return "", errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "查询服务, 序列化参数失败: %+v", marshalErr) + } + encryptParams, aesEncryptErr := crypto.AesEncrypt(paramsMarshal, key) + if aesEncryptErr != nil { + return "", errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "查询服务, 加密参数失败: %+v", aesEncryptErr) + } queryCache := types.QueryCacheLoad{ - Params: params, + Params: encryptParams, Product: Product, } jsonData, marshalErr := json.Marshal(queryCache) if marshalErr != nil { return "", errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "查询服务, 序列化参数失败: %+v", marshalErr) } - outTradeNo := l.svcCtx.WechatPayService.GenerateOutTradeNo() + outTradeNo := l.svcCtx.AlipayService.GenerateOutTradeNo() redisKey := fmt.Sprintf("%d:%s", userID, outTradeNo) cacheErr := l.svcCtx.Redis.SetexCtx(l.ctx, redisKey, string(jsonData), int(2*time.Hour)) if cacheErr != nil { diff --git a/app/user/cmd/api/internal/logic/user/detaillogic.go b/app/user/cmd/api/internal/logic/user/detaillogic.go index 6202997..8f25116 100644 --- a/app/user/cmd/api/internal/logic/user/detaillogic.go +++ b/app/user/cmd/api/internal/logic/user/detaillogic.go @@ -4,8 +4,10 @@ import ( "context" "tyc-server/app/user/cmd/api/internal/svc" "tyc-server/app/user/cmd/api/internal/types" + "tyc-server/app/user/model" "tyc-server/common/ctxdata" "tyc-server/common/xerr" + "tyc-server/pkg/lzkit/crypto" "github.com/jinzhu/copier" "github.com/pkg/errors" @@ -34,6 +36,9 @@ func (l *DetailLogic) Detail() (resp *types.UserInfoResp, err error) { } user, err := l.svcCtx.UserModel.FindOne(l.ctx, userID) if err != nil { + if errors.Is(err, model.ErrNotFound) { + return nil, errors.Wrapf(xerr.NewErrCode(xerr.USER_NOT_FOUND), "用户信息, 用户不存在, %v", err) + } return nil, errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "用户信息, 数据库查询用户信息失败, %+v", err) } var userInfo types.User @@ -41,6 +46,10 @@ func (l *DetailLogic) Detail() (resp *types.UserInfoResp, err error) { if err != nil { return nil, errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "用户信息, 用户信息结构体复制失败, %+v", err) } + userInfo.Mobile, err = crypto.DecryptMobile(userInfo.Mobile, l.svcCtx.Config.Encrypt.SecretKey) + if err != nil { + return nil, errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "用户信息, 解密手机号失败, %v", err) + } return &types.UserInfoResp{ UserInfo: userInfo, }, nil diff --git a/app/user/cmd/api/internal/logic/user/mobilecodeloginlogic.go b/app/user/cmd/api/internal/logic/user/mobilecodeloginlogic.go index cb7e844..fdfbed6 100644 --- a/app/user/cmd/api/internal/logic/user/mobilecodeloginlogic.go +++ b/app/user/cmd/api/internal/logic/user/mobilecodeloginlogic.go @@ -2,6 +2,7 @@ package user import ( "context" + "database/sql" "fmt" "time" "tyc-server/app/user/cmd/api/internal/svc" @@ -9,6 +10,7 @@ import ( "tyc-server/app/user/model" jwtx "tyc-server/common/jwt" "tyc-server/common/xerr" + "tyc-server/pkg/lzkit/crypto" "github.com/pkg/errors" "github.com/zeromicro/go-zero/core/stores/redis" @@ -32,34 +34,41 @@ func NewMobileCodeLoginLogic(ctx context.Context, svcCtx *svc.ServiceContext) *M } func (l *MobileCodeLoginLogic) MobileCodeLogin(req *types.MobileCodeLoginReq) (resp *types.MobileCodeLoginResp, err error) { - if !l.MobileCodeLoginInside(req) { - // 检查手机号是否在一分钟内已发送过验证码 - redisKey := fmt.Sprintf("%s:%s", "login", req.Mobile) - cacheCode, err := l.svcCtx.Redis.Get(redisKey) - if err != nil { - if errors.Is(err, redis.Nil) { - return nil, errors.Wrapf(xerr.NewErrMsg("验证码已过期"), "手机登录, 验证码过期: %s", req.Mobile) - } - return nil, errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "手机登录, 读取验证码redis缓存失败, mobile: %s, err: %+v", req.Mobile, err) - } - if cacheCode != req.Code { - return nil, errors.Wrapf(xerr.NewErrMsg("验证码不正确"), "手机登录, 验证码不正确: %s", req.Mobile) - } + secretKey := l.svcCtx.Config.Encrypt.SecretKey + encryptedMobile, err := crypto.EncryptMobile(req.Mobile, secretKey) + if err != nil { + return nil, errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "手机登录, 加密手机号失败: %+v", err) } - user, findUserErr := l.svcCtx.UserModel.FindOneByMobile(l.ctx, req.Mobile) + // 检查手机号是否在一分钟内已发送过验证码 + redisKey := fmt.Sprintf("%s:%s", "login", encryptedMobile) + cacheCode, err := l.svcCtx.Redis.Get(redisKey) + if err != nil { + if errors.Is(err, redis.Nil) { + return nil, errors.Wrapf(xerr.NewErrMsg("验证码已过期"), "手机登录, 验证码过期: %s", encryptedMobile) + } + return nil, errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "手机登录, 读取验证码redis缓存失败, mobile: %s, err: %+v", encryptedMobile, err) + } + if cacheCode != req.Code { + return nil, errors.Wrapf(xerr.NewErrMsg("验证码不正确"), "手机登录, 验证码不正确: %s", encryptedMobile) + } + + user, findUserErr := l.svcCtx.UserModel.FindOneByMobile(l.ctx, encryptedMobile) if findUserErr != nil && findUserErr != model.ErrNotFound { - return nil, errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "手机登录, 读取数据库获取用户失败, mobile: %s, err: %+v", req.Mobile, err) + return nil, errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "手机登录, 读取数据库获取用户失败, mobile: %s, err: %+v", encryptedMobile, err) } if user == nil { - user = &model.User{Mobile: req.Mobile} - if len(user.Nickname) == 0 { - user.Nickname = req.Mobile + user = &model.User{Mobile: encryptedMobile} + if user.Nickname.Valid && user.Nickname.String != "" { + user.Nickname = sql.NullString{ + String: encryptedMobile, + Valid: true, + } } if transErr := l.svcCtx.UserModel.Trans(l.ctx, func(ctx context.Context, session sqlx.Session) error { insertResult, userInsertErr := l.svcCtx.UserModel.Insert(ctx, session, user) if userInsertErr != nil { - return errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "手机注册, 数据库插入新用户失败, mobile%s, err: %+v", req.Mobile, err) + return errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "手机注册, 数据库插入新用户失败, mobile: %s, err: %+v", encryptedMobile, err) } lastId, lastInsertIdErr := insertResult.LastInsertId() if lastInsertIdErr != nil { @@ -69,7 +78,7 @@ func (l *MobileCodeLoginLogic) MobileCodeLogin(req *types.MobileCodeLoginReq) (r userAuth := new(model.UserAuth) userAuth.UserId = lastId - userAuth.AuthKey = req.Mobile + userAuth.AuthKey = encryptedMobile userAuth.AuthType = model.UserAuthTypeAppMobile if _, userAuthInsertErr := l.svcCtx.UserAuthModel.Insert(ctx, session, userAuth); userAuthInsertErr != nil { return errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "手机注册, 数据库插入用户认证失败, err:%+v", userAuthInsertErr) @@ -92,9 +101,3 @@ func (l *MobileCodeLoginLogic) MobileCodeLogin(req *types.MobileCodeLoginReq) (r RefreshAfter: now + l.svcCtx.Config.JwtAuth.RefreshAfter, }, nil } -func (l *MobileCodeLoginLogic) MobileCodeLoginInside(req *types.MobileCodeLoginReq) (pass bool) { - if req.Mobile == "17776203797" && req.Code == "688629" { - return true - } - return false -} diff --git a/app/user/cmd/api/internal/logic/user/mobileloginlogic.go b/app/user/cmd/api/internal/logic/user/mobileloginlogic.go index cab1ab3..d5a7dd6 100644 --- a/app/user/cmd/api/internal/logic/user/mobileloginlogic.go +++ b/app/user/cmd/api/internal/logic/user/mobileloginlogic.go @@ -7,6 +7,7 @@ import ( jwtx "tyc-server/common/jwt" "tyc-server/common/tool" "tyc-server/common/xerr" + "tyc-server/pkg/lzkit/crypto" "tyc-server/pkg/lzkit/lzUtils" "github.com/pkg/errors" @@ -32,15 +33,20 @@ func NewMobileLoginLogic(ctx context.Context, svcCtx *svc.ServiceContext) *Mobil } func (l *MobileLoginLogic) MobileLogin(req *types.MobileLoginReq) (resp *types.MobileCodeLoginResp, err error) { - user, findUserErr := l.svcCtx.UserModel.FindOneByMobile(l.ctx, req.Mobile) + secretKey := l.svcCtx.Config.Encrypt.SecretKey + encryptedMobile, err := crypto.EncryptMobile(req.Mobile, secretKey) + if err != nil { + return nil, errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "手机登录, 加密手机号失败: %+v", err) + } + user, findUserErr := l.svcCtx.UserModel.FindOneByMobile(l.ctx, encryptedMobile) if findUserErr != nil && findUserErr != model.ErrNotFound { - return nil, errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "手机登录, 读取数据库获取用户失败, mobile%s, err: %+v", req.Mobile, err) + return nil, errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "手机登录, 读取数据库获取用户失败, mobile%s, err: %+v", encryptedMobile, err) } if user == nil { - return nil, errors.Wrapf(xerr.NewErrMsg("手机号码未注册"), "手机登录, 手机号未注册:%s", req.Mobile) + return nil, errors.Wrapf(xerr.NewErrMsg("手机号码未注册"), "手机登录, 手机号未注册:%s", encryptedMobile) } if !(tool.Md5ByString(req.Password) == lzUtils.NullStringToString(user.Password)) { - return nil, errors.Wrapf(xerr.NewErrMsg("密码不正确"), "手机登录, 密码匹配不正确%s", req.Mobile) + return nil, errors.Wrapf(xerr.NewErrMsg("密码不正确"), "手机登录, 密码匹配不正确%s", encryptedMobile) } token, generaErr := jwtx.GenerateJwtToken(user.Id, l.svcCtx.Config.JwtAuth.AccessSecret, l.svcCtx.Config.JwtAuth.AccessExpire) diff --git a/app/user/cmd/api/internal/logic/user/registerlogic.go b/app/user/cmd/api/internal/logic/user/registerlogic.go index 060bcc9..9459543 100644 --- a/app/user/cmd/api/internal/logic/user/registerlogic.go +++ b/app/user/cmd/api/internal/logic/user/registerlogic.go @@ -2,6 +2,7 @@ package user import ( "context" + "database/sql" "fmt" "time" "tyc-server/app/user/cmd/api/internal/svc" @@ -10,6 +11,7 @@ import ( jwtx "tyc-server/common/jwt" "tyc-server/common/tool" "tyc-server/common/xerr" + "tyc-server/pkg/lzkit/crypto" "tyc-server/pkg/lzkit/lzUtils" "github.com/pkg/errors" @@ -34,38 +36,46 @@ func NewRegisterLogic(ctx context.Context, svcCtx *svc.ServiceContext) *Register } func (l *RegisterLogic) Register(req *types.RegisterReq) (resp *types.RegisterResp, err error) { + secretKey := l.svcCtx.Config.Encrypt.SecretKey + encryptedMobile, err := crypto.EncryptMobile(req.Mobile, secretKey) + if err != nil { + return nil, errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "手机注册, 加密手机号失败: %+v", err) + } // 检查手机号是否在一分钟内已发送过验证码 - redisKey := fmt.Sprintf("%s:%s", "register", req.Mobile) + redisKey := fmt.Sprintf("%s:%s", "register", encryptedMobile) cacheCode, err := l.svcCtx.Redis.Get(redisKey) if err != nil { if errors.Is(err, redis.Nil) { - return nil, errors.Wrapf(xerr.NewErrMsg("验证码已过期"), "手机注册, 验证码过期: %s", req.Mobile) + return nil, errors.Wrapf(xerr.NewErrMsg("验证码已过期"), "手机注册, 验证码过期: %s", encryptedMobile) } - return nil, errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "手机注册, 读取验证码redis缓存失败, mobile: %s, err: %+v", req.Mobile, err) + return nil, errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "手机注册, 读取验证码redis缓存失败, mobile: %s, err: %+v", encryptedMobile, err) } if cacheCode != req.Code { - return nil, errors.Wrapf(xerr.NewErrMsg("验证码不正确"), "手机注册, 验证码不正确: %s", req.Mobile) + return nil, errors.Wrapf(xerr.NewErrMsg("验证码不正确"), "手机注册, 验证码不正确: %s", encryptedMobile) } - hasUser, findUserErr := l.svcCtx.UserModel.FindOneByMobile(l.ctx, req.Mobile) + hasUser, findUserErr := l.svcCtx.UserModel.FindOneByMobile(l.ctx, encryptedMobile) if findUserErr != nil && findUserErr != model.ErrNotFound { - return nil, errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "手机注册, 读取数据库获取用户失败, mobile%s, err: %+v", req.Mobile, err) + return nil, errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "手机注册, 读取数据库获取用户失败, mobile%s, err: %+v", encryptedMobile, err) } if hasUser != nil { - return nil, errors.Wrapf(xerr.NewErrMsg("该手机号码已注册"), "手机注册, 手机号码已注册, mobile:%s", req.Mobile) + return nil, errors.Wrapf(xerr.NewErrMsg("该手机号码已注册"), "手机注册, 手机号码已注册, mobile:%s", encryptedMobile) } var userId int64 if transErr := l.svcCtx.UserModel.Trans(l.ctx, func(ctx context.Context, session sqlx.Session) error { user := new(model.User) - user.Mobile = req.Mobile - if len(user.Nickname) == 0 { - user.Nickname = req.Mobile + user.Mobile = encryptedMobile + if user.Nickname.Valid && user.Nickname.String != "" { + user.Nickname = sql.NullString{ + String: encryptedMobile, + Valid: true, + } } if len(req.Password) > 0 { user.Password = lzUtils.StringToNullString(tool.Md5ByString(req.Password)) } insertResult, userInsertErr := l.svcCtx.UserModel.Insert(ctx, session, user) if userInsertErr != nil { - return errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "手机注册, 数据库插入新用户失败, mobile%s, err: %+v", req.Mobile, err) + return errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "手机注册, 数据库插入新用户失败, mobile%s, err: %+v", encryptedMobile, err) } lastId, lastInsertIdErr := insertResult.LastInsertId() if lastInsertIdErr != nil { @@ -75,7 +85,7 @@ func (l *RegisterLogic) Register(req *types.RegisterReq) (resp *types.RegisterRe userAuth := new(model.UserAuth) userAuth.UserId = lastId - userAuth.AuthKey = req.Mobile + userAuth.AuthKey = encryptedMobile userAuth.AuthType = model.UserAuthTypeAppMobile if _, userAuthInsertErr := l.svcCtx.UserAuthModel.Insert(ctx, session, userAuth); userAuthInsertErr != nil { return errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "手机注册, 数据库插入用户认证失败, err:%+v", userAuthInsertErr) diff --git a/app/user/cmd/api/internal/queue/cleanQueryData.go b/app/user/cmd/api/internal/queue/cleanQueryData.go index 79f550a..85cdc9e 100644 --- a/app/user/cmd/api/internal/queue/cleanQueryData.go +++ b/app/user/cmd/api/internal/queue/cleanQueryData.go @@ -2,13 +2,14 @@ package queue import ( "context" - "fmt" + "time" "tyc-server/app/user/cmd/api/internal/svc" "github.com/hibiken/asynq" + "github.com/zeromicro/go-zero/core/logx" ) -const TASKTIME = "0 2 * * *" +const TASKTIME = "0 3 * * *" type CleanQueryDataHandler struct { svcCtx *svc.ServiceContext @@ -21,30 +22,19 @@ func NewCleanQueryDataHandler(svcCtx *svc.ServiceContext) *CleanQueryDataHandler } func (l *CleanQueryDataHandler) ProcessTask(ctx context.Context, t *asynq.Task) error { - fmt.Println(TASKTIME) - //// 计算 30 天前的时间 - //threshold := time.Now().AddDate(0, 0, -30) - // - //// 1. 构造查询条件,查找创建时间小于 threshold 的记录 - //// 假设表中“创建时间”字段为 create_time,且字段类型对应数据库里实际的列名 - //rowBuilder := l.svcCtx.QueryModel.SelectBuilder(). - // Where(squirrel.Lt{"create_time": threshold}) - // - //// 2. 调用 FindAll 获取所有符合条件的记录 - //// orderBy 这里可传空字符串或你想要的排序字段 - //records, err := l.svcCtx.QueryModel.FindAll(ctx, rowBuilder, "") - //if err != nil { - // return err - //} - // - //// 3. 遍历记录,逐条删除 - //// 假设你的 Delete 方法签名是:Delete(ctx context.Context, id int64) error - //// 并且 records[i] 中有一个字段 ID 用于表示主键 - //for _, record := range records { - // if err := l.svcCtx.QueryModel.Delete(ctx, record.Id); err != nil { - // return err - // } - //} + now := time.Now().Format("2006-01-02 15:04:05") + logx.Infof("%s - 开始执行查询数据清理任务", now) + // 计算3天前的时间 + threeDaysAgo := time.Now().AddDate(0, 0, -3) + + // 调用QueryModel删除3天前的数据 + result, err := l.svcCtx.QueryModel.DeleteBefore(ctx, threeDaysAgo) + if err != nil { + logx.Errorf("%s - 清理查询数据失败: %v", time.Now().Format("2006-01-02 15:04:05"), err) + return err + } + + logx.Infof("%s - 查询数据清理完成,共删除 %d 条记录", time.Now().Format("2006-01-02 15:04:05"), result) return nil } diff --git a/app/user/cmd/api/internal/queue/paySuccessNotify.go b/app/user/cmd/api/internal/queue/paySuccessNotify.go index 09db05a..95a59f9 100644 --- a/app/user/cmd/api/internal/queue/paySuccessNotify.go +++ b/app/user/cmd/api/internal/queue/paySuccessNotify.go @@ -5,7 +5,10 @@ import ( "encoding/hex" "encoding/json" "fmt" + "regexp" + "strings" "tyc-server/app/user/cmd/api/internal/svc" + "tyc-server/app/user/cmd/api/internal/types" "tyc-server/app/user/model" "tyc-server/pkg/lzkit/crypto" "tyc-server/pkg/lzkit/lzUtils" @@ -48,38 +51,61 @@ func (l *PaySuccessNotifyUserHandler) ProcessTask(ctx context.Context, t *asynq. return fmt.Errorf("找不到相关产品: orderID: %d, productID: %d", payload.OrderID, order.ProductId) } - query, findQueryErr := l.svcCtx.QueryModel.FindOneByOrderId(ctx, order.Id) - if findQueryErr != nil { - findQueryErr = fmt.Errorf("获取任务请求参数失败: %v", findQueryErr) - logx.Errorf("处理任务失败,原因: %v", findQueryErr) - return asynq.SkipRetry - } - if query.QueryState != model.QueryStatePending { - err = fmt.Errorf("查询已处理: %d", query.Id) - logx.Errorf("处理任务失败,原因: %v", err) - return asynq.SkipRetry - } - - query.QueryState = model.QueryStateProcessing - updateQueryErr := l.svcCtx.QueryModel.UpdateWithVersion(ctx, nil, query) - if updateQueryErr != nil { - handleErrorErr := fmt.Errorf("更新查询状态失败,订单ID: %d, 错误: %v", order.Id, updateQueryErr) - return l.handleError(ctx, handleErrorErr, order, query) + redisKey := fmt.Sprintf("%d:%s", order.UserId, order.OrderNo) + cache, cacheErr := l.svcCtx.Redis.GetCtx(ctx, redisKey) + if cacheErr != nil { + return fmt.Errorf("获取缓存内容失败: %+v", cacheErr) } secretKey := l.svcCtx.Config.Encrypt.SecretKey key, decodeErr := hex.DecodeString(secretKey) if decodeErr != nil { - handleErrorErr := fmt.Errorf("获取AES密钥失败: %v", decodeErr) - return l.handleError(ctx, handleErrorErr, order, query) + return fmt.Errorf("获取AES密钥失败: %+v", decodeErr) } - - decryptData, aesdecryptErr := crypto.AesDecrypt(query.QueryParams, key) + var data types.QueryCacheLoad + err = json.Unmarshal([]byte(cache), &data) + if err != nil { + return fmt.Errorf("解析缓存内容失败: %+v", err) + } + decryptData, aesdecryptErr := crypto.AesDecrypt(data.Params, key) if aesdecryptErr != nil { - handleErrorErr := fmt.Errorf("解密响应信息失败: %v", aesdecryptErr) - return l.handleError(ctx, handleErrorErr, order, query) + return fmt.Errorf("解密参数失败: %+v", aesdecryptErr) + } + // 敏感数据脱敏处理 + desensitizedParams, err := l.desensitizeParams(decryptData) + if err != nil { + return fmt.Errorf("脱敏处理失败: %+v", err) } + // 对脱敏后的数据进行AES加密 + encryptedParams, encryptErr := crypto.AesEncrypt(desensitizedParams, key) + if encryptErr != nil { + return fmt.Errorf("加密脱敏数据失败: %+v", encryptErr) + } + + query := &model.Query{ + OrderId: order.Id, + UserId: order.UserId, + ProductId: product.Id, + QueryParams: encryptedParams, + QueryState: model.QueryStateProcessing, + } + result, insertQueryErr := l.svcCtx.QueryModel.Insert(ctx, nil, query) + if insertQueryErr != nil { + return fmt.Errorf("保存查询失败: %+v", insertQueryErr) + } + + // 获取插入后的ID + queryId, err := result.LastInsertId() + if err != nil { + return fmt.Errorf("获取插入的查询ID失败: %+v", err) + } + + // 从数据库中查询完整的查询记录 + query, err = l.svcCtx.QueryModel.FindOne(ctx, queryId) + if err != nil { + return fmt.Errorf("获取插入后的查询记录失败: %+v", err) + } combinedResponse, err := l.svcCtx.ApiRequestService.ProcessRequests(ctx, decryptData, product.Id) if err != nil { handleErrorErr := fmt.Errorf("处理请求失败: %v", err) @@ -88,21 +114,26 @@ func (l *PaySuccessNotifyUserHandler) ProcessTask(ctx context.Context, t *asynq. // 加密返回响应 encryptData, aesEncryptErr := crypto.AesEncrypt(combinedResponse, key) if aesEncryptErr != nil { - handleErrorErr := fmt.Errorf("加密响应信息失败: %v", aesEncryptErr) - return l.handleError(ctx, handleErrorErr, order, query) + err = fmt.Errorf("加密响应信息失败: %v", aesEncryptErr) + return l.handleError(ctx, err, order, query) } query.QueryData = lzUtils.StringToNullString(encryptData) updateErr := l.svcCtx.QueryModel.UpdateWithVersion(ctx, nil, query) if updateErr != nil { - handleErrorErr := fmt.Errorf("保存响应数据失败: %v", updateErr) - return l.handleError(ctx, handleErrorErr, order, query) + err = fmt.Errorf("保存响应数据失败: %v", updateErr) + return l.handleError(ctx, err, order, query) } - query.QueryState = model.QueryStateSuccess - updateQueryErr = l.svcCtx.QueryModel.UpdateWithVersion(ctx, nil, query) + query.QueryState = "success" + updateQueryErr := l.svcCtx.QueryModel.UpdateWithVersion(ctx, nil, query) if updateQueryErr != nil { - handleErrorErr := fmt.Errorf("修改查询状态失败: %v", updateQueryErr) - return l.handleError(ctx, handleErrorErr, order, query) + updateQueryErr = fmt.Errorf("修改查询状态失败: %v", updateQueryErr) + return l.handleError(ctx, updateQueryErr, order, query) + } + + _, delErr := l.svcCtx.Redis.DelCtx(ctx, redisKey) + if delErr != nil { + logx.Errorf("删除Redis缓存失败,但任务已成功处理,订单ID: %d, 错误: %v", order.Id, delErr) } return nil @@ -111,7 +142,11 @@ func (l *PaySuccessNotifyUserHandler) ProcessTask(ctx context.Context, t *asynq. // 定义一个中间件函数 func (l *PaySuccessNotifyUserHandler) handleError(ctx context.Context, err error, order *model.Order, query *model.Query) error { logx.Errorf("处理任务失败,原因: %v", err) - + redisKey := fmt.Sprintf("%d:%s", order.UserId, order.OrderNo) + _, delErr := l.svcCtx.Redis.DelCtx(ctx, redisKey) + if delErr != nil { + logx.Errorf("删除Redis缓存失败,订单ID: %d, 错误: %v", order.Id, delErr) + } if order.Status == "paid" && query.QueryState == model.QueryStateProcessing { // 更新查询状态为失败 query.QueryState = model.QueryStateFailed @@ -155,3 +190,177 @@ func (l *PaySuccessNotifyUserHandler) handleError(ctx context.Context, err error return asynq.SkipRetry } + +// desensitizeParams 对敏感数据进行脱敏处理 +func (l *PaySuccessNotifyUserHandler) desensitizeParams(data []byte) ([]byte, error) { + // 解析JSON数据到map + var paramsMap map[string]interface{} + if err := json.Unmarshal(data, ¶msMap); err != nil { + return nil, fmt.Errorf("解析JSON数据失败: %v", err) + } + + // 处理可能包含敏感信息的字段 + for key, value := range paramsMap { + if strValue, ok := value.(string); ok { + // 根据字段名和内容判断并脱敏 + if isNameField(key) && len(strValue) > 0 { + // 姓名脱敏 + paramsMap[key] = maskName(strValue) + } else if isIDCardField(key) && len(strValue) > 10 { + // 身份证号脱敏 + paramsMap[key] = maskIDCard(strValue) + } else if isPhoneField(key) && len(strValue) >= 8 { + // 手机号脱敏 + paramsMap[key] = maskPhone(strValue) + } else if len(strValue) > 3 { + // 其他所有未匹配的字段都进行通用脱敏 + paramsMap[key] = maskGeneral(strValue) + } + } else if mapValue, ok := value.(map[string]interface{}); ok { + // 递归处理嵌套的map + for subKey, subValue := range mapValue { + if subStrValue, ok := subValue.(string); ok { + if isNameField(subKey) && len(subStrValue) > 0 { + mapValue[subKey] = maskName(subStrValue) + } else if isIDCardField(subKey) && len(subStrValue) > 10 { + mapValue[subKey] = maskIDCard(subStrValue) + } else if isPhoneField(subKey) && len(subStrValue) >= 8 { + mapValue[subKey] = maskPhone(subStrValue) + } else if len(subStrValue) > 3 { + // 其他所有未匹配的字段都进行通用脱敏 + mapValue[subKey] = maskGeneral(subStrValue) + } + } + } + } + } + + // 将处理后的map重新序列化为JSON + return json.Marshal(paramsMap) +} + +// 判断是否为姓名字段 +func isNameField(key string) bool { + key = strings.ToLower(key) + return strings.Contains(key, "name") || strings.Contains(key, "姓名") || + strings.Contains(key, "owner") || strings.Contains(key, "user") +} + +// 判断是否为身份证字段 +func isIDCardField(key string) bool { + key = strings.ToLower(key) + return strings.Contains(key, "idcard") || strings.Contains(key, "id_card") || + strings.Contains(key, "身份证") || strings.Contains(key, "证件号") +} + +// 判断是否为手机号字段 +func isPhoneField(key string) bool { + key = strings.ToLower(key) + return strings.Contains(key, "phone") || strings.Contains(key, "mobile") || + strings.Contains(key, "手机") || strings.Contains(key, "电话") +} + +// 判断是否包含敏感数据模式 +func containsSensitivePattern(value string) bool { + // 检查是否包含连续的数字或字母模式 + numPattern := regexp.MustCompile(`\d{6,}`) + return numPattern.MatchString(value) +} + +// 姓名脱敏 +func maskName(name string) string { + // 将字符串转换为rune切片以正确处理中文字符 + runes := []rune(name) + length := len(runes) + + if length <= 1 { + return name + } + + if length == 2 { + // 两个字:保留第一个字,第二个字用*替代 + return string(runes[0]) + "*" + } + + // 三个字及以上:保留首尾字,中间用*替代 + first := string(runes[0]) + last := string(runes[length-1]) + mask := strings.Repeat("*", length-2) + + return first + mask + last +} + +// 身份证号脱敏 +func maskIDCard(idCard string) string { + length := len(idCard) + if length <= 10 { + return idCard // 如果长度太短,可能不是身份证,不处理 + } + // 保留前3位和后4位 + return idCard[:3] + strings.Repeat("*", length-7) + idCard[length-4:] +} + +// 手机号脱敏 +func maskPhone(phone string) string { + length := len(phone) + if length < 8 { + return phone // 如果长度太短,可能不是手机号,不处理 + } + // 保留前3位和后4位 + return phone[:3] + strings.Repeat("*", length-7) + phone[length-4:] +} + +// 通用敏感信息脱敏 - 根据字符串长度比例进行脱敏 +func maskGeneral(value string) string { + length := len(value) + + // 小于3个字符的不脱敏 + if length <= 3 { + return value + } + + // 根据字符串长度计算保留字符数 + var prefixLen, suffixLen int + + switch { + case length <= 6: // 短字符串 + // 保留首尾各1个字符 + prefixLen, suffixLen = 1, 1 + case length <= 10: // 中等长度字符串 + // 保留首部30%和尾部20%的字符 + prefixLen = int(float64(length) * 0.3) + suffixLen = int(float64(length) * 0.2) + case length <= 20: // 较长字符串 + // 保留首部25%和尾部15%的字符 + prefixLen = int(float64(length) * 0.25) + suffixLen = int(float64(length) * 0.15) + default: // 非常长的字符串 + // 保留首部20%和尾部10%的字符 + prefixLen = int(float64(length) * 0.2) + suffixLen = int(float64(length) * 0.1) + } + + // 确保至少有一个字符被保留 + if prefixLen < 1 { + prefixLen = 1 + } + if suffixLen < 1 { + suffixLen = 1 + } + + // 确保前缀和后缀总长不超过总长度的80% + if prefixLen+suffixLen > int(float64(length)*0.8) { + // 调整为总长度的80% + totalVisible := int(float64(length) * 0.8) + // 前缀占60%,后缀占40% + prefixLen = int(float64(totalVisible) * 0.6) + suffixLen = totalVisible - prefixLen + } + + // 创建脱敏后的字符串 + prefix := value[:prefixLen] + suffix := value[length-suffixLen:] + masked := strings.Repeat("*", length-prefixLen-suffixLen) + + return prefix + masked + suffix +} diff --git a/app/user/cmd/api/internal/service/alipayService.go b/app/user/cmd/api/internal/service/alipayService.go index 732ad0a..32620fa 100644 --- a/app/user/cmd/api/internal/service/alipayService.go +++ b/app/user/cmd/api/internal/service/alipayService.go @@ -2,10 +2,12 @@ package service import ( "context" + "crypto/rand" + "encoding/hex" "fmt" - mathrand "math/rand" "net/http" "strconv" + "sync/atomic" "time" "tyc-server/app/user/cmd/api/internal/config" "tyc-server/pkg/lzkit/lzUtils" @@ -160,31 +162,35 @@ func (a *AliPayService) QueryOrderStatus(ctx context.Context, outTradeNo string) return nil, fmt.Errorf("查询支付宝订单失败: %v", resp.SubMsg) } -// GenerateOutTradeNo 生成唯一订单号的函数 -func (a *AliPayService) GenerateOutTradeNo() string { - length := 16 - // 获取当前时间戳 - timestamp := time.Now().UnixNano() +// 添加全局原子计数器 +var alipayOrderCounter uint32 = 0 - // 转换为字符串 +// GenerateOutTradeNo 生成唯一订单号的函数 - 优化版本 +func (a *AliPayService) GenerateOutTradeNo() string { + + // 获取当前时间戳(毫秒级) + timestamp := time.Now().UnixMilli() timeStr := strconv.FormatInt(timestamp, 10) - // 生成随机数 - mathrand.Seed(time.Now().UnixNano()) - randomPart := strconv.Itoa(mathrand.Intn(1000000)) + // 原子递增计数器 + counter := atomic.AddUint32(&alipayOrderCounter, 1) - // 组合时间戳和随机数 - combined := timeStr + randomPart + // 生成4字节真随机数 + randomBytes := make([]byte, 4) + _, err := rand.Read(randomBytes) + if err != nil { + // 如果随机数生成失败,回退到使用时间纳秒数据 + randomBytes = []byte(strconv.FormatInt(time.Now().UnixNano()%1000000, 16)) + } + randomHex := hex.EncodeToString(randomBytes) - // 如果长度超出指定值,则截断;如果不够,则填充随机字符 - if len(combined) >= length { - return combined[:length] + // 组合所有部分: 前缀 + 时间戳 + 计数器 + 随机数 + orderNo := fmt.Sprintf("%s%06x%s", timeStr[:10], counter%0xFFFFFF, randomHex[:6]) + + // 确保长度不超过32字符(大多数支付平台的限制) + if len(orderNo) > 32 { + orderNo = orderNo[:32] } - // 如果长度不够,填充0 - for len(combined) < length { - combined += strconv.Itoa(mathrand.Intn(10)) // 填充随机数 - } - - return combined + return orderNo } diff --git a/app/user/cmd/api/internal/types/cache.go b/app/user/cmd/api/internal/types/cache.go index 84ba143..554d157 100644 --- a/app/user/cmd/api/internal/types/cache.go +++ b/app/user/cmd/api/internal/types/cache.go @@ -7,6 +7,6 @@ type QueryCache struct { Product string `json:"product_id"` } type QueryCacheLoad struct { - Product string `json:"product_en"` - Params map[string]interface{} `json:"params"` + Product string `json:"product_en"` + Params string `json:"params"` } diff --git a/app/user/model/queryModel.go b/app/user/model/queryModel.go index 88486c7..d7b5aee 100644 --- a/app/user/model/queryModel.go +++ b/app/user/model/queryModel.go @@ -1,6 +1,11 @@ package model import ( + "context" + "fmt" + "time" + "tyc-server/common/globalkey" + "github.com/zeromicro/go-zero/core/stores/cache" "github.com/zeromicro/go-zero/core/stores/sqlx" ) @@ -12,6 +17,7 @@ type ( // and implement the added methods in customQueryModel. QueryModel interface { queryModel + DeleteBefore(ctx context.Context, before time.Time) (int64, error) } customQueryModel struct { @@ -25,3 +31,29 @@ func NewQueryModel(conn sqlx.SqlConn, c cache.CacheConf) QueryModel { defaultQueryModel: newQueryModel(conn, c), } } +func (m *customQueryModel) DeleteBefore(ctx context.Context, before time.Time) (int64, error) { + var affected int64 = 0 + + // 使用事务处理批量删除 + err := m.defaultQueryModel.Trans(ctx, func(ctx context.Context, session sqlx.Session) error { + query := fmt.Sprintf("DELETE FROM %s WHERE create_time < ? AND del_state = ?", m.defaultQueryModel.table) + result, err := session.ExecCtx(ctx, query, before.Format("2006-01-02 15:04:05"), globalkey.DelStateNo) + if err != nil { + return err + } + + rows, err := result.RowsAffected() + if err != nil { + return err + } + + affected = rows + return nil + }) + + if err != nil { + return 0, err + } + + return affected, nil +} diff --git a/common/interceptor/rpcserver/loggerInterceptor.go b/common/interceptor/rpcserver/loggerInterceptor.go index 98ef6ce..ffe7c2a 100644 --- a/common/interceptor/rpcserver/loggerInterceptor.go +++ b/common/interceptor/rpcserver/loggerInterceptor.go @@ -3,7 +3,7 @@ package rpcserver import ( "context" - "qnc-server/common/xerr" + "tyc-server/common/xerr" "github.com/pkg/errors" "github.com/zeromicro/go-zero/core/logx" diff --git a/common/uniqueid/sn.go b/common/uniqueid/sn.go index 9660b9d..50b010e 100644 --- a/common/uniqueid/sn.go +++ b/common/uniqueid/sn.go @@ -2,8 +2,8 @@ package uniqueid import ( "fmt" - "qnc-server/common/tool" "time" + "tyc-server/common/tool" ) // 生成sn单号 diff --git a/common/xerr/errCode.go b/common/xerr/errCode.go index ade3e6f..4d872e0 100644 --- a/common/xerr/errCode.go +++ b/common/xerr/errCode.go @@ -14,7 +14,9 @@ const DB_ERROR uint32 = 100005 const DB_UPDATE_AFFECTED_ZERO_ERROR uint32 = 100006 const PARAM_VERIFICATION_ERROR uint32 = 100007 const CUSTOM_ERROR uint32 = 100008 +const USER_NOT_FOUND uint32 = 100009 const LOGIN_FAILED uint32 = 200001 const LOGIC_QUERY_WAIT uint32 = 200002 const LOGIC_QUERY_ERROR uint32 = 200003 +const LOGIC_QUERY_NOT_FOUND uint32 = 200004 diff --git a/pkg/lzkit/crypto/README.md b/pkg/lzkit/crypto/README.md new file mode 100644 index 0000000..bd8d77e --- /dev/null +++ b/pkg/lzkit/crypto/README.md @@ -0,0 +1,235 @@ +# AES 加密工具包 + +本包提供了多种加密方式,特别是用于处理敏感个人信息(如手机号、身份证号等)的加密和解密功能。 + +## 主要功能 + +- **AES-CBC 模式加密/解密** - 标准加密模式,适用于一般数据加密 +- **AES-ECB 模式加密/解密** - 确定性加密模式,适用于数据库字段加密和查询 +- **专门针对个人敏感信息的加密/解密方法** +- **密钥生成和管理工具** + +## 安全性说明 + +- **AES-CBC 模式**:使用随机 IV,相同明文每次加密结果不同,安全性较高 +- **AES-ECB 模式**:确定性加密,相同明文每次加密结果相同,便于数据库查询,但安全性较低 + +> **⚠️ 警告**:ECB 模式仅适用于短文本(如手机号、身份证号)的确定性加密,不建议用于加密大段文本或高安全需求场景。 + +## 使用示例 + +### 1. 加密手机号 + +使用 AES-ECB 模式加密手机号,保证确定性(相同手机号总是产生相同密文): + +```go +import ( + "fmt" + "tydata-server/pkg/lzkit/crypto" +) + +func encryptMobileExample() { + // 您的密钥(需安全保存,建议存储在配置中) + key := []byte("1234567890abcdef") // 16字节AES-128密钥 + + // 加密手机号 + mobile := "13800138000" + encryptedMobile, err := crypto.EncryptMobile(mobile, key) + if err != nil { + panic(err) + } + + fmt.Println("加密后的手机号:", encryptedMobile) + + // 解密手机号 + decryptedMobile, err := crypto.DecryptMobile(encryptedMobile, key) + if err != nil { + panic(err) + } + + fmt.Println("解密后的手机号:", decryptedMobile) +} +``` + +### 2. 在数据库中存储和查询加密手机号 + +```go +// 加密并存储手机号 +func saveUser(db *sqlx.DB, mobile string, key []byte) (int64, error) { + encryptedMobile, err := crypto.EncryptMobile(mobile, key) + if err != nil { + return 0, err + } + + var id int64 + err = db.QueryRow( + "INSERT INTO users (mobile, create_time) VALUES (?, NOW()) RETURNING id", + encryptedMobile, + ).Scan(&id) + + return id, err +} + +// 根据手机号查询用户 +func findUserByMobile(db *sqlx.DB, mobile string, key []byte) (*User, error) { + encryptedMobile, err := crypto.EncryptMobile(mobile, key) + if err != nil { + return nil, err + } + + var user User + err = db.QueryRow( + "SELECT id, mobile, create_time FROM users WHERE mobile = ?", + encryptedMobile, + ).Scan(&user.ID, &user.EncryptedMobile, &user.CreateTime) + + if err != nil { + return nil, err + } + + // 解密手机号用于显示 + user.Mobile, _ = crypto.DecryptMobile(user.EncryptedMobile, key) + + return &user, nil +} +``` + +### 3. 加密身份证号 + +```go +func encryptIDCardExample() { + key := []byte("1234567890abcdef") + + idCard := "440101199001011234" + encryptedIDCard, err := crypto.EncryptIDCard(idCard, key) + if err != nil { + panic(err) + } + + fmt.Println("加密后的身份证号:", encryptedIDCard) + + // 解密身份证号 + decryptedIDCard, err := crypto.DecryptIDCard(encryptedIDCard, key) + if err != nil { + panic(err) + } + + fmt.Println("解密后的身份证号:", decryptedIDCard) +} +``` + +### 4. 密钥管理 + +```go +func keyManagementExample() { + // 生成随机密钥 + key, err := crypto.GenerateAESKey(16) // AES-128 + if err != nil { + panic(err) + } + fmt.Printf("生成的密钥(十六进制): %x\n", key) + + // 从密码派生密钥(便于记忆) + password := "my-secure-password" + derivedKey, err := crypto.DeriveKeyFromPassword(password, 16) + if err != nil { + panic(err) + } + fmt.Printf("从密码派生的密钥: %x\n", derivedKey) +} +``` + +### 5. 使用十六进制输出(适用于 URL 参数) + +```go +func hexEncodingExample() { + key := []byte("1234567890abcdef") + mobile := "13800138000" + + // 使用十六进制编码(适合URL参数) + encryptedHex, err := crypto.EncryptMobileHex(mobile, key) + if err != nil { + panic(err) + } + + fmt.Println("十六进制编码的加密手机号:", encryptedHex) + + // 解密十六进制编码的手机号 + decryptedMobile, err := crypto.DecryptMobileHex(encryptedHex, key) + if err != nil { + panic(err) + } + + fmt.Println("解密后的手机号:", decryptedMobile) +} +``` + +## 在 Go-Zero 项目中使用 + +在 Go-Zero 项目中,建议将加密密钥放在配置文件中: + +1. 在配置文件中添加密钥配置: + +```yaml +# etc/main.yaml +Name: user-api +Host: 0.0.0.0 +Port: 8888 + +Encrypt: + MobileKey: "1234567890abcdef" # 16字节AES-128密钥 + IDCardKey: "1234567890abcdef1234567890abcdef" # 32字节AES-256密钥 +``` + +2. 在配置结构中定义: + +```go +type Config struct { + rest.RestConf + Encrypt struct { + MobileKey string + IDCardKey string + } +} +``` + +3. 在服务上下文中使用: + +```go +type ServiceContext struct { + Config config.Config + UserModel model.UserModel + MobileKey []byte + IDCardKey []byte +} + +func NewServiceContext(c config.Config) *ServiceContext { + return &ServiceContext{ + Config: c, + UserModel: model.NewUserModel(sqlx.NewMysql(c.DB.DataSource), c.Cache), + MobileKey: []byte(c.Encrypt.MobileKey), + IDCardKey: []byte(c.Encrypt.IDCardKey), + } +} +``` + +4. 在 Logic 中使用: + +```go +func (l *RegisterLogic) Register(req *types.RegisterReq) (*types.RegisterResp, error) { + // 加密手机号用于存储 + encryptedMobile, err := crypto.EncryptMobile(req.Mobile, l.svcCtx.MobileKey) + if err != nil { + return nil, errors.New("手机号加密失败") + } + + // 保存到数据库 + user := &model.User{ + Mobile: encryptedMobile, + // 其他字段... + } + + result, err := l.svcCtx.UserModel.Insert(l.ctx, nil, user) + // 其余逻辑... +} +``` diff --git a/pkg/lzkit/crypto/crypto_url.go b/pkg/lzkit/crypto/crypto_url.go new file mode 100644 index 0000000..777db50 --- /dev/null +++ b/pkg/lzkit/crypto/crypto_url.go @@ -0,0 +1,67 @@ +package crypto + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "encoding/base64" + "errors" + "io" +) + +// AES CBC模式加密,Base64传入传出 +func AesEncryptURL(plainText, key []byte) (string, error) { + block, err := aes.NewCipher(key) + if err != nil { + return "", err + } + blockSize := block.BlockSize() + plainText = PKCS7Padding(plainText, blockSize) + + cipherText := make([]byte, blockSize+len(plainText)) + iv := cipherText[:blockSize] // 使用前blockSize字节作为IV + _, err = io.ReadFull(rand.Reader, iv) + if err != nil { + return "", err + } + + mode := cipher.NewCBCEncrypter(block, iv) + mode.CryptBlocks(cipherText[blockSize:], plainText) + + return base64.URLEncoding.EncodeToString(cipherText), nil +} + +// AES CBC模式解密,Base64传入传出 +func AesDecryptURL(cipherTextBase64 string, key []byte) ([]byte, error) { + cipherText, err := base64.URLEncoding.DecodeString(cipherTextBase64) + if err != nil { + return nil, err + } + + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + + blockSize := block.BlockSize() + if len(cipherText) < blockSize { + return nil, errors.New("ciphertext too short") + } + + iv := cipherText[:blockSize] + cipherText = cipherText[blockSize:] + + if len(cipherText)%blockSize != 0 { + return nil, errors.New("ciphertext is not a multiple of the block size") + } + + mode := cipher.NewCBCDecrypter(block, iv) + mode.CryptBlocks(cipherText, cipherText) + + plainText, err := PKCS7UnPadding(cipherText) + if err != nil { + return nil, err + } + + return plainText, nil +} diff --git a/pkg/lzkit/crypto/ecb.go b/pkg/lzkit/crypto/ecb.go new file mode 100644 index 0000000..3fed15f --- /dev/null +++ b/pkg/lzkit/crypto/ecb.go @@ -0,0 +1,274 @@ +package crypto + +import ( + "crypto/aes" + "crypto/md5" + "crypto/rand" + "encoding/base64" + "encoding/hex" + "errors" + "fmt" +) + +// ECB模式是一种基本的加密模式,每个明文块独立加密 +// 警告:ECB模式存在安全问题,仅用于需要确定性加密的场景,如数据库字段查询 +// 不要用于加密大段文本或安全要求高的场景 + +// 验证密钥长度是否有效 (AES-128, AES-192, AES-256) +func validateAESKey(key []byte) error { + switch len(key) { + case 16, 24, 32: + return nil + default: + return errors.New("AES密钥长度必须是16、24或32字节(对应AES-128、AES-192、AES-256)") + } +} + +// AesEcbEncrypt AES-ECB模式加密,返回Base64编码的密文 +// 使用已有的ECB实现,但提供更易用的接口 +func AesEcbEncrypt(plainText, key []byte) (string, error) { + if err := validateAESKey(key); err != nil { + return "", err + } + + block, err := aes.NewCipher(key) + if err != nil { + return "", err + } + + // 使用PKCS7填充 + plainText = PKCS7Padding(plainText, block.BlockSize()) + + // 创建密文数组 + cipherText := make([]byte, len(plainText)) + + // ECB模式加密,使用west_crypto.go中已有的实现 + mode := newECBEncrypter(block) + mode.CryptBlocks(cipherText, plainText) + + // 返回Base64编码的密文 + return base64.StdEncoding.EncodeToString(cipherText), nil +} + +// AesEcbDecrypt AES-ECB模式解密,输入Base64编码的密文 +func AesEcbDecrypt(cipherTextBase64 string, key []byte) ([]byte, error) { + if err := validateAESKey(key); err != nil { + return nil, err + } + + // Base64解码 + cipherText, err := base64.StdEncoding.DecodeString(cipherTextBase64) + if err != nil { + return nil, err + } + + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + + // 检查密文长度 + if len(cipherText)%block.BlockSize() != 0 { + return nil, errors.New("密文长度必须是块大小的整数倍") + } + + // 创建明文数组 + plainText := make([]byte, len(cipherText)) + + // ECB模式解密,使用west_crypto.go中已有的实现 + mode := newECBDecrypter(block) + mode.CryptBlocks(plainText, cipherText) + + // 去除PKCS7填充 + plainText, err = PKCS7UnPadding(plainText) + if err != nil { + return nil, err + } + + return plainText, nil +} + +// AesEcbEncryptHex AES-ECB模式加密,返回十六进制编码的密文 +func AesEcbEncryptHex(plainText, key []byte) (string, error) { + if err := validateAESKey(key); err != nil { + return "", err + } + + block, err := aes.NewCipher(key) + if err != nil { + return "", err + } + + // 使用PKCS7填充 + plainText = PKCS7Padding(plainText, block.BlockSize()) + + // 创建密文数组 + cipherText := make([]byte, len(plainText)) + + // ECB模式加密 + mode := newECBEncrypter(block) + mode.CryptBlocks(cipherText, plainText) + + // 返回十六进制编码的密文 + return hex.EncodeToString(cipherText), nil +} + +// AesEcbDecryptHex AES-ECB模式解密,输入十六进制编码的密文 +func AesEcbDecryptHex(cipherTextHex string, key []byte) ([]byte, error) { + if err := validateAESKey(key); err != nil { + return nil, err + } + + // 十六进制解码 + cipherText, err := hex.DecodeString(cipherTextHex) + if err != nil { + return nil, err + } + + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + + // 检查密文长度 + if len(cipherText)%block.BlockSize() != 0 { + return nil, errors.New("密文长度必须是块大小的整数倍") + } + + // 创建明文数组 + plainText := make([]byte, len(cipherText)) + + // ECB模式解密 + mode := newECBDecrypter(block) + mode.CryptBlocks(plainText, cipherText) + + // 去除PKCS7填充 + plainText, err = PKCS7UnPadding(plainText) + if err != nil { + return nil, err + } + + return plainText, nil +} + +// 以下是专门用于处理手机号等敏感数据的实用函数 + +// EncryptMobile 使用AES-ECB加密手机号,返回Base64编码 +// 该方法保证对相同手机号总是产生相同密文,便于数据库查询 +func EncryptMobile(mobile string, secretKey string) (string, error) { + key, decodeErr := hex.DecodeString(secretKey) + if decodeErr != nil { + return "", decodeErr + } + if mobile == "" { + return "", errors.New("手机号不能为空") + } + return AesEcbEncrypt([]byte(mobile), key) +} + +// DecryptMobile 解密手机号 +func DecryptMobile(encryptedMobile string, secretKey string) (string, error) { + key, decodeErr := hex.DecodeString(secretKey) + if decodeErr != nil { + return "", decodeErr + } + if encryptedMobile == "" { + return "", errors.New("加密手机号不能为空") + } + + bytes, err := AesEcbDecrypt(encryptedMobile, key) + if err != nil { + return "", fmt.Errorf("解密手机号失败: %v", err) + } + + return string(bytes), nil +} + +// EncryptMobileHex 使用AES-ECB加密手机号,返回十六进制编码(适用于URL参数) +func EncryptMobileHex(mobile string, key []byte) (string, error) { + if mobile == "" { + return "", errors.New("手机号不能为空") + } + return AesEcbEncryptHex([]byte(mobile), key) +} + +// DecryptMobileHex 解密十六进制编码的手机号 +func DecryptMobileHex(encryptedMobileHex string, key []byte) (string, error) { + if encryptedMobileHex == "" { + return "", errors.New("加密手机号不能为空") + } + + bytes, err := AesEcbDecryptHex(encryptedMobileHex, key) + if err != nil { + return "", fmt.Errorf("解密手机号失败: %v", err) + } + + return string(bytes), nil +} + +// EncryptIDCard 使用AES-ECB加密身份证号 +func EncryptIDCard(idCard string, key []byte) (string, error) { + if idCard == "" { + return "", errors.New("身份证号不能为空") + } + return AesEcbEncrypt([]byte(idCard), key) +} + +// DecryptIDCard 解密身份证号 +func DecryptIDCard(encryptedIDCard string, key []byte) (string, error) { + if encryptedIDCard == "" { + return "", errors.New("加密身份证号不能为空") + } + + bytes, err := AesEcbDecrypt(encryptedIDCard, key) + if err != nil { + return "", fmt.Errorf("解密身份证号失败: %v", err) + } + + return string(bytes), nil +} + +// IsEncrypted 检查字符串是否为Base64编码的加密数据 +func IsEncrypted(data string) bool { + // 检查是否是有效的Base64编码 + _, err := base64.StdEncoding.DecodeString(data) + return err == nil && len(data) >= 20 // 至少20个字符的Base64字符串 +} + +// GenerateAESKey 生成AES密钥 +// keySize: 可选16, 24, 32字节(对应AES-128, AES-192, AES-256) +func GenerateAESKey(keySize int) ([]byte, error) { + if keySize != 16 && keySize != 24 && keySize != 32 { + return nil, errors.New("密钥长度必须是16、24或32字节") + } + + key := make([]byte, keySize) + _, err := rand.Read(key) + if err != nil { + return nil, err + } + + return key, nil +} + +// DeriveKeyFromPassword 基于密码派生固定长度的AES密钥 +func DeriveKeyFromPassword(password string, keySize int) ([]byte, error) { + if keySize != 16 && keySize != 24 && keySize != 32 { + return nil, errors.New("密钥长度必须是16、24或32字节") + } + + // 使用PBKDF2或简单的方法从密码派生密钥 + // 这里使用简单的MD5方法,实际生产环境应使用更安全的PBKDF2 + hash := md5.New() + hash.Write([]byte(password)) + key := hash.Sum(nil) // 16字节 + + // 如果需要24或32字节,继续哈希 + if keySize > 16 { + hash.Reset() + hash.Write(key) + key = append(key, hash.Sum(nil)[:keySize-16]...) + } + + return key, nil +} diff --git a/pkg/lzkit/crypto/ecb_test.go b/pkg/lzkit/crypto/ecb_test.go new file mode 100644 index 0000000..a497495 --- /dev/null +++ b/pkg/lzkit/crypto/ecb_test.go @@ -0,0 +1,185 @@ +package crypto + +import ( + "encoding/base64" + "encoding/hex" + "fmt" + "testing" +) + +func TestAesEcbMobileEncryption(t *testing.T) { + // 测试手机号加密 + mobile := "13800138000" + key := []byte("1234567890abcdef") // 16字节AES-128密钥 + + keyStr := hex.EncodeToString(key) + // 测试加密 + encrypted, err := EncryptMobile(mobile, keyStr) + if err != nil { + t.Fatalf("手机号加密失败: %v", err) + } + fmt.Println(encrypted) + // 测试解密 + decrypted, err := DecryptMobile(encrypted, keyStr) + if err != nil { + t.Fatalf("手机号解密失败: %v", err) + } + fmt.Println(decrypted) + // 验证结果 + if decrypted != mobile { + t.Errorf("解密结果不匹配,期望: %s, 实际: %s", mobile, decrypted) + } + + // 测试相同输入产生相同输出(确定性) + encrypted2, _ := EncryptMobile(mobile, keyStr) + if encrypted != encrypted2 { + t.Errorf("AES-ECB不是确定性的,两次加密结果不同: %s vs %s", encrypted, encrypted2) + } +} + +func TestAesEcbHexEncryption(t *testing.T) { + // 测试十六进制编码加密 + idCard := "440101199001011234" + key := []byte("1234567890abcdef") // 16字节AES-128密钥 + + // 测试HEX加密 + encryptedHex, err := EncryptIDCard(idCard, key) + if err != nil { + t.Fatalf("身份证加密失败: %v", err) + } + + // 测试HEX解密 + decrypted, err := DecryptIDCard(encryptedHex, key) + if err != nil { + t.Fatalf("身份证解密失败: %v", err) + } + + // 验证结果 + if decrypted != idCard { + t.Errorf("解密结果不匹配,期望: %s, 实际: %s", idCard, decrypted) + } +} + +func TestAesEcbKeyValidation(t *testing.T) { + // 测试不同长度的密钥 + validKeys := [][]byte{ + make([]byte, 16), // AES-128 + make([]byte, 24), // AES-192 + make([]byte, 32), // AES-256 + } + + invalidKeys := [][]byte{ + make([]byte, 15), + make([]byte, 20), + make([]byte, 33), + } + + text := []byte("test text") + + // 测试有效密钥 + for _, key := range validKeys { + _, err := AesEcbEncrypt(text, key) + if err != nil { + t.Errorf("有效密钥(%d字节)校验失败: %v", len(key), err) + } + } + + // 测试无效密钥 + for _, key := range invalidKeys { + _, err := AesEcbEncrypt(text, key) + if err == nil { + t.Errorf("无效密钥(%d字节)未被检测出", len(key)) + } + } +} + +func TestIsEncrypted(t *testing.T) { + // 有效的Base64编码字符串 + validBase64 := base64.StdEncoding.EncodeToString([]byte("这是一个足够长的字符串,以通过IsEncrypted检查")) + + // 无效的字符串 + invalidStrings := []string{ + "", + "abc", + "not-base64!@#", + hex.EncodeToString([]byte("hexstring")), + } + + // 测试有效的加密数据 + if !IsEncrypted(validBase64) { + t.Errorf("有效的Base64未被识别为加密数据: %s", validBase64) + } + + // 测试无效的数据 + for _, s := range invalidStrings { + if IsEncrypted(s) { + t.Errorf("无效字符串被错误识别为加密数据: %s", s) + } + } +} + +func TestDeriveKeyFromPassword(t *testing.T) { + password := "my-secure-password" + + // 测试不同长度的派生密钥 + keySizes := []int{16, 24, 32} + + for _, size := range keySizes { + key, err := DeriveKeyFromPassword(password, size) + if err != nil { + t.Errorf("从密码派生%d字节密钥失败: %v", size, err) + continue + } + + if len(key) != size { + t.Errorf("派生的密钥长度错误,期望: %d, 实际: %d", size, len(key)) + } + + // 测试相同密码总是产生相同密钥 + key2, _ := DeriveKeyFromPassword(password, size) + if string(key) != string(key2) { + t.Errorf("从相同密码派生的密钥不一致") + } + + // 使用派生的密钥加密测试 + _, err = AesEcbEncrypt([]byte("test"), key) + if err != nil { + t.Errorf("使用派生的密钥加密失败: %v", err) + } + } + + // 测试无效的密钥大小 + _, err := DeriveKeyFromPassword(password, 18) + if err == nil { + t.Error("无效的密钥大小未被检测出") + } +} + +func TestGenerateAESKey(t *testing.T) { + // 测试生成不同长度的密钥 + keySizes := []int{16, 24, 32} + + for _, size := range keySizes { + key, err := GenerateAESKey(size) + if err != nil { + t.Errorf("生成%d字节密钥失败: %v", size, err) + continue + } + + if len(key) != size { + t.Errorf("生成的密钥长度错误,期望: %d, 实际: %d", size, len(key)) + } + + // 使用生成的密钥加密测试 + _, err = AesEcbEncrypt([]byte("test"), key) + if err != nil { + t.Errorf("使用生成的密钥加密失败: %v", err) + } + } + + // 测试无效的密钥大小 + _, err := GenerateAESKey(18) + if err == nil { + t.Error("无效的密钥大小未被检测出") + } +}