Files
tyapi-server/internal/shared/middleware/auth.go

262 lines
6.6 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package middleware
import (
"net/http"
"strings"
"time"
"tyapi-server/internal/config"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v5"
"go.uber.org/zap"
)
// JWTAuthMiddleware JWT认证中间件
type JWTAuthMiddleware struct {
config *config.Config
logger *zap.Logger
}
// NewJWTAuthMiddleware 创建JWT认证中间件
func NewJWTAuthMiddleware(cfg *config.Config, logger *zap.Logger) *JWTAuthMiddleware {
return &JWTAuthMiddleware{
config: cfg,
logger: logger,
}
}
// GetName 返回中间件名称
func (m *JWTAuthMiddleware) GetName() string {
return "jwt_auth"
}
// GetPriority 返回中间件优先级
func (m *JWTAuthMiddleware) GetPriority() int {
return 60 // 中等优先级,在日志之后,业务处理之前
}
// Handle 返回中间件处理函数
func (m *JWTAuthMiddleware) Handle() gin.HandlerFunc {
return func(c *gin.Context) {
// 获取Authorization头部
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
m.respondUnauthorized(c, "Missing authorization header")
return
}
// 检查Bearer前缀
const bearerPrefix = "Bearer "
if !strings.HasPrefix(authHeader, bearerPrefix) {
m.respondUnauthorized(c, "Invalid authorization header format")
return
}
// 提取token
tokenString := authHeader[len(bearerPrefix):]
if tokenString == "" {
m.respondUnauthorized(c, "Missing token")
return
}
// 验证token
claims, err := m.validateToken(tokenString)
if err != nil {
m.logger.Warn("Invalid token",
zap.Error(err),
zap.String("request_id", c.GetString("request_id")))
m.respondUnauthorized(c, "Invalid token")
return
}
// 将用户信息添加到上下文
c.Set("user_id", claims.UserID)
c.Set("username", claims.Username)
c.Set("email", claims.Email)
c.Set("token_claims", claims)
c.Next()
}
}
// IsGlobal 是否为全局中间件
func (m *JWTAuthMiddleware) IsGlobal() bool {
return false // 不是全局中间件,需要手动应用到需要认证的路由
}
// JWTClaims JWT声明结构
type JWTClaims struct {
UserID string `json:"user_id"`
Username string `json:"username"`
Email string `json:"email"`
jwt.RegisteredClaims
}
// validateToken 验证JWT token
func (m *JWTAuthMiddleware) validateToken(tokenString string) (*JWTClaims, error) {
token, err := jwt.ParseWithClaims(tokenString, &JWTClaims{}, func(token *jwt.Token) (interface{}, error) {
// 验证签名方法
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, jwt.ErrSignatureInvalid
}
return []byte(m.config.JWT.Secret), nil
})
if err != nil {
return nil, err
}
claims, ok := token.Claims.(*JWTClaims)
if !ok || !token.Valid {
return nil, jwt.ErrSignatureInvalid
}
return claims, nil
}
// respondUnauthorized 返回未授权响应
func (m *JWTAuthMiddleware) respondUnauthorized(c *gin.Context, message string) {
c.JSON(http.StatusUnauthorized, gin.H{
"success": false,
"message": "Unauthorized",
"error": message,
"request_id": c.GetString("request_id"),
"timestamp": time.Now().Unix(),
})
c.Abort()
}
// GenerateToken 生成JWT token
func (m *JWTAuthMiddleware) GenerateToken(userID, username, email string) (string, error) {
now := time.Now()
claims := &JWTClaims{
UserID: userID,
Username: username,
Email: email,
RegisteredClaims: jwt.RegisteredClaims{
Issuer: "tyapi-server",
Subject: userID,
Audience: []string{"tyapi-client"},
ExpiresAt: jwt.NewNumericDate(now.Add(m.config.JWT.ExpiresIn)),
NotBefore: jwt.NewNumericDate(now),
IssuedAt: jwt.NewNumericDate(now),
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
return token.SignedString([]byte(m.config.JWT.Secret))
}
// GenerateRefreshToken 生成刷新token
func (m *JWTAuthMiddleware) GenerateRefreshToken(userID string) (string, error) {
now := time.Now()
claims := &jwt.RegisteredClaims{
Issuer: "tyapi-server",
Subject: userID,
Audience: []string{"tyapi-refresh"},
ExpiresAt: jwt.NewNumericDate(now.Add(m.config.JWT.RefreshExpiresIn)),
NotBefore: jwt.NewNumericDate(now),
IssuedAt: jwt.NewNumericDate(now),
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
return token.SignedString([]byte(m.config.JWT.Secret))
}
// ValidateRefreshToken 验证刷新token
func (m *JWTAuthMiddleware) ValidateRefreshToken(tokenString string) (string, error) {
token, err := jwt.ParseWithClaims(tokenString, &jwt.RegisteredClaims{}, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, jwt.ErrSignatureInvalid
}
return []byte(m.config.JWT.Secret), nil
})
if err != nil {
return "", err
}
claims, ok := token.Claims.(*jwt.RegisteredClaims)
if !ok || !token.Valid {
return "", jwt.ErrSignatureInvalid
}
// 检查是否为刷新token
if len(claims.Audience) == 0 || claims.Audience[0] != "tyapi-refresh" {
return "", jwt.ErrSignatureInvalid
}
return claims.Subject, nil
}
// OptionalAuthMiddleware 可选认证中间件(用户可能登录也可能未登录)
type OptionalAuthMiddleware struct {
jwtAuth *JWTAuthMiddleware
}
// NewOptionalAuthMiddleware 创建可选认证中间件
func NewOptionalAuthMiddleware(jwtAuth *JWTAuthMiddleware) *OptionalAuthMiddleware {
return &OptionalAuthMiddleware{
jwtAuth: jwtAuth,
}
}
// GetName 返回中间件名称
func (m *OptionalAuthMiddleware) GetName() string {
return "optional_auth"
}
// GetPriority 返回中间件优先级
func (m *OptionalAuthMiddleware) GetPriority() int {
return 60 // 与JWT认证中间件相同
}
// Handle 返回中间件处理函数
func (m *OptionalAuthMiddleware) Handle() gin.HandlerFunc {
return func(c *gin.Context) {
// 获取Authorization头部
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
// 没有认证头部,设置匿名用户标识
c.Set("is_authenticated", false)
c.Next()
return
}
// 检查Bearer前缀
const bearerPrefix = "Bearer "
if !strings.HasPrefix(authHeader, bearerPrefix) {
c.Set("is_authenticated", false)
c.Next()
return
}
// 提取并验证token
tokenString := authHeader[len(bearerPrefix):]
claims, err := m.jwtAuth.validateToken(tokenString)
if err != nil {
// token无效但不返回错误设置为未认证
c.Set("is_authenticated", false)
c.Next()
return
}
// token有效设置用户信息
c.Set("is_authenticated", true)
c.Set("user_id", claims.UserID)
c.Set("username", claims.Username)
c.Set("email", claims.Email)
c.Set("token_claims", claims)
c.Next()
}
}
// IsGlobal 是否为全局中间件
func (m *OptionalAuthMiddleware) IsGlobal() bool {
return false
}