tianyuan-api-server/apps/api/internal/middleware/apiauthinterceptormiddleware.go
2024-10-12 20:41:55 +08:00

137 lines
4.3 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package middleware
import (
"context"
"errors"
"fmt"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/stores/redis"
xhttp "github.com/zeromicro/x/http"
"net"
"net/http"
"strings"
"tianyuan-api/apps/sentinel/client/secret"
"tianyuan-api/apps/sentinel/client/userproduct"
"tianyuan-api/apps/sentinel/client/whitelist"
"tianyuan-api/apps/sentinel/sentinel"
"tianyuan-api/apps/user/user"
"tianyuan-api/pkg/crypto"
)
type ApiAuthInterceptorMiddleware struct {
WhitelistRpc sentinel.WhitelistClient
SecretRpc sentinel.SecretClient
UserProductRpc sentinel.UserProductClient
UserRpc user.UserClient
Rds *redis.Redis
}
func NewApiAuthInterceptorMiddleware(
whitelistRpc sentinel.WhitelistClient,
secretRpc sentinel.SecretClient,
userProductRpc sentinel.UserProductClient,
userRpc user.UserClient,
rds *redis.Redis) *ApiAuthInterceptorMiddleware {
return &ApiAuthInterceptorMiddleware{
WhitelistRpc: whitelistRpc,
SecretRpc: secretRpc,
UserProductRpc: userProductRpc,
UserRpc: userRpc,
Rds: rds,
}
}
func (m *ApiAuthInterceptorMiddleware) Handle(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// 1. 查找IP白名单
clientIP := r.Header.Get("X-Forwarded-For")
if clientIP == "" {
clientIP = r.Header.Get("X-Real-IP")
}
if clientIP == "" {
clientIP, _, _ = net.SplitHostPort(r.RemoteAddr)
}
logx.Infof("当前请求IP%s", clientIP)
redisKey := "whitelist_ips"
isMember, err := m.Rds.SismemberCtx(r.Context(), redisKey, clientIP)
if err == nil && isMember {
// 如果缓存中存在该IP继续执行后续鉴权操作
// 此处不调用 next(w, r),而是继续后续鉴权逻辑
// 后续鉴权逻辑将继续执行
} else {
isAllowedResp, err := m.WhitelistRpc.MatchWhitelistByIp(r.Context(), &whitelist.MatchWhitelistByIpRequest{Ip: clientIP})
if err != nil {
xhttp.JsonBaseResponseCtx(r.Context(), w, errors.New("系统错误,请联系管理员"))
return
}
if !isAllowedResp.Match {
logx.Debugf("未经授权的IP%s", clientIP)
xhttp.JsonBaseResponseCtx(r.Context(), w, errors.New("未经授权的IP"))
return
}
}
// 2、查找相关accessId
accessId := r.Header.Get("Access-Id")
if accessId == "" {
xhttp.JsonBaseResponseCtx(r.Context(), w, errors.New("缺少Access-Id"))
return
}
secrets, err := m.SecretRpc.GetSecretBySecretId(r.Context(), &secret.GetSecretBySecretIdRequest{SecretId: accessId})
if err != nil {
xhttp.JsonBaseResponseCtx(r.Context(), w, errors.New("系统错误"))
return
}
if secrets.Id == 0 {
xhttp.JsonBaseResponseCtx(r.Context(), w, errors.New("未经授权的AccessId"))
return
}
userId := secrets.UserId
// 3、额度是否冻结
info, err := m.UserRpc.GetUserInfo(r.Context(), &user.UserInfoReq{UserId: userId})
if err != nil {
xhttp.JsonBaseResponseCtx(r.Context(), w, errors.New("系统错误,请联系管理员"))
return
}
if info.QuotaExceeded == 1 {
xhttp.JsonBaseResponseCtx(r.Context(), w, errors.New("账户余额不足,无法请求"))
return
}
// 4、是否有开通该产品
pathParts := strings.Split(r.URL.Path, "/")
productCode := pathParts[len(pathParts)-1]
userProductRedisKey := fmt.Sprintf("user_products:%d", userId)
isMemberUserProduct, err := m.Rds.SismemberCtx(r.Context(), userProductRedisKey, productCode)
if err == nil && isMemberUserProduct {
} else {
isUserProductAllowedResp, err := m.UserProductRpc.MatchingUserIdProductCode(r.Context(), &userproduct.MatchingUserIdProductCodeRequest{Id: userId, ProductCode: productCode})
if err != nil {
xhttp.JsonBaseResponseCtx(r.Context(), w, errors.New("系统错误,请联系管理员"))
return
}
if !isUserProductAllowedResp.Match {
xhttp.JsonBaseResponseCtx(r.Context(), w, errors.New("未开通此产品"))
return
}
}
// 将 userId 存入 context供后续逻辑使用
ctx := context.WithValue(r.Context(), "userId", userId)
ctx = context.WithValue(ctx, "secretKey", secrets.AesKey)
ctx = context.WithValue(ctx, "productCode", productCode)
// 生成流水号
transactionID := crypto.GenerateTransactionID()
ctx = context.WithValue(ctx, "transactionID", transactionID)
next(w, r.WithContext(ctx))
}
}