package middleware import ( "context" "net/http" "strings" ) func ReqHeaderCtxMiddleware(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { brand := r.Header.Get("X-Brand") platform := r.Header.Get("X-Platform") promoteValue := r.Header.Get("X-Promote-Key") clientIP := getClientIP(r) ctx := r.Context() if brand != "" { ctx = context.WithValue(ctx, "brand", brand) } if platform != "" { ctx = context.WithValue(ctx, "platform", platform) } if promoteValue != "" { ctx = context.WithValue(ctx, "promoteKey", promoteValue) } if clientIP != "" { ctx = context.WithValue(ctx, "client_ip", clientIP) } r = r.WithContext(ctx) next(w, r) } } // getClientIP 获取客户端真实IP func getClientIP(r *http.Request) string { // 检查代理头 if ip := r.Header.Get("X-Forwarded-For"); ip != "" { // 取第一个IP(最原始的客户端IP) if commaIndex := strings.Index(ip, ","); commaIndex != -1 { return strings.TrimSpace(ip[:commaIndex]) } return strings.TrimSpace(ip) } if ip := r.Header.Get("X-Real-IP"); ip != "" { return strings.TrimSpace(ip) } if ip := r.Header.Get("X-Client-IP"); ip != "" { return strings.TrimSpace(ip) } // 直接连接 if r.RemoteAddr != "" { if colonIndex := strings.LastIndex(r.RemoteAddr, ":"); colonIndex != -1 { return r.RemoteAddr[:colonIndex] } return r.RemoteAddr } return "unknown" }