增加IP限流中间件
This commit is contained in:
316
ghproxy/ratelimiter.go
Normal file
316
ghproxy/ratelimiter.go
Normal file
@@ -0,0 +1,316 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
// IP限流配置
|
||||
var (
|
||||
// 默认限流:每个IP每1小时允许20个请求
|
||||
DefaultRateLimit = 20.0 // 默认限制请求数
|
||||
DefaultRatePeriodHours = 1.0 // 默认时间周期(小时)
|
||||
|
||||
// 白名单列表,支持IP和CIDR格式,如:"192.168.1.1", "10.0.0.0/8"
|
||||
WhitelistIPs = []string{
|
||||
"127.0.0.1", // 本地回环地址
|
||||
"10.0.0.0/8", // 内网地址段
|
||||
"172.16.0.0/12", // 内网地址段
|
||||
"192.168.0.0/16", // 内网地址段
|
||||
}
|
||||
|
||||
// 黑名单列表,支持IP和CIDR格式
|
||||
BlacklistIPs = []string{
|
||||
// 示例: "1.2.3.4", "5.6.7.0/24"
|
||||
}
|
||||
|
||||
// 清理间隔:多久清理一次过期的限流器
|
||||
CleanupInterval = 1 * time.Hour
|
||||
|
||||
// IP限流器缓存上限,超过此数量将触发清理
|
||||
MaxIPCacheSize = 10000
|
||||
)
|
||||
|
||||
// IPRateLimiter 定义IP限流器结构
|
||||
type IPRateLimiter struct {
|
||||
ips map[string]*rateLimiterEntry // IP到限流器的映射
|
||||
mu *sync.RWMutex // 读写锁,保证并发安全
|
||||
r rate.Limit // 速率限制(每秒允许的请求数)
|
||||
b int // 令牌桶容量(突发请求数)
|
||||
whitelist []*net.IPNet // 白名单IP段
|
||||
blacklist []*net.IPNet // 黑名单IP段
|
||||
}
|
||||
|
||||
// rateLimiterEntry 限流器条目,包含限流器和最后访问时间
|
||||
type rateLimiterEntry struct {
|
||||
limiter *rate.Limiter // 限流器
|
||||
lastAccess time.Time // 最后访问时间
|
||||
}
|
||||
|
||||
// NewIPRateLimiter 创建新的IP限流器
|
||||
func NewIPRateLimiter() *IPRateLimiter {
|
||||
// 从环境变量读取限流配置(如果有)
|
||||
rateLimit := DefaultRateLimit
|
||||
ratePeriod := DefaultRatePeriodHours
|
||||
|
||||
if val, exists := os.LookupEnv("RATE_LIMIT"); exists {
|
||||
if parsed, err := strconv.ParseFloat(val, 64); err == nil && parsed > 0 {
|
||||
rateLimit = parsed
|
||||
}
|
||||
}
|
||||
|
||||
if val, exists := os.LookupEnv("RATE_PERIOD_HOURS"); exists {
|
||||
if parsed, err := strconv.ParseFloat(val, 64); err == nil && parsed > 0 {
|
||||
ratePeriod = parsed
|
||||
}
|
||||
}
|
||||
|
||||
// 从环境变量读取白名单(如果有)
|
||||
whitelistIPs := WhitelistIPs
|
||||
if val, exists := os.LookupEnv("IP_WHITELIST"); exists && val != "" {
|
||||
whitelistIPs = append(whitelistIPs, strings.Split(val, ",")...)
|
||||
}
|
||||
|
||||
// 从环境变量读取黑名单(如果有)
|
||||
blacklistIPs := BlacklistIPs
|
||||
if val, exists := os.LookupEnv("IP_BLACKLIST"); exists && val != "" {
|
||||
blacklistIPs = append(blacklistIPs, strings.Split(val, ",")...)
|
||||
}
|
||||
|
||||
// 解析白名单IP段
|
||||
whitelist := make([]*net.IPNet, 0, len(whitelistIPs))
|
||||
for _, item := range whitelistIPs {
|
||||
if item = strings.TrimSpace(item); item != "" {
|
||||
if !strings.Contains(item, "/") {
|
||||
item = item + "/32" // 单个IP转为CIDR格式
|
||||
}
|
||||
_, ipnet, err := net.ParseCIDR(item)
|
||||
if err == nil {
|
||||
whitelist = append(whitelist, ipnet)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 解析黑名单IP段
|
||||
blacklist := make([]*net.IPNet, 0, len(blacklistIPs))
|
||||
for _, item := range blacklistIPs {
|
||||
if item = strings.TrimSpace(item); item != "" {
|
||||
if !strings.Contains(item, "/") {
|
||||
item = item + "/32" // 单个IP转为CIDR格式
|
||||
}
|
||||
_, ipnet, err := net.ParseCIDR(item)
|
||||
if err == nil {
|
||||
blacklist = append(blacklist, ipnet)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 计算速率:将 "每N小时X个请求" 转换为 "每秒Y个请求"
|
||||
// rate.Limit的单位是每秒允许的请求数
|
||||
ratePerSecond := rate.Limit(rateLimit / (ratePeriod * 3600))
|
||||
|
||||
limiter := &IPRateLimiter{
|
||||
ips: make(map[string]*rateLimiterEntry),
|
||||
mu: &sync.RWMutex{},
|
||||
r: ratePerSecond,
|
||||
b: int(rateLimit), // 令牌桶容量设为允许的请求总数
|
||||
whitelist: whitelist,
|
||||
blacklist: blacklist,
|
||||
}
|
||||
|
||||
// 启动定期清理goroutine
|
||||
go limiter.cleanupRoutine()
|
||||
|
||||
return limiter
|
||||
}
|
||||
|
||||
// cleanupRoutine 定期清理过期的限流器
|
||||
func (i *IPRateLimiter) cleanupRoutine() {
|
||||
ticker := time.NewTicker(CleanupInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
now := time.Now()
|
||||
expired := make([]string, 0)
|
||||
|
||||
// 查找过期的条目
|
||||
i.mu.RLock()
|
||||
for ip, entry := range i.ips {
|
||||
// 如果最后访问时间超过1小时,认为过期
|
||||
if now.Sub(entry.lastAccess) > 1*time.Hour {
|
||||
expired = append(expired, ip)
|
||||
}
|
||||
}
|
||||
i.mu.RUnlock()
|
||||
|
||||
// 如果有过期条目或者缓存过大,进行清理
|
||||
if len(expired) > 0 || len(i.ips) > MaxIPCacheSize {
|
||||
i.mu.Lock()
|
||||
// 删除过期条目
|
||||
for _, ip := range expired {
|
||||
delete(i.ips, ip)
|
||||
}
|
||||
|
||||
// 如果缓存仍然过大,全部清理
|
||||
if len(i.ips) > MaxIPCacheSize {
|
||||
i.ips = make(map[string]*rateLimiterEntry)
|
||||
}
|
||||
i.mu.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// isIPInCIDRList 检查IP是否在CIDR列表中
|
||||
func isIPInCIDRList(ip string, cidrList []*net.IPNet) bool {
|
||||
parsedIP := net.ParseIP(ip)
|
||||
if parsedIP == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, cidr := range cidrList {
|
||||
if cidr.Contains(parsedIP) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// GetLimiter 获取指定IP的限流器,同时返回是否允许访问
|
||||
func (i *IPRateLimiter) GetLimiter(ip string) (*rate.Limiter, bool) {
|
||||
// 检查是否在黑名单中
|
||||
if isIPInCIDRList(ip, i.blacklist) {
|
||||
return nil, false // 黑名单中的IP不允许访问
|
||||
}
|
||||
|
||||
// 检查是否在白名单中
|
||||
if isIPInCIDRList(ip, i.whitelist) {
|
||||
return rate.NewLimiter(rate.Inf, i.b), true // 白名单中的IP不受限制
|
||||
}
|
||||
|
||||
// 从缓存获取限流器
|
||||
i.mu.RLock()
|
||||
entry, exists := i.ips[ip]
|
||||
i.mu.RUnlock()
|
||||
|
||||
now := time.Now()
|
||||
|
||||
if !exists {
|
||||
// 创建新的限流器
|
||||
i.mu.Lock()
|
||||
entry = &rateLimiterEntry{
|
||||
limiter: rate.NewLimiter(i.r, i.b),
|
||||
lastAccess: now,
|
||||
}
|
||||
i.ips[ip] = entry
|
||||
i.mu.Unlock()
|
||||
} else {
|
||||
// 更新最后访问时间
|
||||
i.mu.Lock()
|
||||
entry.lastAccess = now
|
||||
i.mu.Unlock()
|
||||
}
|
||||
|
||||
return entry.limiter, true
|
||||
}
|
||||
|
||||
// RateLimitMiddleware 速率限制中间件
|
||||
func RateLimitMiddleware(limiter *IPRateLimiter) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 获取客户端真实IP
|
||||
var ip string
|
||||
|
||||
// 优先尝试从请求头获取真实IP
|
||||
if forwarded := c.GetHeader("X-Forwarded-For"); forwarded != "" {
|
||||
// X-Forwarded-For可能包含多个IP,取第一个
|
||||
ips := strings.Split(forwarded, ",")
|
||||
ip = strings.TrimSpace(ips[0])
|
||||
} else if realIP := c.GetHeader("X-Real-IP"); realIP != "" {
|
||||
// 如果有X-Real-IP头
|
||||
ip = realIP
|
||||
} else if remoteIP := c.GetHeader("X-Original-Forwarded-For"); remoteIP != "" {
|
||||
// 某些代理可能使用此头
|
||||
ips := strings.Split(remoteIP, ",")
|
||||
ip = strings.TrimSpace(ips[0])
|
||||
} else {
|
||||
// 回退到ClientIP方法
|
||||
ip = c.ClientIP()
|
||||
}
|
||||
|
||||
// 日志记录请求IP和头信息(调试用)
|
||||
fmt.Printf("请求IP: %s, X-Forwarded-For: %s, X-Real-IP: %s\n",
|
||||
ip,
|
||||
c.GetHeader("X-Forwarded-For"),
|
||||
c.GetHeader("X-Real-IP"))
|
||||
|
||||
// 获取限流器并检查是否允许访问
|
||||
ipLimiter, allowed := limiter.GetLimiter(ip)
|
||||
|
||||
// 如果IP在黑名单中
|
||||
if !allowed {
|
||||
c.JSON(403, gin.H{
|
||||
"error": "您的IP已被限制访问",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 检查是否允许本次请求
|
||||
if !ipLimiter.Allow() {
|
||||
c.JSON(429, gin.H{
|
||||
"error": "请求频率过高,每两小时最多下载10个镜像包",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 允许请求继续处理
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// ApplyRateLimit 应用限流到特定路由
|
||||
func ApplyRateLimit(router *gin.Engine, path string, method string, handler gin.HandlerFunc) {
|
||||
// 创建限流器(如果未创建)
|
||||
limiter := NewIPRateLimiter()
|
||||
|
||||
// 根据HTTP方法应用限流
|
||||
switch method {
|
||||
case "GET":
|
||||
router.GET(path, RateLimitMiddleware(limiter), handler)
|
||||
case "POST":
|
||||
router.POST(path, RateLimitMiddleware(limiter), handler)
|
||||
case "PUT":
|
||||
router.PUT(path, RateLimitMiddleware(limiter), handler)
|
||||
case "DELETE":
|
||||
router.DELETE(path, RateLimitMiddleware(limiter), handler)
|
||||
default:
|
||||
router.Any(path, RateLimitMiddleware(limiter), handler)
|
||||
}
|
||||
}
|
||||
|
||||
// 示例:使用此限流器
|
||||
/*
|
||||
func initSkopeoRoutes(router *gin.Engine) {
|
||||
|
||||
os.MkdirAll("./temp", 0755)
|
||||
|
||||
router.GET("/ws/:taskId", handleWebSocket)
|
||||
|
||||
// 对下载API应用限流
|
||||
ApplyRateLimit(router, "/api/download", "POST", handleDownload)
|
||||
|
||||
// 常规路由
|
||||
router.GET("/api/task/:taskId", getTaskStatus)
|
||||
router.GET("/api/files/:filename", serveFile)
|
||||
|
||||
go cleanupTempFiles()
|
||||
}
|
||||
*/
|
||||
@@ -99,8 +99,8 @@ func initSkopeoRoutes(router *gin.Engine) {
|
||||
// WebSocket路由 - 用于实时获取进度
|
||||
router.GET("/ws/:taskId", handleWebSocket)
|
||||
|
||||
// 创建下载任务
|
||||
router.POST("/api/download", handleDownload)
|
||||
// 创建下载任务,应用限流中间件
|
||||
ApplyRateLimit(router, "/api/download", "POST", handleDownload)
|
||||
|
||||
// 获取任务状态
|
||||
router.GET("/api/task/:taskId", getTaskStatus)
|
||||
@@ -979,10 +979,34 @@ func fileExists(path string) bool {
|
||||
|
||||
// 清理过期临时文件
|
||||
func cleanupTempFiles() {
|
||||
for {
|
||||
time.Sleep(1 * time.Hour)
|
||||
|
||||
// 遍历temp目录
|
||||
// 创建两个定时器
|
||||
hourlyTicker := time.NewTicker(1 * time.Hour)
|
||||
fiveMinTicker := time.NewTicker(5 * time.Minute)
|
||||
|
||||
// 清理所有文件的函数
|
||||
cleanAll := func() {
|
||||
fmt.Printf("执行清理所有临时文件\n")
|
||||
entries, err := os.ReadDir("./temp")
|
||||
if err == nil {
|
||||
for _, entry := range entries {
|
||||
entryPath := filepath.Join("./temp", entry.Name())
|
||||
info, err := entry.Info()
|
||||
if err == nil {
|
||||
if info.IsDir() {
|
||||
os.RemoveAll(entryPath)
|
||||
} else {
|
||||
os.Remove(entryPath)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
fmt.Printf("清理临时文件失败: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 检查文件大小并在需要时清理
|
||||
checkSizeAndClean := func() {
|
||||
var totalSize int64 = 0
|
||||
err := filepath.Walk("./temp", func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -993,20 +1017,36 @@ func cleanupTempFiles() {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 如果文件或目录超过2小时未修改,则删除
|
||||
if time.Since(info.ModTime()) > 2*time.Hour {
|
||||
if info.IsDir() {
|
||||
os.RemoveAll(path)
|
||||
return filepath.SkipDir
|
||||
}
|
||||
os.Remove(path)
|
||||
if !info.IsDir() {
|
||||
totalSize += info.Size()
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
fmt.Printf("清理临时文件失败: %v\n", err)
|
||||
fmt.Printf("计算临时文件总大小失败: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 如果总大小超过10GB,清理所有文件,防止恶意下载导致磁盘爆满
|
||||
if totalSize > 10*1024*1024*1024 { // 15GB
|
||||
fmt.Printf("临时文件总大小超过10GB (当前: %.2f GB),清理所有文件\n", float64(totalSize)/(1024*1024*1024))
|
||||
cleanAll()
|
||||
} else {
|
||||
fmt.Printf("临时文件总大小: %.2f GB\n", float64(totalSize)/(1024*1024*1024))
|
||||
}
|
||||
}
|
||||
|
||||
// 主循环
|
||||
for {
|
||||
select {
|
||||
case <-hourlyTicker.C:
|
||||
// 每小时清理所有文件
|
||||
cleanAll()
|
||||
case <-fiveMinTicker.C:
|
||||
// 每5分钟检查一次总文件大小
|
||||
checkSizeAndClean()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user