270 lines
5.8 KiB
Go
270 lines
5.8 KiB
Go
package utils
|
|
|
|
import (
|
|
"fmt"
|
|
"net"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"golang.org/x/time/rate"
|
|
"hubproxy/config"
|
|
)
|
|
|
|
const (
|
|
CleanupInterval = 10 * time.Minute
|
|
MaxIPCacheSize = 10000
|
|
)
|
|
|
|
// IPRateLimiter IP限流器结构体
|
|
type IPRateLimiter struct {
|
|
ips map[string]*rateLimiterEntry
|
|
mu *sync.RWMutex
|
|
r rate.Limit
|
|
b int
|
|
whitelist []*net.IPNet
|
|
blacklist []*net.IPNet
|
|
}
|
|
|
|
// rateLimiterEntry 限流器条目
|
|
type rateLimiterEntry struct {
|
|
limiter *rate.Limiter
|
|
lastAccess time.Time
|
|
}
|
|
|
|
// InitGlobalLimiter 初始化全局限流器
|
|
func InitGlobalLimiter() *IPRateLimiter {
|
|
cfg := config.GetConfig()
|
|
|
|
whitelist := make([]*net.IPNet, 0, len(cfg.Security.WhiteList))
|
|
for _, item := range cfg.Security.WhiteList {
|
|
if item = strings.TrimSpace(item); item != "" {
|
|
if !strings.Contains(item, "/") {
|
|
item = item + "/32"
|
|
}
|
|
_, ipnet, err := net.ParseCIDR(item)
|
|
if err == nil {
|
|
whitelist = append(whitelist, ipnet)
|
|
} else {
|
|
fmt.Printf("警告: 无效的白名单IP格式: %s\n", item)
|
|
}
|
|
}
|
|
}
|
|
|
|
blacklist := make([]*net.IPNet, 0, len(cfg.Security.BlackList))
|
|
for _, item := range cfg.Security.BlackList {
|
|
if item = strings.TrimSpace(item); item != "" {
|
|
if !strings.Contains(item, "/") {
|
|
item = item + "/32"
|
|
}
|
|
_, ipnet, err := net.ParseCIDR(item)
|
|
if err == nil {
|
|
blacklist = append(blacklist, ipnet)
|
|
} else {
|
|
fmt.Printf("警告: 无效的黑名单IP格式: %s\n", item)
|
|
}
|
|
}
|
|
}
|
|
|
|
ratePerSecond := rate.Limit(float64(cfg.RateLimit.RequestLimit) / (cfg.RateLimit.PeriodHours * 3600))
|
|
|
|
burstSize := cfg.RateLimit.RequestLimit
|
|
if burstSize < 1 {
|
|
burstSize = 1
|
|
}
|
|
|
|
limiter := &IPRateLimiter{
|
|
ips: make(map[string]*rateLimiterEntry),
|
|
mu: &sync.RWMutex{},
|
|
r: ratePerSecond,
|
|
b: burstSize,
|
|
whitelist: whitelist,
|
|
blacklist: blacklist,
|
|
}
|
|
|
|
go limiter.cleanupRoutine()
|
|
|
|
return limiter
|
|
}
|
|
|
|
// cleanupRoutine 定期清理过期的限流器
|
|
func (i *IPRateLimiter) cleanupRoutine() {
|
|
ticker := time.NewTicker(CleanupInterval)
|
|
defer ticker.Stop()
|
|
|
|
for range ticker.C {
|
|
now := time.Now()
|
|
expired := make([]string, 0)
|
|
|
|
i.mu.RLock()
|
|
for ip, entry := range i.ips {
|
|
if now.Sub(entry.lastAccess) > 1*time.Hour {
|
|
expired = append(expired, ip)
|
|
}
|
|
}
|
|
i.mu.RUnlock()
|
|
|
|
if len(expired) > 0 || len(i.ips) > MaxIPCacheSize {
|
|
i.mu.Lock()
|
|
for _, ip := range expired {
|
|
delete(i.ips, ip)
|
|
}
|
|
|
|
if len(i.ips) > MaxIPCacheSize {
|
|
i.ips = make(map[string]*rateLimiterEntry)
|
|
}
|
|
i.mu.Unlock()
|
|
}
|
|
}
|
|
}
|
|
|
|
// extractIPFromAddress 从地址中提取纯IP
|
|
func extractIPFromAddress(address string) string {
|
|
if host, _, err := net.SplitHostPort(address); err == nil {
|
|
return host
|
|
}
|
|
return address
|
|
}
|
|
|
|
// normalizeIPForRateLimit 标准化IP地址用于限流
|
|
func normalizeIPForRateLimit(ipStr string) string {
|
|
ip := net.ParseIP(ipStr)
|
|
if ip == nil {
|
|
return ipStr
|
|
}
|
|
|
|
if ip.To4() != nil {
|
|
return ipStr
|
|
}
|
|
|
|
ipv6 := ip.To16()
|
|
for i := 8; i < 16; i++ {
|
|
ipv6[i] = 0
|
|
}
|
|
return ipv6.String() + "/64"
|
|
}
|
|
|
|
// isIPInCIDRList 检查IP是否在CIDR列表中
|
|
func isIPInCIDRList(ip string, cidrList []*net.IPNet) bool {
|
|
cleanIP := extractIPFromAddress(ip)
|
|
parsedIP := net.ParseIP(cleanIP)
|
|
if parsedIP == nil {
|
|
return false
|
|
}
|
|
|
|
for _, cidr := range cidrList {
|
|
if cidr.Contains(parsedIP) {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// GetLimiter 获取指定IP的限流器
|
|
func (i *IPRateLimiter) GetLimiter(ip string) (*rate.Limiter, bool) {
|
|
cleanIP := extractIPFromAddress(ip)
|
|
|
|
if isIPInCIDRList(cleanIP, i.blacklist) {
|
|
return nil, false
|
|
}
|
|
|
|
if isIPInCIDRList(cleanIP, i.whitelist) {
|
|
return rate.NewLimiter(rate.Inf, i.b), true
|
|
}
|
|
|
|
normalizedIP := normalizeIPForRateLimit(cleanIP)
|
|
|
|
now := time.Now()
|
|
|
|
i.mu.RLock()
|
|
entry, exists := i.ips[normalizedIP]
|
|
i.mu.RUnlock()
|
|
|
|
if exists {
|
|
i.mu.Lock()
|
|
if entry, stillExists := i.ips[normalizedIP]; stillExists {
|
|
entry.lastAccess = now
|
|
i.mu.Unlock()
|
|
return entry.limiter, true
|
|
}
|
|
i.mu.Unlock()
|
|
}
|
|
|
|
i.mu.Lock()
|
|
if entry, exists := i.ips[normalizedIP]; exists {
|
|
entry.lastAccess = now
|
|
i.mu.Unlock()
|
|
return entry.limiter, true
|
|
}
|
|
|
|
entry = &rateLimiterEntry{
|
|
limiter: rate.NewLimiter(i.r, i.b),
|
|
lastAccess: now,
|
|
}
|
|
i.ips[normalizedIP] = entry
|
|
i.mu.Unlock()
|
|
|
|
return entry.limiter, true
|
|
}
|
|
|
|
// RateLimitMiddleware 速率限制中间件
|
|
func RateLimitMiddleware(limiter *IPRateLimiter) gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
path := c.Request.URL.Path
|
|
if path == "/" || path == "/favicon.ico" || path == "/images.html" || path == "/search.html" ||
|
|
strings.HasPrefix(path, "/public/") {
|
|
c.Next()
|
|
return
|
|
}
|
|
|
|
var ip string
|
|
|
|
if forwarded := c.GetHeader("X-Forwarded-For"); forwarded != "" {
|
|
ips := strings.Split(forwarded, ",")
|
|
ip = strings.TrimSpace(ips[0])
|
|
} else if realIP := c.GetHeader("X-Real-IP"); realIP != "" {
|
|
ip = realIP
|
|
} else if remoteIP := c.GetHeader("X-Original-Forwarded-For"); remoteIP != "" {
|
|
ips := strings.Split(remoteIP, ",")
|
|
ip = strings.TrimSpace(ips[0])
|
|
} else {
|
|
ip = c.ClientIP()
|
|
}
|
|
|
|
cleanIP := extractIPFromAddress(ip)
|
|
|
|
normalizedIP := normalizeIPForRateLimit(cleanIP)
|
|
if cleanIP != normalizedIP {
|
|
fmt.Printf("请求IP: %s (提纯后: %s, 限流段: %s), X-Forwarded-For: %s, X-Real-IP: %s\n",
|
|
ip, cleanIP, normalizedIP,
|
|
c.GetHeader("X-Forwarded-For"),
|
|
c.GetHeader("X-Real-IP"))
|
|
} else {
|
|
fmt.Printf("请求IP: %s (提纯后: %s), X-Forwarded-For: %s, X-Real-IP: %s\n",
|
|
ip, cleanIP,
|
|
c.GetHeader("X-Forwarded-For"),
|
|
c.GetHeader("X-Real-IP"))
|
|
}
|
|
|
|
ipLimiter, allowed := limiter.GetLimiter(cleanIP)
|
|
|
|
if !allowed {
|
|
c.JSON(403, gin.H{
|
|
"error": "您已被限制访问",
|
|
})
|
|
c.Abort()
|
|
return
|
|
}
|
|
|
|
if !ipLimiter.Allow() {
|
|
c.JSON(429, gin.H{
|
|
"error": "请求频率过快,暂时限制访问",
|
|
})
|
|
c.Abort()
|
|
return
|
|
}
|
|
|
|
c.Next()
|
|
}
|
|
} |