Merge pull request #7 from awkj/main

添加 本地 cache 功能,并适当重构代码
This commit is contained in:
NewName
2025-03-16 15:14:06 +08:00

View File

@@ -3,26 +3,68 @@ package main
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/gin-gonic/gin"
"io" "io"
"log"
"net" "net"
"net/http" "net/http"
"os" "os"
"path/filepath"
"regexp" "regexp"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
"time" "time"
"crypto/sha256"
"encoding/hex"
"github.com/gin-gonic/gin"
) )
const ( const (
sizeLimit = 1024 * 1024 * 1024 * 10 // 允许的文件大小默认10GB MaxFileSize = 10 * 1024 * 1024 * 1024 // 允许的文件大小默认10GB
host = "0.0.0.0" // 监听地址 ListenHost = "0.0.0.0" // 监听地址
port = 5000 // 监听端口 ListenPort = 5000 // 监听端口
CacheDir = "cache"
// 是否开启缓存
CacheExpiry = 0 * time.Minute // 默认不缓存
) )
var ( var (
exps = []*regexp.Regexp{ cache = sync.Map{}
exps = initRegexps()
httpClient = initHTTPClient()
config *Config
configLock sync.RWMutex
)
type Config struct {
WhiteList []string `json:"whiteList"`
BlackList []string `json:"blackList"`
}
type CachedResponse struct {
Header http.Header
StatusCode int
Body []byte
Timestamp time.Time
}
func init() {
if err := os.MkdirAll(CacheDir, 0755); err != nil {
log.Fatalf("Failed to create cache directory: %v", err)
}
go func() {
for {
time.Sleep(10 * time.Minute)
loadConfig()
}
}()
loadConfig()
}
func initRegexps() []*regexp.Regexp {
return []*regexp.Regexp{
regexp.MustCompile(`^(?:https?://)?github\.com/([^/]+)/([^/]+)/(?:releases|archive)/.*$`), regexp.MustCompile(`^(?:https?://)?github\.com/([^/]+)/([^/]+)/(?:releases|archive)/.*$`),
regexp.MustCompile(`^(?:https?://)?github\.com/([^/]+)/([^/]+)/(?:blob|raw)/.*$`), regexp.MustCompile(`^(?:https?://)?github\.com/([^/]+)/([^/]+)/(?:blob|raw)/.*$`),
regexp.MustCompile(`^(?:https?://)?github\.com/([^/]+)/([^/]+)/(?:info|git-).*$`), regexp.MustCompile(`^(?:https?://)?github\.com/([^/]+)/([^/]+)/(?:info|git-).*$`),
@@ -33,21 +75,10 @@ var (
regexp.MustCompile(`^(?:https?://)?cdn-lfs\.hf\.co(?:/spaces)?/([^/]+)/([^/]+)(?:/(.*))?$`), regexp.MustCompile(`^(?:https?://)?cdn-lfs\.hf\.co(?:/spaces)?/([^/]+)/([^/]+)(?:/(.*))?$`),
regexp.MustCompile(`^(?:https?://)?download\.docker\.com/([^/]+)/.*\.(tgz|zip)$`), regexp.MustCompile(`^(?:https?://)?download\.docker\.com/([^/]+)/.*\.(tgz|zip)$`),
} }
httpClient *http.Client
config *Config
configLock sync.RWMutex
)
type Config struct {
WhiteList []string `json:"whiteList"`
BlackList []string `json:"blackList"`
} }
func main() { func initHTTPClient() *http.Client {
gin.SetMode(gin.ReleaseMode) return &http.Client{
router := gin.Default()
httpClient = &http.Client{
Transport: &http.Transport{ Transport: &http.Transport{
DialContext: (&net.Dialer{ DialContext: (&net.Dialer{
Timeout: 30 * time.Second, Timeout: 30 * time.Second,
@@ -61,30 +92,23 @@ func main() {
ResponseHeaderTimeout: 300 * time.Second, ResponseHeaderTimeout: 300 * time.Second,
}, },
} }
}
loadConfig() func main() {
go func() { gin.SetMode(gin.ReleaseMode)
for { router := gin.Default()
time.Sleep(10 * time.Minute)
loadConfig()
}
}()
// 前端访问路径,默认根路径
router.Static("/", "./public") router.Static("/", "./public")
router.NoRoute(handler) router.NoRoute(handler)
err := router.Run(fmt.Sprintf("%s:%d", host, port)) addr := fmt.Sprintf("%s:%d", ListenHost, ListenPort)
if err != nil { if err := router.Run(addr); err != nil {
fmt.Printf("Error starting server: %v\n", err) log.Fatalf("Error starting server: %v", err)
} }
} }
func handler(c *gin.Context) { func handler(c *gin.Context) {
rawPath := strings.TrimPrefix(c.Request.URL.RequestURI(), "/") rawPath := strings.TrimPrefix(c.Request.URL.RequestURI(), "/")
rawPath = strings.TrimPrefix(rawPath, "/")
for strings.HasPrefix(rawPath, "/") {
rawPath = strings.TrimPrefix(rawPath, "/")
}
if !strings.HasPrefix(rawPath, "http") { if !strings.HasPrefix(rawPath, "http") {
c.String(http.StatusForbidden, "无效输入") c.String(http.StatusForbidden, "无效输入")
@@ -92,20 +116,20 @@ func handler(c *gin.Context) {
} }
matches := checkURL(rawPath) matches := checkURL(rawPath)
if matches != nil { if matches == nil {
if len(config.WhiteList) > 0 && !checkList(matches, config.WhiteList) {
c.String(http.StatusForbidden, "不在白名单内,限制访问。")
return
}
if len(config.BlackList) > 0 && checkList(matches, config.BlackList) {
c.String(http.StatusForbidden, "黑名单限制访问")
return
}
} else {
c.String(http.StatusForbidden, "无效输入") c.String(http.StatusForbidden, "无效输入")
return return
} }
if len(config.WhiteList) > 0 && !checkList(matches, config.WhiteList) {
c.String(http.StatusForbidden, "不在白名单内,限制访问。")
return
}
if len(config.BlackList) > 0 && checkList(matches, config.BlackList) {
c.String(http.StatusForbidden, "黑名单限制访问")
return
}
if exps[1].MatchString(rawPath) { if exps[1].MatchString(rawPath) {
rawPath = strings.Replace(rawPath, "/blob/", "/raw/", 1) rawPath = strings.Replace(rawPath, "/blob/", "/raw/", 1)
} }
@@ -114,17 +138,30 @@ func handler(c *gin.Context) {
} }
func proxy(c *gin.Context, u string) { func proxy(c *gin.Context, u string) {
cacheKey := generateCacheKey(u)
// 当 CacheExpiry 为 0 时,不使用缓存
if CacheExpiry != 0 {
if cachedData, ok := cache.Load(cacheKey); ok {
log.Printf("Using cached response for %s", u)
cached := cachedData.(*CachedResponse)
if time.Since(cached.Timestamp) < CacheExpiry {
setHeaders(c, cached.Header)
c.Status(cached.StatusCode)
c.Writer.Write(cached.Body)
return
}
}
}
log.Printf("use proxy response for %s", u)
req, err := http.NewRequest(c.Request.Method, u, c.Request.Body) req, err := http.NewRequest(c.Request.Method, u, c.Request.Body)
if err != nil { if err != nil {
c.String(http.StatusInternalServerError, fmt.Sprintf("server error %v", err)) c.String(http.StatusInternalServerError, fmt.Sprintf("server error %v", err))
return return
} }
for key, values := range c.Request.Header { copyHeaders(req.Header, c.Request.Header)
for _, value := range values {
req.Header.Add(key, value)
}
}
req.Header.Del("Host") req.Header.Del("Host")
resp, err := httpClient.Do(req) resp, err := httpClient.Do(req)
@@ -132,29 +169,17 @@ func proxy(c *gin.Context, u string) {
c.String(http.StatusInternalServerError, fmt.Sprintf("server error %v", err)) c.String(http.StatusInternalServerError, fmt.Sprintf("server error %v", err))
return return
} }
defer func(Body io.ReadCloser) { defer closeWithLog(resp.Body)
err := Body.Close()
if err != nil {
}
}(resp.Body)
if contentLength, ok := resp.Header["Content-Length"]; ok { if contentLength, ok := resp.Header["Content-Length"]; ok {
if size, err := strconv.Atoi(contentLength[0]); err == nil && size > sizeLimit { if size, err := strconv.Atoi(contentLength[0]); err == nil && size > MaxFileSize {
c.String(http.StatusRequestEntityTooLarge, "File too large.") c.String(http.StatusRequestEntityTooLarge, "File too large.")
return return
} }
} }
resp.Header.Del("Content-Security-Policy") removeHeaders(resp.Header, "Content-Security-Policy", "Referrer-Policy", "Strict-Transport-Security")
resp.Header.Del("Referrer-Policy") setHeaders(c, resp.Header)
resp.Header.Del("Strict-Transport-Security")
for key, values := range resp.Header {
for _, value := range values {
c.Header(key, value)
}
}
if location := resp.Header.Get("Location"); location != "" { if location := resp.Header.Get("Location"); location != "" {
if checkURL(location) != nil { if checkURL(location) != nil {
@@ -166,28 +191,51 @@ func proxy(c *gin.Context, u string) {
} }
c.Status(resp.StatusCode) c.Status(resp.StatusCode)
if _, err := io.Copy(c.Writer, resp.Body); err != nil { body, err := io.ReadAll(resp.Body)
if err != nil {
log.Printf("Failed to read response body: %v", err)
return return
} }
if _, err := c.Writer.Write(body); err != nil {
log.Printf("Failed to write response body: %v", err)
return
}
// 当 CacheExpiry 不为 0 时,保存到缓存
if CacheExpiry != 0 {
// Save to cache
cached := &CachedResponse{
Header: resp.Header,
StatusCode: resp.StatusCode,
Body: body,
Timestamp: time.Now(),
}
cache.Store(cacheKey, cached)
cacheFilePath := filepath.Join(CacheDir, cacheKey)
// 修改 ioutil.WriteFile 为 os.WriteFile
if err := os.WriteFile(cacheFilePath, body, 0644); err != nil {
log.Printf("Failed to write cache file: %v", err)
}
}
}
func generateCacheKey(u string) string {
hash := sha256.Sum256([]byte(u))
return hex.EncodeToString(hash[:])
} }
func loadConfig() { func loadConfig() {
file, err := os.Open("config.json") file, err := os.Open("config.json")
if err != nil { if err != nil {
fmt.Printf("Error loading config: %v\n", err) log.Printf("Error loading config: %v", err)
return return
} }
defer func(file *os.File) { defer closeWithLog(file)
err := file.Close()
if err != nil {
}
}(file)
var newConfig Config var newConfig Config
decoder := json.NewDecoder(file) decoder := json.NewDecoder(file)
if err := decoder.Decode(&newConfig); err != nil { if err := decoder.Decode(&newConfig); err != nil {
fmt.Printf("Error decoding config: %v\n", err) log.Printf("Error decoding config: %v", err)
return return
} }
@@ -213,3 +261,31 @@ func checkList(matches, list []string) bool {
} }
return false return false
} }
func setHeaders(c *gin.Context, headers http.Header) {
for key, values := range headers {
for _, value := range values {
c.Header(key, value)
}
}
}
func copyHeaders(dst, src http.Header) {
for key, values := range src {
for _, value := range values {
dst.Add(key, value)
}
}
}
func removeHeaders(headers http.Header, keys ...string) {
for _, key := range keys {
headers.Del(key)
}
}
func closeWithLog(c io.Closer) {
if err := c.Close(); err != nil {
log.Printf("Failed to close: %v", err)
}
}