57 Commits

Author SHA1 Message Date
user123456
a0df3b1a54 修复gist正则匹配 2025-07-28 04:46:08 +08:00
starry
70bf552daf Update release.yml 2025-07-27 12:16:18 +08:00
starry
d5e2abdcff Merge pull request #39 from sky22333/dev
优化代码结构,支持h2
2025-07-27 12:11:39 +08:00
user123456
07a926902a 优化代码格式 2025-07-27 10:58:20 +08:00
user123456
1881b5b1ba 增加HTTP2多路复用的支持 2025-07-27 10:25:52 +08:00
user123456
75e37158ef update 2025-07-27 08:05:36 +08:00
user123456
506de49586 IP白名单优化 2025-07-27 08:01:34 +08:00
user123456
dd704dc499 update 2025-07-27 07:37:35 +08:00
starry
9a8b850bce Delete src/test.exe 2025-07-27 06:15:42 +08:00
user123456
187e842445 拆分包结构 2025-07-27 05:50:34 +08:00
starry
badafd2899 Update README.md 2025-07-20 19:34:02 +08:00
starry
4bf075fcaf Update README.md 2025-07-18 21:12:47 +08:00
starry
208a239af3 修复cf导致的协议头问题,简化健康检查 2025-07-18 21:10:03 +08:00
starry
1fb97b5347 Merge pull request #34 from Thinker-Joe/main
Add registry mirror usage
2025-07-16 20:17:23 +08:00
Thinker-Joe
95c2e4fd68 Merge pull request #4 from Thinker-Joe/codex/readmeregistry-mirrors
Add registry mirror usage
2025-07-16 19:35:37 +08:00
Thinker-Joe
79fa21321f docs: add registry mirror usage 2025-07-16 19:35:10 +08:00
starry
c4c5993bd1 Update README.md 2025-06-30 18:19:14 +08:00
starry
d46fd3fec4 Update README.md 2025-06-28 08:46:24 +08:00
starry
279b48d432 Update README.md 2025-06-28 08:29:34 +08:00
starry
61f09192bb Update README.md 2025-06-27 09:06:44 +08:00
starry
d876809086 完善一些小细节 2025-06-27 08:50:04 +08:00
user123456
fe9156f878 Merge commit 'refs/pull/origin/28' 2025-06-21 00:30:51 +08:00
starry
35651e214f proxy字段修复 2025-06-21 00:15:27 +08:00
user123456
d373e0104d 获取更多镜像tag 2025-06-20 23:44:13 +08:00
starry
207a03a511 Merge pull request #25 from beck-8/me/op_proxy
优化代理配置
2025-06-19 23:00:44 +08:00
beck-8
5bd32cd6c1 go fmt . 2025-06-19 22:53:20 +08:00
beck-8
8c127a795b op http client proxy 2025-06-19 22:52:51 +08:00
user123456
2567652a7d 更新配置说明 2025-06-18 22:26:19 +08:00
user123456
c023e6a9c4 清理冗余written字段 2025-06-18 22:05:28 +08:00
user123456
44c6e4cd7b 修复双重写入 2025-06-18 21:29:56 +08:00
user123456
c22bd0637a 更新默认配置 2025-06-18 20:49:45 +08:00
user123456
a94b476726 移除冗余的限流智能判断逻辑 2025-06-18 20:44:26 +08:00
user123456
4c6751b862 限流改为全局应用 2025-06-18 19:44:32 +08:00
user123456
acc63d7b68 删除热重载 2025-06-18 19:14:13 +08:00
starry
d0b1ea8582 LF 2025-06-18 17:08:14 +08:00
starry
c607061dae LF 2025-06-18 17:07:43 +08:00
starry
143de7b254 Normalize all line endings to LF 2025-06-18 17:03:29 +08:00
user123456
51ace73b78 优化离线镜像的防抖以及日志 2025-06-18 16:04:53 +08:00
user123456
fa9e9210ab 默认为原始压缩层 2025-06-18 15:14:33 +08:00
user123456
f308410920 修复函数调用点传递 2025-06-18 15:00:41 +08:00
user123456
252dc319c6 优化离线包体积 2025-06-18 14:55:35 +08:00
user123456
29ceeef45b IPv6日志适配 2025-06-17 18:49:34 +08:00
user123456
182dced403 修复ipv6标准化的潜在BUG 2025-06-17 18:38:48 +08:00
user123456
aea36939a3 增加支持走代理 2025-06-17 18:18:17 +08:00
starry
4240c1452a Update README.md 2025-06-16 00:51:06 +08:00
starry
212c8e529d Update README.md 2025-06-15 16:18:54 +08:00
starry
3fd630159b Update config.toml 2025-06-14 14:11:14 +08:00
starry
17d827f50b Update README.md 2025-06-14 14:10:31 +08:00
starry
7dcbc839c6 Update README.md 2025-06-14 14:10:07 +08:00
starry
45ffebc820 Update README.md 2025-06-14 14:08:08 +08:00
starry
3027b1f218 Update README.md 2025-06-13 18:31:13 +08:00
starry
3d2c419ebe Update README.md 2025-06-13 18:30:14 +08:00
starry
b529fbfdd2 Update README.md 2025-06-13 18:29:38 +08:00
user123456
737c1dbf46 io.Copy 2025-06-13 17:58:13 +08:00
user123456
a67ef6c52c 离线镜像下载去掉缓存,避免缓存不完整导致空指针 2025-06-13 17:00:47 +08:00
starry
0adf11099e add 2025-06-13 16:25:27 +08:00
starry
dbb9432eb0 Create LICENSE 2025-06-13 14:11:56 +08:00
25 changed files with 3340 additions and 3224 deletions

1
.gitattributes vendored Normal file
View File

@@ -0,0 +1 @@
* text=auto eol=lf

View File

@@ -72,7 +72,7 @@ jobs:
cp hubproxy.service build/hubproxy/ cp hubproxy.service build/hubproxy/
# 复制安装脚本 # 复制安装脚本
cp install-service.sh build/hubproxy/ cp install.sh build/hubproxy/
# 创建README文件 # 创建README文件
cat > build/hubproxy/README.md << 'EOF' cat > build/hubproxy/README.md << 'EOF'
@@ -88,13 +88,13 @@ jobs:
# Linux AMD64 包 # Linux AMD64 包
mkdir -p linux-amd64/hubproxy mkdir -p linux-amd64/hubproxy
cp hubproxy/hubproxy-linux-amd64 linux-amd64/hubproxy/hubproxy cp hubproxy/hubproxy-linux-amd64 linux-amd64/hubproxy/hubproxy
cp hubproxy/config.toml hubproxy/hubproxy.service hubproxy/install-service.sh hubproxy/README.md linux-amd64/hubproxy/ cp hubproxy/config.toml hubproxy/hubproxy.service hubproxy/install.sh hubproxy/README.md linux-amd64/hubproxy/
tar -czf hubproxy-${{ steps.version.outputs.version }}-linux-amd64.tar.gz -C linux-amd64 hubproxy tar -czf hubproxy-${{ steps.version.outputs.version }}-linux-amd64.tar.gz -C linux-amd64 hubproxy
# Linux ARM64 包 # Linux ARM64 包
mkdir -p linux-arm64/hubproxy mkdir -p linux-arm64/hubproxy
cp hubproxy/hubproxy-linux-arm64 linux-arm64/hubproxy/hubproxy cp hubproxy/hubproxy-linux-arm64 linux-arm64/hubproxy/hubproxy
cp hubproxy/config.toml hubproxy/hubproxy.service hubproxy/install-service.sh hubproxy/README.md linux-arm64/hubproxy/ cp hubproxy/config.toml hubproxy/hubproxy.service hubproxy/install.sh hubproxy/README.md linux-arm64/hubproxy/
tar -czf hubproxy-${{ steps.version.outputs.version }}-linux-arm64.tar.gz -C linux-arm64 hubproxy tar -czf hubproxy-${{ steps.version.outputs.version }}-linux-arm64.tar.gz -C linux-arm64 hubproxy
# 列出生成的文件 # 列出生成的文件

4
.gitignore vendored Normal file
View File

@@ -0,0 +1,4 @@
.idea
.vscode
.DS_Store
hubproxy*

21
LICENSE Normal file
View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2025 sky22333
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

158
README.md
View File

@@ -2,7 +2,11 @@
🚀 **Docker 和 GitHub 加速代理服务器** 🚀 **Docker 和 GitHub 加速代理服务器**
一个轻量级、高性能的多功能代理服务,提供 Docker 镜像加速、GitHub 文件加速等功能。 一个轻量级、高性能的多功能代理服务,提供 Docker 镜像加速、GitHub 文件加速、下载离线镜像、在线搜索 Docker 镜像等功能。
<p align="center">
<img src="https://count.getloli.com/get/@sky22333.hubproxy?theme=rule34" alt="Visitors">
</p>
## ✨ 特性 ## ✨ 特性
@@ -14,7 +18,8 @@
- 🚫 **仓库审计** - 强大的自定义黑名单白名单同时审计镜像仓库和GitHub仓库 - 🚫 **仓库审计** - 强大的自定义黑名单白名单同时审计镜像仓库和GitHub仓库
- 🔍 **镜像搜索** - 在线搜索 Docker 镜像 - 🔍 **镜像搜索** - 在线搜索 Docker 镜像
-**轻量高效** - 基于 Go 语言,单二进制文件运行,资源占用低,优雅的内存清理机制。 -**轻量高效** - 基于 Go 语言,单二进制文件运行,资源占用低,优雅的内存清理机制。
- 🔧 **配置热重载** - 统一配置管理,部分配置项支持热重载,无需重启服务 - 🔧 **统一配置** - 统一配置管理
## 🚀 快速开始 ## 🚀 快速开始
@@ -29,12 +34,14 @@ docker run -d \
### 一键安装 ### 一键脚本安装
```bash ```bash
curl -fsSL https://raw.githubusercontent.com/sky22333/hubproxy/main/install-service.sh | sudo bash curl -fsSL https://raw.githubusercontent.com/sky22333/hubproxy/main/install.sh | sudo bash
``` ```
也可以直接下载二进制文件执行`./hubproxy`使用无需配置文件即可启动内置默认配置支持所有功能。初始内存占用约18M二进制文件大小约12M
这个命令会: 这个命令会:
- 🔍 自动检测系统架构AMD64/ARM64 - 🔍 自动检测系统架构AMD64/ARM64
- 📥 从 GitHub Releases 下载最新版本 - 📥 从 GitHub Releases 下载最新版本
@@ -55,9 +62,26 @@ docker pull nginx
docker pull yourdomain.com/nginx docker pull yourdomain.com/nginx
# ghcr加速 # ghcr加速
docker pull yourdomain.com/ghcr.io/user/images docker pull yourdomain.com/ghcr.io/sky22333/hubproxy
# 符合Docker Registry API v2标准的仓库都支持
``` ```
当然也支持配置为全局镜像加速,在主机上新建(或编辑)`/etc/docker/daemon.json`
`"registry-mirrors"` 中加入域名:
```json
{
"registry-mirrors": [
"https://yourdomain.com"
]
}
```
若已设置其他加速地址,直接并列添加后保存,
再执行 `sudo systemctl restart docker` 重启docker服务让配置生效。
### GitHub 文件加速 ### GitHub 文件加速
```bash ```bash
@@ -66,16 +90,130 @@ https://github.com/user/repo/releases/download/v1.0.0/file.tar.gz
# 加速链接 # 加速链接
https://yourdomain.com/https://github.com/user/repo/releases/download/v1.0.0/file.tar.gz https://yourdomain.com/https://github.com/user/repo/releases/download/v1.0.0/file.tar.gz
# 加速下载仓库
git clone https://yourdomain.com/https://github.com/sky22333/hubproxy.git
``` ```
## ⚙️ 配置
<details>
<summary>config.toml 配置说明</summary>
## ⚙️ 提示 *此配置是默认配置,已经内置在程序中了,可以不用添加。*
主配置文件位于 `/opt/hubproxy/config.toml` ```
[server]
host = "0.0.0.0"
# 监听端口
port = 5000
# Github文件大小限制字节默认2GB
fileSize = 2147483648
# HTTP/2 多路复用,提升下载速度
enableH2C = false
[rateLimit]
# 每个IP每周期允许的请求数(注意Docker镜像会有多个层会消耗多个次数)
requestLimit = 500
# 限流周期(小时)
periodHours = 3.0
[security]
# IP白名单支持单个IP或IP段
# 白名单中的IP不受限流限制
whiteList = [
"127.0.0.1",
"172.17.0.0/16",
"192.168.1.0/24"
]
# IP黑名单支持单个IP或IP段
# 黑名单中的IP将被直接拒绝访问
blackList = [
"192.168.100.1",
"192.168.100.0/24"
]
[access]
# 代理服务白名单支持GitHub仓库和Docker镜像支持通配符
# 只允许访问白名单中的仓库/镜像,为空时不限制
whiteList = []
# 代理服务黑名单支持GitHub仓库和Docker镜像支持通配符
# 禁止访问黑名单中的仓库/镜像
blackList = [
"baduser/malicious-repo",
"*/malicious-repo",
"baduser/*"
]
# 代理配置,支持有用户名/密码认证和无认证模式
# 无认证: socks5://127.0.0.1:1080
# 有认证: socks5://username:password@127.0.0.1:1080
# 留空不使用代理
proxy = ""
[download]
# 批量下载离线镜像数量限制
maxImages = 10
# Registry映射配置支持多种镜像仓库上游
[registries]
# GitHub Container Registry
[registries."ghcr.io"]
upstream = "ghcr.io"
authHost = "ghcr.io/token"
authType = "github"
enabled = true
# Google Container Registry
[registries."gcr.io"]
upstream = "gcr.io"
authHost = "gcr.io/v2/token"
authType = "google"
enabled = true
# Quay.io Container Registry
[registries."quay.io"]
upstream = "quay.io"
authHost = "quay.io/v2/auth"
authType = "quay"
enabled = true
# Kubernetes Container Registry
[registries."registry.k8s.io"]
upstream = "registry.k8s.io"
authHost = "registry.k8s.io"
authType = "anonymous"
enabled = true
[tokenCache]
# 是否启用缓存(同时控制Token和Manifest缓存)显著提升性能
enabled = true
# 默认缓存时间(分钟)
defaultTTL = "20m"
```
</details>
容器内的配置文件位于 `/root/config.toml`
脚本部署配置文件位于 `/opt/hubproxy/config.toml`
为了IP限流能够正常运行反向代理需要传递IP头用来获取访客真实IP以caddy为例 为了IP限流能够正常运行反向代理需要传递IP头用来获取访客真实IP以caddy为例
``` ```
example.com {
reverse_proxy {
to 127.0.0.1:5000
header_up X-Real-IP {remote}
header_up X-Forwarded-For {remote}
header_up X-Forwarded-Proto {scheme}
}
}
```
cloudflare CDN
```
example.com { example.com {
reverse_proxy 127.0.0.1:5000 { reverse_proxy 127.0.0.1:5000 {
header_up X-Forwarded-For {http.request.header.CF-Connecting-IP} header_up X-Forwarded-For {http.request.header.CF-Connecting-IP}
@@ -87,6 +225,7 @@ example.com {
``` ```
## ⚠️ 免责声明 ## ⚠️ 免责声明
- 本程序仅供学习交流使用,请勿用于非法用途 - 本程序仅供学习交流使用,请勿用于非法用途
@@ -100,3 +239,8 @@ example.com {
**⭐ 如果这个项目对你有帮助,请给个 Star⭐** **⭐ 如果这个项目对你有帮助,请给个 Star⭐**
</div> </div>
## Star 趋势
[![Star 趋势](https://starchart.cc/sky22333/hubproxy.svg?variant=adaptive)](https://starchart.cc/sky22333/hubproxy)

View File

@@ -1,8 +1,8 @@
services: services:
ghproxy: hubproxy:
build: . build: .
restart: always restart: always
ports: ports:
- '5000:5000' - '5000:5000'
volumes: volumes:
- ./src/config.toml:/root/config.toml - ./src/config.toml:/root/config.toml

View File

@@ -1,32 +1,35 @@
[server] [server]
# 监听地址,默认监听所有接口
host = "0.0.0.0" host = "0.0.0.0"
# 监听端口 # 监听端口
port = 5000 port = 5000
# 文件大小限制字节默认2GB # Github文件大小限制字节默认2GB
fileSize = 2147483648 fileSize = 2147483648
# HTTP/2 多路复用
enableH2C = false
[rateLimit] [rateLimit]
# 每个IP每小时允许的请求数(Docker镜像每个层为一个请求) # 每个IP每周期允许的请求数
requestLimit = 200 requestLimit = 500
# 限流周期(小时) # 限流周期(小时)
periodHours = 1.0 periodHours = 3.0
[security] [security]
# IP白名单支持单个IP或CIDR格式 # IP白名单支持单个IP或IP段
# 白名单中的IP不受限流限制 # 白名单中的IP不受限流限制
whiteList = [ whiteList = [
"127.0.0.1", "127.0.0.1",
"172.17.0.0/16",
"192.168.1.0/24" "192.168.1.0/24"
] ]
# IP黑名单支持单个IP或CIDR格式 # IP黑名单支持单个IP或IP段
# 黑名单中的IP将被直接拒绝访问 # 黑名单中的IP将被直接拒绝访问
blackList = [ blackList = [
"192.168.100.1" "192.168.100.1",
"192.168.100.0/24"
] ]
[proxy] [access]
# 代理服务白名单支持GitHub仓库和Docker镜像支持通配符 # 代理服务白名单支持GitHub仓库和Docker镜像支持通配符
# 只允许访问白名单中的仓库/镜像,为空时不限制 # 只允许访问白名单中的仓库/镜像,为空时不限制
whiteList = [] whiteList = []
@@ -39,11 +42,17 @@ blackList = [
"baduser/*" "baduser/*"
] ]
# 代理配置,支持有用户名/密码认证和无认证模式
# 无认证: socks5://127.0.0.1:1080
# 有认证: socks5://username:password@127.0.0.1:1080
# 留空不使用代理
proxy = ""
[download] [download]
# 单次并发下载离线镜像数量限制 # 批量下载离线镜像数量限制
maxImages = 10 maxImages = 10
# Registry映射配置支持多种Container Registry # Registry映射配置支持多种镜像仓库上游
[registries] [registries]
# GitHub Container Registry # GitHub Container Registry
@@ -74,16 +83,8 @@ authHost = "registry.k8s.io"
authType = "anonymous" authType = "anonymous"
enabled = true enabled = true
# 私有Registry示例默认禁用
# [registries."harbor.company.com"]
# upstream = "harbor.company.com"
# authHost = "harbor.company.com/service/token"
# authType = "basic"
# enabled = false
# 缓存配置Docker临时Token和Manifest统一管理显著提升性能
[tokenCache] [tokenCache]
# 是否启用缓存(同时控制Token和Manifest缓存) # 是否启用缓存(同时控制Token和Manifest缓存)显著提升性能
enabled = true enabled = true
# 默认缓存时间 # 默认缓存时间(分钟)
defaultTTL = "20m" defaultTTL = "20m"

View File

@@ -1,4 +1,4 @@
package main package config
import ( import (
"fmt" "fmt"
@@ -9,58 +9,56 @@ import (
"time" "time"
"github.com/pelletier/go-toml/v2" "github.com/pelletier/go-toml/v2"
"github.com/spf13/viper"
"github.com/fsnotify/fsnotify"
) )
// RegistryMapping Registry映射配置 // RegistryMapping Registry映射配置
type RegistryMapping struct { type RegistryMapping struct {
Upstream string `toml:"upstream"` // 上游Registry地址 Upstream string `toml:"upstream"`
AuthHost string `toml:"authHost"` // 认证服务器地址 AuthHost string `toml:"authHost"`
AuthType string `toml:"authType"` // 认证类型: docker/github/google/basic AuthType string `toml:"authType"`
Enabled bool `toml:"enabled"` // 是否启用 Enabled bool `toml:"enabled"`
} }
// AppConfig 应用配置结构体 // AppConfig 应用配置结构体
type AppConfig struct { type AppConfig struct {
Server struct { Server struct {
Host string `toml:"host"` // 监听地址 Host string `toml:"host"`
Port int `toml:"port"` // 监听端口 Port int `toml:"port"`
FileSize int64 `toml:"fileSize"` // 文件大小限制(字节) FileSize int64 `toml:"fileSize"`
EnableH2C bool `toml:"enableH2C"`
} `toml:"server"` } `toml:"server"`
RateLimit struct { RateLimit struct {
RequestLimit int `toml:"requestLimit"` // 每小时请求限制 RequestLimit int `toml:"requestLimit"`
PeriodHours float64 `toml:"periodHours"` // 限制周期(小时) PeriodHours float64 `toml:"periodHours"`
} `toml:"rateLimit"` } `toml:"rateLimit"`
Security struct { Security struct {
WhiteList []string `toml:"whiteList"` // 白名单IP/CIDR列表 WhiteList []string `toml:"whiteList"`
BlackList []string `toml:"blackList"` // 黑名单IP/CIDR列表 BlackList []string `toml:"blackList"`
} `toml:"security"` } `toml:"security"`
Proxy struct { Access struct {
WhiteList []string `toml:"whiteList"` // 代理白名单(仓库级别) WhiteList []string `toml:"whiteList"`
BlackList []string `toml:"blackList"` // 代理黑名单(仓库级别) BlackList []string `toml:"blackList"`
} `toml:"proxy"` Proxy string `toml:"proxy"`
} `toml:"access"`
Download struct { Download struct {
MaxImages int `toml:"maxImages"` // 单次下载最大镜像数量限制 MaxImages int `toml:"maxImages"`
} `toml:"download"` } `toml:"download"`
Registries map[string]RegistryMapping `toml:"registries"` Registries map[string]RegistryMapping `toml:"registries"`
TokenCache struct { TokenCache struct {
Enabled bool `toml:"enabled"` // 是否启用token缓存 Enabled bool `toml:"enabled"`
DefaultTTL string `toml:"defaultTTL"` // 默认缓存时间 DefaultTTL string `toml:"defaultTTL"`
} `toml:"tokenCache"` } `toml:"tokenCache"`
} }
var ( var (
appConfig *AppConfig appConfig *AppConfig
appConfigLock sync.RWMutex appConfigLock sync.RWMutex
isViperEnabled bool
viperInstance *viper.Viper
cachedConfig *AppConfig cachedConfig *AppConfig
configCacheTime time.Time configCacheTime time.Time
@@ -72,19 +70,21 @@ var (
func DefaultConfig() *AppConfig { func DefaultConfig() *AppConfig {
return &AppConfig{ return &AppConfig{
Server: struct { Server: struct {
Host string `toml:"host"` Host string `toml:"host"`
Port int `toml:"port"` Port int `toml:"port"`
FileSize int64 `toml:"fileSize"` FileSize int64 `toml:"fileSize"`
EnableH2C bool `toml:"enableH2C"`
}{ }{
Host: "0.0.0.0", Host: "0.0.0.0",
Port: 5000, Port: 5000,
FileSize: 2 * 1024 * 1024 * 1024, // 2GB FileSize: 2 * 1024 * 1024 * 1024, // 2GB
EnableH2C: false, // 默认关闭H2C
}, },
RateLimit: struct { RateLimit: struct {
RequestLimit int `toml:"requestLimit"` RequestLimit int `toml:"requestLimit"`
PeriodHours float64 `toml:"periodHours"` PeriodHours float64 `toml:"periodHours"`
}{ }{
RequestLimit: 20, RequestLimit: 200,
PeriodHours: 1.0, PeriodHours: 1.0,
}, },
Security: struct { Security: struct {
@@ -94,17 +94,19 @@ func DefaultConfig() *AppConfig {
WhiteList: []string{}, WhiteList: []string{},
BlackList: []string{}, BlackList: []string{},
}, },
Proxy: struct { Access: struct {
WhiteList []string `toml:"whiteList"` WhiteList []string `toml:"whiteList"`
BlackList []string `toml:"blackList"` BlackList []string `toml:"blackList"`
Proxy string `toml:"proxy"`
}{ }{
WhiteList: []string{}, WhiteList: []string{},
BlackList: []string{}, BlackList: []string{},
Proxy: "",
}, },
Download: struct { Download: struct {
MaxImages int `toml:"maxImages"` MaxImages int `toml:"maxImages"`
}{ }{
MaxImages: 10, // 默认值最多同时下载10个镜像 MaxImages: 10,
}, },
Registries: map[string]RegistryMapping{ Registries: map[string]RegistryMapping{
"ghcr.io": { "ghcr.io": {
@@ -136,7 +138,7 @@ func DefaultConfig() *AppConfig {
Enabled bool `toml:"enabled"` Enabled bool `toml:"enabled"`
DefaultTTL string `toml:"defaultTTL"` DefaultTTL string `toml:"defaultTTL"`
}{ }{
Enabled: true, // docker认证的匿名Token缓存配置用于提升性能 Enabled: true,
DefaultTTL: "20m", DefaultTTL: "20m",
}, },
} }
@@ -152,11 +154,9 @@ func GetConfig() *AppConfig {
} }
configCacheMutex.RUnlock() configCacheMutex.RUnlock()
// 缓存过期,重新生成配置
configCacheMutex.Lock() configCacheMutex.Lock()
defer configCacheMutex.Unlock() defer configCacheMutex.Unlock()
// 双重检查,防止重复生成
if cachedConfig != nil && time.Since(configCacheTime) < configCacheTTL { if cachedConfig != nil && time.Since(configCacheTime) < configCacheTTL {
return cachedConfig return cachedConfig
} }
@@ -170,12 +170,11 @@ func GetConfig() *AppConfig {
return defaultCfg return defaultCfg
} }
// 生成新的配置深拷贝
configCopy := *appConfig configCopy := *appConfig
configCopy.Security.WhiteList = append([]string(nil), appConfig.Security.WhiteList...) configCopy.Security.WhiteList = append([]string(nil), appConfig.Security.WhiteList...)
configCopy.Security.BlackList = append([]string(nil), appConfig.Security.BlackList...) configCopy.Security.BlackList = append([]string(nil), appConfig.Security.BlackList...)
configCopy.Proxy.WhiteList = append([]string(nil), appConfig.Proxy.WhiteList...) configCopy.Access.WhiteList = append([]string(nil), appConfig.Access.WhiteList...)
configCopy.Proxy.BlackList = append([]string(nil), appConfig.Proxy.BlackList...) configCopy.Access.BlackList = append([]string(nil), appConfig.Access.BlackList...)
appConfigLock.RUnlock() appConfigLock.RUnlock()
cachedConfig = &configCopy cachedConfig = &configCopy
@@ -197,10 +196,8 @@ func setConfig(cfg *AppConfig) {
// LoadConfig 加载配置文件 // LoadConfig 加载配置文件
func LoadConfig() error { func LoadConfig() error {
// 首先使用默认配置
cfg := DefaultConfig() cfg := DefaultConfig()
// 尝试加载TOML配置文件
if data, err := os.ReadFile("config.toml"); err == nil { if data, err := os.ReadFile("config.toml"); err == nil {
if err := toml.Unmarshal(data, cfg); err != nil { if err := toml.Unmarshal(data, cfg); err != nil {
return fmt.Errorf("解析配置文件失败: %v", err) return fmt.Errorf("解析配置文件失败: %v", err)
@@ -209,109 +206,14 @@ func LoadConfig() error {
fmt.Println("未找到config.toml使用默认配置") fmt.Println("未找到config.toml使用默认配置")
} }
// 从环境变量覆盖配置
overrideFromEnv(cfg) overrideFromEnv(cfg)
// 设置配置
setConfig(cfg) setConfig(cfg)
if !isViperEnabled {
go enableViperHotReload()
}
return nil return nil
} }
func enableViperHotReload() {
if isViperEnabled {
return
}
// 创建Viper实例
viperInstance = viper.New()
// 配置Viper
viperInstance.SetConfigName("config")
viperInstance.SetConfigType("toml")
viperInstance.AddConfigPath(".")
// 读取配置文件
if err := viperInstance.ReadInConfig(); err != nil {
fmt.Printf("读取配置失败,继续使用当前配置: %v\n", err)
return
}
isViperEnabled = true
viperInstance.WatchConfig()
viperInstance.OnConfigChange(func(e fsnotify.Event) {
fmt.Printf("检测到配置文件变化: %s\n", e.Name)
hotReloadWithViper()
})
}
func hotReloadWithViper() {
start := time.Now()
fmt.Println("🔄 自动热重载...")
// 创建新配置
cfg := DefaultConfig()
// 使用Viper解析配置到结构体
if err := viperInstance.Unmarshal(cfg); err != nil {
fmt.Printf("❌ 配置解析失败: %v\n", err)
return
}
overrideFromEnv(cfg)
setConfig(cfg)
// 异步更新受影响的组件
go func() {
updateAffectedComponents()
fmt.Printf("✅ Viper配置热重载完成耗时: %v\n", time.Since(start))
}()
}
func updateAffectedComponents() {
// 重新初始化限流器
if globalLimiter != nil {
fmt.Println("📡 重新初始化限流器...")
initLimiter()
}
// 重新加载访问控制
fmt.Println("🔒 重新加载访问控制规则...")
if GlobalAccessController != nil {
GlobalAccessController.Reload()
}
fmt.Println("🌐 更新Registry配置映射...")
reloadRegistryConfig()
// 其他需要重新初始化的组件可以在这里添加
fmt.Println("🔧 组件更新完成")
}
func reloadRegistryConfig() {
cfg := GetConfig()
enabledCount := 0
// 统计启用的Registry数量
for _, mapping := range cfg.Registries {
if mapping.Enabled {
enabledCount++
}
}
fmt.Printf("🌐 Registry配置已更新: %d个启用\n", enabledCount)
}
// overrideFromEnv 从环境变量覆盖配置 // overrideFromEnv 从环境变量覆盖配置
func overrideFromEnv(cfg *AppConfig) { func overrideFromEnv(cfg *AppConfig) {
// 服务器配置
if val := os.Getenv("SERVER_HOST"); val != "" { if val := os.Getenv("SERVER_HOST"); val != "" {
cfg.Server.Host = val cfg.Server.Host = val
} }
@@ -320,13 +222,17 @@ func overrideFromEnv(cfg *AppConfig) {
cfg.Server.Port = port cfg.Server.Port = port
} }
} }
if val := os.Getenv("ENABLE_H2C"); val != "" {
if enable, err := strconv.ParseBool(val); err == nil {
cfg.Server.EnableH2C = enable
}
}
if val := os.Getenv("MAX_FILE_SIZE"); val != "" { if val := os.Getenv("MAX_FILE_SIZE"); val != "" {
if size, err := strconv.ParseInt(val, 10, 64); err == nil && size > 0 { if size, err := strconv.ParseInt(val, 10, 64); err == nil && size > 0 {
cfg.Server.FileSize = size cfg.Server.FileSize = size
} }
} }
// 限流配置
if val := os.Getenv("RATE_LIMIT"); val != "" { if val := os.Getenv("RATE_LIMIT"); val != "" {
if limit, err := strconv.Atoi(val); err == nil && limit > 0 { if limit, err := strconv.Atoi(val); err == nil && limit > 0 {
cfg.RateLimit.RequestLimit = limit cfg.RateLimit.RequestLimit = limit
@@ -338,7 +244,6 @@ func overrideFromEnv(cfg *AppConfig) {
} }
} }
// IP限制配置
if val := os.Getenv("IP_WHITELIST"); val != "" { if val := os.Getenv("IP_WHITELIST"); val != "" {
cfg.Security.WhiteList = append(cfg.Security.WhiteList, strings.Split(val, ",")...) cfg.Security.WhiteList = append(cfg.Security.WhiteList, strings.Split(val, ",")...)
} }
@@ -346,7 +251,6 @@ func overrideFromEnv(cfg *AppConfig) {
cfg.Security.BlackList = append(cfg.Security.BlackList, strings.Split(val, ",")...) cfg.Security.BlackList = append(cfg.Security.BlackList, strings.Split(val, ",")...)
} }
// 下载限制配置
if val := os.Getenv("MAX_IMAGES"); val != "" { if val := os.Getenv("MAX_IMAGES"); val != "" {
if maxImages, err := strconv.Atoi(val); err == nil && maxImages > 0 { if maxImages, err := strconv.Atoi(val); err == nil && maxImages > 0 {
cfg.Download.MaxImages = maxImages cfg.Download.MaxImages = maxImages

View File

@@ -3,11 +3,10 @@ module hubproxy
go 1.24.0 go 1.24.0
require ( require (
github.com/fsnotify/fsnotify v1.8.0
github.com/gin-gonic/gin v1.10.0 github.com/gin-gonic/gin v1.10.0
github.com/google/go-containerregistry v0.20.5 github.com/google/go-containerregistry v0.20.5
github.com/pelletier/go-toml/v2 v2.2.3 github.com/pelletier/go-toml/v2 v2.2.3
github.com/spf13/viper v1.20.1 golang.org/x/net v0.33.0
golang.org/x/time v0.11.0 golang.org/x/time v0.11.0
) )
@@ -25,11 +24,11 @@ require (
github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-playground/validator/v10 v10.20.0 // indirect github.com/go-playground/validator/v10 v10.20.0 // indirect
github.com/go-viper/mapstructure/v2 v2.2.1 // indirect
github.com/goccy/go-json v0.10.2 // indirect github.com/goccy/go-json v0.10.2 // indirect
github.com/json-iterator/go v1.1.12 // indirect github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/compress v1.18.0 // indirect github.com/klauspost/compress v1.18.0 // indirect
github.com/klauspost/cpuid/v2 v2.2.7 // indirect github.com/klauspost/cpuid/v2 v2.2.7 // indirect
github.com/kr/pretty v0.3.0 // indirect
github.com/leodido/go-urn v1.4.0 // indirect github.com/leodido/go-urn v1.4.0 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mitchellh/go-homedir v1.1.0 // indirect github.com/mitchellh/go-homedir v1.1.0 // indirect
@@ -38,24 +37,17 @@ require (
github.com/opencontainers/go-digest v1.0.0 // indirect github.com/opencontainers/go-digest v1.0.0 // indirect
github.com/opencontainers/image-spec v1.1.1 // indirect github.com/opencontainers/image-spec v1.1.1 // indirect
github.com/pkg/errors v0.9.1 // indirect github.com/pkg/errors v0.9.1 // indirect
github.com/sagikazarmark/locafero v0.7.0 // indirect github.com/rogpeppe/go-internal v1.9.0 // indirect
github.com/sirupsen/logrus v1.9.3 // indirect github.com/sirupsen/logrus v1.9.3 // indirect
github.com/sourcegraph/conc v0.3.0 // indirect
github.com/spf13/afero v1.12.0 // indirect
github.com/spf13/cast v1.7.1 // indirect
github.com/spf13/pflag v1.0.6 // indirect
github.com/subosito/gotenv v1.6.0 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.12 // indirect github.com/ugorji/go/codec v1.2.12 // indirect
github.com/vbatts/tar-split v0.12.1 // indirect github.com/vbatts/tar-split v0.12.1 // indirect
go.uber.org/atomic v1.9.0 // indirect
go.uber.org/multierr v1.9.0 // indirect
golang.org/x/arch v0.8.0 // indirect golang.org/x/arch v0.8.0 // indirect
golang.org/x/crypto v0.32.0 // indirect golang.org/x/crypto v0.32.0 // indirect
golang.org/x/net v0.33.0 // indirect
golang.org/x/sync v0.14.0 // indirect golang.org/x/sync v0.14.0 // indirect
golang.org/x/sys v0.33.0 // indirect golang.org/x/sys v0.33.0 // indirect
golang.org/x/text v0.21.0 // indirect golang.org/x/text v0.21.0 // indirect
google.golang.org/protobuf v1.36.3 // indirect google.golang.org/protobuf v1.36.3 // indirect
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect
) )

View File

@@ -8,6 +8,7 @@ github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg=
github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY= github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY=
github.com/containerd/stargz-snapshotter/estargz v0.16.3 h1:7evrXtoh1mSbGj/pfRccTampEyKpjpOnS3CyiV1Ebr8= github.com/containerd/stargz-snapshotter/estargz v0.16.3 h1:7evrXtoh1mSbGj/pfRccTampEyKpjpOnS3CyiV1Ebr8=
github.com/containerd/stargz-snapshotter/estargz v0.16.3/go.mod h1:uyr4BfYfOj3G9WBVE8cOlQmXAbPN9VEQpBBeJIuOipU= github.com/containerd/stargz-snapshotter/estargz v0.16.3/go.mod h1:uyr4BfYfOj3G9WBVE8cOlQmXAbPN9VEQpBBeJIuOipU=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
@@ -17,10 +18,6 @@ github.com/docker/distribution v2.8.3+incompatible h1:AtKxIZ36LoNK51+Z6RpzLpddBi
github.com/docker/distribution v2.8.3+incompatible/go.mod h1:J2gT2udsDAN96Uj4KfcMRqY0/ypR+oyYUYmja8H+y+w= github.com/docker/distribution v2.8.3+incompatible/go.mod h1:J2gT2udsDAN96Uj4KfcMRqY0/ypR+oyYUYmja8H+y+w=
github.com/docker/docker-credential-helpers v0.9.3 h1:gAm/VtF9wgqJMoxzT3Gj5p4AqIjCBS4wrsOh9yRqcz8= github.com/docker/docker-credential-helpers v0.9.3 h1:gAm/VtF9wgqJMoxzT3Gj5p4AqIjCBS4wrsOh9yRqcz8=
github.com/docker/docker-credential-helpers v0.9.3/go.mod h1:x+4Gbw9aGmChi3qTLZj8Dfn0TD20M/fuWy0E5+WDeCo= github.com/docker/docker-credential-helpers v0.9.3/go.mod h1:x+4Gbw9aGmChi3qTLZj8Dfn0TD20M/fuWy0E5+WDeCo=
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
github.com/fsnotify/fsnotify v1.8.0 h1:dAwr6QBTBZIkG8roQaJjGof0pp0EeF+tNV7YBP3F/8M=
github.com/fsnotify/fsnotify v1.8.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0= github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk= github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
@@ -35,8 +32,6 @@ github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJn
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
github.com/go-playground/validator/v10 v10.20.0 h1:K9ISHbSaI0lyB2eWMPJo+kOS/FBExVwjEviJTixqxL8= github.com/go-playground/validator/v10 v10.20.0 h1:K9ISHbSaI0lyB2eWMPJo+kOS/FBExVwjEviJTixqxL8=
github.com/go-playground/validator/v10 v10.20.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM= github.com/go-playground/validator/v10 v10.20.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM=
github.com/go-viper/mapstructure/v2 v2.2.1 h1:ZAaOCxANMuZx5RCeg0mBdEZk7DZasvvZIxtHqx8aGss=
github.com/go-viper/mapstructure/v2 v2.2.1/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM=
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
@@ -52,8 +47,11 @@ github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa02
github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM= github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM=
github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws=
github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M= github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0=
github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
@@ -77,22 +75,11 @@ github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=
github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8=
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
github.com/sagikazarmark/locafero v0.7.0 h1:5MqpDsTGNDhY8sGp0Aowyf0qKsPrhewaLSsFaodPcyo=
github.com/sagikazarmark/locafero v0.7.0/go.mod h1:2za3Cg5rMaTMoG/2Ulr9AwtFaIppKXTRYnozin4aB5k=
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo=
github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0=
github.com/spf13/afero v1.12.0 h1:UcOPyRBYczmFn6yvphxkn9ZEOY65cpwGKb5mL36mrqs=
github.com/spf13/afero v1.12.0/go.mod h1:ZTlWwG4/ahT8W7T0WQ5uYmjI9duaLQGy3Q2OAl4sk/4=
github.com/spf13/cast v1.7.1 h1:cuNEagBQEHWN1FnbGEjCXL2szYEXqfJPbP2HNUaca9Y=
github.com/spf13/cast v1.7.1/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
github.com/spf13/pflag v1.0.6 h1:jFzHGLGAlb3ruxLB8MhbI6A8+AQX/2eW4qeyNZXNp2o=
github.com/spf13/pflag v1.0.6/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spf13/viper v1.20.1 h1:ZMi+z/lvLyPSCoNtFCpqjy0S4kPbirhpTMwl8BkW9X4=
github.com/spf13/viper v1.20.1/go.mod h1:P9Mdzt1zoHIG8m2eZQinpiBjo6kCmZSKBClNNqjJvu4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
@@ -101,20 +88,14 @@ github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8=
github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU=
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE= github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
github.com/vbatts/tar-split v0.12.1 h1:CqKoORW7BUWBe7UL/iqTVvkTBOF8UvOMKOIZykxnnbo= github.com/vbatts/tar-split v0.12.1 h1:CqKoORW7BUWBe7UL/iqTVvkTBOF8UvOMKOIZykxnnbo=
github.com/vbatts/tar-split v0.12.1/go.mod h1:eF6B6i6ftWQcDqEn3/iGFRFRo8cBIMSJVOpnNdfTMFA= github.com/vbatts/tar-split v0.12.1/go.mod h1:eF6B6i6ftWQcDqEn3/iGFRFRo8cBIMSJVOpnNdfTMFA=
go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE=
go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI=
go.uber.org/multierr v1.9.0/go.mod h1:X2jQV1h+kxSjClGpnseKVIxpmcjrj7MNnI0bnlfKTVQ=
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc= golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
@@ -136,8 +117,10 @@ golang.org/x/time v0.11.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
google.golang.org/protobuf v1.36.3 h1:82DV7MYdb8anAVi3qge1wSnMDrnKK7ebr+I0hHRN1BU= google.golang.org/protobuf v1.36.3 h1:82DV7MYdb8anAVi3qge1wSnMDrnKK7ebr+I0hHRN1BU=
google.golang.org/protobuf v1.36.3/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= google.golang.org/protobuf v1.36.3/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@@ -1,4 +1,4 @@
package main package handlers
import ( import (
"context" "context"
@@ -12,6 +12,8 @@ import (
"github.com/google/go-containerregistry/pkg/authn" "github.com/google/go-containerregistry/pkg/authn"
"github.com/google/go-containerregistry/pkg/name" "github.com/google/go-containerregistry/pkg/name"
"github.com/google/go-containerregistry/pkg/v1/remote" "github.com/google/go-containerregistry/pkg/v1/remote"
"hubproxy/config"
"hubproxy/utils"
) )
// DockerProxy Docker代理配置 // DockerProxy Docker代理配置
@@ -27,12 +29,10 @@ type RegistryDetector struct{}
// detectRegistryDomain 检测Registry域名并返回域名和剩余路径 // detectRegistryDomain 检测Registry域名并返回域名和剩余路径
func (rd *RegistryDetector) detectRegistryDomain(path string) (string, string) { func (rd *RegistryDetector) detectRegistryDomain(path string) (string, string) {
cfg := GetConfig() cfg := config.GetConfig()
// 检查路径是否以已知Registry域名开头
for domain := range cfg.Registries { for domain := range cfg.Registries {
if strings.HasPrefix(path, domain+"/") { if strings.HasPrefix(path, domain+"/") {
// 找到匹配的域名,返回域名和剩余路径
remainingPath := strings.TrimPrefix(path, domain+"/") remainingPath := strings.TrimPrefix(path, domain+"/")
return domain, remainingPath return domain, remainingPath
} }
@@ -43,7 +43,7 @@ func (rd *RegistryDetector) detectRegistryDomain(path string) (string, string) {
// isRegistryEnabled 检查Registry是否启用 // isRegistryEnabled 检查Registry是否启用
func (rd *RegistryDetector) isRegistryEnabled(domain string) bool { func (rd *RegistryDetector) isRegistryEnabled(domain string) bool {
cfg := GetConfig() cfg := config.GetConfig()
if mapping, exists := cfg.Registries[domain]; exists { if mapping, exists := cfg.Registries[domain]; exists {
return mapping.Enabled return mapping.Enabled
} }
@@ -51,27 +51,26 @@ func (rd *RegistryDetector) isRegistryEnabled(domain string) bool {
} }
// getRegistryMapping 获取Registry映射配置 // getRegistryMapping 获取Registry映射配置
func (rd *RegistryDetector) getRegistryMapping(domain string) (RegistryMapping, bool) { func (rd *RegistryDetector) getRegistryMapping(domain string) (config.RegistryMapping, bool) {
cfg := GetConfig() cfg := config.GetConfig()
mapping, exists := cfg.Registries[domain] mapping, exists := cfg.Registries[domain]
return mapping, exists && mapping.Enabled return mapping, exists && mapping.Enabled
} }
var registryDetector = &RegistryDetector{} var registryDetector = &RegistryDetector{}
// 初始化Docker代理 // InitDockerProxy 初始化Docker代理
func initDockerProxy() { func InitDockerProxy() {
// 创建目标registry
registry, err := name.NewRegistry("registry-1.docker.io") registry, err := name.NewRegistry("registry-1.docker.io")
if err != nil { if err != nil {
fmt.Printf("创建Docker registry失败: %v\n", err) fmt.Printf("创建Docker registry失败: %v\n", err)
return return
} }
// 配置代理选项
options := []remote.Option{ options := []remote.Option{
remote.WithAuth(authn.Anonymous), remote.WithAuth(authn.Anonymous),
remote.WithUserAgent("hubproxy/go-containerregistry"), remote.WithUserAgent("hubproxy/go-containerregistry"),
remote.WithTransport(utils.GetGlobalHTTPClient().Transport),
} }
dockerProxy = &DockerProxy{ dockerProxy = &DockerProxy{
@@ -84,13 +83,11 @@ func initDockerProxy() {
func ProxyDockerRegistryGin(c *gin.Context) { func ProxyDockerRegistryGin(c *gin.Context) {
path := c.Request.URL.Path path := c.Request.URL.Path
// 处理 /v2/ API版本检查
if path == "/v2/" { if path == "/v2/" {
c.JSON(http.StatusOK, gin.H{}) c.JSON(http.StatusOK, gin.H{})
return return
} }
// 处理不同的API端点
if strings.HasPrefix(path, "/v2/") { if strings.HasPrefix(path, "/v2/") {
handleRegistryRequest(c, path) handleRegistryRequest(c, path)
} else { } else {
@@ -100,16 +97,13 @@ func ProxyDockerRegistryGin(c *gin.Context) {
// handleRegistryRequest 处理Registry请求 // handleRegistryRequest 处理Registry请求
func handleRegistryRequest(c *gin.Context, path string) { func handleRegistryRequest(c *gin.Context, path string) {
// 移除 /v2/ 前缀
pathWithoutV2 := strings.TrimPrefix(path, "/v2/") pathWithoutV2 := strings.TrimPrefix(path, "/v2/")
if registryDomain, remainingPath := registryDetector.detectRegistryDomain(pathWithoutV2); registryDomain != "" { if registryDomain, remainingPath := registryDetector.detectRegistryDomain(pathWithoutV2); registryDomain != "" {
if registryDetector.isRegistryEnabled(registryDomain) { if registryDetector.isRegistryEnabled(registryDomain) {
// 设置目标Registry信息到Context
c.Set("target_registry_domain", registryDomain) c.Set("target_registry_domain", registryDomain)
c.Set("target_path", remainingPath) c.Set("target_path", remainingPath)
// 处理多Registry请求
handleMultiRegistryRequest(c, registryDomain, remainingPath) handleMultiRegistryRequest(c, registryDomain, remainingPath)
return return
} }
@@ -121,19 +115,16 @@ func handleRegistryRequest(c *gin.Context, path string) {
return return
} }
// 自动处理官方镜像的library命名空间
if !strings.Contains(imageName, "/") { if !strings.Contains(imageName, "/") {
imageName = "library/" + imageName imageName = "library/" + imageName
} }
// Docker镜像访问控制检查 if allowed, reason := utils.GlobalAccessController.CheckDockerAccess(imageName); !allowed {
if allowed, reason := GlobalAccessController.CheckDockerAccess(imageName); !allowed {
fmt.Printf("Docker镜像 %s 访问被拒绝: %s\n", imageName, reason) fmt.Printf("Docker镜像 %s 访问被拒绝: %s\n", imageName, reason)
c.String(http.StatusForbidden, "镜像访问被限制") c.String(http.StatusForbidden, "镜像访问被限制")
return return
} }
// 构建完整的镜像引用
imageRef := fmt.Sprintf("%s/%s", dockerProxy.registry.Name(), imageName) imageRef := fmt.Sprintf("%s/%s", dockerProxy.registry.Name(), imageName)
switch apiType { switch apiType {
@@ -150,7 +141,6 @@ func handleRegistryRequest(c *gin.Context, path string) {
// parseRegistryPath 解析Registry路径 // parseRegistryPath 解析Registry路径
func parseRegistryPath(path string) (imageName, apiType, reference string) { func parseRegistryPath(path string) (imageName, apiType, reference string) {
// 查找API端点关键字
if idx := strings.Index(path, "/manifests/"); idx != -1 { if idx := strings.Index(path, "/manifests/"); idx != -1 {
imageName = path[:idx] imageName = path[:idx]
apiType = "manifests" apiType = "manifests"
@@ -177,13 +167,11 @@ func parseRegistryPath(path string) (imageName, apiType, reference string) {
// handleManifestRequest 处理manifest请求 // handleManifestRequest 处理manifest请求
func handleManifestRequest(c *gin.Context, imageRef, reference string) { func handleManifestRequest(c *gin.Context, imageRef, reference string) {
// Manifest缓存逻辑(仅对GET请求缓存) if utils.IsCacheEnabled() && c.Request.Method == http.MethodGet {
if isCacheEnabled() && c.Request.Method == http.MethodGet { cacheKey := utils.BuildManifestCacheKey(imageRef, reference)
cacheKey := buildManifestCacheKey(imageRef, reference)
// 优先从缓存获取 if cachedItem := utils.GlobalCache.Get(cacheKey); cachedItem != nil {
if cachedItem := globalCache.Get(cacheKey); cachedItem != nil { utils.WriteCachedResponse(c, cachedItem)
writeCachedResponse(c, cachedItem)
return return
} }
} }
@@ -191,12 +179,9 @@ func handleManifestRequest(c *gin.Context, imageRef, reference string) {
var ref name.Reference var ref name.Reference
var err error var err error
// 判断reference是digest还是tag
if strings.HasPrefix(reference, "sha256:") { if strings.HasPrefix(reference, "sha256:") {
// 是digest
ref, err = name.NewDigest(fmt.Sprintf("%s@%s", imageRef, reference)) ref, err = name.NewDigest(fmt.Sprintf("%s@%s", imageRef, reference))
} else { } else {
// 是tag
ref, err = name.NewTag(fmt.Sprintf("%s:%s", imageRef, reference)) ref, err = name.NewTag(fmt.Sprintf("%s:%s", imageRef, reference))
} }
@@ -206,9 +191,7 @@ func handleManifestRequest(c *gin.Context, imageRef, reference string) {
return return
} }
// 根据请求方法选择操作
if c.Request.Method == http.MethodHead { if c.Request.Method == http.MethodHead {
// HEAD请求使用remote.Head
desc, err := remote.Head(ref, dockerProxy.options...) desc, err := remote.Head(ref, dockerProxy.options...)
if err != nil { if err != nil {
fmt.Printf("HEAD请求失败: %v\n", err) fmt.Printf("HEAD请求失败: %v\n", err)
@@ -216,13 +199,11 @@ func handleManifestRequest(c *gin.Context, imageRef, reference string) {
return return
} }
// 设置响应头
c.Header("Content-Type", string(desc.MediaType)) c.Header("Content-Type", string(desc.MediaType))
c.Header("Docker-Content-Digest", desc.Digest.String()) c.Header("Docker-Content-Digest", desc.Digest.String())
c.Header("Content-Length", fmt.Sprintf("%d", desc.Size)) c.Header("Content-Length", fmt.Sprintf("%d", desc.Size))
c.Status(http.StatusOK) c.Status(http.StatusOK)
} else { } else {
// GET请求使用remote.Get
desc, err := remote.Get(ref, dockerProxy.options...) desc, err := remote.Get(ref, dockerProxy.options...)
if err != nil { if err != nil {
fmt.Printf("GET请求失败: %v\n", err) fmt.Printf("GET请求失败: %v\n", err)
@@ -230,33 +211,28 @@ func handleManifestRequest(c *gin.Context, imageRef, reference string) {
return return
} }
// 设置响应头
headers := map[string]string{ headers := map[string]string{
"Docker-Content-Digest": desc.Digest.String(), "Docker-Content-Digest": desc.Digest.String(),
"Content-Length": fmt.Sprintf("%d", len(desc.Manifest)), "Content-Length": fmt.Sprintf("%d", len(desc.Manifest)),
} }
// 缓存响应 if utils.IsCacheEnabled() {
if isCacheEnabled() { cacheKey := utils.BuildManifestCacheKey(imageRef, reference)
cacheKey := buildManifestCacheKey(imageRef, reference) ttl := utils.GetManifestTTL(reference)
ttl := getManifestTTL(reference) utils.GlobalCache.Set(cacheKey, desc.Manifest, string(desc.MediaType), headers, ttl)
globalCache.Set(cacheKey, desc.Manifest, string(desc.MediaType), headers, ttl)
} }
// 设置响应头
c.Header("Content-Type", string(desc.MediaType)) c.Header("Content-Type", string(desc.MediaType))
for key, value := range headers { for key, value := range headers {
c.Header(key, value) c.Header(key, value)
} }
// 返回manifest内容
c.Data(http.StatusOK, string(desc.MediaType), desc.Manifest) c.Data(http.StatusOK, string(desc.MediaType), desc.Manifest)
} }
} }
// handleBlobRequest 处理blob请求 // handleBlobRequest 处理blob请求
func handleBlobRequest(c *gin.Context, imageRef, digest string) { func handleBlobRequest(c *gin.Context, imageRef, digest string) {
// 构建digest引用
digestRef, err := name.NewDigest(fmt.Sprintf("%s@%s", imageRef, digest)) digestRef, err := name.NewDigest(fmt.Sprintf("%s@%s", imageRef, digest))
if err != nil { if err != nil {
fmt.Printf("解析digest引用失败: %v\n", err) fmt.Printf("解析digest引用失败: %v\n", err)
@@ -264,7 +240,6 @@ func handleBlobRequest(c *gin.Context, imageRef, digest string) {
return return
} }
// 使用remote.Layer获取layer
layer, err := remote.Layer(digestRef, dockerProxy.options...) layer, err := remote.Layer(digestRef, dockerProxy.options...)
if err != nil { if err != nil {
fmt.Printf("获取layer失败: %v\n", err) fmt.Printf("获取layer失败: %v\n", err)
@@ -272,7 +247,6 @@ func handleBlobRequest(c *gin.Context, imageRef, digest string) {
return return
} }
// 获取layer信息
size, err := layer.Size() size, err := layer.Size()
if err != nil { if err != nil {
fmt.Printf("获取layer大小失败: %v\n", err) fmt.Printf("获取layer大小失败: %v\n", err)
@@ -280,7 +254,6 @@ func handleBlobRequest(c *gin.Context, imageRef, digest string) {
return return
} }
// 获取layer内容
reader, err := layer.Compressed() reader, err := layer.Compressed()
if err != nil { if err != nil {
fmt.Printf("获取layer内容失败: %v\n", err) fmt.Printf("获取layer内容失败: %v\n", err)
@@ -289,19 +262,16 @@ func handleBlobRequest(c *gin.Context, imageRef, digest string) {
} }
defer reader.Close() defer reader.Close()
// 设置响应头
c.Header("Content-Type", "application/octet-stream") c.Header("Content-Type", "application/octet-stream")
c.Header("Content-Length", fmt.Sprintf("%d", size)) c.Header("Content-Length", fmt.Sprintf("%d", size))
c.Header("Docker-Content-Digest", digest) c.Header("Docker-Content-Digest", digest)
// 流式传输blob内容
c.Status(http.StatusOK) c.Status(http.StatusOK)
io.Copy(c.Writer, reader) io.Copy(c.Writer, reader)
} }
// handleTagsRequest 处理tags列表请求 // handleTagsRequest 处理tags列表请求
func handleTagsRequest(c *gin.Context, imageRef string) { func handleTagsRequest(c *gin.Context, imageRef string) {
// 解析repository
repo, err := name.NewRepository(imageRef) repo, err := name.NewRepository(imageRef)
if err != nil { if err != nil {
fmt.Printf("解析repository失败: %v\n", err) fmt.Printf("解析repository失败: %v\n", err)
@@ -309,7 +279,6 @@ func handleTagsRequest(c *gin.Context, imageRef string) {
return return
} }
// 使用remote.List获取tags
tags, err := remote.List(repo, dockerProxy.options...) tags, err := remote.List(repo, dockerProxy.options...)
if err != nil { if err != nil {
fmt.Printf("获取tags失败: %v\n", err) fmt.Printf("获取tags失败: %v\n", err)
@@ -317,7 +286,6 @@ func handleTagsRequest(c *gin.Context, imageRef string) {
return return
} }
// 构建响应
response := map[string]interface{}{ response := map[string]interface{}{
"name": strings.TrimPrefix(imageRef, dockerProxy.registry.Name()+"/"), "name": strings.TrimPrefix(imageRef, dockerProxy.registry.Name()+"/"),
"tags": tags, "tags": tags,
@@ -326,10 +294,9 @@ func handleTagsRequest(c *gin.Context, imageRef string) {
c.JSON(http.StatusOK, response) c.JSON(http.StatusOK, response)
} }
// ProxyDockerAuthGin Docker认证代理(带缓存优化) // ProxyDockerAuthGin Docker认证代理
func ProxyDockerAuthGin(c *gin.Context) { func ProxyDockerAuthGin(c *gin.Context) {
// 检查是否启用token缓存 if utils.IsTokenCacheEnabled() {
if isTokenCacheEnabled() {
proxyDockerAuthWithCache(c) proxyDockerAuthWithCache(c)
} else { } else {
proxyDockerAuthOriginal(c) proxyDockerAuthOriginal(c)
@@ -338,36 +305,28 @@ func ProxyDockerAuthGin(c *gin.Context) {
// proxyDockerAuthWithCache 带缓存的认证代理 // proxyDockerAuthWithCache 带缓存的认证代理
func proxyDockerAuthWithCache(c *gin.Context) { func proxyDockerAuthWithCache(c *gin.Context) {
// 1. 构建缓存key基于完整的查询参数 cacheKey := utils.BuildTokenCacheKey(c.Request.URL.RawQuery)
cacheKey := buildTokenCacheKey(c.Request.URL.RawQuery)
// 2. 尝试从缓存获取token if cachedToken := utils.GlobalCache.GetToken(cacheKey); cachedToken != "" {
if cachedToken := globalCache.GetToken(cacheKey); cachedToken != "" { utils.WriteTokenResponse(c, cachedToken)
writeTokenResponse(c, cachedToken)
return return
} }
// 3. 缓存未命中,创建响应记录器
recorder := &ResponseRecorder{ recorder := &ResponseRecorder{
ResponseWriter: c.Writer, ResponseWriter: c.Writer,
statusCode: 200, statusCode: 200,
} }
c.Writer = recorder c.Writer = recorder
// 4. 调用原有认证逻辑
proxyDockerAuthOriginal(c) proxyDockerAuthOriginal(c)
// 5. 如果认证成功,缓存响应
if recorder.statusCode == 200 && len(recorder.body) > 0 { if recorder.statusCode == 200 && len(recorder.body) > 0 {
ttl := extractTTLFromResponse(recorder.body) ttl := utils.ExtractTTLFromResponse(recorder.body)
globalCache.SetToken(cacheKey, string(recorder.body), ttl) utils.GlobalCache.SetToken(cacheKey, string(recorder.body), ttl)
} }
// 6. 写入实际响应(如果还没写入) c.Writer = recorder.ResponseWriter
if !recorder.written { c.Data(recorder.statusCode, "application/json", recorder.body)
c.Writer = recorder.ResponseWriter
c.Data(recorder.statusCode, "application/json", recorder.body)
}
} }
// ResponseRecorder HTTP响应记录器 // ResponseRecorder HTTP响应记录器
@@ -375,7 +334,6 @@ type ResponseRecorder struct {
gin.ResponseWriter gin.ResponseWriter
statusCode int statusCode int
body []byte body []byte
written bool
} }
func (r *ResponseRecorder) WriteHeader(code int) { func (r *ResponseRecorder) WriteHeader(code int) {
@@ -384,22 +342,18 @@ func (r *ResponseRecorder) WriteHeader(code int) {
func (r *ResponseRecorder) Write(data []byte) (int, error) { func (r *ResponseRecorder) Write(data []byte) (int, error) {
r.body = append(r.body, data...) r.body = append(r.body, data...)
r.written = true return len(data), nil
return r.ResponseWriter.Write(data)
} }
func proxyDockerAuthOriginal(c *gin.Context) { func proxyDockerAuthOriginal(c *gin.Context) {
var authURL string var authURL string
if targetDomain, exists := c.Get("target_registry_domain"); exists { if targetDomain, exists := c.Get("target_registry_domain"); exists {
if mapping, found := registryDetector.getRegistryMapping(targetDomain.(string)); found { if mapping, found := registryDetector.getRegistryMapping(targetDomain.(string)); found {
// 使用Registry特定的认证服务器
authURL = "https://" + mapping.AuthHost + c.Request.URL.Path authURL = "https://" + mapping.AuthHost + c.Request.URL.Path
} else { } else {
// fallback到默认Docker认证
authURL = "https://auth.docker.io" + c.Request.URL.Path authURL = "https://auth.docker.io" + c.Request.URL.Path
} }
} else { } else {
// 构建默认Docker认证URL
authURL = "https://auth.docker.io" + c.Request.URL.Path authURL = "https://auth.docker.io" + c.Request.URL.Path
} }
@@ -407,12 +361,11 @@ func proxyDockerAuthOriginal(c *gin.Context) {
authURL += "?" + c.Request.URL.RawQuery authURL += "?" + c.Request.URL.RawQuery
} }
// 创建HTTP客户端
client := &http.Client{ client := &http.Client{
Timeout: 30 * time.Second, Timeout: 30 * time.Second,
Transport: utils.GetGlobalHTTPClient().Transport,
} }
// 创建请求
req, err := http.NewRequestWithContext( req, err := http.NewRequestWithContext(
context.Background(), context.Background(),
c.Request.Method, c.Request.Method,
@@ -424,14 +377,12 @@ func proxyDockerAuthOriginal(c *gin.Context) {
return return
} }
// 复制请求头
for key, values := range c.Request.Header { for key, values := range c.Request.Header {
for _, value := range values { for _, value := range values {
req.Header.Add(key, value) req.Header.Add(key, value)
} }
} }
// 执行请求
resp, err := client.Do(req) resp, err := client.Do(req)
if err != nil { if err != nil {
c.String(http.StatusBadGateway, "Auth request failed") c.String(http.StatusBadGateway, "Auth request failed")
@@ -439,37 +390,30 @@ func proxyDockerAuthOriginal(c *gin.Context) {
} }
defer resp.Body.Close() defer resp.Body.Close()
// 获取当前代理的Host地址
proxyHost := c.Request.Host proxyHost := c.Request.Host
if proxyHost == "" { if proxyHost == "" {
// 使用配置中的服务器地址和端口 cfg := config.GetConfig()
cfg := GetConfig()
proxyHost = fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port) proxyHost = fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port)
if cfg.Server.Host == "0.0.0.0" { if cfg.Server.Host == "0.0.0.0" {
proxyHost = fmt.Sprintf("localhost:%d", cfg.Server.Port) proxyHost = fmt.Sprintf("localhost:%d", cfg.Server.Port)
} }
} }
// 复制响应头并重写认证URL
for key, values := range resp.Header { for key, values := range resp.Header {
for _, value := range values { for _, value := range values {
// 重写WWW-Authenticate头中的realm URL
if key == "Www-Authenticate" { if key == "Www-Authenticate" {
// 支持多Registry的URL重写
value = rewriteAuthHeader(value, proxyHost) value = rewriteAuthHeader(value, proxyHost)
} }
c.Header(key, value) c.Header(key, value)
} }
} }
// 返回响应
c.Status(resp.StatusCode) c.Status(resp.StatusCode)
io.Copy(c.Writer, resp.Body) io.Copy(c.Writer, resp.Body)
} }
// rewriteAuthHeader 重写认证头 // rewriteAuthHeader 重写认证头
func rewriteAuthHeader(authHeader, proxyHost string) string { func rewriteAuthHeader(authHeader, proxyHost string) string {
// 重写各种Registry的认证URL
authHeader = strings.ReplaceAll(authHeader, "https://auth.docker.io", "http://"+proxyHost) authHeader = strings.ReplaceAll(authHeader, "https://auth.docker.io", "http://"+proxyHost)
authHeader = strings.ReplaceAll(authHeader, "https://ghcr.io", "http://"+proxyHost) authHeader = strings.ReplaceAll(authHeader, "https://ghcr.io", "http://"+proxyHost)
authHeader = strings.ReplaceAll(authHeader, "https://gcr.io", "http://"+proxyHost) authHeader = strings.ReplaceAll(authHeader, "https://gcr.io", "http://"+proxyHost)
@@ -480,32 +424,27 @@ func rewriteAuthHeader(authHeader, proxyHost string) string {
// handleMultiRegistryRequest 处理多Registry请求 // handleMultiRegistryRequest 处理多Registry请求
func handleMultiRegistryRequest(c *gin.Context, registryDomain, remainingPath string) { func handleMultiRegistryRequest(c *gin.Context, registryDomain, remainingPath string) {
// 获取Registry映射配置
mapping, exists := registryDetector.getRegistryMapping(registryDomain) mapping, exists := registryDetector.getRegistryMapping(registryDomain)
if !exists { if !exists {
c.String(http.StatusBadRequest, "Registry not configured") c.String(http.StatusBadRequest, "Registry not configured")
return return
} }
// 解析剩余路径
imageName, apiType, reference := parseRegistryPath(remainingPath) imageName, apiType, reference := parseRegistryPath(remainingPath)
if imageName == "" || apiType == "" { if imageName == "" || apiType == "" {
c.String(http.StatusBadRequest, "Invalid path format") c.String(http.StatusBadRequest, "Invalid path format")
return return
} }
// 访问控制检查(使用完整的镜像路径)
fullImageName := registryDomain + "/" + imageName fullImageName := registryDomain + "/" + imageName
if allowed, reason := GlobalAccessController.CheckDockerAccess(fullImageName); !allowed { if allowed, reason := utils.GlobalAccessController.CheckDockerAccess(fullImageName); !allowed {
fmt.Printf("镜像 %s 访问被拒绝: %s\n", fullImageName, reason) fmt.Printf("镜像 %s 访问被拒绝: %s\n", fullImageName, reason)
c.String(http.StatusForbidden, "镜像访问被限制") c.String(http.StatusForbidden, "镜像访问被限制")
return return
} }
// 构建上游Registry引用
upstreamImageRef := fmt.Sprintf("%s/%s", mapping.Upstream, imageName) upstreamImageRef := fmt.Sprintf("%s/%s", mapping.Upstream, imageName)
// 根据API类型处理请求
switch apiType { switch apiType {
case "manifests": case "manifests":
handleUpstreamManifestRequest(c, upstreamImageRef, reference, mapping) handleUpstreamManifestRequest(c, upstreamImageRef, reference, mapping)
@@ -519,14 +458,12 @@ func handleMultiRegistryRequest(c *gin.Context, registryDomain, remainingPath st
} }
// handleUpstreamManifestRequest 处理上游Registry的manifest请求 // handleUpstreamManifestRequest 处理上游Registry的manifest请求
func handleUpstreamManifestRequest(c *gin.Context, imageRef, reference string, mapping RegistryMapping) { func handleUpstreamManifestRequest(c *gin.Context, imageRef, reference string, mapping config.RegistryMapping) {
// Manifest缓存逻辑(仅对GET请求缓存) if utils.IsCacheEnabled() && c.Request.Method == http.MethodGet {
if isCacheEnabled() && c.Request.Method == http.MethodGet { cacheKey := utils.BuildManifestCacheKey(imageRef, reference)
cacheKey := buildManifestCacheKey(imageRef, reference)
// 优先从缓存获取 if cachedItem := utils.GlobalCache.Get(cacheKey); cachedItem != nil {
if cachedItem := globalCache.Get(cacheKey); cachedItem != nil { utils.WriteCachedResponse(c, cachedItem)
writeCachedResponse(c, cachedItem)
return return
} }
} }
@@ -534,7 +471,6 @@ func handleUpstreamManifestRequest(c *gin.Context, imageRef, reference string, m
var ref name.Reference var ref name.Reference
var err error var err error
// 判断reference是digest还是tag
if strings.HasPrefix(reference, "sha256:") { if strings.HasPrefix(reference, "sha256:") {
ref, err = name.NewDigest(fmt.Sprintf("%s@%s", imageRef, reference)) ref, err = name.NewDigest(fmt.Sprintf("%s@%s", imageRef, reference))
} else { } else {
@@ -547,10 +483,8 @@ func handleUpstreamManifestRequest(c *gin.Context, imageRef, reference string, m
return return
} }
// 创建针对上游Registry的选项
options := createUpstreamOptions(mapping) options := createUpstreamOptions(mapping)
// 根据请求方法选择操作
if c.Request.Method == http.MethodHead { if c.Request.Method == http.MethodHead {
desc, err := remote.Head(ref, options...) desc, err := remote.Head(ref, options...)
if err != nil { if err != nil {
@@ -571,20 +505,17 @@ func handleUpstreamManifestRequest(c *gin.Context, imageRef, reference string, m
return return
} }
// 设置响应头
headers := map[string]string{ headers := map[string]string{
"Docker-Content-Digest": desc.Digest.String(), "Docker-Content-Digest": desc.Digest.String(),
"Content-Length": fmt.Sprintf("%d", len(desc.Manifest)), "Content-Length": fmt.Sprintf("%d", len(desc.Manifest)),
} }
// 缓存响应 if utils.IsCacheEnabled() {
if isCacheEnabled() { cacheKey := utils.BuildManifestCacheKey(imageRef, reference)
cacheKey := buildManifestCacheKey(imageRef, reference) ttl := utils.GetManifestTTL(reference)
ttl := getManifestTTL(reference) utils.GlobalCache.Set(cacheKey, desc.Manifest, string(desc.MediaType), headers, ttl)
globalCache.Set(cacheKey, desc.Manifest, string(desc.MediaType), headers, ttl)
} }
// 设置响应头
c.Header("Content-Type", string(desc.MediaType)) c.Header("Content-Type", string(desc.MediaType))
for key, value := range headers { for key, value := range headers {
c.Header(key, value) c.Header(key, value)
@@ -595,7 +526,7 @@ func handleUpstreamManifestRequest(c *gin.Context, imageRef, reference string, m
} }
// handleUpstreamBlobRequest 处理上游Registry的blob请求 // handleUpstreamBlobRequest 处理上游Registry的blob请求
func handleUpstreamBlobRequest(c *gin.Context, imageRef, digest string, mapping RegistryMapping) { func handleUpstreamBlobRequest(c *gin.Context, imageRef, digest string, mapping config.RegistryMapping) {
digestRef, err := name.NewDigest(fmt.Sprintf("%s@%s", imageRef, digest)) digestRef, err := name.NewDigest(fmt.Sprintf("%s@%s", imageRef, digest))
if err != nil { if err != nil {
fmt.Printf("解析digest引用失败: %v\n", err) fmt.Printf("解析digest引用失败: %v\n", err)
@@ -635,7 +566,7 @@ func handleUpstreamBlobRequest(c *gin.Context, imageRef, digest string, mapping
} }
// handleUpstreamTagsRequest 处理上游Registry的tags请求 // handleUpstreamTagsRequest 处理上游Registry的tags请求
func handleUpstreamTagsRequest(c *gin.Context, imageRef string, mapping RegistryMapping) { func handleUpstreamTagsRequest(c *gin.Context, imageRef string, mapping config.RegistryMapping) {
repo, err := name.NewRepository(imageRef) repo, err := name.NewRepository(imageRef)
if err != nil { if err != nil {
fmt.Printf("解析repository失败: %v\n", err) fmt.Printf("解析repository失败: %v\n", err)
@@ -660,13 +591,13 @@ func handleUpstreamTagsRequest(c *gin.Context, imageRef string, mapping Registry
} }
// createUpstreamOptions 创建上游Registry选项 // createUpstreamOptions 创建上游Registry选项
func createUpstreamOptions(mapping RegistryMapping) []remote.Option { func createUpstreamOptions(mapping config.RegistryMapping) []remote.Option {
options := []remote.Option{ options := []remote.Option{
remote.WithAuth(authn.Anonymous), remote.WithAuth(authn.Anonymous),
remote.WithUserAgent("hubproxy/go-containerregistry"), remote.WithUserAgent("hubproxy/go-containerregistry"),
remote.WithTransport(utils.GetGlobalHTTPClient().Transport),
} }
// 根据Registry类型添加特定的认证选项方便后续扩展
switch mapping.AuthType { switch mapping.AuthType {
case "github": case "github":
case "google": case "google":

213
src/handlers/github.go Normal file
View File

@@ -0,0 +1,213 @@
package handlers
import (
"fmt"
"io"
"net/http"
"regexp"
"strconv"
"strings"
"github.com/gin-gonic/gin"
"hubproxy/config"
"hubproxy/utils"
)
var (
// GitHub URL匹配正则表达式
githubExps = []*regexp.Regexp{
regexp.MustCompile(`^(?:https?://)?github\.com/([^/]+)/([^/]+)/(?:releases|archive)/.*`),
regexp.MustCompile(`^(?:https?://)?github\.com/([^/]+)/([^/]+)/(?:blob|raw)/.*`),
regexp.MustCompile(`^(?:https?://)?github\.com/([^/]+)/([^/]+)/(?:info|git-).*`),
regexp.MustCompile(`^(?:https?://)?raw\.github(?:usercontent|)\.com/([^/]+)/([^/]+)/.+?/.+`),
regexp.MustCompile(`^(?:https?://)?gist\.(?:githubusercontent|github)\.com/(.+?)/(.+?)/.+\.[a-zA-Z0-9]+$`),
regexp.MustCompile(`^(?:https?://)?api\.github\.com/repos/([^/]+)/([^/]+)/.*`),
regexp.MustCompile(`^(?:https?://)?huggingface\.co(?:/spaces)?/([^/]+)/(.+)`),
regexp.MustCompile(`^(?:https?://)?cdn-lfs\.hf\.co(?:/spaces)?/([^/]+)/([^/]+)(?:/(.*))?`),
regexp.MustCompile(`^(?:https?://)?download\.docker\.com/([^/]+)/.*\.(tgz|zip)`),
regexp.MustCompile(`^(?:https?://)?(github|opengraph)\.githubassets\.com/([^/]+)/.+?`),
}
)
// GitHubProxyHandler GitHub代理处理器
func GitHubProxyHandler(c *gin.Context) {
rawPath := strings.TrimPrefix(c.Request.URL.RequestURI(), "/")
for strings.HasPrefix(rawPath, "/") {
rawPath = strings.TrimPrefix(rawPath, "/")
}
// 自动补全协议头
if !strings.HasPrefix(rawPath, "https://") {
if strings.HasPrefix(rawPath, "http:/") || strings.HasPrefix(rawPath, "https:/") {
rawPath = strings.Replace(rawPath, "http:/", "", 1)
rawPath = strings.Replace(rawPath, "https:/", "", 1)
} else if strings.HasPrefix(rawPath, "http://") {
rawPath = strings.TrimPrefix(rawPath, "http://")
}
rawPath = "https://" + rawPath
}
matches := CheckGitHubURL(rawPath)
if matches != nil {
if allowed, reason := utils.GlobalAccessController.CheckGitHubAccess(matches); !allowed {
var repoPath string
if len(matches) >= 2 {
username := matches[0]
repoName := strings.TrimSuffix(matches[1], ".git")
repoPath = username + "/" + repoName
}
fmt.Printf("GitHub仓库 %s 访问被拒绝: %s\n", repoPath, reason)
c.String(http.StatusForbidden, reason)
return
}
} else {
c.String(http.StatusForbidden, "无效输入")
return
}
// 将blob链接转换为raw链接
if githubExps[1].MatchString(rawPath) {
rawPath = strings.Replace(rawPath, "/blob/", "/raw/", 1)
}
ProxyGitHubRequest(c, rawPath)
}
// CheckGitHubURL 检查URL是否匹配GitHub模式
func CheckGitHubURL(u string) []string {
for _, exp := range githubExps {
if matches := exp.FindStringSubmatch(u); matches != nil {
return matches[1:]
}
}
return nil
}
// ProxyGitHubRequest 代理GitHub请求
func ProxyGitHubRequest(c *gin.Context, u string) {
proxyGitHubWithRedirect(c, u, 0)
}
// proxyGitHubWithRedirect 带重定向的GitHub代理请求
func proxyGitHubWithRedirect(c *gin.Context, u string, redirectCount int) {
const maxRedirects = 20
if redirectCount > maxRedirects {
c.String(http.StatusLoopDetected, "重定向次数过多,可能存在循环重定向")
return
}
req, err := http.NewRequest(c.Request.Method, u, c.Request.Body)
if err != nil {
c.String(http.StatusInternalServerError, fmt.Sprintf("server error %v", err))
return
}
// 复制请求头
for key, values := range c.Request.Header {
for _, value := range values {
req.Header.Add(key, value)
}
}
req.Header.Del("Host")
resp, err := utils.GetGlobalHTTPClient().Do(req)
if err != nil {
c.String(http.StatusInternalServerError, fmt.Sprintf("server error %v", err))
return
}
defer func() {
if err := resp.Body.Close(); err != nil {
fmt.Printf("关闭响应体失败: %v\n", err)
}
}()
// 检查文件大小限制
cfg := config.GetConfig()
if contentLength := resp.Header.Get("Content-Length"); contentLength != "" {
if size, err := strconv.ParseInt(contentLength, 10, 64); err == nil && size > cfg.Server.FileSize {
c.String(http.StatusRequestEntityTooLarge,
fmt.Sprintf("文件过大,限制大小: %d MB", cfg.Server.FileSize/(1024*1024)))
return
}
}
// 清理安全相关的头
resp.Header.Del("Content-Security-Policy")
resp.Header.Del("Referrer-Policy")
resp.Header.Del("Strict-Transport-Security")
// 获取真实域名
realHost := c.Request.Header.Get("X-Forwarded-Host")
if realHost == "" {
realHost = c.Request.Host
}
if !strings.HasPrefix(realHost, "http://") && !strings.HasPrefix(realHost, "https://") {
realHost = "https://" + realHost
}
// 处理.sh文件的智能处理
if strings.HasSuffix(strings.ToLower(u), ".sh") {
isGzipCompressed := resp.Header.Get("Content-Encoding") == "gzip"
processedBody, processedSize, err := utils.ProcessSmart(resp.Body, isGzipCompressed, realHost)
if err != nil {
fmt.Printf("智能处理失败,回退到直接代理: %v\n", err)
processedBody = resp.Body
processedSize = 0
}
// 智能设置响应头
if processedSize > 0 {
resp.Header.Del("Content-Length")
resp.Header.Del("Content-Encoding")
resp.Header.Set("Transfer-Encoding", "chunked")
}
// 复制其他响应头
for key, values := range resp.Header {
for _, value := range values {
c.Header(key, value)
}
}
// 处理重定向
if location := resp.Header.Get("Location"); location != "" {
if CheckGitHubURL(location) != nil {
c.Header("Location", "/"+location)
} else {
proxyGitHubWithRedirect(c, location, redirectCount+1)
return
}
}
c.Status(resp.StatusCode)
// 输出处理后的内容
if _, err := io.Copy(c.Writer, processedBody); err != nil {
return
}
} else {
// 复制响应头
for key, values := range resp.Header {
for _, value := range values {
c.Header(key, value)
}
}
// 处理重定向
if location := resp.Header.Get("Location"); location != "" {
if CheckGitHubURL(location) != nil {
c.Header("Location", "/"+location)
} else {
proxyGitHubWithRedirect(c, location, redirectCount+1)
return
}
}
c.Status(resp.StatusCode)
// 直接流式转发
io.Copy(c.Writer, resp.Body)
}
}

View File

@@ -1,4 +1,4 @@
package main package handlers
import ( import (
"archive/tar" "archive/tar"
@@ -23,6 +23,8 @@ import (
"github.com/google/go-containerregistry/pkg/v1/partial" "github.com/google/go-containerregistry/pkg/v1/partial"
"github.com/google/go-containerregistry/pkg/v1/remote" "github.com/google/go-containerregistry/pkg/v1/remote"
"github.com/google/go-containerregistry/pkg/v1/types" "github.com/google/go-containerregistry/pkg/v1/types"
"hubproxy/config"
"hubproxy/utils"
) )
// DebounceEntry 防抖条目 // DebounceEntry 防抖条目
@@ -33,16 +35,18 @@ type DebounceEntry struct {
// DownloadDebouncer 下载防抖器 // DownloadDebouncer 下载防抖器
type DownloadDebouncer struct { type DownloadDebouncer struct {
mu sync.RWMutex mu sync.RWMutex
entries map[string]*DebounceEntry entries map[string]*DebounceEntry
window time.Duration window time.Duration
lastCleanup time.Time
} }
// NewDownloadDebouncer 创建下载防抖器 // NewDownloadDebouncer 创建下载防抖器
func NewDownloadDebouncer(window time.Duration) *DownloadDebouncer { func NewDownloadDebouncer(window time.Duration) *DownloadDebouncer {
return &DownloadDebouncer{ return &DownloadDebouncer{
entries: make(map[string]*DebounceEntry), entries: make(map[string]*DebounceEntry),
window: window, window: window,
lastCleanup: time.Now(),
} }
} }
@@ -56,19 +60,18 @@ func (d *DownloadDebouncer) ShouldAllow(userID, contentKey string) bool {
if entry, exists := d.entries[key]; exists { if entry, exists := d.entries[key]; exists {
if now.Sub(entry.LastRequest) < d.window { if now.Sub(entry.LastRequest) < d.window {
return false // 在防抖窗口内,拒绝请求 return false
} }
} }
// 更新或创建条目
d.entries[key] = &DebounceEntry{ d.entries[key] = &DebounceEntry{
LastRequest: now, LastRequest: now,
UserID: userID, UserID: userID,
} }
// 清理过期条目简单策略每100次请求清理一次 if time.Since(d.lastCleanup) > 5*time.Minute {
if len(d.entries)%100 == 0 {
d.cleanup(now) d.cleanup(now)
d.lastCleanup = now
} }
return true return true
@@ -85,51 +88,42 @@ func (d *DownloadDebouncer) cleanup(now time.Time) {
// generateContentFingerprint 生成内容指纹 // generateContentFingerprint 生成内容指纹
func generateContentFingerprint(images []string, platform string) string { func generateContentFingerprint(images []string, platform string) string {
// 对镜像列表排序确保顺序无关
sortedImages := make([]string, len(images)) sortedImages := make([]string, len(images))
copy(sortedImages, images) copy(sortedImages, images)
sort.Strings(sortedImages) sort.Strings(sortedImages)
// 组合内容:镜像列表 + 平台信息
content := strings.Join(sortedImages, "|") + ":" + platform content := strings.Join(sortedImages, "|") + ":" + platform
// 生成MD5哈希
hash := md5.Sum([]byte(content)) hash := md5.Sum([]byte(content))
return hex.EncodeToString(hash[:]) return hex.EncodeToString(hash[:])
} }
// getUserID 获取用户标识 // getUserID 获取用户标识
func getUserID(c *gin.Context) string { func getUserID(c *gin.Context) string {
// 优先使用会话Cookie
if sessionID, err := c.Cookie("session_id"); err == nil && sessionID != "" { if sessionID, err := c.Cookie("session_id"); err == nil && sessionID != "" {
return "session:" + sessionID return "session:" + sessionID
} }
// 备用方案IP + User-Agent组合
ip := c.ClientIP() ip := c.ClientIP()
userAgent := c.GetHeader("User-Agent") userAgent := c.GetHeader("User-Agent")
if userAgent == "" { if userAgent == "" {
userAgent = "unknown" userAgent = "unknown"
} }
// 生成简短标识
combined := ip + ":" + userAgent combined := ip + ":" + userAgent
hash := md5.Sum([]byte(combined)) hash := md5.Sum([]byte(combined))
return "ip:" + hex.EncodeToString(hash[:8]) // 只取前8字节 return "ip:" + hex.EncodeToString(hash[:8])
} }
// 全局防抖器实例
var ( var (
singleImageDebouncer *DownloadDebouncer singleImageDebouncer *DownloadDebouncer
batchImageDebouncer *DownloadDebouncer batchImageDebouncer *DownloadDebouncer
) )
// initDebouncer 初始化防抖器 // InitDebouncer 初始化防抖器
func initDebouncer() { func InitDebouncer() {
// 单个镜像5秒防抖窗口
singleImageDebouncer = NewDownloadDebouncer(5 * time.Second) singleImageDebouncer = NewDownloadDebouncer(5 * time.Second)
// 批量镜像30秒防抖窗口影响更大需要更长保护 batchImageDebouncer = NewDownloadDebouncer(60 * time.Second)
batchImageDebouncer = NewDownloadDebouncer(30 * time.Second)
} }
// ImageStreamer 镜像流式下载器 // ImageStreamer 镜像流式下载器
@@ -144,15 +138,15 @@ type ImageStreamerConfig struct {
} }
// NewImageStreamer 创建镜像下载器 // NewImageStreamer 创建镜像下载器
func NewImageStreamer(config *ImageStreamerConfig) *ImageStreamer { func NewImageStreamer(cfg *ImageStreamerConfig) *ImageStreamer {
if config == nil { if cfg == nil {
config = &ImageStreamerConfig{} cfg = &ImageStreamerConfig{}
} }
concurrency := config.Concurrency concurrency := cfg.Concurrency
if concurrency <= 0 { if concurrency <= 0 {
cfg := GetConfig() appCfg := config.GetConfig()
concurrency = cfg.Download.MaxImages concurrency = appCfg.Download.MaxImages
if concurrency <= 0 { if concurrency <= 0 {
concurrency = 10 concurrency = 10
} }
@@ -160,7 +154,7 @@ func NewImageStreamer(config *ImageStreamerConfig) *ImageStreamer {
remoteOptions := []remote.Option{ remoteOptions := []remote.Option{
remote.WithAuth(authn.Anonymous), remote.WithAuth(authn.Anonymous),
remote.WithTransport(GetGlobalHTTPClient().Transport), remote.WithTransport(utils.GetGlobalHTTPClient().Transport),
} }
return &ImageStreamer{ return &ImageStreamer{
@@ -171,14 +165,15 @@ func NewImageStreamer(config *ImageStreamerConfig) *ImageStreamer {
// StreamOptions 下载选项 // StreamOptions 下载选项
type StreamOptions struct { type StreamOptions struct {
Platform string Platform string
Compression bool Compression bool
UseCompressedLayers bool
} }
// StreamImageToWriter 流式下载镜像到Writer // StreamImageToWriter 流式下载镜像到Writer
func (is *ImageStreamer) StreamImageToWriter(ctx context.Context, imageRef string, writer io.Writer, options *StreamOptions) error { func (is *ImageStreamer) StreamImageToWriter(ctx context.Context, imageRef string, writer io.Writer, options *StreamOptions) error {
if options == nil { if options == nil {
options = &StreamOptions{} options = &StreamOptions{UseCompressedLayers: true}
} }
ref, err := name.ParseReference(imageRef) ref, err := name.ParseReference(imageRef)
@@ -211,57 +206,13 @@ func (is *ImageStreamer) getImageDescriptor(ref name.Reference, options []remote
// getImageDescriptorWithPlatform 获取指定平台的镜像描述符 // getImageDescriptorWithPlatform 获取指定平台的镜像描述符
func (is *ImageStreamer) getImageDescriptorWithPlatform(ref name.Reference, options []remote.Option, platform string) (*remote.Descriptor, error) { func (is *ImageStreamer) getImageDescriptorWithPlatform(ref name.Reference, options []remote.Option, platform string) (*remote.Descriptor, error) {
if isCacheEnabled() { return remote.Get(ref, options...)
var reference string
if tagged, ok := ref.(name.Tag); ok {
reference = tagged.TagStr()
} else if digested, ok := ref.(name.Digest); ok {
reference = digested.DigestStr()
}
if reference != "" {
cacheKey := buildManifestCacheKeyWithPlatform(ref.Context().String(), reference, platform)
if cachedItem := globalCache.Get(cacheKey); cachedItem != nil {
desc := &remote.Descriptor{
Manifest: cachedItem.Data,
}
log.Printf("使用缓存的manifest: %s (平台: %s)", ref.String(), platform)
return desc, nil
}
}
}
desc, err := remote.Get(ref, options...)
if err != nil {
return nil, err
}
if isCacheEnabled() {
var reference string
if tagged, ok := ref.(name.Tag); ok {
reference = tagged.TagStr()
} else if digested, ok := ref.(name.Digest); ok {
reference = digested.DigestStr()
}
if reference != "" {
cacheKey := buildManifestCacheKeyWithPlatform(ref.Context().String(), reference, platform)
ttl := getManifestTTL(reference)
headers := map[string]string{
"Docker-Content-Digest": desc.Digest.String(),
}
globalCache.Set(cacheKey, desc.Manifest, string(desc.MediaType), headers, ttl)
log.Printf("缓存manifest: %s (平台: %s, TTL: %v)", ref.String(), platform, ttl)
}
}
return desc, nil
} }
// StreamImageToGin 流式响应到Gin // StreamImageToGin 流式响应到Gin
func (is *ImageStreamer) StreamImageToGin(ctx context.Context, imageRef string, c *gin.Context, options *StreamOptions) error { func (is *ImageStreamer) StreamImageToGin(ctx context.Context, imageRef string, c *gin.Context, options *StreamOptions) error {
if options == nil { if options == nil {
options = &StreamOptions{} options = &StreamOptions{UseCompressedLayers: true}
} }
filename := strings.ReplaceAll(imageRef, "/", "_") + ".tar" filename := strings.ReplaceAll(imageRef, "/", "_") + ".tar"
@@ -320,16 +271,16 @@ func (is *ImageStreamer) streamImageLayers(ctx context.Context, img v1.Image, wr
log.Printf("镜像包含 %d 层", len(layers)) log.Printf("镜像包含 %d 层", len(layers))
return is.streamDockerFormat(ctx, tarWriter, img, layers, configFile, imageRef) return is.streamDockerFormat(ctx, tarWriter, img, layers, configFile, imageRef, options)
} }
// streamDockerFormat 生成Docker格式 // streamDockerFormat 生成Docker格式
func (is *ImageStreamer) streamDockerFormat(ctx context.Context, tarWriter *tar.Writer, img v1.Image, layers []v1.Layer, configFile *v1.ConfigFile, imageRef string) error { func (is *ImageStreamer) streamDockerFormat(ctx context.Context, tarWriter *tar.Writer, img v1.Image, layers []v1.Layer, configFile *v1.ConfigFile, imageRef string, options *StreamOptions) error {
return is.streamDockerFormatWithReturn(ctx, tarWriter, img, layers, configFile, imageRef, nil, nil) return is.streamDockerFormatWithReturn(ctx, tarWriter, img, layers, configFile, imageRef, nil, nil, options)
} }
// streamDockerFormatWithReturn 生成Docker格式并返回manifest和repositories信息 // streamDockerFormatWithReturn 生成Docker格式并返回manifest和repositories信息
func (is *ImageStreamer) streamDockerFormatWithReturn(ctx context.Context, tarWriter *tar.Writer, img v1.Image, layers []v1.Layer, configFile *v1.ConfigFile, imageRef string, manifestOut *map[string]interface{}, repositoriesOut *map[string]map[string]string) error { func (is *ImageStreamer) streamDockerFormatWithReturn(ctx context.Context, tarWriter *tar.Writer, img v1.Image, layers []v1.Layer, configFile *v1.ConfigFile, imageRef string, manifestOut *map[string]interface{}, repositoriesOut *map[string]map[string]string, options *StreamOptions) error {
configDigest, err := img.ConfigName() configDigest, err := img.ConfigName()
if err != nil { if err != nil {
return err return err
@@ -379,12 +330,23 @@ func (is *ImageStreamer) streamDockerFormatWithReturn(ctx context.Context, tarWr
return err return err
} }
uncompressedSize, err := partial.UncompressedSize(layer) var layerSize int64
if err != nil { var layerReader io.ReadCloser
return err
if options != nil && options.UseCompressedLayers {
layerSize, err = layer.Size()
if err != nil {
return err
}
layerReader, err = layer.Compressed()
} else {
layerSize, err = partial.UncompressedSize(layer)
if err != nil {
return err
}
layerReader, err = layer.Uncompressed()
} }
layerReader, err := layer.Uncompressed()
if err != nil { if err != nil {
return err return err
} }
@@ -392,7 +354,7 @@ func (is *ImageStreamer) streamDockerFormatWithReturn(ctx context.Context, tarWr
layerTarHeader := &tar.Header{ layerTarHeader := &tar.Header{
Name: layerDir + "/layer.tar", Name: layerDir + "/layer.tar",
Size: uncompressedSize, Size: layerSize,
Mode: 0644, Mode: 0644,
} }
@@ -412,12 +374,10 @@ func (is *ImageStreamer) streamDockerFormatWithReturn(ctx context.Context, tarWr
log.Printf("已处理层 %d/%d", i+1, len(layers)) log.Printf("已处理层 %d/%d", i+1, len(layers))
} }
// 构建单个镜像的manifest信息
singleManifest := map[string]interface{}{ singleManifest := map[string]interface{}{
"Config": configDigest.String() + ".json", "Config": configDigest.String() + ".json",
"RepoTags": []string{imageRef}, "RepoTags": []string{imageRef},
"Layers": func() []string { "Layers": func() []string {
var layers []string var layers []string
for _, digest := range layerDigests { for _, digest := range layerDigests {
layers = append(layers, digest+"/layer.tar") layers = append(layers, digest+"/layer.tar")
@@ -426,7 +386,6 @@ func (is *ImageStreamer) streamDockerFormatWithReturn(ctx context.Context, tarWr
}(), }(),
} }
// 构建repositories信息
repositories := make(map[string]map[string]string) repositories := make(map[string]map[string]string)
parts := strings.Split(imageRef, ":") parts := strings.Split(imageRef, ":")
if len(parts) == 2 { if len(parts) == 2 {
@@ -435,14 +394,12 @@ func (is *ImageStreamer) streamDockerFormatWithReturn(ctx context.Context, tarWr
repositories[repoName] = map[string]string{tag: configDigest.String()} repositories[repoName] = map[string]string{tag: configDigest.String()}
} }
// 如果是批量下载,返回信息而不写入文件
if manifestOut != nil && repositoriesOut != nil { if manifestOut != nil && repositoriesOut != nil {
*manifestOut = singleManifest *manifestOut = singleManifest
*repositoriesOut = repositories *repositoriesOut = repositories
return nil return nil
} }
// 单镜像下载直接写入manifest.json
manifest := []map[string]interface{}{singleManifest} manifest := []map[string]interface{}{singleManifest}
manifestData, err := json.Marshal(manifest) manifestData, err := json.Marshal(manifest)
@@ -464,7 +421,6 @@ func (is *ImageStreamer) streamDockerFormatWithReturn(ctx context.Context, tarWr
return err return err
} }
// 写入repositories文件
repositoriesData, err := json.Marshal(repositories) repositoriesData, err := json.Marshal(repositories)
if err != nil { if err != nil {
return err return err
@@ -484,7 +440,31 @@ func (is *ImageStreamer) streamDockerFormatWithReturn(ctx context.Context, tarWr
return err return err
} }
// streamSingleImageForBatch 为批量下载流式处理单个镜像 // processImageForBatch 处理镜像的公共逻辑
func (is *ImageStreamer) processImageForBatch(ctx context.Context, img v1.Image, tarWriter *tar.Writer, imageRef string, options *StreamOptions) (map[string]interface{}, map[string]map[string]string, error) {
layers, err := img.Layers()
if err != nil {
return nil, nil, fmt.Errorf("获取镜像层失败: %w", err)
}
configFile, err := img.ConfigFile()
if err != nil {
return nil, nil, fmt.Errorf("获取镜像配置失败: %w", err)
}
log.Printf("镜像包含 %d 层", len(layers))
var manifest map[string]interface{}
var repositories map[string]map[string]string
err = is.streamDockerFormatWithReturn(ctx, tarWriter, img, layers, configFile, imageRef, &manifest, &repositories, options)
if err != nil {
return nil, nil, err
}
return manifest, repositories, nil
}
func (is *ImageStreamer) streamSingleImageForBatch(ctx context.Context, tarWriter *tar.Writer, imageRef string, options *StreamOptions) (map[string]interface{}, map[string]map[string]string, error) { func (is *ImageStreamer) streamSingleImageForBatch(ctx context.Context, tarWriter *tar.Writer, imageRef string, options *StreamOptions) (map[string]interface{}, map[string]map[string]string, error) {
ref, err := name.ParseReference(imageRef) ref, err := name.ParseReference(imageRef)
if err != nil { if err != nil {
@@ -498,80 +478,27 @@ func (is *ImageStreamer) streamSingleImageForBatch(ctx context.Context, tarWrite
return nil, nil, fmt.Errorf("获取镜像描述失败: %w", err) return nil, nil, fmt.Errorf("获取镜像描述失败: %w", err)
} }
var manifest map[string]interface{} var img v1.Image
var repositories map[string]map[string]string
switch desc.MediaType { switch desc.MediaType {
case types.OCIImageIndex, types.DockerManifestList: case types.OCIImageIndex, types.DockerManifestList:
// 处理多架构镜像,复用单个下载的逻辑 img, err = is.selectPlatformImage(desc, options)
img, err := is.selectPlatformImage(desc, options)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("选择平台镜像失败: %w", err) return nil, nil, fmt.Errorf("选择平台镜像失败: %w", err)
} }
layers, err := img.Layers()
if err != nil {
return nil, nil, fmt.Errorf("获取镜像层失败: %w", err)
}
configFile, err := img.ConfigFile()
if err != nil {
return nil, nil, fmt.Errorf("获取镜像配置失败: %w", err)
}
log.Printf("镜像包含 %d 层", len(layers))
err = is.streamDockerFormatWithReturn(ctx, tarWriter, img, layers, configFile, imageRef, &manifest, &repositories)
if err != nil {
return nil, nil, err
}
case types.OCIManifestSchema1, types.DockerManifestSchema2: case types.OCIManifestSchema1, types.DockerManifestSchema2:
img, err := desc.Image() img, err = desc.Image()
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("获取镜像失败: %w", err) return nil, nil, fmt.Errorf("获取镜像失败: %w", err)
} }
layers, err := img.Layers()
if err != nil {
return nil, nil, fmt.Errorf("获取镜像层失败: %w", err)
}
configFile, err := img.ConfigFile()
if err != nil {
return nil, nil, fmt.Errorf("获取镜像配置失败: %w", err)
}
log.Printf("镜像包含 %d 层", len(layers))
err = is.streamDockerFormatWithReturn(ctx, tarWriter, img, layers, configFile, imageRef, &manifest, &repositories)
if err != nil {
return nil, nil, err
}
default: default:
img, err := desc.Image() img, err = desc.Image()
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("获取镜像失败: %w", err) return nil, nil, fmt.Errorf("获取镜像失败: %w", err)
} }
layers, err := img.Layers()
if err != nil {
return nil, nil, fmt.Errorf("获取镜像层失败: %w", err)
}
configFile, err := img.ConfigFile()
if err != nil {
return nil, nil, fmt.Errorf("获取镜像配置失败: %w", err)
}
log.Printf("镜像包含 %d 层", len(layers))
err = is.streamDockerFormatWithReturn(ctx, tarWriter, img, layers, configFile, imageRef, &manifest, &repositories)
if err != nil {
return nil, nil, err
}
} }
return manifest, repositories, nil return is.processImageForBatch(ctx, img, tarWriter, imageRef, options)
} }
// selectPlatformImage 从多架构镜像中选择合适的平台镜像 // selectPlatformImage 从多架构镜像中选择合适的平台镜像
@@ -586,7 +513,6 @@ func (is *ImageStreamer) selectPlatformImage(desc *remote.Descriptor, options *S
return nil, fmt.Errorf("获取索引清单失败: %w", err) return nil, fmt.Errorf("获取索引清单失败: %w", err)
} }
// 选择合适的平台
var selectedDesc *v1.Descriptor var selectedDesc *v1.Descriptor
for _, m := range manifest.Manifests { for _, m := range manifest.Manifests {
if m.Platform == nil { if m.Platform == nil {
@@ -604,8 +530,8 @@ func (is *ImageStreamer) selectPlatformImage(desc *remote.Descriptor, options *S
} }
if m.Platform.OS == targetOS && if m.Platform.OS == targetOS &&
m.Platform.Architecture == targetArch && m.Platform.Architecture == targetArch &&
m.Platform.Variant == targetVariant { m.Platform.Variant == targetVariant {
selectedDesc = &m selectedDesc = &m
break break
} }
@@ -634,10 +560,9 @@ func (is *ImageStreamer) selectPlatformImage(desc *remote.Descriptor, options *S
var globalImageStreamer *ImageStreamer var globalImageStreamer *ImageStreamer
// initImageStreamer 初始化镜像下载器 // InitImageStreamer 初始化镜像下载器
func initImageStreamer() { func InitImageStreamer() {
globalImageStreamer = NewImageStreamer(nil) globalImageStreamer = NewImageStreamer(nil)
// 镜像下载器初始化完成
} }
// formatPlatformText 格式化平台文本 // formatPlatformText 格式化平台文本
@@ -648,13 +573,13 @@ func formatPlatformText(platform string) string {
return platform return platform
} }
// initImageTarRoutes 初始化镜像下载路由 // InitImageTarRoutes 初始化镜像下载路由
func initImageTarRoutes(router *gin.Engine) { func InitImageTarRoutes(router *gin.Engine) {
imageAPI := router.Group("/api/image") imageAPI := router.Group("/api/image")
{ {
imageAPI.GET("/download/:image", RateLimitMiddleware(globalLimiter), handleDirectImageDownload) imageAPI.GET("/download/:image", handleDirectImageDownload)
imageAPI.GET("/info/:image", RateLimitMiddleware(globalLimiter), handleImageInfo) imageAPI.GET("/info/:image", handleImageInfo)
imageAPI.POST("/batch", RateLimitMiddleware(globalLimiter), handleSimpleBatchDownload) imageAPI.POST("/batch", handleSimpleBatchDownload)
} }
} }
@@ -669,6 +594,7 @@ func handleDirectImageDownload(c *gin.Context) {
imageRef := strings.ReplaceAll(imageParam, "_", "/") imageRef := strings.ReplaceAll(imageParam, "_", "/")
platform := c.Query("platform") platform := c.Query("platform")
tag := c.DefaultQuery("tag", "") tag := c.DefaultQuery("tag", "")
useCompressed := c.DefaultQuery("compressed", "true") == "true"
if tag != "" && !strings.Contains(imageRef, ":") && !strings.Contains(imageRef, "@") { if tag != "" && !strings.Contains(imageRef, ":") && !strings.Contains(imageRef, "@") {
imageRef = imageRef + ":" + tag imageRef = imageRef + ":" + tag
@@ -681,21 +607,21 @@ func handleDirectImageDownload(c *gin.Context) {
return return
} }
// 防抖检查
userID := getUserID(c) userID := getUserID(c)
contentKey := generateContentFingerprint([]string{imageRef}, platform) contentKey := generateContentFingerprint([]string{imageRef}, platform)
if !singleImageDebouncer.ShouldAllow(userID, contentKey) { if !singleImageDebouncer.ShouldAllow(userID, contentKey) {
c.JSON(http.StatusTooManyRequests, gin.H{ c.JSON(http.StatusTooManyRequests, gin.H{
"error": "请求过于频繁,请稍后再试", "error": "请求过于频繁,请稍后再试",
"retry_after": 5, "retry_after": 5,
}) })
return return
} }
options := &StreamOptions{ options := &StreamOptions{
Platform: platform, Platform: platform,
Compression: false, Compression: false,
UseCompressedLayers: useCompressed,
} }
ctx := c.Request.Context() ctx := c.Request.Context()
@@ -711,8 +637,9 @@ func handleDirectImageDownload(c *gin.Context) {
// handleSimpleBatchDownload 处理批量下载 // handleSimpleBatchDownload 处理批量下载
func handleSimpleBatchDownload(c *gin.Context) { func handleSimpleBatchDownload(c *gin.Context) {
var req struct { var req struct {
Images []string `json:"images" binding:"required"` Images []string `json:"images" binding:"required"`
Platform string `json:"platform"` Platform string `json:"platform"`
UseCompressedLayers *bool `json:"useCompressedLayers"`
} }
if err := c.ShouldBindJSON(&req); err != nil { if err := c.ShouldBindJSON(&req); err != nil {
@@ -731,7 +658,7 @@ func handleSimpleBatchDownload(c *gin.Context) {
} }
} }
cfg := GetConfig() cfg := config.GetConfig()
if len(req.Images) > cfg.Download.MaxImages { if len(req.Images) > cfg.Download.MaxImages {
c.JSON(http.StatusBadRequest, gin.H{ c.JSON(http.StatusBadRequest, gin.H{
"error": fmt.Sprintf("镜像数量超过限制,最大允许: %d", cfg.Download.MaxImages), "error": fmt.Sprintf("镜像数量超过限制,最大允许: %d", cfg.Download.MaxImages),
@@ -739,21 +666,26 @@ func handleSimpleBatchDownload(c *gin.Context) {
return return
} }
// 批量下载防抖检查
userID := getUserID(c) userID := getUserID(c)
contentKey := generateContentFingerprint(req.Images, req.Platform) contentKey := generateContentFingerprint(req.Images, req.Platform)
if !batchImageDebouncer.ShouldAllow(userID, contentKey) { if !batchImageDebouncer.ShouldAllow(userID, contentKey) {
c.JSON(http.StatusTooManyRequests, gin.H{ c.JSON(http.StatusTooManyRequests, gin.H{
"error": "批量下载请求过于频繁,请稍后再试", "error": "批量下载请求过于频繁,请稍后再试",
"retry_after": 30, "retry_after": 60,
}) })
return return
} }
useCompressed := true
if req.UseCompressedLayers != nil {
useCompressed = *req.UseCompressedLayers
}
options := &StreamOptions{ options := &StreamOptions{
Platform: req.Platform, Platform: req.Platform,
Compression: false, Compression: false,
UseCompressedLayers: useCompressed,
} }
ctx := c.Request.Context() ctx := c.Request.Context()
@@ -833,7 +765,7 @@ func handleImageInfo(c *gin.Context) {
// StreamMultipleImages 批量下载多个镜像 // StreamMultipleImages 批量下载多个镜像
func (is *ImageStreamer) StreamMultipleImages(ctx context.Context, imageRefs []string, writer io.Writer, options *StreamOptions) error { func (is *ImageStreamer) StreamMultipleImages(ctx context.Context, imageRefs []string, writer io.Writer, options *StreamOptions) error {
if options == nil { if options == nil {
options = &StreamOptions{} options = &StreamOptions{UseCompressedLayers: true}
} }
var finalWriter io.Writer = writer var finalWriter io.Writer = writer
@@ -849,7 +781,6 @@ func (is *ImageStreamer) StreamMultipleImages(ctx context.Context, imageRefs []s
var allManifests []map[string]interface{} var allManifests []map[string]interface{}
var allRepositories = make(map[string]map[string]string) var allRepositories = make(map[string]map[string]string)
// 流式处理每个镜像
for i, imageRef := range imageRefs { for i, imageRef := range imageRefs {
select { select {
case <-ctx.Done(): case <-ctx.Done():
@@ -859,7 +790,6 @@ func (is *ImageStreamer) StreamMultipleImages(ctx context.Context, imageRefs []s
log.Printf("处理镜像 %d/%d: %s", i+1, len(imageRefs), imageRef) log.Printf("处理镜像 %d/%d: %s", i+1, len(imageRefs), imageRef)
// 防止单个镜像处理时间过长
timeoutCtx, cancel := context.WithTimeout(ctx, 15*time.Minute) timeoutCtx, cancel := context.WithTimeout(ctx, 15*time.Minute)
manifest, repositories, err := is.streamSingleImageForBatch(timeoutCtx, tarWriter, imageRef, options) manifest, repositories, err := is.streamSingleImageForBatch(timeoutCtx, tarWriter, imageRef, options)
cancel() cancel()
@@ -873,10 +803,8 @@ func (is *ImageStreamer) StreamMultipleImages(ctx context.Context, imageRefs []s
return fmt.Errorf("镜像 %s manifest数据为空", imageRef) return fmt.Errorf("镜像 %s manifest数据为空", imageRef)
} }
// 收集manifest信息
allManifests = append(allManifests, manifest) allManifests = append(allManifests, manifest)
// 合并repositories信息
for repo, tags := range repositories { for repo, tags := range repositories {
if allRepositories[repo] == nil { if allRepositories[repo] == nil {
allRepositories[repo] = make(map[string]string) allRepositories[repo] = make(map[string]string)
@@ -887,7 +815,6 @@ func (is *ImageStreamer) StreamMultipleImages(ctx context.Context, imageRefs []s
} }
} }
// 写入合并的manifest.json
manifestData, err := json.Marshal(allManifests) manifestData, err := json.Marshal(allManifests)
if err != nil { if err != nil {
return fmt.Errorf("序列化manifest失败: %w", err) return fmt.Errorf("序列化manifest失败: %w", err)
@@ -907,7 +834,6 @@ func (is *ImageStreamer) StreamMultipleImages(ctx context.Context, imageRefs []s
return fmt.Errorf("写入manifest数据失败: %w", err) return fmt.Errorf("写入manifest数据失败: %w", err)
} }
// 写入合并的repositories文件
repositoriesData, err := json.Marshal(allRepositories) repositoriesData, err := json.Marshal(allRepositories)
if err != nil { if err != nil {
return fmt.Errorf("序列化repositories失败: %w", err) return fmt.Errorf("序列化repositories失败: %w", err)

View File

@@ -1,4 +1,4 @@
package main package handlers
import ( import (
"context" "context"
@@ -13,6 +13,7 @@ import (
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"hubproxy/utils"
) )
// SearchResult Docker Hub搜索结果 // SearchResult Docker Hub搜索结果
@@ -25,27 +26,27 @@ type SearchResult struct {
// Repository 仓库信息 // Repository 仓库信息
type Repository struct { type Repository struct {
Name string `json:"repo_name"` Name string `json:"repo_name"`
Description string `json:"short_description"` Description string `json:"short_description"`
IsOfficial bool `json:"is_official"` IsOfficial bool `json:"is_official"`
IsAutomated bool `json:"is_automated"` IsAutomated bool `json:"is_automated"`
StarCount int `json:"star_count"` StarCount int `json:"star_count"`
PullCount int `json:"pull_count"` PullCount int `json:"pull_count"`
RepoOwner string `json:"repo_owner"` RepoOwner string `json:"repo_owner"`
LastUpdated string `json:"last_updated"` LastUpdated string `json:"last_updated"`
Status int `json:"status"` Status int `json:"status"`
Organization string `json:"affiliation"` Organization string `json:"affiliation"`
PullsLastWeek int `json:"pulls_last_week"` PullsLastWeek int `json:"pulls_last_week"`
Namespace string `json:"namespace"` Namespace string `json:"namespace"`
} }
// TagInfo 标签信息 // TagInfo 标签信息
type TagInfo struct { type TagInfo struct {
Name string `json:"name"` Name string `json:"name"`
FullSize int64 `json:"full_size"` FullSize int64 `json:"full_size"`
LastUpdated time.Time `json:"last_updated"` LastUpdated time.Time `json:"last_updated"`
LastPusher string `json:"last_pusher"` LastPusher string `json:"last_pusher"`
Images []Image `json:"images"` Images []Image `json:"images"`
Vulnerabilities struct { Vulnerabilities struct {
Critical int `json:"critical"` Critical int `json:"critical"`
High int `json:"high"` High int `json:"high"`
@@ -66,20 +67,27 @@ type Image struct {
Size int64 `json:"size"` Size int64 `json:"size"`
} }
// TagPageResult 分页标签结果
type TagPageResult struct {
Tags []TagInfo `json:"tags"`
HasMore bool `json:"has_more"`
}
type cacheEntry struct { type cacheEntry struct {
data interface{} data interface{}
timestamp time.Time expiresAt time.Time
} }
const ( const (
maxCacheSize = 1000 // 最大缓存条目数 maxCacheSize = 1000
cacheTTL = 30 * time.Minute maxPaginationCache = 200
cacheTTL = 30 * time.Minute
) )
type Cache struct { type Cache struct {
data map[string]cacheEntry data map[string]cacheEntry
mu sync.RWMutex mu sync.RWMutex
maxSize int maxSize int
} }
var ( var (
@@ -98,7 +106,7 @@ func (c *Cache) Get(key string) (interface{}, bool) {
return nil, false return nil, false
} }
if time.Since(entry.timestamp) > cacheTTL { if time.Now().After(entry.expiresAt) {
c.mu.Lock() c.mu.Lock()
delete(c.data, key) delete(c.data, key)
c.mu.Unlock() c.mu.Unlock()
@@ -109,49 +117,43 @@ func (c *Cache) Get(key string) (interface{}, bool) {
} }
func (c *Cache) Set(key string, data interface{}) { func (c *Cache) Set(key string, data interface{}) {
c.SetWithTTL(key, data, cacheTTL)
}
func (c *Cache) SetWithTTL(key string, data interface{}, ttl time.Duration) {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
now := time.Now()
for k, v := range c.data {
if now.Sub(v.timestamp) > cacheTTL {
delete(c.data, k)
}
}
if len(c.data) >= c.maxSize { if len(c.data) >= c.maxSize {
toDelete := len(c.data) / 4 c.cleanupExpiredLocked()
for k := range c.data {
if toDelete <= 0 {
break
}
delete(c.data, k)
toDelete--
}
} }
c.data[key] = cacheEntry{ c.data[key] = cacheEntry{
data: data, data: data,
timestamp: now, expiresAt: time.Now().Add(ttl),
} }
} }
func (c *Cache) Cleanup() { func (c *Cache) Cleanup() {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
c.cleanupExpiredLocked()
}
func (c *Cache) cleanupExpiredLocked() {
now := time.Now() now := time.Now()
for key, entry := range c.data { for key, entry := range c.data {
if now.Sub(entry.timestamp) > cacheTTL { if now.After(entry.expiresAt) {
delete(c.data, key) delete(c.data, key)
} }
} }
} }
// 定期清理过期缓存
func init() { func init() {
go func() { go func() {
ticker := time.NewTicker(5 * time.Minute) ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for range ticker.C { for range ticker.C {
searchCache.Cleanup() searchCache.Cleanup()
} }
@@ -163,67 +165,85 @@ func filterSearchResults(results []Repository, query string) []Repository {
filtered := make([]Repository, 0) filtered := make([]Repository, 0)
for _, repo := range results { for _, repo := range results {
// 标准化仓库名称
repoName := strings.ToLower(repo.Name) repoName := strings.ToLower(repo.Name)
repoDesc := strings.ToLower(repo.Description) repoDesc := strings.ToLower(repo.Description)
// 计算相关性得分
score := 0 score := 0
// 完全匹配
if repoName == searchTerm { if repoName == searchTerm {
score += 100 score += 100
} }
// 前缀匹配
if strings.HasPrefix(repoName, searchTerm) { if strings.HasPrefix(repoName, searchTerm) {
score += 50 score += 50
} }
// 包含匹配
if strings.Contains(repoName, searchTerm) { if strings.Contains(repoName, searchTerm) {
score += 30 score += 30
} }
// 描述匹配
if strings.Contains(repoDesc, searchTerm) { if strings.Contains(repoDesc, searchTerm) {
score += 10 score += 10
} }
// 官方镜像加分
if repo.IsOfficial { if repo.IsOfficial {
score += 20 score += 20
} }
// 分数达到阈值的结果才保留
if score > 0 { if score > 0 {
filtered = append(filtered, repo) filtered = append(filtered, repo)
} }
} }
// 按相关性排序
sort.Slice(filtered, func(i, j int) bool { sort.Slice(filtered, func(i, j int) bool {
// 优先考虑官方镜像
if filtered[i].IsOfficial != filtered[j].IsOfficial { if filtered[i].IsOfficial != filtered[j].IsOfficial {
return filtered[i].IsOfficial return filtered[i].IsOfficial
} }
// 其次考虑拉取次数
return filtered[i].PullCount > filtered[j].PullCount return filtered[i].PullCount > filtered[j].PullCount
}) })
return filtered return filtered
} }
// normalizeRepository 统一规范化仓库信息
func normalizeRepository(repo *Repository) {
if repo.IsOfficial {
repo.Namespace = "library"
if !strings.Contains(repo.Name, "/") {
repo.Name = "library/" + repo.Name
}
} else {
if repo.Namespace == "" && repo.RepoOwner != "" {
repo.Namespace = repo.RepoOwner
}
if strings.Contains(repo.Name, "/") {
parts := strings.Split(repo.Name, "/")
if len(parts) > 1 {
if repo.Namespace == "" {
repo.Namespace = parts[0]
}
repo.Name = parts[len(parts)-1]
}
}
}
}
// searchDockerHub 搜索镜像 // searchDockerHub 搜索镜像
func searchDockerHub(ctx context.Context, query string, page, pageSize int) (*SearchResult, error) { func searchDockerHub(ctx context.Context, query string, page, pageSize int) (*SearchResult, error) {
return searchDockerHubWithDepth(ctx, query, page, pageSize, 0)
}
func searchDockerHubWithDepth(ctx context.Context, query string, page, pageSize int, depth int) (*SearchResult, error) {
if depth > 1 {
return nil, fmt.Errorf("搜索请求过于复杂,请尝试更具体的关键词")
}
cacheKey := fmt.Sprintf("search:%s:%d:%d", query, page, pageSize) cacheKey := fmt.Sprintf("search:%s:%d:%d", query, page, pageSize)
// 尝试从缓存获取
if cached, ok := searchCache.Get(cacheKey); ok { if cached, ok := searchCache.Get(cacheKey); ok {
return cached.(*SearchResult), nil return cached.(*SearchResult), nil
} }
// 判断是否是用户/仓库格式的搜索
isUserRepo := strings.Contains(query, "/") isUserRepo := strings.Contains(query, "/")
var namespace, repoName string var namespace, repoName string
@@ -235,20 +255,17 @@ func searchDockerHub(ctx context.Context, query string, page, pageSize int) (*Se
} }
} }
// 构建搜索URL
baseURL := "https://registry.hub.docker.com/v2" baseURL := "https://registry.hub.docker.com/v2"
var fullURL string var fullURL string
var params url.Values var params url.Values
if isUserRepo && namespace != "" { if isUserRepo && namespace != "" {
// 如果是用户/仓库格式使用repositories接口
fullURL = fmt.Sprintf("%s/repositories/%s/", baseURL, namespace) fullURL = fmt.Sprintf("%s/repositories/%s/", baseURL, namespace)
params = url.Values{ params = url.Values{
"page": {fmt.Sprintf("%d", page)}, "page": {fmt.Sprintf("%d", page)},
"page_size": {fmt.Sprintf("%d", pageSize)}, "page_size": {fmt.Sprintf("%d", pageSize)},
} }
} else { } else {
// 普通搜索
fullURL = baseURL + "/search/repositories/" fullURL = baseURL + "/search/repositories/"
params = url.Values{ params = url.Values{
"query": {query}, "query": {query},
@@ -259,16 +276,11 @@ func searchDockerHub(ctx context.Context, query string, page, pageSize int) (*Se
fullURL = fullURL + "?" + params.Encode() fullURL = fullURL + "?" + params.Encode()
// 使用统一的搜索HTTP客户端 resp, err := utils.GetSearchHTTPClient().Get(fullURL)
resp, err := GetSearchHTTPClient().Get(fullURL)
if err != nil { if err != nil {
return nil, fmt.Errorf("请求Docker Hub API失败: %v", err) return nil, fmt.Errorf("请求Docker Hub API失败: %v", err)
} }
defer func() { defer safeCloseResponseBody(resp.Body, "搜索响应体")
if err := resp.Body.Close(); err != nil {
fmt.Printf("关闭搜索响应体失败: %v\n", err)
}
}()
body, err := io.ReadAll(resp.Body) body, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
@@ -281,8 +293,7 @@ func searchDockerHub(ctx context.Context, query string, page, pageSize int) (*Se
return nil, fmt.Errorf("请求过于频繁,请稍后重试") return nil, fmt.Errorf("请求过于频繁,请稍后重试")
case http.StatusNotFound: case http.StatusNotFound:
if isUserRepo && namespace != "" { if isUserRepo && namespace != "" {
// 如果用户仓库搜索失败,尝试普通搜索 return searchDockerHubWithDepth(ctx, repoName, page, pageSize, depth+1)
return searchDockerHub(ctx, repoName, page, pageSize)
} }
return nil, fmt.Errorf("未找到相关镜像") return nil, fmt.Errorf("未找到相关镜像")
case http.StatusBadGateway, http.StatusServiceUnavailable: case http.StatusBadGateway, http.StatusServiceUnavailable:
@@ -292,10 +303,8 @@ func searchDockerHub(ctx context.Context, query string, page, pageSize int) (*Se
} }
} }
// 解析响应
var result *SearchResult var result *SearchResult
if isUserRepo && namespace != "" { if isUserRepo && namespace != "" {
// 解析用户仓库列表响应
var userRepos struct { var userRepos struct {
Count int `json:"count"` Count int `json:"count"`
Next string `json:"next"` Next string `json:"next"`
@@ -306,7 +315,6 @@ func searchDockerHub(ctx context.Context, query string, page, pageSize int) (*Se
return nil, fmt.Errorf("解析响应失败: %v", err) return nil, fmt.Errorf("解析响应失败: %v", err)
} }
// 转换为SearchResult格式
result = &SearchResult{ result = &SearchResult{
Count: userRepos.Count, Count: userRepos.Count,
Next: userRepos.Next, Next: userRepos.Next,
@@ -314,52 +322,29 @@ func searchDockerHub(ctx context.Context, query string, page, pageSize int) (*Se
Results: make([]Repository, 0), Results: make([]Repository, 0),
} }
// 处理结果
for _, repo := range userRepos.Results { for _, repo := range userRepos.Results {
// 如果指定了仓库名,只保留匹配的结果
if repoName == "" || strings.Contains(strings.ToLower(repo.Name), strings.ToLower(repoName)) { if repoName == "" || strings.Contains(strings.ToLower(repo.Name), strings.ToLower(repoName)) {
// 确保设置正确的命名空间和名称
repo.Namespace = namespace repo.Namespace = namespace
if !strings.Contains(repo.Name, "/") { normalizeRepository(&repo)
repo.Name = fmt.Sprintf("%s/%s", namespace, repo.Name)
}
result.Results = append(result.Results, repo) result.Results = append(result.Results, repo)
} }
} }
// 如果没有找到结果,尝试普通搜索
if len(result.Results) == 0 { if len(result.Results) == 0 {
return searchDockerHub(ctx, repoName, page, pageSize) return searchDockerHubWithDepth(ctx, repoName, page, pageSize, depth+1)
} }
result.Count = len(result.Results) result.Count = len(result.Results)
} else { } else {
// 解析普通搜索响应
result = &SearchResult{} result = &SearchResult{}
if err := json.Unmarshal(body, &result); err != nil { if err := json.Unmarshal(body, &result); err != nil {
return nil, fmt.Errorf("解析响应失败: %v", err) return nil, fmt.Errorf("解析响应失败: %v", err)
} }
// 处理搜索结果
for i := range result.Results { for i := range result.Results {
if result.Results[i].IsOfficial { normalizeRepository(&result.Results[i])
if !strings.Contains(result.Results[i].Name, "/") {
result.Results[i].Name = "library/" + result.Results[i].Name
}
result.Results[i].Namespace = "library"
} else {
parts := strings.Split(result.Results[i].Name, "/")
if len(parts) > 1 {
result.Results[i].Namespace = parts[0]
result.Results[i].Name = parts[1]
} else if result.Results[i].RepoOwner != "" {
result.Results[i].Namespace = result.Results[i].RepoOwner
result.Results[i].Name = fmt.Sprintf("%s/%s", result.Results[i].RepoOwner, result.Results[i].Name)
}
}
} }
// 如果是用户/仓库搜索,过滤结果
if isUserRepo && namespace != "" { if isUserRepo && namespace != "" {
filteredResults := make([]Repository, 0) filteredResults := make([]Repository, 0)
for _, repo := range result.Results { for _, repo := range result.Results {
@@ -372,22 +357,19 @@ func searchDockerHub(ctx context.Context, query string, page, pageSize int) (*Se
} }
} }
// 缓存结果
searchCache.Set(cacheKey, result) searchCache.Set(cacheKey, result)
return result, nil return result, nil
} }
// 判断错误是否可重试
func isRetryableError(err error) bool { func isRetryableError(err error) bool {
if err == nil { if err == nil {
return false return false
} }
// 网络错误、超时等可以重试
if strings.Contains(err.Error(), "timeout") || if strings.Contains(err.Error(), "timeout") ||
strings.Contains(err.Error(), "connection refused") || strings.Contains(err.Error(), "connection refused") ||
strings.Contains(err.Error(), "no such host") || strings.Contains(err.Error(), "no such host") ||
strings.Contains(err.Error(), "too many requests") { strings.Contains(err.Error(), "too many requests") {
return true return true
} }
@@ -395,106 +377,183 @@ func isRetryableError(err error) bool {
} }
// getRepositoryTags 获取仓库标签信息 // getRepositoryTags 获取仓库标签信息
func getRepositoryTags(ctx context.Context, namespace, name string) ([]TagInfo, error) { func getRepositoryTags(ctx context.Context, namespace, name string, page, pageSize int) ([]TagInfo, bool, error) {
if namespace == "" || name == "" { if namespace == "" || name == "" {
return nil, fmt.Errorf("无效输入:命名空间和名称不能为空") return nil, false, fmt.Errorf("无效输入:命名空间和名称不能为空")
} }
cacheKey := fmt.Sprintf("tags:%s:%s", namespace, name) if page <= 0 {
page = 1
}
if pageSize <= 0 || pageSize > 100 {
pageSize = 100
}
cacheKey := fmt.Sprintf("tags:%s:%s:page_%d", namespace, name, page)
if cached, ok := searchCache.Get(cacheKey); ok { if cached, ok := searchCache.Get(cacheKey); ok {
return cached.([]TagInfo), nil result := cached.(TagPageResult)
return result.Tags, result.HasMore, nil
} }
// 构建API URL
baseURL := fmt.Sprintf("https://registry.hub.docker.com/v2/repositories/%s/%s/tags", namespace, name) baseURL := fmt.Sprintf("https://registry.hub.docker.com/v2/repositories/%s/%s/tags", namespace, name)
params := url.Values{} params := url.Values{}
params.Set("page_size", "100") params.Set("page", fmt.Sprintf("%d", page))
params.Set("page_size", fmt.Sprintf("%d", pageSize))
params.Set("ordering", "last_updated") params.Set("ordering", "last_updated")
fullURL := baseURL + "?" + params.Encode() fullURL := baseURL + "?" + params.Encode()
// 使用统一的搜索HTTP客户端 pageResult, err := fetchTagPage(ctx, fullURL, 3)
resp, err := GetSearchHTTPClient().Get(fullURL)
if err != nil { if err != nil {
return nil, fmt.Errorf("发送请求失败: %v", err) return nil, false, fmt.Errorf("获取标签失败: %v", err)
} }
defer func() {
if err := resp.Body.Close(); err != nil { hasMore := pageResult.Next != ""
fmt.Printf("关闭搜索响应体失败: %v\n", err)
result := TagPageResult{Tags: pageResult.Results, HasMore: hasMore}
searchCache.SetWithTTL(cacheKey, result, 30*time.Minute)
return pageResult.Results, hasMore, nil
}
func fetchTagPage(ctx context.Context, url string, maxRetries int) (*struct {
Count int `json:"count"`
Next string `json:"next"`
Previous string `json:"previous"`
Results []TagInfo `json:"results"`
}, error) {
var lastErr error
for retry := 0; retry < maxRetries; retry++ {
if retry > 0 {
time.Sleep(time.Duration(retry) * 500 * time.Millisecond)
} }
}()
// 读取响应体 resp, err := utils.GetSearchHTTPClient().Get(url)
body, err := io.ReadAll(resp.Body) if err != nil {
if err != nil { lastErr = err
return nil, fmt.Errorf("读取响应失败: %v", err) if isRetryableError(err) && retry < maxRetries-1 {
continue
}
return nil, fmt.Errorf("发送请求失败: %v", err)
}
body, err := func() ([]byte, error) {
defer safeCloseResponseBody(resp.Body, "标签响应体")
return io.ReadAll(resp.Body)
}()
if err != nil {
lastErr = err
if retry < maxRetries-1 {
continue
}
return nil, fmt.Errorf("读取响应失败: %v", err)
}
if resp.StatusCode != http.StatusOK {
lastErr = fmt.Errorf("状态码=%d, 响应=%s", resp.StatusCode, string(body))
if resp.StatusCode >= 400 && resp.StatusCode < 500 && resp.StatusCode != 429 {
return nil, fmt.Errorf("请求失败: %v", lastErr)
}
if retry < maxRetries-1 {
continue
}
return nil, fmt.Errorf("请求失败: %v", lastErr)
}
var result struct {
Count int `json:"count"`
Next string `json:"next"`
Previous string `json:"previous"`
Results []TagInfo `json:"results"`
}
if err := json.Unmarshal(body, &result); err != nil {
lastErr = err
if retry < maxRetries-1 {
continue
}
return nil, fmt.Errorf("解析响应失败: %v", err)
}
return &result, nil
} }
// 检查响应状态码 return nil, lastErr
if resp.StatusCode != http.StatusOK { }
return nil, fmt.Errorf("请求失败: 状态码=%d, 响应=%s", resp.StatusCode, string(body))
func parsePaginationParams(c *gin.Context, defaultPageSize int) (page, pageSize int) {
page = 1
pageSize = defaultPageSize
if p := c.Query("page"); p != "" {
fmt.Sscanf(p, "%d", &page)
}
if ps := c.Query("page_size"); ps != "" {
fmt.Sscanf(ps, "%d", &pageSize)
} }
// 解析响应 return page, pageSize
var result struct { }
Count int `json:"count"`
Next string `json:"next"`
Previous string `json:"previous"`
Results []TagInfo `json:"results"`
}
if err := json.Unmarshal(body, &result); err != nil {
return nil, fmt.Errorf("解析响应失败: %v", err)
}
// 缓存结果 func safeCloseResponseBody(body io.ReadCloser, context string) {
searchCache.Set(cacheKey, result.Results) if body != nil {
return result.Results, nil if err := body.Close(); err != nil {
fmt.Printf("关闭%s失败: %v\n", context, err)
}
}
}
func sendErrorResponse(c *gin.Context, message string) {
c.JSON(http.StatusBadRequest, gin.H{"error": message})
} }
// RegisterSearchRoute 注册搜索相关路由 // RegisterSearchRoute 注册搜索相关路由
func RegisterSearchRoute(r *gin.Engine) { func RegisterSearchRoute(r *gin.Engine) {
// 搜索镜像
r.GET("/search", func(c *gin.Context) { r.GET("/search", func(c *gin.Context) {
query := c.Query("q") query := c.Query("q")
if query == "" { if query == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "搜索关键词不能为空"}) sendErrorResponse(c, "搜索关键词不能为空")
return return
} }
page := 1 page, pageSize := parsePaginationParams(c, 25)
pageSize := 25
if p := c.Query("page"); p != "" {
fmt.Sscanf(p, "%d", &page)
}
if ps := c.Query("page_size"); ps != "" {
fmt.Sscanf(ps, "%d", &pageSize)
}
result, err := searchDockerHub(c.Request.Context(), query, page, pageSize) result, err := searchDockerHub(c.Request.Context(), query, page, pageSize)
if err != nil { if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) sendErrorResponse(c, err.Error())
return return
} }
c.JSON(http.StatusOK, result) c.JSON(http.StatusOK, result)
}) })
// 获取标签信息
r.GET("/tags/:namespace/:name", func(c *gin.Context) { r.GET("/tags/:namespace/:name", func(c *gin.Context) {
namespace := c.Param("namespace") namespace := c.Param("namespace")
name := c.Param("name") name := c.Param("name")
if namespace == "" || name == "" { if namespace == "" || name == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "命名空间和名称不能为空"}) sendErrorResponse(c, "命名空间和名称不能为空")
return return
} }
tags, err := getRepositoryTags(c.Request.Context(), namespace, name) page, pageSize := parsePaginationParams(c, 100)
tags, hasMore, err := getRepositoryTags(c.Request.Context(), namespace, name, page, pageSize)
if err != nil { if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) sendErrorResponse(c, err.Error())
return return
} }
c.JSON(http.StatusOK, tags) if c.Query("page") != "" || c.Query("page_size") != "" {
c.JSON(http.StatusOK, gin.H{
"tags": tags,
"has_more": hasMore,
"page": page,
"page_size": pageSize,
})
} else {
c.JSON(http.StatusOK, tags)
}
}) })
} }

View File

@@ -3,15 +3,17 @@ package main
import ( import (
"embed" "embed"
"fmt" "fmt"
"io"
"log" "log"
"net/http" "net/http"
"regexp"
"strconv"
"strings" "strings"
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"
"hubproxy/config"
"hubproxy/handlers"
"hubproxy/utils"
) )
//go:embed public/* //go:embed public/*
@@ -32,19 +34,7 @@ func serveEmbedFile(c *gin.Context, filename string) {
} }
var ( var (
exps = []*regexp.Regexp{ globalLimiter *utils.IPRateLimiter
regexp.MustCompile(`^(?:https?://)?github\.com/([^/]+)/([^/]+)/(?:releases|archive)/.*$`),
regexp.MustCompile(`^(?:https?://)?github\.com/([^/]+)/([^/]+)/(?:blob|raw)/.*$`),
regexp.MustCompile(`^(?:https?://)?github\.com/([^/]+)/([^/]+)/(?:info|git-).*$`),
regexp.MustCompile(`^(?:https?://)?raw\.github(?:usercontent|)\.com/([^/]+)/([^/]+)/.+?/.+$`),
regexp.MustCompile(`^(?:https?://)?gist\.github(?:usercontent|)\.com/([^/]+)/.+?/.+`),
regexp.MustCompile(`^(?:https?://)?api\.github\.com/repos/([^/]+)/([^/]+)/.*`),
regexp.MustCompile(`^(?:https?://)?huggingface\.co(?:/spaces)?/([^/]+)/(.+)$`),
regexp.MustCompile(`^(?:https?://)?cdn-lfs\.hf\.co(?:/spaces)?/([^/]+)/([^/]+)(?:/(.*))?$`),
regexp.MustCompile(`^(?:https?://)?download\.docker\.com/([^/]+)/.*\.(tgz|zip)$`),
regexp.MustCompile(`^(?:https?://)?(github|opengraph)\.githubassets\.com/([^/]+)/.+?$`),
}
globalLimiter *IPRateLimiter
// 服务启动时间 // 服务启动时间
serviceStartTime = time.Now() serviceStartTime = time.Now()
@@ -52,25 +42,25 @@ var (
func main() { func main() {
// 加载配置 // 加载配置
if err := LoadConfig(); err != nil { if err := config.LoadConfig(); err != nil {
fmt.Printf("配置加载失败: %v\n", err) fmt.Printf("配置加载失败: %v\n", err)
return return
} }
// 初始化HTTP客户端 // 初始化HTTP客户端
initHTTPClients() utils.InitHTTPClients()
// 初始化限流器 // 初始化限流器
initLimiter() globalLimiter = utils.InitGlobalLimiter()
// 初始化Docker流式代理 // 初始化Docker流式代理
initDockerProxy() handlers.InitDockerProxy()
// 初始化镜像流式下载器 // 初始化镜像流式下载器
initImageStreamer() handlers.InitImageStreamer()
// 初始化防抖器 // 初始化防抖器
initDebouncer() handlers.InitDebouncer()
gin.SetMode(gin.ReleaseMode) gin.SetMode(gin.ReleaseMode)
router := gin.Default() router := gin.Default()
@@ -84,11 +74,14 @@ func main() {
}) })
})) }))
// 全局限流中间件
router.Use(utils.RateLimitMiddleware(globalLimiter))
// 初始化监控端点 // 初始化监控端点
initHealthRoutes(router) initHealthRoutes(router)
// 初始化镜像tar下载路由 // 初始化镜像tar下载路由
initImageTarRoutes(router) handlers.InitImageTarRoutes(router)
// 静态文件路由 // 静态文件路由
router.GET("/", func(c *gin.Context) { router.GET("/", func(c *gin.Context) {
@@ -110,272 +103,87 @@ func main() {
}) })
// 注册dockerhub搜索路由 // 注册dockerhub搜索路由
RegisterSearchRoute(router) handlers.RegisterSearchRoute(router)
// 注册Docker认证路由/token* // 注册Docker认证路由
router.Any("/token", RateLimitMiddleware(globalLimiter), ProxyDockerAuthGin) router.Any("/token", handlers.ProxyDockerAuthGin)
router.Any("/token/*path", RateLimitMiddleware(globalLimiter), ProxyDockerAuthGin) router.Any("/token/*path", handlers.ProxyDockerAuthGin)
// 注册Docker Registry代理路由 // 注册Docker Registry代理路由
router.Any("/v2/*path", RateLimitMiddleware(globalLimiter), ProxyDockerRegistryGin) router.Any("/v2/*path", handlers.ProxyDockerRegistryGin)
// 注册GitHub代理路由NoRoute处理器
router.NoRoute(handlers.GitHubProxyHandler)
// 注册NoRoute处理器 cfg := config.GetConfig()
router.NoRoute(RateLimitMiddleware(globalLimiter), handler)
cfg := GetConfig()
fmt.Printf("🚀 HubProxy 启动成功\n") fmt.Printf("🚀 HubProxy 启动成功\n")
fmt.Printf("📡 监听地址: %s:%d\n", cfg.Server.Host, cfg.Server.Port) fmt.Printf("📡 监听地址: %s:%d\n", cfg.Server.Host, cfg.Server.Port)
fmt.Printf("⚡ 限流配置: %d请求/%g小时\n", cfg.RateLimit.RequestLimit, cfg.RateLimit.PeriodHours) fmt.Printf("⚡ 限流配置: %d请求/%g小时\n", cfg.RateLimit.RequestLimit, cfg.RateLimit.PeriodHours)
// 显示HTTP/2支持状态
if cfg.Server.EnableH2C {
fmt.Printf("H2c: 已启用\n")
}
fmt.Printf("🔗 项目地址: https://github.com/sky22333/hubproxy\n") fmt.Printf("🔗 项目地址: https://github.com/sky22333/hubproxy\n")
err := router.Run(fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port)) // 创建HTTP2服务器
server := &http.Server{
Addr: fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port),
ReadTimeout: 60 * time.Second,
WriteTimeout: 300 * time.Second,
IdleTimeout: 120 * time.Second,
}
// 根据配置决定是否启用H2C
if cfg.Server.EnableH2C {
h2cHandler := h2c.NewHandler(router, &http2.Server{
MaxConcurrentStreams: 250,
IdleTimeout: 300 * time.Second,
MaxReadFrameSize: 4 << 20,
MaxUploadBufferPerConnection: 8 << 20,
MaxUploadBufferPerStream: 2 << 20,
})
server.Handler = h2cHandler
} else {
server.Handler = router
}
err := server.ListenAndServe()
if err != nil { if err != nil {
fmt.Printf("启动服务失败: %v\n", err) fmt.Printf("启动服务失败: %v\n", err)
} }
} }
func handler(c *gin.Context) { // 简单的健康检查
rawPath := strings.TrimPrefix(c.Request.URL.RequestURI(), "/") func formatDuration(d time.Duration) string {
if d < time.Minute {
for strings.HasPrefix(rawPath, "/") { return fmt.Sprintf("%d秒", int(d.Seconds()))
rawPath = strings.TrimPrefix(rawPath, "/") } else if d < time.Hour {
} return fmt.Sprintf("%d分钟%d秒", int(d.Minutes()), int(d.Seconds())%60)
} else if d < 24*time.Hour {
if !strings.HasPrefix(rawPath, "http") { return fmt.Sprintf("%d小时%d分钟", int(d.Hours()), int(d.Minutes())%60)
c.String(http.StatusForbidden, "无效输入")
return
}
matches := checkURL(rawPath)
if matches != nil {
// GitHub仓库访问控制检查
if allowed, reason := GlobalAccessController.CheckGitHubAccess(matches); !allowed {
// 构建仓库名用于日志
var repoPath string
if len(matches) >= 2 {
username := matches[0]
repoName := strings.TrimSuffix(matches[1], ".git")
repoPath = username + "/" + repoName
}
fmt.Printf("GitHub仓库 %s 访问被拒绝: %s\n", repoPath, reason)
c.String(http.StatusForbidden, reason)
return
}
} else { } else {
c.String(http.StatusForbidden, "无效输入") days := int(d.Hours()) / 24
return hours := int(d.Hours()) % 24
} return fmt.Sprintf("%d天%d小时", days, hours)
if exps[1].MatchString(rawPath) {
rawPath = strings.Replace(rawPath, "/blob/", "/raw/", 1)
}
proxy(c, rawPath)
}
func proxy(c *gin.Context, u string) {
proxyWithRedirect(c, u, 0)
}
func proxyWithRedirect(c *gin.Context, u string, redirectCount int) {
// 限制最大重定向次数,防止无限递归
const maxRedirects = 20
if redirectCount > maxRedirects {
c.String(http.StatusLoopDetected, "重定向次数过多,可能存在循环重定向")
return
}
req, err := http.NewRequest(c.Request.Method, u, c.Request.Body)
if err != nil {
c.String(http.StatusInternalServerError, fmt.Sprintf("server error %v", err))
return
}
for key, values := range c.Request.Header {
for _, value := range values {
req.Header.Add(key, value)
}
}
req.Header.Del("Host")
resp, err := GetGlobalHTTPClient().Do(req)
if err != nil {
c.String(http.StatusInternalServerError, fmt.Sprintf("server error %v", err))
return
}
defer func() {
if err := resp.Body.Close(); err != nil {
fmt.Printf("关闭响应体失败: %v\n", err)
}
}()
// 检查文件大小限制
cfg := GetConfig()
if contentLength := resp.Header.Get("Content-Length"); contentLength != "" {
if size, err := strconv.ParseInt(contentLength, 10, 64); err == nil && size > cfg.Server.FileSize {
c.String(http.StatusRequestEntityTooLarge,
fmt.Sprintf("文件过大,限制大小: %d MB", cfg.Server.FileSize/(1024*1024)))
return
}
}
// 清理安全相关的头
resp.Header.Del("Content-Security-Policy")
resp.Header.Del("Referrer-Policy")
resp.Header.Del("Strict-Transport-Security")
// 获取真实域名
realHost := c.Request.Header.Get("X-Forwarded-Host")
if realHost == "" {
realHost = c.Request.Host
}
// 如果域名中没有协议前缀添加https://
if !strings.HasPrefix(realHost, "http://") && !strings.HasPrefix(realHost, "https://") {
realHost = "https://" + realHost
}
if strings.HasSuffix(strings.ToLower(u), ".sh") {
isGzipCompressed := resp.Header.Get("Content-Encoding") == "gzip"
processedBody, processedSize, err := ProcessSmart(resp.Body, isGzipCompressed, realHost)
if err != nil {
fmt.Printf("智能处理失败,回退到直接代理: %v\n", err)
processedBody = resp.Body
processedSize = 0
}
// 智能设置响应头
if processedSize > 0 {
resp.Header.Del("Content-Length")
resp.Header.Del("Content-Encoding")
resp.Header.Set("Transfer-Encoding", "chunked")
}
// 复制其他响应头
for key, values := range resp.Header {
for _, value := range values {
c.Header(key, value)
}
}
if location := resp.Header.Get("Location"); location != "" {
if checkURL(location) != nil {
c.Header("Location", "/"+location)
} else {
proxyWithRedirect(c, location, redirectCount+1)
return
}
}
c.Status(resp.StatusCode)
// 输出处理后的内容
if _, err := io.Copy(c.Writer, processedBody); err != nil {
return
}
} else {
for key, values := range resp.Header {
for _, value := range values {
c.Header(key, value)
}
}
// 处理重定向
if location := resp.Header.Get("Location"); location != "" {
if checkURL(location) != nil {
c.Header("Location", "/"+location)
} else {
proxyWithRedirect(c, location, redirectCount+1)
return
}
}
c.Status(resp.StatusCode)
// 直接流式转发
if _, err := io.Copy(c.Writer, resp.Body); err != nil {
fmt.Printf("直接代理失败: %v\n", err)
}
} }
} }
func checkURL(u string) []string { func getUptimeInfo() (time.Duration, float64, string) {
for _, exp := range exps { uptime := time.Since(serviceStartTime)
if matches := exp.FindStringSubmatch(u); matches != nil { return uptime, uptime.Seconds(), formatDuration(uptime)
return matches[1:]
}
}
return nil
} }
// 初始化健康监控路由
func initHealthRoutes(router *gin.Engine) { func initHealthRoutes(router *gin.Engine) {
// 健康检查端点
router.GET("/health", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"status": "healthy",
"timestamp": time.Now().Unix(),
"uptime": time.Since(serviceStartTime).Seconds(),
"service": "hubproxy",
})
})
// 就绪检查端点
router.GET("/ready", func(c *gin.Context) { router.GET("/ready", func(c *gin.Context) {
checks := make(map[string]string) _, uptimeSec, uptimeHuman := getUptimeInfo()
allReady := true c.JSON(http.StatusOK, gin.H{
"ready": true,
if GetConfig() != nil { "service": "hubproxy",
checks["config"] = "ok" "start_time_unix": serviceStartTime.Unix(),
} else { "uptime_sec": uptimeSec,
checks["config"] = "failed" "uptime_human": uptimeHuman,
allReady = false
}
// 检查全局缓存状态
if globalCache != nil {
checks["cache"] = "ok"
} else {
checks["cache"] = "failed"
allReady = false
}
// 检查限流器状态
if globalLimiter != nil {
checks["ratelimiter"] = "ok"
} else {
checks["ratelimiter"] = "failed"
allReady = false
}
// 检查镜像下载器状态
if globalImageStreamer != nil {
checks["imagestreamer"] = "ok"
} else {
checks["imagestreamer"] = "failed"
allReady = false
}
// 检查HTTP客户端状态
if GetGlobalHTTPClient() != nil {
checks["httpclient"] = "ok"
} else {
checks["httpclient"] = "failed"
allReady = false
}
status := http.StatusOK
if !allReady {
status = http.StatusServiceUnavailable
}
c.JSON(status, gin.H{
"ready": allReady,
"checks": checks,
"timestamp": time.Now().Unix(),
"uptime": time.Since(serviceStartTime).Seconds(),
}) })
}) })
} }

View File

@@ -399,6 +399,67 @@
100% { transform: rotate(360deg); } 100% { transform: rotate(360deg); }
} }
/* 切换开关样式 */
.switch-container {
display: flex;
align-items: center;
gap: 0.75rem;
margin-bottom: 1.5rem;
}
.switch {
position: relative;
display: inline-block;
width: 50px;
height: 24px;
}
.switch input {
opacity: 0;
width: 0;
height: 0;
}
.slider {
position: absolute;
cursor: pointer;
top: 0;
left: 0;
right: 0;
bottom: 0;
background-color: var(--muted);
transition: 0.2s;
border-radius: 24px;
border: 1px solid var(--border);
}
.slider:before {
position: absolute;
content: "";
height: 18px;
width: 18px;
left: 2px;
bottom: 2px;
background-color: white;
transition: 0.2s;
border-radius: 50%;
box-shadow: 0 1px 3px rgba(0,0,0,0.1);
}
input:checked + .slider {
background-color: var(--primary);
}
input:checked + .slider:before {
transform: translateX(26px);
}
.switch-label {
font-weight: 500;
color: var(--foreground);
cursor: pointer;
}
.hidden { .hidden {
display: none; display: none;
} }
@@ -520,7 +581,7 @@
</div> </div>
<div class="feature"> <div class="feature">
<span class="feature-icon">💾</span> <span class="feature-icon">💾</span>
<span>无需打包</span> <span>无需等待</span>
</div> </div>
<div class="feature"> <div class="feature">
<span class="feature-icon">🏗️</span> <span class="feature-icon">🏗️</span>
@@ -559,6 +620,14 @@
</div> </div>
</div> </div>
<div class="switch-container">
<label class="switch">
<input type="checkbox" id="compressedToggle" checked>
<span class="slider"></span>
</label>
<label for="compressedToggle" class="switch-label">使用压缩层(减小包体积)</label>
</div>
<button type="submit" class="btn btn-primary btn-full" id="downloadBtn"> <button type="submit" class="btn btn-primary btn-full" id="downloadBtn">
<span id="downloadText">立即下载</span> <span id="downloadText">立即下载</span>
<span id="downloadLoading" class="loading hidden"></span> <span id="downloadLoading" class="loading hidden"></span>
@@ -573,7 +642,7 @@
<form id="batchForm"> <form id="batchForm">
<div class="form-group"> <div class="form-group">
<label class="form-label" for="imagesTextarea">镜像列表,每行一个,会将多个镜像自动合并,符合官方标准,完全兼容docker load</label> <label class="form-label" for="imagesTextarea">镜像列表每行一个会将多个镜像自动合并符合官方标准兼容docker load</label>
<textarea <textarea
id="imagesTextarea" id="imagesTextarea"
class="textarea" class="textarea"
@@ -595,6 +664,14 @@
</div> </div>
</div> </div>
<div class="switch-container">
<label class="switch">
<input type="checkbox" id="batchCompressedToggle" checked>
<span class="slider"></span>
</label>
<label for="batchCompressedToggle" class="switch-label">使用压缩层(减小包体积)</label>
</div>
<button type="submit" class="btn btn-primary btn-full" id="batchDownloadBtn"> <button type="submit" class="btn btn-primary btn-full" id="batchDownloadBtn">
<span id="batchDownloadText">开始下载</span> <span id="batchDownloadText">开始下载</span>
<span id="batchDownloadLoading" class="loading hidden"></span> <span id="batchDownloadLoading" class="loading hidden"></span>
@@ -651,12 +728,18 @@
} }
} }
function buildDownloadUrl(imageName, platform = '') { function buildDownloadUrl(imageName, platform = '', useCompressed = true) {
const encodedImage = imageName.replace(/\//g, '_'); const encodedImage = imageName.replace(/\//g, '_');
let url = `/api/image/download/${encodedImage}`; let url = `/api/image/download/${encodedImage}`;
const params = new URLSearchParams();
if (platform && platform.trim()) { if (platform && platform.trim()) {
url += `?platform=${encodeURIComponent(platform.trim())}`; params.append('platform', platform.trim());
}
params.append('compressed', useCompressed.toString());
if (params.toString()) {
url += '?' + params.toString();
} }
return url; return url;
@@ -672,11 +755,12 @@
} }
const platform = document.getElementById('platformInput').value.trim(); const platform = document.getElementById('platformInput').value.trim();
const useCompressed = document.getElementById('compressedToggle').checked;
hideStatus('singleStatus'); hideStatus('singleStatus');
setButtonLoading('downloadBtn', 'downloadText', 'downloadLoading', true); setButtonLoading('downloadBtn', 'downloadText', 'downloadLoading', true);
const downloadUrl = buildDownloadUrl(imageName, platform); const downloadUrl = buildDownloadUrl(imageName, platform, useCompressed);
const link = document.createElement('a'); const link = document.createElement('a');
link.href = downloadUrl; link.href = downloadUrl;
@@ -711,9 +795,11 @@
} }
const platform = document.getElementById('batchPlatformInput').value.trim(); const platform = document.getElementById('batchPlatformInput').value.trim();
const useCompressed = document.getElementById('batchCompressedToggle').checked;
const options = { const options = {
images: images images: images,
useCompressedLayers: useCompressed
}; };
if (platform) { if (platform) {

View File

@@ -609,10 +609,10 @@
<div class="card"> <div class="card">
<div class="card-header"> <div class="card-header">
<h2 class="card-title"> <h2 class="card-title">
⚡ 快速生成加速链接 ⚡ 快速转换加速链接
</h2> </h2>
<p class="card-description"> <p class="card-description">
输入GitHub文件或仓库链接自动转换加速链接可以直接在Github域名前面加上本站域名使用。 输入GitHub文件链接自动转换加速链接可以直接在Github文件链接前加上本站域名使用。
</p> </p>
</div> </div>
@@ -622,7 +622,7 @@
type="text" type="text"
class="input" class="input"
id="githubLinkInput" id="githubLinkInput"
placeholder="请输入GitHub链接例如https://github.com/user/repo/releases/download/..." placeholder="请输入GitHub文件链接例如https://github.com/user/repo/releases/download/..."
> >
<button class="button button-primary" id="formatButton"> <button class="button button-primary" id="formatButton">
获取加速链接 获取加速链接
@@ -653,12 +653,12 @@
🐳 Docker 镜像加速 🐳 Docker 镜像加速
</h3> </h3>
<p class="card-description"> <p class="card-description">
支持多种Registry,在镜像名前添加本站域名即可加速下载。 支持多种镜像仓库,在镜像名前添加本站域名即可加速下载。
</p> </p>
</div> </div>
<button class="docker-button" id="dockerButton"> <button class="docker-button" id="dockerButton">
查看 Docker 镜像加速配置 查看 Docker 镜像加速使用说明
</button> </button>
</div> </div>
</div> </div>
@@ -669,23 +669,23 @@
<button class="close-button" id="closeModal">&times;</button> <button class="close-button" id="closeModal">&times;</button>
<div class="modal-header"> <div class="modal-header">
<h2 class="modal-title">Docker 镜像加速</h2> <h2 class="modal-title">Docker 镜像加速</h2>
<p>支持多种Registry,在镜像名前添加本站域名即可加速下载。</p> <p>支持多种镜像仓库,在镜像名前添加本站域名即可加速下载。</p>
</div> </div>
<div class="domain-examples"> <div class="domain-examples">
<strong>Docker Hub 官方镜像:</strong> <strong>Docker 官方镜像:</strong>
docker pull <span class="domain-base"></span>/nginx docker pull <span class="domain-base"></span>/nginx
<strong>Docker Hub 第三方镜像:</strong> <strong>Docker 镜像:</strong>
docker pull <span class="domain-base"></span>/user/image docker pull <span class="domain-base"></span>/user/image
<strong>GitHub Container Registry</strong> <strong>ghcr.io 镜像</strong>
docker pull <span class="domain-base"></span>/ghcr.io/user/image docker pull <span class="domain-base"></span>/ghcr.io/user/image
<strong>Quay.io Registry</strong> <strong>Quay.io 镜像</strong>
docker pull <span class="domain-base"></span>/quay.io/org/image docker pull <span class="domain-base"></span>/quay.io/org/image
<strong>Kubernetes Registry</strong> <strong>K8s 镜像</strong>
docker pull <span class="domain-base"></span>/registry.k8s.io/pause:3.8 docker pull <span class="domain-base"></span>/registry.k8s.io/pause:3.8
</div> </div>
</div> </div>

View File

@@ -778,7 +778,12 @@
</div> </div>
</div> </div>
<div class="tag-list" id="tagList"></div> <div class="tag-list" id="tagList">
<div class="pagination" id="tagPagination" style="display: none;">
<button id="tagPrevPage" disabled>上一页</button>
<button id="tagNextPage" disabled>下一页</button>
</div>
</div>
</div> </div>
<div id="toast"></div> <div id="toast"></div>
@@ -854,6 +859,10 @@
let currentQuery = ''; let currentQuery = '';
let currentRepo = null; let currentRepo = null;
// 标签分页相关变量
let currentTagPage = 1;
let totalTagPages = 1;
document.getElementById('searchButton').addEventListener('click', () => { document.getElementById('searchButton').addEventListener('click', () => {
currentPage = 1; currentPage = 1;
performSearch(); performSearch();
@@ -884,6 +893,21 @@
showSearchResults(); showSearchResults();
}); });
// 使用事件委托处理分页按钮点击避免DOM重建导致事件丢失
document.addEventListener('click', (e) => {
if (e.target.id === 'tagPrevPage') {
if (currentTagPage > 1) {
currentTagPage--;
loadTagPage();
}
} else if (e.target.id === 'tagNextPage') {
if (currentTagPage < totalTagPages) {
currentTagPage++;
loadTagPage();
}
}
});
function showLoading() { function showLoading() {
document.querySelector('.loading').style.display = 'block'; document.querySelector('.loading').style.display = 'block';
} }
@@ -901,71 +925,135 @@
}, 3000); }, 3000);
} }
function updatePagination() { // 统一分页更新函数(支持搜索和标签分页)
const prevButton = document.getElementById('prevPage'); function updatePagination(config = {}) {
const nextButton = document.getElementById('nextPage'); const {
currentPage: page = currentPage,
totalPages: total = totalPages,
prefix = ''
} = config;
prevButton.disabled = currentPage <= 1; const prevButtonId = prefix ? `${prefix}PrevPage` : 'prevPage';
nextButton.disabled = currentPage >= totalPages; const nextButtonId = prefix ? `${prefix}NextPage` : 'nextPage';
const paginationId = prefix ? `${prefix}Pagination` : '.pagination';
const prevButton = document.getElementById(prevButtonId);
const nextButton = document.getElementById(nextButtonId);
const paginationDiv = prefix ? document.getElementById(paginationId) : document.querySelector(paginationId);
if (!prevButton || !nextButton || !paginationDiv) {
return; // 静默处理,避免控制台警告
}
// 更新按钮状态
prevButton.disabled = page <= 1;
nextButton.disabled = page >= total;
// 更新或创建页面信息
const pageInfoId = prefix ? `${prefix}PageInfo` : 'pageInfo';
let pageInfo = document.getElementById(pageInfoId);
const paginationDiv = document.querySelector('.pagination');
let pageInfo = document.getElementById('pageInfo');
if (!pageInfo) { if (!pageInfo) {
const container = document.createElement('div'); pageInfo = createPageInfo(pageInfoId, prefix, total);
container.id = 'pageInfo'; paginationDiv.insertBefore(pageInfo, nextButton);
container.style.margin = '0 10px';
container.style.display = 'flex';
container.style.alignItems = 'center';
container.style.gap = '10px';
const pageText = document.createElement('span');
pageText.id = 'pageText';
const jumpInput = document.createElement('input');
jumpInput.type = 'number';
jumpInput.min = '1';
jumpInput.id = 'jumpPage';
jumpInput.style.width = '60px';
jumpInput.style.padding = '4px';
jumpInput.style.borderRadius = '4px';
jumpInput.style.border = '1px solid var(--border)';
jumpInput.style.backgroundColor = 'var(--input)';
jumpInput.style.color = 'var(--foreground)';
const jumpButton = document.createElement('button');
jumpButton.textContent = '跳转';
jumpButton.className = 'btn search-button';
jumpButton.style.padding = '4px 8px';
jumpButton.onclick = () => {
const page = parseInt(jumpInput.value);
if (page && page >= 1 && page <= totalPages) {
currentPage = page;
performSearch();
} else {
showToast('请输入有效的页码');
}
};
container.appendChild(pageText);
container.appendChild(jumpInput);
container.appendChild(jumpButton);
paginationDiv.insertBefore(container, nextButton);
pageInfo = container;
} }
const pageText = document.getElementById('pageText'); updatePageInfo(pageInfo, page, total, prefix);
pageText.textContent = `${currentPage} / ${totalPages || 1} 页 共 ${totalPages || 1}`; paginationDiv.style.display = total > 1 ? 'flex' : 'none';
const jumpInput = document.getElementById('jumpPage');
if (jumpInput) {
jumpInput.max = totalPages;
jumpInput.value = currentPage;
}
paginationDiv.style.display = totalPages > 1 ? 'flex' : 'none';
} }
// 创建页面信息元素
function createPageInfo(pageInfoId, prefix, total) {
const container = document.createElement('div');
container.id = pageInfoId;
container.style.cssText = 'margin: 0 10px; display: flex; align-items: center; gap: 10px;';
const pageText = document.createElement('span');
pageText.id = prefix ? `${prefix}PageText` : 'pageText';
const jumpInput = document.createElement('input');
jumpInput.type = 'number';
jumpInput.min = '1';
jumpInput.max = prefix === 'tag' ? total : Math.min(total, 100); // 搜索页面限制100页
jumpInput.id = prefix ? `${prefix}JumpPage` : 'jumpPage';
jumpInput.style.cssText = 'width: 60px; padding: 4px; border-radius: 4px; border: 1px solid var(--border); background-color: var(--input); color: var(--foreground);';
const jumpButton = document.createElement('button');
jumpButton.textContent = '跳转';
jumpButton.className = 'btn search-button';
jumpButton.style.padding = '4px 8px';
jumpButton.onclick = () => handlePageJump(jumpInput, prefix, total);
container.append(pageText, jumpInput, jumpButton);
return container;
}
// 更新页面信息显示
function updatePageInfo(pageInfo, page, total, prefix) {
const pageText = pageInfo.querySelector('span');
const jumpInput = pageInfo.querySelector('input');
// 标签分页显示策略:根据是否确定总页数显示不同格式
const isTagPagination = prefix === 'tag';
const maxDisplayPages = isTagPagination ? total : Math.min(total, 100);
const pageTextContent = isTagPagination
? `${page}` + (total > page ? ` (至少 ${total} 页)` : ` (共 ${total} 页)`)
: `${page} / ${maxDisplayPages} 页 共 ${maxDisplayPages}` + (total > 100 ? ' (最多100页)' : '');
pageText.textContent = pageTextContent;
jumpInput.max = maxDisplayPages;
jumpInput.value = page;
}
// 处理页面跳转
function handlePageJump(jumpInput, prefix, total) {
const inputPage = parseInt(jumpInput.value);
const maxPage = prefix === 'tag' ? total : Math.min(total, 100);
if (!inputPage || inputPage < 1 || inputPage > maxPage) {
const limitText = prefix === 'tag' ? '页码' : '页码 (最多100页)';
showToast(`请输入有效的${limitText}`);
return;
}
if (prefix === 'tag') {
currentTagPage = inputPage;
loadTagPage();
} else {
currentPage = inputPage;
performSearch();
}
}
// 统一仓库信息处理
function parseRepositoryInfo(repo) {
const namespace = repo.namespace || (repo.is_official ? 'library' : '');
let name = repo.name || repo.repo_name || '';
// 清理名称,确保不包含命名空间前缀
if (name.includes('/')) {
const parts = name.split('/');
name = parts[parts.length - 1];
}
const cleanName = name.replace(/^library\//, '');
const fullRepoName = repo.is_official ? cleanName : `${namespace}/${cleanName}`;
return {
namespace,
name,
cleanName,
fullRepoName
};
}
// 分页更新函数
const updateSearchPagination = () => updatePagination();
const updateTagPagination = () => updatePagination({
currentPage: currentTagPage,
totalPages: totalTagPages,
prefix: 'tag'
});
function showSearchResults() { function showSearchResults() {
document.querySelector('.search-results').style.display = 'block'; document.querySelector('.search-results').style.display = 'block';
document.querySelector('.tag-list').style.display = 'none'; document.querySelector('.tag-list').style.display = 'none';
@@ -1006,7 +1094,7 @@
throw new Error(data.error || '搜索请求失败'); throw new Error(data.error || '搜索请求失败');
} }
totalPages = Math.ceil(data.count / 25); totalPages = Math.min(Math.ceil(data.count / 25), 100);
updatePagination(); updatePagination();
displayResults(data.results, targetRepo); displayResults(data.results, targetRepo);
@@ -1108,23 +1196,58 @@
}); });
} }
// 内存管理
async function loadTags(namespace, name) { async function loadTags(namespace, name) {
currentTagPage = 1;
await loadTagPage(namespace, name);
}
async function loadTagPage(namespace = null, name = null) {
showLoading(); showLoading();
try { try {
if (!namespace || !name) { // 如果传入了新的namespace和name更新currentRepo
if (namespace && name) {
// 清理旧数据,防止内存泄露
cleanupOldTagData();
}
// 获取当前仓库信息
const repoInfo = parseRepositoryInfo(currentRepo);
const currentNamespace = namespace || repoInfo.namespace;
const currentName = name || repoInfo.name;
// 调试日志
console.log(`loadTagPage: namespace=${currentNamespace}, name=${currentName}, page=${currentTagPage}`);
if (!currentNamespace || !currentName) {
showToast('命名空间和镜像名称不能为空'); showToast('命名空间和镜像名称不能为空');
return; return;
} }
const response = await fetch(`/tags/${encodeURIComponent(namespace)}/${encodeURIComponent(name)}`); const response = await fetch(`/tags/${encodeURIComponent(currentNamespace)}/${encodeURIComponent(currentName)}?page=${currentTagPage}&page_size=100`);
if (!response.ok) { if (!response.ok) {
const errorText = await response.text(); const errorText = await response.text();
throw new Error(errorText || '获取标签信息失败'); throw new Error(errorText || '获取标签信息失败');
} }
const data = await response.json(); const data = await response.json();
displayTags(data);
showTagList(); // 改进的总页数计算:使用更准确的分页策略
if (data.has_more) {
// 如果还有更多页面,至少有当前页+1页但可能更多
totalTagPages = Math.max(currentTagPage + 1, totalTagPages);
} else {
// 如果没有更多页面,当前页就是最后一页
totalTagPages = currentTagPage;
}
displayTags(data.tags, data.has_more);
updateTagPagination();
if (namespace && name) {
showTagList();
}
} catch (error) { } catch (error) {
console.error('加载标签错误:', error); console.error('加载标签错误:', error);
showToast(error.message || '获取标签信息失败,请稍后重试'); showToast(error.message || '获取标签信息失败,请稍后重试');
@@ -1133,12 +1256,24 @@
} }
} }
function displayTags(tags) { function cleanupOldTagData() {
// 清理全局变量,释放内存
if (window.currentPageTags) {
window.currentPageTags.length = 0;
window.currentPageTags = null;
}
// 清理DOM缓存
const tagsContainer = document.getElementById('tagsContainer');
if (tagsContainer) {
tagsContainer.innerHTML = '';
}
}
function displayTags(tags, hasMore = false) {
const tagList = document.getElementById('tagList'); const tagList = document.getElementById('tagList');
const namespace = currentRepo.namespace || (currentRepo.is_official ? 'library' : ''); const repoInfo = parseRepositoryInfo(currentRepo);
const name = currentRepo.name || currentRepo.repo_name || ''; const { fullRepoName } = repoInfo;
const cleanName = name.replace(/^library\//, '');
const fullRepoName = currentRepo.is_official ? cleanName : `${namespace}/${cleanName}`;
let header = ` let header = `
<div class="tag-header"> <div class="tag-header">
@@ -1165,22 +1300,60 @@
<button class="tag-search-clear" onclick="clearTagSearch()">×</button> <button class="tag-search-clear" onclick="clearTagSearch()">×</button>
</div> </div>
<div id="tagsContainer"></div> <div id="tagsContainer"></div>
<div class="pagination" id="tagPagination" style="display: none;">
<button id="tagPrevPage" disabled>上一页</button>
<button id="tagNextPage" disabled>下一页</button>
</div>
`; `;
tagList.innerHTML = header; tagList.innerHTML = header;
window.allTags = tags; // 存储当前页标签数据
window.currentPageTags = tags;
renderFilteredTags(tags); renderFilteredTags(tags);
} }
function renderFilteredTags(filteredTags) { function renderFilteredTags(filteredTags) {
const tagsContainer = document.getElementById('tagsContainer'); const tagsContainer = document.getElementById('tagsContainer');
const namespace = currentRepo.namespace || (currentRepo.is_official ? 'library' : ''); const repoInfo = parseRepositoryInfo(currentRepo);
const name = currentRepo.name || currentRepo.repo_name || ''; const { fullRepoName } = repoInfo;
const cleanName = name.replace(/^library\//, '');
const fullRepoName = currentRepo.is_official ? cleanName : `${namespace}/${cleanName}`;
let tagsHtml = filteredTags.map(tag => { if (filteredTags.length === 0) {
tagsContainer.innerHTML = '<div class="text-center" style="padding: 20px;">未找到匹配的标签</div>';
return;
}
// 渐进式渲染:分批处理大数据集
const BATCH_SIZE = 50;
if (filteredTags.length <= BATCH_SIZE) {
// 小数据集:直接渲染
renderTagsBatch(filteredTags, fullRepoName, tagsContainer, true);
} else {
// 大数据集:分批渲染
tagsContainer.innerHTML = ''; // 清空容器
let currentBatch = 0;
function renderNextBatch() {
const start = currentBatch * BATCH_SIZE;
const end = Math.min(start + BATCH_SIZE, filteredTags.length);
const batch = filteredTags.slice(start, end);
renderTagsBatch(batch, fullRepoName, tagsContainer, false);
currentBatch++;
if (end < filteredTags.length) {
// 使用requestAnimationFrame确保UI响应性
requestAnimationFrame(renderNextBatch);
}
}
renderNextBatch();
}
}
function renderTagsBatch(tags, fullRepoName, container, replaceContent = false) {
const tagsHtml = tags.map(tag => {
const vulnIndicators = Object.entries(tag.vulnerabilities || {}) const vulnIndicators = Object.entries(tag.vulnerabilities || {})
.map(([level, count]) => count > 0 ? `<span class="vulnerability-dot vulnerability-${level.toLowerCase()}" title="${level}: ${count}"></span>` : '') .map(([level, count]) => count > 0 ? `<span class="vulnerability-dot vulnerability-${level.toLowerCase()}" title="${level}: ${count}"></span>` : '')
.join(''); .join('');
@@ -1212,23 +1385,23 @@
`; `;
}).join(''); }).join('');
if (filteredTags.length === 0) { if (replaceContent) {
tagsHtml = '<div class="text-center" style="padding: 20px;">未找到匹配的标签</div>'; container.innerHTML = tagsHtml;
} else {
container.insertAdjacentHTML('beforeend', tagsHtml);
} }
tagsContainer.innerHTML = tagsHtml;
} }
function filterTags(searchText) { function filterTags(searchText) {
if (!window.allTags) return; if (!window.currentPageTags) return;
const searchLower = searchText.toLowerCase(); const searchLower = searchText.toLowerCase();
let filteredTags; let filteredTags;
if (!searchText) { if (!searchText) {
filteredTags = window.allTags; filteredTags = window.currentPageTags;
} else { } else {
const scoredTags = window.allTags.map(tag => { const scoredTags = window.currentPageTags.map(tag => {
const name = tag.name.toLowerCase(); const name = tag.name.toLowerCase();
let score = 0; let score = 0;
@@ -1263,6 +1436,8 @@
} }
} }
function copyToClipboard(text) { function copyToClipboard(text) {
navigator.clipboard.writeText(text).then(() => { navigator.clipboard.writeText(text).then(() => {
showToast('已复制到剪贴板'); showToast('已复制到剪贴板');

View File

@@ -1,108 +0,0 @@
package main
import (
"strings"
"sync"
"time"
)
// SmartRateLimit 智能限流会话管理
type SmartRateLimit struct {
sessions sync.Map
}
// PullSession Docker拉取会话
type PullSession struct {
LastManifestTime time.Time
RequestCount int
}
// 全局智能限流实例
var smartLimiter = &SmartRateLimit{}
const (
// manifest请求后的活跃窗口时间
activeWindowDuration = 3 * time.Minute
// 活跃窗口内最大免费blob请求数(防止滥用)
maxFreeBlobRequests = 100
sessionCleanupInterval = 10 * time.Minute
sessionExpireTime = 30 * time.Minute
)
func init() {
go smartLimiter.cleanupSessions()
}
// ShouldSkipRateLimit 判断是否应该跳过限流计数
func (s *SmartRateLimit) ShouldSkipRateLimit(ip, path string) bool {
requestType, _ := parseRequestInfo(path)
if requestType != "manifests" && requestType != "blobs" {
return false
}
sessionKey := ip
sessionInterface, _ := s.sessions.LoadOrStore(sessionKey, &PullSession{})
session := sessionInterface.(*PullSession)
now := time.Now()
if requestType == "manifests" {
session.LastManifestTime = now
session.RequestCount = 0
return false
}
if requestType == "blobs" {
if !session.LastManifestTime.IsZero() &&
now.Sub(session.LastManifestTime) <= activeWindowDuration {
session.RequestCount++
if session.RequestCount <= maxFreeBlobRequests {
return true
}
}
}
return false
}
func parseRequestInfo(path string) (requestType, imageRef string) {
path = strings.TrimPrefix(path, "/v2/")
if idx := strings.Index(path, "/manifests/"); idx != -1 {
return "manifests", path[:idx]
}
if idx := strings.Index(path, "/blobs/"); idx != -1 {
return "blobs", path[:idx]
}
if idx := strings.Index(path, "/tags/"); idx != -1 {
return "tags", path[:idx]
}
return "unknown", ""
}
// cleanupSessions 定期清理过期会话,防止内存泄露
func (s *SmartRateLimit) cleanupSessions() {
ticker := time.NewTicker(sessionCleanupInterval)
defer ticker.Stop()
for range ticker.C {
now := time.Now()
expiredKeys := make([]string, 0)
s.sessions.Range(func(key, value interface{}) bool {
session := value.(*PullSession)
if !session.LastManifestTime.IsZero() &&
now.Sub(session.LastManifestTime) > sessionExpireTime {
expiredKeys = append(expiredKeys, key.(string))
}
return true
})
for _, key := range expiredKeys {
s.sessions.Delete(key)
}
}
}

View File

@@ -1,8 +1,10 @@
package main package utils
import ( import (
"strings" "strings"
"sync" "sync"
"hubproxy/config"
) )
// ResourceType 资源类型 // ResourceType 资源类型
@@ -26,7 +28,7 @@ type DockerImageInfo struct {
FullName string FullName string
} }
// 全局访问控制器实例 // GlobalAccessController 全局访问控制器实例
var GlobalAccessController = &AccessController{} var GlobalAccessController = &AccessController{}
// ParseDockerImage 解析Docker镜像名称 // ParseDockerImage 解析Docker镜像名称
@@ -79,21 +81,18 @@ func (ac *AccessController) ParseDockerImage(image string) DockerImageInfo {
// CheckDockerAccess 检查Docker镜像访问权限 // CheckDockerAccess 检查Docker镜像访问权限
func (ac *AccessController) CheckDockerAccess(image string) (allowed bool, reason string) { func (ac *AccessController) CheckDockerAccess(image string) (allowed bool, reason string) {
cfg := GetConfig() cfg := config.GetConfig()
// 解析镜像名称
imageInfo := ac.ParseDockerImage(image) imageInfo := ac.ParseDockerImage(image)
// 检查白名单(如果配置了白名单,则只允许白名单中的镜像) if len(cfg.Access.WhiteList) > 0 {
if len(cfg.Proxy.WhiteList) > 0 { if !ac.matchImageInList(imageInfo, cfg.Access.WhiteList) {
if !ac.matchImageInList(imageInfo, cfg.Proxy.WhiteList) {
return false, "不在Docker镜像白名单内" return false, "不在Docker镜像白名单内"
} }
} }
// 检查黑名单 if len(cfg.Access.BlackList) > 0 {
if len(cfg.Proxy.BlackList) > 0 { if ac.matchImageInList(imageInfo, cfg.Access.BlackList) {
if ac.matchImageInList(imageInfo, cfg.Proxy.BlackList) {
return false, "Docker镜像在黑名单内" return false, "Docker镜像在黑名单内"
} }
} }
@@ -107,15 +106,13 @@ func (ac *AccessController) CheckGitHubAccess(matches []string) (allowed bool, r
return false, "无效的GitHub仓库格式" return false, "无效的GitHub仓库格式"
} }
cfg := GetConfig() cfg := config.GetConfig()
// 检查白名单 if len(cfg.Access.WhiteList) > 0 && !ac.checkList(matches, cfg.Access.WhiteList) {
if len(cfg.Proxy.WhiteList) > 0 && !ac.checkList(matches, cfg.Proxy.WhiteList) {
return false, "不在GitHub仓库白名单内" return false, "不在GitHub仓库白名单内"
} }
// 检查黑名单 if len(cfg.Access.BlackList) > 0 && ac.checkList(matches, cfg.Access.BlackList) {
if len(cfg.Proxy.BlackList) > 0 && ac.checkList(matches, cfg.Proxy.BlackList) {
return false, "GitHub仓库在黑名单内" return false, "GitHub仓库在黑名单内"
} }
@@ -185,17 +182,14 @@ func (ac *AccessController) checkList(matches, list []string) bool {
continue continue
} }
// 支持多种匹配模式
if fullRepo == item { if fullRepo == item {
return true return true
} }
// 用户级匹配
if item == username || item == username+"/*" { if item == username || item == username+"/*" {
return true return true
} }
// 前缀匹配(支持通配符)
if strings.HasSuffix(item, "*") { if strings.HasSuffix(item, "*") {
prefix := strings.TrimSuffix(item, "*") prefix := strings.TrimSuffix(item, "*")
if strings.HasPrefix(fullRepo, prefix) { if strings.HasPrefix(fullRepo, prefix) {
@@ -203,18 +197,9 @@ func (ac *AccessController) checkList(matches, list []string) bool {
} }
} }
// 子仓库匹配(防止 user/repo 匹配到 user/repo-fork
if strings.HasPrefix(fullRepo, item+"/") { if strings.HasPrefix(fullRepo, item+"/") {
return true return true
} }
} }
return false return false
} }
// Reload 热重载访问控制规则
func (ac *AccessController) Reload() {
ac.mu.Lock()
defer ac.mu.Unlock()
// 访问控制器本身不缓存配置
}

View File

@@ -1,4 +1,4 @@
package main package utils
import ( import (
"crypto/md5" "crypto/md5"
@@ -9,22 +9,23 @@ import (
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"hubproxy/config"
) )
// CachedItem 通用缓存项支持Token和Manifest // CachedItem 通用缓存项
type CachedItem struct { type CachedItem struct {
Data []byte // 缓存数据(token字符串或manifest字节) Data []byte
ContentType string // 内容类型 ContentType string
Headers map[string]string // 额外的响应头 Headers map[string]string
ExpiresAt time.Time // 过期时间 ExpiresAt time.Time
} }
// UniversalCache 通用缓存支持Token和Manifest // UniversalCache 通用缓存
type UniversalCache struct { type UniversalCache struct {
cache sync.Map cache sync.Map
} }
var globalCache = &UniversalCache{} var GlobalCache = &UniversalCache{}
// Get 获取缓存项 // Get 获取缓存项
func (c *UniversalCache) Get(key string) *CachedItem { func (c *UniversalCache) Get(key string) *CachedItem {
@@ -57,30 +58,22 @@ func (c *UniversalCache) SetToken(key, token string, ttl time.Duration) {
c.Set(key, []byte(token), "application/json", nil, ttl) c.Set(key, []byte(token), "application/json", nil, ttl)
} }
// buildCacheKey 构建稳定的缓存key // BuildCacheKey 构建稳定的缓存key
func buildCacheKey(prefix, query string) string { func BuildCacheKey(prefix, query string) string {
return fmt.Sprintf("%s:%x", prefix, md5.Sum([]byte(query))) return fmt.Sprintf("%s:%x", prefix, md5.Sum([]byte(query)))
} }
func buildTokenCacheKey(query string) string { func BuildTokenCacheKey(query string) string {
return buildCacheKey("token", query) return BuildCacheKey("token", query)
} }
func buildManifestCacheKey(imageRef, reference string) string { func BuildManifestCacheKey(imageRef, reference string) string {
key := fmt.Sprintf("%s:%s", imageRef, reference) key := fmt.Sprintf("%s:%s", imageRef, reference)
return buildCacheKey("manifest", key) return BuildCacheKey("manifest", key)
} }
func buildManifestCacheKeyWithPlatform(imageRef, reference, platform string) string { func GetManifestTTL(reference string) time.Duration {
if platform == "" { cfg := config.GetConfig()
platform = "default"
}
key := fmt.Sprintf("%s:%s@%s", imageRef, reference, platform)
return buildCacheKey("manifest", key)
}
func getManifestTTL(reference string) time.Duration {
cfg := GetConfig()
defaultTTL := 30 * time.Minute defaultTTL := 30 * time.Minute
if cfg.TokenCache.DefaultTTL != "" { if cfg.TokenCache.DefaultTTL != "" {
if parsed, err := time.ParseDuration(cfg.TokenCache.DefaultTTL); err == nil { if parsed, err := time.ParseDuration(cfg.TokenCache.DefaultTTL); err == nil {
@@ -92,24 +85,20 @@ func getManifestTTL(reference string) time.Duration {
return 24 * time.Hour return 24 * time.Hour
} }
// mutable tag的智能判断
if reference == "latest" || reference == "main" || reference == "master" || if reference == "latest" || reference == "main" || reference == "master" ||
reference == "dev" || reference == "develop" { reference == "dev" || reference == "develop" {
// 热门可变标签: 短期缓存
return 10 * time.Minute return 10 * time.Minute
} }
// 普通tag: 中等缓存时间
return defaultTTL return defaultTTL
} }
// extractTTLFromResponse 从响应中智能提取TTL // ExtractTTLFromResponse 从响应中智能提取TTL
func extractTTLFromResponse(responseBody []byte) time.Duration { func ExtractTTLFromResponse(responseBody []byte) time.Duration {
var tokenResp struct { var tokenResp struct {
ExpiresIn int `json:"expires_in"` ExpiresIn int `json:"expires_in"`
} }
// 默认30分钟TTL确保稳定性
defaultTTL := 30 * time.Minute defaultTTL := 30 * time.Minute
if json.Unmarshal(responseBody, &tokenResp) == nil && tokenResp.ExpiresIn > 0 { if json.Unmarshal(responseBody, &tokenResp) == nil && tokenResp.ExpiresIn > 0 {
@@ -122,32 +111,54 @@ func extractTTLFromResponse(responseBody []byte) time.Duration {
return defaultTTL return defaultTTL
} }
func writeTokenResponse(c *gin.Context, cachedBody string) { func WriteTokenResponse(c *gin.Context, cachedBody string) {
c.Header("Content-Type", "application/json") c.Header("Content-Type", "application/json")
c.String(200, cachedBody) c.String(200, cachedBody)
} }
func writeCachedResponse(c *gin.Context, item *CachedItem) { func WriteCachedResponse(c *gin.Context, item *CachedItem) {
if item.ContentType != "" { if item.ContentType != "" {
c.Header("Content-Type", item.ContentType) c.Header("Content-Type", item.ContentType)
} }
// 设置额外的响应头
for key, value := range item.Headers { for key, value := range item.Headers {
c.Header(key, value) c.Header(key, value)
} }
// 返回数据
c.Data(200, item.ContentType, item.Data) c.Data(200, item.ContentType, item.Data)
} }
// isCacheEnabled 检查缓存是否启用 // IsCacheEnabled 检查缓存是否启用
func isCacheEnabled() bool { func IsCacheEnabled() bool {
cfg := GetConfig() cfg := config.GetConfig()
return cfg.TokenCache.Enabled return cfg.TokenCache.Enabled
} }
// isTokenCacheEnabled 检查token缓存是否启用(向后兼容) // IsTokenCacheEnabled 检查token缓存是否启用
func isTokenCacheEnabled() bool { func IsTokenCacheEnabled() bool {
return isCacheEnabled() return IsCacheEnabled()
}
// 定期清理过期缓存
func init() {
go func() {
ticker := time.NewTicker(20 * time.Minute)
defer ticker.Stop()
for range ticker.C {
now := time.Now()
expiredKeys := make([]string, 0)
GlobalCache.cache.Range(func(key, value interface{}) bool {
if cached := value.(*CachedItem); now.After(cached.ExpiresAt) {
expiredKeys = append(expiredKeys, key.(string))
}
return true
})
for _, key := range expiredKeys {
GlobalCache.cache.Delete(key)
}
}
}()
} }

View File

@@ -1,23 +1,31 @@
package main package utils
import ( import (
"net" "net"
"net/http" "net/http"
"os"
"time" "time"
"hubproxy/config"
) )
var ( var (
// 全局HTTP客户端 - 用于代理请求(长超时)
globalHTTPClient *http.Client globalHTTPClient *http.Client
// 搜索HTTP客户端 - 用于API请求短超时
searchHTTPClient *http.Client searchHTTPClient *http.Client
) )
// initHTTPClients 初始化HTTP客户端 // InitHTTPClients 初始化HTTP客户端
func initHTTPClients() { func InitHTTPClients() {
// 代理客户端配置 - 适用于大文件传输 cfg := config.GetConfig()
if p := cfg.Access.Proxy; p != "" {
os.Setenv("HTTP_PROXY", p)
os.Setenv("HTTPS_PROXY", p)
}
globalHTTPClient = &http.Client{ globalHTTPClient = &http.Client{
Transport: &http.Transport{ Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{ DialContext: (&net.Dialer{
Timeout: 30 * time.Second, Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second, KeepAlive: 30 * time.Second,
@@ -31,10 +39,10 @@ func initHTTPClients() {
}, },
} }
// 搜索客户端配置 - 适用于API调用
searchHTTPClient = &http.Client{ searchHTTPClient = &http.Client{
Timeout: 10 * time.Second, Timeout: 10 * time.Second,
Transport: &http.Transport{ Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{ DialContext: (&net.Dialer{
Timeout: 5 * time.Second, Timeout: 5 * time.Second,
KeepAlive: 30 * time.Second, KeepAlive: 30 * time.Second,
@@ -48,12 +56,12 @@ func initHTTPClients() {
} }
} }
// GetGlobalHTTPClient 获取全局HTTP客户端(用于代理) // GetGlobalHTTPClient 获取全局HTTP客户端
func GetGlobalHTTPClient() *http.Client { func GetGlobalHTTPClient() *http.Client {
return globalHTTPClient return globalHTTPClient
} }
// GetSearchHTTPClient 获取搜索HTTP客户端用于API调用 // GetSearchHTTPClient 获取搜索HTTP客户端
func GetSearchHTTPClient() *http.Client { func GetSearchHTTPClient() *http.Client {
return searchHTTPClient return searchHTTPClient
} }

View File

@@ -1,4 +1,4 @@
package main package utils
import ( import (
"bytes" "bytes"
@@ -41,7 +41,6 @@ func ProcessSmart(input io.ReadCloser, isCompressed bool, host string) (io.Reade
func readShellContent(input io.ReadCloser, isCompressed bool) (string, error) { func readShellContent(input io.ReadCloser, isCompressed bool) (string, error) {
var reader io.Reader = input var reader io.Reader = input
// 处理gzip压缩
if isCompressed { if isCompressed {
peek := make([]byte, 2) peek := make([]byte, 2)
n, err := input.Read(peek) n, err := input.Read(peek)

View File

@@ -1,4 +1,4 @@
package main package utils
import ( import (
"fmt" "fmt"
@@ -9,22 +9,23 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"golang.org/x/time/rate" "golang.org/x/time/rate"
"hubproxy/config"
) )
const ( const (
// 清理间隔
CleanupInterval = 10 * time.Minute CleanupInterval = 10 * time.Minute
MaxIPCacheSize = 10000 MaxIPCacheSize = 10000
) )
// IPRateLimiter IP限流器结构体 // IPRateLimiter IP限流器结构体
type IPRateLimiter struct { type IPRateLimiter struct {
ips map[string]*rateLimiterEntry // IP到限流器的映射 ips map[string]*rateLimiterEntry
mu *sync.RWMutex // 读写锁,保证并发安全 mu *sync.RWMutex
r rate.Limit // 速率限制(每秒允许的请求数) r rate.Limit
b int // 令牌桶容量(突发请求数) b int
whitelist []*net.IPNet // 白名单IP段 whitelist []*net.IPNet
blacklist []*net.IPNet // 黑名单IP段 blacklist []*net.IPNet
whitelistLimiter *rate.Limiter // 全局共享的白名单限流器
} }
// rateLimiterEntry 限流器条目 // rateLimiterEntry 限流器条目
@@ -33,15 +34,15 @@ type rateLimiterEntry struct {
lastAccess time.Time lastAccess time.Time
} }
// initGlobalLimiter 初始化全局限流器 // InitGlobalLimiter 初始化全局限流器
func initGlobalLimiter() *IPRateLimiter { func InitGlobalLimiter() *IPRateLimiter {
cfg := GetConfig() cfg := config.GetConfig()
whitelist := make([]*net.IPNet, 0, len(cfg.Security.WhiteList)) whitelist := make([]*net.IPNet, 0, len(cfg.Security.WhiteList))
for _, item := range cfg.Security.WhiteList { for _, item := range cfg.Security.WhiteList {
if item = strings.TrimSpace(item); item != "" { if item = strings.TrimSpace(item); item != "" {
if !strings.Contains(item, "/") { if !strings.Contains(item, "/") {
item = item + "/32" // 单个IP转为CIDR格式 item = item + "/32"
} }
_, ipnet, err := net.ParseCIDR(item) _, ipnet, err := net.ParseCIDR(item)
if err == nil { if err == nil {
@@ -52,12 +53,11 @@ func initGlobalLimiter() *IPRateLimiter {
} }
} }
// 解析黑名单IP段
blacklist := make([]*net.IPNet, 0, len(cfg.Security.BlackList)) blacklist := make([]*net.IPNet, 0, len(cfg.Security.BlackList))
for _, item := range cfg.Security.BlackList { for _, item := range cfg.Security.BlackList {
if item = strings.TrimSpace(item); item != "" { if item = strings.TrimSpace(item); item != "" {
if !strings.Contains(item, "/") { if !strings.Contains(item, "/") {
item = item + "/32" // 单个IP转为CIDR格式 item = item + "/32"
} }
_, ipnet, err := net.ParseCIDR(item) _, ipnet, err := net.ParseCIDR(item)
if err == nil { if err == nil {
@@ -68,7 +68,6 @@ func initGlobalLimiter() *IPRateLimiter {
} }
} }
// 计算速率:将 "每N小时X个请求" 转换为 "每秒Y个请求"
ratePerSecond := rate.Limit(float64(cfg.RateLimit.RequestLimit) / (cfg.RateLimit.PeriodHours * 3600)) ratePerSecond := rate.Limit(float64(cfg.RateLimit.RequestLimit) / (cfg.RateLimit.PeriodHours * 3600))
burstSize := cfg.RateLimit.RequestLimit burstSize := cfg.RateLimit.RequestLimit
@@ -77,25 +76,20 @@ func initGlobalLimiter() *IPRateLimiter {
} }
limiter := &IPRateLimiter{ limiter := &IPRateLimiter{
ips: make(map[string]*rateLimiterEntry), ips: make(map[string]*rateLimiterEntry),
mu: &sync.RWMutex{}, mu: &sync.RWMutex{},
r: ratePerSecond, r: ratePerSecond,
b: burstSize, b: burstSize,
whitelist: whitelist, whitelist: whitelist,
blacklist: blacklist, blacklist: blacklist,
whitelistLimiter: rate.NewLimiter(rate.Inf, burstSize),
} }
// 启动定期清理goroutine
go limiter.cleanupRoutine() go limiter.cleanupRoutine()
return limiter return limiter
} }
// initLimiter 初始化限流器
func initLimiter() {
globalLimiter = initGlobalLimiter()
}
// cleanupRoutine 定期清理过期的限流器 // cleanupRoutine 定期清理过期的限流器
func (i *IPRateLimiter) cleanupRoutine() { func (i *IPRateLimiter) cleanupRoutine() {
ticker := time.NewTicker(CleanupInterval) ticker := time.NewTicker(CleanupInterval)
@@ -105,25 +99,20 @@ func (i *IPRateLimiter) cleanupRoutine() {
now := time.Now() now := time.Now()
expired := make([]string, 0) expired := make([]string, 0)
// 查找过期的条目
i.mu.RLock() i.mu.RLock()
for ip, entry := range i.ips { for ip, entry := range i.ips {
// 如果最后访问时间超过1小时认为过期
if now.Sub(entry.lastAccess) > 1*time.Hour { if now.Sub(entry.lastAccess) > 1*time.Hour {
expired = append(expired, ip) expired = append(expired, ip)
} }
} }
i.mu.RUnlock() i.mu.RUnlock()
// 如果有过期条目或者缓存过大,进行清理
if len(expired) > 0 || len(i.ips) > MaxIPCacheSize { if len(expired) > 0 || len(i.ips) > MaxIPCacheSize {
i.mu.Lock() i.mu.Lock()
// 删除过期条目
for _, ip := range expired { for _, ip := range expired {
delete(i.ips, ip) delete(i.ips, ip)
} }
// 如果缓存仍然过大,全部清理
if len(i.ips) > MaxIPCacheSize { if len(i.ips) > MaxIPCacheSize {
i.ips = make(map[string]*rateLimiterEntry) i.ips = make(map[string]*rateLimiterEntry)
} }
@@ -132,26 +121,34 @@ func (i *IPRateLimiter) cleanupRoutine() {
} }
} }
// extractIPFromAddress 从地址中提取纯IP,去除端口号 // extractIPFromAddress 从地址中提取纯IP
func extractIPFromAddress(address string) string { func extractIPFromAddress(address string) string {
// 处理IPv6地址 [::1]:8080 格式 if host, _, err := net.SplitHostPort(address); err == nil {
if strings.HasPrefix(address, "[") { return host
if endIndex := strings.Index(address, "]"); endIndex != -1 {
return address[1:endIndex]
}
} }
// 处理IPv4地址 192.168.1.1:8080 格式
if lastColon := strings.LastIndex(address, ":"); lastColon != -1 {
return address[:lastColon]
}
return address return address
} }
// normalizeIPForRateLimit 标准化IP地址用于限流
func normalizeIPForRateLimit(ipStr string) string {
ip := net.ParseIP(ipStr)
if ip == nil {
return ipStr
}
if ip.To4() != nil {
return ipStr
}
ipv6 := ip.To16()
for i := 8; i < 16; i++ {
ipv6[i] = 0
}
return ipv6.String() + "/64"
}
// isIPInCIDRList 检查IP是否在CIDR列表中 // isIPInCIDRList 检查IP是否在CIDR列表中
func isIPInCIDRList(ip string, cidrList []*net.IPNet) bool { func isIPInCIDRList(ip string, cidrList []*net.IPNet) bool {
// 先提取纯IP地址
cleanIP := extractIPFromAddress(ip) cleanIP := extractIPFromAddress(ip)
parsedIP := net.ParseIP(cleanIP) parsedIP := net.ParseIP(cleanIP)
if parsedIP == nil { if parsedIP == nil {
@@ -166,30 +163,29 @@ func isIPInCIDRList(ip string, cidrList []*net.IPNet) bool {
return false return false
} }
// GetLimiter 获取指定IP的限流器,同时返回是否允许访问 // GetLimiter 获取指定IP的限流器
func (i *IPRateLimiter) GetLimiter(ip string) (*rate.Limiter, bool) { func (i *IPRateLimiter) GetLimiter(ip string) (*rate.Limiter, bool) {
// 提取纯IP地址
cleanIP := extractIPFromAddress(ip) cleanIP := extractIPFromAddress(ip)
// 检查是否在黑名单中
if isIPInCIDRList(cleanIP, i.blacklist) { if isIPInCIDRList(cleanIP, i.blacklist) {
return nil, false return nil, false
} }
// 检查是否在白名单中
if isIPInCIDRList(cleanIP, i.whitelist) { if isIPInCIDRList(cleanIP, i.whitelist) {
return rate.NewLimiter(rate.Inf, i.b), true return i.whitelistLimiter, true
} }
normalizedIP := normalizeIPForRateLimit(cleanIP)
now := time.Now() now := time.Now()
i.mu.RLock() i.mu.RLock()
entry, exists := i.ips[cleanIP] entry, exists := i.ips[normalizedIP]
i.mu.RUnlock() i.mu.RUnlock()
if exists { if exists {
i.mu.Lock() i.mu.Lock()
if entry, stillExists := i.ips[cleanIP]; stillExists { if entry, stillExists := i.ips[normalizedIP]; stillExists {
entry.lastAccess = now entry.lastAccess = now
i.mu.Unlock() i.mu.Unlock()
return entry.limiter, true return entry.limiter, true
@@ -198,7 +194,7 @@ func (i *IPRateLimiter) GetLimiter(ip string) (*rate.Limiter, bool) {
} }
i.mu.Lock() i.mu.Lock()
if entry, exists := i.ips[cleanIP]; exists { if entry, exists := i.ips[normalizedIP]; exists {
entry.lastAccess = now entry.lastAccess = now
i.mu.Unlock() i.mu.Unlock()
return entry.limiter, true return entry.limiter, true
@@ -208,7 +204,7 @@ func (i *IPRateLimiter) GetLimiter(ip string) (*rate.Limiter, bool) {
limiter: rate.NewLimiter(i.r, i.b), limiter: rate.NewLimiter(i.r, i.b),
lastAccess: now, lastAccess: now,
} }
i.ips[cleanIP] = entry i.ips[normalizedIP] = entry
i.mu.Unlock() i.mu.Unlock()
return entry.limiter, true return entry.limiter, true
@@ -217,40 +213,44 @@ func (i *IPRateLimiter) GetLimiter(ip string) (*rate.Limiter, bool) {
// RateLimitMiddleware 速率限制中间件 // RateLimitMiddleware 速率限制中间件
func RateLimitMiddleware(limiter *IPRateLimiter) gin.HandlerFunc { func RateLimitMiddleware(limiter *IPRateLimiter) gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
// 获取客户端真实IP path := c.Request.URL.Path
if path == "/" || path == "/favicon.ico" || path == "/images.html" || path == "/search.html" ||
strings.HasPrefix(path, "/public/") {
c.Next()
return
}
var ip string var ip string
// 优先尝试从请求头获取真实IP
if forwarded := c.GetHeader("X-Forwarded-For"); forwarded != "" { if forwarded := c.GetHeader("X-Forwarded-For"); forwarded != "" {
// X-Forwarded-For可能包含多个IP取第一个
ips := strings.Split(forwarded, ",") ips := strings.Split(forwarded, ",")
ip = strings.TrimSpace(ips[0]) ip = strings.TrimSpace(ips[0])
} else if realIP := c.GetHeader("X-Real-IP"); realIP != "" { } else if realIP := c.GetHeader("X-Real-IP"); realIP != "" {
// 如果有X-Real-IP头
ip = realIP ip = realIP
} else if remoteIP := c.GetHeader("X-Original-Forwarded-For"); remoteIP != "" { } else if remoteIP := c.GetHeader("X-Original-Forwarded-For"); remoteIP != "" {
// 某些代理可能使用此头
ips := strings.Split(remoteIP, ",") ips := strings.Split(remoteIP, ",")
ip = strings.TrimSpace(ips[0]) ip = strings.TrimSpace(ips[0])
} else { } else {
// 回退到ClientIP方法
ip = c.ClientIP() ip = c.ClientIP()
} }
// 提取纯IP地址去除端口号
cleanIP := extractIPFromAddress(ip) cleanIP := extractIPFromAddress(ip)
// 日志记录请求IP和头信息 normalizedIP := normalizeIPForRateLimit(cleanIP)
fmt.Printf("请求IP: %s (去除端口后: %s), X-Forwarded-For: %s, X-Real-IP: %s\n", if cleanIP != normalizedIP {
ip, fmt.Printf("请求IP: %s (提纯后: %s, 限流段: %s), X-Forwarded-For: %s, X-Real-IP: %s\n",
cleanIP, ip, cleanIP, normalizedIP,
c.GetHeader("X-Forwarded-For"), c.GetHeader("X-Forwarded-For"),
c.GetHeader("X-Real-IP")) c.GetHeader("X-Real-IP"))
} else {
fmt.Printf("请求IP: %s (提纯后: %s), X-Forwarded-For: %s, X-Real-IP: %s\n",
ip, cleanIP,
c.GetHeader("X-Forwarded-For"),
c.GetHeader("X-Real-IP"))
}
// 获取限流器并检查是否允许访问
ipLimiter, allowed := limiter.GetLimiter(cleanIP) ipLimiter, allowed := limiter.GetLimiter(cleanIP)
// 如果IP在黑名单中
if !allowed { if !allowed {
c.JSON(403, gin.H{ c.JSON(403, gin.H{
"error": "您已被限制访问", "error": "您已被限制访问",
@@ -259,11 +259,7 @@ func RateLimitMiddleware(limiter *IPRateLimiter) gin.HandlerFunc {
return return
} }
// 智能限流判断:检查是否应该跳过限流计数 if !ipLimiter.Allow() {
shouldSkip := smartLimiter.ShouldSkipRateLimit(cleanIP, c.Request.URL.Path)
// 只有在不跳过的情况下才检查限流
if !shouldSkip && !ipLimiter.Allow() {
c.JSON(429, gin.H{ c.JSON(429, gin.H{
"error": "请求频率过快,暂时限制访问", "error": "请求频率过快,暂时限制访问",
}) })
@@ -274,26 +270,3 @@ func RateLimitMiddleware(limiter *IPRateLimiter) gin.HandlerFunc {
c.Next() c.Next()
} }
} }
// ApplyRateLimit 应用限流到特定路由
func ApplyRateLimit(router *gin.Engine, path string, method string, handler gin.HandlerFunc) {
// 使用全局限流器
limiter := globalLimiter
if limiter == nil {
limiter = initGlobalLimiter()
}
// 根据HTTP方法应用限流
switch method {
case "GET":
router.GET(path, RateLimitMiddleware(limiter), handler)
case "POST":
router.POST(path, RateLimitMiddleware(limiter), handler)
case "PUT":
router.PUT(path, RateLimitMiddleware(limiter), handler)
case "DELETE":
router.DELETE(path, RateLimitMiddleware(limiter), handler)
default:
router.Any(path, RateLimitMiddleware(limiter), handler)
}
}