380 lines
		
	
	
		
			9.7 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			380 lines
		
	
	
		
			9.7 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package middleware
 | ||
| 
 | ||
| import (
 | ||
| 	"net/http"
 | ||
| 	"strings"
 | ||
| 	"time"
 | ||
| 
 | ||
| 	"tyapi-server/internal/config"
 | ||
| 
 | ||
| 	"github.com/gin-gonic/gin"
 | ||
| 	"github.com/golang-jwt/jwt/v5"
 | ||
| 	"go.uber.org/zap"
 | ||
| )
 | ||
| 
 | ||
| // JWTAuthMiddleware JWT认证中间件
 | ||
| type JWTAuthMiddleware struct {
 | ||
| 	config *config.Config
 | ||
| 	logger *zap.Logger
 | ||
| }
 | ||
| 
 | ||
| // NewJWTAuthMiddleware 创建JWT认证中间件
 | ||
| func NewJWTAuthMiddleware(cfg *config.Config, logger *zap.Logger) *JWTAuthMiddleware {
 | ||
| 	return &JWTAuthMiddleware{
 | ||
| 		config: cfg,
 | ||
| 		logger: logger,
 | ||
| 	}
 | ||
| }
 | ||
| 
 | ||
| // GetName 返回中间件名称
 | ||
| func (m *JWTAuthMiddleware) GetName() string {
 | ||
| 	return "jwt_auth"
 | ||
| }
 | ||
| 
 | ||
| // GetExpiresIn 返回JWT过期时间
 | ||
| func (m *JWTAuthMiddleware) GetExpiresIn() time.Duration {
 | ||
| 	return m.config.JWT.ExpiresIn
 | ||
| }
 | ||
| 
 | ||
| // GetPriority 返回中间件优先级
 | ||
| func (m *JWTAuthMiddleware) GetPriority() int {
 | ||
| 	return 60 // 中等优先级,在日志之后,业务处理之前
 | ||
| }
 | ||
| 
 | ||
| // Handle 返回中间件处理函数
 | ||
| func (m *JWTAuthMiddleware) Handle() gin.HandlerFunc {
 | ||
| 	return func(c *gin.Context) {
 | ||
| 		// 获取Authorization头部
 | ||
| 		authHeader := c.GetHeader("Authorization")
 | ||
| 		if authHeader == "" {
 | ||
| 			m.respondUnauthorized(c, "缺少认证头部")
 | ||
| 			return
 | ||
| 		}
 | ||
| 
 | ||
| 		// 检查Bearer前缀
 | ||
| 		const bearerPrefix = "Bearer "
 | ||
| 		if !strings.HasPrefix(authHeader, bearerPrefix) {
 | ||
| 			m.respondUnauthorized(c, "认证头部格式无效")
 | ||
| 			return
 | ||
| 		}
 | ||
| 
 | ||
| 		// 提取token
 | ||
| 		tokenString := authHeader[len(bearerPrefix):]
 | ||
| 		if tokenString == "" {
 | ||
| 			m.respondUnauthorized(c, "缺少认证令牌")
 | ||
| 			return
 | ||
| 		}
 | ||
| 
 | ||
| 		// 验证token
 | ||
| 		claims, err := m.validateToken(tokenString)
 | ||
| 		if err != nil {
 | ||
| 			m.logger.Warn("无效的认证令牌",
 | ||
| 				zap.Error(err),
 | ||
| 				zap.String("request_id", c.GetString("request_id")))
 | ||
| 			m.respondUnauthorized(c, "认证令牌无效")
 | ||
| 			return
 | ||
| 		}
 | ||
| 
 | ||
| 		// 将用户信息添加到上下文
 | ||
| 		c.Set("user_id", claims.UserID)
 | ||
| 		c.Set("username", claims.Username)
 | ||
| 		c.Set("email", claims.Email)
 | ||
| 		c.Set("phone", claims.Phone)
 | ||
| 		c.Set("user_type", claims.UserType)
 | ||
| 		c.Set("token_claims", claims)
 | ||
| 
 | ||
| 		c.Next()
 | ||
| 	}
 | ||
| }
 | ||
| 
 | ||
| // IsGlobal 是否为全局中间件
 | ||
| func (m *JWTAuthMiddleware) IsGlobal() bool {
 | ||
| 	return false // 不是全局中间件,需要手动应用到需要认证的路由
 | ||
| }
 | ||
| 
 | ||
| // JWTClaims JWT声明结构
 | ||
| type JWTClaims struct {
 | ||
| 	UserID   string `json:"user_id"`
 | ||
| 	Username string `json:"username"`
 | ||
| 	Email    string `json:"email"`
 | ||
| 	Phone    string `json:"phone"`
 | ||
| 	UserType string `json:"user_type"` // 新增:用户类型
 | ||
| 	jwt.RegisteredClaims
 | ||
| }
 | ||
| 
 | ||
| // validateToken 验证JWT token
 | ||
| func (m *JWTAuthMiddleware) validateToken(tokenString string) (*JWTClaims, error) {
 | ||
| 	token, err := jwt.ParseWithClaims(tokenString, &JWTClaims{}, func(token *jwt.Token) (interface{}, error) {
 | ||
| 		// 验证签名方法
 | ||
| 		if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
 | ||
| 			return nil, jwt.ErrSignatureInvalid
 | ||
| 		}
 | ||
| 		return []byte(m.config.JWT.Secret), nil
 | ||
| 	})
 | ||
| 
 | ||
| 	if err != nil {
 | ||
| 		return nil, err
 | ||
| 	}
 | ||
| 
 | ||
| 	claims, ok := token.Claims.(*JWTClaims)
 | ||
| 	if !ok || !token.Valid {
 | ||
| 		return nil, jwt.ErrSignatureInvalid
 | ||
| 	}
 | ||
| 
 | ||
| 	return claims, nil
 | ||
| }
 | ||
| 
 | ||
| // respondUnauthorized 返回未授权响应
 | ||
| func (m *JWTAuthMiddleware) respondUnauthorized(c *gin.Context, message string) {
 | ||
| 	c.JSON(http.StatusUnauthorized, gin.H{
 | ||
| 		"success":    false,
 | ||
| 		"message":    "认证失败",
 | ||
| 		"error":      message,
 | ||
| 		"request_id": c.GetString("request_id"),
 | ||
| 		"timestamp":  time.Now().Unix(),
 | ||
| 	})
 | ||
| 	c.Abort()
 | ||
| }
 | ||
| 
 | ||
| // GenerateToken 生成JWT token
 | ||
| func (m *JWTAuthMiddleware) GenerateToken(userID, phone, email, userType string) (string, error) {
 | ||
| 	now := time.Now()
 | ||
| 
 | ||
| 	claims := &JWTClaims{
 | ||
| 		UserID:   userID,
 | ||
| 		Username: phone, // 普通用户用手机号,管理员用用户名
 | ||
| 		Email:    email,
 | ||
| 		Phone:    phone,
 | ||
| 		UserType: userType, // 新增:用户类型
 | ||
| 		RegisteredClaims: jwt.RegisteredClaims{
 | ||
| 			Issuer:    "tyapi-server",
 | ||
| 			Subject:   userID,
 | ||
| 			Audience:  []string{"tyapi-client"},
 | ||
| 			ExpiresAt: jwt.NewNumericDate(now.Add(m.config.JWT.ExpiresIn)),
 | ||
| 			NotBefore: jwt.NewNumericDate(now),
 | ||
| 			IssuedAt:  jwt.NewNumericDate(now),
 | ||
| 		},
 | ||
| 	}
 | ||
| 
 | ||
| 	token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
 | ||
| 	return token.SignedString([]byte(m.config.JWT.Secret))
 | ||
| }
 | ||
| 
 | ||
| // GenerateRefreshToken 生成刷新token
 | ||
| func (m *JWTAuthMiddleware) GenerateRefreshToken(userID string) (string, error) {
 | ||
| 	now := time.Now()
 | ||
| 
 | ||
| 	claims := &jwt.RegisteredClaims{
 | ||
| 		Issuer:    "tyapi-server",
 | ||
| 		Subject:   userID,
 | ||
| 		Audience:  []string{"tyapi-refresh"},
 | ||
| 		ExpiresAt: jwt.NewNumericDate(now.Add(m.config.JWT.RefreshExpiresIn)),
 | ||
| 		NotBefore: jwt.NewNumericDate(now),
 | ||
| 		IssuedAt:  jwt.NewNumericDate(now),
 | ||
| 	}
 | ||
| 
 | ||
| 	token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
 | ||
| 	return token.SignedString([]byte(m.config.JWT.Secret))
 | ||
| }
 | ||
| 
 | ||
| // ValidateRefreshToken 验证刷新token
 | ||
| func (m *JWTAuthMiddleware) ValidateRefreshToken(tokenString string) (string, error) {
 | ||
| 	token, err := jwt.ParseWithClaims(tokenString, &jwt.RegisteredClaims{}, func(token *jwt.Token) (interface{}, error) {
 | ||
| 		if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
 | ||
| 			return nil, jwt.ErrSignatureInvalid
 | ||
| 		}
 | ||
| 		return []byte(m.config.JWT.Secret), nil
 | ||
| 	})
 | ||
| 
 | ||
| 	if err != nil {
 | ||
| 		return "", err
 | ||
| 	}
 | ||
| 
 | ||
| 	claims, ok := token.Claims.(*jwt.RegisteredClaims)
 | ||
| 	if !ok || !token.Valid {
 | ||
| 		return "", jwt.ErrSignatureInvalid
 | ||
| 	}
 | ||
| 
 | ||
| 	// 检查是否为刷新token
 | ||
| 	if len(claims.Audience) == 0 || claims.Audience[0] != "tyapi-refresh" {
 | ||
| 		return "", jwt.ErrSignatureInvalid
 | ||
| 	}
 | ||
| 
 | ||
| 	return claims.Subject, nil
 | ||
| }
 | ||
| 
 | ||
| // OptionalAuthMiddleware 可选认证中间件(用户可能登录也可能未登录)
 | ||
| type OptionalAuthMiddleware struct {
 | ||
| 	jwtAuth *JWTAuthMiddleware
 | ||
| }
 | ||
| 
 | ||
| // NewOptionalAuthMiddleware 创建可选认证中间件
 | ||
| func NewOptionalAuthMiddleware(jwtAuth *JWTAuthMiddleware) *OptionalAuthMiddleware {
 | ||
| 	return &OptionalAuthMiddleware{
 | ||
| 		jwtAuth: jwtAuth,
 | ||
| 	}
 | ||
| }
 | ||
| 
 | ||
| // GetName 返回中间件名称
 | ||
| func (m *OptionalAuthMiddleware) GetName() string {
 | ||
| 	return "optional_auth"
 | ||
| }
 | ||
| 
 | ||
| // GetPriority 返回中间件优先级
 | ||
| func (m *OptionalAuthMiddleware) GetPriority() int {
 | ||
| 	return 60 // 与JWT认证中间件相同
 | ||
| }
 | ||
| 
 | ||
| // Handle 返回中间件处理函数
 | ||
| func (m *OptionalAuthMiddleware) Handle() gin.HandlerFunc {
 | ||
| 	return func(c *gin.Context) {
 | ||
| 		// 获取Authorization头部
 | ||
| 		authHeader := c.GetHeader("Authorization")
 | ||
| 		if authHeader == "" {
 | ||
| 			// 没有认证头部,设置匿名用户标识
 | ||
| 			c.Set("is_authenticated", false)
 | ||
| 			c.Next()
 | ||
| 			return
 | ||
| 		}
 | ||
| 
 | ||
| 		// 检查Bearer前缀
 | ||
| 		const bearerPrefix = "Bearer "
 | ||
| 		if !strings.HasPrefix(authHeader, bearerPrefix) {
 | ||
| 			c.Set("is_authenticated", false)
 | ||
| 			c.Next()
 | ||
| 			return
 | ||
| 		}
 | ||
| 
 | ||
| 		// 提取并验证token
 | ||
| 		tokenString := authHeader[len(bearerPrefix):]
 | ||
| 		claims, err := m.jwtAuth.validateToken(tokenString)
 | ||
| 		if err != nil {
 | ||
| 			// token无效,但不返回错误,设置为未认证
 | ||
| 			c.Set("is_authenticated", false)
 | ||
| 			c.Next()
 | ||
| 			return
 | ||
| 		}
 | ||
| 
 | ||
| 		// token有效,设置用户信息
 | ||
| 		c.Set("is_authenticated", true)
 | ||
| 		c.Set("user_id", claims.UserID)
 | ||
| 		c.Set("username", claims.Username)
 | ||
| 		c.Set("email", claims.Email)
 | ||
| 		c.Set("phone", claims.Phone)
 | ||
| 		c.Set("user_type", claims.UserType)
 | ||
| 		c.Set("token_claims", claims)
 | ||
| 
 | ||
| 		c.Next()
 | ||
| 	}
 | ||
| }
 | ||
| 
 | ||
| // IsGlobal 是否为全局中间件
 | ||
| func (m *OptionalAuthMiddleware) IsGlobal() bool {
 | ||
| 	return false
 | ||
| }
 | ||
| 
 | ||
| // AdminAuthMiddleware 管理员认证中间件
 | ||
| type AdminAuthMiddleware struct {
 | ||
| 	jwtAuth *JWTAuthMiddleware
 | ||
| 	logger  *zap.Logger
 | ||
| }
 | ||
| 
 | ||
| // NewAdminAuthMiddleware 创建管理员认证中间件
 | ||
| func NewAdminAuthMiddleware(jwtAuth *JWTAuthMiddleware, logger *zap.Logger) *AdminAuthMiddleware {
 | ||
| 	return &AdminAuthMiddleware{
 | ||
| 		jwtAuth: jwtAuth,
 | ||
| 		logger:  logger,
 | ||
| 	}
 | ||
| }
 | ||
| 
 | ||
| // GetName 返回中间件名称
 | ||
| func (m *AdminAuthMiddleware) GetName() string {
 | ||
| 	return "admin_auth"
 | ||
| }
 | ||
| 
 | ||
| // GetPriority 返回中间件优先级
 | ||
| func (m *AdminAuthMiddleware) GetPriority() int {
 | ||
| 	return 60 // 与JWT认证中间件相同
 | ||
| }
 | ||
| 
 | ||
| // Handle 管理员认证处理
 | ||
| func (m *AdminAuthMiddleware) Handle() gin.HandlerFunc {
 | ||
| 	return func(c *gin.Context) {
 | ||
| 		// 首先进行JWT认证
 | ||
| 		authHeader := c.GetHeader("Authorization")
 | ||
| 		if authHeader == "" {
 | ||
| 			m.respondUnauthorized(c, "缺少认证头部")
 | ||
| 			return
 | ||
| 		}
 | ||
| 
 | ||
| 		// 检查Bearer前缀
 | ||
| 		const bearerPrefix = "Bearer "
 | ||
| 		if !strings.HasPrefix(authHeader, bearerPrefix) {
 | ||
| 			m.respondUnauthorized(c, "认证头部格式无效")
 | ||
| 			return
 | ||
| 		}
 | ||
| 
 | ||
| 		// 提取token
 | ||
| 		tokenString := authHeader[len(bearerPrefix):]
 | ||
| 		if tokenString == "" {
 | ||
| 			m.respondUnauthorized(c, "缺少认证令牌")
 | ||
| 			return
 | ||
| 		}
 | ||
| 
 | ||
| 		// 验证token
 | ||
| 		claims, err := m.jwtAuth.validateToken(tokenString)
 | ||
| 		if err != nil {
 | ||
| 			m.logger.Warn("无效的认证令牌",
 | ||
| 				zap.Error(err),
 | ||
| 				zap.String("request_id", c.GetString("request_id")))
 | ||
| 			m.respondUnauthorized(c, "认证令牌无效")
 | ||
| 			return
 | ||
| 		}
 | ||
| 
 | ||
| 		// 检查用户类型是否为管理员
 | ||
| 		if claims.UserType != "admin" {
 | ||
| 			m.respondForbidden(c, "需要管理员权限")
 | ||
| 			return
 | ||
| 		}
 | ||
| 
 | ||
| 		// 设置用户信息到上下文
 | ||
| 		c.Set("user_id", claims.UserID)
 | ||
| 		c.Set("username", claims.Username)
 | ||
| 		c.Set("email", claims.Email)
 | ||
| 		c.Set("phone", claims.Phone)
 | ||
| 		c.Set("user_type", claims.UserType)
 | ||
| 		c.Set("token_claims", claims)
 | ||
| 
 | ||
| 		c.Next()
 | ||
| 	}
 | ||
| }
 | ||
| 
 | ||
| // IsGlobal 是否为全局中间件
 | ||
| func (m *AdminAuthMiddleware) IsGlobal() bool {
 | ||
| 	return false
 | ||
| }
 | ||
| 
 | ||
| // respondForbidden 返回禁止访问响应
 | ||
| func (m *AdminAuthMiddleware) respondForbidden(c *gin.Context, message string) {
 | ||
| 	c.JSON(http.StatusForbidden, gin.H{
 | ||
| 		"success":    false,
 | ||
| 		"message":    "权限不足",
 | ||
| 		"error":      message,
 | ||
| 		"request_id": c.GetString("request_id"),
 | ||
| 		"timestamp":  time.Now().Unix(),
 | ||
| 	})
 | ||
| 	c.Abort()
 | ||
| }
 | ||
| 
 | ||
| // respondUnauthorized 返回未授权响应
 | ||
| func (m *AdminAuthMiddleware) respondUnauthorized(c *gin.Context, message string) {
 | ||
| 	c.JSON(http.StatusUnauthorized, gin.H{
 | ||
| 		"success":    false,
 | ||
| 		"message":    "认证失败",
 | ||
| 		"error":      message,
 | ||
| 		"request_id": c.GetString("request_id"),
 | ||
| 		"timestamp":  time.Now().Unix(),
 | ||
| 	})
 | ||
| 	c.Abort()
 | ||
| }
 |