137 lines
4.3 KiB
Go
137 lines
4.3 KiB
Go
package middleware
|
||
|
||
import (
|
||
"context"
|
||
"errors"
|
||
"fmt"
|
||
"github.com/zeromicro/go-zero/core/logx"
|
||
"github.com/zeromicro/go-zero/core/stores/redis"
|
||
xhttp "github.com/zeromicro/x/http"
|
||
"net"
|
||
"net/http"
|
||
"strings"
|
||
"tianyuan-api/apps/sentinel/client/secret"
|
||
"tianyuan-api/apps/sentinel/client/userproduct"
|
||
"tianyuan-api/apps/sentinel/client/whitelist"
|
||
"tianyuan-api/apps/sentinel/sentinel"
|
||
"tianyuan-api/apps/user/user"
|
||
"tianyuan-api/pkg/crypto"
|
||
)
|
||
|
||
type ApiAuthInterceptorMiddleware struct {
|
||
WhitelistRpc sentinel.WhitelistClient
|
||
SecretRpc sentinel.SecretClient
|
||
UserProductRpc sentinel.UserProductClient
|
||
UserRpc user.UserClient
|
||
Rds *redis.Redis
|
||
}
|
||
|
||
func NewApiAuthInterceptorMiddleware(
|
||
whitelistRpc sentinel.WhitelistClient,
|
||
secretRpc sentinel.SecretClient,
|
||
userProductRpc sentinel.UserProductClient,
|
||
userRpc user.UserClient,
|
||
rds *redis.Redis) *ApiAuthInterceptorMiddleware {
|
||
return &ApiAuthInterceptorMiddleware{
|
||
WhitelistRpc: whitelistRpc,
|
||
SecretRpc: secretRpc,
|
||
UserProductRpc: userProductRpc,
|
||
UserRpc: userRpc,
|
||
Rds: rds,
|
||
}
|
||
}
|
||
|
||
func (m *ApiAuthInterceptorMiddleware) Handle(next http.HandlerFunc) http.HandlerFunc {
|
||
return func(w http.ResponseWriter, r *http.Request) {
|
||
|
||
// 1. 查找IP白名单
|
||
clientIP := r.Header.Get("X-Forwarded-For")
|
||
if clientIP == "" {
|
||
clientIP = r.Header.Get("X-Real-IP")
|
||
}
|
||
if clientIP == "" {
|
||
clientIP, _, _ = net.SplitHostPort(r.RemoteAddr)
|
||
}
|
||
logx.Infof("当前请求IP:%s", clientIP)
|
||
redisKey := "whitelist_ips"
|
||
isMember, err := m.Rds.SismemberCtx(r.Context(), redisKey, clientIP)
|
||
if err == nil && isMember {
|
||
// 如果缓存中存在该IP,继续执行后续鉴权操作
|
||
// 此处不调用 next(w, r),而是继续后续鉴权逻辑
|
||
// 后续鉴权逻辑将继续执行
|
||
} else {
|
||
isAllowedResp, err := m.WhitelistRpc.MatchWhitelistByIp(r.Context(), &whitelist.MatchWhitelistByIpRequest{Ip: clientIP})
|
||
if err != nil {
|
||
xhttp.JsonBaseResponseCtx(r.Context(), w, errors.New("系统错误,请联系管理员"))
|
||
return
|
||
}
|
||
if !isAllowedResp.Match {
|
||
logx.Debugf("未经授权的IP%s", clientIP)
|
||
xhttp.JsonBaseResponseCtx(r.Context(), w, errors.New("未经授权的IP"))
|
||
return
|
||
}
|
||
}
|
||
|
||
// 2、查找相关accessId
|
||
accessId := r.Header.Get("Access-Id")
|
||
if accessId == "" {
|
||
xhttp.JsonBaseResponseCtx(r.Context(), w, errors.New("缺少Access-Id"))
|
||
return
|
||
}
|
||
secrets, err := m.SecretRpc.GetSecretBySecretId(r.Context(), &secret.GetSecretBySecretIdRequest{SecretId: accessId})
|
||
if err != nil {
|
||
xhttp.JsonBaseResponseCtx(r.Context(), w, errors.New("系统错误"))
|
||
return
|
||
}
|
||
if secrets.Id == 0 {
|
||
xhttp.JsonBaseResponseCtx(r.Context(), w, errors.New("未经授权的AccessId"))
|
||
return
|
||
}
|
||
|
||
userId := secrets.UserId
|
||
|
||
// 3、额度是否冻结
|
||
info, err := m.UserRpc.GetUserInfo(r.Context(), &user.UserInfoReq{UserId: userId})
|
||
if err != nil {
|
||
xhttp.JsonBaseResponseCtx(r.Context(), w, errors.New("系统错误,请联系管理员"))
|
||
return
|
||
}
|
||
|
||
if info.QuotaExceeded == 1 {
|
||
xhttp.JsonBaseResponseCtx(r.Context(), w, errors.New("账户余额不足,无法请求"))
|
||
return
|
||
}
|
||
|
||
// 4、是否有开通该产品
|
||
pathParts := strings.Split(r.URL.Path, "/")
|
||
productCode := pathParts[len(pathParts)-1]
|
||
userProductRedisKey := fmt.Sprintf("user_products:%d", userId)
|
||
|
||
isMemberUserProduct, err := m.Rds.SismemberCtx(r.Context(), userProductRedisKey, productCode)
|
||
if err == nil && isMemberUserProduct {
|
||
|
||
} else {
|
||
isUserProductAllowedResp, err := m.UserProductRpc.MatchingUserIdProductCode(r.Context(), &userproduct.MatchingUserIdProductCodeRequest{Id: userId, ProductCode: productCode})
|
||
if err != nil {
|
||
xhttp.JsonBaseResponseCtx(r.Context(), w, errors.New("系统错误,请联系管理员"))
|
||
return
|
||
}
|
||
if !isUserProductAllowedResp.Match {
|
||
xhttp.JsonBaseResponseCtx(r.Context(), w, errors.New("未开通此产品"))
|
||
return
|
||
}
|
||
}
|
||
|
||
// 将 userId 存入 context,供后续逻辑使用
|
||
ctx := context.WithValue(r.Context(), "userId", userId)
|
||
ctx = context.WithValue(ctx, "secretKey", secrets.AesKey)
|
||
ctx = context.WithValue(ctx, "productCode", productCode)
|
||
|
||
// 生成流水号
|
||
transactionID := crypto.GenerateTransactionID()
|
||
ctx = context.WithValue(ctx, "transactionID", transactionID)
|
||
|
||
next(w, r.WithContext(ctx))
|
||
}
|
||
}
|