tianyuan-api-server/apps/api/internal/middleware/apiauthinterceptormiddleware.go

137 lines
4.1 KiB
Go
Raw Normal View History

2024-10-02 00:57:17 +08:00
package middleware
import (
"context"
"fmt"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/stores/redis"
2024-10-04 23:07:49 +08:00
"net"
"net/http"
"strings"
2024-10-02 00:57:17 +08:00
"tianyuan-api/apps/sentinel/client/secret"
"tianyuan-api/apps/sentinel/client/userproduct"
"tianyuan-api/apps/sentinel/client/whitelist"
"tianyuan-api/apps/sentinel/sentinel"
2024-10-12 20:41:55 +08:00
"tianyuan-api/apps/user/user"
"tianyuan-api/pkg/crypto"
2024-10-15 20:52:51 +08:00
"tianyuan-api/pkg/errs"
"tianyuan-api/pkg/response"
2024-10-02 00:57:17 +08:00
)
type ApiAuthInterceptorMiddleware struct {
WhitelistRpc sentinel.WhitelistClient
SecretRpc sentinel.SecretClient
UserProductRpc sentinel.UserProductClient
2024-10-12 20:41:55 +08:00
UserRpc user.UserClient
2024-10-02 00:57:17 +08:00
Rds *redis.Redis
}
func NewApiAuthInterceptorMiddleware(
whitelistRpc sentinel.WhitelistClient,
secretRpc sentinel.SecretClient,
userProductRpc sentinel.UserProductClient,
2024-10-12 20:41:55 +08:00
userRpc user.UserClient,
2024-10-02 00:57:17 +08:00
rds *redis.Redis) *ApiAuthInterceptorMiddleware {
return &ApiAuthInterceptorMiddleware{
WhitelistRpc: whitelistRpc,
SecretRpc: secretRpc,
UserProductRpc: userProductRpc,
2024-10-12 20:41:55 +08:00
UserRpc: userRpc,
2024-10-02 00:57:17 +08:00
Rds: rds,
}
}
func (m *ApiAuthInterceptorMiddleware) Handle(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// 1. 查找IP白名单
clientIP := r.Header.Get("X-Forwarded-For")
if clientIP == "" {
clientIP = r.Header.Get("X-Real-IP")
}
if clientIP == "" {
clientIP, _, _ = net.SplitHostPort(r.RemoteAddr)
}
logx.Infof("当前请求IP%s", clientIP)
redisKey := "whitelist_ips"
isMember, err := m.Rds.SismemberCtx(r.Context(), redisKey, clientIP)
if err == nil && isMember {
// 如果缓存中存在该IP继续执行后续鉴权操作
// 此处不调用 next(w, r),而是继续后续鉴权逻辑
// 后续鉴权逻辑将继续执行
} else {
2024-10-15 20:52:51 +08:00
isAllowedResp, matchWhitelistByIpErr := m.WhitelistRpc.MatchWhitelistByIp(r.Context(), &whitelist.MatchWhitelistByIpRequest{Ip: clientIP})
if matchWhitelistByIpErr != nil {
response.Fail(r.Context(), w, errs.ErrSystem, nil)
2024-10-02 00:57:17 +08:00
return
}
if !isAllowedResp.Match {
logx.Debugf("未经授权的IP%s", clientIP)
2024-10-15 20:52:51 +08:00
response.Fail(r.Context(), w, errs.ErrUnauthorizedIP, nil)
2024-10-02 00:57:17 +08:00
return
}
}
// 2、查找相关accessId
accessId := r.Header.Get("Access-Id")
if accessId == "" {
2024-10-15 20:52:51 +08:00
response.Fail(r.Context(), w, errs.ErrMissingAccessID, nil)
2024-10-02 00:57:17 +08:00
return
}
secrets, err := m.SecretRpc.GetSecretBySecretId(r.Context(), &secret.GetSecretBySecretIdRequest{SecretId: accessId})
if err != nil {
2024-10-15 20:52:51 +08:00
response.Fail(r.Context(), w, errs.ErrSystem, nil)
2024-10-02 00:57:17 +08:00
return
}
if secrets.Id == 0 {
2024-10-15 20:52:51 +08:00
response.Fail(r.Context(), w, errs.ErrUnauthorizedAccessID, nil)
2024-10-02 00:57:17 +08:00
return
}
userId := secrets.UserId
2024-10-12 20:41:55 +08:00
// 3、额度是否冻结
info, err := m.UserRpc.GetUserInfo(r.Context(), &user.UserInfoReq{UserId: userId})
if err != nil {
2024-10-15 20:52:51 +08:00
response.Fail(r.Context(), w, errs.ErrSystem, nil)
2024-10-12 20:41:55 +08:00
return
}
if info.QuotaExceeded == 1 {
2024-10-15 20:52:51 +08:00
response.Fail(r.Context(), w, errs.ErrInsufficientBalance, nil)
2024-10-12 20:41:55 +08:00
return
}
// 4、是否有开通该产品
2024-10-02 00:57:17 +08:00
pathParts := strings.Split(r.URL.Path, "/")
productCode := pathParts[len(pathParts)-1]
userProductRedisKey := fmt.Sprintf("user_products:%d", userId)
isMemberUserProduct, err := m.Rds.SismemberCtx(r.Context(), userProductRedisKey, productCode)
if err == nil && isMemberUserProduct {
} else {
isUserProductAllowedResp, err := m.UserProductRpc.MatchingUserIdProductCode(r.Context(), &userproduct.MatchingUserIdProductCodeRequest{Id: userId, ProductCode: productCode})
if err != nil {
2024-10-15 20:52:51 +08:00
response.Fail(r.Context(), w, errs.ErrSystem, nil)
2024-10-02 00:57:17 +08:00
return
}
if !isUserProductAllowedResp.Match {
2024-10-15 20:52:51 +08:00
response.Fail(r.Context(), w, errs.ErrProductNotAvailable, nil)
2024-10-02 00:57:17 +08:00
return
}
}
// 将 userId 存入 context供后续逻辑使用
ctx := context.WithValue(r.Context(), "userId", userId)
2024-10-12 20:41:55 +08:00
ctx = context.WithValue(ctx, "secretKey", secrets.AesKey)
ctx = context.WithValue(ctx, "productCode", productCode)
// 生成流水号
transactionID := crypto.GenerateTransactionID()
ctx = context.WithValue(ctx, "transactionID", transactionID)
2024-10-02 00:57:17 +08:00
next(w, r.WithContext(ctx))
}
}