mirror of
https://github.com/DeBrosOfficial/orama.git
synced 2026-03-27 17:34:29 +00:00
333 lines
8.2 KiB
Go
333 lines
8.2 KiB
Go
package serverless
|
|
|
|
import (
|
|
"sync"
|
|
|
|
"github.com/gorilla/websocket"
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
// Ensure WSManager implements WebSocketManager interface.
|
|
var _ WebSocketManager = (*WSManager)(nil)
|
|
|
|
// WSManager manages WebSocket connections for serverless functions.
|
|
// It handles connection registration, message routing, and topic subscriptions.
|
|
type WSManager struct {
|
|
// connections maps client IDs to their WebSocket connections
|
|
connections map[string]*wsConnection
|
|
connectionsMu sync.RWMutex
|
|
|
|
// subscriptions maps topic names to sets of client IDs
|
|
subscriptions map[string]map[string]struct{}
|
|
subscriptionsMu sync.RWMutex
|
|
|
|
logger *zap.Logger
|
|
}
|
|
|
|
// wsConnection wraps a WebSocket connection with metadata.
|
|
type wsConnection struct {
|
|
conn WebSocketConn
|
|
clientID string
|
|
topics map[string]struct{} // Topics this client is subscribed to
|
|
mu sync.Mutex
|
|
}
|
|
|
|
// GorillaWSConn wraps a gorilla/websocket.Conn to implement WebSocketConn.
|
|
type GorillaWSConn struct {
|
|
*websocket.Conn
|
|
}
|
|
|
|
// Ensure GorillaWSConn implements WebSocketConn.
|
|
var _ WebSocketConn = (*GorillaWSConn)(nil)
|
|
|
|
// WriteMessage writes a message to the WebSocket connection.
|
|
func (c *GorillaWSConn) WriteMessage(messageType int, data []byte) error {
|
|
return c.Conn.WriteMessage(messageType, data)
|
|
}
|
|
|
|
// ReadMessage reads a message from the WebSocket connection.
|
|
func (c *GorillaWSConn) ReadMessage() (messageType int, p []byte, err error) {
|
|
return c.Conn.ReadMessage()
|
|
}
|
|
|
|
// Close closes the WebSocket connection.
|
|
func (c *GorillaWSConn) Close() error {
|
|
return c.Conn.Close()
|
|
}
|
|
|
|
// NewWSManager creates a new WebSocket manager.
|
|
func NewWSManager(logger *zap.Logger) *WSManager {
|
|
return &WSManager{
|
|
connections: make(map[string]*wsConnection),
|
|
subscriptions: make(map[string]map[string]struct{}),
|
|
logger: logger,
|
|
}
|
|
}
|
|
|
|
// Register registers a new WebSocket connection.
|
|
func (m *WSManager) Register(clientID string, conn WebSocketConn) {
|
|
m.connectionsMu.Lock()
|
|
defer m.connectionsMu.Unlock()
|
|
|
|
// Close existing connection if any
|
|
if existing, exists := m.connections[clientID]; exists {
|
|
_ = existing.conn.Close()
|
|
m.logger.Debug("Closed existing connection", zap.String("client_id", clientID))
|
|
}
|
|
|
|
m.connections[clientID] = &wsConnection{
|
|
conn: conn,
|
|
clientID: clientID,
|
|
topics: make(map[string]struct{}),
|
|
}
|
|
|
|
m.logger.Debug("Registered WebSocket connection",
|
|
zap.String("client_id", clientID),
|
|
zap.Int("total_connections", len(m.connections)),
|
|
)
|
|
}
|
|
|
|
// Unregister removes a WebSocket connection and its subscriptions.
|
|
func (m *WSManager) Unregister(clientID string) {
|
|
m.connectionsMu.Lock()
|
|
conn, exists := m.connections[clientID]
|
|
if exists {
|
|
delete(m.connections, clientID)
|
|
}
|
|
m.connectionsMu.Unlock()
|
|
|
|
if !exists {
|
|
return
|
|
}
|
|
|
|
// Remove from all subscriptions
|
|
m.subscriptionsMu.Lock()
|
|
for topic := range conn.topics {
|
|
if clients, ok := m.subscriptions[topic]; ok {
|
|
delete(clients, clientID)
|
|
if len(clients) == 0 {
|
|
delete(m.subscriptions, topic)
|
|
}
|
|
}
|
|
}
|
|
m.subscriptionsMu.Unlock()
|
|
|
|
// Close the connection
|
|
_ = conn.conn.Close()
|
|
|
|
m.logger.Debug("Unregistered WebSocket connection",
|
|
zap.String("client_id", clientID),
|
|
zap.Int("remaining_connections", m.GetConnectionCount()),
|
|
)
|
|
}
|
|
|
|
// Send sends data to a specific client.
|
|
func (m *WSManager) Send(clientID string, data []byte) error {
|
|
m.connectionsMu.RLock()
|
|
conn, exists := m.connections[clientID]
|
|
m.connectionsMu.RUnlock()
|
|
|
|
if !exists {
|
|
return ErrWSClientNotFound
|
|
}
|
|
|
|
conn.mu.Lock()
|
|
defer conn.mu.Unlock()
|
|
|
|
if err := conn.conn.WriteMessage(websocket.TextMessage, data); err != nil {
|
|
m.logger.Warn("Failed to send WebSocket message",
|
|
zap.String("client_id", clientID),
|
|
zap.Error(err),
|
|
)
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Broadcast sends data to all clients subscribed to a topic.
|
|
func (m *WSManager) Broadcast(topic string, data []byte) error {
|
|
m.subscriptionsMu.RLock()
|
|
clients, exists := m.subscriptions[topic]
|
|
if !exists || len(clients) == 0 {
|
|
m.subscriptionsMu.RUnlock()
|
|
return nil // No subscribers, not an error
|
|
}
|
|
|
|
// Copy client IDs to avoid holding lock during send
|
|
clientIDs := make([]string, 0, len(clients))
|
|
for clientID := range clients {
|
|
clientIDs = append(clientIDs, clientID)
|
|
}
|
|
m.subscriptionsMu.RUnlock()
|
|
|
|
// Send to all subscribers
|
|
var sendErrors int
|
|
for _, clientID := range clientIDs {
|
|
if err := m.Send(clientID, data); err != nil {
|
|
sendErrors++
|
|
}
|
|
}
|
|
|
|
m.logger.Debug("Broadcast message",
|
|
zap.String("topic", topic),
|
|
zap.Int("recipients", len(clientIDs)),
|
|
zap.Int("errors", sendErrors),
|
|
)
|
|
|
|
return nil
|
|
}
|
|
|
|
// Subscribe adds a client to a topic.
|
|
func (m *WSManager) Subscribe(clientID, topic string) {
|
|
// Add to connection's topic list
|
|
m.connectionsMu.RLock()
|
|
conn, exists := m.connections[clientID]
|
|
m.connectionsMu.RUnlock()
|
|
|
|
if !exists {
|
|
return
|
|
}
|
|
|
|
conn.mu.Lock()
|
|
conn.topics[topic] = struct{}{}
|
|
conn.mu.Unlock()
|
|
|
|
// Add to topic's client list
|
|
m.subscriptionsMu.Lock()
|
|
if m.subscriptions[topic] == nil {
|
|
m.subscriptions[topic] = make(map[string]struct{})
|
|
}
|
|
m.subscriptions[topic][clientID] = struct{}{}
|
|
m.subscriptionsMu.Unlock()
|
|
|
|
m.logger.Debug("Client subscribed to topic",
|
|
zap.String("client_id", clientID),
|
|
zap.String("topic", topic),
|
|
)
|
|
}
|
|
|
|
// Unsubscribe removes a client from a topic.
|
|
func (m *WSManager) Unsubscribe(clientID, topic string) {
|
|
// Remove from connection's topic list
|
|
m.connectionsMu.RLock()
|
|
conn, exists := m.connections[clientID]
|
|
m.connectionsMu.RUnlock()
|
|
|
|
if exists {
|
|
conn.mu.Lock()
|
|
delete(conn.topics, topic)
|
|
conn.mu.Unlock()
|
|
}
|
|
|
|
// Remove from topic's client list
|
|
m.subscriptionsMu.Lock()
|
|
if clients, ok := m.subscriptions[topic]; ok {
|
|
delete(clients, clientID)
|
|
if len(clients) == 0 {
|
|
delete(m.subscriptions, topic)
|
|
}
|
|
}
|
|
m.subscriptionsMu.Unlock()
|
|
|
|
m.logger.Debug("Client unsubscribed from topic",
|
|
zap.String("client_id", clientID),
|
|
zap.String("topic", topic),
|
|
)
|
|
}
|
|
|
|
// GetConnectionCount returns the number of active connections.
|
|
func (m *WSManager) GetConnectionCount() int {
|
|
m.connectionsMu.RLock()
|
|
defer m.connectionsMu.RUnlock()
|
|
return len(m.connections)
|
|
}
|
|
|
|
// GetTopicSubscriberCount returns the number of subscribers for a topic.
|
|
func (m *WSManager) GetTopicSubscriberCount(topic string) int {
|
|
m.subscriptionsMu.RLock()
|
|
defer m.subscriptionsMu.RUnlock()
|
|
if clients, exists := m.subscriptions[topic]; exists {
|
|
return len(clients)
|
|
}
|
|
return 0
|
|
}
|
|
|
|
// GetClientTopics returns all topics a client is subscribed to.
|
|
func (m *WSManager) GetClientTopics(clientID string) []string {
|
|
m.connectionsMu.RLock()
|
|
conn, exists := m.connections[clientID]
|
|
m.connectionsMu.RUnlock()
|
|
|
|
if !exists {
|
|
return nil
|
|
}
|
|
|
|
conn.mu.Lock()
|
|
defer conn.mu.Unlock()
|
|
|
|
topics := make([]string, 0, len(conn.topics))
|
|
for topic := range conn.topics {
|
|
topics = append(topics, topic)
|
|
}
|
|
return topics
|
|
}
|
|
|
|
// IsConnected checks if a client is connected.
|
|
func (m *WSManager) IsConnected(clientID string) bool {
|
|
m.connectionsMu.RLock()
|
|
defer m.connectionsMu.RUnlock()
|
|
_, exists := m.connections[clientID]
|
|
return exists
|
|
}
|
|
|
|
// Close closes all connections and cleans up resources.
|
|
func (m *WSManager) Close() {
|
|
m.connectionsMu.Lock()
|
|
defer m.connectionsMu.Unlock()
|
|
|
|
for clientID, conn := range m.connections {
|
|
_ = conn.conn.Close()
|
|
delete(m.connections, clientID)
|
|
}
|
|
|
|
m.subscriptionsMu.Lock()
|
|
m.subscriptions = make(map[string]map[string]struct{})
|
|
m.subscriptionsMu.Unlock()
|
|
|
|
m.logger.Info("WebSocket manager closed")
|
|
}
|
|
|
|
// Stats returns statistics about the WebSocket manager.
|
|
type WSStats struct {
|
|
ConnectionCount int `json:"connection_count"`
|
|
TopicCount int `json:"topic_count"`
|
|
SubscriptionCount int `json:"subscription_count"`
|
|
TopicStats map[string]int `json:"topic_stats"` // topic -> subscriber count
|
|
}
|
|
|
|
// GetStats returns current statistics.
|
|
func (m *WSManager) GetStats() *WSStats {
|
|
m.connectionsMu.RLock()
|
|
connCount := len(m.connections)
|
|
m.connectionsMu.RUnlock()
|
|
|
|
m.subscriptionsMu.RLock()
|
|
topicCount := len(m.subscriptions)
|
|
topicStats := make(map[string]int, topicCount)
|
|
totalSubs := 0
|
|
for topic, clients := range m.subscriptions {
|
|
topicStats[topic] = len(clients)
|
|
totalSubs += len(clients)
|
|
}
|
|
m.subscriptionsMu.RUnlock()
|
|
|
|
return &WSStats{
|
|
ConnectionCount: connCount,
|
|
TopicCount: topicCount,
|
|
SubscriptionCount: totalSubs,
|
|
TopicStats: topicStats,
|
|
}
|
|
}
|
|
|