Files
tyc-server/app/main/api/internal/middleware/global/reqHeaderCtxMiddleware.go

62 lines
1.4 KiB
Go
Raw Normal View History

2025-05-09 17:54:28 +08:00
package middleware
import (
"context"
"net/http"
2025-08-31 14:18:31 +08:00
"strings"
2025-05-09 17:54:28 +08:00
)
func ReqHeaderCtxMiddleware(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
brand := r.Header.Get("X-Brand")
platform := r.Header.Get("X-Platform")
promoteValue := r.Header.Get("X-Promote-Key")
2025-08-31 14:18:31 +08:00
clientIP := getClientIP(r)
2025-05-09 17:54:28 +08:00
ctx := r.Context()
if brand != "" {
2025-05-09 20:26:22 +08:00
ctx = context.WithValue(ctx, "brand", brand)
2025-05-09 17:54:28 +08:00
}
if platform != "" {
2025-05-09 20:26:22 +08:00
ctx = context.WithValue(ctx, "platform", platform)
2025-05-09 17:54:28 +08:00
}
if promoteValue != "" {
2025-05-09 20:26:22 +08:00
ctx = context.WithValue(ctx, "promoteKey", promoteValue)
2025-05-09 17:54:28 +08:00
}
2025-08-31 14:18:31 +08:00
if clientIP != "" {
ctx = context.WithValue(ctx, "client_ip", clientIP)
}
2025-05-09 17:54:28 +08:00
r = r.WithContext(ctx)
next(w, r)
}
}
2025-08-31 14:18:31 +08:00
// getClientIP 获取客户端真实IP
func getClientIP(r *http.Request) string {
// 检查代理头
if ip := r.Header.Get("X-Forwarded-For"); ip != "" {
// 取第一个IP最原始的客户端IP
if commaIndex := strings.Index(ip, ","); commaIndex != -1 {
return strings.TrimSpace(ip[:commaIndex])
}
return strings.TrimSpace(ip)
}
if ip := r.Header.Get("X-Real-IP"); ip != "" {
return strings.TrimSpace(ip)
}
if ip := r.Header.Get("X-Client-IP"); ip != "" {
return strings.TrimSpace(ip)
}
// 直接连接
if r.RemoteAddr != "" {
if colonIndex := strings.LastIndex(r.RemoteAddr, ":"); colonIndex != -1 {
return r.RemoteAddr[:colonIndex]
}
return r.RemoteAddr
}
return "unknown"
}