Merge pull request #25 from beck-8/me/op_proxy

优化代理配置
This commit is contained in:
starry
2025-06-19 23:00:44 +08:00
committed by GitHub
14 changed files with 2592 additions and 2628 deletions

4
.gitignore vendored Normal file
View File

@@ -0,0 +1,4 @@
.idea
.vscode
.DS_Store
hubproxy*

View File

@@ -138,11 +138,17 @@ blackList = [
"baduser/*" "baduser/*"
] ]
# SOCKS5代理配置,支持有用户名/密码认证和无认证模式 # 代理配置,支持有用户名/密码认证和无认证模式
# 无认证: socks5://127.0.0.1:1080 # 无认证: socks5://127.0.0.1:1080
# 有认证: socks5://username:password@127.0.0.1:1080 # 有认证: socks5://username:password@127.0.0.1:1080
# HTTP 代理示例
# http://username:password@127.0.0.1:7890
# SOCKS5 代理示例
# socks5://username:password@127.0.0.1:1080
# SOCKS5H 代理示例
# socks5h://username:password@127.0.0.1:1080
# 留空不使用代理 # 留空不使用代理
socks5 = "" proxy = ""
[download] [download]
# 批量下载离线镜像数量限制 # 批量下载离线镜像数量限制

View File

@@ -1,214 +1,212 @@
package main package main
import ( import (
"strings" "strings"
"sync" "sync"
) )
// ResourceType 资源类型 // ResourceType 资源类型
type ResourceType string type ResourceType string
const ( const (
ResourceTypeGitHub ResourceType = "github" ResourceTypeGitHub ResourceType = "github"
ResourceTypeDocker ResourceType = "docker" ResourceTypeDocker ResourceType = "docker"
) )
// AccessController 统一访问控制器 // AccessController 统一访问控制器
type AccessController struct { type AccessController struct {
mu sync.RWMutex mu sync.RWMutex
} }
// DockerImageInfo Docker镜像信息 // DockerImageInfo Docker镜像信息
type DockerImageInfo struct { type DockerImageInfo struct {
Namespace string Namespace string
Repository string Repository string
Tag string Tag string
FullName string FullName string
} }
// 全局访问控制器实例 // 全局访问控制器实例
var GlobalAccessController = &AccessController{} var GlobalAccessController = &AccessController{}
// ParseDockerImage 解析Docker镜像名称 // ParseDockerImage 解析Docker镜像名称
func (ac *AccessController) ParseDockerImage(image string) DockerImageInfo { func (ac *AccessController) ParseDockerImage(image string) DockerImageInfo {
image = strings.TrimPrefix(image, "docker://") image = strings.TrimPrefix(image, "docker://")
var tag string var tag string
if idx := strings.LastIndex(image, ":"); idx != -1 { if idx := strings.LastIndex(image, ":"); idx != -1 {
part := image[idx+1:] part := image[idx+1:]
if !strings.Contains(part, "/") { if !strings.Contains(part, "/") {
tag = part tag = part
image = image[:idx] image = image[:idx]
} }
} }
if tag == "" { if tag == "" {
tag = "latest" tag = "latest"
} }
var namespace, repository string var namespace, repository string
if strings.Contains(image, "/") { if strings.Contains(image, "/") {
parts := strings.Split(image, "/") parts := strings.Split(image, "/")
if len(parts) >= 2 { if len(parts) >= 2 {
if strings.Contains(parts[0], ".") { if strings.Contains(parts[0], ".") {
if len(parts) >= 3 { if len(parts) >= 3 {
namespace = parts[1] namespace = parts[1]
repository = parts[2] repository = parts[2]
} else { } else {
namespace = "library" namespace = "library"
repository = parts[1] repository = parts[1]
} }
} else { } else {
namespace = parts[0] namespace = parts[0]
repository = parts[1] repository = parts[1]
} }
} }
} else { } else {
namespace = "library" namespace = "library"
repository = image repository = image
} }
fullName := namespace + "/" + repository fullName := namespace + "/" + repository
return DockerImageInfo{ return DockerImageInfo{
Namespace: namespace, Namespace: namespace,
Repository: repository, Repository: repository,
Tag: tag, Tag: tag,
FullName: fullName, FullName: fullName,
} }
} }
// CheckDockerAccess 检查Docker镜像访问权限 // CheckDockerAccess 检查Docker镜像访问权限
func (ac *AccessController) CheckDockerAccess(image string) (allowed bool, reason string) { func (ac *AccessController) CheckDockerAccess(image string) (allowed bool, reason string) {
cfg := GetConfig() cfg := GetConfig()
// 解析镜像名称 // 解析镜像名称
imageInfo := ac.ParseDockerImage(image) imageInfo := ac.ParseDockerImage(image)
// 检查白名单(如果配置了白名单,则只允许白名单中的镜像) // 检查白名单(如果配置了白名单,则只允许白名单中的镜像)
if len(cfg.Proxy.WhiteList) > 0 { if len(cfg.Access.WhiteList) > 0 {
if !ac.matchImageInList(imageInfo, cfg.Proxy.WhiteList) { if !ac.matchImageInList(imageInfo, cfg.Access.WhiteList) {
return false, "不在Docker镜像白名单内" return false, "不在Docker镜像白名单内"
} }
} }
// 检查黑名单 // 检查黑名单
if len(cfg.Proxy.BlackList) > 0 { if len(cfg.Access.BlackList) > 0 {
if ac.matchImageInList(imageInfo, cfg.Proxy.BlackList) { if ac.matchImageInList(imageInfo, cfg.Access.BlackList) {
return false, "Docker镜像在黑名单内" return false, "Docker镜像在黑名单内"
} }
} }
return true, "" return true, ""
} }
// CheckGitHubAccess 检查GitHub仓库访问权限 // CheckGitHubAccess 检查GitHub仓库访问权限
func (ac *AccessController) CheckGitHubAccess(matches []string) (allowed bool, reason string) { func (ac *AccessController) CheckGitHubAccess(matches []string) (allowed bool, reason string) {
if len(matches) < 2 { if len(matches) < 2 {
return false, "无效的GitHub仓库格式" return false, "无效的GitHub仓库格式"
} }
cfg := GetConfig() cfg := GetConfig()
// 检查白名单 // 检查白名单
if len(cfg.Proxy.WhiteList) > 0 && !ac.checkList(matches, cfg.Proxy.WhiteList) { if len(cfg.Access.WhiteList) > 0 && !ac.checkList(matches, cfg.Access.WhiteList) {
return false, "不在GitHub仓库白名单内" return false, "不在GitHub仓库白名单内"
} }
// 检查黑名单 // 检查黑名单
if len(cfg.Proxy.BlackList) > 0 && ac.checkList(matches, cfg.Proxy.BlackList) { if len(cfg.Access.BlackList) > 0 && ac.checkList(matches, cfg.Access.BlackList) {
return false, "GitHub仓库在黑名单内" return false, "GitHub仓库在黑名单内"
} }
return true, "" return true, ""
} }
// matchImageInList 检查Docker镜像是否在指定列表中 // matchImageInList 检查Docker镜像是否在指定列表中
func (ac *AccessController) matchImageInList(imageInfo DockerImageInfo, list []string) bool { func (ac *AccessController) matchImageInList(imageInfo DockerImageInfo, list []string) bool {
fullName := strings.ToLower(imageInfo.FullName) fullName := strings.ToLower(imageInfo.FullName)
namespace := strings.ToLower(imageInfo.Namespace) namespace := strings.ToLower(imageInfo.Namespace)
for _, item := range list { for _, item := range list {
item = strings.ToLower(strings.TrimSpace(item)) item = strings.ToLower(strings.TrimSpace(item))
if item == "" { if item == "" {
continue continue
} }
if fullName == item { if fullName == item {
return true return true
} }
if item == namespace || item == namespace+"/*" { if item == namespace || item == namespace+"/*" {
return true return true
} }
if strings.HasSuffix(item, "*") { if strings.HasSuffix(item, "*") {
prefix := strings.TrimSuffix(item, "*") prefix := strings.TrimSuffix(item, "*")
if strings.HasPrefix(fullName, prefix) { if strings.HasPrefix(fullName, prefix) {
return true return true
} }
} }
if strings.HasPrefix(item, "*/") { if strings.HasPrefix(item, "*/") {
repoPattern := strings.TrimPrefix(item, "*/") repoPattern := strings.TrimPrefix(item, "*/")
if strings.HasSuffix(repoPattern, "*") { if strings.HasSuffix(repoPattern, "*") {
repoPrefix := strings.TrimSuffix(repoPattern, "*") repoPrefix := strings.TrimSuffix(repoPattern, "*")
if strings.HasPrefix(imageInfo.Repository, repoPrefix) { if strings.HasPrefix(imageInfo.Repository, repoPrefix) {
return true return true
} }
} else { } else {
if strings.ToLower(imageInfo.Repository) == repoPattern { if strings.ToLower(imageInfo.Repository) == repoPattern {
return true return true
} }
} }
} }
if strings.HasPrefix(fullName, item+"/") { if strings.HasPrefix(fullName, item+"/") {
return true return true
} }
} }
return false return false
} }
// checkList GitHub仓库检查逻辑 // checkList GitHub仓库检查逻辑
func (ac *AccessController) checkList(matches, list []string) bool { func (ac *AccessController) checkList(matches, list []string) bool {
if len(matches) < 2 { if len(matches) < 2 {
return false return false
} }
username := strings.ToLower(strings.TrimSpace(matches[0])) username := strings.ToLower(strings.TrimSpace(matches[0]))
repoName := strings.ToLower(strings.TrimSpace(strings.TrimSuffix(matches[1], ".git"))) repoName := strings.ToLower(strings.TrimSpace(strings.TrimSuffix(matches[1], ".git")))
fullRepo := username + "/" + repoName fullRepo := username + "/" + repoName
for _, item := range list { for _, item := range list {
item = strings.ToLower(strings.TrimSpace(item)) item = strings.ToLower(strings.TrimSpace(item))
if item == "" { if item == "" {
continue continue
} }
// 支持多种匹配模式 // 支持多种匹配模式
if fullRepo == item { if fullRepo == item {
return true return true
} }
// 用户级匹配 // 用户级匹配
if item == username || item == username+"/*" { if item == username || item == username+"/*" {
return true return true
} }
// 前缀匹配(支持通配符) // 前缀匹配(支持通配符)
if strings.HasSuffix(item, "*") { if strings.HasSuffix(item, "*") {
prefix := strings.TrimSuffix(item, "*") prefix := strings.TrimSuffix(item, "*")
if strings.HasPrefix(fullRepo, prefix) { if strings.HasPrefix(fullRepo, prefix) {
return true return true
} }
} }
// 子仓库匹配(防止 user/repo 匹配到 user/repo-fork // 子仓库匹配(防止 user/repo 匹配到 user/repo-fork
if strings.HasPrefix(fullRepo, item+"/") { if strings.HasPrefix(fullRepo, item+"/") {
return true return true
} }
} }
return false return false
} }

View File

@@ -1,275 +1,276 @@
package main package main
import ( import (
"fmt" "fmt"
"os" "os"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
"time" "time"
"github.com/pelletier/go-toml/v2" "github.com/pelletier/go-toml/v2"
) )
// RegistryMapping Registry映射配置 // RegistryMapping Registry映射配置
type RegistryMapping struct { type RegistryMapping struct {
Upstream string `toml:"upstream"` // 上游Registry地址 Upstream string `toml:"upstream"` // 上游Registry地址
AuthHost string `toml:"authHost"` // 认证服务器地址 AuthHost string `toml:"authHost"` // 认证服务器地址
AuthType string `toml:"authType"` // 认证类型: docker/github/google/basic AuthType string `toml:"authType"` // 认证类型: docker/github/google/basic
Enabled bool `toml:"enabled"` // 是否启用 Enabled bool `toml:"enabled"` // 是否启用
} }
// AppConfig 应用配置结构体 // AppConfig 应用配置结构体
type AppConfig struct { type AppConfig struct {
Server struct { Server struct {
Host string `toml:"host"` // 监听地址 Host string `toml:"host"` // 监听地址
Port int `toml:"port"` // 监听端口 Port int `toml:"port"` // 监听端口
FileSize int64 `toml:"fileSize"` // 文件大小限制(字节) FileSize int64 `toml:"fileSize"` // 文件大小限制(字节)
} `toml:"server"` } `toml:"server"`
RateLimit struct { RateLimit struct {
RequestLimit int `toml:"requestLimit"` // 每小时请求限制 RequestLimit int `toml:"requestLimit"` // 每小时请求限制
PeriodHours float64 `toml:"periodHours"` // 限制周期(小时) PeriodHours float64 `toml:"periodHours"` // 限制周期(小时)
} `toml:"rateLimit"` } `toml:"rateLimit"`
Security struct { Security struct {
WhiteList []string `toml:"whiteList"` // 白名单IP/CIDR列表 WhiteList []string `toml:"whiteList"` // 白名单IP/CIDR列表
BlackList []string `toml:"blackList"` // 黑名单IP/CIDR列表 BlackList []string `toml:"blackList"` // 黑名单IP/CIDR列表
} `toml:"security"` } `toml:"security"`
Proxy struct { Access struct {
WhiteList []string `toml:"whiteList"` // 代理白名单(仓库级别) WhiteList []string `toml:"whiteList"` // 代理白名单(仓库级别)
BlackList []string `toml:"blackList"` // 代理黑名单(仓库级别) BlackList []string `toml:"blackList"` // 代理黑名单(仓库级别)
Socks5 string `toml:"socks5"` // SOCKS5代理地址: socks5://[user:pass@]host:port Proxy string `toml:"proxy"` // 代理地址: 支持 http/https/socks5/socks5h
} `toml:"proxy"` } `toml:"proxy"`
Download struct { Download struct {
MaxImages int `toml:"maxImages"` // 单次下载最大镜像数量限制 MaxImages int `toml:"maxImages"` // 单次下载最大镜像数量限制
} `toml:"download"` } `toml:"download"`
Registries map[string]RegistryMapping `toml:"registries"` Registries map[string]RegistryMapping `toml:"registries"`
TokenCache struct { TokenCache struct {
Enabled bool `toml:"enabled"` // 是否启用token缓存 Enabled bool `toml:"enabled"` // 是否启用token缓存
DefaultTTL string `toml:"defaultTTL"` // 默认缓存时间 DefaultTTL string `toml:"defaultTTL"` // 默认缓存时间
} `toml:"tokenCache"` } `toml:"tokenCache"`
} }
var ( var (
appConfig *AppConfig appConfig *AppConfig
appConfigLock sync.RWMutex appConfigLock sync.RWMutex
cachedConfig *AppConfig cachedConfig *AppConfig
configCacheTime time.Time configCacheTime time.Time
configCacheTTL = 5 * time.Second configCacheTTL = 5 * time.Second
configCacheMutex sync.RWMutex configCacheMutex sync.RWMutex
) )
// DefaultConfig 返回默认配置 // todo:Refactoring is needed
func DefaultConfig() *AppConfig { // DefaultConfig 返回默认配置
return &AppConfig{ func DefaultConfig() *AppConfig {
Server: struct { return &AppConfig{
Host string `toml:"host"` Server: struct {
Port int `toml:"port"` Host string `toml:"host"`
FileSize int64 `toml:"fileSize"` Port int `toml:"port"`
}{ FileSize int64 `toml:"fileSize"`
Host: "0.0.0.0", }{
Port: 5000, Host: "0.0.0.0",
FileSize: 2 * 1024 * 1024 * 1024, // 2GB Port: 5000,
}, FileSize: 2 * 1024 * 1024 * 1024, // 2GB
RateLimit: struct { },
RequestLimit int `toml:"requestLimit"` RateLimit: struct {
PeriodHours float64 `toml:"periodHours"` RequestLimit int `toml:"requestLimit"`
}{ PeriodHours float64 `toml:"periodHours"`
RequestLimit: 20, }{
PeriodHours: 1.0, RequestLimit: 20,
}, PeriodHours: 1.0,
Security: struct { },
WhiteList []string `toml:"whiteList"` Security: struct {
BlackList []string `toml:"blackList"` WhiteList []string `toml:"whiteList"`
}{ BlackList []string `toml:"blackList"`
WhiteList: []string{}, }{
BlackList: []string{}, WhiteList: []string{},
}, BlackList: []string{},
Proxy: struct { },
WhiteList []string `toml:"whiteList"` Access: struct {
BlackList []string `toml:"blackList"` WhiteList []string `toml:"whiteList"`
Socks5 string `toml:"socks5"` BlackList []string `toml:"blackList"`
}{ Proxy string `toml:"proxy"`
WhiteList: []string{}, }{
BlackList: []string{}, WhiteList: []string{},
Socks5: "", // 默认不使用代理 BlackList: []string{},
}, Proxy: "", // 默认不使用代理
Download: struct { },
MaxImages int `toml:"maxImages"` Download: struct {
}{ MaxImages int `toml:"maxImages"`
MaxImages: 10, // 默认值最多同时下载10个镜像 }{
}, MaxImages: 10, // 默认值最多同时下载10个镜像
Registries: map[string]RegistryMapping{ },
"ghcr.io": { Registries: map[string]RegistryMapping{
Upstream: "ghcr.io", "ghcr.io": {
AuthHost: "ghcr.io/token", Upstream: "ghcr.io",
AuthType: "github", AuthHost: "ghcr.io/token",
Enabled: true, AuthType: "github",
}, Enabled: true,
"gcr.io": { },
Upstream: "gcr.io", "gcr.io": {
AuthHost: "gcr.io/v2/token", Upstream: "gcr.io",
AuthType: "google", AuthHost: "gcr.io/v2/token",
Enabled: true, AuthType: "google",
}, Enabled: true,
"quay.io": { },
Upstream: "quay.io", "quay.io": {
AuthHost: "quay.io/v2/auth", Upstream: "quay.io",
AuthType: "quay", AuthHost: "quay.io/v2/auth",
Enabled: true, AuthType: "quay",
}, Enabled: true,
"registry.k8s.io": { },
Upstream: "registry.k8s.io", "registry.k8s.io": {
AuthHost: "registry.k8s.io", Upstream: "registry.k8s.io",
AuthType: "anonymous", AuthHost: "registry.k8s.io",
Enabled: true, AuthType: "anonymous",
}, Enabled: true,
}, },
TokenCache: struct { },
Enabled bool `toml:"enabled"` TokenCache: struct {
DefaultTTL string `toml:"defaultTTL"` Enabled bool `toml:"enabled"`
}{ DefaultTTL string `toml:"defaultTTL"`
Enabled: true, // docker认证的匿名Token缓存配置用于提升性能 }{
DefaultTTL: "20m", Enabled: true, // docker认证的匿名Token缓存配置用于提升性能
}, DefaultTTL: "20m",
} },
} }
}
// GetConfig 安全地获取配置副本
func GetConfig() *AppConfig { // GetConfig 安全地获取配置副本
configCacheMutex.RLock() func GetConfig() *AppConfig {
if cachedConfig != nil && time.Since(configCacheTime) < configCacheTTL { configCacheMutex.RLock()
config := cachedConfig if cachedConfig != nil && time.Since(configCacheTime) < configCacheTTL {
configCacheMutex.RUnlock() config := cachedConfig
return config configCacheMutex.RUnlock()
} return config
configCacheMutex.RUnlock() }
configCacheMutex.RUnlock()
// 缓存过期,重新生成配置
configCacheMutex.Lock() // 缓存过期,重新生成配置
defer configCacheMutex.Unlock() configCacheMutex.Lock()
defer configCacheMutex.Unlock()
// 双重检查,防止重复生成
if cachedConfig != nil && time.Since(configCacheTime) < configCacheTTL { // 双重检查,防止重复生成
return cachedConfig if cachedConfig != nil && time.Since(configCacheTime) < configCacheTTL {
} return cachedConfig
}
appConfigLock.RLock()
if appConfig == nil { appConfigLock.RLock()
appConfigLock.RUnlock() if appConfig == nil {
defaultCfg := DefaultConfig() appConfigLock.RUnlock()
cachedConfig = defaultCfg defaultCfg := DefaultConfig()
configCacheTime = time.Now() cachedConfig = defaultCfg
return defaultCfg configCacheTime = time.Now()
} return defaultCfg
}
// 生成新的配置深拷贝
configCopy := *appConfig // 生成新的配置深拷贝
configCopy.Security.WhiteList = append([]string(nil), appConfig.Security.WhiteList...) configCopy := *appConfig
configCopy.Security.BlackList = append([]string(nil), appConfig.Security.BlackList...) configCopy.Security.WhiteList = append([]string(nil), appConfig.Security.WhiteList...)
configCopy.Proxy.WhiteList = append([]string(nil), appConfig.Proxy.WhiteList...) configCopy.Security.BlackList = append([]string(nil), appConfig.Security.BlackList...)
configCopy.Proxy.BlackList = append([]string(nil), appConfig.Proxy.BlackList...) configCopy.Access.WhiteList = append([]string(nil), appConfig.Access.WhiteList...)
appConfigLock.RUnlock() configCopy.Access.BlackList = append([]string(nil), appConfig.Access.BlackList...)
appConfigLock.RUnlock()
cachedConfig = &configCopy
configCacheTime = time.Now() cachedConfig = &configCopy
configCacheTime = time.Now()
return cachedConfig
} return cachedConfig
}
// setConfig 安全地设置配置
func setConfig(cfg *AppConfig) { // setConfig 安全地设置配置
appConfigLock.Lock() func setConfig(cfg *AppConfig) {
defer appConfigLock.Unlock() appConfigLock.Lock()
appConfig = cfg defer appConfigLock.Unlock()
appConfig = cfg
configCacheMutex.Lock()
cachedConfig = nil configCacheMutex.Lock()
configCacheMutex.Unlock() cachedConfig = nil
} configCacheMutex.Unlock()
}
// LoadConfig 加载配置文件
func LoadConfig() error { // LoadConfig 加载配置文件
// 首先使用默认配置 func LoadConfig() error {
cfg := DefaultConfig() // 首先使用默认配置
cfg := DefaultConfig()
// 尝试加载TOML配置文件
if data, err := os.ReadFile("config.toml"); err == nil { // 尝试加载TOML配置文件
if err := toml.Unmarshal(data, cfg); err != nil { if data, err := os.ReadFile("config.toml"); err == nil {
return fmt.Errorf("解析配置文件失败: %v", err) if err := toml.Unmarshal(data, cfg); err != nil {
} return fmt.Errorf("解析配置文件失败: %v", err)
} else { }
fmt.Println("未找到config.toml使用默认配置") } else {
} fmt.Println("未找到config.toml使用默认配置")
}
// 从环境变量覆盖配置
overrideFromEnv(cfg) // 从环境变量覆盖配置
overrideFromEnv(cfg)
// 设置配置
setConfig(cfg) // 设置配置
setConfig(cfg)
return nil
} return nil
}
// overrideFromEnv 从环境变量覆盖配置
func overrideFromEnv(cfg *AppConfig) { // overrideFromEnv 从环境变量覆盖配置
// 服务器配置 func overrideFromEnv(cfg *AppConfig) {
if val := os.Getenv("SERVER_HOST"); val != "" { // 服务器配置
cfg.Server.Host = val if val := os.Getenv("SERVER_HOST"); val != "" {
} cfg.Server.Host = val
if val := os.Getenv("SERVER_PORT"); val != "" { }
if port, err := strconv.Atoi(val); err == nil && port > 0 { if val := os.Getenv("SERVER_PORT"); val != "" {
cfg.Server.Port = port if port, err := strconv.Atoi(val); err == nil && port > 0 {
} cfg.Server.Port = port
} }
if val := os.Getenv("MAX_FILE_SIZE"); val != "" { }
if size, err := strconv.ParseInt(val, 10, 64); err == nil && size > 0 { if val := os.Getenv("MAX_FILE_SIZE"); val != "" {
cfg.Server.FileSize = size if size, err := strconv.ParseInt(val, 10, 64); err == nil && size > 0 {
} cfg.Server.FileSize = size
} }
}
// 限流配置
if val := os.Getenv("RATE_LIMIT"); val != "" { // 限流配置
if limit, err := strconv.Atoi(val); err == nil && limit > 0 { if val := os.Getenv("RATE_LIMIT"); val != "" {
cfg.RateLimit.RequestLimit = limit if limit, err := strconv.Atoi(val); err == nil && limit > 0 {
} cfg.RateLimit.RequestLimit = limit
} }
if val := os.Getenv("RATE_PERIOD_HOURS"); val != "" { }
if period, err := strconv.ParseFloat(val, 64); err == nil && period > 0 { if val := os.Getenv("RATE_PERIOD_HOURS"); val != "" {
cfg.RateLimit.PeriodHours = period if period, err := strconv.ParseFloat(val, 64); err == nil && period > 0 {
} cfg.RateLimit.PeriodHours = period
} }
}
// IP限制配置
if val := os.Getenv("IP_WHITELIST"); val != "" { // IP限制配置
cfg.Security.WhiteList = append(cfg.Security.WhiteList, strings.Split(val, ",")...) if val := os.Getenv("IP_WHITELIST"); val != "" {
} cfg.Security.WhiteList = append(cfg.Security.WhiteList, strings.Split(val, ",")...)
if val := os.Getenv("IP_BLACKLIST"); val != "" { }
cfg.Security.BlackList = append(cfg.Security.BlackList, strings.Split(val, ",")...) if val := os.Getenv("IP_BLACKLIST"); val != "" {
} cfg.Security.BlackList = append(cfg.Security.BlackList, strings.Split(val, ",")...)
}
// 下载限制配置
if val := os.Getenv("MAX_IMAGES"); val != "" { // 下载限制配置
if maxImages, err := strconv.Atoi(val); err == nil && maxImages > 0 { if val := os.Getenv("MAX_IMAGES"); val != "" {
cfg.Download.MaxImages = maxImages if maxImages, err := strconv.Atoi(val); err == nil && maxImages > 0 {
} cfg.Download.MaxImages = maxImages
} }
} }
}
// CreateDefaultConfigFile 创建默认配置文件
func CreateDefaultConfigFile() error { // CreateDefaultConfigFile 创建默认配置文件
cfg := DefaultConfig() func CreateDefaultConfigFile() error {
cfg := DefaultConfig()
data, err := toml.Marshal(cfg)
if err != nil { data, err := toml.Marshal(cfg)
return fmt.Errorf("序列化默认配置失败: %v", err) if err != nil {
} return fmt.Errorf("序列化默认配置失败: %v", err)
}
return os.WriteFile("config.toml", data, 0644)
} return os.WriteFile("config.toml", data, 0644)
}

View File

@@ -26,7 +26,7 @@ blackList = [
"192.168.100.0/24" "192.168.100.0/24"
] ]
[proxy] [access]
# 代理服务白名单支持GitHub仓库和Docker镜像支持通配符 # 代理服务白名单支持GitHub仓库和Docker镜像支持通配符
# 只允许访问白名单中的仓库/镜像,为空时不限制 # 只允许访问白名单中的仓库/镜像,为空时不限制
whiteList = [] whiteList = []
@@ -39,11 +39,17 @@ blackList = [
"baduser/*" "baduser/*"
] ]
# SOCKS5代理配置,支持有用户名/密码认证和无认证模式 # 代理配置,支持有用户名/密码认证和无认证模式
# 无认证: socks5://127.0.0.1:1080 # 无认证: socks5://127.0.0.1:1080
# 有认证: socks5://username:password@127.0.0.1:1080 # 有认证: socks5://username:password@127.0.0.1:1080
# HTTP 代理示例
# http://username:password@127.0.0.1:7890
# SOCKS5 代理示例
# socks5://username:password@127.0.0.1:1080
# SOCKS5H 代理示例
# socks5h://username:password@127.0.0.1:1080
# 留空不使用代理 # 留空不使用代理
socks5 = "" proxy = ""
[download] [download]
# 批量下载离线镜像数量限制 # 批量下载离线镜像数量限制

File diff suppressed because it is too large Load Diff

View File

@@ -6,7 +6,6 @@ require (
github.com/gin-gonic/gin v1.10.0 github.com/gin-gonic/gin v1.10.0
github.com/google/go-containerregistry v0.20.5 github.com/google/go-containerregistry v0.20.5
github.com/pelletier/go-toml/v2 v2.2.3 github.com/pelletier/go-toml/v2 v2.2.3
golang.org/x/net v0.33.0
golang.org/x/time v0.11.0 golang.org/x/time v0.11.0
) )
@@ -44,6 +43,7 @@ require (
github.com/vbatts/tar-split v0.12.1 // indirect github.com/vbatts/tar-split v0.12.1 // indirect
golang.org/x/arch v0.8.0 // indirect golang.org/x/arch v0.8.0 // indirect
golang.org/x/crypto v0.32.0 // indirect golang.org/x/crypto v0.32.0 // indirect
golang.org/x/net v0.33.0 // indirect
golang.org/x/sync v0.14.0 // indirect golang.org/x/sync v0.14.0 // indirect
golang.org/x/sys v0.33.0 // indirect golang.org/x/sys v0.33.0 // indirect
golang.org/x/text v0.21.0 // indirect golang.org/x/text v0.21.0 // indirect

View File

@@ -1,113 +1,68 @@
package main package main
import ( import (
"context" "net"
"log" "net/http"
"net" "os"
"net/http" "time"
"net/url" )
"time"
var (
"golang.org/x/net/proxy" // 全局HTTP客户端 - 用于代理请求(长超时)
) globalHTTPClient *http.Client
// 搜索HTTP客户端 - 用于API请求短超时
var ( searchHTTPClient *http.Client
// 全局HTTP客户端 - 用于代理请求(长超时) )
globalHTTPClient *http.Client
// 搜索HTTP客户端 - 用于API请求短超时 // initHTTPClients 初始化HTTP客户端
searchHTTPClient *http.Client func initHTTPClients() {
) cfg := GetConfig()
// initHTTPClients 初始化HTTP客户端 if p := cfg.Access.Proxy; p != "" {
func initHTTPClients() { os.Setenv("HTTP_PROXY", p)
cfg := GetConfig() os.Setenv("HTTPS_PROXY", p)
}
// 创建DialContext函数支持SOCKS5代理 // 代理客户端配置 - 适用于大文件传输
createDialContext := func(timeout time.Duration) func(ctx context.Context, network, addr string) (net.Conn, error) { globalHTTPClient = &http.Client{
if cfg.Proxy.Socks5 == "" { Transport: &http.Transport{
// 没有配置代理,使用直连 Proxy: http.ProxyFromEnvironment,
dialer := &net.Dialer{ DialContext: (&net.Dialer{
Timeout: timeout, Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second, KeepAlive: 30 * time.Second,
} }).DialContext,
return dialer.DialContext MaxIdleConns: 1000,
} MaxIdleConnsPerHost: 1000,
IdleConnTimeout: 90 * time.Second,
// 解析SOCKS5代理URL TLSHandshakeTimeout: 10 * time.Second,
proxyURL, err := url.Parse(cfg.Proxy.Socks5) ExpectContinueTimeout: 1 * time.Second,
if err != nil { ResponseHeaderTimeout: 300 * time.Second,
log.Printf("SOCKS5代理配置错误使用直连: %v", err) },
dialer := &net.Dialer{ }
Timeout: timeout,
KeepAlive: 30 * time.Second, // 搜索客户端配置 - 适用于API调用
} searchHTTPClient = &http.Client{
return dialer.DialContext Timeout: 10 * time.Second,
} Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
// 创建基础dialer DialContext: (&net.Dialer{
baseDialer := &net.Dialer{ Timeout: 5 * time.Second,
Timeout: timeout, KeepAlive: 30 * time.Second,
KeepAlive: 30 * time.Second, }).DialContext,
} MaxIdleConns: 100,
MaxIdleConnsPerHost: 10,
// 创建SOCKS5代理dialer IdleConnTimeout: 90 * time.Second,
var auth *proxy.Auth TLSHandshakeTimeout: 5 * time.Second,
if proxyURL.User != nil { DisableCompression: false,
if password, ok := proxyURL.User.Password(); ok { },
auth = &proxy.Auth{ }
User: proxyURL.User.Username(), }
Password: password,
} // GetGlobalHTTPClient 获取全局HTTP客户端用于代理
} func GetGlobalHTTPClient() *http.Client {
} return globalHTTPClient
}
socks5Dialer, err := proxy.SOCKS5("tcp", proxyURL.Host, auth, baseDialer)
if err != nil { // GetSearchHTTPClient 获取搜索HTTP客户端用于API调用
log.Printf("创建SOCKS5代理失败使用直连: %v", err) func GetSearchHTTPClient() *http.Client {
return baseDialer.DialContext return searchHTTPClient
} }
log.Printf("使用SOCKS5代理: %s", proxyURL.Host)
// 返回带上下文的dial函数
return func(ctx context.Context, network, addr string) (net.Conn, error) {
return socks5Dialer.Dial(network, addr)
}
}
// 代理客户端配置 - 适用于大文件传输
globalHTTPClient = &http.Client{
Transport: &http.Transport{
DialContext: createDialContext(30 * time.Second),
MaxIdleConns: 1000,
MaxIdleConnsPerHost: 1000,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
ResponseHeaderTimeout: 300 * time.Second,
},
}
// 搜索客户端配置 - 适用于API调用
searchHTTPClient = &http.Client{
Timeout: 10 * time.Second,
Transport: &http.Transport{
DialContext: createDialContext(5 * time.Second),
MaxIdleConns: 100,
MaxIdleConnsPerHost: 10,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 5 * time.Second,
DisableCompression: false,
},
}
}
// GetGlobalHTTPClient 获取全局HTTP客户端用于代理
func GetGlobalHTTPClient() *http.Client {
return globalHTTPClient
}
// GetSearchHTTPClient 获取搜索HTTP客户端用于API调用
func GetSearchHTTPClient() *http.Client {
return searchHTTPClient
}

View File

@@ -52,28 +52,28 @@ func NewDownloadDebouncer(window time.Duration) *DownloadDebouncer {
func (d *DownloadDebouncer) ShouldAllow(userID, contentKey string) bool { func (d *DownloadDebouncer) ShouldAllow(userID, contentKey string) bool {
d.mu.Lock() d.mu.Lock()
defer d.mu.Unlock() defer d.mu.Unlock()
key := userID + ":" + contentKey key := userID + ":" + contentKey
now := time.Now() now := time.Now()
if entry, exists := d.entries[key]; exists { if entry, exists := d.entries[key]; exists {
if now.Sub(entry.LastRequest) < d.window { if now.Sub(entry.LastRequest) < d.window {
return false // 在防抖窗口内,拒绝请求 return false // 在防抖窗口内,拒绝请求
} }
} }
// 更新或创建条目 // 更新或创建条目
d.entries[key] = &DebounceEntry{ d.entries[key] = &DebounceEntry{
LastRequest: now, LastRequest: now,
UserID: userID, UserID: userID,
} }
// 清理过期条目每5分钟清理一次 // 清理过期条目每5分钟清理一次
if time.Since(d.lastCleanup) > 5*time.Minute { if time.Since(d.lastCleanup) > 5*time.Minute {
d.cleanup(now) d.cleanup(now)
d.lastCleanup = now d.lastCleanup = now
} }
return true return true
} }
@@ -92,10 +92,10 @@ func generateContentFingerprint(images []string, platform string) string {
sortedImages := make([]string, len(images)) sortedImages := make([]string, len(images))
copy(sortedImages, images) copy(sortedImages, images)
sort.Strings(sortedImages) sort.Strings(sortedImages)
// 组合内容:镜像列表 + 平台信息 // 组合内容:镜像列表 + 平台信息
content := strings.Join(sortedImages, "|") + ":" + platform content := strings.Join(sortedImages, "|") + ":" + platform
// 生成MD5哈希 // 生成MD5哈希
hash := md5.Sum([]byte(content)) hash := md5.Sum([]byte(content))
return hex.EncodeToString(hash[:]) return hex.EncodeToString(hash[:])
@@ -107,14 +107,14 @@ func getUserID(c *gin.Context) string {
if sessionID, err := c.Cookie("session_id"); err == nil && sessionID != "" { if sessionID, err := c.Cookie("session_id"); err == nil && sessionID != "" {
return "session:" + sessionID return "session:" + sessionID
} }
// 备用方案IP + User-Agent组合 // 备用方案IP + User-Agent组合
ip := c.ClientIP() ip := c.ClientIP()
userAgent := c.GetHeader("User-Agent") userAgent := c.GetHeader("User-Agent")
if userAgent == "" { if userAgent == "" {
userAgent = "unknown" userAgent = "unknown"
} }
// 生成简短标识 // 生成简短标识
combined := ip + ":" + userAgent combined := ip + ":" + userAgent
hash := md5.Sum([]byte(combined)) hash := md5.Sum([]byte(combined))
@@ -228,7 +228,7 @@ func (is *ImageStreamer) StreamImageToGin(ctx context.Context, imageRef string,
filename := strings.ReplaceAll(imageRef, "/", "_") + ".tar" filename := strings.ReplaceAll(imageRef, "/", "_") + ".tar"
c.Header("Content-Type", "application/octet-stream") c.Header("Content-Type", "application/octet-stream")
c.Header("Content-Disposition", fmt.Sprintf("attachment; filename=\"%s\"", filename)) c.Header("Content-Disposition", fmt.Sprintf("attachment; filename=\"%s\"", filename))
if options.Compression { if options.Compression {
c.Header("Content-Encoding", "gzip") c.Header("Content-Encoding", "gzip")
} }
@@ -295,18 +295,18 @@ func (is *ImageStreamer) streamDockerFormatWithReturn(ctx context.Context, tarWr
if err != nil { if err != nil {
return err return err
} }
configData, err := json.Marshal(configFile) configData, err := json.Marshal(configFile)
if err != nil { if err != nil {
return err return err
} }
configHeader := &tar.Header{ configHeader := &tar.Header{
Name: configDigest.String() + ".json", Name: configDigest.String() + ".json",
Size: int64(len(configData)), Size: int64(len(configData)),
Mode: 0644, Mode: 0644,
} }
if err := tarWriter.WriteHeader(configHeader); err != nil { if err := tarWriter.WriteHeader(configHeader); err != nil {
return err return err
} }
@@ -335,14 +335,14 @@ func (is *ImageStreamer) streamDockerFormatWithReturn(ctx context.Context, tarWr
Typeflag: tar.TypeDir, Typeflag: tar.TypeDir,
Mode: 0755, Mode: 0755,
} }
if err := tarWriter.WriteHeader(layerHeader); err != nil { if err := tarWriter.WriteHeader(layerHeader); err != nil {
return err return err
} }
var layerSize int64 var layerSize int64
var layerReader io.ReadCloser var layerReader io.ReadCloser
// 根据配置选择使用压缩层或未压缩层 // 根据配置选择使用压缩层或未压缩层
if options != nil && options.UseCompressedLayers { if options != nil && options.UseCompressedLayers {
layerSize, err = layer.Size() layerSize, err = layer.Size()
@@ -357,7 +357,7 @@ func (is *ImageStreamer) streamDockerFormatWithReturn(ctx context.Context, tarWr
} }
layerReader, err = layer.Uncompressed() layerReader, err = layer.Uncompressed()
} }
if err != nil { if err != nil {
return err return err
} }
@@ -368,7 +368,7 @@ func (is *ImageStreamer) streamDockerFormatWithReturn(ctx context.Context, tarWr
Size: layerSize, Size: layerSize,
Mode: 0644, Mode: 0644,
} }
if err := tarWriter.WriteHeader(layerTarHeader); err != nil { if err := tarWriter.WriteHeader(layerTarHeader); err != nil {
return err return err
} }
@@ -385,12 +385,11 @@ func (is *ImageStreamer) streamDockerFormatWithReturn(ctx context.Context, tarWr
log.Printf("已处理层 %d/%d", i+1, len(layers)) log.Printf("已处理层 %d/%d", i+1, len(layers))
} }
// 构建单个镜像的manifest信息 // 构建单个镜像的manifest信息
singleManifest := map[string]interface{}{ singleManifest := map[string]interface{}{
"Config": configDigest.String() + ".json", "Config": configDigest.String() + ".json",
"RepoTags": []string{imageRef}, "RepoTags": []string{imageRef},
"Layers": func() []string { "Layers": func() []string {
var layers []string var layers []string
for _, digest := range layerDigests { for _, digest := range layerDigests {
layers = append(layers, digest+"/layer.tar") layers = append(layers, digest+"/layer.tar")
@@ -417,22 +416,22 @@ func (is *ImageStreamer) streamDockerFormatWithReturn(ctx context.Context, tarWr
// 单镜像下载直接写入manifest.json // 单镜像下载直接写入manifest.json
manifest := []map[string]interface{}{singleManifest} manifest := []map[string]interface{}{singleManifest}
manifestData, err := json.Marshal(manifest) manifestData, err := json.Marshal(manifest)
if err != nil { if err != nil {
return err return err
} }
manifestHeader := &tar.Header{ manifestHeader := &tar.Header{
Name: "manifest.json", Name: "manifest.json",
Size: int64(len(manifestData)), Size: int64(len(manifestData)),
Mode: 0644, Mode: 0644,
} }
if err := tarWriter.WriteHeader(manifestHeader); err != nil { if err := tarWriter.WriteHeader(manifestHeader); err != nil {
return err return err
} }
if _, err := tarWriter.Write(manifestData); err != nil { if _, err := tarWriter.Write(manifestData); err != nil {
return err return err
} }
@@ -442,17 +441,17 @@ func (is *ImageStreamer) streamDockerFormatWithReturn(ctx context.Context, tarWr
if err != nil { if err != nil {
return err return err
} }
repositoriesHeader := &tar.Header{ repositoriesHeader := &tar.Header{
Name: "repositories", Name: "repositories",
Size: int64(len(repositoriesData)), Size: int64(len(repositoriesData)),
Mode: 0644, Mode: 0644,
} }
if err := tarWriter.WriteHeader(repositoriesHeader); err != nil { if err := tarWriter.WriteHeader(repositoriesHeader); err != nil {
return err return err
} }
_, err = tarWriter.Write(repositoriesData) _, err = tarWriter.Write(repositoriesData)
return err return err
} }
@@ -473,12 +472,12 @@ func (is *ImageStreamer) processImageForBatch(ctx context.Context, img v1.Image,
var manifest map[string]interface{} var manifest map[string]interface{}
var repositories map[string]map[string]string var repositories map[string]map[string]string
err = is.streamDockerFormatWithReturn(ctx, tarWriter, img, layers, configFile, imageRef, &manifest, &repositories, options) err = is.streamDockerFormatWithReturn(ctx, tarWriter, img, layers, configFile, imageRef, &manifest, &repositories, options)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
return manifest, repositories, nil return manifest, repositories, nil
} }
@@ -537,7 +536,7 @@ func (is *ImageStreamer) selectPlatformImage(desc *remote.Descriptor, options *S
if m.Platform == nil { if m.Platform == nil {
continue continue
} }
if options.Platform != "" { if options.Platform != "" {
platformParts := strings.Split(options.Platform, "/") platformParts := strings.Split(options.Platform, "/")
if len(platformParts) >= 2 { if len(platformParts) >= 2 {
@@ -547,10 +546,10 @@ func (is *ImageStreamer) selectPlatformImage(desc *remote.Descriptor, options *S
if len(platformParts) >= 3 { if len(platformParts) >= 3 {
targetVariant = platformParts[2] targetVariant = platformParts[2]
} }
if m.Platform.OS == targetOS && if m.Platform.OS == targetOS &&
m.Platform.Architecture == targetArch && m.Platform.Architecture == targetArch &&
m.Platform.Variant == targetVariant { m.Platform.Variant == targetVariant {
selectedDesc = &m selectedDesc = &m
break break
} }
@@ -629,10 +628,10 @@ func handleDirectImageDownload(c *gin.Context) {
// 防抖检查 // 防抖检查
userID := getUserID(c) userID := getUserID(c)
contentKey := generateContentFingerprint([]string{imageRef}, platform) contentKey := generateContentFingerprint([]string{imageRef}, platform)
if !singleImageDebouncer.ShouldAllow(userID, contentKey) { if !singleImageDebouncer.ShouldAllow(userID, contentKey) {
c.JSON(http.StatusTooManyRequests, gin.H{ c.JSON(http.StatusTooManyRequests, gin.H{
"error": "请求过于频繁,请稍后再试", "error": "请求过于频繁,请稍后再试",
"retry_after": 5, "retry_after": 5,
}) })
return return
@@ -689,10 +688,10 @@ func handleSimpleBatchDownload(c *gin.Context) {
// 批量下载防抖检查 // 批量下载防抖检查
userID := getUserID(c) userID := getUserID(c)
contentKey := generateContentFingerprint(req.Images, req.Platform) contentKey := generateContentFingerprint(req.Images, req.Platform)
if !batchImageDebouncer.ShouldAllow(userID, contentKey) { if !batchImageDebouncer.ShouldAllow(userID, contentKey) {
c.JSON(http.StatusTooManyRequests, gin.H{ c.JSON(http.StatusTooManyRequests, gin.H{
"error": "批量下载请求过于频繁,请稍后再试", "error": "批量下载请求过于频繁,请稍后再试",
"retry_after": 60, "retry_after": 60,
}) })
return return
@@ -713,7 +712,7 @@ func handleSimpleBatchDownload(c *gin.Context) {
log.Printf("批量下载 %d 个镜像 (平台: %s)", len(req.Images), formatPlatformText(req.Platform)) log.Printf("批量下载 %d 个镜像 (平台: %s)", len(req.Images), formatPlatformText(req.Platform))
filename := fmt.Sprintf("batch_%d_images.tar", len(req.Images)) filename := fmt.Sprintf("batch_%d_images.tar", len(req.Images))
c.Header("Content-Type", "application/octet-stream") c.Header("Content-Type", "application/octet-stream")
c.Header("Content-Disposition", fmt.Sprintf("attachment; filename=\"%s\"", filename)) c.Header("Content-Disposition", fmt.Sprintf("attachment; filename=\"%s\"", filename))
@@ -811,12 +810,12 @@ func (is *ImageStreamer) StreamMultipleImages(ctx context.Context, imageRefs []s
} }
log.Printf("处理镜像 %d/%d: %s", i+1, len(imageRefs), imageRef) log.Printf("处理镜像 %d/%d: %s", i+1, len(imageRefs), imageRef)
// 防止单个镜像处理时间过长 // 防止单个镜像处理时间过长
timeoutCtx, cancel := context.WithTimeout(ctx, 15*time.Minute) timeoutCtx, cancel := context.WithTimeout(ctx, 15*time.Minute)
manifest, repositories, err := is.streamSingleImageForBatch(timeoutCtx, tarWriter, imageRef, options) manifest, repositories, err := is.streamSingleImageForBatch(timeoutCtx, tarWriter, imageRef, options)
cancel() cancel()
if err != nil { if err != nil {
log.Printf("下载镜像 %s 失败: %v", imageRef, err) log.Printf("下载镜像 %s 失败: %v", imageRef, err)
return fmt.Errorf("下载镜像 %s 失败: %w", imageRef, err) return fmt.Errorf("下载镜像 %s 失败: %w", imageRef, err)
@@ -845,17 +844,17 @@ func (is *ImageStreamer) StreamMultipleImages(ctx context.Context, imageRefs []s
if err != nil { if err != nil {
return fmt.Errorf("序列化manifest失败: %w", err) return fmt.Errorf("序列化manifest失败: %w", err)
} }
manifestHeader := &tar.Header{ manifestHeader := &tar.Header{
Name: "manifest.json", Name: "manifest.json",
Size: int64(len(manifestData)), Size: int64(len(manifestData)),
Mode: 0644, Mode: 0644,
} }
if err := tarWriter.WriteHeader(manifestHeader); err != nil { if err := tarWriter.WriteHeader(manifestHeader); err != nil {
return fmt.Errorf("写入manifest header失败: %w", err) return fmt.Errorf("写入manifest header失败: %w", err)
} }
if _, err := tarWriter.Write(manifestData); err != nil { if _, err := tarWriter.Write(manifestData); err != nil {
return fmt.Errorf("写入manifest数据失败: %w", err) return fmt.Errorf("写入manifest数据失败: %w", err)
} }
@@ -865,21 +864,21 @@ func (is *ImageStreamer) StreamMultipleImages(ctx context.Context, imageRefs []s
if err != nil { if err != nil {
return fmt.Errorf("序列化repositories失败: %w", err) return fmt.Errorf("序列化repositories失败: %w", err)
} }
repositoriesHeader := &tar.Header{ repositoriesHeader := &tar.Header{
Name: "repositories", Name: "repositories",
Size: int64(len(repositoriesData)), Size: int64(len(repositoriesData)),
Mode: 0644, Mode: 0644,
} }
if err := tarWriter.WriteHeader(repositoriesHeader); err != nil { if err := tarWriter.WriteHeader(repositoriesHeader); err != nil {
return fmt.Errorf("写入repositories header失败: %w", err) return fmt.Errorf("写入repositories header失败: %w", err)
} }
if _, err := tarWriter.Write(repositoriesData); err != nil { if _, err := tarWriter.Write(repositoriesData); err != nil {
return fmt.Errorf("写入repositories数据失败: %w", err) return fmt.Errorf("写入repositories数据失败: %w", err)
} }
log.Printf("批量下载完成,共处理 %d 个镜像", len(imageRefs)) log.Printf("批量下载完成,共处理 %d 个镜像", len(imageRefs))
return nil return nil
} }

View File

@@ -1,382 +1,379 @@
package main package main
import ( import (
"embed" "embed"
"fmt" "fmt"
"io" "io"
"log" "log"
"net/http" "net/http"
"regexp" "regexp"
"strconv" "strconv"
"strings" "strings"
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
//go:embed public/* //go:embed public/*
var staticFiles embed.FS var staticFiles embed.FS
// 服务嵌入的静态文件 // 服务嵌入的静态文件
func serveEmbedFile(c *gin.Context, filename string) { func serveEmbedFile(c *gin.Context, filename string) {
data, err := staticFiles.ReadFile(filename) data, err := staticFiles.ReadFile(filename)
if err != nil { if err != nil {
c.Status(404) c.Status(404)
return return
} }
contentType := "text/html; charset=utf-8" contentType := "text/html; charset=utf-8"
if strings.HasSuffix(filename, ".ico") { if strings.HasSuffix(filename, ".ico") {
contentType = "image/x-icon" contentType = "image/x-icon"
} }
c.Data(200, contentType, data) c.Data(200, contentType, data)
} }
var ( var (
exps = []*regexp.Regexp{ exps = []*regexp.Regexp{
regexp.MustCompile(`^(?:https?://)?github\.com/([^/]+)/([^/]+)/(?:releases|archive)/.*$`), regexp.MustCompile(`^(?:https?://)?github\.com/([^/]+)/([^/]+)/(?:releases|archive)/.*$`),
regexp.MustCompile(`^(?:https?://)?github\.com/([^/]+)/([^/]+)/(?:blob|raw)/.*$`), regexp.MustCompile(`^(?:https?://)?github\.com/([^/]+)/([^/]+)/(?:blob|raw)/.*$`),
regexp.MustCompile(`^(?:https?://)?github\.com/([^/]+)/([^/]+)/(?:info|git-).*$`), regexp.MustCompile(`^(?:https?://)?github\.com/([^/]+)/([^/]+)/(?:info|git-).*$`),
regexp.MustCompile(`^(?:https?://)?raw\.github(?:usercontent|)\.com/([^/]+)/([^/]+)/.+?/.+$`), regexp.MustCompile(`^(?:https?://)?raw\.github(?:usercontent|)\.com/([^/]+)/([^/]+)/.+?/.+$`),
regexp.MustCompile(`^(?:https?://)?gist\.github(?:usercontent|)\.com/([^/]+)/.+?/.+`), regexp.MustCompile(`^(?:https?://)?gist\.github(?:usercontent|)\.com/([^/]+)/.+?/.+`),
regexp.MustCompile(`^(?:https?://)?api\.github\.com/repos/([^/]+)/([^/]+)/.*`), regexp.MustCompile(`^(?:https?://)?api\.github\.com/repos/([^/]+)/([^/]+)/.*`),
regexp.MustCompile(`^(?:https?://)?huggingface\.co(?:/spaces)?/([^/]+)/(.+)$`), regexp.MustCompile(`^(?:https?://)?huggingface\.co(?:/spaces)?/([^/]+)/(.+)$`),
regexp.MustCompile(`^(?:https?://)?cdn-lfs\.hf\.co(?:/spaces)?/([^/]+)/([^/]+)(?:/(.*))?$`), regexp.MustCompile(`^(?:https?://)?cdn-lfs\.hf\.co(?:/spaces)?/([^/]+)/([^/]+)(?:/(.*))?$`),
regexp.MustCompile(`^(?:https?://)?download\.docker\.com/([^/]+)/.*\.(tgz|zip)$`), regexp.MustCompile(`^(?:https?://)?download\.docker\.com/([^/]+)/.*\.(tgz|zip)$`),
regexp.MustCompile(`^(?:https?://)?(github|opengraph)\.githubassets\.com/([^/]+)/.+?$`), regexp.MustCompile(`^(?:https?://)?(github|opengraph)\.githubassets\.com/([^/]+)/.+?$`),
} }
globalLimiter *IPRateLimiter globalLimiter *IPRateLimiter
// 服务启动时间 // 服务启动时间
serviceStartTime = time.Now() serviceStartTime = time.Now()
) )
func main() { func main() {
// 加载配置 // 加载配置
if err := LoadConfig(); err != nil { if err := LoadConfig(); err != nil {
fmt.Printf("配置加载失败: %v\n", err) fmt.Printf("配置加载失败: %v\n", err)
return return
} }
// 初始化HTTP客户端 // 初始化HTTP客户端
initHTTPClients() initHTTPClients()
// 初始化限流器 // 初始化限流器
initLimiter() initLimiter()
// 初始化Docker流式代理 // 初始化Docker流式代理
initDockerProxy() initDockerProxy()
// 初始化镜像流式下载器 // 初始化镜像流式下载器
initImageStreamer() initImageStreamer()
// 初始化防抖器 // 初始化防抖器
initDebouncer() initDebouncer()
gin.SetMode(gin.ReleaseMode) gin.SetMode(gin.ReleaseMode)
router := gin.Default() router := gin.Default()
// 全局Panic恢复保护 // 全局Panic恢复保护
router.Use(gin.CustomRecovery(func(c *gin.Context, recovered interface{}) { router.Use(gin.CustomRecovery(func(c *gin.Context, recovered interface{}) {
log.Printf("🚨 Panic recovered: %v", recovered) log.Printf("🚨 Panic recovered: %v", recovered)
c.JSON(http.StatusInternalServerError, gin.H{ c.JSON(http.StatusInternalServerError, gin.H{
"error": "Internal server error", "error": "Internal server error",
"code": "INTERNAL_ERROR", "code": "INTERNAL_ERROR",
}) })
})) }))
// 全局限流中间件 - 应用到所有路由 // 全局限流中间件 - 应用到所有路由
router.Use(RateLimitMiddleware(globalLimiter)) router.Use(RateLimitMiddleware(globalLimiter))
// 初始化监控端点 // 初始化监控端点
initHealthRoutes(router) initHealthRoutes(router)
// 初始化镜像tar下载路由 // 初始化镜像tar下载路由
initImageTarRoutes(router) initImageTarRoutes(router)
// 静态文件路由 // 静态文件路由
router.GET("/", func(c *gin.Context) { router.GET("/", func(c *gin.Context) {
serveEmbedFile(c, "public/index.html") serveEmbedFile(c, "public/index.html")
}) })
router.GET("/public/*filepath", func(c *gin.Context) { router.GET("/public/*filepath", func(c *gin.Context) {
filepath := strings.TrimPrefix(c.Param("filepath"), "/") filepath := strings.TrimPrefix(c.Param("filepath"), "/")
serveEmbedFile(c, "public/"+filepath) serveEmbedFile(c, "public/"+filepath)
}) })
router.GET("/images.html", func(c *gin.Context) { router.GET("/images.html", func(c *gin.Context) {
serveEmbedFile(c, "public/images.html") serveEmbedFile(c, "public/images.html")
}) })
router.GET("/search.html", func(c *gin.Context) { router.GET("/search.html", func(c *gin.Context) {
serveEmbedFile(c, "public/search.html") serveEmbedFile(c, "public/search.html")
}) })
router.GET("/favicon.ico", func(c *gin.Context) { router.GET("/favicon.ico", func(c *gin.Context) {
serveEmbedFile(c, "public/favicon.ico") serveEmbedFile(c, "public/favicon.ico")
}) })
// 注册dockerhub搜索路由 // 注册dockerhub搜索路由
RegisterSearchRoute(router) RegisterSearchRoute(router)
// 注册Docker认证路由/token* // 注册Docker认证路由/token*
router.Any("/token", ProxyDockerAuthGin) router.Any("/token", ProxyDockerAuthGin)
router.Any("/token/*path", ProxyDockerAuthGin) router.Any("/token/*path", ProxyDockerAuthGin)
// 注册Docker Registry代理路由 // 注册Docker Registry代理路由
router.Any("/v2/*path", ProxyDockerRegistryGin) router.Any("/v2/*path", ProxyDockerRegistryGin)
// 注册NoRoute处理器
// 注册NoRoute处理器 router.NoRoute(handler)
router.NoRoute(handler)
cfg := GetConfig()
cfg := GetConfig() fmt.Printf("🚀 HubProxy 启动成功\n")
fmt.Printf("🚀 HubProxy 启动成功\n") fmt.Printf("📡 监听地址: %s:%d\n", cfg.Server.Host, cfg.Server.Port)
fmt.Printf("📡 监听地址: %s:%d\n", cfg.Server.Host, cfg.Server.Port) fmt.Printf("⚡ 限流配置: %d请求/%g小时\n", cfg.RateLimit.RequestLimit, cfg.RateLimit.PeriodHours)
fmt.Printf("⚡ 限流配置: %d请求/%g小时\n", cfg.RateLimit.RequestLimit, cfg.RateLimit.PeriodHours) fmt.Printf("🔗 项目地址: https://github.com/sky22333/hubproxy\n")
fmt.Printf("🔗 项目地址: https://github.com/sky22333/hubproxy\n")
err := router.Run(fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port))
err := router.Run(fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port)) if err != nil {
if err != nil { fmt.Printf("启动服务失败: %v\n", err)
fmt.Printf("启动服务失败: %v\n", err) }
} }
}
func handler(c *gin.Context) {
func handler(c *gin.Context) { rawPath := strings.TrimPrefix(c.Request.URL.RequestURI(), "/")
rawPath := strings.TrimPrefix(c.Request.URL.RequestURI(), "/")
for strings.HasPrefix(rawPath, "/") {
for strings.HasPrefix(rawPath, "/") { rawPath = strings.TrimPrefix(rawPath, "/")
rawPath = strings.TrimPrefix(rawPath, "/") }
}
if !strings.HasPrefix(rawPath, "http") {
if !strings.HasPrefix(rawPath, "http") { c.String(http.StatusForbidden, "无效输入")
c.String(http.StatusForbidden, "无效输入") return
return }
}
matches := checkURL(rawPath)
matches := checkURL(rawPath) if matches != nil {
if matches != nil { // GitHub仓库访问控制检查
// GitHub仓库访问控制检查 if allowed, reason := GlobalAccessController.CheckGitHubAccess(matches); !allowed {
if allowed, reason := GlobalAccessController.CheckGitHubAccess(matches); !allowed { // 构建仓库名用于日志
// 构建仓库名用于日志 var repoPath string
var repoPath string if len(matches) >= 2 {
if len(matches) >= 2 { username := matches[0]
username := matches[0] repoName := strings.TrimSuffix(matches[1], ".git")
repoName := strings.TrimSuffix(matches[1], ".git") repoPath = username + "/" + repoName
repoPath = username + "/" + repoName }
} fmt.Printf("GitHub仓库 %s 访问被拒绝: %s\n", repoPath, reason)
fmt.Printf("GitHub仓库 %s 访问被拒绝: %s\n", repoPath, reason) c.String(http.StatusForbidden, reason)
c.String(http.StatusForbidden, reason) return
return }
} } else {
} else { c.String(http.StatusForbidden, "无效输入")
c.String(http.StatusForbidden, "无效输入") return
return }
}
if exps[1].MatchString(rawPath) {
if exps[1].MatchString(rawPath) { rawPath = strings.Replace(rawPath, "/blob/", "/raw/", 1)
rawPath = strings.Replace(rawPath, "/blob/", "/raw/", 1) }
}
proxyRequest(c, rawPath)
proxyRequest(c, rawPath) }
}
func proxyRequest(c *gin.Context, u string) {
proxyWithRedirect(c, u, 0)
func proxyRequest(c *gin.Context, u string) { }
proxyWithRedirect(c, u, 0)
} func proxyWithRedirect(c *gin.Context, u string, redirectCount int) {
// 限制最大重定向次数,防止无限递归
const maxRedirects = 20
func proxyWithRedirect(c *gin.Context, u string, redirectCount int) { if redirectCount > maxRedirects {
// 限制最大重定向次数,防止无限递归 c.String(http.StatusLoopDetected, "重定向次数过多,可能存在循环重定向")
const maxRedirects = 20 return
if redirectCount > maxRedirects { }
c.String(http.StatusLoopDetected, "重定向次数过多,可能存在循环重定向") req, err := http.NewRequest(c.Request.Method, u, c.Request.Body)
return if err != nil {
} c.String(http.StatusInternalServerError, fmt.Sprintf("server error %v", err))
req, err := http.NewRequest(c.Request.Method, u, c.Request.Body) return
if err != nil { }
c.String(http.StatusInternalServerError, fmt.Sprintf("server error %v", err))
return for key, values := range c.Request.Header {
} for _, value := range values {
req.Header.Add(key, value)
for key, values := range c.Request.Header { }
for _, value := range values { }
req.Header.Add(key, value) req.Header.Del("Host")
}
} resp, err := GetGlobalHTTPClient().Do(req)
req.Header.Del("Host") if err != nil {
c.String(http.StatusInternalServerError, fmt.Sprintf("server error %v", err))
resp, err := GetGlobalHTTPClient().Do(req) return
if err != nil { }
c.String(http.StatusInternalServerError, fmt.Sprintf("server error %v", err)) defer func() {
return if err := resp.Body.Close(); err != nil {
} fmt.Printf("关闭响应体失败: %v\n", err)
defer func() { }
if err := resp.Body.Close(); err != nil { }()
fmt.Printf("关闭响应体失败: %v\n", err)
} // 检查文件大小限制
}() cfg := GetConfig()
if contentLength := resp.Header.Get("Content-Length"); contentLength != "" {
// 检查文件大小限制 if size, err := strconv.ParseInt(contentLength, 10, 64); err == nil && size > cfg.Server.FileSize {
cfg := GetConfig() c.String(http.StatusRequestEntityTooLarge,
if contentLength := resp.Header.Get("Content-Length"); contentLength != "" { fmt.Sprintf("文件过大,限制大小: %d MB", cfg.Server.FileSize/(1024*1024)))
if size, err := strconv.ParseInt(contentLength, 10, 64); err == nil && size > cfg.Server.FileSize { return
c.String(http.StatusRequestEntityTooLarge, }
fmt.Sprintf("文件过大,限制大小: %d MB", cfg.Server.FileSize/(1024*1024))) }
return
} // 清理安全相关的头
} resp.Header.Del("Content-Security-Policy")
resp.Header.Del("Referrer-Policy")
// 清理安全相关的头 resp.Header.Del("Strict-Transport-Security")
resp.Header.Del("Content-Security-Policy")
resp.Header.Del("Referrer-Policy") // 获取真实域名
resp.Header.Del("Strict-Transport-Security") realHost := c.Request.Header.Get("X-Forwarded-Host")
if realHost == "" {
// 获取真实域名 realHost = c.Request.Host
realHost := c.Request.Header.Get("X-Forwarded-Host") }
if realHost == "" { // 如果域名中没有协议前缀添加https://
realHost = c.Request.Host if !strings.HasPrefix(realHost, "http://") && !strings.HasPrefix(realHost, "https://") {
} realHost = "https://" + realHost
// 如果域名中没有协议前缀添加https:// }
if !strings.HasPrefix(realHost, "http://") && !strings.HasPrefix(realHost, "https://") {
realHost = "https://" + realHost if strings.HasSuffix(strings.ToLower(u), ".sh") {
} isGzipCompressed := resp.Header.Get("Content-Encoding") == "gzip"
if strings.HasSuffix(strings.ToLower(u), ".sh") { processedBody, processedSize, err := ProcessSmart(resp.Body, isGzipCompressed, realHost)
isGzipCompressed := resp.Header.Get("Content-Encoding") == "gzip" if err != nil {
fmt.Printf("智能处理失败,回退到直接代理: %v\n", err)
processedBody, processedSize, err := ProcessSmart(resp.Body, isGzipCompressed, realHost) processedBody = resp.Body
if err != nil { processedSize = 0
fmt.Printf("智能处理失败,回退到直接代理: %v\n", err) }
processedBody = resp.Body
processedSize = 0 // 智能设置响应头
} if processedSize > 0 {
resp.Header.Del("Content-Length")
// 智能设置响应头 resp.Header.Del("Content-Encoding")
if processedSize > 0 { resp.Header.Set("Transfer-Encoding", "chunked")
resp.Header.Del("Content-Length") }
resp.Header.Del("Content-Encoding")
resp.Header.Set("Transfer-Encoding", "chunked") // 复制其他响应头
} for key, values := range resp.Header {
for _, value := range values {
// 复制其他响应头 c.Header(key, value)
for key, values := range resp.Header { }
for _, value := range values { }
c.Header(key, value)
} if location := resp.Header.Get("Location"); location != "" {
} if checkURL(location) != nil {
c.Header("Location", "/"+location)
if location := resp.Header.Get("Location"); location != "" { } else {
if checkURL(location) != nil { proxyWithRedirect(c, location, redirectCount+1)
c.Header("Location", "/"+location) return
} else { }
proxyWithRedirect(c, location, redirectCount+1) }
return
} c.Status(resp.StatusCode)
}
// 输出处理后的内容
c.Status(resp.StatusCode) if _, err := io.Copy(c.Writer, processedBody); err != nil {
return
// 输出处理后的内容 }
if _, err := io.Copy(c.Writer, processedBody); err != nil { } else {
return for key, values := range resp.Header {
} for _, value := range values {
} else { c.Header(key, value)
for key, values := range resp.Header { }
for _, value := range values { }
c.Header(key, value)
} // 处理重定向
} if location := resp.Header.Get("Location"); location != "" {
if checkURL(location) != nil {
// 处理重定向 c.Header("Location", "/"+location)
if location := resp.Header.Get("Location"); location != "" { } else {
if checkURL(location) != nil { proxyWithRedirect(c, location, redirectCount+1)
c.Header("Location", "/"+location) return
} else { }
proxyWithRedirect(c, location, redirectCount+1) }
return
} c.Status(resp.StatusCode)
}
// 直接流式转发
c.Status(resp.StatusCode) io.Copy(c.Writer, resp.Body)
}
// 直接流式转发 }
io.Copy(c.Writer, resp.Body)
} func checkURL(u string) []string {
} for _, exp := range exps {
if matches := exp.FindStringSubmatch(u); matches != nil {
func checkURL(u string) []string { return matches[1:]
for _, exp := range exps { }
if matches := exp.FindStringSubmatch(u); matches != nil { }
return matches[1:] return nil
} }
}
return nil // 初始化健康监控路由
} func initHealthRoutes(router *gin.Engine) {
// 健康检查端点
// 初始化健康监控路由 router.GET("/health", func(c *gin.Context) {
func initHealthRoutes(router *gin.Engine) { c.JSON(http.StatusOK, gin.H{
// 健康检查端点 "status": "healthy",
router.GET("/health", func(c *gin.Context) { "timestamp": time.Now().Unix(),
c.JSON(http.StatusOK, gin.H{ "uptime": time.Since(serviceStartTime).Seconds(),
"status": "healthy", "service": "hubproxy",
"timestamp": time.Now().Unix(), })
"uptime": time.Since(serviceStartTime).Seconds(), })
"service": "hubproxy",
}) // 就绪检查端点
}) router.GET("/ready", func(c *gin.Context) {
checks := make(map[string]string)
// 就绪检查端点 allReady := true
router.GET("/ready", func(c *gin.Context) {
checks := make(map[string]string) if GetConfig() != nil {
allReady := true checks["config"] = "ok"
} else {
if GetConfig() != nil { checks["config"] = "failed"
checks["config"] = "ok" allReady = false
} else { }
checks["config"] = "failed"
allReady = false // 检查全局缓存状态
} if globalCache != nil {
checks["cache"] = "ok"
// 检查全局缓存状态 } else {
if globalCache != nil { checks["cache"] = "failed"
checks["cache"] = "ok" allReady = false
} else { }
checks["cache"] = "failed"
allReady = false // 检查限流器状态
} if globalLimiter != nil {
checks["ratelimiter"] = "ok"
// 检查限流器状态 } else {
if globalLimiter != nil { checks["ratelimiter"] = "failed"
checks["ratelimiter"] = "ok" allReady = false
} else { }
checks["ratelimiter"] = "failed"
allReady = false // 检查镜像下载器状态
} if globalImageStreamer != nil {
checks["imagestreamer"] = "ok"
// 检查镜像下载器状态 } else {
if globalImageStreamer != nil { checks["imagestreamer"] = "failed"
checks["imagestreamer"] = "ok" allReady = false
} else { }
checks["imagestreamer"] = "failed"
allReady = false // 检查HTTP客户端状态
} if GetGlobalHTTPClient() != nil {
checks["httpclient"] = "ok"
// 检查HTTP客户端状态 } else {
if GetGlobalHTTPClient() != nil { checks["httpclient"] = "failed"
checks["httpclient"] = "ok" allReady = false
} else { }
checks["httpclient"] = "failed"
allReady = false status := http.StatusOK
} if !allReady {
status = http.StatusServiceUnavailable
status := http.StatusOK }
if !allReady {
status = http.StatusServiceUnavailable c.JSON(status, gin.H{
} "ready": allReady,
"checks": checks,
c.JSON(status, gin.H{ "timestamp": time.Now().Unix(),
"ready": allReady, "uptime": time.Since(serviceStartTime).Seconds(),
"checks": checks, })
"timestamp": time.Now().Unix(), })
"uptime": time.Since(serviceStartTime).Seconds(), }
})
})
}

View File

@@ -1,95 +1,95 @@
package main package main
import ( import (
"bytes" "bytes"
"compress/gzip" "compress/gzip"
"fmt" "fmt"
"io" "io"
"regexp" "regexp"
"strings" "strings"
) )
// GitHub URL正则表达式 // GitHub URL正则表达式
var githubRegex = regexp.MustCompile(`https?://(?:github\.com|raw\.githubusercontent\.com|raw\.github\.com|gist\.githubusercontent\.com|gist\.github\.com|api\.github\.com)[^\s'"]+`) var githubRegex = regexp.MustCompile(`https?://(?:github\.com|raw\.githubusercontent\.com|raw\.github\.com|gist\.githubusercontent\.com|gist\.github\.com|api\.github\.com)[^\s'"]+`)
// ProcessSmart Shell脚本智能处理函数 // ProcessSmart Shell脚本智能处理函数
func ProcessSmart(input io.ReadCloser, isCompressed bool, host string) (io.Reader, int64, error) { func ProcessSmart(input io.ReadCloser, isCompressed bool, host string) (io.Reader, int64, error) {
defer input.Close() defer input.Close()
content, err := readShellContent(input, isCompressed) content, err := readShellContent(input, isCompressed)
if err != nil { if err != nil {
return nil, 0, fmt.Errorf("内容读取失败: %v", err) return nil, 0, fmt.Errorf("内容读取失败: %v", err)
} }
if len(content) == 0 { if len(content) == 0 {
return strings.NewReader(""), 0, nil return strings.NewReader(""), 0, nil
} }
if len(content) > 10*1024*1024 { if len(content) > 10*1024*1024 {
return strings.NewReader(content), int64(len(content)), nil return strings.NewReader(content), int64(len(content)), nil
} }
if !strings.Contains(content, "github.com") && !strings.Contains(content, "githubusercontent.com") { if !strings.Contains(content, "github.com") && !strings.Contains(content, "githubusercontent.com") {
return strings.NewReader(content), int64(len(content)), nil return strings.NewReader(content), int64(len(content)), nil
} }
processed := processGitHubURLs(content, host) processed := processGitHubURLs(content, host)
return strings.NewReader(processed), int64(len(processed)), nil return strings.NewReader(processed), int64(len(processed)), nil
} }
func readShellContent(input io.ReadCloser, isCompressed bool) (string, error) { func readShellContent(input io.ReadCloser, isCompressed bool) (string, error) {
var reader io.Reader = input var reader io.Reader = input
// 处理gzip压缩 // 处理gzip压缩
if isCompressed { if isCompressed {
peek := make([]byte, 2) peek := make([]byte, 2)
n, err := input.Read(peek) n, err := input.Read(peek)
if err != nil && err != io.EOF { if err != nil && err != io.EOF {
return "", fmt.Errorf("读取数据失败: %v", err) return "", fmt.Errorf("读取数据失败: %v", err)
} }
if n >= 2 && peek[0] == 0x1f && peek[1] == 0x8b { if n >= 2 && peek[0] == 0x1f && peek[1] == 0x8b {
combinedReader := io.MultiReader(bytes.NewReader(peek[:n]), input) combinedReader := io.MultiReader(bytes.NewReader(peek[:n]), input)
gzReader, err := gzip.NewReader(combinedReader) gzReader, err := gzip.NewReader(combinedReader)
if err != nil { if err != nil {
return "", fmt.Errorf("gzip解压失败: %v", err) return "", fmt.Errorf("gzip解压失败: %v", err)
} }
defer gzReader.Close() defer gzReader.Close()
reader = gzReader reader = gzReader
} else { } else {
reader = io.MultiReader(bytes.NewReader(peek[:n]), input) reader = io.MultiReader(bytes.NewReader(peek[:n]), input)
} }
} }
data, err := io.ReadAll(reader) data, err := io.ReadAll(reader)
if err != nil { if err != nil {
return "", fmt.Errorf("读取内容失败: %v", err) return "", fmt.Errorf("读取内容失败: %v", err)
} }
return string(data), nil return string(data), nil
} }
func processGitHubURLs(content, host string) string { func processGitHubURLs(content, host string) string {
return githubRegex.ReplaceAllStringFunc(content, func(url string) string { return githubRegex.ReplaceAllStringFunc(content, func(url string) string {
return transformURL(url, host) return transformURL(url, host)
}) })
} }
// transformURL URL转换函数 // transformURL URL转换函数
func transformURL(url, host string) string { func transformURL(url, host string) string {
if strings.Contains(url, host) { if strings.Contains(url, host) {
return url return url
} }
if strings.HasPrefix(url, "http://") { if strings.HasPrefix(url, "http://") {
url = "https" + url[4:] url = "https" + url[4:]
} else if !strings.HasPrefix(url, "https://") && !strings.HasPrefix(url, "//") { } else if !strings.HasPrefix(url, "https://") && !strings.HasPrefix(url, "//") {
url = "https://" + url url = "https://" + url
} }
cleanHost := strings.TrimPrefix(host, "https://") cleanHost := strings.TrimPrefix(host, "https://")
cleanHost = strings.TrimPrefix(cleanHost, "http://") cleanHost = strings.TrimPrefix(cleanHost, "http://")
cleanHost = strings.TrimSuffix(cleanHost, "/") cleanHost = strings.TrimSuffix(cleanHost, "/")
return cleanHost + "/" + url return cleanHost + "/" + url
} }

View File

@@ -1,303 +1,301 @@
package main package main
import ( import (
"fmt" "fmt"
"net" "net"
"strings" "strings"
"sync" "sync"
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"golang.org/x/time/rate" "golang.org/x/time/rate"
) )
const ( const (
// 清理间隔 // 清理间隔
CleanupInterval = 10 * time.Minute CleanupInterval = 10 * time.Minute
MaxIPCacheSize = 10000 MaxIPCacheSize = 10000
) )
// IPRateLimiter IP限流器结构体 // IPRateLimiter IP限流器结构体
type IPRateLimiter struct { type IPRateLimiter struct {
ips map[string]*rateLimiterEntry // IP到限流器的映射 ips map[string]*rateLimiterEntry // IP到限流器的映射
mu *sync.RWMutex // 读写锁,保证并发安全 mu *sync.RWMutex // 读写锁,保证并发安全
r rate.Limit // 速率限制(每秒允许的请求数) r rate.Limit // 速率限制(每秒允许的请求数)
b int // 令牌桶容量(突发请求数) b int // 令牌桶容量(突发请求数)
whitelist []*net.IPNet // 白名单IP段 whitelist []*net.IPNet // 白名单IP段
blacklist []*net.IPNet // 黑名单IP段 blacklist []*net.IPNet // 黑名单IP段
} }
// rateLimiterEntry 限流器条目 // rateLimiterEntry 限流器条目
type rateLimiterEntry struct { type rateLimiterEntry struct {
limiter *rate.Limiter limiter *rate.Limiter
lastAccess time.Time lastAccess time.Time
} }
// initGlobalLimiter 初始化全局限流器 // initGlobalLimiter 初始化全局限流器
func initGlobalLimiter() *IPRateLimiter { func initGlobalLimiter() *IPRateLimiter {
cfg := GetConfig() cfg := GetConfig()
whitelist := make([]*net.IPNet, 0, len(cfg.Security.WhiteList)) whitelist := make([]*net.IPNet, 0, len(cfg.Security.WhiteList))
for _, item := range cfg.Security.WhiteList { for _, item := range cfg.Security.WhiteList {
if item = strings.TrimSpace(item); item != "" { if item = strings.TrimSpace(item); item != "" {
if !strings.Contains(item, "/") { if !strings.Contains(item, "/") {
item = item + "/32" // 单个IP转为CIDR格式 item = item + "/32" // 单个IP转为CIDR格式
} }
_, ipnet, err := net.ParseCIDR(item) _, ipnet, err := net.ParseCIDR(item)
if err == nil { if err == nil {
whitelist = append(whitelist, ipnet) whitelist = append(whitelist, ipnet)
} else { } else {
fmt.Printf("警告: 无效的白名单IP格式: %s\n", item) fmt.Printf("警告: 无效的白名单IP格式: %s\n", item)
} }
} }
} }
// 解析黑名单IP段 // 解析黑名单IP段
blacklist := make([]*net.IPNet, 0, len(cfg.Security.BlackList)) blacklist := make([]*net.IPNet, 0, len(cfg.Security.BlackList))
for _, item := range cfg.Security.BlackList { for _, item := range cfg.Security.BlackList {
if item = strings.TrimSpace(item); item != "" { if item = strings.TrimSpace(item); item != "" {
if !strings.Contains(item, "/") { if !strings.Contains(item, "/") {
item = item + "/32" // 单个IP转为CIDR格式 item = item + "/32" // 单个IP转为CIDR格式
} }
_, ipnet, err := net.ParseCIDR(item) _, ipnet, err := net.ParseCIDR(item)
if err == nil { if err == nil {
blacklist = append(blacklist, ipnet) blacklist = append(blacklist, ipnet)
} else { } else {
fmt.Printf("警告: 无效的黑名单IP格式: %s\n", item) fmt.Printf("警告: 无效的黑名单IP格式: %s\n", item)
} }
} }
} }
// 计算速率:将 "每N小时X个请求" 转换为 "每秒Y个请求" // 计算速率:将 "每N小时X个请求" 转换为 "每秒Y个请求"
ratePerSecond := rate.Limit(float64(cfg.RateLimit.RequestLimit) / (cfg.RateLimit.PeriodHours * 3600)) ratePerSecond := rate.Limit(float64(cfg.RateLimit.RequestLimit) / (cfg.RateLimit.PeriodHours * 3600))
burstSize := cfg.RateLimit.RequestLimit burstSize := cfg.RateLimit.RequestLimit
if burstSize < 1 { if burstSize < 1 {
burstSize = 1 burstSize = 1
} }
limiter := &IPRateLimiter{ limiter := &IPRateLimiter{
ips: make(map[string]*rateLimiterEntry), ips: make(map[string]*rateLimiterEntry),
mu: &sync.RWMutex{}, mu: &sync.RWMutex{},
r: ratePerSecond, r: ratePerSecond,
b: burstSize, b: burstSize,
whitelist: whitelist, whitelist: whitelist,
blacklist: blacklist, blacklist: blacklist,
} }
// 启动定期清理goroutine // 启动定期清理goroutine
go limiter.cleanupRoutine() go limiter.cleanupRoutine()
return limiter return limiter
} }
// initLimiter 初始化限流器 // initLimiter 初始化限流器
func initLimiter() { func initLimiter() {
globalLimiter = initGlobalLimiter() globalLimiter = initGlobalLimiter()
} }
// cleanupRoutine 定期清理过期的限流器 // cleanupRoutine 定期清理过期的限流器
func (i *IPRateLimiter) cleanupRoutine() { func (i *IPRateLimiter) cleanupRoutine() {
ticker := time.NewTicker(CleanupInterval) ticker := time.NewTicker(CleanupInterval)
defer ticker.Stop() defer ticker.Stop()
for range ticker.C { for range ticker.C {
now := time.Now() now := time.Now()
expired := make([]string, 0) expired := make([]string, 0)
// 查找过期的条目 // 查找过期的条目
i.mu.RLock() i.mu.RLock()
for ip, entry := range i.ips { for ip, entry := range i.ips {
// 如果最后访问时间超过1小时认为过期 // 如果最后访问时间超过1小时认为过期
if now.Sub(entry.lastAccess) > 1*time.Hour { if now.Sub(entry.lastAccess) > 1*time.Hour {
expired = append(expired, ip) expired = append(expired, ip)
} }
} }
i.mu.RUnlock() i.mu.RUnlock()
// 如果有过期条目或者缓存过大,进行清理 // 如果有过期条目或者缓存过大,进行清理
if len(expired) > 0 || len(i.ips) > MaxIPCacheSize { if len(expired) > 0 || len(i.ips) > MaxIPCacheSize {
i.mu.Lock() i.mu.Lock()
// 删除过期条目 // 删除过期条目
for _, ip := range expired { for _, ip := range expired {
delete(i.ips, ip) delete(i.ips, ip)
} }
// 如果缓存仍然过大,全部清理 // 如果缓存仍然过大,全部清理
if len(i.ips) > MaxIPCacheSize { if len(i.ips) > MaxIPCacheSize {
i.ips = make(map[string]*rateLimiterEntry) i.ips = make(map[string]*rateLimiterEntry)
} }
i.mu.Unlock() i.mu.Unlock()
} }
} }
} }
// extractIPFromAddress 从地址中提取纯IP // extractIPFromAddress 从地址中提取纯IP
func extractIPFromAddress(address string) string { func extractIPFromAddress(address string) string {
if host, _, err := net.SplitHostPort(address); err == nil { if host, _, err := net.SplitHostPort(address); err == nil {
return host return host
} }
return address return address
} }
// normalizeIPForRateLimit 标准化IP地址用于限流IPv4保持不变IPv6标准化为/64网段 // normalizeIPForRateLimit 标准化IP地址用于限流IPv4保持不变IPv6标准化为/64网段
func normalizeIPForRateLimit(ipStr string) string { func normalizeIPForRateLimit(ipStr string) string {
ip := net.ParseIP(ipStr) ip := net.ParseIP(ipStr)
if ip == nil { if ip == nil {
return ipStr // 解析失败,返回原值 return ipStr // 解析失败,返回原值
} }
if ip.To4() != nil { if ip.To4() != nil {
return ipStr // IPv4保持不变 return ipStr // IPv4保持不变
} }
// IPv6标准化为 /64 网段 // IPv6标准化为 /64 网段
ipv6 := ip.To16() ipv6 := ip.To16()
for i := 8; i < 16; i++ { for i := 8; i < 16; i++ {
ipv6[i] = 0 // 清零后64位 ipv6[i] = 0 // 清零后64位
} }
return ipv6.String() + "/64" return ipv6.String() + "/64"
} }
// isIPInCIDRList 检查IP是否在CIDR列表中 // isIPInCIDRList 检查IP是否在CIDR列表中
func isIPInCIDRList(ip string, cidrList []*net.IPNet) bool { func isIPInCIDRList(ip string, cidrList []*net.IPNet) bool {
// 先提取纯IP地址 // 先提取纯IP地址
cleanIP := extractIPFromAddress(ip) cleanIP := extractIPFromAddress(ip)
parsedIP := net.ParseIP(cleanIP) parsedIP := net.ParseIP(cleanIP)
if parsedIP == nil { if parsedIP == nil {
return false return false
} }
for _, cidr := range cidrList { for _, cidr := range cidrList {
if cidr.Contains(parsedIP) { if cidr.Contains(parsedIP) {
return true return true
} }
} }
return false return false
} }
// GetLimiter 获取指定IP的限流器同时返回是否允许访问 // GetLimiter 获取指定IP的限流器同时返回是否允许访问
func (i *IPRateLimiter) GetLimiter(ip string) (*rate.Limiter, bool) { func (i *IPRateLimiter) GetLimiter(ip string) (*rate.Limiter, bool) {
// 提取纯IP地址 // 提取纯IP地址
cleanIP := extractIPFromAddress(ip) cleanIP := extractIPFromAddress(ip)
// 检查是否在黑名单中 // 检查是否在黑名单中
if isIPInCIDRList(cleanIP, i.blacklist) { if isIPInCIDRList(cleanIP, i.blacklist) {
return nil, false return nil, false
} }
// 检查是否在白名单中 // 检查是否在白名单中
if isIPInCIDRList(cleanIP, i.whitelist) { if isIPInCIDRList(cleanIP, i.whitelist) {
return rate.NewLimiter(rate.Inf, i.b), true return rate.NewLimiter(rate.Inf, i.b), true
} }
// 标准化IP用于限流IPv4保持不变IPv6标准化为/64网段 // 标准化IP用于限流IPv4保持不变IPv6标准化为/64网段
normalizedIP := normalizeIPForRateLimit(cleanIP) normalizedIP := normalizeIPForRateLimit(cleanIP)
now := time.Now() now := time.Now()
i.mu.RLock() i.mu.RLock()
entry, exists := i.ips[normalizedIP] entry, exists := i.ips[normalizedIP]
i.mu.RUnlock() i.mu.RUnlock()
if exists { if exists {
i.mu.Lock() i.mu.Lock()
if entry, stillExists := i.ips[normalizedIP]; stillExists { if entry, stillExists := i.ips[normalizedIP]; stillExists {
entry.lastAccess = now entry.lastAccess = now
i.mu.Unlock() i.mu.Unlock()
return entry.limiter, true return entry.limiter, true
} }
i.mu.Unlock() i.mu.Unlock()
} }
i.mu.Lock() i.mu.Lock()
if entry, exists := i.ips[normalizedIP]; exists { if entry, exists := i.ips[normalizedIP]; exists {
entry.lastAccess = now entry.lastAccess = now
i.mu.Unlock() i.mu.Unlock()
return entry.limiter, true return entry.limiter, true
} }
entry = &rateLimiterEntry{ entry = &rateLimiterEntry{
limiter: rate.NewLimiter(i.r, i.b), limiter: rate.NewLimiter(i.r, i.b),
lastAccess: now, lastAccess: now,
} }
i.ips[normalizedIP] = entry i.ips[normalizedIP] = entry
i.mu.Unlock() i.mu.Unlock()
return entry.limiter, true return entry.limiter, true
} }
// RateLimitMiddleware 速率限制中间件 // RateLimitMiddleware 速率限制中间件
func RateLimitMiddleware(limiter *IPRateLimiter) gin.HandlerFunc { func RateLimitMiddleware(limiter *IPRateLimiter) gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
// 静态文件豁免:跳过限流检查 // 静态文件豁免:跳过限流检查
path := c.Request.URL.Path path := c.Request.URL.Path
if path == "/" || path == "/favicon.ico" || path == "/images.html" || path == "/search.html" || if path == "/" || path == "/favicon.ico" || path == "/images.html" || path == "/search.html" ||
strings.HasPrefix(path, "/public/") { strings.HasPrefix(path, "/public/") {
c.Next() c.Next()
return return
} }
// 获取客户端真实IP // 获取客户端真实IP
var ip string var ip string
// 优先尝试从请求头获取真实IP // 优先尝试从请求头获取真实IP
if forwarded := c.GetHeader("X-Forwarded-For"); forwarded != "" { if forwarded := c.GetHeader("X-Forwarded-For"); forwarded != "" {
// X-Forwarded-For可能包含多个IP取第一个 // X-Forwarded-For可能包含多个IP取第一个
ips := strings.Split(forwarded, ",") ips := strings.Split(forwarded, ",")
ip = strings.TrimSpace(ips[0]) ip = strings.TrimSpace(ips[0])
} else if realIP := c.GetHeader("X-Real-IP"); realIP != "" { } else if realIP := c.GetHeader("X-Real-IP"); realIP != "" {
// 如果有X-Real-IP头 // 如果有X-Real-IP头
ip = realIP ip = realIP
} else if remoteIP := c.GetHeader("X-Original-Forwarded-For"); remoteIP != "" { } else if remoteIP := c.GetHeader("X-Original-Forwarded-For"); remoteIP != "" {
// 某些代理可能使用此头 // 某些代理可能使用此头
ips := strings.Split(remoteIP, ",") ips := strings.Split(remoteIP, ",")
ip = strings.TrimSpace(ips[0]) ip = strings.TrimSpace(ips[0])
} else { } else {
// 回退到ClientIP方法 // 回退到ClientIP方法
ip = c.ClientIP() ip = c.ClientIP()
} }
// 提取纯IP地址去除可能存在的端口 // 提取纯IP地址去除可能存在的端口
cleanIP := extractIPFromAddress(ip) cleanIP := extractIPFromAddress(ip)
// 日志记录请求IP和头信息 // 日志记录请求IP和头信息
normalizedIP := normalizeIPForRateLimit(cleanIP) normalizedIP := normalizeIPForRateLimit(cleanIP)
if cleanIP != normalizedIP { if cleanIP != normalizedIP {
fmt.Printf("请求IP: %s (提纯后: %s, 限流段: %s), X-Forwarded-For: %s, X-Real-IP: %s\n", fmt.Printf("请求IP: %s (提纯后: %s, 限流段: %s), X-Forwarded-For: %s, X-Real-IP: %s\n",
ip, cleanIP, normalizedIP, ip, cleanIP, normalizedIP,
c.GetHeader("X-Forwarded-For"), c.GetHeader("X-Forwarded-For"),
c.GetHeader("X-Real-IP")) c.GetHeader("X-Real-IP"))
} else { } else {
fmt.Printf("请求IP: %s (提纯后: %s), X-Forwarded-For: %s, X-Real-IP: %s\n", fmt.Printf("请求IP: %s (提纯后: %s), X-Forwarded-For: %s, X-Real-IP: %s\n",
ip, cleanIP, ip, cleanIP,
c.GetHeader("X-Forwarded-For"), c.GetHeader("X-Forwarded-For"),
c.GetHeader("X-Real-IP")) c.GetHeader("X-Real-IP"))
} }
// 获取限流器并检查是否允许访问 // 获取限流器并检查是否允许访问
ipLimiter, allowed := limiter.GetLimiter(cleanIP) ipLimiter, allowed := limiter.GetLimiter(cleanIP)
// 如果IP在黑名单中 // 如果IP在黑名单中
if !allowed { if !allowed {
c.JSON(403, gin.H{ c.JSON(403, gin.H{
"error": "您已被限制访问", "error": "您已被限制访问",
}) })
c.Abort() c.Abort()
return return
} }
// 检查限流 // 检查限流
if !ipLimiter.Allow() { if !ipLimiter.Allow() {
c.JSON(429, gin.H{ c.JSON(429, gin.H{
"error": "请求频率过快,暂时限制访问", "error": "请求频率过快,暂时限制访问",
}) })
c.Abort() c.Abort()
return return
} }
c.Next() c.Next()
} }
} }

File diff suppressed because it is too large Load Diff

View File

@@ -13,10 +13,10 @@ import (
// CachedItem 通用缓存项支持Token和Manifest // CachedItem 通用缓存项支持Token和Manifest
type CachedItem struct { type CachedItem struct {
Data []byte // 缓存数据(token字符串或manifest字节) Data []byte // 缓存数据(token字符串或manifest字节)
ContentType string // 内容类型 ContentType string // 内容类型
Headers map[string]string // 额外的响应头 Headers map[string]string // 额外的响应头
ExpiresAt time.Time // 过期时间 ExpiresAt time.Time // 过期时间
} }
// UniversalCache 通用缓存支持Token和Manifest // UniversalCache 通用缓存支持Token和Manifest
@@ -79,18 +79,18 @@ func getManifestTTL(reference string) time.Duration {
defaultTTL = parsed defaultTTL = parsed
} }
} }
if strings.HasPrefix(reference, "sha256:") { if strings.HasPrefix(reference, "sha256:") {
return 24 * time.Hour return 24 * time.Hour
} }
// mutable tag的智能判断 // mutable tag的智能判断
if reference == "latest" || reference == "main" || reference == "master" || if reference == "latest" || reference == "main" || reference == "master" ||
reference == "dev" || reference == "develop" { reference == "dev" || reference == "develop" {
// 热门可变标签: 短期缓存 // 热门可变标签: 短期缓存
return 10 * time.Minute return 10 * time.Minute
} }
return defaultTTL return defaultTTL
} }
@@ -99,17 +99,17 @@ func extractTTLFromResponse(responseBody []byte) time.Duration {
var tokenResp struct { var tokenResp struct {
ExpiresIn int `json:"expires_in"` ExpiresIn int `json:"expires_in"`
} }
// 默认30分钟TTL确保稳定性 // 默认30分钟TTL确保稳定性
defaultTTL := 30 * time.Minute defaultTTL := 30 * time.Minute
if json.Unmarshal(responseBody, &tokenResp) == nil && tokenResp.ExpiresIn > 0 { if json.Unmarshal(responseBody, &tokenResp) == nil && tokenResp.ExpiresIn > 0 {
safeTTL := time.Duration(tokenResp.ExpiresIn-300) * time.Second safeTTL := time.Duration(tokenResp.ExpiresIn-300) * time.Second
if safeTTL > 5*time.Minute { if safeTTL > 5*time.Minute {
return safeTTL return safeTTL
} }
} }
return defaultTTL return defaultTTL
} }
@@ -122,12 +122,12 @@ func writeCachedResponse(c *gin.Context, item *CachedItem) {
if item.ContentType != "" { if item.ContentType != "" {
c.Header("Content-Type", item.ContentType) c.Header("Content-Type", item.ContentType)
} }
// 设置额外的响应头 // 设置额外的响应头
for key, value := range item.Headers { for key, value := range item.Headers {
c.Header(key, value) c.Header(key, value)
} }
// 返回数据 // 返回数据
c.Data(200, item.ContentType, item.Data) c.Data(200, item.ContentType, item.Data)
} }
@@ -148,21 +148,21 @@ func init() {
go func() { go func() {
ticker := time.NewTicker(20 * time.Minute) ticker := time.NewTicker(20 * time.Minute)
defer ticker.Stop() defer ticker.Stop()
for range ticker.C { for range ticker.C {
now := time.Now() now := time.Now()
expiredKeys := make([]string, 0) expiredKeys := make([]string, 0)
globalCache.cache.Range(func(key, value interface{}) bool { globalCache.cache.Range(func(key, value interface{}) bool {
if cached := value.(*CachedItem); now.After(cached.ExpiresAt) { if cached := value.(*CachedItem); now.After(cached.ExpiresAt) {
expiredKeys = append(expiredKeys, key.(string)) expiredKeys = append(expiredKeys, key.(string))
} }
return true return true
}) })
for _, key := range expiredKeys { for _, key := range expiredKeys {
globalCache.cache.Delete(key) globalCache.cache.Delete(key)
} }
} }
}() }()
} }