Files
tyapi-server/internal/shared/middleware/auth.go
2025-07-20 20:53:26 +08:00

380 lines
9.7 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package middleware
import (
"net/http"
"strings"
"time"
"tyapi-server/internal/config"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v5"
"go.uber.org/zap"
)
// JWTAuthMiddleware JWT认证中间件
type JWTAuthMiddleware struct {
config *config.Config
logger *zap.Logger
}
// NewJWTAuthMiddleware 创建JWT认证中间件
func NewJWTAuthMiddleware(cfg *config.Config, logger *zap.Logger) *JWTAuthMiddleware {
return &JWTAuthMiddleware{
config: cfg,
logger: logger,
}
}
// GetName 返回中间件名称
func (m *JWTAuthMiddleware) GetName() string {
return "jwt_auth"
}
// GetExpiresIn 返回JWT过期时间
func (m *JWTAuthMiddleware) GetExpiresIn() time.Duration {
return m.config.JWT.ExpiresIn
}
// GetPriority 返回中间件优先级
func (m *JWTAuthMiddleware) GetPriority() int {
return 60 // 中等优先级,在日志之后,业务处理之前
}
// Handle 返回中间件处理函数
func (m *JWTAuthMiddleware) Handle() gin.HandlerFunc {
return func(c *gin.Context) {
// 获取Authorization头部
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
m.respondUnauthorized(c, "缺少认证头部")
return
}
// 检查Bearer前缀
const bearerPrefix = "Bearer "
if !strings.HasPrefix(authHeader, bearerPrefix) {
m.respondUnauthorized(c, "认证头部格式无效")
return
}
// 提取token
tokenString := authHeader[len(bearerPrefix):]
if tokenString == "" {
m.respondUnauthorized(c, "缺少认证令牌")
return
}
// 验证token
claims, err := m.validateToken(tokenString)
if err != nil {
m.logger.Warn("无效的认证令牌",
zap.Error(err),
zap.String("request_id", c.GetString("request_id")))
m.respondUnauthorized(c, "认证令牌无效")
return
}
// 将用户信息添加到上下文
c.Set("user_id", claims.UserID)
c.Set("username", claims.Username)
c.Set("email", claims.Email)
c.Set("phone", claims.Phone)
c.Set("user_type", claims.UserType)
c.Set("token_claims", claims)
c.Next()
}
}
// IsGlobal 是否为全局中间件
func (m *JWTAuthMiddleware) IsGlobal() bool {
return false // 不是全局中间件,需要手动应用到需要认证的路由
}
// JWTClaims JWT声明结构
type JWTClaims struct {
UserID string `json:"user_id"`
Username string `json:"username"`
Email string `json:"email"`
Phone string `json:"phone"`
UserType string `json:"user_type"` // 新增:用户类型
jwt.RegisteredClaims
}
// validateToken 验证JWT token
func (m *JWTAuthMiddleware) validateToken(tokenString string) (*JWTClaims, error) {
token, err := jwt.ParseWithClaims(tokenString, &JWTClaims{}, func(token *jwt.Token) (interface{}, error) {
// 验证签名方法
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, jwt.ErrSignatureInvalid
}
return []byte(m.config.JWT.Secret), nil
})
if err != nil {
return nil, err
}
claims, ok := token.Claims.(*JWTClaims)
if !ok || !token.Valid {
return nil, jwt.ErrSignatureInvalid
}
return claims, nil
}
// respondUnauthorized 返回未授权响应
func (m *JWTAuthMiddleware) respondUnauthorized(c *gin.Context, message string) {
c.JSON(http.StatusUnauthorized, gin.H{
"success": false,
"message": "认证失败",
"error": message,
"request_id": c.GetString("request_id"),
"timestamp": time.Now().Unix(),
})
c.Abort()
}
// GenerateToken 生成JWT token
func (m *JWTAuthMiddleware) GenerateToken(userID, phone, email, userType string) (string, error) {
now := time.Now()
claims := &JWTClaims{
UserID: userID,
Username: phone, // 普通用户用手机号,管理员用用户名
Email: email,
Phone: phone,
UserType: userType, // 新增:用户类型
RegisteredClaims: jwt.RegisteredClaims{
Issuer: "tyapi-server",
Subject: userID,
Audience: []string{"tyapi-client"},
ExpiresAt: jwt.NewNumericDate(now.Add(m.config.JWT.ExpiresIn)),
NotBefore: jwt.NewNumericDate(now),
IssuedAt: jwt.NewNumericDate(now),
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
return token.SignedString([]byte(m.config.JWT.Secret))
}
// GenerateRefreshToken 生成刷新token
func (m *JWTAuthMiddleware) GenerateRefreshToken(userID string) (string, error) {
now := time.Now()
claims := &jwt.RegisteredClaims{
Issuer: "tyapi-server",
Subject: userID,
Audience: []string{"tyapi-refresh"},
ExpiresAt: jwt.NewNumericDate(now.Add(m.config.JWT.RefreshExpiresIn)),
NotBefore: jwt.NewNumericDate(now),
IssuedAt: jwt.NewNumericDate(now),
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
return token.SignedString([]byte(m.config.JWT.Secret))
}
// ValidateRefreshToken 验证刷新token
func (m *JWTAuthMiddleware) ValidateRefreshToken(tokenString string) (string, error) {
token, err := jwt.ParseWithClaims(tokenString, &jwt.RegisteredClaims{}, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, jwt.ErrSignatureInvalid
}
return []byte(m.config.JWT.Secret), nil
})
if err != nil {
return "", err
}
claims, ok := token.Claims.(*jwt.RegisteredClaims)
if !ok || !token.Valid {
return "", jwt.ErrSignatureInvalid
}
// 检查是否为刷新token
if len(claims.Audience) == 0 || claims.Audience[0] != "tyapi-refresh" {
return "", jwt.ErrSignatureInvalid
}
return claims.Subject, nil
}
// OptionalAuthMiddleware 可选认证中间件(用户可能登录也可能未登录)
type OptionalAuthMiddleware struct {
jwtAuth *JWTAuthMiddleware
}
// NewOptionalAuthMiddleware 创建可选认证中间件
func NewOptionalAuthMiddleware(jwtAuth *JWTAuthMiddleware) *OptionalAuthMiddleware {
return &OptionalAuthMiddleware{
jwtAuth: jwtAuth,
}
}
// GetName 返回中间件名称
func (m *OptionalAuthMiddleware) GetName() string {
return "optional_auth"
}
// GetPriority 返回中间件优先级
func (m *OptionalAuthMiddleware) GetPriority() int {
return 60 // 与JWT认证中间件相同
}
// Handle 返回中间件处理函数
func (m *OptionalAuthMiddleware) Handle() gin.HandlerFunc {
return func(c *gin.Context) {
// 获取Authorization头部
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
// 没有认证头部,设置匿名用户标识
c.Set("is_authenticated", false)
c.Next()
return
}
// 检查Bearer前缀
const bearerPrefix = "Bearer "
if !strings.HasPrefix(authHeader, bearerPrefix) {
c.Set("is_authenticated", false)
c.Next()
return
}
// 提取并验证token
tokenString := authHeader[len(bearerPrefix):]
claims, err := m.jwtAuth.validateToken(tokenString)
if err != nil {
// token无效但不返回错误设置为未认证
c.Set("is_authenticated", false)
c.Next()
return
}
// token有效设置用户信息
c.Set("is_authenticated", true)
c.Set("user_id", claims.UserID)
c.Set("username", claims.Username)
c.Set("email", claims.Email)
c.Set("phone", claims.Phone)
c.Set("user_type", claims.UserType)
c.Set("token_claims", claims)
c.Next()
}
}
// IsGlobal 是否为全局中间件
func (m *OptionalAuthMiddleware) IsGlobal() bool {
return false
}
// AdminAuthMiddleware 管理员认证中间件
type AdminAuthMiddleware struct {
jwtAuth *JWTAuthMiddleware
logger *zap.Logger
}
// NewAdminAuthMiddleware 创建管理员认证中间件
func NewAdminAuthMiddleware(jwtAuth *JWTAuthMiddleware, logger *zap.Logger) *AdminAuthMiddleware {
return &AdminAuthMiddleware{
jwtAuth: jwtAuth,
logger: logger,
}
}
// GetName 返回中间件名称
func (m *AdminAuthMiddleware) GetName() string {
return "admin_auth"
}
// GetPriority 返回中间件优先级
func (m *AdminAuthMiddleware) GetPriority() int {
return 60 // 与JWT认证中间件相同
}
// Handle 管理员认证处理
func (m *AdminAuthMiddleware) Handle() gin.HandlerFunc {
return func(c *gin.Context) {
// 首先进行JWT认证
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
m.respondUnauthorized(c, "缺少认证头部")
return
}
// 检查Bearer前缀
const bearerPrefix = "Bearer "
if !strings.HasPrefix(authHeader, bearerPrefix) {
m.respondUnauthorized(c, "认证头部格式无效")
return
}
// 提取token
tokenString := authHeader[len(bearerPrefix):]
if tokenString == "" {
m.respondUnauthorized(c, "缺少认证令牌")
return
}
// 验证token
claims, err := m.jwtAuth.validateToken(tokenString)
if err != nil {
m.logger.Warn("无效的认证令牌",
zap.Error(err),
zap.String("request_id", c.GetString("request_id")))
m.respondUnauthorized(c, "认证令牌无效")
return
}
// 检查用户类型是否为管理员
if claims.UserType != "admin" {
m.respondForbidden(c, "需要管理员权限")
return
}
// 设置用户信息到上下文
c.Set("user_id", claims.UserID)
c.Set("username", claims.Username)
c.Set("email", claims.Email)
c.Set("phone", claims.Phone)
c.Set("user_type", claims.UserType)
c.Set("token_claims", claims)
c.Next()
}
}
// IsGlobal 是否为全局中间件
func (m *AdminAuthMiddleware) IsGlobal() bool {
return false
}
// respondForbidden 返回禁止访问响应
func (m *AdminAuthMiddleware) respondForbidden(c *gin.Context, message string) {
c.JSON(http.StatusForbidden, gin.H{
"success": false,
"message": "权限不足",
"error": message,
"request_id": c.GetString("request_id"),
"timestamp": time.Now().Unix(),
})
c.Abort()
}
// respondUnauthorized 返回未授权响应
func (m *AdminAuthMiddleware) respondUnauthorized(c *gin.Context, message string) {
c.JSON(http.StatusUnauthorized, gin.H{
"success": false,
"message": "认证失败",
"error": message,
"request_id": c.GetString("request_id"),
"timestamp": time.Now().Unix(),
})
c.Abort()
}