Files
hubproxy/ghproxy/skopeo_service.go
2025-05-17 13:02:27 +08:00

676 lines
15 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package main
import (
"archive/zip"
"bufio"
"crypto/rand"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"os/exec"
"path/filepath"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
)
// 任务状态
type TaskStatus string
const (
StatusPending TaskStatus = "pending"
StatusRunning TaskStatus = "running"
StatusCompleted TaskStatus = "completed"
StatusFailed TaskStatus = "failed"
)
// 镜像下载任务
type ImageTask struct {
Image string `json:"image"`
Progress float64 `json:"progress"`
Status string `json:"status"`
Error string `json:"error,omitempty"`
OutputPath string `json:"-"` // 输出文件路径,不发送给客户端
}
// 下载任务
type DownloadTask struct {
ID string `json:"id"`
Images []*ImageTask `json:"images"`
TotalProgress float64 `json:"totalProgress"`
Status TaskStatus `json:"status"`
OutputFile string `json:"-"` // 最终输出文件
TempDir string `json:"-"` // 临时目录
Lock sync.Mutex `json:"-"` // 锁,防止并发冲突
}
// WebSocket客户端
type Client struct {
Conn *websocket.Conn
TaskID string
Send chan []byte
CloseOnce sync.Once
}
var (
// WebSocket升级器
upgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool {
return true // 允许所有源
},
}
// 活跃任务映射
tasks = make(map[string]*DownloadTask)
tasksLock sync.Mutex
clients = make(map[string]*Client)
clientLock sync.Mutex
)
// 初始化Skopeo相关路由
func initSkopeoRoutes(router *gin.Engine) {
// 创建临时目录
os.MkdirAll("./temp", 0755)
// WebSocket路由 - 用于实时获取进度
router.GET("/ws/:taskId", handleWebSocket)
// 创建下载任务
router.POST("/api/download", handleDownload)
// 获取任务状态
router.GET("/api/task/:taskId", getTaskStatus)
// 下载文件
router.GET("/api/files/:filename", serveFile)
// 启动清理过期文件的goroutine
go cleanupTempFiles()
}
// 处理WebSocket连接
func handleWebSocket(c *gin.Context) {
taskID := c.Param("taskId")
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
if err != nil {
fmt.Printf("WebSocket升级失败: %v\n", err)
return
}
client := &Client{
Conn: conn,
TaskID: taskID,
Send: make(chan []byte, 256),
}
// 注册客户端
clientLock.Lock()
clients[taskID] = client
clientLock.Unlock()
// 启动goroutine处理消息发送
go client.writePump()
// 如果任务已存在,立即发送当前状态
tasksLock.Lock()
if task, exists := tasks[taskID]; exists {
tasksLock.Unlock()
taskJSON, _ := json.Marshal(task)
client.Send <- taskJSON
} else {
tasksLock.Unlock()
}
// 处理WebSocket关闭
conn.SetCloseHandler(func(code int, text string) error {
client.CloseOnce.Do(func() {
close(client.Send)
clientLock.Lock()
delete(clients, taskID)
clientLock.Unlock()
})
return nil
})
}
// 客户端消息发送loop
func (c *Client) writePump() {
defer func() {
c.Conn.Close()
}()
for message := range c.Send {
err := c.Conn.WriteMessage(websocket.TextMessage, message)
if err != nil {
fmt.Printf("发送WS消息失败: %v\n", err)
break
}
}
}
// 获取任务状态
func getTaskStatus(c *gin.Context) {
taskID := c.Param("taskId")
tasksLock.Lock()
task, exists := tasks[taskID]
tasksLock.Unlock()
if !exists {
c.JSON(http.StatusNotFound, gin.H{"error": "任务不存在"})
return
}
c.JSON(http.StatusOK, task)
}
// 生成随机任务ID
func generateTaskID() string {
b := make([]byte, 16)
rand.Read(b)
return hex.EncodeToString(b)
}
// 处理下载请求
func handleDownload(c *gin.Context) {
type DownloadRequest struct {
Images []string `json:"images"`
Platform string `json:"platform"` // 平台: amd64, arm64等
}
var req DownloadRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的请求参数"})
return
}
if len(req.Images) == 0 {
c.JSON(http.StatusBadRequest, gin.H{"error": "请提供至少一个镜像"})
return
}
// 创建新任务
taskID := generateTaskID()
tempDir := filepath.Join("./temp", taskID)
os.MkdirAll(tempDir, 0755)
// 初始化任务
imageTasks := make([]*ImageTask, len(req.Images))
for i, image := range req.Images {
imageTasks[i] = &ImageTask{
Image: image,
Progress: 0,
Status: string(StatusPending),
}
}
task := &DownloadTask{
ID: taskID,
Images: imageTasks,
TotalProgress: 0,
Status: StatusPending,
TempDir: tempDir,
}
// 保存任务
tasksLock.Lock()
tasks[taskID] = task
tasksLock.Unlock()
// 异步处理下载
go func() {
processDownloadTask(task, req.Platform)
}()
c.JSON(http.StatusOK, gin.H{
"taskId": taskID,
"status": "started",
})
}
// 处理下载任务
func processDownloadTask(task *DownloadTask, platform string) {
task.Lock.Lock()
task.Status = StatusRunning
task.Lock.Unlock()
// 通知客户端任务已开始
sendTaskUpdate(task)
// 使用WaitGroup等待所有镜像下载完成
var wg sync.WaitGroup
wg.Add(len(task.Images))
// 使用并发下载镜像
for i, imgTask := range task.Images {
go func(idx int, imgTask *ImageTask) {
defer wg.Done()
downloadImage(task, idx, imgTask, platform)
}(i, imgTask)
}
// 等待所有下载完成
wg.Wait()
// 判断是单个tar还是需要打包
var finalFilePath string
var err error
task.Lock.Lock()
allSuccess := true
for _, img := range task.Images {
if img.Status == string(StatusFailed) {
allSuccess = false
break
}
}
if !allSuccess {
task.Status = StatusFailed
task.Lock.Unlock()
sendTaskUpdate(task)
return
}
// 如果只有一个文件,直接使用它
if len(task.Images) == 1 && task.Images[0].Status == string(StatusCompleted) {
finalFilePath = task.Images[0].OutputPath
// 重命名为更友好的名称
imageName := strings.ReplaceAll(task.Images[0].Image, "/", "_")
imageName = strings.ReplaceAll(imageName, ":", "_")
newPath := filepath.Join(task.TempDir, imageName+".tar")
os.Rename(finalFilePath, newPath)
finalFilePath = newPath
} else {
// 多个文件打包成zip
finalFilePath, err = createZipArchive(task)
if err != nil {
task.Status = StatusFailed
task.Lock.Unlock()
sendTaskUpdate(task)
return
}
}
task.OutputFile = finalFilePath
task.Status = StatusCompleted
task.TotalProgress = 100
task.Lock.Unlock()
// 发送最终状态更新
sendTaskUpdate(task)
}
// 下载单个镜像
func downloadImage(task *DownloadTask, index int, imgTask *ImageTask, platform string) {
imgTask.Status = string(StatusRunning)
sendImageUpdate(task, index)
// 创建输出文件名
outputFileName := fmt.Sprintf("image_%d.tar", index)
outputPath := filepath.Join(task.TempDir, outputFileName)
imgTask.OutputPath = outputPath
// 创建skopeo命令
platformArg := ""
if platform != "" {
// 支持手动输入完整的平台参数
if strings.Contains(platform, "--") {
platformArg = platform
} else {
// 处理特殊架构格式,如 arm/v7
if strings.Contains(platform, "/") {
parts := strings.Split(platform, "/")
if len(parts) == 2 {
// 适用于arm/v7这样的格式
platformArg = fmt.Sprintf("--override-os linux --override-arch %s --override-variant %s", parts[0], parts[1])
} else {
// 对于其他带/的格式,直接按原格式处理
platformArg = fmt.Sprintf("--override-os linux --override-arch %s", platform)
}
} else {
// 仅指定架构名称的情况
platformArg = fmt.Sprintf("--override-os linux --override-arch %s", platform)
}
}
}
// 构建命令
cmd := fmt.Sprintf("skopeo copy %s docker://%s docker-archive:%s",
platformArg, imgTask.Image, outputPath)
fmt.Printf("执行命令: %s\n", cmd)
// 执行命令
command := exec.Command("sh", "-c", cmd)
// 获取命令输出
stderr, err := command.StderrPipe()
if err != nil {
imgTask.Status = string(StatusFailed)
imgTask.Error = fmt.Sprintf("无法创建输出管道: %v", err)
sendImageUpdate(task, index)
return
}
stdout, err := command.StdoutPipe()
if err != nil {
imgTask.Status = string(StatusFailed)
imgTask.Error = fmt.Sprintf("无法创建标准输出管道: %v", err)
sendImageUpdate(task, index)
return
}
if err := command.Start(); err != nil {
imgTask.Status = string(StatusFailed)
imgTask.Error = fmt.Sprintf("启动命令失败: %v", err)
sendImageUpdate(task, index)
return
}
// 模拟逐步进度增加,确保用户体验更好
go func() {
// 每500ms检查一次进度如果进度没有变化则稍微增加一点
ticker := time.NewTicker(500 * time.Millisecond)
defer ticker.Stop()
lastProgress := 0.0
stagnantCount := 0
for {
select {
case <-ticker.C:
// 检查命令是否还在运行
if command.ProcessState != nil && command.ProcessState.Exited() {
return
}
// 如果进度停滞,小幅增加进度,提高用户体验
task.Lock.Lock()
currentProgress := imgTask.Progress
if currentProgress == lastProgress {
stagnantCount++
if stagnantCount > 5 && currentProgress < 90 { // 连续5次无变化且未接近完成
// 缓慢增加进度但不超过95%
newProgress := currentProgress + 0.5
if newProgress > 95 {
newProgress = 95
}
imgTask.Progress = newProgress
updateTaskProgress(task)
sendImageUpdate(task, index)
}
} else {
stagnantCount = 0
lastProgress = currentProgress
}
task.Lock.Unlock()
}
}
}()
// 读取stderr以获取进度信息
go func() {
scanner := bufio.NewScanner(stderr)
for scanner.Scan() {
output := scanner.Text()
fmt.Printf("镜像 %s 进度输出: %s\n", imgTask.Image, output)
// 解析进度信息
if strings.Contains(output, "%") {
parts := strings.Split(output, "%")
if len(parts) > 0 {
numStr := strings.TrimSpace(parts[0])
numStr = strings.TrimLeft(numStr, "Copying blob ")
numStr = strings.TrimLeft(numStr, "Copying config ")
numStr = strings.TrimRight(numStr, " / ")
numStr = strings.TrimSpace(numStr)
// 尝试提取最后一个数字作为进度
fields := strings.Fields(numStr)
if len(fields) > 0 {
lastField := fields[len(fields)-1]
progress := 0.0
fmt.Sscanf(lastField, "%f", &progress)
if progress > 0 && progress <= 100 {
task.Lock.Lock()
imgTask.Progress = progress
task.Lock.Unlock()
updateTaskProgress(task)
sendImageUpdate(task, index)
}
}
}
}
}
}()
// 读取stdout
go func() {
scanner := bufio.NewScanner(stdout)
for scanner.Scan() {
output := scanner.Text()
fmt.Printf("镜像 %s 标准输出: %s\n", imgTask.Image, output)
}
}()
if err := command.Wait(); err != nil {
imgTask.Status = string(StatusFailed)
imgTask.Error = fmt.Sprintf("命令执行失败: %v", err)
sendImageUpdate(task, index)
return
}
// 检查文件是否成功创建
if _, err := os.Stat(outputPath); os.IsNotExist(err) {
imgTask.Status = string(StatusFailed)
imgTask.Error = "文件未成功创建"
sendImageUpdate(task, index)
return
}
// 更新状态为已完成
task.Lock.Lock()
imgTask.Status = string(StatusCompleted)
imgTask.Progress = 100
task.Lock.Unlock()
updateTaskProgress(task)
sendImageUpdate(task, index)
}
// 更新任务总进度
func updateTaskProgress(task *DownloadTask) {
task.Lock.Lock()
defer task.Lock.Unlock()
totalProgress := 0.0
for _, img := range task.Images {
totalProgress += img.Progress
}
task.TotalProgress = totalProgress / float64(len(task.Images))
}
// 创建ZIP归档
func createZipArchive(task *DownloadTask) (string, error) {
zipFilePath := filepath.Join(task.TempDir, "images.zip")
zipFile, err := os.Create(zipFilePath)
if err != nil {
return "", err
}
defer zipFile.Close()
zipWriter := zip.NewWriter(zipFile)
defer zipWriter.Close()
for _, img := range task.Images {
if img.Status != string(StatusCompleted) || img.OutputPath == "" {
continue
}
// 创建ZIP条目
imgFile, err := os.Open(img.OutputPath)
if err != nil {
return "", err
}
// 使用镜像名作为文件名
imageName := strings.ReplaceAll(img.Image, "/", "_")
imageName = strings.ReplaceAll(imageName, ":", "_")
fileName := imageName + ".tar"
fileInfo, err := imgFile.Stat()
if err != nil {
imgFile.Close()
return "", err
}
header, err := zip.FileInfoHeader(fileInfo)
if err != nil {
imgFile.Close()
return "", err
}
header.Name = fileName
header.Method = zip.Deflate
writer, err := zipWriter.CreateHeader(header)
if err != nil {
imgFile.Close()
return "", err
}
_, err = io.Copy(writer, imgFile)
imgFile.Close()
if err != nil {
return "", err
}
}
return zipFilePath, nil
}
// 发送任务更新到WebSocket
func sendTaskUpdate(task *DownloadTask) {
taskJSON, err := json.Marshal(task)
if err != nil {
fmt.Printf("序列化任务失败: %v\n", err)
return
}
clientLock.Lock()
client, exists := clients[task.ID]
clientLock.Unlock()
if exists {
select {
case client.Send <- taskJSON:
default:
// 通道已满或关闭,忽略
}
}
}
// 发送单个镜像更新
func sendImageUpdate(task *DownloadTask, imageIndex int) {
sendTaskUpdate(task)
}
// 提供文件下载
func serveFile(c *gin.Context) {
filename := c.Param("filename")
// 安全检查,防止任意文件访问
if strings.Contains(filename, "..") {
c.JSON(http.StatusForbidden, gin.H{"error": "无效的文件名"})
return
}
// 根据任务ID和文件名查找文件
parts := strings.Split(filename, "_")
if len(parts) < 2 {
c.JSON(http.StatusBadRequest, gin.H{"error": "无效的文件名格式"})
return
}
taskID := parts[0]
tasksLock.Lock()
task, exists := tasks[taskID]
tasksLock.Unlock()
if !exists {
c.JSON(http.StatusNotFound, gin.H{"error": "任务不存在"})
return
}
// 检查文件是否存在
filePath := task.OutputFile
if filePath == "" || !fileExists(filePath) {
c.JSON(http.StatusNotFound, gin.H{"error": "文件不存在"})
return
}
// 获取文件信息
fileInfo, err := os.Stat(filePath)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "无法获取文件信息"})
return
}
// 设置文件名 - 提取有意义的文件名
downloadName := filepath.Base(filePath)
c.Header("Content-Disposition", fmt.Sprintf("attachment; filename=%s", downloadName))
c.Header("Content-Length", fmt.Sprintf("%d", fileInfo.Size()))
// 返回文件
c.File(filePath)
}
// 检查文件是否存在
func fileExists(path string) bool {
_, err := os.Stat(path)
return !os.IsNotExist(err)
}
// 清理过期临时文件
func cleanupTempFiles() {
for {
time.Sleep(1 * time.Hour)
// 遍历temp目录
err := filepath.Walk("./temp", func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
// 跳过根目录
if path == "./temp" {
return nil
}
// 如果文件或目录超过24小时未修改则删除
if time.Since(info.ModTime()) > 24*time.Hour {
if info.IsDir() {
os.RemoveAll(path)
return filepath.SkipDir
}
os.Remove(path)
}
return nil
})
if err != nil {
fmt.Printf("清理临时文件失败: %v\n", err)
}
}
}