Update main.go

This commit is contained in:
NewName
2025-03-17 16:08:19 +08:00
parent b4f516e25f
commit 156d8a37ba

View File

@@ -3,6 +3,7 @@ package main
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/gin-gonic/gin"
"io" "io"
"net" "net"
"net/http" "net/http"
@@ -12,22 +13,14 @@ import (
"strings" "strings"
"sync" "sync"
"time" "time"
"github.com/gin-gonic/gin"
) )
// 常量定义
const ( const (
sizeLimit = 1024 * 1024 * 1024 * 10 // 允许的文件大小默认10GB sizeLimit = 1024 * 1024 * 1024 * 10 // 允许的文件大小默认10GB
host = "0.0.0.0" // 监听地址 host = "0.0.0.0" // 监听地址
port = 5000 // 监听端口 port = 5000 // 监听端口
) )
type Config struct {
WhiteList []string `json:"whiteList"`
BlackList []string `json:"blackList"`
}
var ( var (
exps = []*regexp.Regexp{ exps = []*regexp.Regexp{
regexp.MustCompile(`^(?:https?://)?github\.com/([^/]+)/([^/]+)/(?:releases|archive)/.*$`), regexp.MustCompile(`^(?:https?://)?github\.com/([^/]+)/([^/]+)/(?:releases|archive)/.*$`),
@@ -45,6 +38,11 @@ var (
configLock sync.RWMutex configLock sync.RWMutex
) )
type Config struct {
WhiteList []string `json:"whiteList"`
BlackList []string `json:"blackList"`
}
func main() { func main() {
gin.SetMode(gin.ReleaseMode) gin.SetMode(gin.ReleaseMode)
router := gin.Default() router := gin.Default()
@@ -65,38 +63,28 @@ func main() {
} }
loadConfig() loadConfig()
// 每60分钟热重载黑白名单
go func() { go func() {
ticker := time.NewTicker(60 * time.Minute) for {
defer ticker.Stop() time.Sleep(10 * time.Minute)
for range ticker.C {
loadConfig() loadConfig()
} }
}() }()
// 前端访问路径,默认根路径 // 前端访问路径,默认根路径
router.Static("/", "./public") router.Static("/", "./public")
router.NoRoute(handler) router.NoRoute(handler)
serverAddr := fmt.Sprintf("%s:%d", host, port) err := router.Run(fmt.Sprintf("%s:%d", host, port))
if err := router.Run(serverAddr); err != nil { if err != nil {
fmt.Printf("Error starting server: %v\n", err) fmt.Printf("Error starting server: %v\n", err)
} }
} }
func handler(c *gin.Context) { func handler(c *gin.Context) {
rawPath := strings.TrimPrefix(c.Request.URL.RequestURI(), "/") rawPath := strings.TrimPrefix(c.Request.URL.RequestURI(), "/")
for strings.HasPrefix(rawPath, "/") { for strings.HasPrefix(rawPath, "/") {
rawPath = strings.TrimPrefix(rawPath, "/") rawPath = strings.TrimPrefix(rawPath, "/")
} }
// 脚本嵌套路径处理
if rawPath == "perl-pe-para" {
handlePerlPePara(c)
return
}
if !strings.HasPrefix(rawPath, "http") { if !strings.HasPrefix(rawPath, "http") {
c.String(http.StatusForbidden, "无效输入") c.String(http.StatusForbidden, "无效输入")
@@ -105,9 +93,6 @@ func handler(c *gin.Context) {
matches := checkURL(rawPath) matches := checkURL(rawPath)
if matches != nil { if matches != nil {
configLock.RLock()
defer configLock.RUnlock()
if len(config.WhiteList) > 0 && !checkList(matches, config.WhiteList) { if len(config.WhiteList) > 0 && !checkList(matches, config.WhiteList) {
c.String(http.StatusForbidden, "不在白名单内,限制访问。") c.String(http.StatusForbidden, "不在白名单内,限制访问。")
return return
@@ -127,22 +112,8 @@ func handler(c *gin.Context) {
proxy(c, rawPath) proxy(c, rawPath)
} }
// 处理脚本嵌套相关函数
func handlePerlPePara(c *gin.Context) {
perlstr := "perl -pe"
responseText := fmt.Sprintf(`s#(bash.*?\.sh)([^/\w\d])#\1 | %s "$(curl -L %s/perl-pe-para)" \2#g; s# (git)# https://\1#g; s#(http.*?git[^/]*?/)#%s/\1#g`, perlstr, c.Request.URL.String(), c.Request.URL.String())
c.Header("Content-Type", "text/plain")
c.Header("Cache-Control", "max-age=300")
c.String(http.StatusOK, responseText)
}
func proxy(c *gin.Context, u string) { func proxy(c *gin.Context, u string) {
// 检查是否脚本嵌套路径
if strings.HasSuffix(u, "perl-pe-para") {
handlePerlPePara(c)
return
}
req, err := http.NewRequest(c.Request.Method, u, c.Request.Body) req, err := http.NewRequest(c.Request.Method, u, c.Request.Body)
if err != nil { if err != nil {
c.String(http.StatusInternalServerError, fmt.Sprintf("server error %v", err)) c.String(http.StatusInternalServerError, fmt.Sprintf("server error %v", err))
@@ -161,11 +132,12 @@ func proxy(c *gin.Context, u string) {
c.String(http.StatusInternalServerError, fmt.Sprintf("server error %v", err)) c.String(http.StatusInternalServerError, fmt.Sprintf("server error %v", err))
return return
} }
defer func() { defer func(Body io.ReadCloser) {
if err := resp.Body.Close(); err != nil { err := Body.Close()
fmt.Printf("Error closing response body: %v\n", err) if err != nil {
} }
}() }(resp.Body)
if contentLength, ok := resp.Header["Content-Length"]; ok { if contentLength, ok := resp.Header["Content-Length"]; ok {
if size, err := strconv.Atoi(contentLength[0]); err == nil && size > sizeLimit { if size, err := strconv.Atoi(contentLength[0]); err == nil && size > sizeLimit {
@@ -195,7 +167,6 @@ func proxy(c *gin.Context, u string) {
c.Status(resp.StatusCode) c.Status(resp.StatusCode)
if _, err := io.Copy(c.Writer, resp.Body); err != nil { if _, err := io.Copy(c.Writer, resp.Body); err != nil {
fmt.Printf("Error copying response body: %v\n", err)
return return
} }
} }
@@ -206,11 +177,12 @@ func loadConfig() {
fmt.Printf("Error loading config: %v\n", err) fmt.Printf("Error loading config: %v\n", err)
return return
} }
defer func() { defer func(file *os.File) {
if err := file.Close(); err != nil { err := file.Close()
fmt.Printf("Error closing config file: %v\n", err) if err != nil {
} }
}() }(file)
var newConfig Config var newConfig Config
decoder := json.NewDecoder(file) decoder := json.NewDecoder(file)