Files
bd-server/app/main/api/internal/middleware/membershipinterceptormiddleware.go
2026-05-08 11:30:05 +08:00

84 lines
2.4 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package middleware
import (
"net/http"
"time"
"bd-server/app/main/api/internal/config"
"bd-server/app/main/model"
jwtx "bd-server/common/jwt"
)
const (
HeaderMembershipExpired = "X-Membership-Expired"
)
// MembershipExpiredInterceptor 检测代理会员是否过期,过期则写入响应头
// 由于是全局中间件(在 AuthInterceptor 之前执行),需要自己从请求头解析 JWT
type MembershipExpiredInterceptor struct {
AgentModel model.AgentModel
Config config.Config
}
func NewMembershipExpiredInterceptor(agentModel model.AgentModel, c config.Config) *MembershipExpiredInterceptor {
return &MembershipExpiredInterceptor{
AgentModel: agentModel,
Config: c,
}
}
func (m *MembershipExpiredInterceptor) Handle(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// 先检测会员过期状态(必须在 next() 之前设置响应头,否则响应已发送 header 无法生效)
needNotifyExpired := false
// 由于全局中间件在路由中间件之前执行context 中可能没有 claims
// 先尝试从 context 获取,如果没有则自己从 Authorization 头解析
claims := getClaimsFromContext(r)
if claims == nil {
claims = parseClaimsFromAuthHeader(r, m.Config.JwtAuth.AccessSecret)
}
if claims != nil && claims.UserType == model.UserTypeNormal && r.URL.Path != "/api/v1/agent/info" {
agent, err := m.AgentModel.FindOneByUserId(r.Context(), claims.UserId)
if err == nil && agent != nil {
// 到期时间已过,说明会员已过期(不管等级是否已降级为 normal
if agent.MembershipExpiryTime.Valid && !agent.MembershipExpiryTime.Time.After(time.Now()) {
needNotifyExpired = true
}
}
}
// 在 next() 之前设置响应头(此时 headers 还没发送给客户端)
if needNotifyExpired {
w.Header().Set(HeaderMembershipExpired, "true")
}
// 执行业务逻辑
next(w, r)
}
}
func getClaimsFromContext(r *http.Request) *jwtx.JwtClaims {
value := r.Context().Value(jwtx.ExtraKey)
if value == nil {
return nil
}
if claims, ok := value.(*jwtx.JwtClaims); ok {
return claims
}
return nil
}
func parseClaimsFromAuthHeader(r *http.Request, secret string) *jwtx.JwtClaims {
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
return nil
}
claims, err := jwtx.ParseJwtToken(authHeader, secret)
if err != nil {
return nil
}
return claims
}