Compare commits
16 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6193a07837 | ||
|
|
bb2f7bcda6 | ||
|
|
4ec36da9b5 | ||
|
|
83a1211067 | ||
|
|
367038a4b5 | ||
|
|
a0df3b1a54 | ||
|
|
70bf552daf | ||
|
|
d5e2abdcff | ||
|
|
07a926902a | ||
|
|
1881b5b1ba | ||
|
|
75e37158ef | ||
|
|
506de49586 | ||
|
|
dd704dc499 | ||
|
|
9a8b850bce | ||
|
|
187e842445 | ||
|
|
badafd2899 |
3
.gitattributes
vendored
3
.gitattributes
vendored
@@ -1 +1,2 @@
|
|||||||
* text=auto eol=lf
|
* text=auto eol=lf
|
||||||
|
*.html linguist-vendored
|
||||||
|
|||||||
8
.github/workflows/release.yml
vendored
8
.github/workflows/release.yml
vendored
@@ -72,7 +72,7 @@ jobs:
|
|||||||
cp hubproxy.service build/hubproxy/
|
cp hubproxy.service build/hubproxy/
|
||||||
|
|
||||||
# 复制安装脚本
|
# 复制安装脚本
|
||||||
cp install-service.sh build/hubproxy/
|
cp install.sh build/hubproxy/
|
||||||
|
|
||||||
# 创建README文件
|
# 创建README文件
|
||||||
cat > build/hubproxy/README.md << 'EOF'
|
cat > build/hubproxy/README.md << 'EOF'
|
||||||
@@ -88,13 +88,13 @@ jobs:
|
|||||||
# Linux AMD64 包
|
# Linux AMD64 包
|
||||||
mkdir -p linux-amd64/hubproxy
|
mkdir -p linux-amd64/hubproxy
|
||||||
cp hubproxy/hubproxy-linux-amd64 linux-amd64/hubproxy/hubproxy
|
cp hubproxy/hubproxy-linux-amd64 linux-amd64/hubproxy/hubproxy
|
||||||
cp hubproxy/config.toml hubproxy/hubproxy.service hubproxy/install-service.sh hubproxy/README.md linux-amd64/hubproxy/
|
cp hubproxy/config.toml hubproxy/hubproxy.service hubproxy/install.sh hubproxy/README.md linux-amd64/hubproxy/
|
||||||
tar -czf hubproxy-${{ steps.version.outputs.version }}-linux-amd64.tar.gz -C linux-amd64 hubproxy
|
tar -czf hubproxy-${{ steps.version.outputs.version }}-linux-amd64.tar.gz -C linux-amd64 hubproxy
|
||||||
|
|
||||||
# Linux ARM64 包
|
# Linux ARM64 包
|
||||||
mkdir -p linux-arm64/hubproxy
|
mkdir -p linux-arm64/hubproxy
|
||||||
cp hubproxy/hubproxy-linux-arm64 linux-arm64/hubproxy/hubproxy
|
cp hubproxy/hubproxy-linux-arm64 linux-arm64/hubproxy/hubproxy
|
||||||
cp hubproxy/config.toml hubproxy/hubproxy.service hubproxy/install-service.sh hubproxy/README.md linux-arm64/hubproxy/
|
cp hubproxy/config.toml hubproxy/hubproxy.service hubproxy/install.sh hubproxy/README.md linux-arm64/hubproxy/
|
||||||
tar -czf hubproxy-${{ steps.version.outputs.version }}-linux-arm64.tar.gz -C linux-arm64 hubproxy
|
tar -czf hubproxy-${{ steps.version.outputs.version }}-linux-arm64.tar.gz -C linux-arm64 hubproxy
|
||||||
|
|
||||||
# 列出生成的文件
|
# 列出生成的文件
|
||||||
@@ -125,4 +125,4 @@ jobs:
|
|||||||
build/checksums.txt
|
build/checksums.txt
|
||||||
draft: false
|
draft: false
|
||||||
prerelease: false
|
prerelease: false
|
||||||
token: ${{ secrets.GITHUB_TOKEN }}
|
token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
|||||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -1,4 +1,5 @@
|
|||||||
.idea
|
.idea
|
||||||
.vscode
|
.vscode
|
||||||
.DS_Store
|
.DS_Store
|
||||||
hubproxy*
|
hubproxy*
|
||||||
|
!hubproxy.service
|
||||||
19
README.md
19
README.md
@@ -37,7 +37,7 @@ docker run -d \
|
|||||||
### 一键脚本安装
|
### 一键脚本安装
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
curl -fsSL https://raw.githubusercontent.com/sky22333/hubproxy/main/install-service.sh | sudo bash
|
curl -fsSL https://raw.githubusercontent.com/sky22333/hubproxy/main/install.sh | sudo bash
|
||||||
```
|
```
|
||||||
|
|
||||||
也可以直接下载二进制文件执行`./hubproxy`使用,无需配置文件即可启动,内置默认配置,支持所有功能。初始内存占用约18M,二进制文件大小约12M
|
也可以直接下载二进制文件执行`./hubproxy`使用,无需配置文件即可启动,内置默认配置,支持所有功能。初始内存占用约18M,二进制文件大小约12M
|
||||||
@@ -109,12 +109,14 @@ host = "0.0.0.0"
|
|||||||
port = 5000
|
port = 5000
|
||||||
# Github文件大小限制(字节),默认2GB
|
# Github文件大小限制(字节),默认2GB
|
||||||
fileSize = 2147483648
|
fileSize = 2147483648
|
||||||
|
# HTTP/2 多路复用,提升下载速度
|
||||||
|
enableH2C = false
|
||||||
|
|
||||||
[rateLimit]
|
[rateLimit]
|
||||||
# 每个IP每小时允许的请求数(注意Docker镜像会有多个层,会消耗多个次数)
|
# 每个IP每周期允许的请求数(注意Docker镜像会有多个层,会消耗多个次数)
|
||||||
requestLimit = 500
|
requestLimit = 500
|
||||||
# 限流周期(小时)
|
# 限流周期(小时)
|
||||||
periodHours = 1.0
|
periodHours = 3.0
|
||||||
|
|
||||||
[security]
|
[security]
|
||||||
# IP白名单,支持单个IP或IP段
|
# IP白名单,支持单个IP或IP段
|
||||||
@@ -132,7 +134,7 @@ blackList = [
|
|||||||
"192.168.100.0/24"
|
"192.168.100.0/24"
|
||||||
]
|
]
|
||||||
|
|
||||||
[proxy]
|
[access]
|
||||||
# 代理服务白名单(支持GitHub仓库和Docker镜像,支持通配符)
|
# 代理服务白名单(支持GitHub仓库和Docker镜像,支持通配符)
|
||||||
# 只允许访问白名单中的仓库/镜像,为空时不限制
|
# 只允许访问白名单中的仓库/镜像,为空时不限制
|
||||||
whiteList = []
|
whiteList = []
|
||||||
@@ -148,12 +150,6 @@ blackList = [
|
|||||||
# 代理配置,支持有用户名/密码认证和无认证模式
|
# 代理配置,支持有用户名/密码认证和无认证模式
|
||||||
# 无认证: socks5://127.0.0.1:1080
|
# 无认证: socks5://127.0.0.1:1080
|
||||||
# 有认证: socks5://username:password@127.0.0.1:1080
|
# 有认证: socks5://username:password@127.0.0.1:1080
|
||||||
# HTTP 代理示例
|
|
||||||
# http://username:password@127.0.0.1:7890
|
|
||||||
# SOCKS5 代理示例
|
|
||||||
# socks5://username:password@127.0.0.1:1080
|
|
||||||
# SOCKS5H 代理示例
|
|
||||||
# socks5h://username:password@127.0.0.1:1080
|
|
||||||
# 留空不使用代理
|
# 留空不使用代理
|
||||||
proxy = ""
|
proxy = ""
|
||||||
|
|
||||||
@@ -246,4 +242,5 @@ example.com {
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
[](https://www.star-history.com/#sky22333/hubproxy&Date)
|
## Star 趋势
|
||||||
|
[](https://starchart.cc/sky22333/hubproxy)
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
services:
|
services:
|
||||||
hubproxy:
|
hubproxy:
|
||||||
build: .
|
build: .
|
||||||
restart: always
|
restart: always
|
||||||
ports:
|
ports:
|
||||||
- '5000:5000'
|
- '5000:5000'
|
||||||
volumes:
|
volumes:
|
||||||
- ./src/config.toml:/root/config.toml
|
- ./src/config.toml:/root/config.toml
|
||||||
@@ -4,12 +4,14 @@ host = "0.0.0.0"
|
|||||||
port = 5000
|
port = 5000
|
||||||
# Github文件大小限制(字节),默认2GB
|
# Github文件大小限制(字节),默认2GB
|
||||||
fileSize = 2147483648
|
fileSize = 2147483648
|
||||||
|
# HTTP/2 多路复用
|
||||||
|
enableH2C = false
|
||||||
|
|
||||||
[rateLimit]
|
[rateLimit]
|
||||||
# 每个IP每小时允许的请求数(注意Docker镜像会有多个层,会消耗多个次数)
|
# 每个IP每周期允许的请求数
|
||||||
requestLimit = 500
|
requestLimit = 500
|
||||||
# 限流周期(小时)
|
# 限流周期(小时)
|
||||||
periodHours = 1.0
|
periodHours = 3.0
|
||||||
|
|
||||||
[security]
|
[security]
|
||||||
# IP白名单,支持单个IP或IP段
|
# IP白名单,支持单个IP或IP段
|
||||||
@@ -43,12 +45,6 @@ blackList = [
|
|||||||
# 代理配置,支持有用户名/密码认证和无认证模式
|
# 代理配置,支持有用户名/密码认证和无认证模式
|
||||||
# 无认证: socks5://127.0.0.1:1080
|
# 无认证: socks5://127.0.0.1:1080
|
||||||
# 有认证: socks5://username:password@127.0.0.1:1080
|
# 有认证: socks5://username:password@127.0.0.1:1080
|
||||||
# HTTP 代理示例
|
|
||||||
# http://username:password@127.0.0.1:7890
|
|
||||||
# SOCKS5 代理示例
|
|
||||||
# socks5://username:password@127.0.0.1:1080
|
|
||||||
# SOCKS5H 代理示例
|
|
||||||
# socks5h://username:password@127.0.0.1:1080
|
|
||||||
# 留空不使用代理
|
# 留空不使用代理
|
||||||
proxy = ""
|
proxy = ""
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
package main
|
package config
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -13,45 +13,46 @@ import (
|
|||||||
|
|
||||||
// RegistryMapping Registry映射配置
|
// RegistryMapping Registry映射配置
|
||||||
type RegistryMapping struct {
|
type RegistryMapping struct {
|
||||||
Upstream string `toml:"upstream"` // 上游Registry地址
|
Upstream string `toml:"upstream"`
|
||||||
AuthHost string `toml:"authHost"` // 认证服务器地址
|
AuthHost string `toml:"authHost"`
|
||||||
AuthType string `toml:"authType"` // 认证类型: docker/github/google/basic
|
AuthType string `toml:"authType"`
|
||||||
Enabled bool `toml:"enabled"` // 是否启用
|
Enabled bool `toml:"enabled"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// AppConfig 应用配置结构体
|
// AppConfig 应用配置结构体
|
||||||
type AppConfig struct {
|
type AppConfig struct {
|
||||||
Server struct {
|
Server struct {
|
||||||
Host string `toml:"host"` // 监听地址
|
Host string `toml:"host"`
|
||||||
Port int `toml:"port"` // 监听端口
|
Port int `toml:"port"`
|
||||||
FileSize int64 `toml:"fileSize"` // 文件大小限制(字节)
|
FileSize int64 `toml:"fileSize"`
|
||||||
|
EnableH2C bool `toml:"enableH2C"`
|
||||||
} `toml:"server"`
|
} `toml:"server"`
|
||||||
|
|
||||||
RateLimit struct {
|
RateLimit struct {
|
||||||
RequestLimit int `toml:"requestLimit"` // 每小时请求限制
|
RequestLimit int `toml:"requestLimit"`
|
||||||
PeriodHours float64 `toml:"periodHours"` // 限制周期(小时)
|
PeriodHours float64 `toml:"periodHours"`
|
||||||
} `toml:"rateLimit"`
|
} `toml:"rateLimit"`
|
||||||
|
|
||||||
Security struct {
|
Security struct {
|
||||||
WhiteList []string `toml:"whiteList"` // 白名单IP/CIDR列表
|
WhiteList []string `toml:"whiteList"`
|
||||||
BlackList []string `toml:"blackList"` // 黑名单IP/CIDR列表
|
BlackList []string `toml:"blackList"`
|
||||||
} `toml:"security"`
|
} `toml:"security"`
|
||||||
|
|
||||||
Access struct {
|
Access struct {
|
||||||
WhiteList []string `toml:"whiteList"` // 代理白名单(仓库级别)
|
WhiteList []string `toml:"whiteList"`
|
||||||
BlackList []string `toml:"blackList"` // 代理黑名单(仓库级别)
|
BlackList []string `toml:"blackList"`
|
||||||
Proxy string `toml:"proxy"` // 代理地址: 支持 http/https/socks5/socks5h
|
Proxy string `toml:"proxy"`
|
||||||
} `toml:"access"`
|
} `toml:"access"`
|
||||||
|
|
||||||
Download struct {
|
Download struct {
|
||||||
MaxImages int `toml:"maxImages"` // 单次下载最大镜像数量限制
|
MaxImages int `toml:"maxImages"`
|
||||||
} `toml:"download"`
|
} `toml:"download"`
|
||||||
|
|
||||||
Registries map[string]RegistryMapping `toml:"registries"`
|
Registries map[string]RegistryMapping `toml:"registries"`
|
||||||
|
|
||||||
TokenCache struct {
|
TokenCache struct {
|
||||||
Enabled bool `toml:"enabled"` // 是否启用token缓存
|
Enabled bool `toml:"enabled"`
|
||||||
DefaultTTL string `toml:"defaultTTL"` // 默认缓存时间
|
DefaultTTL string `toml:"defaultTTL"`
|
||||||
} `toml:"tokenCache"`
|
} `toml:"tokenCache"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -65,24 +66,25 @@ var (
|
|||||||
configCacheMutex sync.RWMutex
|
configCacheMutex sync.RWMutex
|
||||||
)
|
)
|
||||||
|
|
||||||
// todo:Refactoring is needed
|
|
||||||
// DefaultConfig 返回默认配置
|
// DefaultConfig 返回默认配置
|
||||||
func DefaultConfig() *AppConfig {
|
func DefaultConfig() *AppConfig {
|
||||||
return &AppConfig{
|
return &AppConfig{
|
||||||
Server: struct {
|
Server: struct {
|
||||||
Host string `toml:"host"`
|
Host string `toml:"host"`
|
||||||
Port int `toml:"port"`
|
Port int `toml:"port"`
|
||||||
FileSize int64 `toml:"fileSize"`
|
FileSize int64 `toml:"fileSize"`
|
||||||
|
EnableH2C bool `toml:"enableH2C"`
|
||||||
}{
|
}{
|
||||||
Host: "0.0.0.0",
|
Host: "0.0.0.0",
|
||||||
Port: 5000,
|
Port: 5000,
|
||||||
FileSize: 2 * 1024 * 1024 * 1024, // 2GB
|
FileSize: 2 * 1024 * 1024 * 1024, // 2GB
|
||||||
|
EnableH2C: false, // 默认关闭H2C
|
||||||
},
|
},
|
||||||
RateLimit: struct {
|
RateLimit: struct {
|
||||||
RequestLimit int `toml:"requestLimit"`
|
RequestLimit int `toml:"requestLimit"`
|
||||||
PeriodHours float64 `toml:"periodHours"`
|
PeriodHours float64 `toml:"periodHours"`
|
||||||
}{
|
}{
|
||||||
RequestLimit: 20,
|
RequestLimit: 200,
|
||||||
PeriodHours: 1.0,
|
PeriodHours: 1.0,
|
||||||
},
|
},
|
||||||
Security: struct {
|
Security: struct {
|
||||||
@@ -99,12 +101,12 @@ func DefaultConfig() *AppConfig {
|
|||||||
}{
|
}{
|
||||||
WhiteList: []string{},
|
WhiteList: []string{},
|
||||||
BlackList: []string{},
|
BlackList: []string{},
|
||||||
Proxy: "", // 默认不使用代理
|
Proxy: "",
|
||||||
},
|
},
|
||||||
Download: struct {
|
Download: struct {
|
||||||
MaxImages int `toml:"maxImages"`
|
MaxImages int `toml:"maxImages"`
|
||||||
}{
|
}{
|
||||||
MaxImages: 10, // 默认值:最多同时下载10个镜像
|
MaxImages: 10,
|
||||||
},
|
},
|
||||||
Registries: map[string]RegistryMapping{
|
Registries: map[string]RegistryMapping{
|
||||||
"ghcr.io": {
|
"ghcr.io": {
|
||||||
@@ -136,7 +138,7 @@ func DefaultConfig() *AppConfig {
|
|||||||
Enabled bool `toml:"enabled"`
|
Enabled bool `toml:"enabled"`
|
||||||
DefaultTTL string `toml:"defaultTTL"`
|
DefaultTTL string `toml:"defaultTTL"`
|
||||||
}{
|
}{
|
||||||
Enabled: true, // docker认证的匿名Token缓存配置,用于提升性能
|
Enabled: true,
|
||||||
DefaultTTL: "20m",
|
DefaultTTL: "20m",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -152,11 +154,9 @@ func GetConfig() *AppConfig {
|
|||||||
}
|
}
|
||||||
configCacheMutex.RUnlock()
|
configCacheMutex.RUnlock()
|
||||||
|
|
||||||
// 缓存过期,重新生成配置
|
|
||||||
configCacheMutex.Lock()
|
configCacheMutex.Lock()
|
||||||
defer configCacheMutex.Unlock()
|
defer configCacheMutex.Unlock()
|
||||||
|
|
||||||
// 双重检查,防止重复生成
|
|
||||||
if cachedConfig != nil && time.Since(configCacheTime) < configCacheTTL {
|
if cachedConfig != nil && time.Since(configCacheTime) < configCacheTTL {
|
||||||
return cachedConfig
|
return cachedConfig
|
||||||
}
|
}
|
||||||
@@ -170,7 +170,6 @@ func GetConfig() *AppConfig {
|
|||||||
return defaultCfg
|
return defaultCfg
|
||||||
}
|
}
|
||||||
|
|
||||||
// 生成新的配置深拷贝
|
|
||||||
configCopy := *appConfig
|
configCopy := *appConfig
|
||||||
configCopy.Security.WhiteList = append([]string(nil), appConfig.Security.WhiteList...)
|
configCopy.Security.WhiteList = append([]string(nil), appConfig.Security.WhiteList...)
|
||||||
configCopy.Security.BlackList = append([]string(nil), appConfig.Security.BlackList...)
|
configCopy.Security.BlackList = append([]string(nil), appConfig.Security.BlackList...)
|
||||||
@@ -197,10 +196,8 @@ func setConfig(cfg *AppConfig) {
|
|||||||
|
|
||||||
// LoadConfig 加载配置文件
|
// LoadConfig 加载配置文件
|
||||||
func LoadConfig() error {
|
func LoadConfig() error {
|
||||||
// 首先使用默认配置
|
|
||||||
cfg := DefaultConfig()
|
cfg := DefaultConfig()
|
||||||
|
|
||||||
// 尝试加载TOML配置文件
|
|
||||||
if data, err := os.ReadFile("config.toml"); err == nil {
|
if data, err := os.ReadFile("config.toml"); err == nil {
|
||||||
if err := toml.Unmarshal(data, cfg); err != nil {
|
if err := toml.Unmarshal(data, cfg); err != nil {
|
||||||
return fmt.Errorf("解析配置文件失败: %v", err)
|
return fmt.Errorf("解析配置文件失败: %v", err)
|
||||||
@@ -209,10 +206,7 @@ func LoadConfig() error {
|
|||||||
fmt.Println("未找到config.toml,使用默认配置")
|
fmt.Println("未找到config.toml,使用默认配置")
|
||||||
}
|
}
|
||||||
|
|
||||||
// 从环境变量覆盖配置
|
|
||||||
overrideFromEnv(cfg)
|
overrideFromEnv(cfg)
|
||||||
|
|
||||||
// 设置配置
|
|
||||||
setConfig(cfg)
|
setConfig(cfg)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -220,7 +214,6 @@ func LoadConfig() error {
|
|||||||
|
|
||||||
// overrideFromEnv 从环境变量覆盖配置
|
// overrideFromEnv 从环境变量覆盖配置
|
||||||
func overrideFromEnv(cfg *AppConfig) {
|
func overrideFromEnv(cfg *AppConfig) {
|
||||||
// 服务器配置
|
|
||||||
if val := os.Getenv("SERVER_HOST"); val != "" {
|
if val := os.Getenv("SERVER_HOST"); val != "" {
|
||||||
cfg.Server.Host = val
|
cfg.Server.Host = val
|
||||||
}
|
}
|
||||||
@@ -229,13 +222,17 @@ func overrideFromEnv(cfg *AppConfig) {
|
|||||||
cfg.Server.Port = port
|
cfg.Server.Port = port
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if val := os.Getenv("ENABLE_H2C"); val != "" {
|
||||||
|
if enable, err := strconv.ParseBool(val); err == nil {
|
||||||
|
cfg.Server.EnableH2C = enable
|
||||||
|
}
|
||||||
|
}
|
||||||
if val := os.Getenv("MAX_FILE_SIZE"); val != "" {
|
if val := os.Getenv("MAX_FILE_SIZE"); val != "" {
|
||||||
if size, err := strconv.ParseInt(val, 10, 64); err == nil && size > 0 {
|
if size, err := strconv.ParseInt(val, 10, 64); err == nil && size > 0 {
|
||||||
cfg.Server.FileSize = size
|
cfg.Server.FileSize = size
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 限流配置
|
|
||||||
if val := os.Getenv("RATE_LIMIT"); val != "" {
|
if val := os.Getenv("RATE_LIMIT"); val != "" {
|
||||||
if limit, err := strconv.Atoi(val); err == nil && limit > 0 {
|
if limit, err := strconv.Atoi(val); err == nil && limit > 0 {
|
||||||
cfg.RateLimit.RequestLimit = limit
|
cfg.RateLimit.RequestLimit = limit
|
||||||
@@ -247,7 +244,6 @@ func overrideFromEnv(cfg *AppConfig) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// IP限制配置
|
|
||||||
if val := os.Getenv("IP_WHITELIST"); val != "" {
|
if val := os.Getenv("IP_WHITELIST"); val != "" {
|
||||||
cfg.Security.WhiteList = append(cfg.Security.WhiteList, strings.Split(val, ",")...)
|
cfg.Security.WhiteList = append(cfg.Security.WhiteList, strings.Split(val, ",")...)
|
||||||
}
|
}
|
||||||
@@ -255,7 +251,6 @@ func overrideFromEnv(cfg *AppConfig) {
|
|||||||
cfg.Security.BlackList = append(cfg.Security.BlackList, strings.Split(val, ",")...)
|
cfg.Security.BlackList = append(cfg.Security.BlackList, strings.Split(val, ",")...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 下载限制配置
|
|
||||||
if val := os.Getenv("MAX_IMAGES"); val != "" {
|
if val := os.Getenv("MAX_IMAGES"); val != "" {
|
||||||
if maxImages, err := strconv.Atoi(val); err == nil && maxImages > 0 {
|
if maxImages, err := strconv.Atoi(val); err == nil && maxImages > 0 {
|
||||||
cfg.Download.MaxImages = maxImages
|
cfg.Download.MaxImages = maxImages
|
||||||
@@ -6,6 +6,7 @@ require (
|
|||||||
github.com/gin-gonic/gin v1.10.0
|
github.com/gin-gonic/gin v1.10.0
|
||||||
github.com/google/go-containerregistry v0.20.5
|
github.com/google/go-containerregistry v0.20.5
|
||||||
github.com/pelletier/go-toml/v2 v2.2.3
|
github.com/pelletier/go-toml/v2 v2.2.3
|
||||||
|
golang.org/x/net v0.33.0
|
||||||
golang.org/x/time v0.11.0
|
golang.org/x/time v0.11.0
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -43,7 +44,6 @@ require (
|
|||||||
github.com/vbatts/tar-split v0.12.1 // indirect
|
github.com/vbatts/tar-split v0.12.1 // indirect
|
||||||
golang.org/x/arch v0.8.0 // indirect
|
golang.org/x/arch v0.8.0 // indirect
|
||||||
golang.org/x/crypto v0.32.0 // indirect
|
golang.org/x/crypto v0.32.0 // indirect
|
||||||
golang.org/x/net v0.33.0 // indirect
|
|
||||||
golang.org/x/sync v0.14.0 // indirect
|
golang.org/x/sync v0.14.0 // indirect
|
||||||
golang.org/x/sys v0.33.0 // indirect
|
golang.org/x/sys v0.33.0 // indirect
|
||||||
golang.org/x/text v0.21.0 // indirect
|
golang.org/x/text v0.21.0 // indirect
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
package main
|
package handlers
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@@ -12,6 +12,8 @@ import (
|
|||||||
"github.com/google/go-containerregistry/pkg/authn"
|
"github.com/google/go-containerregistry/pkg/authn"
|
||||||
"github.com/google/go-containerregistry/pkg/name"
|
"github.com/google/go-containerregistry/pkg/name"
|
||||||
"github.com/google/go-containerregistry/pkg/v1/remote"
|
"github.com/google/go-containerregistry/pkg/v1/remote"
|
||||||
|
"hubproxy/config"
|
||||||
|
"hubproxy/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
// DockerProxy Docker代理配置
|
// DockerProxy Docker代理配置
|
||||||
@@ -27,12 +29,10 @@ type RegistryDetector struct{}
|
|||||||
|
|
||||||
// detectRegistryDomain 检测Registry域名并返回域名和剩余路径
|
// detectRegistryDomain 检测Registry域名并返回域名和剩余路径
|
||||||
func (rd *RegistryDetector) detectRegistryDomain(path string) (string, string) {
|
func (rd *RegistryDetector) detectRegistryDomain(path string) (string, string) {
|
||||||
cfg := GetConfig()
|
cfg := config.GetConfig()
|
||||||
|
|
||||||
// 检查路径是否以已知Registry域名开头
|
|
||||||
for domain := range cfg.Registries {
|
for domain := range cfg.Registries {
|
||||||
if strings.HasPrefix(path, domain+"/") {
|
if strings.HasPrefix(path, domain+"/") {
|
||||||
// 找到匹配的域名,返回域名和剩余路径
|
|
||||||
remainingPath := strings.TrimPrefix(path, domain+"/")
|
remainingPath := strings.TrimPrefix(path, domain+"/")
|
||||||
return domain, remainingPath
|
return domain, remainingPath
|
||||||
}
|
}
|
||||||
@@ -43,7 +43,7 @@ func (rd *RegistryDetector) detectRegistryDomain(path string) (string, string) {
|
|||||||
|
|
||||||
// isRegistryEnabled 检查Registry是否启用
|
// isRegistryEnabled 检查Registry是否启用
|
||||||
func (rd *RegistryDetector) isRegistryEnabled(domain string) bool {
|
func (rd *RegistryDetector) isRegistryEnabled(domain string) bool {
|
||||||
cfg := GetConfig()
|
cfg := config.GetConfig()
|
||||||
if mapping, exists := cfg.Registries[domain]; exists {
|
if mapping, exists := cfg.Registries[domain]; exists {
|
||||||
return mapping.Enabled
|
return mapping.Enabled
|
||||||
}
|
}
|
||||||
@@ -51,28 +51,26 @@ func (rd *RegistryDetector) isRegistryEnabled(domain string) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// getRegistryMapping 获取Registry映射配置
|
// getRegistryMapping 获取Registry映射配置
|
||||||
func (rd *RegistryDetector) getRegistryMapping(domain string) (RegistryMapping, bool) {
|
func (rd *RegistryDetector) getRegistryMapping(domain string) (config.RegistryMapping, bool) {
|
||||||
cfg := GetConfig()
|
cfg := config.GetConfig()
|
||||||
mapping, exists := cfg.Registries[domain]
|
mapping, exists := cfg.Registries[domain]
|
||||||
return mapping, exists && mapping.Enabled
|
return mapping, exists && mapping.Enabled
|
||||||
}
|
}
|
||||||
|
|
||||||
var registryDetector = &RegistryDetector{}
|
var registryDetector = &RegistryDetector{}
|
||||||
|
|
||||||
// 初始化Docker代理
|
// InitDockerProxy 初始化Docker代理
|
||||||
func initDockerProxy() {
|
func InitDockerProxy() {
|
||||||
// 创建目标registry
|
|
||||||
registry, err := name.NewRegistry("registry-1.docker.io")
|
registry, err := name.NewRegistry("registry-1.docker.io")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("创建Docker registry失败: %v\n", err)
|
fmt.Printf("创建Docker registry失败: %v\n", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 配置代理选项
|
|
||||||
options := []remote.Option{
|
options := []remote.Option{
|
||||||
remote.WithAuth(authn.Anonymous),
|
remote.WithAuth(authn.Anonymous),
|
||||||
remote.WithUserAgent("hubproxy/go-containerregistry"),
|
remote.WithUserAgent("hubproxy/go-containerregistry"),
|
||||||
remote.WithTransport(GetGlobalHTTPClient().Transport),
|
remote.WithTransport(utils.GetGlobalHTTPClient().Transport),
|
||||||
}
|
}
|
||||||
|
|
||||||
dockerProxy = &DockerProxy{
|
dockerProxy = &DockerProxy{
|
||||||
@@ -85,13 +83,11 @@ func initDockerProxy() {
|
|||||||
func ProxyDockerRegistryGin(c *gin.Context) {
|
func ProxyDockerRegistryGin(c *gin.Context) {
|
||||||
path := c.Request.URL.Path
|
path := c.Request.URL.Path
|
||||||
|
|
||||||
// 处理 /v2/ API版本检查
|
|
||||||
if path == "/v2/" {
|
if path == "/v2/" {
|
||||||
c.JSON(http.StatusOK, gin.H{})
|
c.JSON(http.StatusOK, gin.H{})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 处理不同的API端点
|
|
||||||
if strings.HasPrefix(path, "/v2/") {
|
if strings.HasPrefix(path, "/v2/") {
|
||||||
handleRegistryRequest(c, path)
|
handleRegistryRequest(c, path)
|
||||||
} else {
|
} else {
|
||||||
@@ -101,16 +97,13 @@ func ProxyDockerRegistryGin(c *gin.Context) {
|
|||||||
|
|
||||||
// handleRegistryRequest 处理Registry请求
|
// handleRegistryRequest 处理Registry请求
|
||||||
func handleRegistryRequest(c *gin.Context, path string) {
|
func handleRegistryRequest(c *gin.Context, path string) {
|
||||||
// 移除 /v2/ 前缀
|
|
||||||
pathWithoutV2 := strings.TrimPrefix(path, "/v2/")
|
pathWithoutV2 := strings.TrimPrefix(path, "/v2/")
|
||||||
|
|
||||||
if registryDomain, remainingPath := registryDetector.detectRegistryDomain(pathWithoutV2); registryDomain != "" {
|
if registryDomain, remainingPath := registryDetector.detectRegistryDomain(pathWithoutV2); registryDomain != "" {
|
||||||
if registryDetector.isRegistryEnabled(registryDomain) {
|
if registryDetector.isRegistryEnabled(registryDomain) {
|
||||||
// 设置目标Registry信息到Context
|
|
||||||
c.Set("target_registry_domain", registryDomain)
|
c.Set("target_registry_domain", registryDomain)
|
||||||
c.Set("target_path", remainingPath)
|
c.Set("target_path", remainingPath)
|
||||||
|
|
||||||
// 处理多Registry请求
|
|
||||||
handleMultiRegistryRequest(c, registryDomain, remainingPath)
|
handleMultiRegistryRequest(c, registryDomain, remainingPath)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -122,19 +115,16 @@ func handleRegistryRequest(c *gin.Context, path string) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 自动处理官方镜像的library命名空间
|
|
||||||
if !strings.Contains(imageName, "/") {
|
if !strings.Contains(imageName, "/") {
|
||||||
imageName = "library/" + imageName
|
imageName = "library/" + imageName
|
||||||
}
|
}
|
||||||
|
|
||||||
// Docker镜像访问控制检查
|
if allowed, reason := utils.GlobalAccessController.CheckDockerAccess(imageName); !allowed {
|
||||||
if allowed, reason := GlobalAccessController.CheckDockerAccess(imageName); !allowed {
|
|
||||||
fmt.Printf("Docker镜像 %s 访问被拒绝: %s\n", imageName, reason)
|
fmt.Printf("Docker镜像 %s 访问被拒绝: %s\n", imageName, reason)
|
||||||
c.String(http.StatusForbidden, "镜像访问被限制")
|
c.String(http.StatusForbidden, "镜像访问被限制")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 构建完整的镜像引用
|
|
||||||
imageRef := fmt.Sprintf("%s/%s", dockerProxy.registry.Name(), imageName)
|
imageRef := fmt.Sprintf("%s/%s", dockerProxy.registry.Name(), imageName)
|
||||||
|
|
||||||
switch apiType {
|
switch apiType {
|
||||||
@@ -151,7 +141,6 @@ func handleRegistryRequest(c *gin.Context, path string) {
|
|||||||
|
|
||||||
// parseRegistryPath 解析Registry路径
|
// parseRegistryPath 解析Registry路径
|
||||||
func parseRegistryPath(path string) (imageName, apiType, reference string) {
|
func parseRegistryPath(path string) (imageName, apiType, reference string) {
|
||||||
// 查找API端点关键字
|
|
||||||
if idx := strings.Index(path, "/manifests/"); idx != -1 {
|
if idx := strings.Index(path, "/manifests/"); idx != -1 {
|
||||||
imageName = path[:idx]
|
imageName = path[:idx]
|
||||||
apiType = "manifests"
|
apiType = "manifests"
|
||||||
@@ -178,13 +167,11 @@ func parseRegistryPath(path string) (imageName, apiType, reference string) {
|
|||||||
|
|
||||||
// handleManifestRequest 处理manifest请求
|
// handleManifestRequest 处理manifest请求
|
||||||
func handleManifestRequest(c *gin.Context, imageRef, reference string) {
|
func handleManifestRequest(c *gin.Context, imageRef, reference string) {
|
||||||
// Manifest缓存逻辑(仅对GET请求缓存)
|
if utils.IsCacheEnabled() && c.Request.Method == http.MethodGet {
|
||||||
if isCacheEnabled() && c.Request.Method == http.MethodGet {
|
cacheKey := utils.BuildManifestCacheKey(imageRef, reference)
|
||||||
cacheKey := buildManifestCacheKey(imageRef, reference)
|
|
||||||
|
|
||||||
// 优先从缓存获取
|
if cachedItem := utils.GlobalCache.Get(cacheKey); cachedItem != nil {
|
||||||
if cachedItem := globalCache.Get(cacheKey); cachedItem != nil {
|
utils.WriteCachedResponse(c, cachedItem)
|
||||||
writeCachedResponse(c, cachedItem)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -192,12 +179,9 @@ func handleManifestRequest(c *gin.Context, imageRef, reference string) {
|
|||||||
var ref name.Reference
|
var ref name.Reference
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
// 判断reference是digest还是tag
|
|
||||||
if strings.HasPrefix(reference, "sha256:") {
|
if strings.HasPrefix(reference, "sha256:") {
|
||||||
// 是digest
|
|
||||||
ref, err = name.NewDigest(fmt.Sprintf("%s@%s", imageRef, reference))
|
ref, err = name.NewDigest(fmt.Sprintf("%s@%s", imageRef, reference))
|
||||||
} else {
|
} else {
|
||||||
// 是tag
|
|
||||||
ref, err = name.NewTag(fmt.Sprintf("%s:%s", imageRef, reference))
|
ref, err = name.NewTag(fmt.Sprintf("%s:%s", imageRef, reference))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -207,9 +191,7 @@ func handleManifestRequest(c *gin.Context, imageRef, reference string) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 根据请求方法选择操作
|
|
||||||
if c.Request.Method == http.MethodHead {
|
if c.Request.Method == http.MethodHead {
|
||||||
// HEAD请求,使用remote.Head
|
|
||||||
desc, err := remote.Head(ref, dockerProxy.options...)
|
desc, err := remote.Head(ref, dockerProxy.options...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("HEAD请求失败: %v\n", err)
|
fmt.Printf("HEAD请求失败: %v\n", err)
|
||||||
@@ -217,13 +199,11 @@ func handleManifestRequest(c *gin.Context, imageRef, reference string) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 设置响应头
|
|
||||||
c.Header("Content-Type", string(desc.MediaType))
|
c.Header("Content-Type", string(desc.MediaType))
|
||||||
c.Header("Docker-Content-Digest", desc.Digest.String())
|
c.Header("Docker-Content-Digest", desc.Digest.String())
|
||||||
c.Header("Content-Length", fmt.Sprintf("%d", desc.Size))
|
c.Header("Content-Length", fmt.Sprintf("%d", desc.Size))
|
||||||
c.Status(http.StatusOK)
|
c.Status(http.StatusOK)
|
||||||
} else {
|
} else {
|
||||||
// GET请求,使用remote.Get
|
|
||||||
desc, err := remote.Get(ref, dockerProxy.options...)
|
desc, err := remote.Get(ref, dockerProxy.options...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("GET请求失败: %v\n", err)
|
fmt.Printf("GET请求失败: %v\n", err)
|
||||||
@@ -231,33 +211,28 @@ func handleManifestRequest(c *gin.Context, imageRef, reference string) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 设置响应头
|
|
||||||
headers := map[string]string{
|
headers := map[string]string{
|
||||||
"Docker-Content-Digest": desc.Digest.String(),
|
"Docker-Content-Digest": desc.Digest.String(),
|
||||||
"Content-Length": fmt.Sprintf("%d", len(desc.Manifest)),
|
"Content-Length": fmt.Sprintf("%d", len(desc.Manifest)),
|
||||||
}
|
}
|
||||||
|
|
||||||
// 缓存响应
|
if utils.IsCacheEnabled() {
|
||||||
if isCacheEnabled() {
|
cacheKey := utils.BuildManifestCacheKey(imageRef, reference)
|
||||||
cacheKey := buildManifestCacheKey(imageRef, reference)
|
ttl := utils.GetManifestTTL(reference)
|
||||||
ttl := getManifestTTL(reference)
|
utils.GlobalCache.Set(cacheKey, desc.Manifest, string(desc.MediaType), headers, ttl)
|
||||||
globalCache.Set(cacheKey, desc.Manifest, string(desc.MediaType), headers, ttl)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 设置响应头
|
|
||||||
c.Header("Content-Type", string(desc.MediaType))
|
c.Header("Content-Type", string(desc.MediaType))
|
||||||
for key, value := range headers {
|
for key, value := range headers {
|
||||||
c.Header(key, value)
|
c.Header(key, value)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 返回manifest内容
|
|
||||||
c.Data(http.StatusOK, string(desc.MediaType), desc.Manifest)
|
c.Data(http.StatusOK, string(desc.MediaType), desc.Manifest)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleBlobRequest 处理blob请求
|
// handleBlobRequest 处理blob请求
|
||||||
func handleBlobRequest(c *gin.Context, imageRef, digest string) {
|
func handleBlobRequest(c *gin.Context, imageRef, digest string) {
|
||||||
// 构建digest引用
|
|
||||||
digestRef, err := name.NewDigest(fmt.Sprintf("%s@%s", imageRef, digest))
|
digestRef, err := name.NewDigest(fmt.Sprintf("%s@%s", imageRef, digest))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("解析digest引用失败: %v\n", err)
|
fmt.Printf("解析digest引用失败: %v\n", err)
|
||||||
@@ -265,7 +240,6 @@ func handleBlobRequest(c *gin.Context, imageRef, digest string) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 使用remote.Layer获取layer
|
|
||||||
layer, err := remote.Layer(digestRef, dockerProxy.options...)
|
layer, err := remote.Layer(digestRef, dockerProxy.options...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("获取layer失败: %v\n", err)
|
fmt.Printf("获取layer失败: %v\n", err)
|
||||||
@@ -273,7 +247,6 @@ func handleBlobRequest(c *gin.Context, imageRef, digest string) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 获取layer信息
|
|
||||||
size, err := layer.Size()
|
size, err := layer.Size()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("获取layer大小失败: %v\n", err)
|
fmt.Printf("获取layer大小失败: %v\n", err)
|
||||||
@@ -281,7 +254,6 @@ func handleBlobRequest(c *gin.Context, imageRef, digest string) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 获取layer内容
|
|
||||||
reader, err := layer.Compressed()
|
reader, err := layer.Compressed()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("获取layer内容失败: %v\n", err)
|
fmt.Printf("获取layer内容失败: %v\n", err)
|
||||||
@@ -290,19 +262,16 @@ func handleBlobRequest(c *gin.Context, imageRef, digest string) {
|
|||||||
}
|
}
|
||||||
defer reader.Close()
|
defer reader.Close()
|
||||||
|
|
||||||
// 设置响应头
|
|
||||||
c.Header("Content-Type", "application/octet-stream")
|
c.Header("Content-Type", "application/octet-stream")
|
||||||
c.Header("Content-Length", fmt.Sprintf("%d", size))
|
c.Header("Content-Length", fmt.Sprintf("%d", size))
|
||||||
c.Header("Docker-Content-Digest", digest)
|
c.Header("Docker-Content-Digest", digest)
|
||||||
|
|
||||||
// 流式传输blob内容
|
|
||||||
c.Status(http.StatusOK)
|
c.Status(http.StatusOK)
|
||||||
io.Copy(c.Writer, reader)
|
io.Copy(c.Writer, reader)
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleTagsRequest 处理tags列表请求
|
// handleTagsRequest 处理tags列表请求
|
||||||
func handleTagsRequest(c *gin.Context, imageRef string) {
|
func handleTagsRequest(c *gin.Context, imageRef string) {
|
||||||
// 解析repository
|
|
||||||
repo, err := name.NewRepository(imageRef)
|
repo, err := name.NewRepository(imageRef)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("解析repository失败: %v\n", err)
|
fmt.Printf("解析repository失败: %v\n", err)
|
||||||
@@ -310,7 +279,6 @@ func handleTagsRequest(c *gin.Context, imageRef string) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 使用remote.List获取tags
|
|
||||||
tags, err := remote.List(repo, dockerProxy.options...)
|
tags, err := remote.List(repo, dockerProxy.options...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("获取tags失败: %v\n", err)
|
fmt.Printf("获取tags失败: %v\n", err)
|
||||||
@@ -318,7 +286,6 @@ func handleTagsRequest(c *gin.Context, imageRef string) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 构建响应
|
|
||||||
response := map[string]interface{}{
|
response := map[string]interface{}{
|
||||||
"name": strings.TrimPrefix(imageRef, dockerProxy.registry.Name()+"/"),
|
"name": strings.TrimPrefix(imageRef, dockerProxy.registry.Name()+"/"),
|
||||||
"tags": tags,
|
"tags": tags,
|
||||||
@@ -327,10 +294,9 @@ func handleTagsRequest(c *gin.Context, imageRef string) {
|
|||||||
c.JSON(http.StatusOK, response)
|
c.JSON(http.StatusOK, response)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ProxyDockerAuthGin Docker认证代理(带缓存优化)
|
// ProxyDockerAuthGin Docker认证代理
|
||||||
func ProxyDockerAuthGin(c *gin.Context) {
|
func ProxyDockerAuthGin(c *gin.Context) {
|
||||||
// 检查是否启用token缓存
|
if utils.IsTokenCacheEnabled() {
|
||||||
if isTokenCacheEnabled() {
|
|
||||||
proxyDockerAuthWithCache(c)
|
proxyDockerAuthWithCache(c)
|
||||||
} else {
|
} else {
|
||||||
proxyDockerAuthOriginal(c)
|
proxyDockerAuthOriginal(c)
|
||||||
@@ -339,32 +305,26 @@ func ProxyDockerAuthGin(c *gin.Context) {
|
|||||||
|
|
||||||
// proxyDockerAuthWithCache 带缓存的认证代理
|
// proxyDockerAuthWithCache 带缓存的认证代理
|
||||||
func proxyDockerAuthWithCache(c *gin.Context) {
|
func proxyDockerAuthWithCache(c *gin.Context) {
|
||||||
// 1. 构建缓存key(基于完整的查询参数)
|
cacheKey := utils.BuildTokenCacheKey(c.Request.URL.RawQuery)
|
||||||
cacheKey := buildTokenCacheKey(c.Request.URL.RawQuery)
|
|
||||||
|
|
||||||
// 2. 尝试从缓存获取token
|
if cachedToken := utils.GlobalCache.GetToken(cacheKey); cachedToken != "" {
|
||||||
if cachedToken := globalCache.GetToken(cacheKey); cachedToken != "" {
|
utils.WriteTokenResponse(c, cachedToken)
|
||||||
writeTokenResponse(c, cachedToken)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3. 缓存未命中,创建响应记录器
|
|
||||||
recorder := &ResponseRecorder{
|
recorder := &ResponseRecorder{
|
||||||
ResponseWriter: c.Writer,
|
ResponseWriter: c.Writer,
|
||||||
statusCode: 200,
|
statusCode: 200,
|
||||||
}
|
}
|
||||||
c.Writer = recorder
|
c.Writer = recorder
|
||||||
|
|
||||||
// 4. 调用原有认证逻辑
|
|
||||||
proxyDockerAuthOriginal(c)
|
proxyDockerAuthOriginal(c)
|
||||||
|
|
||||||
// 5. 如果认证成功,缓存响应
|
|
||||||
if recorder.statusCode == 200 && len(recorder.body) > 0 {
|
if recorder.statusCode == 200 && len(recorder.body) > 0 {
|
||||||
ttl := extractTTLFromResponse(recorder.body)
|
ttl := utils.ExtractTTLFromResponse(recorder.body)
|
||||||
globalCache.SetToken(cacheKey, string(recorder.body), ttl)
|
utils.GlobalCache.SetToken(cacheKey, string(recorder.body), ttl)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 6. 写入实际响应
|
|
||||||
c.Writer = recorder.ResponseWriter
|
c.Writer = recorder.ResponseWriter
|
||||||
c.Data(recorder.statusCode, "application/json", recorder.body)
|
c.Data(recorder.statusCode, "application/json", recorder.body)
|
||||||
}
|
}
|
||||||
@@ -389,14 +349,11 @@ func proxyDockerAuthOriginal(c *gin.Context) {
|
|||||||
var authURL string
|
var authURL string
|
||||||
if targetDomain, exists := c.Get("target_registry_domain"); exists {
|
if targetDomain, exists := c.Get("target_registry_domain"); exists {
|
||||||
if mapping, found := registryDetector.getRegistryMapping(targetDomain.(string)); found {
|
if mapping, found := registryDetector.getRegistryMapping(targetDomain.(string)); found {
|
||||||
// 使用Registry特定的认证服务器
|
|
||||||
authURL = "https://" + mapping.AuthHost + c.Request.URL.Path
|
authURL = "https://" + mapping.AuthHost + c.Request.URL.Path
|
||||||
} else {
|
} else {
|
||||||
// fallback到默认Docker认证
|
|
||||||
authURL = "https://auth.docker.io" + c.Request.URL.Path
|
authURL = "https://auth.docker.io" + c.Request.URL.Path
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// 构建默认Docker认证URL
|
|
||||||
authURL = "https://auth.docker.io" + c.Request.URL.Path
|
authURL = "https://auth.docker.io" + c.Request.URL.Path
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -404,13 +361,11 @@ func proxyDockerAuthOriginal(c *gin.Context) {
|
|||||||
authURL += "?" + c.Request.URL.RawQuery
|
authURL += "?" + c.Request.URL.RawQuery
|
||||||
}
|
}
|
||||||
|
|
||||||
// 创建HTTP客户端,复用全局传输配置(包含代理设置)
|
|
||||||
client := &http.Client{
|
client := &http.Client{
|
||||||
Timeout: 30 * time.Second,
|
Timeout: 30 * time.Second,
|
||||||
Transport: GetGlobalHTTPClient().Transport,
|
Transport: utils.GetGlobalHTTPClient().Transport,
|
||||||
}
|
}
|
||||||
|
|
||||||
// 创建请求
|
|
||||||
req, err := http.NewRequestWithContext(
|
req, err := http.NewRequestWithContext(
|
||||||
context.Background(),
|
context.Background(),
|
||||||
c.Request.Method,
|
c.Request.Method,
|
||||||
@@ -422,14 +377,12 @@ func proxyDockerAuthOriginal(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 复制请求头
|
|
||||||
for key, values := range c.Request.Header {
|
for key, values := range c.Request.Header {
|
||||||
for _, value := range values {
|
for _, value := range values {
|
||||||
req.Header.Add(key, value)
|
req.Header.Add(key, value)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 执行请求
|
|
||||||
resp, err := client.Do(req)
|
resp, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.String(http.StatusBadGateway, "Auth request failed")
|
c.String(http.StatusBadGateway, "Auth request failed")
|
||||||
@@ -437,37 +390,30 @@ func proxyDockerAuthOriginal(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
// 获取当前代理的Host地址
|
|
||||||
proxyHost := c.Request.Host
|
proxyHost := c.Request.Host
|
||||||
if proxyHost == "" {
|
if proxyHost == "" {
|
||||||
// 使用配置中的服务器地址和端口
|
cfg := config.GetConfig()
|
||||||
cfg := GetConfig()
|
|
||||||
proxyHost = fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port)
|
proxyHost = fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port)
|
||||||
if cfg.Server.Host == "0.0.0.0" {
|
if cfg.Server.Host == "0.0.0.0" {
|
||||||
proxyHost = fmt.Sprintf("localhost:%d", cfg.Server.Port)
|
proxyHost = fmt.Sprintf("localhost:%d", cfg.Server.Port)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 复制响应头并重写认证URL
|
|
||||||
for key, values := range resp.Header {
|
for key, values := range resp.Header {
|
||||||
for _, value := range values {
|
for _, value := range values {
|
||||||
// 重写WWW-Authenticate头中的realm URL
|
|
||||||
if key == "Www-Authenticate" {
|
if key == "Www-Authenticate" {
|
||||||
// 支持多Registry的URL重写
|
|
||||||
value = rewriteAuthHeader(value, proxyHost)
|
value = rewriteAuthHeader(value, proxyHost)
|
||||||
}
|
}
|
||||||
c.Header(key, value)
|
c.Header(key, value)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 返回响应
|
|
||||||
c.Status(resp.StatusCode)
|
c.Status(resp.StatusCode)
|
||||||
io.Copy(c.Writer, resp.Body)
|
io.Copy(c.Writer, resp.Body)
|
||||||
}
|
}
|
||||||
|
|
||||||
// rewriteAuthHeader 重写认证头
|
// rewriteAuthHeader 重写认证头
|
||||||
func rewriteAuthHeader(authHeader, proxyHost string) string {
|
func rewriteAuthHeader(authHeader, proxyHost string) string {
|
||||||
// 重写各种Registry的认证URL
|
|
||||||
authHeader = strings.ReplaceAll(authHeader, "https://auth.docker.io", "http://"+proxyHost)
|
authHeader = strings.ReplaceAll(authHeader, "https://auth.docker.io", "http://"+proxyHost)
|
||||||
authHeader = strings.ReplaceAll(authHeader, "https://ghcr.io", "http://"+proxyHost)
|
authHeader = strings.ReplaceAll(authHeader, "https://ghcr.io", "http://"+proxyHost)
|
||||||
authHeader = strings.ReplaceAll(authHeader, "https://gcr.io", "http://"+proxyHost)
|
authHeader = strings.ReplaceAll(authHeader, "https://gcr.io", "http://"+proxyHost)
|
||||||
@@ -478,32 +424,27 @@ func rewriteAuthHeader(authHeader, proxyHost string) string {
|
|||||||
|
|
||||||
// handleMultiRegistryRequest 处理多Registry请求
|
// handleMultiRegistryRequest 处理多Registry请求
|
||||||
func handleMultiRegistryRequest(c *gin.Context, registryDomain, remainingPath string) {
|
func handleMultiRegistryRequest(c *gin.Context, registryDomain, remainingPath string) {
|
||||||
// 获取Registry映射配置
|
|
||||||
mapping, exists := registryDetector.getRegistryMapping(registryDomain)
|
mapping, exists := registryDetector.getRegistryMapping(registryDomain)
|
||||||
if !exists {
|
if !exists {
|
||||||
c.String(http.StatusBadRequest, "Registry not configured")
|
c.String(http.StatusBadRequest, "Registry not configured")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 解析剩余路径
|
|
||||||
imageName, apiType, reference := parseRegistryPath(remainingPath)
|
imageName, apiType, reference := parseRegistryPath(remainingPath)
|
||||||
if imageName == "" || apiType == "" {
|
if imageName == "" || apiType == "" {
|
||||||
c.String(http.StatusBadRequest, "Invalid path format")
|
c.String(http.StatusBadRequest, "Invalid path format")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 访问控制检查(使用完整的镜像路径)
|
|
||||||
fullImageName := registryDomain + "/" + imageName
|
fullImageName := registryDomain + "/" + imageName
|
||||||
if allowed, reason := GlobalAccessController.CheckDockerAccess(fullImageName); !allowed {
|
if allowed, reason := utils.GlobalAccessController.CheckDockerAccess(fullImageName); !allowed {
|
||||||
fmt.Printf("镜像 %s 访问被拒绝: %s\n", fullImageName, reason)
|
fmt.Printf("镜像 %s 访问被拒绝: %s\n", fullImageName, reason)
|
||||||
c.String(http.StatusForbidden, "镜像访问被限制")
|
c.String(http.StatusForbidden, "镜像访问被限制")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 构建上游Registry引用
|
|
||||||
upstreamImageRef := fmt.Sprintf("%s/%s", mapping.Upstream, imageName)
|
upstreamImageRef := fmt.Sprintf("%s/%s", mapping.Upstream, imageName)
|
||||||
|
|
||||||
// 根据API类型处理请求
|
|
||||||
switch apiType {
|
switch apiType {
|
||||||
case "manifests":
|
case "manifests":
|
||||||
handleUpstreamManifestRequest(c, upstreamImageRef, reference, mapping)
|
handleUpstreamManifestRequest(c, upstreamImageRef, reference, mapping)
|
||||||
@@ -517,14 +458,12 @@ func handleMultiRegistryRequest(c *gin.Context, registryDomain, remainingPath st
|
|||||||
}
|
}
|
||||||
|
|
||||||
// handleUpstreamManifestRequest 处理上游Registry的manifest请求
|
// handleUpstreamManifestRequest 处理上游Registry的manifest请求
|
||||||
func handleUpstreamManifestRequest(c *gin.Context, imageRef, reference string, mapping RegistryMapping) {
|
func handleUpstreamManifestRequest(c *gin.Context, imageRef, reference string, mapping config.RegistryMapping) {
|
||||||
// Manifest缓存逻辑(仅对GET请求缓存)
|
if utils.IsCacheEnabled() && c.Request.Method == http.MethodGet {
|
||||||
if isCacheEnabled() && c.Request.Method == http.MethodGet {
|
cacheKey := utils.BuildManifestCacheKey(imageRef, reference)
|
||||||
cacheKey := buildManifestCacheKey(imageRef, reference)
|
|
||||||
|
|
||||||
// 优先从缓存获取
|
if cachedItem := utils.GlobalCache.Get(cacheKey); cachedItem != nil {
|
||||||
if cachedItem := globalCache.Get(cacheKey); cachedItem != nil {
|
utils.WriteCachedResponse(c, cachedItem)
|
||||||
writeCachedResponse(c, cachedItem)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -532,7 +471,6 @@ func handleUpstreamManifestRequest(c *gin.Context, imageRef, reference string, m
|
|||||||
var ref name.Reference
|
var ref name.Reference
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
// 判断reference是digest还是tag
|
|
||||||
if strings.HasPrefix(reference, "sha256:") {
|
if strings.HasPrefix(reference, "sha256:") {
|
||||||
ref, err = name.NewDigest(fmt.Sprintf("%s@%s", imageRef, reference))
|
ref, err = name.NewDigest(fmt.Sprintf("%s@%s", imageRef, reference))
|
||||||
} else {
|
} else {
|
||||||
@@ -545,10 +483,8 @@ func handleUpstreamManifestRequest(c *gin.Context, imageRef, reference string, m
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 创建针对上游Registry的选项
|
|
||||||
options := createUpstreamOptions(mapping)
|
options := createUpstreamOptions(mapping)
|
||||||
|
|
||||||
// 根据请求方法选择操作
|
|
||||||
if c.Request.Method == http.MethodHead {
|
if c.Request.Method == http.MethodHead {
|
||||||
desc, err := remote.Head(ref, options...)
|
desc, err := remote.Head(ref, options...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -569,20 +505,17 @@ func handleUpstreamManifestRequest(c *gin.Context, imageRef, reference string, m
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 设置响应头
|
|
||||||
headers := map[string]string{
|
headers := map[string]string{
|
||||||
"Docker-Content-Digest": desc.Digest.String(),
|
"Docker-Content-Digest": desc.Digest.String(),
|
||||||
"Content-Length": fmt.Sprintf("%d", len(desc.Manifest)),
|
"Content-Length": fmt.Sprintf("%d", len(desc.Manifest)),
|
||||||
}
|
}
|
||||||
|
|
||||||
// 缓存响应
|
if utils.IsCacheEnabled() {
|
||||||
if isCacheEnabled() {
|
cacheKey := utils.BuildManifestCacheKey(imageRef, reference)
|
||||||
cacheKey := buildManifestCacheKey(imageRef, reference)
|
ttl := utils.GetManifestTTL(reference)
|
||||||
ttl := getManifestTTL(reference)
|
utils.GlobalCache.Set(cacheKey, desc.Manifest, string(desc.MediaType), headers, ttl)
|
||||||
globalCache.Set(cacheKey, desc.Manifest, string(desc.MediaType), headers, ttl)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 设置响应头
|
|
||||||
c.Header("Content-Type", string(desc.MediaType))
|
c.Header("Content-Type", string(desc.MediaType))
|
||||||
for key, value := range headers {
|
for key, value := range headers {
|
||||||
c.Header(key, value)
|
c.Header(key, value)
|
||||||
@@ -593,7 +526,7 @@ func handleUpstreamManifestRequest(c *gin.Context, imageRef, reference string, m
|
|||||||
}
|
}
|
||||||
|
|
||||||
// handleUpstreamBlobRequest 处理上游Registry的blob请求
|
// handleUpstreamBlobRequest 处理上游Registry的blob请求
|
||||||
func handleUpstreamBlobRequest(c *gin.Context, imageRef, digest string, mapping RegistryMapping) {
|
func handleUpstreamBlobRequest(c *gin.Context, imageRef, digest string, mapping config.RegistryMapping) {
|
||||||
digestRef, err := name.NewDigest(fmt.Sprintf("%s@%s", imageRef, digest))
|
digestRef, err := name.NewDigest(fmt.Sprintf("%s@%s", imageRef, digest))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("解析digest引用失败: %v\n", err)
|
fmt.Printf("解析digest引用失败: %v\n", err)
|
||||||
@@ -633,7 +566,7 @@ func handleUpstreamBlobRequest(c *gin.Context, imageRef, digest string, mapping
|
|||||||
}
|
}
|
||||||
|
|
||||||
// handleUpstreamTagsRequest 处理上游Registry的tags请求
|
// handleUpstreamTagsRequest 处理上游Registry的tags请求
|
||||||
func handleUpstreamTagsRequest(c *gin.Context, imageRef string, mapping RegistryMapping) {
|
func handleUpstreamTagsRequest(c *gin.Context, imageRef string, mapping config.RegistryMapping) {
|
||||||
repo, err := name.NewRepository(imageRef)
|
repo, err := name.NewRepository(imageRef)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("解析repository失败: %v\n", err)
|
fmt.Printf("解析repository失败: %v\n", err)
|
||||||
@@ -658,14 +591,13 @@ func handleUpstreamTagsRequest(c *gin.Context, imageRef string, mapping Registry
|
|||||||
}
|
}
|
||||||
|
|
||||||
// createUpstreamOptions 创建上游Registry选项
|
// createUpstreamOptions 创建上游Registry选项
|
||||||
func createUpstreamOptions(mapping RegistryMapping) []remote.Option {
|
func createUpstreamOptions(mapping config.RegistryMapping) []remote.Option {
|
||||||
options := []remote.Option{
|
options := []remote.Option{
|
||||||
remote.WithAuth(authn.Anonymous),
|
remote.WithAuth(authn.Anonymous),
|
||||||
remote.WithUserAgent("hubproxy/go-containerregistry"),
|
remote.WithUserAgent("hubproxy/go-containerregistry"),
|
||||||
remote.WithTransport(GetGlobalHTTPClient().Transport),
|
remote.WithTransport(utils.GetGlobalHTTPClient().Transport),
|
||||||
}
|
}
|
||||||
|
|
||||||
// 根据Registry类型添加特定的认证选项(方便后续扩展)
|
|
||||||
switch mapping.AuthType {
|
switch mapping.AuthType {
|
||||||
case "github":
|
case "github":
|
||||||
case "google":
|
case "google":
|
||||||
219
src/handlers/github.go
Normal file
219
src/handlers/github.go
Normal file
@@ -0,0 +1,219 @@
|
|||||||
|
package handlers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"regexp"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"hubproxy/config"
|
||||||
|
"hubproxy/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// GitHub URL匹配正则表达式
|
||||||
|
githubExps = []*regexp.Regexp{
|
||||||
|
regexp.MustCompile(`^(?:https?://)?github\.com/([^/]+)/([^/]+)/(?:releases|archive)/.*`),
|
||||||
|
regexp.MustCompile(`^(?:https?://)?github\.com/([^/]+)/([^/]+)/(?:blob|raw)/.*`),
|
||||||
|
regexp.MustCompile(`^(?:https?://)?github\.com/([^/]+)/([^/]+)/(?:info|git-).*`),
|
||||||
|
regexp.MustCompile(`^(?:https?://)?raw\.github(?:usercontent|)\.com/([^/]+)/([^/]+)/.+?/.+`),
|
||||||
|
regexp.MustCompile(`^(?:https?://)?gist\.(?:githubusercontent|github)\.com/(.+?)/(.+?)/.+\.[a-zA-Z0-9]+$`),
|
||||||
|
regexp.MustCompile(`^(?:https?://)?api\.github\.com/repos/([^/]+)/([^/]+)/.*`),
|
||||||
|
regexp.MustCompile(`^(?:https?://)?huggingface\.co(?:/spaces)?/([^/]+)/(.+)`),
|
||||||
|
regexp.MustCompile(`^(?:https?://)?cdn-lfs\.hf\.co(?:/spaces)?/([^/]+)/([^/]+)(?:/(.*))?`),
|
||||||
|
regexp.MustCompile(`^(?:https?://)?download\.docker\.com/([^/]+)/.*\.(tgz|zip)`),
|
||||||
|
regexp.MustCompile(`^(?:https?://)?(github|opengraph)\.githubassets\.com/([^/]+)/.+?`),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
// GitHubProxyHandler GitHub代理处理器
|
||||||
|
func GitHubProxyHandler(c *gin.Context) {
|
||||||
|
rawPath := strings.TrimPrefix(c.Request.URL.RequestURI(), "/")
|
||||||
|
|
||||||
|
for strings.HasPrefix(rawPath, "/") {
|
||||||
|
rawPath = strings.TrimPrefix(rawPath, "/")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 自动补全协议头
|
||||||
|
if !strings.HasPrefix(rawPath, "https://") {
|
||||||
|
if strings.HasPrefix(rawPath, "http:/") || strings.HasPrefix(rawPath, "https:/") {
|
||||||
|
rawPath = strings.Replace(rawPath, "http:/", "", 1)
|
||||||
|
rawPath = strings.Replace(rawPath, "https:/", "", 1)
|
||||||
|
} else if strings.HasPrefix(rawPath, "http://") {
|
||||||
|
rawPath = strings.TrimPrefix(rawPath, "http://")
|
||||||
|
}
|
||||||
|
rawPath = "https://" + rawPath
|
||||||
|
}
|
||||||
|
|
||||||
|
matches := CheckGitHubURL(rawPath)
|
||||||
|
if matches != nil {
|
||||||
|
if allowed, reason := utils.GlobalAccessController.CheckGitHubAccess(matches); !allowed {
|
||||||
|
var repoPath string
|
||||||
|
if len(matches) >= 2 {
|
||||||
|
username := matches[0]
|
||||||
|
repoName := strings.TrimSuffix(matches[1], ".git")
|
||||||
|
repoPath = username + "/" + repoName
|
||||||
|
}
|
||||||
|
fmt.Printf("GitHub仓库 %s 访问被拒绝: %s\n", repoPath, reason)
|
||||||
|
c.String(http.StatusForbidden, reason)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
c.String(http.StatusForbidden, "无效输入")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 将blob链接转换为raw链接
|
||||||
|
if githubExps[1].MatchString(rawPath) {
|
||||||
|
rawPath = strings.Replace(rawPath, "/blob/", "/raw/", 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
ProxyGitHubRequest(c, rawPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CheckGitHubURL 检查URL是否匹配GitHub模式
|
||||||
|
func CheckGitHubURL(u string) []string {
|
||||||
|
for _, exp := range githubExps {
|
||||||
|
if matches := exp.FindStringSubmatch(u); matches != nil {
|
||||||
|
return matches[1:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProxyGitHubRequest 代理GitHub请求
|
||||||
|
func ProxyGitHubRequest(c *gin.Context, u string) {
|
||||||
|
proxyGitHubWithRedirect(c, u, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// proxyGitHubWithRedirect 带重定向的GitHub代理请求
|
||||||
|
func proxyGitHubWithRedirect(c *gin.Context, u string, redirectCount int) {
|
||||||
|
const maxRedirects = 20
|
||||||
|
if redirectCount > maxRedirects {
|
||||||
|
c.String(http.StatusLoopDetected, "重定向次数过多,可能存在循环重定向")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequest(c.Request.Method, u, c.Request.Body)
|
||||||
|
if err != nil {
|
||||||
|
c.String(http.StatusInternalServerError, fmt.Sprintf("server error %v", err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 复制请求头
|
||||||
|
for key, values := range c.Request.Header {
|
||||||
|
for _, value := range values {
|
||||||
|
req.Header.Add(key, value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
req.Header.Del("Host")
|
||||||
|
|
||||||
|
resp, err := utils.GetGlobalHTTPClient().Do(req)
|
||||||
|
if err != nil {
|
||||||
|
c.String(http.StatusInternalServerError, fmt.Sprintf("server error %v", err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := resp.Body.Close(); err != nil {
|
||||||
|
fmt.Printf("关闭响应体失败: %v\n", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// 如果Github上游404,则返回错误信息
|
||||||
|
if resp.StatusCode == http.StatusNotFound {
|
||||||
|
c.String(http.StatusForbidden, "无效的GitHub地址")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查文件大小限制
|
||||||
|
cfg := config.GetConfig()
|
||||||
|
if contentLength := resp.Header.Get("Content-Length"); contentLength != "" {
|
||||||
|
if size, err := strconv.ParseInt(contentLength, 10, 64); err == nil && size > cfg.Server.FileSize {
|
||||||
|
c.String(http.StatusRequestEntityTooLarge,
|
||||||
|
fmt.Sprintf("文件过大,限制大小: %d MB", cfg.Server.FileSize/(1024*1024)))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 清理安全相关的头
|
||||||
|
resp.Header.Del("Content-Security-Policy")
|
||||||
|
resp.Header.Del("Referrer-Policy")
|
||||||
|
resp.Header.Del("Strict-Transport-Security")
|
||||||
|
|
||||||
|
// 获取真实域名
|
||||||
|
realHost := c.Request.Header.Get("X-Forwarded-Host")
|
||||||
|
if realHost == "" {
|
||||||
|
realHost = c.Request.Host
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(realHost, "http://") && !strings.HasPrefix(realHost, "https://") {
|
||||||
|
realHost = "https://" + realHost
|
||||||
|
}
|
||||||
|
|
||||||
|
// 处理.sh文件的智能处理
|
||||||
|
if strings.HasSuffix(strings.ToLower(u), ".sh") {
|
||||||
|
isGzipCompressed := resp.Header.Get("Content-Encoding") == "gzip"
|
||||||
|
|
||||||
|
processedBody, processedSize, err := utils.ProcessSmart(resp.Body, isGzipCompressed, realHost)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("智能处理失败,回退到直接代理: %v\n", err)
|
||||||
|
processedBody = resp.Body
|
||||||
|
processedSize = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// 智能设置响应头
|
||||||
|
if processedSize > 0 {
|
||||||
|
resp.Header.Del("Content-Length")
|
||||||
|
resp.Header.Del("Content-Encoding")
|
||||||
|
resp.Header.Set("Transfer-Encoding", "chunked")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 复制其他响应头
|
||||||
|
for key, values := range resp.Header {
|
||||||
|
for _, value := range values {
|
||||||
|
c.Header(key, value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 处理重定向
|
||||||
|
if location := resp.Header.Get("Location"); location != "" {
|
||||||
|
if CheckGitHubURL(location) != nil {
|
||||||
|
c.Header("Location", "/"+location)
|
||||||
|
} else {
|
||||||
|
proxyGitHubWithRedirect(c, location, redirectCount+1)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Status(resp.StatusCode)
|
||||||
|
|
||||||
|
// 输出处理后的内容
|
||||||
|
if _, err := io.Copy(c.Writer, processedBody); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// 复制响应头
|
||||||
|
for key, values := range resp.Header {
|
||||||
|
for _, value := range values {
|
||||||
|
c.Header(key, value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 处理重定向
|
||||||
|
if location := resp.Header.Get("Location"); location != "" {
|
||||||
|
if CheckGitHubURL(location) != nil {
|
||||||
|
c.Header("Location", "/"+location)
|
||||||
|
} else {
|
||||||
|
proxyGitHubWithRedirect(c, location, redirectCount+1)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Status(resp.StatusCode)
|
||||||
|
|
||||||
|
// 直接流式转发
|
||||||
|
io.Copy(c.Writer, resp.Body)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package main
|
package handlers
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"archive/tar"
|
"archive/tar"
|
||||||
@@ -23,6 +23,8 @@ import (
|
|||||||
"github.com/google/go-containerregistry/pkg/v1/partial"
|
"github.com/google/go-containerregistry/pkg/v1/partial"
|
||||||
"github.com/google/go-containerregistry/pkg/v1/remote"
|
"github.com/google/go-containerregistry/pkg/v1/remote"
|
||||||
"github.com/google/go-containerregistry/pkg/v1/types"
|
"github.com/google/go-containerregistry/pkg/v1/types"
|
||||||
|
"hubproxy/config"
|
||||||
|
"hubproxy/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
// DebounceEntry 防抖条目
|
// DebounceEntry 防抖条目
|
||||||
@@ -58,17 +60,15 @@ func (d *DownloadDebouncer) ShouldAllow(userID, contentKey string) bool {
|
|||||||
|
|
||||||
if entry, exists := d.entries[key]; exists {
|
if entry, exists := d.entries[key]; exists {
|
||||||
if now.Sub(entry.LastRequest) < d.window {
|
if now.Sub(entry.LastRequest) < d.window {
|
||||||
return false // 在防抖窗口内,拒绝请求
|
return false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 更新或创建条目
|
|
||||||
d.entries[key] = &DebounceEntry{
|
d.entries[key] = &DebounceEntry{
|
||||||
LastRequest: now,
|
LastRequest: now,
|
||||||
UserID: userID,
|
UserID: userID,
|
||||||
}
|
}
|
||||||
|
|
||||||
// 清理过期条目(每5分钟清理一次)
|
|
||||||
if time.Since(d.lastCleanup) > 5*time.Minute {
|
if time.Since(d.lastCleanup) > 5*time.Minute {
|
||||||
d.cleanup(now)
|
d.cleanup(now)
|
||||||
d.lastCleanup = now
|
d.lastCleanup = now
|
||||||
@@ -88,50 +88,41 @@ func (d *DownloadDebouncer) cleanup(now time.Time) {
|
|||||||
|
|
||||||
// generateContentFingerprint 生成内容指纹
|
// generateContentFingerprint 生成内容指纹
|
||||||
func generateContentFingerprint(images []string, platform string) string {
|
func generateContentFingerprint(images []string, platform string) string {
|
||||||
// 对镜像列表排序确保顺序无关
|
|
||||||
sortedImages := make([]string, len(images))
|
sortedImages := make([]string, len(images))
|
||||||
copy(sortedImages, images)
|
copy(sortedImages, images)
|
||||||
sort.Strings(sortedImages)
|
sort.Strings(sortedImages)
|
||||||
|
|
||||||
// 组合内容:镜像列表 + 平台信息
|
|
||||||
content := strings.Join(sortedImages, "|") + ":" + platform
|
content := strings.Join(sortedImages, "|") + ":" + platform
|
||||||
|
|
||||||
// 生成MD5哈希
|
|
||||||
hash := md5.Sum([]byte(content))
|
hash := md5.Sum([]byte(content))
|
||||||
return hex.EncodeToString(hash[:])
|
return hex.EncodeToString(hash[:])
|
||||||
}
|
}
|
||||||
|
|
||||||
// getUserID 获取用户标识
|
// getUserID 获取用户标识
|
||||||
func getUserID(c *gin.Context) string {
|
func getUserID(c *gin.Context) string {
|
||||||
// 优先使用会话Cookie
|
|
||||||
if sessionID, err := c.Cookie("session_id"); err == nil && sessionID != "" {
|
if sessionID, err := c.Cookie("session_id"); err == nil && sessionID != "" {
|
||||||
return "session:" + sessionID
|
return "session:" + sessionID
|
||||||
}
|
}
|
||||||
|
|
||||||
// 备用方案:IP + User-Agent组合
|
|
||||||
ip := c.ClientIP()
|
ip := c.ClientIP()
|
||||||
userAgent := c.GetHeader("User-Agent")
|
userAgent := c.GetHeader("User-Agent")
|
||||||
if userAgent == "" {
|
if userAgent == "" {
|
||||||
userAgent = "unknown"
|
userAgent = "unknown"
|
||||||
}
|
}
|
||||||
|
|
||||||
// 生成简短标识
|
|
||||||
combined := ip + ":" + userAgent
|
combined := ip + ":" + userAgent
|
||||||
hash := md5.Sum([]byte(combined))
|
hash := md5.Sum([]byte(combined))
|
||||||
return "ip:" + hex.EncodeToString(hash[:8]) // 只取前8字节
|
return "ip:" + hex.EncodeToString(hash[:8])
|
||||||
}
|
}
|
||||||
|
|
||||||
// 全局防抖器实例
|
|
||||||
var (
|
var (
|
||||||
singleImageDebouncer *DownloadDebouncer
|
singleImageDebouncer *DownloadDebouncer
|
||||||
batchImageDebouncer *DownloadDebouncer
|
batchImageDebouncer *DownloadDebouncer
|
||||||
)
|
)
|
||||||
|
|
||||||
// initDebouncer 初始化防抖器
|
// InitDebouncer 初始化防抖器
|
||||||
func initDebouncer() {
|
func InitDebouncer() {
|
||||||
// 单个镜像:5秒防抖窗口
|
|
||||||
singleImageDebouncer = NewDownloadDebouncer(5 * time.Second)
|
singleImageDebouncer = NewDownloadDebouncer(5 * time.Second)
|
||||||
// 批量镜像:60秒防抖窗口
|
|
||||||
batchImageDebouncer = NewDownloadDebouncer(60 * time.Second)
|
batchImageDebouncer = NewDownloadDebouncer(60 * time.Second)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -147,15 +138,15 @@ type ImageStreamerConfig struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewImageStreamer 创建镜像下载器
|
// NewImageStreamer 创建镜像下载器
|
||||||
func NewImageStreamer(config *ImageStreamerConfig) *ImageStreamer {
|
func NewImageStreamer(cfg *ImageStreamerConfig) *ImageStreamer {
|
||||||
if config == nil {
|
if cfg == nil {
|
||||||
config = &ImageStreamerConfig{}
|
cfg = &ImageStreamerConfig{}
|
||||||
}
|
}
|
||||||
|
|
||||||
concurrency := config.Concurrency
|
concurrency := cfg.Concurrency
|
||||||
if concurrency <= 0 {
|
if concurrency <= 0 {
|
||||||
cfg := GetConfig()
|
appCfg := config.GetConfig()
|
||||||
concurrency = cfg.Download.MaxImages
|
concurrency = appCfg.Download.MaxImages
|
||||||
if concurrency <= 0 {
|
if concurrency <= 0 {
|
||||||
concurrency = 10
|
concurrency = 10
|
||||||
}
|
}
|
||||||
@@ -163,7 +154,7 @@ func NewImageStreamer(config *ImageStreamerConfig) *ImageStreamer {
|
|||||||
|
|
||||||
remoteOptions := []remote.Option{
|
remoteOptions := []remote.Option{
|
||||||
remote.WithAuth(authn.Anonymous),
|
remote.WithAuth(authn.Anonymous),
|
||||||
remote.WithTransport(GetGlobalHTTPClient().Transport),
|
remote.WithTransport(utils.GetGlobalHTTPClient().Transport),
|
||||||
}
|
}
|
||||||
|
|
||||||
return &ImageStreamer{
|
return &ImageStreamer{
|
||||||
@@ -176,7 +167,7 @@ func NewImageStreamer(config *ImageStreamerConfig) *ImageStreamer {
|
|||||||
type StreamOptions struct {
|
type StreamOptions struct {
|
||||||
Platform string
|
Platform string
|
||||||
Compression bool
|
Compression bool
|
||||||
UseCompressedLayers bool // 是否保存原始压缩层,默认开启
|
UseCompressedLayers bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// StreamImageToWriter 流式下载镜像到Writer
|
// StreamImageToWriter 流式下载镜像到Writer
|
||||||
@@ -215,7 +206,6 @@ func (is *ImageStreamer) getImageDescriptor(ref name.Reference, options []remote
|
|||||||
|
|
||||||
// getImageDescriptorWithPlatform 获取指定平台的镜像描述符
|
// getImageDescriptorWithPlatform 获取指定平台的镜像描述符
|
||||||
func (is *ImageStreamer) getImageDescriptorWithPlatform(ref name.Reference, options []remote.Option, platform string) (*remote.Descriptor, error) {
|
func (is *ImageStreamer) getImageDescriptorWithPlatform(ref name.Reference, options []remote.Option, platform string) (*remote.Descriptor, error) {
|
||||||
// 直接从网络获取完整的descriptor,确保对象完整性
|
|
||||||
return remote.Get(ref, options...)
|
return remote.Get(ref, options...)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -343,7 +333,6 @@ func (is *ImageStreamer) streamDockerFormatWithReturn(ctx context.Context, tarWr
|
|||||||
var layerSize int64
|
var layerSize int64
|
||||||
var layerReader io.ReadCloser
|
var layerReader io.ReadCloser
|
||||||
|
|
||||||
// 根据配置选择使用压缩层或未压缩层
|
|
||||||
if options != nil && options.UseCompressedLayers {
|
if options != nil && options.UseCompressedLayers {
|
||||||
layerSize, err = layer.Size()
|
layerSize, err = layer.Size()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -385,7 +374,6 @@ func (is *ImageStreamer) streamDockerFormatWithReturn(ctx context.Context, tarWr
|
|||||||
log.Printf("已处理层 %d/%d", i+1, len(layers))
|
log.Printf("已处理层 %d/%d", i+1, len(layers))
|
||||||
}
|
}
|
||||||
|
|
||||||
// 构建单个镜像的manifest信息
|
|
||||||
singleManifest := map[string]interface{}{
|
singleManifest := map[string]interface{}{
|
||||||
"Config": configDigest.String() + ".json",
|
"Config": configDigest.String() + ".json",
|
||||||
"RepoTags": []string{imageRef},
|
"RepoTags": []string{imageRef},
|
||||||
@@ -398,7 +386,6 @@ func (is *ImageStreamer) streamDockerFormatWithReturn(ctx context.Context, tarWr
|
|||||||
}(),
|
}(),
|
||||||
}
|
}
|
||||||
|
|
||||||
// 构建repositories信息
|
|
||||||
repositories := make(map[string]map[string]string)
|
repositories := make(map[string]map[string]string)
|
||||||
parts := strings.Split(imageRef, ":")
|
parts := strings.Split(imageRef, ":")
|
||||||
if len(parts) == 2 {
|
if len(parts) == 2 {
|
||||||
@@ -407,14 +394,12 @@ func (is *ImageStreamer) streamDockerFormatWithReturn(ctx context.Context, tarWr
|
|||||||
repositories[repoName] = map[string]string{tag: configDigest.String()}
|
repositories[repoName] = map[string]string{tag: configDigest.String()}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 如果是批量下载,返回信息而不写入文件
|
|
||||||
if manifestOut != nil && repositoriesOut != nil {
|
if manifestOut != nil && repositoriesOut != nil {
|
||||||
*manifestOut = singleManifest
|
*manifestOut = singleManifest
|
||||||
*repositoriesOut = repositories
|
*repositoriesOut = repositories
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// 单镜像下载,直接写入manifest.json
|
|
||||||
manifest := []map[string]interface{}{singleManifest}
|
manifest := []map[string]interface{}{singleManifest}
|
||||||
|
|
||||||
manifestData, err := json.Marshal(manifest)
|
manifestData, err := json.Marshal(manifest)
|
||||||
@@ -436,7 +421,6 @@ func (is *ImageStreamer) streamDockerFormatWithReturn(ctx context.Context, tarWr
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 写入repositories文件
|
|
||||||
repositoriesData, err := json.Marshal(repositories)
|
repositoriesData, err := json.Marshal(repositories)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -456,7 +440,7 @@ func (is *ImageStreamer) streamDockerFormatWithReturn(ctx context.Context, tarWr
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// processImageForBatch 处理镜像的公共逻辑,用于批量下载
|
// processImageForBatch 处理镜像的公共逻辑
|
||||||
func (is *ImageStreamer) processImageForBatch(ctx context.Context, img v1.Image, tarWriter *tar.Writer, imageRef string, options *StreamOptions) (map[string]interface{}, map[string]map[string]string, error) {
|
func (is *ImageStreamer) processImageForBatch(ctx context.Context, img v1.Image, tarWriter *tar.Writer, imageRef string, options *StreamOptions) (map[string]interface{}, map[string]map[string]string, error) {
|
||||||
layers, err := img.Layers()
|
layers, err := img.Layers()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -498,7 +482,6 @@ func (is *ImageStreamer) streamSingleImageForBatch(ctx context.Context, tarWrite
|
|||||||
|
|
||||||
switch desc.MediaType {
|
switch desc.MediaType {
|
||||||
case types.OCIImageIndex, types.DockerManifestList:
|
case types.OCIImageIndex, types.DockerManifestList:
|
||||||
// 处理多架构镜像
|
|
||||||
img, err = is.selectPlatformImage(desc, options)
|
img, err = is.selectPlatformImage(desc, options)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, fmt.Errorf("选择平台镜像失败: %w", err)
|
return nil, nil, fmt.Errorf("选择平台镜像失败: %w", err)
|
||||||
@@ -530,7 +513,6 @@ func (is *ImageStreamer) selectPlatformImage(desc *remote.Descriptor, options *S
|
|||||||
return nil, fmt.Errorf("获取索引清单失败: %w", err)
|
return nil, fmt.Errorf("获取索引清单失败: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 选择合适的平台
|
|
||||||
var selectedDesc *v1.Descriptor
|
var selectedDesc *v1.Descriptor
|
||||||
for _, m := range manifest.Manifests {
|
for _, m := range manifest.Manifests {
|
||||||
if m.Platform == nil {
|
if m.Platform == nil {
|
||||||
@@ -578,8 +560,8 @@ func (is *ImageStreamer) selectPlatformImage(desc *remote.Descriptor, options *S
|
|||||||
|
|
||||||
var globalImageStreamer *ImageStreamer
|
var globalImageStreamer *ImageStreamer
|
||||||
|
|
||||||
// initImageStreamer 初始化镜像下载器
|
// InitImageStreamer 初始化镜像下载器
|
||||||
func initImageStreamer() {
|
func InitImageStreamer() {
|
||||||
globalImageStreamer = NewImageStreamer(nil)
|
globalImageStreamer = NewImageStreamer(nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -591,8 +573,8 @@ func formatPlatformText(platform string) string {
|
|||||||
return platform
|
return platform
|
||||||
}
|
}
|
||||||
|
|
||||||
// initImageTarRoutes 初始化镜像下载路由
|
// InitImageTarRoutes 初始化镜像下载路由
|
||||||
func initImageTarRoutes(router *gin.Engine) {
|
func InitImageTarRoutes(router *gin.Engine) {
|
||||||
imageAPI := router.Group("/api/image")
|
imageAPI := router.Group("/api/image")
|
||||||
{
|
{
|
||||||
imageAPI.GET("/download/:image", handleDirectImageDownload)
|
imageAPI.GET("/download/:image", handleDirectImageDownload)
|
||||||
@@ -625,7 +607,6 @@ func handleDirectImageDownload(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 防抖检查
|
|
||||||
userID := getUserID(c)
|
userID := getUserID(c)
|
||||||
contentKey := generateContentFingerprint([]string{imageRef}, platform)
|
contentKey := generateContentFingerprint([]string{imageRef}, platform)
|
||||||
|
|
||||||
@@ -677,7 +658,7 @@ func handleSimpleBatchDownload(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
cfg := GetConfig()
|
cfg := config.GetConfig()
|
||||||
if len(req.Images) > cfg.Download.MaxImages {
|
if len(req.Images) > cfg.Download.MaxImages {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{
|
c.JSON(http.StatusBadRequest, gin.H{
|
||||||
"error": fmt.Sprintf("镜像数量超过限制,最大允许: %d", cfg.Download.MaxImages),
|
"error": fmt.Sprintf("镜像数量超过限制,最大允许: %d", cfg.Download.MaxImages),
|
||||||
@@ -685,7 +666,6 @@ func handleSimpleBatchDownload(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 批量下载防抖检查
|
|
||||||
userID := getUserID(c)
|
userID := getUserID(c)
|
||||||
contentKey := generateContentFingerprint(req.Images, req.Platform)
|
contentKey := generateContentFingerprint(req.Images, req.Platform)
|
||||||
|
|
||||||
@@ -697,7 +677,7 @@ func handleSimpleBatchDownload(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
useCompressed := true // 默认启用原始压缩层
|
useCompressed := true
|
||||||
if req.UseCompressedLayers != nil {
|
if req.UseCompressedLayers != nil {
|
||||||
useCompressed = *req.UseCompressedLayers
|
useCompressed = *req.UseCompressedLayers
|
||||||
}
|
}
|
||||||
@@ -801,7 +781,6 @@ func (is *ImageStreamer) StreamMultipleImages(ctx context.Context, imageRefs []s
|
|||||||
var allManifests []map[string]interface{}
|
var allManifests []map[string]interface{}
|
||||||
var allRepositories = make(map[string]map[string]string)
|
var allRepositories = make(map[string]map[string]string)
|
||||||
|
|
||||||
// 流式处理每个镜像
|
|
||||||
for i, imageRef := range imageRefs {
|
for i, imageRef := range imageRefs {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
@@ -811,7 +790,6 @@ func (is *ImageStreamer) StreamMultipleImages(ctx context.Context, imageRefs []s
|
|||||||
|
|
||||||
log.Printf("处理镜像 %d/%d: %s", i+1, len(imageRefs), imageRef)
|
log.Printf("处理镜像 %d/%d: %s", i+1, len(imageRefs), imageRef)
|
||||||
|
|
||||||
// 防止单个镜像处理时间过长
|
|
||||||
timeoutCtx, cancel := context.WithTimeout(ctx, 15*time.Minute)
|
timeoutCtx, cancel := context.WithTimeout(ctx, 15*time.Minute)
|
||||||
manifest, repositories, err := is.streamSingleImageForBatch(timeoutCtx, tarWriter, imageRef, options)
|
manifest, repositories, err := is.streamSingleImageForBatch(timeoutCtx, tarWriter, imageRef, options)
|
||||||
cancel()
|
cancel()
|
||||||
@@ -825,10 +803,8 @@ func (is *ImageStreamer) StreamMultipleImages(ctx context.Context, imageRefs []s
|
|||||||
return fmt.Errorf("镜像 %s manifest数据为空", imageRef)
|
return fmt.Errorf("镜像 %s manifest数据为空", imageRef)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 收集manifest信息
|
|
||||||
allManifests = append(allManifests, manifest)
|
allManifests = append(allManifests, manifest)
|
||||||
|
|
||||||
// 合并repositories信息
|
|
||||||
for repo, tags := range repositories {
|
for repo, tags := range repositories {
|
||||||
if allRepositories[repo] == nil {
|
if allRepositories[repo] == nil {
|
||||||
allRepositories[repo] = make(map[string]string)
|
allRepositories[repo] = make(map[string]string)
|
||||||
@@ -839,7 +815,6 @@ func (is *ImageStreamer) StreamMultipleImages(ctx context.Context, imageRefs []s
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 写入合并的manifest.json
|
|
||||||
manifestData, err := json.Marshal(allManifests)
|
manifestData, err := json.Marshal(allManifests)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("序列化manifest失败: %w", err)
|
return fmt.Errorf("序列化manifest失败: %w", err)
|
||||||
@@ -859,7 +834,6 @@ func (is *ImageStreamer) StreamMultipleImages(ctx context.Context, imageRefs []s
|
|||||||
return fmt.Errorf("写入manifest数据失败: %w", err)
|
return fmt.Errorf("写入manifest数据失败: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 写入合并的repositories文件
|
|
||||||
repositoriesData, err := json.Marshal(allRepositories)
|
repositoriesData, err := json.Marshal(allRepositories)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("序列化repositories失败: %w", err)
|
return fmt.Errorf("序列化repositories失败: %w", err)
|
||||||
File diff suppressed because it is too large
Load Diff
309
src/main.go
309
src/main.go
@@ -3,15 +3,17 @@ package main
|
|||||||
import (
|
import (
|
||||||
"embed"
|
"embed"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"regexp"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"golang.org/x/net/http2"
|
||||||
|
"golang.org/x/net/http2/h2c"
|
||||||
|
"hubproxy/config"
|
||||||
|
"hubproxy/handlers"
|
||||||
|
"hubproxy/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
//go:embed public/*
|
//go:embed public/*
|
||||||
@@ -32,19 +34,7 @@ func serveEmbedFile(c *gin.Context, filename string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
exps = []*regexp.Regexp{
|
globalLimiter *utils.IPRateLimiter
|
||||||
regexp.MustCompile(`^(?:https?://)?github\.com/([^/]+)/([^/]+)/(?:releases|archive)/.*$`),
|
|
||||||
regexp.MustCompile(`^(?:https?://)?github\.com/([^/]+)/([^/]+)/(?:blob|raw)/.*$`),
|
|
||||||
regexp.MustCompile(`^(?:https?://)?github\.com/([^/]+)/([^/]+)/(?:info|git-).*$`),
|
|
||||||
regexp.MustCompile(`^(?:https?://)?raw\.github(?:usercontent|)\.com/([^/]+)/([^/]+)/.+?/.+$`),
|
|
||||||
regexp.MustCompile(`^(?:https?://)?gist\.github(?:usercontent|)\.com/([^/]+)/.+?/.+`),
|
|
||||||
regexp.MustCompile(`^(?:https?://)?api\.github\.com/repos/([^/]+)/([^/]+)/.*`),
|
|
||||||
regexp.MustCompile(`^(?:https?://)?huggingface\.co(?:/spaces)?/([^/]+)/(.+)$`),
|
|
||||||
regexp.MustCompile(`^(?:https?://)?cdn-lfs\.hf\.co(?:/spaces)?/([^/]+)/([^/]+)(?:/(.*))?$`),
|
|
||||||
regexp.MustCompile(`^(?:https?://)?download\.docker\.com/([^/]+)/.*\.(tgz|zip)$`),
|
|
||||||
regexp.MustCompile(`^(?:https?://)?(github|opengraph)\.githubassets\.com/([^/]+)/.+?$`),
|
|
||||||
}
|
|
||||||
globalLimiter *IPRateLimiter
|
|
||||||
|
|
||||||
// 服务启动时间
|
// 服务启动时间
|
||||||
serviceStartTime = time.Now()
|
serviceStartTime = time.Now()
|
||||||
@@ -52,25 +42,25 @@ var (
|
|||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
// 加载配置
|
// 加载配置
|
||||||
if err := LoadConfig(); err != nil {
|
if err := config.LoadConfig(); err != nil {
|
||||||
fmt.Printf("配置加载失败: %v\n", err)
|
fmt.Printf("配置加载失败: %v\n", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 初始化HTTP客户端
|
// 初始化HTTP客户端
|
||||||
initHTTPClients()
|
utils.InitHTTPClients()
|
||||||
|
|
||||||
// 初始化限流器
|
// 初始化限流器
|
||||||
initLimiter()
|
globalLimiter = utils.InitGlobalLimiter()
|
||||||
|
|
||||||
// 初始化Docker流式代理
|
// 初始化Docker流式代理
|
||||||
initDockerProxy()
|
handlers.InitDockerProxy()
|
||||||
|
|
||||||
// 初始化镜像流式下载器
|
// 初始化镜像流式下载器
|
||||||
initImageStreamer()
|
handlers.InitImageStreamer()
|
||||||
|
|
||||||
// 初始化防抖器
|
// 初始化防抖器
|
||||||
initDebouncer()
|
handlers.InitDebouncer()
|
||||||
|
|
||||||
gin.SetMode(gin.ReleaseMode)
|
gin.SetMode(gin.ReleaseMode)
|
||||||
router := gin.Default()
|
router := gin.Default()
|
||||||
@@ -84,14 +74,14 @@ func main() {
|
|||||||
})
|
})
|
||||||
}))
|
}))
|
||||||
|
|
||||||
// 全局限流中间件 - 应用到所有路由
|
// 全局限流中间件
|
||||||
router.Use(RateLimitMiddleware(globalLimiter))
|
router.Use(utils.RateLimitMiddleware(globalLimiter))
|
||||||
|
|
||||||
// 初始化监控端点
|
// 初始化监控端点
|
||||||
initHealthRoutes(router)
|
initHealthRoutes(router)
|
||||||
|
|
||||||
// 初始化镜像tar下载路由
|
// 初始化镜像tar下载路由
|
||||||
initImageTarRoutes(router)
|
handlers.InitImageTarRoutes(router)
|
||||||
|
|
||||||
// 静态文件路由
|
// 静态文件路由
|
||||||
router.GET("/", func(c *gin.Context) {
|
router.GET("/", func(c *gin.Context) {
|
||||||
@@ -113,217 +103,60 @@ func main() {
|
|||||||
})
|
})
|
||||||
|
|
||||||
// 注册dockerhub搜索路由
|
// 注册dockerhub搜索路由
|
||||||
RegisterSearchRoute(router)
|
handlers.RegisterSearchRoute(router)
|
||||||
|
|
||||||
// 注册Docker认证路由(/token*)
|
// 注册Docker认证路由
|
||||||
router.Any("/token", ProxyDockerAuthGin)
|
router.Any("/token", handlers.ProxyDockerAuthGin)
|
||||||
router.Any("/token/*path", ProxyDockerAuthGin)
|
router.Any("/token/*path", handlers.ProxyDockerAuthGin)
|
||||||
|
|
||||||
// 注册Docker Registry代理路由
|
// 注册Docker Registry代理路由
|
||||||
router.Any("/v2/*path", ProxyDockerRegistryGin)
|
router.Any("/v2/*path", handlers.ProxyDockerRegistryGin)
|
||||||
|
|
||||||
// 注册NoRoute处理器
|
// 注册GitHub代理路由(NoRoute处理器)
|
||||||
router.NoRoute(handler)
|
router.NoRoute(handlers.GitHubProxyHandler)
|
||||||
|
|
||||||
cfg := GetConfig()
|
cfg := config.GetConfig()
|
||||||
fmt.Printf("🚀 HubProxy 启动成功\n")
|
fmt.Printf("HubProxy 启动成功\n")
|
||||||
fmt.Printf("📡 监听地址: %s:%d\n", cfg.Server.Host, cfg.Server.Port)
|
fmt.Printf("监听地址: %s:%d\n", cfg.Server.Host, cfg.Server.Port)
|
||||||
fmt.Printf("⚡ 限流配置: %d请求/%g小时\n", cfg.RateLimit.RequestLimit, cfg.RateLimit.PeriodHours)
|
fmt.Printf("限流配置: %d请求/%g小时\n", cfg.RateLimit.RequestLimit, cfg.RateLimit.PeriodHours)
|
||||||
fmt.Printf("🔗 项目地址: https://github.com/sky22333/hubproxy\n")
|
|
||||||
|
|
||||||
err := router.Run(fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port))
|
// 显示HTTP/2支持状态
|
||||||
|
if cfg.Server.EnableH2C {
|
||||||
|
fmt.Printf("H2c: 已启用\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("版本号: v1.1.6\n")
|
||||||
|
fmt.Printf("项目地址: https://github.com/sky22333/hubproxy\n")
|
||||||
|
|
||||||
|
// 创建HTTP2服务器
|
||||||
|
server := &http.Server{
|
||||||
|
Addr: fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port),
|
||||||
|
ReadTimeout: 60 * time.Second,
|
||||||
|
WriteTimeout: 300 * time.Second,
|
||||||
|
IdleTimeout: 120 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 根据配置决定是否启用H2C
|
||||||
|
if cfg.Server.EnableH2C {
|
||||||
|
h2cHandler := h2c.NewHandler(router, &http2.Server{
|
||||||
|
MaxConcurrentStreams: 250,
|
||||||
|
IdleTimeout: 300 * time.Second,
|
||||||
|
MaxReadFrameSize: 4 << 20,
|
||||||
|
MaxUploadBufferPerConnection: 8 << 20,
|
||||||
|
MaxUploadBufferPerStream: 2 << 20,
|
||||||
|
})
|
||||||
|
server.Handler = h2cHandler
|
||||||
|
} else {
|
||||||
|
server.Handler = router
|
||||||
|
}
|
||||||
|
|
||||||
|
err := server.ListenAndServe()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("启动服务失败: %v\n", err)
|
fmt.Printf("启动服务失败: %v\n", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func handler(c *gin.Context) {
|
|
||||||
rawPath := strings.TrimPrefix(c.Request.URL.RequestURI(), "/")
|
|
||||||
|
|
||||||
for strings.HasPrefix(rawPath, "/") {
|
|
||||||
rawPath = strings.TrimPrefix(rawPath, "/")
|
|
||||||
}
|
|
||||||
// 自动补全协议头
|
|
||||||
if !strings.HasPrefix(rawPath, "https://") {
|
|
||||||
// 修复 http:/ 和 https:/ 的情况
|
|
||||||
if strings.HasPrefix(rawPath, "http:/") || strings.HasPrefix(rawPath, "https:/") {
|
|
||||||
rawPath = strings.Replace(rawPath, "http:/", "", 1)
|
|
||||||
rawPath = strings.Replace(rawPath, "https:/", "", 1)
|
|
||||||
} else if strings.HasPrefix(rawPath, "http://") {
|
|
||||||
rawPath = strings.TrimPrefix(rawPath, "http://")
|
|
||||||
}
|
|
||||||
rawPath = "https://" + rawPath
|
|
||||||
}
|
|
||||||
|
|
||||||
matches := checkURL(rawPath)
|
|
||||||
if matches != nil {
|
|
||||||
// GitHub仓库访问控制检查
|
|
||||||
if allowed, reason := GlobalAccessController.CheckGitHubAccess(matches); !allowed {
|
|
||||||
// 构建仓库名用于日志
|
|
||||||
var repoPath string
|
|
||||||
if len(matches) >= 2 {
|
|
||||||
username := matches[0]
|
|
||||||
repoName := strings.TrimSuffix(matches[1], ".git")
|
|
||||||
repoPath = username + "/" + repoName
|
|
||||||
}
|
|
||||||
fmt.Printf("GitHub仓库 %s 访问被拒绝: %s\n", repoPath, reason)
|
|
||||||
c.String(http.StatusForbidden, reason)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
c.String(http.StatusForbidden, "无效输入")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if exps[1].MatchString(rawPath) {
|
|
||||||
rawPath = strings.Replace(rawPath, "/blob/", "/raw/", 1)
|
|
||||||
}
|
|
||||||
|
|
||||||
proxyRequest(c, rawPath)
|
|
||||||
}
|
|
||||||
|
|
||||||
func proxyRequest(c *gin.Context, u string) {
|
|
||||||
proxyWithRedirect(c, u, 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
func proxyWithRedirect(c *gin.Context, u string, redirectCount int) {
|
|
||||||
// 限制最大重定向次数,防止无限递归
|
|
||||||
const maxRedirects = 20
|
|
||||||
if redirectCount > maxRedirects {
|
|
||||||
c.String(http.StatusLoopDetected, "重定向次数过多,可能存在循环重定向")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
req, err := http.NewRequest(c.Request.Method, u, c.Request.Body)
|
|
||||||
if err != nil {
|
|
||||||
c.String(http.StatusInternalServerError, fmt.Sprintf("server error %v", err))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
for key, values := range c.Request.Header {
|
|
||||||
for _, value := range values {
|
|
||||||
req.Header.Add(key, value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
req.Header.Del("Host")
|
|
||||||
|
|
||||||
resp, err := GetGlobalHTTPClient().Do(req)
|
|
||||||
if err != nil {
|
|
||||||
c.String(http.StatusInternalServerError, fmt.Sprintf("server error %v", err))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
if err := resp.Body.Close(); err != nil {
|
|
||||||
fmt.Printf("关闭响应体失败: %v\n", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
// 检查文件大小限制
|
|
||||||
cfg := GetConfig()
|
|
||||||
if contentLength := resp.Header.Get("Content-Length"); contentLength != "" {
|
|
||||||
if size, err := strconv.ParseInt(contentLength, 10, 64); err == nil && size > cfg.Server.FileSize {
|
|
||||||
c.String(http.StatusRequestEntityTooLarge,
|
|
||||||
fmt.Sprintf("文件过大,限制大小: %d MB", cfg.Server.FileSize/(1024*1024)))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 清理安全相关的头
|
|
||||||
resp.Header.Del("Content-Security-Policy")
|
|
||||||
resp.Header.Del("Referrer-Policy")
|
|
||||||
resp.Header.Del("Strict-Transport-Security")
|
|
||||||
|
|
||||||
// 获取真实域名
|
|
||||||
realHost := c.Request.Header.Get("X-Forwarded-Host")
|
|
||||||
if realHost == "" {
|
|
||||||
realHost = c.Request.Host
|
|
||||||
}
|
|
||||||
// 如果域名中没有协议前缀,添加https://
|
|
||||||
if !strings.HasPrefix(realHost, "http://") && !strings.HasPrefix(realHost, "https://") {
|
|
||||||
realHost = "https://" + realHost
|
|
||||||
}
|
|
||||||
|
|
||||||
if strings.HasSuffix(strings.ToLower(u), ".sh") {
|
|
||||||
isGzipCompressed := resp.Header.Get("Content-Encoding") == "gzip"
|
|
||||||
|
|
||||||
processedBody, processedSize, err := ProcessSmart(resp.Body, isGzipCompressed, realHost)
|
|
||||||
if err != nil {
|
|
||||||
fmt.Printf("智能处理失败,回退到直接代理: %v\n", err)
|
|
||||||
processedBody = resp.Body
|
|
||||||
processedSize = 0
|
|
||||||
}
|
|
||||||
|
|
||||||
// 智能设置响应头
|
|
||||||
if processedSize > 0 {
|
|
||||||
resp.Header.Del("Content-Length")
|
|
||||||
resp.Header.Del("Content-Encoding")
|
|
||||||
resp.Header.Set("Transfer-Encoding", "chunked")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 复制其他响应头
|
|
||||||
for key, values := range resp.Header {
|
|
||||||
for _, value := range values {
|
|
||||||
c.Header(key, value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if location := resp.Header.Get("Location"); location != "" {
|
|
||||||
if checkURL(location) != nil {
|
|
||||||
c.Header("Location", "/"+location)
|
|
||||||
} else {
|
|
||||||
proxyWithRedirect(c, location, redirectCount+1)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
c.Status(resp.StatusCode)
|
|
||||||
|
|
||||||
// 输出处理后的内容
|
|
||||||
if _, err := io.Copy(c.Writer, processedBody); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for key, values := range resp.Header {
|
|
||||||
for _, value := range values {
|
|
||||||
c.Header(key, value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 处理重定向
|
|
||||||
if location := resp.Header.Get("Location"); location != "" {
|
|
||||||
if checkURL(location) != nil {
|
|
||||||
c.Header("Location", "/"+location)
|
|
||||||
} else {
|
|
||||||
proxyWithRedirect(c, location, redirectCount+1)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
c.Status(resp.StatusCode)
|
|
||||||
|
|
||||||
// 直接流式转发
|
|
||||||
io.Copy(c.Writer, resp.Body)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func checkURL(u string) []string {
|
|
||||||
for _, exp := range exps {
|
|
||||||
if matches := exp.FindStringSubmatch(u); matches != nil {
|
|
||||||
return matches[1:]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// 简单的健康检查
|
// 简单的健康检查
|
||||||
func formatBeijingTime(t time.Time) string {
|
|
||||||
loc, err := time.LoadLocation("Asia/Shanghai")
|
|
||||||
if err != nil {
|
|
||||||
loc = time.FixedZone("CST", 8*3600) // 兜底时区
|
|
||||||
}
|
|
||||||
return t.In(loc).Format("2006-01-02 15:04:05")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 转换为可读时间
|
|
||||||
func formatDuration(d time.Duration) string {
|
func formatDuration(d time.Duration) string {
|
||||||
if d < time.Minute {
|
if d < time.Minute {
|
||||||
return fmt.Sprintf("%d秒", int(d.Seconds()))
|
return fmt.Sprintf("%d秒", int(d.Seconds()))
|
||||||
@@ -338,26 +171,20 @@ func formatDuration(d time.Duration) string {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func initHealthRoutes(router *gin.Engine) {
|
func getUptimeInfo() (time.Duration, float64, string) {
|
||||||
router.GET("/health", func(c *gin.Context) {
|
uptime := time.Since(serviceStartTime)
|
||||||
uptime := time.Since(serviceStartTime)
|
return uptime, uptime.Seconds(), formatDuration(uptime)
|
||||||
c.JSON(http.StatusOK, gin.H{
|
}
|
||||||
"status": "healthy",
|
|
||||||
"timestamp_unix": serviceStartTime.Unix(),
|
|
||||||
"uptime_sec": uptime.Seconds(),
|
|
||||||
"service": "hubproxy",
|
|
||||||
"start_time_bj": formatBeijingTime(serviceStartTime),
|
|
||||||
"uptime_human": formatDuration(uptime),
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
|
func initHealthRoutes(router *gin.Engine) {
|
||||||
router.GET("/ready", func(c *gin.Context) {
|
router.GET("/ready", func(c *gin.Context) {
|
||||||
uptime := time.Since(serviceStartTime)
|
_, uptimeSec, uptimeHuman := getUptimeInfo()
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"ready": true,
|
"ready": true,
|
||||||
"timestamp_unix": time.Now().Unix(),
|
"service": "hubproxy",
|
||||||
"uptime_sec": uptime.Seconds(),
|
"start_time_unix": serviceStartTime.Unix(),
|
||||||
"uptime_human": formatDuration(uptime),
|
"uptime_sec": uptimeSec,
|
||||||
|
"uptime_human": uptimeHuman,
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,8 +1,10 @@
|
|||||||
package main
|
package utils
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
|
"hubproxy/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ResourceType 资源类型
|
// ResourceType 资源类型
|
||||||
@@ -26,7 +28,7 @@ type DockerImageInfo struct {
|
|||||||
FullName string
|
FullName string
|
||||||
}
|
}
|
||||||
|
|
||||||
// 全局访问控制器实例
|
// GlobalAccessController 全局访问控制器实例
|
||||||
var GlobalAccessController = &AccessController{}
|
var GlobalAccessController = &AccessController{}
|
||||||
|
|
||||||
// ParseDockerImage 解析Docker镜像名称
|
// ParseDockerImage 解析Docker镜像名称
|
||||||
@@ -79,19 +81,16 @@ func (ac *AccessController) ParseDockerImage(image string) DockerImageInfo {
|
|||||||
|
|
||||||
// CheckDockerAccess 检查Docker镜像访问权限
|
// CheckDockerAccess 检查Docker镜像访问权限
|
||||||
func (ac *AccessController) CheckDockerAccess(image string) (allowed bool, reason string) {
|
func (ac *AccessController) CheckDockerAccess(image string) (allowed bool, reason string) {
|
||||||
cfg := GetConfig()
|
cfg := config.GetConfig()
|
||||||
|
|
||||||
// 解析镜像名称
|
|
||||||
imageInfo := ac.ParseDockerImage(image)
|
imageInfo := ac.ParseDockerImage(image)
|
||||||
|
|
||||||
// 检查白名单(如果配置了白名单,则只允许白名单中的镜像)
|
|
||||||
if len(cfg.Access.WhiteList) > 0 {
|
if len(cfg.Access.WhiteList) > 0 {
|
||||||
if !ac.matchImageInList(imageInfo, cfg.Access.WhiteList) {
|
if !ac.matchImageInList(imageInfo, cfg.Access.WhiteList) {
|
||||||
return false, "不在Docker镜像白名单内"
|
return false, "不在Docker镜像白名单内"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 检查黑名单
|
|
||||||
if len(cfg.Access.BlackList) > 0 {
|
if len(cfg.Access.BlackList) > 0 {
|
||||||
if ac.matchImageInList(imageInfo, cfg.Access.BlackList) {
|
if ac.matchImageInList(imageInfo, cfg.Access.BlackList) {
|
||||||
return false, "Docker镜像在黑名单内"
|
return false, "Docker镜像在黑名单内"
|
||||||
@@ -107,14 +106,12 @@ func (ac *AccessController) CheckGitHubAccess(matches []string) (allowed bool, r
|
|||||||
return false, "无效的GitHub仓库格式"
|
return false, "无效的GitHub仓库格式"
|
||||||
}
|
}
|
||||||
|
|
||||||
cfg := GetConfig()
|
cfg := config.GetConfig()
|
||||||
|
|
||||||
// 检查白名单
|
|
||||||
if len(cfg.Access.WhiteList) > 0 && !ac.checkList(matches, cfg.Access.WhiteList) {
|
if len(cfg.Access.WhiteList) > 0 && !ac.checkList(matches, cfg.Access.WhiteList) {
|
||||||
return false, "不在GitHub仓库白名单内"
|
return false, "不在GitHub仓库白名单内"
|
||||||
}
|
}
|
||||||
|
|
||||||
// 检查黑名单
|
|
||||||
if len(cfg.Access.BlackList) > 0 && ac.checkList(matches, cfg.Access.BlackList) {
|
if len(cfg.Access.BlackList) > 0 && ac.checkList(matches, cfg.Access.BlackList) {
|
||||||
return false, "GitHub仓库在黑名单内"
|
return false, "GitHub仓库在黑名单内"
|
||||||
}
|
}
|
||||||
@@ -185,17 +182,14 @@ func (ac *AccessController) checkList(matches, list []string) bool {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// 支持多种匹配模式
|
|
||||||
if fullRepo == item {
|
if fullRepo == item {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// 用户级匹配
|
|
||||||
if item == username || item == username+"/*" {
|
if item == username || item == username+"/*" {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// 前缀匹配(支持通配符)
|
|
||||||
if strings.HasSuffix(item, "*") {
|
if strings.HasSuffix(item, "*") {
|
||||||
prefix := strings.TrimSuffix(item, "*")
|
prefix := strings.TrimSuffix(item, "*")
|
||||||
if strings.HasPrefix(fullRepo, prefix) {
|
if strings.HasPrefix(fullRepo, prefix) {
|
||||||
@@ -203,7 +197,6 @@ func (ac *AccessController) checkList(matches, list []string) bool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 子仓库匹配(防止 user/repo 匹配到 user/repo-fork)
|
|
||||||
if strings.HasPrefix(fullRepo, item+"/") {
|
if strings.HasPrefix(fullRepo, item+"/") {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package main
|
package utils
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/md5"
|
"crypto/md5"
|
||||||
@@ -9,22 +9,23 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"hubproxy/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
// CachedItem 通用缓存项,支持Token和Manifest
|
// CachedItem 通用缓存项
|
||||||
type CachedItem struct {
|
type CachedItem struct {
|
||||||
Data []byte // 缓存数据(token字符串或manifest字节)
|
Data []byte
|
||||||
ContentType string // 内容类型
|
ContentType string
|
||||||
Headers map[string]string // 额外的响应头
|
Headers map[string]string
|
||||||
ExpiresAt time.Time // 过期时间
|
ExpiresAt time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
// UniversalCache 通用缓存,支持Token和Manifest
|
// UniversalCache 通用缓存
|
||||||
type UniversalCache struct {
|
type UniversalCache struct {
|
||||||
cache sync.Map
|
cache sync.Map
|
||||||
}
|
}
|
||||||
|
|
||||||
var globalCache = &UniversalCache{}
|
var GlobalCache = &UniversalCache{}
|
||||||
|
|
||||||
// Get 获取缓存项
|
// Get 获取缓存项
|
||||||
func (c *UniversalCache) Get(key string) *CachedItem {
|
func (c *UniversalCache) Get(key string) *CachedItem {
|
||||||
@@ -57,22 +58,22 @@ func (c *UniversalCache) SetToken(key, token string, ttl time.Duration) {
|
|||||||
c.Set(key, []byte(token), "application/json", nil, ttl)
|
c.Set(key, []byte(token), "application/json", nil, ttl)
|
||||||
}
|
}
|
||||||
|
|
||||||
// buildCacheKey 构建稳定的缓存key
|
// BuildCacheKey 构建稳定的缓存key
|
||||||
func buildCacheKey(prefix, query string) string {
|
func BuildCacheKey(prefix, query string) string {
|
||||||
return fmt.Sprintf("%s:%x", prefix, md5.Sum([]byte(query)))
|
return fmt.Sprintf("%s:%x", prefix, md5.Sum([]byte(query)))
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildTokenCacheKey(query string) string {
|
func BuildTokenCacheKey(query string) string {
|
||||||
return buildCacheKey("token", query)
|
return BuildCacheKey("token", query)
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildManifestCacheKey(imageRef, reference string) string {
|
func BuildManifestCacheKey(imageRef, reference string) string {
|
||||||
key := fmt.Sprintf("%s:%s", imageRef, reference)
|
key := fmt.Sprintf("%s:%s", imageRef, reference)
|
||||||
return buildCacheKey("manifest", key)
|
return BuildCacheKey("manifest", key)
|
||||||
}
|
}
|
||||||
|
|
||||||
func getManifestTTL(reference string) time.Duration {
|
func GetManifestTTL(reference string) time.Duration {
|
||||||
cfg := GetConfig()
|
cfg := config.GetConfig()
|
||||||
defaultTTL := 30 * time.Minute
|
defaultTTL := 30 * time.Minute
|
||||||
if cfg.TokenCache.DefaultTTL != "" {
|
if cfg.TokenCache.DefaultTTL != "" {
|
||||||
if parsed, err := time.ParseDuration(cfg.TokenCache.DefaultTTL); err == nil {
|
if parsed, err := time.ParseDuration(cfg.TokenCache.DefaultTTL); err == nil {
|
||||||
@@ -84,23 +85,20 @@ func getManifestTTL(reference string) time.Duration {
|
|||||||
return 24 * time.Hour
|
return 24 * time.Hour
|
||||||
}
|
}
|
||||||
|
|
||||||
// mutable tag的智能判断
|
|
||||||
if reference == "latest" || reference == "main" || reference == "master" ||
|
if reference == "latest" || reference == "main" || reference == "master" ||
|
||||||
reference == "dev" || reference == "develop" {
|
reference == "dev" || reference == "develop" {
|
||||||
// 热门可变标签: 短期缓存
|
|
||||||
return 10 * time.Minute
|
return 10 * time.Minute
|
||||||
}
|
}
|
||||||
|
|
||||||
return defaultTTL
|
return defaultTTL
|
||||||
}
|
}
|
||||||
|
|
||||||
// extractTTLFromResponse 从响应中智能提取TTL
|
// ExtractTTLFromResponse 从响应中智能提取TTL
|
||||||
func extractTTLFromResponse(responseBody []byte) time.Duration {
|
func ExtractTTLFromResponse(responseBody []byte) time.Duration {
|
||||||
var tokenResp struct {
|
var tokenResp struct {
|
||||||
ExpiresIn int `json:"expires_in"`
|
ExpiresIn int `json:"expires_in"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// 默认30分钟TTL,确保稳定性
|
|
||||||
defaultTTL := 30 * time.Minute
|
defaultTTL := 30 * time.Minute
|
||||||
|
|
||||||
if json.Unmarshal(responseBody, &tokenResp) == nil && tokenResp.ExpiresIn > 0 {
|
if json.Unmarshal(responseBody, &tokenResp) == nil && tokenResp.ExpiresIn > 0 {
|
||||||
@@ -113,37 +111,35 @@ func extractTTLFromResponse(responseBody []byte) time.Duration {
|
|||||||
return defaultTTL
|
return defaultTTL
|
||||||
}
|
}
|
||||||
|
|
||||||
func writeTokenResponse(c *gin.Context, cachedBody string) {
|
func WriteTokenResponse(c *gin.Context, cachedBody string) {
|
||||||
c.Header("Content-Type", "application/json")
|
c.Header("Content-Type", "application/json")
|
||||||
c.String(200, cachedBody)
|
c.String(200, cachedBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
func writeCachedResponse(c *gin.Context, item *CachedItem) {
|
func WriteCachedResponse(c *gin.Context, item *CachedItem) {
|
||||||
if item.ContentType != "" {
|
if item.ContentType != "" {
|
||||||
c.Header("Content-Type", item.ContentType)
|
c.Header("Content-Type", item.ContentType)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 设置额外的响应头
|
|
||||||
for key, value := range item.Headers {
|
for key, value := range item.Headers {
|
||||||
c.Header(key, value)
|
c.Header(key, value)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 返回数据
|
|
||||||
c.Data(200, item.ContentType, item.Data)
|
c.Data(200, item.ContentType, item.Data)
|
||||||
}
|
}
|
||||||
|
|
||||||
// isCacheEnabled 检查缓存是否启用
|
// IsCacheEnabled 检查缓存是否启用
|
||||||
func isCacheEnabled() bool {
|
func IsCacheEnabled() bool {
|
||||||
cfg := GetConfig()
|
cfg := config.GetConfig()
|
||||||
return cfg.TokenCache.Enabled
|
return cfg.TokenCache.Enabled
|
||||||
}
|
}
|
||||||
|
|
||||||
// isTokenCacheEnabled 检查token缓存是否启用(向后兼容)
|
// IsTokenCacheEnabled 检查token缓存是否启用
|
||||||
func isTokenCacheEnabled() bool {
|
func IsTokenCacheEnabled() bool {
|
||||||
return isCacheEnabled()
|
return IsCacheEnabled()
|
||||||
}
|
}
|
||||||
|
|
||||||
// 定期清理过期缓存,防止内存泄漏
|
// 定期清理过期缓存
|
||||||
func init() {
|
func init() {
|
||||||
go func() {
|
go func() {
|
||||||
ticker := time.NewTicker(20 * time.Minute)
|
ticker := time.NewTicker(20 * time.Minute)
|
||||||
@@ -153,7 +149,7 @@ func init() {
|
|||||||
now := time.Now()
|
now := time.Now()
|
||||||
expiredKeys := make([]string, 0)
|
expiredKeys := make([]string, 0)
|
||||||
|
|
||||||
globalCache.cache.Range(func(key, value interface{}) bool {
|
GlobalCache.cache.Range(func(key, value interface{}) bool {
|
||||||
if cached := value.(*CachedItem); now.After(cached.ExpiresAt) {
|
if cached := value.(*CachedItem); now.After(cached.ExpiresAt) {
|
||||||
expiredKeys = append(expiredKeys, key.(string))
|
expiredKeys = append(expiredKeys, key.(string))
|
||||||
}
|
}
|
||||||
@@ -161,7 +157,7 @@ func init() {
|
|||||||
})
|
})
|
||||||
|
|
||||||
for _, key := range expiredKeys {
|
for _, key := range expiredKeys {
|
||||||
globalCache.cache.Delete(key)
|
GlobalCache.cache.Delete(key)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
@@ -1,28 +1,28 @@
|
|||||||
package main
|
package utils
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"hubproxy/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
// 全局HTTP客户端 - 用于代理请求(长超时)
|
|
||||||
globalHTTPClient *http.Client
|
globalHTTPClient *http.Client
|
||||||
// 搜索HTTP客户端 - 用于API请求(短超时)
|
|
||||||
searchHTTPClient *http.Client
|
searchHTTPClient *http.Client
|
||||||
)
|
)
|
||||||
|
|
||||||
// initHTTPClients 初始化HTTP客户端
|
// InitHTTPClients 初始化HTTP客户端
|
||||||
func initHTTPClients() {
|
func InitHTTPClients() {
|
||||||
cfg := GetConfig()
|
cfg := config.GetConfig()
|
||||||
|
|
||||||
if p := cfg.Access.Proxy; p != "" {
|
if p := cfg.Access.Proxy; p != "" {
|
||||||
os.Setenv("HTTP_PROXY", p)
|
os.Setenv("HTTP_PROXY", p)
|
||||||
os.Setenv("HTTPS_PROXY", p)
|
os.Setenv("HTTPS_PROXY", p)
|
||||||
}
|
}
|
||||||
// 代理客户端配置 - 适用于大文件传输
|
|
||||||
globalHTTPClient = &http.Client{
|
globalHTTPClient = &http.Client{
|
||||||
Transport: &http.Transport{
|
Transport: &http.Transport{
|
||||||
Proxy: http.ProxyFromEnvironment,
|
Proxy: http.ProxyFromEnvironment,
|
||||||
@@ -39,7 +39,6 @@ func initHTTPClients() {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
// 搜索客户端配置 - 适用于API调用
|
|
||||||
searchHTTPClient = &http.Client{
|
searchHTTPClient = &http.Client{
|
||||||
Timeout: 10 * time.Second,
|
Timeout: 10 * time.Second,
|
||||||
Transport: &http.Transport{
|
Transport: &http.Transport{
|
||||||
@@ -57,12 +56,12 @@ func initHTTPClients() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetGlobalHTTPClient 获取全局HTTP客户端(用于代理)
|
// GetGlobalHTTPClient 获取全局HTTP客户端
|
||||||
func GetGlobalHTTPClient() *http.Client {
|
func GetGlobalHTTPClient() *http.Client {
|
||||||
return globalHTTPClient
|
return globalHTTPClient
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetSearchHTTPClient 获取搜索HTTP客户端(用于API调用)
|
// GetSearchHTTPClient 获取搜索HTTP客户端
|
||||||
func GetSearchHTTPClient() *http.Client {
|
func GetSearchHTTPClient() *http.Client {
|
||||||
return searchHTTPClient
|
return searchHTTPClient
|
||||||
}
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package main
|
package utils
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
@@ -41,7 +41,6 @@ func ProcessSmart(input io.ReadCloser, isCompressed bool, host string) (io.Reade
|
|||||||
func readShellContent(input io.ReadCloser, isCompressed bool) (string, error) {
|
func readShellContent(input io.ReadCloser, isCompressed bool) (string, error) {
|
||||||
var reader io.Reader = input
|
var reader io.Reader = input
|
||||||
|
|
||||||
// 处理gzip压缩
|
|
||||||
if isCompressed {
|
if isCompressed {
|
||||||
peek := make([]byte, 2)
|
peek := make([]byte, 2)
|
||||||
n, err := input.Read(peek)
|
n, err := input.Read(peek)
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package main
|
package utils
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -9,22 +9,23 @@ import (
|
|||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"golang.org/x/time/rate"
|
"golang.org/x/time/rate"
|
||||||
|
"hubproxy/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// 清理间隔
|
|
||||||
CleanupInterval = 10 * time.Minute
|
CleanupInterval = 10 * time.Minute
|
||||||
MaxIPCacheSize = 10000
|
MaxIPCacheSize = 10000
|
||||||
)
|
)
|
||||||
|
|
||||||
// IPRateLimiter IP限流器结构体
|
// IPRateLimiter IP限流器结构体
|
||||||
type IPRateLimiter struct {
|
type IPRateLimiter struct {
|
||||||
ips map[string]*rateLimiterEntry // IP到限流器的映射
|
ips map[string]*rateLimiterEntry
|
||||||
mu *sync.RWMutex // 读写锁,保证并发安全
|
mu *sync.RWMutex
|
||||||
r rate.Limit // 速率限制(每秒允许的请求数)
|
r rate.Limit
|
||||||
b int // 令牌桶容量(突发请求数)
|
b int
|
||||||
whitelist []*net.IPNet // 白名单IP段
|
whitelist []*net.IPNet
|
||||||
blacklist []*net.IPNet // 黑名单IP段
|
blacklist []*net.IPNet
|
||||||
|
whitelistLimiter *rate.Limiter // 全局共享的白名单限流器
|
||||||
}
|
}
|
||||||
|
|
||||||
// rateLimiterEntry 限流器条目
|
// rateLimiterEntry 限流器条目
|
||||||
@@ -33,15 +34,15 @@ type rateLimiterEntry struct {
|
|||||||
lastAccess time.Time
|
lastAccess time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
// initGlobalLimiter 初始化全局限流器
|
// InitGlobalLimiter 初始化全局限流器
|
||||||
func initGlobalLimiter() *IPRateLimiter {
|
func InitGlobalLimiter() *IPRateLimiter {
|
||||||
cfg := GetConfig()
|
cfg := config.GetConfig()
|
||||||
|
|
||||||
whitelist := make([]*net.IPNet, 0, len(cfg.Security.WhiteList))
|
whitelist := make([]*net.IPNet, 0, len(cfg.Security.WhiteList))
|
||||||
for _, item := range cfg.Security.WhiteList {
|
for _, item := range cfg.Security.WhiteList {
|
||||||
if item = strings.TrimSpace(item); item != "" {
|
if item = strings.TrimSpace(item); item != "" {
|
||||||
if !strings.Contains(item, "/") {
|
if !strings.Contains(item, "/") {
|
||||||
item = item + "/32" // 单个IP转为CIDR格式
|
item = item + "/32"
|
||||||
}
|
}
|
||||||
_, ipnet, err := net.ParseCIDR(item)
|
_, ipnet, err := net.ParseCIDR(item)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
@@ -52,12 +53,11 @@ func initGlobalLimiter() *IPRateLimiter {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 解析黑名单IP段
|
|
||||||
blacklist := make([]*net.IPNet, 0, len(cfg.Security.BlackList))
|
blacklist := make([]*net.IPNet, 0, len(cfg.Security.BlackList))
|
||||||
for _, item := range cfg.Security.BlackList {
|
for _, item := range cfg.Security.BlackList {
|
||||||
if item = strings.TrimSpace(item); item != "" {
|
if item = strings.TrimSpace(item); item != "" {
|
||||||
if !strings.Contains(item, "/") {
|
if !strings.Contains(item, "/") {
|
||||||
item = item + "/32" // 单个IP转为CIDR格式
|
item = item + "/32"
|
||||||
}
|
}
|
||||||
_, ipnet, err := net.ParseCIDR(item)
|
_, ipnet, err := net.ParseCIDR(item)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
@@ -68,34 +68,25 @@ func initGlobalLimiter() *IPRateLimiter {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 计算速率:将 "每N小时X个请求" 转换为 "每秒Y个请求"
|
|
||||||
ratePerSecond := rate.Limit(float64(cfg.RateLimit.RequestLimit) / (cfg.RateLimit.PeriodHours * 3600))
|
ratePerSecond := rate.Limit(float64(cfg.RateLimit.RequestLimit) / (cfg.RateLimit.PeriodHours * 3600))
|
||||||
|
|
||||||
burstSize := cfg.RateLimit.RequestLimit
|
burstSize := cfg.RateLimit.RequestLimit
|
||||||
if burstSize < 1 {
|
|
||||||
burstSize = 1
|
|
||||||
}
|
|
||||||
|
|
||||||
limiter := &IPRateLimiter{
|
limiter := &IPRateLimiter{
|
||||||
ips: make(map[string]*rateLimiterEntry),
|
ips: make(map[string]*rateLimiterEntry),
|
||||||
mu: &sync.RWMutex{},
|
mu: &sync.RWMutex{},
|
||||||
r: ratePerSecond,
|
r: ratePerSecond,
|
||||||
b: burstSize,
|
b: burstSize,
|
||||||
whitelist: whitelist,
|
whitelist: whitelist,
|
||||||
blacklist: blacklist,
|
blacklist: blacklist,
|
||||||
|
whitelistLimiter: rate.NewLimiter(rate.Inf, burstSize),
|
||||||
}
|
}
|
||||||
|
|
||||||
// 启动定期清理goroutine
|
|
||||||
go limiter.cleanupRoutine()
|
go limiter.cleanupRoutine()
|
||||||
|
|
||||||
return limiter
|
return limiter
|
||||||
}
|
}
|
||||||
|
|
||||||
// initLimiter 初始化限流器
|
|
||||||
func initLimiter() {
|
|
||||||
globalLimiter = initGlobalLimiter()
|
|
||||||
}
|
|
||||||
|
|
||||||
// cleanupRoutine 定期清理过期的限流器
|
// cleanupRoutine 定期清理过期的限流器
|
||||||
func (i *IPRateLimiter) cleanupRoutine() {
|
func (i *IPRateLimiter) cleanupRoutine() {
|
||||||
ticker := time.NewTicker(CleanupInterval)
|
ticker := time.NewTicker(CleanupInterval)
|
||||||
@@ -105,25 +96,20 @@ func (i *IPRateLimiter) cleanupRoutine() {
|
|||||||
now := time.Now()
|
now := time.Now()
|
||||||
expired := make([]string, 0)
|
expired := make([]string, 0)
|
||||||
|
|
||||||
// 查找过期的条目
|
|
||||||
i.mu.RLock()
|
i.mu.RLock()
|
||||||
for ip, entry := range i.ips {
|
for ip, entry := range i.ips {
|
||||||
// 如果最后访问时间超过1小时,认为过期
|
|
||||||
if now.Sub(entry.lastAccess) > 1*time.Hour {
|
if now.Sub(entry.lastAccess) > 1*time.Hour {
|
||||||
expired = append(expired, ip)
|
expired = append(expired, ip)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
i.mu.RUnlock()
|
i.mu.RUnlock()
|
||||||
|
|
||||||
// 如果有过期条目或者缓存过大,进行清理
|
|
||||||
if len(expired) > 0 || len(i.ips) > MaxIPCacheSize {
|
if len(expired) > 0 || len(i.ips) > MaxIPCacheSize {
|
||||||
i.mu.Lock()
|
i.mu.Lock()
|
||||||
// 删除过期条目
|
|
||||||
for _, ip := range expired {
|
for _, ip := range expired {
|
||||||
delete(i.ips, ip)
|
delete(i.ips, ip)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 如果缓存仍然过大,全部清理
|
|
||||||
if len(i.ips) > MaxIPCacheSize {
|
if len(i.ips) > MaxIPCacheSize {
|
||||||
i.ips = make(map[string]*rateLimiterEntry)
|
i.ips = make(map[string]*rateLimiterEntry)
|
||||||
}
|
}
|
||||||
@@ -140,28 +126,26 @@ func extractIPFromAddress(address string) string {
|
|||||||
return address
|
return address
|
||||||
}
|
}
|
||||||
|
|
||||||
// normalizeIPForRateLimit 标准化IP地址用于限流:IPv4保持不变,IPv6标准化为/64网段
|
// normalizeIPForRateLimit 标准化IP地址用于限流
|
||||||
func normalizeIPForRateLimit(ipStr string) string {
|
func normalizeIPForRateLimit(ipStr string) string {
|
||||||
ip := net.ParseIP(ipStr)
|
ip := net.ParseIP(ipStr)
|
||||||
if ip == nil {
|
if ip == nil {
|
||||||
return ipStr // 解析失败,返回原值
|
return ipStr
|
||||||
}
|
}
|
||||||
|
|
||||||
if ip.To4() != nil {
|
if ip.To4() != nil {
|
||||||
return ipStr // IPv4保持不变
|
return ipStr
|
||||||
}
|
}
|
||||||
|
|
||||||
// IPv6:标准化为 /64 网段
|
|
||||||
ipv6 := ip.To16()
|
ipv6 := ip.To16()
|
||||||
for i := 8; i < 16; i++ {
|
for i := 8; i < 16; i++ {
|
||||||
ipv6[i] = 0 // 清零后64位
|
ipv6[i] = 0
|
||||||
}
|
}
|
||||||
return ipv6.String() + "/64"
|
return ipv6.String() + "/64"
|
||||||
}
|
}
|
||||||
|
|
||||||
// isIPInCIDRList 检查IP是否在CIDR列表中
|
// isIPInCIDRList 检查IP是否在CIDR列表中
|
||||||
func isIPInCIDRList(ip string, cidrList []*net.IPNet) bool {
|
func isIPInCIDRList(ip string, cidrList []*net.IPNet) bool {
|
||||||
// 先提取纯IP地址
|
|
||||||
cleanIP := extractIPFromAddress(ip)
|
cleanIP := extractIPFromAddress(ip)
|
||||||
parsedIP := net.ParseIP(cleanIP)
|
parsedIP := net.ParseIP(cleanIP)
|
||||||
if parsedIP == nil {
|
if parsedIP == nil {
|
||||||
@@ -176,22 +160,18 @@ func isIPInCIDRList(ip string, cidrList []*net.IPNet) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetLimiter 获取指定IP的限流器,同时返回是否允许访问
|
// GetLimiter 获取指定IP的限流器
|
||||||
func (i *IPRateLimiter) GetLimiter(ip string) (*rate.Limiter, bool) {
|
func (i *IPRateLimiter) GetLimiter(ip string) (*rate.Limiter, bool) {
|
||||||
// 提取纯IP地址
|
|
||||||
cleanIP := extractIPFromAddress(ip)
|
cleanIP := extractIPFromAddress(ip)
|
||||||
|
|
||||||
// 检查是否在黑名单中
|
|
||||||
if isIPInCIDRList(cleanIP, i.blacklist) {
|
if isIPInCIDRList(cleanIP, i.blacklist) {
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
// 检查是否在白名单中
|
|
||||||
if isIPInCIDRList(cleanIP, i.whitelist) {
|
if isIPInCIDRList(cleanIP, i.whitelist) {
|
||||||
return rate.NewLimiter(rate.Inf, i.b), true
|
return i.whitelistLimiter, true
|
||||||
}
|
}
|
||||||
|
|
||||||
// 标准化IP用于限流:IPv4保持不变,IPv6标准化为/64网段
|
|
||||||
normalizedIP := normalizeIPForRateLimit(cleanIP)
|
normalizedIP := normalizeIPForRateLimit(cleanIP)
|
||||||
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
@@ -230,7 +210,6 @@ func (i *IPRateLimiter) GetLimiter(ip string) (*rate.Limiter, bool) {
|
|||||||
// RateLimitMiddleware 速率限制中间件
|
// RateLimitMiddleware 速率限制中间件
|
||||||
func RateLimitMiddleware(limiter *IPRateLimiter) gin.HandlerFunc {
|
func RateLimitMiddleware(limiter *IPRateLimiter) gin.HandlerFunc {
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
// 静态文件豁免:跳过限流检查
|
|
||||||
path := c.Request.URL.Path
|
path := c.Request.URL.Path
|
||||||
if path == "/" || path == "/favicon.ico" || path == "/images.html" || path == "/search.html" ||
|
if path == "/" || path == "/favicon.ico" || path == "/images.html" || path == "/search.html" ||
|
||||||
strings.HasPrefix(path, "/public/") {
|
strings.HasPrefix(path, "/public/") {
|
||||||
@@ -238,30 +217,22 @@ func RateLimitMiddleware(limiter *IPRateLimiter) gin.HandlerFunc {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 获取客户端真实IP
|
|
||||||
var ip string
|
var ip string
|
||||||
|
|
||||||
// 优先尝试从请求头获取真实IP
|
|
||||||
if forwarded := c.GetHeader("X-Forwarded-For"); forwarded != "" {
|
if forwarded := c.GetHeader("X-Forwarded-For"); forwarded != "" {
|
||||||
// X-Forwarded-For可能包含多个IP,取第一个
|
|
||||||
ips := strings.Split(forwarded, ",")
|
ips := strings.Split(forwarded, ",")
|
||||||
ip = strings.TrimSpace(ips[0])
|
ip = strings.TrimSpace(ips[0])
|
||||||
} else if realIP := c.GetHeader("X-Real-IP"); realIP != "" {
|
} else if realIP := c.GetHeader("X-Real-IP"); realIP != "" {
|
||||||
// 如果有X-Real-IP头
|
|
||||||
ip = realIP
|
ip = realIP
|
||||||
} else if remoteIP := c.GetHeader("X-Original-Forwarded-For"); remoteIP != "" {
|
} else if remoteIP := c.GetHeader("X-Original-Forwarded-For"); remoteIP != "" {
|
||||||
// 某些代理可能使用此头
|
|
||||||
ips := strings.Split(remoteIP, ",")
|
ips := strings.Split(remoteIP, ",")
|
||||||
ip = strings.TrimSpace(ips[0])
|
ip = strings.TrimSpace(ips[0])
|
||||||
} else {
|
} else {
|
||||||
// 回退到ClientIP方法
|
|
||||||
ip = c.ClientIP()
|
ip = c.ClientIP()
|
||||||
}
|
}
|
||||||
|
|
||||||
// 提取纯IP地址(去除可能存在的端口)
|
|
||||||
cleanIP := extractIPFromAddress(ip)
|
cleanIP := extractIPFromAddress(ip)
|
||||||
|
|
||||||
// 日志记录请求IP和头信息
|
|
||||||
normalizedIP := normalizeIPForRateLimit(cleanIP)
|
normalizedIP := normalizeIPForRateLimit(cleanIP)
|
||||||
if cleanIP != normalizedIP {
|
if cleanIP != normalizedIP {
|
||||||
fmt.Printf("请求IP: %s (提纯后: %s, 限流段: %s), X-Forwarded-For: %s, X-Real-IP: %s\n",
|
fmt.Printf("请求IP: %s (提纯后: %s, 限流段: %s), X-Forwarded-For: %s, X-Real-IP: %s\n",
|
||||||
@@ -275,10 +246,8 @@ func RateLimitMiddleware(limiter *IPRateLimiter) gin.HandlerFunc {
|
|||||||
c.GetHeader("X-Real-IP"))
|
c.GetHeader("X-Real-IP"))
|
||||||
}
|
}
|
||||||
|
|
||||||
// 获取限流器并检查是否允许访问
|
|
||||||
ipLimiter, allowed := limiter.GetLimiter(cleanIP)
|
ipLimiter, allowed := limiter.GetLimiter(cleanIP)
|
||||||
|
|
||||||
// 如果IP在黑名单中
|
|
||||||
if !allowed {
|
if !allowed {
|
||||||
c.JSON(403, gin.H{
|
c.JSON(403, gin.H{
|
||||||
"error": "您已被限制访问",
|
"error": "您已被限制访问",
|
||||||
@@ -287,7 +256,6 @@ func RateLimitMiddleware(limiter *IPRateLimiter) gin.HandlerFunc {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 检查限流
|
|
||||||
if !ipLimiter.Allow() {
|
if !ipLimiter.Allow() {
|
||||||
c.JSON(429, gin.H{
|
c.JSON(429, gin.H{
|
||||||
"error": "请求频率过快,暂时限制访问",
|
"error": "请求频率过快,暂时限制访问",
|
||||||
Reference in New Issue
Block a user