orama/core/pkg/serverless/websocket.go
anonpenguin23 f41242538e feat(serverless): add raw http response mode and secrets encryption
- Add `raw_http_response` configuration to functions to allow verbatim HTTP responses
- Implement cluster-wide secrets encryption key generation and distribution for serverless functions
- Update documentation with UnifiedPush support for ntfy on Android/GrapheneOS
2026-06-09 13:01:02 +03:00

447 lines
12 KiB
Go

package serverless
import (
"sync"
"sync/atomic"
"time"
"github.com/gorilla/websocket"
"go.uber.org/zap"
)
// Ensure WSManager implements WebSocketManager interface.
var _ WebSocketManager = (*WSManager)(nil)
// WSManager manages WebSocket connections for serverless functions.
// It handles connection registration, message routing, and topic subscriptions.
type WSManager struct {
// connections maps client IDs to their WebSocket connections
connections map[string]*wsConnection
connectionsMu sync.RWMutex
// subscriptions maps topic names to sets of client IDs
subscriptions map[string]map[string]struct{}
subscriptionsMu sync.RWMutex
// disconnectHooks run (synchronously) on Unregister for each client,
// AFTER the connection + subscriptions are torn down. Used by the
// ephemeral-state store (bugboard #710) to auto-clear a client's owned
// state on disconnect. Both the stateless and persistent WS handlers
// call Unregister, so a single hook covers both paths.
disconnectHooks []func(clientID string)
disconnectHooksMu sync.RWMutex
logger *zap.Logger
}
// wsConnection wraps a WebSocket connection with metadata + counters.
// Metric counters are atomic so callers (read loop, send loop) can update
// without acquiring the per-connection mutex.
type wsConnection struct {
conn WebSocketConn
clientID string
topics map[string]struct{} // Topics this client is subscribed to
mu sync.Mutex
// Metrics — read via GetConnStats / ListConnStats.
framesIn atomic.Int64
framesOut atomic.Int64
bytesIn atomic.Int64
bytesOut atomic.Int64
connectedAt time.Time
lastActiveAt atomic.Int64 // unix nano
}
// GorillaWSConn wraps a gorilla/websocket.Conn to implement WebSocketConn.
type GorillaWSConn struct {
*websocket.Conn
}
// Ensure GorillaWSConn implements WebSocketConn.
var _ WebSocketConn = (*GorillaWSConn)(nil)
// WriteMessage writes a message to the WebSocket connection.
func (c *GorillaWSConn) WriteMessage(messageType int, data []byte) error {
return c.Conn.WriteMessage(messageType, data)
}
// ReadMessage reads a message from the WebSocket connection.
func (c *GorillaWSConn) ReadMessage() (messageType int, p []byte, err error) {
return c.Conn.ReadMessage()
}
// Close closes the WebSocket connection.
func (c *GorillaWSConn) Close() error {
return c.Conn.Close()
}
// NewWSManager creates a new WebSocket manager.
func NewWSManager(logger *zap.Logger) *WSManager {
return &WSManager{
connections: make(map[string]*wsConnection),
subscriptions: make(map[string]map[string]struct{}),
logger: logger,
}
}
// Register registers a new WebSocket connection.
func (m *WSManager) Register(clientID string, conn WebSocketConn) {
m.connectionsMu.Lock()
defer m.connectionsMu.Unlock()
// Close existing connection if any
if existing, exists := m.connections[clientID]; exists {
_ = existing.conn.Close()
m.logger.Debug("Closed existing connection", zap.String("client_id", clientID))
}
wc := &wsConnection{
conn: conn,
clientID: clientID,
topics: make(map[string]struct{}),
connectedAt: time.Now(),
}
wc.lastActiveAt.Store(wc.connectedAt.UnixNano())
m.connections[clientID] = wc
m.logger.Debug("Registered WebSocket connection",
zap.String("client_id", clientID),
zap.Int("total_connections", len(m.connections)),
)
}
// AddDisconnectHook registers a callback fired (synchronously) for every
// client passed to Unregister, after its connection + subscriptions are torn
// down. Used to auto-clear WS-subscribe-tracked ephemeral state on disconnect
// (bugboard #710). Hooks must be cheap and non-blocking — they run inline on
// the WS read loop's teardown path. Register once at gateway init.
func (m *WSManager) AddDisconnectHook(hook func(clientID string)) {
if hook == nil {
return
}
m.disconnectHooksMu.Lock()
m.disconnectHooks = append(m.disconnectHooks, hook)
m.disconnectHooksMu.Unlock()
}
// Unregister removes a WebSocket connection and its subscriptions.
func (m *WSManager) Unregister(clientID string) {
m.connectionsMu.Lock()
conn, exists := m.connections[clientID]
if exists {
delete(m.connections, clientID)
}
m.connectionsMu.Unlock()
if !exists {
return
}
// Remove from all subscriptions
m.subscriptionsMu.Lock()
for topic := range conn.topics {
if clients, ok := m.subscriptions[topic]; ok {
delete(clients, clientID)
if len(clients) == 0 {
delete(m.subscriptions, topic)
}
}
}
m.subscriptionsMu.Unlock()
// Close the connection
_ = conn.conn.Close()
// Fire disconnect hooks (ephemeral-state auto-clear, bugboard #710).
m.disconnectHooksMu.RLock()
hooks := m.disconnectHooks
m.disconnectHooksMu.RUnlock()
for _, hook := range hooks {
hook(clientID)
}
m.logger.Debug("Unregistered WebSocket connection",
zap.String("client_id", clientID),
zap.Int("remaining_connections", m.GetConnectionCount()),
)
}
// Send sends data to a specific client.
func (m *WSManager) Send(clientID string, data []byte) error {
m.connectionsMu.RLock()
conn, exists := m.connections[clientID]
m.connectionsMu.RUnlock()
if !exists {
return ErrWSClientNotFound
}
conn.mu.Lock()
defer conn.mu.Unlock()
if err := conn.conn.WriteMessage(websocket.TextMessage, data); err != nil {
m.logger.Warn("Failed to send WebSocket message",
zap.String("client_id", clientID),
zap.Error(err),
)
return err
}
conn.framesOut.Add(1)
conn.bytesOut.Add(int64(len(data)))
conn.lastActiveAt.Store(time.Now().UnixNano())
return nil
}
// RecordInbound is called by the WS handler each time it reads a frame.
// Lets per-connection metrics stay accurate without exposing the inbound
// loop to the manager itself.
func (m *WSManager) RecordInbound(clientID string, byteLen int) {
m.connectionsMu.RLock()
conn, ok := m.connections[clientID]
m.connectionsMu.RUnlock()
if !ok {
return
}
conn.framesIn.Add(1)
conn.bytesIn.Add(int64(byteLen))
conn.lastActiveAt.Store(time.Now().UnixNano())
}
// ConnStats is the per-connection metrics snapshot.
type ConnStats struct {
ClientID string `json:"client_id"`
FramesIn int64 `json:"frames_in"`
FramesOut int64 `json:"frames_out"`
BytesIn int64 `json:"bytes_in"`
BytesOut int64 `json:"bytes_out"`
ConnectedAt int64 `json:"connected_at"` // unix seconds
LastActiveAt int64 `json:"last_active_at"` // unix seconds
DurationSec int64 `json:"duration_seconds"` // server-now - connected_at
TopicsCount int `json:"topics_count"`
}
// GetConnStats returns the metrics snapshot for one client, or false if absent.
func (m *WSManager) GetConnStats(clientID string) (*ConnStats, bool) {
m.connectionsMu.RLock()
conn, ok := m.connections[clientID]
m.connectionsMu.RUnlock()
if !ok {
return nil, false
}
return snapshotConn(conn), true
}
// ListConnStats returns metrics snapshots for all active clients.
func (m *WSManager) ListConnStats() []ConnStats {
m.connectionsMu.RLock()
defer m.connectionsMu.RUnlock()
out := make([]ConnStats, 0, len(m.connections))
for _, c := range m.connections {
out = append(out, *snapshotConn(c))
}
return out
}
func snapshotConn(c *wsConnection) *ConnStats {
now := time.Now()
connSec := c.connectedAt.Unix()
return &ConnStats{
ClientID: c.clientID,
FramesIn: c.framesIn.Load(),
FramesOut: c.framesOut.Load(),
BytesIn: c.bytesIn.Load(),
BytesOut: c.bytesOut.Load(),
ConnectedAt: connSec,
LastActiveAt: c.lastActiveAt.Load() / int64(time.Second),
DurationSec: now.Unix() - connSec,
TopicsCount: len(c.topics),
}
}
// Broadcast sends data to all clients subscribed to a topic.
func (m *WSManager) Broadcast(topic string, data []byte) error {
m.subscriptionsMu.RLock()
clients, exists := m.subscriptions[topic]
if !exists || len(clients) == 0 {
m.subscriptionsMu.RUnlock()
return nil // No subscribers, not an error
}
// Copy client IDs to avoid holding lock during send
clientIDs := make([]string, 0, len(clients))
for clientID := range clients {
clientIDs = append(clientIDs, clientID)
}
m.subscriptionsMu.RUnlock()
// Send to all subscribers
var sendErrors int
for _, clientID := range clientIDs {
if err := m.Send(clientID, data); err != nil {
sendErrors++
}
}
m.logger.Debug("Broadcast message",
zap.String("topic", topic),
zap.Int("recipients", len(clientIDs)),
zap.Int("errors", sendErrors),
)
return nil
}
// Subscribe adds a client to a topic.
func (m *WSManager) Subscribe(clientID, topic string) {
// Add to connection's topic list
m.connectionsMu.RLock()
conn, exists := m.connections[clientID]
m.connectionsMu.RUnlock()
if !exists {
return
}
conn.mu.Lock()
conn.topics[topic] = struct{}{}
conn.mu.Unlock()
// Add to topic's client list
m.subscriptionsMu.Lock()
if m.subscriptions[topic] == nil {
m.subscriptions[topic] = make(map[string]struct{})
}
m.subscriptions[topic][clientID] = struct{}{}
m.subscriptionsMu.Unlock()
m.logger.Debug("Client subscribed to topic",
zap.String("client_id", clientID),
zap.String("topic", topic),
)
}
// Unsubscribe removes a client from a topic.
func (m *WSManager) Unsubscribe(clientID, topic string) {
// Remove from connection's topic list
m.connectionsMu.RLock()
conn, exists := m.connections[clientID]
m.connectionsMu.RUnlock()
if exists {
conn.mu.Lock()
delete(conn.topics, topic)
conn.mu.Unlock()
}
// Remove from topic's client list
m.subscriptionsMu.Lock()
if clients, ok := m.subscriptions[topic]; ok {
delete(clients, clientID)
if len(clients) == 0 {
delete(m.subscriptions, topic)
}
}
m.subscriptionsMu.Unlock()
m.logger.Debug("Client unsubscribed from topic",
zap.String("client_id", clientID),
zap.String("topic", topic),
)
}
// GetConnectionCount returns the number of active connections.
func (m *WSManager) GetConnectionCount() int {
m.connectionsMu.RLock()
defer m.connectionsMu.RUnlock()
return len(m.connections)
}
// GetTopicSubscriberCount returns the number of subscribers for a topic.
func (m *WSManager) GetTopicSubscriberCount(topic string) int {
m.subscriptionsMu.RLock()
defer m.subscriptionsMu.RUnlock()
if clients, exists := m.subscriptions[topic]; exists {
return len(clients)
}
return 0
}
// GetClientTopics returns all topics a client is subscribed to.
func (m *WSManager) GetClientTopics(clientID string) []string {
m.connectionsMu.RLock()
conn, exists := m.connections[clientID]
m.connectionsMu.RUnlock()
if !exists {
return nil
}
conn.mu.Lock()
defer conn.mu.Unlock()
topics := make([]string, 0, len(conn.topics))
for topic := range conn.topics {
topics = append(topics, topic)
}
return topics
}
// IsConnected checks if a client is connected.
func (m *WSManager) IsConnected(clientID string) bool {
m.connectionsMu.RLock()
defer m.connectionsMu.RUnlock()
_, exists := m.connections[clientID]
return exists
}
// Close closes all connections and cleans up resources.
func (m *WSManager) Close() {
m.connectionsMu.Lock()
defer m.connectionsMu.Unlock()
for clientID, conn := range m.connections {
_ = conn.conn.Close()
delete(m.connections, clientID)
}
m.subscriptionsMu.Lock()
m.subscriptions = make(map[string]map[string]struct{})
m.subscriptionsMu.Unlock()
m.logger.Info("WebSocket manager closed")
}
// Stats returns statistics about the WebSocket manager.
type WSStats struct {
ConnectionCount int `json:"connection_count"`
TopicCount int `json:"topic_count"`
SubscriptionCount int `json:"subscription_count"`
TopicStats map[string]int `json:"topic_stats"` // topic -> subscriber count
}
// GetStats returns current statistics.
func (m *WSManager) GetStats() *WSStats {
m.connectionsMu.RLock()
connCount := len(m.connections)
m.connectionsMu.RUnlock()
m.subscriptionsMu.RLock()
topicCount := len(m.subscriptions)
topicStats := make(map[string]int, topicCount)
totalSubs := 0
for topic, clients := range m.subscriptions {
topicStats[topic] = len(clients)
totalSubs += len(clients)
}
m.subscriptionsMu.RUnlock()
return &WSStats{
ConnectionCount: connCount,
TopicCount: topicCount,
SubscriptionCount: totalSubs,
TopicStats: topicStats,
}
}