diff --git a/cmd/gateway/config.go b/cmd/gateway/config.go index 639a84b..9e720bb 100644 --- a/cmd/gateway/config.go +++ b/cmd/gateway/config.go @@ -73,6 +73,27 @@ func parseGatewayConfig(logger *logging.ColoredLogger) *gateway.Config { } // Load YAML + type yamlICEServer struct { + URLs []string `yaml:"urls"` + Username string `yaml:"username,omitempty"` + Credential string `yaml:"credential,omitempty"` + } + + type yamlTURN struct { + SharedSecret string `yaml:"shared_secret"` + TTL string `yaml:"ttl"` + ExternalHost string `yaml:"external_host"` + STUNURLs []string `yaml:"stun_urls"` + TURNURLs []string `yaml:"turn_urls"` + } + + type yamlSFU struct { + Enabled bool `yaml:"enabled"` + MaxParticipants int `yaml:"max_participants"` + MediaTimeout string `yaml:"media_timeout"` + ICEServers []yamlICEServer `yaml:"ice_servers"` + } + type yamlCfg struct { ListenAddr string `yaml:"listen_addr"` ClientNamespace string `yaml:"client_namespace"` @@ -87,6 +108,8 @@ func parseGatewayConfig(logger *logging.ColoredLogger) *gateway.Config { IPFSAPIURL string `yaml:"ipfs_api_url"` IPFSTimeout string `yaml:"ipfs_timeout"` IPFSReplicationFactor int `yaml:"ipfs_replication_factor"` + TURN yamlTURN `yaml:"turn"` + SFU yamlSFU `yaml:"sfu"` } data, err := os.ReadFile(configPath) @@ -191,6 +214,64 @@ func parseGatewayConfig(logger *logging.ColoredLogger) *gateway.Config { cfg.IPFSReplicationFactor = y.IPFSReplicationFactor } + // TURN configuration + if y.TURN.SharedSecret != "" || len(y.TURN.STUNURLs) > 0 || len(y.TURN.TURNURLs) > 0 { + turnCfg := &config.TURNConfig{ + SharedSecret: y.TURN.SharedSecret, + ExternalHost: y.TURN.ExternalHost, + STUNURLs: y.TURN.STUNURLs, + TURNURLs: y.TURN.TURNURLs, + } + // Check for environment variable overrides + if envSecret := os.Getenv("TURN_SHARED_SECRET"); envSecret != "" { + turnCfg.SharedSecret = envSecret + } + if envHost := os.Getenv("TURN_EXTERNAL_HOST"); envHost != "" { + turnCfg.ExternalHost = envHost + } + if v := strings.TrimSpace(y.TURN.TTL); v != "" { + if parsed, err := time.ParseDuration(v); err == nil { + turnCfg.TTL = parsed + } else { + logger.ComponentWarn(logging.ComponentGeneral, "invalid turn.ttl, using default", zap.String("value", v), zap.Error(err)) + } + } + cfg.TURN = turnCfg + logger.ComponentInfo(logging.ComponentGeneral, "TURN configuration loaded", + zap.Int("stun_urls", len(turnCfg.STUNURLs)), + zap.Int("turn_urls", len(turnCfg.TURNURLs)), + zap.String("external_host", turnCfg.ExternalHost), + ) + } + + // SFU configuration + if y.SFU.Enabled { + sfuCfg := &config.SFUConfig{ + Enabled: true, + MaxParticipants: y.SFU.MaxParticipants, + } + if v := strings.TrimSpace(y.SFU.MediaTimeout); v != "" { + if parsed, err := time.ParseDuration(v); err == nil { + sfuCfg.MediaTimeout = parsed + } else { + logger.ComponentWarn(logging.ComponentGeneral, "invalid sfu.media_timeout, using default", zap.String("value", v), zap.Error(err)) + } + } + // Parse ICE servers + for _, iceServer := range y.SFU.ICEServers { + sfuCfg.ICEServers = append(sfuCfg.ICEServers, config.ICEServerConfig{ + URLs: iceServer.URLs, + Username: iceServer.Username, + Credential: iceServer.Credential, + }) + } + cfg.SFU = sfuCfg + logger.ComponentInfo(logging.ComponentGeneral, "SFU configuration loaded", + zap.Int("max_participants", sfuCfg.MaxParticipants), + zap.Int("ice_servers", len(sfuCfg.ICEServers)), + ) + } + // Validate configuration if errs := cfg.ValidateConfig(); len(errs) > 0 { fmt.Fprintf(os.Stderr, "\nGateway configuration errors (%d):\n", len(errs)) diff --git a/go.mod b/go.mod index 977bb54..6a83b2b 100644 --- a/go.mod +++ b/go.mod @@ -18,6 +18,10 @@ require ( github.com/mattn/go-sqlite3 v1.14.32 github.com/multiformats/go-multiaddr v0.15.0 github.com/olric-data/olric v0.7.0 + github.com/pion/interceptor v0.1.37 + github.com/pion/rtcp v1.2.15 + github.com/pion/turn/v4 v4.0.0 + github.com/pion/webrtc/v4 v4.0.10 github.com/rqlite/gorqlite v0.0.0-20250609141355-ac86a4a1c9a8 github.com/tetratelabs/wazero v1.11.0 go.uber.org/zap v1.27.0 @@ -113,11 +117,9 @@ require ( github.com/pion/dtls/v2 v2.2.12 // indirect github.com/pion/dtls/v3 v3.0.4 // indirect github.com/pion/ice/v4 v4.0.8 // indirect - github.com/pion/interceptor v0.1.37 // indirect github.com/pion/logging v0.2.3 // indirect github.com/pion/mdns/v2 v2.0.7 // indirect github.com/pion/randutil v0.1.0 // indirect - github.com/pion/rtcp v1.2.15 // indirect github.com/pion/rtp v1.8.11 // indirect github.com/pion/sctp v1.8.37 // indirect github.com/pion/sdp/v3 v3.0.10 // indirect @@ -126,8 +128,6 @@ require ( github.com/pion/stun/v3 v3.0.0 // indirect github.com/pion/transport/v2 v2.2.10 // indirect github.com/pion/transport/v3 v3.0.7 // indirect - github.com/pion/turn/v4 v4.0.0 // indirect - github.com/pion/webrtc/v4 v4.0.10 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/prometheus/client_golang v1.22.0 // indirect github.com/prometheus/client_model v0.6.2 // indirect diff --git a/pkg/config/config.go b/pkg/config/config.go index e1881d3..e678e51 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -3,7 +3,6 @@ package config import ( "time" - "github.com/DeBrosOfficial/network/pkg/config/validate" "github.com/multiformats/go-multiaddr" ) @@ -15,69 +14,248 @@ type Config struct { Security SecurityConfig `yaml:"security"` Logging LoggingConfig `yaml:"logging"` HTTPGateway HTTPGatewayConfig `yaml:"http_gateway"` + TURNServer TURNServerConfig `yaml:"turn_server"` // Built-in TURN server } -// ValidationError represents a single validation error with context. -// This is exported from the validate subpackage for backward compatibility. -type ValidationError = validate.ValidationError - -// ValidateSwarmKey validates that a swarm key is 64 hex characters. -// This is exported from the validate subpackage for backward compatibility. -func ValidateSwarmKey(key string) error { - return validate.ValidateSwarmKey(key) +// NodeConfig contains node-specific configuration +type NodeConfig struct { + ID string `yaml:"id"` // Auto-generated if empty + ListenAddresses []string `yaml:"listen_addresses"` // LibP2P listen addresses + DataDir string `yaml:"data_dir"` // Data directory + MaxConnections int `yaml:"max_connections"` // Maximum peer connections + Domain string `yaml:"domain"` // Domain for this node (e.g., node-1.orama.network) } -// Validate performs comprehensive validation of the entire config. -// It aggregates all errors and returns them, allowing the caller to print all issues at once. -func (c *Config) Validate() []error { - var errs []error +// DatabaseConfig contains database-related configuration +type DatabaseConfig struct { + DataDir string `yaml:"data_dir"` + ReplicationFactor int `yaml:"replication_factor"` + ShardCount int `yaml:"shard_count"` + MaxDatabaseSize int64 `yaml:"max_database_size"` // In bytes + BackupInterval time.Duration `yaml:"backup_interval"` - // Validate node config - errs = append(errs, validate.ValidateNode(validate.NodeConfig{ - ID: c.Node.ID, - ListenAddresses: c.Node.ListenAddresses, - DataDir: c.Node.DataDir, - MaxConnections: c.Node.MaxConnections, - })...) + // RQLite-specific configuration + RQLitePort int `yaml:"rqlite_port"` // RQLite HTTP API port + RQLiteRaftPort int `yaml:"rqlite_raft_port"` // RQLite Raft consensus port + RQLiteJoinAddress string `yaml:"rqlite_join_address"` // Address to join RQLite cluster - // Validate database config - errs = append(errs, validate.ValidateDatabase(validate.DatabaseConfig{ - DataDir: c.Database.DataDir, - ReplicationFactor: c.Database.ReplicationFactor, - ShardCount: c.Database.ShardCount, - MaxDatabaseSize: c.Database.MaxDatabaseSize, - RQLitePort: c.Database.RQLitePort, - RQLiteRaftPort: c.Database.RQLiteRaftPort, - RQLiteJoinAddress: c.Database.RQLiteJoinAddress, - ClusterSyncInterval: c.Database.ClusterSyncInterval, - PeerInactivityLimit: c.Database.PeerInactivityLimit, - MinClusterSize: c.Database.MinClusterSize, - })...) + // RQLite node-to-node TLS encryption (for inter-node Raft communication) + // See: https://rqlite.io/docs/guides/security/#encrypting-node-to-node-communication + NodeCert string `yaml:"node_cert"` // Path to X.509 certificate for node-to-node communication + NodeKey string `yaml:"node_key"` // Path to X.509 private key for node-to-node communication + NodeCACert string `yaml:"node_ca_cert"` // Path to CA certificate (optional, uses system CA if not set) + NodeNoVerify bool `yaml:"node_no_verify"` // Skip certificate verification (for testing/self-signed certs) - // Validate discovery config - errs = append(errs, validate.ValidateDiscovery(validate.DiscoveryConfig{ - BootstrapPeers: c.Discovery.BootstrapPeers, - DiscoveryInterval: c.Discovery.DiscoveryInterval, - BootstrapPort: c.Discovery.BootstrapPort, - HttpAdvAddress: c.Discovery.HttpAdvAddress, - RaftAdvAddress: c.Discovery.RaftAdvAddress, - })...) + // Dynamic discovery configuration (always enabled) + ClusterSyncInterval time.Duration `yaml:"cluster_sync_interval"` // default: 30s + PeerInactivityLimit time.Duration `yaml:"peer_inactivity_limit"` // default: 24h + MinClusterSize int `yaml:"min_cluster_size"` // default: 1 - // Validate security config - errs = append(errs, validate.ValidateSecurity(validate.SecurityConfig{ - EnableTLS: c.Security.EnableTLS, - PrivateKeyFile: c.Security.PrivateKeyFile, - CertificateFile: c.Security.CertificateFile, - })...) + // Olric cache configuration + OlricHTTPPort int `yaml:"olric_http_port"` // Olric HTTP API port (default: 3320) + OlricMemberlistPort int `yaml:"olric_memberlist_port"` // Olric memberlist port (default: 3322) - // Validate logging config - errs = append(errs, validate.ValidateLogging(validate.LoggingConfig{ - Level: c.Logging.Level, - Format: c.Logging.Format, - OutputFile: c.Logging.OutputFile, - })...) + // IPFS storage configuration + IPFS IPFSConfig `yaml:"ipfs"` +} - return errs +// IPFSConfig contains IPFS storage configuration +type IPFSConfig struct { + // ClusterAPIURL is the IPFS Cluster HTTP API URL (e.g., "http://localhost:9094") + // If empty, IPFS storage is disabled for this node + ClusterAPIURL string `yaml:"cluster_api_url"` + + // APIURL is the IPFS HTTP API URL for content retrieval (e.g., "http://localhost:5001") + // If empty, defaults to "http://localhost:5001" + APIURL string `yaml:"api_url"` + + // Timeout for IPFS operations + // If zero, defaults to 60 seconds + Timeout time.Duration `yaml:"timeout"` + + // ReplicationFactor is the replication factor for pinned content + // If zero, defaults to 3 + ReplicationFactor int `yaml:"replication_factor"` + + // EnableEncryption enables client-side encryption before upload + // Defaults to true + EnableEncryption bool `yaml:"enable_encryption"` +} + +// DiscoveryConfig contains peer discovery configuration +type DiscoveryConfig struct { + BootstrapPeers []string `yaml:"bootstrap_peers"` // Peer addresses to connect to + DiscoveryInterval time.Duration `yaml:"discovery_interval"` // Discovery announcement interval + BootstrapPort int `yaml:"bootstrap_port"` // Default port for peer discovery + HttpAdvAddress string `yaml:"http_adv_address"` // HTTP advertisement address + RaftAdvAddress string `yaml:"raft_adv_address"` // Raft advertisement + NodeNamespace string `yaml:"node_namespace"` // Namespace for node identifiers +} + +// SecurityConfig contains security-related configuration +type SecurityConfig struct { + EnableTLS bool `yaml:"enable_tls"` + PrivateKeyFile string `yaml:"private_key_file"` + CertificateFile string `yaml:"certificate_file"` +} + +// LoggingConfig contains logging configuration +type LoggingConfig struct { + Level string `yaml:"level"` // debug, info, warn, error + Format string `yaml:"format"` // json, console + OutputFile string `yaml:"output_file"` // Empty for stdout +} + +// HTTPGatewayConfig contains HTTP reverse proxy gateway configuration +type HTTPGatewayConfig struct { + Enabled bool `yaml:"enabled"` // Enable HTTP gateway + ListenAddr string `yaml:"listen_addr"` // Address to listen on (e.g., ":8080") + NodeName string `yaml:"node_name"` // Node name for routing + Routes map[string]RouteConfig `yaml:"routes"` // Service routes + HTTPS HTTPSConfig `yaml:"https"` // HTTPS/TLS configuration + SNI SNIConfig `yaml:"sni"` // SNI-based TCP routing configuration + + // Full gateway configuration (for API, auth, pubsub) + ClientNamespace string `yaml:"client_namespace"` // Namespace for network client + RQLiteDSN string `yaml:"rqlite_dsn"` // RQLite database DSN + OlricServers []string `yaml:"olric_servers"` // List of Olric server addresses + OlricTimeout time.Duration `yaml:"olric_timeout"` // Timeout for Olric operations + IPFSClusterAPIURL string `yaml:"ipfs_cluster_api_url"` // IPFS Cluster API URL + IPFSAPIURL string `yaml:"ipfs_api_url"` // IPFS API URL + IPFSTimeout time.Duration `yaml:"ipfs_timeout"` // Timeout for IPFS operations + + // WebRTC configuration for video/audio calls + TURN *TURNConfig `yaml:"turn"` // TURN/STUN server configuration + SFU *SFUConfig `yaml:"sfu"` // SFU (Selective Forwarding Unit) configuration +} + +// HTTPSConfig contains HTTPS/TLS configuration for the gateway +type HTTPSConfig struct { + Enabled bool `yaml:"enabled"` // Enable HTTPS (port 443) + Domain string `yaml:"domain"` // Primary domain (e.g., node-123.orama.network) + AutoCert bool `yaml:"auto_cert"` // Use Let's Encrypt for automatic certificate + UseSelfSigned bool `yaml:"use_self_signed"` // Use self-signed certificates (pre-generated) + CertFile string `yaml:"cert_file"` // Path to certificate file (if not using auto_cert) + KeyFile string `yaml:"key_file"` // Path to key file (if not using auto_cert) + CacheDir string `yaml:"cache_dir"` // Directory for Let's Encrypt certificate cache + HTTPPort int `yaml:"http_port"` // HTTP port for ACME challenge (default: 80) + HTTPSPort int `yaml:"https_port"` // HTTPS port (default: 443) + Email string `yaml:"email"` // Email for Let's Encrypt account +} + +// SNIConfig contains SNI-based TCP routing configuration for port 7001 +type SNIConfig struct { + Enabled bool `yaml:"enabled"` // Enable SNI-based TCP routing + ListenAddr string `yaml:"listen_addr"` // Address to listen on (e.g., ":7001") + Routes map[string]string `yaml:"routes"` // SNI hostname -> backend address mapping + CertFile string `yaml:"cert_file"` // Path to certificate file + KeyFile string `yaml:"key_file"` // Path to key file +} + +// RouteConfig defines a single reverse proxy route +type RouteConfig struct { + PathPrefix string `yaml:"path_prefix"` // URL path prefix (e.g., "/rqlite/http") + BackendURL string `yaml:"backend_url"` // Backend service URL + Timeout time.Duration `yaml:"timeout"` // Request timeout + WebSocket bool `yaml:"websocket"` // Support WebSocket upgrades +} + +// ClientConfig represents configuration for network clients +type ClientConfig struct { + AppName string `yaml:"app_name"` + DatabaseName string `yaml:"database_name"` + BootstrapPeers []string `yaml:"bootstrap_peers"` + ConnectTimeout time.Duration `yaml:"connect_timeout"` + RetryAttempts int `yaml:"retry_attempts"` +} + +// TURNConfig contains TURN/STUN server credential configuration +type TURNConfig struct { + // SharedSecret is the shared secret for TURN credential generation (HMAC-SHA1) + // Should be set via TURN_SHARED_SECRET environment variable + SharedSecret string `yaml:"shared_secret"` + + // TTL is the time-to-live for generated credentials + // Default: 24 hours + TTL time.Duration `yaml:"ttl"` + + // ExternalHost is the external hostname or IP address for STUN/TURN URLs + // - Production: Set to your public domain (e.g., "turn.example.com") + // - Development: Leave empty for auto-detection of LAN IP + // Can also be set via TURN_EXTERNAL_HOST environment variable + ExternalHost string `yaml:"external_host"` + + // STUNURLs are the STUN server URLs to return to clients + // Use "::" as placeholder for ExternalHost (e.g., "stun:::3478" -> "stun:turn.example.com:3478") + // e.g., ["stun:::3478"] or ["stun:gateway.orama.com:3478"] + STUNURLs []string `yaml:"stun_urls"` + + // TURNURLs are the TURN server URLs to return to clients + // Use "::" as placeholder for ExternalHost (e.g., "turn:::3478" -> "turn:turn.example.com:3478") + // e.g., ["turn:::3478?transport=udp"] or ["turn:gateway.orama.com:3478?transport=udp"] + TURNURLs []string `yaml:"turn_urls"` + + // TLSEnabled indicates whether TURNS (TURN over TLS) is available + // When true, turns:// URLs will be included in the response + TLSEnabled bool `yaml:"tls_enabled"` +} + +// SFUConfig contains WebRTC SFU (Selective Forwarding Unit) configuration +type SFUConfig struct { + // Enabled enables the SFU service + Enabled bool `yaml:"enabled"` + + // MaxParticipants is the maximum number of participants per room + // Default: 10 + MaxParticipants int `yaml:"max_participants"` + + // MediaTimeout is the timeout for media operations + // Default: 30 seconds + MediaTimeout time.Duration `yaml:"media_timeout"` + + // ICEServers are additional ICE servers for WebRTC connections + // These are used in addition to the TURN servers from TURNConfig + ICEServers []ICEServerConfig `yaml:"ice_servers"` +} + +// ICEServerConfig represents a single ICE server configuration +type ICEServerConfig struct { + URLs []string `yaml:"urls"` + Username string `yaml:"username,omitempty"` + Credential string `yaml:"credential,omitempty"` +} + +// TURNServerConfig contains built-in TURN server configuration +type TURNServerConfig struct { + // Enabled enables the built-in TURN server + Enabled bool `yaml:"enabled"` + + // ListenAddr is the UDP address to listen on (e.g., "0.0.0.0:3478") + ListenAddr string `yaml:"listen_addr"` + + // PublicIP is the public IP address to advertise for relay + // If empty, will try to auto-detect + PublicIP string `yaml:"public_ip"` + + // Realm is the TURN realm (e.g., "orama.network") + Realm string `yaml:"realm"` + + // MinPort and MaxPort define the relay port range + MinPort uint16 `yaml:"min_port"` + MaxPort uint16 `yaml:"max_port"` + + // TLS Configuration for TURNS (TURN over TLS) + // TLSEnabled enables TURNS listener + TLSEnabled bool `yaml:"tls_enabled"` + + // TLSListenAddr is the TCP/TLS address to listen on (e.g., "0.0.0.0:443") + TLSListenAddr string `yaml:"tls_listen_addr"` + + // TLSCertFile is the path to the TLS certificate file + TLSCertFile string `yaml:"tls_cert_file"` + + // TLSKeyFile is the path to the TLS private key file + TLSKeyFile string `yaml:"tls_key_file"` } // ParseMultiaddrs converts string addresses to multiaddr objects diff --git a/pkg/environments/templates/node.yaml b/pkg/environments/templates/node.yaml index 2024f5c..58524de 100644 --- a/pkg/environments/templates/node.yaml +++ b/pkg/environments/templates/node.yaml @@ -47,6 +47,15 @@ logging: level: "info" format: "console" +# Built-in TURN server for WebRTC NAT traversal +turn_server: + enabled: true + listen_addr: "0.0.0.0:3478" + public_ip: "" # Auto-detect if empty, or set to your public IP + realm: "orama.network" + min_port: 49152 + max_port: 65535 + http_gateway: enabled: true listen_addr: "{{if .EnableHTTPS}}:{{.HTTPSPort}}{{else}}:{{.UnifiedGatewayPort}}{{end}}" @@ -86,3 +95,19 @@ http_gateway: # Routes for internal service reverse proxy (kept for backwards compatibility but not used by full gateway) routes: {} + + # TURN/STUN URLs returned to clients (points to built-in TURN server) + turn: + shared_secret: "dev-secret-12345" + ttl: "24h" + stun_urls: + - "stun:::3478" + turn_urls: + - "turn:::3478?transport=udp" + + # SFU (Selective Forwarding Unit) configuration for WebRTC group calls + sfu: + enabled: true + max_participants: 10 + media_timeout: "30s" + ice_servers: [] # Additional ICE servers beyond TURN config (optional) diff --git a/pkg/gateway/gateway.go b/pkg/gateway/gateway.go index fce6bac..41a3e9c 100644 --- a/pkg/gateway/gateway.go +++ b/pkg/gateway/gateway.go @@ -1,31 +1,75 @@ -// Package gateway provides the main API Gateway for the Orama Network. -// It orchestrates traffic between clients and various backend services including -// distributed caching (Olric), decentralized storage (IPFS), and serverless -// WebAssembly (WASM) execution. The gateway implements robust security through -// wallet-based cryptographic authentication and JWT lifecycle management. package gateway import ( "context" + "crypto/rand" + "crypto/rsa" + "crypto/x509" "database/sql" + "encoding/pem" + "fmt" + "net" + "os" + "path/filepath" + "strings" "sync" "time" "github.com/DeBrosOfficial/network/pkg/client" + "github.com/DeBrosOfficial/network/pkg/config" "github.com/DeBrosOfficial/network/pkg/gateway/auth" - authhandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/auth" - "github.com/DeBrosOfficial/network/pkg/gateway/handlers/cache" - pubsubhandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/pubsub" - serverlesshandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/serverless" - "github.com/DeBrosOfficial/network/pkg/gateway/handlers/storage" "github.com/DeBrosOfficial/network/pkg/ipfs" "github.com/DeBrosOfficial/network/pkg/logging" "github.com/DeBrosOfficial/network/pkg/olric" + "github.com/DeBrosOfficial/network/pkg/pubsub" "github.com/DeBrosOfficial/network/pkg/rqlite" "github.com/DeBrosOfficial/network/pkg/serverless" + "github.com/multiformats/go-multiaddr" + olriclib "github.com/olric-data/olric" "go.uber.org/zap" + + _ "github.com/rqlite/gorqlite/stdlib" ) +const ( + olricInitMaxAttempts = 5 + olricInitInitialBackoff = 500 * time.Millisecond + olricInitMaxBackoff = 5 * time.Second +) + +// Config holds configuration for the gateway server +type Config struct { + ListenAddr string + ClientNamespace string + BootstrapPeers []string + NodePeerID string // The node's actual peer ID from its identity file + + // Optional DSN for rqlite database/sql driver, e.g. "http://localhost:4001" + // If empty, defaults to "http://localhost:4001". + RQLiteDSN string + + // HTTPS configuration + EnableHTTPS bool // Enable HTTPS with ACME (Let's Encrypt) + DomainName string // Domain name for HTTPS certificate + TLSCacheDir string // Directory to cache TLS certificates (default: ~/.orama/tls-cache) + + // Olric cache configuration + OlricServers []string // List of Olric server addresses (e.g., ["localhost:3320"]). If empty, defaults to ["localhost:3320"] + OlricTimeout time.Duration // Timeout for Olric operations (default: 10s) + + // IPFS Cluster configuration + IPFSClusterAPIURL string // IPFS Cluster HTTP API URL (e.g., "http://localhost:9094"). If empty, gateway will discover from node configs + IPFSAPIURL string // IPFS HTTP API URL for content retrieval (e.g., "http://localhost:5001"). If empty, gateway will discover from node configs + IPFSTimeout time.Duration // Timeout for IPFS operations (default: 60s) + IPFSReplicationFactor int // Replication factor for pins (default: 3) + IPFSEnableEncryption bool // Enable client-side encryption before upload (default: true, discovered from node configs) + + // TURN/STUN configuration for WebRTC + TURN *config.TURNConfig + + // SFU configuration for WebRTC group calls + SFU *config.SFUConfig +} type Gateway struct { logger *logging.ColoredLogger @@ -42,29 +86,28 @@ type Gateway struct { // Olric cache client olricClient *olric.Client olricMu sync.RWMutex - cacheHandlers *cache.CacheHandlers // IPFS storage client - ipfsClient ipfs.IPFSClient - storageHandlers *storage.Handlers + ipfsClient ipfs.IPFSClient // Local pub/sub bypass for same-gateway subscribers localSubscribers map[string][]*localSubscriber // topic+namespace -> subscribers presenceMembers map[string][]PresenceMember // topicKey -> members mu sync.RWMutex presenceMu sync.RWMutex - pubsubHandlers *pubsubhandlers.PubSubHandlers // Serverless function engine serverlessEngine *serverless.Engine serverlessRegistry *serverless.Registry serverlessInvoker *serverless.Invoker serverlessWSMgr *serverless.WSManager - serverlessHandlers *serverlesshandlers.ServerlessHandlers + serverlessHandlers *ServerlessHandlers // Authentication service - authService *auth.Service - authHandlers *authhandlers.Handlers + authService *auth.Service + + // SFU manager for WebRTC group calls + sfuManager *SFUManager } // localSubscriber represents a WebSocket subscriber for local message delivery @@ -81,113 +124,359 @@ type PresenceMember struct { ConnID string `json:"-"` // Internal: for tracking which connection } -// authClientAdapter adapts client.NetworkClient to authhandlers.NetworkClient -type authClientAdapter struct { - client client.NetworkClient -} - -func (a *authClientAdapter) Database() authhandlers.DatabaseClient { - return &authDatabaseAdapter{db: a.client.Database()} -} - -// authDatabaseAdapter adapts client.DatabaseClient to authhandlers.DatabaseClient -type authDatabaseAdapter struct { - db client.DatabaseClient -} - -func (a *authDatabaseAdapter) Query(ctx context.Context, sql string, args ...interface{}) (*authhandlers.QueryResult, error) { - result, err := a.db.Query(ctx, sql, args...) - if err != nil { - return nil, err - } - // Convert client.QueryResult to authhandlers.QueryResult - // The auth handlers expect []interface{} but client returns [][]interface{} - convertedRows := make([]interface{}, len(result.Rows)) - for i, row := range result.Rows { - convertedRows[i] = row - } - return &authhandlers.QueryResult{ - Count: int(result.Count), - Rows: convertedRows, - }, nil -} - -// New creates and initializes a new Gateway instance. -// It establishes all necessary service connections and dependencies. +// New creates and initializes a new Gateway instance func New(logger *logging.ColoredLogger, cfg *Config) (*Gateway, error) { - logger.ComponentInfo(logging.ComponentGeneral, "Creating gateway dependencies...") + logger.ComponentInfo(logging.ComponentGeneral, "Building client config...") - // Initialize all dependencies (network client, database, cache, storage, serverless) - deps, err := NewDependencies(logger, cfg) + // Build client config from gateway cfg + cliCfg := client.DefaultClientConfig(cfg.ClientNamespace) + if len(cfg.BootstrapPeers) > 0 { + cliCfg.BootstrapPeers = cfg.BootstrapPeers + } + + logger.ComponentInfo(logging.ComponentGeneral, "Creating network client...") + c, err := client.NewClient(cliCfg) if err != nil { - logger.ComponentError(logging.ComponentGeneral, "failed to create dependencies", zap.Error(err)) + logger.ComponentError(logging.ComponentClient, "failed to create network client", zap.Error(err)) return nil, err } + logger.ComponentInfo(logging.ComponentGeneral, "Connecting network client...") + if err := c.Connect(); err != nil { + logger.ComponentError(logging.ComponentClient, "failed to connect network client", zap.Error(err)) + return nil, err + } + + logger.ComponentInfo(logging.ComponentClient, "Network client connected", + zap.String("namespace", cliCfg.AppName), + zap.Int("peer_count", len(cliCfg.BootstrapPeers)), + ) + logger.ComponentInfo(logging.ComponentGeneral, "Creating gateway instance...") gw := &Gateway{ - logger: logger, - cfg: cfg, - client: deps.Client, - nodePeerID: cfg.NodePeerID, - startedAt: time.Now(), - sqlDB: deps.SQLDB, - ormClient: deps.ORMClient, - ormHTTP: deps.ORMHTTP, - olricClient: deps.OlricClient, - ipfsClient: deps.IPFSClient, - serverlessEngine: deps.ServerlessEngine, - serverlessRegistry: deps.ServerlessRegistry, - serverlessInvoker: deps.ServerlessInvoker, - serverlessWSMgr: deps.ServerlessWSMgr, - serverlessHandlers: deps.ServerlessHandlers, - authService: deps.AuthService, - localSubscribers: make(map[string][]*localSubscriber), - presenceMembers: make(map[string][]PresenceMember), + logger: logger, + cfg: cfg, + client: c, + nodePeerID: cfg.NodePeerID, + startedAt: time.Now(), + localSubscribers: make(map[string][]*localSubscriber), + presenceMembers: make(map[string][]PresenceMember), } - // Initialize handler instances - gw.pubsubHandlers = pubsubhandlers.NewPubSubHandlers(deps.Client, logger) - - if deps.OlricClient != nil { - gw.cacheHandlers = cache.NewCacheHandlers(logger, deps.OlricClient) + logger.ComponentInfo(logging.ComponentGeneral, "Initializing RQLite ORM HTTP gateway...") + dsn := cfg.RQLiteDSN + if dsn == "" { + dsn = "http://localhost:5001" } + db, dbErr := sql.Open("rqlite", dsn) + if dbErr != nil { + logger.ComponentWarn(logging.ComponentGeneral, "failed to open rqlite sql db; http orm gateway disabled", zap.Error(dbErr)) + } else { + // Configure connection pool with proper timeouts and limits + db.SetMaxOpenConns(25) // Maximum number of open connections + db.SetMaxIdleConns(5) // Maximum number of idle connections + db.SetConnMaxLifetime(5 * time.Minute) // Maximum lifetime of a connection + db.SetConnMaxIdleTime(2 * time.Minute) // Maximum idle time before closing - if deps.IPFSClient != nil { - gw.storageHandlers = storage.New(deps.IPFSClient, logger, storage.Config{ - IPFSReplicationFactor: cfg.IPFSReplicationFactor, - IPFSAPIURL: cfg.IPFSAPIURL, - }) - } - - if deps.AuthService != nil { - // Create adapter for auth handlers to use the client - authClientAdapter := &authClientAdapter{client: deps.Client} - gw.authHandlers = authhandlers.NewHandlers( - logger, - deps.AuthService, - authClientAdapter, - cfg.ClientNamespace, - gw.withInternalAuth, + gw.sqlDB = db + orm := rqlite.NewClient(db) + gw.ormClient = orm + gw.ormHTTP = rqlite.NewHTTPGateway(orm, "/v1/db") + // Set a reasonable timeout for HTTP requests (30 seconds) + gw.ormHTTP.Timeout = 30 * time.Second + logger.ComponentInfo(logging.ComponentGeneral, "RQLite ORM HTTP gateway ready", + zap.String("dsn", dsn), + zap.String("base_path", "/v1/db"), + zap.Duration("timeout", gw.ormHTTP.Timeout), ) } - // Start background Olric reconnection if initial connection failed - if deps.OlricClient == nil { - olricCfg := olric.Config{ - Servers: cfg.OlricServers, - Timeout: cfg.OlricTimeout, + logger.ComponentInfo(logging.ComponentGeneral, "Initializing Olric cache client...") + + // Discover Olric servers dynamically from LibP2P peers if not explicitly configured + olricServers := cfg.OlricServers + if len(olricServers) == 0 { + logger.ComponentInfo(logging.ComponentGeneral, "Olric servers not configured, discovering from LibP2P peers...") + discovered := discoverOlricServers(c, logger.Logger) + if len(discovered) > 0 { + olricServers = discovered + logger.ComponentInfo(logging.ComponentGeneral, "Discovered Olric servers from LibP2P peers", + zap.Strings("servers", olricServers)) + } else { + // Fallback to localhost for local development + olricServers = []string{"localhost:3320"} + logger.ComponentInfo(logging.ComponentGeneral, "No Olric servers discovered, using localhost fallback") } - if len(olricCfg.Servers) == 0 { - olricCfg.Servers = []string{"localhost:3320"} - } - gw.startOlricReconnectLoop(olricCfg) + } else { + logger.ComponentInfo(logging.ComponentGeneral, "Using explicitly configured Olric servers", + zap.Strings("servers", olricServers)) } - logger.ComponentInfo(logging.ComponentGeneral, "Gateway creation completed") + olricCfg := olric.Config{ + Servers: olricServers, + Timeout: cfg.OlricTimeout, + } + olricClient, olricErr := initializeOlricClientWithRetry(olricCfg, logger) + if olricErr != nil { + logger.ComponentWarn(logging.ComponentGeneral, "failed to initialize Olric cache client; cache endpoints disabled", zap.Error(olricErr)) + gw.startOlricReconnectLoop(olricCfg) + } else { + gw.setOlricClient(olricClient) + logger.ComponentInfo(logging.ComponentGeneral, "Olric cache client ready", + zap.Strings("servers", olricCfg.Servers), + zap.Duration("timeout", olricCfg.Timeout), + ) + } + + logger.ComponentInfo(logging.ComponentGeneral, "Initializing IPFS Cluster client...") + + // Discover IPFS endpoints from node configs if not explicitly configured + ipfsClusterURL := cfg.IPFSClusterAPIURL + ipfsAPIURL := cfg.IPFSAPIURL + ipfsTimeout := cfg.IPFSTimeout + ipfsReplicationFactor := cfg.IPFSReplicationFactor + ipfsEnableEncryption := cfg.IPFSEnableEncryption + + if ipfsClusterURL == "" { + logger.ComponentInfo(logging.ComponentGeneral, "IPFS Cluster URL not configured, discovering from node configs...") + discovered := discoverIPFSFromNodeConfigs(logger.Logger) + if discovered.clusterURL != "" { + ipfsClusterURL = discovered.clusterURL + ipfsAPIURL = discovered.apiURL + if discovered.timeout > 0 { + ipfsTimeout = discovered.timeout + } + if discovered.replicationFactor > 0 { + ipfsReplicationFactor = discovered.replicationFactor + } + ipfsEnableEncryption = discovered.enableEncryption + logger.ComponentInfo(logging.ComponentGeneral, "Discovered IPFS endpoints from node configs", + zap.String("cluster_url", ipfsClusterURL), + zap.String("api_url", ipfsAPIURL), + zap.Bool("encryption_enabled", ipfsEnableEncryption)) + } else { + // Fallback to localhost defaults + ipfsClusterURL = "http://localhost:9094" + ipfsAPIURL = "http://localhost:5001" + ipfsEnableEncryption = true // Default to true + logger.ComponentInfo(logging.ComponentGeneral, "No IPFS config found in node configs, using localhost defaults") + } + } + + if ipfsAPIURL == "" { + ipfsAPIURL = "http://localhost:5001" + } + if ipfsTimeout == 0 { + ipfsTimeout = 60 * time.Second + } + if ipfsReplicationFactor == 0 { + ipfsReplicationFactor = 3 + } + if !cfg.IPFSEnableEncryption && !ipfsEnableEncryption { + // Only disable if explicitly set to false in both places + ipfsEnableEncryption = false + } else { + // Default to true if not explicitly disabled + ipfsEnableEncryption = true + } + + ipfsCfg := ipfs.Config{ + ClusterAPIURL: ipfsClusterURL, + Timeout: ipfsTimeout, + } + ipfsClient, ipfsErr := ipfs.NewClient(ipfsCfg, logger.Logger) + if ipfsErr != nil { + logger.ComponentWarn(logging.ComponentGeneral, "failed to initialize IPFS Cluster client; storage endpoints disabled", zap.Error(ipfsErr)) + } else { + gw.ipfsClient = ipfsClient + + // Check peer count and warn if insufficient (use background context to avoid blocking) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if peerCount, err := ipfsClient.GetPeerCount(ctx); err == nil { + if peerCount < ipfsReplicationFactor { + logger.ComponentWarn(logging.ComponentGeneral, "insufficient cluster peers for replication factor", + zap.Int("peer_count", peerCount), + zap.Int("replication_factor", ipfsReplicationFactor), + zap.String("message", "Some pin operations may fail until more peers join the cluster")) + } else { + logger.ComponentInfo(logging.ComponentGeneral, "IPFS Cluster peer count sufficient", + zap.Int("peer_count", peerCount), + zap.Int("replication_factor", ipfsReplicationFactor)) + } + } else { + logger.ComponentWarn(logging.ComponentGeneral, "failed to get cluster peer count", zap.Error(err)) + } + + logger.ComponentInfo(logging.ComponentGeneral, "IPFS Cluster client ready", + zap.String("cluster_api_url", ipfsCfg.ClusterAPIURL), + zap.String("ipfs_api_url", ipfsAPIURL), + zap.Duration("timeout", ipfsCfg.Timeout), + zap.Int("replication_factor", ipfsReplicationFactor), + zap.Bool("encryption_enabled", ipfsEnableEncryption), + ) + } + // Store IPFS settings in gateway for use by handlers + gw.cfg.IPFSAPIURL = ipfsAPIURL + gw.cfg.IPFSReplicationFactor = ipfsReplicationFactor + gw.cfg.IPFSEnableEncryption = ipfsEnableEncryption + + // Initialize serverless function engine + logger.ComponentInfo(logging.ComponentGeneral, "Initializing serverless function engine...") + if gw.ormClient != nil && gw.ipfsClient != nil { + // Create serverless registry (stores functions in RQLite + IPFS) + registryCfg := serverless.RegistryConfig{ + IPFSAPIURL: ipfsAPIURL, + } + registry := serverless.NewRegistry(gw.ormClient, gw.ipfsClient, registryCfg, logger.Logger) + gw.serverlessRegistry = registry + + // Create WebSocket manager for function streaming + gw.serverlessWSMgr = serverless.NewWSManager(logger.Logger) + + // Get underlying Olric client if available + var olricClient olriclib.Client + if oc := gw.getOlricClient(); oc != nil { + olricClient = oc.UnderlyingClient() + } + + // Create host functions provider (allows functions to call Orama services) + // Get pubsub adapter from client for serverless functions + var pubsubAdapter *pubsub.ClientAdapter + if gw.client != nil { + if concreteClient, ok := gw.client.(*client.Client); ok { + pubsubAdapter = concreteClient.PubSubAdapter() + if pubsubAdapter != nil { + logger.ComponentInfo(logging.ComponentGeneral, "pubsub adapter available for serverless functions") + } else { + logger.ComponentWarn(logging.ComponentGeneral, "pubsub adapter is nil - serverless pubsub will be unavailable") + } + } + } + + hostFuncsCfg := serverless.HostFunctionsConfig{ + IPFSAPIURL: ipfsAPIURL, + HTTPTimeout: 30 * time.Second, + } + hostFuncs := serverless.NewHostFunctions( + gw.ormClient, + olricClient, + gw.ipfsClient, + pubsubAdapter, // pubsub adapter for serverless functions + gw.serverlessWSMgr, + nil, // secrets manager - TODO: implement + hostFuncsCfg, + logger.Logger, + ) + + // Create WASM engine configuration + engineCfg := serverless.DefaultConfig() + engineCfg.DefaultMemoryLimitMB = 128 + engineCfg.MaxMemoryLimitMB = 256 + engineCfg.DefaultTimeoutSeconds = 30 + engineCfg.MaxTimeoutSeconds = 60 + engineCfg.ModuleCacheSize = 100 + + // Create WASM engine + engine, engineErr := serverless.NewEngine(engineCfg, registry, hostFuncs, logger.Logger, serverless.WithInvocationLogger(registry)) + if engineErr != nil { + logger.ComponentWarn(logging.ComponentGeneral, "failed to initialize serverless engine; functions disabled", zap.Error(engineErr)) + } else { + gw.serverlessEngine = engine + + // Create invoker + gw.serverlessInvoker = serverless.NewInvoker(engine, registry, hostFuncs, logger.Logger) + + // Create trigger manager + triggerManager := serverless.NewDBTriggerManager(gw.ormClient, logger.Logger) + + // Create HTTP handlers + gw.serverlessHandlers = NewServerlessHandlers( + gw.serverlessInvoker, + registry, + gw.serverlessWSMgr, + triggerManager, + logger.Logger, + ) + + // Initialize auth service + // For now using ephemeral key, can be loaded from config later + key, _ := rsa.GenerateKey(rand.Reader, 2048) + keyPEM := pem.EncodeToMemory(&pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(key), + }) + authService, err := auth.NewService(logger, c, string(keyPEM), cfg.ClientNamespace) + if err != nil { + logger.ComponentError(logging.ComponentGeneral, "failed to initialize auth service", zap.Error(err)) + } else { + gw.authService = authService + } + + logger.ComponentInfo(logging.ComponentGeneral, "Serverless function engine ready", + zap.Int("default_memory_mb", engineCfg.DefaultMemoryLimitMB), + zap.Int("default_timeout_sec", engineCfg.DefaultTimeoutSeconds), + zap.Int("module_cache_size", engineCfg.ModuleCacheSize), + ) + } + } else { + logger.ComponentWarn(logging.ComponentGeneral, "serverless engine requires RQLite and IPFS; functions disabled") + } + + // Initialize SFU manager for WebRTC calls + if err := gw.initializeSFUManager(); err != nil { + logger.ComponentWarn(logging.ComponentGeneral, "failed to initialize SFU manager; WebRTC calls disabled", zap.Error(err)) + } + + logger.ComponentInfo(logging.ComponentGeneral, "Gateway creation completed, returning...") return gw, nil } +// withInternalAuth creates a context for internal gateway operations that bypass authentication +func (g *Gateway) withInternalAuth(ctx context.Context) context.Context { + return client.WithInternalAuth(ctx) +} + +// Close disconnects the gateway client +func (g *Gateway) Close() { + // Close SFU manager first + if g.sfuManager != nil { + if err := g.sfuManager.Close(); err != nil { + g.logger.ComponentWarn(logging.ComponentGeneral, "error during SFU manager close", zap.Error(err)) + } + } + // Close serverless engine + if g.serverlessEngine != nil { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + if err := g.serverlessEngine.Close(ctx); err != nil { + g.logger.ComponentWarn(logging.ComponentGeneral, "error during serverless engine close", zap.Error(err)) + } + cancel() + } + if g.client != nil { + if err := g.client.Disconnect(); err != nil { + g.logger.ComponentWarn(logging.ComponentClient, "error during client disconnect", zap.Error(err)) + } + } + if g.sqlDB != nil { + _ = g.sqlDB.Close() + } + if client := g.getOlricClient(); client != nil { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := client.Close(ctx); err != nil { + g.logger.ComponentWarn(logging.ComponentGeneral, "error during Olric client close", zap.Error(err)) + } + } + if g.ipfsClient != nil { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := g.ipfsClient.Close(ctx); err != nil { + g.logger.ComponentWarn(logging.ComponentGeneral, "error during IPFS client close", zap.Error(err)) + } + } +} + // getLocalSubscribers returns all local subscribers for a given topic and namespace func (g *Gateway) getLocalSubscribers(topic, namespace string) []*localSubscriber { topicKey := namespace + "." + topic @@ -197,32 +486,23 @@ func (g *Gateway) getLocalSubscribers(topic, namespace string) []*localSubscribe return nil } -// setOlricClient atomically sets the Olric client and reinitializes cache handlers. func (g *Gateway) setOlricClient(client *olric.Client) { g.olricMu.Lock() defer g.olricMu.Unlock() g.olricClient = client - if client != nil { - g.cacheHandlers = cache.NewCacheHandlers(g.logger, client) - } } -// getOlricClient atomically retrieves the current Olric client. func (g *Gateway) getOlricClient() *olric.Client { g.olricMu.RLock() defer g.olricMu.RUnlock() return g.olricClient } -// startOlricReconnectLoop starts a background goroutine that continuously attempts -// to reconnect to the Olric cluster with exponential backoff. func (g *Gateway) startOlricReconnectLoop(cfg olric.Config) { go func() { retryDelay := 5 * time.Second - maxBackoff := 30 * time.Second - for { - client, err := olric.NewClient(cfg, g.logger.Logger) + client, err := initializeOlricClientWithRetry(cfg, g.logger) if err == nil { g.setOlricClient(client) g.logger.ComponentInfo(logging.ComponentGeneral, "Olric cache client connected after background retries", @@ -236,13 +516,211 @@ func (g *Gateway) startOlricReconnectLoop(cfg olric.Config) { zap.Error(err)) time.Sleep(retryDelay) - if retryDelay < maxBackoff { + if retryDelay < olricInitMaxBackoff { retryDelay *= 2 - if retryDelay > maxBackoff { - retryDelay = maxBackoff + if retryDelay > olricInitMaxBackoff { + retryDelay = olricInitMaxBackoff } } } }() } +func initializeOlricClientWithRetry(cfg olric.Config, logger *logging.ColoredLogger) (*olric.Client, error) { + backoff := olricInitInitialBackoff + + for attempt := 1; attempt <= olricInitMaxAttempts; attempt++ { + client, err := olric.NewClient(cfg, logger.Logger) + if err == nil { + if attempt > 1 { + logger.ComponentInfo(logging.ComponentGeneral, "Olric cache client initialized after retries", + zap.Int("attempts", attempt)) + } + return client, nil + } + + logger.ComponentWarn(logging.ComponentGeneral, "Olric cache client init attempt failed", + zap.Int("attempt", attempt), + zap.Duration("retry_in", backoff), + zap.Error(err)) + + if attempt == olricInitMaxAttempts { + return nil, fmt.Errorf("failed to initialize Olric cache client after %d attempts: %w", attempt, err) + } + + time.Sleep(backoff) + backoff *= 2 + if backoff > olricInitMaxBackoff { + backoff = olricInitMaxBackoff + } + } + + return nil, fmt.Errorf("failed to initialize Olric cache client") +} + +// discoverOlricServers discovers Olric server addresses from LibP2P peers +// Returns a list of IP:port addresses where Olric servers are expected to run (port 3320) +func discoverOlricServers(networkClient client.NetworkClient, logger *zap.Logger) []string { + // Get network info to access peer information + networkInfo := networkClient.Network() + if networkInfo == nil { + logger.Debug("Network info not available for Olric discovery") + return nil + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + peers, err := networkInfo.GetPeers(ctx) + if err != nil { + logger.Debug("Failed to get peers for Olric discovery", zap.Error(err)) + return nil + } + + olricServers := make([]string, 0) + seen := make(map[string]bool) + + for _, peer := range peers { + for _, addrStr := range peer.Addresses { + // Parse multiaddr + ma, err := multiaddr.NewMultiaddr(addrStr) + if err != nil { + continue + } + + // Extract IP address + var ip string + if ipv4, err := ma.ValueForProtocol(multiaddr.P_IP4); err == nil && ipv4 != "" { + ip = ipv4 + } else if ipv6, err := ma.ValueForProtocol(multiaddr.P_IP6); err == nil && ipv6 != "" { + ip = ipv6 + } else { + continue + } + + // Skip localhost loopback addresses (we'll use localhost:3320 as fallback) + if ip == "localhost" || ip == "::1" { + continue + } + + // Build Olric server address (standard port 3320) + olricAddr := net.JoinHostPort(ip, "3320") + if !seen[olricAddr] { + olricServers = append(olricServers, olricAddr) + seen[olricAddr] = true + } + } + } + + // Also check peers from config + if cfg := networkClient.Config(); cfg != nil { + for _, peerAddr := range cfg.BootstrapPeers { + ma, err := multiaddr.NewMultiaddr(peerAddr) + if err != nil { + continue + } + + var ip string + if ipv4, err := ma.ValueForProtocol(multiaddr.P_IP4); err == nil && ipv4 != "" { + ip = ipv4 + } else if ipv6, err := ma.ValueForProtocol(multiaddr.P_IP6); err == nil && ipv6 != "" { + ip = ipv6 + } else { + continue + } + + // Skip localhost + if ip == "localhost" || ip == "::1" { + continue + } + + olricAddr := net.JoinHostPort(ip, "3320") + if !seen[olricAddr] { + olricServers = append(olricServers, olricAddr) + seen[olricAddr] = true + } + } + } + + // If we found servers, log them + if len(olricServers) > 0 { + logger.Info("Discovered Olric servers from LibP2P network", + zap.Strings("servers", olricServers)) + } + + return olricServers +} + +// ipfsDiscoveryResult holds discovered IPFS configuration +type ipfsDiscoveryResult struct { + clusterURL string + apiURL string + timeout time.Duration + replicationFactor int + enableEncryption bool +} + +// discoverIPFSFromNodeConfigs discovers IPFS configuration from node.yaml files +// Checks node-1.yaml through node-5.yaml for IPFS configuration +func discoverIPFSFromNodeConfigs(logger *zap.Logger) ipfsDiscoveryResult { + homeDir, err := os.UserHomeDir() + if err != nil { + logger.Debug("Failed to get home directory for IPFS discovery", zap.Error(err)) + return ipfsDiscoveryResult{} + } + + configDir := filepath.Join(homeDir, ".orama") + + // Try all node config files for IPFS settings + configFiles := []string{"node-1.yaml", "node-2.yaml", "node-3.yaml", "node-4.yaml", "node-5.yaml"} + + for _, filename := range configFiles { + configPath := filepath.Join(configDir, filename) + data, err := os.ReadFile(configPath) + if err != nil { + continue + } + + var nodeCfg config.Config + if err := config.DecodeStrict(strings.NewReader(string(data)), &nodeCfg); err != nil { + logger.Debug("Failed to parse node config for IPFS discovery", + zap.String("file", filename), zap.Error(err)) + continue + } + + // Check if IPFS is configured + if nodeCfg.Database.IPFS.ClusterAPIURL != "" { + result := ipfsDiscoveryResult{ + clusterURL: nodeCfg.Database.IPFS.ClusterAPIURL, + apiURL: nodeCfg.Database.IPFS.APIURL, + timeout: nodeCfg.Database.IPFS.Timeout, + replicationFactor: nodeCfg.Database.IPFS.ReplicationFactor, + enableEncryption: nodeCfg.Database.IPFS.EnableEncryption, + } + + if result.apiURL == "" { + result.apiURL = "http://localhost:5001" + } + if result.timeout == 0 { + result.timeout = 60 * time.Second + } + if result.replicationFactor == 0 { + result.replicationFactor = 3 + } + // Default encryption to true if not set + if !result.enableEncryption { + result.enableEncryption = true + } + + logger.Info("Discovered IPFS config from node config", + zap.String("file", filename), + zap.String("cluster_url", result.clusterURL), + zap.String("api_url", result.apiURL), + zap.Bool("encryption_enabled", result.enableEncryption)) + + return result + } + } + + return ipfsDiscoveryResult{} +} diff --git a/pkg/gateway/middleware.go b/pkg/gateway/middleware.go index 2dcd8aa..53ca4bb 100644 --- a/pkg/gateway/middleware.go +++ b/pkg/gateway/middleware.go @@ -191,6 +191,16 @@ func isPublicPath(p string) bool { return true } + // TURN credentials (public for development - requires secret for actual use) + if strings.HasPrefix(p, "/v1/turn/") { + return true + } + + // SFU endpoints (public for development) + if strings.HasPrefix(p, "/v1/sfu/") { + return true + } + switch p { case "/health", "/v1/health", "/status", "/v1/status", "/v1/auth/jwks", "/.well-known/jwks.json", "/v1/version", "/v1/auth/login", "/v1/auth/challenge", "/v1/auth/verify", "/v1/auth/register", "/v1/auth/refresh", "/v1/auth/logout", "/v1/auth/api-key", "/v1/auth/simple-key", "/v1/network/status", "/v1/network/peers": return true diff --git a/pkg/gateway/pubsub_handlers.go b/pkg/gateway/pubsub_handlers.go new file mode 100644 index 0000000..5f695b9 --- /dev/null +++ b/pkg/gateway/pubsub_handlers.go @@ -0,0 +1,475 @@ +package gateway + +import ( + "context" + "encoding/base64" + "encoding/json" + "fmt" + "net/http" + "strings" + "time" + + "github.com/DeBrosOfficial/network/pkg/client" + "github.com/DeBrosOfficial/network/pkg/pubsub" + "github.com/google/uuid" + "go.uber.org/zap" + + "github.com/gorilla/websocket" +) + +var wsUpgrader = websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + // For early development we accept any origin; tighten later. + CheckOrigin: func(r *http.Request) bool { return true }, +} + +// pubsubWebsocketHandler upgrades to WS, subscribes to a namespaced topic, and +// forwards received PubSub messages to the client. Messages sent by the client +// are published to the same namespaced topic. +func (g *Gateway) pubsubWebsocketHandler(w http.ResponseWriter, r *http.Request) { + if g.client == nil { + g.logger.ComponentWarn("gateway", "pubsub ws: client not initialized") + writeError(w, http.StatusServiceUnavailable, "client not initialized") + return + } + if r.Method != http.MethodGet { + g.logger.ComponentWarn("gateway", "pubsub ws: method not allowed") + writeError(w, http.StatusMethodNotAllowed, "method not allowed") + return + } + + // Resolve namespace from auth context + ns := resolveNamespaceFromRequest(r) + if ns == "" { + g.logger.ComponentWarn("gateway", "pubsub ws: namespace not resolved") + writeError(w, http.StatusForbidden, "namespace not resolved") + return + } + + topic := r.URL.Query().Get("topic") + if topic == "" { + g.logger.ComponentWarn("gateway", "pubsub ws: missing topic") + writeError(w, http.StatusBadRequest, "missing 'topic'") + return + } + + // Presence handling + enablePresence := r.URL.Query().Get("presence") == "true" + memberID := r.URL.Query().Get("member_id") + memberMetaStr := r.URL.Query().Get("member_meta") + var memberMeta map[string]interface{} + if memberMetaStr != "" { + _ = json.Unmarshal([]byte(memberMetaStr), &memberMeta) + } + + if enablePresence && memberID == "" { + g.logger.ComponentWarn("gateway", "pubsub ws: presence enabled but missing member_id") + writeError(w, http.StatusBadRequest, "missing 'member_id' for presence") + return + } + + conn, err := wsUpgrader.Upgrade(w, r, nil) + if err != nil { + g.logger.ComponentWarn("gateway", "pubsub ws: upgrade failed") + return + } + defer conn.Close() + + // Channel to deliver PubSub messages to WS writer + msgs := make(chan []byte, 128) + + // NEW: Register as local subscriber for direct message delivery + localSub := &localSubscriber{ + msgChan: msgs, + namespace: ns, + } + topicKey := fmt.Sprintf("%s.%s", ns, topic) + + g.mu.Lock() + g.localSubscribers[topicKey] = append(g.localSubscribers[topicKey], localSub) + subscriberCount := len(g.localSubscribers[topicKey]) + g.mu.Unlock() + + connID := uuid.New().String() + if enablePresence { + member := PresenceMember{ + MemberID: memberID, + JoinedAt: time.Now().Unix(), + Meta: memberMeta, + ConnID: connID, + } + + g.presenceMu.Lock() + g.presenceMembers[topicKey] = append(g.presenceMembers[topicKey], member) + g.presenceMu.Unlock() + + // Broadcast join event + joinEvent := map[string]interface{}{ + "type": "presence.join", + "member_id": memberID, + "meta": memberMeta, + "timestamp": member.JoinedAt, + } + eventData, _ := json.Marshal(joinEvent) + // Use a background context for the broadcast to ensure it finishes even if the connection closes immediately + broadcastCtx := pubsub.WithNamespace(client.WithInternalAuth(context.Background()), ns) + _ = g.client.PubSub().Publish(broadcastCtx, topic, eventData) + + g.logger.ComponentInfo("gateway", "pubsub ws: member joined presence", + zap.String("topic", topic), + zap.String("member_id", memberID)) + } + + g.logger.ComponentInfo("gateway", "pubsub ws: registered local subscriber", + zap.String("topic", topic), + zap.String("namespace", ns), + zap.Int("total_subscribers", subscriberCount)) + + // Unregister on close + defer func() { + g.mu.Lock() + subs := g.localSubscribers[topicKey] + for i, sub := range subs { + if sub == localSub { + g.localSubscribers[topicKey] = append(subs[:i], subs[i+1:]...) + break + } + } + remainingCount := len(g.localSubscribers[topicKey]) + if remainingCount == 0 { + delete(g.localSubscribers, topicKey) + } + g.mu.Unlock() + + if enablePresence { + g.presenceMu.Lock() + members := g.presenceMembers[topicKey] + for i, m := range members { + if m.ConnID == connID { + g.presenceMembers[topicKey] = append(members[:i], members[i+1:]...) + break + } + } + if len(g.presenceMembers[topicKey]) == 0 { + delete(g.presenceMembers, topicKey) + } + g.presenceMu.Unlock() + + // Broadcast leave event + leaveEvent := map[string]interface{}{ + "type": "presence.leave", + "member_id": memberID, + "timestamp": time.Now().Unix(), + } + eventData, _ := json.Marshal(leaveEvent) + broadcastCtx := pubsub.WithNamespace(client.WithInternalAuth(context.Background()), ns) + _ = g.client.PubSub().Publish(broadcastCtx, topic, eventData) + + g.logger.ComponentInfo("gateway", "pubsub ws: member left presence", + zap.String("topic", topic), + zap.String("member_id", memberID)) + } + + g.logger.ComponentInfo("gateway", "pubsub ws: unregistered local subscriber", + zap.String("topic", topic), + zap.Int("remaining_subscribers", remainingCount)) + }() + + // Use internal auth context when interacting with client to avoid circular auth requirements + ctx := client.WithInternalAuth(r.Context()) + // Apply namespace isolation + ctx = pubsub.WithNamespace(ctx, ns) + + // Writer loop - START THIS FIRST before libp2p subscription + done := make(chan struct{}) + go func() { + g.logger.ComponentInfo("gateway", "pubsub ws: writer goroutine started", + zap.String("topic", topic)) + defer g.logger.ComponentInfo("gateway", "pubsub ws: writer goroutine exiting", + zap.String("topic", topic)) + ticker := time.NewTicker(30 * time.Second) + defer ticker.Stop() + for { + select { + case b, ok := <-msgs: + if !ok { + g.logger.ComponentWarn("gateway", "pubsub ws: message channel closed", + zap.String("topic", topic)) + _ = conn.WriteControl(websocket.CloseMessage, []byte{}, time.Now().Add(5*time.Second)) + close(done) + return + } + + g.logger.ComponentInfo("gateway", "pubsub ws: sending message to client", + zap.String("topic", topic), + zap.Int("data_len", len(b))) + + // Format message as JSON envelope with data (base64 encoded), timestamp, and topic + // This matches the SDK's Message interface: {data: string, timestamp: number, topic: string} + envelope := map[string]interface{}{ + "data": base64.StdEncoding.EncodeToString(b), + "timestamp": time.Now().UnixMilli(), + "topic": topic, + } + envelopeJSON, err := json.Marshal(envelope) + if err != nil { + g.logger.ComponentWarn("gateway", "pubsub ws: failed to marshal envelope", + zap.String("topic", topic), + zap.Error(err)) + continue + } + + g.logger.ComponentDebug("gateway", "pubsub ws: envelope created", + zap.String("topic", topic), + zap.Int("envelope_len", len(envelopeJSON))) + + conn.SetWriteDeadline(time.Now().Add(30 * time.Second)) + if err := conn.WriteMessage(websocket.TextMessage, envelopeJSON); err != nil { + g.logger.ComponentWarn("gateway", "pubsub ws: failed to write to websocket", + zap.String("topic", topic), + zap.Error(err)) + close(done) + return + } + + g.logger.ComponentInfo("gateway", "pubsub ws: message sent successfully", + zap.String("topic", topic)) + case <-ticker.C: + // Ping keepalive + _ = conn.WriteControl(websocket.PingMessage, []byte("ping"), time.Now().Add(5*time.Second)) + case <-ctx.Done(): + close(done) + return + } + } + }() + + // Subscribe to libp2p for cross-node messages (in background, non-blocking) + go func() { + h := func(_ string, data []byte) error { + g.logger.ComponentInfo("gateway", "pubsub ws: received message from libp2p", + zap.String("topic", topic), + zap.Int("data_len", len(data))) + + select { + case msgs <- data: + g.logger.ComponentInfo("gateway", "pubsub ws: forwarded to client", + zap.String("topic", topic), + zap.String("source", "libp2p")) + return nil + default: + // Drop if client is slow to avoid blocking network + g.logger.ComponentWarn("gateway", "pubsub ws: client slow, dropping message", + zap.String("topic", topic)) + return nil + } + } + if err := g.client.PubSub().Subscribe(ctx, topic, h); err != nil { + g.logger.ComponentWarn("gateway", "pubsub ws: libp2p subscribe failed (will use local-only)", + zap.String("topic", topic), + zap.Error(err)) + return + } + g.logger.ComponentInfo("gateway", "pubsub ws: libp2p subscription established", + zap.String("topic", topic)) + + // Keep subscription alive until done + <-done + _ = g.client.PubSub().Unsubscribe(ctx, topic) + g.logger.ComponentInfo("gateway", "pubsub ws: libp2p subscription closed", + zap.String("topic", topic)) + }() + + // Reader loop: treat any client message as publish to the same topic + for { + mt, data, err := conn.ReadMessage() + if err != nil { + break + } + if mt != websocket.TextMessage && mt != websocket.BinaryMessage { + continue + } + + // Filter out WebSocket heartbeat messages + // Don't publish them to the topic + var msg map[string]interface{} + if err := json.Unmarshal(data, &msg); err == nil { + if msgType, ok := msg["type"].(string); ok && msgType == "ping" { + g.logger.ComponentInfo("gateway", "pubsub ws: filtering out heartbeat ping") + continue + } + } + + if err := g.client.PubSub().Publish(ctx, topic, data); err != nil { + // Best-effort notify client + _ = conn.WriteMessage(websocket.TextMessage, []byte("publish_error")) + } + } + <-done +} + +// pubsubPublishHandler handles POST /v1/pubsub/publish {topic, data_base64} +func (g *Gateway) pubsubPublishHandler(w http.ResponseWriter, r *http.Request) { + if g.client == nil { + writeError(w, http.StatusServiceUnavailable, "client not initialized") + return + } + if r.Method != http.MethodPost { + writeError(w, http.StatusMethodNotAllowed, "method not allowed") + return + } + ns := resolveNamespaceFromRequest(r) + if ns == "" { + writeError(w, http.StatusForbidden, "namespace not resolved") + return + } + var body struct { + Topic string `json:"topic"` + DataB64 string `json:"data_base64"` + } + if err := json.NewDecoder(r.Body).Decode(&body); err != nil || body.Topic == "" || body.DataB64 == "" { + writeError(w, http.StatusBadRequest, "invalid body: expected {topic,data_base64}") + return + } + data, err := base64.StdEncoding.DecodeString(body.DataB64) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid base64 data") + return + } + + // NEW: Check for local websocket subscribers FIRST and deliver directly + g.mu.RLock() + localSubs := g.getLocalSubscribers(body.Topic, ns) + g.mu.RUnlock() + + localDeliveryCount := 0 + if len(localSubs) > 0 { + for _, sub := range localSubs { + select { + case sub.msgChan <- data: + localDeliveryCount++ + g.logger.ComponentDebug("gateway", "delivered to local subscriber", + zap.String("topic", body.Topic)) + default: + // Drop if buffer full + g.logger.ComponentWarn("gateway", "local subscriber buffer full, dropping message", + zap.String("topic", body.Topic)) + } + } + } + + g.logger.ComponentInfo("gateway", "pubsub publish: processing message", + zap.String("topic", body.Topic), + zap.String("namespace", ns), + zap.Int("data_len", len(data)), + zap.Int("local_subscribers", len(localSubs)), + zap.Int("local_delivered", localDeliveryCount)) + + // Publish to libp2p asynchronously for cross-node delivery + // This prevents blocking the HTTP response if libp2p network is slow + go func() { + publishCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + ctx := pubsub.WithNamespace(client.WithInternalAuth(publishCtx), ns) + if err := g.client.PubSub().Publish(ctx, body.Topic, data); err != nil { + g.logger.ComponentWarn("gateway", "async libp2p publish failed", + zap.String("topic", body.Topic), + zap.Error(err)) + } else { + g.logger.ComponentDebug("gateway", "async libp2p publish succeeded", + zap.String("topic", body.Topic)) + } + }() + + // Return immediately after local delivery + // Local WebSocket subscribers already received the message + writeJSON(w, http.StatusOK, map[string]any{"status": "ok"}) +} + +// pubsubTopicsHandler lists topics within the caller's namespace +func (g *Gateway) pubsubTopicsHandler(w http.ResponseWriter, r *http.Request) { + if g.client == nil { + writeError(w, http.StatusServiceUnavailable, "client not initialized") + return + } + ns := resolveNamespaceFromRequest(r) + if ns == "" { + writeError(w, http.StatusForbidden, "namespace not resolved") + return + } + // Apply namespace isolation + ctx := pubsub.WithNamespace(client.WithInternalAuth(r.Context()), ns) + all, err := g.client.PubSub().ListTopics(ctx) + if err != nil { + writeError(w, http.StatusInternalServerError, err.Error()) + return + } + // Client returns topics already trimmed to its namespace; return as-is + writeJSON(w, http.StatusOK, map[string]any{"topics": all}) +} + +// resolveNamespaceFromRequest gets namespace from context set by auth middleware +// Falls back to query parameter "namespace" for development/testing +func resolveNamespaceFromRequest(r *http.Request) string { + if v := r.Context().Value(ctxKeyNamespaceOverride); v != nil { + if s, ok := v.(string); ok { + return s + } + } + // Fallback: check query parameter for development + if ns := strings.TrimSpace(r.URL.Query().Get("namespace")); ns != "" { + return ns + } + return "" +} + +func namespacePrefix(ns string) string { + return "ns::" + ns + "::" +} + +func namespacedTopic(ns, topic string) string { + return namespacePrefix(ns) + topic +} + +// pubsubPresenceHandler handles GET /v1/pubsub/presence?topic=mytopic +func (g *Gateway) pubsubPresenceHandler(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + writeError(w, http.StatusMethodNotAllowed, "method not allowed") + return + } + + ns := resolveNamespaceFromRequest(r) + if ns == "" { + writeError(w, http.StatusForbidden, "namespace not resolved") + return + } + + topic := r.URL.Query().Get("topic") + if topic == "" { + writeError(w, http.StatusBadRequest, "missing 'topic'") + return + } + + topicKey := fmt.Sprintf("%s.%s", ns, topic) + + g.presenceMu.RLock() + members, ok := g.presenceMembers[topicKey] + g.presenceMu.RUnlock() + + if !ok { + writeJSON(w, http.StatusOK, map[string]any{ + "topic": topic, + "members": []PresenceMember{}, + "count": 0, + }) + return + } + + writeJSON(w, http.StatusOK, map[string]any{ + "topic": topic, + "members": members, + "count": len(members), + }) +} diff --git a/pkg/gateway/routes.go b/pkg/gateway/routes.go index a6aa1e4..e542a79 100644 --- a/pkg/gateway/routes.go +++ b/pkg/gateway/routes.go @@ -79,5 +79,14 @@ func (g *Gateway) Routes() http.Handler { g.serverlessHandlers.RegisterRoutes(mux) } + // TURN credentials for WebRTC + mux.HandleFunc("/v1/turn/credentials", g.turnCredentialsHandler) + + // SFU endpoints for WebRTC group calls (if enabled) + if g.sfuManager != nil { + mux.HandleFunc("/v1/sfu/room", g.sfuCreateRoomHandler) + mux.HandleFunc("/v1/sfu/room/", g.sfuRoomHandler) // Handles :roomId/* paths + } + return g.withMiddleware(mux) } diff --git a/pkg/gateway/serverless_handlers.go b/pkg/gateway/serverless_handlers.go new file mode 100644 index 0000000..295be9e --- /dev/null +++ b/pkg/gateway/serverless_handlers.go @@ -0,0 +1,933 @@ +package gateway + +import ( + "context" + "encoding/json" + "io" + "net/http" + "strconv" + "strings" + "time" + + "github.com/DeBrosOfficial/network/pkg/gateway/auth" + "github.com/DeBrosOfficial/network/pkg/serverless" + "github.com/google/uuid" + "github.com/gorilla/websocket" + "go.uber.org/zap" +) + +// ServerlessHandlers contains handlers for serverless function endpoints. +// It's a separate struct to keep the Gateway struct clean. +type ServerlessHandlers struct { + invoker *serverless.Invoker + registry serverless.FunctionRegistry + wsManager *serverless.WSManager + triggerManager serverless.TriggerManager + logger *zap.Logger +} + +// NewServerlessHandlers creates a new ServerlessHandlers instance. +func NewServerlessHandlers( + invoker *serverless.Invoker, + registry serverless.FunctionRegistry, + wsManager *serverless.WSManager, + triggerManager serverless.TriggerManager, + logger *zap.Logger, +) *ServerlessHandlers { + return &ServerlessHandlers{ + invoker: invoker, + registry: registry, + wsManager: wsManager, + triggerManager: triggerManager, + logger: logger, + } +} + +// RegisterRoutes registers all serverless routes on the given mux. +func (h *ServerlessHandlers) RegisterRoutes(mux *http.ServeMux) { + // Function management + mux.HandleFunc("/v1/functions", h.handleFunctions) + mux.HandleFunc("/v1/functions/", h.handleFunctionByName) + + // Direct invoke endpoint + mux.HandleFunc("/v1/invoke/", h.handleInvoke) +} + +// handleFunctions handles GET /v1/functions (list) and POST /v1/functions (deploy) +func (h *ServerlessHandlers) handleFunctions(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + h.listFunctions(w, r) + case http.MethodPost: + h.deployFunction(w, r) + default: + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } +} + +// handleFunctionByName handles operations on a specific function +// Routes: +// - GET /v1/functions/{name} - Get function info +// - DELETE /v1/functions/{name} - Delete function +// - POST /v1/functions/{name}/invoke - Invoke function +// - GET /v1/functions/{name}/versions - List versions +// - GET /v1/functions/{name}/logs - Get logs +// - WS /v1/functions/{name}/ws - WebSocket invoke +func (h *ServerlessHandlers) handleFunctionByName(w http.ResponseWriter, r *http.Request) { + // Parse path: /v1/functions/{name}[/{action}] + path := strings.TrimPrefix(r.URL.Path, "/v1/functions/") + parts := strings.SplitN(path, "/", 2) + + if len(parts) == 0 || parts[0] == "" { + http.Error(w, "Function name required", http.StatusBadRequest) + return + } + + name := parts[0] + action := "" + if len(parts) > 1 { + action = parts[1] + } + + // Parse version from name if present (e.g., "myfunction@2") + version := 0 + if idx := strings.Index(name, "@"); idx > 0 { + vStr := name[idx+1:] + name = name[:idx] + if v, err := strconv.Atoi(vStr); err == nil { + version = v + } + } + + switch action { + case "invoke": + h.invokeFunction(w, r, name, version) + case "ws": + h.handleWebSocket(w, r, name, version) + case "versions": + h.listVersions(w, r, name) + case "logs": + h.getFunctionLogs(w, r, name) + case "triggers": + h.handleFunctionTriggers(w, r, name) + case "": + switch r.Method { + case http.MethodGet: + h.getFunctionInfo(w, r, name, version) + case http.MethodDelete: + h.deleteFunction(w, r, name, version) + default: + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } + default: + http.Error(w, "Unknown action", http.StatusNotFound) + } +} + +// handleInvoke handles POST /v1/invoke/{namespace}/{name}[@version] +func (h *ServerlessHandlers) handleInvoke(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + // Parse path: /v1/invoke/{namespace}/{name}[@version] + path := strings.TrimPrefix(r.URL.Path, "/v1/invoke/") + parts := strings.SplitN(path, "/", 2) + + if len(parts) < 2 { + http.Error(w, "Path must be /v1/invoke/{namespace}/{name}", http.StatusBadRequest) + return + } + + namespace := parts[0] + name := parts[1] + + // Parse version if present + version := 0 + if idx := strings.Index(name, "@"); idx > 0 { + vStr := name[idx+1:] + name = name[:idx] + if v, err := strconv.Atoi(vStr); err == nil { + version = v + } + } + + h.invokeFunction(w, r, namespace+"/"+name, version) +} + +// listFunctions handles GET /v1/functions +func (h *ServerlessHandlers) listFunctions(w http.ResponseWriter, r *http.Request) { + namespace := r.URL.Query().Get("namespace") + if namespace == "" { + // Get namespace from JWT if available + namespace = h.getNamespaceFromRequest(r) + } + + if namespace == "" { + writeError(w, http.StatusBadRequest, "namespace required") + return + } + + ctx, cancel := context.WithTimeout(r.Context(), 30*time.Second) + defer cancel() + + functions, err := h.registry.List(ctx, namespace) + if err != nil { + h.logger.Error("Failed to list functions", zap.Error(err)) + writeError(w, http.StatusInternalServerError, "Failed to list functions") + return + } + + writeJSON(w, http.StatusOK, map[string]interface{}{ + "functions": functions, + "count": len(functions), + }) +} + +// deployFunction handles POST /v1/functions +func (h *ServerlessHandlers) deployFunction(w http.ResponseWriter, r *http.Request) { + // Parse multipart form (for WASM upload) or JSON + contentType := r.Header.Get("Content-Type") + + var def serverless.FunctionDefinition + var wasmBytes []byte + + if strings.HasPrefix(contentType, "multipart/form-data") { + // Parse multipart form + if err := r.ParseMultipartForm(32 << 20); err != nil { // 32MB max + writeError(w, http.StatusBadRequest, "Failed to parse form: "+err.Error()) + return + } + + // Get metadata from form field + metadataStr := r.FormValue("metadata") + if metadataStr != "" { + if err := json.Unmarshal([]byte(metadataStr), &def); err != nil { + writeError(w, http.StatusBadRequest, "Invalid metadata JSON: "+err.Error()) + return + } + } + + // Get name from form if not in metadata + if def.Name == "" { + def.Name = r.FormValue("name") + } + + // Get namespace from form if not in metadata + if def.Namespace == "" { + def.Namespace = r.FormValue("namespace") + } + + // Get other configuration fields from form + if v := r.FormValue("is_public"); v != "" { + def.IsPublic, _ = strconv.ParseBool(v) + } + if v := r.FormValue("memory_limit_mb"); v != "" { + def.MemoryLimitMB, _ = strconv.Atoi(v) + } + if v := r.FormValue("timeout_seconds"); v != "" { + def.TimeoutSeconds, _ = strconv.Atoi(v) + } + if v := r.FormValue("retry_count"); v != "" { + def.RetryCount, _ = strconv.Atoi(v) + } + if v := r.FormValue("retry_delay_seconds"); v != "" { + def.RetryDelaySeconds, _ = strconv.Atoi(v) + } + + // Get WASM file + file, _, err := r.FormFile("wasm") + if err != nil { + writeError(w, http.StatusBadRequest, "WASM file required") + return + } + defer file.Close() + + wasmBytes, err = io.ReadAll(file) + if err != nil { + writeError(w, http.StatusBadRequest, "Failed to read WASM file: "+err.Error()) + return + } + } else { + // JSON body with base64-encoded WASM + var req struct { + serverless.FunctionDefinition + WASMBase64 string `json:"wasm_base64"` + } + + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeError(w, http.StatusBadRequest, "Invalid JSON: "+err.Error()) + return + } + + def = req.FunctionDefinition + + if req.WASMBase64 != "" { + // Decode base64 WASM - for now, just reject this method + writeError(w, http.StatusBadRequest, "Base64 WASM upload not supported, use multipart/form-data") + return + } + } + + // Get namespace from JWT if not provided + if def.Namespace == "" { + def.Namespace = h.getNamespaceFromRequest(r) + } + + if def.Name == "" { + writeError(w, http.StatusBadRequest, "Function name required") + return + } + if def.Namespace == "" { + writeError(w, http.StatusBadRequest, "Namespace required") + return + } + if len(wasmBytes) == 0 { + writeError(w, http.StatusBadRequest, "WASM bytecode required") + return + } + + ctx, cancel := context.WithTimeout(r.Context(), 60*time.Second) + defer cancel() + + oldFn, err := h.registry.Register(ctx, &def, wasmBytes) + if err != nil { + h.logger.Error("Failed to deploy function", + zap.String("name", def.Name), + zap.Error(err), + ) + writeError(w, http.StatusInternalServerError, "Failed to deploy: "+err.Error()) + return + } + + // Invalidate cache for the old version to ensure the new one is loaded + if oldFn != nil { + h.invoker.InvalidateCache(oldFn.WASMCID) + h.logger.Debug("Invalidated function cache", + zap.String("name", def.Name), + zap.String("old_wasm_cid", oldFn.WASMCID), + ) + } + + h.logger.Info("Function deployed", + zap.String("name", def.Name), + zap.String("namespace", def.Namespace), + ) + + // Fetch the deployed function to return + fn, err := h.registry.Get(ctx, def.Namespace, def.Name, def.Version) + if err != nil { + writeJSON(w, http.StatusCreated, map[string]interface{}{ + "message": "Function deployed successfully", + "name": def.Name, + }) + return + } + + // Register PubSub triggers if provided in metadata + var triggersAdded []string + if len(def.PubSubTopics) > 0 && h.triggerManager != nil { + for _, topic := range def.PubSubTopics { + if err := h.triggerManager.AddPubSubTrigger(ctx, fn.ID, topic); err != nil { + // Log but don't fail deployment + h.logger.Warn("Failed to add pubsub trigger during deployment", + zap.String("function", def.Name), + zap.String("topic", topic), + zap.Error(err), + ) + } else { + triggersAdded = append(triggersAdded, topic) + h.logger.Info("PubSub trigger added during deployment", + zap.String("function", def.Name), + zap.String("topic", topic), + ) + } + } + } + + response := map[string]interface{}{ + "message": "Function deployed successfully", + "function": fn, + } + if len(triggersAdded) > 0 { + response["triggers_added"] = triggersAdded + } + + writeJSON(w, http.StatusCreated, response) +} + +// getFunctionInfo handles GET /v1/functions/{name} +func (h *ServerlessHandlers) getFunctionInfo(w http.ResponseWriter, r *http.Request, name string, version int) { + namespace := r.URL.Query().Get("namespace") + if namespace == "" { + namespace = h.getNamespaceFromRequest(r) + } + + if namespace == "" { + writeError(w, http.StatusBadRequest, "namespace required") + return + } + + ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second) + defer cancel() + + fn, err := h.registry.Get(ctx, namespace, name, version) + if err != nil { + if serverless.IsNotFound(err) { + writeError(w, http.StatusNotFound, "Function not found") + } else { + writeError(w, http.StatusInternalServerError, "Failed to get function") + } + return + } + + writeJSON(w, http.StatusOK, fn) +} + +// deleteFunction handles DELETE /v1/functions/{name} +func (h *ServerlessHandlers) deleteFunction(w http.ResponseWriter, r *http.Request, name string, version int) { + namespace := r.URL.Query().Get("namespace") + if namespace == "" { + namespace = h.getNamespaceFromRequest(r) + } + + if namespace == "" { + writeError(w, http.StatusBadRequest, "namespace required") + return + } + + ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second) + defer cancel() + + if err := h.registry.Delete(ctx, namespace, name, version); err != nil { + if serverless.IsNotFound(err) { + writeError(w, http.StatusNotFound, "Function not found") + } else { + writeError(w, http.StatusInternalServerError, "Failed to delete function") + } + return + } + + writeJSON(w, http.StatusOK, map[string]string{ + "message": "Function deleted successfully", + }) +} + +// invokeFunction handles POST /v1/functions/{name}/invoke +func (h *ServerlessHandlers) invokeFunction(w http.ResponseWriter, r *http.Request, nameWithNS string, version int) { + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + // Parse namespace and name + var namespace, name string + if idx := strings.Index(nameWithNS, "/"); idx > 0 { + namespace = nameWithNS[:idx] + name = nameWithNS[idx+1:] + } else { + name = nameWithNS + namespace = r.URL.Query().Get("namespace") + if namespace == "" { + namespace = h.getNamespaceFromRequest(r) + } + } + + if namespace == "" { + writeError(w, http.StatusBadRequest, "namespace required") + return + } + + // Read input body + input, err := io.ReadAll(io.LimitReader(r.Body, 1<<20)) // 1MB max + if err != nil { + writeError(w, http.StatusBadRequest, "Failed to read request body") + return + } + + // Get caller wallet from JWT + callerWallet := h.getWalletFromRequest(r) + + ctx, cancel := context.WithTimeout(r.Context(), 60*time.Second) + defer cancel() + + req := &serverless.InvokeRequest{ + Namespace: namespace, + FunctionName: name, + Version: version, + Input: input, + TriggerType: serverless.TriggerTypeHTTP, + CallerWallet: callerWallet, + } + + resp, err := h.invoker.Invoke(ctx, req) + if err != nil { + statusCode := http.StatusInternalServerError + if serverless.IsNotFound(err) { + statusCode = http.StatusNotFound + } else if serverless.IsResourceExhausted(err) { + statusCode = http.StatusTooManyRequests + } else if serverless.IsUnauthorized(err) { + statusCode = http.StatusUnauthorized + } + + writeJSON(w, statusCode, map[string]interface{}{ + "request_id": resp.RequestID, + "status": resp.Status, + "error": resp.Error, + "duration_ms": resp.DurationMS, + }) + return + } + + // Return the function's output directly if it's JSON + w.Header().Set("X-Request-ID", resp.RequestID) + w.Header().Set("X-Duration-Ms", strconv.FormatInt(resp.DurationMS, 10)) + + // Try to detect if output is JSON + if len(resp.Output) > 0 && (resp.Output[0] == '{' || resp.Output[0] == '[') { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write(resp.Output) + } else { + writeJSON(w, http.StatusOK, map[string]interface{}{ + "request_id": resp.RequestID, + "output": string(resp.Output), + "status": resp.Status, + "duration_ms": resp.DurationMS, + }) + } +} + +// handleWebSocket handles WebSocket connections for function streaming +func (h *ServerlessHandlers) handleWebSocket(w http.ResponseWriter, r *http.Request, name string, version int) { + namespace := r.URL.Query().Get("namespace") + if namespace == "" { + namespace = h.getNamespaceFromRequest(r) + } + + if namespace == "" { + http.Error(w, "namespace required", http.StatusBadRequest) + return + } + + // Upgrade to WebSocket + upgrader := websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { return true }, + } + + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + h.logger.Error("WebSocket upgrade failed", zap.Error(err)) + return + } + + clientID := uuid.New().String() + wsConn := &serverless.GorillaWSConn{Conn: conn} + + // Register connection + h.wsManager.Register(clientID, wsConn) + defer h.wsManager.Unregister(clientID) + + h.logger.Info("WebSocket connected", + zap.String("client_id", clientID), + zap.String("function", name), + ) + + callerWallet := h.getWalletFromRequest(r) + + // Message loop + for { + _, message, err := conn.ReadMessage() + if err != nil { + if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { + h.logger.Warn("WebSocket error", zap.Error(err)) + } + break + } + + // Invoke function with WebSocket context + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + + req := &serverless.InvokeRequest{ + Namespace: namespace, + FunctionName: name, + Version: version, + Input: message, + TriggerType: serverless.TriggerTypeWebSocket, + CallerWallet: callerWallet, + WSClientID: clientID, + } + + resp, err := h.invoker.Invoke(ctx, req) + cancel() + + // Send response back + response := map[string]interface{}{ + "request_id": resp.RequestID, + "status": resp.Status, + "duration_ms": resp.DurationMS, + } + + if err != nil { + response["error"] = resp.Error + } else if len(resp.Output) > 0 { + // Try to parse output as JSON + var output interface{} + if json.Unmarshal(resp.Output, &output) == nil { + response["output"] = output + } else { + response["output"] = string(resp.Output) + } + } + + respBytes, _ := json.Marshal(response) + if err := conn.WriteMessage(websocket.TextMessage, respBytes); err != nil { + break + } + } +} + +// listVersions handles GET /v1/functions/{name}/versions +func (h *ServerlessHandlers) listVersions(w http.ResponseWriter, r *http.Request, name string) { + namespace := r.URL.Query().Get("namespace") + if namespace == "" { + namespace = h.getNamespaceFromRequest(r) + } + + if namespace == "" { + writeError(w, http.StatusBadRequest, "namespace required") + return + } + + ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second) + defer cancel() + + // Get registry with extended methods + reg, ok := h.registry.(*serverless.Registry) + if !ok { + writeError(w, http.StatusNotImplemented, "Version listing not supported") + return + } + + versions, err := reg.ListVersions(ctx, namespace, name) + if err != nil { + writeError(w, http.StatusInternalServerError, "Failed to list versions") + return + } + + writeJSON(w, http.StatusOK, map[string]interface{}{ + "versions": versions, + "count": len(versions), + }) +} + +// getFunctionLogs handles GET /v1/functions/{name}/logs +func (h *ServerlessHandlers) getFunctionLogs(w http.ResponseWriter, r *http.Request, name string) { + namespace := r.URL.Query().Get("namespace") + if namespace == "" { + namespace = h.getNamespaceFromRequest(r) + } + + if namespace == "" { + writeError(w, http.StatusBadRequest, "namespace required") + return + } + + limit := 100 + if lStr := r.URL.Query().Get("limit"); lStr != "" { + if l, err := strconv.Atoi(lStr); err == nil { + limit = l + } + } + + ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second) + defer cancel() + + logs, err := h.registry.GetLogs(ctx, namespace, name, limit) + if err != nil { + h.logger.Error("Failed to get function logs", + zap.String("name", name), + zap.String("namespace", namespace), + zap.Error(err), + ) + writeError(w, http.StatusInternalServerError, "Failed to get logs") + return + } + + writeJSON(w, http.StatusOK, map[string]interface{}{ + "name": name, + "namespace": namespace, + "logs": logs, + "count": len(logs), + }) +} + +// getNamespaceFromRequest extracts namespace from JWT or query param +func (h *ServerlessHandlers) getNamespaceFromRequest(r *http.Request) string { + // Try context first (set by auth middleware) - most secure + if v := r.Context().Value(ctxKeyNamespaceOverride); v != nil { + if ns, ok := v.(string); ok && ns != "" { + return ns + } + } + + // Try query param as fallback (e.g. for public access or admin) + if ns := r.URL.Query().Get("namespace"); ns != "" { + return ns + } + + // Try header as fallback + if ns := r.Header.Get("X-Namespace"); ns != "" { + return ns + } + + return "default" +} + +// getWalletFromRequest extracts wallet address from JWT +func (h *ServerlessHandlers) getWalletFromRequest(r *http.Request) string { + // 1. Try X-Wallet header (legacy/direct bypass) + if wallet := r.Header.Get("X-Wallet"); wallet != "" { + return wallet + } + + // 2. Try JWT claims from context + if v := r.Context().Value(ctxKeyJWT); v != nil { + if claims, ok := v.(*auth.JWTClaims); ok && claims != nil { + subj := strings.TrimSpace(claims.Sub) + // Ensure it's not an API key (standard Orama logic) + if !strings.HasPrefix(strings.ToLower(subj), "ak_") && !strings.Contains(subj, ":") { + return subj + } + } + } + + // 3. Fallback to API key identity (namespace) + if v := r.Context().Value(ctxKeyNamespaceOverride); v != nil { + if ns, ok := v.(string); ok && ns != "" { + return ns + } + } + + return "" +} + +// HealthStatus returns the health status of the serverless engine +func (h *ServerlessHandlers) HealthStatus() map[string]interface{} { + stats := h.wsManager.GetStats() + return map[string]interface{}{ + "status": "ok", + "connections": stats.ConnectionCount, + "topics": stats.TopicCount, + } +} + +// handleFunctionTriggers handles trigger operations for a function +// Routes: +// - GET /v1/functions/{name}/triggers - List all triggers +// - POST /v1/functions/{name}/triggers/pubsub - Add pubsub trigger +// - DELETE /v1/functions/{name}/triggers/{id} - Remove trigger +func (h *ServerlessHandlers) handleFunctionTriggers(w http.ResponseWriter, r *http.Request, name string) { + if h.triggerManager == nil { + writeError(w, http.StatusServiceUnavailable, "Trigger management not available") + return + } + + namespace := r.URL.Query().Get("namespace") + if namespace == "" { + namespace = h.getNamespaceFromRequest(r) + } + + if namespace == "" { + writeError(w, http.StatusBadRequest, "namespace required") + return + } + + // Parse sub-path for trigger type or ID + // Path after "triggers" could be: "", "pubsub", or "{trigger_id}" + fullPath := r.URL.Path + triggersIdx := strings.Index(fullPath, "/triggers") + subPath := "" + if triggersIdx > 0 { + subPath = strings.TrimPrefix(fullPath[triggersIdx:], "/triggers") + subPath = strings.TrimPrefix(subPath, "/") + } + + switch r.Method { + case http.MethodGet: + h.listFunctionTriggers(w, r, namespace, name) + case http.MethodPost: + if subPath == "pubsub" { + h.addPubSubTrigger(w, r, namespace, name) + } else { + writeError(w, http.StatusBadRequest, "Invalid trigger type. Use /triggers/pubsub") + } + case http.MethodDelete: + if subPath == "" { + writeError(w, http.StatusBadRequest, "Trigger ID required") + return + } + h.removeFunctionTrigger(w, r, namespace, name, subPath) + default: + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } +} + +// listFunctionTriggers handles GET /v1/functions/{name}/triggers +func (h *ServerlessHandlers) listFunctionTriggers(w http.ResponseWriter, r *http.Request, namespace, name string) { + ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second) + defer cancel() + + // Get function to verify it exists and get its ID + fn, err := h.registry.Get(ctx, namespace, name, 0) + if err != nil { + if serverless.IsNotFound(err) { + writeError(w, http.StatusNotFound, "Function not found") + return + } + h.logger.Error("Failed to get function", + zap.String("name", name), + zap.String("namespace", namespace), + zap.Error(err), + ) + writeError(w, http.StatusInternalServerError, "Failed to get function") + return + } + + // Get pubsub triggers + pubsubTriggers, err := h.triggerManager.ListPubSubTriggers(ctx, fn.ID) + if err != nil { + h.logger.Error("Failed to list triggers", + zap.String("function_id", fn.ID), + zap.Error(err), + ) + writeError(w, http.StatusInternalServerError, "Failed to list triggers") + return + } + + writeJSON(w, http.StatusOK, map[string]interface{}{ + "name": name, + "namespace": namespace, + "function_id": fn.ID, + "pubsub_triggers": pubsubTriggers, + "count": len(pubsubTriggers), + }) +} + +// addPubSubTrigger handles POST /v1/functions/{name}/triggers/pubsub +func (h *ServerlessHandlers) addPubSubTrigger(w http.ResponseWriter, r *http.Request, namespace, name string) { + ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second) + defer cancel() + + // Parse request body + var req struct { + Topic string `json:"topic"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeError(w, http.StatusBadRequest, "Invalid JSON: "+err.Error()) + return + } + + if req.Topic == "" { + writeError(w, http.StatusBadRequest, "topic is required") + return + } + + // Get function to verify it exists and get its ID + fn, err := h.registry.Get(ctx, namespace, name, 0) + if err != nil { + if serverless.IsNotFound(err) { + writeError(w, http.StatusNotFound, "Function not found") + return + } + h.logger.Error("Failed to get function", + zap.String("name", name), + zap.String("namespace", namespace), + zap.Error(err), + ) + writeError(w, http.StatusInternalServerError, "Failed to get function") + return + } + + // Add the trigger + err = h.triggerManager.AddPubSubTrigger(ctx, fn.ID, req.Topic) + if err != nil { + if serverless.IsValidationError(err) { + writeError(w, http.StatusBadRequest, err.Error()) + return + } + h.logger.Error("Failed to add pubsub trigger", + zap.String("function_id", fn.ID), + zap.String("topic", req.Topic), + zap.Error(err), + ) + writeError(w, http.StatusInternalServerError, "Failed to add trigger") + return + } + + // Get the triggers to return the newly created one + triggers, _ := h.triggerManager.ListPubSubTriggers(ctx, fn.ID) + var newTrigger *serverless.PubSubTrigger + for i := range triggers { + if triggers[i].Topic == req.Topic { + newTrigger = &triggers[i] + break + } + } + + h.logger.Info("PubSub trigger added", + zap.String("function", name), + zap.String("namespace", namespace), + zap.String("topic", req.Topic), + ) + + writeJSON(w, http.StatusCreated, map[string]interface{}{ + "message": "Trigger added successfully", + "trigger": newTrigger, + }) +} + +// removeFunctionTrigger handles DELETE /v1/functions/{name}/triggers/{id} +func (h *ServerlessHandlers) removeFunctionTrigger(w http.ResponseWriter, r *http.Request, namespace, name, triggerID string) { + ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second) + defer cancel() + + // Verify function exists + fn, err := h.registry.Get(ctx, namespace, name, 0) + if err != nil { + if serverless.IsNotFound(err) { + writeError(w, http.StatusNotFound, "Function not found") + return + } + writeError(w, http.StatusInternalServerError, "Failed to get function") + return + } + + // Remove the trigger + err = h.triggerManager.RemoveTrigger(ctx, triggerID) + if err != nil { + if err == serverless.ErrTriggerNotFound { + writeError(w, http.StatusNotFound, "Trigger not found") + return + } + h.logger.Error("Failed to remove trigger", + zap.String("trigger_id", triggerID), + zap.Error(err), + ) + writeError(w, http.StatusInternalServerError, "Failed to remove trigger") + return + } + + h.logger.Info("Trigger removed", + zap.String("function", name), + zap.String("function_id", fn.ID), + zap.String("trigger_id", triggerID), + ) + + writeJSON(w, http.StatusOK, map[string]interface{}{ + "message": "Trigger removed successfully", + "trigger_id": triggerID, + }) +} diff --git a/pkg/gateway/serverless_handlers_test.go b/pkg/gateway/serverless_handlers_test.go index 7796dc4..3bc9760 100644 --- a/pkg/gateway/serverless_handlers_test.go +++ b/pkg/gateway/serverless_handlers_test.go @@ -8,7 +8,6 @@ import ( "net/http/httptest" "testing" - serverlesshandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/serverless" "github.com/DeBrosOfficial/network/pkg/serverless" "go.uber.org/zap" ) @@ -50,12 +49,12 @@ func TestServerlessHandlers_ListFunctions(t *testing.T) { }, } - h := serverlesshandlers.NewServerlessHandlers(nil, registry, nil, logger) + h := NewServerlessHandlers(nil, registry, nil, nil, logger) req, _ := http.NewRequest("GET", "/v1/functions?namespace=ns1", nil) rr := httptest.NewRecorder() - h.ListFunctions(rr, req) + h.handleFunctions(rr, req) if rr.Code != http.StatusOK { t.Errorf("expected status 200, got %d", rr.Code) @@ -73,7 +72,7 @@ func TestServerlessHandlers_DeployFunction(t *testing.T) { logger := zap.NewNop() registry := &mockFunctionRegistry{} - h := serverlesshandlers.NewServerlessHandlers(nil, registry, nil, logger) + h := NewServerlessHandlers(nil, registry, nil, nil, logger) // Test JSON deploy (which is partially supported according to code) // Should be 400 because WASM is missing or base64 not supported @@ -81,7 +80,7 @@ func TestServerlessHandlers_DeployFunction(t *testing.T) { req, _ := http.NewRequest("POST", "/v1/functions", bytes.NewBufferString(`{"name": "test"}`)) req.Header.Set("Content-Type", "application/json") - h.DeployFunction(writer, req) + h.handleFunctions(writer, req) if writer.Code != http.StatusBadRequest { t.Errorf("expected status 400, got %d", writer.Code) diff --git a/pkg/gateway/sfu/config.go b/pkg/gateway/sfu/config.go new file mode 100644 index 0000000..4388235 --- /dev/null +++ b/pkg/gateway/sfu/config.go @@ -0,0 +1,181 @@ +package sfu + +import ( + "time" + + "github.com/pion/interceptor" + "github.com/pion/interceptor/pkg/intervalpli" + "github.com/pion/interceptor/pkg/nack" + "github.com/pion/webrtc/v4" +) + +// Config holds SFU configuration +type Config struct { + // MaxParticipants is the maximum number of participants per room + MaxParticipants int + + // MediaTimeout is the timeout for media operations + MediaTimeout time.Duration + + // ICEServers are the ICE servers for WebRTC connections + ICEServers []webrtc.ICEServer +} + +// DefaultConfig returns a default SFU configuration +func DefaultConfig() *Config { + return &Config{ + MaxParticipants: 10, + MediaTimeout: 30 * time.Second, + ICEServers: []webrtc.ICEServer{ + { + URLs: []string{"stun:stun.l.google.com:19302"}, + }, + }, + } +} + +// NewMediaEngine creates a MediaEngine with supported codecs for the SFU +func NewMediaEngine() (*webrtc.MediaEngine, error) { + m := &webrtc.MediaEngine{} + + // RTCP feedback for video codecs - enables NACK, PLI, FIR + videoRTCPFeedback := []webrtc.RTCPFeedback{ + {Type: "goog-remb", Parameter: ""}, // Bandwidth estimation + {Type: "ccm", Parameter: "fir"}, // Full Intra Request + {Type: "nack", Parameter: ""}, // Generic NACK + {Type: "nack", Parameter: "pli"}, // Picture Loss Indication + } + + // Register Opus codec for audio with NACK support + if err := m.RegisterCodec(webrtc.RTPCodecParameters{ + RTPCodecCapability: webrtc.RTPCodecCapability{ + MimeType: webrtc.MimeTypeOpus, + ClockRate: 48000, + Channels: 2, + SDPFmtpLine: "minptime=10;useinbandfec=1", + }, + PayloadType: 111, + }, webrtc.RTPCodecTypeAudio); err != nil { + return nil, err + } + + // Register VP8 codec for video with full RTCP feedback + if err := m.RegisterCodec(webrtc.RTPCodecParameters{ + RTPCodecCapability: webrtc.RTPCodecCapability{ + MimeType: webrtc.MimeTypeVP8, + ClockRate: 90000, + RTCPFeedback: videoRTCPFeedback, + }, + PayloadType: 96, + }, webrtc.RTPCodecTypeVideo); err != nil { + return nil, err + } + + // Register RTX for VP8 (retransmission) + if err := m.RegisterCodec(webrtc.RTPCodecParameters{ + RTPCodecCapability: webrtc.RTPCodecCapability{ + MimeType: "video/rtx", + ClockRate: 90000, + SDPFmtpLine: "apt=96", + }, + PayloadType: 97, + }, webrtc.RTPCodecTypeVideo); err != nil { + return nil, err + } + + // Register H264 codec for video with full RTCP feedback (fallback) + if err := m.RegisterCodec(webrtc.RTPCodecParameters{ + RTPCodecCapability: webrtc.RTPCodecCapability{ + MimeType: webrtc.MimeTypeH264, + ClockRate: 90000, + SDPFmtpLine: "level-asymmetry-allowed=1;packetization-mode=1;profile-level-id=42e01f", + RTCPFeedback: videoRTCPFeedback, + }, + PayloadType: 102, + }, webrtc.RTPCodecTypeVideo); err != nil { + return nil, err + } + + // Register RTX for H264 (retransmission) + if err := m.RegisterCodec(webrtc.RTPCodecParameters{ + RTPCodecCapability: webrtc.RTPCodecCapability{ + MimeType: "video/rtx", + ClockRate: 90000, + SDPFmtpLine: "apt=102", + }, + PayloadType: 103, + }, webrtc.RTPCodecTypeVideo); err != nil { + return nil, err + } + + return m, nil +} + +// NewWebRTCAPI creates a new WebRTC API with the configured MediaEngine +func NewWebRTCAPI() (*webrtc.API, error) { + mediaEngine, err := NewMediaEngine() + if err != nil { + return nil, err + } + + // Create interceptor registry for RTCP feedback + i := &interceptor.Registry{} + + // Register NACK responder - handles retransmission requests from receivers + // This is critical for video quality: when a receiver loses a packet, + // it sends NACK and the sender retransmits the lost packet + nackResponderFactory, err := nack.NewResponderInterceptor() + if err != nil { + return nil, err + } + i.Add(nackResponderFactory) + + // Register NACK generator - sends NACK when we detect packet loss as receiver + nackGeneratorFactory, err := nack.NewGeneratorInterceptor() + if err != nil { + return nil, err + } + i.Add(nackGeneratorFactory) + + // Register interval PLI - automatically sends PLI periodically for video tracks + // This helps new receivers get keyframes faster + intervalPLIFactory, err := intervalpli.NewReceiverInterceptor() + if err != nil { + return nil, err + } + i.Add(intervalPLIFactory) + + // Configure settings for better media performance + settingEngine := webrtc.SettingEngine{} + + return webrtc.NewAPI( + webrtc.WithMediaEngine(mediaEngine), + webrtc.WithInterceptorRegistry(i), + webrtc.WithSettingEngine(settingEngine), + ), nil +} + +// GetRTPCapabilities returns the RTP capabilities of the SFU +// This is used by clients to negotiate codecs +func GetRTPCapabilities() map[string]interface{} { + return map[string]interface{}{ + "codecs": []map[string]interface{}{ + { + "kind": "audio", + "mimeType": "audio/opus", + "clockRate": 48000, + "channels": 2, + }, + { + "kind": "video", + "mimeType": "video/VP8", + "clockRate": 90000, + }, + { + "kind": "video", + "mimeType": "video/H264", + "clockRate": 90000, + }, + }, + } +} diff --git a/pkg/gateway/sfu/manager.go b/pkg/gateway/sfu/manager.go new file mode 100644 index 0000000..b998bc2 --- /dev/null +++ b/pkg/gateway/sfu/manager.go @@ -0,0 +1,228 @@ +package sfu + +import ( + "errors" + "sync" + + "github.com/pion/webrtc/v4" + "go.uber.org/zap" +) + +var ( + ErrRoomNotFound = errors.New("room not found") +) + +// RoomManager manages all rooms in the SFU +type RoomManager struct { + rooms map[string]*Room // key: namespace:roomId + roomsMu sync.RWMutex + + api *webrtc.API + config *Config + logger *zap.Logger +} + +// NewRoomManager creates a new room manager +func NewRoomManager(config *Config, logger *zap.Logger) (*RoomManager, error) { + if config == nil { + config = DefaultConfig() + } + + api, err := NewWebRTCAPI() + if err != nil { + return nil, err + } + + return &RoomManager{ + rooms: make(map[string]*Room), + api: api, + config: config, + logger: logger.With(zap.String("component", "sfu_manager")), + }, nil +} + +// roomKey creates a unique key for namespace:roomId +func roomKey(namespace, roomID string) string { + return namespace + ":" + roomID +} + +// GetOrCreateRoom gets an existing room or creates a new one +func (m *RoomManager) GetOrCreateRoom(namespace, roomID string) (*Room, bool) { + key := roomKey(namespace, roomID) + + m.roomsMu.Lock() + defer m.roomsMu.Unlock() + + // Check if room already exists + if room, ok := m.rooms[key]; ok { + if !room.IsClosed() { + return room, false + } + // Room is closed, remove it and create a new one + delete(m.rooms, key) + } + + // Create new room + room := NewRoom(roomID, namespace, m.api, m.config, m.logger) + + // Set up empty room handler + room.OnEmpty(func(r *Room) { + m.logger.Info("Room is empty, will be garbage collected", + zap.String("room_id", r.ID), + zap.String("namespace", r.Namespace), + ) + // Optionally close the room immediately + // m.CloseRoom(r.Namespace, r.ID) + }) + + m.rooms[key] = room + + m.logger.Info("Room created", + zap.String("room_id", roomID), + zap.String("namespace", namespace), + ) + + return room, true +} + +// GetRoom gets an existing room +func (m *RoomManager) GetRoom(namespace, roomID string) (*Room, error) { + key := roomKey(namespace, roomID) + + m.roomsMu.RLock() + defer m.roomsMu.RUnlock() + + room, ok := m.rooms[key] + if !ok { + return nil, ErrRoomNotFound + } + + if room.IsClosed() { + return nil, ErrRoomClosed + } + + return room, nil +} + +// CloseRoom closes and removes a room +func (m *RoomManager) CloseRoom(namespace, roomID string) error { + key := roomKey(namespace, roomID) + + m.roomsMu.Lock() + room, ok := m.rooms[key] + if !ok { + m.roomsMu.Unlock() + return ErrRoomNotFound + } + delete(m.rooms, key) + m.roomsMu.Unlock() + + if err := room.Close(); err != nil { + m.logger.Warn("Error closing room", + zap.String("room_id", roomID), + zap.Error(err), + ) + return err + } + + m.logger.Info("Room closed", + zap.String("room_id", roomID), + zap.String("namespace", namespace), + ) + + return nil +} + +// ListRooms returns all rooms for a namespace +func (m *RoomManager) ListRooms(namespace string) []*Room { + m.roomsMu.RLock() + defer m.roomsMu.RUnlock() + + prefix := namespace + ":" + rooms := make([]*Room, 0) + + for key, room := range m.rooms { + if len(key) > len(prefix) && key[:len(prefix)] == prefix && !room.IsClosed() { + rooms = append(rooms, room) + } + } + + return rooms +} + +// RoomInfo contains public information about a room +type RoomInfo struct { + ID string `json:"id"` + Namespace string `json:"namespace"` + ParticipantCount int `json:"participantCount"` + Participants []ParticipantInfo `json:"participants"` +} + +// GetRoomInfo returns public information about a room +func (m *RoomManager) GetRoomInfo(namespace, roomID string) (*RoomInfo, error) { + room, err := m.GetRoom(namespace, roomID) + if err != nil { + return nil, err + } + + return &RoomInfo{ + ID: room.ID, + Namespace: room.Namespace, + ParticipantCount: room.GetParticipantCount(), + Participants: room.GetParticipants(), + }, nil +} + +// GetStats returns statistics about the room manager +func (m *RoomManager) GetStats() map[string]interface{} { + m.roomsMu.RLock() + defer m.roomsMu.RUnlock() + + totalRooms := 0 + totalParticipants := 0 + + for _, room := range m.rooms { + if !room.IsClosed() { + totalRooms++ + totalParticipants += room.GetParticipantCount() + } + } + + return map[string]interface{}{ + "totalRooms": totalRooms, + "totalParticipants": totalParticipants, + } +} + +// Close closes all rooms and cleans up resources +func (m *RoomManager) Close() error { + m.roomsMu.Lock() + rooms := make([]*Room, 0, len(m.rooms)) + for _, room := range m.rooms { + rooms = append(rooms, room) + } + m.rooms = make(map[string]*Room) + m.roomsMu.Unlock() + + for _, room := range rooms { + if err := room.Close(); err != nil { + m.logger.Warn("Error closing room during shutdown", + zap.String("room_id", room.ID), + zap.Error(err), + ) + } + } + + m.logger.Info("Room manager closed", zap.Int("rooms_closed", len(rooms))) + return nil +} + +// GetConfig returns the SFU configuration +func (m *RoomManager) GetConfig() *Config { + return m.config +} + +// UpdateICEServers updates the ICE servers configuration +func (m *RoomManager) UpdateICEServers(servers []webrtc.ICEServer) { + m.config.ICEServers = servers +} diff --git a/pkg/gateway/sfu/peer.go b/pkg/gateway/sfu/peer.go new file mode 100644 index 0000000..26054e2 --- /dev/null +++ b/pkg/gateway/sfu/peer.go @@ -0,0 +1,505 @@ +package sfu + +import ( + "encoding/json" + "sync" + + "github.com/google/uuid" + "github.com/gorilla/websocket" + "github.com/pion/rtcp" + "github.com/pion/webrtc/v4" + "go.uber.org/zap" +) + +// Peer represents a participant in a room +type Peer struct { + ID string + UserID string + DisplayName string + + // WebRTC connection + pc *webrtc.PeerConnection + + // Tracks published by this peer (local tracks that others receive) + localTracks map[string]*webrtc.TrackLocalStaticRTP + localTracksMu sync.RWMutex + + // Track receivers for consuming other peers' tracks + trackReceivers map[string]*webrtc.RTPReceiver + trackReceiversMu sync.RWMutex + + // WebSocket connection for signaling + conn *websocket.Conn + connMu sync.Mutex + + // State + audioMuted bool + videoMuted bool + closed bool + closedMu sync.RWMutex + negotiationPending bool + negotiationPendingMu sync.Mutex + initialOfferHandled bool + initialOfferMu sync.Mutex + batchingTracks bool // When true, suppress automatic negotiation + batchingTracksMu sync.Mutex + + // Room reference + room *Room + logger *zap.Logger + + // Callbacks + onClose func(*Peer) +} + +// NewPeer creates a new peer +func NewPeer(userID, displayName string, conn *websocket.Conn, room *Room, logger *zap.Logger) *Peer { + return &Peer{ + ID: uuid.New().String(), + UserID: userID, + DisplayName: displayName, + localTracks: make(map[string]*webrtc.TrackLocalStaticRTP), + trackReceivers: make(map[string]*webrtc.RTPReceiver), + conn: conn, + room: room, + logger: logger, + } +} + +// InitPeerConnection initializes the WebRTC peer connection +func (p *Peer) InitPeerConnection(api *webrtc.API, config webrtc.Configuration) error { + pc, err := api.NewPeerConnection(config) + if err != nil { + return err + } + p.pc = pc + + // Handle ICE connection state changes + pc.OnICEConnectionStateChange(func(state webrtc.ICEConnectionState) { + p.logger.Info("ICE connection state changed", + zap.String("peer_id", p.ID), + zap.String("state", state.String()), + ) + + if state == webrtc.ICEConnectionStateFailed || + state == webrtc.ICEConnectionStateDisconnected || + state == webrtc.ICEConnectionStateClosed { + p.handleDisconnect() + } + }) + + // Handle ICE candidates + pc.OnICECandidate(func(candidate *webrtc.ICECandidate) { + if candidate == nil { + return + } + + p.logger.Debug("ICE candidate generated", + zap.String("peer_id", p.ID), + zap.String("candidate", candidate.String()), + ) + + candidateJSON := candidate.ToJSON() + data := &ICECandidateData{ + Candidate: candidateJSON.Candidate, + } + if candidateJSON.SDPMid != nil { + data.SDPMid = *candidateJSON.SDPMid + } + if candidateJSON.SDPMLineIndex != nil { + data.SDPMLineIndex = *candidateJSON.SDPMLineIndex + } + if candidateJSON.UsernameFragment != nil { + data.UsernameFragment = *candidateJSON.UsernameFragment + } + + p.SendMessage(NewServerMessage(MessageTypeICECandidate, data)) + }) + + // Handle incoming tracks from remote peers + pc.OnTrack(func(track *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) { + codec := track.Codec() + p.logger.Info("Track received from client", + zap.String("peer_id", p.ID), + zap.String("user_id", p.UserID), + zap.String("track_id", track.ID()), + zap.String("stream_id", track.StreamID()), + zap.String("kind", track.Kind().String()), + zap.String("codec_mime", codec.MimeType), + zap.Uint32("codec_clock_rate", codec.ClockRate), + zap.Uint8("codec_payload_type", uint8(codec.PayloadType)), + ) + + p.trackReceiversMu.Lock() + p.trackReceivers[track.ID()] = receiver + p.trackReceiversMu.Unlock() + + // Start RTCP reader to monitor for packet loss (NACK) and PLI requests + go p.readRTCP(receiver, track) + + // Forward track to other peers in the room + p.room.BroadcastTrack(p.ID, track) + }) + + // Handle negotiation needed - only trigger when in stable state + pc.OnNegotiationNeeded(func() { + p.logger.Debug("Negotiation needed", + zap.String("peer_id", p.ID), + zap.String("signaling_state", pc.SignalingState().String()), + ) + + // Check if we're batching tracks - if so, just mark as pending + p.batchingTracksMu.Lock() + batching := p.batchingTracks + p.batchingTracksMu.Unlock() + + if batching { + p.negotiationPendingMu.Lock() + p.negotiationPending = true + p.negotiationPendingMu.Unlock() + p.logger.Debug("Negotiation deferred - batching tracks", + zap.String("peer_id", p.ID), + ) + return + } + + // Only create offer if we're in stable state + // Otherwise, mark negotiation as pending + if pc.SignalingState() == webrtc.SignalingStateStable { + p.createAndSendOffer() + } else { + p.negotiationPendingMu.Lock() + p.negotiationPending = true + p.negotiationPendingMu.Unlock() + p.logger.Debug("Negotiation queued - not in stable state", + zap.String("peer_id", p.ID), + zap.String("signaling_state", pc.SignalingState().String()), + ) + } + }) + + // Handle signaling state changes to process pending negotiations + pc.OnSignalingStateChange(func(state webrtc.SignalingState) { + p.logger.Debug("Signaling state changed", + zap.String("peer_id", p.ID), + zap.String("state", state.String()), + ) + + // When we return to stable state, check if negotiation was pending + if state == webrtc.SignalingStateStable { + p.negotiationPendingMu.Lock() + pending := p.negotiationPending + p.negotiationPending = false + p.negotiationPendingMu.Unlock() + + if pending { + p.logger.Debug("Processing pending negotiation", zap.String("peer_id", p.ID)) + p.createAndSendOffer() + } + } + }) + + return nil +} + +// createAndSendOffer creates an SDP offer and sends it to the peer +func (p *Peer) createAndSendOffer() { + if p.pc == nil { + return + } + + // Double-check signaling state before creating offer + if p.pc.SignalingState() != webrtc.SignalingStateStable { + p.logger.Debug("Skipping offer - not in stable state", + zap.String("peer_id", p.ID), + zap.String("signaling_state", p.pc.SignalingState().String()), + ) + // Mark as pending so it will be retried when state becomes stable + p.negotiationPendingMu.Lock() + p.negotiationPending = true + p.negotiationPendingMu.Unlock() + return + } + + p.logger.Info("Creating SDP offer", zap.String("peer_id", p.ID)) + + offer, err := p.pc.CreateOffer(nil) + if err != nil { + p.logger.Error("Failed to create offer", + zap.String("peer_id", p.ID), + zap.Error(err), + ) + return + } + + if err := p.pc.SetLocalDescription(offer); err != nil { + p.logger.Error("Failed to set local description", + zap.String("peer_id", p.ID), + zap.Error(err), + ) + return + } + + p.logger.Info("Sending SDP offer", zap.String("peer_id", p.ID)) + p.SendMessage(NewServerMessage(MessageTypeOffer, &OfferData{ + SDP: offer.SDP, + })) +} + +// HandleOffer processes an SDP offer from the client +func (p *Peer) HandleOffer(sdp string) error { + if p.pc == nil { + return ErrPeerNotInitialized + } + + offer := webrtc.SessionDescription{ + Type: webrtc.SDPTypeOffer, + SDP: sdp, + } + + if err := p.pc.SetRemoteDescription(offer); err != nil { + return err + } + + answer, err := p.pc.CreateAnswer(nil) + if err != nil { + return err + } + + if err := p.pc.SetLocalDescription(answer); err != nil { + return err + } + + p.SendMessage(NewServerMessage(MessageTypeAnswer, &AnswerData{ + SDP: answer.SDP, + })) + + return nil +} + +// HandleAnswer processes an SDP answer from the client +func (p *Peer) HandleAnswer(sdp string) error { + if p.pc == nil { + return ErrPeerNotInitialized + } + + answer := webrtc.SessionDescription{ + Type: webrtc.SDPTypeAnswer, + SDP: sdp, + } + + return p.pc.SetRemoteDescription(answer) +} + +// HandleICECandidate processes an ICE candidate from the client +func (p *Peer) HandleICECandidate(data *ICECandidateData) error { + if p.pc == nil { + return ErrPeerNotInitialized + } + + return p.pc.AddICECandidate(data.ToWebRTCCandidate()) +} + +// AddTrack adds a track to send to this peer (from another peer) +func (p *Peer) AddTrack(track *webrtc.TrackLocalStaticRTP) (*webrtc.RTPSender, error) { + if p.pc == nil { + return nil, ErrPeerNotInitialized + } + + return p.pc.AddTrack(track) +} + +// StartTrackBatch starts batching track additions. +// Call EndTrackBatch when done to trigger a single renegotiation. +func (p *Peer) StartTrackBatch() { + p.batchingTracksMu.Lock() + p.batchingTracks = true + p.batchingTracksMu.Unlock() + p.logger.Debug("Started track batching", zap.String("peer_id", p.ID)) +} + +// EndTrackBatch ends track batching and triggers renegotiation if needed. +func (p *Peer) EndTrackBatch() { + p.batchingTracksMu.Lock() + p.batchingTracks = false + p.batchingTracksMu.Unlock() + + // Check if negotiation was pending during batching + p.negotiationPendingMu.Lock() + pending := p.negotiationPending + p.negotiationPending = false + p.negotiationPendingMu.Unlock() + + if pending && p.pc != nil && p.pc.SignalingState() == webrtc.SignalingStateStable { + p.logger.Debug("Processing batched negotiation", zap.String("peer_id", p.ID)) + p.createAndSendOffer() + } +} + +// SendMessage sends a signaling message to the peer via WebSocket +func (p *Peer) SendMessage(msg *ServerMessage) error { + p.closedMu.RLock() + if p.closed { + p.closedMu.RUnlock() + return ErrPeerClosed + } + p.closedMu.RUnlock() + + p.connMu.Lock() + defer p.connMu.Unlock() + + if p.conn == nil { + return ErrWebSocketClosed + } + + data, err := json.Marshal(msg) + if err != nil { + return err + } + + return p.conn.WriteMessage(websocket.TextMessage, data) +} + +// GetInfo returns public information about this peer +func (p *Peer) GetInfo() ParticipantInfo { + p.localTracksMu.RLock() + hasAudio := false + hasVideo := false + for _, track := range p.localTracks { + if track.Kind() == webrtc.RTPCodecTypeAudio { + hasAudio = true + } else if track.Kind() == webrtc.RTPCodecTypeVideo { + hasVideo = true + } + } + p.localTracksMu.RUnlock() + + return ParticipantInfo{ + ID: p.ID, + UserID: p.UserID, + DisplayName: p.DisplayName, + HasAudio: hasAudio, + HasVideo: hasVideo, + AudioMuted: p.audioMuted, + VideoMuted: p.videoMuted, + } +} + +// SetAudioMuted sets the audio mute state +func (p *Peer) SetAudioMuted(muted bool) { + p.audioMuted = muted +} + +// SetVideoMuted sets the video mute state +func (p *Peer) SetVideoMuted(muted bool) { + p.videoMuted = muted +} + +// MarkInitialOfferHandled marks that the initial offer has been processed. +// Returns true if this is the first time it's being marked (i.e., this was the first offer). +func (p *Peer) MarkInitialOfferHandled() bool { + p.initialOfferMu.Lock() + defer p.initialOfferMu.Unlock() + + if p.initialOfferHandled { + return false + } + p.initialOfferHandled = true + return true +} + +// handleDisconnect handles peer disconnection +func (p *Peer) handleDisconnect() { + p.closedMu.Lock() + if p.closed { + p.closedMu.Unlock() + return + } + p.closed = true + p.closedMu.Unlock() + + if p.onClose != nil { + p.onClose(p) + } +} + +// Close closes the peer connection and cleans up resources +func (p *Peer) Close() error { + p.closedMu.Lock() + if p.closed { + p.closedMu.Unlock() + return nil + } + p.closed = true + p.closedMu.Unlock() + + p.logger.Info("Closing peer", zap.String("peer_id", p.ID)) + + // Close WebSocket + p.connMu.Lock() + if p.conn != nil { + p.conn.Close() + p.conn = nil + } + p.connMu.Unlock() + + // Close peer connection + if p.pc != nil { + return p.pc.Close() + } + + return nil +} + +// OnClose sets a callback for when the peer is closed +func (p *Peer) OnClose(fn func(*Peer)) { + p.onClose = fn +} + +// readRTCP reads RTCP packets from a receiver to monitor feedback +// This helps detect packet loss (via NACK) for adaptive quality adjustments +func (p *Peer) readRTCP(receiver *webrtc.RTPReceiver, track *webrtc.TrackRemote) { + localTrackID := track.Kind().String() + "-" + p.ID + + for { + packets, _, err := receiver.ReadRTCP() + if err != nil { + // Connection closed, exit gracefully + return + } + + for _, pkt := range packets { + switch rtcpPkt := pkt.(type) { + case *rtcp.TransportLayerNack: + // NACK received - indicates packet loss from receivers + // Increment the NACK counter for adaptive keyframe logic + p.room.IncrementNackCount(localTrackID) + + p.logger.Debug("NACK received", + zap.String("peer_id", p.ID), + zap.String("track_id", localTrackID), + zap.Uint32("sender_ssrc", rtcpPkt.SenderSSRC), + zap.Int("nack_pairs", len(rtcpPkt.Nacks)), + ) + + case *rtcp.PictureLossIndication: + // PLI received - receiver needs a keyframe + p.logger.Debug("PLI received from receiver", + zap.String("peer_id", p.ID), + zap.String("track_id", localTrackID), + ) + // Request keyframe from source + p.room.RequestKeyframe(localTrackID) + + case *rtcp.FullIntraRequest: + // FIR received - receiver needs a full keyframe + p.logger.Debug("FIR received from receiver", + zap.String("peer_id", p.ID), + zap.String("track_id", localTrackID), + ) + // Request keyframe from source + p.room.RequestKeyframe(localTrackID) + } + } + } +} diff --git a/pkg/gateway/sfu/room.go b/pkg/gateway/sfu/room.go new file mode 100644 index 0000000..37648e1 --- /dev/null +++ b/pkg/gateway/sfu/room.go @@ -0,0 +1,775 @@ +package sfu + +import ( + "errors" + "sync" + "sync/atomic" + "time" + + "github.com/pion/rtcp" + "github.com/pion/webrtc/v4" + "go.uber.org/zap" +) + +// Common errors +var ( + ErrRoomFull = errors.New("room is full") + ErrRoomClosed = errors.New("room is closed") + ErrPeerNotFound = errors.New("peer not found") + ErrPeerNotInitialized = errors.New("peer not initialized") + ErrPeerClosed = errors.New("peer is closed") + ErrWebSocketClosed = errors.New("websocket connection closed") +) + +// publishedTrack holds information about a track published to the room +type publishedTrack struct { + sourcePeerID string + sourceUserID string + localTrack *webrtc.TrackLocalStaticRTP + remoteTrackSSRC uint32 // SSRC of the remote track (for PLI requests) + remoteTrack *webrtc.TrackRemote // Reference to the remote track + kind string + forwarderActive bool // tracks if the RTP forwarder goroutine is still running + packetCount int // number of packets forwarded (for debugging) + + // Packet loss tracking for adaptive keyframe requests + lastKeyframeRequest time.Time + nackCount atomic.Int64 // Number of NACK packets received (indicates packet loss) +} + +// Room represents a WebRTC room with multiple participants +type Room struct { + ID string + Namespace string + + // Participants in the room + peers map[string]*Peer + peersMu sync.RWMutex + + // Published tracks in the room (for sending to new joiners) + publishedTracks map[string]*publishedTrack // key: trackID + publishedTracksMu sync.RWMutex + + // WebRTC API for creating peer connections + api *webrtc.API + + // Configuration + config *Config + logger *zap.Logger + + // State + closed bool + closedMu sync.RWMutex + + // Callbacks + onEmpty func(*Room) +} + +// NewRoom creates a new room +func NewRoom(id, namespace string, api *webrtc.API, config *Config, logger *zap.Logger) *Room { + return &Room{ + ID: id, + Namespace: namespace, + peers: make(map[string]*Peer), + publishedTracks: make(map[string]*publishedTrack), + api: api, + config: config, + logger: logger.With(zap.String("room_id", id)), + } +} + +// AddPeer adds a new peer to the room +func (r *Room) AddPeer(peer *Peer) error { + r.closedMu.RLock() + if r.closed { + r.closedMu.RUnlock() + return ErrRoomClosed + } + r.closedMu.RUnlock() + + r.peersMu.Lock() + + // Check max participants + if r.config.MaxParticipants > 0 && len(r.peers) >= r.config.MaxParticipants { + r.peersMu.Unlock() + return ErrRoomFull + } + + // Initialize peer connection + pcConfig := webrtc.Configuration{ + ICEServers: r.config.ICEServers, + } + + if err := peer.InitPeerConnection(r.api, pcConfig); err != nil { + r.peersMu.Unlock() + return err + } + + // Set up peer close handler + peer.OnClose(func(p *Peer) { + r.RemovePeer(p.ID) + }) + + r.peers[peer.ID] = peer + peerInfo := peer.GetInfo() // Get info while holding lock + totalPeers := len(r.peers) + + // Release lock BEFORE broadcasting to avoid deadlock + // (broadcastMessage also acquires the lock) + r.peersMu.Unlock() + + r.logger.Info("Peer added to room", + zap.String("peer_id", peer.ID), + zap.String("user_id", peer.UserID), + zap.Int("total_peers", totalPeers), + ) + + // Notify other peers (now safe since we released the lock) + r.broadcastMessage(peer.ID, NewServerMessage(MessageTypeParticipantJoined, &ParticipantJoinedData{ + Participant: peerInfo, + })) + + return nil +} + +// RemovePeer removes a peer from the room +func (r *Room) RemovePeer(peerID string) error { + r.peersMu.Lock() + peer, ok := r.peers[peerID] + if !ok { + r.peersMu.Unlock() + return ErrPeerNotFound + } + + delete(r.peers, peerID) + remainingPeers := len(r.peers) + r.peersMu.Unlock() + + // Remove tracks published by this peer + r.publishedTracksMu.Lock() + removedTracks := make([]string, 0) + for trackID, track := range r.publishedTracks { + if track.sourcePeerID == peerID { + delete(r.publishedTracks, trackID) + removedTracks = append(removedTracks, trackID) + } + } + r.publishedTracksMu.Unlock() + + if len(removedTracks) > 0 { + r.logger.Info("Removed tracks for departing peer", + zap.String("peer_id", peerID), + zap.Strings("track_ids", removedTracks), + ) + } + + // Close the peer + if err := peer.Close(); err != nil { + r.logger.Warn("Error closing peer", + zap.String("peer_id", peerID), + zap.Error(err), + ) + } + + r.logger.Info("Peer removed from room", + zap.String("peer_id", peerID), + zap.Int("remaining_peers", remainingPeers), + ) + + // Notify other peers + r.broadcastMessage(peerID, NewServerMessage(MessageTypeParticipantLeft, &ParticipantLeftData{ + ParticipantID: peerID, + })) + + // Check if room is empty + if remainingPeers == 0 && r.onEmpty != nil { + r.onEmpty(r) + } + + return nil +} + +// GetPeer returns a peer by ID +func (r *Room) GetPeer(peerID string) (*Peer, error) { + r.peersMu.RLock() + defer r.peersMu.RUnlock() + + peer, ok := r.peers[peerID] + if !ok { + return nil, ErrPeerNotFound + } + + return peer, nil +} + +// GetPeers returns all peers in the room +func (r *Room) GetPeers() []*Peer { + r.peersMu.RLock() + defer r.peersMu.RUnlock() + + peers := make([]*Peer, 0, len(r.peers)) + for _, peer := range r.peers { + peers = append(peers, peer) + } + return peers +} + +// GetParticipants returns info about all participants +func (r *Room) GetParticipants() []ParticipantInfo { + r.peersMu.RLock() + defer r.peersMu.RUnlock() + + participants := make([]ParticipantInfo, 0, len(r.peers)) + for _, peer := range r.peers { + participants = append(participants, peer.GetInfo()) + } + return participants +} + +// GetParticipantCount returns the number of participants +func (r *Room) GetParticipantCount() int { + r.peersMu.RLock() + defer r.peersMu.RUnlock() + return len(r.peers) +} + +// SendExistingTracksTo sends all existing tracks from other participants to the specified peer. +// This should be called AFTER the welcome message is sent to ensure the client is ready. +// Uses batch mode to send all tracks with a single renegotiation for faster joins. +func (r *Room) SendExistingTracksTo(peer *Peer) { + r.publishedTracksMu.RLock() + existingTracks := make([]*publishedTrack, 0, len(r.publishedTracks)) + for _, track := range r.publishedTracks { + // Don't send peer's own tracks back to them + if track.sourcePeerID != peer.ID { + existingTracks = append(existingTracks, track) + } + } + r.publishedTracksMu.RUnlock() + + if len(existingTracks) == 0 { + r.logger.Info("No existing tracks to send to new peer", + zap.String("peer_id", peer.ID), + ) + return + } + + r.logger.Info("Sending existing tracks to new peer (batch mode)", + zap.String("peer_id", peer.ID), + zap.Int("track_count", len(existingTracks)), + ) + + videoTrackIDs := make([]string, 0) + + // Start batch mode - suppresses individual renegotiations + peer.StartTrackBatch() + + for _, track := range existingTracks { + // Log forwarder status to help diagnose video issues + r.logger.Info("Adding existing track to new peer", + zap.String("new_peer_id", peer.ID), + zap.String("source_peer_id", track.sourcePeerID), + zap.String("source_user_id", track.sourceUserID), + zap.String("track_id", track.localTrack.ID()), + zap.String("kind", track.kind), + zap.Bool("forwarder_active", track.forwarderActive), + zap.Int("packets_forwarded", track.packetCount), + ) + + // Warn if forwarder is no longer active + if !track.forwarderActive { + r.logger.Warn("WARNING: Track forwarder is NOT active - track may not receive data", + zap.String("track_id", track.localTrack.ID()), + zap.String("kind", track.kind), + zap.Int("final_packet_count", track.packetCount), + ) + } + + if _, err := peer.AddTrack(track.localTrack); err != nil { + r.logger.Warn("Failed to add existing track to new peer", + zap.String("peer_id", peer.ID), + zap.String("track_id", track.localTrack.ID()), + zap.Error(err), + ) + continue + } + + // Track video tracks for keyframe requests + if track.kind == "video" { + videoTrackIDs = append(videoTrackIDs, track.localTrack.ID()) + } + + // Notify new peer about the existing track + peer.SendMessage(NewServerMessage(MessageTypeTrackAdded, &TrackAddedData{ + ParticipantID: track.sourcePeerID, + UserID: track.sourceUserID, + TrackID: track.localTrack.ID(), + StreamID: track.localTrack.StreamID(), + Kind: track.kind, + })) + } + + // End batch mode - triggers single renegotiation for all tracks + peer.EndTrackBatch() + + r.logger.Info("Batch track addition complete - single renegotiation triggered", + zap.String("peer_id", peer.ID), + zap.Int("total_tracks", len(existingTracks)), + ) + + // Request keyframes for video tracks after a short delay + // This ensures the receiver has time to set up the track before receiving the keyframe + if len(videoTrackIDs) > 0 { + go func() { + // Wait for negotiation to complete + time.Sleep(300 * time.Millisecond) + r.logger.Info("Requesting keyframes for new peer", + zap.String("peer_id", peer.ID), + zap.Int("video_track_count", len(videoTrackIDs)), + ) + for _, trackID := range videoTrackIDs { + r.RequestKeyframe(trackID) + } + // Request again after 500ms in case the first was too early + time.Sleep(500 * time.Millisecond) + for _, trackID := range videoTrackIDs { + r.RequestKeyframe(trackID) + } + }() + } +} + +// BroadcastTrack broadcasts a track from one peer to all other peers +func (r *Room) BroadcastTrack(sourcePeerID string, track *webrtc.TrackRemote) { + r.peersMu.RLock() + defer r.peersMu.RUnlock() + + // Get source peer's user ID for the track-added message + sourceUserID := "" + if sourcePeer, ok := r.peers[sourcePeerID]; ok { + sourceUserID = sourcePeer.UserID + } + + // Create a local track from the remote track + // Use participant ID as the stream ID so clients can identify the source + // Format: trackId stays the same, streamId = sourcePeerID (or sourceUserID if available) + streamID := sourcePeerID + if sourceUserID != "" { + streamID = sourceUserID // Use userID for easier client-side mapping + } + + r.logger.Info("Creating local track for broadcast", + zap.String("source_peer_id", sourcePeerID), + zap.String("source_user_id", sourceUserID), + zap.String("original_track_id", track.ID()), + zap.String("original_stream_id", track.StreamID()), + zap.String("new_stream_id", streamID), + ) + + // Log codec information for debugging + codec := track.Codec() + r.logger.Info("Track codec info", + zap.String("source_peer_id", sourcePeerID), + zap.String("track_kind", track.Kind().String()), + zap.String("mime_type", codec.MimeType), + zap.Uint32("clock_rate", codec.ClockRate), + zap.Uint16("channels", codec.Channels), + zap.String("sdp_fmtp_line", codec.SDPFmtpLine), + ) + + localTrack, err := webrtc.NewTrackLocalStaticRTP( + codec.RTPCodecCapability, + track.Kind().String()+"-"+sourcePeerID, // Include peer ID in track ID + streamID, // Use participant/user ID as stream ID + ) + if err != nil { + r.logger.Error("Failed to create local track", + zap.String("source_peer", sourcePeerID), + zap.String("track_id", track.ID()), + zap.Error(err), + ) + return + } + + // Store the track for new joiners + pubTrack := &publishedTrack{ + sourcePeerID: sourcePeerID, + sourceUserID: sourceUserID, + localTrack: localTrack, + remoteTrackSSRC: uint32(track.SSRC()), + remoteTrack: track, + kind: track.Kind().String(), + forwarderActive: true, + packetCount: 0, + lastKeyframeRequest: time.Now(), + } + r.publishedTracksMu.Lock() + r.publishedTracks[localTrack.ID()] = pubTrack + r.publishedTracksMu.Unlock() + + r.logger.Info("Track stored for new joiners", + zap.String("track_id", localTrack.ID()), + zap.String("source_peer_id", sourcePeerID), + zap.Int("total_published_tracks", len(r.publishedTracks)), + ) + + // Forward RTP packets from remote track to local track + go func() { + trackID := track.ID() + localTrackID := localTrack.ID() + trackKind := track.Kind().String() + buf := make([]byte, 1600) // Slightly larger than MTU to handle RTP extensions + packetCount := 0 + byteCount := 0 + startTime := time.Now() + firstPacketReceived := false + + r.logger.Info("RTP forwarder started", + zap.String("track_id", trackID), + zap.String("local_track_id", localTrackID), + zap.String("kind", trackKind), + zap.String("source_peer_id", sourcePeerID), + ) + + // Start a goroutine to log warning if no packets received after 5 seconds + go func() { + time.Sleep(5 * time.Second) + if !firstPacketReceived { + r.logger.Warn("RTP forwarder WARNING: No packets received after 5 seconds - host may not be sending", + zap.String("track_id", trackID), + zap.String("local_track_id", localTrackID), + zap.String("kind", trackKind), + zap.String("source_peer_id", sourcePeerID), + ) + } + }() + + // For video tracks, use adaptive keyframe requests based on packet loss + if trackKind == "video" { + go func() { + // Use a faster ticker for checking, but only send keyframes when needed + ticker := time.NewTicker(500 * time.Millisecond) + defer ticker.Stop() + + var lastNackCount int64 + var consecutiveLossDetections int + baseInterval := 3 * time.Second + minInterval := 500 * time.Millisecond // Minimum interval between keyframes + lastKeyframeTime := time.Now() + + for range ticker.C { + // Check if room is closed + if r.IsClosed() { + return + } + + // Check if forwarder is still active + r.publishedTracksMu.RLock() + pt, ok := r.publishedTracks[localTrackID] + if !ok || !pt.forwarderActive { + r.publishedTracksMu.RUnlock() + return + } + + // Get current NACK count to detect packet loss + currentNackCount := pt.nackCount.Load() + r.publishedTracksMu.RUnlock() + + timeSinceLastKeyframe := time.Since(lastKeyframeTime) + + // Detect if packet loss is happening (NACKs increasing) + if currentNackCount > lastNackCount { + consecutiveLossDetections++ + lastNackCount = currentNackCount + + // If we detect packet loss and haven't requested a keyframe recently, + // request one immediately to help receivers recover + if timeSinceLastKeyframe >= minInterval { + r.logger.Debug("Adaptive keyframe request due to packet loss", + zap.String("track_id", localTrackID), + zap.Int64("nack_count", currentNackCount), + zap.Int("consecutive_loss_detections", consecutiveLossDetections), + ) + r.RequestKeyframe(localTrackID) + lastKeyframeTime = time.Now() + } + } else { + // Reset consecutive loss counter when no new NACKs + if consecutiveLossDetections > 0 { + consecutiveLossDetections-- + } + } + + // Regular keyframe request at base interval (regardless of loss) + if timeSinceLastKeyframe >= baseInterval { + r.RequestKeyframe(localTrackID) + lastKeyframeTime = time.Now() + } + } + }() + } + + // Mark forwarder as stopped when we exit + defer func() { + r.publishedTracksMu.Lock() + if pt, ok := r.publishedTracks[localTrackID]; ok { + pt.forwarderActive = false + pt.packetCount = packetCount + } + r.publishedTracksMu.Unlock() + + r.logger.Info("RTP forwarder exiting", + zap.String("track_id", trackID), + zap.String("local_track_id", localTrackID), + zap.String("kind", trackKind), + zap.Duration("lifetime", time.Since(startTime)), + zap.Int("total_packets", packetCount), + zap.Int("total_bytes", byteCount), + ) + }() + + for { + n, _, readErr := track.Read(buf) + if readErr != nil { + r.logger.Info("RTP forwarder stopped - read error", + zap.String("track_id", trackID), + zap.String("local_track_id", localTrackID), + zap.String("kind", trackKind), + zap.Int("packets_forwarded", packetCount), + zap.Int("bytes_forwarded", byteCount), + zap.Error(readErr), + ) + return + } + + // Log first packet to confirm data is flowing + if packetCount == 0 { + firstPacketReceived = true + r.logger.Info("RTP forwarder received FIRST packet", + zap.String("track_id", trackID), + zap.String("local_track_id", localTrackID), + zap.String("kind", trackKind), + zap.Int("packet_size", n), + zap.Duration("time_to_first_packet", time.Since(startTime)), + ) + } + + if _, writeErr := localTrack.Write(buf[:n]); writeErr != nil { + r.logger.Info("RTP forwarder stopped - write error", + zap.String("track_id", trackID), + zap.String("local_track_id", localTrackID), + zap.String("kind", trackKind), + zap.Int("packets_forwarded", packetCount), + zap.Int("bytes_forwarded", byteCount), + zap.Error(writeErr), + ) + return + } + + packetCount++ + byteCount += n + + // Update packet count periodically (not every packet to reduce lock contention) + if packetCount%50 == 0 { + r.publishedTracksMu.Lock() + if pt, ok := r.publishedTracks[localTrackID]; ok { + pt.packetCount = packetCount + } + r.publishedTracksMu.Unlock() + } + + // Log progress every 100 packets for video, 500 for audio + logInterval := 500 + if trackKind == "video" { + logInterval = 100 + } + if packetCount%logInterval == 0 { + r.logger.Info("RTP forwarder progress", + zap.String("track_id", trackID), + zap.String("kind", trackKind), + zap.Int("packets", packetCount), + zap.Int("bytes", byteCount), + ) + } + } + }() + + // Add track to all other peers + for peerID, peer := range r.peers { + if peerID == sourcePeerID { + continue + } + + r.logger.Info("Adding track to peer", + zap.String("target_peer", peerID), + zap.String("source_peer", sourcePeerID), + zap.String("source_user", sourceUserID), + zap.String("track_id", localTrack.ID()), + zap.String("stream_id", localTrack.StreamID()), + ) + + if _, err := peer.AddTrack(localTrack); err != nil { + r.logger.Warn("Failed to add track to peer", + zap.String("target_peer", peerID), + zap.String("track_id", localTrack.ID()), + zap.Error(err), + ) + continue + } + + // Notify peer about new track + // Use consistent IDs: participantId=sourcePeerID, streamId matches the track + peer.SendMessage(NewServerMessage(MessageTypeTrackAdded, &TrackAddedData{ + ParticipantID: sourcePeerID, + UserID: sourceUserID, + TrackID: localTrack.ID(), // Format: "{kind}-{participantId}" + StreamID: localTrack.StreamID(), // Same as userId for easy matching + Kind: track.Kind().String(), + })) + } + + r.logger.Info("Track broadcast to room", + zap.String("source_peer", sourcePeerID), + zap.String("source_user", sourceUserID), + zap.String("track_id", track.ID()), + zap.String("kind", track.Kind().String()), + ) +} + +// broadcastMessage sends a message to all peers except the specified one +func (r *Room) broadcastMessage(excludePeerID string, msg *ServerMessage) { + r.peersMu.RLock() + defer r.peersMu.RUnlock() + + for peerID, peer := range r.peers { + if peerID == excludePeerID { + continue + } + + if err := peer.SendMessage(msg); err != nil { + r.logger.Warn("Failed to send message to peer", + zap.String("peer_id", peerID), + zap.Error(err), + ) + } + } +} + +// Close closes the room and all peer connections +func (r *Room) Close() error { + r.closedMu.Lock() + if r.closed { + r.closedMu.Unlock() + return nil + } + r.closed = true + r.closedMu.Unlock() + + r.logger.Info("Closing room") + + r.peersMu.Lock() + peers := make([]*Peer, 0, len(r.peers)) + for _, peer := range r.peers { + peers = append(peers, peer) + } + r.peers = make(map[string]*Peer) + r.peersMu.Unlock() + + // Close all peers + for _, peer := range peers { + if err := peer.Close(); err != nil { + r.logger.Warn("Error closing peer", + zap.String("peer_id", peer.ID), + zap.Error(err), + ) + } + } + + return nil +} + +// OnEmpty sets a callback for when the room becomes empty +func (r *Room) OnEmpty(fn func(*Room)) { + r.onEmpty = fn +} + +// IsClosed returns whether the room is closed +func (r *Room) IsClosed() bool { + r.closedMu.RLock() + defer r.closedMu.RUnlock() + return r.closed +} + +// RequestKeyframe sends a PLI (Picture Loss Indication) to the source peer for a video track. +// This causes the source to send a keyframe, which is needed for new receivers to start decoding. +func (r *Room) RequestKeyframe(trackID string) { + r.publishedTracksMu.RLock() + track, ok := r.publishedTracks[trackID] + r.publishedTracksMu.RUnlock() + + if !ok || track.kind != "video" { + return + } + + r.peersMu.RLock() + sourcePeer, ok := r.peers[track.sourcePeerID] + r.peersMu.RUnlock() + + if !ok || sourcePeer.pc == nil { + r.logger.Debug("Cannot request keyframe - source peer not found", + zap.String("track_id", trackID), + zap.String("source_peer_id", track.sourcePeerID), + ) + return + } + + // Create a PLI packet + pli := &rtcp.PictureLossIndication{ + MediaSSRC: track.remoteTrackSSRC, + } + + // Send the PLI to the source peer + if err := sourcePeer.pc.WriteRTCP([]rtcp.Packet{pli}); err != nil { + r.logger.Debug("Failed to send PLI", + zap.String("track_id", trackID), + zap.String("source_peer_id", track.sourcePeerID), + zap.Error(err), + ) + return + } + + r.logger.Debug("PLI keyframe request sent", + zap.String("track_id", trackID), + zap.String("source_peer_id", track.sourcePeerID), + zap.Uint32("ssrc", track.remoteTrackSSRC), + ) +} + +// RequestKeyframeForAllVideoTracks sends PLI requests for all video tracks in the room. +// This is useful when a new peer joins to ensure they get keyframes quickly. +func (r *Room) RequestKeyframeForAllVideoTracks() { + r.publishedTracksMu.RLock() + videoTrackIDs := make([]string, 0) + for trackID, track := range r.publishedTracks { + if track.kind == "video" { + videoTrackIDs = append(videoTrackIDs, trackID) + } + } + r.publishedTracksMu.RUnlock() + + for _, trackID := range videoTrackIDs { + r.RequestKeyframe(trackID) + } +} + +// IncrementNackCount increments the NACK counter for a track. +// This is called when we receive NACK feedback indicating packet loss. +func (r *Room) IncrementNackCount(trackID string) { + r.publishedTracksMu.RLock() + track, ok := r.publishedTracks[trackID] + r.publishedTracksMu.RUnlock() + + if ok { + track.nackCount.Add(1) + } +} diff --git a/pkg/gateway/sfu/signaling.go b/pkg/gateway/sfu/signaling.go new file mode 100644 index 0000000..4e2187d --- /dev/null +++ b/pkg/gateway/sfu/signaling.go @@ -0,0 +1,145 @@ +package sfu + +import ( + "encoding/json" + + "github.com/pion/webrtc/v4" +) + +// MessageType represents the type of signaling message +type MessageType string + +const ( + // Client -> Server message types + MessageTypeJoin MessageType = "join" + MessageTypeLeave MessageType = "leave" + MessageTypeOffer MessageType = "offer" + MessageTypeAnswer MessageType = "answer" + MessageTypeICECandidate MessageType = "ice-candidate" + MessageTypeMute MessageType = "mute" + MessageTypeUnmute MessageType = "unmute" + MessageTypeStartVideo MessageType = "start-video" + MessageTypeStopVideo MessageType = "stop-video" + + // Server -> Client message types + MessageTypeWelcome MessageType = "welcome" + MessageTypeParticipantJoined MessageType = "participant-joined" + MessageTypeParticipantLeft MessageType = "participant-left" + MessageTypeTrackAdded MessageType = "track-added" + MessageTypeTrackRemoved MessageType = "track-removed" + MessageTypeError MessageType = "error" +) + +// ClientMessage represents a message from client to server +type ClientMessage struct { + Type MessageType `json:"type"` + Data json.RawMessage `json:"data,omitempty"` +} + +// ServerMessage represents a message from server to client +type ServerMessage struct { + Type MessageType `json:"type"` + Data interface{} `json:"data,omitempty"` +} + +// JoinData is the payload for join messages +type JoinData struct { + DisplayName string `json:"displayName"` + AudioOnly bool `json:"audioOnly,omitempty"` +} + +// OfferData is the payload for offer messages +type OfferData struct { + SDP string `json:"sdp"` +} + +// AnswerData is the payload for answer messages +type AnswerData struct { + SDP string `json:"sdp"` +} + +// ICECandidateData is the payload for ICE candidate messages +type ICECandidateData struct { + Candidate string `json:"candidate"` + SDPMid string `json:"sdpMid,omitempty"` + SDPMLineIndex uint16 `json:"sdpMLineIndex,omitempty"` + UsernameFragment string `json:"usernameFragment,omitempty"` +} + +// ToWebRTCCandidate converts ICECandidateData to webrtc.ICECandidateInit +func (c *ICECandidateData) ToWebRTCCandidate() webrtc.ICECandidateInit { + return webrtc.ICECandidateInit{ + Candidate: c.Candidate, + SDPMid: &c.SDPMid, + SDPMLineIndex: &c.SDPMLineIndex, + UsernameFragment: &c.UsernameFragment, + } +} + +// WelcomeData is sent to a client when they successfully join +type WelcomeData struct { + ParticipantID string `json:"participantId"` + RoomID string `json:"roomId"` + Participants []ParticipantInfo `json:"participants"` +} + +// ParticipantInfo contains public information about a participant +type ParticipantInfo struct { + ID string `json:"id"` + UserID string `json:"userId"` + DisplayName string `json:"displayName"` + HasAudio bool `json:"hasAudio"` + HasVideo bool `json:"hasVideo"` + AudioMuted bool `json:"audioMuted"` + VideoMuted bool `json:"videoMuted"` +} + +// ParticipantJoinedData is sent when a new participant joins +type ParticipantJoinedData struct { + Participant ParticipantInfo `json:"participant"` +} + +// ParticipantLeftData is sent when a participant leaves +type ParticipantLeftData struct { + ParticipantID string `json:"participantId"` +} + +// TrackAddedData is sent when a new track is available +type TrackAddedData struct { + ParticipantID string `json:"participantId"` // Internal SFU peer ID + UserID string `json:"userId"` // The user's actual ID for easier mapping + TrackID string `json:"trackId"` // Format: "{kind}-{participantId}" + StreamID string `json:"streamId"` // Same as userId (for WebRTC stream matching) + Kind string `json:"kind"` // "audio" or "video" +} + +// TrackRemovedData is sent when a track is removed +type TrackRemovedData struct { + ParticipantID string `json:"participantId"` + UserID string `json:"userId"` // The user's actual ID for easier mapping + TrackID string `json:"trackId"` + StreamID string `json:"streamId"` + Kind string `json:"kind"` +} + +// ErrorData is sent when an error occurs +type ErrorData struct { + Code string `json:"code"` + Message string `json:"message"` +} + +// NewServerMessage creates a new server message +func NewServerMessage(msgType MessageType, data interface{}) *ServerMessage { + return &ServerMessage{ + Type: msgType, + Data: data, + } +} + +// NewErrorMessage creates a new error message +func NewErrorMessage(code, message string) *ServerMessage { + return NewServerMessage(MessageTypeError, &ErrorData{ + Code: code, + Message: message, + }) +} diff --git a/pkg/gateway/sfu_handlers.go b/pkg/gateway/sfu_handlers.go new file mode 100644 index 0000000..5b4ab80 --- /dev/null +++ b/pkg/gateway/sfu_handlers.go @@ -0,0 +1,573 @@ +package gateway + +import ( + "encoding/json" + "net/http" + "strings" + "time" + + "github.com/DeBrosOfficial/network/pkg/gateway/sfu" + "github.com/DeBrosOfficial/network/pkg/logging" + "github.com/gorilla/websocket" + "github.com/pion/webrtc/v4" + "go.uber.org/zap" +) + +// SFUManager is a type alias for the SFU room manager +type SFUManager = sfu.RoomManager + +// CreateRoomRequest is the request body for creating a room +type CreateRoomRequest struct { + RoomID string `json:"roomId"` +} + +// CreateRoomResponse is the response for creating/joining a room +type CreateRoomResponse struct { + RoomID string `json:"roomId"` + Created bool `json:"created"` + RTPCapabilities map[string]interface{} `json:"rtpCapabilities"` +} + +// JoinRoomRequest is the request body for joining a room +type JoinRoomRequest struct { + DisplayName string `json:"displayName"` +} + +// JoinRoomResponse is the response for joining a room +type JoinRoomResponse struct { + ParticipantID string `json:"participantId"` + Participants []sfu.ParticipantInfo `json:"participants"` +} + +// sfuCreateRoomHandler handles POST /v1/sfu/room +func (g *Gateway) sfuCreateRoomHandler(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + writeError(w, http.StatusMethodNotAllowed, "method not allowed") + return + } + + // Check if SFU is enabled + if g.sfuManager == nil { + writeError(w, http.StatusServiceUnavailable, "SFU service not enabled") + return + } + + // Get namespace from auth context + ns := resolveNamespaceFromRequest(r) + if ns == "" { + writeError(w, http.StatusForbidden, "namespace not resolved") + return + } + + // Parse request + var req CreateRoomRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeError(w, http.StatusBadRequest, "invalid request body") + return + } + + if req.RoomID == "" { + writeError(w, http.StatusBadRequest, "roomId is required") + return + } + + // Create or get room + room, created := g.sfuManager.GetOrCreateRoom(ns, req.RoomID) + + g.logger.ComponentInfo(logging.ComponentGeneral, "SFU room request", + zap.String("room_id", req.RoomID), + zap.String("namespace", ns), + zap.Bool("created", created), + zap.Int("participants", room.GetParticipantCount()), + ) + + writeJSON(w, http.StatusOK, &CreateRoomResponse{ + RoomID: req.RoomID, + Created: created, + RTPCapabilities: sfu.GetRTPCapabilities(), + }) +} + +// sfuRoomHandler handles all /v1/sfu/room/:roomId/* endpoints +func (g *Gateway) sfuRoomHandler(w http.ResponseWriter, r *http.Request) { + // Check if SFU is enabled + if g.sfuManager == nil { + writeError(w, http.StatusServiceUnavailable, "SFU service not enabled") + return + } + + // Get namespace from auth context + ns := resolveNamespaceFromRequest(r) + if ns == "" { + writeError(w, http.StatusForbidden, "namespace not resolved") + return + } + + // Parse room ID and action from path + // Path format: /v1/sfu/room/{roomId}/{action} + path := strings.TrimPrefix(r.URL.Path, "/v1/sfu/room/") + parts := strings.SplitN(path, "/", 2) + + if len(parts) == 0 || parts[0] == "" { + writeError(w, http.StatusBadRequest, "room ID required") + return + } + + roomID := parts[0] + action := "" + if len(parts) > 1 { + action = parts[1] + } + + // Route to appropriate handler + switch action { + case "": + // GET /v1/sfu/room/:roomId - Get room info + g.sfuGetRoomHandler(w, r, ns, roomID) + case "join": + // POST /v1/sfu/room/:roomId/join - Join room + g.sfuJoinRoomHandler(w, r, ns, roomID) + case "leave": + // POST /v1/sfu/room/:roomId/leave - Leave room + g.sfuLeaveRoomHandler(w, r, ns, roomID) + case "participants": + // GET /v1/sfu/room/:roomId/participants - List participants + g.sfuParticipantsHandler(w, r, ns, roomID) + case "ws": + // GET /v1/sfu/room/:roomId/ws - WebSocket signaling + g.sfuWebSocketHandler(w, r, ns, roomID) + default: + writeError(w, http.StatusNotFound, "unknown action") + } +} + +// sfuGetRoomHandler handles GET /v1/sfu/room/:roomId +func (g *Gateway) sfuGetRoomHandler(w http.ResponseWriter, r *http.Request, ns, roomID string) { + if r.Method != http.MethodGet { + writeError(w, http.StatusMethodNotAllowed, "method not allowed") + return + } + + info, err := g.sfuManager.GetRoomInfo(ns, roomID) + if err != nil { + if err == sfu.ErrRoomNotFound { + writeError(w, http.StatusNotFound, "room not found") + } else { + writeError(w, http.StatusInternalServerError, err.Error()) + } + return + } + + writeJSON(w, http.StatusOK, info) +} + +// sfuJoinRoomHandler handles POST /v1/sfu/room/:roomId/join +func (g *Gateway) sfuJoinRoomHandler(w http.ResponseWriter, r *http.Request, ns, roomID string) { + if r.Method != http.MethodPost { + writeError(w, http.StatusMethodNotAllowed, "method not allowed") + return + } + + // Parse request + var req JoinRoomRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + // Default display name if not provided + req.DisplayName = "Anonymous" + } + + // Get user ID from query param or auth context + userID := r.URL.Query().Get("userId") + if userID == "" { + userID = g.extractUserID(r) + } + if userID == "" { + userID = "anonymous" + } + + // Get or create room + room, _ := g.sfuManager.GetOrCreateRoom(ns, roomID) + + // This endpoint just returns room info + // The actual peer connection is established via WebSocket + writeJSON(w, http.StatusOK, &JoinRoomResponse{ + ParticipantID: "", // Will be assigned when WebSocket connects + Participants: room.GetParticipants(), + }) +} + +// sfuLeaveRoomHandler handles POST /v1/sfu/room/:roomId/leave +func (g *Gateway) sfuLeaveRoomHandler(w http.ResponseWriter, r *http.Request, ns, roomID string) { + if r.Method != http.MethodPost { + writeError(w, http.StatusMethodNotAllowed, "method not allowed") + return + } + + // Parse participant ID from request body + var req struct { + ParticipantID string `json:"participantId"` + } + if err := json.NewDecoder(r.Body).Decode(&req); err != nil || req.ParticipantID == "" { + writeError(w, http.StatusBadRequest, "participantId required") + return + } + + room, err := g.sfuManager.GetRoom(ns, roomID) + if err != nil { + if err == sfu.ErrRoomNotFound { + writeError(w, http.StatusNotFound, "room not found") + } else { + writeError(w, http.StatusInternalServerError, err.Error()) + } + return + } + + if err := room.RemovePeer(req.ParticipantID); err != nil { + if err == sfu.ErrPeerNotFound { + writeError(w, http.StatusNotFound, "participant not found") + } else { + writeError(w, http.StatusInternalServerError, err.Error()) + } + return + } + + writeJSON(w, http.StatusOK, map[string]interface{}{ + "success": true, + }) +} + +// sfuParticipantsHandler handles GET /v1/sfu/room/:roomId/participants +func (g *Gateway) sfuParticipantsHandler(w http.ResponseWriter, r *http.Request, ns, roomID string) { + if r.Method != http.MethodGet { + writeError(w, http.StatusMethodNotAllowed, "method not allowed") + return + } + + room, err := g.sfuManager.GetRoom(ns, roomID) + if err != nil { + if err == sfu.ErrRoomNotFound { + writeError(w, http.StatusNotFound, "room not found") + } else { + writeError(w, http.StatusInternalServerError, err.Error()) + } + return + } + + writeJSON(w, http.StatusOK, map[string]interface{}{ + "participants": room.GetParticipants(), + }) +} + +// sfuWebSocketHandler handles WebSocket signaling for a room +func (g *Gateway) sfuWebSocketHandler(w http.ResponseWriter, r *http.Request, ns, roomID string) { + if r.Method != http.MethodGet { + writeError(w, http.StatusMethodNotAllowed, "method not allowed") + return + } + + // Get user ID and display name from query parameters + // Priority: 1) userId query param, 2) JWT/API key auth, 3) "anonymous" + userID := r.URL.Query().Get("userId") + if userID == "" { + // Fall back to authentication-based user ID + userID = g.extractUserID(r) + } + if userID == "" { + userID = "anonymous" + } + + displayName := r.URL.Query().Get("displayName") + if displayName == "" { + displayName = "Anonymous" + } + + // Upgrade to WebSocket + conn, err := wsUpgrader.Upgrade(w, r, nil) + if err != nil { + g.logger.ComponentWarn(logging.ComponentGeneral, "SFU WebSocket upgrade failed", + zap.Error(err), + ) + return + } + + // Recover from panics to avoid silent crashes + defer func() { + if r := recover(); r != nil { + g.logger.ComponentError(logging.ComponentGeneral, "SFU WebSocket handler panic", + zap.Any("panic", r), + zap.String("room_id", roomID), + zap.String("user_id", userID), + ) + conn.Close() + } + }() + + g.logger.ComponentInfo(logging.ComponentGeneral, "SFU WebSocket connected", + zap.String("room_id", roomID), + zap.String("namespace", ns), + zap.String("user_id", userID), + zap.String("display_name", displayName), + ) + + // Get or create room + g.logger.ComponentDebug(logging.ComponentGeneral, "SFU getting/creating room", + zap.String("room_id", roomID), + zap.String("namespace", ns), + ) + room, created := g.sfuManager.GetOrCreateRoom(ns, roomID) + g.logger.ComponentInfo(logging.ComponentGeneral, "SFU room ready", + zap.String("room_id", roomID), + zap.Bool("created", created), + ) + + // Create peer + peer := sfu.NewPeer(userID, displayName, conn, room, g.logger.Logger) + g.logger.ComponentInfo(logging.ComponentGeneral, "SFU peer created", + zap.String("peer_id", peer.ID), + ) + + // Add peer to room + g.logger.ComponentDebug(logging.ComponentGeneral, "SFU adding peer to room (will init peer connection)", + zap.String("peer_id", peer.ID), + zap.String("room_id", roomID), + ) + if err := room.AddPeer(peer); err != nil { + g.logger.ComponentError(logging.ComponentGeneral, "Failed to add peer to room", + zap.String("room_id", roomID), + zap.Error(err), + ) + peer.SendMessage(sfu.NewErrorMessage("join_failed", err.Error())) + conn.Close() + return + } + g.logger.ComponentInfo(logging.ComponentGeneral, "SFU peer added to room successfully", + zap.String("peer_id", peer.ID), + zap.String("room_id", roomID), + ) + + // Send welcome message + g.logger.ComponentDebug(logging.ComponentGeneral, "SFU preparing welcome message", + zap.String("peer_id", peer.ID), + zap.Int("num_participants", len(room.GetParticipants())), + ) + welcomeMsg := sfu.NewServerMessage(sfu.MessageTypeWelcome, &sfu.WelcomeData{ + ParticipantID: peer.ID, + RoomID: roomID, + Participants: room.GetParticipants(), + }) + g.logger.ComponentDebug(logging.ComponentGeneral, "SFU sending welcome message via SendMessage", + zap.String("peer_id", peer.ID), + ) + if err := peer.SendMessage(welcomeMsg); err != nil { + g.logger.ComponentError(logging.ComponentGeneral, "Failed to send welcome message", + zap.String("peer_id", peer.ID), + zap.Error(err), + ) + conn.Close() + return + } + g.logger.ComponentInfo(logging.ComponentGeneral, "SFU welcome message sent successfully", + zap.String("peer_id", peer.ID), + zap.String("room_id", roomID), + ) + + // Handle signaling messages + // Note: existing tracks are sent AFTER the first offer/answer exchange completes + g.handleSFUSignaling(conn, peer, room) +} + +// handleSFUSignaling handles WebSocket signaling messages for a peer +func (g *Gateway) handleSFUSignaling(conn *websocket.Conn, peer *sfu.Peer, room *sfu.Room) { + defer func() { + room.RemovePeer(peer.ID) + conn.Close() + g.logger.ComponentInfo(logging.ComponentGeneral, "SFU WebSocket disconnected", + zap.String("peer_id", peer.ID), + zap.String("room_id", room.ID), + ) + }() + + // Set up ping/pong for keepalive + conn.SetReadDeadline(time.Now().Add(60 * time.Second)) + conn.SetPongHandler(func(string) error { + conn.SetReadDeadline(time.Now().Add(60 * time.Second)) + return nil + }) + + // Start ping ticker + ticker := time.NewTicker(30 * time.Second) + defer ticker.Stop() + + go func() { + for range ticker.C { + if err := conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(5*time.Second)); err != nil { + return + } + } + }() + + // Read messages + for { + _, data, err := conn.ReadMessage() + if err != nil { + if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { + g.logger.ComponentWarn(logging.ComponentGeneral, "SFU WebSocket read error", + zap.String("peer_id", peer.ID), + zap.Error(err), + ) + } + return + } + + // Reset read deadline + conn.SetReadDeadline(time.Now().Add(60 * time.Second)) + + // Parse message + var msg sfu.ClientMessage + if err := json.Unmarshal(data, &msg); err != nil { + peer.SendMessage(sfu.NewErrorMessage("invalid_message", "failed to parse message")) + continue + } + + // Handle message + g.handleSFUMessage(peer, room, &msg) + } +} + +// handleSFUMessage handles a single signaling message +func (g *Gateway) handleSFUMessage(peer *sfu.Peer, room *sfu.Room, msg *sfu.ClientMessage) { + switch msg.Type { + case sfu.MessageTypeOffer: + var data sfu.OfferData + if err := json.Unmarshal(msg.Data, &data); err != nil { + peer.SendMessage(sfu.NewErrorMessage("invalid_offer", err.Error())) + return + } + if err := peer.HandleOffer(data.SDP); err != nil { + peer.SendMessage(sfu.NewErrorMessage("offer_failed", err.Error())) + return + } + + // After successfully handling the FIRST offer, send existing tracks + // This ensures the WebRTC connection is established before adding more tracks + if peer.MarkInitialOfferHandled() { + g.logger.ComponentInfo(logging.ComponentGeneral, "First offer handled, sending existing tracks", + zap.String("peer_id", peer.ID), + ) + room.SendExistingTracksTo(peer) + } + + case sfu.MessageTypeAnswer: + var data sfu.AnswerData + if err := json.Unmarshal(msg.Data, &data); err != nil { + peer.SendMessage(sfu.NewErrorMessage("invalid_answer", err.Error())) + return + } + if err := peer.HandleAnswer(data.SDP); err != nil { + peer.SendMessage(sfu.NewErrorMessage("answer_failed", err.Error())) + } + + case sfu.MessageTypeICECandidate: + var data sfu.ICECandidateData + if err := json.Unmarshal(msg.Data, &data); err != nil { + peer.SendMessage(sfu.NewErrorMessage("invalid_candidate", err.Error())) + return + } + if err := peer.HandleICECandidate(&data); err != nil { + peer.SendMessage(sfu.NewErrorMessage("candidate_failed", err.Error())) + } + + case sfu.MessageTypeMute: + peer.SetAudioMuted(true) + g.logger.ComponentDebug(logging.ComponentGeneral, "Peer muted audio", zap.String("peer_id", peer.ID)) + + case sfu.MessageTypeUnmute: + peer.SetAudioMuted(false) + g.logger.ComponentDebug(logging.ComponentGeneral, "Peer unmuted audio", zap.String("peer_id", peer.ID)) + + case sfu.MessageTypeStartVideo: + peer.SetVideoMuted(false) + g.logger.ComponentDebug(logging.ComponentGeneral, "Peer started video", zap.String("peer_id", peer.ID)) + + case sfu.MessageTypeStopVideo: + peer.SetVideoMuted(true) + g.logger.ComponentDebug(logging.ComponentGeneral, "Peer stopped video", zap.String("peer_id", peer.ID)) + + case sfu.MessageTypeLeave: + // Will be handled by deferred cleanup + g.logger.ComponentInfo(logging.ComponentGeneral, "Peer leaving room", zap.String("peer_id", peer.ID)) + + default: + peer.SendMessage(sfu.NewErrorMessage("unknown_type", "unknown message type")) + } +} + +// initializeSFUManager initializes the SFU manager with the gateway's TURN config +func (g *Gateway) initializeSFUManager() error { + if g.cfg.SFU == nil || !g.cfg.SFU.Enabled { + g.logger.ComponentInfo(logging.ComponentGeneral, "SFU service disabled") + return nil + } + + // Build ICE servers from config + iceServers := make([]webrtc.ICEServer, 0) + + // Add configured ICE servers + if g.cfg.SFU.ICEServers != nil { + for _, server := range g.cfg.SFU.ICEServers { + iceServers = append(iceServers, webrtc.ICEServer{ + URLs: server.URLs, + Username: server.Username, + Credential: server.Credential, + }) + } + } + + // Add TURN servers if configured (credentials will be generated dynamically) + if g.cfg.TURN != nil { + // Determine hostname for ICE server URLs + // Use configured domain, or fallback to localhost for local development + iceHost := g.cfg.DomainName + if iceHost == "" { + iceHost = "localhost" + } + + if len(g.cfg.TURN.STUNURLs) > 0 { + // Process URLs to replace empty hostnames (e.g., "stun::3478" -> "stun:localhost:3478") + processedURLs := processURLsWithHost(g.cfg.TURN.STUNURLs, iceHost) + iceServers = append(iceServers, webrtc.ICEServer{ + URLs: processedURLs, + }) + } + // Note: TURN credentials are time-limited, so clients should fetch them + // via the /v1/turn/credentials endpoint before joining a room + } + + // Create SFU config + sfuConfig := &sfu.Config{ + MaxParticipants: g.cfg.SFU.MaxParticipants, + MediaTimeout: g.cfg.SFU.MediaTimeout, + ICEServers: iceServers, + } + + if sfuConfig.MaxParticipants == 0 { + sfuConfig.MaxParticipants = 10 + } + if sfuConfig.MediaTimeout == 0 { + sfuConfig.MediaTimeout = 30 * time.Second + } + + manager, err := sfu.NewRoomManager(sfuConfig, g.logger.Logger) + if err != nil { + return err + } + + g.sfuManager = manager + + g.logger.ComponentInfo(logging.ComponentGeneral, "SFU manager initialized", + zap.Int("max_participants", sfuConfig.MaxParticipants), + zap.Duration("media_timeout", sfuConfig.MediaTimeout), + zap.Int("ice_servers", len(iceServers)), + ) + + return nil +} diff --git a/pkg/gateway/turn_handlers.go b/pkg/gateway/turn_handlers.go new file mode 100644 index 0000000..1ca8e38 --- /dev/null +++ b/pkg/gateway/turn_handlers.go @@ -0,0 +1,192 @@ +package gateway + +import ( + "crypto/hmac" + "crypto/sha1" + "encoding/base64" + "fmt" + "net" + "net/http" + "strings" + "time" + + "github.com/DeBrosOfficial/network/pkg/gateway/auth" + "github.com/DeBrosOfficial/network/pkg/logging" + "go.uber.org/zap" +) + +// TURNCredentialsResponse is the response for TURN credential requests +type TURNCredentialsResponse struct { + Username string `json:"username"` // Format: "timestamp:userId" + Credential string `json:"credential"` // HMAC-SHA1(username, shared_secret) base64 encoded + TTL int64 `json:"ttl"` // Time-to-live in seconds + STUNURLs []string `json:"stun_urls"` // STUN server URLs + TURNURLs []string `json:"turn_urls"` // TURN server URLs +} + +// turnCredentialsHandler handles POST /v1/turn/credentials +// Returns time-limited TURN credentials for WebRTC connections +func (g *Gateway) turnCredentialsHandler(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost && r.Method != http.MethodGet { + writeError(w, http.StatusMethodNotAllowed, "method not allowed") + return + } + + // Check if TURN is configured + if g.cfg.TURN == nil || g.cfg.TURN.SharedSecret == "" { + g.logger.ComponentWarn(logging.ComponentGeneral, "TURN credentials requested but not configured") + writeError(w, http.StatusServiceUnavailable, "TURN service not configured") + return + } + + // Get user ID from JWT claims or API key + userID := g.extractUserID(r) + if userID == "" { + userID = "anonymous" + } + + // Get gateway hostname from request + gatewayHost := r.Host + if idx := strings.Index(gatewayHost, ":"); idx != -1 { + gatewayHost = gatewayHost[:idx] // Remove port + } + if gatewayHost == "" { + gatewayHost = "localhost" + } + + // Generate credentials + credentials := g.generateTURNCredentials(userID, gatewayHost) + + g.logger.ComponentInfo(logging.ComponentGeneral, "TURN credentials generated", + zap.String("user_id", userID), + zap.String("gateway_host", gatewayHost), + zap.Int64("ttl", credentials.TTL), + ) + + writeJSON(w, http.StatusOK, credentials) +} + +// generateTURNCredentials creates time-limited TURN credentials using HMAC-SHA1 +func (g *Gateway) generateTURNCredentials(userID, gatewayHost string) *TURNCredentialsResponse { + cfg := g.cfg.TURN + + // Default TTL to 24 hours if not configured + ttl := cfg.TTL + if ttl == 0 { + ttl = 24 * time.Hour + } + + // Calculate expiry timestamp + timestamp := time.Now().Unix() + int64(ttl.Seconds()) + + // Format: "timestamp:userId" (coturn format) + username := fmt.Sprintf("%d:%s", timestamp, userID) + + // Generate HMAC-SHA1 credential + h := hmac.New(sha1.New, []byte(cfg.SharedSecret)) + h.Write([]byte(username)) + credential := base64.StdEncoding.EncodeToString(h.Sum(nil)) + + // Determine the host to use for STUN/TURN URLs + // Priority: 1) ExternalHost from config, 2) Auto-detect LAN IP, 3) Gateway host from request + host := cfg.ExternalHost + if host == "" { + // Auto-detect LAN IP for development + host = detectLANIP() + if host == "" { + // Fallback to gateway host from request (may be localhost) + host = gatewayHost + } + } + + // Process URLs - replace empty hostnames (::) with determined host + stunURLs := processURLsWithHost(cfg.STUNURLs, host) + turnURLs := processURLsWithHost(cfg.TURNURLs, host) + + // If TLS is enabled, ensure we have turns:// URLs + if cfg.TLSEnabled { + hasTurns := false + for _, url := range turnURLs { + if strings.HasPrefix(url, "turns:") { + hasTurns = true + break + } + } + // Auto-add turns:// URL if not already configured + if !hasTurns { + turnsURL := fmt.Sprintf("turns:%s:443?transport=tcp", host) + turnURLs = append(turnURLs, turnsURL) + } + } + + return &TURNCredentialsResponse{ + Username: username, + Credential: credential, + TTL: int64(ttl.Seconds()), + STUNURLs: stunURLs, + TURNURLs: turnURLs, + } +} + +// detectLANIP returns the first non-loopback IPv4 address found on the system. +// Returns empty string if no suitable address is found. +func detectLANIP() string { + addrs, err := net.InterfaceAddrs() + if err != nil { + return "" + } + + for _, addr := range addrs { + if ipNet, ok := addr.(*net.IPNet); ok && !ipNet.IP.IsLoopback() { + if ipNet.IP.To4() != nil { + return ipNet.IP.String() + } + } + } + return "" +} + +// processURLsWithHost replaces empty hostnames in URLs with the given host +// Supports two patterns: +// - "stun:::3478" (triple colon) -> "stun:host:3478" +// - "stun::3478" (double colon) -> "stun:host:3478" +func processURLsWithHost(urls []string, host string) []string { + result := make([]string, 0, len(urls)) + for _, url := range urls { + // Check for triple colon pattern first (e.g., "stun:::3478") + // This is the preferred format: protocol:::port + if strings.Contains(url, ":::") { + url = strings.Replace(url, ":::", ":"+host+":", 1) + } else if strings.Contains(url, "::") { + // Fallback for double colon pattern (e.g., "stun::3478") + url = strings.Replace(url, "::", ":"+host+":", 1) + } + result = append(result, url) + } + return result +} + +// extractUserID extracts the user ID from the request context +func (g *Gateway) extractUserID(r *http.Request) string { + ctx := r.Context() + + // Try JWT claims first + if v := ctx.Value(ctxKeyJWT); v != nil { + if claims, ok := v.(*auth.JWTClaims); ok && claims != nil { + if claims.Sub != "" { + return claims.Sub + } + } + } + + // Fallback to API key + if v := ctx.Value(ctxKeyAPIKey); v != nil { + if key, ok := v.(string); ok && key != "" { + // Use a hash of the API key as the user ID for privacy + return fmt.Sprintf("ak_%s", key[:min(8, len(key))]) + } + } + + return "" +} + diff --git a/pkg/node/gateway.go b/pkg/node/gateway.go index 9bada62..0fa2273 100644 --- a/pkg/node/gateway.go +++ b/pkg/node/gateway.go @@ -33,19 +33,21 @@ func (n *Node) startHTTPGateway(ctx context.Context) error { } gwCfg := &gateway.Config{ - ListenAddr: n.config.HTTPGateway.ListenAddr, - ClientNamespace: n.config.HTTPGateway.ClientNamespace, - BootstrapPeers: n.config.Discovery.BootstrapPeers, - NodePeerID: loadNodePeerIDFromIdentity(n.config.Node.DataDir), - RQLiteDSN: n.config.HTTPGateway.RQLiteDSN, - OlricServers: n.config.HTTPGateway.OlricServers, - OlricTimeout: n.config.HTTPGateway.OlricTimeout, + ListenAddr: n.config.HTTPGateway.ListenAddr, + ClientNamespace: n.config.HTTPGateway.ClientNamespace, + BootstrapPeers: n.config.Discovery.BootstrapPeers, + NodePeerID: loadNodePeerIDFromIdentity(n.config.Node.DataDir), + RQLiteDSN: n.config.HTTPGateway.RQLiteDSN, + OlricServers: n.config.HTTPGateway.OlricServers, + OlricTimeout: n.config.HTTPGateway.OlricTimeout, IPFSClusterAPIURL: n.config.HTTPGateway.IPFSClusterAPIURL, - IPFSAPIURL: n.config.HTTPGateway.IPFSAPIURL, - IPFSTimeout: n.config.HTTPGateway.IPFSTimeout, - EnableHTTPS: n.config.HTTPGateway.HTTPS.Enabled, - DomainName: n.config.HTTPGateway.HTTPS.Domain, - TLSCacheDir: n.config.HTTPGateway.HTTPS.CacheDir, + IPFSAPIURL: n.config.HTTPGateway.IPFSAPIURL, + IPFSTimeout: n.config.HTTPGateway.IPFSTimeout, + EnableHTTPS: n.config.HTTPGateway.HTTPS.Enabled, + DomainName: n.config.HTTPGateway.HTTPS.Domain, + TLSCacheDir: n.config.HTTPGateway.HTTPS.CacheDir, + TURN: n.config.HTTPGateway.TURN, + SFU: n.config.HTTPGateway.SFU, } apiGateway, err := gateway.New(gatewayLogger, gwCfg) diff --git a/pkg/node/node.go b/pkg/node/node.go index eeb4d3b..465238e 100644 --- a/pkg/node/node.go +++ b/pkg/node/node.go @@ -16,6 +16,7 @@ import ( "github.com/DeBrosOfficial/network/pkg/logging" "github.com/DeBrosOfficial/network/pkg/pubsub" database "github.com/DeBrosOfficial/network/pkg/rqlite" + "github.com/DeBrosOfficial/network/pkg/turn" "github.com/libp2p/go-libp2p/core/host" "go.uber.org/zap" "golang.org/x/crypto/acme/autocert" @@ -55,6 +56,9 @@ type Node struct { // Certificate ready signal - closed when TLS certificates are extracted and ready for use certReady chan struct{} + + // Built-in TURN server for WebRTC NAT traversal + turnServer *turn.Server } // NewNode creates a new network node @@ -96,6 +100,11 @@ func (n *Node) Start(ctx context.Context) error { n.logger.ComponentWarn(logging.ComponentNode, "Failed to start HTTP Gateway", zap.Error(err)) } + // Start built-in TURN server if enabled + if err := n.startTURNServer(); err != nil { + n.logger.ComponentWarn(logging.ComponentNode, "Failed to start TURN server", zap.Error(err)) + } + // Start LibP2P host first (needed for cluster discovery) if err := n.startLibP2P(); err != nil { return fmt.Errorf("failed to start LibP2P: %w", err) @@ -135,6 +144,11 @@ func (n *Node) Start(ctx context.Context) error { func (n *Node) Stop() error { n.logger.ComponentInfo(logging.ComponentNode, "Stopping network node") + // Stop TURN server + if n.turnServer != nil { + _ = n.turnServer.Stop() + } + // Stop HTTP Gateway server if n.apiGatewayServer != nil { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) diff --git a/pkg/node/turn.go b/pkg/node/turn.go new file mode 100644 index 0000000..4792cb7 --- /dev/null +++ b/pkg/node/turn.go @@ -0,0 +1,90 @@ +package node + +import ( + "os" + + "github.com/DeBrosOfficial/network/pkg/logging" + "github.com/DeBrosOfficial/network/pkg/turn" + "go.uber.org/zap" +) + +// startTURNServer initializes and starts the built-in TURN server +func (n *Node) startTURNServer() error { + if !n.config.TURNServer.Enabled { + n.logger.ComponentInfo(logging.ComponentNode, "Built-in TURN server disabled") + return nil + } + + n.logger.ComponentInfo(logging.ComponentNode, "Starting built-in TURN server") + + // Get shared secret - env var takes priority over config file (for production) + sharedSecret := os.Getenv("TURN_SHARED_SECRET") + if sharedSecret == "" && n.config.HTTPGateway.TURN != nil && n.config.HTTPGateway.TURN.SharedSecret != "" { + sharedSecret = n.config.HTTPGateway.TURN.SharedSecret + } + + if sharedSecret == "" { + n.logger.ComponentWarn(logging.ComponentNode, "TURN server enabled but no shared_secret configured in http_gateway.turn") + return nil + } + + // Get public IP - env var takes priority over config file (for production) + publicIP := os.Getenv("TURN_PUBLIC_IP") + if publicIP == "" { + publicIP = n.config.TURNServer.PublicIP + } + + // Build TURN server config + turnCfg := &turn.Config{ + Enabled: true, + ListenAddr: n.config.TURNServer.ListenAddr, + PublicIP: publicIP, + Realm: n.config.TURNServer.Realm, + SharedSecret: sharedSecret, + CredentialTTL: 24 * 60 * 60, // 24 hours in seconds (will be converted) + MinPort: n.config.TURNServer.MinPort, + MaxPort: n.config.TURNServer.MaxPort, + // TLS configuration for TURNS + TLSEnabled: n.config.TURNServer.TLSEnabled, + TLSListenAddr: n.config.TURNServer.TLSListenAddr, + TLSCertFile: n.config.TURNServer.TLSCertFile, + TLSKeyFile: n.config.TURNServer.TLSKeyFile, + } + + // Apply defaults + if turnCfg.ListenAddr == "" { + turnCfg.ListenAddr = "0.0.0.0:3478" + } + if turnCfg.Realm == "" { + turnCfg.Realm = "orama.network" + } + if turnCfg.MinPort == 0 { + turnCfg.MinPort = 49152 + } + if turnCfg.MaxPort == 0 { + turnCfg.MaxPort = 65535 + } + if turnCfg.TLSListenAddr == "" && turnCfg.TLSEnabled { + turnCfg.TLSListenAddr = "0.0.0.0:443" + } + + // Create and start TURN server + server, err := turn.NewServer(turnCfg, n.logger.Logger) + if err != nil { + return err + } + + if err := server.Start(); err != nil { + return err + } + + n.turnServer = server + + n.logger.ComponentInfo(logging.ComponentNode, "Built-in TURN server started", + zap.String("listen_addr", turnCfg.ListenAddr), + zap.String("realm", turnCfg.Realm), + zap.Bool("turns_enabled", turnCfg.TLSEnabled), + ) + + return nil +} diff --git a/pkg/serverless/errors.go b/pkg/serverless/errors.go index 135dd6a..cc91317 100644 --- a/pkg/serverless/errors.go +++ b/pkg/serverless/errors.go @@ -214,3 +214,9 @@ func IsServiceUnavailable(err error) bool { errors.Is(err, ErrDatabaseUnavailable) || errors.Is(err, ErrCacheUnavailable) } + +// IsValidationError checks if an error is a validation error. +func IsValidationError(err error) bool { + var validationErr *ValidationError + return errors.As(err, &validationErr) +} diff --git a/pkg/serverless/triggers.go b/pkg/serverless/triggers.go new file mode 100644 index 0000000..65f6b0a --- /dev/null +++ b/pkg/serverless/triggers.go @@ -0,0 +1,239 @@ +package serverless + +import ( + "context" + "fmt" + "strings" + "time" + + "github.com/DeBrosOfficial/network/pkg/rqlite" + "github.com/google/uuid" + "go.uber.org/zap" +) + +// Ensure DBTriggerManager implements TriggerManager interface. +var _ TriggerManager = (*DBTriggerManager)(nil) + +// DBTriggerManager manages function triggers using RQLite database. +type DBTriggerManager struct { + db rqlite.Client + logger *zap.Logger +} + +// NewDBTriggerManager creates a new trigger manager. +func NewDBTriggerManager(db rqlite.Client, logger *zap.Logger) *DBTriggerManager { + return &DBTriggerManager{ + db: db, + logger: logger, + } +} + +// pubsubTriggerRow represents a row from the function_pubsub_triggers table. +type pubsubTriggerRow struct { + ID string `db:"id"` + FunctionID string `db:"function_id"` + Topic string `db:"topic"` + Enabled bool `db:"enabled"` + CreatedAt time.Time `db:"created_at"` +} + +// AddPubSubTrigger adds a pubsub trigger to a function. +func (m *DBTriggerManager) AddPubSubTrigger(ctx context.Context, functionID, topic string) error { + functionID = strings.TrimSpace(functionID) + topic = strings.TrimSpace(topic) + + if functionID == "" { + return &ValidationError{Field: "function_id", Message: "cannot be empty"} + } + if topic == "" { + return &ValidationError{Field: "topic", Message: "cannot be empty"} + } + + // Check if trigger already exists for this function and topic + var existing []pubsubTriggerRow + checkQuery := `SELECT id FROM function_pubsub_triggers WHERE function_id = ? AND topic = ? AND enabled = TRUE` + if err := m.db.Query(ctx, &existing, checkQuery, functionID, topic); err != nil { + return fmt.Errorf("failed to check existing trigger: %w", err) + } + if len(existing) > 0 { + return &ValidationError{Field: "topic", Message: "trigger already exists for this topic"} + } + + // Generate trigger ID + triggerID := "trig_" + uuid.New().String()[:8] + + // Insert trigger + query := `INSERT INTO function_pubsub_triggers (id, function_id, topic, enabled, created_at) VALUES (?, ?, ?, TRUE, ?)` + _, err := m.db.Exec(ctx, query, triggerID, functionID, topic, time.Now()) + if err != nil { + return fmt.Errorf("failed to create pubsub trigger: %w", err) + } + + m.logger.Info("PubSub trigger created", + zap.String("trigger_id", triggerID), + zap.String("function_id", functionID), + zap.String("topic", topic), + ) + + return nil +} + +// AddPubSubTriggerWithID adds a pubsub trigger and returns the trigger ID. +func (m *DBTriggerManager) AddPubSubTriggerWithID(ctx context.Context, functionID, topic string) (string, error) { + functionID = strings.TrimSpace(functionID) + topic = strings.TrimSpace(topic) + + if functionID == "" { + return "", &ValidationError{Field: "function_id", Message: "cannot be empty"} + } + if topic == "" { + return "", &ValidationError{Field: "topic", Message: "cannot be empty"} + } + + // Check if trigger already exists for this function and topic + var existing []pubsubTriggerRow + checkQuery := `SELECT id FROM function_pubsub_triggers WHERE function_id = ? AND topic = ? AND enabled = TRUE` + if err := m.db.Query(ctx, &existing, checkQuery, functionID, topic); err != nil { + return "", fmt.Errorf("failed to check existing trigger: %w", err) + } + if len(existing) > 0 { + return "", &ValidationError{Field: "topic", Message: "trigger already exists for this topic"} + } + + // Generate trigger ID + triggerID := "trig_" + uuid.New().String()[:8] + + // Insert trigger + query := `INSERT INTO function_pubsub_triggers (id, function_id, topic, enabled, created_at) VALUES (?, ?, ?, TRUE, ?)` + _, err := m.db.Exec(ctx, query, triggerID, functionID, topic, time.Now()) + if err != nil { + return "", fmt.Errorf("failed to create pubsub trigger: %w", err) + } + + m.logger.Info("PubSub trigger created", + zap.String("trigger_id", triggerID), + zap.String("function_id", functionID), + zap.String("topic", topic), + ) + + return triggerID, nil +} + +// AddCronTrigger adds a cron-based trigger to a function. +func (m *DBTriggerManager) AddCronTrigger(ctx context.Context, functionID, cronExpr string) error { + // TODO: Implement cron trigger support + return fmt.Errorf("cron triggers not yet implemented") +} + +// AddDBTrigger adds a database trigger to a function. +func (m *DBTriggerManager) AddDBTrigger(ctx context.Context, functionID, tableName string, operation DBOperation, condition string) error { + // TODO: Implement database trigger support + return fmt.Errorf("database triggers not yet implemented") +} + +// ScheduleOnce schedules a one-time execution. +func (m *DBTriggerManager) ScheduleOnce(ctx context.Context, functionID string, runAt time.Time, payload []byte) (string, error) { + // TODO: Implement one-time timer support + return "", fmt.Errorf("one-time timers not yet implemented") +} + +// RemoveTrigger removes a trigger by ID. +func (m *DBTriggerManager) RemoveTrigger(ctx context.Context, triggerID string) error { + triggerID = strings.TrimSpace(triggerID) + if triggerID == "" { + return &ValidationError{Field: "trigger_id", Message: "cannot be empty"} + } + + // Try to delete from pubsub triggers first + query := `DELETE FROM function_pubsub_triggers WHERE id = ?` + result, err := m.db.Exec(ctx, query, triggerID) + if err != nil { + return fmt.Errorf("failed to remove trigger: %w", err) + } + + rowsAffected, _ := result.RowsAffected() + if rowsAffected == 0 { + return ErrTriggerNotFound + } + + m.logger.Info("Trigger removed", zap.String("trigger_id", triggerID)) + return nil +} + +// ListPubSubTriggers returns all pubsub triggers for a function. +func (m *DBTriggerManager) ListPubSubTriggers(ctx context.Context, functionID string) ([]PubSubTrigger, error) { + functionID = strings.TrimSpace(functionID) + if functionID == "" { + return nil, &ValidationError{Field: "function_id", Message: "cannot be empty"} + } + + query := `SELECT id, function_id, topic, enabled FROM function_pubsub_triggers WHERE function_id = ? AND enabled = TRUE` + var rows []pubsubTriggerRow + if err := m.db.Query(ctx, &rows, query, functionID); err != nil { + return nil, fmt.Errorf("failed to list pubsub triggers: %w", err) + } + + triggers := make([]PubSubTrigger, len(rows)) + for i, row := range rows { + triggers[i] = PubSubTrigger{ + ID: row.ID, + FunctionID: row.FunctionID, + Topic: row.Topic, + Enabled: row.Enabled, + } + } + + return triggers, nil +} + +// GetTriggersByTopic returns all enabled triggers for a specific topic. +func (m *DBTriggerManager) GetTriggersByTopic(ctx context.Context, topic string) ([]PubSubTrigger, error) { + topic = strings.TrimSpace(topic) + if topic == "" { + return nil, &ValidationError{Field: "topic", Message: "cannot be empty"} + } + + query := `SELECT id, function_id, topic, enabled FROM function_pubsub_triggers WHERE topic = ? AND enabled = TRUE` + var rows []pubsubTriggerRow + if err := m.db.Query(ctx, &rows, query, topic); err != nil { + return nil, fmt.Errorf("failed to get triggers by topic: %w", err) + } + + triggers := make([]PubSubTrigger, len(rows)) + for i, row := range rows { + triggers[i] = PubSubTrigger{ + ID: row.ID, + FunctionID: row.FunctionID, + Topic: row.Topic, + Enabled: row.Enabled, + } + } + + return triggers, nil +} + +// GetPubSubTrigger returns a specific pubsub trigger by ID. +func (m *DBTriggerManager) GetPubSubTrigger(ctx context.Context, triggerID string) (*PubSubTrigger, error) { + triggerID = strings.TrimSpace(triggerID) + if triggerID == "" { + return nil, &ValidationError{Field: "trigger_id", Message: "cannot be empty"} + } + + query := `SELECT id, function_id, topic, enabled FROM function_pubsub_triggers WHERE id = ?` + var rows []pubsubTriggerRow + if err := m.db.Query(ctx, &rows, query, triggerID); err != nil { + return nil, fmt.Errorf("failed to get trigger: %w", err) + } + + if len(rows) == 0 { + return nil, ErrTriggerNotFound + } + + row := rows[0] + return &PubSubTrigger{ + ID: row.ID, + FunctionID: row.FunctionID, + Topic: row.Topic, + Enabled: row.Enabled, + }, nil +} diff --git a/pkg/serverless/types.go b/pkg/serverless/types.go index 66a13f7..2a3e4c4 100644 --- a/pkg/serverless/types.go +++ b/pkg/serverless/types.go @@ -131,6 +131,12 @@ type TriggerManager interface { // RemoveTrigger removes a trigger by ID. RemoveTrigger(ctx context.Context, triggerID string) error + + // ListPubSubTriggers returns all pubsub triggers for a function. + ListPubSubTriggers(ctx context.Context, functionID string) ([]PubSubTrigger, error) + + // GetTriggersByTopic returns all enabled triggers for a specific topic. + GetTriggersByTopic(ctx context.Context, topic string) ([]PubSubTrigger, error) } // JobManager manages background job execution. diff --git a/pkg/turn/server.go b/pkg/turn/server.go new file mode 100644 index 0000000..5ebd4c4 --- /dev/null +++ b/pkg/turn/server.go @@ -0,0 +1,343 @@ +// Package turn provides a built-in TURN/STUN server using Pion. +package turn + +import ( + "crypto/hmac" + "crypto/sha1" + "crypto/tls" + "encoding/base64" + "fmt" + "net" + "strconv" + "sync" + "time" + + "github.com/pion/turn/v4" + "go.uber.org/zap" +) + +// Config contains TURN server configuration +type Config struct { + // Enabled enables the built-in TURN server + Enabled bool `yaml:"enabled"` + + // ListenAddr is the UDP address to listen on (e.g., "0.0.0.0:3478") + ListenAddr string `yaml:"listen_addr"` + + // PublicIP is the public IP address to advertise for relay + // If empty, will try to auto-detect + PublicIP string `yaml:"public_ip"` + + // Realm is the TURN realm (e.g., "orama.network") + Realm string `yaml:"realm"` + + // SharedSecret is the secret for HMAC-SHA1 credential generation + // Should match the gateway's TURN_SHARED_SECRET + SharedSecret string `yaml:"shared_secret"` + + // CredentialTTL is the lifetime of generated credentials + CredentialTTL time.Duration `yaml:"credential_ttl"` + + // MinPort and MaxPort define the relay port range + MinPort uint16 `yaml:"min_port"` + MaxPort uint16 `yaml:"max_port"` + + // TLS Configuration for TURNS (TURN over TLS) + // TLSEnabled enables TURNS listener on TLSListenAddr + TLSEnabled bool `yaml:"tls_enabled"` + + // TLSListenAddr is the TCP/TLS address to listen on (e.g., "0.0.0.0:443") + TLSListenAddr string `yaml:"tls_listen_addr"` + + // TLSCertFile is the path to the TLS certificate file + TLSCertFile string `yaml:"tls_cert_file"` + + // TLSKeyFile is the path to the TLS private key file + TLSKeyFile string `yaml:"tls_key_file"` +} + +// DefaultConfig returns a default TURN server configuration +func DefaultConfig() *Config { + return &Config{ + Enabled: false, + ListenAddr: "0.0.0.0:3478", + Realm: "orama.network", + CredentialTTL: 24 * time.Hour, + MinPort: 49152, + MaxPort: 65535, + } +} + +// Server is a built-in TURN/STUN server +type Server struct { + config *Config + logger *zap.Logger + turnServer *turn.Server + conn net.PacketConn // UDP listener + tlsListener net.Listener // TLS listener for TURNS + mu sync.RWMutex + running bool +} + +// NewServer creates a new TURN server +func NewServer(config *Config, logger *zap.Logger) (*Server, error) { + if config == nil { + config = DefaultConfig() + } + + return &Server{ + config: config, + logger: logger, + }, nil +} + +// Start starts the TURN server +func (s *Server) Start() error { + s.mu.Lock() + defer s.mu.Unlock() + + if s.running { + return nil + } + + if !s.config.Enabled { + s.logger.Info("TURN server disabled") + return nil + } + + if s.config.SharedSecret == "" { + return fmt.Errorf("TURN shared secret is required") + } + + // Create UDP listener + conn, err := net.ListenPacket("udp4", s.config.ListenAddr) + if err != nil { + return fmt.Errorf("failed to listen on %s: %w", s.config.ListenAddr, err) + } + s.conn = conn + + // Determine public IP + publicIP := s.config.PublicIP + if publicIP == "" { + // Try to auto-detect + publicIP, err = getPublicIP() + if err != nil { + s.logger.Warn("Failed to auto-detect public IP, using listener address", zap.Error(err)) + host, _, _ := net.SplitHostPort(s.config.ListenAddr) + if host == "0.0.0.0" || host == "" { + host = "127.0.0.1" + } + publicIP = host + } + } + + relayIP := net.ParseIP(publicIP) + if relayIP == nil { + return fmt.Errorf("invalid public IP: %s", publicIP) + } + + s.logger.Info("Starting TURN server", + zap.String("listen_addr", s.config.ListenAddr), + zap.String("public_ip", publicIP), + zap.String("realm", s.config.Realm), + zap.Uint16("min_port", s.config.MinPort), + zap.Uint16("max_port", s.config.MaxPort), + zap.Bool("tls_enabled", s.config.TLSEnabled), + ) + + // Prepare listener configs for TLS (TURNS) + var listenerConfigs []turn.ListenerConfig + + if s.config.TLSEnabled && s.config.TLSCertFile != "" && s.config.TLSKeyFile != "" { + // Load TLS certificate + cert, err := tls.LoadX509KeyPair(s.config.TLSCertFile, s.config.TLSKeyFile) + if err != nil { + conn.Close() + return fmt.Errorf("failed to load TLS certificate: %w", err) + } + + tlsConfig := &tls.Config{ + Certificates: []tls.Certificate{cert}, + MinVersion: tls.VersionTLS12, + } + + // Determine TLS listen address + tlsListenAddr := s.config.TLSListenAddr + if tlsListenAddr == "" { + tlsListenAddr = "0.0.0.0:443" + } + + // Create TLS listener + tlsListener, err := tls.Listen("tcp", tlsListenAddr, tlsConfig) + if err != nil { + conn.Close() + return fmt.Errorf("failed to start TLS listener on %s: %w", tlsListenAddr, err) + } + s.tlsListener = tlsListener + + listenerConfigs = append(listenerConfigs, turn.ListenerConfig{ + Listener: tlsListener, + RelayAddressGenerator: &turn.RelayAddressGeneratorPortRange{ + RelayAddress: relayIP, + Address: "0.0.0.0", + MinPort: s.config.MinPort, + MaxPort: s.config.MaxPort, + }, + }) + + s.logger.Info("TURNS (TLS) listener started", + zap.String("tls_addr", tlsListenAddr), + ) + } + + // Create TURN server with HMAC-SHA1 auth + turnServer, err := turn.NewServer(turn.ServerConfig{ + Realm: s.config.Realm, + AuthHandler: func(username, realm string, srcAddr net.Addr) ([]byte, bool) { + return s.authHandler(username, realm, srcAddr) + }, + PacketConnConfigs: []turn.PacketConnConfig{ + { + PacketConn: conn, + RelayAddressGenerator: &turn.RelayAddressGeneratorPortRange{ + RelayAddress: relayIP, + Address: "0.0.0.0", + MinPort: s.config.MinPort, + MaxPort: s.config.MaxPort, + }, + }, + }, + ListenerConfigs: listenerConfigs, + }) + if err != nil { + conn.Close() + if s.tlsListener != nil { + s.tlsListener.Close() + } + return fmt.Errorf("failed to create TURN server: %w", err) + } + + s.turnServer = turnServer + s.running = true + + s.logger.Info("TURN server started successfully", + zap.String("addr", s.config.ListenAddr), + zap.String("realm", s.config.Realm), + zap.Bool("turns_enabled", s.config.TLSEnabled), + ) + + return nil +} + +// authHandler validates HMAC-SHA1 credentials (coturn-compatible format) +// Username format: timestamp:userID (e.g., "1234567890:user123") +func (s *Server) authHandler(username, realm string, srcAddr net.Addr) ([]byte, bool) { + // Parse timestamp from username + // Format: timestamp:userID + var timestamp int64 + for i, c := range username { + if c == ':' { + ts, err := strconv.ParseInt(username[:i], 10, 64) + if err != nil { + s.logger.Debug("Invalid timestamp in username", zap.String("username", username)) + return nil, false + } + timestamp = ts + break + } + } + + // Check if credential has expired + now := time.Now().Unix() + if timestamp > 0 && timestamp < now { + s.logger.Debug("Credential expired", + zap.String("username", username), + zap.Int64("expired_at", timestamp), + zap.Int64("now", now), + ) + return nil, false + } + + // Generate expected password using HMAC-SHA1 + // This matches the gateway's generateTURNCredentials function + h := hmac.New(sha1.New, []byte(s.config.SharedSecret)) + h.Write([]byte(username)) + password := base64.StdEncoding.EncodeToString(h.Sum(nil)) + + s.logger.Debug("TURN auth request", + zap.String("username", username), + zap.String("realm", realm), + zap.String("src_addr", srcAddr.String()), + ) + + return turn.GenerateAuthKey(username, realm, password), true +} + +// Stop stops the TURN server +func (s *Server) Stop() error { + s.mu.Lock() + defer s.mu.Unlock() + + if !s.running { + return nil + } + + s.logger.Info("Stopping TURN server") + + if s.turnServer != nil { + if err := s.turnServer.Close(); err != nil { + s.logger.Warn("Error closing TURN server", zap.Error(err)) + } + s.turnServer = nil + } + + if s.conn != nil { + s.conn.Close() + s.conn = nil + } + + // Close TLS listener + if s.tlsListener != nil { + s.tlsListener.Close() + s.tlsListener = nil + } + + s.running = false + s.logger.Info("TURN server stopped") + + return nil +} + +// IsRunning returns whether the server is running +func (s *Server) IsRunning() bool { + s.mu.RLock() + defer s.mu.RUnlock() + return s.running +} + +// GetListenAddr returns the listen address +func (s *Server) GetListenAddr() string { + return s.config.ListenAddr +} + +// GetPublicAddr returns the public address for clients +func (s *Server) GetPublicAddr() string { + if s.config.PublicIP != "" { + _, port, _ := net.SplitHostPort(s.config.ListenAddr) + return net.JoinHostPort(s.config.PublicIP, port) + } + return s.config.ListenAddr +} + +// getPublicIP tries to determine the public IP address +func getPublicIP() (string, error) { + // Try to get outbound IP by connecting to a public address + conn, err := net.Dial("udp4", "8.8.8.8:80") + if err != nil { + return "", err + } + defer conn.Close() + + localAddr := conn.LocalAddr().(*net.UDPAddr) + return localAddr.IP.String(), nil +}