重构优化

This commit is contained in:
starry
2025-06-11 11:58:57 +08:00
committed by GitHub
parent 86ca361057
commit a87f76dbd0
26 changed files with 5530 additions and 4569 deletions

View File

@@ -46,7 +46,7 @@ jobs:
- name: Build and push Docker image
run: |
cd ghproxy
cd src
docker buildx build --push \
--platform linux/amd64,linux/arm64 \
--tag ghcr.io/${{ env.REPO_LOWER }}:${{ env.VERSION }} \

View File

@@ -1,15 +0,0 @@
hub.{$DOMAIN} {
reverse_proxy * ghproxy:5000
}
docker.{$DOMAIN} {
@v2_manifest_blob path_regexp v2_rewrite ^/v2/([^/]+)/(manifests|blobs)/(.*)$
handle @v2_manifest_blob {
rewrite * /v2/library/{re.v2_rewrite.1}/{re.v2_rewrite.2}/{re.v2_rewrite.3}
}
reverse_proxy * docker:5000
}
ghcr.{$DOMAIN} {
reverse_proxy * ghcr:5000
}

View File

@@ -1,31 +0,0 @@
services:
caddy:
image: caddy:alpine
container_name: caddy
ports:
- "80:80"
- "443:443"
volumes:
- ./Caddyfile:/etc/caddy/Caddyfile
environment:
- DOMAIN=example.com # 修改为你的根域名
restart: always
ghcr:
image: "registry:2.8.3"
container_name: "ghcr"
restart: "always"
volumes:
- "./ghcr/config.yml:/etc/docker/registry/config.yml"
docker:
image: "registry:2.8.3"
container_name: "docker"
restart: "always"
volumes:
- "./docker/config.yml:/etc/docker/registry/config.yml"
ghproxy:
image: "ghcr.io/sky22333/hubproxy"
container_name: "ghproxy"
restart: "always"

View File

@@ -1,16 +0,0 @@
version: 0.1
storage:
filesystem:
rootdirectory: /var/lib/registry
delete:
enabled: true
maintenance:
uploadpurging:
enabled: true
age: 72h
dryrun: false
interval: 1m
http:
addr: 0.0.0.0:5000
proxy:
remoteurl: https://registry-1.docker.io

View File

@@ -1,16 +0,0 @@
version: 0.1
storage:
filesystem:
rootdirectory: /var/lib/registry
delete:
enabled: true
maintenance:
uploadpurging:
enabled: true
age: 72h
dryrun: false
interval: 1m
http:
addr: 0.0.0.0:5000
proxy:
remoteurl: https://ghcr.io

View File

@@ -1,8 +0,0 @@
{
"whiteList": [
],
"blackList": [
"example1",
"login"
]
}

View File

@@ -1,6 +0,0 @@
services:
ghproxy:
build: .
restart: always
ports:
- '5000:5000'

View File

@@ -1,11 +1,11 @@
FROM golang:1.23-alpine AS builder
FROM golang:1.24-alpine AS builder
WORKDIR /app
COPY go.mod go.sum ./
RUN go mod download
COPY . .
RUN CGO_ENABLED=0 GOOS=linux go build -ldflags="-s -w" -trimpath -o ghproxy .
RUN CGO_ENABLED=0 GOOS=linux go build -ldflags="-s -w" -trimpath -o hubproxy .
FROM alpine
@@ -14,8 +14,8 @@ WORKDIR /root/
# 安装skopeo
RUN apk add --no-cache skopeo && mkdir -p temp && chmod 700 temp
COPY --from=builder /app/ghproxy .
COPY --from=builder /app/config.json .
COPY --from=builder /app/hubproxy .
COPY --from=builder /app/config.toml .
COPY --from=builder /app/public ./public
CMD ["./ghproxy"]
CMD ["./hubproxy"]

226
src/access_control.go Normal file
View File

@@ -0,0 +1,226 @@
package main
import (
"strings"
"sync"
)
// ResourceType 资源类型
type ResourceType string
const (
ResourceTypeGitHub ResourceType = "github"
ResourceTypeDocker ResourceType = "docker"
)
// AccessController 统一访问控制器
type AccessController struct {
mu sync.RWMutex
}
// DockerImageInfo Docker镜像信息
type DockerImageInfo struct {
Namespace string
Repository string
Tag string
FullName string
}
// 全局访问控制器实例
var GlobalAccessController = &AccessController{}
// ParseDockerImage 解析Docker镜像名称
func (ac *AccessController) ParseDockerImage(image string) DockerImageInfo {
// 移除可能的协议前缀
image = strings.TrimPrefix(image, "docker://")
// 分离标签
var tag string
if idx := strings.LastIndex(image, ":"); idx != -1 {
// 检查是否是端口号而不是标签(包含斜杠)
part := image[idx+1:]
if !strings.Contains(part, "/") {
tag = part
image = image[:idx]
}
}
if tag == "" {
tag = "latest"
}
// 分离命名空间和仓库名
var namespace, repository string
if strings.Contains(image, "/") {
// 处理自定义registry的情况如 registry.com/user/repo
parts := strings.Split(image, "/")
if len(parts) >= 2 {
// 检查第一部分是否是域名(包含.
if strings.Contains(parts[0], ".") {
// 跳过registry域名取用户名和仓库名
if len(parts) >= 3 {
namespace = parts[1]
repository = parts[2]
} else {
namespace = "library"
repository = parts[1]
}
} else {
// 标准格式user/repo
namespace = parts[0]
repository = parts[1]
}
}
} else {
// 官方镜像,如 nginx
namespace = "library"
repository = image
}
fullName := namespace + "/" + repository
return DockerImageInfo{
Namespace: namespace,
Repository: repository,
Tag: tag,
FullName: fullName,
}
}
// CheckDockerAccess 检查Docker镜像访问权限
func (ac *AccessController) CheckDockerAccess(image string) (allowed bool, reason string) {
cfg := GetConfig()
// 解析镜像名称
imageInfo := ac.ParseDockerImage(image)
// 检查白名单(如果配置了白名单,则只允许白名单中的镜像)
if len(cfg.Proxy.WhiteList) > 0 {
if !ac.matchImageInList(imageInfo, cfg.Proxy.WhiteList) {
return false, "不在Docker镜像白名单内"
}
}
// 检查黑名单
if len(cfg.Proxy.BlackList) > 0 {
if ac.matchImageInList(imageInfo, cfg.Proxy.BlackList) {
return false, "Docker镜像在黑名单内"
}
}
return true, ""
}
// CheckGitHubAccess 检查GitHub仓库访问权限
func (ac *AccessController) CheckGitHubAccess(matches []string) (allowed bool, reason string) {
if len(matches) < 2 {
return false, "无效的GitHub仓库格式"
}
cfg := GetConfig()
// 检查白名单
if len(cfg.Proxy.WhiteList) > 0 && !ac.checkList(matches, cfg.Proxy.WhiteList) {
return false, "不在GitHub仓库白名单内"
}
// 检查黑名单
if len(cfg.Proxy.BlackList) > 0 && ac.checkList(matches, cfg.Proxy.BlackList) {
return false, "GitHub仓库在黑名单内"
}
return true, ""
}
// matchImageInList 检查Docker镜像是否在指定列表中
func (ac *AccessController) matchImageInList(imageInfo DockerImageInfo, list []string) bool {
fullName := strings.ToLower(imageInfo.FullName)
namespace := strings.ToLower(imageInfo.Namespace)
for _, item := range list {
item = strings.ToLower(strings.TrimSpace(item))
if item == "" {
continue
}
if fullName == item {
return true
}
if item == namespace || item == namespace+"/*" {
return true
}
if strings.HasSuffix(item, "*") {
prefix := strings.TrimSuffix(item, "*")
if strings.HasPrefix(fullName, prefix) {
return true
}
}
if strings.HasPrefix(item, "*/") {
repoPattern := strings.TrimPrefix(item, "*/")
if strings.HasSuffix(repoPattern, "*") {
repoPrefix := strings.TrimSuffix(repoPattern, "*")
if strings.HasPrefix(imageInfo.Repository, repoPrefix) {
return true
}
} else {
if strings.ToLower(imageInfo.Repository) == repoPattern {
return true
}
}
}
// 5. 子仓库匹配(防止 user/repo 匹配到 user/repo-fork
if strings.HasPrefix(fullName, item+"/") {
return true
}
}
return false
}
// checkList GitHub仓库检查逻辑
func (ac *AccessController) checkList(matches, list []string) bool {
if len(matches) < 2 {
return false
}
// 组合用户名和仓库名,处理.git后缀
username := strings.ToLower(strings.TrimSpace(matches[0]))
repoName := strings.ToLower(strings.TrimSpace(strings.TrimSuffix(matches[1], ".git")))
fullRepo := username + "/" + repoName
for _, item := range list {
item = strings.ToLower(strings.TrimSpace(item))
if item == "" {
continue
}
// 支持多种匹配模式:
// 1. 精确匹配: "vaxilu/x-ui"
// 2. 用户级匹配: "vaxilu/*" 或 "vaxilu"
// 3. 前缀匹配: "vaxilu/x-ui-*"
if fullRepo == item {
return true
}
// 用户级匹配
if item == username || item == username+"/*" {
return true
}
// 前缀匹配(支持通配符)
if strings.HasSuffix(item, "*") {
prefix := strings.TrimSuffix(item, "*")
if strings.HasPrefix(fullRepo, prefix) {
return true
}
}
// 子仓库匹配(防止 user/repo 匹配到 user/repo-fork
if strings.HasPrefix(fullRepo, item+"/") {
return true
}
}
return false
}

195
src/config.go Normal file
View File

@@ -0,0 +1,195 @@
package main
import (
"fmt"
"os"
"strconv"
"strings"
"sync"
"github.com/pelletier/go-toml/v2"
)
// AppConfig 应用配置结构体
type AppConfig struct {
Server struct {
Host string `toml:"host"` // 监听地址
Port int `toml:"port"` // 监听端口
FileSize int64 `toml:"fileSize"` // 文件大小限制(字节)
} `toml:"server"`
RateLimit struct {
RequestLimit int `toml:"requestLimit"` // 每小时请求限制
PeriodHours float64 `toml:"periodHours"` // 限制周期(小时)
} `toml:"rateLimit"`
Security struct {
WhiteList []string `toml:"whiteList"` // 白名单IP/CIDR列表
BlackList []string `toml:"blackList"` // 黑名单IP/CIDR列表
} `toml:"security"`
Proxy struct {
WhiteList []string `toml:"whiteList"` // 代理白名单(仓库级别)
BlackList []string `toml:"blackList"` // 代理黑名单(仓库级别)
} `toml:"proxy"`
Download struct {
MaxImages int `toml:"maxImages"` // 单次下载最大镜像数量限制
} `toml:"download"`
}
var (
appConfig *AppConfig
appConfigLock sync.RWMutex
)
// DefaultConfig 返回默认配置
func DefaultConfig() *AppConfig {
return &AppConfig{
Server: struct {
Host string `toml:"host"`
Port int `toml:"port"`
FileSize int64 `toml:"fileSize"`
}{
Host: "0.0.0.0",
Port: 5000,
FileSize: 2 * 1024 * 1024 * 1024, // 2GB
},
RateLimit: struct {
RequestLimit int `toml:"requestLimit"`
PeriodHours float64 `toml:"periodHours"`
}{
RequestLimit: 20,
PeriodHours: 1.0,
},
Security: struct {
WhiteList []string `toml:"whiteList"`
BlackList []string `toml:"blackList"`
}{
WhiteList: []string{},
BlackList: []string{},
},
Proxy: struct {
WhiteList []string `toml:"whiteList"`
BlackList []string `toml:"blackList"`
}{
WhiteList: []string{},
BlackList: []string{},
},
Download: struct {
MaxImages int `toml:"maxImages"`
}{
MaxImages: 10, // 默认值最多同时下载10个镜像
},
}
}
// GetConfig 安全地获取配置副本
func GetConfig() *AppConfig {
appConfigLock.RLock()
defer appConfigLock.RUnlock()
if appConfig == nil {
return DefaultConfig()
}
// 返回配置的深拷贝
configCopy := *appConfig
configCopy.Security.WhiteList = append([]string(nil), appConfig.Security.WhiteList...)
configCopy.Security.BlackList = append([]string(nil), appConfig.Security.BlackList...)
configCopy.Proxy.WhiteList = append([]string(nil), appConfig.Proxy.WhiteList...)
configCopy.Proxy.BlackList = append([]string(nil), appConfig.Proxy.BlackList...)
return &configCopy
}
// setConfig 安全地设置配置
func setConfig(cfg *AppConfig) {
appConfigLock.Lock()
defer appConfigLock.Unlock()
appConfig = cfg
}
// LoadConfig 加载配置文件
func LoadConfig() error {
// 首先使用默认配置
cfg := DefaultConfig()
// 尝试加载TOML配置文件
if data, err := os.ReadFile("config.toml"); err == nil {
if err := toml.Unmarshal(data, cfg); err != nil {
return fmt.Errorf("解析配置文件失败: %v", err)
}
} else {
fmt.Println("未找到config.toml使用默认配置")
}
// 从环境变量覆盖配置
overrideFromEnv(cfg)
// 设置配置
setConfig(cfg)
fmt.Printf("配置加载成功: 监听 %s:%d, 文件大小限制 %d MB, 限流 %d请求/%g小时, 离线镜像并发数 %d\n",
cfg.Server.Host, cfg.Server.Port, cfg.Server.FileSize/(1024*1024),
cfg.RateLimit.RequestLimit, cfg.RateLimit.PeriodHours, cfg.Download.MaxImages)
return nil
}
// overrideFromEnv 从环境变量覆盖配置
func overrideFromEnv(cfg *AppConfig) {
// 服务器配置
if val := os.Getenv("SERVER_HOST"); val != "" {
cfg.Server.Host = val
}
if val := os.Getenv("SERVER_PORT"); val != "" {
if port, err := strconv.Atoi(val); err == nil && port > 0 {
cfg.Server.Port = port
}
}
if val := os.Getenv("MAX_FILE_SIZE"); val != "" {
if size, err := strconv.ParseInt(val, 10, 64); err == nil && size > 0 {
cfg.Server.FileSize = size
}
}
// 限流配置
if val := os.Getenv("RATE_LIMIT"); val != "" {
if limit, err := strconv.Atoi(val); err == nil && limit > 0 {
cfg.RateLimit.RequestLimit = limit
}
}
if val := os.Getenv("RATE_PERIOD_HOURS"); val != "" {
if period, err := strconv.ParseFloat(val, 64); err == nil && period > 0 {
cfg.RateLimit.PeriodHours = period
}
}
// IP限制配置
if val := os.Getenv("IP_WHITELIST"); val != "" {
cfg.Security.WhiteList = append(cfg.Security.WhiteList, strings.Split(val, ",")...)
}
if val := os.Getenv("IP_BLACKLIST"); val != "" {
cfg.Security.BlackList = append(cfg.Security.BlackList, strings.Split(val, ",")...)
}
// 下载限制配置
if val := os.Getenv("MAX_IMAGES"); val != "" {
if maxImages, err := strconv.Atoi(val); err == nil && maxImages > 0 {
cfg.Download.MaxImages = maxImages
}
}
}
// CreateDefaultConfigFile 创建默认配置文件
func CreateDefaultConfigFile() error {
cfg := DefaultConfig()
data, err := toml.Marshal(cfg)
if err != nil {
return fmt.Errorf("序列化默认配置失败: %v", err)
}
return os.WriteFile("config.toml", data, 0644)
}

45
src/config.toml Normal file
View File

@@ -0,0 +1,45 @@
[server]
# 监听地址,默认监听所有接口
host = "0.0.0.0"
# 监听端口
port = 5000
# 文件大小限制字节默认2GB
fileSize = 2147483648
[rateLimit]
# 每个IP每小时允许的请求数
requestLimit = 200
# 限流周期(小时)
periodHours = 1.0
[security]
# IP白名单支持单个IP或CIDR格式
# 白名单中的IP不受限流限制
whiteList = [
"127.0.0.1",
"192.168.1.0/24"
]
# IP黑名单支持单个IP或CIDR格式
# 黑名单中的IP将被直接拒绝访问
blackList = [
"192.168.100.1"
]
[proxy]
# 代理服务白名单支持GitHub仓库和Docker镜像支持通配符
# 只允许访问白名单中的仓库/镜像,为空时不限制
whiteList = []
# 代理服务黑名单支持GitHub仓库和Docker镜像支持通配符
# 禁止访问黑名单中的仓库/镜像
blackList = [
"baduser/malicious-repo",
"thesadboy/x-ui",
"vaxilu/x-ui",
"vaxilu/*"
]
[download]
# 单次并发下载离线镜像数量限制
maxImages = 10

8
src/docker-compose.yml Normal file
View File

@@ -0,0 +1,8 @@
services:
ghproxy:
build: .
restart: always
ports:
- '5000:5000'
volumes:
- ./config.toml:/root/config.toml

323
src/docker.go Normal file
View File

@@ -0,0 +1,323 @@
package main
import (
"context"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/google/go-containerregistry/pkg/authn"
"github.com/google/go-containerregistry/pkg/name"
"github.com/google/go-containerregistry/pkg/v1/remote"
)
// DockerProxy Docker代理配置
type DockerProxy struct {
registry name.Registry
options []remote.Option
}
var dockerProxy *DockerProxy
// 初始化Docker代理
func initDockerProxy() {
// 创建目标registry
registry, err := name.NewRegistry("registry-1.docker.io")
if err != nil {
fmt.Printf("创建Docker registry失败: %v\n", err)
return
}
// 配置代理选项
options := []remote.Option{
remote.WithAuth(authn.Anonymous),
remote.WithUserAgent("ghproxy/go-containerregistry"),
}
dockerProxy = &DockerProxy{
registry: registry,
options: options,
}
fmt.Printf("Docker代理已初始化\n")
}
// ProxyDockerRegistryGin 标准Docker Registry API v2代理
func ProxyDockerRegistryGin(c *gin.Context) {
path := c.Request.URL.Path
// 处理 /v2/ API版本检查
if path == "/v2/" {
c.JSON(http.StatusOK, gin.H{})
return
}
// 处理不同的API端点
if strings.HasPrefix(path, "/v2/") {
handleRegistryRequest(c, path)
} else {
c.String(http.StatusNotFound, "Docker Registry API v2 only")
}
}
// handleRegistryRequest 处理Registry请求
func handleRegistryRequest(c *gin.Context, path string) {
// 移除 /v2/ 前缀
pathWithoutV2 := strings.TrimPrefix(path, "/v2/")
// 解析路径
imageName, apiType, reference := parseRegistryPath(pathWithoutV2)
if imageName == "" || apiType == "" {
c.String(http.StatusBadRequest, "Invalid path format")
return
}
// 自动处理官方镜像的library命名空间
if !strings.Contains(imageName, "/") {
imageName = "library/" + imageName
}
// Docker镜像访问控制检查
if allowed, reason := GlobalAccessController.CheckDockerAccess(imageName); !allowed {
fmt.Printf("Docker镜像 %s 访问被拒绝: %s\n", imageName, reason)
c.String(http.StatusForbidden, "镜像访问被限制")
return
}
// 构建完整的镜像引用
imageRef := fmt.Sprintf("%s/%s", dockerProxy.registry.Name(), imageName)
switch apiType {
case "manifests":
handleManifestRequest(c, imageRef, reference)
case "blobs":
handleBlobRequest(c, imageRef, reference)
case "tags":
handleTagsRequest(c, imageRef)
default:
c.String(http.StatusNotFound, "API endpoint not found")
}
}
// parseRegistryPath 解析Registry路径
func parseRegistryPath(path string) (imageName, apiType, reference string) {
// 查找API端点关键字
if idx := strings.Index(path, "/manifests/"); idx != -1 {
imageName = path[:idx]
apiType = "manifests"
reference = path[idx+len("/manifests/"):]
return
}
if idx := strings.Index(path, "/blobs/"); idx != -1 {
imageName = path[:idx]
apiType = "blobs"
reference = path[idx+len("/blobs/"):]
return
}
if idx := strings.Index(path, "/tags/list"); idx != -1 {
imageName = path[:idx]
apiType = "tags"
reference = "list"
return
}
return "", "", ""
}
// handleManifestRequest 处理manifest请求
func handleManifestRequest(c *gin.Context, imageRef, reference string) {
var ref name.Reference
var err error
// 判断reference是digest还是tag
if strings.HasPrefix(reference, "sha256:") {
// 是digest
ref, err = name.NewDigest(fmt.Sprintf("%s@%s", imageRef, reference))
} else {
// 是tag
ref, err = name.NewTag(fmt.Sprintf("%s:%s", imageRef, reference))
}
if err != nil {
fmt.Printf("解析镜像引用失败: %v\n", err)
c.String(http.StatusBadRequest, "Invalid reference")
return
}
// 根据请求方法选择操作
if c.Request.Method == http.MethodHead {
// HEAD请求使用remote.Head
desc, err := remote.Head(ref, dockerProxy.options...)
if err != nil {
fmt.Printf("HEAD请求失败: %v\n", err)
c.String(http.StatusNotFound, "Manifest not found")
return
}
// 设置响应头
c.Header("Content-Type", string(desc.MediaType))
c.Header("Docker-Content-Digest", desc.Digest.String())
c.Header("Content-Length", fmt.Sprintf("%d", desc.Size))
c.Status(http.StatusOK)
} else {
// GET请求使用remote.Get
desc, err := remote.Get(ref, dockerProxy.options...)
if err != nil {
fmt.Printf("GET请求失败: %v\n", err)
c.String(http.StatusNotFound, "Manifest not found")
return
}
// 设置响应头
c.Header("Content-Type", string(desc.MediaType))
c.Header("Docker-Content-Digest", desc.Digest.String())
c.Header("Content-Length", fmt.Sprintf("%d", len(desc.Manifest)))
// 返回manifest内容
c.Data(http.StatusOK, string(desc.MediaType), desc.Manifest)
}
}
// handleBlobRequest 处理blob请求
func handleBlobRequest(c *gin.Context, imageRef, digest string) {
// 构建digest引用
digestRef, err := name.NewDigest(fmt.Sprintf("%s@%s", imageRef, digest))
if err != nil {
fmt.Printf("解析digest引用失败: %v\n", err)
c.String(http.StatusBadRequest, "Invalid digest reference")
return
}
// 使用remote.Layer获取layer
layer, err := remote.Layer(digestRef, dockerProxy.options...)
if err != nil {
fmt.Printf("获取layer失败: %v\n", err)
c.String(http.StatusNotFound, "Layer not found")
return
}
// 获取layer信息
size, err := layer.Size()
if err != nil {
fmt.Printf("获取layer大小失败: %v\n", err)
c.String(http.StatusInternalServerError, "Failed to get layer size")
return
}
// 获取layer内容
reader, err := layer.Compressed()
if err != nil {
fmt.Printf("获取layer内容失败: %v\n", err)
c.String(http.StatusInternalServerError, "Failed to get layer content")
return
}
defer reader.Close()
// 设置响应头
c.Header("Content-Type", "application/octet-stream")
c.Header("Content-Length", fmt.Sprintf("%d", size))
c.Header("Docker-Content-Digest", digest)
// 流式传输blob内容
c.Status(http.StatusOK)
io.Copy(c.Writer, reader)
}
// handleTagsRequest 处理tags列表请求
func handleTagsRequest(c *gin.Context, imageRef string) {
// 解析repository
repo, err := name.NewRepository(imageRef)
if err != nil {
fmt.Printf("解析repository失败: %v\n", err)
c.String(http.StatusBadRequest, "Invalid repository")
return
}
// 使用remote.List获取tags
tags, err := remote.List(repo, dockerProxy.options...)
if err != nil {
fmt.Printf("获取tags失败: %v\n", err)
c.String(http.StatusNotFound, "Tags not found")
return
}
// 构建响应
response := map[string]interface{}{
"name": strings.TrimPrefix(imageRef, dockerProxy.registry.Name()+"/"),
"tags": tags,
}
c.JSON(http.StatusOK, response)
}
// ProxyDockerAuthGin Docker认证代理
func ProxyDockerAuthGin(c *gin.Context) {
// 构建认证URL
authURL := "https://auth.docker.io" + c.Request.URL.Path
if c.Request.URL.RawQuery != "" {
authURL += "?" + c.Request.URL.RawQuery
}
// 创建HTTP客户端
client := &http.Client{
Timeout: 30 * time.Second,
}
// 创建请求
req, err := http.NewRequestWithContext(
context.Background(),
c.Request.Method,
authURL,
c.Request.Body,
)
if err != nil {
c.String(http.StatusInternalServerError, "Failed to create request")
return
}
// 复制请求头
for key, values := range c.Request.Header {
for _, value := range values {
req.Header.Add(key, value)
}
}
// 执行请求
resp, err := client.Do(req)
if err != nil {
c.String(http.StatusBadGateway, "Auth request failed")
return
}
defer resp.Body.Close()
// 获取当前代理的Host地址
proxyHost := c.Request.Host
if proxyHost == "" {
// 使用配置中的服务器地址和端口
cfg := GetConfig()
proxyHost = fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port)
if cfg.Server.Host == "0.0.0.0" {
proxyHost = fmt.Sprintf("localhost:%d", cfg.Server.Port)
}
}
// 复制响应头并重写认证URL
for key, values := range resp.Header {
for _, value := range values {
// 重写WWW-Authenticate头中的realm URL
if key == "Www-Authenticate" && strings.Contains(value, "auth.docker.io") {
value = strings.ReplaceAll(value, "https://auth.docker.io", "http://"+proxyHost)
}
c.Header(key, value)
}
}
// 返回响应
c.Status(resp.StatusCode)
io.Copy(c.Writer, resp.Body)
}

View File

@@ -1,12 +1,12 @@
module ghproxy
module hubproxy
go 1.23.0
toolchain go1.24.1
go 1.24.0
require (
github.com/gin-gonic/gin v1.10.0
github.com/google/go-containerregistry v0.20.5
github.com/gorilla/websocket v1.5.1
github.com/pelletier/go-toml/v2 v2.2.2
golang.org/x/sync v0.14.0
golang.org/x/time v0.11.0
)
@@ -16,6 +16,10 @@ require (
github.com/bytedance/sonic/loader v0.1.1 // indirect
github.com/cloudwego/base64x v0.1.4 // indirect
github.com/cloudwego/iasm v0.2.0 // indirect
github.com/containerd/stargz-snapshotter/estargz v0.16.3 // indirect
github.com/docker/cli v28.1.1+incompatible // indirect
github.com/docker/distribution v2.8.3+incompatible // indirect
github.com/docker/docker-credential-helpers v0.9.3 // indirect
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect
github.com/go-playground/locales v0.14.1 // indirect
@@ -23,19 +27,25 @@ require (
github.com/go-playground/validator/v10 v10.20.0 // indirect
github.com/goccy/go-json v0.10.2 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/compress v1.18.0 // indirect
github.com/klauspost/cpuid/v2 v2.2.7 // indirect
github.com/leodido/go-urn v1.4.0 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mitchellh/go-homedir v1.1.0 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
github.com/opencontainers/go-digest v1.0.0 // indirect
github.com/opencontainers/image-spec v1.1.1 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/sirupsen/logrus v1.9.3 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.12 // indirect
github.com/vbatts/tar-split v0.12.1 // indirect
golang.org/x/arch v0.8.0 // indirect
golang.org/x/crypto v0.23.0 // indirect
golang.org/x/net v0.25.0 // indirect
golang.org/x/sys v0.20.0 // indirect
golang.org/x/sys v0.33.0 // indirect
golang.org/x/text v0.15.0 // indirect
google.golang.org/protobuf v1.34.1 // indirect
google.golang.org/protobuf v1.36.3 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

View File

@@ -6,9 +6,17 @@ github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/
github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w=
github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg=
github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY=
github.com/containerd/stargz-snapshotter/estargz v0.16.3 h1:7evrXtoh1mSbGj/pfRccTampEyKpjpOnS3CyiV1Ebr8=
github.com/containerd/stargz-snapshotter/estargz v0.16.3/go.mod h1:uyr4BfYfOj3G9WBVE8cOlQmXAbPN9VEQpBBeJIuOipU=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/docker/cli v28.1.1+incompatible h1:eyUemzeI45DY7eDPuwUcmDyDj1pM98oD5MdSpiItp8k=
github.com/docker/cli v28.1.1+incompatible/go.mod h1:JLrzqnKDaYBop7H2jaqPtU4hHvMKP+vjCwu2uszcLI8=
github.com/docker/distribution v2.8.3+incompatible h1:AtKxIZ36LoNK51+Z6RpzLpddBirtxJnzDrHLEKxTAYk=
github.com/docker/distribution v2.8.3+incompatible/go.mod h1:J2gT2udsDAN96Uj4KfcMRqY0/ypR+oyYUYmja8H+y+w=
github.com/docker/docker-credential-helpers v0.9.3 h1:gAm/VtF9wgqJMoxzT3Gj5p4AqIjCBS4wrsOh9yRqcz8=
github.com/docker/docker-credential-helpers v0.9.3/go.mod h1:x+4Gbw9aGmChi3qTLZj8Dfn0TD20M/fuWy0E5+WDeCo=
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
@@ -25,13 +33,17 @@ github.com/go-playground/validator/v10 v10.20.0 h1:K9ISHbSaI0lyB2eWMPJo+kOS/FBEx
github.com/go-playground/validator/v10 v10.20.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM=
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/go-containerregistry v0.20.5 h1:4RnlYcDs5hoA++CeFjlbZ/U9Yp1EuWr+UhhTyYQjOP0=
github.com/google/go-containerregistry v0.20.5/go.mod h1:Q14vdOOzug02bwnhMkZKD4e30pDaD9W65qzXpyzF49E=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/gorilla/websocket v1.5.1 h1:gmztn0JnHVt9JZquRuzLw3g4wouNVzKL15iLr/zn/QY=
github.com/gorilla/websocket v1.5.1/go.mod h1:x3kM2JMyaluk02fnUJpQuwD2dCS5NDG2ZHL0uE0tcaY=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM=
github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws=
@@ -40,15 +52,25 @@ github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mitchellh/go-homedir v1.1.0 h1:lukF9ziXFxDFPkA1vsr5zpc1XuPDn/wFntq5mG+4E0Y=
github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040=
github.com/opencontainers/image-spec v1.1.1/go.mod h1:qpqAh3Dmcf36wStyyWU+kCeDgrGnAve2nCC8+7h8Q0M=
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
@@ -65,6 +87,8 @@ github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
github.com/vbatts/tar-split v0.12.1 h1:CqKoORW7BUWBe7UL/iqTVvkTBOF8UvOMKOIZykxnnbo=
github.com/vbatts/tar-split v0.12.1/go.mod h1:eF6B6i6ftWQcDqEn3/iGFRFRo8cBIMSJVOpnNdfTMFA=
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
@@ -74,22 +98,23 @@ golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac=
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
golang.org/x/sync v0.14.0 h1:woo0S4Yywslg6hp4eUFjTVOyKt0RookbpAHG4c1HmhQ=
golang.org/x/sync v0.14.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y=
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=
golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk=
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/time v0.11.0 h1:/bpjEDfN9tkoN/ryeYHnv5hcMlc8ncjMcM4XBk5NWV0=
golang.org/x/time v0.11.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg=
google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
google.golang.org/protobuf v1.36.3 h1:82DV7MYdb8anAVi3qge1wSnMDrnKK7ebr+I0hHRN1BU=
google.golang.org/protobuf v1.36.3/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gotest.tools/v3 v3.0.3 h1:4AuOwCGf4lLR9u3YOe2awrHygurzhO/HeQ6laiA6Sx0=
gotest.tools/v3 v3.0.3/go.mod h1:Z7Lb0S5l+klDB31fvDQX8ss/FlKDxtlFlw3Oa8Ymbl8=
nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50=
rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=

59
src/http_client.go Normal file
View File

@@ -0,0 +1,59 @@
package main
import (
"net"
"net/http"
"time"
)
var (
// 全局HTTP客户端 - 用于代理请求(长超时)
globalHTTPClient *http.Client
// 搜索HTTP客户端 - 用于API请求短超时
searchHTTPClient *http.Client
)
// initHTTPClients 初始化HTTP客户端
func initHTTPClients() {
// 代理客户端配置 - 适用于大文件传输
globalHTTPClient = &http.Client{
Transport: &http.Transport{
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).DialContext,
MaxIdleConns: 1000,
MaxIdleConnsPerHost: 1000,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
ResponseHeaderTimeout: 300 * time.Second,
},
}
// 搜索客户端配置 - 适用于API调用
searchHTTPClient = &http.Client{
Timeout: 10 * time.Second,
Transport: &http.Transport{
DialContext: (&net.Dialer{
Timeout: 5 * time.Second,
KeepAlive: 30 * time.Second,
}).DialContext,
MaxIdleConns: 100,
MaxIdleConnsPerHost: 10,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 5 * time.Second,
DisableCompression: false,
},
}
}
// GetGlobalHTTPClient 获取全局HTTP客户端用于代理
func GetGlobalHTTPClient() *http.Client {
return globalHTTPClient
}
// GetSearchHTTPClient 获取搜索HTTP客户端用于API调用
func GetSearchHTTPClient() *http.Client {
return searchHTTPClient
}

View File

@@ -1,24 +1,13 @@
package main
import (
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net"
"net/http"
"os"
"regexp"
"strconv"
"strings"
"sync"
"time"
)
const (
sizeLimit = 1024 * 1024 * 1024 * 2 // 允许的文件大小默认2GB
host = "0.0.0.0" // 监听地址
port = 5000 // 监听端口
)
var (
@@ -34,78 +23,68 @@ var (
regexp.MustCompile(`^(?:https?://)?download\.docker\.com/([^/]+)/.*\.(tgz|zip)$`),
regexp.MustCompile(`^(?:https?://)?(github|opengraph)\.githubassets\.com/([^/]+)/.+?$`),
}
httpClient *http.Client
config *Config
configLock sync.RWMutex
globalLimiter *IPRateLimiter
)
type Config struct {
WhiteList []string `json:"whiteList"`
BlackList []string `json:"blackList"`
}
func main() {
// 加载配置
if err := LoadConfig(); err != nil {
fmt.Printf("配置加载失败: %v\n", err)
return
}
// 初始化HTTP客户端
initHTTPClients()
// 初始化限流器
initLimiter()
// 初始化Docker流式代理
initDockerProxy()
gin.SetMode(gin.ReleaseMode)
router := gin.Default()
httpClient = &http.Client{
Transport: &http.Transport{
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).DialContext,
MaxIdleConns: 1000,
MaxIdleConnsPerHost: 1000,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
ResponseHeaderTimeout: 300 * time.Second,
},
}
loadConfig()
go func() {
for {
time.Sleep(10 * time.Minute)
loadConfig()
}
}()
// 初始化Skopeo相关路由 - 在任何通配符路由之前注册
// 初始化skopeo路由静态文件和API路由
initSkopeoRoutes(router)
// 单独处理根路径请求,避免冲突
// 单独处理根路径请求
router.GET("/", func(c *gin.Context) {
c.File("./public/index.html")
})
// 指定具体的静态文件路径,避免使用通配符
// 指定具体的静态文件路径
router.Static("/public", "./public")
// 对于.html等特定文件注册
router.GET("/skopeo.html", func(c *gin.Context) {
c.File("./public/skopeo.html")
})
router.GET("/search.html", func(c *gin.Context) {
c.File("./public/search.html")
})
// 图标文件
router.GET("/favicon.ico", func(c *gin.Context) {
c.File("./public/favicon.ico")
})
// 注册dockerhub搜索路由
RegisterSearchRoute(router)
// 创建GitHub文件下载专用的限流器
githubLimiter := NewIPRateLimiter()
// 注册Docker认证路由/token*
router.Any("/token", RateLimitMiddleware(globalLimiter), ProxyDockerAuthGin)
router.Any("/token/*path", RateLimitMiddleware(globalLimiter), ProxyDockerAuthGin)
// 注册Docker Registry代理路由
router.Any("/v2/*path", RateLimitMiddleware(globalLimiter), ProxyDockerRegistryGin)
// 注册NoRoute处理器应用限流中间件
router.NoRoute(RateLimitMiddleware(githubLimiter), handler)
router.NoRoute(RateLimitMiddleware(globalLimiter), handler)
err := router.Run(fmt.Sprintf("%s:%d", host, port))
cfg := GetConfig()
fmt.Printf("启动成功项目地址https://github.com/sky22333/hubproxy \n")
err := router.Run(fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port))
if err != nil {
fmt.Printf("Error starting server: %v\n", err)
fmt.Printf("启动服务失败: %v\n", err)
}
}
@@ -123,12 +102,17 @@ func handler(c *gin.Context) {
matches := checkURL(rawPath)
if matches != nil {
if len(config.WhiteList) > 0 && !checkList(matches, config.WhiteList) {
c.String(http.StatusForbidden, "不在白名单内,限制访问。")
return
// GitHub仓库访问控制检查
if allowed, reason := GlobalAccessController.CheckGitHubAccess(matches); !allowed {
// 构建仓库名用于日志
var repoPath string
if len(matches) >= 2 {
username := matches[0]
repoName := strings.TrimSuffix(matches[1], ".git")
repoPath = username + "/" + repoName
}
if len(config.BlackList) > 0 && checkList(matches, config.BlackList) {
c.String(http.StatusForbidden, "黑名单限制访问")
fmt.Printf("GitHub仓库 %s 访问被拒绝: %s\n", repoPath, reason)
c.String(http.StatusForbidden, reason)
return
}
} else {
@@ -143,7 +127,19 @@ func handler(c *gin.Context) {
proxy(c, rawPath)
}
func proxy(c *gin.Context, u string) {
proxyWithRedirect(c, u, 0)
}
func proxyWithRedirect(c *gin.Context, u string, redirectCount int) {
// 限制最大重定向次数,防止无限递归
const maxRedirects = 20
if redirectCount > maxRedirects {
c.String(http.StatusLoopDetected, "重定向次数过多,可能存在循环重定向")
return
}
req, err := http.NewRequest(c.Request.Method, u, c.Request.Body)
if err != nil {
c.String(http.StatusInternalServerError, fmt.Sprintf("server error %v", err))
@@ -157,22 +153,23 @@ func proxy(c *gin.Context, u string) {
}
req.Header.Del("Host")
resp, err := httpClient.Do(req)
resp, err := GetGlobalHTTPClient().Do(req)
if err != nil {
c.String(http.StatusInternalServerError, fmt.Sprintf("server error %v", err))
return
}
defer func(Body io.ReadCloser) {
err := Body.Close()
if err != nil {
defer func() {
if err := resp.Body.Close(); err != nil {
fmt.Printf("关闭响应体失败: %v\n", err)
}
}(resp.Body)
}()
// 检查文件大小限制
cfg := GetConfig()
if contentLength := resp.Header.Get("Content-Length"); contentLength != "" {
if size, err := strconv.Atoi(contentLength); err == nil && size > sizeLimit {
c.String(http.StatusRequestEntityTooLarge, "File too large.")
if size, err := strconv.ParseInt(contentLength, 10, 64); err == nil && size > cfg.Server.FileSize {
c.String(http.StatusRequestEntityTooLarge,
fmt.Sprintf("文件过大,限制大小: %d MB", cfg.Server.FileSize/(1024*1024)))
return
}
}
@@ -200,7 +197,8 @@ func proxy(c *gin.Context, u string) {
if checkURL(location) != nil {
c.Header("Location", "/"+location)
} else {
proxy(c, location)
// 递归处理重定向,增加计数防止无限循环
proxyWithRedirect(c, location, redirectCount+1)
return
}
}
@@ -236,31 +234,6 @@ func proxy(c *gin.Context, u string) {
}
}
func loadConfig() {
file, err := os.Open("config.json")
if err != nil {
fmt.Printf("Error loading config: %v\n", err)
return
}
defer func(file *os.File) {
err := file.Close()
if err != nil {
}
}(file)
var newConfig Config
decoder := json.NewDecoder(file)
if err := decoder.Decode(&newConfig); err != nil {
fmt.Printf("Error decoding config: %v\n", err)
return
}
configLock.Lock()
config = &newConfig
configLock.Unlock()
}
func checkURL(u string) []string {
for _, exp := range exps {
if matches := exp.FindStringSubmatch(u); matches != nil {
@@ -270,11 +243,4 @@ func checkURL(u string) []string {
return nil
}
func checkList(matches, list []string) bool {
for _, item := range list {
if strings.HasPrefix(matches[0], item) {
return true
}
}
return false
}

View File

@@ -163,7 +163,7 @@ func processLine(line string, host string, lineNum int) string {
})
}
// modifyGitHubURL 修改GitHub URL添加代理域名前缀
// 判断代理域名前缀
func modifyGitHubURL(url string, host string) string {
for _, domain := range gitHubDomains {
hasHttps := strings.HasPrefix(url, "https://"+domain)

View File

Before

Width:  |  Height:  |  Size: 2.0 KiB

After

Width:  |  Height:  |  Size: 2.0 KiB

View File

@@ -3,8 +3,6 @@ package main
import (
"fmt"
"net"
"os"
"strconv"
"strings"
"sync"
"time"
@@ -13,33 +11,14 @@ import (
"golang.org/x/time/rate"
)
// IP限流配置
var (
// 默认限流每个IP每1小时允许20个请求
DefaultRateLimit = 20.0 // 默认限制请求数
DefaultRatePeriodHours = 1.0 // 默认时间周期(小时)
// 白名单列表支持IP和CIDR格式"192.168.1.1", "10.0.0.0/8"
WhitelistIPs = []string{
"127.0.0.1", // 本地回环地址
"10.0.0.0/8", // 内网地址段
"172.16.0.0/12", // 内网地址段
"192.168.0.0/16", // 内网地址段
}
// 黑名单列表支持IP和CIDR格式
BlacklistIPs = []string{
// 示例: "1.2.3.4", "5.6.7.0/24"
}
// 清理间隔:多久清理一次过期的限流器
CleanupInterval = 1 * time.Hour
// IP限流器缓存上限超过此数量将触发清理
const (
// 清理间隔
CleanupInterval = 10 * time.Minute
// 最大IP缓存数量防止内存过度占用
MaxIPCacheSize = 10000
)
// IPRateLimiter 定义IP限流器结构
// IPRateLimiter IP限流器结构
type IPRateLimiter struct {
ips map[string]*rateLimiterEntry // IP到限流器的映射
mu *sync.RWMutex // 读写锁,保证并发安全
@@ -49,45 +28,20 @@ type IPRateLimiter struct {
blacklist []*net.IPNet // 黑名单IP段
}
// rateLimiterEntry 限流器条目,包含限流器和最后访问时间
// rateLimiterEntry 限流器条目
type rateLimiterEntry struct {
limiter *rate.Limiter // 限流器
lastAccess time.Time // 最后访问时间
}
// NewIPRateLimiter 创建新的IP限流器
func NewIPRateLimiter() *IPRateLimiter {
// 从环境变量读取限流配置(如果有)
rateLimit := DefaultRateLimit
ratePeriod := DefaultRatePeriodHours
if val, exists := os.LookupEnv("RATE_LIMIT"); exists {
if parsed, err := strconv.ParseFloat(val, 64); err == nil && parsed > 0 {
rateLimit = parsed
}
}
if val, exists := os.LookupEnv("RATE_PERIOD_HOURS"); exists {
if parsed, err := strconv.ParseFloat(val, 64); err == nil && parsed > 0 {
ratePeriod = parsed
}
}
// 从环境变量读取白名单(如果有)
whitelistIPs := WhitelistIPs
if val, exists := os.LookupEnv("IP_WHITELIST"); exists && val != "" {
whitelistIPs = append(whitelistIPs, strings.Split(val, ",")...)
}
// 从环境变量读取黑名单(如果有)
blacklistIPs := BlacklistIPs
if val, exists := os.LookupEnv("IP_BLACKLIST"); exists && val != "" {
blacklistIPs = append(blacklistIPs, strings.Split(val, ",")...)
}
// initGlobalLimiter 初始化全局限流器
func initGlobalLimiter() *IPRateLimiter {
// 获取配置
cfg := GetConfig()
// 解析白名单IP段
whitelist := make([]*net.IPNet, 0, len(whitelistIPs))
for _, item := range whitelistIPs {
whitelist := make([]*net.IPNet, 0, len(cfg.Security.WhiteList))
for _, item := range cfg.Security.WhiteList {
if item = strings.TrimSpace(item); item != "" {
if !strings.Contains(item, "/") {
item = item + "/32" // 单个IP转为CIDR格式
@@ -95,13 +49,15 @@ func NewIPRateLimiter() *IPRateLimiter {
_, ipnet, err := net.ParseCIDR(item)
if err == nil {
whitelist = append(whitelist, ipnet)
} else {
fmt.Printf("警告: 无效的白名单IP格式: %s\n", item)
}
}
}
// 解析黑名单IP段
blacklist := make([]*net.IPNet, 0, len(blacklistIPs))
for _, item := range blacklistIPs {
blacklist := make([]*net.IPNet, 0, len(cfg.Security.BlackList))
for _, item := range cfg.Security.BlackList {
if item = strings.TrimSpace(item); item != "" {
if !strings.Contains(item, "/") {
item = item + "/32" // 单个IP转为CIDR格式
@@ -109,19 +65,26 @@ func NewIPRateLimiter() *IPRateLimiter {
_, ipnet, err := net.ParseCIDR(item)
if err == nil {
blacklist = append(blacklist, ipnet)
} else {
fmt.Printf("警告: 无效的黑名单IP格式: %s\n", item)
}
}
}
// 计算速率:将 "每N小时X个请求" 转换为 "每秒Y个请求"
// rate.Limit的单位是每秒允许的请求数
ratePerSecond := rate.Limit(rateLimit / (ratePeriod * 3600))
ratePerSecond := rate.Limit(float64(cfg.RateLimit.RequestLimit) / (cfg.RateLimit.PeriodHours * 3600))
// 令牌桶容量设置为最大突发请求数,建议设为限制值的一半以允许合理突发
burstSize := cfg.RateLimit.RequestLimit
if burstSize < 1 {
burstSize = 1 // 至少允许1个请求
}
limiter := &IPRateLimiter{
ips: make(map[string]*rateLimiterEntry),
mu: &sync.RWMutex{},
r: ratePerSecond,
b: int(rateLimit), // 令牌桶容量设为允许的请求总数
b: burstSize,
whitelist: whitelist,
blacklist: blacklist,
}
@@ -129,9 +92,17 @@ func NewIPRateLimiter() *IPRateLimiter {
// 启动定期清理goroutine
go limiter.cleanupRoutine()
fmt.Printf("限流器初始化: %d请求/%g小时, 白名单 %d个, 黑名单 %d个\n",
cfg.RateLimit.RequestLimit, cfg.RateLimit.PeriodHours, len(whitelist), len(blacklist))
return limiter
}
// initLimiter 初始化限流器(保持向后兼容)
func initLimiter() {
globalLimiter = initGlobalLimiter()
}
// cleanupRoutine 定期清理过期的限流器
func (i *IPRateLimiter) cleanupRoutine() {
ticker := time.NewTicker(CleanupInterval)
@@ -168,9 +139,29 @@ func (i *IPRateLimiter) cleanupRoutine() {
}
}
// extractIPFromAddress 从地址中提取纯IP去除端口号
func extractIPFromAddress(address string) string {
// 处理IPv6地址 [::1]:8080 格式
if strings.HasPrefix(address, "[") {
if endIndex := strings.Index(address, "]"); endIndex != -1 {
return address[1:endIndex]
}
}
// 处理IPv4地址 192.168.1.1:8080 格式
if lastColon := strings.LastIndex(address, ":"); lastColon != -1 {
return address[:lastColon]
}
// 如果没有端口号,直接返回
return address
}
// isIPInCIDRList 检查IP是否在CIDR列表中
func isIPInCIDRList(ip string, cidrList []*net.IPNet) bool {
parsedIP := net.ParseIP(ip)
// 先提取纯IP地址
cleanIP := extractIPFromAddress(ip)
parsedIP := net.ParseIP(cleanIP)
if parsedIP == nil {
return false
}
@@ -185,19 +176,22 @@ func isIPInCIDRList(ip string, cidrList []*net.IPNet) bool {
// GetLimiter 获取指定IP的限流器同时返回是否允许访问
func (i *IPRateLimiter) GetLimiter(ip string) (*rate.Limiter, bool) {
// 提取纯IP地址
cleanIP := extractIPFromAddress(ip)
// 检查是否在黑名单中
if isIPInCIDRList(ip, i.blacklist) {
if isIPInCIDRList(cleanIP, i.blacklist) {
return nil, false // 黑名单中的IP不允许访问
}
// 检查是否在白名单中
if isIPInCIDRList(ip, i.whitelist) {
if isIPInCIDRList(cleanIP, i.whitelist) {
return rate.NewLimiter(rate.Inf, i.b), true // 白名单中的IP不受限制
}
// 从缓存获取限流器
// 使用纯IP作为缓存键
i.mu.RLock()
entry, exists := i.ips[ip]
entry, exists := i.ips[cleanIP]
i.mu.RUnlock()
now := time.Now()
@@ -209,7 +203,7 @@ func (i *IPRateLimiter) GetLimiter(ip string) (*rate.Limiter, bool) {
limiter: rate.NewLimiter(i.r, i.b),
lastAccess: now,
}
i.ips[ip] = entry
i.ips[cleanIP] = entry
i.mu.Unlock()
} else {
// 更新最后访问时间
@@ -244,14 +238,18 @@ func RateLimitMiddleware(limiter *IPRateLimiter) gin.HandlerFunc {
ip = c.ClientIP()
}
// 日志记录请求IP和头信息调试用
fmt.Printf("请求IP: %s, X-Forwarded-For: %s, X-Real-IP: %s\n",
// 提取纯IP地址去除端口号
cleanIP := extractIPFromAddress(ip)
// 日志记录请求IP和头信息
fmt.Printf("请求IP: %s (去除端口后: %s), X-Forwarded-For: %s, X-Real-IP: %s\n",
ip,
cleanIP,
c.GetHeader("X-Forwarded-For"),
c.GetHeader("X-Real-IP"))
// 获取限流器并检查是否允许访问
ipLimiter, allowed := limiter.GetLimiter(ip)
ipLimiter, allowed := limiter.GetLimiter(cleanIP)
// 如果IP在黑名单中
if !allowed {
@@ -278,8 +276,11 @@ func RateLimitMiddleware(limiter *IPRateLimiter) gin.HandlerFunc {
// ApplyRateLimit 应用限流到特定路由
func ApplyRateLimit(router *gin.Engine, path string, method string, handler gin.HandlerFunc) {
// 创建限流器(如果未创建)
limiter := NewIPRateLimiter()
// 使用全局限流器
limiter := globalLimiter
if limiter == nil {
limiter = initGlobalLimiter()
}
// 根据HTTP方法应用限流
switch method {

View File

@@ -89,17 +89,7 @@ var (
}
)
// 添加全局HTTP客户端配置
var defaultHTTPClient = &http.Client{
Timeout: 10 * time.Second,
Transport: &http.Transport{
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
DisableCompression: true,
DisableKeepAlives: false,
MaxIdleConnsPerHost: 10,
},
}
// HTTP客户端配置在 http_client.go 中统一管理
func (c *Cache) Get(key string) (interface{}, bool) {
c.mu.RLock()
@@ -267,68 +257,41 @@ func searchDockerHub(ctx context.Context, query string, page, pageSize int) (*Se
fullURL = fullURL + "?" + params.Encode()
// 创建请求
req, err := http.NewRequestWithContext(ctx, "GET", fullURL, nil)
// 使用统一的搜索HTTP客户端
resp, err := GetSearchHTTPClient().Get(fullURL)
if err != nil {
return nil, fmt.Errorf("创建请求失败: %v", err)
return nil, fmt.Errorf("请求Docker Hub API失败: %v", err)
}
// 设置请求头
req.Header.Set("Accept", "application/json")
req.Header.Set("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36")
// 使用全局HTTP客户端
client := defaultHTTPClient
var result *SearchResult
var lastErr error
// 重试逻辑
for retries := 3; retries > 0; retries-- {
resp, err := client.Do(req)
if err != nil {
lastErr = fmt.Errorf("发送请求失败: %v", err)
if !isRetryableError(err) {
break
defer func() {
if err := resp.Body.Close(); err != nil {
fmt.Printf("关闭搜索响应体失败: %v\n", err)
}
time.Sleep(time.Second * time.Duration(4-retries))
continue
}
defer resp.Body.Close()
}()
body, err := io.ReadAll(resp.Body)
if err != nil {
lastErr = fmt.Errorf("读取响应失败: %v", err)
if !isRetryableError(err) {
break
}
time.Sleep(time.Second * time.Duration(4-retries))
continue
return nil, fmt.Errorf("读取响应失败: %v", err)
}
if resp.StatusCode != http.StatusOK {
switch resp.StatusCode {
case http.StatusTooManyRequests:
lastErr = fmt.Errorf("请求过于频繁,请稍后重试")
return nil, fmt.Errorf("请求过于频繁,请稍后重试")
case http.StatusNotFound:
if isUserRepo && namespace != "" {
// 如果用户仓库搜索失败,尝试普通搜索
return searchDockerHub(ctx, repoName, page, pageSize)
}
lastErr = fmt.Errorf("未找到相关镜像")
return nil, fmt.Errorf("未找到相关镜像")
case http.StatusBadGateway, http.StatusServiceUnavailable:
lastErr = fmt.Errorf("Docker Hub服务暂时不可用请稍后重试")
return nil, fmt.Errorf("Docker Hub服务暂时不可用请稍后重试")
default:
lastErr = fmt.Errorf("请求失败: 状态码=%d, 响应=%s", resp.StatusCode, string(body))
return nil, fmt.Errorf("请求失败: 状态码=%d, 响应=%s", resp.StatusCode, string(body))
}
if !isRetryableError(lastErr) {
break
}
time.Sleep(time.Second * time.Duration(4-retries))
continue
}
// 解析响应
var result *SearchResult
if isUserRepo && namespace != "" {
// 解析用户仓库列表响应
var userRepos struct {
@@ -338,8 +301,7 @@ func searchDockerHub(ctx context.Context, query string, page, pageSize int) (*Se
Results []Repository `json:"results"`
}
if err := json.Unmarshal(body, &userRepos); err != nil {
lastErr = fmt.Errorf("解析响应失败: %v", err)
break
return nil, fmt.Errorf("解析响应失败: %v", err)
}
// 转换为SearchResult格式
@@ -373,8 +335,7 @@ func searchDockerHub(ctx context.Context, query string, page, pageSize int) (*Se
// 解析普通搜索响应
result = &SearchResult{}
if err := json.Unmarshal(body, &result); err != nil {
lastErr = fmt.Errorf("解析响应失败: %v", err)
break
return nil, fmt.Errorf("解析响应失败: %v", err)
}
// 处理搜索结果
@@ -409,15 +370,6 @@ func searchDockerHub(ctx context.Context, query string, page, pageSize int) (*Se
}
}
// 成功获取结果,跳出重试循环
lastErr = nil
break
}
if lastErr != nil {
return nil, fmt.Errorf("搜索失败: %v", lastErr)
}
// 缓存结果
searchCache.Set(cacheKey, result)
return result, nil
@@ -459,24 +411,16 @@ func getRepositoryTags(ctx context.Context, namespace, name string) ([]TagInfo,
fullURL := baseURL + "?" + params.Encode()
// 使用全局HTTP客户端
client := defaultHTTPClient
req, err := http.NewRequestWithContext(ctx, "GET", fullURL, nil)
if err != nil {
return nil, fmt.Errorf("创建请求失败: %v", err)
}
// 添加必要的请求头
req.Header.Set("Accept", "application/json")
req.Header.Set("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36")
// 发送请求
resp, err := client.Do(req)
// 使用统一的搜索HTTP客户端
resp, err := GetSearchHTTPClient().Get(fullURL)
if err != nil {
return nil, fmt.Errorf("发送请求失败: %v", err)
}
defer resp.Body.Close()
defer func() {
if err := resp.Body.Close(); err != nil {
fmt.Printf("关闭搜索响应体失败: %v\n", err)
}
}()
// 读取响应体
body, err := io.ReadAll(resp.Body)

View File

@@ -55,6 +55,9 @@ type DownloadTask struct {
ProgressLock sync.RWMutex `json:"-"` // 进度锁
ImageLock sync.RWMutex `json:"-"` // 镜像列表锁
updateChan chan *ProgressUpdate `json:"-"` // 进度更新通道
done chan struct{} `json:"-"` // 用于安全关闭goroutine
once sync.Once `json:"-"` // 确保只关闭一次
createTime time.Time `json:"-"` // 创建时间,用于清理
}
// 进度更新消息
@@ -110,6 +113,12 @@ func initSkopeoRoutes(router *gin.Engine) {
// 启动清理过期文件的goroutine
go cleanupTempFiles()
// 启动WebSocket连接清理goroutine
go cleanupWebSocketConnections()
// 启动过期任务清理goroutine
go cleanupExpiredTasks()
}
// 处理WebSocket连接
@@ -146,6 +155,10 @@ func handleWebSocket(c *gin.Context) {
tasksLock.Unlock()
}
// 设置WebSocket超时
conn.SetReadDeadline(time.Now().Add(120 * time.Second))
conn.SetWriteDeadline(time.Now().Add(60 * time.Second))
// 处理WebSocket关闭
conn.SetCloseHandler(func(code int, text string) error {
client.CloseOnce.Do(func() {
@@ -231,22 +244,30 @@ func generateTaskID() string {
// 初始化任务并启动进度处理器
func initTask(task *DownloadTask) {
// 创建进度更新通道
// 创建进度更新通道和控制通道
task.updateChan = make(chan *ProgressUpdate, 100)
task.done = make(chan struct{})
task.createTime = time.Now()
// 启动进度处理goroutine
go func() {
for update := range task.updateChan {
defer func() {
if r := recover(); r != nil {
fmt.Printf("任务 %s 进度处理goroutine异常: %v\n", task.ID, r)
}
}()
// 处理消息的函数
processUpdate := func(update *ProgressUpdate) {
if update == nil {
// 通道关闭信号
break
return
}
// 获取更新的镜像
task.ImageLock.RLock()
if update.ImageIndex < 0 || update.ImageIndex >= len(task.Images) {
task.ImageLock.RUnlock()
continue
return
}
imgTask := task.Images[update.ImageIndex]
task.ImageLock.RUnlock()
@@ -291,11 +312,74 @@ func initTask(task *DownloadTask) {
// 发送更新到客户端
sendTaskUpdate(task)
}
// 主处理循环
for {
select {
case update := <-task.updateChan:
if update == nil {
// 通道关闭信号,直接退出
return
}
processUpdate(update)
case <-task.done:
// 收到关闭信号进入drain模式处理剩余消息
goto drainMode
}
}
drainMode:
// 处理通道中剩余的所有消息,确保不丢失任何更新
for {
select {
case update := <-task.updateChan:
if update == nil {
// 通道关闭,安全退出
return
}
processUpdate(update)
default:
// 没有更多待处理的消息,安全退出
return
}
}
}()
}
// 安全关闭任务的goroutine和通道
func (task *DownloadTask) Close() {
task.once.Do(func() {
close(task.done)
// 给一点时间让goroutine退出然后安全关闭updateChan
time.AfterFunc(100*time.Millisecond, func() {
task.safeCloseUpdateChan()
})
})
}
// 安全关闭updateChan防止重复关闭
func (task *DownloadTask) safeCloseUpdateChan() {
defer func() {
if r := recover(); r != nil {
// 捕获关闭已关闭channel的panic忽略它
fmt.Printf("任务 %s: updateChan已经关闭\n", task.ID)
}
}()
close(task.updateChan)
}
// 发送进度更新
func sendProgressUpdate(task *DownloadTask, index int, progress float64, status string, errorMsg string) {
// 检查任务是否已经关闭
select {
case <-task.done:
// 任务已关闭,不发送更新
return
default:
}
// 安全发送进度更新
select {
case task.updateChan <- &ProgressUpdate{
TaskID: task.ID,
@@ -305,6 +389,9 @@ func sendProgressUpdate(task *DownloadTask, index int, progress float64, status
Error: errorMsg,
}:
// 成功发送
case <-task.done:
// 在发送过程中任务被关闭
return
default:
// 通道已满,丢弃更新
fmt.Printf("Warning: Update channel for task %s is full\n", task.ID)
@@ -360,9 +447,29 @@ func handleDownload(c *gin.Context) {
return
}
// 添加镜像数量限制10个防止恶意刷流量
if len(req.Images) > 10 {
c.JSON(http.StatusBadRequest, gin.H{"error": "您下载的数量太多,宝宝承受不住"})
// Docker镜像访问控制检查
for _, image := range req.Images {
if allowed, reason := GlobalAccessController.CheckDockerAccess(image); !allowed {
fmt.Printf("Docker镜像 %s 下载被拒绝: %s\n", image, reason)
c.JSON(http.StatusForbidden, gin.H{
"error": fmt.Sprintf("镜像 %s 访问被限制: %s", image, reason),
})
return
}
}
// 获取配置中的镜像数量限制
cfg := GetConfig()
maxImages := cfg.Download.MaxImages
if maxImages <= 0 {
maxImages = 10 // 安全默认值,防止配置错误
}
// 检查镜像数量限制,防止恶意刷流量
if len(req.Images) > maxImages {
c.JSON(http.StatusBadRequest, gin.H{
"error": fmt.Sprintf("单次下载镜像数量超过限制,最多允许 %d 个镜像", maxImages),
})
return
}
@@ -400,9 +507,11 @@ func handleDownload(c *gin.Context) {
// 异步处理下载
go func() {
defer func() {
// 任务完成后安全关闭更新通道
task.safeCloseUpdateChan()
}()
processDownloadTask(task, req.Platform)
// 任务完成后关闭更新通道
close(task.updateChan)
}()
c.JSON(http.StatusOK, gin.H{
@@ -482,6 +591,8 @@ func processDownloadTask(task *DownloadTask, platform string) {
task.Status = StatusFailed
task.StatusLock.Unlock()
sendTaskUpdate(task)
// 任务失败时关闭goroutine
task.Close()
return
}
@@ -557,6 +668,9 @@ func processDownloadTask(task *DownloadTask, platform string) {
// 确保所有进度都达到100%
ensureTaskCompletion(task)
// 任务完成时关闭goroutine
task.Close()
fmt.Printf("任务 %s 全部完成: %d/%d\n", task.ID, task.CompletedCount, task.TotalCount)
}
@@ -737,12 +851,29 @@ func downloadImageWithContext(ctx context.Context, task *DownloadTask, index int
// 读取标准输出
go func() {
defer func() {
// 确保pipe在goroutine退出时关闭
stdout.Close()
}()
scanner := bufio.NewScanner(stdout)
for scanner.Scan() {
for {
// 检查context是否已取消
select {
case <-ctx.Done():
return
default:
}
if !scanner.Scan() {
break // EOF或错误正常退出
}
output := scanner.Text()
fmt.Printf("镜像 %s 标准输出: %s\n", imgTask.Image, output)
select {
case outputChan <- output:
case <-ctx.Done():
return
default:
// 通道已满,丢弃
}
@@ -751,12 +882,29 @@ func downloadImageWithContext(ctx context.Context, task *DownloadTask, index int
// 读取错误输出
go func() {
defer func() {
// 确保pipe在goroutine退出时关闭
stderr.Close()
}()
scanner := bufio.NewScanner(stderr)
for scanner.Scan() {
for {
// 检查context是否已取消
select {
case <-ctx.Done():
return
default:
}
if !scanner.Scan() {
break // EOF或错误正常退出
}
output := scanner.Text()
fmt.Printf("镜像 %s 错误输出: %s\n", imgTask.Image, output)
select {
case outputChan <- output:
case <-ctx.Done():
return
default:
// 通道已满,丢弃
}
@@ -911,17 +1059,15 @@ func sendTaskUpdate(task *DownloadTask) {
}
}
// 发送单个镜像更新 - 保持兼容性
func sendImageUpdate(task *DownloadTask, imageIndex int) {
sendTaskUpdate(task)
}
// 提供文件下载
func serveFile(c *gin.Context) {
filename := c.Param("filename")
// 安全检查,防止任意文件访问
if strings.Contains(filename, "..") {
// 增强安全检查,防止路径遍历攻击
if strings.Contains(filename, "..") ||
strings.Contains(filename, "/") ||
strings.Contains(filename, "\\") ||
strings.Contains(filename, "\x00") {
c.JSON(http.StatusForbidden, gin.H{"error": "无效的文件名"})
return
}
@@ -1118,3 +1264,108 @@ func checkForCompletionMarkers(output string) bool {
return false
}
// cleanupWebSocketConnections 定期清理无效的WebSocket连接
func cleanupWebSocketConnections() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for range ticker.C {
clientLock.Lock()
disconnectedClients := make([]string, 0)
for taskID, client := range clients {
// 检查连接是否仍然活跃
if err := client.Conn.WriteMessage(websocket.PingMessage, []byte{}); err != nil {
// 连接已断开,标记待清理
disconnectedClients = append(disconnectedClients, taskID)
}
}
// 清理断开的连接
for _, taskID := range disconnectedClients {
if client, exists := clients[taskID]; exists {
client.CloseOnce.Do(func() {
close(client.Send)
client.Conn.Close()
})
delete(clients, taskID)
}
}
clientLock.Unlock()
if len(disconnectedClients) > 0 {
fmt.Printf("清理了 %d 个断开的WebSocket连接\n", len(disconnectedClients))
}
}
}
// cleanupExpiredTasks 清理过期任务
func cleanupExpiredTasks() {
ticker := time.NewTicker(30 * time.Minute) // 每30分钟清理一次
defer ticker.Stop()
for range ticker.C {
now := time.Now()
expiredTasks := make([]string, 0)
tasksLock.Lock()
for taskID, task := range tasks {
// 清理超过2小时的已完成任务或超过6小时的任何任务
isExpired := false
task.StatusLock.RLock()
taskStatus := task.Status
task.StatusLock.RUnlock()
// 已完成或失败的任务2小时后清理
if (taskStatus == StatusCompleted || taskStatus == StatusFailed) &&
now.Sub(task.createTime) > 2*time.Hour {
isExpired = true
}
// 任何任务6小时后强制清理
if now.Sub(task.createTime) > 6*time.Hour {
isExpired = true
}
if isExpired {
expiredTasks = append(expiredTasks, taskID)
}
}
// 清理过期任务
for _, taskID := range expiredTasks {
if task, exists := tasks[taskID]; exists {
// 安全关闭任务的goroutine
task.Close()
// 清理临时文件
if task.TempDir != "" {
os.RemoveAll(task.TempDir)
}
if task.OutputFile != "" && fileExists(task.OutputFile) {
os.Remove(task.OutputFile)
}
delete(tasks, taskID)
}
}
tasksLock.Unlock()
if len(expiredTasks) > 0 {
fmt.Printf("清理了 %d 个过期任务\n", len(expiredTasks))
}
// 输出统计信息
tasksLock.Lock()
activeTaskCount := len(tasks)
tasksLock.Unlock()
clientLock.Lock()
activeClientCount := len(clients)
clientLock.Unlock()
fmt.Printf("当前活跃任务: %d, 活跃WebSocket连接: %d\n", activeTaskCount, activeClientCount)
}
}