拆分包结构

This commit is contained in:
user123456
2025-07-27 05:50:34 +08:00
parent badafd2899
commit 187e842445
15 changed files with 956 additions and 1148 deletions

205
src/utils/access_control.go Normal file
View File

@@ -0,0 +1,205 @@
package utils
import (
"strings"
"sync"
"hubproxy/config"
)
// ResourceType 资源类型
type ResourceType string
const (
ResourceTypeGitHub ResourceType = "github"
ResourceTypeDocker ResourceType = "docker"
)
// AccessController 统一访问控制器
type AccessController struct {
mu sync.RWMutex
}
// DockerImageInfo Docker镜像信息
type DockerImageInfo struct {
Namespace string
Repository string
Tag string
FullName string
}
// GlobalAccessController 全局访问控制器实例
var GlobalAccessController = &AccessController{}
// ParseDockerImage 解析Docker镜像名称
func (ac *AccessController) ParseDockerImage(image string) DockerImageInfo {
image = strings.TrimPrefix(image, "docker://")
var tag string
if idx := strings.LastIndex(image, ":"); idx != -1 {
part := image[idx+1:]
if !strings.Contains(part, "/") {
tag = part
image = image[:idx]
}
}
if tag == "" {
tag = "latest"
}
var namespace, repository string
if strings.Contains(image, "/") {
parts := strings.Split(image, "/")
if len(parts) >= 2 {
if strings.Contains(parts[0], ".") {
if len(parts) >= 3 {
namespace = parts[1]
repository = parts[2]
} else {
namespace = "library"
repository = parts[1]
}
} else {
namespace = parts[0]
repository = parts[1]
}
}
} else {
namespace = "library"
repository = image
}
fullName := namespace + "/" + repository
return DockerImageInfo{
Namespace: namespace,
Repository: repository,
Tag: tag,
FullName: fullName,
}
}
// CheckDockerAccess 检查Docker镜像访问权限
func (ac *AccessController) CheckDockerAccess(image string) (allowed bool, reason string) {
cfg := config.GetConfig()
imageInfo := ac.ParseDockerImage(image)
if len(cfg.Access.WhiteList) > 0 {
if !ac.matchImageInList(imageInfo, cfg.Access.WhiteList) {
return false, "不在Docker镜像白名单内"
}
}
if len(cfg.Access.BlackList) > 0 {
if ac.matchImageInList(imageInfo, cfg.Access.BlackList) {
return false, "Docker镜像在黑名单内"
}
}
return true, ""
}
// CheckGitHubAccess 检查GitHub仓库访问权限
func (ac *AccessController) CheckGitHubAccess(matches []string) (allowed bool, reason string) {
if len(matches) < 2 {
return false, "无效的GitHub仓库格式"
}
cfg := config.GetConfig()
if len(cfg.Access.WhiteList) > 0 && !ac.checkList(matches, cfg.Access.WhiteList) {
return false, "不在GitHub仓库白名单内"
}
if len(cfg.Access.BlackList) > 0 && ac.checkList(matches, cfg.Access.BlackList) {
return false, "GitHub仓库在黑名单内"
}
return true, ""
}
// matchImageInList 检查Docker镜像是否在指定列表中
func (ac *AccessController) matchImageInList(imageInfo DockerImageInfo, list []string) bool {
fullName := strings.ToLower(imageInfo.FullName)
namespace := strings.ToLower(imageInfo.Namespace)
for _, item := range list {
item = strings.ToLower(strings.TrimSpace(item))
if item == "" {
continue
}
if fullName == item {
return true
}
if item == namespace || item == namespace+"/*" {
return true
}
if strings.HasSuffix(item, "*") {
prefix := strings.TrimSuffix(item, "*")
if strings.HasPrefix(fullName, prefix) {
return true
}
}
if strings.HasPrefix(item, "*/") {
repoPattern := strings.TrimPrefix(item, "*/")
if strings.HasSuffix(repoPattern, "*") {
repoPrefix := strings.TrimSuffix(repoPattern, "*")
if strings.HasPrefix(imageInfo.Repository, repoPrefix) {
return true
}
} else {
if strings.ToLower(imageInfo.Repository) == repoPattern {
return true
}
}
}
if strings.HasPrefix(fullName, item+"/") {
return true
}
}
return false
}
// checkList GitHub仓库检查逻辑
func (ac *AccessController) checkList(matches, list []string) bool {
if len(matches) < 2 {
return false
}
username := strings.ToLower(strings.TrimSpace(matches[0]))
repoName := strings.ToLower(strings.TrimSpace(strings.TrimSuffix(matches[1], ".git")))
fullRepo := username + "/" + repoName
for _, item := range list {
item = strings.ToLower(strings.TrimSpace(item))
if item == "" {
continue
}
if fullRepo == item {
return true
}
if item == username || item == username+"/*" {
return true
}
if strings.HasSuffix(item, "*") {
prefix := strings.TrimSuffix(item, "*")
if strings.HasPrefix(fullRepo, prefix) {
return true
}
}
if strings.HasPrefix(fullRepo, item+"/") {
return true
}
}
return false
}

164
src/utils/cache.go Normal file
View File

@@ -0,0 +1,164 @@
package utils
import (
"crypto/md5"
"encoding/json"
"fmt"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
"hubproxy/config"
)
// CachedItem 通用缓存项
type CachedItem struct {
Data []byte
ContentType string
Headers map[string]string
ExpiresAt time.Time
}
// UniversalCache 通用缓存
type UniversalCache struct {
cache sync.Map
}
var GlobalCache = &UniversalCache{}
// Get 获取缓存项
func (c *UniversalCache) Get(key string) *CachedItem {
if v, ok := c.cache.Load(key); ok {
if cached := v.(*CachedItem); time.Now().Before(cached.ExpiresAt) {
return cached
}
c.cache.Delete(key)
}
return nil
}
func (c *UniversalCache) Set(key string, data []byte, contentType string, headers map[string]string, ttl time.Duration) {
c.cache.Store(key, &CachedItem{
Data: data,
ContentType: contentType,
Headers: headers,
ExpiresAt: time.Now().Add(ttl),
})
}
func (c *UniversalCache) GetToken(key string) string {
if item := c.Get(key); item != nil {
return string(item.Data)
}
return ""
}
func (c *UniversalCache) SetToken(key, token string, ttl time.Duration) {
c.Set(key, []byte(token), "application/json", nil, ttl)
}
// BuildCacheKey 构建稳定的缓存key
func BuildCacheKey(prefix, query string) string {
return fmt.Sprintf("%s:%x", prefix, md5.Sum([]byte(query)))
}
func BuildTokenCacheKey(query string) string {
return BuildCacheKey("token", query)
}
func BuildManifestCacheKey(imageRef, reference string) string {
key := fmt.Sprintf("%s:%s", imageRef, reference)
return BuildCacheKey("manifest", key)
}
func GetManifestTTL(reference string) time.Duration {
cfg := config.GetConfig()
defaultTTL := 30 * time.Minute
if cfg.TokenCache.DefaultTTL != "" {
if parsed, err := time.ParseDuration(cfg.TokenCache.DefaultTTL); err == nil {
defaultTTL = parsed
}
}
if strings.HasPrefix(reference, "sha256:") {
return 24 * time.Hour
}
if reference == "latest" || reference == "main" || reference == "master" ||
reference == "dev" || reference == "develop" {
return 10 * time.Minute
}
return defaultTTL
}
// ExtractTTLFromResponse 从响应中智能提取TTL
func ExtractTTLFromResponse(responseBody []byte) time.Duration {
var tokenResp struct {
ExpiresIn int `json:"expires_in"`
}
defaultTTL := 30 * time.Minute
if json.Unmarshal(responseBody, &tokenResp) == nil && tokenResp.ExpiresIn > 0 {
safeTTL := time.Duration(tokenResp.ExpiresIn-300) * time.Second
if safeTTL > 5*time.Minute {
return safeTTL
}
}
return defaultTTL
}
func WriteTokenResponse(c *gin.Context, cachedBody string) {
c.Header("Content-Type", "application/json")
c.String(200, cachedBody)
}
func WriteCachedResponse(c *gin.Context, item *CachedItem) {
if item.ContentType != "" {
c.Header("Content-Type", item.ContentType)
}
for key, value := range item.Headers {
c.Header(key, value)
}
c.Data(200, item.ContentType, item.Data)
}
// IsCacheEnabled 检查缓存是否启用
func IsCacheEnabled() bool {
cfg := config.GetConfig()
return cfg.TokenCache.Enabled
}
// IsTokenCacheEnabled 检查token缓存是否启用
func IsTokenCacheEnabled() bool {
return IsCacheEnabled()
}
// 定期清理过期缓存
func init() {
go func() {
ticker := time.NewTicker(20 * time.Minute)
defer ticker.Stop()
for range ticker.C {
now := time.Now()
expiredKeys := make([]string, 0)
GlobalCache.cache.Range(func(key, value interface{}) bool {
if cached := value.(*CachedItem); now.After(cached.ExpiresAt) {
expiredKeys = append(expiredKeys, key.(string))
}
return true
})
for _, key := range expiredKeys {
GlobalCache.cache.Delete(key)
}
}
}()
}

67
src/utils/http_client.go Normal file
View File

@@ -0,0 +1,67 @@
package utils
import (
"net"
"net/http"
"os"
"time"
"hubproxy/config"
)
var (
globalHTTPClient *http.Client
searchHTTPClient *http.Client
)
// InitHTTPClients 初始化HTTP客户端
func InitHTTPClients() {
cfg := config.GetConfig()
if p := cfg.Access.Proxy; p != "" {
os.Setenv("HTTP_PROXY", p)
os.Setenv("HTTPS_PROXY", p)
}
globalHTTPClient = &http.Client{
Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).DialContext,
MaxIdleConns: 1000,
MaxIdleConnsPerHost: 1000,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
ResponseHeaderTimeout: 300 * time.Second,
},
}
searchHTTPClient = &http.Client{
Timeout: 10 * time.Second,
Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 5 * time.Second,
KeepAlive: 30 * time.Second,
}).DialContext,
MaxIdleConns: 100,
MaxIdleConnsPerHost: 10,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 5 * time.Second,
DisableCompression: false,
},
}
}
// GetGlobalHTTPClient 获取全局HTTP客户端
func GetGlobalHTTPClient() *http.Client {
return globalHTTPClient
}
// GetSearchHTTPClient 获取搜索HTTP客户端
func GetSearchHTTPClient() *http.Client {
return searchHTTPClient
}

94
src/utils/proxy_shell.go Normal file
View File

@@ -0,0 +1,94 @@
package utils
import (
"bytes"
"compress/gzip"
"fmt"
"io"
"regexp"
"strings"
)
// GitHub URL正则表达式
var githubRegex = regexp.MustCompile(`https?://(?:github\.com|raw\.githubusercontent\.com|raw\.github\.com|gist\.githubusercontent\.com|gist\.github\.com|api\.github\.com)[^\s'"]+`)
// ProcessSmart Shell脚本智能处理函数
func ProcessSmart(input io.ReadCloser, isCompressed bool, host string) (io.Reader, int64, error) {
defer input.Close()
content, err := readShellContent(input, isCompressed)
if err != nil {
return nil, 0, fmt.Errorf("内容读取失败: %v", err)
}
if len(content) == 0 {
return strings.NewReader(""), 0, nil
}
if len(content) > 10*1024*1024 {
return strings.NewReader(content), int64(len(content)), nil
}
if !strings.Contains(content, "github.com") && !strings.Contains(content, "githubusercontent.com") {
return strings.NewReader(content), int64(len(content)), nil
}
processed := processGitHubURLs(content, host)
return strings.NewReader(processed), int64(len(processed)), nil
}
func readShellContent(input io.ReadCloser, isCompressed bool) (string, error) {
var reader io.Reader = input
if isCompressed {
peek := make([]byte, 2)
n, err := input.Read(peek)
if err != nil && err != io.EOF {
return "", fmt.Errorf("读取数据失败: %v", err)
}
if n >= 2 && peek[0] == 0x1f && peek[1] == 0x8b {
combinedReader := io.MultiReader(bytes.NewReader(peek[:n]), input)
gzReader, err := gzip.NewReader(combinedReader)
if err != nil {
return "", fmt.Errorf("gzip解压失败: %v", err)
}
defer gzReader.Close()
reader = gzReader
} else {
reader = io.MultiReader(bytes.NewReader(peek[:n]), input)
}
}
data, err := io.ReadAll(reader)
if err != nil {
return "", fmt.Errorf("读取内容失败: %v", err)
}
return string(data), nil
}
func processGitHubURLs(content, host string) string {
return githubRegex.ReplaceAllStringFunc(content, func(url string) string {
return transformURL(url, host)
})
}
// transformURL URL转换函数
func transformURL(url, host string) string {
if strings.Contains(url, host) {
return url
}
if strings.HasPrefix(url, "http://") {
url = "https" + url[4:]
} else if !strings.HasPrefix(url, "https://") && !strings.HasPrefix(url, "//") {
url = "https://" + url
}
cleanHost := strings.TrimPrefix(host, "https://")
cleanHost = strings.TrimPrefix(cleanHost, "http://")
cleanHost = strings.TrimSuffix(cleanHost, "/")
return cleanHost + "/" + url
}

270
src/utils/ratelimiter.go Normal file
View File

@@ -0,0 +1,270 @@
package utils
import (
"fmt"
"net"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
"golang.org/x/time/rate"
"hubproxy/config"
)
const (
CleanupInterval = 10 * time.Minute
MaxIPCacheSize = 10000
)
// IPRateLimiter IP限流器结构体
type IPRateLimiter struct {
ips map[string]*rateLimiterEntry
mu *sync.RWMutex
r rate.Limit
b int
whitelist []*net.IPNet
blacklist []*net.IPNet
}
// rateLimiterEntry 限流器条目
type rateLimiterEntry struct {
limiter *rate.Limiter
lastAccess time.Time
}
// InitGlobalLimiter 初始化全局限流器
func InitGlobalLimiter() *IPRateLimiter {
cfg := config.GetConfig()
whitelist := make([]*net.IPNet, 0, len(cfg.Security.WhiteList))
for _, item := range cfg.Security.WhiteList {
if item = strings.TrimSpace(item); item != "" {
if !strings.Contains(item, "/") {
item = item + "/32"
}
_, ipnet, err := net.ParseCIDR(item)
if err == nil {
whitelist = append(whitelist, ipnet)
} else {
fmt.Printf("警告: 无效的白名单IP格式: %s\n", item)
}
}
}
blacklist := make([]*net.IPNet, 0, len(cfg.Security.BlackList))
for _, item := range cfg.Security.BlackList {
if item = strings.TrimSpace(item); item != "" {
if !strings.Contains(item, "/") {
item = item + "/32"
}
_, ipnet, err := net.ParseCIDR(item)
if err == nil {
blacklist = append(blacklist, ipnet)
} else {
fmt.Printf("警告: 无效的黑名单IP格式: %s\n", item)
}
}
}
ratePerSecond := rate.Limit(float64(cfg.RateLimit.RequestLimit) / (cfg.RateLimit.PeriodHours * 3600))
burstSize := cfg.RateLimit.RequestLimit
if burstSize < 1 {
burstSize = 1
}
limiter := &IPRateLimiter{
ips: make(map[string]*rateLimiterEntry),
mu: &sync.RWMutex{},
r: ratePerSecond,
b: burstSize,
whitelist: whitelist,
blacklist: blacklist,
}
go limiter.cleanupRoutine()
return limiter
}
// cleanupRoutine 定期清理过期的限流器
func (i *IPRateLimiter) cleanupRoutine() {
ticker := time.NewTicker(CleanupInterval)
defer ticker.Stop()
for range ticker.C {
now := time.Now()
expired := make([]string, 0)
i.mu.RLock()
for ip, entry := range i.ips {
if now.Sub(entry.lastAccess) > 1*time.Hour {
expired = append(expired, ip)
}
}
i.mu.RUnlock()
if len(expired) > 0 || len(i.ips) > MaxIPCacheSize {
i.mu.Lock()
for _, ip := range expired {
delete(i.ips, ip)
}
if len(i.ips) > MaxIPCacheSize {
i.ips = make(map[string]*rateLimiterEntry)
}
i.mu.Unlock()
}
}
}
// extractIPFromAddress 从地址中提取纯IP
func extractIPFromAddress(address string) string {
if host, _, err := net.SplitHostPort(address); err == nil {
return host
}
return address
}
// normalizeIPForRateLimit 标准化IP地址用于限流
func normalizeIPForRateLimit(ipStr string) string {
ip := net.ParseIP(ipStr)
if ip == nil {
return ipStr
}
if ip.To4() != nil {
return ipStr
}
ipv6 := ip.To16()
for i := 8; i < 16; i++ {
ipv6[i] = 0
}
return ipv6.String() + "/64"
}
// isIPInCIDRList 检查IP是否在CIDR列表中
func isIPInCIDRList(ip string, cidrList []*net.IPNet) bool {
cleanIP := extractIPFromAddress(ip)
parsedIP := net.ParseIP(cleanIP)
if parsedIP == nil {
return false
}
for _, cidr := range cidrList {
if cidr.Contains(parsedIP) {
return true
}
}
return false
}
// GetLimiter 获取指定IP的限流器
func (i *IPRateLimiter) GetLimiter(ip string) (*rate.Limiter, bool) {
cleanIP := extractIPFromAddress(ip)
if isIPInCIDRList(cleanIP, i.blacklist) {
return nil, false
}
if isIPInCIDRList(cleanIP, i.whitelist) {
return rate.NewLimiter(rate.Inf, i.b), true
}
normalizedIP := normalizeIPForRateLimit(cleanIP)
now := time.Now()
i.mu.RLock()
entry, exists := i.ips[normalizedIP]
i.mu.RUnlock()
if exists {
i.mu.Lock()
if entry, stillExists := i.ips[normalizedIP]; stillExists {
entry.lastAccess = now
i.mu.Unlock()
return entry.limiter, true
}
i.mu.Unlock()
}
i.mu.Lock()
if entry, exists := i.ips[normalizedIP]; exists {
entry.lastAccess = now
i.mu.Unlock()
return entry.limiter, true
}
entry = &rateLimiterEntry{
limiter: rate.NewLimiter(i.r, i.b),
lastAccess: now,
}
i.ips[normalizedIP] = entry
i.mu.Unlock()
return entry.limiter, true
}
// RateLimitMiddleware 速率限制中间件
func RateLimitMiddleware(limiter *IPRateLimiter) gin.HandlerFunc {
return func(c *gin.Context) {
path := c.Request.URL.Path
if path == "/" || path == "/favicon.ico" || path == "/images.html" || path == "/search.html" ||
strings.HasPrefix(path, "/public/") {
c.Next()
return
}
var ip string
if forwarded := c.GetHeader("X-Forwarded-For"); forwarded != "" {
ips := strings.Split(forwarded, ",")
ip = strings.TrimSpace(ips[0])
} else if realIP := c.GetHeader("X-Real-IP"); realIP != "" {
ip = realIP
} else if remoteIP := c.GetHeader("X-Original-Forwarded-For"); remoteIP != "" {
ips := strings.Split(remoteIP, ",")
ip = strings.TrimSpace(ips[0])
} else {
ip = c.ClientIP()
}
cleanIP := extractIPFromAddress(ip)
normalizedIP := normalizeIPForRateLimit(cleanIP)
if cleanIP != normalizedIP {
fmt.Printf("请求IP: %s (提纯后: %s, 限流段: %s), X-Forwarded-For: %s, X-Real-IP: %s\n",
ip, cleanIP, normalizedIP,
c.GetHeader("X-Forwarded-For"),
c.GetHeader("X-Real-IP"))
} else {
fmt.Printf("请求IP: %s (提纯后: %s), X-Forwarded-For: %s, X-Real-IP: %s\n",
ip, cleanIP,
c.GetHeader("X-Forwarded-For"),
c.GetHeader("X-Real-IP"))
}
ipLimiter, allowed := limiter.GetLimiter(cleanIP)
if !allowed {
c.JSON(403, gin.H{
"error": "您已被限制访问",
})
c.Abort()
return
}
if !ipLimiter.Allow() {
c.JSON(429, gin.H{
"error": "请求频率过快,暂时限制访问",
})
c.Abort()
return
}
c.Next()
}
}