拆分包结构

This commit is contained in:
user123456
2025-07-27 05:50:34 +08:00
parent badafd2899
commit 187e842445
15 changed files with 956 additions and 1148 deletions

608
src/handlers/docker.go Normal file
View File

@@ -0,0 +1,608 @@
package handlers
import (
"context"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/google/go-containerregistry/pkg/authn"
"github.com/google/go-containerregistry/pkg/name"
"github.com/google/go-containerregistry/pkg/v1/remote"
"hubproxy/config"
"hubproxy/utils"
)
// DockerProxy Docker代理配置
type DockerProxy struct {
registry name.Registry
options []remote.Option
}
var dockerProxy *DockerProxy
// RegistryDetector Registry检测器
type RegistryDetector struct{}
// detectRegistryDomain 检测Registry域名并返回域名和剩余路径
func (rd *RegistryDetector) detectRegistryDomain(path string) (string, string) {
cfg := config.GetConfig()
for domain := range cfg.Registries {
if strings.HasPrefix(path, domain+"/") {
remainingPath := strings.TrimPrefix(path, domain+"/")
return domain, remainingPath
}
}
return "", path
}
// isRegistryEnabled 检查Registry是否启用
func (rd *RegistryDetector) isRegistryEnabled(domain string) bool {
cfg := config.GetConfig()
if mapping, exists := cfg.Registries[domain]; exists {
return mapping.Enabled
}
return false
}
// getRegistryMapping 获取Registry映射配置
func (rd *RegistryDetector) getRegistryMapping(domain string) (config.RegistryMapping, bool) {
cfg := config.GetConfig()
mapping, exists := cfg.Registries[domain]
return mapping, exists && mapping.Enabled
}
var registryDetector = &RegistryDetector{}
// InitDockerProxy 初始化Docker代理
func InitDockerProxy() {
registry, err := name.NewRegistry("registry-1.docker.io")
if err != nil {
fmt.Printf("创建Docker registry失败: %v\n", err)
return
}
options := []remote.Option{
remote.WithAuth(authn.Anonymous),
remote.WithUserAgent("hubproxy/go-containerregistry"),
remote.WithTransport(utils.GetGlobalHTTPClient().Transport),
}
dockerProxy = &DockerProxy{
registry: registry,
options: options,
}
}
// ProxyDockerRegistryGin 标准Docker Registry API v2代理
func ProxyDockerRegistryGin(c *gin.Context) {
path := c.Request.URL.Path
if path == "/v2/" {
c.JSON(http.StatusOK, gin.H{})
return
}
if strings.HasPrefix(path, "/v2/") {
handleRegistryRequest(c, path)
} else {
c.String(http.StatusNotFound, "Docker Registry API v2 only")
}
}
// handleRegistryRequest 处理Registry请求
func handleRegistryRequest(c *gin.Context, path string) {
pathWithoutV2 := strings.TrimPrefix(path, "/v2/")
if registryDomain, remainingPath := registryDetector.detectRegistryDomain(pathWithoutV2); registryDomain != "" {
if registryDetector.isRegistryEnabled(registryDomain) {
c.Set("target_registry_domain", registryDomain)
c.Set("target_path", remainingPath)
handleMultiRegistryRequest(c, registryDomain, remainingPath)
return
}
}
imageName, apiType, reference := parseRegistryPath(pathWithoutV2)
if imageName == "" || apiType == "" {
c.String(http.StatusBadRequest, "Invalid path format")
return
}
if !strings.Contains(imageName, "/") {
imageName = "library/" + imageName
}
if allowed, reason := utils.GlobalAccessController.CheckDockerAccess(imageName); !allowed {
fmt.Printf("Docker镜像 %s 访问被拒绝: %s\n", imageName, reason)
c.String(http.StatusForbidden, "镜像访问被限制")
return
}
imageRef := fmt.Sprintf("%s/%s", dockerProxy.registry.Name(), imageName)
switch apiType {
case "manifests":
handleManifestRequest(c, imageRef, reference)
case "blobs":
handleBlobRequest(c, imageRef, reference)
case "tags":
handleTagsRequest(c, imageRef)
default:
c.String(http.StatusNotFound, "API endpoint not found")
}
}
// parseRegistryPath 解析Registry路径
func parseRegistryPath(path string) (imageName, apiType, reference string) {
if idx := strings.Index(path, "/manifests/"); idx != -1 {
imageName = path[:idx]
apiType = "manifests"
reference = path[idx+len("/manifests/"):]
return
}
if idx := strings.Index(path, "/blobs/"); idx != -1 {
imageName = path[:idx]
apiType = "blobs"
reference = path[idx+len("/blobs/"):]
return
}
if idx := strings.Index(path, "/tags/list"); idx != -1 {
imageName = path[:idx]
apiType = "tags"
reference = "list"
return
}
return "", "", ""
}
// handleManifestRequest 处理manifest请求
func handleManifestRequest(c *gin.Context, imageRef, reference string) {
if utils.IsCacheEnabled() && c.Request.Method == http.MethodGet {
cacheKey := utils.BuildManifestCacheKey(imageRef, reference)
if cachedItem := utils.GlobalCache.Get(cacheKey); cachedItem != nil {
utils.WriteCachedResponse(c, cachedItem)
return
}
}
var ref name.Reference
var err error
if strings.HasPrefix(reference, "sha256:") {
ref, err = name.NewDigest(fmt.Sprintf("%s@%s", imageRef, reference))
} else {
ref, err = name.NewTag(fmt.Sprintf("%s:%s", imageRef, reference))
}
if err != nil {
fmt.Printf("解析镜像引用失败: %v\n", err)
c.String(http.StatusBadRequest, "Invalid reference")
return
}
if c.Request.Method == http.MethodHead {
desc, err := remote.Head(ref, dockerProxy.options...)
if err != nil {
fmt.Printf("HEAD请求失败: %v\n", err)
c.String(http.StatusNotFound, "Manifest not found")
return
}
c.Header("Content-Type", string(desc.MediaType))
c.Header("Docker-Content-Digest", desc.Digest.String())
c.Header("Content-Length", fmt.Sprintf("%d", desc.Size))
c.Status(http.StatusOK)
} else {
desc, err := remote.Get(ref, dockerProxy.options...)
if err != nil {
fmt.Printf("GET请求失败: %v\n", err)
c.String(http.StatusNotFound, "Manifest not found")
return
}
headers := map[string]string{
"Docker-Content-Digest": desc.Digest.String(),
"Content-Length": fmt.Sprintf("%d", len(desc.Manifest)),
}
if utils.IsCacheEnabled() {
cacheKey := utils.BuildManifestCacheKey(imageRef, reference)
ttl := utils.GetManifestTTL(reference)
utils.GlobalCache.Set(cacheKey, desc.Manifest, string(desc.MediaType), headers, ttl)
}
c.Header("Content-Type", string(desc.MediaType))
for key, value := range headers {
c.Header(key, value)
}
c.Data(http.StatusOK, string(desc.MediaType), desc.Manifest)
}
}
// handleBlobRequest 处理blob请求
func handleBlobRequest(c *gin.Context, imageRef, digest string) {
digestRef, err := name.NewDigest(fmt.Sprintf("%s@%s", imageRef, digest))
if err != nil {
fmt.Printf("解析digest引用失败: %v\n", err)
c.String(http.StatusBadRequest, "Invalid digest reference")
return
}
layer, err := remote.Layer(digestRef, dockerProxy.options...)
if err != nil {
fmt.Printf("获取layer失败: %v\n", err)
c.String(http.StatusNotFound, "Layer not found")
return
}
size, err := layer.Size()
if err != nil {
fmt.Printf("获取layer大小失败: %v\n", err)
c.String(http.StatusInternalServerError, "Failed to get layer size")
return
}
reader, err := layer.Compressed()
if err != nil {
fmt.Printf("获取layer内容失败: %v\n", err)
c.String(http.StatusInternalServerError, "Failed to get layer content")
return
}
defer reader.Close()
c.Header("Content-Type", "application/octet-stream")
c.Header("Content-Length", fmt.Sprintf("%d", size))
c.Header("Docker-Content-Digest", digest)
c.Status(http.StatusOK)
io.Copy(c.Writer, reader)
}
// handleTagsRequest 处理tags列表请求
func handleTagsRequest(c *gin.Context, imageRef string) {
repo, err := name.NewRepository(imageRef)
if err != nil {
fmt.Printf("解析repository失败: %v\n", err)
c.String(http.StatusBadRequest, "Invalid repository")
return
}
tags, err := remote.List(repo, dockerProxy.options...)
if err != nil {
fmt.Printf("获取tags失败: %v\n", err)
c.String(http.StatusNotFound, "Tags not found")
return
}
response := map[string]interface{}{
"name": strings.TrimPrefix(imageRef, dockerProxy.registry.Name()+"/"),
"tags": tags,
}
c.JSON(http.StatusOK, response)
}
// ProxyDockerAuthGin Docker认证代理
func ProxyDockerAuthGin(c *gin.Context) {
if utils.IsTokenCacheEnabled() {
proxyDockerAuthWithCache(c)
} else {
proxyDockerAuthOriginal(c)
}
}
// proxyDockerAuthWithCache 带缓存的认证代理
func proxyDockerAuthWithCache(c *gin.Context) {
cacheKey := utils.BuildTokenCacheKey(c.Request.URL.RawQuery)
if cachedToken := utils.GlobalCache.GetToken(cacheKey); cachedToken != "" {
utils.WriteTokenResponse(c, cachedToken)
return
}
recorder := &ResponseRecorder{
ResponseWriter: c.Writer,
statusCode: 200,
}
c.Writer = recorder
proxyDockerAuthOriginal(c)
if recorder.statusCode == 200 && len(recorder.body) > 0 {
ttl := utils.ExtractTTLFromResponse(recorder.body)
utils.GlobalCache.SetToken(cacheKey, string(recorder.body), ttl)
}
c.Writer = recorder.ResponseWriter
c.Data(recorder.statusCode, "application/json", recorder.body)
}
// ResponseRecorder HTTP响应记录器
type ResponseRecorder struct {
gin.ResponseWriter
statusCode int
body []byte
}
func (r *ResponseRecorder) WriteHeader(code int) {
r.statusCode = code
}
func (r *ResponseRecorder) Write(data []byte) (int, error) {
r.body = append(r.body, data...)
return len(data), nil
}
func proxyDockerAuthOriginal(c *gin.Context) {
var authURL string
if targetDomain, exists := c.Get("target_registry_domain"); exists {
if mapping, found := registryDetector.getRegistryMapping(targetDomain.(string)); found {
authURL = "https://" + mapping.AuthHost + c.Request.URL.Path
} else {
authURL = "https://auth.docker.io" + c.Request.URL.Path
}
} else {
authURL = "https://auth.docker.io" + c.Request.URL.Path
}
if c.Request.URL.RawQuery != "" {
authURL += "?" + c.Request.URL.RawQuery
}
client := &http.Client{
Timeout: 30 * time.Second,
Transport: utils.GetGlobalHTTPClient().Transport,
}
req, err := http.NewRequestWithContext(
context.Background(),
c.Request.Method,
authURL,
c.Request.Body,
)
if err != nil {
c.String(http.StatusInternalServerError, "Failed to create request")
return
}
for key, values := range c.Request.Header {
for _, value := range values {
req.Header.Add(key, value)
}
}
resp, err := client.Do(req)
if err != nil {
c.String(http.StatusBadGateway, "Auth request failed")
return
}
defer resp.Body.Close()
proxyHost := c.Request.Host
if proxyHost == "" {
cfg := config.GetConfig()
proxyHost = fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port)
if cfg.Server.Host == "0.0.0.0" {
proxyHost = fmt.Sprintf("localhost:%d", cfg.Server.Port)
}
}
for key, values := range resp.Header {
for _, value := range values {
if key == "Www-Authenticate" {
value = rewriteAuthHeader(value, proxyHost)
}
c.Header(key, value)
}
}
c.Status(resp.StatusCode)
io.Copy(c.Writer, resp.Body)
}
// rewriteAuthHeader 重写认证头
func rewriteAuthHeader(authHeader, proxyHost string) string {
authHeader = strings.ReplaceAll(authHeader, "https://auth.docker.io", "http://"+proxyHost)
authHeader = strings.ReplaceAll(authHeader, "https://ghcr.io", "http://"+proxyHost)
authHeader = strings.ReplaceAll(authHeader, "https://gcr.io", "http://"+proxyHost)
authHeader = strings.ReplaceAll(authHeader, "https://quay.io", "http://"+proxyHost)
return authHeader
}
// handleMultiRegistryRequest 处理多Registry请求
func handleMultiRegistryRequest(c *gin.Context, registryDomain, remainingPath string) {
mapping, exists := registryDetector.getRegistryMapping(registryDomain)
if !exists {
c.String(http.StatusBadRequest, "Registry not configured")
return
}
imageName, apiType, reference := parseRegistryPath(remainingPath)
if imageName == "" || apiType == "" {
c.String(http.StatusBadRequest, "Invalid path format")
return
}
fullImageName := registryDomain + "/" + imageName
if allowed, reason := utils.GlobalAccessController.CheckDockerAccess(fullImageName); !allowed {
fmt.Printf("镜像 %s 访问被拒绝: %s\n", fullImageName, reason)
c.String(http.StatusForbidden, "镜像访问被限制")
return
}
upstreamImageRef := fmt.Sprintf("%s/%s", mapping.Upstream, imageName)
switch apiType {
case "manifests":
handleUpstreamManifestRequest(c, upstreamImageRef, reference, mapping)
case "blobs":
handleUpstreamBlobRequest(c, upstreamImageRef, reference, mapping)
case "tags":
handleUpstreamTagsRequest(c, upstreamImageRef, mapping)
default:
c.String(http.StatusNotFound, "API endpoint not found")
}
}
// handleUpstreamManifestRequest 处理上游Registry的manifest请求
func handleUpstreamManifestRequest(c *gin.Context, imageRef, reference string, mapping config.RegistryMapping) {
if utils.IsCacheEnabled() && c.Request.Method == http.MethodGet {
cacheKey := utils.BuildManifestCacheKey(imageRef, reference)
if cachedItem := utils.GlobalCache.Get(cacheKey); cachedItem != nil {
utils.WriteCachedResponse(c, cachedItem)
return
}
}
var ref name.Reference
var err error
if strings.HasPrefix(reference, "sha256:") {
ref, err = name.NewDigest(fmt.Sprintf("%s@%s", imageRef, reference))
} else {
ref, err = name.NewTag(fmt.Sprintf("%s:%s", imageRef, reference))
}
if err != nil {
fmt.Printf("解析镜像引用失败: %v\n", err)
c.String(http.StatusBadRequest, "Invalid reference")
return
}
options := createUpstreamOptions(mapping)
if c.Request.Method == http.MethodHead {
desc, err := remote.Head(ref, options...)
if err != nil {
fmt.Printf("HEAD请求失败: %v\n", err)
c.String(http.StatusNotFound, "Manifest not found")
return
}
c.Header("Content-Type", string(desc.MediaType))
c.Header("Docker-Content-Digest", desc.Digest.String())
c.Header("Content-Length", fmt.Sprintf("%d", desc.Size))
c.Status(http.StatusOK)
} else {
desc, err := remote.Get(ref, options...)
if err != nil {
fmt.Printf("GET请求失败: %v\n", err)
c.String(http.StatusNotFound, "Manifest not found")
return
}
headers := map[string]string{
"Docker-Content-Digest": desc.Digest.String(),
"Content-Length": fmt.Sprintf("%d", len(desc.Manifest)),
}
if utils.IsCacheEnabled() {
cacheKey := utils.BuildManifestCacheKey(imageRef, reference)
ttl := utils.GetManifestTTL(reference)
utils.GlobalCache.Set(cacheKey, desc.Manifest, string(desc.MediaType), headers, ttl)
}
c.Header("Content-Type", string(desc.MediaType))
for key, value := range headers {
c.Header(key, value)
}
c.Data(http.StatusOK, string(desc.MediaType), desc.Manifest)
}
}
// handleUpstreamBlobRequest 处理上游Registry的blob请求
func handleUpstreamBlobRequest(c *gin.Context, imageRef, digest string, mapping config.RegistryMapping) {
digestRef, err := name.NewDigest(fmt.Sprintf("%s@%s", imageRef, digest))
if err != nil {
fmt.Printf("解析digest引用失败: %v\n", err)
c.String(http.StatusBadRequest, "Invalid digest reference")
return
}
options := createUpstreamOptions(mapping)
layer, err := remote.Layer(digestRef, options...)
if err != nil {
fmt.Printf("获取layer失败: %v\n", err)
c.String(http.StatusNotFound, "Layer not found")
return
}
size, err := layer.Size()
if err != nil {
fmt.Printf("获取layer大小失败: %v\n", err)
c.String(http.StatusInternalServerError, "Failed to get layer size")
return
}
reader, err := layer.Compressed()
if err != nil {
fmt.Printf("获取layer内容失败: %v\n", err)
c.String(http.StatusInternalServerError, "Failed to get layer content")
return
}
defer reader.Close()
c.Header("Content-Type", "application/octet-stream")
c.Header("Content-Length", fmt.Sprintf("%d", size))
c.Header("Docker-Content-Digest", digest)
c.Status(http.StatusOK)
io.Copy(c.Writer, reader)
}
// handleUpstreamTagsRequest 处理上游Registry的tags请求
func handleUpstreamTagsRequest(c *gin.Context, imageRef string, mapping config.RegistryMapping) {
repo, err := name.NewRepository(imageRef)
if err != nil {
fmt.Printf("解析repository失败: %v\n", err)
c.String(http.StatusBadRequest, "Invalid repository")
return
}
options := createUpstreamOptions(mapping)
tags, err := remote.List(repo, options...)
if err != nil {
fmt.Printf("获取tags失败: %v\n", err)
c.String(http.StatusNotFound, "Tags not found")
return
}
response := map[string]interface{}{
"name": strings.TrimPrefix(imageRef, mapping.Upstream+"/"),
"tags": tags,
}
c.JSON(http.StatusOK, response)
}
// createUpstreamOptions 创建上游Registry选项
func createUpstreamOptions(mapping config.RegistryMapping) []remote.Option {
options := []remote.Option{
remote.WithAuth(authn.Anonymous),
remote.WithUserAgent("hubproxy/go-containerregistry"),
remote.WithTransport(utils.GetGlobalHTTPClient().Transport),
}
switch mapping.AuthType {
case "github":
case "google":
case "quay":
}
return options
}

213
src/handlers/github.go Normal file
View File

@@ -0,0 +1,213 @@
package handlers
import (
"fmt"
"io"
"net/http"
"regexp"
"strconv"
"strings"
"github.com/gin-gonic/gin"
"hubproxy/config"
"hubproxy/utils"
)
var (
// GitHub URL匹配正则表达式
githubExps = []*regexp.Regexp{
regexp.MustCompile(`^(?:https?://)?github\.com/([^/]+)/([^/]+)/(?:releases|archive)/.*`),
regexp.MustCompile(`^(?:https?://)?github\.com/([^/]+)/([^/]+)/(?:blob|raw)/.*`),
regexp.MustCompile(`^(?:https?://)?github\.com/([^/]+)/([^/]+)/(?:info|git-).*`),
regexp.MustCompile(`^(?:https?://)?raw\.github(?:usercontent|)\.com/([^/]+)/([^/]+)/.+?/.+`),
regexp.MustCompile(`^(?:https?://)?gist\.github(?:usercontent|)\.com/([^/]+)/.+?/.+`),
regexp.MustCompile(`^(?:https?://)?api\.github\.com/repos/([^/]+)/([^/]+)/.*`),
regexp.MustCompile(`^(?:https?://)?huggingface\.co(?:/spaces)?/([^/]+)/(.+)`),
regexp.MustCompile(`^(?:https?://)?cdn-lfs\.hf\.co(?:/spaces)?/([^/]+)/([^/]+)(?:/(.*))?`),
regexp.MustCompile(`^(?:https?://)?download\.docker\.com/([^/]+)/.*\.(tgz|zip)`),
regexp.MustCompile(`^(?:https?://)?(github|opengraph)\.githubassets\.com/([^/]+)/.+?`),
}
)
// GitHubProxyHandler GitHub代理处理器
func GitHubProxyHandler(c *gin.Context) {
rawPath := strings.TrimPrefix(c.Request.URL.RequestURI(), "/")
for strings.HasPrefix(rawPath, "/") {
rawPath = strings.TrimPrefix(rawPath, "/")
}
// 自动补全协议头
if !strings.HasPrefix(rawPath, "https://") {
if strings.HasPrefix(rawPath, "http:/") || strings.HasPrefix(rawPath, "https:/") {
rawPath = strings.Replace(rawPath, "http:/", "", 1)
rawPath = strings.Replace(rawPath, "https:/", "", 1)
} else if strings.HasPrefix(rawPath, "http://") {
rawPath = strings.TrimPrefix(rawPath, "http://")
}
rawPath = "https://" + rawPath
}
matches := CheckGitHubURL(rawPath)
if matches != nil {
if allowed, reason := utils.GlobalAccessController.CheckGitHubAccess(matches); !allowed {
var repoPath string
if len(matches) >= 2 {
username := matches[0]
repoName := strings.TrimSuffix(matches[1], ".git")
repoPath = username + "/" + repoName
}
fmt.Printf("GitHub仓库 %s 访问被拒绝: %s\n", repoPath, reason)
c.String(http.StatusForbidden, reason)
return
}
} else {
c.String(http.StatusForbidden, "无效输入")
return
}
// 将blob链接转换为raw链接
if githubExps[1].MatchString(rawPath) {
rawPath = strings.Replace(rawPath, "/blob/", "/raw/", 1)
}
ProxyGitHubRequest(c, rawPath)
}
// CheckGitHubURL 检查URL是否匹配GitHub模式
func CheckGitHubURL(u string) []string {
for _, exp := range githubExps {
if matches := exp.FindStringSubmatch(u); matches != nil {
return matches[1:]
}
}
return nil
}
// ProxyGitHubRequest 代理GitHub请求
func ProxyGitHubRequest(c *gin.Context, u string) {
proxyGitHubWithRedirect(c, u, 0)
}
// proxyGitHubWithRedirect 带重定向的GitHub代理请求
func proxyGitHubWithRedirect(c *gin.Context, u string, redirectCount int) {
const maxRedirects = 20
if redirectCount > maxRedirects {
c.String(http.StatusLoopDetected, "重定向次数过多,可能存在循环重定向")
return
}
req, err := http.NewRequest(c.Request.Method, u, c.Request.Body)
if err != nil {
c.String(http.StatusInternalServerError, fmt.Sprintf("server error %v", err))
return
}
// 复制请求头
for key, values := range c.Request.Header {
for _, value := range values {
req.Header.Add(key, value)
}
}
req.Header.Del("Host")
resp, err := utils.GetGlobalHTTPClient().Do(req)
if err != nil {
c.String(http.StatusInternalServerError, fmt.Sprintf("server error %v", err))
return
}
defer func() {
if err := resp.Body.Close(); err != nil {
fmt.Printf("关闭响应体失败: %v\n", err)
}
}()
// 检查文件大小限制
cfg := config.GetConfig()
if contentLength := resp.Header.Get("Content-Length"); contentLength != "" {
if size, err := strconv.ParseInt(contentLength, 10, 64); err == nil && size > cfg.Server.FileSize {
c.String(http.StatusRequestEntityTooLarge,
fmt.Sprintf("文件过大,限制大小: %d MB", cfg.Server.FileSize/(1024*1024)))
return
}
}
// 清理安全相关的头
resp.Header.Del("Content-Security-Policy")
resp.Header.Del("Referrer-Policy")
resp.Header.Del("Strict-Transport-Security")
// 获取真实域名
realHost := c.Request.Header.Get("X-Forwarded-Host")
if realHost == "" {
realHost = c.Request.Host
}
if !strings.HasPrefix(realHost, "http://") && !strings.HasPrefix(realHost, "https://") {
realHost = "https://" + realHost
}
// 处理.sh文件的智能处理
if strings.HasSuffix(strings.ToLower(u), ".sh") {
isGzipCompressed := resp.Header.Get("Content-Encoding") == "gzip"
processedBody, processedSize, err := utils.ProcessSmart(resp.Body, isGzipCompressed, realHost)
if err != nil {
fmt.Printf("智能处理失败,回退到直接代理: %v\n", err)
processedBody = resp.Body
processedSize = 0
}
// 智能设置响应头
if processedSize > 0 {
resp.Header.Del("Content-Length")
resp.Header.Del("Content-Encoding")
resp.Header.Set("Transfer-Encoding", "chunked")
}
// 复制其他响应头
for key, values := range resp.Header {
for _, value := range values {
c.Header(key, value)
}
}
// 处理重定向
if location := resp.Header.Get("Location"); location != "" {
if CheckGitHubURL(location) != nil {
c.Header("Location", "/"+location)
} else {
proxyGitHubWithRedirect(c, location, redirectCount+1)
return
}
}
c.Status(resp.StatusCode)
// 输出处理后的内容
if _, err := io.Copy(c.Writer, processedBody); err != nil {
return
}
} else {
// 复制响应头
for key, values := range resp.Header {
for _, value := range values {
c.Header(key, value)
}
}
// 处理重定向
if location := resp.Header.Get("Location"); location != "" {
if CheckGitHubURL(location) != nil {
c.Header("Location", "/"+location)
} else {
proxyGitHubWithRedirect(c, location, redirectCount+1)
return
}
}
c.Status(resp.StatusCode)
// 直接流式转发
io.Copy(c.Writer, resp.Body)
}
}

858
src/handlers/imagetar.go Normal file
View File

@@ -0,0 +1,858 @@
package handlers
import (
"archive/tar"
"compress/gzip"
"context"
"crypto/md5"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"sort"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
"github.com/google/go-containerregistry/pkg/authn"
"github.com/google/go-containerregistry/pkg/name"
"github.com/google/go-containerregistry/pkg/v1"
"github.com/google/go-containerregistry/pkg/v1/partial"
"github.com/google/go-containerregistry/pkg/v1/remote"
"github.com/google/go-containerregistry/pkg/v1/types"
"hubproxy/config"
"hubproxy/utils"
)
// DebounceEntry 防抖条目
type DebounceEntry struct {
LastRequest time.Time
UserID string
}
// DownloadDebouncer 下载防抖器
type DownloadDebouncer struct {
mu sync.RWMutex
entries map[string]*DebounceEntry
window time.Duration
lastCleanup time.Time
}
// NewDownloadDebouncer 创建下载防抖器
func NewDownloadDebouncer(window time.Duration) *DownloadDebouncer {
return &DownloadDebouncer{
entries: make(map[string]*DebounceEntry),
window: window,
lastCleanup: time.Now(),
}
}
// ShouldAllow 检查是否应该允许请求
func (d *DownloadDebouncer) ShouldAllow(userID, contentKey string) bool {
d.mu.Lock()
defer d.mu.Unlock()
key := userID + ":" + contentKey
now := time.Now()
if entry, exists := d.entries[key]; exists {
if now.Sub(entry.LastRequest) < d.window {
return false
}
}
d.entries[key] = &DebounceEntry{
LastRequest: now,
UserID: userID,
}
if time.Since(d.lastCleanup) > 5*time.Minute {
d.cleanup(now)
d.lastCleanup = now
}
return true
}
// cleanup 清理过期条目
func (d *DownloadDebouncer) cleanup(now time.Time) {
for key, entry := range d.entries {
if now.Sub(entry.LastRequest) > d.window*2 {
delete(d.entries, key)
}
}
}
// generateContentFingerprint 生成内容指纹
func generateContentFingerprint(images []string, platform string) string {
sortedImages := make([]string, len(images))
copy(sortedImages, images)
sort.Strings(sortedImages)
content := strings.Join(sortedImages, "|") + ":" + platform
hash := md5.Sum([]byte(content))
return hex.EncodeToString(hash[:])
}
// getUserID 获取用户标识
func getUserID(c *gin.Context) string {
if sessionID, err := c.Cookie("session_id"); err == nil && sessionID != "" {
return "session:" + sessionID
}
ip := c.ClientIP()
userAgent := c.GetHeader("User-Agent")
if userAgent == "" {
userAgent = "unknown"
}
combined := ip + ":" + userAgent
hash := md5.Sum([]byte(combined))
return "ip:" + hex.EncodeToString(hash[:8])
}
var (
singleImageDebouncer *DownloadDebouncer
batchImageDebouncer *DownloadDebouncer
)
// InitDebouncer 初始化防抖器
func InitDebouncer() {
singleImageDebouncer = NewDownloadDebouncer(5 * time.Second)
batchImageDebouncer = NewDownloadDebouncer(60 * time.Second)
}
// ImageStreamer 镜像流式下载器
type ImageStreamer struct {
concurrency int
remoteOptions []remote.Option
}
// ImageStreamerConfig 下载器配置
type ImageStreamerConfig struct {
Concurrency int
}
// NewImageStreamer 创建镜像下载器
func NewImageStreamer(cfg *ImageStreamerConfig) *ImageStreamer {
if cfg == nil {
cfg = &ImageStreamerConfig{}
}
concurrency := cfg.Concurrency
if concurrency <= 0 {
appCfg := config.GetConfig()
concurrency = appCfg.Download.MaxImages
if concurrency <= 0 {
concurrency = 10
}
}
remoteOptions := []remote.Option{
remote.WithAuth(authn.Anonymous),
remote.WithTransport(utils.GetGlobalHTTPClient().Transport),
}
return &ImageStreamer{
concurrency: concurrency,
remoteOptions: remoteOptions,
}
}
// StreamOptions 下载选项
type StreamOptions struct {
Platform string
Compression bool
UseCompressedLayers bool
}
// StreamImageToWriter 流式下载镜像到Writer
func (is *ImageStreamer) StreamImageToWriter(ctx context.Context, imageRef string, writer io.Writer, options *StreamOptions) error {
if options == nil {
options = &StreamOptions{UseCompressedLayers: true}
}
ref, err := name.ParseReference(imageRef)
if err != nil {
return fmt.Errorf("解析镜像引用失败: %w", err)
}
log.Printf("开始下载镜像: %s", ref.String())
contextOptions := append(is.remoteOptions, remote.WithContext(ctx))
desc, err := is.getImageDescriptorWithPlatform(ref, contextOptions, options.Platform)
if err != nil {
return fmt.Errorf("获取镜像描述失败: %w", err)
}
switch desc.MediaType {
case types.OCIImageIndex, types.DockerManifestList:
return is.streamMultiArchImage(ctx, desc, writer, options, contextOptions, imageRef)
case types.OCIManifestSchema1, types.DockerManifestSchema2:
return is.streamSingleImage(ctx, desc, writer, options, contextOptions, imageRef)
default:
return is.streamSingleImage(ctx, desc, writer, options, contextOptions, imageRef)
}
}
// getImageDescriptor 获取镜像描述符
func (is *ImageStreamer) getImageDescriptor(ref name.Reference, options []remote.Option) (*remote.Descriptor, error) {
return is.getImageDescriptorWithPlatform(ref, options, "")
}
// getImageDescriptorWithPlatform 获取指定平台的镜像描述符
func (is *ImageStreamer) getImageDescriptorWithPlatform(ref name.Reference, options []remote.Option, platform string) (*remote.Descriptor, error) {
return remote.Get(ref, options...)
}
// StreamImageToGin 流式响应到Gin
func (is *ImageStreamer) StreamImageToGin(ctx context.Context, imageRef string, c *gin.Context, options *StreamOptions) error {
if options == nil {
options = &StreamOptions{UseCompressedLayers: true}
}
filename := strings.ReplaceAll(imageRef, "/", "_") + ".tar"
c.Header("Content-Type", "application/octet-stream")
c.Header("Content-Disposition", fmt.Sprintf("attachment; filename=\"%s\"", filename))
if options.Compression {
c.Header("Content-Encoding", "gzip")
}
return is.StreamImageToWriter(ctx, imageRef, c.Writer, options)
}
// streamMultiArchImage 处理多架构镜像
func (is *ImageStreamer) streamMultiArchImage(ctx context.Context, desc *remote.Descriptor, writer io.Writer, options *StreamOptions, remoteOptions []remote.Option, imageRef string) error {
img, err := is.selectPlatformImage(desc, options)
if err != nil {
return err
}
return is.streamImageLayers(ctx, img, writer, options, imageRef)
}
// streamSingleImage 处理单架构镜像
func (is *ImageStreamer) streamSingleImage(ctx context.Context, desc *remote.Descriptor, writer io.Writer, options *StreamOptions, remoteOptions []remote.Option, imageRef string) error {
img, err := desc.Image()
if err != nil {
return fmt.Errorf("获取镜像失败: %w", err)
}
return is.streamImageLayers(ctx, img, writer, options, imageRef)
}
// streamImageLayers 处理镜像层
func (is *ImageStreamer) streamImageLayers(ctx context.Context, img v1.Image, writer io.Writer, options *StreamOptions, imageRef string) error {
var finalWriter io.Writer = writer
if options.Compression {
gzWriter := gzip.NewWriter(writer)
defer gzWriter.Close()
finalWriter = gzWriter
}
tarWriter := tar.NewWriter(finalWriter)
defer tarWriter.Close()
configFile, err := img.ConfigFile()
if err != nil {
return fmt.Errorf("获取镜像配置失败: %w", err)
}
layers, err := img.Layers()
if err != nil {
return fmt.Errorf("获取镜像层失败: %w", err)
}
log.Printf("镜像包含 %d 层", len(layers))
return is.streamDockerFormat(ctx, tarWriter, img, layers, configFile, imageRef, options)
}
// streamDockerFormat 生成Docker格式
func (is *ImageStreamer) streamDockerFormat(ctx context.Context, tarWriter *tar.Writer, img v1.Image, layers []v1.Layer, configFile *v1.ConfigFile, imageRef string, options *StreamOptions) error {
return is.streamDockerFormatWithReturn(ctx, tarWriter, img, layers, configFile, imageRef, nil, nil, options)
}
// streamDockerFormatWithReturn 生成Docker格式并返回manifest和repositories信息
func (is *ImageStreamer) streamDockerFormatWithReturn(ctx context.Context, tarWriter *tar.Writer, img v1.Image, layers []v1.Layer, configFile *v1.ConfigFile, imageRef string, manifestOut *map[string]interface{}, repositoriesOut *map[string]map[string]string, options *StreamOptions) error {
configDigest, err := img.ConfigName()
if err != nil {
return err
}
configData, err := json.Marshal(configFile)
if err != nil {
return err
}
configHeader := &tar.Header{
Name: configDigest.String() + ".json",
Size: int64(len(configData)),
Mode: 0644,
}
if err := tarWriter.WriteHeader(configHeader); err != nil {
return err
}
if _, err := tarWriter.Write(configData); err != nil {
return err
}
layerDigests := make([]string, len(layers))
for i, layer := range layers {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
if err := func() error {
digest, err := layer.Digest()
if err != nil {
return err
}
layerDigests[i] = digest.String()
layerDir := digest.String()
layerHeader := &tar.Header{
Name: layerDir + "/",
Typeflag: tar.TypeDir,
Mode: 0755,
}
if err := tarWriter.WriteHeader(layerHeader); err != nil {
return err
}
var layerSize int64
var layerReader io.ReadCloser
if options != nil && options.UseCompressedLayers {
layerSize, err = layer.Size()
if err != nil {
return err
}
layerReader, err = layer.Compressed()
} else {
layerSize, err = partial.UncompressedSize(layer)
if err != nil {
return err
}
layerReader, err = layer.Uncompressed()
}
if err != nil {
return err
}
defer layerReader.Close()
layerTarHeader := &tar.Header{
Name: layerDir + "/layer.tar",
Size: layerSize,
Mode: 0644,
}
if err := tarWriter.WriteHeader(layerTarHeader); err != nil {
return err
}
if _, err := io.Copy(tarWriter, layerReader); err != nil {
return err
}
return nil
}(); err != nil {
return err
}
log.Printf("已处理层 %d/%d", i+1, len(layers))
}
singleManifest := map[string]interface{}{
"Config": configDigest.String() + ".json",
"RepoTags": []string{imageRef},
"Layers": func() []string {
var layers []string
for _, digest := range layerDigests {
layers = append(layers, digest+"/layer.tar")
}
return layers
}(),
}
repositories := make(map[string]map[string]string)
parts := strings.Split(imageRef, ":")
if len(parts) == 2 {
repoName := parts[0]
tag := parts[1]
repositories[repoName] = map[string]string{tag: configDigest.String()}
}
if manifestOut != nil && repositoriesOut != nil {
*manifestOut = singleManifest
*repositoriesOut = repositories
return nil
}
manifest := []map[string]interface{}{singleManifest}
manifestData, err := json.Marshal(manifest)
if err != nil {
return err
}
manifestHeader := &tar.Header{
Name: "manifest.json",
Size: int64(len(manifestData)),
Mode: 0644,
}
if err := tarWriter.WriteHeader(manifestHeader); err != nil {
return err
}
if _, err := tarWriter.Write(manifestData); err != nil {
return err
}
repositoriesData, err := json.Marshal(repositories)
if err != nil {
return err
}
repositoriesHeader := &tar.Header{
Name: "repositories",
Size: int64(len(repositoriesData)),
Mode: 0644,
}
if err := tarWriter.WriteHeader(repositoriesHeader); err != nil {
return err
}
_, err = tarWriter.Write(repositoriesData)
return err
}
// processImageForBatch 处理镜像的公共逻辑
func (is *ImageStreamer) processImageForBatch(ctx context.Context, img v1.Image, tarWriter *tar.Writer, imageRef string, options *StreamOptions) (map[string]interface{}, map[string]map[string]string, error) {
layers, err := img.Layers()
if err != nil {
return nil, nil, fmt.Errorf("获取镜像层失败: %w", err)
}
configFile, err := img.ConfigFile()
if err != nil {
return nil, nil, fmt.Errorf("获取镜像配置失败: %w", err)
}
log.Printf("镜像包含 %d 层", len(layers))
var manifest map[string]interface{}
var repositories map[string]map[string]string
err = is.streamDockerFormatWithReturn(ctx, tarWriter, img, layers, configFile, imageRef, &manifest, &repositories, options)
if err != nil {
return nil, nil, err
}
return manifest, repositories, nil
}
func (is *ImageStreamer) streamSingleImageForBatch(ctx context.Context, tarWriter *tar.Writer, imageRef string, options *StreamOptions) (map[string]interface{}, map[string]map[string]string, error) {
ref, err := name.ParseReference(imageRef)
if err != nil {
return nil, nil, fmt.Errorf("解析镜像引用失败: %w", err)
}
contextOptions := append(is.remoteOptions, remote.WithContext(ctx))
desc, err := is.getImageDescriptorWithPlatform(ref, contextOptions, options.Platform)
if err != nil {
return nil, nil, fmt.Errorf("获取镜像描述失败: %w", err)
}
var img v1.Image
switch desc.MediaType {
case types.OCIImageIndex, types.DockerManifestList:
img, err = is.selectPlatformImage(desc, options)
if err != nil {
return nil, nil, fmt.Errorf("选择平台镜像失败: %w", err)
}
case types.OCIManifestSchema1, types.DockerManifestSchema2:
img, err = desc.Image()
if err != nil {
return nil, nil, fmt.Errorf("获取镜像失败: %w", err)
}
default:
img, err = desc.Image()
if err != nil {
return nil, nil, fmt.Errorf("获取镜像失败: %w", err)
}
}
return is.processImageForBatch(ctx, img, tarWriter, imageRef, options)
}
// selectPlatformImage 从多架构镜像中选择合适的平台镜像
func (is *ImageStreamer) selectPlatformImage(desc *remote.Descriptor, options *StreamOptions) (v1.Image, error) {
index, err := desc.ImageIndex()
if err != nil {
return nil, fmt.Errorf("获取镜像索引失败: %w", err)
}
manifest, err := index.IndexManifest()
if err != nil {
return nil, fmt.Errorf("获取索引清单失败: %w", err)
}
var selectedDesc *v1.Descriptor
for _, m := range manifest.Manifests {
if m.Platform == nil {
continue
}
if options.Platform != "" {
platformParts := strings.Split(options.Platform, "/")
if len(platformParts) >= 2 {
targetOS := platformParts[0]
targetArch := platformParts[1]
targetVariant := ""
if len(platformParts) >= 3 {
targetVariant = platformParts[2]
}
if m.Platform.OS == targetOS &&
m.Platform.Architecture == targetArch &&
m.Platform.Variant == targetVariant {
selectedDesc = &m
break
}
}
} else if m.Platform.OS == "linux" && m.Platform.Architecture == "amd64" {
selectedDesc = &m
break
}
}
if selectedDesc == nil && len(manifest.Manifests) > 0 {
selectedDesc = &manifest.Manifests[0]
}
if selectedDesc == nil {
return nil, fmt.Errorf("未找到合适的平台镜像")
}
img, err := index.Image(selectedDesc.Digest)
if err != nil {
return nil, fmt.Errorf("获取选中镜像失败: %w", err)
}
return img, nil
}
var globalImageStreamer *ImageStreamer
// InitImageStreamer 初始化镜像下载器
func InitImageStreamer() {
globalImageStreamer = NewImageStreamer(nil)
}
// formatPlatformText 格式化平台文本
func formatPlatformText(platform string) string {
if platform == "" {
return "自动选择"
}
return platform
}
// InitImageTarRoutes 初始化镜像下载路由
func InitImageTarRoutes(router *gin.Engine) {
imageAPI := router.Group("/api/image")
{
imageAPI.GET("/download/:image", handleDirectImageDownload)
imageAPI.GET("/info/:image", handleImageInfo)
imageAPI.POST("/batch", handleSimpleBatchDownload)
}
}
// handleDirectImageDownload 处理单镜像下载
func handleDirectImageDownload(c *gin.Context) {
imageParam := c.Param("image")
if imageParam == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "缺少镜像参数"})
return
}
imageRef := strings.ReplaceAll(imageParam, "_", "/")
platform := c.Query("platform")
tag := c.DefaultQuery("tag", "")
useCompressed := c.DefaultQuery("compressed", "true") == "true"
if tag != "" && !strings.Contains(imageRef, ":") && !strings.Contains(imageRef, "@") {
imageRef = imageRef + ":" + tag
} else if !strings.Contains(imageRef, ":") && !strings.Contains(imageRef, "@") {
imageRef = imageRef + ":latest"
}
if _, err := name.ParseReference(imageRef); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "镜像引用格式错误: " + err.Error()})
return
}
userID := getUserID(c)
contentKey := generateContentFingerprint([]string{imageRef}, platform)
if !singleImageDebouncer.ShouldAllow(userID, contentKey) {
c.JSON(http.StatusTooManyRequests, gin.H{
"error": "请求过于频繁,请稍后再试",
"retry_after": 5,
})
return
}
options := &StreamOptions{
Platform: platform,
Compression: false,
UseCompressedLayers: useCompressed,
}
ctx := c.Request.Context()
log.Printf("下载镜像: %s (平台: %s)", imageRef, formatPlatformText(platform))
if err := globalImageStreamer.StreamImageToGin(ctx, imageRef, c, options); err != nil {
log.Printf("镜像下载失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "镜像下载失败: " + err.Error()})
return
}
}
// handleSimpleBatchDownload 处理批量下载
func handleSimpleBatchDownload(c *gin.Context) {
var req struct {
Images []string `json:"images" binding:"required"`
Platform string `json:"platform"`
UseCompressedLayers *bool `json:"useCompressedLayers"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "请求参数错误: " + err.Error()})
return
}
if len(req.Images) == 0 {
c.JSON(http.StatusBadRequest, gin.H{"error": "镜像列表不能为空"})
return
}
for i, imageRef := range req.Images {
if !strings.Contains(imageRef, ":") && !strings.Contains(imageRef, "@") {
req.Images[i] = imageRef + ":latest"
}
}
cfg := config.GetConfig()
if len(req.Images) > cfg.Download.MaxImages {
c.JSON(http.StatusBadRequest, gin.H{
"error": fmt.Sprintf("镜像数量超过限制,最大允许: %d", cfg.Download.MaxImages),
})
return
}
userID := getUserID(c)
contentKey := generateContentFingerprint(req.Images, req.Platform)
if !batchImageDebouncer.ShouldAllow(userID, contentKey) {
c.JSON(http.StatusTooManyRequests, gin.H{
"error": "批量下载请求过于频繁,请稍后再试",
"retry_after": 60,
})
return
}
useCompressed := true
if req.UseCompressedLayers != nil {
useCompressed = *req.UseCompressedLayers
}
options := &StreamOptions{
Platform: req.Platform,
Compression: false,
UseCompressedLayers: useCompressed,
}
ctx := c.Request.Context()
log.Printf("批量下载 %d 个镜像 (平台: %s)", len(req.Images), formatPlatformText(req.Platform))
filename := fmt.Sprintf("batch_%d_images.tar", len(req.Images))
c.Header("Content-Type", "application/octet-stream")
c.Header("Content-Disposition", fmt.Sprintf("attachment; filename=\"%s\"", filename))
if err := globalImageStreamer.StreamMultipleImages(ctx, req.Images, c.Writer, options); err != nil {
log.Printf("批量镜像下载失败: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "批量镜像下载失败: " + err.Error()})
return
}
}
// handleImageInfo 处理镜像信息查询
func handleImageInfo(c *gin.Context) {
imageParam := c.Param("image")
if imageParam == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "缺少镜像参数"})
return
}
imageRef := strings.ReplaceAll(imageParam, "_", "/")
tag := c.DefaultQuery("tag", "latest")
if !strings.Contains(imageRef, ":") && !strings.Contains(imageRef, "@") {
imageRef = imageRef + ":" + tag
}
ref, err := name.ParseReference(imageRef)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "镜像引用格式错误: " + err.Error()})
return
}
ctx := c.Request.Context()
contextOptions := append(globalImageStreamer.remoteOptions, remote.WithContext(ctx))
desc, err := globalImageStreamer.getImageDescriptor(ref, contextOptions)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "获取镜像信息失败: " + err.Error()})
return
}
info := gin.H{
"name": ref.String(),
"mediaType": desc.MediaType,
"digest": desc.Digest.String(),
"size": desc.Size,
}
if desc.MediaType == types.OCIImageIndex || desc.MediaType == types.DockerManifestList {
index, err := desc.ImageIndex()
if err == nil {
manifest, err := index.IndexManifest()
if err == nil {
var platforms []string
for _, m := range manifest.Manifests {
if m.Platform != nil {
platforms = append(platforms, m.Platform.OS+"/"+m.Platform.Architecture)
}
}
info["platforms"] = platforms
info["multiArch"] = true
}
}
} else {
info["multiArch"] = false
}
c.JSON(http.StatusOK, gin.H{"success": true, "data": info})
}
// StreamMultipleImages 批量下载多个镜像
func (is *ImageStreamer) StreamMultipleImages(ctx context.Context, imageRefs []string, writer io.Writer, options *StreamOptions) error {
if options == nil {
options = &StreamOptions{UseCompressedLayers: true}
}
var finalWriter io.Writer = writer
if options.Compression {
gzWriter := gzip.NewWriter(writer)
defer gzWriter.Close()
finalWriter = gzWriter
}
tarWriter := tar.NewWriter(finalWriter)
defer tarWriter.Close()
var allManifests []map[string]interface{}
var allRepositories = make(map[string]map[string]string)
for i, imageRef := range imageRefs {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
log.Printf("处理镜像 %d/%d: %s", i+1, len(imageRefs), imageRef)
timeoutCtx, cancel := context.WithTimeout(ctx, 15*time.Minute)
manifest, repositories, err := is.streamSingleImageForBatch(timeoutCtx, tarWriter, imageRef, options)
cancel()
if err != nil {
log.Printf("下载镜像 %s 失败: %v", imageRef, err)
return fmt.Errorf("下载镜像 %s 失败: %w", imageRef, err)
}
if manifest == nil {
return fmt.Errorf("镜像 %s manifest数据为空", imageRef)
}
allManifests = append(allManifests, manifest)
for repo, tags := range repositories {
if allRepositories[repo] == nil {
allRepositories[repo] = make(map[string]string)
}
for tag, digest := range tags {
allRepositories[repo][tag] = digest
}
}
}
manifestData, err := json.Marshal(allManifests)
if err != nil {
return fmt.Errorf("序列化manifest失败: %w", err)
}
manifestHeader := &tar.Header{
Name: "manifest.json",
Size: int64(len(manifestData)),
Mode: 0644,
}
if err := tarWriter.WriteHeader(manifestHeader); err != nil {
return fmt.Errorf("写入manifest header失败: %w", err)
}
if _, err := tarWriter.Write(manifestData); err != nil {
return fmt.Errorf("写入manifest数据失败: %w", err)
}
repositoriesData, err := json.Marshal(allRepositories)
if err != nil {
return fmt.Errorf("序列化repositories失败: %w", err)
}
repositoriesHeader := &tar.Header{
Name: "repositories",
Size: int64(len(repositoriesData)),
Mode: 0644,
}
if err := tarWriter.WriteHeader(repositoriesHeader); err != nil {
return fmt.Errorf("写入repositories header失败: %w", err)
}
if _, err := tarWriter.Write(repositoriesData); err != nil {
return fmt.Errorf("写入repositories数据失败: %w", err)
}
log.Printf("批量下载完成,共处理 %d 个镜像", len(imageRefs))
return nil
}

559
src/handlers/search.go Normal file
View File

@@ -0,0 +1,559 @@
package handlers
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"sort"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
"hubproxy/utils"
)
// SearchResult Docker Hub搜索结果
type SearchResult struct {
Count int `json:"count"`
Next string `json:"next"`
Previous string `json:"previous"`
Results []Repository `json:"results"`
}
// Repository 仓库信息
type Repository struct {
Name string `json:"repo_name"`
Description string `json:"short_description"`
IsOfficial bool `json:"is_official"`
IsAutomated bool `json:"is_automated"`
StarCount int `json:"star_count"`
PullCount int `json:"pull_count"`
RepoOwner string `json:"repo_owner"`
LastUpdated string `json:"last_updated"`
Status int `json:"status"`
Organization string `json:"affiliation"`
PullsLastWeek int `json:"pulls_last_week"`
Namespace string `json:"namespace"`
}
// TagInfo 标签信息
type TagInfo struct {
Name string `json:"name"`
FullSize int64 `json:"full_size"`
LastUpdated time.Time `json:"last_updated"`
LastPusher string `json:"last_pusher"`
Images []Image `json:"images"`
Vulnerabilities struct {
Critical int `json:"critical"`
High int `json:"high"`
Medium int `json:"medium"`
Low int `json:"low"`
Unknown int `json:"unknown"`
} `json:"vulnerabilities"`
}
// Image 镜像信息
type Image struct {
Architecture string `json:"architecture"`
Features string `json:"features"`
Variant string `json:"variant,omitempty"`
Digest string `json:"digest"`
OS string `json:"os"`
OSFeatures string `json:"os_features"`
Size int64 `json:"size"`
}
// TagPageResult 分页标签结果
type TagPageResult struct {
Tags []TagInfo `json:"tags"`
HasMore bool `json:"has_more"`
}
type cacheEntry struct {
data interface{}
expiresAt time.Time
}
const (
maxCacheSize = 1000
maxPaginationCache = 200
cacheTTL = 30 * time.Minute
)
type Cache struct {
data map[string]cacheEntry
mu sync.RWMutex
maxSize int
}
var (
searchCache = &Cache{
data: make(map[string]cacheEntry),
maxSize: maxCacheSize,
}
)
func (c *Cache) Get(key string) (interface{}, bool) {
c.mu.RLock()
entry, exists := c.data[key]
c.mu.RUnlock()
if !exists {
return nil, false
}
if time.Now().After(entry.expiresAt) {
c.mu.Lock()
delete(c.data, key)
c.mu.Unlock()
return nil, false
}
return entry.data, true
}
func (c *Cache) Set(key string, data interface{}) {
c.SetWithTTL(key, data, cacheTTL)
}
func (c *Cache) SetWithTTL(key string, data interface{}, ttl time.Duration) {
c.mu.Lock()
defer c.mu.Unlock()
if len(c.data) >= c.maxSize {
c.cleanupExpiredLocked()
}
c.data[key] = cacheEntry{
data: data,
expiresAt: time.Now().Add(ttl),
}
}
func (c *Cache) Cleanup() {
c.mu.Lock()
defer c.mu.Unlock()
c.cleanupExpiredLocked()
}
func (c *Cache) cleanupExpiredLocked() {
now := time.Now()
for key, entry := range c.data {
if now.After(entry.expiresAt) {
delete(c.data, key)
}
}
}
func init() {
go func() {
ticker := time.NewTicker(5 * time.Minute)
defer ticker.Stop()
for range ticker.C {
searchCache.Cleanup()
}
}()
}
func filterSearchResults(results []Repository, query string) []Repository {
searchTerm := strings.ToLower(strings.TrimPrefix(query, "library/"))
filtered := make([]Repository, 0)
for _, repo := range results {
repoName := strings.ToLower(repo.Name)
repoDesc := strings.ToLower(repo.Description)
score := 0
if repoName == searchTerm {
score += 100
}
if strings.HasPrefix(repoName, searchTerm) {
score += 50
}
if strings.Contains(repoName, searchTerm) {
score += 30
}
if strings.Contains(repoDesc, searchTerm) {
score += 10
}
if repo.IsOfficial {
score += 20
}
if score > 0 {
filtered = append(filtered, repo)
}
}
sort.Slice(filtered, func(i, j int) bool {
if filtered[i].IsOfficial != filtered[j].IsOfficial {
return filtered[i].IsOfficial
}
return filtered[i].PullCount > filtered[j].PullCount
})
return filtered
}
// normalizeRepository 统一规范化仓库信息
func normalizeRepository(repo *Repository) {
if repo.IsOfficial {
repo.Namespace = "library"
if !strings.Contains(repo.Name, "/") {
repo.Name = "library/" + repo.Name
}
} else {
if repo.Namespace == "" && repo.RepoOwner != "" {
repo.Namespace = repo.RepoOwner
}
if strings.Contains(repo.Name, "/") {
parts := strings.Split(repo.Name, "/")
if len(parts) > 1 {
if repo.Namespace == "" {
repo.Namespace = parts[0]
}
repo.Name = parts[len(parts)-1]
}
}
}
}
// searchDockerHub 搜索镜像
func searchDockerHub(ctx context.Context, query string, page, pageSize int) (*SearchResult, error) {
return searchDockerHubWithDepth(ctx, query, page, pageSize, 0)
}
func searchDockerHubWithDepth(ctx context.Context, query string, page, pageSize int, depth int) (*SearchResult, error) {
if depth > 1 {
return nil, fmt.Errorf("搜索请求过于复杂,请尝试更具体的关键词")
}
cacheKey := fmt.Sprintf("search:%s:%d:%d", query, page, pageSize)
if cached, ok := searchCache.Get(cacheKey); ok {
return cached.(*SearchResult), nil
}
isUserRepo := strings.Contains(query, "/")
var namespace, repoName string
if isUserRepo {
parts := strings.Split(query, "/")
if len(parts) == 2 {
namespace = parts[0]
repoName = parts[1]
}
}
baseURL := "https://registry.hub.docker.com/v2"
var fullURL string
var params url.Values
if isUserRepo && namespace != "" {
fullURL = fmt.Sprintf("%s/repositories/%s/", baseURL, namespace)
params = url.Values{
"page": {fmt.Sprintf("%d", page)},
"page_size": {fmt.Sprintf("%d", pageSize)},
}
} else {
fullURL = baseURL + "/search/repositories/"
params = url.Values{
"query": {query},
"page": {fmt.Sprintf("%d", page)},
"page_size": {fmt.Sprintf("%d", pageSize)},
}
}
fullURL = fullURL + "?" + params.Encode()
resp, err := utils.GetSearchHTTPClient().Get(fullURL)
if err != nil {
return nil, fmt.Errorf("请求Docker Hub API失败: %v", err)
}
defer safeCloseResponseBody(resp.Body, "搜索响应体")
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("读取响应失败: %v", err)
}
if resp.StatusCode != http.StatusOK {
switch resp.StatusCode {
case http.StatusTooManyRequests:
return nil, fmt.Errorf("请求过于频繁,请稍后重试")
case http.StatusNotFound:
if isUserRepo && namespace != "" {
return searchDockerHubWithDepth(ctx, repoName, page, pageSize, depth+1)
}
return nil, fmt.Errorf("未找到相关镜像")
case http.StatusBadGateway, http.StatusServiceUnavailable:
return nil, fmt.Errorf("Docker Hub服务暂时不可用请稍后重试")
default:
return nil, fmt.Errorf("请求失败: 状态码=%d, 响应=%s", resp.StatusCode, string(body))
}
}
var result *SearchResult
if isUserRepo && namespace != "" {
var userRepos struct {
Count int `json:"count"`
Next string `json:"next"`
Previous string `json:"previous"`
Results []Repository `json:"results"`
}
if err := json.Unmarshal(body, &userRepos); err != nil {
return nil, fmt.Errorf("解析响应失败: %v", err)
}
result = &SearchResult{
Count: userRepos.Count,
Next: userRepos.Next,
Previous: userRepos.Previous,
Results: make([]Repository, 0),
}
for _, repo := range userRepos.Results {
if repoName == "" || strings.Contains(strings.ToLower(repo.Name), strings.ToLower(repoName)) {
repo.Namespace = namespace
normalizeRepository(&repo)
result.Results = append(result.Results, repo)
}
}
if len(result.Results) == 0 {
return searchDockerHubWithDepth(ctx, repoName, page, pageSize, depth+1)
}
result.Count = len(result.Results)
} else {
result = &SearchResult{}
if err := json.Unmarshal(body, &result); err != nil {
return nil, fmt.Errorf("解析响应失败: %v", err)
}
for i := range result.Results {
normalizeRepository(&result.Results[i])
}
if isUserRepo && namespace != "" {
filteredResults := make([]Repository, 0)
for _, repo := range result.Results {
if strings.EqualFold(repo.Namespace, namespace) {
filteredResults = append(filteredResults, repo)
}
}
result.Results = filteredResults
result.Count = len(filteredResults)
}
}
searchCache.Set(cacheKey, result)
return result, nil
}
func isRetryableError(err error) bool {
if err == nil {
return false
}
if strings.Contains(err.Error(), "timeout") ||
strings.Contains(err.Error(), "connection refused") ||
strings.Contains(err.Error(), "no such host") ||
strings.Contains(err.Error(), "too many requests") {
return true
}
return false
}
// getRepositoryTags 获取仓库标签信息
func getRepositoryTags(ctx context.Context, namespace, name string, page, pageSize int) ([]TagInfo, bool, error) {
if namespace == "" || name == "" {
return nil, false, fmt.Errorf("无效输入:命名空间和名称不能为空")
}
if page <= 0 {
page = 1
}
if pageSize <= 0 || pageSize > 100 {
pageSize = 100
}
cacheKey := fmt.Sprintf("tags:%s:%s:page_%d", namespace, name, page)
if cached, ok := searchCache.Get(cacheKey); ok {
result := cached.(TagPageResult)
return result.Tags, result.HasMore, nil
}
baseURL := fmt.Sprintf("https://registry.hub.docker.com/v2/repositories/%s/%s/tags", namespace, name)
params := url.Values{}
params.Set("page", fmt.Sprintf("%d", page))
params.Set("page_size", fmt.Sprintf("%d", pageSize))
params.Set("ordering", "last_updated")
fullURL := baseURL + "?" + params.Encode()
pageResult, err := fetchTagPage(ctx, fullURL, 3)
if err != nil {
return nil, false, fmt.Errorf("获取标签失败: %v", err)
}
hasMore := pageResult.Next != ""
result := TagPageResult{Tags: pageResult.Results, HasMore: hasMore}
searchCache.SetWithTTL(cacheKey, result, 30*time.Minute)
return pageResult.Results, hasMore, nil
}
func fetchTagPage(ctx context.Context, url string, maxRetries int) (*struct {
Count int `json:"count"`
Next string `json:"next"`
Previous string `json:"previous"`
Results []TagInfo `json:"results"`
}, error) {
var lastErr error
for retry := 0; retry < maxRetries; retry++ {
if retry > 0 {
time.Sleep(time.Duration(retry) * 500 * time.Millisecond)
}
resp, err := utils.GetSearchHTTPClient().Get(url)
if err != nil {
lastErr = err
if isRetryableError(err) && retry < maxRetries-1 {
continue
}
return nil, fmt.Errorf("发送请求失败: %v", err)
}
body, err := func() ([]byte, error) {
defer safeCloseResponseBody(resp.Body, "标签响应体")
return io.ReadAll(resp.Body)
}()
if err != nil {
lastErr = err
if retry < maxRetries-1 {
continue
}
return nil, fmt.Errorf("读取响应失败: %v", err)
}
if resp.StatusCode != http.StatusOK {
lastErr = fmt.Errorf("状态码=%d, 响应=%s", resp.StatusCode, string(body))
if resp.StatusCode >= 400 && resp.StatusCode < 500 && resp.StatusCode != 429 {
return nil, fmt.Errorf("请求失败: %v", lastErr)
}
if retry < maxRetries-1 {
continue
}
return nil, fmt.Errorf("请求失败: %v", lastErr)
}
var result struct {
Count int `json:"count"`
Next string `json:"next"`
Previous string `json:"previous"`
Results []TagInfo `json:"results"`
}
if err := json.Unmarshal(body, &result); err != nil {
lastErr = err
if retry < maxRetries-1 {
continue
}
return nil, fmt.Errorf("解析响应失败: %v", err)
}
return &result, nil
}
return nil, lastErr
}
func parsePaginationParams(c *gin.Context, defaultPageSize int) (page, pageSize int) {
page = 1
pageSize = defaultPageSize
if p := c.Query("page"); p != "" {
fmt.Sscanf(p, "%d", &page)
}
if ps := c.Query("page_size"); ps != "" {
fmt.Sscanf(ps, "%d", &pageSize)
}
return page, pageSize
}
func safeCloseResponseBody(body io.ReadCloser, context string) {
if body != nil {
if err := body.Close(); err != nil {
fmt.Printf("关闭%s失败: %v\n", context, err)
}
}
}
func sendErrorResponse(c *gin.Context, message string) {
c.JSON(http.StatusBadRequest, gin.H{"error": message})
}
// RegisterSearchRoute 注册搜索相关路由
func RegisterSearchRoute(r *gin.Engine) {
r.GET("/search", func(c *gin.Context) {
query := c.Query("q")
if query == "" {
sendErrorResponse(c, "搜索关键词不能为空")
return
}
page, pageSize := parsePaginationParams(c, 25)
result, err := searchDockerHub(c.Request.Context(), query, page, pageSize)
if err != nil {
sendErrorResponse(c, err.Error())
return
}
c.JSON(http.StatusOK, result)
})
r.GET("/tags/:namespace/:name", func(c *gin.Context) {
namespace := c.Param("namespace")
name := c.Param("name")
if namespace == "" || name == "" {
sendErrorResponse(c, "命名空间和名称不能为空")
return
}
page, pageSize := parsePaginationParams(c, 100)
tags, hasMore, err := getRepositoryTags(c.Request.Context(), namespace, name, page, pageSize)
if err != nil {
sendErrorResponse(c, err.Error())
return
}
if c.Query("page") != "" || c.Query("page_size") != "" {
c.JSON(http.StatusOK, gin.H{
"tags": tags,
"has_more": hasMore,
"page": page,
"page_size": pageSize,
})
} else {
c.JSON(http.StatusOK, tags)
}
})
}