diff --git a/Makefile b/Makefile index 0912a19..bafd5ea 100644 --- a/Makefile +++ b/Makefile @@ -63,7 +63,7 @@ test-e2e-quick: .PHONY: build clean test deps tidy fmt vet lint install-hooks redeploy-devnet redeploy-testnet release health -VERSION := 0.111.0 +VERSION := 0.112.0 COMMIT ?= $(shell git rev-parse --short HEAD 2>/dev/null || echo unknown) DATE ?= $(shell date -u +%Y-%m-%dT%H:%M:%SZ) LDFLAGS := -X 'main.version=$(VERSION)' -X 'main.commit=$(COMMIT)' -X 'main.date=$(DATE)' @@ -78,6 +78,8 @@ build: deps go build -ldflags "$(LDFLAGS)" -o bin/orama ./cmd/cli/ # Inject gateway build metadata via pkg path variables go build -ldflags "$(LDFLAGS) -X 'github.com/DeBrosOfficial/network/pkg/gateway.BuildVersion=$(VERSION)' -X 'github.com/DeBrosOfficial/network/pkg/gateway.BuildCommit=$(COMMIT)' -X 'github.com/DeBrosOfficial/network/pkg/gateway.BuildTime=$(DATE)'" -o bin/gateway ./cmd/gateway + go build -ldflags "$(LDFLAGS)" -o bin/sfu ./cmd/sfu + go build -ldflags "$(LDFLAGS)" -o bin/turn ./cmd/turn @echo "Build complete! Run ./bin/orama version" # Cross-compile CLI for Linux (only binary needed locally; VPS builds everything else from source) diff --git a/cmd/gateway/config.go b/cmd/gateway/config.go index 3983f2c..e263d1d 100644 --- a/cmd/gateway/config.go +++ b/cmd/gateway/config.go @@ -69,6 +69,13 @@ func parseGatewayConfig(logger *logging.ColoredLogger) *gateway.Config { } // Load YAML + type yamlWebRTCCfg struct { + Enabled bool `yaml:"enabled"` + SFUPort int `yaml:"sfu_port"` + TURNDomain string `yaml:"turn_domain"` + TURNSecret string `yaml:"turn_secret"` + } + type yamlCfg struct { ListenAddr string `yaml:"listen_addr"` ClientNamespace string `yaml:"client_namespace"` @@ -84,6 +91,7 @@ func parseGatewayConfig(logger *logging.ColoredLogger) *gateway.Config { IPFSAPIURL string `yaml:"ipfs_api_url"` IPFSTimeout string `yaml:"ipfs_timeout"` IPFSReplicationFactor int `yaml:"ipfs_replication_factor"` + WebRTC yamlWebRTCCfg `yaml:"webrtc"` } data, err := os.ReadFile(configPath) @@ -192,6 +200,18 @@ func parseGatewayConfig(logger *logging.ColoredLogger) *gateway.Config { cfg.IPFSReplicationFactor = y.IPFSReplicationFactor } + // WebRTC configuration + cfg.WebRTCEnabled = y.WebRTC.Enabled + if y.WebRTC.SFUPort > 0 { + cfg.SFUPort = y.WebRTC.SFUPort + } + if v := strings.TrimSpace(y.WebRTC.TURNDomain); v != "" { + cfg.TURNDomain = v + } + if v := strings.TrimSpace(y.WebRTC.TURNSecret); v != "" { + cfg.TURNSecret = v + } + // Validate configuration if errs := cfg.ValidateConfig(); len(errs) > 0 { fmt.Fprintf(os.Stderr, "\nGateway configuration errors (%d):\n", len(errs)) diff --git a/cmd/sfu/config.go b/cmd/sfu/config.go new file mode 100644 index 0000000..ea16be3 --- /dev/null +++ b/cmd/sfu/config.go @@ -0,0 +1,116 @@ +package main + +import ( + "flag" + "fmt" + "os" + "path/filepath" + "strings" + + "github.com/DeBrosOfficial/network/pkg/config" + "github.com/DeBrosOfficial/network/pkg/logging" + "github.com/DeBrosOfficial/network/pkg/sfu" + "go.uber.org/zap" +) + +// newSFUServer creates a new SFU server from config and logger. +// Wrapper to keep main.go clean and avoid importing sfu in main. +func newSFUServer(cfg *sfu.Config, logger *zap.Logger) (*sfu.Server, error) { + return sfu.NewServer(cfg, logger) +} + +func parseSFUConfig(logger *logging.ColoredLogger) *sfu.Config { + configFlag := flag.String("config", "", "Config file path (absolute path or filename in ~/.orama)") + flag.Parse() + + var configPath string + var err error + if *configFlag != "" { + if filepath.IsAbs(*configFlag) { + configPath = *configFlag + } else { + configPath, err = config.DefaultPath(*configFlag) + if err != nil { + logger.ComponentError(logging.ComponentSFU, "Failed to determine config path", zap.Error(err)) + fmt.Fprintf(os.Stderr, "Configuration error: %v\n", err) + os.Exit(1) + } + } + } else { + configPath, err = config.DefaultPath("sfu.yaml") + if err != nil { + logger.ComponentError(logging.ComponentSFU, "Failed to determine config path", zap.Error(err)) + fmt.Fprintf(os.Stderr, "Configuration error: %v\n", err) + os.Exit(1) + } + } + + type yamlTURNServer struct { + Host string `yaml:"host"` + Port int `yaml:"port"` + } + + type yamlCfg struct { + ListenAddr string `yaml:"listen_addr"` + Namespace string `yaml:"namespace"` + MediaPortStart int `yaml:"media_port_start"` + MediaPortEnd int `yaml:"media_port_end"` + TURNServers []yamlTURNServer `yaml:"turn_servers"` + TURNSecret string `yaml:"turn_secret"` + TURNCredentialTTL int `yaml:"turn_credential_ttl"` + RQLiteDSN string `yaml:"rqlite_dsn"` + } + + data, err := os.ReadFile(configPath) + if err != nil { + logger.ComponentError(logging.ComponentSFU, "Config file not found", + zap.String("path", configPath), zap.Error(err)) + fmt.Fprintf(os.Stderr, "\nConfig file not found at %s\n", configPath) + os.Exit(1) + } + + var y yamlCfg + if err := config.DecodeStrict(strings.NewReader(string(data)), &y); err != nil { + logger.ComponentError(logging.ComponentSFU, "Failed to parse SFU config", zap.Error(err)) + fmt.Fprintf(os.Stderr, "Configuration parse error: %v\n", err) + os.Exit(1) + } + + var turnServers []sfu.TURNServerConfig + for _, ts := range y.TURNServers { + turnServers = append(turnServers, sfu.TURNServerConfig{ + Host: ts.Host, + Port: ts.Port, + }) + } + + cfg := &sfu.Config{ + ListenAddr: y.ListenAddr, + Namespace: y.Namespace, + MediaPortStart: y.MediaPortStart, + MediaPortEnd: y.MediaPortEnd, + TURNServers: turnServers, + TURNSecret: y.TURNSecret, + TURNCredentialTTL: y.TURNCredentialTTL, + RQLiteDSN: y.RQLiteDSN, + } + + if errs := cfg.Validate(); len(errs) > 0 { + fmt.Fprintf(os.Stderr, "\nSFU configuration errors (%d):\n", len(errs)) + for _, e := range errs { + fmt.Fprintf(os.Stderr, " - %s\n", e) + } + fmt.Fprintf(os.Stderr, "\nPlease fix the configuration and try again.\n") + os.Exit(1) + } + + logger.ComponentInfo(logging.ComponentSFU, "Loaded SFU configuration", + zap.String("path", configPath), + zap.String("listen_addr", cfg.ListenAddr), + zap.String("namespace", cfg.Namespace), + zap.Int("media_ports", cfg.MediaPortEnd-cfg.MediaPortStart), + zap.Int("turn_servers", len(cfg.TURNServers)), + ) + + return cfg +} diff --git a/cmd/sfu/main.go b/cmd/sfu/main.go new file mode 100644 index 0000000..e71d9c9 --- /dev/null +++ b/cmd/sfu/main.go @@ -0,0 +1,59 @@ +package main + +import ( + "os" + "os/signal" + "syscall" + "time" + + "github.com/DeBrosOfficial/network/pkg/logging" + "go.uber.org/zap" +) + +var ( + version = "dev" + commit = "unknown" +) + +func main() { + logger, err := logging.NewColoredLogger(logging.ComponentSFU, true) + if err != nil { + panic(err) + } + + logger.ComponentInfo(logging.ComponentSFU, "Starting SFU server", + zap.String("version", version), + zap.String("commit", commit)) + + cfg := parseSFUConfig(logger) + + server, err := newSFUServer(cfg, logger.Logger) + if err != nil { + logger.ComponentError(logging.ComponentSFU, "Failed to create SFU server", zap.Error(err)) + os.Exit(1) + } + + // Start HTTP server in background + go func() { + if err := server.ListenAndServe(); err != nil { + logger.ComponentError(logging.ComponentSFU, "SFU server error", zap.Error(err)) + os.Exit(1) + } + }() + + // Wait for termination signal + quit := make(chan os.Signal, 1) + signal.Notify(quit, os.Interrupt, syscall.SIGTERM) + sig := <-quit + + logger.ComponentInfo(logging.ComponentSFU, "Shutdown signal received", zap.String("signal", sig.String())) + + // Graceful drain: notify peers and wait + server.Drain(30 * time.Second) + + if err := server.Close(); err != nil { + logger.ComponentError(logging.ComponentSFU, "Error during shutdown", zap.Error(err)) + } + + logger.ComponentInfo(logging.ComponentSFU, "SFU server shutdown complete") +} diff --git a/cmd/turn/config.go b/cmd/turn/config.go new file mode 100644 index 0000000..0b38f81 --- /dev/null +++ b/cmd/turn/config.go @@ -0,0 +1,96 @@ +package main + +import ( + "flag" + "fmt" + "os" + "path/filepath" + "strings" + + "github.com/DeBrosOfficial/network/pkg/config" + "github.com/DeBrosOfficial/network/pkg/logging" + "github.com/DeBrosOfficial/network/pkg/turn" + "go.uber.org/zap" +) + +func parseTURNConfig(logger *logging.ColoredLogger) *turn.Config { + configFlag := flag.String("config", "", "Config file path (absolute path or filename in ~/.orama)") + flag.Parse() + + var configPath string + var err error + if *configFlag != "" { + if filepath.IsAbs(*configFlag) { + configPath = *configFlag + } else { + configPath, err = config.DefaultPath(*configFlag) + if err != nil { + logger.ComponentError(logging.ComponentTURN, "Failed to determine config path", zap.Error(err)) + fmt.Fprintf(os.Stderr, "Configuration error: %v\n", err) + os.Exit(1) + } + } + } else { + configPath, err = config.DefaultPath("turn.yaml") + if err != nil { + logger.ComponentError(logging.ComponentTURN, "Failed to determine config path", zap.Error(err)) + fmt.Fprintf(os.Stderr, "Configuration error: %v\n", err) + os.Exit(1) + } + } + + type yamlCfg struct { + ListenAddr string `yaml:"listen_addr"` + TLSListenAddr string `yaml:"tls_listen_addr"` + PublicIP string `yaml:"public_ip"` + Realm string `yaml:"realm"` + AuthSecret string `yaml:"auth_secret"` + RelayPortStart int `yaml:"relay_port_start"` + RelayPortEnd int `yaml:"relay_port_end"` + Namespace string `yaml:"namespace"` + } + + data, err := os.ReadFile(configPath) + if err != nil { + logger.ComponentError(logging.ComponentTURN, "Config file not found", + zap.String("path", configPath), zap.Error(err)) + fmt.Fprintf(os.Stderr, "\nConfig file not found at %s\n", configPath) + os.Exit(1) + } + + var y yamlCfg + if err := config.DecodeStrict(strings.NewReader(string(data)), &y); err != nil { + logger.ComponentError(logging.ComponentTURN, "Failed to parse TURN config", zap.Error(err)) + fmt.Fprintf(os.Stderr, "Configuration parse error: %v\n", err) + os.Exit(1) + } + + cfg := &turn.Config{ + ListenAddr: y.ListenAddr, + TLSListenAddr: y.TLSListenAddr, + PublicIP: y.PublicIP, + Realm: y.Realm, + AuthSecret: y.AuthSecret, + RelayPortStart: y.RelayPortStart, + RelayPortEnd: y.RelayPortEnd, + Namespace: y.Namespace, + } + + if errs := cfg.Validate(); len(errs) > 0 { + fmt.Fprintf(os.Stderr, "\nTURN configuration errors (%d):\n", len(errs)) + for _, e := range errs { + fmt.Fprintf(os.Stderr, " - %s\n", e) + } + fmt.Fprintf(os.Stderr, "\nPlease fix the configuration and try again.\n") + os.Exit(1) + } + + logger.ComponentInfo(logging.ComponentTURN, "Loaded TURN configuration", + zap.String("path", configPath), + zap.String("listen_addr", cfg.ListenAddr), + zap.String("namespace", cfg.Namespace), + zap.String("realm", cfg.Realm), + ) + + return cfg +} diff --git a/cmd/turn/main.go b/cmd/turn/main.go new file mode 100644 index 0000000..90efe34 --- /dev/null +++ b/cmd/turn/main.go @@ -0,0 +1,48 @@ +package main + +import ( + "os" + "os/signal" + "syscall" + + "github.com/DeBrosOfficial/network/pkg/logging" + "github.com/DeBrosOfficial/network/pkg/turn" + "go.uber.org/zap" +) + +var ( + version = "dev" + commit = "unknown" +) + +func main() { + logger, err := logging.NewColoredLogger(logging.ComponentTURN, true) + if err != nil { + panic(err) + } + + logger.ComponentInfo(logging.ComponentTURN, "Starting TURN server", + zap.String("version", version), + zap.String("commit", commit)) + + cfg := parseTURNConfig(logger) + + server, err := turn.NewServer(cfg, logger.Logger) + if err != nil { + logger.ComponentError(logging.ComponentTURN, "Failed to start TURN server", zap.Error(err)) + os.Exit(1) + } + + // Wait for termination signal + quit := make(chan os.Signal, 1) + signal.Notify(quit, os.Interrupt, syscall.SIGTERM) + sig := <-quit + + logger.ComponentInfo(logging.ComponentTURN, "Shutdown signal received", zap.String("signal", sig.String())) + + if err := server.Close(); err != nil { + logger.ComponentError(logging.ComponentTURN, "Error during shutdown", zap.Error(err)) + } + + logger.ComponentInfo(logging.ComponentTURN, "TURN server shutdown complete") +} diff --git a/docs/ARCHITECTURE.md b/docs/ARCHITECTURE.md index a82b2fa..cbe8e4c 100644 --- a/docs/ARCHITECTURE.md +++ b/docs/ARCHITECTURE.md @@ -474,6 +474,36 @@ configured, use the IP over HTTP port 80 (`http://`) which goes through Cadd Planned containerization with Docker Compose and Kubernetes support. +## WebRTC (Voice/Video/Data) + +Namespaces can opt in to WebRTC support for real-time voice, video, and data channels. + +### Components + +- **SFU (Selective Forwarding Unit)** — Pion WebRTC server that handles signaling (WebSocket), SDP negotiation, and RTP forwarding. Runs on all 3 cluster nodes, binds only to WireGuard IPs. +- **TURN Server** — Pion TURN relay that provides NAT traversal. Runs on 2 of 3 nodes for redundancy. Public-facing (UDP 3478, 443, relay range 49152-65535). + +### Security Model + +- **TURN-shielded**: SFU binds only to WireGuard (10.0.0.x), never 0.0.0.0. All client media flows through TURN relay. +- **Forced relay**: `iceTransportPolicy: relay` enforced server-side — no direct peer connections. +- **HMAC credentials**: Per-namespace TURN shared secret with 10-minute TTL. +- **Namespace isolation**: Each namespace has its own TURN secret, port ranges, and rooms. + +### Port Allocation + +WebRTC uses a separate port allocation system from core namespace services: + +| Service | Port Range | +|---------|-----------| +| SFU signaling | 30000-30099 | +| SFU media (RTP) | 20000-29999 | +| TURN listen | 3478/udp (standard) | +| TURN TLS | 443/udp | +| TURN relay | 49152-65535/udp | + +See [docs/WEBRTC.md](WEBRTC.md) for full details including client integration, API reference, and debugging. + ## Future Enhancements 1. **GraphQL Support** - GraphQL gateway alongside REST diff --git a/docs/DEPLOYMENT_GUIDE.md b/docs/DEPLOYMENT_GUIDE.md index 82af8eb..71481a9 100644 --- a/docs/DEPLOYMENT_GUIDE.md +++ b/docs/DEPLOYMENT_GUIDE.md @@ -872,6 +872,57 @@ orama app delete my-old-app --- +## WebRTC (Voice/Video/Data) + +Namespaces can enable WebRTC support for real-time communication (voice calls, video calls, data channels). + +### Enable WebRTC + +```bash +# Enable WebRTC for a namespace (must be run on a cluster node) +orama namespace enable webrtc --namespace myapp + +# Check WebRTC status +orama namespace webrtc-status --namespace myapp +``` + +This provisions SFU servers on all 3 nodes and TURN relay servers on 2 nodes, allocates port blocks, creates DNS records, and opens firewall ports. + +### Disable WebRTC + +```bash +orama namespace disable webrtc --namespace myapp +``` + +Stops all SFU/TURN services, deallocates ports, removes DNS records, and closes firewall ports. + +### Client Integration + +```javascript +// 1. Get TURN credentials +const creds = await fetch('https://ns-myapp.orama.network/v1/webrtc/turn/credentials', { + method: 'POST', + headers: { 'Authorization': `Bearer ${jwt}` } +}); +const { urls, username, credential, ttl } = await creds.json(); + +// 2. Create PeerConnection (forced relay) +const pc = new RTCPeerConnection({ + iceServers: [{ urls, username, credential }], + iceTransportPolicy: 'relay' +}); + +// 3. Connect signaling WebSocket +const ws = new WebSocket( + `wss://ns-myapp.orama.network/v1/webrtc/signal?room=${roomId}`, + ['Bearer', jwt] +); +``` + +See [docs/WEBRTC.md](WEBRTC.md) for the full API reference, room management, credential protocol, and debugging guide. + +--- + ## Troubleshooting ### Deployment Issues diff --git a/docs/MONITORING.md b/docs/MONITORING.md index 698f5cd..228a328 100644 --- a/docs/MONITORING.md +++ b/docs/MONITORING.md @@ -238,6 +238,8 @@ These checks compare data across all nodes: - **WireGuard Peer Symmetry**: Each node has N-1 peers - **Clock Skew**: Node clocks within 5 seconds of each other - **Binary Version**: All nodes running the same version +- **WebRTC SFU Coverage**: SFU running on expected nodes (3/3) per namespace +- **WebRTC TURN Redundancy**: TURN running on expected nodes (2/3) per namespace ### Per-Node Checks @@ -249,6 +251,7 @@ These checks compare data across all nodes: - **Anyone**: Bootstrap progress - **Processes**: Zombies, orphans, panics in logs - **Namespaces**: Gateway and RQLite per namespace +- **WebRTC**: SFU and TURN service health (when provisioned) - **Network**: UFW, internet reachability, TCP retransmission ## Monitor vs Inspector diff --git a/docs/WEBRTC.md b/docs/WEBRTC.md new file mode 100644 index 0000000..bae8261 --- /dev/null +++ b/docs/WEBRTC.md @@ -0,0 +1,262 @@ +# WebRTC Integration + +Real-time voice, video, and data channels for Orama Network namespaces. + +## Architecture + +``` +Client A Client B + │ │ + │ 1. Get TURN credentials (REST) │ + │ 2. Connect WebSocket (signaling) │ + │ 3. Exchange SDP/ICE via SFU │ + │ │ + ▼ ▼ +┌──────────┐ UDP relay ┌──────────┐ +│ TURN │◄──────────────────►│ TURN │ +│ Server │ (public IPs) │ Server │ +│ Node 1 │ │ Node 2 │ +└────┬─────┘ └────┬─────┘ + │ WireGuard │ WireGuard + ▼ ▼ +┌──────────────────────────────────────────┐ +│ SFU Servers (3 nodes) │ +│ - WebSocket signaling (WireGuard only) │ +│ - Pion WebRTC (RTP forwarding) │ +│ - Room management │ +│ - Track publish/subscribe │ +└──────────────────────────────────────────┘ +``` + +**Key design decisions:** +- **TURN-shielded**: SFU binds only to WireGuard IPs. All client media flows through TURN relay. +- **`iceTransportPolicy: relay`** enforced server-side — no direct peer connections. +- **Opt-in per namespace** via `orama namespace enable webrtc`. +- **SFU on all 3 nodes**, **TURN on 2 of 3 nodes** (redundancy without over-provisioning). +- **Separate port allocation** from existing namespace services. + +## Prerequisites + +- Namespace must be provisioned with a ready cluster (RQLite + Olric + Gateway running). +- Command must be run on a cluster node (uses internal gateway endpoint). + +## Enable / Disable + +```bash +# Enable WebRTC for a namespace +orama namespace enable webrtc --namespace myapp + +# Check status +orama namespace webrtc-status --namespace myapp + +# Disable WebRTC (stops services, deallocates ports, removes DNS) +orama namespace disable webrtc --namespace myapp +``` + +### What happens on enable: +1. Generates a per-namespace TURN shared secret (32 bytes, crypto/rand) +2. Inserts `namespace_webrtc_config` DB record +3. Allocates WebRTC port blocks on each node (SFU signaling + media range, TURN relay range) +4. Spawns TURN on 2 nodes (selected by capacity) +5. Spawns SFU on all 3 nodes +6. Creates DNS A records: `turn.ns-{name}.{baseDomain}` pointing to TURN node public IPs +7. Updates cluster state on all nodes (for cold-boot restoration) + +### What happens on disable: +1. Stops SFU on all 3 nodes +2. Stops TURN on 2 nodes +3. Deallocates all WebRTC ports +4. Deletes TURN DNS records +5. Cleans up DB records (`namespace_webrtc_config`, `webrtc_rooms`) +6. Updates cluster state + +## Client Integration (JavaScript) + +### 1. Get TURN Credentials + +```javascript +const response = await fetch('https://ns-myapp.orama.network/v1/webrtc/turn/credentials', { + method: 'POST', + headers: { 'Authorization': `Bearer ${jwt}` } +}); + +const { urls, username, credential, ttl } = await response.json(); +// urls: ["turn:1.2.3.4:3478?transport=udp", "turns:1.2.3.4:443?transport=udp"] +// username: "{expiry_unix}:{namespace}" +// credential: HMAC-SHA1 derived +// ttl: 600 (seconds) +``` + +### 2. Create PeerConnection + +```javascript +const pc = new RTCPeerConnection({ + iceServers: [{ urls, username, credential }], + iceTransportPolicy: 'relay' // enforced by SFU +}); +``` + +### 3. Connect Signaling WebSocket + +```javascript +const ws = new WebSocket( + `wss://ns-myapp.orama.network/v1/webrtc/signal?room=${roomId}`, + ['Bearer', jwt] +); + +ws.onmessage = (event) => { + const msg = JSON.parse(event.data); + switch (msg.type) { + case 'offer': handleOffer(msg); break; + case 'answer': handleAnswer(msg); break; + case 'ice-candidate': handleICE(msg); break; + case 'peer-joined': handleJoin(msg); break; + case 'peer-left': handleLeave(msg); break; + case 'turn-credentials': + case 'refresh-credentials': + updateTURN(msg); // SFU sends refreshed creds at 80% TTL + break; + case 'server-draining': + reconnect(); // SFU shutting down, reconnect to another node + break; + } +}; +``` + +### 4. Room Management (REST) + +```javascript +// Create room +await fetch('/v1/webrtc/rooms', { + method: 'POST', + headers: { 'Authorization': `Bearer ${jwt}`, 'Content-Type': 'application/json' }, + body: JSON.stringify({ room_id: 'my-room' }) +}); + +// List rooms +const rooms = await fetch('/v1/webrtc/rooms', { + headers: { 'Authorization': `Bearer ${jwt}` } +}); + +// Close room +await fetch('/v1/webrtc/rooms?room_id=my-room', { + method: 'DELETE', + headers: { 'Authorization': `Bearer ${jwt}` } +}); +``` + +## API Reference + +### REST Endpoints + +| Method | Path | Auth | Description | +|--------|------|------|-------------| +| POST | `/v1/webrtc/turn/credentials` | JWT/API key | Get TURN relay credentials | +| GET/WS | `/v1/webrtc/signal` | JWT/API key | WebSocket signaling | +| GET | `/v1/webrtc/rooms` | JWT/API key | List rooms | +| POST | `/v1/webrtc/rooms` | JWT/API key (owner) | Create room | +| DELETE | `/v1/webrtc/rooms` | JWT/API key (owner) | Close room | + +### Signaling Messages + +| Type | Direction | Description | +|------|-----------|-------------| +| `join` | Client → SFU | Join room | +| `offer` | Client ↔ SFU | SDP offer | +| `answer` | Client ↔ SFU | SDP answer | +| `ice-candidate` | Client ↔ SFU | ICE candidate | +| `leave` | Client → SFU | Leave room | +| `peer-joined` | SFU → Client | New peer notification | +| `peer-left` | SFU → Client | Peer departure | +| `turn-credentials` | SFU → Client | Initial TURN credentials | +| `refresh-credentials` | SFU → Client | Refreshed credentials (at 80% TTL) | +| `server-draining` | SFU → Client | SFU shutting down | + +## Port Allocation + +WebRTC uses a **separate port allocation system** from the core namespace ports: + +| Service | Port Range | Per Namespace | +|---------|-----------|---------------| +| SFU signaling | 30000-30099 | 1 port | +| SFU media (RTP) | 20000-29999 | 500 ports | +| TURN listen | 3478 (standard) | fixed | +| TURN TLS | 443/udp (standard) | fixed | +| TURN relay | 49152-65535 | 800 ports | + +## TURN Credential Protocol + +- Credentials use HMAC-SHA1 with a per-namespace shared secret +- Username format: `{expiry_unix}:{namespace}` +- Default TTL: 600 seconds (10 minutes) +- SFU proactively sends `refresh-credentials` at 80% of TTL (8 minutes) +- Clients should update ICE servers on receiving refresh + +## Monitoring + +```bash +# Check WebRTC status +orama namespace webrtc-status --namespace myapp + +# Monitor report includes SFU/TURN status +orama monitor report --env devnet + +# Inspector checks WebRTC health +orama inspector --env devnet +``` + +The monitoring report includes per-namespace `sfu_up` and `turn_up` fields. The inspector runs cross-node checks to verify SFU coverage (3 nodes) and TURN redundancy (2 nodes). + +## Debugging + +```bash +# SFU logs +journalctl -u orama-namespace-sfu@myapp -f + +# TURN logs +journalctl -u orama-namespace-turn@myapp -f + +# Check service status +systemctl status orama-namespace-sfu@myapp +systemctl status orama-namespace-turn@myapp +``` + +## Security Model + +- **Forced relay**: `iceTransportPolicy: relay` enforced server-side. Clients cannot bypass TURN. +- **HMAC credentials**: Per-namespace TURN shared secret. Credentials expire after 10 minutes. +- **Namespace isolation**: Each namespace has its own TURN secret, port ranges, and rooms. +- **Authentication required**: All WebRTC endpoints require JWT or API key (not in `isPublicPath()`). +- **Room management**: Creating/closing rooms requires namespace ownership. +- **SFU on WireGuard only**: SFU binds to 10.0.0.x, never 0.0.0.0. Only reachable via TURN relay. +- **Permissions-Policy**: `camera=(self), microphone=(self)` — only same-origin can access media devices. + +## Firewall + +When WebRTC is enabled, the following ports are opened via UFW: + +| Port | Protocol | Purpose | +|------|----------|---------| +| 3478 | UDP | TURN standard | +| 443 | UDP | TURN TLS (does not conflict with Caddy TCP 443) | +| 49152-65535 | UDP | TURN relay range (allocated per namespace) | + +SFU ports are NOT opened in the firewall — they are WireGuard-internal only. + +## Database Tables + +| Table | Purpose | +|-------|---------| +| `namespace_webrtc_config` | Per-namespace WebRTC config (enabled, TURN secret, node counts) | +| `webrtc_rooms` | Room-to-SFU-node affinity | +| `webrtc_port_allocations` | SFU/TURN port tracking | + +## Cold Boot Recovery + +On node restart, the cluster state file (`cluster_state.json`) includes `has_sfu`, `has_turn`, and port allocation data. The restore process: + +1. Core services restore first: RQLite → Olric → Gateway +2. If `has_turn` is set: fetches TURN shared secret from DB, spawns TURN +3. If `has_sfu` is set: fetches WebRTC config from DB, spawns SFU with TURN server list + +If the DB is unavailable during restore, SFU/TURN restoration is skipped with a warning log. They will be restored on the next successful DB connection. diff --git a/e2e/shared/webrtc_test.go b/e2e/shared/webrtc_test.go new file mode 100644 index 0000000..9fb92c6 --- /dev/null +++ b/e2e/shared/webrtc_test.go @@ -0,0 +1,241 @@ +//go:build e2e + +package shared_test + +import ( + "bytes" + "encoding/json" + "net/http" + "strings" + "testing" + "time" + + e2e "github.com/DeBrosOfficial/network/e2e" +) + +// turnCredentialsResponse is the expected response from the TURN credentials endpoint. +type turnCredentialsResponse struct { + URLs []string `json:"urls"` + Username string `json:"username"` + Credential string `json:"credential"` + TTL int `json:"ttl"` +} + +// TestWebRTC_TURNCredentials_RequiresAuth verifies that the TURN credentials endpoint +// rejects unauthenticated requests. +func TestWebRTC_TURNCredentials_RequiresAuth(t *testing.T) { + e2e.SkipIfMissingGateway(t) + + gatewayURL := e2e.GetGatewayURL() + client := e2e.NewHTTPClient(10 * time.Second) + + req, err := http.NewRequest("POST", gatewayURL+"/v1/webrtc/turn/credentials", nil) + if err != nil { + t.Fatalf("failed to create request: %v", err) + } + + resp, err := client.Do(req) + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusUnauthorized { + t.Fatalf("expected 401 Unauthorized, got %d", resp.StatusCode) + } +} + +// TestWebRTC_TURNCredentials_ValidResponse verifies that authenticated requests to the +// TURN credentials endpoint return a valid credential structure. +func TestWebRTC_TURNCredentials_ValidResponse(t *testing.T) { + e2e.SkipIfMissingGateway(t) + + gatewayURL := e2e.GetGatewayURL() + apiKey := e2e.GetAPIKey() + if apiKey == "" { + t.Skip("no API key configured") + } + client := e2e.NewHTTPClient(10 * time.Second) + + req, err := http.NewRequest("POST", gatewayURL+"/v1/webrtc/turn/credentials", nil) + if err != nil { + t.Fatalf("failed to create request: %v", err) + } + req.Header.Set("Authorization", "Bearer "+apiKey) + + resp, err := client.Do(req) + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected 200 OK, got %d", resp.StatusCode) + } + + var creds turnCredentialsResponse + if err := json.NewDecoder(resp.Body).Decode(&creds); err != nil { + t.Fatalf("failed to decode response: %v", err) + } + + if len(creds.URLs) == 0 { + t.Fatal("expected at least one TURN URL") + } + if creds.Username == "" { + t.Fatal("expected non-empty username") + } + if creds.Credential == "" { + t.Fatal("expected non-empty credential") + } + if creds.TTL <= 0 { + t.Fatalf("expected positive TTL, got %d", creds.TTL) + } +} + +// TestWebRTC_Rooms_RequiresAuth verifies that the rooms endpoint rejects unauthenticated requests. +func TestWebRTC_Rooms_RequiresAuth(t *testing.T) { + e2e.SkipIfMissingGateway(t) + + gatewayURL := e2e.GetGatewayURL() + client := e2e.NewHTTPClient(10 * time.Second) + + req, err := http.NewRequest("GET", gatewayURL+"/v1/webrtc/rooms", nil) + if err != nil { + t.Fatalf("failed to create request: %v", err) + } + + resp, err := client.Do(req) + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusUnauthorized { + t.Fatalf("expected 401 Unauthorized, got %d", resp.StatusCode) + } +} + +// TestWebRTC_Signal_RequiresAuth verifies that the signaling WebSocket rejects +// unauthenticated connections. +func TestWebRTC_Signal_RequiresAuth(t *testing.T) { + e2e.SkipIfMissingGateway(t) + + gatewayURL := e2e.GetGatewayURL() + client := e2e.NewHTTPClient(10 * time.Second) + + // Use regular HTTP GET to the signal endpoint — without auth it should return 401 + // before WebSocket upgrade + req, err := http.NewRequest("GET", gatewayURL+"/v1/webrtc/signal?room=test-room", nil) + if err != nil { + t.Fatalf("failed to create request: %v", err) + } + + resp, err := client.Do(req) + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusUnauthorized { + t.Fatalf("expected 401, got %d", resp.StatusCode) + } +} + +// TestWebRTC_Rooms_CreateAndList verifies room creation and listing with proper auth. +func TestWebRTC_Rooms_CreateAndList(t *testing.T) { + e2e.SkipIfMissingGateway(t) + + gatewayURL := e2e.GetGatewayURL() + apiKey := e2e.GetAPIKey() + if apiKey == "" { + t.Skip("no API key configured") + } + client := e2e.NewHTTPClient(10 * time.Second) + + roomID := e2e.GenerateUniqueID("e2e-webrtc-room") + + // Create room + createBody, _ := json.Marshal(map[string]string{"room_id": roomID}) + req, err := http.NewRequest("POST", gatewayURL+"/v1/webrtc/rooms", bytes.NewReader(createBody)) + if err != nil { + t.Fatalf("failed to create request: %v", err) + } + req.Header.Set("Authorization", "Bearer "+apiKey) + req.Header.Set("Content-Type", "application/json") + + resp, err := client.Do(req) + if err != nil { + t.Fatalf("create room failed: %v", err) + } + resp.Body.Close() + + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated { + t.Fatalf("expected 200/201, got %d", resp.StatusCode) + } + + // List rooms + req, err = http.NewRequest("GET", gatewayURL+"/v1/webrtc/rooms", nil) + if err != nil { + t.Fatalf("failed to create request: %v", err) + } + req.Header.Set("Authorization", "Bearer "+apiKey) + + resp, err = client.Do(req) + if err != nil { + t.Fatalf("list rooms failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected 200, got %d", resp.StatusCode) + } + + // Clean up: delete room + req, err = http.NewRequest("DELETE", gatewayURL+"/v1/webrtc/rooms?room_id="+roomID, nil) + if err != nil { + t.Fatalf("failed to create request: %v", err) + } + req.Header.Set("Authorization", "Bearer "+apiKey) + + resp2, err := client.Do(req) + if err != nil { + t.Fatalf("delete room failed: %v", err) + } + resp2.Body.Close() +} + +// TestWebRTC_PermissionsPolicy verifies the Permissions-Policy header allows camera and microphone. +func TestWebRTC_PermissionsPolicy(t *testing.T) { + e2e.SkipIfMissingGateway(t) + + gatewayURL := e2e.GetGatewayURL() + apiKey := e2e.GetAPIKey() + if apiKey == "" { + t.Skip("no API key configured") + } + client := e2e.NewHTTPClient(10 * time.Second) + + req, err := http.NewRequest("GET", gatewayURL+"/v1/webrtc/rooms", nil) + if err != nil { + t.Fatalf("failed to create request: %v", err) + } + req.Header.Set("Authorization", "Bearer "+apiKey) + + resp, err := client.Do(req) + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + pp := resp.Header.Get("Permissions-Policy") + if pp == "" { + t.Skip("Permissions-Policy header not set") + } + + if !strings.Contains(pp, "camera=(self)") { + t.Errorf("Permissions-Policy missing camera=(self), got: %s", pp) + } + if !strings.Contains(pp, "microphone=(self)") { + t.Errorf("Permissions-Policy missing microphone=(self), got: %s", pp) + } +} diff --git a/migrations/018_webrtc_services.sql b/migrations/018_webrtc_services.sql new file mode 100644 index 0000000..de1116c --- /dev/null +++ b/migrations/018_webrtc_services.sql @@ -0,0 +1,96 @@ +-- Migration 018: WebRTC Services (SFU + TURN) for Namespace Clusters +-- Adds per-namespace WebRTC configuration, room tracking, and port allocation +-- WebRTC is opt-in: enabled via `orama namespace enable webrtc` + +BEGIN; + +-- Per-namespace WebRTC configuration +-- One row per namespace that has WebRTC enabled +CREATE TABLE IF NOT EXISTS namespace_webrtc_config ( + id TEXT PRIMARY KEY, -- UUID + namespace_cluster_id TEXT NOT NULL UNIQUE, -- FK to namespace_clusters + namespace_name TEXT NOT NULL, -- Cached for easier lookups + enabled INTEGER NOT NULL DEFAULT 1, -- 1 = enabled, 0 = disabled + + -- TURN authentication + turn_shared_secret TEXT NOT NULL, -- HMAC-SHA1 shared secret (base64, 32 bytes) + turn_credential_ttl INTEGER NOT NULL DEFAULT 600, -- Credential TTL in seconds (default: 10 min) + + -- Service topology + sfu_node_count INTEGER NOT NULL DEFAULT 3, -- SFU instances (all 3 nodes) + turn_node_count INTEGER NOT NULL DEFAULT 2, -- TURN instances (2 of 3 nodes for HA) + + -- Metadata + enabled_by TEXT NOT NULL, -- Wallet address that enabled WebRTC + enabled_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + disabled_at TIMESTAMP, + + FOREIGN KEY (namespace_cluster_id) REFERENCES namespace_clusters(id) ON DELETE CASCADE +); + +CREATE INDEX IF NOT EXISTS idx_webrtc_config_namespace ON namespace_webrtc_config(namespace_name); +CREATE INDEX IF NOT EXISTS idx_webrtc_config_cluster ON namespace_webrtc_config(namespace_cluster_id); + +-- WebRTC room tracking +-- Tracks active rooms and their SFU node affinity +CREATE TABLE IF NOT EXISTS webrtc_rooms ( + id TEXT PRIMARY KEY, -- UUID + namespace_cluster_id TEXT NOT NULL, -- FK to namespace_clusters + namespace_name TEXT NOT NULL, -- Cached for easier lookups + room_id TEXT NOT NULL, -- Application-defined room identifier + + -- SFU affinity + sfu_node_id TEXT NOT NULL, -- Node hosting this room's SFU + sfu_internal_ip TEXT NOT NULL, -- WireGuard IP of SFU node + sfu_signaling_port INTEGER NOT NULL, -- SFU WebSocket signaling port + + -- Room state + participant_count INTEGER NOT NULL DEFAULT 0, + max_participants INTEGER NOT NULL DEFAULT 100, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + last_activity TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + + -- Prevent duplicate rooms within a namespace + UNIQUE(namespace_cluster_id, room_id), + FOREIGN KEY (namespace_cluster_id) REFERENCES namespace_clusters(id) ON DELETE CASCADE +); + +CREATE INDEX IF NOT EXISTS idx_webrtc_rooms_namespace ON webrtc_rooms(namespace_name); +CREATE INDEX IF NOT EXISTS idx_webrtc_rooms_node ON webrtc_rooms(sfu_node_id); +CREATE INDEX IF NOT EXISTS idx_webrtc_rooms_activity ON webrtc_rooms(last_activity); + +-- WebRTC port allocations +-- Separate from namespace_port_allocations to avoid breaking existing port blocks +-- Each namespace gets SFU + TURN ports on each node where those services run +CREATE TABLE IF NOT EXISTS webrtc_port_allocations ( + id TEXT PRIMARY KEY, -- UUID + node_id TEXT NOT NULL, -- Physical node ID + namespace_cluster_id TEXT NOT NULL, -- FK to namespace_clusters + service_type TEXT NOT NULL, -- 'sfu' or 'turn' + + -- SFU ports (when service_type = 'sfu') + sfu_signaling_port INTEGER, -- WebSocket signaling port + sfu_media_port_start INTEGER, -- Start of RTP media port range + sfu_media_port_end INTEGER, -- End of RTP media port range + + -- TURN ports (when service_type = 'turn') + turn_listen_port INTEGER, -- TURN listener port (3478) + turn_tls_port INTEGER, -- TURN TLS port (443/UDP) + turn_relay_port_start INTEGER, -- Start of relay port range + turn_relay_port_end INTEGER, -- End of relay port range + + allocated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + + -- Prevent overlapping allocations + UNIQUE(node_id, namespace_cluster_id, service_type), + FOREIGN KEY (namespace_cluster_id) REFERENCES namespace_clusters(id) ON DELETE CASCADE +); + +CREATE INDEX IF NOT EXISTS idx_webrtc_ports_node ON webrtc_port_allocations(node_id); +CREATE INDEX IF NOT EXISTS idx_webrtc_ports_cluster ON webrtc_port_allocations(namespace_cluster_id); +CREATE INDEX IF NOT EXISTS idx_webrtc_ports_type ON webrtc_port_allocations(service_type); + +-- Mark migration as applied +INSERT OR IGNORE INTO schema_migrations(version) VALUES (18); + +COMMIT; diff --git a/pkg/cli/cmd/namespacecmd/namespace.go b/pkg/cli/cmd/namespacecmd/namespace.go index 1807f74..db0d0e9 100644 --- a/pkg/cli/cmd/namespacecmd/namespace.go +++ b/pkg/cli/cmd/namespacecmd/namespace.go @@ -45,10 +45,59 @@ var repairCmd = &cobra.Command{ }, } +var enableCmd = &cobra.Command{ + Use: "enable ", + Short: "Enable a feature for a namespace", + Long: "Enable a feature for a namespace. Supported features: webrtc", + Args: cobra.ExactArgs(1), + Run: func(cmd *cobra.Command, args []string) { + ns, _ := cmd.Flags().GetString("namespace") + cliArgs := []string{"enable", args[0]} + if ns != "" { + cliArgs = append(cliArgs, "--namespace", ns) + } + cli.HandleNamespaceCommand(cliArgs) + }, +} + +var disableCmd = &cobra.Command{ + Use: "disable ", + Short: "Disable a feature for a namespace", + Long: "Disable a feature for a namespace. Supported features: webrtc", + Args: cobra.ExactArgs(1), + Run: func(cmd *cobra.Command, args []string) { + ns, _ := cmd.Flags().GetString("namespace") + cliArgs := []string{"disable", args[0]} + if ns != "" { + cliArgs = append(cliArgs, "--namespace", ns) + } + cli.HandleNamespaceCommand(cliArgs) + }, +} + +var webrtcStatusCmd = &cobra.Command{ + Use: "webrtc-status", + Short: "Show WebRTC service status for a namespace", + Run: func(cmd *cobra.Command, args []string) { + ns, _ := cmd.Flags().GetString("namespace") + cliArgs := []string{"webrtc-status"} + if ns != "" { + cliArgs = append(cliArgs, "--namespace", ns) + } + cli.HandleNamespaceCommand(cliArgs) + }, +} + func init() { deleteCmd.Flags().Bool("force", false, "Skip confirmation prompt") + enableCmd.Flags().String("namespace", "", "Namespace name") + disableCmd.Flags().String("namespace", "", "Namespace name") + webrtcStatusCmd.Flags().String("namespace", "", "Namespace name") Cmd.AddCommand(listCmd) Cmd.AddCommand(deleteCmd) Cmd.AddCommand(repairCmd) + Cmd.AddCommand(enableCmd) + Cmd.AddCommand(disableCmd) + Cmd.AddCommand(webrtcStatusCmd) } diff --git a/pkg/cli/namespace_commands.go b/pkg/cli/namespace_commands.go index 7ed8c7f..6150406 100644 --- a/pkg/cli/namespace_commands.go +++ b/pkg/cli/namespace_commands.go @@ -37,6 +37,30 @@ func HandleNamespaceCommand(args []string) { os.Exit(1) } handleNamespaceRepair(args[1]) + case "enable": + if len(args) < 2 { + fmt.Fprintf(os.Stderr, "Usage: orama namespace enable --namespace \n") + fmt.Fprintf(os.Stderr, "Features: webrtc\n") + os.Exit(1) + } + handleNamespaceEnable(args[1:]) + case "disable": + if len(args) < 2 { + fmt.Fprintf(os.Stderr, "Usage: orama namespace disable --namespace \n") + fmt.Fprintf(os.Stderr, "Features: webrtc\n") + os.Exit(1) + } + handleNamespaceDisable(args[1:]) + case "webrtc-status": + var ns string + fs := flag.NewFlagSet("namespace webrtc-status", flag.ExitOnError) + fs.StringVar(&ns, "namespace", "", "Namespace name") + _ = fs.Parse(args[1:]) + if ns == "" { + fmt.Fprintf(os.Stderr, "Usage: orama namespace webrtc-status --namespace \n") + os.Exit(1) + } + handleNamespaceWebRTCStatus(ns) case "help": showNamespaceHelp() default: @@ -50,17 +74,24 @@ func showNamespaceHelp() { fmt.Printf("Namespace Management Commands\n\n") fmt.Printf("Usage: orama namespace \n\n") fmt.Printf("Subcommands:\n") - fmt.Printf(" list - List namespaces owned by the current wallet\n") - fmt.Printf(" delete - Delete the current namespace and all its resources\n") - fmt.Printf(" repair - Repair an under-provisioned namespace cluster (add missing nodes)\n") - fmt.Printf(" help - Show this help message\n\n") + fmt.Printf(" list - List namespaces owned by the current wallet\n") + fmt.Printf(" delete - Delete the current namespace and all its resources\n") + fmt.Printf(" repair - Repair an under-provisioned namespace cluster\n") + fmt.Printf(" enable webrtc --namespace NS - Enable WebRTC (SFU + TURN) for a namespace\n") + fmt.Printf(" disable webrtc --namespace NS - Disable WebRTC for a namespace\n") + fmt.Printf(" webrtc-status --namespace NS - Show WebRTC service status\n") + fmt.Printf(" help - Show this help message\n\n") fmt.Printf("Flags:\n") - fmt.Printf(" --force - Skip confirmation prompt (delete only)\n\n") + fmt.Printf(" --force - Skip confirmation prompt (delete only)\n") + fmt.Printf(" --namespace - Namespace name (enable/disable/webrtc-status)\n\n") fmt.Printf("Examples:\n") fmt.Printf(" orama namespace list\n") fmt.Printf(" orama namespace delete\n") fmt.Printf(" orama namespace delete --force\n") fmt.Printf(" orama namespace repair anchat\n") + fmt.Printf(" orama namespace enable webrtc --namespace myapp\n") + fmt.Printf(" orama namespace disable webrtc --namespace myapp\n") + fmt.Printf(" orama namespace webrtc-status --namespace myapp\n") } func handleNamespaceRepair(namespaceName string) { @@ -193,6 +224,203 @@ func handleNamespaceDelete(force bool) { fmt.Printf("Run 'orama auth login' to create a new namespace.\n") } +func handleNamespaceEnable(args []string) { + feature := args[0] + if feature != "webrtc" { + fmt.Fprintf(os.Stderr, "Unknown feature: %s\nSupported features: webrtc\n", feature) + os.Exit(1) + } + + var ns string + fs := flag.NewFlagSet("namespace enable webrtc", flag.ExitOnError) + fs.StringVar(&ns, "namespace", "", "Namespace name") + _ = fs.Parse(args[1:]) + + if ns == "" { + fmt.Fprintf(os.Stderr, "Usage: orama namespace enable webrtc --namespace \n") + os.Exit(1) + } + + gatewayURL, apiKey := loadAuthForNamespace(ns) + + fmt.Printf("Enabling WebRTC for namespace '%s'...\n", ns) + fmt.Printf("This will provision SFU (3 nodes) and TURN (2 nodes) services.\n") + + url := fmt.Sprintf("%s/v1/namespace/webrtc/enable", gatewayURL) + req, err := http.NewRequest(http.MethodPost, url, nil) + if err != nil { + fmt.Fprintf(os.Stderr, "Failed to create request: %v\n", err) + os.Exit(1) + } + req.Header.Set("Authorization", "Bearer "+apiKey) + + client := &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + }, + } + resp, err := client.Do(req) + if err != nil { + fmt.Fprintf(os.Stderr, "Failed to connect to gateway: %v\n", err) + os.Exit(1) + } + defer resp.Body.Close() + + var result map[string]interface{} + json.NewDecoder(resp.Body).Decode(&result) + + if resp.StatusCode != http.StatusOK { + errMsg := "unknown error" + if e, ok := result["error"].(string); ok { + errMsg = e + } + fmt.Fprintf(os.Stderr, "Failed to enable WebRTC: %s\n", errMsg) + os.Exit(1) + } + + fmt.Printf("WebRTC enabled for namespace '%s'.\n", ns) + fmt.Printf(" SFU instances: 3 nodes (signaling via WireGuard)\n") + fmt.Printf(" TURN instances: 2 nodes (relay on public IPs)\n") +} + +func handleNamespaceDisable(args []string) { + feature := args[0] + if feature != "webrtc" { + fmt.Fprintf(os.Stderr, "Unknown feature: %s\nSupported features: webrtc\n", feature) + os.Exit(1) + } + + var ns string + fs := flag.NewFlagSet("namespace disable webrtc", flag.ExitOnError) + fs.StringVar(&ns, "namespace", "", "Namespace name") + _ = fs.Parse(args[1:]) + + if ns == "" { + fmt.Fprintf(os.Stderr, "Usage: orama namespace disable webrtc --namespace \n") + os.Exit(1) + } + + gatewayURL, apiKey := loadAuthForNamespace(ns) + + fmt.Printf("Disabling WebRTC for namespace '%s'...\n", ns) + + url := fmt.Sprintf("%s/v1/namespace/webrtc/disable", gatewayURL) + req, err := http.NewRequest(http.MethodPost, url, nil) + if err != nil { + fmt.Fprintf(os.Stderr, "Failed to create request: %v\n", err) + os.Exit(1) + } + req.Header.Set("Authorization", "Bearer "+apiKey) + + client := &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + }, + } + resp, err := client.Do(req) + if err != nil { + fmt.Fprintf(os.Stderr, "Failed to connect to gateway: %v\n", err) + os.Exit(1) + } + defer resp.Body.Close() + + var result map[string]interface{} + json.NewDecoder(resp.Body).Decode(&result) + + if resp.StatusCode != http.StatusOK { + errMsg := "unknown error" + if e, ok := result["error"].(string); ok { + errMsg = e + } + fmt.Fprintf(os.Stderr, "Failed to disable WebRTC: %s\n", errMsg) + os.Exit(1) + } + + fmt.Printf("WebRTC disabled for namespace '%s'.\n", ns) + fmt.Printf(" SFU and TURN services stopped, ports deallocated, DNS records removed.\n") +} + +func handleNamespaceWebRTCStatus(ns string) { + gatewayURL, apiKey := loadAuthForNamespace(ns) + + url := fmt.Sprintf("%s/v1/namespace/webrtc/status", gatewayURL) + req, err := http.NewRequest(http.MethodGet, url, nil) + if err != nil { + fmt.Fprintf(os.Stderr, "Failed to create request: %v\n", err) + os.Exit(1) + } + req.Header.Set("Authorization", "Bearer "+apiKey) + + client := &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + }, + } + resp, err := client.Do(req) + if err != nil { + fmt.Fprintf(os.Stderr, "Failed to connect to gateway: %v\n", err) + os.Exit(1) + } + defer resp.Body.Close() + + var result map[string]interface{} + json.NewDecoder(resp.Body).Decode(&result) + + if resp.StatusCode != http.StatusOK { + errMsg := "unknown error" + if e, ok := result["error"].(string); ok { + errMsg = e + } + fmt.Fprintf(os.Stderr, "Failed to get WebRTC status: %s\n", errMsg) + os.Exit(1) + } + + enabled, _ := result["enabled"].(bool) + if !enabled { + fmt.Printf("WebRTC is not enabled for namespace '%s'.\n", ns) + fmt.Printf(" Enable with: orama namespace enable webrtc --namespace %s\n", ns) + return + } + + fmt.Printf("WebRTC Status for namespace '%s'\n\n", ns) + fmt.Printf(" Enabled: yes\n") + if sfuCount, ok := result["sfu_node_count"].(float64); ok { + fmt.Printf(" SFU nodes: %.0f\n", sfuCount) + } + if turnCount, ok := result["turn_node_count"].(float64); ok { + fmt.Printf(" TURN nodes: %.0f\n", turnCount) + } + if ttl, ok := result["turn_credential_ttl"].(float64); ok { + fmt.Printf(" TURN cred TTL: %.0fs\n", ttl) + } + if enabledBy, ok := result["enabled_by"].(string); ok { + fmt.Printf(" Enabled by: %s\n", enabledBy) + } + if enabledAt, ok := result["enabled_at"].(string); ok { + fmt.Printf(" Enabled at: %s\n", enabledAt) + } +} + +// loadAuthForNamespace loads credentials and returns the gateway URL and API key. +// Exits with an error message if not authenticated. +func loadAuthForNamespace(ns string) (gatewayURL, apiKey string) { + store, err := auth.LoadEnhancedCredentials() + if err != nil { + fmt.Fprintf(os.Stderr, "Failed to load credentials: %v\n", err) + os.Exit(1) + } + + gatewayURL = getGatewayURL() + creds := store.GetDefaultCredential(gatewayURL) + + if creds == nil || !creds.IsValid() { + fmt.Fprintf(os.Stderr, "Not authenticated. Run 'orama auth login' first.\n") + os.Exit(1) + } + + return gatewayURL, creds.APIKey +} + func handleNamespaceList() { // Load credentials store, err := auth.LoadEnhancedCredentials() diff --git a/pkg/cli/production/install/orchestrator.go b/pkg/cli/production/install/orchestrator.go index e7630d7..1689faa 100644 --- a/pkg/cli/production/install/orchestrator.go +++ b/pkg/cli/production/install/orchestrator.go @@ -543,6 +543,8 @@ func (o *Orchestrator) installNamespaceTemplates() error { "orama-namespace-rqlite@.service", "orama-namespace-olric@.service", "orama-namespace-gateway@.service", + "orama-namespace-sfu@.service", + "orama-namespace-turn@.service", } installedCount := 0 diff --git a/pkg/cli/production/report/namespaces.go b/pkg/cli/production/report/namespaces.go index 6725451..ef0bf8c 100644 --- a/pkg/cli/production/report/namespaces.go +++ b/pkg/cli/production/report/namespaces.go @@ -7,6 +7,7 @@ import ( "io" "net/http" "os" + "os/exec" "path/filepath" "regexp" "strconv" @@ -162,5 +163,25 @@ func collectNamespaceReport(ns nsInfo) NamespaceReport { } } + // 5. SFUUp: check if namespace SFU systemd service is active (optional) + r.SFUUp = isNamespaceServiceActive("sfu", ns.name) + + // 6. TURNUp: check if namespace TURN systemd service is active (optional) + r.TURNUp = isNamespaceServiceActive("turn", ns.name) + return r } + +// isNamespaceServiceActive checks if a namespace service is provisioned and active. +// Returns false if the service is not provisioned (no env file) or not running. +func isNamespaceServiceActive(serviceType, namespace string) bool { + // Only check if the service was provisioned (env file exists) + envFile := fmt.Sprintf("/opt/orama/.orama/data/namespaces/%s/%s.env", namespace, serviceType) + if _, err := os.Stat(envFile); err != nil { + return false // not provisioned + } + + svcName := fmt.Sprintf("orama-namespace-%s@%s", serviceType, namespace) + cmd := exec.Command("systemctl", "is-active", "--quiet", svcName) + return cmd.Run() == nil +} diff --git a/pkg/cli/production/report/types.go b/pkg/cli/production/report/types.go index b782267..7607917 100644 --- a/pkg/cli/production/report/types.go +++ b/pkg/cli/production/report/types.go @@ -274,6 +274,8 @@ type NamespaceReport struct { OlricUp bool `json:"olric_up"` GatewayUp bool `json:"gateway_up"` GatewayStatus int `json:"gateway_status,omitempty"` + SFUUp bool `json:"sfu_up"` + TURNUp bool `json:"turn_up"` } // --- Deployments --- diff --git a/pkg/cli/production/upgrade/orchestrator.go b/pkg/cli/production/upgrade/orchestrator.go index 73e4cec..455b563 100644 --- a/pkg/cli/production/upgrade/orchestrator.go +++ b/pkg/cli/production/upgrade/orchestrator.go @@ -431,6 +431,8 @@ func (o *Orchestrator) installNamespaceTemplates() error { "orama-namespace-rqlite@.service", "orama-namespace-olric@.service", "orama-namespace-gateway@.service", + "orama-namespace-sfu@.service", + "orama-namespace-turn@.service", } installedCount := 0 diff --git a/pkg/cli/utils/systemd.go b/pkg/cli/utils/systemd.go index a8b363d..b4a6ffb 100644 --- a/pkg/cli/utils/systemd.go +++ b/pkg/cli/utils/systemd.go @@ -184,7 +184,7 @@ func GetProductionServices() []string { namespacesDir := "/opt/orama/.orama/data/namespaces" nsEntries, err := os.ReadDir(namespacesDir) if err == nil { - serviceTypes := []string{"rqlite", "olric", "gateway"} + serviceTypes := []string{"rqlite", "olric", "gateway", "sfu", "turn"} for _, nsEntry := range nsEntries { if !nsEntry.IsDir() { continue @@ -289,7 +289,8 @@ func identifyPortProcess(port int) string { // NamespaceServiceOrder defines the dependency order for namespace services. // RQLite must start first (database), then Olric (cache), then Gateway (depends on both). -var NamespaceServiceOrder = []string{"rqlite", "olric", "gateway"} +// TURN and SFU are optional WebRTC services that start after Gateway. +var NamespaceServiceOrder = []string{"rqlite", "olric", "gateway", "turn", "sfu"} // StartServicesOrdered starts services respecting namespace dependency order. // Namespace services are started in order: rqlite → olric (+ wait) → gateway. diff --git a/pkg/config/gateway_config.go b/pkg/config/gateway_config.go index d277be9..c60b474 100644 --- a/pkg/config/gateway_config.go +++ b/pkg/config/gateway_config.go @@ -20,6 +20,17 @@ type HTTPGatewayConfig struct { IPFSAPIURL string `yaml:"ipfs_api_url"` // IPFS API URL IPFSTimeout time.Duration `yaml:"ipfs_timeout"` // Timeout for IPFS operations BaseDomain string `yaml:"base_domain"` // Base domain for deployments (e.g., "dbrs.space"). Defaults to "dbrs.space" + + // WebRTC configuration (optional, enabled per-namespace) + WebRTC WebRTCConfig `yaml:"webrtc"` +} + +// WebRTCConfig contains WebRTC-related gateway configuration +type WebRTCConfig struct { + Enabled bool `yaml:"enabled"` // Whether this gateway has WebRTC support active + SFUPort int `yaml:"sfu_port"` // Local SFU signaling port to proxy to + TURNDomain string `yaml:"turn_domain"` // TURN domain (e.g., "turn.ns-myapp.dbrs.space") + TURNSecret string `yaml:"turn_secret"` // HMAC-SHA1 shared secret for TURN credential generation } // HTTPSConfig contains HTTPS/TLS configuration for the gateway diff --git a/pkg/environments/production/firewall.go b/pkg/environments/production/firewall.go index 40ec143..0074a7c 100644 --- a/pkg/environments/production/firewall.go +++ b/pkg/environments/production/firewall.go @@ -12,6 +12,9 @@ type FirewallConfig struct { IsNameserver bool // enables port 53 TCP+UDP AnyoneORPort int // 0 = disabled, typically 9001 WireGuardPort int // default 51820 + TURNEnabled bool // enables TURN relay ports (3478/udp, 443/udp, relay range) + TURNRelayStart int // start of TURN relay port range (default 49152) + TURNRelayEnd int // end of TURN relay port range (default 65535) } // FirewallProvisioner manages UFW firewall setup @@ -84,6 +87,15 @@ func (fp *FirewallProvisioner) GenerateRules() []string { rules = append(rules, fmt.Sprintf("ufw allow %d/tcp", fp.config.AnyoneORPort)) } + // TURN relay (only for nodes running TURN servers) + if fp.config.TURNEnabled { + rules = append(rules, "ufw allow 3478/udp") // TURN standard port + rules = append(rules, "ufw allow 443/udp") // TURN TLS port (does not conflict with Caddy TCP 443) + if fp.config.TURNRelayStart > 0 && fp.config.TURNRelayEnd > 0 { + rules = append(rules, fmt.Sprintf("ufw allow %d:%d/udp", fp.config.TURNRelayStart, fp.config.TURNRelayEnd)) + } + } + // Allow all traffic from WireGuard subnet (inter-node encrypted traffic) rules = append(rules, "ufw allow from 10.0.0.0/8") @@ -130,6 +142,47 @@ func (fp *FirewallProvisioner) IsActive() bool { return strings.Contains(string(output), "Status: active") } +// AddWebRTCRules dynamically adds TURN port rules without a full firewall reset. +// Used when enabling WebRTC on a namespace. +func (fp *FirewallProvisioner) AddWebRTCRules(relayStart, relayEnd int) error { + rules := []string{ + "ufw allow 3478/udp", + "ufw allow 443/udp", + } + if relayStart > 0 && relayEnd > 0 { + rules = append(rules, fmt.Sprintf("ufw allow %d:%d/udp", relayStart, relayEnd)) + } + + for _, rule := range rules { + parts := strings.Fields(rule) + cmd := exec.Command(parts[0], parts[1:]...) + if output, err := cmd.CombinedOutput(); err != nil { + return fmt.Errorf("failed to add firewall rule '%s': %w\n%s", rule, err, string(output)) + } + } + return nil +} + +// RemoveWebRTCRules dynamically removes TURN port rules without a full firewall reset. +// Used when disabling WebRTC on a namespace. +func (fp *FirewallProvisioner) RemoveWebRTCRules(relayStart, relayEnd int) error { + rules := []string{ + "ufw delete allow 3478/udp", + "ufw delete allow 443/udp", + } + if relayStart > 0 && relayEnd > 0 { + rules = append(rules, fmt.Sprintf("ufw delete allow %d:%d/udp", relayStart, relayEnd)) + } + + for _, rule := range rules { + parts := strings.Fields(rule) + cmd := exec.Command(parts[0], parts[1:]...) + // Ignore errors on delete — rule may not exist + cmd.CombinedOutput() + } + return nil +} + // GetStatus returns the current UFW status func (fp *FirewallProvisioner) GetStatus() (string, error) { cmd := exec.Command("ufw", "status", "verbose") diff --git a/pkg/environments/production/installers/caddy.go b/pkg/environments/production/installers/caddy.go index 88eb7f2..449653b 100644 --- a/pkg/environments/production/installers/caddy.go +++ b/pkg/environments/production/installers/caddy.go @@ -378,7 +378,8 @@ func (ci *CaddyInstaller) generateCaddyfile(domain, email, acmeEndpoint, baseDom }`, acmeEndpoint) var sb strings.Builder - sb.WriteString(fmt.Sprintf("{\n email %s\n}\n", email)) + // Disable HTTP/3 (QUIC) so Caddy doesn't bind UDP 443, which TURN needs for relay + sb.WriteString(fmt.Sprintf("{\n email %s\n servers {\n protocols h1 h2\n }\n}\n", email)) // Node domain blocks (e.g., node1.dbrs.space, *.node1.dbrs.space) sb.WriteString(fmt.Sprintf("\n*.%s {\n%s\n reverse_proxy localhost:6001\n}\n", domain, tlsBlock)) diff --git a/pkg/gateway/config.go b/pkg/gateway/config.go index 9384513..e45a74b 100644 --- a/pkg/gateway/config.go +++ b/pkg/gateway/config.go @@ -41,4 +41,10 @@ type Config struct { // WireGuard mesh configuration ClusterSecret string // Cluster secret for authenticating internal WireGuard peer exchange + + // WebRTC configuration (set when namespace has WebRTC enabled) + WebRTCEnabled bool // Whether WebRTC endpoints are active on this gateway + SFUPort int // Local SFU signaling port to proxy WebSocket connections to + TURNDomain string // TURN server domain for credential generation + TURNSecret string // HMAC-SHA1 shared secret for TURN credential generation } diff --git a/pkg/gateway/config_validate.go b/pkg/gateway/config_validate.go index e4e086d..1a2036e 100644 --- a/pkg/gateway/config_validate.go +++ b/pkg/gateway/config_validate.go @@ -70,6 +70,16 @@ func (c *Config) ValidateConfig() []error { } } + // Validate WebRTC configuration + if c.WebRTCEnabled { + if c.SFUPort <= 0 || c.SFUPort > 65535 { + errs = append(errs, fmt.Errorf("gateway.sfu_port: must be between 1 and 65535 when webrtc is enabled")) + } + if c.TURNSecret == "" { + errs = append(errs, fmt.Errorf("gateway.turn_secret: must not be empty when webrtc is enabled")) + } + } + // Validate HTTPS configuration if c.EnableHTTPS { if c.DomainName == "" { diff --git a/pkg/gateway/dependencies.go b/pkg/gateway/dependencies.go index 32b2787..b1043c1 100644 --- a/pkg/gateway/dependencies.go +++ b/pkg/gateway/dependencies.go @@ -22,6 +22,7 @@ import ( "github.com/DeBrosOfficial/network/pkg/rqlite" "github.com/DeBrosOfficial/network/pkg/serverless" "github.com/DeBrosOfficial/network/pkg/serverless/hostfunctions" + "github.com/DeBrosOfficial/network/pkg/serverless/triggers" "github.com/multiformats/go-multiaddr" olriclib "github.com/olric-data/olric" "go.uber.org/zap" @@ -59,6 +60,9 @@ type Dependencies struct { ServerlessWSMgr *serverless.WSManager ServerlessHandlers *serverlesshandlers.ServerlessHandlers + // PubSub trigger dispatcher (used to wire into PubSubHandlers) + PubSubDispatcher *triggers.PubSubDispatcher + // Authentication service AuthService *auth.Service } @@ -434,11 +438,27 @@ func initializeServerless(logger *logging.ColoredLogger, cfg *Config, deps *Depe // Create invoker deps.ServerlessInvoker = serverless.NewInvoker(engine, registry, hostFuncs, logger.Logger) + // Create PubSub trigger store and dispatcher + triggerStore := triggers.NewPubSubTriggerStore(deps.ORMClient, logger.Logger) + + var olricUnderlying olriclib.Client + if deps.OlricClient != nil { + olricUnderlying = deps.OlricClient.UnderlyingClient() + } + deps.PubSubDispatcher = triggers.NewPubSubDispatcher( + triggerStore, + deps.ServerlessInvoker, + olricUnderlying, + logger.Logger, + ) + // Create HTTP handlers deps.ServerlessHandlers = serverlesshandlers.NewServerlessHandlers( deps.ServerlessInvoker, registry, deps.ServerlessWSMgr, + triggerStore, + deps.PubSubDispatcher, logger.Logger, ) diff --git a/pkg/gateway/gateway.go b/pkg/gateway/gateway.go index 5ef414e..ca82b73 100644 --- a/pkg/gateway/gateway.go +++ b/pkg/gateway/gateway.go @@ -30,6 +30,7 @@ import ( pubsubhandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/pubsub" serverlesshandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/serverless" joinhandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/join" + webrtchandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/webrtc" wireguardhandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/wireguard" sqlitehandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/sqlite" "github.com/DeBrosOfficial/network/pkg/gateway/handlers/storage" @@ -122,6 +123,9 @@ type Gateway struct { rateLimiter *RateLimiter namespaceRateLimiter *NamespaceRateLimiter + // WebRTC signaling and TURN credentials + webrtcHandlers *webrtchandlers.WebRTCHandlers + // WireGuard peer exchange wireguardHandler *wireguardhandlers.Handler @@ -149,6 +153,9 @@ type Gateway struct { // Node recovery handler (called when health monitor confirms a node dead or recovered) nodeRecoverer authhandlers.NodeRecoverer + // WebRTC manager for enable/disable operations + webrtcManager authhandlers.WebRTCManager + // Circuit breakers for proxy targets (per-target failure tracking) circuitBreakers *CircuitBreakerRegistry @@ -323,6 +330,25 @@ func New(logger *logging.ColoredLogger, cfg *Config) (*Gateway, error) { // Initialize handler instances gw.pubsubHandlers = pubsubhandlers.NewPubSubHandlers(deps.Client, logger) + // Wire PubSub trigger dispatch if serverless is available + if deps.PubSubDispatcher != nil { + gw.pubsubHandlers.SetOnPublish(func(ctx context.Context, namespace, topic string, data []byte) { + deps.PubSubDispatcher.Dispatch(ctx, namespace, topic, data, 0) + }) + } + + if cfg.WebRTCEnabled && cfg.SFUPort > 0 { + gw.webrtcHandlers = webrtchandlers.NewWebRTCHandlers( + logger, + cfg.SFUPort, + cfg.TURNDomain, + cfg.TURNSecret, + gw.proxyWebSocket, + ) + logger.ComponentInfo(logging.ComponentGeneral, "WebRTC handlers initialized", + zap.Int("sfu_port", cfg.SFUPort)) + } + if deps.OlricClient != nil { gw.cacheHandlers = cache.NewCacheHandlers(logger, deps.OlricClient) } @@ -633,6 +659,11 @@ func (g *Gateway) SetNodeRecoverer(nr authhandlers.NodeRecoverer) { g.nodeRecoverer = nr } +// SetWebRTCManager sets the WebRTC lifecycle manager for enable/disable operations. +func (g *Gateway) SetWebRTCManager(wm authhandlers.WebRTCManager) { + g.webrtcManager = wm +} + // SetSpawnHandler sets the handler for internal namespace spawn/stop requests. func (g *Gateway) SetSpawnHandler(h http.Handler) { g.spawnHandler = h @@ -847,3 +878,224 @@ func (g *Gateway) namespaceClusterRepairHandler(w http.ResponseWriter, r *http.R }) } +// namespaceWebRTCEnablePublicHandler handles POST /v1/namespace/webrtc/enable +// Public: authenticated by JWT/API key via auth middleware. Namespace from context. +func (g *Gateway) namespaceWebRTCEnablePublicHandler(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + writeError(w, http.StatusMethodNotAllowed, "method not allowed") + return + } + + namespaceName, _ := r.Context().Value(CtxKeyNamespaceOverride).(string) + if namespaceName == "" { + writeError(w, http.StatusForbidden, "namespace not resolved") + return + } + + if g.webrtcManager == nil { + writeError(w, http.StatusServiceUnavailable, "WebRTC management not enabled") + return + } + + if err := g.webrtcManager.EnableWebRTC(r.Context(), namespaceName, "api"); err != nil { + writeError(w, http.StatusInternalServerError, err.Error()) + return + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]interface{}{ + "status": "ok", + "namespace": namespaceName, + "message": "WebRTC enabled successfully", + }) +} + +// namespaceWebRTCDisablePublicHandler handles POST /v1/namespace/webrtc/disable +// Public: authenticated by JWT/API key via auth middleware. Namespace from context. +func (g *Gateway) namespaceWebRTCDisablePublicHandler(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + writeError(w, http.StatusMethodNotAllowed, "method not allowed") + return + } + + namespaceName, _ := r.Context().Value(CtxKeyNamespaceOverride).(string) + if namespaceName == "" { + writeError(w, http.StatusForbidden, "namespace not resolved") + return + } + + if g.webrtcManager == nil { + writeError(w, http.StatusServiceUnavailable, "WebRTC management not enabled") + return + } + + if err := g.webrtcManager.DisableWebRTC(r.Context(), namespaceName); err != nil { + writeError(w, http.StatusInternalServerError, err.Error()) + return + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]interface{}{ + "status": "ok", + "namespace": namespaceName, + "message": "WebRTC disabled successfully", + }) +} + +// namespaceWebRTCStatusPublicHandler handles GET /v1/namespace/webrtc/status +// Public: authenticated by JWT/API key via auth middleware. Namespace from context. +func (g *Gateway) namespaceWebRTCStatusPublicHandler(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + writeError(w, http.StatusMethodNotAllowed, "method not allowed") + return + } + + namespaceName, _ := r.Context().Value(CtxKeyNamespaceOverride).(string) + if namespaceName == "" { + writeError(w, http.StatusForbidden, "namespace not resolved") + return + } + + if g.webrtcManager == nil { + writeError(w, http.StatusServiceUnavailable, "WebRTC management not enabled") + return + } + + config, err := g.webrtcManager.GetWebRTCStatus(r.Context(), namespaceName) + if err != nil { + writeError(w, http.StatusInternalServerError, err.Error()) + return + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + if config == nil { + json.NewEncoder(w).Encode(map[string]interface{}{ + "namespace": namespaceName, + "enabled": false, + }) + } else { + json.NewEncoder(w).Encode(config) + } +} + +// namespaceWebRTCEnableHandler handles POST /v1/internal/namespace/webrtc/enable?namespace={name} +// Internal-only: authenticated by X-Orama-Internal-Auth header + WireGuard subnet. +func (g *Gateway) namespaceWebRTCEnableHandler(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + writeError(w, http.StatusMethodNotAllowed, "method not allowed") + return + } + + if r.Header.Get("X-Orama-Internal-Auth") != "namespace-coordination" || !nodeauth.IsWireGuardPeer(r.RemoteAddr) { + writeError(w, http.StatusUnauthorized, "unauthorized") + return + } + + namespaceName := r.URL.Query().Get("namespace") + if namespaceName == "" { + writeError(w, http.StatusBadRequest, "namespace parameter required") + return + } + + if g.webrtcManager == nil { + writeError(w, http.StatusServiceUnavailable, "WebRTC management not enabled") + return + } + + if err := g.webrtcManager.EnableWebRTC(r.Context(), namespaceName, "cli"); err != nil { + writeError(w, http.StatusInternalServerError, err.Error()) + return + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]interface{}{ + "status": "ok", + "namespace": namespaceName, + "message": "WebRTC enabled successfully", + }) +} + +// namespaceWebRTCDisableHandler handles POST /v1/internal/namespace/webrtc/disable?namespace={name} +// Internal-only: authenticated by X-Orama-Internal-Auth header + WireGuard subnet. +func (g *Gateway) namespaceWebRTCDisableHandler(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + writeError(w, http.StatusMethodNotAllowed, "method not allowed") + return + } + + if r.Header.Get("X-Orama-Internal-Auth") != "namespace-coordination" || !nodeauth.IsWireGuardPeer(r.RemoteAddr) { + writeError(w, http.StatusUnauthorized, "unauthorized") + return + } + + namespaceName := r.URL.Query().Get("namespace") + if namespaceName == "" { + writeError(w, http.StatusBadRequest, "namespace parameter required") + return + } + + if g.webrtcManager == nil { + writeError(w, http.StatusServiceUnavailable, "WebRTC management not enabled") + return + } + + if err := g.webrtcManager.DisableWebRTC(r.Context(), namespaceName); err != nil { + writeError(w, http.StatusInternalServerError, err.Error()) + return + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]interface{}{ + "status": "ok", + "namespace": namespaceName, + "message": "WebRTC disabled successfully", + }) +} + +// namespaceWebRTCStatusHandler handles GET /v1/internal/namespace/webrtc/status?namespace={name} +// Internal-only: authenticated by X-Orama-Internal-Auth header + WireGuard subnet. +func (g *Gateway) namespaceWebRTCStatusHandler(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + writeError(w, http.StatusMethodNotAllowed, "method not allowed") + return + } + + if r.Header.Get("X-Orama-Internal-Auth") != "namespace-coordination" || !nodeauth.IsWireGuardPeer(r.RemoteAddr) { + writeError(w, http.StatusUnauthorized, "unauthorized") + return + } + + namespaceName := r.URL.Query().Get("namespace") + if namespaceName == "" { + writeError(w, http.StatusBadRequest, "namespace parameter required") + return + } + + if g.webrtcManager == nil { + writeError(w, http.StatusServiceUnavailable, "WebRTC management not enabled") + return + } + + config, err := g.webrtcManager.GetWebRTCStatus(r.Context(), namespaceName) + if err != nil { + writeError(w, http.StatusInternalServerError, err.Error()) + return + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + if config == nil { + json.NewEncoder(w).Encode(map[string]interface{}{ + "namespace": namespaceName, + "enabled": false, + }) + } else { + json.NewEncoder(w).Encode(config) + } +} + diff --git a/pkg/gateway/handlers/auth/handlers.go b/pkg/gateway/handlers/auth/handlers.go index beb4733..eb08721 100644 --- a/pkg/gateway/handlers/auth/handlers.go +++ b/pkg/gateway/handlers/auth/handlers.go @@ -58,6 +58,14 @@ type NodeRecoverer interface { RepairCluster(ctx context.Context, namespaceName string) error } +// WebRTCManager handles enabling/disabling WebRTC services for namespaces. +type WebRTCManager interface { + EnableWebRTC(ctx context.Context, namespaceName, enabledBy string) error + DisableWebRTC(ctx context.Context, namespaceName string) error + // GetWebRTCStatus returns the WebRTC config for a namespace, or nil if not enabled. + GetWebRTCStatus(ctx context.Context, namespaceName string) (interface{}, error) +} + // Handlers holds dependencies for authentication HTTP handlers type Handlers struct { logger *logging.ColoredLogger diff --git a/pkg/gateway/handlers/namespace/delete_handler.go b/pkg/gateway/handlers/namespace/delete_handler.go index 2a0d9e1..2021fd9 100644 --- a/pkg/gateway/handlers/namespace/delete_handler.go +++ b/pkg/gateway/handlers/namespace/delete_handler.go @@ -302,6 +302,8 @@ func (h *DeleteHandler) cleanupGlobalTables(ctx context.Context, ns string) { {"namespace_sqlite_databases", "namespace"}, {"namespace_quotas", "namespace"}, {"home_node_assignments", "namespace"}, + {"webrtc_rooms", "namespace_name"}, + {"namespace_webrtc_config", "namespace_name"}, } for _, t := range tables { diff --git a/pkg/gateway/handlers/namespace/spawn_handler.go b/pkg/gateway/handlers/namespace/spawn_handler.go index e075862..6d9b2da 100644 --- a/pkg/gateway/handlers/namespace/spawn_handler.go +++ b/pkg/gateway/handlers/namespace/spawn_handler.go @@ -12,12 +12,13 @@ import ( namespacepkg "github.com/DeBrosOfficial/network/pkg/namespace" "github.com/DeBrosOfficial/network/pkg/olric" "github.com/DeBrosOfficial/network/pkg/rqlite" + "github.com/DeBrosOfficial/network/pkg/sfu" "go.uber.org/zap" ) // SpawnRequest represents a request to spawn or stop a namespace instance type SpawnRequest struct { - Action string `json:"action"` // "spawn-rqlite", "spawn-olric", "spawn-gateway", "stop-rqlite", "stop-olric", "stop-gateway", "save-cluster-state", "delete-cluster-state" + Action string `json:"action"` // spawn-{rqlite,olric,gateway,sfu,turn}, stop-{rqlite,olric,gateway,sfu,turn}, save-cluster-state, delete-cluster-state Namespace string `json:"namespace"` NodeID string `json:"node_id"` @@ -48,6 +49,24 @@ type SpawnRequest struct { IPFSTimeout string `json:"ipfs_timeout,omitempty"` IPFSReplicationFactor int `json:"ipfs_replication_factor,omitempty"` + // SFU config (when action = "spawn-sfu") + SFUListenAddr string `json:"sfu_listen_addr,omitempty"` + SFUMediaStart int `json:"sfu_media_start,omitempty"` + SFUMediaEnd int `json:"sfu_media_end,omitempty"` + TURNServers []sfu.TURNServerConfig `json:"turn_servers,omitempty"` + TURNSecret string `json:"turn_secret,omitempty"` + TURNCredTTL int `json:"turn_cred_ttl,omitempty"` + RQLiteDSN string `json:"rqlite_dsn,omitempty"` + + // TURN config (when action = "spawn-turn") + TURNListenAddr string `json:"turn_listen_addr,omitempty"` + TURNTLSAddr string `json:"turn_tls_addr,omitempty"` + TURNPublicIP string `json:"turn_public_ip,omitempty"` + TURNRealm string `json:"turn_realm,omitempty"` + TURNAuthSecret string `json:"turn_auth_secret,omitempty"` + TURNRelayStart int `json:"turn_relay_start,omitempty"` + TURNRelayEnd int `json:"turn_relay_end,omitempty"` + // Cluster state (when action = "save-cluster-state") ClusterState json.RawMessage `json:"cluster_state,omitempty"` } @@ -242,6 +261,60 @@ func (h *SpawnHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } writeSpawnResponse(w, http.StatusOK, SpawnResponse{Success: true}) + case "spawn-sfu": + cfg := namespacepkg.SFUInstanceConfig{ + Namespace: req.Namespace, + NodeID: req.NodeID, + ListenAddr: req.SFUListenAddr, + MediaPortStart: req.SFUMediaStart, + MediaPortEnd: req.SFUMediaEnd, + TURNServers: req.TURNServers, + TURNSecret: req.TURNSecret, + TURNCredTTL: req.TURNCredTTL, + RQLiteDSN: req.RQLiteDSN, + } + if err := h.systemdSpawner.SpawnSFU(ctx, req.Namespace, req.NodeID, cfg); err != nil { + h.logger.Error("Failed to spawn SFU instance", zap.Error(err)) + writeSpawnResponse(w, http.StatusInternalServerError, SpawnResponse{Error: err.Error()}) + return + } + writeSpawnResponse(w, http.StatusOK, SpawnResponse{Success: true}) + + case "stop-sfu": + if err := h.systemdSpawner.StopSFU(ctx, req.Namespace, req.NodeID); err != nil { + h.logger.Error("Failed to stop SFU instance", zap.Error(err)) + writeSpawnResponse(w, http.StatusInternalServerError, SpawnResponse{Error: err.Error()}) + return + } + writeSpawnResponse(w, http.StatusOK, SpawnResponse{Success: true}) + + case "spawn-turn": + cfg := namespacepkg.TURNInstanceConfig{ + Namespace: req.Namespace, + NodeID: req.NodeID, + ListenAddr: req.TURNListenAddr, + TLSListenAddr: req.TURNTLSAddr, + PublicIP: req.TURNPublicIP, + Realm: req.TURNRealm, + AuthSecret: req.TURNAuthSecret, + RelayPortStart: req.TURNRelayStart, + RelayPortEnd: req.TURNRelayEnd, + } + if err := h.systemdSpawner.SpawnTURN(ctx, req.Namespace, req.NodeID, cfg); err != nil { + h.logger.Error("Failed to spawn TURN instance", zap.Error(err)) + writeSpawnResponse(w, http.StatusInternalServerError, SpawnResponse{Error: err.Error()}) + return + } + writeSpawnResponse(w, http.StatusOK, SpawnResponse{Success: true}) + + case "stop-turn": + if err := h.systemdSpawner.StopTURN(ctx, req.Namespace, req.NodeID); err != nil { + h.logger.Error("Failed to stop TURN instance", zap.Error(err)) + writeSpawnResponse(w, http.StatusInternalServerError, SpawnResponse{Error: err.Error()}) + return + } + writeSpawnResponse(w, http.StatusOK, SpawnResponse{Success: true}) + default: writeSpawnResponse(w, http.StatusBadRequest, SpawnResponse{Error: fmt.Sprintf("unknown action: %s", req.Action)}) } diff --git a/pkg/gateway/handlers/pubsub/publish_handler.go b/pkg/gateway/handlers/pubsub/publish_handler.go index 63e5450..a3cedd5 100644 --- a/pkg/gateway/handlers/pubsub/publish_handler.go +++ b/pkg/gateway/handlers/pubsub/publish_handler.go @@ -67,6 +67,11 @@ func (p *PubSubHandlers) PublishHandler(w http.ResponseWriter, r *http.Request) zap.Int("local_subscribers", len(localSubs)), zap.Int("local_delivered", localDeliveryCount)) + // Fire PubSub triggers for serverless functions (non-blocking) + if p.onPublish != nil { + go p.onPublish(context.Background(), ns, body.Topic, data) + } + // Publish to libp2p asynchronously for cross-node delivery // This prevents blocking the HTTP response if libp2p network is slow go func() { diff --git a/pkg/gateway/handlers/pubsub/types.go b/pkg/gateway/handlers/pubsub/types.go index 3d95acf..21f238a 100644 --- a/pkg/gateway/handlers/pubsub/types.go +++ b/pkg/gateway/handlers/pubsub/types.go @@ -1,6 +1,7 @@ package pubsub import ( + "context" "net/http" "sync" @@ -19,6 +20,16 @@ type PubSubHandlers struct { presenceMembers map[string][]PresenceMember // topicKey -> members mu sync.RWMutex presenceMu sync.RWMutex + + // onPublish is called when a message is published, to dispatch PubSub triggers. + // Set via SetOnPublish. May be nil if serverless triggers are not configured. + onPublish func(ctx context.Context, namespace, topic string, data []byte) +} + +// SetOnPublish sets the callback invoked when messages are published. +// Used to wire PubSub trigger dispatch from the serverless engine. +func (p *PubSubHandlers) SetOnPublish(fn func(ctx context.Context, namespace, topic string, data []byte)) { + p.onPublish = fn } // NewPubSubHandlers creates a new PubSubHandlers instance diff --git a/pkg/gateway/handlers/serverless/deploy_handler.go b/pkg/gateway/handlers/serverless/deploy_handler.go index 7595395..0e4a2fd 100644 --- a/pkg/gateway/handlers/serverless/deploy_handler.go +++ b/pkg/gateway/handlers/serverless/deploy_handler.go @@ -154,6 +154,20 @@ func (h *ServerlessHandlers) DeployFunction(w http.ResponseWriter, r *http.Reque return } + // Register PubSub triggers from definition (deploy-time auto-registration) + if h.triggerStore != nil && len(def.PubSubTopics) > 0 && fn != nil { + _ = h.triggerStore.RemoveByFunction(ctx, fn.ID) + for _, topic := range def.PubSubTopics { + if _, err := h.triggerStore.Add(ctx, fn.ID, topic); err != nil { + h.logger.Warn("Failed to register pubsub trigger", + zap.String("topic", topic), + zap.Error(err)) + } else if h.dispatcher != nil { + h.dispatcher.InvalidateCache(ctx, def.Namespace, topic) + } + } + } + writeJSON(w, http.StatusCreated, map[string]interface{}{ "message": "Function deployed successfully", "function": fn, diff --git a/pkg/gateway/handlers/serverless/handlers_test.go b/pkg/gateway/handlers/serverless/handlers_test.go index c3f3cb4..15187c0 100644 --- a/pkg/gateway/handlers/serverless/handlers_test.go +++ b/pkg/gateway/handlers/serverless/handlers_test.go @@ -92,6 +92,8 @@ func newTestHandlers(reg serverless.FunctionRegistry) *ServerlessHandlers { nil, // invoker is nil — we only test paths that don't reach it reg, wsManager, + nil, // triggerStore + nil, // dispatcher logger, ) } diff --git a/pkg/gateway/handlers/serverless/routes.go b/pkg/gateway/handlers/serverless/routes.go index 24fefe8..3a7fec5 100644 --- a/pkg/gateway/handlers/serverless/routes.go +++ b/pkg/gateway/handlers/serverless/routes.go @@ -30,14 +30,17 @@ func (h *ServerlessHandlers) handleFunctions(w http.ResponseWriter, r *http.Requ // handleFunctionByName handles operations on a specific function // Routes: -// - GET /v1/functions/{name} - Get function info -// - DELETE /v1/functions/{name} - Delete function -// - POST /v1/functions/{name}/invoke - Invoke function -// - GET /v1/functions/{name}/versions - List versions -// - GET /v1/functions/{name}/logs - Get logs -// - WS /v1/functions/{name}/ws - WebSocket invoke +// - GET /v1/functions/{name} - Get function info +// - DELETE /v1/functions/{name} - Delete function +// - POST /v1/functions/{name}/invoke - Invoke function +// - GET /v1/functions/{name}/versions - List versions +// - GET /v1/functions/{name}/logs - Get logs +// - WS /v1/functions/{name}/ws - WebSocket invoke +// - POST /v1/functions/{name}/triggers - Add trigger +// - GET /v1/functions/{name}/triggers - List triggers +// - DELETE /v1/functions/{name}/triggers/{id} - Remove trigger func (h *ServerlessHandlers) handleFunctionByName(w http.ResponseWriter, r *http.Request) { - // Parse path: /v1/functions/{name}[/{action}] + // Parse path: /v1/functions/{name}[/{action}[/{subID}]] path := strings.TrimPrefix(r.URL.Path, "/v1/functions/") parts := strings.SplitN(path, "/", 2) @@ -62,6 +65,13 @@ func (h *ServerlessHandlers) handleFunctionByName(w http.ResponseWriter, r *http } } + // Handle triggers sub-path: "triggers" or "triggers/{triggerID}" + triggerID := "" + if strings.HasPrefix(action, "triggers/") { + triggerID = strings.TrimPrefix(action, "triggers/") + action = "triggers" + } + switch action { case "invoke": h.InvokeFunction(w, r, name, version) @@ -71,6 +81,17 @@ func (h *ServerlessHandlers) handleFunctionByName(w http.ResponseWriter, r *http h.ListVersions(w, r, name) case "logs": h.GetFunctionLogs(w, r, name) + case "triggers": + switch { + case triggerID != "" && r.Method == http.MethodDelete: + h.HandleDeleteTrigger(w, r, name, triggerID) + case r.Method == http.MethodPost: + h.HandleAddTrigger(w, r, name) + case r.Method == http.MethodGet: + h.HandleListTriggers(w, r, name) + default: + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } case "": switch r.Method { case http.MethodGet: diff --git a/pkg/gateway/handlers/serverless/trigger_handler.go b/pkg/gateway/handlers/serverless/trigger_handler.go new file mode 100644 index 0000000..8832866 --- /dev/null +++ b/pkg/gateway/handlers/serverless/trigger_handler.go @@ -0,0 +1,188 @@ +package serverless + +import ( + "context" + "encoding/json" + "net/http" + "time" + + "github.com/DeBrosOfficial/network/pkg/serverless" + "go.uber.org/zap" +) + +// addTriggerRequest is the request body for adding a PubSub trigger. +type addTriggerRequest struct { + Topic string `json:"topic"` +} + +// HandleAddTrigger handles POST /v1/functions/{name}/triggers +// Adds a PubSub trigger that invokes this function when a message is published to the topic. +func (h *ServerlessHandlers) HandleAddTrigger(w http.ResponseWriter, r *http.Request, functionName string) { + if h.triggerStore == nil { + writeError(w, http.StatusNotImplemented, "PubSub triggers not available") + return + } + + namespace := h.getNamespaceFromRequest(r) + if namespace == "" { + writeError(w, http.StatusBadRequest, "namespace required") + return + } + + var req addTriggerRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeError(w, http.StatusBadRequest, "Invalid JSON: "+err.Error()) + return + } + + if req.Topic == "" { + writeError(w, http.StatusBadRequest, "topic required") + return + } + + ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second) + defer cancel() + + // Look up function to get its ID + fn, err := h.registry.Get(ctx, namespace, functionName, 0) + if err != nil { + if serverless.IsNotFound(err) { + writeError(w, http.StatusNotFound, "Function not found") + } else { + writeError(w, http.StatusInternalServerError, "Failed to look up function") + } + return + } + + triggerID, err := h.triggerStore.Add(ctx, fn.ID, req.Topic) + if err != nil { + h.logger.Error("Failed to add PubSub trigger", + zap.String("function", functionName), + zap.String("topic", req.Topic), + zap.Error(err), + ) + writeError(w, http.StatusInternalServerError, "Failed to add trigger: "+err.Error()) + return + } + + // Invalidate cache for this topic + if h.dispatcher != nil { + h.dispatcher.InvalidateCache(ctx, namespace, req.Topic) + } + + h.logger.Info("PubSub trigger added via API", + zap.String("function", functionName), + zap.String("topic", req.Topic), + zap.String("trigger_id", triggerID), + ) + + writeJSON(w, http.StatusCreated, map[string]interface{}{ + "trigger_id": triggerID, + "function": functionName, + "topic": req.Topic, + }) +} + +// HandleListTriggers handles GET /v1/functions/{name}/triggers +// Lists all PubSub triggers for a function. +func (h *ServerlessHandlers) HandleListTriggers(w http.ResponseWriter, r *http.Request, functionName string) { + if h.triggerStore == nil { + writeError(w, http.StatusNotImplemented, "PubSub triggers not available") + return + } + + namespace := h.getNamespaceFromRequest(r) + if namespace == "" { + writeError(w, http.StatusBadRequest, "namespace required") + return + } + + ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second) + defer cancel() + + // Look up function to get its ID + fn, err := h.registry.Get(ctx, namespace, functionName, 0) + if err != nil { + if serverless.IsNotFound(err) { + writeError(w, http.StatusNotFound, "Function not found") + } else { + writeError(w, http.StatusInternalServerError, "Failed to look up function") + } + return + } + + triggers, err := h.triggerStore.ListByFunction(ctx, fn.ID) + if err != nil { + writeError(w, http.StatusInternalServerError, "Failed to list triggers") + return + } + + writeJSON(w, http.StatusOK, map[string]interface{}{ + "triggers": triggers, + "count": len(triggers), + }) +} + +// HandleDeleteTrigger handles DELETE /v1/functions/{name}/triggers/{triggerID} +// Removes a PubSub trigger. +func (h *ServerlessHandlers) HandleDeleteTrigger(w http.ResponseWriter, r *http.Request, functionName, triggerID string) { + if h.triggerStore == nil { + writeError(w, http.StatusNotImplemented, "PubSub triggers not available") + return + } + + namespace := h.getNamespaceFromRequest(r) + if namespace == "" { + writeError(w, http.StatusBadRequest, "namespace required") + return + } + + ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second) + defer cancel() + + // Look up the trigger's topic before deleting (for cache invalidation) + fn, err := h.registry.Get(ctx, namespace, functionName, 0) + if err != nil { + if serverless.IsNotFound(err) { + writeError(w, http.StatusNotFound, "Function not found") + } else { + writeError(w, http.StatusInternalServerError, "Failed to look up function") + } + return + } + + // Get current triggers to find the topic for cache invalidation + triggers, err := h.triggerStore.ListByFunction(ctx, fn.ID) + if err != nil { + writeError(w, http.StatusInternalServerError, "Failed to look up triggers") + return + } + + // Find the topic for the trigger being deleted + var triggerTopic string + for _, t := range triggers { + if t.ID == triggerID { + triggerTopic = t.Topic + break + } + } + + if err := h.triggerStore.Remove(ctx, triggerID); err != nil { + writeError(w, http.StatusInternalServerError, "Failed to remove trigger: "+err.Error()) + return + } + + // Invalidate cache for the topic + if h.dispatcher != nil && triggerTopic != "" { + h.dispatcher.InvalidateCache(ctx, namespace, triggerTopic) + } + + h.logger.Info("PubSub trigger removed via API", + zap.String("function", functionName), + zap.String("trigger_id", triggerID), + ) + + writeJSON(w, http.StatusOK, map[string]interface{}{ + "message": "Trigger removed", + }) +} diff --git a/pkg/gateway/handlers/serverless/types.go b/pkg/gateway/handlers/serverless/types.go index 8e7ef6c..3a561fc 100644 --- a/pkg/gateway/handlers/serverless/types.go +++ b/pkg/gateway/handlers/serverless/types.go @@ -6,16 +6,19 @@ import ( "github.com/DeBrosOfficial/network/pkg/gateway/auth" "github.com/DeBrosOfficial/network/pkg/gateway/ctxkeys" "github.com/DeBrosOfficial/network/pkg/serverless" + "github.com/DeBrosOfficial/network/pkg/serverless/triggers" "go.uber.org/zap" ) // ServerlessHandlers contains handlers for serverless function endpoints. // It's a separate struct to keep the Gateway struct clean. type ServerlessHandlers struct { - invoker *serverless.Invoker - registry serverless.FunctionRegistry - wsManager *serverless.WSManager - logger *zap.Logger + invoker *serverless.Invoker + registry serverless.FunctionRegistry + wsManager *serverless.WSManager + triggerStore *triggers.PubSubTriggerStore + dispatcher *triggers.PubSubDispatcher + logger *zap.Logger } // NewServerlessHandlers creates a new ServerlessHandlers instance. @@ -23,13 +26,17 @@ func NewServerlessHandlers( invoker *serverless.Invoker, registry serverless.FunctionRegistry, wsManager *serverless.WSManager, + triggerStore *triggers.PubSubTriggerStore, + dispatcher *triggers.PubSubDispatcher, logger *zap.Logger, ) *ServerlessHandlers { return &ServerlessHandlers{ - invoker: invoker, - registry: registry, - wsManager: wsManager, - logger: logger, + invoker: invoker, + registry: registry, + wsManager: wsManager, + triggerStore: triggerStore, + dispatcher: dispatcher, + logger: logger, } } diff --git a/pkg/gateway/handlers/webrtc/credentials.go b/pkg/gateway/handlers/webrtc/credentials.go new file mode 100644 index 0000000..dba6544 --- /dev/null +++ b/pkg/gateway/handlers/webrtc/credentials.go @@ -0,0 +1,56 @@ +package webrtc + +import ( + "fmt" + "net/http" + "time" + + "github.com/DeBrosOfficial/network/pkg/logging" + "github.com/DeBrosOfficial/network/pkg/turn" + "go.uber.org/zap" +) + +const turnCredentialTTL = 10 * time.Minute + +// CredentialsHandler handles POST /v1/webrtc/turn/credentials +// Returns fresh TURN credentials scoped to the authenticated namespace. +func (h *WebRTCHandlers) CredentialsHandler(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + writeError(w, http.StatusMethodNotAllowed, "method not allowed") + return + } + + ns := resolveNamespaceFromRequest(r) + if ns == "" { + writeError(w, http.StatusForbidden, "namespace not resolved") + return + } + + if h.turnSecret == "" { + writeError(w, http.StatusServiceUnavailable, "TURN not configured") + return + } + + username, password := turn.GenerateCredentials(h.turnSecret, ns, turnCredentialTTL) + + // Build TURN URIs — use IPs to bypass DNS propagation delays + var uris []string + if h.turnDomain != "" { + uris = append(uris, + fmt.Sprintf("turn:%s:3478?transport=udp", h.turnDomain), + fmt.Sprintf("turn:%s:443?transport=udp", h.turnDomain), + ) + } + + h.logger.ComponentInfo(logging.ComponentGeneral, "Issued TURN credentials", + zap.String("namespace", ns), + zap.String("username", username), + ) + + writeJSON(w, http.StatusOK, map[string]interface{}{ + "username": username, + "password": password, + "ttl": int(turnCredentialTTL.Seconds()), + "uris": uris, + }) +} diff --git a/pkg/gateway/handlers/webrtc/handlers_test.go b/pkg/gateway/handlers/webrtc/handlers_test.go new file mode 100644 index 0000000..cad3110 --- /dev/null +++ b/pkg/gateway/handlers/webrtc/handlers_test.go @@ -0,0 +1,270 @@ +package webrtc + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/DeBrosOfficial/network/pkg/gateway/ctxkeys" + "github.com/DeBrosOfficial/network/pkg/logging" +) + +func testHandlers() *WebRTCHandlers { + logger, _ := logging.NewColoredLogger(logging.ComponentGeneral, false) + return NewWebRTCHandlers( + logger, + 8443, + "turn.ns-test.dbrs.space", + "test-secret-key-32bytes-long!!!!", + nil, // No actual proxy in tests + ) +} + +func requestWithNamespace(method, path, namespace string) *http.Request { + req := httptest.NewRequest(method, path, nil) + ctx := context.WithValue(req.Context(), ctxkeys.NamespaceOverride, namespace) + return req.WithContext(ctx) +} + +// --- Credentials handler tests --- + +func TestCredentialsHandler_Success(t *testing.T) { + h := testHandlers() + req := requestWithNamespace("POST", "/v1/webrtc/turn/credentials", "test-ns") + w := httptest.NewRecorder() + + h.CredentialsHandler(w, req) + + if w.Code != http.StatusOK { + t.Errorf("status = %d, want %d", w.Code, http.StatusOK) + } + + var result map[string]interface{} + if err := json.NewDecoder(w.Body).Decode(&result); err != nil { + t.Fatalf("failed to decode response: %v", err) + } + + if result["username"] == nil || result["username"] == "" { + t.Error("expected non-empty username") + } + if result["password"] == nil || result["password"] == "" { + t.Error("expected non-empty password") + } + if result["ttl"] == nil { + t.Error("expected ttl field") + } + ttl, ok := result["ttl"].(float64) + if !ok || ttl != 600 { + t.Errorf("ttl = %v, want 600", result["ttl"]) + } + uris, ok := result["uris"].([]interface{}) + if !ok || len(uris) != 2 { + t.Errorf("uris count = %v, want 2", result["uris"]) + } +} + +func TestCredentialsHandler_MethodNotAllowed(t *testing.T) { + h := testHandlers() + req := requestWithNamespace("GET", "/v1/webrtc/turn/credentials", "test-ns") + w := httptest.NewRecorder() + + h.CredentialsHandler(w, req) + + if w.Code != http.StatusMethodNotAllowed { + t.Errorf("status = %d, want %d", w.Code, http.StatusMethodNotAllowed) + } +} + +func TestCredentialsHandler_NoNamespace(t *testing.T) { + h := testHandlers() + req := httptest.NewRequest("POST", "/v1/webrtc/turn/credentials", nil) + w := httptest.NewRecorder() + + h.CredentialsHandler(w, req) + + if w.Code != http.StatusForbidden { + t.Errorf("status = %d, want %d", w.Code, http.StatusForbidden) + } +} + +func TestCredentialsHandler_NoTURNSecret(t *testing.T) { + logger, _ := logging.NewColoredLogger(logging.ComponentGeneral, false) + h := NewWebRTCHandlers(logger, 8443, "turn.test.dbrs.space", "", nil) + + req := requestWithNamespace("POST", "/v1/webrtc/turn/credentials", "test-ns") + w := httptest.NewRecorder() + + h.CredentialsHandler(w, req) + + if w.Code != http.StatusServiceUnavailable { + t.Errorf("status = %d, want %d", w.Code, http.StatusServiceUnavailable) + } +} + +// --- Signal handler tests --- + +func TestSignalHandler_NoNamespace(t *testing.T) { + h := testHandlers() + req := httptest.NewRequest("GET", "/v1/webrtc/signal", nil) + w := httptest.NewRecorder() + + h.SignalHandler(w, req) + + if w.Code != http.StatusForbidden { + t.Errorf("status = %d, want %d", w.Code, http.StatusForbidden) + } +} + +func TestSignalHandler_NoSFUPort(t *testing.T) { + logger, _ := logging.NewColoredLogger(logging.ComponentGeneral, false) + h := NewWebRTCHandlers(logger, 0, "", "secret", nil) + + req := requestWithNamespace("GET", "/v1/webrtc/signal", "test-ns") + w := httptest.NewRecorder() + + h.SignalHandler(w, req) + + if w.Code != http.StatusServiceUnavailable { + t.Errorf("status = %d, want %d", w.Code, http.StatusServiceUnavailable) + } +} + +func TestSignalHandler_NoProxyFunc(t *testing.T) { + h := testHandlers() // proxyWebSocket is nil + req := requestWithNamespace("GET", "/v1/webrtc/signal", "test-ns") + w := httptest.NewRecorder() + + h.SignalHandler(w, req) + + if w.Code != http.StatusInternalServerError { + t.Errorf("status = %d, want %d", w.Code, http.StatusInternalServerError) + } +} + +// --- Rooms handler tests --- + +func TestRoomsHandler_MethodNotAllowed(t *testing.T) { + h := testHandlers() + req := requestWithNamespace("POST", "/v1/webrtc/rooms", "test-ns") + w := httptest.NewRecorder() + + h.RoomsHandler(w, req) + + if w.Code != http.StatusMethodNotAllowed { + t.Errorf("status = %d, want %d", w.Code, http.StatusMethodNotAllowed) + } +} + +func TestRoomsHandler_NoNamespace(t *testing.T) { + h := testHandlers() + req := httptest.NewRequest("GET", "/v1/webrtc/rooms", nil) + w := httptest.NewRecorder() + + h.RoomsHandler(w, req) + + if w.Code != http.StatusForbidden { + t.Errorf("status = %d, want %d", w.Code, http.StatusForbidden) + } +} + +func TestRoomsHandler_NoSFUPort(t *testing.T) { + logger, _ := logging.NewColoredLogger(logging.ComponentGeneral, false) + h := NewWebRTCHandlers(logger, 0, "", "secret", nil) + + req := requestWithNamespace("GET", "/v1/webrtc/rooms", "test-ns") + w := httptest.NewRecorder() + + h.RoomsHandler(w, req) + + if w.Code != http.StatusServiceUnavailable { + t.Errorf("status = %d, want %d", w.Code, http.StatusServiceUnavailable) + } +} + +func TestRoomsHandler_SFUProxySuccess(t *testing.T) { + // Start a mock SFU health endpoint + mockSFU := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"status":"ok","rooms":3}`)) + })) + defer mockSFU.Close() + + // Extract port from mock server + logger, _ := logging.NewColoredLogger(logging.ComponentGeneral, false) + // Parse port from mockSFU.URL (format: http://127.0.0.1:PORT) + var port int + for i := len(mockSFU.URL) - 1; i >= 0; i-- { + if mockSFU.URL[i] == ':' { + p := mockSFU.URL[i+1:] + for _, c := range p { + port = port*10 + int(c-'0') + } + break + } + } + + h := NewWebRTCHandlers(logger, port, "", "secret", nil) + req := requestWithNamespace("GET", "/v1/webrtc/rooms", "test-ns") + w := httptest.NewRecorder() + + h.RoomsHandler(w, req) + + if w.Code != http.StatusOK { + t.Errorf("status = %d, want %d", w.Code, http.StatusOK) + } + + body := w.Body.String() + if body != `{"status":"ok","rooms":3}` { + t.Errorf("body = %q, want %q", body, `{"status":"ok","rooms":3}`) + } +} + +// --- Helper tests --- + +func TestResolveNamespaceFromRequest(t *testing.T) { + // With namespace + req := requestWithNamespace("GET", "/test", "my-namespace") + ns := resolveNamespaceFromRequest(req) + if ns != "my-namespace" { + t.Errorf("namespace = %q, want %q", ns, "my-namespace") + } + + // Without namespace + req = httptest.NewRequest("GET", "/test", nil) + ns = resolveNamespaceFromRequest(req) + if ns != "" { + t.Errorf("namespace = %q, want empty", ns) + } +} + +func TestWriteError(t *testing.T) { + w := httptest.NewRecorder() + writeError(w, http.StatusBadRequest, "bad request") + + if w.Code != http.StatusBadRequest { + t.Errorf("status = %d, want %d", w.Code, http.StatusBadRequest) + } + + var result map[string]string + if err := json.NewDecoder(w.Body).Decode(&result); err != nil { + t.Fatalf("failed to decode: %v", err) + } + if result["error"] != "bad request" { + t.Errorf("error = %q, want %q", result["error"], "bad request") + } +} + +func TestWriteJSON(t *testing.T) { + w := httptest.NewRecorder() + writeJSON(w, http.StatusOK, map[string]string{"status": "ok"}) + + if w.Code != http.StatusOK { + t.Errorf("status = %d, want %d", w.Code, http.StatusOK) + } + if ct := w.Header().Get("Content-Type"); ct != "application/json" { + t.Errorf("Content-Type = %q, want %q", ct, "application/json") + } +} diff --git a/pkg/gateway/handlers/webrtc/rooms.go b/pkg/gateway/handlers/webrtc/rooms.go new file mode 100644 index 0000000..12a621e --- /dev/null +++ b/pkg/gateway/handlers/webrtc/rooms.go @@ -0,0 +1,51 @@ +package webrtc + +import ( + "fmt" + "io" + "net/http" + "time" + + "github.com/DeBrosOfficial/network/pkg/logging" + "go.uber.org/zap" +) + +// RoomsHandler handles GET /v1/webrtc/rooms (list rooms) +// and GET /v1/webrtc/rooms?room_id=X (get specific room) +// Proxies to the local SFU's health endpoint for room data. +func (h *WebRTCHandlers) RoomsHandler(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + writeError(w, http.StatusMethodNotAllowed, "method not allowed") + return + } + + ns := resolveNamespaceFromRequest(r) + if ns == "" { + writeError(w, http.StatusForbidden, "namespace not resolved") + return + } + + if h.sfuPort <= 0 { + writeError(w, http.StatusServiceUnavailable, "SFU not configured") + return + } + + // Proxy to SFU health endpoint which returns room count + targetURL := fmt.Sprintf("http://127.0.0.1:%d/health", h.sfuPort) + + client := &http.Client{Timeout: 5 * time.Second} + resp, err := client.Get(targetURL) + if err != nil { + h.logger.ComponentWarn(logging.ComponentGeneral, "SFU health check failed", + zap.String("namespace", ns), + zap.Error(err), + ) + writeError(w, http.StatusServiceUnavailable, "SFU unavailable") + return + } + defer resp.Body.Close() + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(resp.StatusCode) + io.Copy(w, resp.Body) +} diff --git a/pkg/gateway/handlers/webrtc/signal.go b/pkg/gateway/handlers/webrtc/signal.go new file mode 100644 index 0000000..5f15ace --- /dev/null +++ b/pkg/gateway/handlers/webrtc/signal.go @@ -0,0 +1,52 @@ +package webrtc + +import ( + "fmt" + "net/http" + + "github.com/DeBrosOfficial/network/pkg/logging" + "go.uber.org/zap" +) + +// SignalHandler handles WebSocket /v1/webrtc/signal +// Proxies the WebSocket connection to the local SFU's signaling endpoint. +func (h *WebRTCHandlers) SignalHandler(w http.ResponseWriter, r *http.Request) { + ns := resolveNamespaceFromRequest(r) + if ns == "" { + writeError(w, http.StatusForbidden, "namespace not resolved") + return + } + + if h.sfuPort <= 0 { + writeError(w, http.StatusServiceUnavailable, "SFU not configured") + return + } + + // Proxy WebSocket to local SFU on WireGuard IP + // SFU binds to WireGuard IP, so we use 127.0.0.1 since we're on the same node + targetHost := fmt.Sprintf("127.0.0.1:%d", h.sfuPort) + + h.logger.ComponentDebug(logging.ComponentGeneral, "Proxying WebRTC signal to SFU", + zap.String("namespace", ns), + zap.String("target", targetHost), + ) + + // Rewrite the URL path to match the SFU's expected endpoint + r.URL.Path = "/ws/signal" + r.URL.Scheme = "http" + r.URL.Host = targetHost + r.Host = targetHost + + if h.proxyWebSocket == nil { + writeError(w, http.StatusInternalServerError, "WebSocket proxy not available") + return + } + + if !h.proxyWebSocket(w, r, targetHost) { + // proxyWebSocket already wrote the error response + h.logger.ComponentWarn(logging.ComponentGeneral, "SFU WebSocket proxy failed", + zap.String("namespace", ns), + zap.String("target", targetHost), + ) + } +} diff --git a/pkg/gateway/handlers/webrtc/types.go b/pkg/gateway/handlers/webrtc/types.go new file mode 100644 index 0000000..64ca747 --- /dev/null +++ b/pkg/gateway/handlers/webrtc/types.go @@ -0,0 +1,58 @@ +package webrtc + +import ( + "encoding/json" + "net/http" + + "github.com/DeBrosOfficial/network/pkg/gateway/ctxkeys" + "github.com/DeBrosOfficial/network/pkg/logging" +) + +// WebRTCHandlers handles all WebRTC-related HTTP and WebSocket endpoints. +// These run on the namespace gateway and proxy signaling to the local SFU. +type WebRTCHandlers struct { + logger *logging.ColoredLogger + sfuPort int // Local SFU signaling port to proxy WebSocket connections to + turnDomain string // TURN server domain for building URIs + turnSecret string // HMAC-SHA1 shared secret for TURN credential generation + + // proxyWebSocket is injected from the gateway to reuse its WebSocket proxy logic + proxyWebSocket func(w http.ResponseWriter, r *http.Request, targetHost string) bool +} + +// NewWebRTCHandlers creates a new WebRTCHandlers instance. +func NewWebRTCHandlers( + logger *logging.ColoredLogger, + sfuPort int, + turnDomain string, + turnSecret string, + proxyWS func(w http.ResponseWriter, r *http.Request, targetHost string) bool, +) *WebRTCHandlers { + return &WebRTCHandlers{ + logger: logger, + sfuPort: sfuPort, + turnDomain: turnDomain, + turnSecret: turnSecret, + proxyWebSocket: proxyWS, + } +} + +// resolveNamespaceFromRequest gets namespace from context set by auth middleware +func resolveNamespaceFromRequest(r *http.Request) string { + if v := r.Context().Value(ctxkeys.NamespaceOverride); v != nil { + if s, ok := v.(string); ok { + return s + } + } + return "" +} + +func writeJSON(w http.ResponseWriter, code int, v any) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(code) + json.NewEncoder(w).Encode(v) +} + +func writeError(w http.ResponseWriter, code int, msg string) { + writeJSON(w, code, map[string]string{"error": msg}) +} diff --git a/pkg/gateway/middleware.go b/pkg/gateway/middleware.go index 0b3be80..6c9718b 100644 --- a/pkg/gateway/middleware.go +++ b/pkg/gateway/middleware.go @@ -196,7 +196,7 @@ func (g *Gateway) securityHeadersMiddleware(next http.Handler) http.Handler { w.Header().Set("X-Frame-Options", "DENY") w.Header().Set("X-XSS-Protection", "0") w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin") - w.Header().Set("Permissions-Policy", "camera=(), microphone=(), geolocation=()") + w.Header().Set("Permissions-Policy", "camera=(self), microphone=(self), geolocation=()") // HSTS only when behind TLS (Caddy) if r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https" { w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains") @@ -618,6 +618,9 @@ func requiresNamespaceOwnership(p string) bool { if strings.HasPrefix(p, "/v1/functions") { return true } + if strings.HasPrefix(p, "/v1/webrtc/") { + return true + } return false } diff --git a/pkg/gateway/middleware_test.go b/pkg/gateway/middleware_test.go index bef792b..b5e38cb 100644 --- a/pkg/gateway/middleware_test.go +++ b/pkg/gateway/middleware_test.go @@ -382,7 +382,7 @@ func TestSecurityHeadersMiddleware(t *testing.T) { "X-Frame-Options": "DENY", "X-Xss-Protection": "0", "Referrer-Policy": "strict-origin-when-cross-origin", - "Permissions-Policy": "camera=(), microphone=(), geolocation=()", + "Permissions-Policy": "camera=(self), microphone=(self), geolocation=()", } for header, want := range expected { got := rr.Header().Get(header) diff --git a/pkg/gateway/routes.go b/pkg/gateway/routes.go index 63831f1..4e49910 100644 --- a/pkg/gateway/routes.go +++ b/pkg/gateway/routes.go @@ -47,6 +47,16 @@ func (g *Gateway) Routes() http.Handler { // Namespace cluster repair (internal, handler does its own auth) mux.HandleFunc("/v1/internal/namespace/repair", g.namespaceClusterRepairHandler) + // Namespace WebRTC enable/disable/status (internal, handler does its own auth) + mux.HandleFunc("/v1/internal/namespace/webrtc/enable", g.namespaceWebRTCEnableHandler) + mux.HandleFunc("/v1/internal/namespace/webrtc/disable", g.namespaceWebRTCDisableHandler) + mux.HandleFunc("/v1/internal/namespace/webrtc/status", g.namespaceWebRTCStatusHandler) + + // Namespace WebRTC enable/disable/status (public, JWT/API key auth via middleware) + mux.HandleFunc("/v1/namespace/webrtc/enable", g.namespaceWebRTCEnablePublicHandler) + mux.HandleFunc("/v1/namespace/webrtc/disable", g.namespaceWebRTCDisablePublicHandler) + mux.HandleFunc("/v1/namespace/webrtc/status", g.namespaceWebRTCStatusPublicHandler) + // auth endpoints mux.HandleFunc("/v1/auth/jwks", g.authService.JWKSHandler) mux.HandleFunc("/.well-known/jwks.json", g.authService.JWKSHandler) @@ -104,6 +114,13 @@ func (g *Gateway) Routes() http.Handler { mux.HandleFunc("/v1/pubsub/presence", g.pubsubHandlers.PresenceHandler) } + // webrtc + if g.webrtcHandlers != nil { + mux.HandleFunc("/v1/webrtc/turn/credentials", g.webrtcHandlers.CredentialsHandler) + mux.HandleFunc("/v1/webrtc/signal", g.webrtcHandlers.SignalHandler) + mux.HandleFunc("/v1/webrtc/rooms", g.webrtcHandlers.RoomsHandler) + } + // anon proxy (authenticated users only) mux.HandleFunc("/v1/proxy/anon", g.anonProxyHandler) diff --git a/pkg/gateway/serverless_handlers_test.go b/pkg/gateway/serverless_handlers_test.go index 7796dc4..c2501de 100644 --- a/pkg/gateway/serverless_handlers_test.go +++ b/pkg/gateway/serverless_handlers_test.go @@ -50,7 +50,7 @@ func TestServerlessHandlers_ListFunctions(t *testing.T) { }, } - h := serverlesshandlers.NewServerlessHandlers(nil, registry, nil, logger) + h := serverlesshandlers.NewServerlessHandlers(nil, registry, nil, nil, nil, logger) req, _ := http.NewRequest("GET", "/v1/functions?namespace=ns1", nil) rr := httptest.NewRecorder() @@ -73,7 +73,7 @@ func TestServerlessHandlers_DeployFunction(t *testing.T) { logger := zap.NewNop() registry := &mockFunctionRegistry{} - h := serverlesshandlers.NewServerlessHandlers(nil, registry, nil, logger) + h := serverlesshandlers.NewServerlessHandlers(nil, registry, nil, nil, nil, logger) // Test JSON deploy (which is partially supported according to code) // Should be 400 because WASM is missing or base64 not supported diff --git a/pkg/inspector/checks/webrtc.go b/pkg/inspector/checks/webrtc.go new file mode 100644 index 0000000..0d87d34 --- /dev/null +++ b/pkg/inspector/checks/webrtc.go @@ -0,0 +1,132 @@ +package checks + +import ( + "fmt" + + "github.com/DeBrosOfficial/network/pkg/inspector" +) + +func init() { + inspector.RegisterChecker("webrtc", CheckWebRTC) +} + +const webrtcSub = "webrtc" + +// CheckWebRTC runs WebRTC (SFU/TURN) health checks. +// These checks only apply to namespaces that have SFU or TURN provisioned. +func CheckWebRTC(data *inspector.ClusterData) []inspector.CheckResult { + var results []inspector.CheckResult + + for _, nd := range data.Nodes { + results = append(results, checkWebRTCPerNode(nd)...) + } + + results = append(results, checkWebRTCCrossNode(data)...) + + return results +} + +func checkWebRTCPerNode(nd *inspector.NodeData) []inspector.CheckResult { + var r []inspector.CheckResult + node := nd.Node.Name() + + for _, ns := range nd.Namespaces { + // Only check SFU/TURN if they are provisioned on this node. + // A false value when not provisioned is not an error. + hasSFU := ns.SFUUp // true = service active + hasTURN := ns.TURNUp // true = service active + + // If neither is provisioned, skip WebRTC checks for this namespace + if !hasSFU && !hasTURN { + continue + } + + prefix := fmt.Sprintf("ns.%s", ns.Name) + + if hasSFU { + r = append(r, inspector.Pass(prefix+".sfu_up", + fmt.Sprintf("Namespace %s SFU active", ns.Name), + webrtcSub, node, "systemd service running", inspector.High)) + } + + if hasTURN { + r = append(r, inspector.Pass(prefix+".turn_up", + fmt.Sprintf("Namespace %s TURN active", ns.Name), + webrtcSub, node, "systemd service running", inspector.High)) + } + } + + return r +} + +func checkWebRTCCrossNode(data *inspector.ClusterData) []inspector.CheckResult { + var r []inspector.CheckResult + + // Collect SFU/TURN node counts per namespace + type webrtcCounts struct { + sfuNodes int + turnNodes int + } + nsCounts := map[string]*webrtcCounts{} + + for _, nd := range data.Nodes { + for _, ns := range nd.Namespaces { + if !ns.SFUUp && !ns.TURNUp { + continue + } + c, ok := nsCounts[ns.Name] + if !ok { + c = &webrtcCounts{} + nsCounts[ns.Name] = c + } + if ns.SFUUp { + c.sfuNodes++ + } + if ns.TURNUp { + c.turnNodes++ + } + } + } + + for name, counts := range nsCounts { + // SFU should be on all cluster nodes (typically 3) + if counts.sfuNodes > 0 { + if counts.sfuNodes >= 3 { + r = append(r, inspector.Pass( + fmt.Sprintf("ns.%s.sfu_coverage", name), + fmt.Sprintf("Namespace %s SFU on all nodes", name), + webrtcSub, "", + fmt.Sprintf("%d SFU nodes active", counts.sfuNodes), + inspector.High)) + } else { + r = append(r, inspector.Warn( + fmt.Sprintf("ns.%s.sfu_coverage", name), + fmt.Sprintf("Namespace %s SFU on all nodes", name), + webrtcSub, "", + fmt.Sprintf("only %d/3 SFU nodes active", counts.sfuNodes), + inspector.High)) + } + } + + // TURN should be on 2 nodes + if counts.turnNodes > 0 { + if counts.turnNodes >= 2 { + r = append(r, inspector.Pass( + fmt.Sprintf("ns.%s.turn_coverage", name), + fmt.Sprintf("Namespace %s TURN redundant", name), + webrtcSub, "", + fmt.Sprintf("%d TURN nodes active", counts.turnNodes), + inspector.High)) + } else { + r = append(r, inspector.Warn( + fmt.Sprintf("ns.%s.turn_coverage", name), + fmt.Sprintf("Namespace %s TURN redundant", name), + webrtcSub, "", + fmt.Sprintf("only %d/2 TURN nodes active (no redundancy)", counts.turnNodes), + inspector.High)) + } + } + } + + return r +} diff --git a/pkg/inspector/collector.go b/pkg/inspector/collector.go index 67470b2..534b9f7 100644 --- a/pkg/inspector/collector.go +++ b/pkg/inspector/collector.go @@ -41,6 +41,8 @@ type NamespaceData struct { OlricUp bool // Olric memberlist port listening GatewayUp bool // Gateway HTTP port responding GatewayStatus int // HTTP status code from gateway health + SFUUp bool // SFU systemd service active (optional, WebRTC) + TURNUp bool // TURN systemd service active (optional, WebRTC) } // RQLiteData holds parsed RQLite status from a single node. diff --git a/pkg/logging/logger.go b/pkg/logging/logger.go index 0dee825..4b78345 100644 --- a/pkg/logging/logger.go +++ b/pkg/logging/logger.go @@ -55,6 +55,8 @@ const ( ComponentGeneral Component = "GENERAL" ComponentAnyone Component = "ANYONE" ComponentGateway Component = "GATEWAY" + ComponentSFU Component = "SFU" + ComponentTURN Component = "TURN" ) // getComponentColor returns the color for a specific component @@ -78,6 +80,10 @@ func getComponentColor(component Component) string { return Cyan case ComponentGateway: return BrightGreen + case ComponentSFU: + return BrightRed + case ComponentTURN: + return Magenta default: return White } diff --git a/pkg/namespace/cluster_manager.go b/pkg/namespace/cluster_manager.go index 1172bc4..0193075 100644 --- a/pkg/namespace/cluster_manager.go +++ b/pkg/namespace/cluster_manager.go @@ -18,6 +18,7 @@ import ( "github.com/DeBrosOfficial/network/pkg/gateway" "github.com/DeBrosOfficial/network/pkg/olric" "github.com/DeBrosOfficial/network/pkg/rqlite" + "github.com/DeBrosOfficial/network/pkg/sfu" "github.com/DeBrosOfficial/network/pkg/systemd" "github.com/google/uuid" "go.uber.org/zap" @@ -37,12 +38,13 @@ type ClusterManagerConfig struct { // ClusterManager orchestrates namespace cluster provisioning and lifecycle type ClusterManager struct { - db rqlite.Client - portAllocator *NamespacePortAllocator - nodeSelector *ClusterNodeSelector - systemdSpawner *SystemdSpawner // NEW: Systemd-based spawner replaces old spawners - dnsManager *DNSRecordManager - logger *zap.Logger + db rqlite.Client + portAllocator *NamespacePortAllocator + webrtcPortAllocator *WebRTCPortAllocator + nodeSelector *ClusterNodeSelector + systemdSpawner *SystemdSpawner // NEW: Systemd-based spawner replaces old spawners + dnsManager *DNSRecordManager + logger *zap.Logger baseDomain string baseDataDir string globalRQLiteDSN string // Global RQLite DSN for namespace gateway auth @@ -69,6 +71,7 @@ func NewClusterManager( ) *ClusterManager { // Create internal components portAllocator := NewNamespacePortAllocator(db, logger) + webrtcPortAllocator := NewWebRTCPortAllocator(db, logger) nodeSelector := NewClusterNodeSelector(db, portAllocator, logger) systemdSpawner := NewSystemdSpawner(cfg.BaseDataDir, logger) dnsManager := NewDNSRecordManager(db, cfg.BaseDomain, logger) @@ -94,6 +97,7 @@ func NewClusterManager( return &ClusterManager{ db: db, portAllocator: portAllocator, + webrtcPortAllocator: webrtcPortAllocator, nodeSelector: nodeSelector, systemdSpawner: systemdSpawner, dnsManager: dnsManager, @@ -139,6 +143,7 @@ func NewClusterManagerWithComponents( return &ClusterManager{ db: db, portAllocator: portAllocator, + webrtcPortAllocator: NewWebRTCPortAllocator(db, logger), nodeSelector: nodeSelector, systemdSpawner: systemdSpawner, dnsManager: NewDNSRecordManager(db, cfg.BaseDomain, logger), @@ -854,12 +859,21 @@ func (cm *ClusterManager) DeprovisionCluster(ctx context.Context, namespaceID in if err := cm.db.Query(ctx, &clusterNodes, nodeQuery, cluster.ID); err != nil { cm.logger.Warn("Failed to query cluster nodes for deprovisioning, falling back to local-only stop", zap.Error(err)) // Fall back to local-only stop (individual methods, NOT StopAll which uses dangerous glob) + // Stop WebRTC services first (SFU → TURN), then core services (Gateway → Olric → RQLite) + cm.systemdSpawner.StopSFU(ctx, cluster.NamespaceName, cm.localNodeID) + cm.systemdSpawner.StopTURN(ctx, cluster.NamespaceName, cm.localNodeID) cm.systemdSpawner.StopGateway(ctx, cluster.NamespaceName, cm.localNodeID) cm.systemdSpawner.StopOlric(ctx, cluster.NamespaceName, cm.localNodeID) cm.systemdSpawner.StopRQLite(ctx, cluster.NamespaceName, cm.localNodeID) cm.systemdSpawner.DeleteClusterState(cluster.NamespaceName) } else { - // 2. Stop namespace infra on ALL nodes (reverse dependency order: Gateway → Olric → RQLite) + // 2. Stop WebRTC services first (SFU → TURN), then core infra (Gateway → Olric → RQLite) + for _, node := range clusterNodes { + cm.stopSFUOnNode(ctx, node.NodeID, node.InternalIP, cluster.NamespaceName) + } + for _, node := range clusterNodes { + cm.stopTURNOnNode(ctx, node.NodeID, node.InternalIP, cluster.NamespaceName) + } for _, node := range clusterNodes { cm.stopGatewayOnNode(ctx, node.NodeID, node.InternalIP, cluster.NamespaceName) } @@ -880,16 +894,21 @@ func (cm *ClusterManager) DeprovisionCluster(ctx context.Context, namespaceID in } } - // 4. Deallocate all ports + // 4. Deallocate all ports (core + WebRTC) cm.portAllocator.DeallocateAllPortBlocks(ctx, cluster.ID) + cm.webrtcPortAllocator.DeallocateAll(ctx, cluster.ID) - // 5. Delete namespace DNS records + // 5. Delete namespace DNS records (gateway + TURN) cm.dnsManager.DeleteNamespaceRecords(ctx, cluster.NamespaceName) + cm.dnsManager.DeleteTURNRecords(ctx, cluster.NamespaceName) // 6. Explicitly delete child tables (FK cascades disabled in rqlite) cm.db.Exec(ctx, `DELETE FROM namespace_cluster_events WHERE namespace_cluster_id = ?`, cluster.ID) cm.db.Exec(ctx, `DELETE FROM namespace_cluster_nodes WHERE namespace_cluster_id = ?`, cluster.ID) cm.db.Exec(ctx, `DELETE FROM namespace_port_allocations WHERE namespace_cluster_id = ?`, cluster.ID) + cm.db.Exec(ctx, `DELETE FROM webrtc_port_allocations WHERE namespace_cluster_id = ?`, cluster.ID) + cm.db.Exec(ctx, `DELETE FROM webrtc_rooms WHERE namespace_cluster_id = ?`, cluster.ID) + cm.db.Exec(ctx, `DELETE FROM namespace_webrtc_config WHERE namespace_cluster_id = ?`, cluster.ID) // 7. Delete cluster record cm.db.Exec(ctx, `DELETE FROM namespace_clusters WHERE id = ?`, cluster.ID) @@ -1594,6 +1613,19 @@ type ClusterLocalState struct { HasGateway bool `json:"has_gateway"` BaseDomain string `json:"base_domain"` SavedAt time.Time `json:"saved_at"` + + // WebRTC fields (zero values when WebRTC not enabled — backward compatible) + HasSFU bool `json:"has_sfu,omitempty"` + HasTURN bool `json:"has_turn,omitempty"` + TURNSharedSecret string `json:"-"` // Never persisted to disk state file + TURNCredentialTTL int `json:"turn_credential_ttl,omitempty"` + SFUSignalingPort int `json:"sfu_signaling_port,omitempty"` + SFUMediaPortStart int `json:"sfu_media_port_start,omitempty"` + SFUMediaPortEnd int `json:"sfu_media_port_end,omitempty"` + TURNListenPort int `json:"turn_listen_port,omitempty"` + TURNTLSPort int `json:"turn_tls_port,omitempty"` + TURNRelayPortStart int `json:"turn_relay_port_start,omitempty"` + TURNRelayPortEnd int `json:"turn_relay_port_end,omitempty"` } type ClusterLocalStatePorts struct { @@ -1891,6 +1923,70 @@ func (cm *ClusterManager) restoreClusterFromState(ctx context.Context, state *Cl } } + // 4. Restore TURN (if enabled) + if state.HasTURN && state.TURNRelayPortStart > 0 { + turnRunning, _ := cm.systemdSpawner.systemdMgr.IsServiceActive(state.NamespaceName, systemd.ServiceTypeTURN) + if !turnRunning { + // TURN config needs the shared secret from DB — we can't persist it to disk state. + // If DB is available, fetch it; otherwise skip TURN restore (it will come back when DB is ready). + webrtcCfg, err := cm.GetWebRTCConfig(ctx, state.NamespaceName) + if err == nil && webrtcCfg != nil { + turnCfg := TURNInstanceConfig{ + Namespace: state.NamespaceName, + NodeID: cm.localNodeID, + ListenAddr: fmt.Sprintf("0.0.0.0:%d", state.TURNListenPort), + TLSListenAddr: fmt.Sprintf("0.0.0.0:%d", state.TURNTLSPort), + PublicIP: "", // Will be resolved by spawner or from node info + Realm: cm.baseDomain, + AuthSecret: webrtcCfg.TURNSharedSecret, + RelayPortStart: state.TURNRelayPortStart, + RelayPortEnd: state.TURNRelayPortEnd, + } + if err := cm.systemdSpawner.SpawnTURN(ctx, state.NamespaceName, cm.localNodeID, turnCfg); err != nil { + cm.logger.Error("Failed to restore TURN from state", zap.String("namespace", state.NamespaceName), zap.Error(err)) + } else { + cm.logger.Info("Restored TURN instance from state", zap.String("namespace", state.NamespaceName)) + } + } else { + cm.logger.Warn("Skipping TURN restore: WebRTC config not available from DB", + zap.String("namespace", state.NamespaceName)) + } + } + } + + // 5. Restore SFU (if enabled) + if state.HasSFU && state.SFUSignalingPort > 0 { + sfuRunning, _ := cm.systemdSpawner.systemdMgr.IsServiceActive(state.NamespaceName, systemd.ServiceTypeSFU) + if !sfuRunning { + webrtcCfg, err := cm.GetWebRTCConfig(ctx, state.NamespaceName) + if err == nil && webrtcCfg != nil { + turnDomain := fmt.Sprintf("turn.ns-%s.%s", state.NamespaceName, cm.baseDomain) + sfuCfg := SFUInstanceConfig{ + Namespace: state.NamespaceName, + NodeID: cm.localNodeID, + ListenAddr: fmt.Sprintf("%s:%d", localIP, state.SFUSignalingPort), + MediaPortStart: state.SFUMediaPortStart, + MediaPortEnd: state.SFUMediaPortEnd, + TURNServers: []sfu.TURNServerConfig{ + {Host: turnDomain, Port: TURNDefaultPort}, + {Host: turnDomain, Port: TURNTLSPort}, + }, + TURNSecret: webrtcCfg.TURNSharedSecret, + TURNCredTTL: webrtcCfg.TURNCredentialTTL, + RQLiteDSN: fmt.Sprintf("http://localhost:%d", pb.RQLiteHTTPPort), + } + if err := cm.systemdSpawner.SpawnSFU(ctx, state.NamespaceName, cm.localNodeID, sfuCfg); err != nil { + cm.logger.Error("Failed to restore SFU from state", zap.String("namespace", state.NamespaceName), zap.Error(err)) + } else { + cm.logger.Info("Restored SFU instance from state", zap.String("namespace", state.NamespaceName)) + } + } else { + cm.logger.Warn("Skipping SFU restore: WebRTC config not available from DB", + zap.String("namespace", state.NamespaceName)) + } + } + } + return nil } diff --git a/pkg/namespace/cluster_manager_webrtc.go b/pkg/namespace/cluster_manager_webrtc.go new file mode 100644 index 0000000..9ba75ca --- /dev/null +++ b/pkg/namespace/cluster_manager_webrtc.go @@ -0,0 +1,616 @@ +package namespace + +import ( + "context" + "crypto/rand" + "encoding/base64" + "fmt" + "time" + + "github.com/DeBrosOfficial/network/pkg/client" + "github.com/DeBrosOfficial/network/pkg/sfu" + "github.com/google/uuid" + "go.uber.org/zap" +) + +// EnableWebRTC enables WebRTC (SFU + TURN) for an existing namespace cluster. +// Allocates ports, spawns SFU on all 3 nodes and TURN on 2 nodes, +// creates TURN DNS records, and updates cluster state. +func (cm *ClusterManager) EnableWebRTC(ctx context.Context, namespaceName, enabledBy string) error { + internalCtx := client.WithInternalAuth(ctx) + + // 1. Verify cluster exists and is ready + cluster, err := cm.GetClusterByNamespace(ctx, namespaceName) + if err != nil { + return fmt.Errorf("failed to get cluster: %w", err) + } + if cluster == nil { + return ErrClusterNotFound + } + if cluster.Status != ClusterStatusReady { + return &ClusterError{Message: fmt.Sprintf("cluster status is %q, must be %q to enable WebRTC", cluster.Status, ClusterStatusReady)} + } + + // 2. Check if WebRTC is already enabled + var existingConfigs []WebRTCConfig + if err := cm.db.Query(internalCtx, &existingConfigs, + `SELECT * FROM namespace_webrtc_config WHERE namespace_cluster_id = ? AND enabled = 1`, cluster.ID); err == nil && len(existingConfigs) > 0 { + return ErrWebRTCAlreadyEnabled + } + + cm.logger.Info("Enabling WebRTC for namespace", + zap.String("namespace", namespaceName), + zap.String("cluster_id", cluster.ID), + ) + + // 3. Generate TURN shared secret (32 bytes, crypto/rand) + secretBytes := make([]byte, 32) + if _, err := rand.Read(secretBytes); err != nil { + return fmt.Errorf("failed to generate TURN secret: %w", err) + } + turnSecret := base64.StdEncoding.EncodeToString(secretBytes) + + // 4. Insert namespace_webrtc_config + webrtcConfigID := uuid.New().String() + _, err = cm.db.Exec(internalCtx, + `INSERT INTO namespace_webrtc_config (id, namespace_cluster_id, namespace_name, enabled, turn_shared_secret, turn_credential_ttl, sfu_node_count, turn_node_count, enabled_by, enabled_at) + VALUES (?, ?, ?, 1, ?, ?, ?, ?, ?, ?)`, + webrtcConfigID, cluster.ID, namespaceName, + turnSecret, DefaultTURNCredentialTTL, + DefaultSFUNodeCount, DefaultTURNNodeCount, + enabledBy, time.Now(), + ) + if err != nil { + return fmt.Errorf("failed to insert WebRTC config: %w", err) + } + + // 5. Get cluster nodes with IPs + clusterNodes, err := cm.getClusterNodesWithIPs(ctx, cluster.ID) + if err != nil { + return fmt.Errorf("failed to get cluster nodes: %w", err) + } + if len(clusterNodes) < 3 { + return fmt.Errorf("cluster has %d nodes, need at least 3 for WebRTC", len(clusterNodes)) + } + + // 6. Allocate SFU ports on all nodes + sfuBlocks := make(map[string]*WebRTCPortBlock) // nodeID -> block + for _, node := range clusterNodes { + block, err := cm.webrtcPortAllocator.AllocateSFUPorts(ctx, node.NodeID, cluster.ID) + if err != nil { + cm.cleanupWebRTCOnError(ctx, cluster.ID, namespaceName, clusterNodes) + return fmt.Errorf("failed to allocate SFU ports on node %s: %w", node.NodeID, err) + } + sfuBlocks[node.NodeID] = block + } + + // 7. Select TURN nodes (prefer nodes without existing TURN allocations) + turnNodes := cm.selectTURNNodes(ctx, clusterNodes, DefaultTURNNodeCount) + + // 8. Allocate TURN ports on selected nodes + turnBlocks := make(map[string]*WebRTCPortBlock) // nodeID -> block + for _, node := range turnNodes { + block, err := cm.webrtcPortAllocator.AllocateTURNPorts(ctx, node.NodeID, cluster.ID) + if err != nil { + cm.cleanupWebRTCOnError(ctx, cluster.ID, namespaceName, clusterNodes) + return fmt.Errorf("failed to allocate TURN ports on node %s: %w", node.NodeID, err) + } + turnBlocks[node.NodeID] = block + } + + // 9. Build TURN server list for SFU config + turnDomain := fmt.Sprintf("turn.ns-%s.%s", namespaceName, cm.baseDomain) + turnServers := []sfu.TURNServerConfig{ + {Host: turnDomain, Port: TURNDefaultPort}, + {Host: turnDomain, Port: TURNTLSPort}, + } + + // 10. Get port blocks for RQLite DSN + portBlocks, err := cm.portAllocator.GetAllPortBlocks(ctx, cluster.ID) + if err != nil { + cm.cleanupWebRTCOnError(ctx, cluster.ID, namespaceName, clusterNodes) + return fmt.Errorf("failed to get port blocks: %w", err) + } + + // Build nodeID -> PortBlock map + nodePortBlocks := make(map[string]*PortBlock) + for i := range portBlocks { + nodePortBlocks[portBlocks[i].NodeID] = &portBlocks[i] + } + + // 11. Spawn TURN on selected nodes + for _, node := range turnNodes { + turnBlock := turnBlocks[node.NodeID] + turnCfg := TURNInstanceConfig{ + Namespace: namespaceName, + NodeID: node.NodeID, + ListenAddr: fmt.Sprintf("0.0.0.0:%d", turnBlock.TURNListenPort), + TLSListenAddr: fmt.Sprintf("0.0.0.0:%d", turnBlock.TURNTLSPort), + PublicIP: node.PublicIP, + Realm: cm.baseDomain, + AuthSecret: turnSecret, + RelayPortStart: turnBlock.TURNRelayPortStart, + RelayPortEnd: turnBlock.TURNRelayPortEnd, + } + + if err := cm.spawnTURNOnNode(ctx, node, namespaceName, turnCfg); err != nil { + cm.logger.Error("Failed to spawn TURN", + zap.String("namespace", namespaceName), + zap.String("node_id", node.NodeID), + zap.Error(err)) + cm.cleanupWebRTCOnError(ctx, cluster.ID, namespaceName, clusterNodes) + return fmt.Errorf("failed to spawn TURN on node %s: %w", node.NodeID, err) + } + + cm.logEvent(ctx, cluster.ID, EventTURNStarted, node.NodeID, + fmt.Sprintf("TURN started on %s (relay ports %d-%d)", node.NodeID, turnBlock.TURNRelayPortStart, turnBlock.TURNRelayPortEnd), nil) + } + + // 12. Spawn SFU on all nodes + for _, node := range clusterNodes { + sfuBlock := sfuBlocks[node.NodeID] + pb := nodePortBlocks[node.NodeID] + rqliteDSN := fmt.Sprintf("http://localhost:%d", pb.RQLiteHTTPPort) + + sfuCfg := SFUInstanceConfig{ + Namespace: namespaceName, + NodeID: node.NodeID, + ListenAddr: fmt.Sprintf("%s:%d", node.InternalIP, sfuBlock.SFUSignalingPort), + MediaPortStart: sfuBlock.SFUMediaPortStart, + MediaPortEnd: sfuBlock.SFUMediaPortEnd, + TURNServers: turnServers, + TURNSecret: turnSecret, + TURNCredTTL: DefaultTURNCredentialTTL, + RQLiteDSN: rqliteDSN, + } + + if err := cm.spawnSFUOnNode(ctx, node, namespaceName, sfuCfg); err != nil { + cm.logger.Error("Failed to spawn SFU", + zap.String("namespace", namespaceName), + zap.String("node_id", node.NodeID), + zap.Error(err)) + cm.cleanupWebRTCOnError(ctx, cluster.ID, namespaceName, clusterNodes) + return fmt.Errorf("failed to spawn SFU on node %s: %w", node.NodeID, err) + } + + cm.logEvent(ctx, cluster.ID, EventSFUStarted, node.NodeID, + fmt.Sprintf("SFU started on %s:%d", node.InternalIP, sfuBlock.SFUSignalingPort), nil) + } + + // 13. Create TURN DNS records + var turnIPs []string + for _, node := range turnNodes { + turnIPs = append(turnIPs, node.PublicIP) + } + if err := cm.dnsManager.CreateTURNRecords(ctx, namespaceName, turnIPs); err != nil { + cm.logger.Warn("Failed to create TURN DNS records", + zap.String("namespace", namespaceName), + zap.Error(err)) + } + + // 14. Update cluster-state.json on all nodes with WebRTC info + cm.updateClusterStateWithWebRTC(ctx, cluster, clusterNodes, sfuBlocks, turnBlocks) + + cm.logEvent(ctx, cluster.ID, EventWebRTCEnabled, "", + fmt.Sprintf("WebRTC enabled: SFU on %d nodes, TURN on %d nodes", len(clusterNodes), len(turnNodes)), nil) + + cm.logger.Info("WebRTC enabled successfully", + zap.String("namespace", namespaceName), + zap.String("cluster_id", cluster.ID), + zap.Int("sfu_nodes", len(clusterNodes)), + zap.Int("turn_nodes", len(turnNodes)), + ) + + return nil +} + +// DisableWebRTC disables WebRTC for a namespace cluster. +// Stops SFU/TURN services, deallocates ports, and cleans up DNS/DB. +func (cm *ClusterManager) DisableWebRTC(ctx context.Context, namespaceName string) error { + internalCtx := client.WithInternalAuth(ctx) + + // 1. Verify cluster exists + cluster, err := cm.GetClusterByNamespace(ctx, namespaceName) + if err != nil { + return fmt.Errorf("failed to get cluster: %w", err) + } + if cluster == nil { + return ErrClusterNotFound + } + + // 2. Verify WebRTC is enabled + var configs []WebRTCConfig + if err := cm.db.Query(internalCtx, &configs, + `SELECT * FROM namespace_webrtc_config WHERE namespace_cluster_id = ? AND enabled = 1`, cluster.ID); err != nil || len(configs) == 0 { + return ErrWebRTCNotEnabled + } + + cm.logger.Info("Disabling WebRTC for namespace", + zap.String("namespace", namespaceName), + zap.String("cluster_id", cluster.ID), + ) + + // 3. Get cluster nodes with IPs + clusterNodes, err := cm.getClusterNodesWithIPs(ctx, cluster.ID) + if err != nil { + return fmt.Errorf("failed to get cluster nodes: %w", err) + } + + // 4. Stop SFU on all nodes + for _, node := range clusterNodes { + cm.stopSFUOnNode(ctx, node.NodeID, node.InternalIP, namespaceName) + cm.logEvent(ctx, cluster.ID, EventSFUStopped, node.NodeID, "SFU stopped", nil) + } + + // 5. Stop TURN on nodes that have TURN allocations + turnBlocks, _ := cm.getWebRTCBlocksByType(ctx, cluster.ID, "turn") + for _, block := range turnBlocks { + nodeIP := cm.getNodeIP(clusterNodes, block.NodeID) + cm.stopTURNOnNode(ctx, block.NodeID, nodeIP, namespaceName) + cm.logEvent(ctx, cluster.ID, EventTURNStopped, block.NodeID, "TURN stopped", nil) + } + + // 6. Deallocate all WebRTC ports + if err := cm.webrtcPortAllocator.DeallocateAll(ctx, cluster.ID); err != nil { + cm.logger.Warn("Failed to deallocate WebRTC ports", zap.Error(err)) + } + + // 7. Delete TURN DNS records + if err := cm.dnsManager.DeleteTURNRecords(ctx, namespaceName); err != nil { + cm.logger.Warn("Failed to delete TURN DNS records", zap.Error(err)) + } + + // 8. Clean up DB tables + cm.db.Exec(internalCtx, `DELETE FROM webrtc_rooms WHERE namespace_cluster_id = ?`, cluster.ID) + cm.db.Exec(internalCtx, `DELETE FROM namespace_webrtc_config WHERE namespace_cluster_id = ?`, cluster.ID) + + // 9. Update cluster-state.json to remove WebRTC info + cm.updateClusterStateWithWebRTC(ctx, cluster, clusterNodes, nil, nil) + + cm.logEvent(ctx, cluster.ID, EventWebRTCDisabled, "", "WebRTC disabled", nil) + + cm.logger.Info("WebRTC disabled successfully", + zap.String("namespace", namespaceName), + zap.String("cluster_id", cluster.ID), + ) + + return nil +} + +// GetWebRTCConfig returns the WebRTC configuration for a namespace. +func (cm *ClusterManager) GetWebRTCConfig(ctx context.Context, namespaceName string) (*WebRTCConfig, error) { + internalCtx := client.WithInternalAuth(ctx) + + var configs []WebRTCConfig + err := cm.db.Query(internalCtx, &configs, + `SELECT * FROM namespace_webrtc_config WHERE namespace_name = ? AND enabled = 1`, namespaceName) + if err != nil { + return nil, fmt.Errorf("failed to query WebRTC config: %w", err) + } + if len(configs) == 0 { + return nil, nil + } + return &configs[0], nil +} + +// GetWebRTCStatus returns the WebRTC config as an interface{} for the WebRTCManager interface. +func (cm *ClusterManager) GetWebRTCStatus(ctx context.Context, namespaceName string) (interface{}, error) { + cfg, err := cm.GetWebRTCConfig(ctx, namespaceName) + if err != nil { + return nil, err + } + if cfg == nil { + return nil, nil + } + return cfg, nil +} + +// --- Internal helpers --- + +// clusterNodeInfo holds node info needed for WebRTC operations +type clusterNodeInfo struct { + NodeID string + InternalIP string // WireGuard IP + PublicIP string // Public IP for TURN +} + +// getClusterNodesWithIPs returns cluster nodes with both internal and public IPs. +func (cm *ClusterManager) getClusterNodesWithIPs(ctx context.Context, clusterID string) ([]clusterNodeInfo, error) { + internalCtx := client.WithInternalAuth(ctx) + + type nodeRow struct { + NodeID string `db:"node_id"` + InternalIP string `db:"internal_ip"` + PublicIP string `db:"public_ip"` + } + var rows []nodeRow + query := ` + SELECT ncn.node_id, + COALESCE(dn.internal_ip, dn.ip_address) as internal_ip, + dn.ip_address as public_ip + FROM namespace_cluster_nodes ncn + JOIN dns_nodes dn ON ncn.node_id = dn.id + WHERE ncn.namespace_cluster_id = ? + ` + if err := cm.db.Query(internalCtx, &rows, query, clusterID); err != nil { + return nil, err + } + + nodes := make([]clusterNodeInfo, len(rows)) + for i, r := range rows { + nodes[i] = clusterNodeInfo{ + NodeID: r.NodeID, + InternalIP: r.InternalIP, + PublicIP: r.PublicIP, + } + } + return nodes, nil +} + +// selectTURNNodes selects the best N nodes for TURN, preferring nodes without existing TURN allocations. +func (cm *ClusterManager) selectTURNNodes(ctx context.Context, nodes []clusterNodeInfo, count int) []clusterNodeInfo { + if count >= len(nodes) { + return nodes + } + + // Prefer nodes without existing TURN allocations + var preferred, fallback []clusterNodeInfo + for _, node := range nodes { + hasTURN, err := cm.webrtcPortAllocator.NodeHasTURN(ctx, node.NodeID) + if err != nil || !hasTURN { + preferred = append(preferred, node) + } else { + fallback = append(fallback, node) + } + } + + // Take from preferred first, then fallback + result := make([]clusterNodeInfo, 0, count) + for _, node := range preferred { + if len(result) >= count { + break + } + result = append(result, node) + } + for _, node := range fallback { + if len(result) >= count { + break + } + result = append(result, node) + } + return result +} + +// spawnSFUOnNode spawns SFU on a node (local or remote) +func (cm *ClusterManager) spawnSFUOnNode(ctx context.Context, node clusterNodeInfo, namespace string, cfg SFUInstanceConfig) error { + if node.NodeID == cm.localNodeID { + return cm.systemdSpawner.SpawnSFU(ctx, namespace, node.NodeID, cfg) + } + return cm.spawnSFURemote(ctx, node.InternalIP, cfg) +} + +// spawnTURNOnNode spawns TURN on a node (local or remote) +func (cm *ClusterManager) spawnTURNOnNode(ctx context.Context, node clusterNodeInfo, namespace string, cfg TURNInstanceConfig) error { + if node.NodeID == cm.localNodeID { + return cm.systemdSpawner.SpawnTURN(ctx, namespace, node.NodeID, cfg) + } + return cm.spawnTURNRemote(ctx, node.InternalIP, cfg) +} + +// stopSFUOnNode stops SFU on a node (local or remote) +func (cm *ClusterManager) stopSFUOnNode(ctx context.Context, nodeID, nodeIP, namespace string) { + if nodeID == cm.localNodeID { + cm.systemdSpawner.StopSFU(ctx, namespace, nodeID) + } else { + cm.sendStopRequest(ctx, nodeIP, "stop-sfu", namespace, nodeID) + } +} + +// stopTURNOnNode stops TURN on a node (local or remote) +func (cm *ClusterManager) stopTURNOnNode(ctx context.Context, nodeID, nodeIP, namespace string) { + if nodeID == cm.localNodeID { + cm.systemdSpawner.StopTURN(ctx, namespace, nodeID) + } else { + cm.sendStopRequest(ctx, nodeIP, "stop-turn", namespace, nodeID) + } +} + +// spawnSFURemote sends a spawn-sfu request to a remote node +func (cm *ClusterManager) spawnSFURemote(ctx context.Context, nodeIP string, cfg SFUInstanceConfig) error { + // Serialize TURN servers for transport + turnServers := make([]map[string]interface{}, len(cfg.TURNServers)) + for i, ts := range cfg.TURNServers { + turnServers[i] = map[string]interface{}{ + "host": ts.Host, + "port": ts.Port, + } + } + + _, err := cm.sendSpawnRequest(ctx, nodeIP, map[string]interface{}{ + "action": "spawn-sfu", + "namespace": cfg.Namespace, + "node_id": cfg.NodeID, + "sfu_listen_addr": cfg.ListenAddr, + "sfu_media_start": cfg.MediaPortStart, + "sfu_media_end": cfg.MediaPortEnd, + "turn_servers": turnServers, + "turn_secret": cfg.TURNSecret, + "turn_cred_ttl": cfg.TURNCredTTL, + "rqlite_dsn": cfg.RQLiteDSN, + }) + return err +} + +// spawnTURNRemote sends a spawn-turn request to a remote node +func (cm *ClusterManager) spawnTURNRemote(ctx context.Context, nodeIP string, cfg TURNInstanceConfig) error { + _, err := cm.sendSpawnRequest(ctx, nodeIP, map[string]interface{}{ + "action": "spawn-turn", + "namespace": cfg.Namespace, + "node_id": cfg.NodeID, + "turn_listen_addr": cfg.ListenAddr, + "turn_tls_addr": cfg.TLSListenAddr, + "turn_public_ip": cfg.PublicIP, + "turn_realm": cfg.Realm, + "turn_auth_secret": cfg.AuthSecret, + "turn_relay_start": cfg.RelayPortStart, + "turn_relay_end": cfg.RelayPortEnd, + }) + return err +} + +// getWebRTCBlocksByType returns all WebRTC port blocks of a given type for a cluster. +func (cm *ClusterManager) getWebRTCBlocksByType(ctx context.Context, clusterID, serviceType string) ([]WebRTCPortBlock, error) { + allBlocks, err := cm.webrtcPortAllocator.GetAllPorts(ctx, clusterID) + if err != nil { + return nil, err + } + + var filtered []WebRTCPortBlock + for _, b := range allBlocks { + if b.ServiceType == serviceType { + filtered = append(filtered, b) + } + } + return filtered, nil +} + +// getNodeIP looks up the internal IP for a node ID from a list. +func (cm *ClusterManager) getNodeIP(nodes []clusterNodeInfo, nodeID string) string { + for _, n := range nodes { + if n.NodeID == nodeID { + return n.InternalIP + } + } + return "" +} + +// cleanupWebRTCOnError cleans up partial WebRTC allocations when EnableWebRTC fails mid-way. +func (cm *ClusterManager) cleanupWebRTCOnError(ctx context.Context, clusterID, namespaceName string, nodes []clusterNodeInfo) { + cm.logger.Warn("Cleaning up partial WebRTC enablement", + zap.String("namespace", namespaceName), + zap.String("cluster_id", clusterID)) + + internalCtx := client.WithInternalAuth(ctx) + + // Stop any spawned SFU/TURN services + for _, node := range nodes { + cm.stopSFUOnNode(ctx, node.NodeID, node.InternalIP, namespaceName) + cm.stopTURNOnNode(ctx, node.NodeID, node.InternalIP, namespaceName) + } + + // Deallocate ports + cm.webrtcPortAllocator.DeallocateAll(ctx, clusterID) + + // Remove config row + cm.db.Exec(internalCtx, `DELETE FROM namespace_webrtc_config WHERE namespace_cluster_id = ?`, clusterID) +} + +// updateClusterStateWithWebRTC updates the cluster-state.json on all nodes +// to include (or remove) WebRTC port information. +// Pass nil maps to clear WebRTC state (when disabling). +func (cm *ClusterManager) updateClusterStateWithWebRTC( + ctx context.Context, + cluster *NamespaceCluster, + nodes []clusterNodeInfo, + sfuBlocks map[string]*WebRTCPortBlock, + turnBlocks map[string]*WebRTCPortBlock, +) { + // Get existing port blocks for base state + portBlocks, err := cm.portAllocator.GetAllPortBlocks(ctx, cluster.ID) + if err != nil { + cm.logger.Warn("Failed to get port blocks for state update", zap.Error(err)) + return + } + + // Build nodeID -> PortBlock map + nodePortMap := make(map[string]*PortBlock) + for i := range portBlocks { + nodePortMap[portBlocks[i].NodeID] = &portBlocks[i] + } + + // Build AllNodes list + var allStateNodes []ClusterLocalStateNode + for _, node := range nodes { + pb := nodePortMap[node.NodeID] + if pb == nil { + continue + } + allStateNodes = append(allStateNodes, ClusterLocalStateNode{ + NodeID: node.NodeID, + InternalIP: node.InternalIP, + RQLiteHTTPPort: pb.RQLiteHTTPPort, + RQLiteRaftPort: pb.RQLiteRaftPort, + OlricHTTPPort: pb.OlricHTTPPort, + OlricMemberlistPort: pb.OlricMemberlistPort, + }) + } + + // Save state on each node + for _, node := range nodes { + pb := nodePortMap[node.NodeID] + if pb == nil { + continue + } + + state := &ClusterLocalState{ + ClusterID: cluster.ID, + NamespaceName: cluster.NamespaceName, + LocalNodeID: node.NodeID, + LocalIP: node.InternalIP, + LocalPorts: ClusterLocalStatePorts{ + RQLiteHTTPPort: pb.RQLiteHTTPPort, + RQLiteRaftPort: pb.RQLiteRaftPort, + OlricHTTPPort: pb.OlricHTTPPort, + OlricMemberlistPort: pb.OlricMemberlistPort, + GatewayHTTPPort: pb.GatewayHTTPPort, + }, + AllNodes: allStateNodes, + HasGateway: true, + BaseDomain: cm.baseDomain, + SavedAt: time.Now(), + } + + // Add WebRTC fields if enabling + if sfuBlocks != nil { + if sfuBlock, ok := sfuBlocks[node.NodeID]; ok { + state.HasSFU = true + state.SFUSignalingPort = sfuBlock.SFUSignalingPort + state.SFUMediaPortStart = sfuBlock.SFUMediaPortStart + state.SFUMediaPortEnd = sfuBlock.SFUMediaPortEnd + } + } + if turnBlocks != nil { + if turnBlock, ok := turnBlocks[node.NodeID]; ok { + state.HasTURN = true + state.TURNListenPort = turnBlock.TURNListenPort + state.TURNTLSPort = turnBlock.TURNTLSPort + state.TURNRelayPortStart = turnBlock.TURNRelayPortStart + state.TURNRelayPortEnd = turnBlock.TURNRelayPortEnd + } + } + + if node.NodeID == cm.localNodeID { + if err := cm.saveLocalState(state); err != nil { + cm.logger.Warn("Failed to save local cluster state", + zap.String("namespace", cluster.NamespaceName), + zap.Error(err)) + } + } else { + cm.saveRemoteState(ctx, node.InternalIP, cluster.NamespaceName, state) + } + } +} + +// saveRemoteState sends cluster state to a remote node for persistence. +func (cm *ClusterManager) saveRemoteState(ctx context.Context, nodeIP, namespace string, state *ClusterLocalState) { + _, err := cm.sendSpawnRequest(ctx, nodeIP, map[string]interface{}{ + "action": "save-cluster-state", + "namespace": namespace, + "cluster_state": state, + }) + if err != nil { + cm.logger.Warn("Failed to save cluster state on remote node", + zap.String("node_ip", nodeIP), + zap.Error(err)) + } +} diff --git a/pkg/namespace/dns_manager.go b/pkg/namespace/dns_manager.go index d14df8b..40e30ca 100644 --- a/pkg/namespace/dns_manager.go +++ b/pkg/namespace/dns_manager.go @@ -300,6 +300,78 @@ func (drm *DNSRecordManager) DisableNamespaceRecord(ctx context.Context, namespa return nil } +// CreateTURNRecords creates DNS A records for TURN servers. +// TURN records follow the pattern: turn.ns-{namespace}.{baseDomain} -> TURN node IPs +func (drm *DNSRecordManager) CreateTURNRecords(ctx context.Context, namespaceName string, turnIPs []string) error { + internalCtx := client.WithInternalAuth(ctx) + + if len(turnIPs) == 0 { + return &ClusterError{Message: "no TURN IPs provided for DNS records"} + } + + fqdn := fmt.Sprintf("turn.ns-%s.%s.", namespaceName, drm.baseDomain) + + drm.logger.Info("Creating TURN DNS records", + zap.String("namespace", namespaceName), + zap.String("fqdn", fqdn), + zap.Strings("turn_ips", turnIPs), + ) + + // Delete existing TURN records for this namespace + deleteQuery := `DELETE FROM dns_records WHERE fqdn = ? AND namespace = ?` + _, _ = drm.db.Exec(internalCtx, deleteQuery, fqdn, "namespace-turn:"+namespaceName) + + // Create A records for each TURN node IP + now := time.Now() + for _, ip := range turnIPs { + recordID := uuid.New().String() + insertQuery := ` + INSERT INTO dns_records ( + id, fqdn, record_type, value, ttl, namespace, created_by, is_active, created_at, updated_at + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + ` + _, err := drm.db.Exec(internalCtx, insertQuery, + recordID, fqdn, "A", ip, 60, + "namespace-turn:"+namespaceName, + "cluster-manager", + true, now, now, + ) + if err != nil { + return &ClusterError{ + Message: fmt.Sprintf("failed to create TURN DNS record %s -> %s", fqdn, ip), + Cause: err, + } + } + } + + drm.logger.Info("TURN DNS records created", + zap.String("namespace", namespaceName), + zap.Int("record_count", len(turnIPs)), + ) + + return nil +} + +// DeleteTURNRecords deletes all TURN DNS records for a namespace. +func (drm *DNSRecordManager) DeleteTURNRecords(ctx context.Context, namespaceName string) error { + internalCtx := client.WithInternalAuth(ctx) + + drm.logger.Info("Deleting TURN DNS records", + zap.String("namespace", namespaceName), + ) + + deleteQuery := `DELETE FROM dns_records WHERE namespace = ?` + _, err := drm.db.Exec(internalCtx, deleteQuery, "namespace-turn:"+namespaceName) + if err != nil { + return &ClusterError{ + Message: "failed to delete TURN DNS records", + Cause: err, + } + } + + return nil +} + // EnableNamespaceRecord marks a specific IP's record as active (for recovery) func (drm *DNSRecordManager) EnableNamespaceRecord(ctx context.Context, namespaceName, ip string) error { internalCtx := client.WithInternalAuth(ctx) diff --git a/pkg/namespace/systemd_spawner.go b/pkg/namespace/systemd_spawner.go index a3db02b..a8dbbfc 100644 --- a/pkg/namespace/systemd_spawner.go +++ b/pkg/namespace/systemd_spawner.go @@ -10,7 +10,9 @@ import ( "github.com/DeBrosOfficial/network/pkg/gateway" "github.com/DeBrosOfficial/network/pkg/olric" "github.com/DeBrosOfficial/network/pkg/rqlite" + "github.com/DeBrosOfficial/network/pkg/sfu" "github.com/DeBrosOfficial/network/pkg/systemd" + "github.com/DeBrosOfficial/network/pkg/turn" "go.uber.org/zap" "gopkg.in/yaml.v3" ) @@ -289,6 +291,185 @@ func (s *SystemdSpawner) StopGateway(ctx context.Context, namespace, nodeID stri return s.systemdMgr.StopService(namespace, systemd.ServiceTypeGateway) } +// SFUInstanceConfig holds configuration for spawning an SFU instance +type SFUInstanceConfig struct { + Namespace string + NodeID string + ListenAddr string // WireGuard IP:port (e.g., "10.0.0.1:30000") + MediaPortStart int // Start of RTP media port range + MediaPortEnd int // End of RTP media port range + TURNServers []sfu.TURNServerConfig // TURN servers to advertise to peers + TURNSecret string // HMAC-SHA1 shared secret + TURNCredTTL int // Credential TTL in seconds + RQLiteDSN string // Namespace-local RQLite DSN +} + +// SpawnSFU starts an SFU instance using systemd +func (s *SystemdSpawner) SpawnSFU(ctx context.Context, namespace, nodeID string, cfg SFUInstanceConfig) error { + s.logger.Info("Spawning SFU via systemd", + zap.String("namespace", namespace), + zap.String("node_id", nodeID), + zap.String("listen_addr", cfg.ListenAddr)) + + // Create config directory + configDir := filepath.Join(s.namespaceBase, namespace, "configs") + if err := os.MkdirAll(configDir, 0755); err != nil { + return fmt.Errorf("failed to create config directory: %w", err) + } + + configPath := filepath.Join(configDir, fmt.Sprintf("sfu-%s.yaml", nodeID)) + + // Build SFU YAML config + sfuConfig := sfu.Config{ + ListenAddr: cfg.ListenAddr, + Namespace: cfg.Namespace, + MediaPortStart: cfg.MediaPortStart, + MediaPortEnd: cfg.MediaPortEnd, + TURNServers: cfg.TURNServers, + TURNSecret: cfg.TURNSecret, + TURNCredentialTTL: cfg.TURNCredTTL, + RQLiteDSN: cfg.RQLiteDSN, + } + + configBytes, err := yaml.Marshal(sfuConfig) + if err != nil { + return fmt.Errorf("failed to marshal SFU config: %w", err) + } + + if err := os.WriteFile(configPath, configBytes, 0644); err != nil { + return fmt.Errorf("failed to write SFU config: %w", err) + } + + s.logger.Info("Created SFU config file", + zap.String("path", configPath), + zap.String("namespace", namespace), + zap.String("node_id", nodeID)) + + // Generate environment file pointing to config + envVars := map[string]string{ + "SFU_CONFIG": configPath, + } + + if err := s.systemdMgr.GenerateEnvFile(namespace, nodeID, systemd.ServiceTypeSFU, envVars); err != nil { + return fmt.Errorf("failed to generate SFU env file: %w", err) + } + + // Start the systemd service + if err := s.systemdMgr.StartService(namespace, systemd.ServiceTypeSFU); err != nil { + return fmt.Errorf("failed to start SFU service: %w", err) + } + + // Wait for service to be active + if err := s.waitForService(namespace, systemd.ServiceTypeSFU, 30*time.Second); err != nil { + return fmt.Errorf("SFU service did not become active: %w", err) + } + + s.logger.Info("SFU spawned successfully via systemd", + zap.String("namespace", namespace), + zap.String("node_id", nodeID)) + + return nil +} + +// StopSFU stops an SFU instance +func (s *SystemdSpawner) StopSFU(ctx context.Context, namespace, nodeID string) error { + s.logger.Info("Stopping SFU via systemd", + zap.String("namespace", namespace), + zap.String("node_id", nodeID)) + + return s.systemdMgr.StopService(namespace, systemd.ServiceTypeSFU) +} + +// TURNInstanceConfig holds configuration for spawning a TURN instance +type TURNInstanceConfig struct { + Namespace string + NodeID string + ListenAddr string // e.g., "0.0.0.0:3478" + TLSListenAddr string // e.g., "0.0.0.0:443" (UDP, no conflict with Caddy TCP) + PublicIP string // Public IP for TURN relay allocations + Realm string // TURN realm (typically base domain) + AuthSecret string // HMAC-SHA1 shared secret + RelayPortStart int // Start of relay port range + RelayPortEnd int // End of relay port range +} + +// SpawnTURN starts a TURN instance using systemd +func (s *SystemdSpawner) SpawnTURN(ctx context.Context, namespace, nodeID string, cfg TURNInstanceConfig) error { + s.logger.Info("Spawning TURN via systemd", + zap.String("namespace", namespace), + zap.String("node_id", nodeID), + zap.String("listen_addr", cfg.ListenAddr), + zap.String("public_ip", cfg.PublicIP)) + + // Create config directory + configDir := filepath.Join(s.namespaceBase, namespace, "configs") + if err := os.MkdirAll(configDir, 0755); err != nil { + return fmt.Errorf("failed to create config directory: %w", err) + } + + configPath := filepath.Join(configDir, fmt.Sprintf("turn-%s.yaml", nodeID)) + + // Build TURN YAML config + turnConfig := turn.Config{ + ListenAddr: cfg.ListenAddr, + TLSListenAddr: cfg.TLSListenAddr, + PublicIP: cfg.PublicIP, + Realm: cfg.Realm, + AuthSecret: cfg.AuthSecret, + RelayPortStart: cfg.RelayPortStart, + RelayPortEnd: cfg.RelayPortEnd, + Namespace: cfg.Namespace, + } + + configBytes, err := yaml.Marshal(turnConfig) + if err != nil { + return fmt.Errorf("failed to marshal TURN config: %w", err) + } + + if err := os.WriteFile(configPath, configBytes, 0644); err != nil { + return fmt.Errorf("failed to write TURN config: %w", err) + } + + s.logger.Info("Created TURN config file", + zap.String("path", configPath), + zap.String("namespace", namespace), + zap.String("node_id", nodeID)) + + // Generate environment file pointing to config + envVars := map[string]string{ + "TURN_CONFIG": configPath, + } + + if err := s.systemdMgr.GenerateEnvFile(namespace, nodeID, systemd.ServiceTypeTURN, envVars); err != nil { + return fmt.Errorf("failed to generate TURN env file: %w", err) + } + + // Start the systemd service + if err := s.systemdMgr.StartService(namespace, systemd.ServiceTypeTURN); err != nil { + return fmt.Errorf("failed to start TURN service: %w", err) + } + + // Wait for service to be active + if err := s.waitForService(namespace, systemd.ServiceTypeTURN, 30*time.Second); err != nil { + return fmt.Errorf("TURN service did not become active: %w", err) + } + + s.logger.Info("TURN spawned successfully via systemd", + zap.String("namespace", namespace), + zap.String("node_id", nodeID)) + + return nil +} + +// StopTURN stops a TURN instance +func (s *SystemdSpawner) StopTURN(ctx context.Context, namespace, nodeID string) error { + s.logger.Info("Stopping TURN via systemd", + zap.String("namespace", namespace), + zap.String("node_id", nodeID)) + + return s.systemdMgr.StopService(namespace, systemd.ServiceTypeTURN) +} + // SaveClusterState writes cluster state JSON to the namespace data directory. // Used by the spawn handler to persist state received from the coordinator node. func (s *SystemdSpawner) SaveClusterState(namespace string, data []byte) error { diff --git a/pkg/namespace/types.go b/pkg/namespace/types.go index b3146f2..9f77de7 100644 --- a/pkg/namespace/types.go +++ b/pkg/namespace/types.go @@ -24,6 +24,8 @@ const ( NodeRoleRQLiteFollower NodeRole = "rqlite_follower" NodeRoleOlric NodeRole = "olric" NodeRoleGateway NodeRole = "gateway" + NodeRoleSFU NodeRole = "sfu" + NodeRoleTURN NodeRole = "turn" ) // NodeStatus represents the status of a service on a node @@ -62,6 +64,12 @@ const ( EventNodeReplaced EventType = "node_replaced" EventRecoveryComplete EventType = "recovery_complete" EventRecoveryFailed EventType = "recovery_failed" + EventWebRTCEnabled EventType = "webrtc_enabled" + EventWebRTCDisabled EventType = "webrtc_disabled" + EventSFUStarted EventType = "sfu_started" + EventSFUStopped EventType = "sfu_stopped" + EventTURNStarted EventType = "turn_started" + EventTURNStopped EventType = "turn_stopped" ) // Port allocation constants @@ -80,6 +88,39 @@ const ( MaxNamespacesPerNode = (NamespacePortRangeEnd - NamespacePortRangeStart + 1) / PortsPerNamespace // 20 ) +// WebRTC port allocation constants +// These are separate from the core namespace port range (10000-10099) +// to avoid breaking existing port blocks. +const ( + // SFU media port range: 20000-29999 + // Each namespace gets a 500-port sub-range for RTP media + SFUMediaPortRangeStart = 20000 + SFUMediaPortRangeEnd = 29999 + SFUMediaPortsPerNamespace = 500 + + // SFU signaling ports: 30000-30099 + // Each namespace gets 1 signaling port per node + SFUSignalingPortRangeStart = 30000 + SFUSignalingPortRangeEnd = 30099 + + // TURN relay port range: 49152-65535 + // Each namespace gets an 800-port sub-range for TURN relay + TURNRelayPortRangeStart = 49152 + TURNRelayPortRangeEnd = 65535 + TURNRelayPortsPerNamespace = 800 + + // TURN listen ports (standard) + TURNDefaultPort = 3478 + TURNTLSPort = 443 + + // Default TURN credential TTL in seconds (10 minutes) + DefaultTURNCredentialTTL = 600 + + // Default service counts per namespace + DefaultSFUNodeCount = 3 // SFU on all 3 nodes + DefaultTURNNodeCount = 2 // TURN on 2 of 3 nodes for HA +) + // Default cluster sizes const ( DefaultRQLiteNodeCount = 3 @@ -206,4 +247,58 @@ var ( ErrNamespaceNotFound = &ClusterError{Message: "namespace not found"} ErrInvalidClusterStatus = &ClusterError{Message: "invalid cluster status for operation"} ErrRecoveryInProgress = &ClusterError{Message: "recovery already in progress for this cluster"} + ErrWebRTCAlreadyEnabled = &ClusterError{Message: "WebRTC is already enabled for this namespace"} + ErrWebRTCNotEnabled = &ClusterError{Message: "WebRTC is not enabled for this namespace"} + ErrNoWebRTCPortsAvailable = &ClusterError{Message: "no WebRTC ports available on node"} ) + +// WebRTCConfig represents the per-namespace WebRTC configuration stored in the database +type WebRTCConfig struct { + ID string `json:"id" db:"id"` + NamespaceClusterID string `json:"namespace_cluster_id" db:"namespace_cluster_id"` + NamespaceName string `json:"namespace_name" db:"namespace_name"` + Enabled bool `json:"enabled" db:"enabled"` + TURNSharedSecret string `json:"-" db:"turn_shared_secret"` // Never serialize secret to JSON + TURNCredentialTTL int `json:"turn_credential_ttl" db:"turn_credential_ttl"` + SFUNodeCount int `json:"sfu_node_count" db:"sfu_node_count"` + TURNNodeCount int `json:"turn_node_count" db:"turn_node_count"` + EnabledBy string `json:"enabled_by" db:"enabled_by"` + EnabledAt time.Time `json:"enabled_at" db:"enabled_at"` + DisabledAt *time.Time `json:"disabled_at,omitempty" db:"disabled_at"` +} + +// WebRTCRoom represents an active WebRTC room tracked in the database +type WebRTCRoom struct { + ID string `json:"id" db:"id"` + NamespaceClusterID string `json:"namespace_cluster_id" db:"namespace_cluster_id"` + NamespaceName string `json:"namespace_name" db:"namespace_name"` + RoomID string `json:"room_id" db:"room_id"` + SFUNodeID string `json:"sfu_node_id" db:"sfu_node_id"` + SFUInternalIP string `json:"sfu_internal_ip" db:"sfu_internal_ip"` + SFUSignalingPort int `json:"sfu_signaling_port" db:"sfu_signaling_port"` + ParticipantCount int `json:"participant_count" db:"participant_count"` + MaxParticipants int `json:"max_participants" db:"max_participants"` + CreatedAt time.Time `json:"created_at" db:"created_at"` + LastActivity time.Time `json:"last_activity" db:"last_activity"` +} + +// WebRTCPortBlock represents allocated WebRTC ports for a namespace on a node +type WebRTCPortBlock struct { + ID string `json:"id" db:"id"` + NodeID string `json:"node_id" db:"node_id"` + NamespaceClusterID string `json:"namespace_cluster_id" db:"namespace_cluster_id"` + ServiceType string `json:"service_type" db:"service_type"` // "sfu" or "turn" + + // SFU ports + SFUSignalingPort int `json:"sfu_signaling_port,omitempty" db:"sfu_signaling_port"` + SFUMediaPortStart int `json:"sfu_media_port_start,omitempty" db:"sfu_media_port_start"` + SFUMediaPortEnd int `json:"sfu_media_port_end,omitempty" db:"sfu_media_port_end"` + + // TURN ports + TURNListenPort int `json:"turn_listen_port,omitempty" db:"turn_listen_port"` + TURNTLSPort int `json:"turn_tls_port,omitempty" db:"turn_tls_port"` + TURNRelayPortStart int `json:"turn_relay_port_start,omitempty" db:"turn_relay_port_start"` + TURNRelayPortEnd int `json:"turn_relay_port_end,omitempty" db:"turn_relay_port_end"` + + AllocatedAt time.Time `json:"allocated_at" db:"allocated_at"` +} diff --git a/pkg/namespace/webrtc_port_allocator.go b/pkg/namespace/webrtc_port_allocator.go new file mode 100644 index 0000000..39d7a6e --- /dev/null +++ b/pkg/namespace/webrtc_port_allocator.go @@ -0,0 +1,519 @@ +package namespace + +import ( + "context" + "fmt" + "time" + + "github.com/DeBrosOfficial/network/pkg/client" + "github.com/DeBrosOfficial/network/pkg/rqlite" + "github.com/google/uuid" + "go.uber.org/zap" +) + +// WebRTCPortAllocator manages port allocation for SFU and TURN services. +// Uses the webrtc_port_allocations table, separate from namespace_port_allocations, +// to avoid breaking existing port blocks. +type WebRTCPortAllocator struct { + db rqlite.Client + logger *zap.Logger +} + +// NewWebRTCPortAllocator creates a new WebRTC port allocator +func NewWebRTCPortAllocator(db rqlite.Client, logger *zap.Logger) *WebRTCPortAllocator { + return &WebRTCPortAllocator{ + db: db, + logger: logger.With(zap.String("component", "webrtc-port-allocator")), + } +} + +// AllocateSFUPorts allocates SFU ports for a namespace on a node. +// Each namespace gets: 1 signaling port (30000-30099) + 500 media ports (20000-29999). +// Returns the existing allocation if one already exists (idempotent). +func (wpa *WebRTCPortAllocator) AllocateSFUPorts(ctx context.Context, nodeID, namespaceClusterID string) (*WebRTCPortBlock, error) { + internalCtx := client.WithInternalAuth(ctx) + + // Check for existing allocation (idempotent) + existing, err := wpa.GetSFUPorts(ctx, namespaceClusterID, nodeID) + if err == nil && existing != nil { + wpa.logger.Debug("SFU ports already allocated", + zap.String("node_id", nodeID), + zap.String("namespace_cluster_id", namespaceClusterID), + zap.Int("signaling_port", existing.SFUSignalingPort), + ) + return existing, nil + } + + // Retry logic for concurrent allocation conflicts + maxRetries := 10 + retryDelay := 100 * time.Millisecond + + for attempt := 0; attempt < maxRetries; attempt++ { + block, err := wpa.tryAllocateSFUPorts(internalCtx, nodeID, namespaceClusterID) + if err == nil { + wpa.logger.Info("SFU ports allocated", + zap.String("node_id", nodeID), + zap.String("namespace_cluster_id", namespaceClusterID), + zap.Int("signaling_port", block.SFUSignalingPort), + zap.Int("media_start", block.SFUMediaPortStart), + zap.Int("media_end", block.SFUMediaPortEnd), + zap.Int("attempt", attempt+1), + ) + return block, nil + } + + if isConflictError(err) { + wpa.logger.Debug("SFU port allocation conflict, retrying", + zap.String("node_id", nodeID), + zap.Int("attempt", attempt+1), + zap.Error(err), + ) + time.Sleep(retryDelay) + retryDelay *= 2 + continue + } + + return nil, err + } + + return nil, &ClusterError{ + Message: fmt.Sprintf("failed to allocate SFU ports after %d retries", maxRetries), + } +} + +// tryAllocateSFUPorts performs a single attempt to allocate SFU ports. +func (wpa *WebRTCPortAllocator) tryAllocateSFUPorts(ctx context.Context, nodeID, namespaceClusterID string) (*WebRTCPortBlock, error) { + // Get node IPs sharing the same physical address (dev environment handling) + nodeIDs, err := wpa.getColocatedNodeIDs(ctx, nodeID) + if err != nil { + nodeIDs = []string{nodeID} + } + + // Find next available SFU signaling port (30000-30099) + signalingPort, err := wpa.findAvailablePort(ctx, nodeIDs, "sfu", "sfu_signaling_port", + SFUSignalingPortRangeStart, SFUSignalingPortRangeEnd, 1) + if err != nil { + return nil, &ClusterError{ + Message: "no SFU signaling port available on node", + Cause: err, + } + } + + // Find next available SFU media port block (20000-29999, 500 per namespace) + mediaStart, err := wpa.findAvailablePortBlock(ctx, nodeIDs, "sfu", "sfu_media_port_start", + SFUMediaPortRangeStart, SFUMediaPortRangeEnd, SFUMediaPortsPerNamespace) + if err != nil { + return nil, &ClusterError{ + Message: "no SFU media port range available on node", + Cause: err, + } + } + + block := &WebRTCPortBlock{ + ID: uuid.New().String(), + NodeID: nodeID, + NamespaceClusterID: namespaceClusterID, + ServiceType: "sfu", + SFUSignalingPort: signalingPort, + SFUMediaPortStart: mediaStart, + SFUMediaPortEnd: mediaStart + SFUMediaPortsPerNamespace - 1, + AllocatedAt: time.Now(), + } + + if err := wpa.insertPortBlock(ctx, block); err != nil { + return nil, err + } + + return block, nil +} + +// AllocateTURNPorts allocates TURN ports for a namespace on a node. +// Each namespace gets: standard listen ports (3478/443) + 800 relay ports (49152-65535). +// Returns the existing allocation if one already exists (idempotent). +func (wpa *WebRTCPortAllocator) AllocateTURNPorts(ctx context.Context, nodeID, namespaceClusterID string) (*WebRTCPortBlock, error) { + internalCtx := client.WithInternalAuth(ctx) + + // Check for existing allocation (idempotent) + existing, err := wpa.GetTURNPorts(ctx, namespaceClusterID, nodeID) + if err == nil && existing != nil { + wpa.logger.Debug("TURN ports already allocated", + zap.String("node_id", nodeID), + zap.String("namespace_cluster_id", namespaceClusterID), + ) + return existing, nil + } + + // Retry logic for concurrent allocation conflicts + maxRetries := 10 + retryDelay := 100 * time.Millisecond + + for attempt := 0; attempt < maxRetries; attempt++ { + block, err := wpa.tryAllocateTURNPorts(internalCtx, nodeID, namespaceClusterID) + if err == nil { + wpa.logger.Info("TURN ports allocated", + zap.String("node_id", nodeID), + zap.String("namespace_cluster_id", namespaceClusterID), + zap.Int("relay_start", block.TURNRelayPortStart), + zap.Int("relay_end", block.TURNRelayPortEnd), + zap.Int("attempt", attempt+1), + ) + return block, nil + } + + if isConflictError(err) { + wpa.logger.Debug("TURN port allocation conflict, retrying", + zap.String("node_id", nodeID), + zap.Int("attempt", attempt+1), + zap.Error(err), + ) + time.Sleep(retryDelay) + retryDelay *= 2 + continue + } + + return nil, err + } + + return nil, &ClusterError{ + Message: fmt.Sprintf("failed to allocate TURN ports after %d retries", maxRetries), + } +} + +// tryAllocateTURNPorts performs a single attempt to allocate TURN ports. +func (wpa *WebRTCPortAllocator) tryAllocateTURNPorts(ctx context.Context, nodeID, namespaceClusterID string) (*WebRTCPortBlock, error) { + // Get colocated node IDs (dev environment handling) + nodeIDs, err := wpa.getColocatedNodeIDs(ctx, nodeID) + if err != nil { + nodeIDs = []string{nodeID} + } + + // Find next available TURN relay port block (49152-65535, 800 per namespace) + relayStart, err := wpa.findAvailablePortBlock(ctx, nodeIDs, "turn", "turn_relay_port_start", + TURNRelayPortRangeStart, TURNRelayPortRangeEnd, TURNRelayPortsPerNamespace) + if err != nil { + return nil, &ClusterError{ + Message: "no TURN relay port range available on node", + Cause: err, + } + } + + block := &WebRTCPortBlock{ + ID: uuid.New().String(), + NodeID: nodeID, + NamespaceClusterID: namespaceClusterID, + ServiceType: "turn", + TURNListenPort: TURNDefaultPort, + TURNTLSPort: TURNTLSPort, + TURNRelayPortStart: relayStart, + TURNRelayPortEnd: relayStart + TURNRelayPortsPerNamespace - 1, + AllocatedAt: time.Now(), + } + + if err := wpa.insertPortBlock(ctx, block); err != nil { + return nil, err + } + + return block, nil +} + +// DeallocateAll releases all WebRTC port blocks for a namespace cluster. +func (wpa *WebRTCPortAllocator) DeallocateAll(ctx context.Context, namespaceClusterID string) error { + internalCtx := client.WithInternalAuth(ctx) + + query := `DELETE FROM webrtc_port_allocations WHERE namespace_cluster_id = ?` + _, err := wpa.db.Exec(internalCtx, query, namespaceClusterID) + if err != nil { + return &ClusterError{ + Message: "failed to deallocate WebRTC port blocks", + Cause: err, + } + } + + wpa.logger.Info("All WebRTC port blocks deallocated", + zap.String("namespace_cluster_id", namespaceClusterID), + ) + + return nil +} + +// DeallocateByNode releases WebRTC port blocks for a specific node and service type. +func (wpa *WebRTCPortAllocator) DeallocateByNode(ctx context.Context, namespaceClusterID, nodeID, serviceType string) error { + internalCtx := client.WithInternalAuth(ctx) + + query := `DELETE FROM webrtc_port_allocations WHERE namespace_cluster_id = ? AND node_id = ? AND service_type = ?` + _, err := wpa.db.Exec(internalCtx, query, namespaceClusterID, nodeID, serviceType) + if err != nil { + return &ClusterError{ + Message: fmt.Sprintf("failed to deallocate %s port block on node %s", serviceType, nodeID), + Cause: err, + } + } + + wpa.logger.Info("WebRTC port block deallocated", + zap.String("namespace_cluster_id", namespaceClusterID), + zap.String("node_id", nodeID), + zap.String("service_type", serviceType), + ) + + return nil +} + +// GetSFUPorts retrieves the SFU port allocation for a namespace on a node. +func (wpa *WebRTCPortAllocator) GetSFUPorts(ctx context.Context, namespaceClusterID, nodeID string) (*WebRTCPortBlock, error) { + return wpa.getPortBlock(ctx, namespaceClusterID, nodeID, "sfu") +} + +// GetTURNPorts retrieves the TURN port allocation for a namespace on a node. +func (wpa *WebRTCPortAllocator) GetTURNPorts(ctx context.Context, namespaceClusterID, nodeID string) (*WebRTCPortBlock, error) { + return wpa.getPortBlock(ctx, namespaceClusterID, nodeID, "turn") +} + +// GetAllPorts retrieves all WebRTC port blocks for a namespace cluster. +func (wpa *WebRTCPortAllocator) GetAllPorts(ctx context.Context, namespaceClusterID string) ([]WebRTCPortBlock, error) { + internalCtx := client.WithInternalAuth(ctx) + + var blocks []WebRTCPortBlock + query := ` + SELECT id, node_id, namespace_cluster_id, service_type, + sfu_signaling_port, sfu_media_port_start, sfu_media_port_end, + turn_listen_port, turn_tls_port, turn_relay_port_start, turn_relay_port_end, + allocated_at + FROM webrtc_port_allocations + WHERE namespace_cluster_id = ? + ORDER BY service_type, node_id + ` + err := wpa.db.Query(internalCtx, &blocks, query, namespaceClusterID) + if err != nil { + return nil, &ClusterError{ + Message: "failed to query WebRTC port blocks", + Cause: err, + } + } + + return blocks, nil +} + +// NodeHasTURN checks if a node already has a TURN allocation from any namespace. +// Used during node selection to avoid port conflicts on standard TURN ports (3478/443). +func (wpa *WebRTCPortAllocator) NodeHasTURN(ctx context.Context, nodeID string) (bool, error) { + internalCtx := client.WithInternalAuth(ctx) + + type countResult struct { + Count int `db:"count"` + } + + var results []countResult + query := `SELECT COUNT(*) as count FROM webrtc_port_allocations WHERE node_id = ? AND service_type = 'turn'` + err := wpa.db.Query(internalCtx, &results, query, nodeID) + if err != nil { + return false, &ClusterError{ + Message: "failed to check TURN allocation on node", + Cause: err, + } + } + + if len(results) == 0 { + return false, nil + } + + return results[0].Count > 0, nil +} + +// --- internal helpers --- + +// getPortBlock retrieves a specific port block by cluster, node, and service type. +func (wpa *WebRTCPortAllocator) getPortBlock(ctx context.Context, namespaceClusterID, nodeID, serviceType string) (*WebRTCPortBlock, error) { + internalCtx := client.WithInternalAuth(ctx) + + var blocks []WebRTCPortBlock + query := ` + SELECT id, node_id, namespace_cluster_id, service_type, + sfu_signaling_port, sfu_media_port_start, sfu_media_port_end, + turn_listen_port, turn_tls_port, turn_relay_port_start, turn_relay_port_end, + allocated_at + FROM webrtc_port_allocations + WHERE namespace_cluster_id = ? AND node_id = ? AND service_type = ? + LIMIT 1 + ` + err := wpa.db.Query(internalCtx, &blocks, query, namespaceClusterID, nodeID, serviceType) + if err != nil { + return nil, &ClusterError{ + Message: fmt.Sprintf("failed to query %s port block", serviceType), + Cause: err, + } + } + + if len(blocks) == 0 { + return nil, nil + } + + return &blocks[0], nil +} + +// insertPortBlock inserts a WebRTC port allocation record. +func (wpa *WebRTCPortAllocator) insertPortBlock(ctx context.Context, block *WebRTCPortBlock) error { + query := ` + INSERT INTO webrtc_port_allocations ( + id, node_id, namespace_cluster_id, service_type, + sfu_signaling_port, sfu_media_port_start, sfu_media_port_end, + turn_listen_port, turn_tls_port, turn_relay_port_start, turn_relay_port_end, + allocated_at + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + ` + _, err := wpa.db.Exec(ctx, query, + block.ID, + block.NodeID, + block.NamespaceClusterID, + block.ServiceType, + block.SFUSignalingPort, + block.SFUMediaPortStart, + block.SFUMediaPortEnd, + block.TURNListenPort, + block.TURNTLSPort, + block.TURNRelayPortStart, + block.TURNRelayPortEnd, + block.AllocatedAt, + ) + if err != nil { + return &ClusterError{ + Message: fmt.Sprintf("failed to insert %s port allocation", block.ServiceType), + Cause: err, + } + } + + return nil +} + +// getColocatedNodeIDs returns all node IDs that share the same IP address as the given node. +// In dev environments, multiple logical nodes share one physical IP — port ranges must not overlap. +// In production (one node per IP), returns only the given nodeID. +func (wpa *WebRTCPortAllocator) getColocatedNodeIDs(ctx context.Context, nodeID string) ([]string, error) { + // Get this node's IP + type nodeInfo struct { + IPAddress string `db:"ip_address"` + } + var infos []nodeInfo + if err := wpa.db.Query(ctx, &infos, `SELECT ip_address FROM dns_nodes WHERE id = ? LIMIT 1`, nodeID); err != nil || len(infos) == 0 { + return []string{nodeID}, nil + } + + ip := infos[0].IPAddress + if ip == "" { + return []string{nodeID}, nil + } + + // Check if multiple nodes share this IP + type nodeIDRow struct { + ID string `db:"id"` + } + var colocated []nodeIDRow + if err := wpa.db.Query(ctx, &colocated, `SELECT id FROM dns_nodes WHERE ip_address = ?`, ip); err != nil || len(colocated) <= 1 { + return []string{nodeID}, nil + } + + ids := make([]string, len(colocated)) + for i, n := range colocated { + ids[i] = n.ID + } + + wpa.logger.Debug("Multiple nodes share IP, allocating globally", + zap.String("ip_address", ip), + zap.Int("node_count", len(ids)), + ) + + return ids, nil +} + +// findAvailablePort finds the next available single port in a range on the given nodes. +func (wpa *WebRTCPortAllocator) findAvailablePort(ctx context.Context, nodeIDs []string, serviceType, portColumn string, rangeStart, rangeEnd, step int) (int, error) { + allocated, err := wpa.getAllocatedValues(ctx, nodeIDs, serviceType, portColumn) + if err != nil { + return 0, err + } + + allocatedSet := make(map[int]bool, len(allocated)) + for _, v := range allocated { + allocatedSet[v] = true + } + + for port := rangeStart; port <= rangeEnd; port += step { + if !allocatedSet[port] { + return port, nil + } + } + + return 0, ErrNoWebRTCPortsAvailable +} + +// findAvailablePortBlock finds the next available contiguous port block in a range. +func (wpa *WebRTCPortAllocator) findAvailablePortBlock(ctx context.Context, nodeIDs []string, serviceType, portColumn string, rangeStart, rangeEnd, blockSize int) (int, error) { + allocated, err := wpa.getAllocatedValues(ctx, nodeIDs, serviceType, portColumn) + if err != nil { + return 0, err + } + + allocatedSet := make(map[int]bool, len(allocated)) + for _, v := range allocated { + allocatedSet[v] = true + } + + for start := rangeStart; start+blockSize-1 <= rangeEnd; start += blockSize { + if !allocatedSet[start] { + return start, nil + } + } + + return 0, ErrNoWebRTCPortsAvailable +} + +// getAllocatedValues queries the allocated port values for a given column across colocated nodes. +func (wpa *WebRTCPortAllocator) getAllocatedValues(ctx context.Context, nodeIDs []string, serviceType, portColumn string) ([]int, error) { + type portRow struct { + Port int `db:"port_val"` + } + + var rows []portRow + + if len(nodeIDs) == 1 { + query := fmt.Sprintf( + `SELECT %s as port_val FROM webrtc_port_allocations WHERE node_id = ? AND service_type = ? AND %s > 0 ORDER BY %s ASC`, + portColumn, portColumn, portColumn, + ) + if err := wpa.db.Query(ctx, &rows, query, nodeIDs[0], serviceType); err != nil { + return nil, &ClusterError{ + Message: "failed to query allocated WebRTC ports", + Cause: err, + } + } + } else { + // Multiple colocated nodes — query by joining with dns_nodes on IP + // Get the IP of the first node (they all share the same IP) + type nodeInfo struct { + IPAddress string `db:"ip_address"` + } + var infos []nodeInfo + if err := wpa.db.Query(ctx, &infos, `SELECT ip_address FROM dns_nodes WHERE id = ? LIMIT 1`, nodeIDs[0]); err != nil || len(infos) == 0 { + return nil, &ClusterError{Message: "failed to get node IP for colocated port query"} + } + + query := fmt.Sprintf( + `SELECT wpa.%s as port_val FROM webrtc_port_allocations wpa + JOIN dns_nodes dn ON wpa.node_id = dn.id + WHERE dn.ip_address = ? AND wpa.service_type = ? AND wpa.%s > 0 + ORDER BY wpa.%s ASC`, + portColumn, portColumn, portColumn, + ) + if err := wpa.db.Query(ctx, &rows, query, infos[0].IPAddress, serviceType); err != nil { + return nil, &ClusterError{ + Message: "failed to query allocated WebRTC ports (colocated)", + Cause: err, + } + } + } + + result := make([]int, len(rows)) + for i, r := range rows { + result[i] = r.Port + } + return result, nil +} diff --git a/pkg/namespace/webrtc_port_allocator_test.go b/pkg/namespace/webrtc_port_allocator_test.go new file mode 100644 index 0000000..ea217a1 --- /dev/null +++ b/pkg/namespace/webrtc_port_allocator_test.go @@ -0,0 +1,337 @@ +package namespace + +import ( + "context" + "strings" + "testing" + + "go.uber.org/zap" +) + +func TestWebRTCPortConstants_NoOverlap(t *testing.T) { + // Verify WebRTC port ranges don't overlap with core namespace ports (10000-10099) + ranges := []struct { + name string + start int + end int + }{ + {"core namespace", NamespacePortRangeStart, NamespacePortRangeEnd}, + {"SFU media", SFUMediaPortRangeStart, SFUMediaPortRangeEnd}, + {"SFU signaling", SFUSignalingPortRangeStart, SFUSignalingPortRangeEnd}, + {"TURN relay", TURNRelayPortRangeStart, TURNRelayPortRangeEnd}, + } + + for i := 0; i < len(ranges); i++ { + for j := i + 1; j < len(ranges); j++ { + a, b := ranges[i], ranges[j] + if a.start <= b.end && b.start <= a.end { + t.Errorf("Range overlap: %s (%d-%d) overlaps with %s (%d-%d)", + a.name, a.start, a.end, b.name, b.start, b.end) + } + } + } +} + +func TestWebRTCPortConstants_Capacity(t *testing.T) { + // SFU media: (29999-20000+1)/500 = 20 namespaces per node + sfuMediaCapacity := (SFUMediaPortRangeEnd - SFUMediaPortRangeStart + 1) / SFUMediaPortsPerNamespace + if sfuMediaCapacity < 20 { + t.Errorf("SFU media capacity = %d, want >= 20", sfuMediaCapacity) + } + + // SFU signaling: 30099-30000+1 = 100 ports → 100 namespaces per node + sfuSignalingCapacity := SFUSignalingPortRangeEnd - SFUSignalingPortRangeStart + 1 + if sfuSignalingCapacity < 20 { + t.Errorf("SFU signaling capacity = %d, want >= 20", sfuSignalingCapacity) + } + + // TURN relay: (65535-49152+1)/800 = 20 namespaces per node + turnRelayCapacity := (TURNRelayPortRangeEnd - TURNRelayPortRangeStart + 1) / TURNRelayPortsPerNamespace + if turnRelayCapacity < 20 { + t.Errorf("TURN relay capacity = %d, want >= 20", turnRelayCapacity) + } +} + +func TestWebRTCPortConstants_Values(t *testing.T) { + if SFUMediaPortRangeStart != 20000 { + t.Errorf("SFUMediaPortRangeStart = %d, want 20000", SFUMediaPortRangeStart) + } + if SFUMediaPortRangeEnd != 29999 { + t.Errorf("SFUMediaPortRangeEnd = %d, want 29999", SFUMediaPortRangeEnd) + } + if SFUMediaPortsPerNamespace != 500 { + t.Errorf("SFUMediaPortsPerNamespace = %d, want 500", SFUMediaPortsPerNamespace) + } + if SFUSignalingPortRangeStart != 30000 { + t.Errorf("SFUSignalingPortRangeStart = %d, want 30000", SFUSignalingPortRangeStart) + } + if TURNRelayPortRangeStart != 49152 { + t.Errorf("TURNRelayPortRangeStart = %d, want 49152", TURNRelayPortRangeStart) + } + if TURNRelayPortsPerNamespace != 800 { + t.Errorf("TURNRelayPortsPerNamespace = %d, want 800", TURNRelayPortsPerNamespace) + } + if TURNDefaultPort != 3478 { + t.Errorf("TURNDefaultPort = %d, want 3478", TURNDefaultPort) + } + if DefaultSFUNodeCount != 3 { + t.Errorf("DefaultSFUNodeCount = %d, want 3", DefaultSFUNodeCount) + } + if DefaultTURNNodeCount != 2 { + t.Errorf("DefaultTURNNodeCount = %d, want 2", DefaultTURNNodeCount) + } +} + +func TestNewWebRTCPortAllocator(t *testing.T) { + mockDB := newMockRQLiteClient() + allocator := NewWebRTCPortAllocator(mockDB, testLogger()) + + if allocator == nil { + t.Fatal("NewWebRTCPortAllocator returned nil") + } + if allocator.db != mockDB { + t.Error("allocator.db not set correctly") + } +} + +func TestWebRTCPortAllocator_AllocateSFUPorts(t *testing.T) { + mockDB := newMockRQLiteClient() + allocator := NewWebRTCPortAllocator(mockDB, testLogger()) + + block, err := allocator.AllocateSFUPorts(context.Background(), "node-1", "cluster-1") + if err != nil { + t.Fatalf("AllocateSFUPorts failed: %v", err) + } + + if block == nil { + t.Fatal("AllocateSFUPorts returned nil block") + } + + if block.ServiceType != "sfu" { + t.Errorf("ServiceType = %q, want %q", block.ServiceType, "sfu") + } + if block.NodeID != "node-1" { + t.Errorf("NodeID = %q, want %q", block.NodeID, "node-1") + } + if block.NamespaceClusterID != "cluster-1" { + t.Errorf("NamespaceClusterID = %q, want %q", block.NamespaceClusterID, "cluster-1") + } + + // First allocation should get the first port in each range + if block.SFUSignalingPort != SFUSignalingPortRangeStart { + t.Errorf("SFUSignalingPort = %d, want %d", block.SFUSignalingPort, SFUSignalingPortRangeStart) + } + if block.SFUMediaPortStart != SFUMediaPortRangeStart { + t.Errorf("SFUMediaPortStart = %d, want %d", block.SFUMediaPortStart, SFUMediaPortRangeStart) + } + if block.SFUMediaPortEnd != SFUMediaPortRangeStart+SFUMediaPortsPerNamespace-1 { + t.Errorf("SFUMediaPortEnd = %d, want %d", block.SFUMediaPortEnd, SFUMediaPortRangeStart+SFUMediaPortsPerNamespace-1) + } + + // TURN fields should be zero for SFU allocation + if block.TURNListenPort != 0 { + t.Errorf("TURNListenPort = %d, want 0 for SFU allocation", block.TURNListenPort) + } + if block.TURNRelayPortStart != 0 { + t.Errorf("TURNRelayPortStart = %d, want 0 for SFU allocation", block.TURNRelayPortStart) + } + + // Verify INSERT was called + hasInsert := false + for _, call := range mockDB.execCalls { + if strings.Contains(call.Query, "INSERT INTO webrtc_port_allocations") { + hasInsert = true + break + } + } + if !hasInsert { + t.Error("expected INSERT INTO webrtc_port_allocations to be called") + } +} + +func TestWebRTCPortAllocator_AllocateTURNPorts(t *testing.T) { + mockDB := newMockRQLiteClient() + allocator := NewWebRTCPortAllocator(mockDB, testLogger()) + + block, err := allocator.AllocateTURNPorts(context.Background(), "node-1", "cluster-1") + if err != nil { + t.Fatalf("AllocateTURNPorts failed: %v", err) + } + + if block == nil { + t.Fatal("AllocateTURNPorts returned nil block") + } + + if block.ServiceType != "turn" { + t.Errorf("ServiceType = %q, want %q", block.ServiceType, "turn") + } + if block.TURNListenPort != TURNDefaultPort { + t.Errorf("TURNListenPort = %d, want %d", block.TURNListenPort, TURNDefaultPort) + } + if block.TURNTLSPort != TURNTLSPort { + t.Errorf("TURNTLSPort = %d, want %d", block.TURNTLSPort, TURNTLSPort) + } + if block.TURNRelayPortStart != TURNRelayPortRangeStart { + t.Errorf("TURNRelayPortStart = %d, want %d", block.TURNRelayPortStart, TURNRelayPortRangeStart) + } + if block.TURNRelayPortEnd != TURNRelayPortRangeStart+TURNRelayPortsPerNamespace-1 { + t.Errorf("TURNRelayPortEnd = %d, want %d", block.TURNRelayPortEnd, TURNRelayPortRangeStart+TURNRelayPortsPerNamespace-1) + } + + // SFU fields should be zero for TURN allocation + if block.SFUSignalingPort != 0 { + t.Errorf("SFUSignalingPort = %d, want 0 for TURN allocation", block.SFUSignalingPort) + } +} + +func TestWebRTCPortAllocator_DeallocateAll(t *testing.T) { + mockDB := newMockRQLiteClient() + allocator := NewWebRTCPortAllocator(mockDB, testLogger()) + + err := allocator.DeallocateAll(context.Background(), "cluster-1") + if err != nil { + t.Fatalf("DeallocateAll failed: %v", err) + } + + // Verify DELETE was called with correct cluster ID + hasDelete := false + for _, call := range mockDB.execCalls { + if strings.Contains(call.Query, "DELETE FROM webrtc_port_allocations") && + strings.Contains(call.Query, "namespace_cluster_id") { + hasDelete = true + if len(call.Args) < 1 || call.Args[0] != "cluster-1" { + t.Errorf("DELETE called with wrong cluster ID: %v", call.Args) + } + } + } + if !hasDelete { + t.Error("expected DELETE FROM webrtc_port_allocations to be called") + } +} + +func TestWebRTCPortAllocator_DeallocateByNode(t *testing.T) { + mockDB := newMockRQLiteClient() + allocator := NewWebRTCPortAllocator(mockDB, testLogger()) + + err := allocator.DeallocateByNode(context.Background(), "cluster-1", "node-1", "sfu") + if err != nil { + t.Fatalf("DeallocateByNode failed: %v", err) + } + + // Verify DELETE was called with correct parameters + hasDelete := false + for _, call := range mockDB.execCalls { + if strings.Contains(call.Query, "DELETE FROM webrtc_port_allocations") && + strings.Contains(call.Query, "service_type") { + hasDelete = true + if len(call.Args) != 3 { + t.Fatalf("DELETE called with %d args, want 3", len(call.Args)) + } + if call.Args[0] != "cluster-1" { + t.Errorf("arg[0] = %v, want cluster-1", call.Args[0]) + } + if call.Args[1] != "node-1" { + t.Errorf("arg[1] = %v, want node-1", call.Args[1]) + } + if call.Args[2] != "sfu" { + t.Errorf("arg[2] = %v, want sfu", call.Args[2]) + } + } + } + if !hasDelete { + t.Error("expected DELETE FROM webrtc_port_allocations to be called") + } +} + +func TestWebRTCPortAllocator_NodeHasTURN(t *testing.T) { + mockDB := newMockRQLiteClient() + allocator := NewWebRTCPortAllocator(mockDB, testLogger()) + + // Mock query returns empty results → no TURN on node + hasTURN, err := allocator.NodeHasTURN(context.Background(), "node-1") + if err != nil { + t.Fatalf("NodeHasTURN failed: %v", err) + } + if hasTURN { + t.Error("expected NodeHasTURN = false for node with no allocations") + } +} + +func TestWebRTCPortAllocator_GetSFUPorts_NoAllocation(t *testing.T) { + mockDB := newMockRQLiteClient() + allocator := NewWebRTCPortAllocator(mockDB, testLogger()) + + block, err := allocator.GetSFUPorts(context.Background(), "cluster-1", "node-1") + if err != nil { + t.Fatalf("GetSFUPorts failed: %v", err) + } + if block != nil { + t.Error("expected nil block when no allocation exists") + } +} + +func TestWebRTCPortAllocator_GetTURNPorts_NoAllocation(t *testing.T) { + mockDB := newMockRQLiteClient() + allocator := NewWebRTCPortAllocator(mockDB, testLogger()) + + block, err := allocator.GetTURNPorts(context.Background(), "cluster-1", "node-1") + if err != nil { + t.Fatalf("GetTURNPorts failed: %v", err) + } + if block != nil { + t.Error("expected nil block when no allocation exists") + } +} + +func TestWebRTCPortAllocator_GetAllPorts_Empty(t *testing.T) { + mockDB := newMockRQLiteClient() + allocator := NewWebRTCPortAllocator(mockDB, testLogger()) + + blocks, err := allocator.GetAllPorts(context.Background(), "cluster-1") + if err != nil { + t.Fatalf("GetAllPorts failed: %v", err) + } + if len(blocks) != 0 { + t.Errorf("expected 0 blocks, got %d", len(blocks)) + } +} + +func TestWebRTCPortBlock_SFUFields(t *testing.T) { + block := &WebRTCPortBlock{ + ID: "test-id", + NodeID: "node-1", + NamespaceClusterID: "cluster-1", + ServiceType: "sfu", + SFUSignalingPort: 30000, + SFUMediaPortStart: 20000, + SFUMediaPortEnd: 20499, + } + + mediaRange := block.SFUMediaPortEnd - block.SFUMediaPortStart + 1 + if mediaRange != SFUMediaPortsPerNamespace { + t.Errorf("SFU media range = %d, want %d", mediaRange, SFUMediaPortsPerNamespace) + } +} + +func TestWebRTCPortBlock_TURNFields(t *testing.T) { + block := &WebRTCPortBlock{ + ID: "test-id", + NodeID: "node-1", + NamespaceClusterID: "cluster-1", + ServiceType: "turn", + TURNListenPort: 3478, + TURNTLSPort: 443, + TURNRelayPortStart: 49152, + TURNRelayPortEnd: 49951, + } + + relayRange := block.TURNRelayPortEnd - block.TURNRelayPortStart + 1 + if relayRange != TURNRelayPortsPerNamespace { + t.Errorf("TURN relay range = %d, want %d", relayRange, TURNRelayPortsPerNamespace) + } +} + +// testLogger returns a no-op logger for tests +func testLogger() *zap.Logger { + return zap.NewNop() +} diff --git a/pkg/node/gateway.go b/pkg/node/gateway.go index 5ef9807..915cbb3 100644 --- a/pkg/node/gateway.go +++ b/pkg/node/gateway.go @@ -57,7 +57,11 @@ func (n *Node) startHTTPGateway(ctx context.Context) error { IPFSTimeout: n.config.HTTPGateway.IPFSTimeout, BaseDomain: n.config.HTTPGateway.BaseDomain, DataDir: oramaDir, - ClusterSecret: clusterSecret, + ClusterSecret: clusterSecret, + WebRTCEnabled: n.config.HTTPGateway.WebRTC.Enabled, + SFUPort: n.config.HTTPGateway.WebRTC.SFUPort, + TURNDomain: n.config.HTTPGateway.WebRTC.TURNDomain, + TURNSecret: n.config.HTTPGateway.WebRTC.TURNSecret, } apiGateway, err := gateway.New(gatewayLogger, gwCfg) @@ -82,6 +86,7 @@ func (n *Node) startHTTPGateway(ctx context.Context) error { clusterManager.SetLocalNodeID(gwCfg.NodePeerID) apiGateway.SetClusterProvisioner(clusterManager) apiGateway.SetNodeRecoverer(clusterManager) + apiGateway.SetWebRTCManager(clusterManager) // Wire spawn handler for distributed namespace instance spawning systemdSpawner := namespace.NewSystemdSpawner(baseDataDir, n.logger.Logger) diff --git a/pkg/rqlite/scanner.go b/pkg/rqlite/scanner.go index fe8c8e1..3581be3 100644 --- a/pkg/rqlite/scanner.go +++ b/pkg/rqlite/scanner.go @@ -173,6 +173,8 @@ func setReflectValue(field reflect.Value, raw any) error { field.SetBool(v) case int64: field.SetBool(v != 0) + case float64: + field.SetBool(v != 0) case []byte: s := string(v) field.SetBool(s == "1" || strings.EqualFold(s, "true")) diff --git a/pkg/serverless/triggers/dispatcher.go b/pkg/serverless/triggers/dispatcher.go new file mode 100644 index 0000000..94e5d55 --- /dev/null +++ b/pkg/serverless/triggers/dispatcher.go @@ -0,0 +1,230 @@ +package triggers + +import ( + "context" + "encoding/json" + "time" + + "github.com/DeBrosOfficial/network/pkg/serverless" + olriclib "github.com/olric-data/olric" + "go.uber.org/zap" +) + +const ( + // triggerCacheDMap is the Olric DMap name for caching trigger lookups. + triggerCacheDMap = "pubsub_triggers" + + // maxTriggerDepth prevents infinite loops when triggered functions publish + // back to the same topic via the HTTP API. + maxTriggerDepth = 5 + + // dispatchTimeout is the timeout for each triggered function invocation. + dispatchTimeout = 60 * time.Second +) + +// PubSubEvent is the JSON payload sent to functions triggered by PubSub messages. +type PubSubEvent struct { + Topic string `json:"topic"` + Data json.RawMessage `json:"data"` + Namespace string `json:"namespace"` + TriggerDepth int `json:"trigger_depth"` + Timestamp int64 `json:"timestamp"` +} + +// PubSubDispatcher looks up triggers for a topic+namespace and asynchronously +// invokes matching serverless functions. +type PubSubDispatcher struct { + store *PubSubTriggerStore + invoker *serverless.Invoker + olricClient olriclib.Client // may be nil (cache disabled) + logger *zap.Logger +} + +// NewPubSubDispatcher creates a new PubSub trigger dispatcher. +func NewPubSubDispatcher( + store *PubSubTriggerStore, + invoker *serverless.Invoker, + olricClient olriclib.Client, + logger *zap.Logger, +) *PubSubDispatcher { + return &PubSubDispatcher{ + store: store, + invoker: invoker, + olricClient: olricClient, + logger: logger, + } +} + +// Dispatch looks up all triggers registered for the given topic+namespace and +// invokes matching functions asynchronously. Each invocation runs in its own +// goroutine and does not block the caller. +func (d *PubSubDispatcher) Dispatch(ctx context.Context, namespace, topic string, data []byte, depth int) { + if depth >= maxTriggerDepth { + d.logger.Warn("PubSub trigger depth limit reached, skipping dispatch", + zap.String("namespace", namespace), + zap.String("topic", topic), + zap.Int("depth", depth), + ) + return + } + + matches, err := d.getMatches(ctx, namespace, topic) + if err != nil { + d.logger.Error("Failed to look up PubSub triggers", + zap.String("namespace", namespace), + zap.String("topic", topic), + zap.Error(err), + ) + return + } + + if len(matches) == 0 { + return + } + + // Build the event payload once for all invocations + event := PubSubEvent{ + Topic: topic, + Data: json.RawMessage(data), + Namespace: namespace, + TriggerDepth: depth + 1, + Timestamp: time.Now().Unix(), + } + eventJSON, err := json.Marshal(event) + if err != nil { + d.logger.Error("Failed to marshal PubSub event", zap.Error(err)) + return + } + + d.logger.Debug("Dispatching PubSub triggers", + zap.String("namespace", namespace), + zap.String("topic", topic), + zap.Int("matches", len(matches)), + zap.Int("depth", depth), + ) + + for _, match := range matches { + go d.invokeFunction(match, eventJSON) + } +} + +// InvalidateCache removes the cached trigger lookup for a namespace+topic. +// Call this when triggers are added or removed. +func (d *PubSubDispatcher) InvalidateCache(ctx context.Context, namespace, topic string) { + if d.olricClient == nil { + return + } + + dm, err := d.olricClient.NewDMap(triggerCacheDMap) + if err != nil { + d.logger.Debug("Failed to get trigger cache DMap for invalidation", zap.Error(err)) + return + } + + key := cacheKey(namespace, topic) + if _, err := dm.Delete(ctx, key); err != nil { + d.logger.Debug("Failed to invalidate trigger cache", zap.String("key", key), zap.Error(err)) + } +} + +// getMatches returns the trigger matches for a topic+namespace, using Olric cache when available. +func (d *PubSubDispatcher) getMatches(ctx context.Context, namespace, topic string) ([]TriggerMatch, error) { + // Try cache first + if d.olricClient != nil { + if matches, ok := d.getCached(ctx, namespace, topic); ok { + return matches, nil + } + } + + // Cache miss — query database + matches, err := d.store.GetByTopicAndNamespace(ctx, topic, namespace) + if err != nil { + return nil, err + } + + // Populate cache + if d.olricClient != nil && matches != nil { + d.setCache(ctx, namespace, topic, matches) + } + + return matches, nil +} + +// getCached attempts to retrieve trigger matches from Olric cache. +func (d *PubSubDispatcher) getCached(ctx context.Context, namespace, topic string) ([]TriggerMatch, bool) { + dm, err := d.olricClient.NewDMap(triggerCacheDMap) + if err != nil { + return nil, false + } + + key := cacheKey(namespace, topic) + result, err := dm.Get(ctx, key) + if err != nil { + return nil, false + } + + data, err := result.Byte() + if err != nil { + return nil, false + } + + var matches []TriggerMatch + if err := json.Unmarshal(data, &matches); err != nil { + return nil, false + } + + return matches, true +} + +// setCache stores trigger matches in Olric cache. +func (d *PubSubDispatcher) setCache(ctx context.Context, namespace, topic string, matches []TriggerMatch) { + dm, err := d.olricClient.NewDMap(triggerCacheDMap) + if err != nil { + return + } + + data, err := json.Marshal(matches) + if err != nil { + return + } + + key := cacheKey(namespace, topic) + _ = dm.Put(ctx, key, data) +} + +// invokeFunction invokes a single function for a trigger match. +func (d *PubSubDispatcher) invokeFunction(match TriggerMatch, eventJSON []byte) { + ctx, cancel := context.WithTimeout(context.Background(), dispatchTimeout) + defer cancel() + + req := &serverless.InvokeRequest{ + Namespace: match.Namespace, + FunctionName: match.FunctionName, + Input: eventJSON, + TriggerType: serverless.TriggerTypePubSub, + } + + resp, err := d.invoker.Invoke(ctx, req) + if err != nil { + d.logger.Warn("PubSub trigger invocation failed", + zap.String("function", match.FunctionName), + zap.String("namespace", match.Namespace), + zap.String("topic", match.Topic), + zap.String("trigger_id", match.TriggerID), + zap.Error(err), + ) + return + } + + d.logger.Debug("PubSub trigger invocation completed", + zap.String("function", match.FunctionName), + zap.String("topic", match.Topic), + zap.String("status", string(resp.Status)), + zap.Int64("duration_ms", resp.DurationMS), + ) +} + +// cacheKey returns the Olric cache key for a namespace+topic pair. +func cacheKey(namespace, topic string) string { + return "triggers:" + namespace + ":" + topic +} diff --git a/pkg/serverless/triggers/pubsub_store.go b/pkg/serverless/triggers/pubsub_store.go new file mode 100644 index 0000000..7ee14fb --- /dev/null +++ b/pkg/serverless/triggers/pubsub_store.go @@ -0,0 +1,187 @@ +// Package triggers provides PubSub trigger management for the serverless engine. +// It handles registering, querying, and removing triggers that automatically invoke +// functions when messages are published to specific PubSub topics. +package triggers + +import ( + "context" + "fmt" + "time" + + "github.com/DeBrosOfficial/network/pkg/rqlite" + "github.com/DeBrosOfficial/network/pkg/serverless" + "github.com/google/uuid" + "go.uber.org/zap" +) + +// TriggerMatch contains the fields needed to dispatch a trigger invocation. +// It's the result of JOINing function_pubsub_triggers with functions. +type TriggerMatch struct { + TriggerID string + FunctionID string + FunctionName string + Namespace string + Topic string +} + +// triggerRow maps to the function_pubsub_triggers table for query scanning. +type triggerRow struct { + ID string + FunctionID string + Topic string + Enabled bool + CreatedAt time.Time +} + +// triggerMatchRow maps to the JOIN query result for scanning. +type triggerMatchRow struct { + TriggerID string + FunctionID string + FunctionName string + Namespace string + Topic string +} + +// PubSubTriggerStore manages PubSub trigger persistence in RQLite. +type PubSubTriggerStore struct { + db rqlite.Client + logger *zap.Logger +} + +// NewPubSubTriggerStore creates a new PubSub trigger store. +func NewPubSubTriggerStore(db rqlite.Client, logger *zap.Logger) *PubSubTriggerStore { + return &PubSubTriggerStore{ + db: db, + logger: logger, + } +} + +// Add registers a new PubSub trigger for a function. +// Returns the trigger ID. +func (s *PubSubTriggerStore) Add(ctx context.Context, functionID, topic string) (string, error) { + if functionID == "" { + return "", fmt.Errorf("function ID required") + } + if topic == "" { + return "", fmt.Errorf("topic required") + } + + id := uuid.New().String() + now := time.Now() + + query := ` + INSERT INTO function_pubsub_triggers (id, function_id, topic, enabled, created_at) + VALUES (?, ?, ?, TRUE, ?) + ` + if _, err := s.db.Exec(ctx, query, id, functionID, topic, now); err != nil { + return "", fmt.Errorf("failed to add pubsub trigger: %w", err) + } + + s.logger.Info("PubSub trigger added", + zap.String("trigger_id", id), + zap.String("function_id", functionID), + zap.String("topic", topic), + ) + + return id, nil +} + +// Remove deletes a trigger by ID. +func (s *PubSubTriggerStore) Remove(ctx context.Context, triggerID string) error { + if triggerID == "" { + return fmt.Errorf("trigger ID required") + } + + query := `DELETE FROM function_pubsub_triggers WHERE id = ?` + result, err := s.db.Exec(ctx, query, triggerID) + if err != nil { + return fmt.Errorf("failed to remove trigger: %w", err) + } + + affected, _ := result.RowsAffected() + if affected == 0 { + return fmt.Errorf("trigger not found: %s", triggerID) + } + + s.logger.Info("PubSub trigger removed", zap.String("trigger_id", triggerID)) + return nil +} + +// RemoveByFunction deletes all triggers for a function. +// Used during function re-deploy to clear old triggers. +func (s *PubSubTriggerStore) RemoveByFunction(ctx context.Context, functionID string) error { + if functionID == "" { + return fmt.Errorf("function ID required") + } + + query := `DELETE FROM function_pubsub_triggers WHERE function_id = ?` + if _, err := s.db.Exec(ctx, query, functionID); err != nil { + return fmt.Errorf("failed to remove triggers for function: %w", err) + } + + return nil +} + +// ListByFunction returns all PubSub triggers for a function. +func (s *PubSubTriggerStore) ListByFunction(ctx context.Context, functionID string) ([]serverless.PubSubTrigger, error) { + if functionID == "" { + return nil, fmt.Errorf("function ID required") + } + + query := ` + SELECT id, function_id, topic, enabled, created_at + FROM function_pubsub_triggers + WHERE function_id = ? + ` + + var rows []triggerRow + if err := s.db.Query(ctx, &rows, query, functionID); err != nil { + return nil, fmt.Errorf("failed to list triggers: %w", err) + } + + triggers := make([]serverless.PubSubTrigger, len(rows)) + for i, row := range rows { + triggers[i] = serverless.PubSubTrigger{ + ID: row.ID, + FunctionID: row.FunctionID, + Topic: row.Topic, + Enabled: row.Enabled, + } + } + + return triggers, nil +} + +// GetByTopicAndNamespace returns all enabled triggers for a topic within a namespace. +// Only returns triggers for active functions. +func (s *PubSubTriggerStore) GetByTopicAndNamespace(ctx context.Context, topic, namespace string) ([]TriggerMatch, error) { + if topic == "" || namespace == "" { + return nil, nil + } + + query := ` + SELECT t.id AS trigger_id, t.function_id AS function_id, + f.name AS function_name, f.namespace AS namespace, t.topic AS topic + FROM function_pubsub_triggers t + JOIN functions f ON t.function_id = f.id + WHERE t.topic = ? AND f.namespace = ? AND t.enabled = TRUE AND f.status = 'active' + ` + + var rows []triggerMatchRow + if err := s.db.Query(ctx, &rows, query, topic, namespace); err != nil { + return nil, fmt.Errorf("failed to query triggers for topic: %w", err) + } + + matches := make([]TriggerMatch, len(rows)) + for i, row := range rows { + matches[i] = TriggerMatch{ + TriggerID: row.TriggerID, + FunctionID: row.FunctionID, + FunctionName: row.FunctionName, + Namespace: row.Namespace, + Topic: row.Topic, + } + } + + return matches, nil +} diff --git a/pkg/serverless/triggers/triggers_test.go b/pkg/serverless/triggers/triggers_test.go new file mode 100644 index 0000000..a9822cc --- /dev/null +++ b/pkg/serverless/triggers/triggers_test.go @@ -0,0 +1,219 @@ +package triggers + +import ( + "context" + "encoding/json" + "sync/atomic" + "testing" + "time" + + "github.com/DeBrosOfficial/network/pkg/serverless" + "go.uber.org/zap" +) + +// --------------------------------------------------------------------------- +// Mock Invoker +// --------------------------------------------------------------------------- + +type mockInvokeCall struct { + Namespace string + FunctionName string + TriggerType serverless.TriggerType + Input []byte +} + +// mockInvokerForTest wraps a real nil invoker but tracks calls. +// Since we can't construct a real Invoker without engine/registry/hostfuncs, +// we test the dispatcher at a higher level by checking its behavior. + +// --------------------------------------------------------------------------- +// Dispatcher Tests +// --------------------------------------------------------------------------- + +func TestDispatcher_DepthLimit(t *testing.T) { + logger, _ := zap.NewDevelopment() + store := NewPubSubTriggerStore(nil, logger) // store won't be called + d := NewPubSubDispatcher(store, nil, nil, logger) + + // Dispatch at max depth should be a no-op (no panic, no store call) + d.Dispatch(context.Background(), "ns", "topic", []byte("data"), maxTriggerDepth) + d.Dispatch(context.Background(), "ns", "topic", []byte("data"), maxTriggerDepth+1) +} + +func TestCacheKey(t *testing.T) { + key := cacheKey("my-namespace", "my-topic") + if key != "triggers:my-namespace:my-topic" { + t.Errorf("unexpected cache key: %s", key) + } +} + +func TestPubSubEvent_Marshal(t *testing.T) { + event := PubSubEvent{ + Topic: "chat", + Data: json.RawMessage(`{"msg":"hello"}`), + Namespace: "my-app", + TriggerDepth: 1, + Timestamp: 1708300000, + } + + data, err := json.Marshal(event) + if err != nil { + t.Fatalf("marshal failed: %v", err) + } + + var decoded PubSubEvent + if err := json.Unmarshal(data, &decoded); err != nil { + t.Fatalf("unmarshal failed: %v", err) + } + + if decoded.Topic != "chat" { + t.Errorf("expected topic 'chat', got '%s'", decoded.Topic) + } + if decoded.Namespace != "my-app" { + t.Errorf("expected namespace 'my-app', got '%s'", decoded.Namespace) + } + if decoded.TriggerDepth != 1 { + t.Errorf("expected depth 1, got %d", decoded.TriggerDepth) + } +} + +// --------------------------------------------------------------------------- +// Store Tests (validation only — DB operations require rqlite.Client) +// --------------------------------------------------------------------------- + +func TestStore_AddValidation(t *testing.T) { + logger, _ := zap.NewDevelopment() + store := NewPubSubTriggerStore(nil, logger) + + _, err := store.Add(context.Background(), "", "topic") + if err == nil { + t.Error("expected error for empty function ID") + } + + _, err = store.Add(context.Background(), "fn-123", "") + if err == nil { + t.Error("expected error for empty topic") + } +} + +func TestStore_RemoveValidation(t *testing.T) { + logger, _ := zap.NewDevelopment() + store := NewPubSubTriggerStore(nil, logger) + + err := store.Remove(context.Background(), "") + if err == nil { + t.Error("expected error for empty trigger ID") + } +} + +func TestStore_RemoveByFunctionValidation(t *testing.T) { + logger, _ := zap.NewDevelopment() + store := NewPubSubTriggerStore(nil, logger) + + err := store.RemoveByFunction(context.Background(), "") + if err == nil { + t.Error("expected error for empty function ID") + } +} + +func TestStore_ListByFunctionValidation(t *testing.T) { + logger, _ := zap.NewDevelopment() + store := NewPubSubTriggerStore(nil, logger) + + _, err := store.ListByFunction(context.Background(), "") + if err == nil { + t.Error("expected error for empty function ID") + } +} + +func TestStore_GetByTopicAndNamespace_Empty(t *testing.T) { + logger, _ := zap.NewDevelopment() + store := NewPubSubTriggerStore(nil, logger) + + // Empty topic/namespace should return nil, nil (not an error) + matches, err := store.GetByTopicAndNamespace(context.Background(), "", "ns") + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if matches != nil { + t.Errorf("expected nil matches for empty topic, got %v", matches) + } + + matches, err = store.GetByTopicAndNamespace(context.Background(), "topic", "") + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if matches != nil { + t.Errorf("expected nil matches for empty namespace, got %v", matches) + } +} + +// --------------------------------------------------------------------------- +// Dispatcher Integration-like Tests +// --------------------------------------------------------------------------- + +func TestDispatcher_NoMatchesNoPanic(t *testing.T) { + // Dispatcher with nil olricClient and nil invoker should handle + // the case where there are no matches gracefully. + logger, _ := zap.NewDevelopment() + + // Create a mock store that returns empty matches + store := &mockTriggerStore{matches: nil} + d := &PubSubDispatcher{ + store: &PubSubTriggerStore{db: nil, logger: logger}, + invoker: nil, + logger: logger, + } + // Replace store field directly for testing + d.store = store.asPubSubTriggerStore() + + // This should not panic even with nil invoker since no matches + // We can't easily test this without a real store, so we test the depth limit instead + d.Dispatch(context.Background(), "ns", "topic", []byte("data"), maxTriggerDepth) +} + +// mockTriggerStore is used only for structural validation in tests. +type mockTriggerStore struct { + matches []TriggerMatch +} + +func (m *mockTriggerStore) asPubSubTriggerStore() *PubSubTriggerStore { + // Can't return a mock as *PubSubTriggerStore since it's a concrete type. + // This is a limitation — integration tests with a real rqlite would be needed. + return nil +} + +// --------------------------------------------------------------------------- +// Callback Wiring Test +// --------------------------------------------------------------------------- + +func TestOnPublishCallback(t *testing.T) { + var called atomic.Int32 + var receivedNS, receivedTopic string + var receivedData []byte + + callback := func(ctx context.Context, namespace, topic string, data []byte) { + called.Add(1) + receivedNS = namespace + receivedTopic = topic + receivedData = data + } + + // Simulate what gateway.go does + callback(context.Background(), "my-ns", "events", []byte("hello")) + + time.Sleep(10 * time.Millisecond) // Let goroutine complete + + if called.Load() != 1 { + t.Errorf("expected callback called once, got %d", called.Load()) + } + if receivedNS != "my-ns" { + t.Errorf("expected namespace 'my-ns', got '%s'", receivedNS) + } + if receivedTopic != "events" { + t.Errorf("expected topic 'events', got '%s'", receivedTopic) + } + if string(receivedData) != "hello" { + t.Errorf("expected data 'hello', got '%s'", string(receivedData)) + } +} diff --git a/pkg/sfu/config.go b/pkg/sfu/config.go new file mode 100644 index 0000000..ac71a40 --- /dev/null +++ b/pkg/sfu/config.go @@ -0,0 +1,80 @@ +package sfu + +import "fmt" + +// Config holds configuration for the SFU server +type Config struct { + // ListenAddr is the address to bind the signaling WebSocket server. + // Must be a WireGuard IP (10.0.0.x) — never 0.0.0.0. + ListenAddr string `yaml:"listen_addr"` + + // Namespace this SFU instance belongs to + Namespace string `yaml:"namespace"` + + // MediaPortRange defines the UDP port range for RTP media + MediaPortStart int `yaml:"media_port_start"` + MediaPortEnd int `yaml:"media_port_end"` + + // TURN servers this SFU should advertise to peers + TURNServers []TURNServerConfig `yaml:"turn_servers"` + + // TURNSecret is the shared HMAC-SHA1 secret for generating TURN credentials + TURNSecret string `yaml:"turn_secret"` + + // TURNCredentialTTL is the lifetime of TURN credentials in seconds + TURNCredentialTTL int `yaml:"turn_credential_ttl"` + + // RQLiteDSN is the namespace-local RQLite DSN for room state + RQLiteDSN string `yaml:"rqlite_dsn"` +} + +// TURNServerConfig represents a single TURN server endpoint +type TURNServerConfig struct { + Host string `yaml:"host"` // IP or hostname + Port int `yaml:"port"` // UDP port (3478 or 443) +} + +// Validate checks the SFU configuration for errors +func (c *Config) Validate() []error { + var errs []error + + if c.ListenAddr == "" { + errs = append(errs, fmt.Errorf("sfu.listen_addr: must not be empty")) + } + + if c.Namespace == "" { + errs = append(errs, fmt.Errorf("sfu.namespace: must not be empty")) + } + + if c.MediaPortStart <= 0 || c.MediaPortEnd <= 0 { + errs = append(errs, fmt.Errorf("sfu.media_port_range: start and end must be positive")) + } else if c.MediaPortEnd <= c.MediaPortStart { + errs = append(errs, fmt.Errorf("sfu.media_port_range: end (%d) must be greater than start (%d)", c.MediaPortEnd, c.MediaPortStart)) + } + + if len(c.TURNServers) == 0 { + errs = append(errs, fmt.Errorf("sfu.turn_servers: at least one TURN server must be configured")) + } + for i, ts := range c.TURNServers { + if ts.Host == "" { + errs = append(errs, fmt.Errorf("sfu.turn_servers[%d].host: must not be empty", i)) + } + if ts.Port <= 0 || ts.Port > 65535 { + errs = append(errs, fmt.Errorf("sfu.turn_servers[%d].port: must be between 1 and 65535", i)) + } + } + + if c.TURNSecret == "" { + errs = append(errs, fmt.Errorf("sfu.turn_secret: must not be empty")) + } + + if c.TURNCredentialTTL <= 0 { + errs = append(errs, fmt.Errorf("sfu.turn_credential_ttl: must be positive")) + } + + if c.RQLiteDSN == "" { + errs = append(errs, fmt.Errorf("sfu.rqlite_dsn: must not be empty")) + } + + return errs +} diff --git a/pkg/sfu/config_test.go b/pkg/sfu/config_test.go new file mode 100644 index 0000000..1900f16 --- /dev/null +++ b/pkg/sfu/config_test.go @@ -0,0 +1,167 @@ +package sfu + +import "testing" + +func TestConfigValidation(t *testing.T) { + tests := []struct { + name string + config Config + wantErrs int + }{ + { + name: "valid config", + config: Config{ + ListenAddr: "10.0.0.1:8443", + Namespace: "test-ns", + MediaPortStart: 20000, + MediaPortEnd: 20500, + TURNServers: []TURNServerConfig{{Host: "1.2.3.4", Port: 3478}}, + TURNSecret: "secret-key", + TURNCredentialTTL: 600, + RQLiteDSN: "http://10.0.0.1:4001", + }, + wantErrs: 0, + }, + { + name: "valid config with multiple TURN servers", + config: Config{ + ListenAddr: "10.0.0.1:8443", + Namespace: "test-ns", + MediaPortStart: 20000, + MediaPortEnd: 20500, + TURNServers: []TURNServerConfig{ + {Host: "1.2.3.4", Port: 3478}, + {Host: "5.6.7.8", Port: 443}, + }, + TURNSecret: "secret-key", + TURNCredentialTTL: 600, + RQLiteDSN: "http://10.0.0.1:4001", + }, + wantErrs: 0, + }, + { + name: "missing all fields", + config: Config{}, + wantErrs: 7, // listen_addr, namespace, media_port_range, turn_servers, turn_secret, turn_credential_ttl, rqlite_dsn + }, + { + name: "missing listen addr", + config: Config{ + Namespace: "test-ns", + MediaPortStart: 20000, + MediaPortEnd: 20500, + TURNServers: []TURNServerConfig{{Host: "1.2.3.4", Port: 3478}}, + TURNSecret: "secret", + TURNCredentialTTL: 600, + RQLiteDSN: "http://10.0.0.1:4001", + }, + wantErrs: 1, + }, + { + name: "missing namespace", + config: Config{ + ListenAddr: "10.0.0.1:8443", + MediaPortStart: 20000, + MediaPortEnd: 20500, + TURNServers: []TURNServerConfig{{Host: "1.2.3.4", Port: 3478}}, + TURNSecret: "secret", + TURNCredentialTTL: 600, + RQLiteDSN: "http://10.0.0.1:4001", + }, + wantErrs: 1, + }, + { + name: "invalid media port range - inverted", + config: Config{ + ListenAddr: "10.0.0.1:8443", + Namespace: "test-ns", + MediaPortStart: 20500, + MediaPortEnd: 20000, + TURNServers: []TURNServerConfig{{Host: "1.2.3.4", Port: 3478}}, + TURNSecret: "secret", + TURNCredentialTTL: 600, + RQLiteDSN: "http://10.0.0.1:4001", + }, + wantErrs: 1, + }, + { + name: "invalid media port range - zero", + config: Config{ + ListenAddr: "10.0.0.1:8443", + Namespace: "test-ns", + MediaPortStart: 0, + MediaPortEnd: 0, + TURNServers: []TURNServerConfig{{Host: "1.2.3.4", Port: 3478}}, + TURNSecret: "secret", + TURNCredentialTTL: 600, + RQLiteDSN: "http://10.0.0.1:4001", + }, + wantErrs: 1, + }, + { + name: "no TURN servers", + config: Config{ + ListenAddr: "10.0.0.1:8443", + Namespace: "test-ns", + MediaPortStart: 20000, + MediaPortEnd: 20500, + TURNServers: []TURNServerConfig{}, + TURNSecret: "secret", + TURNCredentialTTL: 600, + RQLiteDSN: "http://10.0.0.1:4001", + }, + wantErrs: 1, + }, + { + name: "TURN server with invalid port", + config: Config{ + ListenAddr: "10.0.0.1:8443", + Namespace: "test-ns", + MediaPortStart: 20000, + MediaPortEnd: 20500, + TURNServers: []TURNServerConfig{{Host: "1.2.3.4", Port: 0}}, + TURNSecret: "secret", + TURNCredentialTTL: 600, + RQLiteDSN: "http://10.0.0.1:4001", + }, + wantErrs: 1, + }, + { + name: "TURN server with empty host", + config: Config{ + ListenAddr: "10.0.0.1:8443", + Namespace: "test-ns", + MediaPortStart: 20000, + MediaPortEnd: 20500, + TURNServers: []TURNServerConfig{{Host: "", Port: 3478}}, + TURNSecret: "secret", + TURNCredentialTTL: 600, + RQLiteDSN: "http://10.0.0.1:4001", + }, + wantErrs: 1, + }, + { + name: "negative credential TTL", + config: Config{ + ListenAddr: "10.0.0.1:8443", + Namespace: "test-ns", + MediaPortStart: 20000, + MediaPortEnd: 20500, + TURNServers: []TURNServerConfig{{Host: "1.2.3.4", Port: 3478}}, + TURNSecret: "secret", + TURNCredentialTTL: -1, + RQLiteDSN: "http://10.0.0.1:4001", + }, + wantErrs: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + errs := tt.config.Validate() + if len(errs) != tt.wantErrs { + t.Errorf("Validate() returned %d errors, want %d: %v", len(errs), tt.wantErrs, errs) + } + }) + } +} diff --git a/pkg/sfu/peer.go b/pkg/sfu/peer.go new file mode 100644 index 0000000..07d6186 --- /dev/null +++ b/pkg/sfu/peer.go @@ -0,0 +1,340 @@ +package sfu + +import ( + "encoding/json" + "errors" + "sync" + + "github.com/google/uuid" + "github.com/gorilla/websocket" + "github.com/pion/rtcp" + "github.com/pion/webrtc/v4" + "go.uber.org/zap" +) + +var ( + ErrPeerNotInitialized = errors.New("peer connection not initialized") + ErrPeerClosed = errors.New("peer is closed") + ErrWebSocketClosed = errors.New("websocket connection closed") +) + +// Peer represents a participant in a room with a WebRTC PeerConnection. +type Peer struct { + ID string + UserID string + + pc *webrtc.PeerConnection + conn *websocket.Conn + room *Room + + // Negotiation state machine + negotiationPending bool + batchingTracks bool + negotiationMu sync.Mutex + + // Connection state + closed bool + closedMu sync.RWMutex + connMu sync.Mutex + + logger *zap.Logger + onClose func(*Peer) +} + +// NewPeer creates a new peer +func NewPeer(userID string, conn *websocket.Conn, room *Room, logger *zap.Logger) *Peer { + return &Peer{ + ID: uuid.New().String(), + UserID: userID, + conn: conn, + room: room, + logger: logger.With(zap.String("peer_id", "")), // Updated after ID assigned + } +} + +// InitPeerConnection creates and configures the WebRTC PeerConnection. +func (p *Peer) InitPeerConnection(api *webrtc.API, iceServers []webrtc.ICEServer) error { + pc, err := api.NewPeerConnection(webrtc.Configuration{ + ICEServers: iceServers, + ICETransportPolicy: webrtc.ICETransportPolicyRelay, // Force TURN relay + }) + if err != nil { + return err + } + p.pc = pc + p.logger = p.logger.With(zap.String("peer_id", p.ID)) + + // ICE connection state changes + pc.OnICEConnectionStateChange(func(state webrtc.ICEConnectionState) { + p.logger.Info("ICE state changed", zap.String("state", state.String())) + + switch state { + case webrtc.ICEConnectionStateDisconnected: + // Give 15 seconds to reconnect before removing + go p.handleReconnectTimeout() + case webrtc.ICEConnectionStateFailed, webrtc.ICEConnectionStateClosed: + p.handleDisconnect() + } + }) + + // ICE candidate generation + pc.OnICECandidate(func(candidate *webrtc.ICECandidate) { + if candidate == nil { + return + } + c := candidate.ToJSON() + data := &ICECandidateData{Candidate: c.Candidate} + if c.SDPMid != nil { + data.SDPMid = *c.SDPMid + } + if c.SDPMLineIndex != nil { + data.SDPMLineIndex = *c.SDPMLineIndex + } + if c.UsernameFragment != nil { + data.UsernameFragment = *c.UsernameFragment + } + p.SendMessage(NewServerMessage(MessageTypeICECandidate, data)) + }) + + // Incoming tracks from the client + pc.OnTrack(func(track *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) { + p.logger.Info("Track received", + zap.String("track_id", track.ID()), + zap.String("kind", track.Kind().String()), + zap.String("codec", track.Codec().MimeType)) + + // Read RTCP feedback (PLI/NACK) in background + go p.readRTCP(receiver, track) + + // Forward track to all other peers + p.room.BroadcastTrack(p.ID, track) + }) + + // Negotiation needed — only when stable + pc.OnNegotiationNeeded(func() { + p.negotiationMu.Lock() + if p.batchingTracks { + p.negotiationPending = true + p.negotiationMu.Unlock() + return + } + p.negotiationMu.Unlock() + + if pc.SignalingState() == webrtc.SignalingStateStable { + p.createAndSendOffer() + } else { + p.negotiationMu.Lock() + p.negotiationPending = true + p.negotiationMu.Unlock() + } + }) + + // When state returns to stable, fire pending negotiation + pc.OnSignalingStateChange(func(state webrtc.SignalingState) { + if state == webrtc.SignalingStateStable { + p.negotiationMu.Lock() + pending := p.negotiationPending + p.negotiationPending = false + p.negotiationMu.Unlock() + + if pending { + p.createAndSendOffer() + } + } + }) + + return nil +} + +func (p *Peer) createAndSendOffer() { + if p.pc == nil { + return + } + if p.pc.SignalingState() != webrtc.SignalingStateStable { + p.negotiationMu.Lock() + p.negotiationPending = true + p.negotiationMu.Unlock() + return + } + + offer, err := p.pc.CreateOffer(nil) + if err != nil { + p.logger.Error("Failed to create offer", zap.Error(err)) + return + } + if err := p.pc.SetLocalDescription(offer); err != nil { + p.logger.Error("Failed to set local description", zap.Error(err)) + return + } + p.SendMessage(NewServerMessage(MessageTypeOffer, &OfferData{SDP: offer.SDP})) +} + +// HandleOffer processes an SDP offer from the client +func (p *Peer) HandleOffer(sdp string) error { + if p.pc == nil { + return ErrPeerNotInitialized + } + if err := p.pc.SetRemoteDescription(webrtc.SessionDescription{ + Type: webrtc.SDPTypeOffer, SDP: sdp, + }); err != nil { + return err + } + answer, err := p.pc.CreateAnswer(nil) + if err != nil { + return err + } + if err := p.pc.SetLocalDescription(answer); err != nil { + return err + } + p.SendMessage(NewServerMessage(MessageTypeAnswer, &AnswerData{SDP: answer.SDP})) + return nil +} + +// HandleAnswer processes an SDP answer from the client +func (p *Peer) HandleAnswer(sdp string) error { + if p.pc == nil { + return ErrPeerNotInitialized + } + return p.pc.SetRemoteDescription(webrtc.SessionDescription{ + Type: webrtc.SDPTypeAnswer, SDP: sdp, + }) +} + +// HandleICECandidate adds a remote ICE candidate +func (p *Peer) HandleICECandidate(data *ICECandidateData) error { + if p.pc == nil { + return ErrPeerNotInitialized + } + return p.pc.AddICECandidate(data.ToWebRTCCandidate()) +} + +// AddTrack adds a local track to send to this peer +func (p *Peer) AddTrack(track *webrtc.TrackLocalStaticRTP) (*webrtc.RTPSender, error) { + if p.pc == nil { + return nil, ErrPeerNotInitialized + } + return p.pc.AddTrack(track) +} + +// StartTrackBatch suppresses renegotiation during bulk track additions +func (p *Peer) StartTrackBatch() { + p.negotiationMu.Lock() + p.batchingTracks = true + p.negotiationMu.Unlock() +} + +// EndTrackBatch ends batching and fires deferred renegotiation +func (p *Peer) EndTrackBatch() { + p.negotiationMu.Lock() + p.batchingTracks = false + pending := p.negotiationPending + p.negotiationPending = false + p.negotiationMu.Unlock() + + if pending && p.pc != nil && p.pc.SignalingState() == webrtc.SignalingStateStable { + p.createAndSendOffer() + } +} + +// SendMessage sends a signaling message via WebSocket +func (p *Peer) SendMessage(msg *ServerMessage) error { + p.closedMu.RLock() + if p.closed { + p.closedMu.RUnlock() + return ErrPeerClosed + } + p.closedMu.RUnlock() + + p.connMu.Lock() + defer p.connMu.Unlock() + if p.conn == nil { + return ErrWebSocketClosed + } + + data, err := json.Marshal(msg) + if err != nil { + return err + } + return p.conn.WriteMessage(websocket.TextMessage, data) +} + +// GetInfo returns public info about this peer +func (p *Peer) GetInfo() ParticipantInfo { + return ParticipantInfo{PeerID: p.ID, UserID: p.UserID} +} + +// handleReconnectTimeout waits 15 seconds for ICE reconnection before removing the peer. +func (p *Peer) handleReconnectTimeout() { + // Use a channel that closes when peer state changes + // Check after 15 seconds if still disconnected + <-timeAfter(reconnectTimeout) + + if p.pc == nil { + return + } + state := p.pc.ICEConnectionState() + if state == webrtc.ICEConnectionStateDisconnected || state == webrtc.ICEConnectionStateFailed { + p.logger.Info("Peer did not reconnect within timeout, removing") + p.handleDisconnect() + } +} + +func (p *Peer) handleDisconnect() { + p.closedMu.Lock() + if p.closed { + p.closedMu.Unlock() + return + } + p.closed = true + p.closedMu.Unlock() + + if p.onClose != nil { + p.onClose(p) + } +} + +// Close closes the peer connection and WebSocket +func (p *Peer) Close() error { + p.closedMu.Lock() + if p.closed { + p.closedMu.Unlock() + return nil + } + p.closed = true + p.closedMu.Unlock() + + p.connMu.Lock() + if p.conn != nil { + p.conn.Close() + p.conn = nil + } + p.connMu.Unlock() + + if p.pc != nil { + return p.pc.Close() + } + return nil +} + +// OnClose sets the disconnect callback +func (p *Peer) OnClose(fn func(*Peer)) { + p.onClose = fn +} + +// readRTCP reads RTCP feedback and forwards PLI/FIR to the source peer +func (p *Peer) readRTCP(receiver *webrtc.RTPReceiver, track *webrtc.TrackRemote) { + localTrackID := track.Kind().String() + "-" + p.ID + + for { + packets, _, err := receiver.ReadRTCP() + if err != nil { + return + } + for _, pkt := range packets { + switch pkt.(type) { + case *rtcp.PictureLossIndication, *rtcp.FullIntraRequest: + p.room.RequestKeyframe(localTrackID) + } + } + } +} diff --git a/pkg/sfu/room.go b/pkg/sfu/room.go new file mode 100644 index 0000000..adb4477 --- /dev/null +++ b/pkg/sfu/room.go @@ -0,0 +1,555 @@ +package sfu + +import ( + "errors" + "fmt" + "sync" + "time" + + "github.com/DeBrosOfficial/network/pkg/turn" + "github.com/pion/interceptor" + "github.com/pion/interceptor/pkg/intervalpli" + "github.com/pion/interceptor/pkg/nack" + "github.com/pion/rtcp" + "github.com/pion/webrtc/v4" + "go.uber.org/zap" +) + +// For testing: allows overriding time.After +var timeAfter = func(d time.Duration) <-chan time.Time { return time.After(d) } + +const ( + reconnectTimeout = 15 * time.Second + emptyRoomTTL = 60 * time.Second + rtpBufferSize = 8192 +) + +var ( + ErrRoomFull = errors.New("room is full") + ErrRoomClosed = errors.New("room is closed") + ErrPeerNotFound = errors.New("peer not found") +) + +// publishedTrack holds a local track being forwarded from a remote source. +type publishedTrack struct { + sourcePeerID string + localTrack *webrtc.TrackLocalStaticRTP + remoteTrackSSRC uint32 + kind string +} + +// Room is a WebRTC room with multiple participants sharing media tracks. +type Room struct { + ID string + Namespace string + + peers map[string]*Peer + peersMu sync.RWMutex + + publishedTracks map[string]*publishedTrack // key: localTrack.ID() + publishedTracksMu sync.RWMutex + + api *webrtc.API + config *Config + logger *zap.Logger + + closed bool + closedMu sync.RWMutex + + onEmpty func(*Room) +} + +// RoomManager manages the lifecycle of rooms. +type RoomManager struct { + rooms map[string]*Room // key: roomID + mu sync.RWMutex + config *Config + logger *zap.Logger +} + +// NewRoomManager creates a new room manager. +func NewRoomManager(cfg *Config, logger *zap.Logger) *RoomManager { + return &RoomManager{ + rooms: make(map[string]*Room), + config: cfg, + logger: logger.With(zap.String("component", "room-manager")), + } +} + +// GetOrCreateRoom returns an existing room or creates a new one. +func (rm *RoomManager) GetOrCreateRoom(roomID string) *Room { + rm.mu.Lock() + defer rm.mu.Unlock() + + if room, ok := rm.rooms[roomID]; ok && !room.IsClosed() { + return room + } + + api := newWebRTCAPI(rm.config) + room := &Room{ + ID: roomID, + Namespace: rm.config.Namespace, + peers: make(map[string]*Peer), + publishedTracks: make(map[string]*publishedTrack), + api: api, + config: rm.config, + logger: rm.logger.With(zap.String("room_id", roomID)), + } + + room.onEmpty = func(r *Room) { + // Start empty room cleanup timer + go func() { + <-timeAfter(emptyRoomTTL) + if r.GetParticipantCount() == 0 { + rm.mu.Lock() + delete(rm.rooms, r.ID) + rm.mu.Unlock() + r.Close() + rm.logger.Info("Empty room cleaned up", zap.String("room_id", r.ID)) + } + }() + } + + rm.rooms[roomID] = room + rm.logger.Info("Room created", zap.String("room_id", roomID)) + return room +} + +// GetRoom returns a room by ID, or nil if not found. +func (rm *RoomManager) GetRoom(roomID string) *Room { + rm.mu.RLock() + defer rm.mu.RUnlock() + return rm.rooms[roomID] +} + +// CloseAll closes all rooms (for graceful shutdown). +func (rm *RoomManager) CloseAll() { + rm.mu.Lock() + rooms := make([]*Room, 0, len(rm.rooms)) + for _, r := range rm.rooms { + rooms = append(rooms, r) + } + rm.rooms = make(map[string]*Room) + rm.mu.Unlock() + + for _, r := range rooms { + r.Close() + } +} + +// RoomCount returns the number of active rooms. +func (rm *RoomManager) RoomCount() int { + rm.mu.RLock() + defer rm.mu.RUnlock() + return len(rm.rooms) +} + +// newWebRTCAPI creates a Pion WebRTC API with codecs and interceptors. +func newWebRTCAPI(cfg *Config) *webrtc.API { + m := &webrtc.MediaEngine{} + + // Audio: Opus + videoRTCPFeedback := []webrtc.RTCPFeedback{ + {Type: "goog-remb", Parameter: ""}, + {Type: "ccm", Parameter: "fir"}, + {Type: "nack", Parameter: ""}, + {Type: "nack", Parameter: "pli"}, + } + + _ = m.RegisterCodec(webrtc.RTPCodecParameters{ + RTPCodecCapability: webrtc.RTPCodecCapability{ + MimeType: webrtc.MimeTypeOpus, + ClockRate: 48000, + Channels: 2, + SDPFmtpLine: "minptime=10;useinbandfec=1", + }, + PayloadType: 111, + }, webrtc.RTPCodecTypeAudio) + + // Video: VP8 + _ = m.RegisterCodec(webrtc.RTPCodecParameters{ + RTPCodecCapability: webrtc.RTPCodecCapability{ + MimeType: webrtc.MimeTypeVP8, + ClockRate: 90000, + RTCPFeedback: videoRTCPFeedback, + }, + PayloadType: 96, + }, webrtc.RTPCodecTypeVideo) + + // Video: H264 + _ = m.RegisterCodec(webrtc.RTPCodecParameters{ + RTPCodecCapability: webrtc.RTPCodecCapability{ + MimeType: webrtc.MimeTypeH264, + ClockRate: 90000, + SDPFmtpLine: "level-asymmetry-allowed=1;packetization-mode=1;profile-level-id=42001f", + RTCPFeedback: videoRTCPFeedback, + }, + PayloadType: 125, + }, webrtc.RTPCodecTypeVideo) + + // Interceptors: NACK + PLI + i := &interceptor.Registry{} + if f, err := nack.NewResponderInterceptor(); err == nil { + i.Add(f) + } + if f, err := nack.NewGeneratorInterceptor(); err == nil { + i.Add(f) + } + if f, err := intervalpli.NewReceiverInterceptor(); err == nil { + i.Add(f) + } + + // SettingEngine: restrict media ports + se := webrtc.SettingEngine{} + if cfg.MediaPortStart > 0 && cfg.MediaPortEnd > 0 { + se.SetEphemeralUDPPortRange(uint16(cfg.MediaPortStart), uint16(cfg.MediaPortEnd)) + } + + return webrtc.NewAPI( + webrtc.WithMediaEngine(m), + webrtc.WithInterceptorRegistry(i), + webrtc.WithSettingEngine(se), + ) +} + +// --- Room methods --- + +// AddPeer adds a peer to the room and notifies other participants. +func (r *Room) AddPeer(peer *Peer) error { + r.closedMu.RLock() + if r.closed { + r.closedMu.RUnlock() + return ErrRoomClosed + } + r.closedMu.RUnlock() + + // Build ICE servers for TURN + iceServers := r.buildICEServers() + + r.peersMu.Lock() + if len(r.peers) >= 100 { // Hard cap + r.peersMu.Unlock() + return ErrRoomFull + } + + if err := peer.InitPeerConnection(r.api, iceServers); err != nil { + r.peersMu.Unlock() + return err + } + + peer.OnClose(func(p *Peer) { r.RemovePeer(p.ID) }) + + r.peers[peer.ID] = peer + info := peer.GetInfo() + total := len(r.peers) + r.peersMu.Unlock() + + r.logger.Info("Peer joined", zap.String("peer_id", peer.ID), zap.Int("total", total)) + + // Notify others + r.broadcastMessage(peer.ID, NewServerMessage(MessageTypeParticipantJoined, &ParticipantJoinedData{ + Participant: info, + })) + + return nil +} + +// RemovePeer removes a peer and cleans up their published tracks. +func (r *Room) RemovePeer(peerID string) { + r.peersMu.Lock() + peer, ok := r.peers[peerID] + if !ok { + r.peersMu.Unlock() + return + } + delete(r.peers, peerID) + remaining := len(r.peers) + r.peersMu.Unlock() + + // Remove published tracks from this peer + r.publishedTracksMu.Lock() + var removed []string + for trackID, pt := range r.publishedTracks { + if pt.sourcePeerID == peerID { + delete(r.publishedTracks, trackID) + removed = append(removed, trackID) + } + } + r.publishedTracksMu.Unlock() + + // Remove RTPSenders for this peer's tracks from all other peers + if len(removed) > 0 { + r.removeTrackSendersFromPeers(removed) + } + + peer.Close() + + r.logger.Info("Peer left", zap.String("peer_id", peerID), zap.Int("remaining", remaining)) + + r.broadcastMessage(peerID, NewServerMessage(MessageTypeParticipantLeft, &ParticipantLeftData{ + PeerID: peerID, + })) + + // Notify about removed tracks + for _, trackID := range removed { + r.broadcastMessage(peerID, NewServerMessage(MessageTypeTrackRemoved, &TrackRemovedData{ + PeerID: peerID, + TrackID: trackID, + })) + } + + if remaining == 0 && r.onEmpty != nil { + r.onEmpty(r) + } +} + +// removeTrackSendersFromPeers removes RTPSenders for the given track IDs from all peers. +// This fixes the ghost track bug from the original implementation. +func (r *Room) removeTrackSendersFromPeers(trackIDs []string) { + trackIDSet := make(map[string]bool, len(trackIDs)) + for _, id := range trackIDs { + trackIDSet[id] = true + } + + r.peersMu.RLock() + defer r.peersMu.RUnlock() + + for _, peer := range r.peers { + if peer.pc == nil { + continue + } + for _, sender := range peer.pc.GetSenders() { + if sender.Track() == nil { + continue + } + if trackIDSet[sender.Track().ID()] { + if err := peer.pc.RemoveTrack(sender); err != nil { + r.logger.Warn("Failed to remove track sender", + zap.String("peer_id", peer.ID), + zap.String("track_id", sender.Track().ID()), + zap.Error(err)) + } + } + } + } +} + +// BroadcastTrack creates a local track from a remote track and forwards it to all other peers. +func (r *Room) BroadcastTrack(sourcePeerID string, track *webrtc.TrackRemote) { + codec := track.Codec() + + localTrack, err := webrtc.NewTrackLocalStaticRTP( + codec.RTPCodecCapability, + track.Kind().String()+"-"+sourcePeerID, + sourcePeerID, + ) + if err != nil { + r.logger.Error("Failed to create local track", zap.Error(err)) + return + } + + // Store for future joiners + r.publishedTracksMu.Lock() + r.publishedTracks[localTrack.ID()] = &publishedTrack{ + sourcePeerID: sourcePeerID, + localTrack: localTrack, + remoteTrackSSRC: uint32(track.SSRC()), + kind: track.Kind().String(), + } + r.publishedTracksMu.Unlock() + + // RTP forwarding loop with proper buffer size + go func() { + buf := make([]byte, rtpBufferSize) + for { + n, _, err := track.Read(buf) + if err != nil { + return + } + if _, err := localTrack.Write(buf[:n]); err != nil { + return + } + } + }() + + // Add to all current peers except the source + r.peersMu.RLock() + for peerID, peer := range r.peers { + if peerID == sourcePeerID { + continue + } + if _, err := peer.AddTrack(localTrack); err != nil { + r.logger.Warn("Failed to add track to peer", + zap.String("peer_id", peerID), zap.Error(err)) + continue + } + peer.SendMessage(NewServerMessage(MessageTypeTrackAdded, &TrackAddedData{ + PeerID: sourcePeerID, + TrackID: localTrack.ID(), + StreamID: localTrack.StreamID(), + Kind: track.Kind().String(), + })) + } + r.peersMu.RUnlock() +} + +// SendExistingTracksTo sends all published tracks to a newly joined peer. +// Uses batch mode for a single renegotiation. +func (r *Room) SendExistingTracksTo(peer *Peer) { + r.publishedTracksMu.RLock() + var tracks []*publishedTrack + for _, pt := range r.publishedTracks { + if pt.sourcePeerID != peer.ID { + tracks = append(tracks, pt) + } + } + r.publishedTracksMu.RUnlock() + + if len(tracks) == 0 { + return + } + + peer.StartTrackBatch() + for _, pt := range tracks { + if _, err := peer.AddTrack(pt.localTrack); err != nil { + r.logger.Warn("Failed to add existing track", zap.Error(err)) + continue + } + peer.SendMessage(NewServerMessage(MessageTypeTrackAdded, &TrackAddedData{ + PeerID: pt.sourcePeerID, + TrackID: pt.localTrack.ID(), + StreamID: pt.localTrack.StreamID(), + Kind: pt.kind, + })) + } + peer.EndTrackBatch() + + // Request keyframes for video tracks after negotiation settles + go func() { + <-timeAfter(300 * time.Millisecond) + r.RequestKeyframeForAllVideoTracks() + }() +} + +// RequestKeyframe sends a PLI to the source peer for a video track. +func (r *Room) RequestKeyframe(trackID string) { + r.publishedTracksMu.RLock() + pt, ok := r.publishedTracks[trackID] + r.publishedTracksMu.RUnlock() + if !ok || pt.kind != "video" { + return + } + + r.peersMu.RLock() + source, ok := r.peers[pt.sourcePeerID] + r.peersMu.RUnlock() + if !ok || source.pc == nil { + return + } + + pli := &rtcp.PictureLossIndication{MediaSSRC: pt.remoteTrackSSRC} + if err := source.pc.WriteRTCP([]rtcp.Packet{pli}); err != nil { + r.logger.Debug("Failed to send PLI", zap.String("track_id", trackID), zap.Error(err)) + } +} + +// RequestKeyframeForAllVideoTracks sends PLIs for all video tracks. +func (r *Room) RequestKeyframeForAllVideoTracks() { + r.publishedTracksMu.RLock() + var ids []string + for id, pt := range r.publishedTracks { + if pt.kind == "video" { + ids = append(ids, id) + } + } + r.publishedTracksMu.RUnlock() + + for _, id := range ids { + r.RequestKeyframe(id) + } +} + +// GetParticipants returns info about all participants. +func (r *Room) GetParticipants() []ParticipantInfo { + r.peersMu.RLock() + defer r.peersMu.RUnlock() + infos := make([]ParticipantInfo, 0, len(r.peers)) + for _, p := range r.peers { + infos = append(infos, p.GetInfo()) + } + return infos +} + +// GetParticipantCount returns the number of participants. +func (r *Room) GetParticipantCount() int { + r.peersMu.RLock() + defer r.peersMu.RUnlock() + return len(r.peers) +} + +// IsClosed returns whether the room is closed. +func (r *Room) IsClosed() bool { + r.closedMu.RLock() + defer r.closedMu.RUnlock() + return r.closed +} + +// Close closes the room and all peer connections. +func (r *Room) Close() error { + r.closedMu.Lock() + if r.closed { + r.closedMu.Unlock() + return nil + } + r.closed = true + r.closedMu.Unlock() + + r.peersMu.Lock() + peers := make([]*Peer, 0, len(r.peers)) + for _, p := range r.peers { + peers = append(peers, p) + } + r.peers = make(map[string]*Peer) + r.peersMu.Unlock() + + for _, p := range peers { + p.Close() + } + + r.logger.Info("Room closed") + return nil +} + +func (r *Room) broadcastMessage(excludePeerID string, msg *ServerMessage) { + r.peersMu.RLock() + defer r.peersMu.RUnlock() + for id, peer := range r.peers { + if id == excludePeerID { + continue + } + peer.SendMessage(msg) + } +} + +// buildICEServers constructs ICE server config from TURN settings. +func (r *Room) buildICEServers() []webrtc.ICEServer { + if len(r.config.TURNServers) == 0 || r.config.TURNSecret == "" { + return nil + } + + var urls []string + for _, ts := range r.config.TURNServers { + urls = append(urls, fmt.Sprintf("turn:%s:%d?transport=udp", ts.Host, ts.Port)) + } + + ttl := time.Duration(r.config.TURNCredentialTTL) * time.Second + username, password := turn.GenerateCredentials(r.config.TURNSecret, r.config.Namespace, ttl) + + return []webrtc.ICEServer{ + { + URLs: urls, + Username: username, + Credential: password, + }, + } +} diff --git a/pkg/sfu/room_test.go b/pkg/sfu/room_test.go new file mode 100644 index 0000000..62d1e46 --- /dev/null +++ b/pkg/sfu/room_test.go @@ -0,0 +1,368 @@ +package sfu + +import ( + "net/http" + "net/http/httptest" + "testing" + "time" + + "go.uber.org/zap" +) + +func testConfig() *Config { + return &Config{ + ListenAddr: "10.0.0.1:8443", + Namespace: "test-ns", + MediaPortStart: 20000, + MediaPortEnd: 20500, + TURNServers: []TURNServerConfig{{Host: "1.2.3.4", Port: 3478}}, + TURNSecret: "test-secret-key-32bytes-long!!!!", + TURNCredentialTTL: 600, + RQLiteDSN: "http://10.0.0.1:4001", + } +} + +func testLogger() *zap.Logger { + return zap.NewNop() +} + +// --- RoomManager tests --- + +func TestNewRoomManager(t *testing.T) { + rm := NewRoomManager(testConfig(), testLogger()) + if rm == nil { + t.Fatal("NewRoomManager returned nil") + } + if rm.RoomCount() != 0 { + t.Errorf("RoomCount = %d, want 0", rm.RoomCount()) + } +} + +func TestRoomManagerGetOrCreateRoom(t *testing.T) { + rm := NewRoomManager(testConfig(), testLogger()) + + room1 := rm.GetOrCreateRoom("room-1") + if room1 == nil { + t.Fatal("GetOrCreateRoom returned nil") + } + if room1.ID != "room-1" { + t.Errorf("Room.ID = %q, want %q", room1.ID, "room-1") + } + if room1.Namespace != "test-ns" { + t.Errorf("Room.Namespace = %q, want %q", room1.Namespace, "test-ns") + } + if rm.RoomCount() != 1 { + t.Errorf("RoomCount = %d, want 1", rm.RoomCount()) + } + + // Getting same room returns same instance + room1Again := rm.GetOrCreateRoom("room-1") + if room1 != room1Again { + t.Error("expected same room instance") + } + if rm.RoomCount() != 1 { + t.Errorf("RoomCount = %d, want 1 (same room)", rm.RoomCount()) + } + + // Different room creates new instance + room2 := rm.GetOrCreateRoom("room-2") + if room2 == nil { + t.Fatal("second room is nil") + } + if room2.ID != "room-2" { + t.Errorf("Room.ID = %q, want %q", room2.ID, "room-2") + } + if rm.RoomCount() != 2 { + t.Errorf("RoomCount = %d, want 2", rm.RoomCount()) + } +} + +func TestRoomManagerGetRoom(t *testing.T) { + rm := NewRoomManager(testConfig(), testLogger()) + + // Non-existent room returns nil + room := rm.GetRoom("nonexistent") + if room != nil { + t.Error("expected nil for non-existent room") + } + + // Create a room and retrieve it + rm.GetOrCreateRoom("room-1") + room = rm.GetRoom("room-1") + if room == nil { + t.Fatal("expected non-nil for existing room") + } + if room.ID != "room-1" { + t.Errorf("Room.ID = %q, want %q", room.ID, "room-1") + } +} + +func TestRoomManagerCloseAll(t *testing.T) { + rm := NewRoomManager(testConfig(), testLogger()) + + rm.GetOrCreateRoom("room-1") + rm.GetOrCreateRoom("room-2") + rm.GetOrCreateRoom("room-3") + if rm.RoomCount() != 3 { + t.Fatalf("RoomCount = %d, want 3", rm.RoomCount()) + } + + rm.CloseAll() + if rm.RoomCount() != 0 { + t.Errorf("RoomCount after CloseAll = %d, want 0", rm.RoomCount()) + } +} + +func TestRoomManagerGetOrCreateRoomReplacesClosedRoom(t *testing.T) { + rm := NewRoomManager(testConfig(), testLogger()) + + room1 := rm.GetOrCreateRoom("room-1") + room1.Close() + + // Getting the same room ID after close should create a new room + room1New := rm.GetOrCreateRoom("room-1") + if room1New == room1 { + t.Error("expected new room instance after close") + } + if room1New.IsClosed() { + t.Error("new room should not be closed") + } +} + +// --- Room tests --- + +func TestRoomIsClosed(t *testing.T) { + rm := NewRoomManager(testConfig(), testLogger()) + room := rm.GetOrCreateRoom("room-1") + + if room.IsClosed() { + t.Error("new room should not be closed") + } + + room.Close() + if !room.IsClosed() { + t.Error("room should be closed after Close()") + } +} + +func TestRoomCloseIdempotent(t *testing.T) { + rm := NewRoomManager(testConfig(), testLogger()) + room := rm.GetOrCreateRoom("room-1") + + // Should not panic or error when called multiple times + if err := room.Close(); err != nil { + t.Errorf("first Close() returned error: %v", err) + } + if err := room.Close(); err != nil { + t.Errorf("second Close() returned error: %v", err) + } +} + +func TestRoomGetParticipantsEmpty(t *testing.T) { + rm := NewRoomManager(testConfig(), testLogger()) + room := rm.GetOrCreateRoom("room-1") + + participants := room.GetParticipants() + if len(participants) != 0 { + t.Errorf("Participants count = %d, want 0", len(participants)) + } + if room.GetParticipantCount() != 0 { + t.Errorf("ParticipantCount = %d, want 0", room.GetParticipantCount()) + } +} + +func TestRoomBuildICEServers(t *testing.T) { + rm := NewRoomManager(testConfig(), testLogger()) + room := rm.GetOrCreateRoom("room-1") + + servers := room.buildICEServers() + if len(servers) != 1 { + t.Fatalf("ICE servers count = %d, want 1", len(servers)) + } + if len(servers[0].URLs) != 1 { + t.Fatalf("URLs count = %d, want 1", len(servers[0].URLs)) + } + if servers[0].URLs[0] != "turn:1.2.3.4:3478?transport=udp" { + t.Errorf("URL = %q, want %q", servers[0].URLs[0], "turn:1.2.3.4:3478?transport=udp") + } + if servers[0].Username == "" { + t.Error("Username should not be empty") + } + if servers[0].Credential == "" { + t.Error("Credential should not be empty") + } +} + +func TestRoomBuildICEServersNoTURN(t *testing.T) { + cfg := testConfig() + cfg.TURNServers = nil + + rm := NewRoomManager(cfg, testLogger()) + room := rm.GetOrCreateRoom("room-1") + + servers := room.buildICEServers() + if servers != nil { + t.Errorf("expected nil ICE servers when no TURN configured, got %v", servers) + } +} + +func TestRoomBuildICEServersNoSecret(t *testing.T) { + cfg := testConfig() + cfg.TURNSecret = "" + + rm := NewRoomManager(cfg, testLogger()) + room := rm.GetOrCreateRoom("room-1") + + servers := room.buildICEServers() + if servers != nil { + t.Errorf("expected nil ICE servers when no secret, got %v", servers) + } +} + +func TestRoomBuildICEServersMultipleTURN(t *testing.T) { + cfg := testConfig() + cfg.TURNServers = []TURNServerConfig{ + {Host: "1.2.3.4", Port: 3478}, + {Host: "5.6.7.8", Port: 443}, + } + + rm := NewRoomManager(cfg, testLogger()) + room := rm.GetOrCreateRoom("room-1") + + servers := room.buildICEServers() + if len(servers) != 1 { + t.Fatalf("ICE servers count = %d, want 1", len(servers)) + } + if len(servers[0].URLs) != 2 { + t.Fatalf("URLs count = %d, want 2", len(servers[0].URLs)) + } +} + +// --- Empty room cleanup test --- + +func TestEmptyRoomCleanup(t *testing.T) { + // Override timeAfter for instant timer + origTimeAfter := timeAfter + timeAfter = func(d time.Duration) <-chan time.Time { + ch := make(chan time.Time, 1) + ch <- time.Now() + return ch + } + defer func() { timeAfter = origTimeAfter }() + + rm := NewRoomManager(testConfig(), testLogger()) + room := rm.GetOrCreateRoom("room-1") + + // Trigger the onEmpty callback (which starts cleanup timer) + room.onEmpty(room) + + // Give the goroutine time to execute + time.Sleep(50 * time.Millisecond) + + if rm.RoomCount() != 0 { + t.Errorf("RoomCount = %d, want 0 (should have been cleaned up)", rm.RoomCount()) + } +} + +// --- Server health tests --- + +func TestHealthEndpointOK(t *testing.T) { + cfg := testConfig() + server, err := NewServer(cfg, testLogger()) + if err != nil { + t.Fatalf("NewServer failed: %v", err) + } + + req := httptest.NewRequest("GET", "/health", nil) + w := httptest.NewRecorder() + server.handleHealth(w, req) + + if w.Code != http.StatusOK { + t.Errorf("status = %d, want %d", w.Code, http.StatusOK) + } + body := w.Body.String() + if body != `{"status":"ok","rooms":0}` { + t.Errorf("body = %q, want %q", body, `{"status":"ok","rooms":0}`) + } +} + +func TestHealthEndpointDraining(t *testing.T) { + cfg := testConfig() + server, err := NewServer(cfg, testLogger()) + if err != nil { + t.Fatalf("NewServer failed: %v", err) + } + + // Set draining + server.drainingMu.Lock() + server.draining = true + server.drainingMu.Unlock() + + req := httptest.NewRequest("GET", "/health", nil) + w := httptest.NewRecorder() + server.handleHealth(w, req) + + if w.Code != http.StatusServiceUnavailable { + t.Errorf("status = %d, want %d", w.Code, http.StatusServiceUnavailable) + } + body := w.Body.String() + if body != `{"status":"draining","rooms":0}` { + t.Errorf("body = %q, want %q", body, `{"status":"draining","rooms":0}`) + } +} + +func TestServerDrainSetsFlag(t *testing.T) { + // Override timeAfter for instant timer + origTimeAfter := timeAfter + timeAfter = func(d time.Duration) <-chan time.Time { + ch := make(chan time.Time, 1) + ch <- time.Now() + return ch + } + defer func() { timeAfter = origTimeAfter }() + + cfg := testConfig() + server, err := NewServer(cfg, testLogger()) + if err != nil { + t.Fatalf("NewServer failed: %v", err) + } + + server.Drain(0) + + server.drainingMu.RLock() + draining := server.draining + server.drainingMu.RUnlock() + + if !draining { + t.Error("expected draining to be true after Drain()") + } +} + +func TestServerNewServerValidation(t *testing.T) { + // Invalid config should return error + cfg := &Config{} // Empty = invalid + _, err := NewServer(cfg, testLogger()) + if err == nil { + t.Error("expected error for invalid config") + } +} + +func TestServerSignalEndpointRejectsDraining(t *testing.T) { + cfg := testConfig() + server, err := NewServer(cfg, testLogger()) + if err != nil { + t.Fatalf("NewServer failed: %v", err) + } + + server.drainingMu.Lock() + server.draining = true + server.drainingMu.Unlock() + + req := httptest.NewRequest("GET", "/ws/signal", nil) + w := httptest.NewRecorder() + server.handleSignal(w, req) + + if w.Code != http.StatusServiceUnavailable { + t.Errorf("status = %d, want %d", w.Code, http.StatusServiceUnavailable) + } +} diff --git a/pkg/sfu/server.go b/pkg/sfu/server.go new file mode 100644 index 0000000..49c803f --- /dev/null +++ b/pkg/sfu/server.go @@ -0,0 +1,293 @@ +package sfu + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "sync" + "time" + + "github.com/DeBrosOfficial/network/pkg/turn" + "github.com/gorilla/websocket" + "go.uber.org/zap" +) + +// Server is the SFU HTTP server providing WebSocket signaling and a health endpoint. +// It binds only to a WireGuard IP — never exposed publicly. +type Server struct { + config *Config + roomManager *RoomManager + logger *zap.Logger + httpServer *http.Server + upgrader websocket.Upgrader + draining bool + drainingMu sync.RWMutex +} + +// NewServer creates a new SFU server. +func NewServer(cfg *Config, logger *zap.Logger) (*Server, error) { + if errs := cfg.Validate(); len(errs) > 0 { + return nil, fmt.Errorf("invalid SFU config: %v", errs[0]) + } + + s := &Server{ + config: cfg, + roomManager: NewRoomManager(cfg, logger), + logger: logger.With(zap.String("component", "sfu"), zap.String("namespace", cfg.Namespace)), + upgrader: websocket.Upgrader{ + ReadBufferSize: 4096, + WriteBufferSize: 4096, + CheckOrigin: func(r *http.Request) bool { return true }, // Gateway handles auth + }, + } + + mux := http.NewServeMux() + mux.HandleFunc("/ws/signal", s.handleSignal) + mux.HandleFunc("/health", s.handleHealth) + + s.httpServer = &http.Server{ + Addr: cfg.ListenAddr, + Handler: mux, + ReadHeaderTimeout: 10 * time.Second, + } + + return s, nil +} + +// ListenAndServe starts the HTTP server. Blocks until the server is stopped. +func (s *Server) ListenAndServe() error { + s.logger.Info("SFU server starting", + zap.String("addr", s.config.ListenAddr), + zap.String("namespace", s.config.Namespace)) + return s.httpServer.ListenAndServe() +} + +// Drain initiates graceful drain: notifies all peers, waits, then closes. +func (s *Server) Drain(timeout time.Duration) { + s.drainingMu.Lock() + s.draining = true + s.drainingMu.Unlock() + + s.logger.Info("SFU draining started", zap.Duration("timeout", timeout)) + + // Notify all peers + s.roomManager.mu.RLock() + for _, room := range s.roomManager.rooms { + room.broadcastMessage("", NewServerMessage(MessageTypeServerDraining, &ServerDrainingData{ + Reason: "server shutting down", + TimeoutMs: int(timeout.Milliseconds()), + })) + } + s.roomManager.mu.RUnlock() + + // Wait for timeout, then force close + <-timeAfter(timeout) +} + +// Close shuts down the SFU server. +func (s *Server) Close() error { + s.logger.Info("SFU server shutting down") + s.roomManager.CloseAll() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + return s.httpServer.Shutdown(ctx) +} + +// handleHealth is a simple health check endpoint. +func (s *Server) handleHealth(w http.ResponseWriter, r *http.Request) { + s.drainingMu.RLock() + draining := s.draining + s.drainingMu.RUnlock() + + if draining { + w.WriteHeader(http.StatusServiceUnavailable) + fmt.Fprintf(w, `{"status":"draining","rooms":%d}`, s.roomManager.RoomCount()) + return + } + + w.WriteHeader(http.StatusOK) + fmt.Fprintf(w, `{"status":"ok","rooms":%d}`, s.roomManager.RoomCount()) +} + +// handleSignal upgrades to WebSocket and runs the signaling loop for one peer. +func (s *Server) handleSignal(w http.ResponseWriter, r *http.Request) { + s.drainingMu.RLock() + if s.draining { + s.drainingMu.RUnlock() + http.Error(w, "server draining", http.StatusServiceUnavailable) + return + } + s.drainingMu.RUnlock() + + conn, err := s.upgrader.Upgrade(w, r, nil) + if err != nil { + s.logger.Error("WebSocket upgrade failed", zap.Error(err)) + return + } + + s.logger.Debug("WebSocket connected", zap.String("remote", r.RemoteAddr)) + + // Read the first message — must be a join + conn.SetReadDeadline(time.Now().Add(10 * time.Second)) + _, msgBytes, err := conn.ReadMessage() + if err != nil { + s.logger.Warn("Failed to read join message", zap.Error(err)) + conn.Close() + return + } + conn.SetReadDeadline(time.Time{}) // Clear deadline + + var msg ClientMessage + if err := json.Unmarshal(msgBytes, &msg); err != nil { + conn.WriteMessage(websocket.TextMessage, mustMarshal(NewErrorMessage("invalid_message", "malformed JSON"))) + conn.Close() + return + } + if msg.Type != MessageTypeJoin { + conn.WriteMessage(websocket.TextMessage, mustMarshal(NewErrorMessage("invalid_message", "first message must be join"))) + conn.Close() + return + } + + var joinData JoinData + if err := json.Unmarshal(msg.Data, &joinData); err != nil || joinData.RoomID == "" || joinData.UserID == "" { + conn.WriteMessage(websocket.TextMessage, mustMarshal(NewErrorMessage("invalid_join", "roomId and userId required"))) + conn.Close() + return + } + + room := s.roomManager.GetOrCreateRoom(joinData.RoomID) + peer := NewPeer(joinData.UserID, conn, room, s.logger) + + if err := room.AddPeer(peer); err != nil { + conn.WriteMessage(websocket.TextMessage, mustMarshal(NewErrorMessage("join_failed", err.Error()))) + conn.Close() + return + } + + // Send welcome with current participants + peer.SendMessage(NewServerMessage(MessageTypeWelcome, &WelcomeData{ + PeerID: peer.ID, + RoomID: room.ID, + Participants: room.GetParticipants(), + })) + + // Send TURN credentials + if s.config.TURNSecret != "" && len(s.config.TURNServers) > 0 { + s.sendTURNCredentials(peer) + } + + // Send existing tracks from other peers + room.SendExistingTracksTo(peer) + + // Start credential refresh goroutine + if s.config.TURNCredentialTTL > 0 { + go s.credentialRefreshLoop(peer) + } + + // Signaling read loop + s.signalingLoop(peer, room) +} + +// signalingLoop reads signaling messages from the WebSocket until disconnect. +func (s *Server) signalingLoop(peer *Peer, room *Room) { + defer room.RemovePeer(peer.ID) + + for { + _, msgBytes, err := peer.conn.ReadMessage() + if err != nil { + s.logger.Debug("WebSocket read error", zap.String("peer_id", peer.ID), zap.Error(err)) + return + } + + var msg ClientMessage + if err := json.Unmarshal(msgBytes, &msg); err != nil { + peer.SendMessage(NewErrorMessage("invalid_message", "malformed JSON")) + continue + } + + switch msg.Type { + case MessageTypeOffer: + var data OfferData + if err := json.Unmarshal(msg.Data, &data); err != nil { + peer.SendMessage(NewErrorMessage("invalid_offer", err.Error())) + continue + } + if err := peer.HandleOffer(data.SDP); err != nil { + s.logger.Error("Failed to handle offer", zap.String("peer_id", peer.ID), zap.Error(err)) + peer.SendMessage(NewErrorMessage("offer_failed", err.Error())) + } + + case MessageTypeAnswer: + var data AnswerData + if err := json.Unmarshal(msg.Data, &data); err != nil { + peer.SendMessage(NewErrorMessage("invalid_answer", err.Error())) + continue + } + if err := peer.HandleAnswer(data.SDP); err != nil { + s.logger.Error("Failed to handle answer", zap.String("peer_id", peer.ID), zap.Error(err)) + } + + case MessageTypeICECandidate: + var data ICECandidateData + if err := json.Unmarshal(msg.Data, &data); err != nil { + peer.SendMessage(NewErrorMessage("invalid_candidate", err.Error())) + continue + } + if err := peer.HandleICECandidate(&data); err != nil { + s.logger.Error("Failed to handle ICE candidate", zap.String("peer_id", peer.ID), zap.Error(err)) + } + + case MessageTypeLeave: + s.logger.Info("Peer leaving", zap.String("peer_id", peer.ID)) + return + + default: + peer.SendMessage(NewErrorMessage("unknown_message", fmt.Sprintf("unknown message type: %s", msg.Type))) + } + } +} + +// sendTURNCredentials sends TURN server credentials to a peer. +func (s *Server) sendTURNCredentials(peer *Peer) { + ttl := time.Duration(s.config.TURNCredentialTTL) * time.Second + username, password := turn.GenerateCredentials(s.config.TURNSecret, s.config.Namespace, ttl) + + var uris []string + for _, ts := range s.config.TURNServers { + uris = append(uris, fmt.Sprintf("turn:%s:%d?transport=udp", ts.Host, ts.Port)) + } + + peer.SendMessage(NewServerMessage(MessageTypeTURNCredentials, &TURNCredentialsData{ + Username: username, + Password: password, + TTL: s.config.TURNCredentialTTL, + URIs: uris, + })) +} + +// credentialRefreshLoop sends fresh TURN credentials at 80% of TTL. +func (s *Server) credentialRefreshLoop(peer *Peer) { + refreshInterval := time.Duration(float64(s.config.TURNCredentialTTL)*0.8) * time.Second + + for { + <-timeAfter(refreshInterval) + + peer.closedMu.RLock() + closed := peer.closed + peer.closedMu.RUnlock() + if closed { + return + } + + s.sendTURNCredentials(peer) + s.logger.Debug("Refreshed TURN credentials", zap.String("peer_id", peer.ID)) + } +} + +func mustMarshal(v interface{}) []byte { + data, _ := json.Marshal(v) + return data +} diff --git a/pkg/sfu/signaling.go b/pkg/sfu/signaling.go new file mode 100644 index 0000000..a43c852 --- /dev/null +++ b/pkg/sfu/signaling.go @@ -0,0 +1,144 @@ +package sfu + +import ( + "encoding/json" + + "github.com/pion/webrtc/v4" +) + +// MessageType represents the type of signaling message +type MessageType string + +const ( + // Client → Server + MessageTypeJoin MessageType = "join" + MessageTypeLeave MessageType = "leave" + MessageTypeOffer MessageType = "offer" + MessageTypeAnswer MessageType = "answer" + MessageTypeICECandidate MessageType = "ice-candidate" + + // Server → Client + MessageTypeWelcome MessageType = "welcome" + MessageTypeParticipantJoined MessageType = "participant-joined" + MessageTypeParticipantLeft MessageType = "participant-left" + MessageTypeTrackAdded MessageType = "track-added" + MessageTypeTrackRemoved MessageType = "track-removed" + MessageTypeTURNCredentials MessageType = "turn-credentials" + MessageTypeRefreshCredentials MessageType = "refresh-credentials" + MessageTypeServerDraining MessageType = "server-draining" + MessageTypeError MessageType = "error" +) + +// ClientMessage is a message from client to server +type ClientMessage struct { + Type MessageType `json:"type"` + Data json.RawMessage `json:"data,omitempty"` +} + +// ServerMessage is a message from server to client +type ServerMessage struct { + Type MessageType `json:"type"` + Data interface{} `json:"data,omitempty"` +} + +// JoinData is the payload for join messages +type JoinData struct { + RoomID string `json:"roomId"` + UserID string `json:"userId"` +} + +// OfferData is the payload for SDP offer messages +type OfferData struct { + SDP string `json:"sdp"` +} + +// AnswerData is the payload for SDP answer messages +type AnswerData struct { + SDP string `json:"sdp"` +} + +// ICECandidateData is the payload for ICE candidate messages +type ICECandidateData struct { + Candidate string `json:"candidate"` + SDPMid string `json:"sdpMid,omitempty"` + SDPMLineIndex uint16 `json:"sdpMLineIndex,omitempty"` + UsernameFragment string `json:"usernameFragment,omitempty"` +} + +// ToWebRTCCandidate converts to pion ICECandidateInit +func (c *ICECandidateData) ToWebRTCCandidate() webrtc.ICECandidateInit { + return webrtc.ICECandidateInit{ + Candidate: c.Candidate, + SDPMid: &c.SDPMid, + SDPMLineIndex: &c.SDPMLineIndex, + UsernameFragment: &c.UsernameFragment, + } +} + +// WelcomeData is sent when a peer successfully joins a room +type WelcomeData struct { + PeerID string `json:"peerId"` + RoomID string `json:"roomId"` + Participants []ParticipantInfo `json:"participants"` +} + +// ParticipantInfo is public info about a room participant +type ParticipantInfo struct { + PeerID string `json:"peerId"` + UserID string `json:"userId"` +} + +// ParticipantJoinedData is sent when a new participant joins +type ParticipantJoinedData struct { + Participant ParticipantInfo `json:"participant"` +} + +// ParticipantLeftData is sent when a participant leaves +type ParticipantLeftData struct { + PeerID string `json:"peerId"` +} + +// TrackAddedData is sent when a new track is available +type TrackAddedData struct { + PeerID string `json:"peerId"` + TrackID string `json:"trackId"` + StreamID string `json:"streamId"` + Kind string `json:"kind"` // "audio" or "video" +} + +// TrackRemovedData is sent when a track is removed +type TrackRemovedData struct { + PeerID string `json:"peerId"` + TrackID string `json:"trackId"` + Kind string `json:"kind"` +} + +// TURNCredentialsData provides TURN server credentials +type TURNCredentialsData struct { + Username string `json:"username"` + Password string `json:"password"` + TTL int `json:"ttl"` + URIs []string `json:"uris"` +} + +// ServerDrainingData warns clients the server is shutting down +type ServerDrainingData struct { + Reason string `json:"reason"` + TimeoutMs int `json:"timeoutMs"` +} + +// ErrorData is sent when an error occurs +type ErrorData struct { + Code string `json:"code"` + Message string `json:"message"` +} + +// NewServerMessage creates a new server message +func NewServerMessage(msgType MessageType, data interface{}) *ServerMessage { + return &ServerMessage{Type: msgType, Data: data} +} + +// NewErrorMessage creates a new error message +func NewErrorMessage(code, message string) *ServerMessage { + return NewServerMessage(MessageTypeError, &ErrorData{Code: code, Message: message}) +} diff --git a/pkg/sfu/signaling_test.go b/pkg/sfu/signaling_test.go new file mode 100644 index 0000000..cecbdd8 --- /dev/null +++ b/pkg/sfu/signaling_test.go @@ -0,0 +1,257 @@ +package sfu + +import ( + "encoding/json" + "testing" +) + +func TestClientMessageDeserialization(t *testing.T) { + tests := []struct { + name string + input string + wantType MessageType + wantData bool + }{ + { + name: "join message", + input: `{"type":"join","data":{"roomId":"room-1","userId":"user-1"}}`, + wantType: MessageTypeJoin, + wantData: true, + }, + { + name: "leave message", + input: `{"type":"leave"}`, + wantType: MessageTypeLeave, + wantData: false, + }, + { + name: "offer message", + input: `{"type":"offer","data":{"sdp":"v=0..."}}`, + wantType: MessageTypeOffer, + wantData: true, + }, + { + name: "answer message", + input: `{"type":"answer","data":{"sdp":"v=0..."}}`, + wantType: MessageTypeAnswer, + wantData: true, + }, + { + name: "ice-candidate message", + input: `{"type":"ice-candidate","data":{"candidate":"candidate:1234"}}`, + wantType: MessageTypeICECandidate, + wantData: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var msg ClientMessage + if err := json.Unmarshal([]byte(tt.input), &msg); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + if msg.Type != tt.wantType { + t.Errorf("Type = %q, want %q", msg.Type, tt.wantType) + } + if tt.wantData && msg.Data == nil { + t.Error("expected Data to be non-nil") + } + if !tt.wantData && msg.Data != nil { + t.Error("expected Data to be nil") + } + }) + } +} + +func TestJoinDataDeserialization(t *testing.T) { + input := `{"roomId":"room-abc","userId":"user-xyz"}` + var data JoinData + if err := json.Unmarshal([]byte(input), &data); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + if data.RoomID != "room-abc" { + t.Errorf("RoomID = %q, want %q", data.RoomID, "room-abc") + } + if data.UserID != "user-xyz" { + t.Errorf("UserID = %q, want %q", data.UserID, "user-xyz") + } +} + +func TestServerMessageSerialization(t *testing.T) { + tests := []struct { + name string + msg *ServerMessage + wantKey string + }{ + { + name: "welcome message", + msg: NewServerMessage(MessageTypeWelcome, &WelcomeData{PeerID: "p1", RoomID: "r1"}), + wantKey: "welcome", + }, + { + name: "participant joined", + msg: NewServerMessage(MessageTypeParticipantJoined, &ParticipantJoinedData{Participant: ParticipantInfo{PeerID: "p2", UserID: "u2"}}), + wantKey: "participant-joined", + }, + { + name: "participant left", + msg: NewServerMessage(MessageTypeParticipantLeft, &ParticipantLeftData{PeerID: "p2"}), + wantKey: "participant-left", + }, + { + name: "track added", + msg: NewServerMessage(MessageTypeTrackAdded, &TrackAddedData{PeerID: "p1", TrackID: "t1", StreamID: "s1", Kind: "video"}), + wantKey: "track-added", + }, + { + name: "track removed", + msg: NewServerMessage(MessageTypeTrackRemoved, &TrackRemovedData{PeerID: "p1", TrackID: "t1", Kind: "video"}), + wantKey: "track-removed", + }, + { + name: "TURN credentials", + msg: NewServerMessage(MessageTypeTURNCredentials, &TURNCredentialsData{Username: "u", Password: "p", TTL: 600, URIs: []string{"turn:1.2.3.4:3478"}}), + wantKey: "turn-credentials", + }, + { + name: "server draining", + msg: NewServerMessage(MessageTypeServerDraining, &ServerDrainingData{Reason: "shutdown", TimeoutMs: 30000}), + wantKey: "server-draining", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + data, err := json.Marshal(tt.msg) + if err != nil { + t.Fatalf("failed to marshal: %v", err) + } + + // Verify it roundtrips correctly + var raw map[string]json.RawMessage + if err := json.Unmarshal(data, &raw); err != nil { + t.Fatalf("failed to unmarshal to raw: %v", err) + } + + var msgType string + if err := json.Unmarshal(raw["type"], &msgType); err != nil { + t.Fatalf("failed to unmarshal type: %v", err) + } + if msgType != tt.wantKey { + t.Errorf("type = %q, want %q", msgType, tt.wantKey) + } + if _, ok := raw["data"]; !ok { + t.Error("expected data field in output") + } + }) + } +} + +func TestNewErrorMessage(t *testing.T) { + msg := NewErrorMessage("invalid_offer", "bad SDP") + if msg.Type != MessageTypeError { + t.Errorf("Type = %q, want %q", msg.Type, MessageTypeError) + } + + errData, ok := msg.Data.(*ErrorData) + if !ok { + t.Fatal("Data is not *ErrorData") + } + if errData.Code != "invalid_offer" { + t.Errorf("Code = %q, want %q", errData.Code, "invalid_offer") + } + if errData.Message != "bad SDP" { + t.Errorf("Message = %q, want %q", errData.Message, "bad SDP") + } + + // Verify serialization + data, err := json.Marshal(msg) + if err != nil { + t.Fatalf("failed to marshal: %v", err) + } + result := string(data) + if result == "" { + t.Error("expected non-empty serialized output") + } +} + +func TestICECandidateDataToWebRTCCandidate(t *testing.T) { + data := &ICECandidateData{ + Candidate: "candidate:842163049 1 udp 1677729535 203.0.113.1 3478 typ srflx", + SDPMid: "0", + SDPMLineIndex: 0, + UsernameFragment: "abc123", + } + + candidate := data.ToWebRTCCandidate() + if candidate.Candidate != data.Candidate { + t.Errorf("Candidate = %q, want %q", candidate.Candidate, data.Candidate) + } + if candidate.SDPMid == nil || *candidate.SDPMid != "0" { + t.Error("SDPMid should be pointer to '0'") + } + if candidate.SDPMLineIndex == nil || *candidate.SDPMLineIndex != 0 { + t.Error("SDPMLineIndex should be pointer to 0") + } + if candidate.UsernameFragment == nil || *candidate.UsernameFragment != "abc123" { + t.Error("UsernameFragment should be pointer to 'abc123'") + } +} + +func TestWelcomeDataSerialization(t *testing.T) { + welcome := &WelcomeData{ + PeerID: "peer-123", + RoomID: "room-456", + Participants: []ParticipantInfo{ + {PeerID: "peer-001", UserID: "user-001"}, + {PeerID: "peer-002", UserID: "user-002"}, + }, + } + + data, err := json.Marshal(welcome) + if err != nil { + t.Fatalf("failed to marshal: %v", err) + } + + var result WelcomeData + if err := json.Unmarshal(data, &result); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + if result.PeerID != "peer-123" { + t.Errorf("PeerID = %q, want %q", result.PeerID, "peer-123") + } + if result.RoomID != "room-456" { + t.Errorf("RoomID = %q, want %q", result.RoomID, "room-456") + } + if len(result.Participants) != 2 { + t.Errorf("Participants count = %d, want 2", len(result.Participants)) + } +} + +func TestTURNCredentialsDataSerialization(t *testing.T) { + creds := &TURNCredentialsData{ + Username: "1234567890:test-ns", + Password: "base64password==", + TTL: 600, + URIs: []string{"turn:1.2.3.4:3478?transport=udp", "turn:5.6.7.8:443?transport=udp"}, + } + + data, err := json.Marshal(creds) + if err != nil { + t.Fatalf("failed to marshal: %v", err) + } + + var result TURNCredentialsData + if err := json.Unmarshal(data, &result); err != nil { + t.Fatalf("failed to unmarshal: %v", err) + } + if result.Username != creds.Username { + t.Errorf("Username = %q, want %q", result.Username, creds.Username) + } + if result.TTL != 600 { + t.Errorf("TTL = %d, want 600", result.TTL) + } + if len(result.URIs) != 2 { + t.Errorf("URIs count = %d, want 2", len(result.URIs)) + } +} diff --git a/pkg/systemd/manager.go b/pkg/systemd/manager.go index 384eb7c..4648d99 100644 --- a/pkg/systemd/manager.go +++ b/pkg/systemd/manager.go @@ -17,6 +17,8 @@ const ( ServiceTypeRQLite ServiceType = "rqlite" ServiceTypeOlric ServiceType = "olric" ServiceTypeGateway ServiceType = "gateway" + ServiceTypeSFU ServiceType = "sfu" + ServiceTypeTURN ServiceType = "turn" ) // Manager manages systemd units for namespace services @@ -192,13 +194,33 @@ func (m *Manager) ReloadDaemon() error { return nil } +// serviceExists checks if a namespace service has an env file on disk, +// indicating the service was provisioned for this namespace. +func (m *Manager) serviceExists(namespace string, serviceType ServiceType) bool { + envFile := filepath.Join(m.namespaceBase, namespace, fmt.Sprintf("%s.env", serviceType)) + _, err := os.Stat(envFile) + return err == nil +} + // StopAllNamespaceServices stops all namespace services for a given namespace func (m *Manager) StopAllNamespaceServices(namespace string) error { m.logger.Info("Stopping all namespace services", zap.String("namespace", namespace)) - // Stop in reverse dependency order: Gateway → Olric → RQLite - services := []ServiceType{ServiceTypeGateway, ServiceTypeOlric, ServiceTypeRQLite} - for _, svcType := range services { + // Stop in reverse dependency order: SFU → TURN → Gateway → Olric → RQLite + // SFU and TURN are conditional — only stop if they exist + for _, svcType := range []ServiceType{ServiceTypeSFU, ServiceTypeTURN} { + if m.serviceExists(namespace, svcType) { + if err := m.StopService(namespace, svcType); err != nil { + m.logger.Warn("Failed to stop service", + zap.String("namespace", namespace), + zap.String("service_type", string(svcType)), + zap.Error(err)) + } + } + } + + // Core services always exist + for _, svcType := range []ServiceType{ServiceTypeGateway, ServiceTypeOlric, ServiceTypeRQLite} { if err := m.StopService(namespace, svcType); err != nil { m.logger.Warn("Failed to stop service", zap.String("namespace", namespace), @@ -215,14 +237,22 @@ func (m *Manager) StopAllNamespaceServices(namespace string) error { func (m *Manager) StartAllNamespaceServices(namespace string) error { m.logger.Info("Starting all namespace services", zap.String("namespace", namespace)) - // Start in dependency order: RQLite → Olric → Gateway - services := []ServiceType{ServiceTypeRQLite, ServiceTypeOlric, ServiceTypeGateway} - for _, svcType := range services { + // Start core services in dependency order: RQLite → Olric → Gateway + for _, svcType := range []ServiceType{ServiceTypeRQLite, ServiceTypeOlric, ServiceTypeGateway} { if err := m.StartService(namespace, svcType); err != nil { return fmt.Errorf("failed to start %s service: %w", svcType, err) } } + // Start WebRTC services if provisioned: TURN → SFU + for _, svcType := range []ServiceType{ServiceTypeTURN, ServiceTypeSFU} { + if m.serviceExists(namespace, svcType) { + if err := m.StartService(namespace, svcType); err != nil { + return fmt.Errorf("failed to start %s service: %w", svcType, err) + } + } + } + return nil } @@ -419,6 +449,8 @@ func (m *Manager) InstallTemplateUnits(sourceDir string) error { "orama-namespace-rqlite@.service", "orama-namespace-olric@.service", "orama-namespace-gateway@.service", + "orama-namespace-sfu@.service", + "orama-namespace-turn@.service", } for _, template := range templates { diff --git a/pkg/turn/config.go b/pkg/turn/config.go new file mode 100644 index 0000000..61bd809 --- /dev/null +++ b/pkg/turn/config.go @@ -0,0 +1,71 @@ +package turn + +import ( + "fmt" + "net" +) + +// Config holds configuration for the TURN server +type Config struct { + // ListenAddr is the address to bind the TURN listener (e.g., "0.0.0.0:3478") + ListenAddr string `yaml:"listen_addr"` + + // TLSListenAddr is the address for TURN over TLS/DTLS (e.g., "0.0.0.0:443") + // Uses UDP 443 — requires Caddy HTTP/3 (QUIC) to be disabled to avoid port conflict + TLSListenAddr string `yaml:"tls_listen_addr"` + + // PublicIP is the public IP address of this node, advertised in TURN allocations + PublicIP string `yaml:"public_ip"` + + // Realm is the TURN realm (typically the base domain) + Realm string `yaml:"realm"` + + // AuthSecret is the HMAC-SHA1 shared secret for credential validation + AuthSecret string `yaml:"auth_secret"` + + // RelayPortStart is the beginning of the UDP relay port range + RelayPortStart int `yaml:"relay_port_start"` + + // RelayPortEnd is the end of the UDP relay port range + RelayPortEnd int `yaml:"relay_port_end"` + + // Namespace this TURN instance belongs to + Namespace string `yaml:"namespace"` +} + +// Validate checks the TURN configuration for errors +func (c *Config) Validate() []error { + var errs []error + + if c.ListenAddr == "" { + errs = append(errs, fmt.Errorf("turn.listen_addr: must not be empty")) + } + + if c.PublicIP == "" { + errs = append(errs, fmt.Errorf("turn.public_ip: must not be empty")) + } else if ip := net.ParseIP(c.PublicIP); ip == nil { + errs = append(errs, fmt.Errorf("turn.public_ip: %q is not a valid IP address", c.PublicIP)) + } + + if c.Realm == "" { + errs = append(errs, fmt.Errorf("turn.realm: must not be empty")) + } + + if c.AuthSecret == "" { + errs = append(errs, fmt.Errorf("turn.auth_secret: must not be empty")) + } + + if c.RelayPortStart <= 0 || c.RelayPortEnd <= 0 { + errs = append(errs, fmt.Errorf("turn.relay_port_range: start and end must be positive")) + } else if c.RelayPortEnd <= c.RelayPortStart { + errs = append(errs, fmt.Errorf("turn.relay_port_range: end (%d) must be greater than start (%d)", c.RelayPortEnd, c.RelayPortStart)) + } else if c.RelayPortEnd-c.RelayPortStart < 100 { + errs = append(errs, fmt.Errorf("turn.relay_port_range: range must be at least 100 ports (got %d)", c.RelayPortEnd-c.RelayPortStart)) + } + + if c.Namespace == "" { + errs = append(errs, fmt.Errorf("turn.namespace: must not be empty")) + } + + return errs +} diff --git a/pkg/turn/server.go b/pkg/turn/server.go new file mode 100644 index 0000000..077e5da --- /dev/null +++ b/pkg/turn/server.go @@ -0,0 +1,228 @@ +package turn + +import ( + "crypto/hmac" + "crypto/sha1" + "encoding/base64" + "fmt" + "net" + "strconv" + "strings" + "time" + + pionTurn "github.com/pion/turn/v4" + "go.uber.org/zap" +) + +// Server wraps a Pion TURN server with namespace-scoped HMAC-SHA1 authentication. +type Server struct { + config *Config + logger *zap.Logger + turnServer *pionTurn.Server + conn net.PacketConn // UDP listener on primary port (3478) + tlsConn net.PacketConn // UDP listener on TLS port (443) +} + +// NewServer creates and starts a TURN server. +func NewServer(cfg *Config, logger *zap.Logger) (*Server, error) { + if errs := cfg.Validate(); len(errs) > 0 { + return nil, fmt.Errorf("invalid TURN config: %v", errs[0]) + } + + relayIP := net.ParseIP(cfg.PublicIP) + if relayIP == nil { + return nil, fmt.Errorf("turn.public_ip: %q is not a valid IP address", cfg.PublicIP) + } + + s := &Server{ + config: cfg, + logger: logger.With(zap.String("component", "turn"), zap.String("namespace", cfg.Namespace)), + } + + // Create primary UDP listener (port 3478) + conn, err := net.ListenPacket("udp4", cfg.ListenAddr) + if err != nil { + return nil, fmt.Errorf("failed to listen on %s: %w", cfg.ListenAddr, err) + } + s.conn = conn + + packetConfigs := []pionTurn.PacketConnConfig{ + { + PacketConn: conn, + RelayAddressGenerator: &pionTurn.RelayAddressGeneratorPortRange{ + RelayAddress: relayIP, + Address: "0.0.0.0", + MinPort: uint16(cfg.RelayPortStart), + MaxPort: uint16(cfg.RelayPortEnd), + }, + }, + } + + // Create TLS UDP listener (port 443) if configured + // Requires Caddy HTTP/3 (QUIC) to be disabled to avoid UDP 443 conflict + if cfg.TLSListenAddr != "" { + tlsConn, err := net.ListenPacket("udp4", cfg.TLSListenAddr) + if err != nil { + conn.Close() + return nil, fmt.Errorf("failed to listen on %s: %w", cfg.TLSListenAddr, err) + } + s.tlsConn = tlsConn + + packetConfigs = append(packetConfigs, pionTurn.PacketConnConfig{ + PacketConn: tlsConn, + RelayAddressGenerator: &pionTurn.RelayAddressGeneratorPortRange{ + RelayAddress: relayIP, + Address: "0.0.0.0", + MinPort: uint16(cfg.RelayPortStart), + MaxPort: uint16(cfg.RelayPortEnd), + }, + }) + } + + // Create TURN server with HMAC-SHA1 auth + turnServer, err := pionTurn.NewServer(pionTurn.ServerConfig{ + Realm: cfg.Realm, + AuthHandler: func(username, realm string, srcAddr net.Addr) ([]byte, bool) { + return s.authHandler(username, realm, srcAddr) + }, + PacketConnConfigs: packetConfigs, + }) + if err != nil { + s.closeListeners() + return nil, fmt.Errorf("failed to create TURN server: %w", err) + } + s.turnServer = turnServer + + s.logger.Info("TURN server started", + zap.String("listen_addr", cfg.ListenAddr), + zap.String("tls_listen_addr", cfg.TLSListenAddr), + zap.String("public_ip", cfg.PublicIP), + zap.String("realm", cfg.Realm), + zap.Int("relay_port_start", cfg.RelayPortStart), + zap.Int("relay_port_end", cfg.RelayPortEnd), + ) + + return s, nil +} + +// authHandler validates HMAC-SHA1 credentials. +// Username format: {expiry_unix}:{namespace} +// Password: base64(HMAC-SHA1(shared_secret, username)) +func (s *Server) authHandler(username, realm string, srcAddr net.Addr) ([]byte, bool) { + // Parse username: must be "{timestamp}:{namespace}" + parts := strings.SplitN(username, ":", 2) + if len(parts) != 2 { + s.logger.Debug("Malformed TURN username: expected timestamp:namespace", + zap.String("username", username), + zap.String("src_addr", srcAddr.String())) + return nil, false + } + + timestamp, err := strconv.ParseInt(parts[0], 10, 64) + if err != nil { + s.logger.Debug("Invalid timestamp in TURN username", + zap.String("username", username), + zap.String("src_addr", srcAddr.String())) + return nil, false + } + + ns := parts[1] + + // Verify namespace matches this TURN server's namespace + if ns != s.config.Namespace { + s.logger.Debug("TURN credential namespace mismatch", + zap.String("credential_namespace", ns), + zap.String("server_namespace", s.config.Namespace), + zap.String("src_addr", srcAddr.String())) + return nil, false + } + + // Check expiry — credential must not be expired + if timestamp <= time.Now().Unix() { + s.logger.Debug("TURN credential expired", + zap.String("username", username), + zap.Int64("expired_at", timestamp), + zap.String("src_addr", srcAddr.String())) + return nil, false + } + + // Generate expected password and derive auth key + password := GeneratePassword(s.config.AuthSecret, username) + key := pionTurn.GenerateAuthKey(username, realm, password) + + s.logger.Debug("TURN auth accepted", + zap.String("namespace", ns), + zap.String("src_addr", srcAddr.String())) + + return key, true +} + +// Close gracefully shuts down the TURN server. +func (s *Server) Close() error { + s.logger.Info("Stopping TURN server") + + if s.turnServer != nil { + if err := s.turnServer.Close(); err != nil { + s.logger.Warn("Error closing TURN server", zap.Error(err)) + } + } + + s.closeListeners() + + s.logger.Info("TURN server stopped") + return nil +} + +func (s *Server) closeListeners() { + if s.conn != nil { + s.conn.Close() + s.conn = nil + } + if s.tlsConn != nil { + s.tlsConn.Close() + s.tlsConn = nil + } +} + +// GenerateCredentials creates time-limited HMAC-SHA1 TURN credentials. +// Returns username and password suitable for WebRTC ICE server configuration. +func GenerateCredentials(secret, namespace string, ttl time.Duration) (username, password string) { + expiry := time.Now().Add(ttl).Unix() + username = fmt.Sprintf("%d:%s", expiry, namespace) + password = GeneratePassword(secret, username) + return username, password +} + +// GeneratePassword computes the HMAC-SHA1 password for a TURN username. +func GeneratePassword(secret, username string) string { + h := hmac.New(sha1.New, []byte(secret)) + h.Write([]byte(username)) + return base64.StdEncoding.EncodeToString(h.Sum(nil)) +} + +// ValidateCredentials checks if TURN credentials are valid and not expired. +func ValidateCredentials(secret, username, password, expectedNamespace string) bool { + parts := strings.SplitN(username, ":", 2) + if len(parts) != 2 { + return false + } + + timestamp, err := strconv.ParseInt(parts[0], 10, 64) + if err != nil { + return false + } + + // Check namespace + if parts[1] != expectedNamespace { + return false + } + + // Check expiry + if timestamp <= time.Now().Unix() { + return false + } + + // Check password + expected := GeneratePassword(secret, username) + return hmac.Equal([]byte(password), []byte(expected)) +} diff --git a/pkg/turn/server_test.go b/pkg/turn/server_test.go new file mode 100644 index 0000000..dc7c235 --- /dev/null +++ b/pkg/turn/server_test.go @@ -0,0 +1,225 @@ +package turn + +import ( + "fmt" + "testing" + "time" +) + +func TestGenerateCredentials(t *testing.T) { + secret := "test-secret-key-32bytes-long!!!!" + namespace := "test-namespace" + ttl := 10 * time.Minute + + username, password := GenerateCredentials(secret, namespace, ttl) + + if username == "" { + t.Fatal("username should not be empty") + } + if password == "" { + t.Fatal("password should not be empty") + } + + // Username should be "{timestamp}:{namespace}" + var ts int64 + var ns string + n, err := fmt.Sscanf(username, "%d:%s", &ts, &ns) + if err != nil || n != 2 { + t.Fatalf("username format should be timestamp:namespace, got %q", username) + } + + if ns != namespace { + t.Fatalf("namespace in username should be %q, got %q", namespace, ns) + } + + // Timestamp should be ~10 minutes in the future + now := time.Now().Unix() + expectedExpiry := now + int64(ttl.Seconds()) + if ts < expectedExpiry-2 || ts > expectedExpiry+2 { + t.Fatalf("expiry timestamp should be ~%d, got %d", expectedExpiry, ts) + } +} + +func TestGeneratePassword(t *testing.T) { + secret := "test-secret" + username := "1234567890:test-ns" + + password1 := GeneratePassword(secret, username) + password2 := GeneratePassword(secret, username) + + // Same inputs should produce same output + if password1 != password2 { + t.Fatal("GeneratePassword should be deterministic") + } + + // Different secret should produce different output + password3 := GeneratePassword("different-secret", username) + if password1 == password3 { + t.Fatal("different secrets should produce different passwords") + } + + // Different username should produce different output + password4 := GeneratePassword(secret, "9999999999:other-ns") + if password1 == password4 { + t.Fatal("different usernames should produce different passwords") + } +} + +func TestValidateCredentials(t *testing.T) { + secret := "test-secret-key" + namespace := "my-namespace" + ttl := 10 * time.Minute + + // Generate valid credentials + username, password := GenerateCredentials(secret, namespace, ttl) + + tests := []struct { + name string + secret string + username string + password string + namespace string + wantValid bool + }{ + { + name: "valid credentials", + secret: secret, + username: username, + password: password, + namespace: namespace, + wantValid: true, + }, + { + name: "wrong secret", + secret: "wrong-secret", + username: username, + password: password, + namespace: namespace, + wantValid: false, + }, + { + name: "wrong password", + secret: secret, + username: username, + password: "wrongpassword", + namespace: namespace, + wantValid: false, + }, + { + name: "wrong namespace", + secret: secret, + username: username, + password: password, + namespace: "other-namespace", + wantValid: false, + }, + { + name: "expired credentials", + secret: secret, + username: fmt.Sprintf("%d:%s", time.Now().Unix()-60, namespace), + password: GeneratePassword(secret, fmt.Sprintf("%d:%s", time.Now().Unix()-60, namespace)), + namespace: namespace, + wantValid: false, + }, + { + name: "malformed username - no colon", + secret: secret, + username: "badusername", + password: "whatever", + namespace: namespace, + wantValid: false, + }, + { + name: "malformed username - non-numeric timestamp", + secret: secret, + username: "notanumber:my-namespace", + password: "whatever", + namespace: namespace, + wantValid: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := ValidateCredentials(tt.secret, tt.username, tt.password, tt.namespace) + if got != tt.wantValid { + t.Errorf("ValidateCredentials() = %v, want %v", got, tt.wantValid) + } + }) + } +} + +func TestConfigValidation(t *testing.T) { + tests := []struct { + name string + config Config + wantErrs int + }{ + { + name: "valid config", + config: Config{ + ListenAddr: "0.0.0.0:3478", + PublicIP: "1.2.3.4", + Realm: "dbrs.space", + AuthSecret: "secret123", + RelayPortStart: 49152, + RelayPortEnd: 50000, + Namespace: "test-ns", + }, + wantErrs: 0, + }, + { + name: "missing all fields", + config: Config{}, + wantErrs: 6, // listen_addr, public_ip, realm, auth_secret, relay_port_range, namespace + }, + { + name: "invalid public IP", + config: Config{ + ListenAddr: "0.0.0.0:3478", + PublicIP: "not-an-ip", + Realm: "dbrs.space", + AuthSecret: "secret", + RelayPortStart: 49152, + RelayPortEnd: 50000, + Namespace: "test-ns", + }, + wantErrs: 1, + }, + { + name: "relay range too small", + config: Config{ + ListenAddr: "0.0.0.0:3478", + PublicIP: "1.2.3.4", + Realm: "dbrs.space", + AuthSecret: "secret", + RelayPortStart: 49152, + RelayPortEnd: 49200, + Namespace: "test-ns", + }, + wantErrs: 1, + }, + { + name: "relay range inverted", + config: Config{ + ListenAddr: "0.0.0.0:3478", + PublicIP: "1.2.3.4", + Realm: "dbrs.space", + AuthSecret: "secret", + RelayPortStart: 50000, + RelayPortEnd: 49152, + Namespace: "test-ns", + }, + wantErrs: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + errs := tt.config.Validate() + if len(errs) != tt.wantErrs { + t.Errorf("Validate() returned %d errors, want %d: %v", len(errs), tt.wantErrs, errs) + } + }) + } +} diff --git a/scripts/patches/disable-caddy-http3.sh b/scripts/patches/disable-caddy-http3.sh new file mode 100755 index 0000000..12cc308 --- /dev/null +++ b/scripts/patches/disable-caddy-http3.sh @@ -0,0 +1,49 @@ +#!/usr/bin/env bash +# Patch: Disable HTTP/3 (QUIC) in Caddy to free UDP 443 for TURN server. +# Run on each VPS node. Safe to run multiple times (idempotent). +# +# Usage: sudo bash disable-caddy-http3.sh +set -euo pipefail + +CADDYFILE="/etc/caddy/Caddyfile" + +if [ ! -f "$CADDYFILE" ]; then + echo "ERROR: $CADDYFILE not found" + exit 1 +fi + +# Check if already patched +if grep -q 'protocols h1 h2' "$CADDYFILE"; then + echo "Already patched — Caddyfile already has 'protocols h1 h2'" +else + # The global block looks like: + # { + # email admin@... + # } + # + # Insert 'servers { protocols h1 h2 }' after the email line. + sed -i '/^ email /a\ + servers {\ + protocols h1 h2\ + }' "$CADDYFILE" + echo "Patched Caddyfile — added 'servers { protocols h1 h2 }'" +fi + +# Validate the new config before reloading +if ! caddy validate --config "$CADDYFILE" --adapter caddyfile 2>/dev/null; then + echo "ERROR: Caddyfile validation failed! Reverting..." + sed -i '/^ servers {$/,/^ }$/d' "$CADDYFILE" + exit 1 +fi + +# Reload Caddy (graceful, no downtime) +systemctl reload caddy +echo "Caddy reloaded successfully" + +# Verify UDP 443 is no longer bound by Caddy +sleep 1 +if ss -ulnp | grep -q ':443.*caddy'; then + echo "WARNING: Caddy still binding UDP 443 — reload may need more time" +else + echo "Confirmed: UDP 443 is free for TURN" +fi diff --git a/systemd/orama-namespace-sfu@.service b/systemd/orama-namespace-sfu@.service new file mode 100644 index 0000000..8601626 --- /dev/null +++ b/systemd/orama-namespace-sfu@.service @@ -0,0 +1,32 @@ +[Unit] +Description=Orama Namespace SFU (%i) +Documentation=https://github.com/DeBrosOfficial/network +After=network.target orama-namespace-olric@%i.service +Wants=orama-namespace-olric@%i.service +PartOf=orama-node.service + +[Service] +Type=simple +WorkingDirectory=/opt/orama + +EnvironmentFile=/opt/orama/.orama/data/namespaces/%i/sfu.env + +ExecStart=/bin/sh -c 'exec /opt/orama/bin/sfu --config ${SFU_CONFIG}' + +TimeoutStopSec=45s +KillMode=mixed +KillSignal=SIGTERM + +Restart=on-failure +RestartSec=5s + +StandardOutput=journal +StandardError=journal +SyslogIdentifier=orama-sfu-%i + +PrivateTmp=yes +LimitNOFILE=65536 +MemoryMax=2G + +[Install] +WantedBy=multi-user.target diff --git a/systemd/orama-namespace-turn@.service b/systemd/orama-namespace-turn@.service new file mode 100644 index 0000000..ef337a7 --- /dev/null +++ b/systemd/orama-namespace-turn@.service @@ -0,0 +1,31 @@ +[Unit] +Description=Orama Namespace TURN (%i) +Documentation=https://github.com/DeBrosOfficial/network +After=network.target +PartOf=orama-node.service + +[Service] +Type=simple +WorkingDirectory=/opt/orama + +EnvironmentFile=/opt/orama/.orama/data/namespaces/%i/turn.env + +ExecStart=/bin/sh -c 'exec /opt/orama/bin/turn --config ${TURN_CONFIG}' + +TimeoutStopSec=30s +KillMode=mixed +KillSignal=SIGTERM + +Restart=on-failure +RestartSec=5s + +StandardOutput=journal +StandardError=journal +SyslogIdentifier=orama-turn-%i + +PrivateTmp=yes +LimitNOFILE=65536 +MemoryMax=1G + +[Install] +WantedBy=multi-user.target