修复ipv6标准化的潜在BUG

This commit is contained in:
user123456
2025-06-17 18:38:48 +08:00
parent aea36939a3
commit 182dced403
3 changed files with 32 additions and 19 deletions

View File

@@ -171,11 +171,11 @@ func handler(c *gin.Context) {
rawPath = strings.Replace(rawPath, "/blob/", "/raw/", 1) rawPath = strings.Replace(rawPath, "/blob/", "/raw/", 1)
} }
proxy(c, rawPath) proxyRequest(c, rawPath)
} }
func proxy(c *gin.Context, u string) { func proxyRequest(c *gin.Context, u string) {
proxyWithRedirect(c, u, 0) proxyWithRedirect(c, u, 0)
} }

View File

@@ -132,23 +132,33 @@ func (i *IPRateLimiter) cleanupRoutine() {
} }
} }
// extractIPFromAddress 从地址中提取纯IP,去除端口号 // extractIPFromAddress 从地址中提取纯IP
func extractIPFromAddress(address string) string { func extractIPFromAddress(address string) string {
// 处理IPv6地址 [::1]:8080 格式 if host, _, err := net.SplitHostPort(address); err == nil {
if strings.HasPrefix(address, "[") { return host
if endIndex := strings.Index(address, "]"); endIndex != -1 {
return address[1:endIndex]
}
} }
// 处理IPv4地址 192.168.1.1:8080 格式
if lastColon := strings.LastIndex(address, ":"); lastColon != -1 {
return address[:lastColon]
}
return address return address
} }
// normalizeIPForRateLimit 标准化IP地址用于限流IPv4保持不变IPv6标准化为/64网段
func normalizeIPForRateLimit(ipStr string) string {
ip := net.ParseIP(ipStr)
if ip == nil {
return ipStr // 解析失败,返回原值
}
if ip.To4() != nil {
return ipStr // IPv4保持不变
}
// IPv6标准化为 /64 网段
ipv6 := ip.To16()
for i := 8; i < 16; i++ {
ipv6[i] = 0 // 清零后64位
}
return ipv6.String() + "/64"
}
// isIPInCIDRList 检查IP是否在CIDR列表中 // isIPInCIDRList 检查IP是否在CIDR列表中
func isIPInCIDRList(ip string, cidrList []*net.IPNet) bool { func isIPInCIDRList(ip string, cidrList []*net.IPNet) bool {
// 先提取纯IP地址 // 先提取纯IP地址
@@ -181,15 +191,18 @@ func (i *IPRateLimiter) GetLimiter(ip string) (*rate.Limiter, bool) {
return rate.NewLimiter(rate.Inf, i.b), true return rate.NewLimiter(rate.Inf, i.b), true
} }
// 标准化IP用于限流IPv4保持不变IPv6标准化为/64网段
normalizedIP := normalizeIPForRateLimit(cleanIP)
now := time.Now() now := time.Now()
i.mu.RLock() i.mu.RLock()
entry, exists := i.ips[cleanIP] entry, exists := i.ips[normalizedIP]
i.mu.RUnlock() i.mu.RUnlock()
if exists { if exists {
i.mu.Lock() i.mu.Lock()
if entry, stillExists := i.ips[cleanIP]; stillExists { if entry, stillExists := i.ips[normalizedIP]; stillExists {
entry.lastAccess = now entry.lastAccess = now
i.mu.Unlock() i.mu.Unlock()
return entry.limiter, true return entry.limiter, true
@@ -198,7 +211,7 @@ func (i *IPRateLimiter) GetLimiter(ip string) (*rate.Limiter, bool) {
} }
i.mu.Lock() i.mu.Lock()
if entry, exists := i.ips[cleanIP]; exists { if entry, exists := i.ips[normalizedIP]; exists {
entry.lastAccess = now entry.lastAccess = now
i.mu.Unlock() i.mu.Unlock()
return entry.limiter, true return entry.limiter, true
@@ -208,7 +221,7 @@ func (i *IPRateLimiter) GetLimiter(ip string) (*rate.Limiter, bool) {
limiter: rate.NewLimiter(i.r, i.b), limiter: rate.NewLimiter(i.r, i.b),
lastAccess: now, lastAccess: now,
} }
i.ips[cleanIP] = entry i.ips[normalizedIP] = entry
i.mu.Unlock() i.mu.Unlock()
return entry.limiter, true return entry.limiter, true

View File

@@ -41,7 +41,7 @@ func (s *SmartRateLimit) ShouldSkipRateLimit(ip, path string) bool {
return false return false
} }
sessionKey := ip sessionKey := normalizeIPForRateLimit(ip)
sessionInterface, _ := s.sessions.LoadOrStore(sessionKey, &PullSession{}) sessionInterface, _ := s.sessions.LoadOrStore(sessionKey, &PullSession{})
session := sessionInterface.(*PullSession) session := sessionInterface.(*PullSession)