network/pkg/serverless/websocket.go
anonpenguin23 ee80be15d8 feat: add network MCP rules and documentation
- Introduced a new `network.mdc` file containing comprehensive guidelines for utilizing the network Model Context Protocol (MCP).
- Documented available MCP tools for code understanding, skill learning, and recommended workflows to enhance developer efficiency.
- Provided detailed instructions on the collaborative skill learning process and user override commands for better interaction with the MCP.
2025-12-29 14:08:58 +02:00

333 lines
8.2 KiB
Go

package serverless
import (
"sync"
"github.com/gorilla/websocket"
"go.uber.org/zap"
)
// Ensure WSManager implements WebSocketManager interface.
var _ WebSocketManager = (*WSManager)(nil)
// WSManager manages WebSocket connections for serverless functions.
// It handles connection registration, message routing, and topic subscriptions.
type WSManager struct {
// connections maps client IDs to their WebSocket connections
connections map[string]*wsConnection
connectionsMu sync.RWMutex
// subscriptions maps topic names to sets of client IDs
subscriptions map[string]map[string]struct{}
subscriptionsMu sync.RWMutex
logger *zap.Logger
}
// wsConnection wraps a WebSocket connection with metadata.
type wsConnection struct {
conn WebSocketConn
clientID string
topics map[string]struct{} // Topics this client is subscribed to
mu sync.Mutex
}
// GorillaWSConn wraps a gorilla/websocket.Conn to implement WebSocketConn.
type GorillaWSConn struct {
*websocket.Conn
}
// Ensure GorillaWSConn implements WebSocketConn.
var _ WebSocketConn = (*GorillaWSConn)(nil)
// WriteMessage writes a message to the WebSocket connection.
func (c *GorillaWSConn) WriteMessage(messageType int, data []byte) error {
return c.Conn.WriteMessage(messageType, data)
}
// ReadMessage reads a message from the WebSocket connection.
func (c *GorillaWSConn) ReadMessage() (messageType int, p []byte, err error) {
return c.Conn.ReadMessage()
}
// Close closes the WebSocket connection.
func (c *GorillaWSConn) Close() error {
return c.Conn.Close()
}
// NewWSManager creates a new WebSocket manager.
func NewWSManager(logger *zap.Logger) *WSManager {
return &WSManager{
connections: make(map[string]*wsConnection),
subscriptions: make(map[string]map[string]struct{}),
logger: logger,
}
}
// Register registers a new WebSocket connection.
func (m *WSManager) Register(clientID string, conn WebSocketConn) {
m.connectionsMu.Lock()
defer m.connectionsMu.Unlock()
// Close existing connection if any
if existing, exists := m.connections[clientID]; exists {
_ = existing.conn.Close()
m.logger.Debug("Closed existing connection", zap.String("client_id", clientID))
}
m.connections[clientID] = &wsConnection{
conn: conn,
clientID: clientID,
topics: make(map[string]struct{}),
}
m.logger.Debug("Registered WebSocket connection",
zap.String("client_id", clientID),
zap.Int("total_connections", len(m.connections)),
)
}
// Unregister removes a WebSocket connection and its subscriptions.
func (m *WSManager) Unregister(clientID string) {
m.connectionsMu.Lock()
conn, exists := m.connections[clientID]
if exists {
delete(m.connections, clientID)
}
m.connectionsMu.Unlock()
if !exists {
return
}
// Remove from all subscriptions
m.subscriptionsMu.Lock()
for topic := range conn.topics {
if clients, ok := m.subscriptions[topic]; ok {
delete(clients, clientID)
if len(clients) == 0 {
delete(m.subscriptions, topic)
}
}
}
m.subscriptionsMu.Unlock()
// Close the connection
_ = conn.conn.Close()
m.logger.Debug("Unregistered WebSocket connection",
zap.String("client_id", clientID),
zap.Int("remaining_connections", m.GetConnectionCount()),
)
}
// Send sends data to a specific client.
func (m *WSManager) Send(clientID string, data []byte) error {
m.connectionsMu.RLock()
conn, exists := m.connections[clientID]
m.connectionsMu.RUnlock()
if !exists {
return ErrWSClientNotFound
}
conn.mu.Lock()
defer conn.mu.Unlock()
if err := conn.conn.WriteMessage(websocket.TextMessage, data); err != nil {
m.logger.Warn("Failed to send WebSocket message",
zap.String("client_id", clientID),
zap.Error(err),
)
return err
}
return nil
}
// Broadcast sends data to all clients subscribed to a topic.
func (m *WSManager) Broadcast(topic string, data []byte) error {
m.subscriptionsMu.RLock()
clients, exists := m.subscriptions[topic]
if !exists || len(clients) == 0 {
m.subscriptionsMu.RUnlock()
return nil // No subscribers, not an error
}
// Copy client IDs to avoid holding lock during send
clientIDs := make([]string, 0, len(clients))
for clientID := range clients {
clientIDs = append(clientIDs, clientID)
}
m.subscriptionsMu.RUnlock()
// Send to all subscribers
var sendErrors int
for _, clientID := range clientIDs {
if err := m.Send(clientID, data); err != nil {
sendErrors++
}
}
m.logger.Debug("Broadcast message",
zap.String("topic", topic),
zap.Int("recipients", len(clientIDs)),
zap.Int("errors", sendErrors),
)
return nil
}
// Subscribe adds a client to a topic.
func (m *WSManager) Subscribe(clientID, topic string) {
// Add to connection's topic list
m.connectionsMu.RLock()
conn, exists := m.connections[clientID]
m.connectionsMu.RUnlock()
if !exists {
return
}
conn.mu.Lock()
conn.topics[topic] = struct{}{}
conn.mu.Unlock()
// Add to topic's client list
m.subscriptionsMu.Lock()
if m.subscriptions[topic] == nil {
m.subscriptions[topic] = make(map[string]struct{})
}
m.subscriptions[topic][clientID] = struct{}{}
m.subscriptionsMu.Unlock()
m.logger.Debug("Client subscribed to topic",
zap.String("client_id", clientID),
zap.String("topic", topic),
)
}
// Unsubscribe removes a client from a topic.
func (m *WSManager) Unsubscribe(clientID, topic string) {
// Remove from connection's topic list
m.connectionsMu.RLock()
conn, exists := m.connections[clientID]
m.connectionsMu.RUnlock()
if exists {
conn.mu.Lock()
delete(conn.topics, topic)
conn.mu.Unlock()
}
// Remove from topic's client list
m.subscriptionsMu.Lock()
if clients, ok := m.subscriptions[topic]; ok {
delete(clients, clientID)
if len(clients) == 0 {
delete(m.subscriptions, topic)
}
}
m.subscriptionsMu.Unlock()
m.logger.Debug("Client unsubscribed from topic",
zap.String("client_id", clientID),
zap.String("topic", topic),
)
}
// GetConnectionCount returns the number of active connections.
func (m *WSManager) GetConnectionCount() int {
m.connectionsMu.RLock()
defer m.connectionsMu.RUnlock()
return len(m.connections)
}
// GetTopicSubscriberCount returns the number of subscribers for a topic.
func (m *WSManager) GetTopicSubscriberCount(topic string) int {
m.subscriptionsMu.RLock()
defer m.subscriptionsMu.RUnlock()
if clients, exists := m.subscriptions[topic]; exists {
return len(clients)
}
return 0
}
// GetClientTopics returns all topics a client is subscribed to.
func (m *WSManager) GetClientTopics(clientID string) []string {
m.connectionsMu.RLock()
conn, exists := m.connections[clientID]
m.connectionsMu.RUnlock()
if !exists {
return nil
}
conn.mu.Lock()
defer conn.mu.Unlock()
topics := make([]string, 0, len(conn.topics))
for topic := range conn.topics {
topics = append(topics, topic)
}
return topics
}
// IsConnected checks if a client is connected.
func (m *WSManager) IsConnected(clientID string) bool {
m.connectionsMu.RLock()
defer m.connectionsMu.RUnlock()
_, exists := m.connections[clientID]
return exists
}
// Close closes all connections and cleans up resources.
func (m *WSManager) Close() {
m.connectionsMu.Lock()
defer m.connectionsMu.Unlock()
for clientID, conn := range m.connections {
_ = conn.conn.Close()
delete(m.connections, clientID)
}
m.subscriptionsMu.Lock()
m.subscriptions = make(map[string]map[string]struct{})
m.subscriptionsMu.Unlock()
m.logger.Info("WebSocket manager closed")
}
// Stats returns statistics about the WebSocket manager.
type WSStats struct {
ConnectionCount int `json:"connection_count"`
TopicCount int `json:"topic_count"`
SubscriptionCount int `json:"subscription_count"`
TopicStats map[string]int `json:"topic_stats"` // topic -> subscriber count
}
// GetStats returns current statistics.
func (m *WSManager) GetStats() *WSStats {
m.connectionsMu.RLock()
connCount := len(m.connections)
m.connectionsMu.RUnlock()
m.subscriptionsMu.RLock()
topicCount := len(m.subscriptions)
topicStats := make(map[string]int, topicCount)
totalSubs := 0
for topic, clients := range m.subscriptions {
topicStats[topic] = len(clients)
totalSubs += len(clients)
}
m.subscriptionsMu.RUnlock()
return &WSStats{
ConnectionCount: connCount,
TopicCount: topicCount,
SubscriptionCount: totalSubs,
TopicStats: topicStats,
}
}