Files
tyapi-server/internal/shared/middleware/auth.go

380 lines
9.7 KiB
Go
Raw Normal View History

package middleware
import (
"net/http"
"strings"
"time"
"tyapi-server/internal/config"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v5"
"go.uber.org/zap"
)
// JWTAuthMiddleware JWT认证中间件
type JWTAuthMiddleware struct {
config *config.Config
logger *zap.Logger
}
// NewJWTAuthMiddleware 创建JWT认证中间件
func NewJWTAuthMiddleware(cfg *config.Config, logger *zap.Logger) *JWTAuthMiddleware {
return &JWTAuthMiddleware{
config: cfg,
logger: logger,
}
}
// GetName 返回中间件名称
func (m *JWTAuthMiddleware) GetName() string {
return "jwt_auth"
}
2025-07-20 20:53:26 +08:00
// GetExpiresIn 返回JWT过期时间
func (m *JWTAuthMiddleware) GetExpiresIn() time.Duration {
return m.config.JWT.ExpiresIn
}
// GetPriority 返回中间件优先级
func (m *JWTAuthMiddleware) GetPriority() int {
return 60 // 中等优先级,在日志之后,业务处理之前
}
// Handle 返回中间件处理函数
func (m *JWTAuthMiddleware) Handle() gin.HandlerFunc {
return func(c *gin.Context) {
// 获取Authorization头部
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
2025-07-02 16:17:59 +08:00
m.respondUnauthorized(c, "缺少认证头部")
return
}
// 检查Bearer前缀
const bearerPrefix = "Bearer "
if !strings.HasPrefix(authHeader, bearerPrefix) {
2025-07-02 16:17:59 +08:00
m.respondUnauthorized(c, "认证头部格式无效")
return
}
// 提取token
tokenString := authHeader[len(bearerPrefix):]
if tokenString == "" {
2025-07-02 16:17:59 +08:00
m.respondUnauthorized(c, "缺少认证令牌")
return
}
// 验证token
claims, err := m.validateToken(tokenString)
if err != nil {
2025-07-02 16:17:59 +08:00
m.logger.Warn("无效的认证令牌",
zap.Error(err),
zap.String("request_id", c.GetString("request_id")))
2025-07-02 16:17:59 +08:00
m.respondUnauthorized(c, "认证令牌无效")
return
}
// 将用户信息添加到上下文
c.Set("user_id", claims.UserID)
c.Set("username", claims.Username)
c.Set("email", claims.Email)
2025-07-20 20:53:26 +08:00
c.Set("phone", claims.Phone)
c.Set("user_type", claims.UserType)
c.Set("token_claims", claims)
c.Next()
}
}
// IsGlobal 是否为全局中间件
func (m *JWTAuthMiddleware) IsGlobal() bool {
return false // 不是全局中间件,需要手动应用到需要认证的路由
}
// JWTClaims JWT声明结构
type JWTClaims struct {
UserID string `json:"user_id"`
Username string `json:"username"`
Email string `json:"email"`
2025-07-20 20:53:26 +08:00
Phone string `json:"phone"`
UserType string `json:"user_type"` // 新增:用户类型
jwt.RegisteredClaims
}
// validateToken 验证JWT token
func (m *JWTAuthMiddleware) validateToken(tokenString string) (*JWTClaims, error) {
token, err := jwt.ParseWithClaims(tokenString, &JWTClaims{}, func(token *jwt.Token) (interface{}, error) {
// 验证签名方法
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, jwt.ErrSignatureInvalid
}
return []byte(m.config.JWT.Secret), nil
})
if err != nil {
return nil, err
}
claims, ok := token.Claims.(*JWTClaims)
if !ok || !token.Valid {
return nil, jwt.ErrSignatureInvalid
}
return claims, nil
}
// respondUnauthorized 返回未授权响应
func (m *JWTAuthMiddleware) respondUnauthorized(c *gin.Context, message string) {
c.JSON(http.StatusUnauthorized, gin.H{
"success": false,
2025-07-02 16:17:59 +08:00
"message": "认证失败",
"error": message,
"request_id": c.GetString("request_id"),
"timestamp": time.Now().Unix(),
})
c.Abort()
}
// GenerateToken 生成JWT token
2025-07-20 20:53:26 +08:00
func (m *JWTAuthMiddleware) GenerateToken(userID, phone, email, userType string) (string, error) {
now := time.Now()
claims := &JWTClaims{
UserID: userID,
2025-07-20 20:53:26 +08:00
Username: phone, // 普通用户用手机号,管理员用用户名
Email: email,
2025-07-20 20:53:26 +08:00
Phone: phone,
UserType: userType, // 新增:用户类型
RegisteredClaims: jwt.RegisteredClaims{
Issuer: "tyapi-server",
Subject: userID,
Audience: []string{"tyapi-client"},
ExpiresAt: jwt.NewNumericDate(now.Add(m.config.JWT.ExpiresIn)),
NotBefore: jwt.NewNumericDate(now),
IssuedAt: jwt.NewNumericDate(now),
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
return token.SignedString([]byte(m.config.JWT.Secret))
}
// GenerateRefreshToken 生成刷新token
func (m *JWTAuthMiddleware) GenerateRefreshToken(userID string) (string, error) {
now := time.Now()
claims := &jwt.RegisteredClaims{
Issuer: "tyapi-server",
Subject: userID,
Audience: []string{"tyapi-refresh"},
ExpiresAt: jwt.NewNumericDate(now.Add(m.config.JWT.RefreshExpiresIn)),
NotBefore: jwt.NewNumericDate(now),
IssuedAt: jwt.NewNumericDate(now),
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
return token.SignedString([]byte(m.config.JWT.Secret))
}
// ValidateRefreshToken 验证刷新token
func (m *JWTAuthMiddleware) ValidateRefreshToken(tokenString string) (string, error) {
token, err := jwt.ParseWithClaims(tokenString, &jwt.RegisteredClaims{}, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, jwt.ErrSignatureInvalid
}
return []byte(m.config.JWT.Secret), nil
})
if err != nil {
return "", err
}
claims, ok := token.Claims.(*jwt.RegisteredClaims)
if !ok || !token.Valid {
return "", jwt.ErrSignatureInvalid
}
// 检查是否为刷新token
if len(claims.Audience) == 0 || claims.Audience[0] != "tyapi-refresh" {
return "", jwt.ErrSignatureInvalid
}
return claims.Subject, nil
}
// OptionalAuthMiddleware 可选认证中间件(用户可能登录也可能未登录)
type OptionalAuthMiddleware struct {
jwtAuth *JWTAuthMiddleware
}
// NewOptionalAuthMiddleware 创建可选认证中间件
func NewOptionalAuthMiddleware(jwtAuth *JWTAuthMiddleware) *OptionalAuthMiddleware {
return &OptionalAuthMiddleware{
jwtAuth: jwtAuth,
}
}
// GetName 返回中间件名称
func (m *OptionalAuthMiddleware) GetName() string {
return "optional_auth"
}
// GetPriority 返回中间件优先级
func (m *OptionalAuthMiddleware) GetPriority() int {
return 60 // 与JWT认证中间件相同
}
// Handle 返回中间件处理函数
func (m *OptionalAuthMiddleware) Handle() gin.HandlerFunc {
return func(c *gin.Context) {
// 获取Authorization头部
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
// 没有认证头部,设置匿名用户标识
c.Set("is_authenticated", false)
c.Next()
return
}
// 检查Bearer前缀
const bearerPrefix = "Bearer "
if !strings.HasPrefix(authHeader, bearerPrefix) {
c.Set("is_authenticated", false)
c.Next()
return
}
// 提取并验证token
tokenString := authHeader[len(bearerPrefix):]
claims, err := m.jwtAuth.validateToken(tokenString)
if err != nil {
// token无效但不返回错误设置为未认证
c.Set("is_authenticated", false)
c.Next()
return
}
// token有效设置用户信息
c.Set("is_authenticated", true)
c.Set("user_id", claims.UserID)
c.Set("username", claims.Username)
c.Set("email", claims.Email)
2025-07-20 20:53:26 +08:00
c.Set("phone", claims.Phone)
c.Set("user_type", claims.UserType)
c.Set("token_claims", claims)
c.Next()
}
}
// IsGlobal 是否为全局中间件
func (m *OptionalAuthMiddleware) IsGlobal() bool {
return false
}
2025-07-20 20:53:26 +08:00
// AdminAuthMiddleware 管理员认证中间件
type AdminAuthMiddleware struct {
jwtAuth *JWTAuthMiddleware
logger *zap.Logger
}
// NewAdminAuthMiddleware 创建管理员认证中间件
func NewAdminAuthMiddleware(jwtAuth *JWTAuthMiddleware, logger *zap.Logger) *AdminAuthMiddleware {
return &AdminAuthMiddleware{
jwtAuth: jwtAuth,
logger: logger,
}
}
// GetName 返回中间件名称
func (m *AdminAuthMiddleware) GetName() string {
return "admin_auth"
}
// GetPriority 返回中间件优先级
func (m *AdminAuthMiddleware) GetPriority() int {
return 60 // 与JWT认证中间件相同
}
// Handle 管理员认证处理
func (m *AdminAuthMiddleware) Handle() gin.HandlerFunc {
return func(c *gin.Context) {
// 首先进行JWT认证
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
m.respondUnauthorized(c, "缺少认证头部")
return
}
// 检查Bearer前缀
const bearerPrefix = "Bearer "
if !strings.HasPrefix(authHeader, bearerPrefix) {
m.respondUnauthorized(c, "认证头部格式无效")
return
}
// 提取token
tokenString := authHeader[len(bearerPrefix):]
if tokenString == "" {
m.respondUnauthorized(c, "缺少认证令牌")
return
}
// 验证token
claims, err := m.jwtAuth.validateToken(tokenString)
if err != nil {
m.logger.Warn("无效的认证令牌",
zap.Error(err),
zap.String("request_id", c.GetString("request_id")))
m.respondUnauthorized(c, "认证令牌无效")
return
}
// 检查用户类型是否为管理员
if claims.UserType != "admin" {
m.respondForbidden(c, "需要管理员权限")
return
}
// 设置用户信息到上下文
c.Set("user_id", claims.UserID)
c.Set("username", claims.Username)
c.Set("email", claims.Email)
c.Set("phone", claims.Phone)
c.Set("user_type", claims.UserType)
c.Set("token_claims", claims)
c.Next()
}
}
// IsGlobal 是否为全局中间件
func (m *AdminAuthMiddleware) IsGlobal() bool {
return false
}
// respondForbidden 返回禁止访问响应
func (m *AdminAuthMiddleware) respondForbidden(c *gin.Context, message string) {
c.JSON(http.StatusForbidden, gin.H{
"success": false,
"message": "权限不足",
"error": message,
"request_id": c.GetString("request_id"),
"timestamp": time.Now().Unix(),
})
c.Abort()
}
// respondUnauthorized 返回未授权响应
func (m *AdminAuthMiddleware) respondUnauthorized(c *gin.Context, message string) {
c.JSON(http.StatusUnauthorized, gin.H{
"success": false,
"message": "认证失败",
"error": message,
"request_id": c.GetString("request_id"),
"timestamp": time.Now().Unix(),
})
c.Abort()
}