增加IP限流中间件

This commit is contained in:
NewName
2025-05-17 16:06:44 +08:00
parent 78f3416f95
commit c8c9ff61f7
2 changed files with 370 additions and 14 deletions

316
ghproxy/ratelimiter.go Normal file
View File

@@ -0,0 +1,316 @@
package main
import (
"fmt"
"net"
"os"
"strconv"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
"golang.org/x/time/rate"
)
// IP限流配置
var (
// 默认限流每个IP每1小时允许20个请求
DefaultRateLimit = 20.0 // 默认限制请求数
DefaultRatePeriodHours = 1.0 // 默认时间周期(小时)
// 白名单列表支持IP和CIDR格式"192.168.1.1", "10.0.0.0/8"
WhitelistIPs = []string{
"127.0.0.1", // 本地回环地址
"10.0.0.0/8", // 内网地址段
"172.16.0.0/12", // 内网地址段
"192.168.0.0/16", // 内网地址段
}
// 黑名单列表支持IP和CIDR格式
BlacklistIPs = []string{
// 示例: "1.2.3.4", "5.6.7.0/24"
}
// 清理间隔:多久清理一次过期的限流器
CleanupInterval = 1 * time.Hour
// IP限流器缓存上限超过此数量将触发清理
MaxIPCacheSize = 10000
)
// IPRateLimiter 定义IP限流器结构
type IPRateLimiter struct {
ips map[string]*rateLimiterEntry // IP到限流器的映射
mu *sync.RWMutex // 读写锁,保证并发安全
r rate.Limit // 速率限制(每秒允许的请求数)
b int // 令牌桶容量(突发请求数)
whitelist []*net.IPNet // 白名单IP段
blacklist []*net.IPNet // 黑名单IP段
}
// rateLimiterEntry 限流器条目,包含限流器和最后访问时间
type rateLimiterEntry struct {
limiter *rate.Limiter // 限流器
lastAccess time.Time // 最后访问时间
}
// NewIPRateLimiter 创建新的IP限流器
func NewIPRateLimiter() *IPRateLimiter {
// 从环境变量读取限流配置(如果有)
rateLimit := DefaultRateLimit
ratePeriod := DefaultRatePeriodHours
if val, exists := os.LookupEnv("RATE_LIMIT"); exists {
if parsed, err := strconv.ParseFloat(val, 64); err == nil && parsed > 0 {
rateLimit = parsed
}
}
if val, exists := os.LookupEnv("RATE_PERIOD_HOURS"); exists {
if parsed, err := strconv.ParseFloat(val, 64); err == nil && parsed > 0 {
ratePeriod = parsed
}
}
// 从环境变量读取白名单(如果有)
whitelistIPs := WhitelistIPs
if val, exists := os.LookupEnv("IP_WHITELIST"); exists && val != "" {
whitelistIPs = append(whitelistIPs, strings.Split(val, ",")...)
}
// 从环境变量读取黑名单(如果有)
blacklistIPs := BlacklistIPs
if val, exists := os.LookupEnv("IP_BLACKLIST"); exists && val != "" {
blacklistIPs = append(blacklistIPs, strings.Split(val, ",")...)
}
// 解析白名单IP段
whitelist := make([]*net.IPNet, 0, len(whitelistIPs))
for _, item := range whitelistIPs {
if item = strings.TrimSpace(item); item != "" {
if !strings.Contains(item, "/") {
item = item + "/32" // 单个IP转为CIDR格式
}
_, ipnet, err := net.ParseCIDR(item)
if err == nil {
whitelist = append(whitelist, ipnet)
}
}
}
// 解析黑名单IP段
blacklist := make([]*net.IPNet, 0, len(blacklistIPs))
for _, item := range blacklistIPs {
if item = strings.TrimSpace(item); item != "" {
if !strings.Contains(item, "/") {
item = item + "/32" // 单个IP转为CIDR格式
}
_, ipnet, err := net.ParseCIDR(item)
if err == nil {
blacklist = append(blacklist, ipnet)
}
}
}
// 计算速率:将 "每N小时X个请求" 转换为 "每秒Y个请求"
// rate.Limit的单位是每秒允许的请求数
ratePerSecond := rate.Limit(rateLimit / (ratePeriod * 3600))
limiter := &IPRateLimiter{
ips: make(map[string]*rateLimiterEntry),
mu: &sync.RWMutex{},
r: ratePerSecond,
b: int(rateLimit), // 令牌桶容量设为允许的请求总数
whitelist: whitelist,
blacklist: blacklist,
}
// 启动定期清理goroutine
go limiter.cleanupRoutine()
return limiter
}
// cleanupRoutine 定期清理过期的限流器
func (i *IPRateLimiter) cleanupRoutine() {
ticker := time.NewTicker(CleanupInterval)
defer ticker.Stop()
for range ticker.C {
now := time.Now()
expired := make([]string, 0)
// 查找过期的条目
i.mu.RLock()
for ip, entry := range i.ips {
// 如果最后访问时间超过1小时认为过期
if now.Sub(entry.lastAccess) > 1*time.Hour {
expired = append(expired, ip)
}
}
i.mu.RUnlock()
// 如果有过期条目或者缓存过大,进行清理
if len(expired) > 0 || len(i.ips) > MaxIPCacheSize {
i.mu.Lock()
// 删除过期条目
for _, ip := range expired {
delete(i.ips, ip)
}
// 如果缓存仍然过大,全部清理
if len(i.ips) > MaxIPCacheSize {
i.ips = make(map[string]*rateLimiterEntry)
}
i.mu.Unlock()
}
}
}
// isIPInCIDRList 检查IP是否在CIDR列表中
func isIPInCIDRList(ip string, cidrList []*net.IPNet) bool {
parsedIP := net.ParseIP(ip)
if parsedIP == nil {
return false
}
for _, cidr := range cidrList {
if cidr.Contains(parsedIP) {
return true
}
}
return false
}
// GetLimiter 获取指定IP的限流器同时返回是否允许访问
func (i *IPRateLimiter) GetLimiter(ip string) (*rate.Limiter, bool) {
// 检查是否在黑名单中
if isIPInCIDRList(ip, i.blacklist) {
return nil, false // 黑名单中的IP不允许访问
}
// 检查是否在白名单中
if isIPInCIDRList(ip, i.whitelist) {
return rate.NewLimiter(rate.Inf, i.b), true // 白名单中的IP不受限制
}
// 从缓存获取限流器
i.mu.RLock()
entry, exists := i.ips[ip]
i.mu.RUnlock()
now := time.Now()
if !exists {
// 创建新的限流器
i.mu.Lock()
entry = &rateLimiterEntry{
limiter: rate.NewLimiter(i.r, i.b),
lastAccess: now,
}
i.ips[ip] = entry
i.mu.Unlock()
} else {
// 更新最后访问时间
i.mu.Lock()
entry.lastAccess = now
i.mu.Unlock()
}
return entry.limiter, true
}
// RateLimitMiddleware 速率限制中间件
func RateLimitMiddleware(limiter *IPRateLimiter) gin.HandlerFunc {
return func(c *gin.Context) {
// 获取客户端真实IP
var ip string
// 优先尝试从请求头获取真实IP
if forwarded := c.GetHeader("X-Forwarded-For"); forwarded != "" {
// X-Forwarded-For可能包含多个IP取第一个
ips := strings.Split(forwarded, ",")
ip = strings.TrimSpace(ips[0])
} else if realIP := c.GetHeader("X-Real-IP"); realIP != "" {
// 如果有X-Real-IP头
ip = realIP
} else if remoteIP := c.GetHeader("X-Original-Forwarded-For"); remoteIP != "" {
// 某些代理可能使用此头
ips := strings.Split(remoteIP, ",")
ip = strings.TrimSpace(ips[0])
} else {
// 回退到ClientIP方法
ip = c.ClientIP()
}
// 日志记录请求IP和头信息调试用
fmt.Printf("请求IP: %s, X-Forwarded-For: %s, X-Real-IP: %s\n",
ip,
c.GetHeader("X-Forwarded-For"),
c.GetHeader("X-Real-IP"))
// 获取限流器并检查是否允许访问
ipLimiter, allowed := limiter.GetLimiter(ip)
// 如果IP在黑名单中
if !allowed {
c.JSON(403, gin.H{
"error": "您的IP已被限制访问",
})
c.Abort()
return
}
// 检查是否允许本次请求
if !ipLimiter.Allow() {
c.JSON(429, gin.H{
"error": "请求频率过高每两小时最多下载10个镜像包",
})
c.Abort()
return
}
// 允许请求继续处理
c.Next()
}
}
// ApplyRateLimit 应用限流到特定路由
func ApplyRateLimit(router *gin.Engine, path string, method string, handler gin.HandlerFunc) {
// 创建限流器(如果未创建)
limiter := NewIPRateLimiter()
// 根据HTTP方法应用限流
switch method {
case "GET":
router.GET(path, RateLimitMiddleware(limiter), handler)
case "POST":
router.POST(path, RateLimitMiddleware(limiter), handler)
case "PUT":
router.PUT(path, RateLimitMiddleware(limiter), handler)
case "DELETE":
router.DELETE(path, RateLimitMiddleware(limiter), handler)
default:
router.Any(path, RateLimitMiddleware(limiter), handler)
}
}
// 示例:使用此限流器
/*
func initSkopeoRoutes(router *gin.Engine) {
os.MkdirAll("./temp", 0755)
router.GET("/ws/:taskId", handleWebSocket)
// 对下载API应用限流
ApplyRateLimit(router, "/api/download", "POST", handleDownload)
// 常规路由
router.GET("/api/task/:taskId", getTaskStatus)
router.GET("/api/files/:filename", serveFile)
go cleanupTempFiles()
}
*/

View File

@@ -99,8 +99,8 @@ func initSkopeoRoutes(router *gin.Engine) {
// WebSocket路由 - 用于实时获取进度
router.GET("/ws/:taskId", handleWebSocket)
// 创建下载任务
router.POST("/api/download", handleDownload)
// 创建下载任务,应用限流中间件
ApplyRateLimit(router, "/api/download", "POST", handleDownload)
// 获取任务状态
router.GET("/api/task/:taskId", getTaskStatus)
@@ -979,10 +979,34 @@ func fileExists(path string) bool {
// 清理过期临时文件
func cleanupTempFiles() {
for {
time.Sleep(1 * time.Hour)
// 遍历temp目录
// 创建两个定时器
hourlyTicker := time.NewTicker(1 * time.Hour)
fiveMinTicker := time.NewTicker(5 * time.Minute)
// 清理所有文件的函数
cleanAll := func() {
fmt.Printf("执行清理所有临时文件\n")
entries, err := os.ReadDir("./temp")
if err == nil {
for _, entry := range entries {
entryPath := filepath.Join("./temp", entry.Name())
info, err := entry.Info()
if err == nil {
if info.IsDir() {
os.RemoveAll(entryPath)
} else {
os.Remove(entryPath)
}
}
}
} else {
fmt.Printf("清理临时文件失败: %v\n", err)
}
}
// 检查文件大小并在需要时清理
checkSizeAndClean := func() {
var totalSize int64 = 0
err := filepath.Walk("./temp", func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
@@ -993,20 +1017,36 @@ func cleanupTempFiles() {
return nil
}
// 如果文件或目录超过2小时未修改则删除
if time.Since(info.ModTime()) > 2*time.Hour {
if info.IsDir() {
os.RemoveAll(path)
return filepath.SkipDir
}
os.Remove(path)
if !info.IsDir() {
totalSize += info.Size()
}
return nil
})
if err != nil {
fmt.Printf("清理临时文件失败: %v\n", err)
fmt.Printf("计算临时文件总大小失败: %v\n", err)
return
}
// 如果总大小超过10GB清理所有文件防止恶意下载导致磁盘爆满
if totalSize > 10*1024*1024*1024 { // 15GB
fmt.Printf("临时文件总大小超过10GB (当前: %.2f GB),清理所有文件\n", float64(totalSize)/(1024*1024*1024))
cleanAll()
} else {
fmt.Printf("临时文件总大小: %.2f GB\n", float64(totalSize)/(1024*1024*1024))
}
}
// 主循环
for {
select {
case <-hourlyTicker.C:
// 每小时清理所有文件
cleanAll()
case <-fiveMinTicker.C:
// 每5分钟检查一次总文件大小
checkSizeAndClean()
}
}
}