增加IP限流中间件
This commit is contained in:
316
ghproxy/ratelimiter.go
Normal file
316
ghproxy/ratelimiter.go
Normal file
@@ -0,0 +1,316 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"golang.org/x/time/rate"
|
||||||
|
)
|
||||||
|
|
||||||
|
// IP限流配置
|
||||||
|
var (
|
||||||
|
// 默认限流:每个IP每1小时允许20个请求
|
||||||
|
DefaultRateLimit = 20.0 // 默认限制请求数
|
||||||
|
DefaultRatePeriodHours = 1.0 // 默认时间周期(小时)
|
||||||
|
|
||||||
|
// 白名单列表,支持IP和CIDR格式,如:"192.168.1.1", "10.0.0.0/8"
|
||||||
|
WhitelistIPs = []string{
|
||||||
|
"127.0.0.1", // 本地回环地址
|
||||||
|
"10.0.0.0/8", // 内网地址段
|
||||||
|
"172.16.0.0/12", // 内网地址段
|
||||||
|
"192.168.0.0/16", // 内网地址段
|
||||||
|
}
|
||||||
|
|
||||||
|
// 黑名单列表,支持IP和CIDR格式
|
||||||
|
BlacklistIPs = []string{
|
||||||
|
// 示例: "1.2.3.4", "5.6.7.0/24"
|
||||||
|
}
|
||||||
|
|
||||||
|
// 清理间隔:多久清理一次过期的限流器
|
||||||
|
CleanupInterval = 1 * time.Hour
|
||||||
|
|
||||||
|
// IP限流器缓存上限,超过此数量将触发清理
|
||||||
|
MaxIPCacheSize = 10000
|
||||||
|
)
|
||||||
|
|
||||||
|
// IPRateLimiter 定义IP限流器结构
|
||||||
|
type IPRateLimiter struct {
|
||||||
|
ips map[string]*rateLimiterEntry // IP到限流器的映射
|
||||||
|
mu *sync.RWMutex // 读写锁,保证并发安全
|
||||||
|
r rate.Limit // 速率限制(每秒允许的请求数)
|
||||||
|
b int // 令牌桶容量(突发请求数)
|
||||||
|
whitelist []*net.IPNet // 白名单IP段
|
||||||
|
blacklist []*net.IPNet // 黑名单IP段
|
||||||
|
}
|
||||||
|
|
||||||
|
// rateLimiterEntry 限流器条目,包含限流器和最后访问时间
|
||||||
|
type rateLimiterEntry struct {
|
||||||
|
limiter *rate.Limiter // 限流器
|
||||||
|
lastAccess time.Time // 最后访问时间
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewIPRateLimiter 创建新的IP限流器
|
||||||
|
func NewIPRateLimiter() *IPRateLimiter {
|
||||||
|
// 从环境变量读取限流配置(如果有)
|
||||||
|
rateLimit := DefaultRateLimit
|
||||||
|
ratePeriod := DefaultRatePeriodHours
|
||||||
|
|
||||||
|
if val, exists := os.LookupEnv("RATE_LIMIT"); exists {
|
||||||
|
if parsed, err := strconv.ParseFloat(val, 64); err == nil && parsed > 0 {
|
||||||
|
rateLimit = parsed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if val, exists := os.LookupEnv("RATE_PERIOD_HOURS"); exists {
|
||||||
|
if parsed, err := strconv.ParseFloat(val, 64); err == nil && parsed > 0 {
|
||||||
|
ratePeriod = parsed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 从环境变量读取白名单(如果有)
|
||||||
|
whitelistIPs := WhitelistIPs
|
||||||
|
if val, exists := os.LookupEnv("IP_WHITELIST"); exists && val != "" {
|
||||||
|
whitelistIPs = append(whitelistIPs, strings.Split(val, ",")...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 从环境变量读取黑名单(如果有)
|
||||||
|
blacklistIPs := BlacklistIPs
|
||||||
|
if val, exists := os.LookupEnv("IP_BLACKLIST"); exists && val != "" {
|
||||||
|
blacklistIPs = append(blacklistIPs, strings.Split(val, ",")...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析白名单IP段
|
||||||
|
whitelist := make([]*net.IPNet, 0, len(whitelistIPs))
|
||||||
|
for _, item := range whitelistIPs {
|
||||||
|
if item = strings.TrimSpace(item); item != "" {
|
||||||
|
if !strings.Contains(item, "/") {
|
||||||
|
item = item + "/32" // 单个IP转为CIDR格式
|
||||||
|
}
|
||||||
|
_, ipnet, err := net.ParseCIDR(item)
|
||||||
|
if err == nil {
|
||||||
|
whitelist = append(whitelist, ipnet)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析黑名单IP段
|
||||||
|
blacklist := make([]*net.IPNet, 0, len(blacklistIPs))
|
||||||
|
for _, item := range blacklistIPs {
|
||||||
|
if item = strings.TrimSpace(item); item != "" {
|
||||||
|
if !strings.Contains(item, "/") {
|
||||||
|
item = item + "/32" // 单个IP转为CIDR格式
|
||||||
|
}
|
||||||
|
_, ipnet, err := net.ParseCIDR(item)
|
||||||
|
if err == nil {
|
||||||
|
blacklist = append(blacklist, ipnet)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 计算速率:将 "每N小时X个请求" 转换为 "每秒Y个请求"
|
||||||
|
// rate.Limit的单位是每秒允许的请求数
|
||||||
|
ratePerSecond := rate.Limit(rateLimit / (ratePeriod * 3600))
|
||||||
|
|
||||||
|
limiter := &IPRateLimiter{
|
||||||
|
ips: make(map[string]*rateLimiterEntry),
|
||||||
|
mu: &sync.RWMutex{},
|
||||||
|
r: ratePerSecond,
|
||||||
|
b: int(rateLimit), // 令牌桶容量设为允许的请求总数
|
||||||
|
whitelist: whitelist,
|
||||||
|
blacklist: blacklist,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 启动定期清理goroutine
|
||||||
|
go limiter.cleanupRoutine()
|
||||||
|
|
||||||
|
return limiter
|
||||||
|
}
|
||||||
|
|
||||||
|
// cleanupRoutine 定期清理过期的限流器
|
||||||
|
func (i *IPRateLimiter) cleanupRoutine() {
|
||||||
|
ticker := time.NewTicker(CleanupInterval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for range ticker.C {
|
||||||
|
now := time.Now()
|
||||||
|
expired := make([]string, 0)
|
||||||
|
|
||||||
|
// 查找过期的条目
|
||||||
|
i.mu.RLock()
|
||||||
|
for ip, entry := range i.ips {
|
||||||
|
// 如果最后访问时间超过1小时,认为过期
|
||||||
|
if now.Sub(entry.lastAccess) > 1*time.Hour {
|
||||||
|
expired = append(expired, ip)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
i.mu.RUnlock()
|
||||||
|
|
||||||
|
// 如果有过期条目或者缓存过大,进行清理
|
||||||
|
if len(expired) > 0 || len(i.ips) > MaxIPCacheSize {
|
||||||
|
i.mu.Lock()
|
||||||
|
// 删除过期条目
|
||||||
|
for _, ip := range expired {
|
||||||
|
delete(i.ips, ip)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果缓存仍然过大,全部清理
|
||||||
|
if len(i.ips) > MaxIPCacheSize {
|
||||||
|
i.ips = make(map[string]*rateLimiterEntry)
|
||||||
|
}
|
||||||
|
i.mu.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// isIPInCIDRList 检查IP是否在CIDR列表中
|
||||||
|
func isIPInCIDRList(ip string, cidrList []*net.IPNet) bool {
|
||||||
|
parsedIP := net.ParseIP(ip)
|
||||||
|
if parsedIP == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, cidr := range cidrList {
|
||||||
|
if cidr.Contains(parsedIP) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetLimiter 获取指定IP的限流器,同时返回是否允许访问
|
||||||
|
func (i *IPRateLimiter) GetLimiter(ip string) (*rate.Limiter, bool) {
|
||||||
|
// 检查是否在黑名单中
|
||||||
|
if isIPInCIDRList(ip, i.blacklist) {
|
||||||
|
return nil, false // 黑名单中的IP不允许访问
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查是否在白名单中
|
||||||
|
if isIPInCIDRList(ip, i.whitelist) {
|
||||||
|
return rate.NewLimiter(rate.Inf, i.b), true // 白名单中的IP不受限制
|
||||||
|
}
|
||||||
|
|
||||||
|
// 从缓存获取限流器
|
||||||
|
i.mu.RLock()
|
||||||
|
entry, exists := i.ips[ip]
|
||||||
|
i.mu.RUnlock()
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
if !exists {
|
||||||
|
// 创建新的限流器
|
||||||
|
i.mu.Lock()
|
||||||
|
entry = &rateLimiterEntry{
|
||||||
|
limiter: rate.NewLimiter(i.r, i.b),
|
||||||
|
lastAccess: now,
|
||||||
|
}
|
||||||
|
i.ips[ip] = entry
|
||||||
|
i.mu.Unlock()
|
||||||
|
} else {
|
||||||
|
// 更新最后访问时间
|
||||||
|
i.mu.Lock()
|
||||||
|
entry.lastAccess = now
|
||||||
|
i.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
return entry.limiter, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// RateLimitMiddleware 速率限制中间件
|
||||||
|
func RateLimitMiddleware(limiter *IPRateLimiter) gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
// 获取客户端真实IP
|
||||||
|
var ip string
|
||||||
|
|
||||||
|
// 优先尝试从请求头获取真实IP
|
||||||
|
if forwarded := c.GetHeader("X-Forwarded-For"); forwarded != "" {
|
||||||
|
// X-Forwarded-For可能包含多个IP,取第一个
|
||||||
|
ips := strings.Split(forwarded, ",")
|
||||||
|
ip = strings.TrimSpace(ips[0])
|
||||||
|
} else if realIP := c.GetHeader("X-Real-IP"); realIP != "" {
|
||||||
|
// 如果有X-Real-IP头
|
||||||
|
ip = realIP
|
||||||
|
} else if remoteIP := c.GetHeader("X-Original-Forwarded-For"); remoteIP != "" {
|
||||||
|
// 某些代理可能使用此头
|
||||||
|
ips := strings.Split(remoteIP, ",")
|
||||||
|
ip = strings.TrimSpace(ips[0])
|
||||||
|
} else {
|
||||||
|
// 回退到ClientIP方法
|
||||||
|
ip = c.ClientIP()
|
||||||
|
}
|
||||||
|
|
||||||
|
// 日志记录请求IP和头信息(调试用)
|
||||||
|
fmt.Printf("请求IP: %s, X-Forwarded-For: %s, X-Real-IP: %s\n",
|
||||||
|
ip,
|
||||||
|
c.GetHeader("X-Forwarded-For"),
|
||||||
|
c.GetHeader("X-Real-IP"))
|
||||||
|
|
||||||
|
// 获取限流器并检查是否允许访问
|
||||||
|
ipLimiter, allowed := limiter.GetLimiter(ip)
|
||||||
|
|
||||||
|
// 如果IP在黑名单中
|
||||||
|
if !allowed {
|
||||||
|
c.JSON(403, gin.H{
|
||||||
|
"error": "您的IP已被限制访问",
|
||||||
|
})
|
||||||
|
c.Abort()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查是否允许本次请求
|
||||||
|
if !ipLimiter.Allow() {
|
||||||
|
c.JSON(429, gin.H{
|
||||||
|
"error": "请求频率过高,每两小时最多下载10个镜像包",
|
||||||
|
})
|
||||||
|
c.Abort()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 允许请求继续处理
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ApplyRateLimit 应用限流到特定路由
|
||||||
|
func ApplyRateLimit(router *gin.Engine, path string, method string, handler gin.HandlerFunc) {
|
||||||
|
// 创建限流器(如果未创建)
|
||||||
|
limiter := NewIPRateLimiter()
|
||||||
|
|
||||||
|
// 根据HTTP方法应用限流
|
||||||
|
switch method {
|
||||||
|
case "GET":
|
||||||
|
router.GET(path, RateLimitMiddleware(limiter), handler)
|
||||||
|
case "POST":
|
||||||
|
router.POST(path, RateLimitMiddleware(limiter), handler)
|
||||||
|
case "PUT":
|
||||||
|
router.PUT(path, RateLimitMiddleware(limiter), handler)
|
||||||
|
case "DELETE":
|
||||||
|
router.DELETE(path, RateLimitMiddleware(limiter), handler)
|
||||||
|
default:
|
||||||
|
router.Any(path, RateLimitMiddleware(limiter), handler)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 示例:使用此限流器
|
||||||
|
/*
|
||||||
|
func initSkopeoRoutes(router *gin.Engine) {
|
||||||
|
|
||||||
|
os.MkdirAll("./temp", 0755)
|
||||||
|
|
||||||
|
router.GET("/ws/:taskId", handleWebSocket)
|
||||||
|
|
||||||
|
// 对下载API应用限流
|
||||||
|
ApplyRateLimit(router, "/api/download", "POST", handleDownload)
|
||||||
|
|
||||||
|
// 常规路由
|
||||||
|
router.GET("/api/task/:taskId", getTaskStatus)
|
||||||
|
router.GET("/api/files/:filename", serveFile)
|
||||||
|
|
||||||
|
go cleanupTempFiles()
|
||||||
|
}
|
||||||
|
*/
|
||||||
@@ -99,8 +99,8 @@ func initSkopeoRoutes(router *gin.Engine) {
|
|||||||
// WebSocket路由 - 用于实时获取进度
|
// WebSocket路由 - 用于实时获取进度
|
||||||
router.GET("/ws/:taskId", handleWebSocket)
|
router.GET("/ws/:taskId", handleWebSocket)
|
||||||
|
|
||||||
// 创建下载任务
|
// 创建下载任务,应用限流中间件
|
||||||
router.POST("/api/download", handleDownload)
|
ApplyRateLimit(router, "/api/download", "POST", handleDownload)
|
||||||
|
|
||||||
// 获取任务状态
|
// 获取任务状态
|
||||||
router.GET("/api/task/:taskId", getTaskStatus)
|
router.GET("/api/task/:taskId", getTaskStatus)
|
||||||
@@ -979,10 +979,34 @@ func fileExists(path string) bool {
|
|||||||
|
|
||||||
// 清理过期临时文件
|
// 清理过期临时文件
|
||||||
func cleanupTempFiles() {
|
func cleanupTempFiles() {
|
||||||
for {
|
// 创建两个定时器
|
||||||
time.Sleep(1 * time.Hour)
|
hourlyTicker := time.NewTicker(1 * time.Hour)
|
||||||
|
fiveMinTicker := time.NewTicker(5 * time.Minute)
|
||||||
// 遍历temp目录
|
|
||||||
|
// 清理所有文件的函数
|
||||||
|
cleanAll := func() {
|
||||||
|
fmt.Printf("执行清理所有临时文件\n")
|
||||||
|
entries, err := os.ReadDir("./temp")
|
||||||
|
if err == nil {
|
||||||
|
for _, entry := range entries {
|
||||||
|
entryPath := filepath.Join("./temp", entry.Name())
|
||||||
|
info, err := entry.Info()
|
||||||
|
if err == nil {
|
||||||
|
if info.IsDir() {
|
||||||
|
os.RemoveAll(entryPath)
|
||||||
|
} else {
|
||||||
|
os.Remove(entryPath)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
fmt.Printf("清理临时文件失败: %v\n", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查文件大小并在需要时清理
|
||||||
|
checkSizeAndClean := func() {
|
||||||
|
var totalSize int64 = 0
|
||||||
err := filepath.Walk("./temp", func(path string, info os.FileInfo, err error) error {
|
err := filepath.Walk("./temp", func(path string, info os.FileInfo, err error) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -993,20 +1017,36 @@ func cleanupTempFiles() {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// 如果文件或目录超过2小时未修改,则删除
|
if !info.IsDir() {
|
||||||
if time.Since(info.ModTime()) > 2*time.Hour {
|
totalSize += info.Size()
|
||||||
if info.IsDir() {
|
|
||||||
os.RemoveAll(path)
|
|
||||||
return filepath.SkipDir
|
|
||||||
}
|
|
||||||
os.Remove(path)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("清理临时文件失败: %v\n", err)
|
fmt.Printf("计算临时文件总大小失败: %v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果总大小超过10GB,清理所有文件,防止恶意下载导致磁盘爆满
|
||||||
|
if totalSize > 10*1024*1024*1024 { // 15GB
|
||||||
|
fmt.Printf("临时文件总大小超过10GB (当前: %.2f GB),清理所有文件\n", float64(totalSize)/(1024*1024*1024))
|
||||||
|
cleanAll()
|
||||||
|
} else {
|
||||||
|
fmt.Printf("临时文件总大小: %.2f GB\n", float64(totalSize)/(1024*1024*1024))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 主循环
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-hourlyTicker.C:
|
||||||
|
// 每小时清理所有文件
|
||||||
|
cleanAll()
|
||||||
|
case <-fiveMinTicker.C:
|
||||||
|
// 每5分钟检查一次总文件大小
|
||||||
|
checkSizeAndClean()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user