feat: implement SFU and TURN server functionality

- Add signaling package with message types and structures for SFU communication.
- Implement client and server message serialization/deserialization tests.
- Enhance systemd manager to handle SFU and TURN services, including start/stop logic.
- Create TURN server configuration and main server logic with HMAC-SHA1 authentication.
- Add tests for TURN server credential generation and validation.
- Define systemd service files for SFU and TURN services.
This commit is contained in:
anonpenguin23 2026-02-21 11:17:13 +02:00
parent 58ea896cb0
commit 8ee606bfb1
49 changed files with 6162 additions and 26 deletions

View File

@ -78,6 +78,8 @@ build: deps
go build -ldflags "$(LDFLAGS)" -o bin/orama ./cmd/cli/ go build -ldflags "$(LDFLAGS)" -o bin/orama ./cmd/cli/
# Inject gateway build metadata via pkg path variables # Inject gateway build metadata via pkg path variables
go build -ldflags "$(LDFLAGS) -X 'github.com/DeBrosOfficial/network/pkg/gateway.BuildVersion=$(VERSION)' -X 'github.com/DeBrosOfficial/network/pkg/gateway.BuildCommit=$(COMMIT)' -X 'github.com/DeBrosOfficial/network/pkg/gateway.BuildTime=$(DATE)'" -o bin/gateway ./cmd/gateway go build -ldflags "$(LDFLAGS) -X 'github.com/DeBrosOfficial/network/pkg/gateway.BuildVersion=$(VERSION)' -X 'github.com/DeBrosOfficial/network/pkg/gateway.BuildCommit=$(COMMIT)' -X 'github.com/DeBrosOfficial/network/pkg/gateway.BuildTime=$(DATE)'" -o bin/gateway ./cmd/gateway
go build -ldflags "$(LDFLAGS)" -o bin/sfu ./cmd/sfu
go build -ldflags "$(LDFLAGS)" -o bin/turn ./cmd/turn
@echo "Build complete! Run ./bin/orama version" @echo "Build complete! Run ./bin/orama version"
# Cross-compile CLI for Linux (only binary needed locally; VPS builds everything else from source) # Cross-compile CLI for Linux (only binary needed locally; VPS builds everything else from source)

View File

@ -69,6 +69,13 @@ func parseGatewayConfig(logger *logging.ColoredLogger) *gateway.Config {
} }
// Load YAML // Load YAML
type yamlWebRTCCfg struct {
Enabled bool `yaml:"enabled"`
SFUPort int `yaml:"sfu_port"`
TURNDomain string `yaml:"turn_domain"`
TURNSecret string `yaml:"turn_secret"`
}
type yamlCfg struct { type yamlCfg struct {
ListenAddr string `yaml:"listen_addr"` ListenAddr string `yaml:"listen_addr"`
ClientNamespace string `yaml:"client_namespace"` ClientNamespace string `yaml:"client_namespace"`
@ -84,6 +91,7 @@ func parseGatewayConfig(logger *logging.ColoredLogger) *gateway.Config {
IPFSAPIURL string `yaml:"ipfs_api_url"` IPFSAPIURL string `yaml:"ipfs_api_url"`
IPFSTimeout string `yaml:"ipfs_timeout"` IPFSTimeout string `yaml:"ipfs_timeout"`
IPFSReplicationFactor int `yaml:"ipfs_replication_factor"` IPFSReplicationFactor int `yaml:"ipfs_replication_factor"`
WebRTC yamlWebRTCCfg `yaml:"webrtc"`
} }
data, err := os.ReadFile(configPath) data, err := os.ReadFile(configPath)
@ -192,6 +200,18 @@ func parseGatewayConfig(logger *logging.ColoredLogger) *gateway.Config {
cfg.IPFSReplicationFactor = y.IPFSReplicationFactor cfg.IPFSReplicationFactor = y.IPFSReplicationFactor
} }
// WebRTC configuration
cfg.WebRTCEnabled = y.WebRTC.Enabled
if y.WebRTC.SFUPort > 0 {
cfg.SFUPort = y.WebRTC.SFUPort
}
if v := strings.TrimSpace(y.WebRTC.TURNDomain); v != "" {
cfg.TURNDomain = v
}
if v := strings.TrimSpace(y.WebRTC.TURNSecret); v != "" {
cfg.TURNSecret = v
}
// Validate configuration // Validate configuration
if errs := cfg.ValidateConfig(); len(errs) > 0 { if errs := cfg.ValidateConfig(); len(errs) > 0 {
fmt.Fprintf(os.Stderr, "\nGateway configuration errors (%d):\n", len(errs)) fmt.Fprintf(os.Stderr, "\nGateway configuration errors (%d):\n", len(errs))

116
cmd/sfu/config.go Normal file
View File

@ -0,0 +1,116 @@
package main
import (
"flag"
"fmt"
"os"
"path/filepath"
"strings"
"github.com/DeBrosOfficial/network/pkg/config"
"github.com/DeBrosOfficial/network/pkg/logging"
"github.com/DeBrosOfficial/network/pkg/sfu"
"go.uber.org/zap"
)
// newSFUServer creates a new SFU server from config and logger.
// Wrapper to keep main.go clean and avoid importing sfu in main.
func newSFUServer(cfg *sfu.Config, logger *zap.Logger) (*sfu.Server, error) {
return sfu.NewServer(cfg, logger)
}
func parseSFUConfig(logger *logging.ColoredLogger) *sfu.Config {
configFlag := flag.String("config", "", "Config file path (absolute path or filename in ~/.orama)")
flag.Parse()
var configPath string
var err error
if *configFlag != "" {
if filepath.IsAbs(*configFlag) {
configPath = *configFlag
} else {
configPath, err = config.DefaultPath(*configFlag)
if err != nil {
logger.ComponentError(logging.ComponentSFU, "Failed to determine config path", zap.Error(err))
fmt.Fprintf(os.Stderr, "Configuration error: %v\n", err)
os.Exit(1)
}
}
} else {
configPath, err = config.DefaultPath("sfu.yaml")
if err != nil {
logger.ComponentError(logging.ComponentSFU, "Failed to determine config path", zap.Error(err))
fmt.Fprintf(os.Stderr, "Configuration error: %v\n", err)
os.Exit(1)
}
}
type yamlTURNServer struct {
Host string `yaml:"host"`
Port int `yaml:"port"`
}
type yamlCfg struct {
ListenAddr string `yaml:"listen_addr"`
Namespace string `yaml:"namespace"`
MediaPortStart int `yaml:"media_port_start"`
MediaPortEnd int `yaml:"media_port_end"`
TURNServers []yamlTURNServer `yaml:"turn_servers"`
TURNSecret string `yaml:"turn_secret"`
TURNCredentialTTL int `yaml:"turn_credential_ttl"`
RQLiteDSN string `yaml:"rqlite_dsn"`
}
data, err := os.ReadFile(configPath)
if err != nil {
logger.ComponentError(logging.ComponentSFU, "Config file not found",
zap.String("path", configPath), zap.Error(err))
fmt.Fprintf(os.Stderr, "\nConfig file not found at %s\n", configPath)
os.Exit(1)
}
var y yamlCfg
if err := config.DecodeStrict(strings.NewReader(string(data)), &y); err != nil {
logger.ComponentError(logging.ComponentSFU, "Failed to parse SFU config", zap.Error(err))
fmt.Fprintf(os.Stderr, "Configuration parse error: %v\n", err)
os.Exit(1)
}
var turnServers []sfu.TURNServerConfig
for _, ts := range y.TURNServers {
turnServers = append(turnServers, sfu.TURNServerConfig{
Host: ts.Host,
Port: ts.Port,
})
}
cfg := &sfu.Config{
ListenAddr: y.ListenAddr,
Namespace: y.Namespace,
MediaPortStart: y.MediaPortStart,
MediaPortEnd: y.MediaPortEnd,
TURNServers: turnServers,
TURNSecret: y.TURNSecret,
TURNCredentialTTL: y.TURNCredentialTTL,
RQLiteDSN: y.RQLiteDSN,
}
if errs := cfg.Validate(); len(errs) > 0 {
fmt.Fprintf(os.Stderr, "\nSFU configuration errors (%d):\n", len(errs))
for _, e := range errs {
fmt.Fprintf(os.Stderr, " - %s\n", e)
}
fmt.Fprintf(os.Stderr, "\nPlease fix the configuration and try again.\n")
os.Exit(1)
}
logger.ComponentInfo(logging.ComponentSFU, "Loaded SFU configuration",
zap.String("path", configPath),
zap.String("listen_addr", cfg.ListenAddr),
zap.String("namespace", cfg.Namespace),
zap.Int("media_ports", cfg.MediaPortEnd-cfg.MediaPortStart),
zap.Int("turn_servers", len(cfg.TURNServers)),
)
return cfg
}

59
cmd/sfu/main.go Normal file
View File

@ -0,0 +1,59 @@
package main
import (
"os"
"os/signal"
"syscall"
"time"
"github.com/DeBrosOfficial/network/pkg/logging"
"go.uber.org/zap"
)
var (
version = "dev"
commit = "unknown"
)
func main() {
logger, err := logging.NewColoredLogger(logging.ComponentSFU, true)
if err != nil {
panic(err)
}
logger.ComponentInfo(logging.ComponentSFU, "Starting SFU server",
zap.String("version", version),
zap.String("commit", commit))
cfg := parseSFUConfig(logger)
server, err := newSFUServer(cfg, logger.Logger)
if err != nil {
logger.ComponentError(logging.ComponentSFU, "Failed to create SFU server", zap.Error(err))
os.Exit(1)
}
// Start HTTP server in background
go func() {
if err := server.ListenAndServe(); err != nil {
logger.ComponentError(logging.ComponentSFU, "SFU server error", zap.Error(err))
os.Exit(1)
}
}()
// Wait for termination signal
quit := make(chan os.Signal, 1)
signal.Notify(quit, os.Interrupt, syscall.SIGTERM)
sig := <-quit
logger.ComponentInfo(logging.ComponentSFU, "Shutdown signal received", zap.String("signal", sig.String()))
// Graceful drain: notify peers and wait
server.Drain(30 * time.Second)
if err := server.Close(); err != nil {
logger.ComponentError(logging.ComponentSFU, "Error during shutdown", zap.Error(err))
}
logger.ComponentInfo(logging.ComponentSFU, "SFU server shutdown complete")
}

96
cmd/turn/config.go Normal file
View File

@ -0,0 +1,96 @@
package main
import (
"flag"
"fmt"
"os"
"path/filepath"
"strings"
"github.com/DeBrosOfficial/network/pkg/config"
"github.com/DeBrosOfficial/network/pkg/logging"
"github.com/DeBrosOfficial/network/pkg/turn"
"go.uber.org/zap"
)
func parseTURNConfig(logger *logging.ColoredLogger) *turn.Config {
configFlag := flag.String("config", "", "Config file path (absolute path or filename in ~/.orama)")
flag.Parse()
var configPath string
var err error
if *configFlag != "" {
if filepath.IsAbs(*configFlag) {
configPath = *configFlag
} else {
configPath, err = config.DefaultPath(*configFlag)
if err != nil {
logger.ComponentError(logging.ComponentTURN, "Failed to determine config path", zap.Error(err))
fmt.Fprintf(os.Stderr, "Configuration error: %v\n", err)
os.Exit(1)
}
}
} else {
configPath, err = config.DefaultPath("turn.yaml")
if err != nil {
logger.ComponentError(logging.ComponentTURN, "Failed to determine config path", zap.Error(err))
fmt.Fprintf(os.Stderr, "Configuration error: %v\n", err)
os.Exit(1)
}
}
type yamlCfg struct {
ListenAddr string `yaml:"listen_addr"`
TLSListenAddr string `yaml:"tls_listen_addr"`
PublicIP string `yaml:"public_ip"`
Realm string `yaml:"realm"`
AuthSecret string `yaml:"auth_secret"`
RelayPortStart int `yaml:"relay_port_start"`
RelayPortEnd int `yaml:"relay_port_end"`
Namespace string `yaml:"namespace"`
}
data, err := os.ReadFile(configPath)
if err != nil {
logger.ComponentError(logging.ComponentTURN, "Config file not found",
zap.String("path", configPath), zap.Error(err))
fmt.Fprintf(os.Stderr, "\nConfig file not found at %s\n", configPath)
os.Exit(1)
}
var y yamlCfg
if err := config.DecodeStrict(strings.NewReader(string(data)), &y); err != nil {
logger.ComponentError(logging.ComponentTURN, "Failed to parse TURN config", zap.Error(err))
fmt.Fprintf(os.Stderr, "Configuration parse error: %v\n", err)
os.Exit(1)
}
cfg := &turn.Config{
ListenAddr: y.ListenAddr,
TLSListenAddr: y.TLSListenAddr,
PublicIP: y.PublicIP,
Realm: y.Realm,
AuthSecret: y.AuthSecret,
RelayPortStart: y.RelayPortStart,
RelayPortEnd: y.RelayPortEnd,
Namespace: y.Namespace,
}
if errs := cfg.Validate(); len(errs) > 0 {
fmt.Fprintf(os.Stderr, "\nTURN configuration errors (%d):\n", len(errs))
for _, e := range errs {
fmt.Fprintf(os.Stderr, " - %s\n", e)
}
fmt.Fprintf(os.Stderr, "\nPlease fix the configuration and try again.\n")
os.Exit(1)
}
logger.ComponentInfo(logging.ComponentTURN, "Loaded TURN configuration",
zap.String("path", configPath),
zap.String("listen_addr", cfg.ListenAddr),
zap.String("namespace", cfg.Namespace),
zap.String("realm", cfg.Realm),
)
return cfg
}

48
cmd/turn/main.go Normal file
View File

@ -0,0 +1,48 @@
package main
import (
"os"
"os/signal"
"syscall"
"github.com/DeBrosOfficial/network/pkg/logging"
"github.com/DeBrosOfficial/network/pkg/turn"
"go.uber.org/zap"
)
var (
version = "dev"
commit = "unknown"
)
func main() {
logger, err := logging.NewColoredLogger(logging.ComponentTURN, true)
if err != nil {
panic(err)
}
logger.ComponentInfo(logging.ComponentTURN, "Starting TURN server",
zap.String("version", version),
zap.String("commit", commit))
cfg := parseTURNConfig(logger)
server, err := turn.NewServer(cfg, logger.Logger)
if err != nil {
logger.ComponentError(logging.ComponentTURN, "Failed to start TURN server", zap.Error(err))
os.Exit(1)
}
// Wait for termination signal
quit := make(chan os.Signal, 1)
signal.Notify(quit, os.Interrupt, syscall.SIGTERM)
sig := <-quit
logger.ComponentInfo(logging.ComponentTURN, "Shutdown signal received", zap.String("signal", sig.String()))
if err := server.Close(); err != nil {
logger.ComponentError(logging.ComponentTURN, "Error during shutdown", zap.Error(err))
}
logger.ComponentInfo(logging.ComponentTURN, "TURN server shutdown complete")
}

View File

@ -0,0 +1,96 @@
-- Migration 018: WebRTC Services (SFU + TURN) for Namespace Clusters
-- Adds per-namespace WebRTC configuration, room tracking, and port allocation
-- WebRTC is opt-in: enabled via `orama namespace enable webrtc`
BEGIN;
-- Per-namespace WebRTC configuration
-- One row per namespace that has WebRTC enabled
CREATE TABLE IF NOT EXISTS namespace_webrtc_config (
id TEXT PRIMARY KEY, -- UUID
namespace_cluster_id TEXT NOT NULL UNIQUE, -- FK to namespace_clusters
namespace_name TEXT NOT NULL, -- Cached for easier lookups
enabled INTEGER NOT NULL DEFAULT 1, -- 1 = enabled, 0 = disabled
-- TURN authentication
turn_shared_secret TEXT NOT NULL, -- HMAC-SHA1 shared secret (base64, 32 bytes)
turn_credential_ttl INTEGER NOT NULL DEFAULT 600, -- Credential TTL in seconds (default: 10 min)
-- Service topology
sfu_node_count INTEGER NOT NULL DEFAULT 3, -- SFU instances (all 3 nodes)
turn_node_count INTEGER NOT NULL DEFAULT 2, -- TURN instances (2 of 3 nodes for HA)
-- Metadata
enabled_by TEXT NOT NULL, -- Wallet address that enabled WebRTC
enabled_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
disabled_at TIMESTAMP,
FOREIGN KEY (namespace_cluster_id) REFERENCES namespace_clusters(id) ON DELETE CASCADE
);
CREATE INDEX IF NOT EXISTS idx_webrtc_config_namespace ON namespace_webrtc_config(namespace_name);
CREATE INDEX IF NOT EXISTS idx_webrtc_config_cluster ON namespace_webrtc_config(namespace_cluster_id);
-- WebRTC room tracking
-- Tracks active rooms and their SFU node affinity
CREATE TABLE IF NOT EXISTS webrtc_rooms (
id TEXT PRIMARY KEY, -- UUID
namespace_cluster_id TEXT NOT NULL, -- FK to namespace_clusters
namespace_name TEXT NOT NULL, -- Cached for easier lookups
room_id TEXT NOT NULL, -- Application-defined room identifier
-- SFU affinity
sfu_node_id TEXT NOT NULL, -- Node hosting this room's SFU
sfu_internal_ip TEXT NOT NULL, -- WireGuard IP of SFU node
sfu_signaling_port INTEGER NOT NULL, -- SFU WebSocket signaling port
-- Room state
participant_count INTEGER NOT NULL DEFAULT 0,
max_participants INTEGER NOT NULL DEFAULT 100,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
last_activity TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
-- Prevent duplicate rooms within a namespace
UNIQUE(namespace_cluster_id, room_id),
FOREIGN KEY (namespace_cluster_id) REFERENCES namespace_clusters(id) ON DELETE CASCADE
);
CREATE INDEX IF NOT EXISTS idx_webrtc_rooms_namespace ON webrtc_rooms(namespace_name);
CREATE INDEX IF NOT EXISTS idx_webrtc_rooms_node ON webrtc_rooms(sfu_node_id);
CREATE INDEX IF NOT EXISTS idx_webrtc_rooms_activity ON webrtc_rooms(last_activity);
-- WebRTC port allocations
-- Separate from namespace_port_allocations to avoid breaking existing port blocks
-- Each namespace gets SFU + TURN ports on each node where those services run
CREATE TABLE IF NOT EXISTS webrtc_port_allocations (
id TEXT PRIMARY KEY, -- UUID
node_id TEXT NOT NULL, -- Physical node ID
namespace_cluster_id TEXT NOT NULL, -- FK to namespace_clusters
service_type TEXT NOT NULL, -- 'sfu' or 'turn'
-- SFU ports (when service_type = 'sfu')
sfu_signaling_port INTEGER, -- WebSocket signaling port
sfu_media_port_start INTEGER, -- Start of RTP media port range
sfu_media_port_end INTEGER, -- End of RTP media port range
-- TURN ports (when service_type = 'turn')
turn_listen_port INTEGER, -- TURN listener port (3478)
turn_tls_port INTEGER, -- TURN TLS port (443/UDP)
turn_relay_port_start INTEGER, -- Start of relay port range
turn_relay_port_end INTEGER, -- End of relay port range
allocated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
-- Prevent overlapping allocations
UNIQUE(node_id, namespace_cluster_id, service_type),
FOREIGN KEY (namespace_cluster_id) REFERENCES namespace_clusters(id) ON DELETE CASCADE
);
CREATE INDEX IF NOT EXISTS idx_webrtc_ports_node ON webrtc_port_allocations(node_id);
CREATE INDEX IF NOT EXISTS idx_webrtc_ports_cluster ON webrtc_port_allocations(namespace_cluster_id);
CREATE INDEX IF NOT EXISTS idx_webrtc_ports_type ON webrtc_port_allocations(service_type);
-- Mark migration as applied
INSERT OR IGNORE INTO schema_migrations(version) VALUES (18);
COMMIT;

View File

@ -37,6 +37,30 @@ func HandleNamespaceCommand(args []string) {
os.Exit(1) os.Exit(1)
} }
handleNamespaceRepair(args[1]) handleNamespaceRepair(args[1])
case "enable":
if len(args) < 2 {
fmt.Fprintf(os.Stderr, "Usage: orama namespace enable <feature> --namespace <name>\n")
fmt.Fprintf(os.Stderr, "Features: webrtc\n")
os.Exit(1)
}
handleNamespaceEnable(args[1:])
case "disable":
if len(args) < 2 {
fmt.Fprintf(os.Stderr, "Usage: orama namespace disable <feature> --namespace <name>\n")
fmt.Fprintf(os.Stderr, "Features: webrtc\n")
os.Exit(1)
}
handleNamespaceDisable(args[1:])
case "webrtc-status":
var ns string
fs := flag.NewFlagSet("namespace webrtc-status", flag.ExitOnError)
fs.StringVar(&ns, "namespace", "", "Namespace name")
_ = fs.Parse(args[1:])
if ns == "" {
fmt.Fprintf(os.Stderr, "Usage: orama namespace webrtc-status --namespace <name>\n")
os.Exit(1)
}
handleNamespaceWebRTCStatus(ns)
case "help": case "help":
showNamespaceHelp() showNamespaceHelp()
default: default:
@ -50,17 +74,24 @@ func showNamespaceHelp() {
fmt.Printf("Namespace Management Commands\n\n") fmt.Printf("Namespace Management Commands\n\n")
fmt.Printf("Usage: orama namespace <subcommand>\n\n") fmt.Printf("Usage: orama namespace <subcommand>\n\n")
fmt.Printf("Subcommands:\n") fmt.Printf("Subcommands:\n")
fmt.Printf(" list - List namespaces owned by the current wallet\n") fmt.Printf(" list - List namespaces owned by the current wallet\n")
fmt.Printf(" delete - Delete the current namespace and all its resources\n") fmt.Printf(" delete - Delete the current namespace and all its resources\n")
fmt.Printf(" repair <namespace> - Repair an under-provisioned namespace cluster (add missing nodes)\n") fmt.Printf(" repair <namespace> - Repair an under-provisioned namespace cluster\n")
fmt.Printf(" help - Show this help message\n\n") fmt.Printf(" enable webrtc --namespace NS - Enable WebRTC (SFU + TURN) for a namespace\n")
fmt.Printf(" disable webrtc --namespace NS - Disable WebRTC for a namespace\n")
fmt.Printf(" webrtc-status --namespace NS - Show WebRTC service status\n")
fmt.Printf(" help - Show this help message\n\n")
fmt.Printf("Flags:\n") fmt.Printf("Flags:\n")
fmt.Printf(" --force - Skip confirmation prompt (delete only)\n\n") fmt.Printf(" --force - Skip confirmation prompt (delete only)\n")
fmt.Printf(" --namespace - Namespace name (enable/disable/webrtc-status)\n\n")
fmt.Printf("Examples:\n") fmt.Printf("Examples:\n")
fmt.Printf(" orama namespace list\n") fmt.Printf(" orama namespace list\n")
fmt.Printf(" orama namespace delete\n") fmt.Printf(" orama namespace delete\n")
fmt.Printf(" orama namespace delete --force\n") fmt.Printf(" orama namespace delete --force\n")
fmt.Printf(" orama namespace repair anchat\n") fmt.Printf(" orama namespace repair anchat\n")
fmt.Printf(" orama namespace enable webrtc --namespace myapp\n")
fmt.Printf(" orama namespace disable webrtc --namespace myapp\n")
fmt.Printf(" orama namespace webrtc-status --namespace myapp\n")
} }
func handleNamespaceRepair(namespaceName string) { func handleNamespaceRepair(namespaceName string) {
@ -193,6 +224,165 @@ func handleNamespaceDelete(force bool) {
fmt.Printf("Run 'orama auth login' to create a new namespace.\n") fmt.Printf("Run 'orama auth login' to create a new namespace.\n")
} }
func handleNamespaceEnable(args []string) {
feature := args[0]
if feature != "webrtc" {
fmt.Fprintf(os.Stderr, "Unknown feature: %s\nSupported features: webrtc\n", feature)
os.Exit(1)
}
var ns string
fs := flag.NewFlagSet("namespace enable webrtc", flag.ExitOnError)
fs.StringVar(&ns, "namespace", "", "Namespace name")
_ = fs.Parse(args[1:])
if ns == "" {
fmt.Fprintf(os.Stderr, "Usage: orama namespace enable webrtc --namespace <name>\n")
os.Exit(1)
}
fmt.Printf("Enabling WebRTC for namespace '%s'...\n", ns)
fmt.Printf("This will provision SFU (3 nodes) and TURN (2 nodes) services.\n")
url := fmt.Sprintf("http://localhost:%d/v1/internal/namespace/webrtc/enable?namespace=%s", constants.GatewayAPIPort, ns)
req, err := http.NewRequest(http.MethodPost, url, nil)
if err != nil {
fmt.Fprintf(os.Stderr, "Failed to create request: %v\n", err)
os.Exit(1)
}
req.Header.Set("X-Orama-Internal-Auth", "namespace-coordination")
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
fmt.Fprintf(os.Stderr, "Failed to connect to local gateway (is the node running?): %v\n", err)
os.Exit(1)
}
defer resp.Body.Close()
var result map[string]interface{}
json.NewDecoder(resp.Body).Decode(&result)
if resp.StatusCode != http.StatusOK {
errMsg := "unknown error"
if e, ok := result["error"].(string); ok {
errMsg = e
}
fmt.Fprintf(os.Stderr, "Failed to enable WebRTC: %s\n", errMsg)
os.Exit(1)
}
fmt.Printf("WebRTC enabled for namespace '%s'.\n", ns)
fmt.Printf(" SFU instances: 3 nodes (signaling via WireGuard)\n")
fmt.Printf(" TURN instances: 2 nodes (relay on public IPs)\n")
}
func handleNamespaceDisable(args []string) {
feature := args[0]
if feature != "webrtc" {
fmt.Fprintf(os.Stderr, "Unknown feature: %s\nSupported features: webrtc\n", feature)
os.Exit(1)
}
var ns string
fs := flag.NewFlagSet("namespace disable webrtc", flag.ExitOnError)
fs.StringVar(&ns, "namespace", "", "Namespace name")
_ = fs.Parse(args[1:])
if ns == "" {
fmt.Fprintf(os.Stderr, "Usage: orama namespace disable webrtc --namespace <name>\n")
os.Exit(1)
}
fmt.Printf("Disabling WebRTC for namespace '%s'...\n", ns)
url := fmt.Sprintf("http://localhost:%d/v1/internal/namespace/webrtc/disable?namespace=%s", constants.GatewayAPIPort, ns)
req, err := http.NewRequest(http.MethodPost, url, nil)
if err != nil {
fmt.Fprintf(os.Stderr, "Failed to create request: %v\n", err)
os.Exit(1)
}
req.Header.Set("X-Orama-Internal-Auth", "namespace-coordination")
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
fmt.Fprintf(os.Stderr, "Failed to connect to local gateway (is the node running?): %v\n", err)
os.Exit(1)
}
defer resp.Body.Close()
var result map[string]interface{}
json.NewDecoder(resp.Body).Decode(&result)
if resp.StatusCode != http.StatusOK {
errMsg := "unknown error"
if e, ok := result["error"].(string); ok {
errMsg = e
}
fmt.Fprintf(os.Stderr, "Failed to disable WebRTC: %s\n", errMsg)
os.Exit(1)
}
fmt.Printf("WebRTC disabled for namespace '%s'.\n", ns)
fmt.Printf(" SFU and TURN services stopped, ports deallocated, DNS records removed.\n")
}
func handleNamespaceWebRTCStatus(ns string) {
url := fmt.Sprintf("http://localhost:%d/v1/internal/namespace/webrtc/status?namespace=%s", constants.GatewayAPIPort, ns)
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
fmt.Fprintf(os.Stderr, "Failed to create request: %v\n", err)
os.Exit(1)
}
req.Header.Set("X-Orama-Internal-Auth", "namespace-coordination")
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
fmt.Fprintf(os.Stderr, "Failed to connect to local gateway (is the node running?): %v\n", err)
os.Exit(1)
}
defer resp.Body.Close()
var result map[string]interface{}
json.NewDecoder(resp.Body).Decode(&result)
if resp.StatusCode != http.StatusOK {
errMsg := "unknown error"
if e, ok := result["error"].(string); ok {
errMsg = e
}
fmt.Fprintf(os.Stderr, "Failed to get WebRTC status: %s\n", errMsg)
os.Exit(1)
}
enabled, _ := result["enabled"].(bool)
if !enabled {
fmt.Printf("WebRTC is not enabled for namespace '%s'.\n", ns)
fmt.Printf(" Enable with: orama namespace enable webrtc --namespace %s\n", ns)
return
}
fmt.Printf("WebRTC Status for namespace '%s'\n\n", ns)
fmt.Printf(" Enabled: yes\n")
if sfuCount, ok := result["sfu_node_count"].(float64); ok {
fmt.Printf(" SFU nodes: %.0f\n", sfuCount)
}
if turnCount, ok := result["turn_node_count"].(float64); ok {
fmt.Printf(" TURN nodes: %.0f\n", turnCount)
}
if ttl, ok := result["turn_credential_ttl"].(float64); ok {
fmt.Printf(" TURN cred TTL: %.0fs\n", ttl)
}
if enabledBy, ok := result["enabled_by"].(string); ok {
fmt.Printf(" Enabled by: %s\n", enabledBy)
}
if enabledAt, ok := result["enabled_at"].(string); ok {
fmt.Printf(" Enabled at: %s\n", enabledAt)
}
}
func handleNamespaceList() { func handleNamespaceList() {
// Load credentials // Load credentials
store, err := auth.LoadEnhancedCredentials() store, err := auth.LoadEnhancedCredentials()

View File

@ -543,6 +543,8 @@ func (o *Orchestrator) installNamespaceTemplates() error {
"orama-namespace-rqlite@.service", "orama-namespace-rqlite@.service",
"orama-namespace-olric@.service", "orama-namespace-olric@.service",
"orama-namespace-gateway@.service", "orama-namespace-gateway@.service",
"orama-namespace-sfu@.service",
"orama-namespace-turn@.service",
} }
installedCount := 0 installedCount := 0

View File

@ -431,6 +431,8 @@ func (o *Orchestrator) installNamespaceTemplates() error {
"orama-namespace-rqlite@.service", "orama-namespace-rqlite@.service",
"orama-namespace-olric@.service", "orama-namespace-olric@.service",
"orama-namespace-gateway@.service", "orama-namespace-gateway@.service",
"orama-namespace-sfu@.service",
"orama-namespace-turn@.service",
} }
installedCount := 0 installedCount := 0

View File

@ -184,7 +184,7 @@ func GetProductionServices() []string {
namespacesDir := "/opt/orama/.orama/data/namespaces" namespacesDir := "/opt/orama/.orama/data/namespaces"
nsEntries, err := os.ReadDir(namespacesDir) nsEntries, err := os.ReadDir(namespacesDir)
if err == nil { if err == nil {
serviceTypes := []string{"rqlite", "olric", "gateway"} serviceTypes := []string{"rqlite", "olric", "gateway", "sfu", "turn"}
for _, nsEntry := range nsEntries { for _, nsEntry := range nsEntries {
if !nsEntry.IsDir() { if !nsEntry.IsDir() {
continue continue
@ -289,7 +289,8 @@ func identifyPortProcess(port int) string {
// NamespaceServiceOrder defines the dependency order for namespace services. // NamespaceServiceOrder defines the dependency order for namespace services.
// RQLite must start first (database), then Olric (cache), then Gateway (depends on both). // RQLite must start first (database), then Olric (cache), then Gateway (depends on both).
var NamespaceServiceOrder = []string{"rqlite", "olric", "gateway"} // TURN and SFU are optional WebRTC services that start after Gateway.
var NamespaceServiceOrder = []string{"rqlite", "olric", "gateway", "turn", "sfu"}
// StartServicesOrdered starts services respecting namespace dependency order. // StartServicesOrdered starts services respecting namespace dependency order.
// Namespace services are started in order: rqlite → olric (+ wait) → gateway. // Namespace services are started in order: rqlite → olric (+ wait) → gateway.

View File

@ -20,6 +20,17 @@ type HTTPGatewayConfig struct {
IPFSAPIURL string `yaml:"ipfs_api_url"` // IPFS API URL IPFSAPIURL string `yaml:"ipfs_api_url"` // IPFS API URL
IPFSTimeout time.Duration `yaml:"ipfs_timeout"` // Timeout for IPFS operations IPFSTimeout time.Duration `yaml:"ipfs_timeout"` // Timeout for IPFS operations
BaseDomain string `yaml:"base_domain"` // Base domain for deployments (e.g., "dbrs.space"). Defaults to "dbrs.space" BaseDomain string `yaml:"base_domain"` // Base domain for deployments (e.g., "dbrs.space"). Defaults to "dbrs.space"
// WebRTC configuration (optional, enabled per-namespace)
WebRTC WebRTCConfig `yaml:"webrtc"`
}
// WebRTCConfig contains WebRTC-related gateway configuration
type WebRTCConfig struct {
Enabled bool `yaml:"enabled"` // Whether this gateway has WebRTC support active
SFUPort int `yaml:"sfu_port"` // Local SFU signaling port to proxy to
TURNDomain string `yaml:"turn_domain"` // TURN domain (e.g., "turn.ns-myapp.dbrs.space")
TURNSecret string `yaml:"turn_secret"` // HMAC-SHA1 shared secret for TURN credential generation
} }
// HTTPSConfig contains HTTPS/TLS configuration for the gateway // HTTPSConfig contains HTTPS/TLS configuration for the gateway

View File

@ -41,4 +41,10 @@ type Config struct {
// WireGuard mesh configuration // WireGuard mesh configuration
ClusterSecret string // Cluster secret for authenticating internal WireGuard peer exchange ClusterSecret string // Cluster secret for authenticating internal WireGuard peer exchange
// WebRTC configuration (set when namespace has WebRTC enabled)
WebRTCEnabled bool // Whether WebRTC endpoints are active on this gateway
SFUPort int // Local SFU signaling port to proxy WebSocket connections to
TURNDomain string // TURN server domain for credential generation
TURNSecret string // HMAC-SHA1 shared secret for TURN credential generation
} }

View File

@ -70,6 +70,16 @@ func (c *Config) ValidateConfig() []error {
} }
} }
// Validate WebRTC configuration
if c.WebRTCEnabled {
if c.SFUPort <= 0 || c.SFUPort > 65535 {
errs = append(errs, fmt.Errorf("gateway.sfu_port: must be between 1 and 65535 when webrtc is enabled"))
}
if c.TURNSecret == "" {
errs = append(errs, fmt.Errorf("gateway.turn_secret: must not be empty when webrtc is enabled"))
}
}
// Validate HTTPS configuration // Validate HTTPS configuration
if c.EnableHTTPS { if c.EnableHTTPS {
if c.DomainName == "" { if c.DomainName == "" {

View File

@ -30,6 +30,7 @@ import (
pubsubhandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/pubsub" pubsubhandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/pubsub"
serverlesshandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/serverless" serverlesshandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/serverless"
joinhandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/join" joinhandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/join"
webrtchandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/webrtc"
wireguardhandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/wireguard" wireguardhandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/wireguard"
sqlitehandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/sqlite" sqlitehandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/sqlite"
"github.com/DeBrosOfficial/network/pkg/gateway/handlers/storage" "github.com/DeBrosOfficial/network/pkg/gateway/handlers/storage"
@ -122,6 +123,9 @@ type Gateway struct {
rateLimiter *RateLimiter rateLimiter *RateLimiter
namespaceRateLimiter *NamespaceRateLimiter namespaceRateLimiter *NamespaceRateLimiter
// WebRTC signaling and TURN credentials
webrtcHandlers *webrtchandlers.WebRTCHandlers
// WireGuard peer exchange // WireGuard peer exchange
wireguardHandler *wireguardhandlers.Handler wireguardHandler *wireguardhandlers.Handler
@ -149,6 +153,9 @@ type Gateway struct {
// Node recovery handler (called when health monitor confirms a node dead or recovered) // Node recovery handler (called when health monitor confirms a node dead or recovered)
nodeRecoverer authhandlers.NodeRecoverer nodeRecoverer authhandlers.NodeRecoverer
// WebRTC manager for enable/disable operations
webrtcManager authhandlers.WebRTCManager
// Circuit breakers for proxy targets (per-target failure tracking) // Circuit breakers for proxy targets (per-target failure tracking)
circuitBreakers *CircuitBreakerRegistry circuitBreakers *CircuitBreakerRegistry
@ -323,6 +330,18 @@ func New(logger *logging.ColoredLogger, cfg *Config) (*Gateway, error) {
// Initialize handler instances // Initialize handler instances
gw.pubsubHandlers = pubsubhandlers.NewPubSubHandlers(deps.Client, logger) gw.pubsubHandlers = pubsubhandlers.NewPubSubHandlers(deps.Client, logger)
if cfg.WebRTCEnabled && cfg.SFUPort > 0 {
gw.webrtcHandlers = webrtchandlers.NewWebRTCHandlers(
logger,
cfg.SFUPort,
cfg.TURNDomain,
cfg.TURNSecret,
gw.proxyWebSocket,
)
logger.ComponentInfo(logging.ComponentGeneral, "WebRTC handlers initialized",
zap.Int("sfu_port", cfg.SFUPort))
}
if deps.OlricClient != nil { if deps.OlricClient != nil {
gw.cacheHandlers = cache.NewCacheHandlers(logger, deps.OlricClient) gw.cacheHandlers = cache.NewCacheHandlers(logger, deps.OlricClient)
} }
@ -633,6 +652,11 @@ func (g *Gateway) SetNodeRecoverer(nr authhandlers.NodeRecoverer) {
g.nodeRecoverer = nr g.nodeRecoverer = nr
} }
// SetWebRTCManager sets the WebRTC lifecycle manager for enable/disable operations.
func (g *Gateway) SetWebRTCManager(wm authhandlers.WebRTCManager) {
g.webrtcManager = wm
}
// SetSpawnHandler sets the handler for internal namespace spawn/stop requests. // SetSpawnHandler sets the handler for internal namespace spawn/stop requests.
func (g *Gateway) SetSpawnHandler(h http.Handler) { func (g *Gateway) SetSpawnHandler(h http.Handler) {
g.spawnHandler = h g.spawnHandler = h
@ -847,3 +871,121 @@ func (g *Gateway) namespaceClusterRepairHandler(w http.ResponseWriter, r *http.R
}) })
} }
// namespaceWebRTCEnableHandler handles POST /v1/internal/namespace/webrtc/enable?namespace={name}
// Internal-only: authenticated by X-Orama-Internal-Auth header + WireGuard subnet.
func (g *Gateway) namespaceWebRTCEnableHandler(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
return
}
if r.Header.Get("X-Orama-Internal-Auth") != "namespace-coordination" || !nodeauth.IsWireGuardPeer(r.RemoteAddr) {
writeError(w, http.StatusUnauthorized, "unauthorized")
return
}
namespaceName := r.URL.Query().Get("namespace")
if namespaceName == "" {
writeError(w, http.StatusBadRequest, "namespace parameter required")
return
}
if g.webrtcManager == nil {
writeError(w, http.StatusServiceUnavailable, "WebRTC management not enabled")
return
}
if err := g.webrtcManager.EnableWebRTC(r.Context(), namespaceName, "cli"); err != nil {
writeError(w, http.StatusInternalServerError, err.Error())
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]interface{}{
"status": "ok",
"namespace": namespaceName,
"message": "WebRTC enabled successfully",
})
}
// namespaceWebRTCDisableHandler handles POST /v1/internal/namespace/webrtc/disable?namespace={name}
// Internal-only: authenticated by X-Orama-Internal-Auth header + WireGuard subnet.
func (g *Gateway) namespaceWebRTCDisableHandler(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
return
}
if r.Header.Get("X-Orama-Internal-Auth") != "namespace-coordination" || !nodeauth.IsWireGuardPeer(r.RemoteAddr) {
writeError(w, http.StatusUnauthorized, "unauthorized")
return
}
namespaceName := r.URL.Query().Get("namespace")
if namespaceName == "" {
writeError(w, http.StatusBadRequest, "namespace parameter required")
return
}
if g.webrtcManager == nil {
writeError(w, http.StatusServiceUnavailable, "WebRTC management not enabled")
return
}
if err := g.webrtcManager.DisableWebRTC(r.Context(), namespaceName); err != nil {
writeError(w, http.StatusInternalServerError, err.Error())
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]interface{}{
"status": "ok",
"namespace": namespaceName,
"message": "WebRTC disabled successfully",
})
}
// namespaceWebRTCStatusHandler handles GET /v1/internal/namespace/webrtc/status?namespace={name}
// Internal-only: authenticated by X-Orama-Internal-Auth header + WireGuard subnet.
func (g *Gateway) namespaceWebRTCStatusHandler(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
return
}
if r.Header.Get("X-Orama-Internal-Auth") != "namespace-coordination" || !nodeauth.IsWireGuardPeer(r.RemoteAddr) {
writeError(w, http.StatusUnauthorized, "unauthorized")
return
}
namespaceName := r.URL.Query().Get("namespace")
if namespaceName == "" {
writeError(w, http.StatusBadRequest, "namespace parameter required")
return
}
if g.webrtcManager == nil {
writeError(w, http.StatusServiceUnavailable, "WebRTC management not enabled")
return
}
config, err := g.webrtcManager.GetWebRTCStatus(r.Context(), namespaceName)
if err != nil {
writeError(w, http.StatusInternalServerError, err.Error())
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
if config == nil {
json.NewEncoder(w).Encode(map[string]interface{}{
"namespace": namespaceName,
"enabled": false,
})
} else {
json.NewEncoder(w).Encode(config)
}
}

View File

@ -58,6 +58,14 @@ type NodeRecoverer interface {
RepairCluster(ctx context.Context, namespaceName string) error RepairCluster(ctx context.Context, namespaceName string) error
} }
// WebRTCManager handles enabling/disabling WebRTC services for namespaces.
type WebRTCManager interface {
EnableWebRTC(ctx context.Context, namespaceName, enabledBy string) error
DisableWebRTC(ctx context.Context, namespaceName string) error
// GetWebRTCStatus returns the WebRTC config for a namespace, or nil if not enabled.
GetWebRTCStatus(ctx context.Context, namespaceName string) (interface{}, error)
}
// Handlers holds dependencies for authentication HTTP handlers // Handlers holds dependencies for authentication HTTP handlers
type Handlers struct { type Handlers struct {
logger *logging.ColoredLogger logger *logging.ColoredLogger

View File

@ -302,6 +302,8 @@ func (h *DeleteHandler) cleanupGlobalTables(ctx context.Context, ns string) {
{"namespace_sqlite_databases", "namespace"}, {"namespace_sqlite_databases", "namespace"},
{"namespace_quotas", "namespace"}, {"namespace_quotas", "namespace"},
{"home_node_assignments", "namespace"}, {"home_node_assignments", "namespace"},
{"webrtc_rooms", "namespace_name"},
{"namespace_webrtc_config", "namespace_name"},
} }
for _, t := range tables { for _, t := range tables {

View File

@ -12,12 +12,13 @@ import (
namespacepkg "github.com/DeBrosOfficial/network/pkg/namespace" namespacepkg "github.com/DeBrosOfficial/network/pkg/namespace"
"github.com/DeBrosOfficial/network/pkg/olric" "github.com/DeBrosOfficial/network/pkg/olric"
"github.com/DeBrosOfficial/network/pkg/rqlite" "github.com/DeBrosOfficial/network/pkg/rqlite"
"github.com/DeBrosOfficial/network/pkg/sfu"
"go.uber.org/zap" "go.uber.org/zap"
) )
// SpawnRequest represents a request to spawn or stop a namespace instance // SpawnRequest represents a request to spawn or stop a namespace instance
type SpawnRequest struct { type SpawnRequest struct {
Action string `json:"action"` // "spawn-rqlite", "spawn-olric", "spawn-gateway", "stop-rqlite", "stop-olric", "stop-gateway", "save-cluster-state", "delete-cluster-state" Action string `json:"action"` // spawn-{rqlite,olric,gateway,sfu,turn}, stop-{rqlite,olric,gateway,sfu,turn}, save-cluster-state, delete-cluster-state
Namespace string `json:"namespace"` Namespace string `json:"namespace"`
NodeID string `json:"node_id"` NodeID string `json:"node_id"`
@ -48,6 +49,24 @@ type SpawnRequest struct {
IPFSTimeout string `json:"ipfs_timeout,omitempty"` IPFSTimeout string `json:"ipfs_timeout,omitempty"`
IPFSReplicationFactor int `json:"ipfs_replication_factor,omitempty"` IPFSReplicationFactor int `json:"ipfs_replication_factor,omitempty"`
// SFU config (when action = "spawn-sfu")
SFUListenAddr string `json:"sfu_listen_addr,omitempty"`
SFUMediaStart int `json:"sfu_media_start,omitempty"`
SFUMediaEnd int `json:"sfu_media_end,omitempty"`
TURNServers []sfu.TURNServerConfig `json:"turn_servers,omitempty"`
TURNSecret string `json:"turn_secret,omitempty"`
TURNCredTTL int `json:"turn_cred_ttl,omitempty"`
RQLiteDSN string `json:"rqlite_dsn,omitempty"`
// TURN config (when action = "spawn-turn")
TURNListenAddr string `json:"turn_listen_addr,omitempty"`
TURNTLSAddr string `json:"turn_tls_addr,omitempty"`
TURNPublicIP string `json:"turn_public_ip,omitempty"`
TURNRealm string `json:"turn_realm,omitempty"`
TURNAuthSecret string `json:"turn_auth_secret,omitempty"`
TURNRelayStart int `json:"turn_relay_start,omitempty"`
TURNRelayEnd int `json:"turn_relay_end,omitempty"`
// Cluster state (when action = "save-cluster-state") // Cluster state (when action = "save-cluster-state")
ClusterState json.RawMessage `json:"cluster_state,omitempty"` ClusterState json.RawMessage `json:"cluster_state,omitempty"`
} }
@ -242,6 +261,60 @@ func (h *SpawnHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
writeSpawnResponse(w, http.StatusOK, SpawnResponse{Success: true}) writeSpawnResponse(w, http.StatusOK, SpawnResponse{Success: true})
case "spawn-sfu":
cfg := namespacepkg.SFUInstanceConfig{
Namespace: req.Namespace,
NodeID: req.NodeID,
ListenAddr: req.SFUListenAddr,
MediaPortStart: req.SFUMediaStart,
MediaPortEnd: req.SFUMediaEnd,
TURNServers: req.TURNServers,
TURNSecret: req.TURNSecret,
TURNCredTTL: req.TURNCredTTL,
RQLiteDSN: req.RQLiteDSN,
}
if err := h.systemdSpawner.SpawnSFU(ctx, req.Namespace, req.NodeID, cfg); err != nil {
h.logger.Error("Failed to spawn SFU instance", zap.Error(err))
writeSpawnResponse(w, http.StatusInternalServerError, SpawnResponse{Error: err.Error()})
return
}
writeSpawnResponse(w, http.StatusOK, SpawnResponse{Success: true})
case "stop-sfu":
if err := h.systemdSpawner.StopSFU(ctx, req.Namespace, req.NodeID); err != nil {
h.logger.Error("Failed to stop SFU instance", zap.Error(err))
writeSpawnResponse(w, http.StatusInternalServerError, SpawnResponse{Error: err.Error()})
return
}
writeSpawnResponse(w, http.StatusOK, SpawnResponse{Success: true})
case "spawn-turn":
cfg := namespacepkg.TURNInstanceConfig{
Namespace: req.Namespace,
NodeID: req.NodeID,
ListenAddr: req.TURNListenAddr,
TLSListenAddr: req.TURNTLSAddr,
PublicIP: req.TURNPublicIP,
Realm: req.TURNRealm,
AuthSecret: req.TURNAuthSecret,
RelayPortStart: req.TURNRelayStart,
RelayPortEnd: req.TURNRelayEnd,
}
if err := h.systemdSpawner.SpawnTURN(ctx, req.Namespace, req.NodeID, cfg); err != nil {
h.logger.Error("Failed to spawn TURN instance", zap.Error(err))
writeSpawnResponse(w, http.StatusInternalServerError, SpawnResponse{Error: err.Error()})
return
}
writeSpawnResponse(w, http.StatusOK, SpawnResponse{Success: true})
case "stop-turn":
if err := h.systemdSpawner.StopTURN(ctx, req.Namespace, req.NodeID); err != nil {
h.logger.Error("Failed to stop TURN instance", zap.Error(err))
writeSpawnResponse(w, http.StatusInternalServerError, SpawnResponse{Error: err.Error()})
return
}
writeSpawnResponse(w, http.StatusOK, SpawnResponse{Success: true})
default: default:
writeSpawnResponse(w, http.StatusBadRequest, SpawnResponse{Error: fmt.Sprintf("unknown action: %s", req.Action)}) writeSpawnResponse(w, http.StatusBadRequest, SpawnResponse{Error: fmt.Sprintf("unknown action: %s", req.Action)})
} }

View File

@ -0,0 +1,56 @@
package webrtc
import (
"fmt"
"net/http"
"time"
"github.com/DeBrosOfficial/network/pkg/logging"
"github.com/DeBrosOfficial/network/pkg/turn"
"go.uber.org/zap"
)
const turnCredentialTTL = 10 * time.Minute
// CredentialsHandler handles POST /v1/webrtc/turn/credentials
// Returns fresh TURN credentials scoped to the authenticated namespace.
func (h *WebRTCHandlers) CredentialsHandler(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
return
}
ns := resolveNamespaceFromRequest(r)
if ns == "" {
writeError(w, http.StatusForbidden, "namespace not resolved")
return
}
if h.turnSecret == "" {
writeError(w, http.StatusServiceUnavailable, "TURN not configured")
return
}
username, password := turn.GenerateCredentials(h.turnSecret, ns, turnCredentialTTL)
// Build TURN URIs — use IPs to bypass DNS propagation delays
var uris []string
if h.turnDomain != "" {
uris = append(uris,
fmt.Sprintf("turn:%s:3478?transport=udp", h.turnDomain),
fmt.Sprintf("turn:%s:443?transport=udp", h.turnDomain),
)
}
h.logger.ComponentInfo(logging.ComponentGeneral, "Issued TURN credentials",
zap.String("namespace", ns),
zap.String("username", username),
)
writeJSON(w, http.StatusOK, map[string]interface{}{
"username": username,
"password": password,
"ttl": int(turnCredentialTTL.Seconds()),
"uris": uris,
})
}

View File

@ -0,0 +1,270 @@
package webrtc
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/DeBrosOfficial/network/pkg/gateway/ctxkeys"
"github.com/DeBrosOfficial/network/pkg/logging"
)
func testHandlers() *WebRTCHandlers {
logger, _ := logging.NewColoredLogger(logging.ComponentGeneral, false)
return NewWebRTCHandlers(
logger,
8443,
"turn.ns-test.dbrs.space",
"test-secret-key-32bytes-long!!!!",
nil, // No actual proxy in tests
)
}
func requestWithNamespace(method, path, namespace string) *http.Request {
req := httptest.NewRequest(method, path, nil)
ctx := context.WithValue(req.Context(), ctxkeys.NamespaceOverride, namespace)
return req.WithContext(ctx)
}
// --- Credentials handler tests ---
func TestCredentialsHandler_Success(t *testing.T) {
h := testHandlers()
req := requestWithNamespace("POST", "/v1/webrtc/turn/credentials", "test-ns")
w := httptest.NewRecorder()
h.CredentialsHandler(w, req)
if w.Code != http.StatusOK {
t.Errorf("status = %d, want %d", w.Code, http.StatusOK)
}
var result map[string]interface{}
if err := json.NewDecoder(w.Body).Decode(&result); err != nil {
t.Fatalf("failed to decode response: %v", err)
}
if result["username"] == nil || result["username"] == "" {
t.Error("expected non-empty username")
}
if result["password"] == nil || result["password"] == "" {
t.Error("expected non-empty password")
}
if result["ttl"] == nil {
t.Error("expected ttl field")
}
ttl, ok := result["ttl"].(float64)
if !ok || ttl != 600 {
t.Errorf("ttl = %v, want 600", result["ttl"])
}
uris, ok := result["uris"].([]interface{})
if !ok || len(uris) != 2 {
t.Errorf("uris count = %v, want 2", result["uris"])
}
}
func TestCredentialsHandler_MethodNotAllowed(t *testing.T) {
h := testHandlers()
req := requestWithNamespace("GET", "/v1/webrtc/turn/credentials", "test-ns")
w := httptest.NewRecorder()
h.CredentialsHandler(w, req)
if w.Code != http.StatusMethodNotAllowed {
t.Errorf("status = %d, want %d", w.Code, http.StatusMethodNotAllowed)
}
}
func TestCredentialsHandler_NoNamespace(t *testing.T) {
h := testHandlers()
req := httptest.NewRequest("POST", "/v1/webrtc/turn/credentials", nil)
w := httptest.NewRecorder()
h.CredentialsHandler(w, req)
if w.Code != http.StatusForbidden {
t.Errorf("status = %d, want %d", w.Code, http.StatusForbidden)
}
}
func TestCredentialsHandler_NoTURNSecret(t *testing.T) {
logger, _ := logging.NewColoredLogger(logging.ComponentGeneral, false)
h := NewWebRTCHandlers(logger, 8443, "turn.test.dbrs.space", "", nil)
req := requestWithNamespace("POST", "/v1/webrtc/turn/credentials", "test-ns")
w := httptest.NewRecorder()
h.CredentialsHandler(w, req)
if w.Code != http.StatusServiceUnavailable {
t.Errorf("status = %d, want %d", w.Code, http.StatusServiceUnavailable)
}
}
// --- Signal handler tests ---
func TestSignalHandler_NoNamespace(t *testing.T) {
h := testHandlers()
req := httptest.NewRequest("GET", "/v1/webrtc/signal", nil)
w := httptest.NewRecorder()
h.SignalHandler(w, req)
if w.Code != http.StatusForbidden {
t.Errorf("status = %d, want %d", w.Code, http.StatusForbidden)
}
}
func TestSignalHandler_NoSFUPort(t *testing.T) {
logger, _ := logging.NewColoredLogger(logging.ComponentGeneral, false)
h := NewWebRTCHandlers(logger, 0, "", "secret", nil)
req := requestWithNamespace("GET", "/v1/webrtc/signal", "test-ns")
w := httptest.NewRecorder()
h.SignalHandler(w, req)
if w.Code != http.StatusServiceUnavailable {
t.Errorf("status = %d, want %d", w.Code, http.StatusServiceUnavailable)
}
}
func TestSignalHandler_NoProxyFunc(t *testing.T) {
h := testHandlers() // proxyWebSocket is nil
req := requestWithNamespace("GET", "/v1/webrtc/signal", "test-ns")
w := httptest.NewRecorder()
h.SignalHandler(w, req)
if w.Code != http.StatusInternalServerError {
t.Errorf("status = %d, want %d", w.Code, http.StatusInternalServerError)
}
}
// --- Rooms handler tests ---
func TestRoomsHandler_MethodNotAllowed(t *testing.T) {
h := testHandlers()
req := requestWithNamespace("POST", "/v1/webrtc/rooms", "test-ns")
w := httptest.NewRecorder()
h.RoomsHandler(w, req)
if w.Code != http.StatusMethodNotAllowed {
t.Errorf("status = %d, want %d", w.Code, http.StatusMethodNotAllowed)
}
}
func TestRoomsHandler_NoNamespace(t *testing.T) {
h := testHandlers()
req := httptest.NewRequest("GET", "/v1/webrtc/rooms", nil)
w := httptest.NewRecorder()
h.RoomsHandler(w, req)
if w.Code != http.StatusForbidden {
t.Errorf("status = %d, want %d", w.Code, http.StatusForbidden)
}
}
func TestRoomsHandler_NoSFUPort(t *testing.T) {
logger, _ := logging.NewColoredLogger(logging.ComponentGeneral, false)
h := NewWebRTCHandlers(logger, 0, "", "secret", nil)
req := requestWithNamespace("GET", "/v1/webrtc/rooms", "test-ns")
w := httptest.NewRecorder()
h.RoomsHandler(w, req)
if w.Code != http.StatusServiceUnavailable {
t.Errorf("status = %d, want %d", w.Code, http.StatusServiceUnavailable)
}
}
func TestRoomsHandler_SFUProxySuccess(t *testing.T) {
// Start a mock SFU health endpoint
mockSFU := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"status":"ok","rooms":3}`))
}))
defer mockSFU.Close()
// Extract port from mock server
logger, _ := logging.NewColoredLogger(logging.ComponentGeneral, false)
// Parse port from mockSFU.URL (format: http://127.0.0.1:PORT)
var port int
for i := len(mockSFU.URL) - 1; i >= 0; i-- {
if mockSFU.URL[i] == ':' {
p := mockSFU.URL[i+1:]
for _, c := range p {
port = port*10 + int(c-'0')
}
break
}
}
h := NewWebRTCHandlers(logger, port, "", "secret", nil)
req := requestWithNamespace("GET", "/v1/webrtc/rooms", "test-ns")
w := httptest.NewRecorder()
h.RoomsHandler(w, req)
if w.Code != http.StatusOK {
t.Errorf("status = %d, want %d", w.Code, http.StatusOK)
}
body := w.Body.String()
if body != `{"status":"ok","rooms":3}` {
t.Errorf("body = %q, want %q", body, `{"status":"ok","rooms":3}`)
}
}
// --- Helper tests ---
func TestResolveNamespaceFromRequest(t *testing.T) {
// With namespace
req := requestWithNamespace("GET", "/test", "my-namespace")
ns := resolveNamespaceFromRequest(req)
if ns != "my-namespace" {
t.Errorf("namespace = %q, want %q", ns, "my-namespace")
}
// Without namespace
req = httptest.NewRequest("GET", "/test", nil)
ns = resolveNamespaceFromRequest(req)
if ns != "" {
t.Errorf("namespace = %q, want empty", ns)
}
}
func TestWriteError(t *testing.T) {
w := httptest.NewRecorder()
writeError(w, http.StatusBadRequest, "bad request")
if w.Code != http.StatusBadRequest {
t.Errorf("status = %d, want %d", w.Code, http.StatusBadRequest)
}
var result map[string]string
if err := json.NewDecoder(w.Body).Decode(&result); err != nil {
t.Fatalf("failed to decode: %v", err)
}
if result["error"] != "bad request" {
t.Errorf("error = %q, want %q", result["error"], "bad request")
}
}
func TestWriteJSON(t *testing.T) {
w := httptest.NewRecorder()
writeJSON(w, http.StatusOK, map[string]string{"status": "ok"})
if w.Code != http.StatusOK {
t.Errorf("status = %d, want %d", w.Code, http.StatusOK)
}
if ct := w.Header().Get("Content-Type"); ct != "application/json" {
t.Errorf("Content-Type = %q, want %q", ct, "application/json")
}
}

View File

@ -0,0 +1,51 @@
package webrtc
import (
"fmt"
"io"
"net/http"
"time"
"github.com/DeBrosOfficial/network/pkg/logging"
"go.uber.org/zap"
)
// RoomsHandler handles GET /v1/webrtc/rooms (list rooms)
// and GET /v1/webrtc/rooms?room_id=X (get specific room)
// Proxies to the local SFU's health endpoint for room data.
func (h *WebRTCHandlers) RoomsHandler(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
return
}
ns := resolveNamespaceFromRequest(r)
if ns == "" {
writeError(w, http.StatusForbidden, "namespace not resolved")
return
}
if h.sfuPort <= 0 {
writeError(w, http.StatusServiceUnavailable, "SFU not configured")
return
}
// Proxy to SFU health endpoint which returns room count
targetURL := fmt.Sprintf("http://127.0.0.1:%d/health", h.sfuPort)
client := &http.Client{Timeout: 5 * time.Second}
resp, err := client.Get(targetURL)
if err != nil {
h.logger.ComponentWarn(logging.ComponentGeneral, "SFU health check failed",
zap.String("namespace", ns),
zap.Error(err),
)
writeError(w, http.StatusServiceUnavailable, "SFU unavailable")
return
}
defer resp.Body.Close()
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(resp.StatusCode)
io.Copy(w, resp.Body)
}

View File

@ -0,0 +1,52 @@
package webrtc
import (
"fmt"
"net/http"
"github.com/DeBrosOfficial/network/pkg/logging"
"go.uber.org/zap"
)
// SignalHandler handles WebSocket /v1/webrtc/signal
// Proxies the WebSocket connection to the local SFU's signaling endpoint.
func (h *WebRTCHandlers) SignalHandler(w http.ResponseWriter, r *http.Request) {
ns := resolveNamespaceFromRequest(r)
if ns == "" {
writeError(w, http.StatusForbidden, "namespace not resolved")
return
}
if h.sfuPort <= 0 {
writeError(w, http.StatusServiceUnavailable, "SFU not configured")
return
}
// Proxy WebSocket to local SFU on WireGuard IP
// SFU binds to WireGuard IP, so we use 127.0.0.1 since we're on the same node
targetHost := fmt.Sprintf("127.0.0.1:%d", h.sfuPort)
h.logger.ComponentDebug(logging.ComponentGeneral, "Proxying WebRTC signal to SFU",
zap.String("namespace", ns),
zap.String("target", targetHost),
)
// Rewrite the URL path to match the SFU's expected endpoint
r.URL.Path = "/ws/signal"
r.URL.Scheme = "http"
r.URL.Host = targetHost
r.Host = targetHost
if h.proxyWebSocket == nil {
writeError(w, http.StatusInternalServerError, "WebSocket proxy not available")
return
}
if !h.proxyWebSocket(w, r, targetHost) {
// proxyWebSocket already wrote the error response
h.logger.ComponentWarn(logging.ComponentGeneral, "SFU WebSocket proxy failed",
zap.String("namespace", ns),
zap.String("target", targetHost),
)
}
}

View File

@ -0,0 +1,58 @@
package webrtc
import (
"encoding/json"
"net/http"
"github.com/DeBrosOfficial/network/pkg/gateway/ctxkeys"
"github.com/DeBrosOfficial/network/pkg/logging"
)
// WebRTCHandlers handles all WebRTC-related HTTP and WebSocket endpoints.
// These run on the namespace gateway and proxy signaling to the local SFU.
type WebRTCHandlers struct {
logger *logging.ColoredLogger
sfuPort int // Local SFU signaling port to proxy WebSocket connections to
turnDomain string // TURN server domain for building URIs
turnSecret string // HMAC-SHA1 shared secret for TURN credential generation
// proxyWebSocket is injected from the gateway to reuse its WebSocket proxy logic
proxyWebSocket func(w http.ResponseWriter, r *http.Request, targetHost string) bool
}
// NewWebRTCHandlers creates a new WebRTCHandlers instance.
func NewWebRTCHandlers(
logger *logging.ColoredLogger,
sfuPort int,
turnDomain string,
turnSecret string,
proxyWS func(w http.ResponseWriter, r *http.Request, targetHost string) bool,
) *WebRTCHandlers {
return &WebRTCHandlers{
logger: logger,
sfuPort: sfuPort,
turnDomain: turnDomain,
turnSecret: turnSecret,
proxyWebSocket: proxyWS,
}
}
// resolveNamespaceFromRequest gets namespace from context set by auth middleware
func resolveNamespaceFromRequest(r *http.Request) string {
if v := r.Context().Value(ctxkeys.NamespaceOverride); v != nil {
if s, ok := v.(string); ok {
return s
}
}
return ""
}
func writeJSON(w http.ResponseWriter, code int, v any) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(code)
json.NewEncoder(w).Encode(v)
}
func writeError(w http.ResponseWriter, code int, msg string) {
writeJSON(w, code, map[string]string{"error": msg})
}

View File

@ -196,7 +196,7 @@ func (g *Gateway) securityHeadersMiddleware(next http.Handler) http.Handler {
w.Header().Set("X-Frame-Options", "DENY") w.Header().Set("X-Frame-Options", "DENY")
w.Header().Set("X-XSS-Protection", "0") w.Header().Set("X-XSS-Protection", "0")
w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin") w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
w.Header().Set("Permissions-Policy", "camera=(), microphone=(), geolocation=()") w.Header().Set("Permissions-Policy", "camera=(self), microphone=(self), geolocation=()")
// HSTS only when behind TLS (Caddy) // HSTS only when behind TLS (Caddy)
if r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https" { if r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https" {
w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains") w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains")
@ -618,6 +618,9 @@ func requiresNamespaceOwnership(p string) bool {
if strings.HasPrefix(p, "/v1/functions") { if strings.HasPrefix(p, "/v1/functions") {
return true return true
} }
if strings.HasPrefix(p, "/v1/webrtc/") {
return true
}
return false return false
} }

View File

@ -382,7 +382,7 @@ func TestSecurityHeadersMiddleware(t *testing.T) {
"X-Frame-Options": "DENY", "X-Frame-Options": "DENY",
"X-Xss-Protection": "0", "X-Xss-Protection": "0",
"Referrer-Policy": "strict-origin-when-cross-origin", "Referrer-Policy": "strict-origin-when-cross-origin",
"Permissions-Policy": "camera=(), microphone=(), geolocation=()", "Permissions-Policy": "camera=(self), microphone=(self), geolocation=()",
} }
for header, want := range expected { for header, want := range expected {
got := rr.Header().Get(header) got := rr.Header().Get(header)

View File

@ -47,6 +47,11 @@ func (g *Gateway) Routes() http.Handler {
// Namespace cluster repair (internal, handler does its own auth) // Namespace cluster repair (internal, handler does its own auth)
mux.HandleFunc("/v1/internal/namespace/repair", g.namespaceClusterRepairHandler) mux.HandleFunc("/v1/internal/namespace/repair", g.namespaceClusterRepairHandler)
// Namespace WebRTC enable/disable/status (internal, handler does its own auth)
mux.HandleFunc("/v1/internal/namespace/webrtc/enable", g.namespaceWebRTCEnableHandler)
mux.HandleFunc("/v1/internal/namespace/webrtc/disable", g.namespaceWebRTCDisableHandler)
mux.HandleFunc("/v1/internal/namespace/webrtc/status", g.namespaceWebRTCStatusHandler)
// auth endpoints // auth endpoints
mux.HandleFunc("/v1/auth/jwks", g.authService.JWKSHandler) mux.HandleFunc("/v1/auth/jwks", g.authService.JWKSHandler)
mux.HandleFunc("/.well-known/jwks.json", g.authService.JWKSHandler) mux.HandleFunc("/.well-known/jwks.json", g.authService.JWKSHandler)
@ -104,6 +109,13 @@ func (g *Gateway) Routes() http.Handler {
mux.HandleFunc("/v1/pubsub/presence", g.pubsubHandlers.PresenceHandler) mux.HandleFunc("/v1/pubsub/presence", g.pubsubHandlers.PresenceHandler)
} }
// webrtc
if g.webrtcHandlers != nil {
mux.HandleFunc("/v1/webrtc/turn/credentials", g.webrtcHandlers.CredentialsHandler)
mux.HandleFunc("/v1/webrtc/signal", g.webrtcHandlers.SignalHandler)
mux.HandleFunc("/v1/webrtc/rooms", g.webrtcHandlers.RoomsHandler)
}
// anon proxy (authenticated users only) // anon proxy (authenticated users only)
mux.HandleFunc("/v1/proxy/anon", g.anonProxyHandler) mux.HandleFunc("/v1/proxy/anon", g.anonProxyHandler)

View File

@ -55,6 +55,8 @@ const (
ComponentGeneral Component = "GENERAL" ComponentGeneral Component = "GENERAL"
ComponentAnyone Component = "ANYONE" ComponentAnyone Component = "ANYONE"
ComponentGateway Component = "GATEWAY" ComponentGateway Component = "GATEWAY"
ComponentSFU Component = "SFU"
ComponentTURN Component = "TURN"
) )
// getComponentColor returns the color for a specific component // getComponentColor returns the color for a specific component
@ -78,6 +80,10 @@ func getComponentColor(component Component) string {
return Cyan return Cyan
case ComponentGateway: case ComponentGateway:
return BrightGreen return BrightGreen
case ComponentSFU:
return BrightRed
case ComponentTURN:
return Magenta
default: default:
return White return White
} }

View File

@ -18,6 +18,7 @@ import (
"github.com/DeBrosOfficial/network/pkg/gateway" "github.com/DeBrosOfficial/network/pkg/gateway"
"github.com/DeBrosOfficial/network/pkg/olric" "github.com/DeBrosOfficial/network/pkg/olric"
"github.com/DeBrosOfficial/network/pkg/rqlite" "github.com/DeBrosOfficial/network/pkg/rqlite"
"github.com/DeBrosOfficial/network/pkg/sfu"
"github.com/DeBrosOfficial/network/pkg/systemd" "github.com/DeBrosOfficial/network/pkg/systemd"
"github.com/google/uuid" "github.com/google/uuid"
"go.uber.org/zap" "go.uber.org/zap"
@ -37,12 +38,13 @@ type ClusterManagerConfig struct {
// ClusterManager orchestrates namespace cluster provisioning and lifecycle // ClusterManager orchestrates namespace cluster provisioning and lifecycle
type ClusterManager struct { type ClusterManager struct {
db rqlite.Client db rqlite.Client
portAllocator *NamespacePortAllocator portAllocator *NamespacePortAllocator
nodeSelector *ClusterNodeSelector webrtcPortAllocator *WebRTCPortAllocator
systemdSpawner *SystemdSpawner // NEW: Systemd-based spawner replaces old spawners nodeSelector *ClusterNodeSelector
dnsManager *DNSRecordManager systemdSpawner *SystemdSpawner // NEW: Systemd-based spawner replaces old spawners
logger *zap.Logger dnsManager *DNSRecordManager
logger *zap.Logger
baseDomain string baseDomain string
baseDataDir string baseDataDir string
globalRQLiteDSN string // Global RQLite DSN for namespace gateway auth globalRQLiteDSN string // Global RQLite DSN for namespace gateway auth
@ -69,6 +71,7 @@ func NewClusterManager(
) *ClusterManager { ) *ClusterManager {
// Create internal components // Create internal components
portAllocator := NewNamespacePortAllocator(db, logger) portAllocator := NewNamespacePortAllocator(db, logger)
webrtcPortAllocator := NewWebRTCPortAllocator(db, logger)
nodeSelector := NewClusterNodeSelector(db, portAllocator, logger) nodeSelector := NewClusterNodeSelector(db, portAllocator, logger)
systemdSpawner := NewSystemdSpawner(cfg.BaseDataDir, logger) systemdSpawner := NewSystemdSpawner(cfg.BaseDataDir, logger)
dnsManager := NewDNSRecordManager(db, cfg.BaseDomain, logger) dnsManager := NewDNSRecordManager(db, cfg.BaseDomain, logger)
@ -94,6 +97,7 @@ func NewClusterManager(
return &ClusterManager{ return &ClusterManager{
db: db, db: db,
portAllocator: portAllocator, portAllocator: portAllocator,
webrtcPortAllocator: webrtcPortAllocator,
nodeSelector: nodeSelector, nodeSelector: nodeSelector,
systemdSpawner: systemdSpawner, systemdSpawner: systemdSpawner,
dnsManager: dnsManager, dnsManager: dnsManager,
@ -139,6 +143,7 @@ func NewClusterManagerWithComponents(
return &ClusterManager{ return &ClusterManager{
db: db, db: db,
portAllocator: portAllocator, portAllocator: portAllocator,
webrtcPortAllocator: NewWebRTCPortAllocator(db, logger),
nodeSelector: nodeSelector, nodeSelector: nodeSelector,
systemdSpawner: systemdSpawner, systemdSpawner: systemdSpawner,
dnsManager: NewDNSRecordManager(db, cfg.BaseDomain, logger), dnsManager: NewDNSRecordManager(db, cfg.BaseDomain, logger),
@ -854,12 +859,21 @@ func (cm *ClusterManager) DeprovisionCluster(ctx context.Context, namespaceID in
if err := cm.db.Query(ctx, &clusterNodes, nodeQuery, cluster.ID); err != nil { if err := cm.db.Query(ctx, &clusterNodes, nodeQuery, cluster.ID); err != nil {
cm.logger.Warn("Failed to query cluster nodes for deprovisioning, falling back to local-only stop", zap.Error(err)) cm.logger.Warn("Failed to query cluster nodes for deprovisioning, falling back to local-only stop", zap.Error(err))
// Fall back to local-only stop (individual methods, NOT StopAll which uses dangerous glob) // Fall back to local-only stop (individual methods, NOT StopAll which uses dangerous glob)
// Stop WebRTC services first (SFU → TURN), then core services (Gateway → Olric → RQLite)
cm.systemdSpawner.StopSFU(ctx, cluster.NamespaceName, cm.localNodeID)
cm.systemdSpawner.StopTURN(ctx, cluster.NamespaceName, cm.localNodeID)
cm.systemdSpawner.StopGateway(ctx, cluster.NamespaceName, cm.localNodeID) cm.systemdSpawner.StopGateway(ctx, cluster.NamespaceName, cm.localNodeID)
cm.systemdSpawner.StopOlric(ctx, cluster.NamespaceName, cm.localNodeID) cm.systemdSpawner.StopOlric(ctx, cluster.NamespaceName, cm.localNodeID)
cm.systemdSpawner.StopRQLite(ctx, cluster.NamespaceName, cm.localNodeID) cm.systemdSpawner.StopRQLite(ctx, cluster.NamespaceName, cm.localNodeID)
cm.systemdSpawner.DeleteClusterState(cluster.NamespaceName) cm.systemdSpawner.DeleteClusterState(cluster.NamespaceName)
} else { } else {
// 2. Stop namespace infra on ALL nodes (reverse dependency order: Gateway → Olric → RQLite) // 2. Stop WebRTC services first (SFU → TURN), then core infra (Gateway → Olric → RQLite)
for _, node := range clusterNodes {
cm.stopSFUOnNode(ctx, node.NodeID, node.InternalIP, cluster.NamespaceName)
}
for _, node := range clusterNodes {
cm.stopTURNOnNode(ctx, node.NodeID, node.InternalIP, cluster.NamespaceName)
}
for _, node := range clusterNodes { for _, node := range clusterNodes {
cm.stopGatewayOnNode(ctx, node.NodeID, node.InternalIP, cluster.NamespaceName) cm.stopGatewayOnNode(ctx, node.NodeID, node.InternalIP, cluster.NamespaceName)
} }
@ -880,16 +894,21 @@ func (cm *ClusterManager) DeprovisionCluster(ctx context.Context, namespaceID in
} }
} }
// 4. Deallocate all ports // 4. Deallocate all ports (core + WebRTC)
cm.portAllocator.DeallocateAllPortBlocks(ctx, cluster.ID) cm.portAllocator.DeallocateAllPortBlocks(ctx, cluster.ID)
cm.webrtcPortAllocator.DeallocateAll(ctx, cluster.ID)
// 5. Delete namespace DNS records // 5. Delete namespace DNS records (gateway + TURN)
cm.dnsManager.DeleteNamespaceRecords(ctx, cluster.NamespaceName) cm.dnsManager.DeleteNamespaceRecords(ctx, cluster.NamespaceName)
cm.dnsManager.DeleteTURNRecords(ctx, cluster.NamespaceName)
// 6. Explicitly delete child tables (FK cascades disabled in rqlite) // 6. Explicitly delete child tables (FK cascades disabled in rqlite)
cm.db.Exec(ctx, `DELETE FROM namespace_cluster_events WHERE namespace_cluster_id = ?`, cluster.ID) cm.db.Exec(ctx, `DELETE FROM namespace_cluster_events WHERE namespace_cluster_id = ?`, cluster.ID)
cm.db.Exec(ctx, `DELETE FROM namespace_cluster_nodes WHERE namespace_cluster_id = ?`, cluster.ID) cm.db.Exec(ctx, `DELETE FROM namespace_cluster_nodes WHERE namespace_cluster_id = ?`, cluster.ID)
cm.db.Exec(ctx, `DELETE FROM namespace_port_allocations WHERE namespace_cluster_id = ?`, cluster.ID) cm.db.Exec(ctx, `DELETE FROM namespace_port_allocations WHERE namespace_cluster_id = ?`, cluster.ID)
cm.db.Exec(ctx, `DELETE FROM webrtc_port_allocations WHERE namespace_cluster_id = ?`, cluster.ID)
cm.db.Exec(ctx, `DELETE FROM webrtc_rooms WHERE namespace_cluster_id = ?`, cluster.ID)
cm.db.Exec(ctx, `DELETE FROM namespace_webrtc_config WHERE namespace_cluster_id = ?`, cluster.ID)
// 7. Delete cluster record // 7. Delete cluster record
cm.db.Exec(ctx, `DELETE FROM namespace_clusters WHERE id = ?`, cluster.ID) cm.db.Exec(ctx, `DELETE FROM namespace_clusters WHERE id = ?`, cluster.ID)
@ -1594,6 +1613,19 @@ type ClusterLocalState struct {
HasGateway bool `json:"has_gateway"` HasGateway bool `json:"has_gateway"`
BaseDomain string `json:"base_domain"` BaseDomain string `json:"base_domain"`
SavedAt time.Time `json:"saved_at"` SavedAt time.Time `json:"saved_at"`
// WebRTC fields (zero values when WebRTC not enabled — backward compatible)
HasSFU bool `json:"has_sfu,omitempty"`
HasTURN bool `json:"has_turn,omitempty"`
TURNSharedSecret string `json:"-"` // Never persisted to disk state file
TURNCredentialTTL int `json:"turn_credential_ttl,omitempty"`
SFUSignalingPort int `json:"sfu_signaling_port,omitempty"`
SFUMediaPortStart int `json:"sfu_media_port_start,omitempty"`
SFUMediaPortEnd int `json:"sfu_media_port_end,omitempty"`
TURNListenPort int `json:"turn_listen_port,omitempty"`
TURNTLSPort int `json:"turn_tls_port,omitempty"`
TURNRelayPortStart int `json:"turn_relay_port_start,omitempty"`
TURNRelayPortEnd int `json:"turn_relay_port_end,omitempty"`
} }
type ClusterLocalStatePorts struct { type ClusterLocalStatePorts struct {
@ -1891,6 +1923,70 @@ func (cm *ClusterManager) restoreClusterFromState(ctx context.Context, state *Cl
} }
} }
// 4. Restore TURN (if enabled)
if state.HasTURN && state.TURNRelayPortStart > 0 {
turnRunning, _ := cm.systemdSpawner.systemdMgr.IsServiceActive(state.NamespaceName, systemd.ServiceTypeTURN)
if !turnRunning {
// TURN config needs the shared secret from DB — we can't persist it to disk state.
// If DB is available, fetch it; otherwise skip TURN restore (it will come back when DB is ready).
webrtcCfg, err := cm.GetWebRTCConfig(ctx, state.NamespaceName)
if err == nil && webrtcCfg != nil {
turnCfg := TURNInstanceConfig{
Namespace: state.NamespaceName,
NodeID: cm.localNodeID,
ListenAddr: fmt.Sprintf("0.0.0.0:%d", state.TURNListenPort),
TLSListenAddr: fmt.Sprintf("0.0.0.0:%d", state.TURNTLSPort),
PublicIP: "", // Will be resolved by spawner or from node info
Realm: cm.baseDomain,
AuthSecret: webrtcCfg.TURNSharedSecret,
RelayPortStart: state.TURNRelayPortStart,
RelayPortEnd: state.TURNRelayPortEnd,
}
if err := cm.systemdSpawner.SpawnTURN(ctx, state.NamespaceName, cm.localNodeID, turnCfg); err != nil {
cm.logger.Error("Failed to restore TURN from state", zap.String("namespace", state.NamespaceName), zap.Error(err))
} else {
cm.logger.Info("Restored TURN instance from state", zap.String("namespace", state.NamespaceName))
}
} else {
cm.logger.Warn("Skipping TURN restore: WebRTC config not available from DB",
zap.String("namespace", state.NamespaceName))
}
}
}
// 5. Restore SFU (if enabled)
if state.HasSFU && state.SFUSignalingPort > 0 {
sfuRunning, _ := cm.systemdSpawner.systemdMgr.IsServiceActive(state.NamespaceName, systemd.ServiceTypeSFU)
if !sfuRunning {
webrtcCfg, err := cm.GetWebRTCConfig(ctx, state.NamespaceName)
if err == nil && webrtcCfg != nil {
turnDomain := fmt.Sprintf("turn.ns-%s.%s", state.NamespaceName, cm.baseDomain)
sfuCfg := SFUInstanceConfig{
Namespace: state.NamespaceName,
NodeID: cm.localNodeID,
ListenAddr: fmt.Sprintf("%s:%d", localIP, state.SFUSignalingPort),
MediaPortStart: state.SFUMediaPortStart,
MediaPortEnd: state.SFUMediaPortEnd,
TURNServers: []sfu.TURNServerConfig{
{Host: turnDomain, Port: TURNDefaultPort},
{Host: turnDomain, Port: TURNTLSPort},
},
TURNSecret: webrtcCfg.TURNSharedSecret,
TURNCredTTL: webrtcCfg.TURNCredentialTTL,
RQLiteDSN: fmt.Sprintf("http://localhost:%d", pb.RQLiteHTTPPort),
}
if err := cm.systemdSpawner.SpawnSFU(ctx, state.NamespaceName, cm.localNodeID, sfuCfg); err != nil {
cm.logger.Error("Failed to restore SFU from state", zap.String("namespace", state.NamespaceName), zap.Error(err))
} else {
cm.logger.Info("Restored SFU instance from state", zap.String("namespace", state.NamespaceName))
}
} else {
cm.logger.Warn("Skipping SFU restore: WebRTC config not available from DB",
zap.String("namespace", state.NamespaceName))
}
}
}
return nil return nil
} }

View File

@ -0,0 +1,616 @@
package namespace
import (
"context"
"crypto/rand"
"encoding/base64"
"fmt"
"time"
"github.com/DeBrosOfficial/network/pkg/client"
"github.com/DeBrosOfficial/network/pkg/sfu"
"github.com/google/uuid"
"go.uber.org/zap"
)
// EnableWebRTC enables WebRTC (SFU + TURN) for an existing namespace cluster.
// Allocates ports, spawns SFU on all 3 nodes and TURN on 2 nodes,
// creates TURN DNS records, and updates cluster state.
func (cm *ClusterManager) EnableWebRTC(ctx context.Context, namespaceName, enabledBy string) error {
internalCtx := client.WithInternalAuth(ctx)
// 1. Verify cluster exists and is ready
cluster, err := cm.GetClusterByNamespace(ctx, namespaceName)
if err != nil {
return fmt.Errorf("failed to get cluster: %w", err)
}
if cluster == nil {
return ErrClusterNotFound
}
if cluster.Status != ClusterStatusReady {
return &ClusterError{Message: fmt.Sprintf("cluster status is %q, must be %q to enable WebRTC", cluster.Status, ClusterStatusReady)}
}
// 2. Check if WebRTC is already enabled
var existingConfigs []WebRTCConfig
if err := cm.db.Query(internalCtx, &existingConfigs,
`SELECT * FROM namespace_webrtc_config WHERE namespace_cluster_id = ? AND enabled = 1`, cluster.ID); err == nil && len(existingConfigs) > 0 {
return ErrWebRTCAlreadyEnabled
}
cm.logger.Info("Enabling WebRTC for namespace",
zap.String("namespace", namespaceName),
zap.String("cluster_id", cluster.ID),
)
// 3. Generate TURN shared secret (32 bytes, crypto/rand)
secretBytes := make([]byte, 32)
if _, err := rand.Read(secretBytes); err != nil {
return fmt.Errorf("failed to generate TURN secret: %w", err)
}
turnSecret := base64.StdEncoding.EncodeToString(secretBytes)
// 4. Insert namespace_webrtc_config
webrtcConfigID := uuid.New().String()
_, err = cm.db.Exec(internalCtx,
`INSERT INTO namespace_webrtc_config (id, namespace_cluster_id, namespace_name, enabled, turn_shared_secret, turn_credential_ttl, sfu_node_count, turn_node_count, enabled_by, enabled_at)
VALUES (?, ?, ?, 1, ?, ?, ?, ?, ?, ?)`,
webrtcConfigID, cluster.ID, namespaceName,
turnSecret, DefaultTURNCredentialTTL,
DefaultSFUNodeCount, DefaultTURNNodeCount,
enabledBy, time.Now(),
)
if err != nil {
return fmt.Errorf("failed to insert WebRTC config: %w", err)
}
// 5. Get cluster nodes with IPs
clusterNodes, err := cm.getClusterNodesWithIPs(ctx, cluster.ID)
if err != nil {
return fmt.Errorf("failed to get cluster nodes: %w", err)
}
if len(clusterNodes) < 3 {
return fmt.Errorf("cluster has %d nodes, need at least 3 for WebRTC", len(clusterNodes))
}
// 6. Allocate SFU ports on all nodes
sfuBlocks := make(map[string]*WebRTCPortBlock) // nodeID -> block
for _, node := range clusterNodes {
block, err := cm.webrtcPortAllocator.AllocateSFUPorts(ctx, node.NodeID, cluster.ID)
if err != nil {
cm.cleanupWebRTCOnError(ctx, cluster.ID, namespaceName, clusterNodes)
return fmt.Errorf("failed to allocate SFU ports on node %s: %w", node.NodeID, err)
}
sfuBlocks[node.NodeID] = block
}
// 7. Select TURN nodes (prefer nodes without existing TURN allocations)
turnNodes := cm.selectTURNNodes(ctx, clusterNodes, DefaultTURNNodeCount)
// 8. Allocate TURN ports on selected nodes
turnBlocks := make(map[string]*WebRTCPortBlock) // nodeID -> block
for _, node := range turnNodes {
block, err := cm.webrtcPortAllocator.AllocateTURNPorts(ctx, node.NodeID, cluster.ID)
if err != nil {
cm.cleanupWebRTCOnError(ctx, cluster.ID, namespaceName, clusterNodes)
return fmt.Errorf("failed to allocate TURN ports on node %s: %w", node.NodeID, err)
}
turnBlocks[node.NodeID] = block
}
// 9. Build TURN server list for SFU config
turnDomain := fmt.Sprintf("turn.ns-%s.%s", namespaceName, cm.baseDomain)
turnServers := []sfu.TURNServerConfig{
{Host: turnDomain, Port: TURNDefaultPort},
{Host: turnDomain, Port: TURNTLSPort},
}
// 10. Get port blocks for RQLite DSN
portBlocks, err := cm.portAllocator.GetAllPortBlocks(ctx, cluster.ID)
if err != nil {
cm.cleanupWebRTCOnError(ctx, cluster.ID, namespaceName, clusterNodes)
return fmt.Errorf("failed to get port blocks: %w", err)
}
// Build nodeID -> PortBlock map
nodePortBlocks := make(map[string]*PortBlock)
for i := range portBlocks {
nodePortBlocks[portBlocks[i].NodeID] = &portBlocks[i]
}
// 11. Spawn TURN on selected nodes
for _, node := range turnNodes {
turnBlock := turnBlocks[node.NodeID]
turnCfg := TURNInstanceConfig{
Namespace: namespaceName,
NodeID: node.NodeID,
ListenAddr: fmt.Sprintf("0.0.0.0:%d", turnBlock.TURNListenPort),
TLSListenAddr: fmt.Sprintf("0.0.0.0:%d", turnBlock.TURNTLSPort),
PublicIP: node.PublicIP,
Realm: cm.baseDomain,
AuthSecret: turnSecret,
RelayPortStart: turnBlock.TURNRelayPortStart,
RelayPortEnd: turnBlock.TURNRelayPortEnd,
}
if err := cm.spawnTURNOnNode(ctx, node, namespaceName, turnCfg); err != nil {
cm.logger.Error("Failed to spawn TURN",
zap.String("namespace", namespaceName),
zap.String("node_id", node.NodeID),
zap.Error(err))
cm.cleanupWebRTCOnError(ctx, cluster.ID, namespaceName, clusterNodes)
return fmt.Errorf("failed to spawn TURN on node %s: %w", node.NodeID, err)
}
cm.logEvent(ctx, cluster.ID, EventTURNStarted, node.NodeID,
fmt.Sprintf("TURN started on %s (relay ports %d-%d)", node.NodeID, turnBlock.TURNRelayPortStart, turnBlock.TURNRelayPortEnd), nil)
}
// 12. Spawn SFU on all nodes
for _, node := range clusterNodes {
sfuBlock := sfuBlocks[node.NodeID]
pb := nodePortBlocks[node.NodeID]
rqliteDSN := fmt.Sprintf("http://localhost:%d", pb.RQLiteHTTPPort)
sfuCfg := SFUInstanceConfig{
Namespace: namespaceName,
NodeID: node.NodeID,
ListenAddr: fmt.Sprintf("%s:%d", node.InternalIP, sfuBlock.SFUSignalingPort),
MediaPortStart: sfuBlock.SFUMediaPortStart,
MediaPortEnd: sfuBlock.SFUMediaPortEnd,
TURNServers: turnServers,
TURNSecret: turnSecret,
TURNCredTTL: DefaultTURNCredentialTTL,
RQLiteDSN: rqliteDSN,
}
if err := cm.spawnSFUOnNode(ctx, node, namespaceName, sfuCfg); err != nil {
cm.logger.Error("Failed to spawn SFU",
zap.String("namespace", namespaceName),
zap.String("node_id", node.NodeID),
zap.Error(err))
cm.cleanupWebRTCOnError(ctx, cluster.ID, namespaceName, clusterNodes)
return fmt.Errorf("failed to spawn SFU on node %s: %w", node.NodeID, err)
}
cm.logEvent(ctx, cluster.ID, EventSFUStarted, node.NodeID,
fmt.Sprintf("SFU started on %s:%d", node.InternalIP, sfuBlock.SFUSignalingPort), nil)
}
// 13. Create TURN DNS records
var turnIPs []string
for _, node := range turnNodes {
turnIPs = append(turnIPs, node.PublicIP)
}
if err := cm.dnsManager.CreateTURNRecords(ctx, namespaceName, turnIPs); err != nil {
cm.logger.Warn("Failed to create TURN DNS records",
zap.String("namespace", namespaceName),
zap.Error(err))
}
// 14. Update cluster-state.json on all nodes with WebRTC info
cm.updateClusterStateWithWebRTC(ctx, cluster, clusterNodes, sfuBlocks, turnBlocks)
cm.logEvent(ctx, cluster.ID, EventWebRTCEnabled, "",
fmt.Sprintf("WebRTC enabled: SFU on %d nodes, TURN on %d nodes", len(clusterNodes), len(turnNodes)), nil)
cm.logger.Info("WebRTC enabled successfully",
zap.String("namespace", namespaceName),
zap.String("cluster_id", cluster.ID),
zap.Int("sfu_nodes", len(clusterNodes)),
zap.Int("turn_nodes", len(turnNodes)),
)
return nil
}
// DisableWebRTC disables WebRTC for a namespace cluster.
// Stops SFU/TURN services, deallocates ports, and cleans up DNS/DB.
func (cm *ClusterManager) DisableWebRTC(ctx context.Context, namespaceName string) error {
internalCtx := client.WithInternalAuth(ctx)
// 1. Verify cluster exists
cluster, err := cm.GetClusterByNamespace(ctx, namespaceName)
if err != nil {
return fmt.Errorf("failed to get cluster: %w", err)
}
if cluster == nil {
return ErrClusterNotFound
}
// 2. Verify WebRTC is enabled
var configs []WebRTCConfig
if err := cm.db.Query(internalCtx, &configs,
`SELECT * FROM namespace_webrtc_config WHERE namespace_cluster_id = ? AND enabled = 1`, cluster.ID); err != nil || len(configs) == 0 {
return ErrWebRTCNotEnabled
}
cm.logger.Info("Disabling WebRTC for namespace",
zap.String("namespace", namespaceName),
zap.String("cluster_id", cluster.ID),
)
// 3. Get cluster nodes with IPs
clusterNodes, err := cm.getClusterNodesWithIPs(ctx, cluster.ID)
if err != nil {
return fmt.Errorf("failed to get cluster nodes: %w", err)
}
// 4. Stop SFU on all nodes
for _, node := range clusterNodes {
cm.stopSFUOnNode(ctx, node.NodeID, node.InternalIP, namespaceName)
cm.logEvent(ctx, cluster.ID, EventSFUStopped, node.NodeID, "SFU stopped", nil)
}
// 5. Stop TURN on nodes that have TURN allocations
turnBlocks, _ := cm.getWebRTCBlocksByType(ctx, cluster.ID, "turn")
for _, block := range turnBlocks {
nodeIP := cm.getNodeIP(clusterNodes, block.NodeID)
cm.stopTURNOnNode(ctx, block.NodeID, nodeIP, namespaceName)
cm.logEvent(ctx, cluster.ID, EventTURNStopped, block.NodeID, "TURN stopped", nil)
}
// 6. Deallocate all WebRTC ports
if err := cm.webrtcPortAllocator.DeallocateAll(ctx, cluster.ID); err != nil {
cm.logger.Warn("Failed to deallocate WebRTC ports", zap.Error(err))
}
// 7. Delete TURN DNS records
if err := cm.dnsManager.DeleteTURNRecords(ctx, namespaceName); err != nil {
cm.logger.Warn("Failed to delete TURN DNS records", zap.Error(err))
}
// 8. Clean up DB tables
cm.db.Exec(internalCtx, `DELETE FROM webrtc_rooms WHERE namespace_cluster_id = ?`, cluster.ID)
cm.db.Exec(internalCtx, `DELETE FROM namespace_webrtc_config WHERE namespace_cluster_id = ?`, cluster.ID)
// 9. Update cluster-state.json to remove WebRTC info
cm.updateClusterStateWithWebRTC(ctx, cluster, clusterNodes, nil, nil)
cm.logEvent(ctx, cluster.ID, EventWebRTCDisabled, "", "WebRTC disabled", nil)
cm.logger.Info("WebRTC disabled successfully",
zap.String("namespace", namespaceName),
zap.String("cluster_id", cluster.ID),
)
return nil
}
// GetWebRTCConfig returns the WebRTC configuration for a namespace.
func (cm *ClusterManager) GetWebRTCConfig(ctx context.Context, namespaceName string) (*WebRTCConfig, error) {
internalCtx := client.WithInternalAuth(ctx)
var configs []WebRTCConfig
err := cm.db.Query(internalCtx, &configs,
`SELECT * FROM namespace_webrtc_config WHERE namespace_name = ? AND enabled = 1`, namespaceName)
if err != nil {
return nil, fmt.Errorf("failed to query WebRTC config: %w", err)
}
if len(configs) == 0 {
return nil, nil
}
return &configs[0], nil
}
// GetWebRTCStatus returns the WebRTC config as an interface{} for the WebRTCManager interface.
func (cm *ClusterManager) GetWebRTCStatus(ctx context.Context, namespaceName string) (interface{}, error) {
cfg, err := cm.GetWebRTCConfig(ctx, namespaceName)
if err != nil {
return nil, err
}
if cfg == nil {
return nil, nil
}
return cfg, nil
}
// --- Internal helpers ---
// clusterNodeInfo holds node info needed for WebRTC operations
type clusterNodeInfo struct {
NodeID string
InternalIP string // WireGuard IP
PublicIP string // Public IP for TURN
}
// getClusterNodesWithIPs returns cluster nodes with both internal and public IPs.
func (cm *ClusterManager) getClusterNodesWithIPs(ctx context.Context, clusterID string) ([]clusterNodeInfo, error) {
internalCtx := client.WithInternalAuth(ctx)
type nodeRow struct {
NodeID string `db:"node_id"`
InternalIP string `db:"internal_ip"`
PublicIP string `db:"public_ip"`
}
var rows []nodeRow
query := `
SELECT ncn.node_id,
COALESCE(dn.internal_ip, dn.ip_address) as internal_ip,
dn.ip_address as public_ip
FROM namespace_cluster_nodes ncn
JOIN dns_nodes dn ON ncn.node_id = dn.id
WHERE ncn.namespace_cluster_id = ?
`
if err := cm.db.Query(internalCtx, &rows, query, clusterID); err != nil {
return nil, err
}
nodes := make([]clusterNodeInfo, len(rows))
for i, r := range rows {
nodes[i] = clusterNodeInfo{
NodeID: r.NodeID,
InternalIP: r.InternalIP,
PublicIP: r.PublicIP,
}
}
return nodes, nil
}
// selectTURNNodes selects the best N nodes for TURN, preferring nodes without existing TURN allocations.
func (cm *ClusterManager) selectTURNNodes(ctx context.Context, nodes []clusterNodeInfo, count int) []clusterNodeInfo {
if count >= len(nodes) {
return nodes
}
// Prefer nodes without existing TURN allocations
var preferred, fallback []clusterNodeInfo
for _, node := range nodes {
hasTURN, err := cm.webrtcPortAllocator.NodeHasTURN(ctx, node.NodeID)
if err != nil || !hasTURN {
preferred = append(preferred, node)
} else {
fallback = append(fallback, node)
}
}
// Take from preferred first, then fallback
result := make([]clusterNodeInfo, 0, count)
for _, node := range preferred {
if len(result) >= count {
break
}
result = append(result, node)
}
for _, node := range fallback {
if len(result) >= count {
break
}
result = append(result, node)
}
return result
}
// spawnSFUOnNode spawns SFU on a node (local or remote)
func (cm *ClusterManager) spawnSFUOnNode(ctx context.Context, node clusterNodeInfo, namespace string, cfg SFUInstanceConfig) error {
if node.NodeID == cm.localNodeID {
return cm.systemdSpawner.SpawnSFU(ctx, namespace, node.NodeID, cfg)
}
return cm.spawnSFURemote(ctx, node.InternalIP, cfg)
}
// spawnTURNOnNode spawns TURN on a node (local or remote)
func (cm *ClusterManager) spawnTURNOnNode(ctx context.Context, node clusterNodeInfo, namespace string, cfg TURNInstanceConfig) error {
if node.NodeID == cm.localNodeID {
return cm.systemdSpawner.SpawnTURN(ctx, namespace, node.NodeID, cfg)
}
return cm.spawnTURNRemote(ctx, node.InternalIP, cfg)
}
// stopSFUOnNode stops SFU on a node (local or remote)
func (cm *ClusterManager) stopSFUOnNode(ctx context.Context, nodeID, nodeIP, namespace string) {
if nodeID == cm.localNodeID {
cm.systemdSpawner.StopSFU(ctx, namespace, nodeID)
} else {
cm.sendStopRequest(ctx, nodeIP, "stop-sfu", namespace, nodeID)
}
}
// stopTURNOnNode stops TURN on a node (local or remote)
func (cm *ClusterManager) stopTURNOnNode(ctx context.Context, nodeID, nodeIP, namespace string) {
if nodeID == cm.localNodeID {
cm.systemdSpawner.StopTURN(ctx, namespace, nodeID)
} else {
cm.sendStopRequest(ctx, nodeIP, "stop-turn", namespace, nodeID)
}
}
// spawnSFURemote sends a spawn-sfu request to a remote node
func (cm *ClusterManager) spawnSFURemote(ctx context.Context, nodeIP string, cfg SFUInstanceConfig) error {
// Serialize TURN servers for transport
turnServers := make([]map[string]interface{}, len(cfg.TURNServers))
for i, ts := range cfg.TURNServers {
turnServers[i] = map[string]interface{}{
"host": ts.Host,
"port": ts.Port,
}
}
_, err := cm.sendSpawnRequest(ctx, nodeIP, map[string]interface{}{
"action": "spawn-sfu",
"namespace": cfg.Namespace,
"node_id": cfg.NodeID,
"sfu_listen_addr": cfg.ListenAddr,
"sfu_media_start": cfg.MediaPortStart,
"sfu_media_end": cfg.MediaPortEnd,
"turn_servers": turnServers,
"turn_secret": cfg.TURNSecret,
"turn_cred_ttl": cfg.TURNCredTTL,
"rqlite_dsn": cfg.RQLiteDSN,
})
return err
}
// spawnTURNRemote sends a spawn-turn request to a remote node
func (cm *ClusterManager) spawnTURNRemote(ctx context.Context, nodeIP string, cfg TURNInstanceConfig) error {
_, err := cm.sendSpawnRequest(ctx, nodeIP, map[string]interface{}{
"action": "spawn-turn",
"namespace": cfg.Namespace,
"node_id": cfg.NodeID,
"turn_listen_addr": cfg.ListenAddr,
"turn_tls_addr": cfg.TLSListenAddr,
"turn_public_ip": cfg.PublicIP,
"turn_realm": cfg.Realm,
"turn_auth_secret": cfg.AuthSecret,
"turn_relay_start": cfg.RelayPortStart,
"turn_relay_end": cfg.RelayPortEnd,
})
return err
}
// getWebRTCBlocksByType returns all WebRTC port blocks of a given type for a cluster.
func (cm *ClusterManager) getWebRTCBlocksByType(ctx context.Context, clusterID, serviceType string) ([]WebRTCPortBlock, error) {
allBlocks, err := cm.webrtcPortAllocator.GetAllPorts(ctx, clusterID)
if err != nil {
return nil, err
}
var filtered []WebRTCPortBlock
for _, b := range allBlocks {
if b.ServiceType == serviceType {
filtered = append(filtered, b)
}
}
return filtered, nil
}
// getNodeIP looks up the internal IP for a node ID from a list.
func (cm *ClusterManager) getNodeIP(nodes []clusterNodeInfo, nodeID string) string {
for _, n := range nodes {
if n.NodeID == nodeID {
return n.InternalIP
}
}
return ""
}
// cleanupWebRTCOnError cleans up partial WebRTC allocations when EnableWebRTC fails mid-way.
func (cm *ClusterManager) cleanupWebRTCOnError(ctx context.Context, clusterID, namespaceName string, nodes []clusterNodeInfo) {
cm.logger.Warn("Cleaning up partial WebRTC enablement",
zap.String("namespace", namespaceName),
zap.String("cluster_id", clusterID))
internalCtx := client.WithInternalAuth(ctx)
// Stop any spawned SFU/TURN services
for _, node := range nodes {
cm.stopSFUOnNode(ctx, node.NodeID, node.InternalIP, namespaceName)
cm.stopTURNOnNode(ctx, node.NodeID, node.InternalIP, namespaceName)
}
// Deallocate ports
cm.webrtcPortAllocator.DeallocateAll(ctx, clusterID)
// Remove config row
cm.db.Exec(internalCtx, `DELETE FROM namespace_webrtc_config WHERE namespace_cluster_id = ?`, clusterID)
}
// updateClusterStateWithWebRTC updates the cluster-state.json on all nodes
// to include (or remove) WebRTC port information.
// Pass nil maps to clear WebRTC state (when disabling).
func (cm *ClusterManager) updateClusterStateWithWebRTC(
ctx context.Context,
cluster *NamespaceCluster,
nodes []clusterNodeInfo,
sfuBlocks map[string]*WebRTCPortBlock,
turnBlocks map[string]*WebRTCPortBlock,
) {
// Get existing port blocks for base state
portBlocks, err := cm.portAllocator.GetAllPortBlocks(ctx, cluster.ID)
if err != nil {
cm.logger.Warn("Failed to get port blocks for state update", zap.Error(err))
return
}
// Build nodeID -> PortBlock map
nodePortMap := make(map[string]*PortBlock)
for i := range portBlocks {
nodePortMap[portBlocks[i].NodeID] = &portBlocks[i]
}
// Build AllNodes list
var allStateNodes []ClusterLocalStateNode
for _, node := range nodes {
pb := nodePortMap[node.NodeID]
if pb == nil {
continue
}
allStateNodes = append(allStateNodes, ClusterLocalStateNode{
NodeID: node.NodeID,
InternalIP: node.InternalIP,
RQLiteHTTPPort: pb.RQLiteHTTPPort,
RQLiteRaftPort: pb.RQLiteRaftPort,
OlricHTTPPort: pb.OlricHTTPPort,
OlricMemberlistPort: pb.OlricMemberlistPort,
})
}
// Save state on each node
for _, node := range nodes {
pb := nodePortMap[node.NodeID]
if pb == nil {
continue
}
state := &ClusterLocalState{
ClusterID: cluster.ID,
NamespaceName: cluster.NamespaceName,
LocalNodeID: node.NodeID,
LocalIP: node.InternalIP,
LocalPorts: ClusterLocalStatePorts{
RQLiteHTTPPort: pb.RQLiteHTTPPort,
RQLiteRaftPort: pb.RQLiteRaftPort,
OlricHTTPPort: pb.OlricHTTPPort,
OlricMemberlistPort: pb.OlricMemberlistPort,
GatewayHTTPPort: pb.GatewayHTTPPort,
},
AllNodes: allStateNodes,
HasGateway: true,
BaseDomain: cm.baseDomain,
SavedAt: time.Now(),
}
// Add WebRTC fields if enabling
if sfuBlocks != nil {
if sfuBlock, ok := sfuBlocks[node.NodeID]; ok {
state.HasSFU = true
state.SFUSignalingPort = sfuBlock.SFUSignalingPort
state.SFUMediaPortStart = sfuBlock.SFUMediaPortStart
state.SFUMediaPortEnd = sfuBlock.SFUMediaPortEnd
}
}
if turnBlocks != nil {
if turnBlock, ok := turnBlocks[node.NodeID]; ok {
state.HasTURN = true
state.TURNListenPort = turnBlock.TURNListenPort
state.TURNTLSPort = turnBlock.TURNTLSPort
state.TURNRelayPortStart = turnBlock.TURNRelayPortStart
state.TURNRelayPortEnd = turnBlock.TURNRelayPortEnd
}
}
if node.NodeID == cm.localNodeID {
if err := cm.saveLocalState(state); err != nil {
cm.logger.Warn("Failed to save local cluster state",
zap.String("namespace", cluster.NamespaceName),
zap.Error(err))
}
} else {
cm.saveRemoteState(ctx, node.InternalIP, cluster.NamespaceName, state)
}
}
}
// saveRemoteState sends cluster state to a remote node for persistence.
func (cm *ClusterManager) saveRemoteState(ctx context.Context, nodeIP, namespace string, state *ClusterLocalState) {
_, err := cm.sendSpawnRequest(ctx, nodeIP, map[string]interface{}{
"action": "save-cluster-state",
"namespace": namespace,
"cluster_state": state,
})
if err != nil {
cm.logger.Warn("Failed to save cluster state on remote node",
zap.String("node_ip", nodeIP),
zap.Error(err))
}
}

View File

@ -300,6 +300,78 @@ func (drm *DNSRecordManager) DisableNamespaceRecord(ctx context.Context, namespa
return nil return nil
} }
// CreateTURNRecords creates DNS A records for TURN servers.
// TURN records follow the pattern: turn.ns-{namespace}.{baseDomain} -> TURN node IPs
func (drm *DNSRecordManager) CreateTURNRecords(ctx context.Context, namespaceName string, turnIPs []string) error {
internalCtx := client.WithInternalAuth(ctx)
if len(turnIPs) == 0 {
return &ClusterError{Message: "no TURN IPs provided for DNS records"}
}
fqdn := fmt.Sprintf("turn.ns-%s.%s.", namespaceName, drm.baseDomain)
drm.logger.Info("Creating TURN DNS records",
zap.String("namespace", namespaceName),
zap.String("fqdn", fqdn),
zap.Strings("turn_ips", turnIPs),
)
// Delete existing TURN records for this namespace
deleteQuery := `DELETE FROM dns_records WHERE fqdn = ? AND namespace = ?`
_, _ = drm.db.Exec(internalCtx, deleteQuery, fqdn, "namespace-turn:"+namespaceName)
// Create A records for each TURN node IP
now := time.Now()
for _, ip := range turnIPs {
recordID := uuid.New().String()
insertQuery := `
INSERT INTO dns_records (
id, fqdn, record_type, value, ttl, namespace, created_by, is_active, created_at, updated_at
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`
_, err := drm.db.Exec(internalCtx, insertQuery,
recordID, fqdn, "A", ip, 60,
"namespace-turn:"+namespaceName,
"cluster-manager",
true, now, now,
)
if err != nil {
return &ClusterError{
Message: fmt.Sprintf("failed to create TURN DNS record %s -> %s", fqdn, ip),
Cause: err,
}
}
}
drm.logger.Info("TURN DNS records created",
zap.String("namespace", namespaceName),
zap.Int("record_count", len(turnIPs)),
)
return nil
}
// DeleteTURNRecords deletes all TURN DNS records for a namespace.
func (drm *DNSRecordManager) DeleteTURNRecords(ctx context.Context, namespaceName string) error {
internalCtx := client.WithInternalAuth(ctx)
drm.logger.Info("Deleting TURN DNS records",
zap.String("namespace", namespaceName),
)
deleteQuery := `DELETE FROM dns_records WHERE namespace = ?`
_, err := drm.db.Exec(internalCtx, deleteQuery, "namespace-turn:"+namespaceName)
if err != nil {
return &ClusterError{
Message: "failed to delete TURN DNS records",
Cause: err,
}
}
return nil
}
// EnableNamespaceRecord marks a specific IP's record as active (for recovery) // EnableNamespaceRecord marks a specific IP's record as active (for recovery)
func (drm *DNSRecordManager) EnableNamespaceRecord(ctx context.Context, namespaceName, ip string) error { func (drm *DNSRecordManager) EnableNamespaceRecord(ctx context.Context, namespaceName, ip string) error {
internalCtx := client.WithInternalAuth(ctx) internalCtx := client.WithInternalAuth(ctx)

View File

@ -10,7 +10,9 @@ import (
"github.com/DeBrosOfficial/network/pkg/gateway" "github.com/DeBrosOfficial/network/pkg/gateway"
"github.com/DeBrosOfficial/network/pkg/olric" "github.com/DeBrosOfficial/network/pkg/olric"
"github.com/DeBrosOfficial/network/pkg/rqlite" "github.com/DeBrosOfficial/network/pkg/rqlite"
"github.com/DeBrosOfficial/network/pkg/sfu"
"github.com/DeBrosOfficial/network/pkg/systemd" "github.com/DeBrosOfficial/network/pkg/systemd"
"github.com/DeBrosOfficial/network/pkg/turn"
"go.uber.org/zap" "go.uber.org/zap"
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
) )
@ -289,6 +291,185 @@ func (s *SystemdSpawner) StopGateway(ctx context.Context, namespace, nodeID stri
return s.systemdMgr.StopService(namespace, systemd.ServiceTypeGateway) return s.systemdMgr.StopService(namespace, systemd.ServiceTypeGateway)
} }
// SFUInstanceConfig holds configuration for spawning an SFU instance
type SFUInstanceConfig struct {
Namespace string
NodeID string
ListenAddr string // WireGuard IP:port (e.g., "10.0.0.1:30000")
MediaPortStart int // Start of RTP media port range
MediaPortEnd int // End of RTP media port range
TURNServers []sfu.TURNServerConfig // TURN servers to advertise to peers
TURNSecret string // HMAC-SHA1 shared secret
TURNCredTTL int // Credential TTL in seconds
RQLiteDSN string // Namespace-local RQLite DSN
}
// SpawnSFU starts an SFU instance using systemd
func (s *SystemdSpawner) SpawnSFU(ctx context.Context, namespace, nodeID string, cfg SFUInstanceConfig) error {
s.logger.Info("Spawning SFU via systemd",
zap.String("namespace", namespace),
zap.String("node_id", nodeID),
zap.String("listen_addr", cfg.ListenAddr))
// Create config directory
configDir := filepath.Join(s.namespaceBase, namespace, "configs")
if err := os.MkdirAll(configDir, 0755); err != nil {
return fmt.Errorf("failed to create config directory: %w", err)
}
configPath := filepath.Join(configDir, fmt.Sprintf("sfu-%s.yaml", nodeID))
// Build SFU YAML config
sfuConfig := sfu.Config{
ListenAddr: cfg.ListenAddr,
Namespace: cfg.Namespace,
MediaPortStart: cfg.MediaPortStart,
MediaPortEnd: cfg.MediaPortEnd,
TURNServers: cfg.TURNServers,
TURNSecret: cfg.TURNSecret,
TURNCredentialTTL: cfg.TURNCredTTL,
RQLiteDSN: cfg.RQLiteDSN,
}
configBytes, err := yaml.Marshal(sfuConfig)
if err != nil {
return fmt.Errorf("failed to marshal SFU config: %w", err)
}
if err := os.WriteFile(configPath, configBytes, 0644); err != nil {
return fmt.Errorf("failed to write SFU config: %w", err)
}
s.logger.Info("Created SFU config file",
zap.String("path", configPath),
zap.String("namespace", namespace),
zap.String("node_id", nodeID))
// Generate environment file pointing to config
envVars := map[string]string{
"SFU_CONFIG": configPath,
}
if err := s.systemdMgr.GenerateEnvFile(namespace, nodeID, systemd.ServiceTypeSFU, envVars); err != nil {
return fmt.Errorf("failed to generate SFU env file: %w", err)
}
// Start the systemd service
if err := s.systemdMgr.StartService(namespace, systemd.ServiceTypeSFU); err != nil {
return fmt.Errorf("failed to start SFU service: %w", err)
}
// Wait for service to be active
if err := s.waitForService(namespace, systemd.ServiceTypeSFU, 30*time.Second); err != nil {
return fmt.Errorf("SFU service did not become active: %w", err)
}
s.logger.Info("SFU spawned successfully via systemd",
zap.String("namespace", namespace),
zap.String("node_id", nodeID))
return nil
}
// StopSFU stops an SFU instance
func (s *SystemdSpawner) StopSFU(ctx context.Context, namespace, nodeID string) error {
s.logger.Info("Stopping SFU via systemd",
zap.String("namespace", namespace),
zap.String("node_id", nodeID))
return s.systemdMgr.StopService(namespace, systemd.ServiceTypeSFU)
}
// TURNInstanceConfig holds configuration for spawning a TURN instance
type TURNInstanceConfig struct {
Namespace string
NodeID string
ListenAddr string // e.g., "0.0.0.0:3478"
TLSListenAddr string // e.g., "0.0.0.0:443" (UDP, no conflict with Caddy TCP)
PublicIP string // Public IP for TURN relay allocations
Realm string // TURN realm (typically base domain)
AuthSecret string // HMAC-SHA1 shared secret
RelayPortStart int // Start of relay port range
RelayPortEnd int // End of relay port range
}
// SpawnTURN starts a TURN instance using systemd
func (s *SystemdSpawner) SpawnTURN(ctx context.Context, namespace, nodeID string, cfg TURNInstanceConfig) error {
s.logger.Info("Spawning TURN via systemd",
zap.String("namespace", namespace),
zap.String("node_id", nodeID),
zap.String("listen_addr", cfg.ListenAddr),
zap.String("public_ip", cfg.PublicIP))
// Create config directory
configDir := filepath.Join(s.namespaceBase, namespace, "configs")
if err := os.MkdirAll(configDir, 0755); err != nil {
return fmt.Errorf("failed to create config directory: %w", err)
}
configPath := filepath.Join(configDir, fmt.Sprintf("turn-%s.yaml", nodeID))
// Build TURN YAML config
turnConfig := turn.Config{
ListenAddr: cfg.ListenAddr,
TLSListenAddr: cfg.TLSListenAddr,
PublicIP: cfg.PublicIP,
Realm: cfg.Realm,
AuthSecret: cfg.AuthSecret,
RelayPortStart: cfg.RelayPortStart,
RelayPortEnd: cfg.RelayPortEnd,
Namespace: cfg.Namespace,
}
configBytes, err := yaml.Marshal(turnConfig)
if err != nil {
return fmt.Errorf("failed to marshal TURN config: %w", err)
}
if err := os.WriteFile(configPath, configBytes, 0644); err != nil {
return fmt.Errorf("failed to write TURN config: %w", err)
}
s.logger.Info("Created TURN config file",
zap.String("path", configPath),
zap.String("namespace", namespace),
zap.String("node_id", nodeID))
// Generate environment file pointing to config
envVars := map[string]string{
"TURN_CONFIG": configPath,
}
if err := s.systemdMgr.GenerateEnvFile(namespace, nodeID, systemd.ServiceTypeTURN, envVars); err != nil {
return fmt.Errorf("failed to generate TURN env file: %w", err)
}
// Start the systemd service
if err := s.systemdMgr.StartService(namespace, systemd.ServiceTypeTURN); err != nil {
return fmt.Errorf("failed to start TURN service: %w", err)
}
// Wait for service to be active
if err := s.waitForService(namespace, systemd.ServiceTypeTURN, 30*time.Second); err != nil {
return fmt.Errorf("TURN service did not become active: %w", err)
}
s.logger.Info("TURN spawned successfully via systemd",
zap.String("namespace", namespace),
zap.String("node_id", nodeID))
return nil
}
// StopTURN stops a TURN instance
func (s *SystemdSpawner) StopTURN(ctx context.Context, namespace, nodeID string) error {
s.logger.Info("Stopping TURN via systemd",
zap.String("namespace", namespace),
zap.String("node_id", nodeID))
return s.systemdMgr.StopService(namespace, systemd.ServiceTypeTURN)
}
// SaveClusterState writes cluster state JSON to the namespace data directory. // SaveClusterState writes cluster state JSON to the namespace data directory.
// Used by the spawn handler to persist state received from the coordinator node. // Used by the spawn handler to persist state received from the coordinator node.
func (s *SystemdSpawner) SaveClusterState(namespace string, data []byte) error { func (s *SystemdSpawner) SaveClusterState(namespace string, data []byte) error {

View File

@ -24,6 +24,8 @@ const (
NodeRoleRQLiteFollower NodeRole = "rqlite_follower" NodeRoleRQLiteFollower NodeRole = "rqlite_follower"
NodeRoleOlric NodeRole = "olric" NodeRoleOlric NodeRole = "olric"
NodeRoleGateway NodeRole = "gateway" NodeRoleGateway NodeRole = "gateway"
NodeRoleSFU NodeRole = "sfu"
NodeRoleTURN NodeRole = "turn"
) )
// NodeStatus represents the status of a service on a node // NodeStatus represents the status of a service on a node
@ -62,6 +64,12 @@ const (
EventNodeReplaced EventType = "node_replaced" EventNodeReplaced EventType = "node_replaced"
EventRecoveryComplete EventType = "recovery_complete" EventRecoveryComplete EventType = "recovery_complete"
EventRecoveryFailed EventType = "recovery_failed" EventRecoveryFailed EventType = "recovery_failed"
EventWebRTCEnabled EventType = "webrtc_enabled"
EventWebRTCDisabled EventType = "webrtc_disabled"
EventSFUStarted EventType = "sfu_started"
EventSFUStopped EventType = "sfu_stopped"
EventTURNStarted EventType = "turn_started"
EventTURNStopped EventType = "turn_stopped"
) )
// Port allocation constants // Port allocation constants
@ -80,6 +88,39 @@ const (
MaxNamespacesPerNode = (NamespacePortRangeEnd - NamespacePortRangeStart + 1) / PortsPerNamespace // 20 MaxNamespacesPerNode = (NamespacePortRangeEnd - NamespacePortRangeStart + 1) / PortsPerNamespace // 20
) )
// WebRTC port allocation constants
// These are separate from the core namespace port range (10000-10099)
// to avoid breaking existing port blocks.
const (
// SFU media port range: 20000-29999
// Each namespace gets a 500-port sub-range for RTP media
SFUMediaPortRangeStart = 20000
SFUMediaPortRangeEnd = 29999
SFUMediaPortsPerNamespace = 500
// SFU signaling ports: 30000-30099
// Each namespace gets 1 signaling port per node
SFUSignalingPortRangeStart = 30000
SFUSignalingPortRangeEnd = 30099
// TURN relay port range: 49152-65535
// Each namespace gets an 800-port sub-range for TURN relay
TURNRelayPortRangeStart = 49152
TURNRelayPortRangeEnd = 65535
TURNRelayPortsPerNamespace = 800
// TURN listen ports (standard)
TURNDefaultPort = 3478
TURNTLSPort = 443
// Default TURN credential TTL in seconds (10 minutes)
DefaultTURNCredentialTTL = 600
// Default service counts per namespace
DefaultSFUNodeCount = 3 // SFU on all 3 nodes
DefaultTURNNodeCount = 2 // TURN on 2 of 3 nodes for HA
)
// Default cluster sizes // Default cluster sizes
const ( const (
DefaultRQLiteNodeCount = 3 DefaultRQLiteNodeCount = 3
@ -206,4 +247,58 @@ var (
ErrNamespaceNotFound = &ClusterError{Message: "namespace not found"} ErrNamespaceNotFound = &ClusterError{Message: "namespace not found"}
ErrInvalidClusterStatus = &ClusterError{Message: "invalid cluster status for operation"} ErrInvalidClusterStatus = &ClusterError{Message: "invalid cluster status for operation"}
ErrRecoveryInProgress = &ClusterError{Message: "recovery already in progress for this cluster"} ErrRecoveryInProgress = &ClusterError{Message: "recovery already in progress for this cluster"}
ErrWebRTCAlreadyEnabled = &ClusterError{Message: "WebRTC is already enabled for this namespace"}
ErrWebRTCNotEnabled = &ClusterError{Message: "WebRTC is not enabled for this namespace"}
ErrNoWebRTCPortsAvailable = &ClusterError{Message: "no WebRTC ports available on node"}
) )
// WebRTCConfig represents the per-namespace WebRTC configuration stored in the database
type WebRTCConfig struct {
ID string `json:"id" db:"id"`
NamespaceClusterID string `json:"namespace_cluster_id" db:"namespace_cluster_id"`
NamespaceName string `json:"namespace_name" db:"namespace_name"`
Enabled bool `json:"enabled" db:"enabled"`
TURNSharedSecret string `json:"-" db:"turn_shared_secret"` // Never serialize secret to JSON
TURNCredentialTTL int `json:"turn_credential_ttl" db:"turn_credential_ttl"`
SFUNodeCount int `json:"sfu_node_count" db:"sfu_node_count"`
TURNNodeCount int `json:"turn_node_count" db:"turn_node_count"`
EnabledBy string `json:"enabled_by" db:"enabled_by"`
EnabledAt time.Time `json:"enabled_at" db:"enabled_at"`
DisabledAt *time.Time `json:"disabled_at,omitempty" db:"disabled_at"`
}
// WebRTCRoom represents an active WebRTC room tracked in the database
type WebRTCRoom struct {
ID string `json:"id" db:"id"`
NamespaceClusterID string `json:"namespace_cluster_id" db:"namespace_cluster_id"`
NamespaceName string `json:"namespace_name" db:"namespace_name"`
RoomID string `json:"room_id" db:"room_id"`
SFUNodeID string `json:"sfu_node_id" db:"sfu_node_id"`
SFUInternalIP string `json:"sfu_internal_ip" db:"sfu_internal_ip"`
SFUSignalingPort int `json:"sfu_signaling_port" db:"sfu_signaling_port"`
ParticipantCount int `json:"participant_count" db:"participant_count"`
MaxParticipants int `json:"max_participants" db:"max_participants"`
CreatedAt time.Time `json:"created_at" db:"created_at"`
LastActivity time.Time `json:"last_activity" db:"last_activity"`
}
// WebRTCPortBlock represents allocated WebRTC ports for a namespace on a node
type WebRTCPortBlock struct {
ID string `json:"id" db:"id"`
NodeID string `json:"node_id" db:"node_id"`
NamespaceClusterID string `json:"namespace_cluster_id" db:"namespace_cluster_id"`
ServiceType string `json:"service_type" db:"service_type"` // "sfu" or "turn"
// SFU ports
SFUSignalingPort int `json:"sfu_signaling_port,omitempty" db:"sfu_signaling_port"`
SFUMediaPortStart int `json:"sfu_media_port_start,omitempty" db:"sfu_media_port_start"`
SFUMediaPortEnd int `json:"sfu_media_port_end,omitempty" db:"sfu_media_port_end"`
// TURN ports
TURNListenPort int `json:"turn_listen_port,omitempty" db:"turn_listen_port"`
TURNTLSPort int `json:"turn_tls_port,omitempty" db:"turn_tls_port"`
TURNRelayPortStart int `json:"turn_relay_port_start,omitempty" db:"turn_relay_port_start"`
TURNRelayPortEnd int `json:"turn_relay_port_end,omitempty" db:"turn_relay_port_end"`
AllocatedAt time.Time `json:"allocated_at" db:"allocated_at"`
}

View File

@ -0,0 +1,519 @@
package namespace
import (
"context"
"fmt"
"time"
"github.com/DeBrosOfficial/network/pkg/client"
"github.com/DeBrosOfficial/network/pkg/rqlite"
"github.com/google/uuid"
"go.uber.org/zap"
)
// WebRTCPortAllocator manages port allocation for SFU and TURN services.
// Uses the webrtc_port_allocations table, separate from namespace_port_allocations,
// to avoid breaking existing port blocks.
type WebRTCPortAllocator struct {
db rqlite.Client
logger *zap.Logger
}
// NewWebRTCPortAllocator creates a new WebRTC port allocator
func NewWebRTCPortAllocator(db rqlite.Client, logger *zap.Logger) *WebRTCPortAllocator {
return &WebRTCPortAllocator{
db: db,
logger: logger.With(zap.String("component", "webrtc-port-allocator")),
}
}
// AllocateSFUPorts allocates SFU ports for a namespace on a node.
// Each namespace gets: 1 signaling port (30000-30099) + 500 media ports (20000-29999).
// Returns the existing allocation if one already exists (idempotent).
func (wpa *WebRTCPortAllocator) AllocateSFUPorts(ctx context.Context, nodeID, namespaceClusterID string) (*WebRTCPortBlock, error) {
internalCtx := client.WithInternalAuth(ctx)
// Check for existing allocation (idempotent)
existing, err := wpa.GetSFUPorts(ctx, namespaceClusterID, nodeID)
if err == nil && existing != nil {
wpa.logger.Debug("SFU ports already allocated",
zap.String("node_id", nodeID),
zap.String("namespace_cluster_id", namespaceClusterID),
zap.Int("signaling_port", existing.SFUSignalingPort),
)
return existing, nil
}
// Retry logic for concurrent allocation conflicts
maxRetries := 10
retryDelay := 100 * time.Millisecond
for attempt := 0; attempt < maxRetries; attempt++ {
block, err := wpa.tryAllocateSFUPorts(internalCtx, nodeID, namespaceClusterID)
if err == nil {
wpa.logger.Info("SFU ports allocated",
zap.String("node_id", nodeID),
zap.String("namespace_cluster_id", namespaceClusterID),
zap.Int("signaling_port", block.SFUSignalingPort),
zap.Int("media_start", block.SFUMediaPortStart),
zap.Int("media_end", block.SFUMediaPortEnd),
zap.Int("attempt", attempt+1),
)
return block, nil
}
if isConflictError(err) {
wpa.logger.Debug("SFU port allocation conflict, retrying",
zap.String("node_id", nodeID),
zap.Int("attempt", attempt+1),
zap.Error(err),
)
time.Sleep(retryDelay)
retryDelay *= 2
continue
}
return nil, err
}
return nil, &ClusterError{
Message: fmt.Sprintf("failed to allocate SFU ports after %d retries", maxRetries),
}
}
// tryAllocateSFUPorts performs a single attempt to allocate SFU ports.
func (wpa *WebRTCPortAllocator) tryAllocateSFUPorts(ctx context.Context, nodeID, namespaceClusterID string) (*WebRTCPortBlock, error) {
// Get node IPs sharing the same physical address (dev environment handling)
nodeIDs, err := wpa.getColocatedNodeIDs(ctx, nodeID)
if err != nil {
nodeIDs = []string{nodeID}
}
// Find next available SFU signaling port (30000-30099)
signalingPort, err := wpa.findAvailablePort(ctx, nodeIDs, "sfu", "sfu_signaling_port",
SFUSignalingPortRangeStart, SFUSignalingPortRangeEnd, 1)
if err != nil {
return nil, &ClusterError{
Message: "no SFU signaling port available on node",
Cause: err,
}
}
// Find next available SFU media port block (20000-29999, 500 per namespace)
mediaStart, err := wpa.findAvailablePortBlock(ctx, nodeIDs, "sfu", "sfu_media_port_start",
SFUMediaPortRangeStart, SFUMediaPortRangeEnd, SFUMediaPortsPerNamespace)
if err != nil {
return nil, &ClusterError{
Message: "no SFU media port range available on node",
Cause: err,
}
}
block := &WebRTCPortBlock{
ID: uuid.New().String(),
NodeID: nodeID,
NamespaceClusterID: namespaceClusterID,
ServiceType: "sfu",
SFUSignalingPort: signalingPort,
SFUMediaPortStart: mediaStart,
SFUMediaPortEnd: mediaStart + SFUMediaPortsPerNamespace - 1,
AllocatedAt: time.Now(),
}
if err := wpa.insertPortBlock(ctx, block); err != nil {
return nil, err
}
return block, nil
}
// AllocateTURNPorts allocates TURN ports for a namespace on a node.
// Each namespace gets: standard listen ports (3478/443) + 800 relay ports (49152-65535).
// Returns the existing allocation if one already exists (idempotent).
func (wpa *WebRTCPortAllocator) AllocateTURNPorts(ctx context.Context, nodeID, namespaceClusterID string) (*WebRTCPortBlock, error) {
internalCtx := client.WithInternalAuth(ctx)
// Check for existing allocation (idempotent)
existing, err := wpa.GetTURNPorts(ctx, namespaceClusterID, nodeID)
if err == nil && existing != nil {
wpa.logger.Debug("TURN ports already allocated",
zap.String("node_id", nodeID),
zap.String("namespace_cluster_id", namespaceClusterID),
)
return existing, nil
}
// Retry logic for concurrent allocation conflicts
maxRetries := 10
retryDelay := 100 * time.Millisecond
for attempt := 0; attempt < maxRetries; attempt++ {
block, err := wpa.tryAllocateTURNPorts(internalCtx, nodeID, namespaceClusterID)
if err == nil {
wpa.logger.Info("TURN ports allocated",
zap.String("node_id", nodeID),
zap.String("namespace_cluster_id", namespaceClusterID),
zap.Int("relay_start", block.TURNRelayPortStart),
zap.Int("relay_end", block.TURNRelayPortEnd),
zap.Int("attempt", attempt+1),
)
return block, nil
}
if isConflictError(err) {
wpa.logger.Debug("TURN port allocation conflict, retrying",
zap.String("node_id", nodeID),
zap.Int("attempt", attempt+1),
zap.Error(err),
)
time.Sleep(retryDelay)
retryDelay *= 2
continue
}
return nil, err
}
return nil, &ClusterError{
Message: fmt.Sprintf("failed to allocate TURN ports after %d retries", maxRetries),
}
}
// tryAllocateTURNPorts performs a single attempt to allocate TURN ports.
func (wpa *WebRTCPortAllocator) tryAllocateTURNPorts(ctx context.Context, nodeID, namespaceClusterID string) (*WebRTCPortBlock, error) {
// Get colocated node IDs (dev environment handling)
nodeIDs, err := wpa.getColocatedNodeIDs(ctx, nodeID)
if err != nil {
nodeIDs = []string{nodeID}
}
// Find next available TURN relay port block (49152-65535, 800 per namespace)
relayStart, err := wpa.findAvailablePortBlock(ctx, nodeIDs, "turn", "turn_relay_port_start",
TURNRelayPortRangeStart, TURNRelayPortRangeEnd, TURNRelayPortsPerNamespace)
if err != nil {
return nil, &ClusterError{
Message: "no TURN relay port range available on node",
Cause: err,
}
}
block := &WebRTCPortBlock{
ID: uuid.New().String(),
NodeID: nodeID,
NamespaceClusterID: namespaceClusterID,
ServiceType: "turn",
TURNListenPort: TURNDefaultPort,
TURNTLSPort: TURNTLSPort,
TURNRelayPortStart: relayStart,
TURNRelayPortEnd: relayStart + TURNRelayPortsPerNamespace - 1,
AllocatedAt: time.Now(),
}
if err := wpa.insertPortBlock(ctx, block); err != nil {
return nil, err
}
return block, nil
}
// DeallocateAll releases all WebRTC port blocks for a namespace cluster.
func (wpa *WebRTCPortAllocator) DeallocateAll(ctx context.Context, namespaceClusterID string) error {
internalCtx := client.WithInternalAuth(ctx)
query := `DELETE FROM webrtc_port_allocations WHERE namespace_cluster_id = ?`
_, err := wpa.db.Exec(internalCtx, query, namespaceClusterID)
if err != nil {
return &ClusterError{
Message: "failed to deallocate WebRTC port blocks",
Cause: err,
}
}
wpa.logger.Info("All WebRTC port blocks deallocated",
zap.String("namespace_cluster_id", namespaceClusterID),
)
return nil
}
// DeallocateByNode releases WebRTC port blocks for a specific node and service type.
func (wpa *WebRTCPortAllocator) DeallocateByNode(ctx context.Context, namespaceClusterID, nodeID, serviceType string) error {
internalCtx := client.WithInternalAuth(ctx)
query := `DELETE FROM webrtc_port_allocations WHERE namespace_cluster_id = ? AND node_id = ? AND service_type = ?`
_, err := wpa.db.Exec(internalCtx, query, namespaceClusterID, nodeID, serviceType)
if err != nil {
return &ClusterError{
Message: fmt.Sprintf("failed to deallocate %s port block on node %s", serviceType, nodeID),
Cause: err,
}
}
wpa.logger.Info("WebRTC port block deallocated",
zap.String("namespace_cluster_id", namespaceClusterID),
zap.String("node_id", nodeID),
zap.String("service_type", serviceType),
)
return nil
}
// GetSFUPorts retrieves the SFU port allocation for a namespace on a node.
func (wpa *WebRTCPortAllocator) GetSFUPorts(ctx context.Context, namespaceClusterID, nodeID string) (*WebRTCPortBlock, error) {
return wpa.getPortBlock(ctx, namespaceClusterID, nodeID, "sfu")
}
// GetTURNPorts retrieves the TURN port allocation for a namespace on a node.
func (wpa *WebRTCPortAllocator) GetTURNPorts(ctx context.Context, namespaceClusterID, nodeID string) (*WebRTCPortBlock, error) {
return wpa.getPortBlock(ctx, namespaceClusterID, nodeID, "turn")
}
// GetAllPorts retrieves all WebRTC port blocks for a namespace cluster.
func (wpa *WebRTCPortAllocator) GetAllPorts(ctx context.Context, namespaceClusterID string) ([]WebRTCPortBlock, error) {
internalCtx := client.WithInternalAuth(ctx)
var blocks []WebRTCPortBlock
query := `
SELECT id, node_id, namespace_cluster_id, service_type,
sfu_signaling_port, sfu_media_port_start, sfu_media_port_end,
turn_listen_port, turn_tls_port, turn_relay_port_start, turn_relay_port_end,
allocated_at
FROM webrtc_port_allocations
WHERE namespace_cluster_id = ?
ORDER BY service_type, node_id
`
err := wpa.db.Query(internalCtx, &blocks, query, namespaceClusterID)
if err != nil {
return nil, &ClusterError{
Message: "failed to query WebRTC port blocks",
Cause: err,
}
}
return blocks, nil
}
// NodeHasTURN checks if a node already has a TURN allocation from any namespace.
// Used during node selection to avoid port conflicts on standard TURN ports (3478/443).
func (wpa *WebRTCPortAllocator) NodeHasTURN(ctx context.Context, nodeID string) (bool, error) {
internalCtx := client.WithInternalAuth(ctx)
type countResult struct {
Count int `db:"count"`
}
var results []countResult
query := `SELECT COUNT(*) as count FROM webrtc_port_allocations WHERE node_id = ? AND service_type = 'turn'`
err := wpa.db.Query(internalCtx, &results, query, nodeID)
if err != nil {
return false, &ClusterError{
Message: "failed to check TURN allocation on node",
Cause: err,
}
}
if len(results) == 0 {
return false, nil
}
return results[0].Count > 0, nil
}
// --- internal helpers ---
// getPortBlock retrieves a specific port block by cluster, node, and service type.
func (wpa *WebRTCPortAllocator) getPortBlock(ctx context.Context, namespaceClusterID, nodeID, serviceType string) (*WebRTCPortBlock, error) {
internalCtx := client.WithInternalAuth(ctx)
var blocks []WebRTCPortBlock
query := `
SELECT id, node_id, namespace_cluster_id, service_type,
sfu_signaling_port, sfu_media_port_start, sfu_media_port_end,
turn_listen_port, turn_tls_port, turn_relay_port_start, turn_relay_port_end,
allocated_at
FROM webrtc_port_allocations
WHERE namespace_cluster_id = ? AND node_id = ? AND service_type = ?
LIMIT 1
`
err := wpa.db.Query(internalCtx, &blocks, query, namespaceClusterID, nodeID, serviceType)
if err != nil {
return nil, &ClusterError{
Message: fmt.Sprintf("failed to query %s port block", serviceType),
Cause: err,
}
}
if len(blocks) == 0 {
return nil, nil
}
return &blocks[0], nil
}
// insertPortBlock inserts a WebRTC port allocation record.
func (wpa *WebRTCPortAllocator) insertPortBlock(ctx context.Context, block *WebRTCPortBlock) error {
query := `
INSERT INTO webrtc_port_allocations (
id, node_id, namespace_cluster_id, service_type,
sfu_signaling_port, sfu_media_port_start, sfu_media_port_end,
turn_listen_port, turn_tls_port, turn_relay_port_start, turn_relay_port_end,
allocated_at
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`
_, err := wpa.db.Exec(ctx, query,
block.ID,
block.NodeID,
block.NamespaceClusterID,
block.ServiceType,
block.SFUSignalingPort,
block.SFUMediaPortStart,
block.SFUMediaPortEnd,
block.TURNListenPort,
block.TURNTLSPort,
block.TURNRelayPortStart,
block.TURNRelayPortEnd,
block.AllocatedAt,
)
if err != nil {
return &ClusterError{
Message: fmt.Sprintf("failed to insert %s port allocation", block.ServiceType),
Cause: err,
}
}
return nil
}
// getColocatedNodeIDs returns all node IDs that share the same IP address as the given node.
// In dev environments, multiple logical nodes share one physical IP — port ranges must not overlap.
// In production (one node per IP), returns only the given nodeID.
func (wpa *WebRTCPortAllocator) getColocatedNodeIDs(ctx context.Context, nodeID string) ([]string, error) {
// Get this node's IP
type nodeInfo struct {
IPAddress string `db:"ip_address"`
}
var infos []nodeInfo
if err := wpa.db.Query(ctx, &infos, `SELECT ip_address FROM dns_nodes WHERE id = ? LIMIT 1`, nodeID); err != nil || len(infos) == 0 {
return []string{nodeID}, nil
}
ip := infos[0].IPAddress
if ip == "" {
return []string{nodeID}, nil
}
// Check if multiple nodes share this IP
type nodeIDRow struct {
ID string `db:"id"`
}
var colocated []nodeIDRow
if err := wpa.db.Query(ctx, &colocated, `SELECT id FROM dns_nodes WHERE ip_address = ?`, ip); err != nil || len(colocated) <= 1 {
return []string{nodeID}, nil
}
ids := make([]string, len(colocated))
for i, n := range colocated {
ids[i] = n.ID
}
wpa.logger.Debug("Multiple nodes share IP, allocating globally",
zap.String("ip_address", ip),
zap.Int("node_count", len(ids)),
)
return ids, nil
}
// findAvailablePort finds the next available single port in a range on the given nodes.
func (wpa *WebRTCPortAllocator) findAvailablePort(ctx context.Context, nodeIDs []string, serviceType, portColumn string, rangeStart, rangeEnd, step int) (int, error) {
allocated, err := wpa.getAllocatedValues(ctx, nodeIDs, serviceType, portColumn)
if err != nil {
return 0, err
}
allocatedSet := make(map[int]bool, len(allocated))
for _, v := range allocated {
allocatedSet[v] = true
}
for port := rangeStart; port <= rangeEnd; port += step {
if !allocatedSet[port] {
return port, nil
}
}
return 0, ErrNoWebRTCPortsAvailable
}
// findAvailablePortBlock finds the next available contiguous port block in a range.
func (wpa *WebRTCPortAllocator) findAvailablePortBlock(ctx context.Context, nodeIDs []string, serviceType, portColumn string, rangeStart, rangeEnd, blockSize int) (int, error) {
allocated, err := wpa.getAllocatedValues(ctx, nodeIDs, serviceType, portColumn)
if err != nil {
return 0, err
}
allocatedSet := make(map[int]bool, len(allocated))
for _, v := range allocated {
allocatedSet[v] = true
}
for start := rangeStart; start+blockSize-1 <= rangeEnd; start += blockSize {
if !allocatedSet[start] {
return start, nil
}
}
return 0, ErrNoWebRTCPortsAvailable
}
// getAllocatedValues queries the allocated port values for a given column across colocated nodes.
func (wpa *WebRTCPortAllocator) getAllocatedValues(ctx context.Context, nodeIDs []string, serviceType, portColumn string) ([]int, error) {
type portRow struct {
Port int `db:"port_val"`
}
var rows []portRow
if len(nodeIDs) == 1 {
query := fmt.Sprintf(
`SELECT %s as port_val FROM webrtc_port_allocations WHERE node_id = ? AND service_type = ? AND %s > 0 ORDER BY %s ASC`,
portColumn, portColumn, portColumn,
)
if err := wpa.db.Query(ctx, &rows, query, nodeIDs[0], serviceType); err != nil {
return nil, &ClusterError{
Message: "failed to query allocated WebRTC ports",
Cause: err,
}
}
} else {
// Multiple colocated nodes — query by joining with dns_nodes on IP
// Get the IP of the first node (they all share the same IP)
type nodeInfo struct {
IPAddress string `db:"ip_address"`
}
var infos []nodeInfo
if err := wpa.db.Query(ctx, &infos, `SELECT ip_address FROM dns_nodes WHERE id = ? LIMIT 1`, nodeIDs[0]); err != nil || len(infos) == 0 {
return nil, &ClusterError{Message: "failed to get node IP for colocated port query"}
}
query := fmt.Sprintf(
`SELECT wpa.%s as port_val FROM webrtc_port_allocations wpa
JOIN dns_nodes dn ON wpa.node_id = dn.id
WHERE dn.ip_address = ? AND wpa.service_type = ? AND wpa.%s > 0
ORDER BY wpa.%s ASC`,
portColumn, portColumn, portColumn,
)
if err := wpa.db.Query(ctx, &rows, query, infos[0].IPAddress, serviceType); err != nil {
return nil, &ClusterError{
Message: "failed to query allocated WebRTC ports (colocated)",
Cause: err,
}
}
}
result := make([]int, len(rows))
for i, r := range rows {
result[i] = r.Port
}
return result, nil
}

View File

@ -0,0 +1,337 @@
package namespace
import (
"context"
"strings"
"testing"
"go.uber.org/zap"
)
func TestWebRTCPortConstants_NoOverlap(t *testing.T) {
// Verify WebRTC port ranges don't overlap with core namespace ports (10000-10099)
ranges := []struct {
name string
start int
end int
}{
{"core namespace", NamespacePortRangeStart, NamespacePortRangeEnd},
{"SFU media", SFUMediaPortRangeStart, SFUMediaPortRangeEnd},
{"SFU signaling", SFUSignalingPortRangeStart, SFUSignalingPortRangeEnd},
{"TURN relay", TURNRelayPortRangeStart, TURNRelayPortRangeEnd},
}
for i := 0; i < len(ranges); i++ {
for j := i + 1; j < len(ranges); j++ {
a, b := ranges[i], ranges[j]
if a.start <= b.end && b.start <= a.end {
t.Errorf("Range overlap: %s (%d-%d) overlaps with %s (%d-%d)",
a.name, a.start, a.end, b.name, b.start, b.end)
}
}
}
}
func TestWebRTCPortConstants_Capacity(t *testing.T) {
// SFU media: (29999-20000+1)/500 = 20 namespaces per node
sfuMediaCapacity := (SFUMediaPortRangeEnd - SFUMediaPortRangeStart + 1) / SFUMediaPortsPerNamespace
if sfuMediaCapacity < 20 {
t.Errorf("SFU media capacity = %d, want >= 20", sfuMediaCapacity)
}
// SFU signaling: 30099-30000+1 = 100 ports → 100 namespaces per node
sfuSignalingCapacity := SFUSignalingPortRangeEnd - SFUSignalingPortRangeStart + 1
if sfuSignalingCapacity < 20 {
t.Errorf("SFU signaling capacity = %d, want >= 20", sfuSignalingCapacity)
}
// TURN relay: (65535-49152+1)/800 = 20 namespaces per node
turnRelayCapacity := (TURNRelayPortRangeEnd - TURNRelayPortRangeStart + 1) / TURNRelayPortsPerNamespace
if turnRelayCapacity < 20 {
t.Errorf("TURN relay capacity = %d, want >= 20", turnRelayCapacity)
}
}
func TestWebRTCPortConstants_Values(t *testing.T) {
if SFUMediaPortRangeStart != 20000 {
t.Errorf("SFUMediaPortRangeStart = %d, want 20000", SFUMediaPortRangeStart)
}
if SFUMediaPortRangeEnd != 29999 {
t.Errorf("SFUMediaPortRangeEnd = %d, want 29999", SFUMediaPortRangeEnd)
}
if SFUMediaPortsPerNamespace != 500 {
t.Errorf("SFUMediaPortsPerNamespace = %d, want 500", SFUMediaPortsPerNamespace)
}
if SFUSignalingPortRangeStart != 30000 {
t.Errorf("SFUSignalingPortRangeStart = %d, want 30000", SFUSignalingPortRangeStart)
}
if TURNRelayPortRangeStart != 49152 {
t.Errorf("TURNRelayPortRangeStart = %d, want 49152", TURNRelayPortRangeStart)
}
if TURNRelayPortsPerNamespace != 800 {
t.Errorf("TURNRelayPortsPerNamespace = %d, want 800", TURNRelayPortsPerNamespace)
}
if TURNDefaultPort != 3478 {
t.Errorf("TURNDefaultPort = %d, want 3478", TURNDefaultPort)
}
if DefaultSFUNodeCount != 3 {
t.Errorf("DefaultSFUNodeCount = %d, want 3", DefaultSFUNodeCount)
}
if DefaultTURNNodeCount != 2 {
t.Errorf("DefaultTURNNodeCount = %d, want 2", DefaultTURNNodeCount)
}
}
func TestNewWebRTCPortAllocator(t *testing.T) {
mockDB := newMockRQLiteClient()
allocator := NewWebRTCPortAllocator(mockDB, testLogger())
if allocator == nil {
t.Fatal("NewWebRTCPortAllocator returned nil")
}
if allocator.db != mockDB {
t.Error("allocator.db not set correctly")
}
}
func TestWebRTCPortAllocator_AllocateSFUPorts(t *testing.T) {
mockDB := newMockRQLiteClient()
allocator := NewWebRTCPortAllocator(mockDB, testLogger())
block, err := allocator.AllocateSFUPorts(context.Background(), "node-1", "cluster-1")
if err != nil {
t.Fatalf("AllocateSFUPorts failed: %v", err)
}
if block == nil {
t.Fatal("AllocateSFUPorts returned nil block")
}
if block.ServiceType != "sfu" {
t.Errorf("ServiceType = %q, want %q", block.ServiceType, "sfu")
}
if block.NodeID != "node-1" {
t.Errorf("NodeID = %q, want %q", block.NodeID, "node-1")
}
if block.NamespaceClusterID != "cluster-1" {
t.Errorf("NamespaceClusterID = %q, want %q", block.NamespaceClusterID, "cluster-1")
}
// First allocation should get the first port in each range
if block.SFUSignalingPort != SFUSignalingPortRangeStart {
t.Errorf("SFUSignalingPort = %d, want %d", block.SFUSignalingPort, SFUSignalingPortRangeStart)
}
if block.SFUMediaPortStart != SFUMediaPortRangeStart {
t.Errorf("SFUMediaPortStart = %d, want %d", block.SFUMediaPortStart, SFUMediaPortRangeStart)
}
if block.SFUMediaPortEnd != SFUMediaPortRangeStart+SFUMediaPortsPerNamespace-1 {
t.Errorf("SFUMediaPortEnd = %d, want %d", block.SFUMediaPortEnd, SFUMediaPortRangeStart+SFUMediaPortsPerNamespace-1)
}
// TURN fields should be zero for SFU allocation
if block.TURNListenPort != 0 {
t.Errorf("TURNListenPort = %d, want 0 for SFU allocation", block.TURNListenPort)
}
if block.TURNRelayPortStart != 0 {
t.Errorf("TURNRelayPortStart = %d, want 0 for SFU allocation", block.TURNRelayPortStart)
}
// Verify INSERT was called
hasInsert := false
for _, call := range mockDB.execCalls {
if strings.Contains(call.Query, "INSERT INTO webrtc_port_allocations") {
hasInsert = true
break
}
}
if !hasInsert {
t.Error("expected INSERT INTO webrtc_port_allocations to be called")
}
}
func TestWebRTCPortAllocator_AllocateTURNPorts(t *testing.T) {
mockDB := newMockRQLiteClient()
allocator := NewWebRTCPortAllocator(mockDB, testLogger())
block, err := allocator.AllocateTURNPorts(context.Background(), "node-1", "cluster-1")
if err != nil {
t.Fatalf("AllocateTURNPorts failed: %v", err)
}
if block == nil {
t.Fatal("AllocateTURNPorts returned nil block")
}
if block.ServiceType != "turn" {
t.Errorf("ServiceType = %q, want %q", block.ServiceType, "turn")
}
if block.TURNListenPort != TURNDefaultPort {
t.Errorf("TURNListenPort = %d, want %d", block.TURNListenPort, TURNDefaultPort)
}
if block.TURNTLSPort != TURNTLSPort {
t.Errorf("TURNTLSPort = %d, want %d", block.TURNTLSPort, TURNTLSPort)
}
if block.TURNRelayPortStart != TURNRelayPortRangeStart {
t.Errorf("TURNRelayPortStart = %d, want %d", block.TURNRelayPortStart, TURNRelayPortRangeStart)
}
if block.TURNRelayPortEnd != TURNRelayPortRangeStart+TURNRelayPortsPerNamespace-1 {
t.Errorf("TURNRelayPortEnd = %d, want %d", block.TURNRelayPortEnd, TURNRelayPortRangeStart+TURNRelayPortsPerNamespace-1)
}
// SFU fields should be zero for TURN allocation
if block.SFUSignalingPort != 0 {
t.Errorf("SFUSignalingPort = %d, want 0 for TURN allocation", block.SFUSignalingPort)
}
}
func TestWebRTCPortAllocator_DeallocateAll(t *testing.T) {
mockDB := newMockRQLiteClient()
allocator := NewWebRTCPortAllocator(mockDB, testLogger())
err := allocator.DeallocateAll(context.Background(), "cluster-1")
if err != nil {
t.Fatalf("DeallocateAll failed: %v", err)
}
// Verify DELETE was called with correct cluster ID
hasDelete := false
for _, call := range mockDB.execCalls {
if strings.Contains(call.Query, "DELETE FROM webrtc_port_allocations") &&
strings.Contains(call.Query, "namespace_cluster_id") {
hasDelete = true
if len(call.Args) < 1 || call.Args[0] != "cluster-1" {
t.Errorf("DELETE called with wrong cluster ID: %v", call.Args)
}
}
}
if !hasDelete {
t.Error("expected DELETE FROM webrtc_port_allocations to be called")
}
}
func TestWebRTCPortAllocator_DeallocateByNode(t *testing.T) {
mockDB := newMockRQLiteClient()
allocator := NewWebRTCPortAllocator(mockDB, testLogger())
err := allocator.DeallocateByNode(context.Background(), "cluster-1", "node-1", "sfu")
if err != nil {
t.Fatalf("DeallocateByNode failed: %v", err)
}
// Verify DELETE was called with correct parameters
hasDelete := false
for _, call := range mockDB.execCalls {
if strings.Contains(call.Query, "DELETE FROM webrtc_port_allocations") &&
strings.Contains(call.Query, "service_type") {
hasDelete = true
if len(call.Args) != 3 {
t.Fatalf("DELETE called with %d args, want 3", len(call.Args))
}
if call.Args[0] != "cluster-1" {
t.Errorf("arg[0] = %v, want cluster-1", call.Args[0])
}
if call.Args[1] != "node-1" {
t.Errorf("arg[1] = %v, want node-1", call.Args[1])
}
if call.Args[2] != "sfu" {
t.Errorf("arg[2] = %v, want sfu", call.Args[2])
}
}
}
if !hasDelete {
t.Error("expected DELETE FROM webrtc_port_allocations to be called")
}
}
func TestWebRTCPortAllocator_NodeHasTURN(t *testing.T) {
mockDB := newMockRQLiteClient()
allocator := NewWebRTCPortAllocator(mockDB, testLogger())
// Mock query returns empty results → no TURN on node
hasTURN, err := allocator.NodeHasTURN(context.Background(), "node-1")
if err != nil {
t.Fatalf("NodeHasTURN failed: %v", err)
}
if hasTURN {
t.Error("expected NodeHasTURN = false for node with no allocations")
}
}
func TestWebRTCPortAllocator_GetSFUPorts_NoAllocation(t *testing.T) {
mockDB := newMockRQLiteClient()
allocator := NewWebRTCPortAllocator(mockDB, testLogger())
block, err := allocator.GetSFUPorts(context.Background(), "cluster-1", "node-1")
if err != nil {
t.Fatalf("GetSFUPorts failed: %v", err)
}
if block != nil {
t.Error("expected nil block when no allocation exists")
}
}
func TestWebRTCPortAllocator_GetTURNPorts_NoAllocation(t *testing.T) {
mockDB := newMockRQLiteClient()
allocator := NewWebRTCPortAllocator(mockDB, testLogger())
block, err := allocator.GetTURNPorts(context.Background(), "cluster-1", "node-1")
if err != nil {
t.Fatalf("GetTURNPorts failed: %v", err)
}
if block != nil {
t.Error("expected nil block when no allocation exists")
}
}
func TestWebRTCPortAllocator_GetAllPorts_Empty(t *testing.T) {
mockDB := newMockRQLiteClient()
allocator := NewWebRTCPortAllocator(mockDB, testLogger())
blocks, err := allocator.GetAllPorts(context.Background(), "cluster-1")
if err != nil {
t.Fatalf("GetAllPorts failed: %v", err)
}
if len(blocks) != 0 {
t.Errorf("expected 0 blocks, got %d", len(blocks))
}
}
func TestWebRTCPortBlock_SFUFields(t *testing.T) {
block := &WebRTCPortBlock{
ID: "test-id",
NodeID: "node-1",
NamespaceClusterID: "cluster-1",
ServiceType: "sfu",
SFUSignalingPort: 30000,
SFUMediaPortStart: 20000,
SFUMediaPortEnd: 20499,
}
mediaRange := block.SFUMediaPortEnd - block.SFUMediaPortStart + 1
if mediaRange != SFUMediaPortsPerNamespace {
t.Errorf("SFU media range = %d, want %d", mediaRange, SFUMediaPortsPerNamespace)
}
}
func TestWebRTCPortBlock_TURNFields(t *testing.T) {
block := &WebRTCPortBlock{
ID: "test-id",
NodeID: "node-1",
NamespaceClusterID: "cluster-1",
ServiceType: "turn",
TURNListenPort: 3478,
TURNTLSPort: 443,
TURNRelayPortStart: 49152,
TURNRelayPortEnd: 49951,
}
relayRange := block.TURNRelayPortEnd - block.TURNRelayPortStart + 1
if relayRange != TURNRelayPortsPerNamespace {
t.Errorf("TURN relay range = %d, want %d", relayRange, TURNRelayPortsPerNamespace)
}
}
// testLogger returns a no-op logger for tests
func testLogger() *zap.Logger {
return zap.NewNop()
}

View File

@ -57,7 +57,11 @@ func (n *Node) startHTTPGateway(ctx context.Context) error {
IPFSTimeout: n.config.HTTPGateway.IPFSTimeout, IPFSTimeout: n.config.HTTPGateway.IPFSTimeout,
BaseDomain: n.config.HTTPGateway.BaseDomain, BaseDomain: n.config.HTTPGateway.BaseDomain,
DataDir: oramaDir, DataDir: oramaDir,
ClusterSecret: clusterSecret, ClusterSecret: clusterSecret,
WebRTCEnabled: n.config.HTTPGateway.WebRTC.Enabled,
SFUPort: n.config.HTTPGateway.WebRTC.SFUPort,
TURNDomain: n.config.HTTPGateway.WebRTC.TURNDomain,
TURNSecret: n.config.HTTPGateway.WebRTC.TURNSecret,
} }
apiGateway, err := gateway.New(gatewayLogger, gwCfg) apiGateway, err := gateway.New(gatewayLogger, gwCfg)
@ -82,6 +86,7 @@ func (n *Node) startHTTPGateway(ctx context.Context) error {
clusterManager.SetLocalNodeID(gwCfg.NodePeerID) clusterManager.SetLocalNodeID(gwCfg.NodePeerID)
apiGateway.SetClusterProvisioner(clusterManager) apiGateway.SetClusterProvisioner(clusterManager)
apiGateway.SetNodeRecoverer(clusterManager) apiGateway.SetNodeRecoverer(clusterManager)
apiGateway.SetWebRTCManager(clusterManager)
// Wire spawn handler for distributed namespace instance spawning // Wire spawn handler for distributed namespace instance spawning
systemdSpawner := namespace.NewSystemdSpawner(baseDataDir, n.logger.Logger) systemdSpawner := namespace.NewSystemdSpawner(baseDataDir, n.logger.Logger)

80
pkg/sfu/config.go Normal file
View File

@ -0,0 +1,80 @@
package sfu
import "fmt"
// Config holds configuration for the SFU server
type Config struct {
// ListenAddr is the address to bind the signaling WebSocket server.
// Must be a WireGuard IP (10.0.0.x) — never 0.0.0.0.
ListenAddr string `yaml:"listen_addr"`
// Namespace this SFU instance belongs to
Namespace string `yaml:"namespace"`
// MediaPortRange defines the UDP port range for RTP media
MediaPortStart int `yaml:"media_port_start"`
MediaPortEnd int `yaml:"media_port_end"`
// TURN servers this SFU should advertise to peers
TURNServers []TURNServerConfig `yaml:"turn_servers"`
// TURNSecret is the shared HMAC-SHA1 secret for generating TURN credentials
TURNSecret string `yaml:"turn_secret"`
// TURNCredentialTTL is the lifetime of TURN credentials in seconds
TURNCredentialTTL int `yaml:"turn_credential_ttl"`
// RQLiteDSN is the namespace-local RQLite DSN for room state
RQLiteDSN string `yaml:"rqlite_dsn"`
}
// TURNServerConfig represents a single TURN server endpoint
type TURNServerConfig struct {
Host string `yaml:"host"` // IP or hostname
Port int `yaml:"port"` // UDP port (3478 or 443)
}
// Validate checks the SFU configuration for errors
func (c *Config) Validate() []error {
var errs []error
if c.ListenAddr == "" {
errs = append(errs, fmt.Errorf("sfu.listen_addr: must not be empty"))
}
if c.Namespace == "" {
errs = append(errs, fmt.Errorf("sfu.namespace: must not be empty"))
}
if c.MediaPortStart <= 0 || c.MediaPortEnd <= 0 {
errs = append(errs, fmt.Errorf("sfu.media_port_range: start and end must be positive"))
} else if c.MediaPortEnd <= c.MediaPortStart {
errs = append(errs, fmt.Errorf("sfu.media_port_range: end (%d) must be greater than start (%d)", c.MediaPortEnd, c.MediaPortStart))
}
if len(c.TURNServers) == 0 {
errs = append(errs, fmt.Errorf("sfu.turn_servers: at least one TURN server must be configured"))
}
for i, ts := range c.TURNServers {
if ts.Host == "" {
errs = append(errs, fmt.Errorf("sfu.turn_servers[%d].host: must not be empty", i))
}
if ts.Port <= 0 || ts.Port > 65535 {
errs = append(errs, fmt.Errorf("sfu.turn_servers[%d].port: must be between 1 and 65535", i))
}
}
if c.TURNSecret == "" {
errs = append(errs, fmt.Errorf("sfu.turn_secret: must not be empty"))
}
if c.TURNCredentialTTL <= 0 {
errs = append(errs, fmt.Errorf("sfu.turn_credential_ttl: must be positive"))
}
if c.RQLiteDSN == "" {
errs = append(errs, fmt.Errorf("sfu.rqlite_dsn: must not be empty"))
}
return errs
}

167
pkg/sfu/config_test.go Normal file
View File

@ -0,0 +1,167 @@
package sfu
import "testing"
func TestConfigValidation(t *testing.T) {
tests := []struct {
name string
config Config
wantErrs int
}{
{
name: "valid config",
config: Config{
ListenAddr: "10.0.0.1:8443",
Namespace: "test-ns",
MediaPortStart: 20000,
MediaPortEnd: 20500,
TURNServers: []TURNServerConfig{{Host: "1.2.3.4", Port: 3478}},
TURNSecret: "secret-key",
TURNCredentialTTL: 600,
RQLiteDSN: "http://10.0.0.1:4001",
},
wantErrs: 0,
},
{
name: "valid config with multiple TURN servers",
config: Config{
ListenAddr: "10.0.0.1:8443",
Namespace: "test-ns",
MediaPortStart: 20000,
MediaPortEnd: 20500,
TURNServers: []TURNServerConfig{
{Host: "1.2.3.4", Port: 3478},
{Host: "5.6.7.8", Port: 443},
},
TURNSecret: "secret-key",
TURNCredentialTTL: 600,
RQLiteDSN: "http://10.0.0.1:4001",
},
wantErrs: 0,
},
{
name: "missing all fields",
config: Config{},
wantErrs: 7, // listen_addr, namespace, media_port_range, turn_servers, turn_secret, turn_credential_ttl, rqlite_dsn
},
{
name: "missing listen addr",
config: Config{
Namespace: "test-ns",
MediaPortStart: 20000,
MediaPortEnd: 20500,
TURNServers: []TURNServerConfig{{Host: "1.2.3.4", Port: 3478}},
TURNSecret: "secret",
TURNCredentialTTL: 600,
RQLiteDSN: "http://10.0.0.1:4001",
},
wantErrs: 1,
},
{
name: "missing namespace",
config: Config{
ListenAddr: "10.0.0.1:8443",
MediaPortStart: 20000,
MediaPortEnd: 20500,
TURNServers: []TURNServerConfig{{Host: "1.2.3.4", Port: 3478}},
TURNSecret: "secret",
TURNCredentialTTL: 600,
RQLiteDSN: "http://10.0.0.1:4001",
},
wantErrs: 1,
},
{
name: "invalid media port range - inverted",
config: Config{
ListenAddr: "10.0.0.1:8443",
Namespace: "test-ns",
MediaPortStart: 20500,
MediaPortEnd: 20000,
TURNServers: []TURNServerConfig{{Host: "1.2.3.4", Port: 3478}},
TURNSecret: "secret",
TURNCredentialTTL: 600,
RQLiteDSN: "http://10.0.0.1:4001",
},
wantErrs: 1,
},
{
name: "invalid media port range - zero",
config: Config{
ListenAddr: "10.0.0.1:8443",
Namespace: "test-ns",
MediaPortStart: 0,
MediaPortEnd: 0,
TURNServers: []TURNServerConfig{{Host: "1.2.3.4", Port: 3478}},
TURNSecret: "secret",
TURNCredentialTTL: 600,
RQLiteDSN: "http://10.0.0.1:4001",
},
wantErrs: 1,
},
{
name: "no TURN servers",
config: Config{
ListenAddr: "10.0.0.1:8443",
Namespace: "test-ns",
MediaPortStart: 20000,
MediaPortEnd: 20500,
TURNServers: []TURNServerConfig{},
TURNSecret: "secret",
TURNCredentialTTL: 600,
RQLiteDSN: "http://10.0.0.1:4001",
},
wantErrs: 1,
},
{
name: "TURN server with invalid port",
config: Config{
ListenAddr: "10.0.0.1:8443",
Namespace: "test-ns",
MediaPortStart: 20000,
MediaPortEnd: 20500,
TURNServers: []TURNServerConfig{{Host: "1.2.3.4", Port: 0}},
TURNSecret: "secret",
TURNCredentialTTL: 600,
RQLiteDSN: "http://10.0.0.1:4001",
},
wantErrs: 1,
},
{
name: "TURN server with empty host",
config: Config{
ListenAddr: "10.0.0.1:8443",
Namespace: "test-ns",
MediaPortStart: 20000,
MediaPortEnd: 20500,
TURNServers: []TURNServerConfig{{Host: "", Port: 3478}},
TURNSecret: "secret",
TURNCredentialTTL: 600,
RQLiteDSN: "http://10.0.0.1:4001",
},
wantErrs: 1,
},
{
name: "negative credential TTL",
config: Config{
ListenAddr: "10.0.0.1:8443",
Namespace: "test-ns",
MediaPortStart: 20000,
MediaPortEnd: 20500,
TURNServers: []TURNServerConfig{{Host: "1.2.3.4", Port: 3478}},
TURNSecret: "secret",
TURNCredentialTTL: -1,
RQLiteDSN: "http://10.0.0.1:4001",
},
wantErrs: 1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
errs := tt.config.Validate()
if len(errs) != tt.wantErrs {
t.Errorf("Validate() returned %d errors, want %d: %v", len(errs), tt.wantErrs, errs)
}
})
}
}

340
pkg/sfu/peer.go Normal file
View File

@ -0,0 +1,340 @@
package sfu
import (
"encoding/json"
"errors"
"sync"
"github.com/google/uuid"
"github.com/gorilla/websocket"
"github.com/pion/rtcp"
"github.com/pion/webrtc/v4"
"go.uber.org/zap"
)
var (
ErrPeerNotInitialized = errors.New("peer connection not initialized")
ErrPeerClosed = errors.New("peer is closed")
ErrWebSocketClosed = errors.New("websocket connection closed")
)
// Peer represents a participant in a room with a WebRTC PeerConnection.
type Peer struct {
ID string
UserID string
pc *webrtc.PeerConnection
conn *websocket.Conn
room *Room
// Negotiation state machine
negotiationPending bool
batchingTracks bool
negotiationMu sync.Mutex
// Connection state
closed bool
closedMu sync.RWMutex
connMu sync.Mutex
logger *zap.Logger
onClose func(*Peer)
}
// NewPeer creates a new peer
func NewPeer(userID string, conn *websocket.Conn, room *Room, logger *zap.Logger) *Peer {
return &Peer{
ID: uuid.New().String(),
UserID: userID,
conn: conn,
room: room,
logger: logger.With(zap.String("peer_id", "")), // Updated after ID assigned
}
}
// InitPeerConnection creates and configures the WebRTC PeerConnection.
func (p *Peer) InitPeerConnection(api *webrtc.API, iceServers []webrtc.ICEServer) error {
pc, err := api.NewPeerConnection(webrtc.Configuration{
ICEServers: iceServers,
ICETransportPolicy: webrtc.ICETransportPolicyRelay, // Force TURN relay
})
if err != nil {
return err
}
p.pc = pc
p.logger = p.logger.With(zap.String("peer_id", p.ID))
// ICE connection state changes
pc.OnICEConnectionStateChange(func(state webrtc.ICEConnectionState) {
p.logger.Info("ICE state changed", zap.String("state", state.String()))
switch state {
case webrtc.ICEConnectionStateDisconnected:
// Give 15 seconds to reconnect before removing
go p.handleReconnectTimeout()
case webrtc.ICEConnectionStateFailed, webrtc.ICEConnectionStateClosed:
p.handleDisconnect()
}
})
// ICE candidate generation
pc.OnICECandidate(func(candidate *webrtc.ICECandidate) {
if candidate == nil {
return
}
c := candidate.ToJSON()
data := &ICECandidateData{Candidate: c.Candidate}
if c.SDPMid != nil {
data.SDPMid = *c.SDPMid
}
if c.SDPMLineIndex != nil {
data.SDPMLineIndex = *c.SDPMLineIndex
}
if c.UsernameFragment != nil {
data.UsernameFragment = *c.UsernameFragment
}
p.SendMessage(NewServerMessage(MessageTypeICECandidate, data))
})
// Incoming tracks from the client
pc.OnTrack(func(track *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) {
p.logger.Info("Track received",
zap.String("track_id", track.ID()),
zap.String("kind", track.Kind().String()),
zap.String("codec", track.Codec().MimeType))
// Read RTCP feedback (PLI/NACK) in background
go p.readRTCP(receiver, track)
// Forward track to all other peers
p.room.BroadcastTrack(p.ID, track)
})
// Negotiation needed — only when stable
pc.OnNegotiationNeeded(func() {
p.negotiationMu.Lock()
if p.batchingTracks {
p.negotiationPending = true
p.negotiationMu.Unlock()
return
}
p.negotiationMu.Unlock()
if pc.SignalingState() == webrtc.SignalingStateStable {
p.createAndSendOffer()
} else {
p.negotiationMu.Lock()
p.negotiationPending = true
p.negotiationMu.Unlock()
}
})
// When state returns to stable, fire pending negotiation
pc.OnSignalingStateChange(func(state webrtc.SignalingState) {
if state == webrtc.SignalingStateStable {
p.negotiationMu.Lock()
pending := p.negotiationPending
p.negotiationPending = false
p.negotiationMu.Unlock()
if pending {
p.createAndSendOffer()
}
}
})
return nil
}
func (p *Peer) createAndSendOffer() {
if p.pc == nil {
return
}
if p.pc.SignalingState() != webrtc.SignalingStateStable {
p.negotiationMu.Lock()
p.negotiationPending = true
p.negotiationMu.Unlock()
return
}
offer, err := p.pc.CreateOffer(nil)
if err != nil {
p.logger.Error("Failed to create offer", zap.Error(err))
return
}
if err := p.pc.SetLocalDescription(offer); err != nil {
p.logger.Error("Failed to set local description", zap.Error(err))
return
}
p.SendMessage(NewServerMessage(MessageTypeOffer, &OfferData{SDP: offer.SDP}))
}
// HandleOffer processes an SDP offer from the client
func (p *Peer) HandleOffer(sdp string) error {
if p.pc == nil {
return ErrPeerNotInitialized
}
if err := p.pc.SetRemoteDescription(webrtc.SessionDescription{
Type: webrtc.SDPTypeOffer, SDP: sdp,
}); err != nil {
return err
}
answer, err := p.pc.CreateAnswer(nil)
if err != nil {
return err
}
if err := p.pc.SetLocalDescription(answer); err != nil {
return err
}
p.SendMessage(NewServerMessage(MessageTypeAnswer, &AnswerData{SDP: answer.SDP}))
return nil
}
// HandleAnswer processes an SDP answer from the client
func (p *Peer) HandleAnswer(sdp string) error {
if p.pc == nil {
return ErrPeerNotInitialized
}
return p.pc.SetRemoteDescription(webrtc.SessionDescription{
Type: webrtc.SDPTypeAnswer, SDP: sdp,
})
}
// HandleICECandidate adds a remote ICE candidate
func (p *Peer) HandleICECandidate(data *ICECandidateData) error {
if p.pc == nil {
return ErrPeerNotInitialized
}
return p.pc.AddICECandidate(data.ToWebRTCCandidate())
}
// AddTrack adds a local track to send to this peer
func (p *Peer) AddTrack(track *webrtc.TrackLocalStaticRTP) (*webrtc.RTPSender, error) {
if p.pc == nil {
return nil, ErrPeerNotInitialized
}
return p.pc.AddTrack(track)
}
// StartTrackBatch suppresses renegotiation during bulk track additions
func (p *Peer) StartTrackBatch() {
p.negotiationMu.Lock()
p.batchingTracks = true
p.negotiationMu.Unlock()
}
// EndTrackBatch ends batching and fires deferred renegotiation
func (p *Peer) EndTrackBatch() {
p.negotiationMu.Lock()
p.batchingTracks = false
pending := p.negotiationPending
p.negotiationPending = false
p.negotiationMu.Unlock()
if pending && p.pc != nil && p.pc.SignalingState() == webrtc.SignalingStateStable {
p.createAndSendOffer()
}
}
// SendMessage sends a signaling message via WebSocket
func (p *Peer) SendMessage(msg *ServerMessage) error {
p.closedMu.RLock()
if p.closed {
p.closedMu.RUnlock()
return ErrPeerClosed
}
p.closedMu.RUnlock()
p.connMu.Lock()
defer p.connMu.Unlock()
if p.conn == nil {
return ErrWebSocketClosed
}
data, err := json.Marshal(msg)
if err != nil {
return err
}
return p.conn.WriteMessage(websocket.TextMessage, data)
}
// GetInfo returns public info about this peer
func (p *Peer) GetInfo() ParticipantInfo {
return ParticipantInfo{PeerID: p.ID, UserID: p.UserID}
}
// handleReconnectTimeout waits 15 seconds for ICE reconnection before removing the peer.
func (p *Peer) handleReconnectTimeout() {
// Use a channel that closes when peer state changes
// Check after 15 seconds if still disconnected
<-timeAfter(reconnectTimeout)
if p.pc == nil {
return
}
state := p.pc.ICEConnectionState()
if state == webrtc.ICEConnectionStateDisconnected || state == webrtc.ICEConnectionStateFailed {
p.logger.Info("Peer did not reconnect within timeout, removing")
p.handleDisconnect()
}
}
func (p *Peer) handleDisconnect() {
p.closedMu.Lock()
if p.closed {
p.closedMu.Unlock()
return
}
p.closed = true
p.closedMu.Unlock()
if p.onClose != nil {
p.onClose(p)
}
}
// Close closes the peer connection and WebSocket
func (p *Peer) Close() error {
p.closedMu.Lock()
if p.closed {
p.closedMu.Unlock()
return nil
}
p.closed = true
p.closedMu.Unlock()
p.connMu.Lock()
if p.conn != nil {
p.conn.Close()
p.conn = nil
}
p.connMu.Unlock()
if p.pc != nil {
return p.pc.Close()
}
return nil
}
// OnClose sets the disconnect callback
func (p *Peer) OnClose(fn func(*Peer)) {
p.onClose = fn
}
// readRTCP reads RTCP feedback and forwards PLI/FIR to the source peer
func (p *Peer) readRTCP(receiver *webrtc.RTPReceiver, track *webrtc.TrackRemote) {
localTrackID := track.Kind().String() + "-" + p.ID
for {
packets, _, err := receiver.ReadRTCP()
if err != nil {
return
}
for _, pkt := range packets {
switch pkt.(type) {
case *rtcp.PictureLossIndication, *rtcp.FullIntraRequest:
p.room.RequestKeyframe(localTrackID)
}
}
}
}

555
pkg/sfu/room.go Normal file
View File

@ -0,0 +1,555 @@
package sfu
import (
"errors"
"fmt"
"sync"
"time"
"github.com/DeBrosOfficial/network/pkg/turn"
"github.com/pion/interceptor"
"github.com/pion/interceptor/pkg/intervalpli"
"github.com/pion/interceptor/pkg/nack"
"github.com/pion/rtcp"
"github.com/pion/webrtc/v4"
"go.uber.org/zap"
)
// For testing: allows overriding time.After
var timeAfter = func(d time.Duration) <-chan time.Time { return time.After(d) }
const (
reconnectTimeout = 15 * time.Second
emptyRoomTTL = 60 * time.Second
rtpBufferSize = 8192
)
var (
ErrRoomFull = errors.New("room is full")
ErrRoomClosed = errors.New("room is closed")
ErrPeerNotFound = errors.New("peer not found")
)
// publishedTrack holds a local track being forwarded from a remote source.
type publishedTrack struct {
sourcePeerID string
localTrack *webrtc.TrackLocalStaticRTP
remoteTrackSSRC uint32
kind string
}
// Room is a WebRTC room with multiple participants sharing media tracks.
type Room struct {
ID string
Namespace string
peers map[string]*Peer
peersMu sync.RWMutex
publishedTracks map[string]*publishedTrack // key: localTrack.ID()
publishedTracksMu sync.RWMutex
api *webrtc.API
config *Config
logger *zap.Logger
closed bool
closedMu sync.RWMutex
onEmpty func(*Room)
}
// RoomManager manages the lifecycle of rooms.
type RoomManager struct {
rooms map[string]*Room // key: roomID
mu sync.RWMutex
config *Config
logger *zap.Logger
}
// NewRoomManager creates a new room manager.
func NewRoomManager(cfg *Config, logger *zap.Logger) *RoomManager {
return &RoomManager{
rooms: make(map[string]*Room),
config: cfg,
logger: logger.With(zap.String("component", "room-manager")),
}
}
// GetOrCreateRoom returns an existing room or creates a new one.
func (rm *RoomManager) GetOrCreateRoom(roomID string) *Room {
rm.mu.Lock()
defer rm.mu.Unlock()
if room, ok := rm.rooms[roomID]; ok && !room.IsClosed() {
return room
}
api := newWebRTCAPI(rm.config)
room := &Room{
ID: roomID,
Namespace: rm.config.Namespace,
peers: make(map[string]*Peer),
publishedTracks: make(map[string]*publishedTrack),
api: api,
config: rm.config,
logger: rm.logger.With(zap.String("room_id", roomID)),
}
room.onEmpty = func(r *Room) {
// Start empty room cleanup timer
go func() {
<-timeAfter(emptyRoomTTL)
if r.GetParticipantCount() == 0 {
rm.mu.Lock()
delete(rm.rooms, r.ID)
rm.mu.Unlock()
r.Close()
rm.logger.Info("Empty room cleaned up", zap.String("room_id", r.ID))
}
}()
}
rm.rooms[roomID] = room
rm.logger.Info("Room created", zap.String("room_id", roomID))
return room
}
// GetRoom returns a room by ID, or nil if not found.
func (rm *RoomManager) GetRoom(roomID string) *Room {
rm.mu.RLock()
defer rm.mu.RUnlock()
return rm.rooms[roomID]
}
// CloseAll closes all rooms (for graceful shutdown).
func (rm *RoomManager) CloseAll() {
rm.mu.Lock()
rooms := make([]*Room, 0, len(rm.rooms))
for _, r := range rm.rooms {
rooms = append(rooms, r)
}
rm.rooms = make(map[string]*Room)
rm.mu.Unlock()
for _, r := range rooms {
r.Close()
}
}
// RoomCount returns the number of active rooms.
func (rm *RoomManager) RoomCount() int {
rm.mu.RLock()
defer rm.mu.RUnlock()
return len(rm.rooms)
}
// newWebRTCAPI creates a Pion WebRTC API with codecs and interceptors.
func newWebRTCAPI(cfg *Config) *webrtc.API {
m := &webrtc.MediaEngine{}
// Audio: Opus
videoRTCPFeedback := []webrtc.RTCPFeedback{
{Type: "goog-remb", Parameter: ""},
{Type: "ccm", Parameter: "fir"},
{Type: "nack", Parameter: ""},
{Type: "nack", Parameter: "pli"},
}
_ = m.RegisterCodec(webrtc.RTPCodecParameters{
RTPCodecCapability: webrtc.RTPCodecCapability{
MimeType: webrtc.MimeTypeOpus,
ClockRate: 48000,
Channels: 2,
SDPFmtpLine: "minptime=10;useinbandfec=1",
},
PayloadType: 111,
}, webrtc.RTPCodecTypeAudio)
// Video: VP8
_ = m.RegisterCodec(webrtc.RTPCodecParameters{
RTPCodecCapability: webrtc.RTPCodecCapability{
MimeType: webrtc.MimeTypeVP8,
ClockRate: 90000,
RTCPFeedback: videoRTCPFeedback,
},
PayloadType: 96,
}, webrtc.RTPCodecTypeVideo)
// Video: H264
_ = m.RegisterCodec(webrtc.RTPCodecParameters{
RTPCodecCapability: webrtc.RTPCodecCapability{
MimeType: webrtc.MimeTypeH264,
ClockRate: 90000,
SDPFmtpLine: "level-asymmetry-allowed=1;packetization-mode=1;profile-level-id=42001f",
RTCPFeedback: videoRTCPFeedback,
},
PayloadType: 125,
}, webrtc.RTPCodecTypeVideo)
// Interceptors: NACK + PLI
i := &interceptor.Registry{}
if f, err := nack.NewResponderInterceptor(); err == nil {
i.Add(f)
}
if f, err := nack.NewGeneratorInterceptor(); err == nil {
i.Add(f)
}
if f, err := intervalpli.NewReceiverInterceptor(); err == nil {
i.Add(f)
}
// SettingEngine: restrict media ports
se := webrtc.SettingEngine{}
if cfg.MediaPortStart > 0 && cfg.MediaPortEnd > 0 {
se.SetEphemeralUDPPortRange(uint16(cfg.MediaPortStart), uint16(cfg.MediaPortEnd))
}
return webrtc.NewAPI(
webrtc.WithMediaEngine(m),
webrtc.WithInterceptorRegistry(i),
webrtc.WithSettingEngine(se),
)
}
// --- Room methods ---
// AddPeer adds a peer to the room and notifies other participants.
func (r *Room) AddPeer(peer *Peer) error {
r.closedMu.RLock()
if r.closed {
r.closedMu.RUnlock()
return ErrRoomClosed
}
r.closedMu.RUnlock()
// Build ICE servers for TURN
iceServers := r.buildICEServers()
r.peersMu.Lock()
if len(r.peers) >= 100 { // Hard cap
r.peersMu.Unlock()
return ErrRoomFull
}
if err := peer.InitPeerConnection(r.api, iceServers); err != nil {
r.peersMu.Unlock()
return err
}
peer.OnClose(func(p *Peer) { r.RemovePeer(p.ID) })
r.peers[peer.ID] = peer
info := peer.GetInfo()
total := len(r.peers)
r.peersMu.Unlock()
r.logger.Info("Peer joined", zap.String("peer_id", peer.ID), zap.Int("total", total))
// Notify others
r.broadcastMessage(peer.ID, NewServerMessage(MessageTypeParticipantJoined, &ParticipantJoinedData{
Participant: info,
}))
return nil
}
// RemovePeer removes a peer and cleans up their published tracks.
func (r *Room) RemovePeer(peerID string) {
r.peersMu.Lock()
peer, ok := r.peers[peerID]
if !ok {
r.peersMu.Unlock()
return
}
delete(r.peers, peerID)
remaining := len(r.peers)
r.peersMu.Unlock()
// Remove published tracks from this peer
r.publishedTracksMu.Lock()
var removed []string
for trackID, pt := range r.publishedTracks {
if pt.sourcePeerID == peerID {
delete(r.publishedTracks, trackID)
removed = append(removed, trackID)
}
}
r.publishedTracksMu.Unlock()
// Remove RTPSenders for this peer's tracks from all other peers
if len(removed) > 0 {
r.removeTrackSendersFromPeers(removed)
}
peer.Close()
r.logger.Info("Peer left", zap.String("peer_id", peerID), zap.Int("remaining", remaining))
r.broadcastMessage(peerID, NewServerMessage(MessageTypeParticipantLeft, &ParticipantLeftData{
PeerID: peerID,
}))
// Notify about removed tracks
for _, trackID := range removed {
r.broadcastMessage(peerID, NewServerMessage(MessageTypeTrackRemoved, &TrackRemovedData{
PeerID: peerID,
TrackID: trackID,
}))
}
if remaining == 0 && r.onEmpty != nil {
r.onEmpty(r)
}
}
// removeTrackSendersFromPeers removes RTPSenders for the given track IDs from all peers.
// This fixes the ghost track bug from the original implementation.
func (r *Room) removeTrackSendersFromPeers(trackIDs []string) {
trackIDSet := make(map[string]bool, len(trackIDs))
for _, id := range trackIDs {
trackIDSet[id] = true
}
r.peersMu.RLock()
defer r.peersMu.RUnlock()
for _, peer := range r.peers {
if peer.pc == nil {
continue
}
for _, sender := range peer.pc.GetSenders() {
if sender.Track() == nil {
continue
}
if trackIDSet[sender.Track().ID()] {
if err := peer.pc.RemoveTrack(sender); err != nil {
r.logger.Warn("Failed to remove track sender",
zap.String("peer_id", peer.ID),
zap.String("track_id", sender.Track().ID()),
zap.Error(err))
}
}
}
}
}
// BroadcastTrack creates a local track from a remote track and forwards it to all other peers.
func (r *Room) BroadcastTrack(sourcePeerID string, track *webrtc.TrackRemote) {
codec := track.Codec()
localTrack, err := webrtc.NewTrackLocalStaticRTP(
codec.RTPCodecCapability,
track.Kind().String()+"-"+sourcePeerID,
sourcePeerID,
)
if err != nil {
r.logger.Error("Failed to create local track", zap.Error(err))
return
}
// Store for future joiners
r.publishedTracksMu.Lock()
r.publishedTracks[localTrack.ID()] = &publishedTrack{
sourcePeerID: sourcePeerID,
localTrack: localTrack,
remoteTrackSSRC: uint32(track.SSRC()),
kind: track.Kind().String(),
}
r.publishedTracksMu.Unlock()
// RTP forwarding loop with proper buffer size
go func() {
buf := make([]byte, rtpBufferSize)
for {
n, _, err := track.Read(buf)
if err != nil {
return
}
if _, err := localTrack.Write(buf[:n]); err != nil {
return
}
}
}()
// Add to all current peers except the source
r.peersMu.RLock()
for peerID, peer := range r.peers {
if peerID == sourcePeerID {
continue
}
if _, err := peer.AddTrack(localTrack); err != nil {
r.logger.Warn("Failed to add track to peer",
zap.String("peer_id", peerID), zap.Error(err))
continue
}
peer.SendMessage(NewServerMessage(MessageTypeTrackAdded, &TrackAddedData{
PeerID: sourcePeerID,
TrackID: localTrack.ID(),
StreamID: localTrack.StreamID(),
Kind: track.Kind().String(),
}))
}
r.peersMu.RUnlock()
}
// SendExistingTracksTo sends all published tracks to a newly joined peer.
// Uses batch mode for a single renegotiation.
func (r *Room) SendExistingTracksTo(peer *Peer) {
r.publishedTracksMu.RLock()
var tracks []*publishedTrack
for _, pt := range r.publishedTracks {
if pt.sourcePeerID != peer.ID {
tracks = append(tracks, pt)
}
}
r.publishedTracksMu.RUnlock()
if len(tracks) == 0 {
return
}
peer.StartTrackBatch()
for _, pt := range tracks {
if _, err := peer.AddTrack(pt.localTrack); err != nil {
r.logger.Warn("Failed to add existing track", zap.Error(err))
continue
}
peer.SendMessage(NewServerMessage(MessageTypeTrackAdded, &TrackAddedData{
PeerID: pt.sourcePeerID,
TrackID: pt.localTrack.ID(),
StreamID: pt.localTrack.StreamID(),
Kind: pt.kind,
}))
}
peer.EndTrackBatch()
// Request keyframes for video tracks after negotiation settles
go func() {
<-timeAfter(300 * time.Millisecond)
r.RequestKeyframeForAllVideoTracks()
}()
}
// RequestKeyframe sends a PLI to the source peer for a video track.
func (r *Room) RequestKeyframe(trackID string) {
r.publishedTracksMu.RLock()
pt, ok := r.publishedTracks[trackID]
r.publishedTracksMu.RUnlock()
if !ok || pt.kind != "video" {
return
}
r.peersMu.RLock()
source, ok := r.peers[pt.sourcePeerID]
r.peersMu.RUnlock()
if !ok || source.pc == nil {
return
}
pli := &rtcp.PictureLossIndication{MediaSSRC: pt.remoteTrackSSRC}
if err := source.pc.WriteRTCP([]rtcp.Packet{pli}); err != nil {
r.logger.Debug("Failed to send PLI", zap.String("track_id", trackID), zap.Error(err))
}
}
// RequestKeyframeForAllVideoTracks sends PLIs for all video tracks.
func (r *Room) RequestKeyframeForAllVideoTracks() {
r.publishedTracksMu.RLock()
var ids []string
for id, pt := range r.publishedTracks {
if pt.kind == "video" {
ids = append(ids, id)
}
}
r.publishedTracksMu.RUnlock()
for _, id := range ids {
r.RequestKeyframe(id)
}
}
// GetParticipants returns info about all participants.
func (r *Room) GetParticipants() []ParticipantInfo {
r.peersMu.RLock()
defer r.peersMu.RUnlock()
infos := make([]ParticipantInfo, 0, len(r.peers))
for _, p := range r.peers {
infos = append(infos, p.GetInfo())
}
return infos
}
// GetParticipantCount returns the number of participants.
func (r *Room) GetParticipantCount() int {
r.peersMu.RLock()
defer r.peersMu.RUnlock()
return len(r.peers)
}
// IsClosed returns whether the room is closed.
func (r *Room) IsClosed() bool {
r.closedMu.RLock()
defer r.closedMu.RUnlock()
return r.closed
}
// Close closes the room and all peer connections.
func (r *Room) Close() error {
r.closedMu.Lock()
if r.closed {
r.closedMu.Unlock()
return nil
}
r.closed = true
r.closedMu.Unlock()
r.peersMu.Lock()
peers := make([]*Peer, 0, len(r.peers))
for _, p := range r.peers {
peers = append(peers, p)
}
r.peers = make(map[string]*Peer)
r.peersMu.Unlock()
for _, p := range peers {
p.Close()
}
r.logger.Info("Room closed")
return nil
}
func (r *Room) broadcastMessage(excludePeerID string, msg *ServerMessage) {
r.peersMu.RLock()
defer r.peersMu.RUnlock()
for id, peer := range r.peers {
if id == excludePeerID {
continue
}
peer.SendMessage(msg)
}
}
// buildICEServers constructs ICE server config from TURN settings.
func (r *Room) buildICEServers() []webrtc.ICEServer {
if len(r.config.TURNServers) == 0 || r.config.TURNSecret == "" {
return nil
}
var urls []string
for _, ts := range r.config.TURNServers {
urls = append(urls, fmt.Sprintf("turn:%s:%d?transport=udp", ts.Host, ts.Port))
}
ttl := time.Duration(r.config.TURNCredentialTTL) * time.Second
username, password := turn.GenerateCredentials(r.config.TURNSecret, r.config.Namespace, ttl)
return []webrtc.ICEServer{
{
URLs: urls,
Username: username,
Credential: password,
},
}
}

368
pkg/sfu/room_test.go Normal file
View File

@ -0,0 +1,368 @@
package sfu
import (
"net/http"
"net/http/httptest"
"testing"
"time"
"go.uber.org/zap"
)
func testConfig() *Config {
return &Config{
ListenAddr: "10.0.0.1:8443",
Namespace: "test-ns",
MediaPortStart: 20000,
MediaPortEnd: 20500,
TURNServers: []TURNServerConfig{{Host: "1.2.3.4", Port: 3478}},
TURNSecret: "test-secret-key-32bytes-long!!!!",
TURNCredentialTTL: 600,
RQLiteDSN: "http://10.0.0.1:4001",
}
}
func testLogger() *zap.Logger {
return zap.NewNop()
}
// --- RoomManager tests ---
func TestNewRoomManager(t *testing.T) {
rm := NewRoomManager(testConfig(), testLogger())
if rm == nil {
t.Fatal("NewRoomManager returned nil")
}
if rm.RoomCount() != 0 {
t.Errorf("RoomCount = %d, want 0", rm.RoomCount())
}
}
func TestRoomManagerGetOrCreateRoom(t *testing.T) {
rm := NewRoomManager(testConfig(), testLogger())
room1 := rm.GetOrCreateRoom("room-1")
if room1 == nil {
t.Fatal("GetOrCreateRoom returned nil")
}
if room1.ID != "room-1" {
t.Errorf("Room.ID = %q, want %q", room1.ID, "room-1")
}
if room1.Namespace != "test-ns" {
t.Errorf("Room.Namespace = %q, want %q", room1.Namespace, "test-ns")
}
if rm.RoomCount() != 1 {
t.Errorf("RoomCount = %d, want 1", rm.RoomCount())
}
// Getting same room returns same instance
room1Again := rm.GetOrCreateRoom("room-1")
if room1 != room1Again {
t.Error("expected same room instance")
}
if rm.RoomCount() != 1 {
t.Errorf("RoomCount = %d, want 1 (same room)", rm.RoomCount())
}
// Different room creates new instance
room2 := rm.GetOrCreateRoom("room-2")
if room2 == nil {
t.Fatal("second room is nil")
}
if room2.ID != "room-2" {
t.Errorf("Room.ID = %q, want %q", room2.ID, "room-2")
}
if rm.RoomCount() != 2 {
t.Errorf("RoomCount = %d, want 2", rm.RoomCount())
}
}
func TestRoomManagerGetRoom(t *testing.T) {
rm := NewRoomManager(testConfig(), testLogger())
// Non-existent room returns nil
room := rm.GetRoom("nonexistent")
if room != nil {
t.Error("expected nil for non-existent room")
}
// Create a room and retrieve it
rm.GetOrCreateRoom("room-1")
room = rm.GetRoom("room-1")
if room == nil {
t.Fatal("expected non-nil for existing room")
}
if room.ID != "room-1" {
t.Errorf("Room.ID = %q, want %q", room.ID, "room-1")
}
}
func TestRoomManagerCloseAll(t *testing.T) {
rm := NewRoomManager(testConfig(), testLogger())
rm.GetOrCreateRoom("room-1")
rm.GetOrCreateRoom("room-2")
rm.GetOrCreateRoom("room-3")
if rm.RoomCount() != 3 {
t.Fatalf("RoomCount = %d, want 3", rm.RoomCount())
}
rm.CloseAll()
if rm.RoomCount() != 0 {
t.Errorf("RoomCount after CloseAll = %d, want 0", rm.RoomCount())
}
}
func TestRoomManagerGetOrCreateRoomReplacesClosedRoom(t *testing.T) {
rm := NewRoomManager(testConfig(), testLogger())
room1 := rm.GetOrCreateRoom("room-1")
room1.Close()
// Getting the same room ID after close should create a new room
room1New := rm.GetOrCreateRoom("room-1")
if room1New == room1 {
t.Error("expected new room instance after close")
}
if room1New.IsClosed() {
t.Error("new room should not be closed")
}
}
// --- Room tests ---
func TestRoomIsClosed(t *testing.T) {
rm := NewRoomManager(testConfig(), testLogger())
room := rm.GetOrCreateRoom("room-1")
if room.IsClosed() {
t.Error("new room should not be closed")
}
room.Close()
if !room.IsClosed() {
t.Error("room should be closed after Close()")
}
}
func TestRoomCloseIdempotent(t *testing.T) {
rm := NewRoomManager(testConfig(), testLogger())
room := rm.GetOrCreateRoom("room-1")
// Should not panic or error when called multiple times
if err := room.Close(); err != nil {
t.Errorf("first Close() returned error: %v", err)
}
if err := room.Close(); err != nil {
t.Errorf("second Close() returned error: %v", err)
}
}
func TestRoomGetParticipantsEmpty(t *testing.T) {
rm := NewRoomManager(testConfig(), testLogger())
room := rm.GetOrCreateRoom("room-1")
participants := room.GetParticipants()
if len(participants) != 0 {
t.Errorf("Participants count = %d, want 0", len(participants))
}
if room.GetParticipantCount() != 0 {
t.Errorf("ParticipantCount = %d, want 0", room.GetParticipantCount())
}
}
func TestRoomBuildICEServers(t *testing.T) {
rm := NewRoomManager(testConfig(), testLogger())
room := rm.GetOrCreateRoom("room-1")
servers := room.buildICEServers()
if len(servers) != 1 {
t.Fatalf("ICE servers count = %d, want 1", len(servers))
}
if len(servers[0].URLs) != 1 {
t.Fatalf("URLs count = %d, want 1", len(servers[0].URLs))
}
if servers[0].URLs[0] != "turn:1.2.3.4:3478?transport=udp" {
t.Errorf("URL = %q, want %q", servers[0].URLs[0], "turn:1.2.3.4:3478?transport=udp")
}
if servers[0].Username == "" {
t.Error("Username should not be empty")
}
if servers[0].Credential == "" {
t.Error("Credential should not be empty")
}
}
func TestRoomBuildICEServersNoTURN(t *testing.T) {
cfg := testConfig()
cfg.TURNServers = nil
rm := NewRoomManager(cfg, testLogger())
room := rm.GetOrCreateRoom("room-1")
servers := room.buildICEServers()
if servers != nil {
t.Errorf("expected nil ICE servers when no TURN configured, got %v", servers)
}
}
func TestRoomBuildICEServersNoSecret(t *testing.T) {
cfg := testConfig()
cfg.TURNSecret = ""
rm := NewRoomManager(cfg, testLogger())
room := rm.GetOrCreateRoom("room-1")
servers := room.buildICEServers()
if servers != nil {
t.Errorf("expected nil ICE servers when no secret, got %v", servers)
}
}
func TestRoomBuildICEServersMultipleTURN(t *testing.T) {
cfg := testConfig()
cfg.TURNServers = []TURNServerConfig{
{Host: "1.2.3.4", Port: 3478},
{Host: "5.6.7.8", Port: 443},
}
rm := NewRoomManager(cfg, testLogger())
room := rm.GetOrCreateRoom("room-1")
servers := room.buildICEServers()
if len(servers) != 1 {
t.Fatalf("ICE servers count = %d, want 1", len(servers))
}
if len(servers[0].URLs) != 2 {
t.Fatalf("URLs count = %d, want 2", len(servers[0].URLs))
}
}
// --- Empty room cleanup test ---
func TestEmptyRoomCleanup(t *testing.T) {
// Override timeAfter for instant timer
origTimeAfter := timeAfter
timeAfter = func(d time.Duration) <-chan time.Time {
ch := make(chan time.Time, 1)
ch <- time.Now()
return ch
}
defer func() { timeAfter = origTimeAfter }()
rm := NewRoomManager(testConfig(), testLogger())
room := rm.GetOrCreateRoom("room-1")
// Trigger the onEmpty callback (which starts cleanup timer)
room.onEmpty(room)
// Give the goroutine time to execute
time.Sleep(50 * time.Millisecond)
if rm.RoomCount() != 0 {
t.Errorf("RoomCount = %d, want 0 (should have been cleaned up)", rm.RoomCount())
}
}
// --- Server health tests ---
func TestHealthEndpointOK(t *testing.T) {
cfg := testConfig()
server, err := NewServer(cfg, testLogger())
if err != nil {
t.Fatalf("NewServer failed: %v", err)
}
req := httptest.NewRequest("GET", "/health", nil)
w := httptest.NewRecorder()
server.handleHealth(w, req)
if w.Code != http.StatusOK {
t.Errorf("status = %d, want %d", w.Code, http.StatusOK)
}
body := w.Body.String()
if body != `{"status":"ok","rooms":0}` {
t.Errorf("body = %q, want %q", body, `{"status":"ok","rooms":0}`)
}
}
func TestHealthEndpointDraining(t *testing.T) {
cfg := testConfig()
server, err := NewServer(cfg, testLogger())
if err != nil {
t.Fatalf("NewServer failed: %v", err)
}
// Set draining
server.drainingMu.Lock()
server.draining = true
server.drainingMu.Unlock()
req := httptest.NewRequest("GET", "/health", nil)
w := httptest.NewRecorder()
server.handleHealth(w, req)
if w.Code != http.StatusServiceUnavailable {
t.Errorf("status = %d, want %d", w.Code, http.StatusServiceUnavailable)
}
body := w.Body.String()
if body != `{"status":"draining","rooms":0}` {
t.Errorf("body = %q, want %q", body, `{"status":"draining","rooms":0}`)
}
}
func TestServerDrainSetsFlag(t *testing.T) {
// Override timeAfter for instant timer
origTimeAfter := timeAfter
timeAfter = func(d time.Duration) <-chan time.Time {
ch := make(chan time.Time, 1)
ch <- time.Now()
return ch
}
defer func() { timeAfter = origTimeAfter }()
cfg := testConfig()
server, err := NewServer(cfg, testLogger())
if err != nil {
t.Fatalf("NewServer failed: %v", err)
}
server.Drain(0)
server.drainingMu.RLock()
draining := server.draining
server.drainingMu.RUnlock()
if !draining {
t.Error("expected draining to be true after Drain()")
}
}
func TestServerNewServerValidation(t *testing.T) {
// Invalid config should return error
cfg := &Config{} // Empty = invalid
_, err := NewServer(cfg, testLogger())
if err == nil {
t.Error("expected error for invalid config")
}
}
func TestServerSignalEndpointRejectsDraining(t *testing.T) {
cfg := testConfig()
server, err := NewServer(cfg, testLogger())
if err != nil {
t.Fatalf("NewServer failed: %v", err)
}
server.drainingMu.Lock()
server.draining = true
server.drainingMu.Unlock()
req := httptest.NewRequest("GET", "/ws/signal", nil)
w := httptest.NewRecorder()
server.handleSignal(w, req)
if w.Code != http.StatusServiceUnavailable {
t.Errorf("status = %d, want %d", w.Code, http.StatusServiceUnavailable)
}
}

293
pkg/sfu/server.go Normal file
View File

@ -0,0 +1,293 @@
package sfu
import (
"context"
"encoding/json"
"fmt"
"net/http"
"sync"
"time"
"github.com/DeBrosOfficial/network/pkg/turn"
"github.com/gorilla/websocket"
"go.uber.org/zap"
)
// Server is the SFU HTTP server providing WebSocket signaling and a health endpoint.
// It binds only to a WireGuard IP — never exposed publicly.
type Server struct {
config *Config
roomManager *RoomManager
logger *zap.Logger
httpServer *http.Server
upgrader websocket.Upgrader
draining bool
drainingMu sync.RWMutex
}
// NewServer creates a new SFU server.
func NewServer(cfg *Config, logger *zap.Logger) (*Server, error) {
if errs := cfg.Validate(); len(errs) > 0 {
return nil, fmt.Errorf("invalid SFU config: %v", errs[0])
}
s := &Server{
config: cfg,
roomManager: NewRoomManager(cfg, logger),
logger: logger.With(zap.String("component", "sfu"), zap.String("namespace", cfg.Namespace)),
upgrader: websocket.Upgrader{
ReadBufferSize: 4096,
WriteBufferSize: 4096,
CheckOrigin: func(r *http.Request) bool { return true }, // Gateway handles auth
},
}
mux := http.NewServeMux()
mux.HandleFunc("/ws/signal", s.handleSignal)
mux.HandleFunc("/health", s.handleHealth)
s.httpServer = &http.Server{
Addr: cfg.ListenAddr,
Handler: mux,
ReadHeaderTimeout: 10 * time.Second,
}
return s, nil
}
// ListenAndServe starts the HTTP server. Blocks until the server is stopped.
func (s *Server) ListenAndServe() error {
s.logger.Info("SFU server starting",
zap.String("addr", s.config.ListenAddr),
zap.String("namespace", s.config.Namespace))
return s.httpServer.ListenAndServe()
}
// Drain initiates graceful drain: notifies all peers, waits, then closes.
func (s *Server) Drain(timeout time.Duration) {
s.drainingMu.Lock()
s.draining = true
s.drainingMu.Unlock()
s.logger.Info("SFU draining started", zap.Duration("timeout", timeout))
// Notify all peers
s.roomManager.mu.RLock()
for _, room := range s.roomManager.rooms {
room.broadcastMessage("", NewServerMessage(MessageTypeServerDraining, &ServerDrainingData{
Reason: "server shutting down",
TimeoutMs: int(timeout.Milliseconds()),
}))
}
s.roomManager.mu.RUnlock()
// Wait for timeout, then force close
<-timeAfter(timeout)
}
// Close shuts down the SFU server.
func (s *Server) Close() error {
s.logger.Info("SFU server shutting down")
s.roomManager.CloseAll()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
return s.httpServer.Shutdown(ctx)
}
// handleHealth is a simple health check endpoint.
func (s *Server) handleHealth(w http.ResponseWriter, r *http.Request) {
s.drainingMu.RLock()
draining := s.draining
s.drainingMu.RUnlock()
if draining {
w.WriteHeader(http.StatusServiceUnavailable)
fmt.Fprintf(w, `{"status":"draining","rooms":%d}`, s.roomManager.RoomCount())
return
}
w.WriteHeader(http.StatusOK)
fmt.Fprintf(w, `{"status":"ok","rooms":%d}`, s.roomManager.RoomCount())
}
// handleSignal upgrades to WebSocket and runs the signaling loop for one peer.
func (s *Server) handleSignal(w http.ResponseWriter, r *http.Request) {
s.drainingMu.RLock()
if s.draining {
s.drainingMu.RUnlock()
http.Error(w, "server draining", http.StatusServiceUnavailable)
return
}
s.drainingMu.RUnlock()
conn, err := s.upgrader.Upgrade(w, r, nil)
if err != nil {
s.logger.Error("WebSocket upgrade failed", zap.Error(err))
return
}
s.logger.Debug("WebSocket connected", zap.String("remote", r.RemoteAddr))
// Read the first message — must be a join
conn.SetReadDeadline(time.Now().Add(10 * time.Second))
_, msgBytes, err := conn.ReadMessage()
if err != nil {
s.logger.Warn("Failed to read join message", zap.Error(err))
conn.Close()
return
}
conn.SetReadDeadline(time.Time{}) // Clear deadline
var msg ClientMessage
if err := json.Unmarshal(msgBytes, &msg); err != nil {
conn.WriteMessage(websocket.TextMessage, mustMarshal(NewErrorMessage("invalid_message", "malformed JSON")))
conn.Close()
return
}
if msg.Type != MessageTypeJoin {
conn.WriteMessage(websocket.TextMessage, mustMarshal(NewErrorMessage("invalid_message", "first message must be join")))
conn.Close()
return
}
var joinData JoinData
if err := json.Unmarshal(msg.Data, &joinData); err != nil || joinData.RoomID == "" || joinData.UserID == "" {
conn.WriteMessage(websocket.TextMessage, mustMarshal(NewErrorMessage("invalid_join", "roomId and userId required")))
conn.Close()
return
}
room := s.roomManager.GetOrCreateRoom(joinData.RoomID)
peer := NewPeer(joinData.UserID, conn, room, s.logger)
if err := room.AddPeer(peer); err != nil {
conn.WriteMessage(websocket.TextMessage, mustMarshal(NewErrorMessage("join_failed", err.Error())))
conn.Close()
return
}
// Send welcome with current participants
peer.SendMessage(NewServerMessage(MessageTypeWelcome, &WelcomeData{
PeerID: peer.ID,
RoomID: room.ID,
Participants: room.GetParticipants(),
}))
// Send TURN credentials
if s.config.TURNSecret != "" && len(s.config.TURNServers) > 0 {
s.sendTURNCredentials(peer)
}
// Send existing tracks from other peers
room.SendExistingTracksTo(peer)
// Start credential refresh goroutine
if s.config.TURNCredentialTTL > 0 {
go s.credentialRefreshLoop(peer)
}
// Signaling read loop
s.signalingLoop(peer, room)
}
// signalingLoop reads signaling messages from the WebSocket until disconnect.
func (s *Server) signalingLoop(peer *Peer, room *Room) {
defer room.RemovePeer(peer.ID)
for {
_, msgBytes, err := peer.conn.ReadMessage()
if err != nil {
s.logger.Debug("WebSocket read error", zap.String("peer_id", peer.ID), zap.Error(err))
return
}
var msg ClientMessage
if err := json.Unmarshal(msgBytes, &msg); err != nil {
peer.SendMessage(NewErrorMessage("invalid_message", "malformed JSON"))
continue
}
switch msg.Type {
case MessageTypeOffer:
var data OfferData
if err := json.Unmarshal(msg.Data, &data); err != nil {
peer.SendMessage(NewErrorMessage("invalid_offer", err.Error()))
continue
}
if err := peer.HandleOffer(data.SDP); err != nil {
s.logger.Error("Failed to handle offer", zap.String("peer_id", peer.ID), zap.Error(err))
peer.SendMessage(NewErrorMessage("offer_failed", err.Error()))
}
case MessageTypeAnswer:
var data AnswerData
if err := json.Unmarshal(msg.Data, &data); err != nil {
peer.SendMessage(NewErrorMessage("invalid_answer", err.Error()))
continue
}
if err := peer.HandleAnswer(data.SDP); err != nil {
s.logger.Error("Failed to handle answer", zap.String("peer_id", peer.ID), zap.Error(err))
}
case MessageTypeICECandidate:
var data ICECandidateData
if err := json.Unmarshal(msg.Data, &data); err != nil {
peer.SendMessage(NewErrorMessage("invalid_candidate", err.Error()))
continue
}
if err := peer.HandleICECandidate(&data); err != nil {
s.logger.Error("Failed to handle ICE candidate", zap.String("peer_id", peer.ID), zap.Error(err))
}
case MessageTypeLeave:
s.logger.Info("Peer leaving", zap.String("peer_id", peer.ID))
return
default:
peer.SendMessage(NewErrorMessage("unknown_message", fmt.Sprintf("unknown message type: %s", msg.Type)))
}
}
}
// sendTURNCredentials sends TURN server credentials to a peer.
func (s *Server) sendTURNCredentials(peer *Peer) {
ttl := time.Duration(s.config.TURNCredentialTTL) * time.Second
username, password := turn.GenerateCredentials(s.config.TURNSecret, s.config.Namespace, ttl)
var uris []string
for _, ts := range s.config.TURNServers {
uris = append(uris, fmt.Sprintf("turn:%s:%d?transport=udp", ts.Host, ts.Port))
}
peer.SendMessage(NewServerMessage(MessageTypeTURNCredentials, &TURNCredentialsData{
Username: username,
Password: password,
TTL: s.config.TURNCredentialTTL,
URIs: uris,
}))
}
// credentialRefreshLoop sends fresh TURN credentials at 80% of TTL.
func (s *Server) credentialRefreshLoop(peer *Peer) {
refreshInterval := time.Duration(float64(s.config.TURNCredentialTTL)*0.8) * time.Second
for {
<-timeAfter(refreshInterval)
peer.closedMu.RLock()
closed := peer.closed
peer.closedMu.RUnlock()
if closed {
return
}
s.sendTURNCredentials(peer)
s.logger.Debug("Refreshed TURN credentials", zap.String("peer_id", peer.ID))
}
}
func mustMarshal(v interface{}) []byte {
data, _ := json.Marshal(v)
return data
}

144
pkg/sfu/signaling.go Normal file
View File

@ -0,0 +1,144 @@
package sfu
import (
"encoding/json"
"github.com/pion/webrtc/v4"
)
// MessageType represents the type of signaling message
type MessageType string
const (
// Client → Server
MessageTypeJoin MessageType = "join"
MessageTypeLeave MessageType = "leave"
MessageTypeOffer MessageType = "offer"
MessageTypeAnswer MessageType = "answer"
MessageTypeICECandidate MessageType = "ice-candidate"
// Server → Client
MessageTypeWelcome MessageType = "welcome"
MessageTypeParticipantJoined MessageType = "participant-joined"
MessageTypeParticipantLeft MessageType = "participant-left"
MessageTypeTrackAdded MessageType = "track-added"
MessageTypeTrackRemoved MessageType = "track-removed"
MessageTypeTURNCredentials MessageType = "turn-credentials"
MessageTypeRefreshCredentials MessageType = "refresh-credentials"
MessageTypeServerDraining MessageType = "server-draining"
MessageTypeError MessageType = "error"
)
// ClientMessage is a message from client to server
type ClientMessage struct {
Type MessageType `json:"type"`
Data json.RawMessage `json:"data,omitempty"`
}
// ServerMessage is a message from server to client
type ServerMessage struct {
Type MessageType `json:"type"`
Data interface{} `json:"data,omitempty"`
}
// JoinData is the payload for join messages
type JoinData struct {
RoomID string `json:"roomId"`
UserID string `json:"userId"`
}
// OfferData is the payload for SDP offer messages
type OfferData struct {
SDP string `json:"sdp"`
}
// AnswerData is the payload for SDP answer messages
type AnswerData struct {
SDP string `json:"sdp"`
}
// ICECandidateData is the payload for ICE candidate messages
type ICECandidateData struct {
Candidate string `json:"candidate"`
SDPMid string `json:"sdpMid,omitempty"`
SDPMLineIndex uint16 `json:"sdpMLineIndex,omitempty"`
UsernameFragment string `json:"usernameFragment,omitempty"`
}
// ToWebRTCCandidate converts to pion ICECandidateInit
func (c *ICECandidateData) ToWebRTCCandidate() webrtc.ICECandidateInit {
return webrtc.ICECandidateInit{
Candidate: c.Candidate,
SDPMid: &c.SDPMid,
SDPMLineIndex: &c.SDPMLineIndex,
UsernameFragment: &c.UsernameFragment,
}
}
// WelcomeData is sent when a peer successfully joins a room
type WelcomeData struct {
PeerID string `json:"peerId"`
RoomID string `json:"roomId"`
Participants []ParticipantInfo `json:"participants"`
}
// ParticipantInfo is public info about a room participant
type ParticipantInfo struct {
PeerID string `json:"peerId"`
UserID string `json:"userId"`
}
// ParticipantJoinedData is sent when a new participant joins
type ParticipantJoinedData struct {
Participant ParticipantInfo `json:"participant"`
}
// ParticipantLeftData is sent when a participant leaves
type ParticipantLeftData struct {
PeerID string `json:"peerId"`
}
// TrackAddedData is sent when a new track is available
type TrackAddedData struct {
PeerID string `json:"peerId"`
TrackID string `json:"trackId"`
StreamID string `json:"streamId"`
Kind string `json:"kind"` // "audio" or "video"
}
// TrackRemovedData is sent when a track is removed
type TrackRemovedData struct {
PeerID string `json:"peerId"`
TrackID string `json:"trackId"`
Kind string `json:"kind"`
}
// TURNCredentialsData provides TURN server credentials
type TURNCredentialsData struct {
Username string `json:"username"`
Password string `json:"password"`
TTL int `json:"ttl"`
URIs []string `json:"uris"`
}
// ServerDrainingData warns clients the server is shutting down
type ServerDrainingData struct {
Reason string `json:"reason"`
TimeoutMs int `json:"timeoutMs"`
}
// ErrorData is sent when an error occurs
type ErrorData struct {
Code string `json:"code"`
Message string `json:"message"`
}
// NewServerMessage creates a new server message
func NewServerMessage(msgType MessageType, data interface{}) *ServerMessage {
return &ServerMessage{Type: msgType, Data: data}
}
// NewErrorMessage creates a new error message
func NewErrorMessage(code, message string) *ServerMessage {
return NewServerMessage(MessageTypeError, &ErrorData{Code: code, Message: message})
}

257
pkg/sfu/signaling_test.go Normal file
View File

@ -0,0 +1,257 @@
package sfu
import (
"encoding/json"
"testing"
)
func TestClientMessageDeserialization(t *testing.T) {
tests := []struct {
name string
input string
wantType MessageType
wantData bool
}{
{
name: "join message",
input: `{"type":"join","data":{"roomId":"room-1","userId":"user-1"}}`,
wantType: MessageTypeJoin,
wantData: true,
},
{
name: "leave message",
input: `{"type":"leave"}`,
wantType: MessageTypeLeave,
wantData: false,
},
{
name: "offer message",
input: `{"type":"offer","data":{"sdp":"v=0..."}}`,
wantType: MessageTypeOffer,
wantData: true,
},
{
name: "answer message",
input: `{"type":"answer","data":{"sdp":"v=0..."}}`,
wantType: MessageTypeAnswer,
wantData: true,
},
{
name: "ice-candidate message",
input: `{"type":"ice-candidate","data":{"candidate":"candidate:1234"}}`,
wantType: MessageTypeICECandidate,
wantData: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var msg ClientMessage
if err := json.Unmarshal([]byte(tt.input), &msg); err != nil {
t.Fatalf("failed to unmarshal: %v", err)
}
if msg.Type != tt.wantType {
t.Errorf("Type = %q, want %q", msg.Type, tt.wantType)
}
if tt.wantData && msg.Data == nil {
t.Error("expected Data to be non-nil")
}
if !tt.wantData && msg.Data != nil {
t.Error("expected Data to be nil")
}
})
}
}
func TestJoinDataDeserialization(t *testing.T) {
input := `{"roomId":"room-abc","userId":"user-xyz"}`
var data JoinData
if err := json.Unmarshal([]byte(input), &data); err != nil {
t.Fatalf("failed to unmarshal: %v", err)
}
if data.RoomID != "room-abc" {
t.Errorf("RoomID = %q, want %q", data.RoomID, "room-abc")
}
if data.UserID != "user-xyz" {
t.Errorf("UserID = %q, want %q", data.UserID, "user-xyz")
}
}
func TestServerMessageSerialization(t *testing.T) {
tests := []struct {
name string
msg *ServerMessage
wantKey string
}{
{
name: "welcome message",
msg: NewServerMessage(MessageTypeWelcome, &WelcomeData{PeerID: "p1", RoomID: "r1"}),
wantKey: "welcome",
},
{
name: "participant joined",
msg: NewServerMessage(MessageTypeParticipantJoined, &ParticipantJoinedData{Participant: ParticipantInfo{PeerID: "p2", UserID: "u2"}}),
wantKey: "participant-joined",
},
{
name: "participant left",
msg: NewServerMessage(MessageTypeParticipantLeft, &ParticipantLeftData{PeerID: "p2"}),
wantKey: "participant-left",
},
{
name: "track added",
msg: NewServerMessage(MessageTypeTrackAdded, &TrackAddedData{PeerID: "p1", TrackID: "t1", StreamID: "s1", Kind: "video"}),
wantKey: "track-added",
},
{
name: "track removed",
msg: NewServerMessage(MessageTypeTrackRemoved, &TrackRemovedData{PeerID: "p1", TrackID: "t1", Kind: "video"}),
wantKey: "track-removed",
},
{
name: "TURN credentials",
msg: NewServerMessage(MessageTypeTURNCredentials, &TURNCredentialsData{Username: "u", Password: "p", TTL: 600, URIs: []string{"turn:1.2.3.4:3478"}}),
wantKey: "turn-credentials",
},
{
name: "server draining",
msg: NewServerMessage(MessageTypeServerDraining, &ServerDrainingData{Reason: "shutdown", TimeoutMs: 30000}),
wantKey: "server-draining",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
data, err := json.Marshal(tt.msg)
if err != nil {
t.Fatalf("failed to marshal: %v", err)
}
// Verify it roundtrips correctly
var raw map[string]json.RawMessage
if err := json.Unmarshal(data, &raw); err != nil {
t.Fatalf("failed to unmarshal to raw: %v", err)
}
var msgType string
if err := json.Unmarshal(raw["type"], &msgType); err != nil {
t.Fatalf("failed to unmarshal type: %v", err)
}
if msgType != tt.wantKey {
t.Errorf("type = %q, want %q", msgType, tt.wantKey)
}
if _, ok := raw["data"]; !ok {
t.Error("expected data field in output")
}
})
}
}
func TestNewErrorMessage(t *testing.T) {
msg := NewErrorMessage("invalid_offer", "bad SDP")
if msg.Type != MessageTypeError {
t.Errorf("Type = %q, want %q", msg.Type, MessageTypeError)
}
errData, ok := msg.Data.(*ErrorData)
if !ok {
t.Fatal("Data is not *ErrorData")
}
if errData.Code != "invalid_offer" {
t.Errorf("Code = %q, want %q", errData.Code, "invalid_offer")
}
if errData.Message != "bad SDP" {
t.Errorf("Message = %q, want %q", errData.Message, "bad SDP")
}
// Verify serialization
data, err := json.Marshal(msg)
if err != nil {
t.Fatalf("failed to marshal: %v", err)
}
result := string(data)
if result == "" {
t.Error("expected non-empty serialized output")
}
}
func TestICECandidateDataToWebRTCCandidate(t *testing.T) {
data := &ICECandidateData{
Candidate: "candidate:842163049 1 udp 1677729535 203.0.113.1 3478 typ srflx",
SDPMid: "0",
SDPMLineIndex: 0,
UsernameFragment: "abc123",
}
candidate := data.ToWebRTCCandidate()
if candidate.Candidate != data.Candidate {
t.Errorf("Candidate = %q, want %q", candidate.Candidate, data.Candidate)
}
if candidate.SDPMid == nil || *candidate.SDPMid != "0" {
t.Error("SDPMid should be pointer to '0'")
}
if candidate.SDPMLineIndex == nil || *candidate.SDPMLineIndex != 0 {
t.Error("SDPMLineIndex should be pointer to 0")
}
if candidate.UsernameFragment == nil || *candidate.UsernameFragment != "abc123" {
t.Error("UsernameFragment should be pointer to 'abc123'")
}
}
func TestWelcomeDataSerialization(t *testing.T) {
welcome := &WelcomeData{
PeerID: "peer-123",
RoomID: "room-456",
Participants: []ParticipantInfo{
{PeerID: "peer-001", UserID: "user-001"},
{PeerID: "peer-002", UserID: "user-002"},
},
}
data, err := json.Marshal(welcome)
if err != nil {
t.Fatalf("failed to marshal: %v", err)
}
var result WelcomeData
if err := json.Unmarshal(data, &result); err != nil {
t.Fatalf("failed to unmarshal: %v", err)
}
if result.PeerID != "peer-123" {
t.Errorf("PeerID = %q, want %q", result.PeerID, "peer-123")
}
if result.RoomID != "room-456" {
t.Errorf("RoomID = %q, want %q", result.RoomID, "room-456")
}
if len(result.Participants) != 2 {
t.Errorf("Participants count = %d, want 2", len(result.Participants))
}
}
func TestTURNCredentialsDataSerialization(t *testing.T) {
creds := &TURNCredentialsData{
Username: "1234567890:test-ns",
Password: "base64password==",
TTL: 600,
URIs: []string{"turn:1.2.3.4:3478?transport=udp", "turn:5.6.7.8:443?transport=udp"},
}
data, err := json.Marshal(creds)
if err != nil {
t.Fatalf("failed to marshal: %v", err)
}
var result TURNCredentialsData
if err := json.Unmarshal(data, &result); err != nil {
t.Fatalf("failed to unmarshal: %v", err)
}
if result.Username != creds.Username {
t.Errorf("Username = %q, want %q", result.Username, creds.Username)
}
if result.TTL != 600 {
t.Errorf("TTL = %d, want 600", result.TTL)
}
if len(result.URIs) != 2 {
t.Errorf("URIs count = %d, want 2", len(result.URIs))
}
}

View File

@ -17,6 +17,8 @@ const (
ServiceTypeRQLite ServiceType = "rqlite" ServiceTypeRQLite ServiceType = "rqlite"
ServiceTypeOlric ServiceType = "olric" ServiceTypeOlric ServiceType = "olric"
ServiceTypeGateway ServiceType = "gateway" ServiceTypeGateway ServiceType = "gateway"
ServiceTypeSFU ServiceType = "sfu"
ServiceTypeTURN ServiceType = "turn"
) )
// Manager manages systemd units for namespace services // Manager manages systemd units for namespace services
@ -192,13 +194,33 @@ func (m *Manager) ReloadDaemon() error {
return nil return nil
} }
// serviceExists checks if a namespace service has an env file on disk,
// indicating the service was provisioned for this namespace.
func (m *Manager) serviceExists(namespace string, serviceType ServiceType) bool {
envFile := filepath.Join(m.namespaceBase, namespace, fmt.Sprintf("%s.env", serviceType))
_, err := os.Stat(envFile)
return err == nil
}
// StopAllNamespaceServices stops all namespace services for a given namespace // StopAllNamespaceServices stops all namespace services for a given namespace
func (m *Manager) StopAllNamespaceServices(namespace string) error { func (m *Manager) StopAllNamespaceServices(namespace string) error {
m.logger.Info("Stopping all namespace services", zap.String("namespace", namespace)) m.logger.Info("Stopping all namespace services", zap.String("namespace", namespace))
// Stop in reverse dependency order: Gateway → Olric → RQLite // Stop in reverse dependency order: SFU → TURN → Gateway → Olric → RQLite
services := []ServiceType{ServiceTypeGateway, ServiceTypeOlric, ServiceTypeRQLite} // SFU and TURN are conditional — only stop if they exist
for _, svcType := range services { for _, svcType := range []ServiceType{ServiceTypeSFU, ServiceTypeTURN} {
if m.serviceExists(namespace, svcType) {
if err := m.StopService(namespace, svcType); err != nil {
m.logger.Warn("Failed to stop service",
zap.String("namespace", namespace),
zap.String("service_type", string(svcType)),
zap.Error(err))
}
}
}
// Core services always exist
for _, svcType := range []ServiceType{ServiceTypeGateway, ServiceTypeOlric, ServiceTypeRQLite} {
if err := m.StopService(namespace, svcType); err != nil { if err := m.StopService(namespace, svcType); err != nil {
m.logger.Warn("Failed to stop service", m.logger.Warn("Failed to stop service",
zap.String("namespace", namespace), zap.String("namespace", namespace),
@ -215,14 +237,22 @@ func (m *Manager) StopAllNamespaceServices(namespace string) error {
func (m *Manager) StartAllNamespaceServices(namespace string) error { func (m *Manager) StartAllNamespaceServices(namespace string) error {
m.logger.Info("Starting all namespace services", zap.String("namespace", namespace)) m.logger.Info("Starting all namespace services", zap.String("namespace", namespace))
// Start in dependency order: RQLite → Olric → Gateway // Start core services in dependency order: RQLite → Olric → Gateway
services := []ServiceType{ServiceTypeRQLite, ServiceTypeOlric, ServiceTypeGateway} for _, svcType := range []ServiceType{ServiceTypeRQLite, ServiceTypeOlric, ServiceTypeGateway} {
for _, svcType := range services {
if err := m.StartService(namespace, svcType); err != nil { if err := m.StartService(namespace, svcType); err != nil {
return fmt.Errorf("failed to start %s service: %w", svcType, err) return fmt.Errorf("failed to start %s service: %w", svcType, err)
} }
} }
// Start WebRTC services if provisioned: TURN → SFU
for _, svcType := range []ServiceType{ServiceTypeTURN, ServiceTypeSFU} {
if m.serviceExists(namespace, svcType) {
if err := m.StartService(namespace, svcType); err != nil {
return fmt.Errorf("failed to start %s service: %w", svcType, err)
}
}
}
return nil return nil
} }
@ -419,6 +449,8 @@ func (m *Manager) InstallTemplateUnits(sourceDir string) error {
"orama-namespace-rqlite@.service", "orama-namespace-rqlite@.service",
"orama-namespace-olric@.service", "orama-namespace-olric@.service",
"orama-namespace-gateway@.service", "orama-namespace-gateway@.service",
"orama-namespace-sfu@.service",
"orama-namespace-turn@.service",
} }
for _, template := range templates { for _, template := range templates {

71
pkg/turn/config.go Normal file
View File

@ -0,0 +1,71 @@
package turn
import (
"fmt"
"net"
)
// Config holds configuration for the TURN server
type Config struct {
// ListenAddr is the address to bind the TURN listener (e.g., "0.0.0.0:3478")
ListenAddr string `yaml:"listen_addr"`
// TLSListenAddr is the address for TURN over TLS/DTLS (e.g., "0.0.0.0:443")
// Uses UDP 443 which does not conflict with Caddy's TCP 443
TLSListenAddr string `yaml:"tls_listen_addr"`
// PublicIP is the public IP address of this node, advertised in TURN allocations
PublicIP string `yaml:"public_ip"`
// Realm is the TURN realm (typically the base domain)
Realm string `yaml:"realm"`
// AuthSecret is the HMAC-SHA1 shared secret for credential validation
AuthSecret string `yaml:"auth_secret"`
// RelayPortStart is the beginning of the UDP relay port range
RelayPortStart int `yaml:"relay_port_start"`
// RelayPortEnd is the end of the UDP relay port range
RelayPortEnd int `yaml:"relay_port_end"`
// Namespace this TURN instance belongs to
Namespace string `yaml:"namespace"`
}
// Validate checks the TURN configuration for errors
func (c *Config) Validate() []error {
var errs []error
if c.ListenAddr == "" {
errs = append(errs, fmt.Errorf("turn.listen_addr: must not be empty"))
}
if c.PublicIP == "" {
errs = append(errs, fmt.Errorf("turn.public_ip: must not be empty"))
} else if ip := net.ParseIP(c.PublicIP); ip == nil {
errs = append(errs, fmt.Errorf("turn.public_ip: %q is not a valid IP address", c.PublicIP))
}
if c.Realm == "" {
errs = append(errs, fmt.Errorf("turn.realm: must not be empty"))
}
if c.AuthSecret == "" {
errs = append(errs, fmt.Errorf("turn.auth_secret: must not be empty"))
}
if c.RelayPortStart <= 0 || c.RelayPortEnd <= 0 {
errs = append(errs, fmt.Errorf("turn.relay_port_range: start and end must be positive"))
} else if c.RelayPortEnd <= c.RelayPortStart {
errs = append(errs, fmt.Errorf("turn.relay_port_range: end (%d) must be greater than start (%d)", c.RelayPortEnd, c.RelayPortStart))
} else if c.RelayPortEnd-c.RelayPortStart < 100 {
errs = append(errs, fmt.Errorf("turn.relay_port_range: range must be at least 100 ports (got %d)", c.RelayPortEnd-c.RelayPortStart))
}
if c.Namespace == "" {
errs = append(errs, fmt.Errorf("turn.namespace: must not be empty"))
}
return errs
}

228
pkg/turn/server.go Normal file
View File

@ -0,0 +1,228 @@
package turn
import (
"crypto/hmac"
"crypto/sha1"
"encoding/base64"
"fmt"
"net"
"strconv"
"strings"
"time"
pionTurn "github.com/pion/turn/v4"
"go.uber.org/zap"
)
// Server wraps a Pion TURN server with namespace-scoped HMAC-SHA1 authentication.
type Server struct {
config *Config
logger *zap.Logger
turnServer *pionTurn.Server
conn net.PacketConn // UDP listener on primary port (3478)
tlsConn net.PacketConn // UDP listener on TLS port (443)
}
// NewServer creates and starts a TURN server.
func NewServer(cfg *Config, logger *zap.Logger) (*Server, error) {
if errs := cfg.Validate(); len(errs) > 0 {
return nil, fmt.Errorf("invalid TURN config: %v", errs[0])
}
relayIP := net.ParseIP(cfg.PublicIP)
if relayIP == nil {
return nil, fmt.Errorf("turn.public_ip: %q is not a valid IP address", cfg.PublicIP)
}
s := &Server{
config: cfg,
logger: logger.With(zap.String("component", "turn"), zap.String("namespace", cfg.Namespace)),
}
// Create primary UDP listener (port 3478)
conn, err := net.ListenPacket("udp4", cfg.ListenAddr)
if err != nil {
return nil, fmt.Errorf("failed to listen on %s: %w", cfg.ListenAddr, err)
}
s.conn = conn
packetConfigs := []pionTurn.PacketConnConfig{
{
PacketConn: conn,
RelayAddressGenerator: &pionTurn.RelayAddressGeneratorPortRange{
RelayAddress: relayIP,
Address: "0.0.0.0",
MinPort: uint16(cfg.RelayPortStart),
MaxPort: uint16(cfg.RelayPortEnd),
},
},
}
// Create TLS UDP listener (port 443) if configured
// UDP 443 does not conflict with Caddy's TCP 443
if cfg.TLSListenAddr != "" {
tlsConn, err := net.ListenPacket("udp4", cfg.TLSListenAddr)
if err != nil {
conn.Close()
return nil, fmt.Errorf("failed to listen on %s: %w", cfg.TLSListenAddr, err)
}
s.tlsConn = tlsConn
packetConfigs = append(packetConfigs, pionTurn.PacketConnConfig{
PacketConn: tlsConn,
RelayAddressGenerator: &pionTurn.RelayAddressGeneratorPortRange{
RelayAddress: relayIP,
Address: "0.0.0.0",
MinPort: uint16(cfg.RelayPortStart),
MaxPort: uint16(cfg.RelayPortEnd),
},
})
}
// Create TURN server with HMAC-SHA1 auth
turnServer, err := pionTurn.NewServer(pionTurn.ServerConfig{
Realm: cfg.Realm,
AuthHandler: func(username, realm string, srcAddr net.Addr) ([]byte, bool) {
return s.authHandler(username, realm, srcAddr)
},
PacketConnConfigs: packetConfigs,
})
if err != nil {
s.closeListeners()
return nil, fmt.Errorf("failed to create TURN server: %w", err)
}
s.turnServer = turnServer
s.logger.Info("TURN server started",
zap.String("listen_addr", cfg.ListenAddr),
zap.String("tls_listen_addr", cfg.TLSListenAddr),
zap.String("public_ip", cfg.PublicIP),
zap.String("realm", cfg.Realm),
zap.Int("relay_port_start", cfg.RelayPortStart),
zap.Int("relay_port_end", cfg.RelayPortEnd),
)
return s, nil
}
// authHandler validates HMAC-SHA1 credentials.
// Username format: {expiry_unix}:{namespace}
// Password: base64(HMAC-SHA1(shared_secret, username))
func (s *Server) authHandler(username, realm string, srcAddr net.Addr) ([]byte, bool) {
// Parse username: must be "{timestamp}:{namespace}"
parts := strings.SplitN(username, ":", 2)
if len(parts) != 2 {
s.logger.Debug("Malformed TURN username: expected timestamp:namespace",
zap.String("username", username),
zap.String("src_addr", srcAddr.String()))
return nil, false
}
timestamp, err := strconv.ParseInt(parts[0], 10, 64)
if err != nil {
s.logger.Debug("Invalid timestamp in TURN username",
zap.String("username", username),
zap.String("src_addr", srcAddr.String()))
return nil, false
}
ns := parts[1]
// Verify namespace matches this TURN server's namespace
if ns != s.config.Namespace {
s.logger.Debug("TURN credential namespace mismatch",
zap.String("credential_namespace", ns),
zap.String("server_namespace", s.config.Namespace),
zap.String("src_addr", srcAddr.String()))
return nil, false
}
// Check expiry — credential must not be expired
if timestamp <= time.Now().Unix() {
s.logger.Debug("TURN credential expired",
zap.String("username", username),
zap.Int64("expired_at", timestamp),
zap.String("src_addr", srcAddr.String()))
return nil, false
}
// Generate expected password and derive auth key
password := GeneratePassword(s.config.AuthSecret, username)
key := pionTurn.GenerateAuthKey(username, realm, password)
s.logger.Debug("TURN auth accepted",
zap.String("namespace", ns),
zap.String("src_addr", srcAddr.String()))
return key, true
}
// Close gracefully shuts down the TURN server.
func (s *Server) Close() error {
s.logger.Info("Stopping TURN server")
if s.turnServer != nil {
if err := s.turnServer.Close(); err != nil {
s.logger.Warn("Error closing TURN server", zap.Error(err))
}
}
s.closeListeners()
s.logger.Info("TURN server stopped")
return nil
}
func (s *Server) closeListeners() {
if s.conn != nil {
s.conn.Close()
s.conn = nil
}
if s.tlsConn != nil {
s.tlsConn.Close()
s.tlsConn = nil
}
}
// GenerateCredentials creates time-limited HMAC-SHA1 TURN credentials.
// Returns username and password suitable for WebRTC ICE server configuration.
func GenerateCredentials(secret, namespace string, ttl time.Duration) (username, password string) {
expiry := time.Now().Add(ttl).Unix()
username = fmt.Sprintf("%d:%s", expiry, namespace)
password = GeneratePassword(secret, username)
return username, password
}
// GeneratePassword computes the HMAC-SHA1 password for a TURN username.
func GeneratePassword(secret, username string) string {
h := hmac.New(sha1.New, []byte(secret))
h.Write([]byte(username))
return base64.StdEncoding.EncodeToString(h.Sum(nil))
}
// ValidateCredentials checks if TURN credentials are valid and not expired.
func ValidateCredentials(secret, username, password, expectedNamespace string) bool {
parts := strings.SplitN(username, ":", 2)
if len(parts) != 2 {
return false
}
timestamp, err := strconv.ParseInt(parts[0], 10, 64)
if err != nil {
return false
}
// Check namespace
if parts[1] != expectedNamespace {
return false
}
// Check expiry
if timestamp <= time.Now().Unix() {
return false
}
// Check password
expected := GeneratePassword(secret, username)
return hmac.Equal([]byte(password), []byte(expected))
}

225
pkg/turn/server_test.go Normal file
View File

@ -0,0 +1,225 @@
package turn
import (
"fmt"
"testing"
"time"
)
func TestGenerateCredentials(t *testing.T) {
secret := "test-secret-key-32bytes-long!!!!"
namespace := "test-namespace"
ttl := 10 * time.Minute
username, password := GenerateCredentials(secret, namespace, ttl)
if username == "" {
t.Fatal("username should not be empty")
}
if password == "" {
t.Fatal("password should not be empty")
}
// Username should be "{timestamp}:{namespace}"
var ts int64
var ns string
n, err := fmt.Sscanf(username, "%d:%s", &ts, &ns)
if err != nil || n != 2 {
t.Fatalf("username format should be timestamp:namespace, got %q", username)
}
if ns != namespace {
t.Fatalf("namespace in username should be %q, got %q", namespace, ns)
}
// Timestamp should be ~10 minutes in the future
now := time.Now().Unix()
expectedExpiry := now + int64(ttl.Seconds())
if ts < expectedExpiry-2 || ts > expectedExpiry+2 {
t.Fatalf("expiry timestamp should be ~%d, got %d", expectedExpiry, ts)
}
}
func TestGeneratePassword(t *testing.T) {
secret := "test-secret"
username := "1234567890:test-ns"
password1 := GeneratePassword(secret, username)
password2 := GeneratePassword(secret, username)
// Same inputs should produce same output
if password1 != password2 {
t.Fatal("GeneratePassword should be deterministic")
}
// Different secret should produce different output
password3 := GeneratePassword("different-secret", username)
if password1 == password3 {
t.Fatal("different secrets should produce different passwords")
}
// Different username should produce different output
password4 := GeneratePassword(secret, "9999999999:other-ns")
if password1 == password4 {
t.Fatal("different usernames should produce different passwords")
}
}
func TestValidateCredentials(t *testing.T) {
secret := "test-secret-key"
namespace := "my-namespace"
ttl := 10 * time.Minute
// Generate valid credentials
username, password := GenerateCredentials(secret, namespace, ttl)
tests := []struct {
name string
secret string
username string
password string
namespace string
wantValid bool
}{
{
name: "valid credentials",
secret: secret,
username: username,
password: password,
namespace: namespace,
wantValid: true,
},
{
name: "wrong secret",
secret: "wrong-secret",
username: username,
password: password,
namespace: namespace,
wantValid: false,
},
{
name: "wrong password",
secret: secret,
username: username,
password: "wrongpassword",
namespace: namespace,
wantValid: false,
},
{
name: "wrong namespace",
secret: secret,
username: username,
password: password,
namespace: "other-namespace",
wantValid: false,
},
{
name: "expired credentials",
secret: secret,
username: fmt.Sprintf("%d:%s", time.Now().Unix()-60, namespace),
password: GeneratePassword(secret, fmt.Sprintf("%d:%s", time.Now().Unix()-60, namespace)),
namespace: namespace,
wantValid: false,
},
{
name: "malformed username - no colon",
secret: secret,
username: "badusername",
password: "whatever",
namespace: namespace,
wantValid: false,
},
{
name: "malformed username - non-numeric timestamp",
secret: secret,
username: "notanumber:my-namespace",
password: "whatever",
namespace: namespace,
wantValid: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := ValidateCredentials(tt.secret, tt.username, tt.password, tt.namespace)
if got != tt.wantValid {
t.Errorf("ValidateCredentials() = %v, want %v", got, tt.wantValid)
}
})
}
}
func TestConfigValidation(t *testing.T) {
tests := []struct {
name string
config Config
wantErrs int
}{
{
name: "valid config",
config: Config{
ListenAddr: "0.0.0.0:3478",
PublicIP: "1.2.3.4",
Realm: "dbrs.space",
AuthSecret: "secret123",
RelayPortStart: 49152,
RelayPortEnd: 50000,
Namespace: "test-ns",
},
wantErrs: 0,
},
{
name: "missing all fields",
config: Config{},
wantErrs: 6, // listen_addr, public_ip, realm, auth_secret, relay_port_range, namespace
},
{
name: "invalid public IP",
config: Config{
ListenAddr: "0.0.0.0:3478",
PublicIP: "not-an-ip",
Realm: "dbrs.space",
AuthSecret: "secret",
RelayPortStart: 49152,
RelayPortEnd: 50000,
Namespace: "test-ns",
},
wantErrs: 1,
},
{
name: "relay range too small",
config: Config{
ListenAddr: "0.0.0.0:3478",
PublicIP: "1.2.3.4",
Realm: "dbrs.space",
AuthSecret: "secret",
RelayPortStart: 49152,
RelayPortEnd: 49200,
Namespace: "test-ns",
},
wantErrs: 1,
},
{
name: "relay range inverted",
config: Config{
ListenAddr: "0.0.0.0:3478",
PublicIP: "1.2.3.4",
Realm: "dbrs.space",
AuthSecret: "secret",
RelayPortStart: 50000,
RelayPortEnd: 49152,
Namespace: "test-ns",
},
wantErrs: 1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
errs := tt.config.Validate()
if len(errs) != tt.wantErrs {
t.Errorf("Validate() returned %d errors, want %d: %v", len(errs), tt.wantErrs, errs)
}
})
}
}

View File

@ -0,0 +1,32 @@
[Unit]
Description=Orama Namespace SFU (%i)
Documentation=https://github.com/DeBrosOfficial/network
After=network.target orama-namespace-olric@%i.service
Wants=orama-namespace-olric@%i.service
PartOf=orama-node.service
[Service]
Type=simple
WorkingDirectory=/opt/orama
EnvironmentFile=/opt/orama/.orama/data/namespaces/%i/sfu.env
ExecStart=/bin/sh -c 'exec /opt/orama/bin/sfu --config ${SFU_CONFIG}'
TimeoutStopSec=45s
KillMode=mixed
KillSignal=SIGTERM
Restart=on-failure
RestartSec=5s
StandardOutput=journal
StandardError=journal
SyslogIdentifier=orama-sfu-%i
PrivateTmp=yes
LimitNOFILE=65536
MemoryMax=2G
[Install]
WantedBy=multi-user.target

View File

@ -0,0 +1,31 @@
[Unit]
Description=Orama Namespace TURN (%i)
Documentation=https://github.com/DeBrosOfficial/network
After=network.target
PartOf=orama-node.service
[Service]
Type=simple
WorkingDirectory=/opt/orama
EnvironmentFile=/opt/orama/.orama/data/namespaces/%i/turn.env
ExecStart=/bin/sh -c 'exec /opt/orama/bin/turn --config ${TURN_CONFIG}'
TimeoutStopSec=30s
KillMode=mixed
KillSignal=SIGTERM
Restart=on-failure
RestartSec=5s
StandardOutput=journal
StandardError=journal
SyslogIdentifier=orama-turn-%i
PrivateTmp=yes
LimitNOFILE=65536
MemoryMax=1G
[Install]
WantedBy=multi-user.target