84 lines
2.4 KiB
Go
84 lines
2.4 KiB
Go
|
|
package middleware
|
|||
|
|
|
|||
|
|
import (
|
|||
|
|
"net/http"
|
|||
|
|
"time"
|
|||
|
|
|
|||
|
|
"bd-server/app/main/api/internal/config"
|
|||
|
|
"bd-server/app/main/model"
|
|||
|
|
jwtx "bd-server/common/jwt"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
const (
|
|||
|
|
HeaderMembershipExpired = "X-Membership-Expired"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
// MembershipExpiredInterceptor 检测代理会员是否过期,过期则写入响应头
|
|||
|
|
// 由于是全局中间件(在 AuthInterceptor 之前执行),需要自己从请求头解析 JWT
|
|||
|
|
type MembershipExpiredInterceptor struct {
|
|||
|
|
AgentModel model.AgentModel
|
|||
|
|
Config config.Config
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func NewMembershipExpiredInterceptor(agentModel model.AgentModel, c config.Config) *MembershipExpiredInterceptor {
|
|||
|
|
return &MembershipExpiredInterceptor{
|
|||
|
|
AgentModel: agentModel,
|
|||
|
|
Config: c,
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (m *MembershipExpiredInterceptor) Handle(next http.HandlerFunc) http.HandlerFunc {
|
|||
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|||
|
|
// 先检测会员过期状态(必须在 next() 之前设置响应头,否则响应已发送 header 无法生效)
|
|||
|
|
needNotifyExpired := false
|
|||
|
|
|
|||
|
|
// 由于全局中间件在路由中间件之前执行,context 中可能没有 claims
|
|||
|
|
// 先尝试从 context 获取,如果没有则自己从 Authorization 头解析
|
|||
|
|
claims := getClaimsFromContext(r)
|
|||
|
|
if claims == nil {
|
|||
|
|
claims = parseClaimsFromAuthHeader(r, m.Config.JwtAuth.AccessSecret)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if claims != nil && claims.UserType == model.UserTypeNormal && r.URL.Path != "/api/v1/agent/info" {
|
|||
|
|
agent, err := m.AgentModel.FindOneByUserId(r.Context(), claims.UserId)
|
|||
|
|
if err == nil && agent != nil {
|
|||
|
|
// 到期时间已过,说明会员已过期(不管等级是否已降级为 normal)
|
|||
|
|
if agent.MembershipExpiryTime.Valid && !agent.MembershipExpiryTime.Time.After(time.Now()) {
|
|||
|
|
needNotifyExpired = true
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 在 next() 之前设置响应头(此时 headers 还没发送给客户端)
|
|||
|
|
if needNotifyExpired {
|
|||
|
|
w.Header().Set(HeaderMembershipExpired, "true")
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 执行业务逻辑
|
|||
|
|
next(w, r)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func getClaimsFromContext(r *http.Request) *jwtx.JwtClaims {
|
|||
|
|
value := r.Context().Value(jwtx.ExtraKey)
|
|||
|
|
if value == nil {
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
if claims, ok := value.(*jwtx.JwtClaims); ok {
|
|||
|
|
return claims
|
|||
|
|
}
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func parseClaimsFromAuthHeader(r *http.Request, secret string) *jwtx.JwtClaims {
|
|||
|
|
authHeader := r.Header.Get("Authorization")
|
|||
|
|
if authHeader == "" {
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
claims, err := jwtx.ParseJwtToken(authHeader, secret)
|
|||
|
|
if err != nil {
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
return claims
|
|||
|
|
}
|