add
This commit is contained in:
26
internal/shared/ipgeo/city_coords.go
Normal file
26
internal/shared/ipgeo/city_coords.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package ipgeo
|
||||
|
||||
// Coord 城市经纬度
|
||||
type Coord struct {
|
||||
Lng float64
|
||||
Lat float64
|
||||
}
|
||||
|
||||
// CityCoordinates MVP阶段常用城市坐标
|
||||
var CityCoordinates = map[string]Coord{
|
||||
"北京市": {Lng: 116.4074, Lat: 39.9042},
|
||||
"上海市": {Lng: 121.4737, Lat: 31.2304},
|
||||
"广州市": {Lng: 113.2644, Lat: 23.1291},
|
||||
"深圳市": {Lng: 114.0579, Lat: 22.5431},
|
||||
"杭州市": {Lng: 120.1551, Lat: 30.2741},
|
||||
"成都市": {Lng: 104.0665, Lat: 30.5728},
|
||||
"武汉市": {Lng: 114.3055, Lat: 30.5928},
|
||||
"西安市": {Lng: 108.9398, Lat: 34.3416},
|
||||
"南京市": {Lng: 118.7969, Lat: 32.0603},
|
||||
"苏州市": {Lng: 120.5853, Lat: 31.2989},
|
||||
"重庆市": {Lng: 106.5516, Lat: 29.5630},
|
||||
"天津市": {Lng: 117.2009, Lat: 39.0842},
|
||||
"郑州市": {Lng: 113.6254, Lat: 34.7466},
|
||||
"长沙市": {Lng: 112.9388, Lat: 28.2282},
|
||||
"青岛市": {Lng: 120.3826, Lat: 36.0671},
|
||||
}
|
||||
134
internal/shared/ipgeo/ip_locator.go
Normal file
134
internal/shared/ipgeo/ip_locator.go
Normal file
@@ -0,0 +1,134 @@
|
||||
package ipgeo
|
||||
|
||||
import (
|
||||
"net"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"tyapi-server/internal/domains/security/entities"
|
||||
|
||||
"github.com/lionsoul2014/ip2region/binding/golang/xdb"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Location IP解析后的地理信息
|
||||
type Location struct {
|
||||
Country string
|
||||
Province string
|
||||
City string
|
||||
ISP string
|
||||
Region string
|
||||
}
|
||||
|
||||
// Locator IP地理定位器
|
||||
type Locator struct {
|
||||
logger *zap.Logger
|
||||
searcher *xdb.Searcher
|
||||
}
|
||||
|
||||
// NewLocator 创建定位器,优先读取 resources/ipgeo/ip2region.xdb
|
||||
func NewLocator(logger *zap.Logger) *Locator {
|
||||
locator := &Locator{logger: logger}
|
||||
dbPath := filepath.Join("resources", "ipgeo", "ip2region.xdb")
|
||||
|
||||
cBuff, err := xdb.LoadContentFromFile(dbPath)
|
||||
if err != nil {
|
||||
logger.Warn("加载ip2region库失败,将使用降级定位", zap.String("db_path", dbPath), zap.Error(err))
|
||||
return locator
|
||||
}
|
||||
|
||||
header, err := xdb.LoadHeaderFromBuff(cBuff)
|
||||
if err != nil {
|
||||
logger.Warn("读取ip2region头信息失败,将使用降级定位", zap.Error(err))
|
||||
return locator
|
||||
}
|
||||
version, err := xdb.VersionFromHeader(header)
|
||||
if err != nil {
|
||||
logger.Warn("解析ip2region版本失败,将使用降级定位", zap.Error(err))
|
||||
return locator
|
||||
}
|
||||
|
||||
searcher, err := xdb.NewWithBuffer(version, cBuff)
|
||||
if err != nil {
|
||||
logger.Warn("初始化ip2region搜索器失败,将使用降级定位", zap.Error(err))
|
||||
return locator
|
||||
}
|
||||
locator.searcher = searcher
|
||||
|
||||
logger.Info("ip2region定位器初始化成功", zap.String("db_path", dbPath))
|
||||
return locator
|
||||
}
|
||||
|
||||
// LookupByIP 根据IP定位,失败返回 false
|
||||
func (l *Locator) LookupByIP(ip string) (Location, bool) {
|
||||
if ip == "" || isPrivateOrLocalIP(ip) || l.searcher == nil {
|
||||
return Location{}, false
|
||||
}
|
||||
|
||||
region, err := l.searcher.SearchByStr(ip)
|
||||
if err != nil {
|
||||
l.logger.Debug("ip2region查询失败", zap.String("ip", ip), zap.Error(err))
|
||||
return Location{}, false
|
||||
}
|
||||
loc := parseRegion(region)
|
||||
if loc.Region == "" {
|
||||
return Location{}, false
|
||||
}
|
||||
return loc, true
|
||||
}
|
||||
|
||||
// ToGeoPoint 将记录转换为地球飞线起点
|
||||
func (l *Locator) ToGeoPoint(record entities.SuspiciousIPRecord) (fromName string, lng float64, lat float64) {
|
||||
// 默认降级坐标:北京
|
||||
const defaultLng = 116.4074
|
||||
const defaultLat = 39.9042
|
||||
|
||||
loc, ok := l.LookupByIP(record.IP)
|
||||
if !ok {
|
||||
return record.IP, defaultLng, defaultLat
|
||||
}
|
||||
|
||||
cityName := strings.TrimSpace(loc.City)
|
||||
if cityName == "" || cityName == "0" {
|
||||
cityName = strings.TrimSpace(loc.Province)
|
||||
}
|
||||
if cityName == "" || cityName == "0" {
|
||||
return record.IP, defaultLng, defaultLat
|
||||
}
|
||||
|
||||
coord, exists := CityCoordinates[cityName]
|
||||
if !exists {
|
||||
// 降级:未命中城市映射,回默认坐标
|
||||
return cityName, defaultLng, defaultLat
|
||||
}
|
||||
return cityName, coord.Lng, coord.Lat
|
||||
}
|
||||
|
||||
func parseRegion(region string) Location {
|
||||
parts := strings.Split(region, "|")
|
||||
for len(parts) < 5 {
|
||||
parts = append(parts, "")
|
||||
}
|
||||
return Location{
|
||||
Country: normalizeField(parts[0]),
|
||||
Region: normalizeField(parts[1]),
|
||||
Province: normalizeField(parts[2]),
|
||||
City: normalizeField(parts[3]),
|
||||
ISP: normalizeField(parts[4]),
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeField(s string) string {
|
||||
s = strings.TrimSpace(s)
|
||||
if s == "0" {
|
||||
return ""
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func isPrivateOrLocalIP(ip string) bool {
|
||||
parsed := net.ParseIP(ip)
|
||||
if parsed == nil {
|
||||
return true
|
||||
}
|
||||
return parsed.IsLoopback() || parsed.IsPrivate() || parsed.IsUnspecified() || parsed.IsLinkLocalUnicast()
|
||||
}
|
||||
@@ -3,16 +3,19 @@ package middleware
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"tyapi-server/internal/config"
|
||||
securityEntities "tyapi-server/internal/domains/security/entities"
|
||||
"tyapi-server/internal/shared/interfaces"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// DailyRateLimitConfig 每日限流配置
|
||||
@@ -45,6 +48,7 @@ type DailyRateLimitConfig struct {
|
||||
type DailyRateLimitMiddleware struct {
|
||||
config *config.Config
|
||||
redis *redis.Client
|
||||
db *gorm.DB
|
||||
response interfaces.ResponseBuilder
|
||||
logger *zap.Logger
|
||||
limitConfig DailyRateLimitConfig
|
||||
@@ -54,6 +58,7 @@ type DailyRateLimitMiddleware struct {
|
||||
func NewDailyRateLimitMiddleware(
|
||||
cfg *config.Config,
|
||||
redis *redis.Client,
|
||||
db *gorm.DB,
|
||||
response interfaces.ResponseBuilder,
|
||||
logger *zap.Logger,
|
||||
limitConfig DailyRateLimitConfig,
|
||||
@@ -78,6 +83,7 @@ func NewDailyRateLimitMiddleware(
|
||||
return &DailyRateLimitMiddleware{
|
||||
config: cfg,
|
||||
redis: redis,
|
||||
db: db,
|
||||
response: response,
|
||||
logger: logger,
|
||||
limitConfig: limitConfig,
|
||||
@@ -154,7 +160,9 @@ func (m *DailyRateLimitMiddleware) Handle() gin.HandlerFunc {
|
||||
}
|
||||
|
||||
// 4. 检查并发限制
|
||||
if err := m.checkConcurrentLimit(ctx, clientIP); err != nil {
|
||||
concurrentCount, err := m.checkConcurrentLimit(ctx, clientIP)
|
||||
if err != nil {
|
||||
m.recordSuspiciousRequest(c, clientIP, "daily_concurrent_limit")
|
||||
m.logger.Warn("并发请求超限",
|
||||
zap.String("ip", clientIP),
|
||||
zap.String("request_id", c.GetString("request_id")),
|
||||
@@ -163,9 +171,14 @@ func (m *DailyRateLimitMiddleware) Handle() gin.HandlerFunc {
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
if m.shouldRecordNearLimit(concurrentCount, m.limitConfig.MaxConcurrent) {
|
||||
m.recordSuspiciousRequest(c, clientIP, "daily_concurrent_limit")
|
||||
}
|
||||
|
||||
// 5. 检查接口总请求次数限制
|
||||
if err := m.checkTotalLimit(ctx); err != nil {
|
||||
totalCount, err := m.checkTotalLimit(ctx)
|
||||
if err != nil {
|
||||
m.recordSuspiciousRequest(c, clientIP, "daily_total_limit")
|
||||
m.logger.Warn("接口总请求次数超限",
|
||||
zap.String("ip", clientIP),
|
||||
zap.String("request_id", c.GetString("request_id")),
|
||||
@@ -175,9 +188,14 @@ func (m *DailyRateLimitMiddleware) Handle() gin.HandlerFunc {
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
if m.shouldRecordNearLimit(totalCount+1, m.limitConfig.MaxRequestsPerDay) {
|
||||
m.recordSuspiciousRequest(c, clientIP, "daily_total_limit")
|
||||
}
|
||||
|
||||
// 6. 检查IP限制
|
||||
if err := m.checkIPLimit(ctx, clientIP); err != nil {
|
||||
ipCount, err := m.checkIPLimit(ctx, clientIP)
|
||||
if err != nil {
|
||||
m.recordSuspiciousRequest(c, clientIP, "daily_ip_limit")
|
||||
m.logger.Warn("IP请求次数超限",
|
||||
zap.String("ip", clientIP),
|
||||
zap.String("request_id", c.GetString("request_id")),
|
||||
@@ -187,6 +205,9 @@ func (m *DailyRateLimitMiddleware) Handle() gin.HandlerFunc {
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
if m.shouldRecordNearLimit(ipCount+1, m.limitConfig.MaxRequestsPerIP) {
|
||||
m.recordSuspiciousRequest(c, clientIP, "daily_ip_limit")
|
||||
}
|
||||
|
||||
// 7. 增加计数
|
||||
m.incrementCounters(ctx, clientIP)
|
||||
@@ -198,6 +219,38 @@ func (m *DailyRateLimitMiddleware) Handle() gin.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
func (m *DailyRateLimitMiddleware) recordSuspiciousRequest(c *gin.Context, ip, reason string) {
|
||||
if m.db == nil {
|
||||
return
|
||||
}
|
||||
record := securityEntities.SuspiciousIPRecord{
|
||||
IP: ip,
|
||||
Path: c.Request.URL.Path,
|
||||
Method: c.Request.Method,
|
||||
RequestCount: 1,
|
||||
WindowSeconds: int(m.limitConfig.TTL.Seconds()),
|
||||
TriggerReason: reason,
|
||||
UserAgent: c.GetHeader("User-Agent"),
|
||||
}
|
||||
if record.WindowSeconds <= 0 {
|
||||
record.WindowSeconds = 10
|
||||
}
|
||||
if err := m.db.Create(&record).Error; err != nil {
|
||||
m.logger.Warn("记录每日限流可疑IP失败", zap.String("ip", ip), zap.String("reason", reason), zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
func (m *DailyRateLimitMiddleware) shouldRecordNearLimit(current, max int) bool {
|
||||
if max <= 0 {
|
||||
return false
|
||||
}
|
||||
threshold := int(math.Ceil(float64(max) * 0.8))
|
||||
if threshold < 1 {
|
||||
threshold = 1
|
||||
}
|
||||
return current >= threshold
|
||||
}
|
||||
|
||||
// isExcludedDomain 检查域名是否在排除列表中
|
||||
func (m *DailyRateLimitMiddleware) isExcludedDomain(host string) bool {
|
||||
for _, excludeDomain := range m.limitConfig.ExcludeDomains {
|
||||
@@ -360,13 +413,13 @@ func (m *DailyRateLimitMiddleware) checkReferer(c *gin.Context) error {
|
||||
}
|
||||
|
||||
// checkConcurrentLimit 检查并发限制
|
||||
func (m *DailyRateLimitMiddleware) checkConcurrentLimit(ctx context.Context, clientIP string) error {
|
||||
func (m *DailyRateLimitMiddleware) checkConcurrentLimit(ctx context.Context, clientIP string) (int, error) {
|
||||
key := fmt.Sprintf("%s:concurrent:%s", m.limitConfig.KeyPrefix, clientIP)
|
||||
|
||||
// 获取当前并发数
|
||||
current, err := m.redis.Get(ctx, key).Result()
|
||||
if err != nil && err != redis.Nil {
|
||||
return fmt.Errorf("获取并发计数失败: %w", err)
|
||||
return 0, fmt.Errorf("获取并发计数失败: %w", err)
|
||||
}
|
||||
|
||||
currentCount := 0
|
||||
@@ -377,7 +430,7 @@ func (m *DailyRateLimitMiddleware) checkConcurrentLimit(ctx context.Context, cli
|
||||
}
|
||||
|
||||
if currentCount >= m.limitConfig.MaxConcurrent {
|
||||
return fmt.Errorf("并发请求超限: %d", currentCount)
|
||||
return currentCount, fmt.Errorf("并发请求超限: %d", currentCount)
|
||||
}
|
||||
|
||||
// 增加并发计数
|
||||
@@ -390,7 +443,7 @@ func (m *DailyRateLimitMiddleware) checkConcurrentLimit(ctx context.Context, cli
|
||||
m.logger.Error("增加并发计数失败", zap.String("key", key), zap.Error(err))
|
||||
}
|
||||
|
||||
return nil
|
||||
return currentCount + 1, nil
|
||||
}
|
||||
|
||||
// getClientIP 获取客户端IP地址(增强版)
|
||||
@@ -435,35 +488,35 @@ func (m *DailyRateLimitMiddleware) getClientIP(c *gin.Context) string {
|
||||
}
|
||||
|
||||
// checkTotalLimit 检查接口总请求次数限制
|
||||
func (m *DailyRateLimitMiddleware) checkTotalLimit(ctx context.Context) error {
|
||||
func (m *DailyRateLimitMiddleware) checkTotalLimit(ctx context.Context) (int, error) {
|
||||
key := fmt.Sprintf("%s:total:%s", m.limitConfig.KeyPrefix, m.getDateKey())
|
||||
|
||||
count, err := m.getCounter(ctx, key)
|
||||
if err != nil {
|
||||
return fmt.Errorf("获取总请求计数失败: %w", err)
|
||||
return 0, fmt.Errorf("获取总请求计数失败: %w", err)
|
||||
}
|
||||
|
||||
if count >= m.limitConfig.MaxRequestsPerDay {
|
||||
return fmt.Errorf("接口今日总请求次数已达上限 %d", m.limitConfig.MaxRequestsPerDay)
|
||||
return count, fmt.Errorf("接口今日总请求次数已达上限 %d", m.limitConfig.MaxRequestsPerDay)
|
||||
}
|
||||
|
||||
return nil
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// checkIPLimit 检查IP限制
|
||||
func (m *DailyRateLimitMiddleware) checkIPLimit(ctx context.Context, clientIP string) error {
|
||||
func (m *DailyRateLimitMiddleware) checkIPLimit(ctx context.Context, clientIP string) (int, error) {
|
||||
key := fmt.Sprintf("%s:ip:%s:%s", m.limitConfig.KeyPrefix, clientIP, m.getDateKey())
|
||||
|
||||
count, err := m.getCounter(ctx, key)
|
||||
if err != nil {
|
||||
return fmt.Errorf("获取IP计数失败: %w", err)
|
||||
return 0, fmt.Errorf("获取IP计数失败: %w", err)
|
||||
}
|
||||
|
||||
if count >= m.limitConfig.MaxRequestsPerIP {
|
||||
return fmt.Errorf("IP %s 今日请求次数已达上限 %d", clientIP, m.limitConfig.MaxRequestsPerIP)
|
||||
return count, fmt.Errorf("IP %s 今日请求次数已达上限 %d", clientIP, m.limitConfig.MaxRequestsPerIP)
|
||||
}
|
||||
|
||||
return nil
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// incrementCounters 增加计数器
|
||||
|
||||
@@ -5,25 +5,32 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
"tyapi-server/internal/config"
|
||||
securityEntities "tyapi-server/internal/domains/security/entities"
|
||||
"tyapi-server/internal/shared/interfaces"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
"golang.org/x/time/rate"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// RateLimitMiddleware 限流中间件
|
||||
type RateLimitMiddleware struct {
|
||||
config *config.Config
|
||||
response interfaces.ResponseBuilder
|
||||
db *gorm.DB
|
||||
logger *zap.Logger
|
||||
limiters map[string]*rate.Limiter
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
// NewRateLimitMiddleware 创建限流中间件
|
||||
func NewRateLimitMiddleware(cfg *config.Config, response interfaces.ResponseBuilder) *RateLimitMiddleware {
|
||||
func NewRateLimitMiddleware(cfg *config.Config, response interfaces.ResponseBuilder, db *gorm.DB, logger *zap.Logger) *RateLimitMiddleware {
|
||||
return &RateLimitMiddleware{
|
||||
config: cfg,
|
||||
response: response,
|
||||
db: db,
|
||||
logger: logger,
|
||||
limiters: make(map[string]*rate.Limiter),
|
||||
}
|
||||
}
|
||||
@@ -49,6 +56,8 @@ func (m *RateLimitMiddleware) Handle() gin.HandlerFunc {
|
||||
|
||||
// 检查是否允许请求
|
||||
if !limiter.Allow() {
|
||||
m.recordSuspiciousRequest(c, clientID, "rate_limit")
|
||||
|
||||
// 添加限流头部信息
|
||||
c.Header("X-RateLimit-Limit", fmt.Sprintf("%d", m.config.RateLimit.Requests))
|
||||
c.Header("X-RateLimit-Window", m.config.RateLimit.Window.String())
|
||||
@@ -68,6 +77,28 @@ func (m *RateLimitMiddleware) Handle() gin.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
func (m *RateLimitMiddleware) recordSuspiciousRequest(c *gin.Context, ip, reason string) {
|
||||
if m.db == nil {
|
||||
return
|
||||
}
|
||||
windowSeconds := int(m.config.RateLimit.Window.Seconds())
|
||||
if windowSeconds <= 0 {
|
||||
windowSeconds = 1
|
||||
}
|
||||
record := securityEntities.SuspiciousIPRecord{
|
||||
IP: ip,
|
||||
Path: c.Request.URL.Path,
|
||||
Method: c.Request.Method,
|
||||
RequestCount: 1,
|
||||
WindowSeconds: windowSeconds,
|
||||
TriggerReason: reason,
|
||||
UserAgent: c.GetHeader("User-Agent"),
|
||||
}
|
||||
if err := m.db.Create(&record).Error; err != nil && m.logger != nil {
|
||||
m.logger.Warn("记录可疑IP失败", zap.String("ip", ip), zap.String("path", record.Path), zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
// IsGlobal 是否为全局中间件
|
||||
func (m *RateLimitMiddleware) IsGlobal() bool {
|
||||
return true
|
||||
|
||||
Reference in New Issue
Block a user