| 
									
										
										
										
											2025-06-30 19:21:56 +08:00
										 |  |  |  | package middleware | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | import ( | 
					
						
							|  |  |  |  | 	"net/http" | 
					
						
							|  |  |  |  | 	"strings" | 
					
						
							|  |  |  |  | 	"time" | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 	"tyapi-server/internal/config" | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 	"github.com/gin-gonic/gin" | 
					
						
							|  |  |  |  | 	"github.com/golang-jwt/jwt/v5" | 
					
						
							|  |  |  |  | 	"go.uber.org/zap" | 
					
						
							|  |  |  |  | ) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | // JWTAuthMiddleware JWT认证中间件 | 
					
						
							|  |  |  |  | type JWTAuthMiddleware struct { | 
					
						
							|  |  |  |  | 	config *config.Config | 
					
						
							|  |  |  |  | 	logger *zap.Logger | 
					
						
							|  |  |  |  | } | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | // NewJWTAuthMiddleware 创建JWT认证中间件 | 
					
						
							|  |  |  |  | func NewJWTAuthMiddleware(cfg *config.Config, logger *zap.Logger) *JWTAuthMiddleware { | 
					
						
							|  |  |  |  | 	return &JWTAuthMiddleware{ | 
					
						
							|  |  |  |  | 		config: cfg, | 
					
						
							|  |  |  |  | 		logger: logger, | 
					
						
							|  |  |  |  | 	} | 
					
						
							|  |  |  |  | } | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | // GetName 返回中间件名称 | 
					
						
							|  |  |  |  | func (m *JWTAuthMiddleware) GetName() string { | 
					
						
							|  |  |  |  | 	return "jwt_auth" | 
					
						
							|  |  |  |  | } | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | // GetPriority 返回中间件优先级 | 
					
						
							|  |  |  |  | func (m *JWTAuthMiddleware) GetPriority() int { | 
					
						
							|  |  |  |  | 	return 60 // 中等优先级,在日志之后,业务处理之前 | 
					
						
							|  |  |  |  | } | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | // Handle 返回中间件处理函数 | 
					
						
							|  |  |  |  | func (m *JWTAuthMiddleware) Handle() gin.HandlerFunc { | 
					
						
							|  |  |  |  | 	return func(c *gin.Context) { | 
					
						
							|  |  |  |  | 		// 获取Authorization头部 | 
					
						
							|  |  |  |  | 		authHeader := c.GetHeader("Authorization") | 
					
						
							|  |  |  |  | 		if authHeader == "" { | 
					
						
							| 
									
										
										
										
											2025-07-02 16:17:59 +08:00
										 |  |  |  | 			m.respondUnauthorized(c, "缺少认证头部") | 
					
						
							| 
									
										
										
										
											2025-06-30 19:21:56 +08:00
										 |  |  |  | 			return | 
					
						
							|  |  |  |  | 		} | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 		// 检查Bearer前缀 | 
					
						
							|  |  |  |  | 		const bearerPrefix = "Bearer " | 
					
						
							|  |  |  |  | 		if !strings.HasPrefix(authHeader, bearerPrefix) { | 
					
						
							| 
									
										
										
										
											2025-07-02 16:17:59 +08:00
										 |  |  |  | 			m.respondUnauthorized(c, "认证头部格式无效") | 
					
						
							| 
									
										
										
										
											2025-06-30 19:21:56 +08:00
										 |  |  |  | 			return | 
					
						
							|  |  |  |  | 		} | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 		// 提取token | 
					
						
							|  |  |  |  | 		tokenString := authHeader[len(bearerPrefix):] | 
					
						
							|  |  |  |  | 		if tokenString == "" { | 
					
						
							| 
									
										
										
										
											2025-07-02 16:17:59 +08:00
										 |  |  |  | 			m.respondUnauthorized(c, "缺少认证令牌") | 
					
						
							| 
									
										
										
										
											2025-06-30 19:21:56 +08:00
										 |  |  |  | 			return | 
					
						
							|  |  |  |  | 		} | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 		// 验证token | 
					
						
							|  |  |  |  | 		claims, err := m.validateToken(tokenString) | 
					
						
							|  |  |  |  | 		if err != nil { | 
					
						
							| 
									
										
										
										
											2025-07-02 16:17:59 +08:00
										 |  |  |  | 			m.logger.Warn("无效的认证令牌", | 
					
						
							| 
									
										
										
										
											2025-06-30 19:21:56 +08:00
										 |  |  |  | 				zap.Error(err), | 
					
						
							|  |  |  |  | 				zap.String("request_id", c.GetString("request_id"))) | 
					
						
							| 
									
										
										
										
											2025-07-02 16:17:59 +08:00
										 |  |  |  | 			m.respondUnauthorized(c, "认证令牌无效") | 
					
						
							| 
									
										
										
										
											2025-06-30 19:21:56 +08:00
										 |  |  |  | 			return | 
					
						
							|  |  |  |  | 		} | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 		// 将用户信息添加到上下文 | 
					
						
							|  |  |  |  | 		c.Set("user_id", claims.UserID) | 
					
						
							|  |  |  |  | 		c.Set("username", claims.Username) | 
					
						
							|  |  |  |  | 		c.Set("email", claims.Email) | 
					
						
							|  |  |  |  | 		c.Set("token_claims", claims) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 		c.Next() | 
					
						
							|  |  |  |  | 	} | 
					
						
							|  |  |  |  | } | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | // IsGlobal 是否为全局中间件 | 
					
						
							|  |  |  |  | func (m *JWTAuthMiddleware) IsGlobal() bool { | 
					
						
							|  |  |  |  | 	return false // 不是全局中间件,需要手动应用到需要认证的路由 | 
					
						
							|  |  |  |  | } | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | // JWTClaims JWT声明结构 | 
					
						
							|  |  |  |  | type JWTClaims struct { | 
					
						
							|  |  |  |  | 	UserID   string `json:"user_id"` | 
					
						
							|  |  |  |  | 	Username string `json:"username"` | 
					
						
							|  |  |  |  | 	Email    string `json:"email"` | 
					
						
							|  |  |  |  | 	jwt.RegisteredClaims | 
					
						
							|  |  |  |  | } | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | // validateToken 验证JWT token | 
					
						
							|  |  |  |  | func (m *JWTAuthMiddleware) validateToken(tokenString string) (*JWTClaims, error) { | 
					
						
							|  |  |  |  | 	token, err := jwt.ParseWithClaims(tokenString, &JWTClaims{}, func(token *jwt.Token) (interface{}, error) { | 
					
						
							|  |  |  |  | 		// 验证签名方法 | 
					
						
							|  |  |  |  | 		if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { | 
					
						
							|  |  |  |  | 			return nil, jwt.ErrSignatureInvalid | 
					
						
							|  |  |  |  | 		} | 
					
						
							|  |  |  |  | 		return []byte(m.config.JWT.Secret), nil | 
					
						
							|  |  |  |  | 	}) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 	if err != nil { | 
					
						
							|  |  |  |  | 		return nil, err | 
					
						
							|  |  |  |  | 	} | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 	claims, ok := token.Claims.(*JWTClaims) | 
					
						
							|  |  |  |  | 	if !ok || !token.Valid { | 
					
						
							|  |  |  |  | 		return nil, jwt.ErrSignatureInvalid | 
					
						
							|  |  |  |  | 	} | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 	return claims, nil | 
					
						
							|  |  |  |  | } | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | // respondUnauthorized 返回未授权响应 | 
					
						
							|  |  |  |  | func (m *JWTAuthMiddleware) respondUnauthorized(c *gin.Context, message string) { | 
					
						
							|  |  |  |  | 	c.JSON(http.StatusUnauthorized, gin.H{ | 
					
						
							|  |  |  |  | 		"success":    false, | 
					
						
							| 
									
										
										
										
											2025-07-02 16:17:59 +08:00
										 |  |  |  | 		"message":    "认证失败", | 
					
						
							| 
									
										
										
										
											2025-06-30 19:21:56 +08:00
										 |  |  |  | 		"error":      message, | 
					
						
							|  |  |  |  | 		"request_id": c.GetString("request_id"), | 
					
						
							|  |  |  |  | 		"timestamp":  time.Now().Unix(), | 
					
						
							|  |  |  |  | 	}) | 
					
						
							|  |  |  |  | 	c.Abort() | 
					
						
							|  |  |  |  | } | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | // GenerateToken 生成JWT token | 
					
						
							|  |  |  |  | func (m *JWTAuthMiddleware) GenerateToken(userID, username, email string) (string, error) { | 
					
						
							|  |  |  |  | 	now := time.Now() | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 	claims := &JWTClaims{ | 
					
						
							|  |  |  |  | 		UserID:   userID, | 
					
						
							|  |  |  |  | 		Username: username, | 
					
						
							|  |  |  |  | 		Email:    email, | 
					
						
							|  |  |  |  | 		RegisteredClaims: jwt.RegisteredClaims{ | 
					
						
							|  |  |  |  | 			Issuer:    "tyapi-server", | 
					
						
							|  |  |  |  | 			Subject:   userID, | 
					
						
							|  |  |  |  | 			Audience:  []string{"tyapi-client"}, | 
					
						
							|  |  |  |  | 			ExpiresAt: jwt.NewNumericDate(now.Add(m.config.JWT.ExpiresIn)), | 
					
						
							|  |  |  |  | 			NotBefore: jwt.NewNumericDate(now), | 
					
						
							|  |  |  |  | 			IssuedAt:  jwt.NewNumericDate(now), | 
					
						
							|  |  |  |  | 		}, | 
					
						
							|  |  |  |  | 	} | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 	token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) | 
					
						
							|  |  |  |  | 	return token.SignedString([]byte(m.config.JWT.Secret)) | 
					
						
							|  |  |  |  | } | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | // GenerateRefreshToken 生成刷新token | 
					
						
							|  |  |  |  | func (m *JWTAuthMiddleware) GenerateRefreshToken(userID string) (string, error) { | 
					
						
							|  |  |  |  | 	now := time.Now() | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 	claims := &jwt.RegisteredClaims{ | 
					
						
							|  |  |  |  | 		Issuer:    "tyapi-server", | 
					
						
							|  |  |  |  | 		Subject:   userID, | 
					
						
							|  |  |  |  | 		Audience:  []string{"tyapi-refresh"}, | 
					
						
							|  |  |  |  | 		ExpiresAt: jwt.NewNumericDate(now.Add(m.config.JWT.RefreshExpiresIn)), | 
					
						
							|  |  |  |  | 		NotBefore: jwt.NewNumericDate(now), | 
					
						
							|  |  |  |  | 		IssuedAt:  jwt.NewNumericDate(now), | 
					
						
							|  |  |  |  | 	} | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 	token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) | 
					
						
							|  |  |  |  | 	return token.SignedString([]byte(m.config.JWT.Secret)) | 
					
						
							|  |  |  |  | } | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | // ValidateRefreshToken 验证刷新token | 
					
						
							|  |  |  |  | func (m *JWTAuthMiddleware) ValidateRefreshToken(tokenString string) (string, error) { | 
					
						
							|  |  |  |  | 	token, err := jwt.ParseWithClaims(tokenString, &jwt.RegisteredClaims{}, func(token *jwt.Token) (interface{}, error) { | 
					
						
							|  |  |  |  | 		if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { | 
					
						
							|  |  |  |  | 			return nil, jwt.ErrSignatureInvalid | 
					
						
							|  |  |  |  | 		} | 
					
						
							|  |  |  |  | 		return []byte(m.config.JWT.Secret), nil | 
					
						
							|  |  |  |  | 	}) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 	if err != nil { | 
					
						
							|  |  |  |  | 		return "", err | 
					
						
							|  |  |  |  | 	} | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 	claims, ok := token.Claims.(*jwt.RegisteredClaims) | 
					
						
							|  |  |  |  | 	if !ok || !token.Valid { | 
					
						
							|  |  |  |  | 		return "", jwt.ErrSignatureInvalid | 
					
						
							|  |  |  |  | 	} | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 	// 检查是否为刷新token | 
					
						
							|  |  |  |  | 	if len(claims.Audience) == 0 || claims.Audience[0] != "tyapi-refresh" { | 
					
						
							|  |  |  |  | 		return "", jwt.ErrSignatureInvalid | 
					
						
							|  |  |  |  | 	} | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 	return claims.Subject, nil | 
					
						
							|  |  |  |  | } | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | // OptionalAuthMiddleware 可选认证中间件(用户可能登录也可能未登录) | 
					
						
							|  |  |  |  | type OptionalAuthMiddleware struct { | 
					
						
							|  |  |  |  | 	jwtAuth *JWTAuthMiddleware | 
					
						
							|  |  |  |  | } | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | // NewOptionalAuthMiddleware 创建可选认证中间件 | 
					
						
							|  |  |  |  | func NewOptionalAuthMiddleware(jwtAuth *JWTAuthMiddleware) *OptionalAuthMiddleware { | 
					
						
							|  |  |  |  | 	return &OptionalAuthMiddleware{ | 
					
						
							|  |  |  |  | 		jwtAuth: jwtAuth, | 
					
						
							|  |  |  |  | 	} | 
					
						
							|  |  |  |  | } | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | // GetName 返回中间件名称 | 
					
						
							|  |  |  |  | func (m *OptionalAuthMiddleware) GetName() string { | 
					
						
							|  |  |  |  | 	return "optional_auth" | 
					
						
							|  |  |  |  | } | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | // GetPriority 返回中间件优先级 | 
					
						
							|  |  |  |  | func (m *OptionalAuthMiddleware) GetPriority() int { | 
					
						
							|  |  |  |  | 	return 60 // 与JWT认证中间件相同 | 
					
						
							|  |  |  |  | } | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | // Handle 返回中间件处理函数 | 
					
						
							|  |  |  |  | func (m *OptionalAuthMiddleware) Handle() gin.HandlerFunc { | 
					
						
							|  |  |  |  | 	return func(c *gin.Context) { | 
					
						
							|  |  |  |  | 		// 获取Authorization头部 | 
					
						
							|  |  |  |  | 		authHeader := c.GetHeader("Authorization") | 
					
						
							|  |  |  |  | 		if authHeader == "" { | 
					
						
							|  |  |  |  | 			// 没有认证头部,设置匿名用户标识 | 
					
						
							|  |  |  |  | 			c.Set("is_authenticated", false) | 
					
						
							|  |  |  |  | 			c.Next() | 
					
						
							|  |  |  |  | 			return | 
					
						
							|  |  |  |  | 		} | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 		// 检查Bearer前缀 | 
					
						
							|  |  |  |  | 		const bearerPrefix = "Bearer " | 
					
						
							|  |  |  |  | 		if !strings.HasPrefix(authHeader, bearerPrefix) { | 
					
						
							|  |  |  |  | 			c.Set("is_authenticated", false) | 
					
						
							|  |  |  |  | 			c.Next() | 
					
						
							|  |  |  |  | 			return | 
					
						
							|  |  |  |  | 		} | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 		// 提取并验证token | 
					
						
							|  |  |  |  | 		tokenString := authHeader[len(bearerPrefix):] | 
					
						
							|  |  |  |  | 		claims, err := m.jwtAuth.validateToken(tokenString) | 
					
						
							|  |  |  |  | 		if err != nil { | 
					
						
							|  |  |  |  | 			// token无效,但不返回错误,设置为未认证 | 
					
						
							|  |  |  |  | 			c.Set("is_authenticated", false) | 
					
						
							|  |  |  |  | 			c.Next() | 
					
						
							|  |  |  |  | 			return | 
					
						
							|  |  |  |  | 		} | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 		// token有效,设置用户信息 | 
					
						
							|  |  |  |  | 		c.Set("is_authenticated", true) | 
					
						
							|  |  |  |  | 		c.Set("user_id", claims.UserID) | 
					
						
							|  |  |  |  | 		c.Set("username", claims.Username) | 
					
						
							|  |  |  |  | 		c.Set("email", claims.Email) | 
					
						
							|  |  |  |  | 		c.Set("token_claims", claims) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 		c.Next() | 
					
						
							|  |  |  |  | 	} | 
					
						
							|  |  |  |  | } | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | // IsGlobal 是否为全局中间件 | 
					
						
							|  |  |  |  | func (m *OptionalAuthMiddleware) IsGlobal() bool { | 
					
						
							|  |  |  |  | 	return false | 
					
						
							|  |  |  |  | } |