84 lines
2.4 KiB
Go
84 lines
2.4 KiB
Go
package middleware
|
||
|
||
import (
|
||
"net/http"
|
||
"time"
|
||
|
||
"bd-server/app/main/api/internal/config"
|
||
"bd-server/app/main/model"
|
||
jwtx "bd-server/common/jwt"
|
||
)
|
||
|
||
const (
|
||
HeaderMembershipExpired = "X-Membership-Expired"
|
||
)
|
||
|
||
// MembershipExpiredInterceptor 检测代理会员是否过期,过期则写入响应头
|
||
// 由于是全局中间件(在 AuthInterceptor 之前执行),需要自己从请求头解析 JWT
|
||
type MembershipExpiredInterceptor struct {
|
||
AgentModel model.AgentModel
|
||
Config config.Config
|
||
}
|
||
|
||
func NewMembershipExpiredInterceptor(agentModel model.AgentModel, c config.Config) *MembershipExpiredInterceptor {
|
||
return &MembershipExpiredInterceptor{
|
||
AgentModel: agentModel,
|
||
Config: c,
|
||
}
|
||
}
|
||
|
||
func (m *MembershipExpiredInterceptor) Handle(next http.HandlerFunc) http.HandlerFunc {
|
||
return func(w http.ResponseWriter, r *http.Request) {
|
||
// 先检测会员过期状态(必须在 next() 之前设置响应头,否则响应已发送 header 无法生效)
|
||
needNotifyExpired := false
|
||
|
||
// 由于全局中间件在路由中间件之前执行,context 中可能没有 claims
|
||
// 先尝试从 context 获取,如果没有则自己从 Authorization 头解析
|
||
claims := getClaimsFromContext(r)
|
||
if claims == nil {
|
||
claims = parseClaimsFromAuthHeader(r, m.Config.JwtAuth.AccessSecret)
|
||
}
|
||
|
||
if claims != nil && claims.UserType == model.UserTypeNormal && r.URL.Path != "/api/v1/agent/info" {
|
||
agent, err := m.AgentModel.FindOneByUserId(r.Context(), claims.UserId)
|
||
if err == nil && agent != nil {
|
||
// 到期时间已过,说明会员已过期(不管等级是否已降级为 normal)
|
||
if agent.MembershipExpiryTime.Valid && !agent.MembershipExpiryTime.Time.After(time.Now()) {
|
||
needNotifyExpired = true
|
||
}
|
||
}
|
||
}
|
||
|
||
// 在 next() 之前设置响应头(此时 headers 还没发送给客户端)
|
||
if needNotifyExpired {
|
||
w.Header().Set(HeaderMembershipExpired, "true")
|
||
}
|
||
|
||
// 执行业务逻辑
|
||
next(w, r)
|
||
}
|
||
}
|
||
|
||
func getClaimsFromContext(r *http.Request) *jwtx.JwtClaims {
|
||
value := r.Context().Value(jwtx.ExtraKey)
|
||
if value == nil {
|
||
return nil
|
||
}
|
||
if claims, ok := value.(*jwtx.JwtClaims); ok {
|
||
return claims
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func parseClaimsFromAuthHeader(r *http.Request, secret string) *jwtx.JwtClaims {
|
||
authHeader := r.Header.Get("Authorization")
|
||
if authHeader == "" {
|
||
return nil
|
||
}
|
||
claims, err := jwtx.ParseJwtToken(authHeader, secret)
|
||
if err != nil {
|
||
return nil
|
||
}
|
||
return claims
|
||
}
|