优化ws连接和心跳保活

This commit is contained in:
user123456
2025-06-12 16:18:25 +08:00
parent 35916cfdc0
commit 543d17dd25

View File

@@ -75,6 +75,10 @@ type Client struct {
TaskID string TaskID string
Send chan []byte Send chan []byte
CloseOnce sync.Once CloseOnce sync.Once
heartbeat *time.Ticker // 心跳定时器
isActive bool // 连接是否活跃
lastPing time.Time // 最后一次ping时间
mu sync.RWMutex // 保护isActive和lastPing
} }
var ( var (
@@ -132,9 +136,11 @@ func handleWebSocket(c *gin.Context) {
} }
client := &Client{ client := &Client{
Conn: conn, Conn: conn,
TaskID: taskID, TaskID: taskID,
Send: make(chan []byte, 256), Send: make(chan []byte, 256),
isActive: true,
lastPing: time.Now(),
} }
// 注册客户端 // 注册客户端
@@ -142,6 +148,9 @@ func handleWebSocket(c *gin.Context) {
clients[taskID] = client clients[taskID] = client
clientLock.Unlock() clientLock.Unlock()
// 启动心跳保活机制
client.startHeartbeat()
// 启动goroutine处理消息发送 // 启动goroutine处理消息发送
go client.writePump() go client.writePump()
@@ -150,34 +159,133 @@ func handleWebSocket(c *gin.Context) {
if task, exists := tasks[taskID]; exists { if task, exists := tasks[taskID]; exists {
tasksLock.Unlock() tasksLock.Unlock()
taskJSON, _ := json.Marshal(task) taskJSON, _ := json.Marshal(task)
client.Send <- taskJSON select {
case client.Send <- taskJSON:
default:
// 通道满时不阻塞
}
} else { } else {
tasksLock.Unlock() tasksLock.Unlock()
} }
// 设置WebSocket超时 // 设置更宽松的读取超时,主要用于检测客户端断开
conn.SetReadDeadline(time.Now().Add(120 * time.Second)) conn.SetReadDeadline(time.Now().Add(5 * time.Minute))
conn.SetWriteDeadline(time.Now().Add(60 * time.Second))
// 不设置写入超时,让心跳机制处理连接活跃性
// 处理WebSocket关闭 // 处理WebSocket关闭
conn.SetCloseHandler(func(code int, text string) error { conn.SetCloseHandler(func(code int, text string) error {
client.CloseOnce.Do(func() { client.close()
close(client.Send)
clientLock.Lock()
delete(clients, taskID)
clientLock.Unlock()
})
return nil return nil
}) })
// 处理pong消息以确认连接活跃
conn.SetPongHandler(func(appData string) error {
client.mu.Lock()
client.lastPing = time.Now()
client.mu.Unlock()
conn.SetReadDeadline(time.Now().Add(5 * time.Minute))
return nil
})
// 启动读取循环主要用于处理pong和检测断开
go client.readPump()
}
// 启动心跳保活机制
func (c *Client) startHeartbeat() {
c.heartbeat = time.NewTicker(30 * time.Second)
go func() {
defer c.heartbeat.Stop()
for {
select {
case <-c.heartbeat.C:
c.mu.RLock()
if !c.isActive {
c.mu.RUnlock()
return
}
c.mu.RUnlock()
// 发送ping消息保持连接活跃
if err := c.Conn.WriteMessage(websocket.PingMessage, []byte{}); err != nil {
fmt.Printf("发送心跳失败: %v\n", err)
c.close()
return
}
// 检查上次pong响应时间
c.mu.RLock()
timeSinceLastPong := time.Since(c.lastPing)
c.mu.RUnlock()
// 如果超过2分钟没有收到pong响应认为连接已断开
if timeSinceLastPong > 2*time.Minute {
fmt.Printf("客户端 %s 心跳超时,关闭连接\n", c.TaskID)
c.close()
return
}
}
}
}()
}
// 读取循环主要处理pong消息和检测连接断开
func (c *Client) readPump() {
defer c.close()
for {
c.mu.RLock()
if !c.isActive {
c.mu.RUnlock()
break
}
c.mu.RUnlock()
// 读取消息,主要是为了检测连接状态
_, _, err := c.Conn.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
fmt.Printf("WebSocket意外关闭: %v\n", err)
}
break
}
}
}
// 安全关闭客户端连接
func (c *Client) close() {
c.CloseOnce.Do(func() {
c.mu.Lock()
c.isActive = false
c.mu.Unlock()
if c.heartbeat != nil {
c.heartbeat.Stop()
}
close(c.Send)
c.Conn.Close()
clientLock.Lock()
delete(clients, c.TaskID)
clientLock.Unlock()
})
} }
// 客户端消息发送loop // 客户端消息发送loop
func (c *Client) writePump() { func (c *Client) writePump() {
defer func() { defer c.close()
c.Conn.Close()
}()
for message := range c.Send { for message := range c.Send {
c.mu.RLock()
if !c.isActive {
c.mu.RUnlock()
break
}
c.mu.RUnlock()
// 不设置写入超时,依赖心跳机制检测连接状态
err := c.Conn.WriteMessage(websocket.TextMessage, message) err := c.Conn.WriteMessage(websocket.TextMessage, message)
if err != nil { if err != nil {
fmt.Printf("发送WS消息失败: %v\n", err) fmt.Printf("发送WS消息失败: %v\n", err)
@@ -249,6 +357,22 @@ func initTask(task *DownloadTask) {
task.done = make(chan struct{}) task.done = make(chan struct{})
task.createTime = time.Now() task.createTime = time.Now()
// 启动定期状态更新保持WebSocket连接活跃
go func() {
ticker := time.NewTicker(15 * time.Second)
defer ticker.Stop()
for {
select {
case <-ticker.C:
// 定期发送当前状态保持连接活跃
sendTaskUpdate(task)
case <-task.done:
return
}
}
}()
// 启动进度处理goroutine // 启动进度处理goroutine
go func() { go func() {
defer func() { defer func() {
@@ -1050,11 +1174,19 @@ func sendTaskUpdate(task *DownloadTask) {
clientLock.Unlock() clientLock.Unlock()
if exists { if exists {
select { // 检查客户端是否活跃
case client.Send <- taskJSON: client.mu.RLock()
// 成功发送 isActive := client.isActive
default: client.mu.RUnlock()
// 通道已满或关闭,忽略
if isActive {
select {
case client.Send <- taskJSON:
// 成功发送
case <-time.After(5 * time.Second):
// 发送超时,可能客户端处理慢或连接有问题
fmt.Printf("发送消息到客户端 %s 超时\n", task.ID)
}
} }
} }
} }
@@ -1267,7 +1399,7 @@ func checkForCompletionMarkers(output string) bool {
// cleanupWebSocketConnections 定期清理无效的WebSocket连接 // cleanupWebSocketConnections 定期清理无效的WebSocket连接
func cleanupWebSocketConnections() { func cleanupWebSocketConnections() {
ticker := time.NewTicker(5 * time.Minute) ticker := time.NewTicker(2 * time.Minute) // 增加清理频率
defer ticker.Stop() defer ticker.Stop()
for range ticker.C { for range ticker.C {
@@ -1275,9 +1407,28 @@ func cleanupWebSocketConnections() {
disconnectedClients := make([]string, 0) disconnectedClients := make([]string, 0)
for taskID, client := range clients { for taskID, client := range clients {
// 检查连接是否仍然活跃 client.mu.RLock()
if err := client.Conn.WriteMessage(websocket.PingMessage, []byte{}); err != nil { isActive := client.isActive
// 连接已断开,标记待清理 lastPing := client.lastPing
client.mu.RUnlock()
// 检查连接是否还活跃
shouldRemove := false
if !isActive {
shouldRemove = true
} else if time.Since(lastPing) > 3*time.Minute {
// 超过3分钟没有心跳响应认为连接已断开
shouldRemove = true
fmt.Printf("客户端 %s 心跳超时,标记为断开\n", taskID)
} else {
// 尝试发送ping测试连接
if err := client.Conn.WriteMessage(websocket.PingMessage, []byte{}); err != nil {
shouldRemove = true
}
}
if shouldRemove {
disconnectedClients = append(disconnectedClients, taskID) disconnectedClients = append(disconnectedClients, taskID)
} }
} }
@@ -1285,11 +1436,7 @@ func cleanupWebSocketConnections() {
// 清理断开的连接 // 清理断开的连接
for _, taskID := range disconnectedClients { for _, taskID := range disconnectedClients {
if client, exists := clients[taskID]; exists { if client, exists := clients[taskID]; exists {
client.CloseOnce.Do(func() { client.close()
close(client.Send)
client.Conn.Close()
})
delete(clients, taskID)
} }
} }