262 lines
6.6 KiB
Go
262 lines
6.6 KiB
Go
package middleware
|
||
|
||
import (
|
||
"net/http"
|
||
"strings"
|
||
"time"
|
||
|
||
"tyapi-server/internal/config"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
"github.com/golang-jwt/jwt/v5"
|
||
"go.uber.org/zap"
|
||
)
|
||
|
||
// JWTAuthMiddleware JWT认证中间件
|
||
type JWTAuthMiddleware struct {
|
||
config *config.Config
|
||
logger *zap.Logger
|
||
}
|
||
|
||
// NewJWTAuthMiddleware 创建JWT认证中间件
|
||
func NewJWTAuthMiddleware(cfg *config.Config, logger *zap.Logger) *JWTAuthMiddleware {
|
||
return &JWTAuthMiddleware{
|
||
config: cfg,
|
||
logger: logger,
|
||
}
|
||
}
|
||
|
||
// GetName 返回中间件名称
|
||
func (m *JWTAuthMiddleware) GetName() string {
|
||
return "jwt_auth"
|
||
}
|
||
|
||
// GetPriority 返回中间件优先级
|
||
func (m *JWTAuthMiddleware) GetPriority() int {
|
||
return 60 // 中等优先级,在日志之后,业务处理之前
|
||
}
|
||
|
||
// Handle 返回中间件处理函数
|
||
func (m *JWTAuthMiddleware) Handle() gin.HandlerFunc {
|
||
return func(c *gin.Context) {
|
||
// 获取Authorization头部
|
||
authHeader := c.GetHeader("Authorization")
|
||
if authHeader == "" {
|
||
m.respondUnauthorized(c, "Missing authorization header")
|
||
return
|
||
}
|
||
|
||
// 检查Bearer前缀
|
||
const bearerPrefix = "Bearer "
|
||
if !strings.HasPrefix(authHeader, bearerPrefix) {
|
||
m.respondUnauthorized(c, "Invalid authorization header format")
|
||
return
|
||
}
|
||
|
||
// 提取token
|
||
tokenString := authHeader[len(bearerPrefix):]
|
||
if tokenString == "" {
|
||
m.respondUnauthorized(c, "Missing token")
|
||
return
|
||
}
|
||
|
||
// 验证token
|
||
claims, err := m.validateToken(tokenString)
|
||
if err != nil {
|
||
m.logger.Warn("Invalid token",
|
||
zap.Error(err),
|
||
zap.String("request_id", c.GetString("request_id")))
|
||
m.respondUnauthorized(c, "Invalid token")
|
||
return
|
||
}
|
||
|
||
// 将用户信息添加到上下文
|
||
c.Set("user_id", claims.UserID)
|
||
c.Set("username", claims.Username)
|
||
c.Set("email", claims.Email)
|
||
c.Set("token_claims", claims)
|
||
|
||
c.Next()
|
||
}
|
||
}
|
||
|
||
// IsGlobal 是否为全局中间件
|
||
func (m *JWTAuthMiddleware) IsGlobal() bool {
|
||
return false // 不是全局中间件,需要手动应用到需要认证的路由
|
||
}
|
||
|
||
// JWTClaims JWT声明结构
|
||
type JWTClaims struct {
|
||
UserID string `json:"user_id"`
|
||
Username string `json:"username"`
|
||
Email string `json:"email"`
|
||
jwt.RegisteredClaims
|
||
}
|
||
|
||
// validateToken 验证JWT token
|
||
func (m *JWTAuthMiddleware) validateToken(tokenString string) (*JWTClaims, error) {
|
||
token, err := jwt.ParseWithClaims(tokenString, &JWTClaims{}, func(token *jwt.Token) (interface{}, error) {
|
||
// 验证签名方法
|
||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||
return nil, jwt.ErrSignatureInvalid
|
||
}
|
||
return []byte(m.config.JWT.Secret), nil
|
||
})
|
||
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
claims, ok := token.Claims.(*JWTClaims)
|
||
if !ok || !token.Valid {
|
||
return nil, jwt.ErrSignatureInvalid
|
||
}
|
||
|
||
return claims, nil
|
||
}
|
||
|
||
// respondUnauthorized 返回未授权响应
|
||
func (m *JWTAuthMiddleware) respondUnauthorized(c *gin.Context, message string) {
|
||
c.JSON(http.StatusUnauthorized, gin.H{
|
||
"success": false,
|
||
"message": "Unauthorized",
|
||
"error": message,
|
||
"request_id": c.GetString("request_id"),
|
||
"timestamp": time.Now().Unix(),
|
||
})
|
||
c.Abort()
|
||
}
|
||
|
||
// GenerateToken 生成JWT token
|
||
func (m *JWTAuthMiddleware) GenerateToken(userID, username, email string) (string, error) {
|
||
now := time.Now()
|
||
|
||
claims := &JWTClaims{
|
||
UserID: userID,
|
||
Username: username,
|
||
Email: email,
|
||
RegisteredClaims: jwt.RegisteredClaims{
|
||
Issuer: "tyapi-server",
|
||
Subject: userID,
|
||
Audience: []string{"tyapi-client"},
|
||
ExpiresAt: jwt.NewNumericDate(now.Add(m.config.JWT.ExpiresIn)),
|
||
NotBefore: jwt.NewNumericDate(now),
|
||
IssuedAt: jwt.NewNumericDate(now),
|
||
},
|
||
}
|
||
|
||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||
return token.SignedString([]byte(m.config.JWT.Secret))
|
||
}
|
||
|
||
// GenerateRefreshToken 生成刷新token
|
||
func (m *JWTAuthMiddleware) GenerateRefreshToken(userID string) (string, error) {
|
||
now := time.Now()
|
||
|
||
claims := &jwt.RegisteredClaims{
|
||
Issuer: "tyapi-server",
|
||
Subject: userID,
|
||
Audience: []string{"tyapi-refresh"},
|
||
ExpiresAt: jwt.NewNumericDate(now.Add(m.config.JWT.RefreshExpiresIn)),
|
||
NotBefore: jwt.NewNumericDate(now),
|
||
IssuedAt: jwt.NewNumericDate(now),
|
||
}
|
||
|
||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||
return token.SignedString([]byte(m.config.JWT.Secret))
|
||
}
|
||
|
||
// ValidateRefreshToken 验证刷新token
|
||
func (m *JWTAuthMiddleware) ValidateRefreshToken(tokenString string) (string, error) {
|
||
token, err := jwt.ParseWithClaims(tokenString, &jwt.RegisteredClaims{}, func(token *jwt.Token) (interface{}, error) {
|
||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||
return nil, jwt.ErrSignatureInvalid
|
||
}
|
||
return []byte(m.config.JWT.Secret), nil
|
||
})
|
||
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
|
||
claims, ok := token.Claims.(*jwt.RegisteredClaims)
|
||
if !ok || !token.Valid {
|
||
return "", jwt.ErrSignatureInvalid
|
||
}
|
||
|
||
// 检查是否为刷新token
|
||
if len(claims.Audience) == 0 || claims.Audience[0] != "tyapi-refresh" {
|
||
return "", jwt.ErrSignatureInvalid
|
||
}
|
||
|
||
return claims.Subject, nil
|
||
}
|
||
|
||
// OptionalAuthMiddleware 可选认证中间件(用户可能登录也可能未登录)
|
||
type OptionalAuthMiddleware struct {
|
||
jwtAuth *JWTAuthMiddleware
|
||
}
|
||
|
||
// NewOptionalAuthMiddleware 创建可选认证中间件
|
||
func NewOptionalAuthMiddleware(jwtAuth *JWTAuthMiddleware) *OptionalAuthMiddleware {
|
||
return &OptionalAuthMiddleware{
|
||
jwtAuth: jwtAuth,
|
||
}
|
||
}
|
||
|
||
// GetName 返回中间件名称
|
||
func (m *OptionalAuthMiddleware) GetName() string {
|
||
return "optional_auth"
|
||
}
|
||
|
||
// GetPriority 返回中间件优先级
|
||
func (m *OptionalAuthMiddleware) GetPriority() int {
|
||
return 60 // 与JWT认证中间件相同
|
||
}
|
||
|
||
// Handle 返回中间件处理函数
|
||
func (m *OptionalAuthMiddleware) Handle() gin.HandlerFunc {
|
||
return func(c *gin.Context) {
|
||
// 获取Authorization头部
|
||
authHeader := c.GetHeader("Authorization")
|
||
if authHeader == "" {
|
||
// 没有认证头部,设置匿名用户标识
|
||
c.Set("is_authenticated", false)
|
||
c.Next()
|
||
return
|
||
}
|
||
|
||
// 检查Bearer前缀
|
||
const bearerPrefix = "Bearer "
|
||
if !strings.HasPrefix(authHeader, bearerPrefix) {
|
||
c.Set("is_authenticated", false)
|
||
c.Next()
|
||
return
|
||
}
|
||
|
||
// 提取并验证token
|
||
tokenString := authHeader[len(bearerPrefix):]
|
||
claims, err := m.jwtAuth.validateToken(tokenString)
|
||
if err != nil {
|
||
// token无效,但不返回错误,设置为未认证
|
||
c.Set("is_authenticated", false)
|
||
c.Next()
|
||
return
|
||
}
|
||
|
||
// token有效,设置用户信息
|
||
c.Set("is_authenticated", true)
|
||
c.Set("user_id", claims.UserID)
|
||
c.Set("username", claims.Username)
|
||
c.Set("email", claims.Email)
|
||
c.Set("token_claims", claims)
|
||
|
||
c.Next()
|
||
}
|
||
}
|
||
|
||
// IsGlobal 是否为全局中间件
|
||
func (m *OptionalAuthMiddleware) IsGlobal() bool {
|
||
return false
|
||
}
|