mirror of
https://github.com/DeBrosOfficial/orama.git
synced 2026-03-17 21:46:57 +00:00
feat: implement SFU and TURN server functionality
- Add signaling package with message types and structures for SFU communication. - Implement client and server message serialization/deserialization tests. - Enhance systemd manager to handle SFU and TURN services, including start/stop logic. - Create TURN server configuration and main server logic with HMAC-SHA1 authentication. - Add tests for TURN server credential generation and validation. - Define systemd service files for SFU and TURN services.
This commit is contained in:
parent
58ea896cb0
commit
8ee606bfb1
2
Makefile
2
Makefile
@ -78,6 +78,8 @@ build: deps
|
|||||||
go build -ldflags "$(LDFLAGS)" -o bin/orama ./cmd/cli/
|
go build -ldflags "$(LDFLAGS)" -o bin/orama ./cmd/cli/
|
||||||
# Inject gateway build metadata via pkg path variables
|
# Inject gateway build metadata via pkg path variables
|
||||||
go build -ldflags "$(LDFLAGS) -X 'github.com/DeBrosOfficial/network/pkg/gateway.BuildVersion=$(VERSION)' -X 'github.com/DeBrosOfficial/network/pkg/gateway.BuildCommit=$(COMMIT)' -X 'github.com/DeBrosOfficial/network/pkg/gateway.BuildTime=$(DATE)'" -o bin/gateway ./cmd/gateway
|
go build -ldflags "$(LDFLAGS) -X 'github.com/DeBrosOfficial/network/pkg/gateway.BuildVersion=$(VERSION)' -X 'github.com/DeBrosOfficial/network/pkg/gateway.BuildCommit=$(COMMIT)' -X 'github.com/DeBrosOfficial/network/pkg/gateway.BuildTime=$(DATE)'" -o bin/gateway ./cmd/gateway
|
||||||
|
go build -ldflags "$(LDFLAGS)" -o bin/sfu ./cmd/sfu
|
||||||
|
go build -ldflags "$(LDFLAGS)" -o bin/turn ./cmd/turn
|
||||||
@echo "Build complete! Run ./bin/orama version"
|
@echo "Build complete! Run ./bin/orama version"
|
||||||
|
|
||||||
# Cross-compile CLI for Linux (only binary needed locally; VPS builds everything else from source)
|
# Cross-compile CLI for Linux (only binary needed locally; VPS builds everything else from source)
|
||||||
|
|||||||
@ -69,6 +69,13 @@ func parseGatewayConfig(logger *logging.ColoredLogger) *gateway.Config {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Load YAML
|
// Load YAML
|
||||||
|
type yamlWebRTCCfg struct {
|
||||||
|
Enabled bool `yaml:"enabled"`
|
||||||
|
SFUPort int `yaml:"sfu_port"`
|
||||||
|
TURNDomain string `yaml:"turn_domain"`
|
||||||
|
TURNSecret string `yaml:"turn_secret"`
|
||||||
|
}
|
||||||
|
|
||||||
type yamlCfg struct {
|
type yamlCfg struct {
|
||||||
ListenAddr string `yaml:"listen_addr"`
|
ListenAddr string `yaml:"listen_addr"`
|
||||||
ClientNamespace string `yaml:"client_namespace"`
|
ClientNamespace string `yaml:"client_namespace"`
|
||||||
@ -84,6 +91,7 @@ func parseGatewayConfig(logger *logging.ColoredLogger) *gateway.Config {
|
|||||||
IPFSAPIURL string `yaml:"ipfs_api_url"`
|
IPFSAPIURL string `yaml:"ipfs_api_url"`
|
||||||
IPFSTimeout string `yaml:"ipfs_timeout"`
|
IPFSTimeout string `yaml:"ipfs_timeout"`
|
||||||
IPFSReplicationFactor int `yaml:"ipfs_replication_factor"`
|
IPFSReplicationFactor int `yaml:"ipfs_replication_factor"`
|
||||||
|
WebRTC yamlWebRTCCfg `yaml:"webrtc"`
|
||||||
}
|
}
|
||||||
|
|
||||||
data, err := os.ReadFile(configPath)
|
data, err := os.ReadFile(configPath)
|
||||||
@ -192,6 +200,18 @@ func parseGatewayConfig(logger *logging.ColoredLogger) *gateway.Config {
|
|||||||
cfg.IPFSReplicationFactor = y.IPFSReplicationFactor
|
cfg.IPFSReplicationFactor = y.IPFSReplicationFactor
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WebRTC configuration
|
||||||
|
cfg.WebRTCEnabled = y.WebRTC.Enabled
|
||||||
|
if y.WebRTC.SFUPort > 0 {
|
||||||
|
cfg.SFUPort = y.WebRTC.SFUPort
|
||||||
|
}
|
||||||
|
if v := strings.TrimSpace(y.WebRTC.TURNDomain); v != "" {
|
||||||
|
cfg.TURNDomain = v
|
||||||
|
}
|
||||||
|
if v := strings.TrimSpace(y.WebRTC.TURNSecret); v != "" {
|
||||||
|
cfg.TURNSecret = v
|
||||||
|
}
|
||||||
|
|
||||||
// Validate configuration
|
// Validate configuration
|
||||||
if errs := cfg.ValidateConfig(); len(errs) > 0 {
|
if errs := cfg.ValidateConfig(); len(errs) > 0 {
|
||||||
fmt.Fprintf(os.Stderr, "\nGateway configuration errors (%d):\n", len(errs))
|
fmt.Fprintf(os.Stderr, "\nGateway configuration errors (%d):\n", len(errs))
|
||||||
|
|||||||
116
cmd/sfu/config.go
Normal file
116
cmd/sfu/config.go
Normal file
@ -0,0 +1,116 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"flag"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/DeBrosOfficial/network/pkg/config"
|
||||||
|
"github.com/DeBrosOfficial/network/pkg/logging"
|
||||||
|
"github.com/DeBrosOfficial/network/pkg/sfu"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
// newSFUServer creates a new SFU server from config and logger.
|
||||||
|
// Wrapper to keep main.go clean and avoid importing sfu in main.
|
||||||
|
func newSFUServer(cfg *sfu.Config, logger *zap.Logger) (*sfu.Server, error) {
|
||||||
|
return sfu.NewServer(cfg, logger)
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseSFUConfig(logger *logging.ColoredLogger) *sfu.Config {
|
||||||
|
configFlag := flag.String("config", "", "Config file path (absolute path or filename in ~/.orama)")
|
||||||
|
flag.Parse()
|
||||||
|
|
||||||
|
var configPath string
|
||||||
|
var err error
|
||||||
|
if *configFlag != "" {
|
||||||
|
if filepath.IsAbs(*configFlag) {
|
||||||
|
configPath = *configFlag
|
||||||
|
} else {
|
||||||
|
configPath, err = config.DefaultPath(*configFlag)
|
||||||
|
if err != nil {
|
||||||
|
logger.ComponentError(logging.ComponentSFU, "Failed to determine config path", zap.Error(err))
|
||||||
|
fmt.Fprintf(os.Stderr, "Configuration error: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
configPath, err = config.DefaultPath("sfu.yaml")
|
||||||
|
if err != nil {
|
||||||
|
logger.ComponentError(logging.ComponentSFU, "Failed to determine config path", zap.Error(err))
|
||||||
|
fmt.Fprintf(os.Stderr, "Configuration error: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type yamlTURNServer struct {
|
||||||
|
Host string `yaml:"host"`
|
||||||
|
Port int `yaml:"port"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type yamlCfg struct {
|
||||||
|
ListenAddr string `yaml:"listen_addr"`
|
||||||
|
Namespace string `yaml:"namespace"`
|
||||||
|
MediaPortStart int `yaml:"media_port_start"`
|
||||||
|
MediaPortEnd int `yaml:"media_port_end"`
|
||||||
|
TURNServers []yamlTURNServer `yaml:"turn_servers"`
|
||||||
|
TURNSecret string `yaml:"turn_secret"`
|
||||||
|
TURNCredentialTTL int `yaml:"turn_credential_ttl"`
|
||||||
|
RQLiteDSN string `yaml:"rqlite_dsn"`
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := os.ReadFile(configPath)
|
||||||
|
if err != nil {
|
||||||
|
logger.ComponentError(logging.ComponentSFU, "Config file not found",
|
||||||
|
zap.String("path", configPath), zap.Error(err))
|
||||||
|
fmt.Fprintf(os.Stderr, "\nConfig file not found at %s\n", configPath)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
var y yamlCfg
|
||||||
|
if err := config.DecodeStrict(strings.NewReader(string(data)), &y); err != nil {
|
||||||
|
logger.ComponentError(logging.ComponentSFU, "Failed to parse SFU config", zap.Error(err))
|
||||||
|
fmt.Fprintf(os.Stderr, "Configuration parse error: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
var turnServers []sfu.TURNServerConfig
|
||||||
|
for _, ts := range y.TURNServers {
|
||||||
|
turnServers = append(turnServers, sfu.TURNServerConfig{
|
||||||
|
Host: ts.Host,
|
||||||
|
Port: ts.Port,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := &sfu.Config{
|
||||||
|
ListenAddr: y.ListenAddr,
|
||||||
|
Namespace: y.Namespace,
|
||||||
|
MediaPortStart: y.MediaPortStart,
|
||||||
|
MediaPortEnd: y.MediaPortEnd,
|
||||||
|
TURNServers: turnServers,
|
||||||
|
TURNSecret: y.TURNSecret,
|
||||||
|
TURNCredentialTTL: y.TURNCredentialTTL,
|
||||||
|
RQLiteDSN: y.RQLiteDSN,
|
||||||
|
}
|
||||||
|
|
||||||
|
if errs := cfg.Validate(); len(errs) > 0 {
|
||||||
|
fmt.Fprintf(os.Stderr, "\nSFU configuration errors (%d):\n", len(errs))
|
||||||
|
for _, e := range errs {
|
||||||
|
fmt.Fprintf(os.Stderr, " - %s\n", e)
|
||||||
|
}
|
||||||
|
fmt.Fprintf(os.Stderr, "\nPlease fix the configuration and try again.\n")
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.ComponentInfo(logging.ComponentSFU, "Loaded SFU configuration",
|
||||||
|
zap.String("path", configPath),
|
||||||
|
zap.String("listen_addr", cfg.ListenAddr),
|
||||||
|
zap.String("namespace", cfg.Namespace),
|
||||||
|
zap.Int("media_ports", cfg.MediaPortEnd-cfg.MediaPortStart),
|
||||||
|
zap.Int("turn_servers", len(cfg.TURNServers)),
|
||||||
|
)
|
||||||
|
|
||||||
|
return cfg
|
||||||
|
}
|
||||||
59
cmd/sfu/main.go
Normal file
59
cmd/sfu/main.go
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
"syscall"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/DeBrosOfficial/network/pkg/logging"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
version = "dev"
|
||||||
|
commit = "unknown"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
logger, err := logging.NewColoredLogger(logging.ComponentSFU, true)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.ComponentInfo(logging.ComponentSFU, "Starting SFU server",
|
||||||
|
zap.String("version", version),
|
||||||
|
zap.String("commit", commit))
|
||||||
|
|
||||||
|
cfg := parseSFUConfig(logger)
|
||||||
|
|
||||||
|
server, err := newSFUServer(cfg, logger.Logger)
|
||||||
|
if err != nil {
|
||||||
|
logger.ComponentError(logging.ComponentSFU, "Failed to create SFU server", zap.Error(err))
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start HTTP server in background
|
||||||
|
go func() {
|
||||||
|
if err := server.ListenAndServe(); err != nil {
|
||||||
|
logger.ComponentError(logging.ComponentSFU, "SFU server error", zap.Error(err))
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Wait for termination signal
|
||||||
|
quit := make(chan os.Signal, 1)
|
||||||
|
signal.Notify(quit, os.Interrupt, syscall.SIGTERM)
|
||||||
|
sig := <-quit
|
||||||
|
|
||||||
|
logger.ComponentInfo(logging.ComponentSFU, "Shutdown signal received", zap.String("signal", sig.String()))
|
||||||
|
|
||||||
|
// Graceful drain: notify peers and wait
|
||||||
|
server.Drain(30 * time.Second)
|
||||||
|
|
||||||
|
if err := server.Close(); err != nil {
|
||||||
|
logger.ComponentError(logging.ComponentSFU, "Error during shutdown", zap.Error(err))
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.ComponentInfo(logging.ComponentSFU, "SFU server shutdown complete")
|
||||||
|
}
|
||||||
96
cmd/turn/config.go
Normal file
96
cmd/turn/config.go
Normal file
@ -0,0 +1,96 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"flag"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/DeBrosOfficial/network/pkg/config"
|
||||||
|
"github.com/DeBrosOfficial/network/pkg/logging"
|
||||||
|
"github.com/DeBrosOfficial/network/pkg/turn"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
func parseTURNConfig(logger *logging.ColoredLogger) *turn.Config {
|
||||||
|
configFlag := flag.String("config", "", "Config file path (absolute path or filename in ~/.orama)")
|
||||||
|
flag.Parse()
|
||||||
|
|
||||||
|
var configPath string
|
||||||
|
var err error
|
||||||
|
if *configFlag != "" {
|
||||||
|
if filepath.IsAbs(*configFlag) {
|
||||||
|
configPath = *configFlag
|
||||||
|
} else {
|
||||||
|
configPath, err = config.DefaultPath(*configFlag)
|
||||||
|
if err != nil {
|
||||||
|
logger.ComponentError(logging.ComponentTURN, "Failed to determine config path", zap.Error(err))
|
||||||
|
fmt.Fprintf(os.Stderr, "Configuration error: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
configPath, err = config.DefaultPath("turn.yaml")
|
||||||
|
if err != nil {
|
||||||
|
logger.ComponentError(logging.ComponentTURN, "Failed to determine config path", zap.Error(err))
|
||||||
|
fmt.Fprintf(os.Stderr, "Configuration error: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type yamlCfg struct {
|
||||||
|
ListenAddr string `yaml:"listen_addr"`
|
||||||
|
TLSListenAddr string `yaml:"tls_listen_addr"`
|
||||||
|
PublicIP string `yaml:"public_ip"`
|
||||||
|
Realm string `yaml:"realm"`
|
||||||
|
AuthSecret string `yaml:"auth_secret"`
|
||||||
|
RelayPortStart int `yaml:"relay_port_start"`
|
||||||
|
RelayPortEnd int `yaml:"relay_port_end"`
|
||||||
|
Namespace string `yaml:"namespace"`
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := os.ReadFile(configPath)
|
||||||
|
if err != nil {
|
||||||
|
logger.ComponentError(logging.ComponentTURN, "Config file not found",
|
||||||
|
zap.String("path", configPath), zap.Error(err))
|
||||||
|
fmt.Fprintf(os.Stderr, "\nConfig file not found at %s\n", configPath)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
var y yamlCfg
|
||||||
|
if err := config.DecodeStrict(strings.NewReader(string(data)), &y); err != nil {
|
||||||
|
logger.ComponentError(logging.ComponentTURN, "Failed to parse TURN config", zap.Error(err))
|
||||||
|
fmt.Fprintf(os.Stderr, "Configuration parse error: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := &turn.Config{
|
||||||
|
ListenAddr: y.ListenAddr,
|
||||||
|
TLSListenAddr: y.TLSListenAddr,
|
||||||
|
PublicIP: y.PublicIP,
|
||||||
|
Realm: y.Realm,
|
||||||
|
AuthSecret: y.AuthSecret,
|
||||||
|
RelayPortStart: y.RelayPortStart,
|
||||||
|
RelayPortEnd: y.RelayPortEnd,
|
||||||
|
Namespace: y.Namespace,
|
||||||
|
}
|
||||||
|
|
||||||
|
if errs := cfg.Validate(); len(errs) > 0 {
|
||||||
|
fmt.Fprintf(os.Stderr, "\nTURN configuration errors (%d):\n", len(errs))
|
||||||
|
for _, e := range errs {
|
||||||
|
fmt.Fprintf(os.Stderr, " - %s\n", e)
|
||||||
|
}
|
||||||
|
fmt.Fprintf(os.Stderr, "\nPlease fix the configuration and try again.\n")
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.ComponentInfo(logging.ComponentTURN, "Loaded TURN configuration",
|
||||||
|
zap.String("path", configPath),
|
||||||
|
zap.String("listen_addr", cfg.ListenAddr),
|
||||||
|
zap.String("namespace", cfg.Namespace),
|
||||||
|
zap.String("realm", cfg.Realm),
|
||||||
|
)
|
||||||
|
|
||||||
|
return cfg
|
||||||
|
}
|
||||||
48
cmd/turn/main.go
Normal file
48
cmd/turn/main.go
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
"syscall"
|
||||||
|
|
||||||
|
"github.com/DeBrosOfficial/network/pkg/logging"
|
||||||
|
"github.com/DeBrosOfficial/network/pkg/turn"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
version = "dev"
|
||||||
|
commit = "unknown"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
logger, err := logging.NewColoredLogger(logging.ComponentTURN, true)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.ComponentInfo(logging.ComponentTURN, "Starting TURN server",
|
||||||
|
zap.String("version", version),
|
||||||
|
zap.String("commit", commit))
|
||||||
|
|
||||||
|
cfg := parseTURNConfig(logger)
|
||||||
|
|
||||||
|
server, err := turn.NewServer(cfg, logger.Logger)
|
||||||
|
if err != nil {
|
||||||
|
logger.ComponentError(logging.ComponentTURN, "Failed to start TURN server", zap.Error(err))
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for termination signal
|
||||||
|
quit := make(chan os.Signal, 1)
|
||||||
|
signal.Notify(quit, os.Interrupt, syscall.SIGTERM)
|
||||||
|
sig := <-quit
|
||||||
|
|
||||||
|
logger.ComponentInfo(logging.ComponentTURN, "Shutdown signal received", zap.String("signal", sig.String()))
|
||||||
|
|
||||||
|
if err := server.Close(); err != nil {
|
||||||
|
logger.ComponentError(logging.ComponentTURN, "Error during shutdown", zap.Error(err))
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.ComponentInfo(logging.ComponentTURN, "TURN server shutdown complete")
|
||||||
|
}
|
||||||
96
migrations/018_webrtc_services.sql
Normal file
96
migrations/018_webrtc_services.sql
Normal file
@ -0,0 +1,96 @@
|
|||||||
|
-- Migration 018: WebRTC Services (SFU + TURN) for Namespace Clusters
|
||||||
|
-- Adds per-namespace WebRTC configuration, room tracking, and port allocation
|
||||||
|
-- WebRTC is opt-in: enabled via `orama namespace enable webrtc`
|
||||||
|
|
||||||
|
BEGIN;
|
||||||
|
|
||||||
|
-- Per-namespace WebRTC configuration
|
||||||
|
-- One row per namespace that has WebRTC enabled
|
||||||
|
CREATE TABLE IF NOT EXISTS namespace_webrtc_config (
|
||||||
|
id TEXT PRIMARY KEY, -- UUID
|
||||||
|
namespace_cluster_id TEXT NOT NULL UNIQUE, -- FK to namespace_clusters
|
||||||
|
namespace_name TEXT NOT NULL, -- Cached for easier lookups
|
||||||
|
enabled INTEGER NOT NULL DEFAULT 1, -- 1 = enabled, 0 = disabled
|
||||||
|
|
||||||
|
-- TURN authentication
|
||||||
|
turn_shared_secret TEXT NOT NULL, -- HMAC-SHA1 shared secret (base64, 32 bytes)
|
||||||
|
turn_credential_ttl INTEGER NOT NULL DEFAULT 600, -- Credential TTL in seconds (default: 10 min)
|
||||||
|
|
||||||
|
-- Service topology
|
||||||
|
sfu_node_count INTEGER NOT NULL DEFAULT 3, -- SFU instances (all 3 nodes)
|
||||||
|
turn_node_count INTEGER NOT NULL DEFAULT 2, -- TURN instances (2 of 3 nodes for HA)
|
||||||
|
|
||||||
|
-- Metadata
|
||||||
|
enabled_by TEXT NOT NULL, -- Wallet address that enabled WebRTC
|
||||||
|
enabled_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
disabled_at TIMESTAMP,
|
||||||
|
|
||||||
|
FOREIGN KEY (namespace_cluster_id) REFERENCES namespace_clusters(id) ON DELETE CASCADE
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_webrtc_config_namespace ON namespace_webrtc_config(namespace_name);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_webrtc_config_cluster ON namespace_webrtc_config(namespace_cluster_id);
|
||||||
|
|
||||||
|
-- WebRTC room tracking
|
||||||
|
-- Tracks active rooms and their SFU node affinity
|
||||||
|
CREATE TABLE IF NOT EXISTS webrtc_rooms (
|
||||||
|
id TEXT PRIMARY KEY, -- UUID
|
||||||
|
namespace_cluster_id TEXT NOT NULL, -- FK to namespace_clusters
|
||||||
|
namespace_name TEXT NOT NULL, -- Cached for easier lookups
|
||||||
|
room_id TEXT NOT NULL, -- Application-defined room identifier
|
||||||
|
|
||||||
|
-- SFU affinity
|
||||||
|
sfu_node_id TEXT NOT NULL, -- Node hosting this room's SFU
|
||||||
|
sfu_internal_ip TEXT NOT NULL, -- WireGuard IP of SFU node
|
||||||
|
sfu_signaling_port INTEGER NOT NULL, -- SFU WebSocket signaling port
|
||||||
|
|
||||||
|
-- Room state
|
||||||
|
participant_count INTEGER NOT NULL DEFAULT 0,
|
||||||
|
max_participants INTEGER NOT NULL DEFAULT 100,
|
||||||
|
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
last_activity TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
|
||||||
|
-- Prevent duplicate rooms within a namespace
|
||||||
|
UNIQUE(namespace_cluster_id, room_id),
|
||||||
|
FOREIGN KEY (namespace_cluster_id) REFERENCES namespace_clusters(id) ON DELETE CASCADE
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_webrtc_rooms_namespace ON webrtc_rooms(namespace_name);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_webrtc_rooms_node ON webrtc_rooms(sfu_node_id);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_webrtc_rooms_activity ON webrtc_rooms(last_activity);
|
||||||
|
|
||||||
|
-- WebRTC port allocations
|
||||||
|
-- Separate from namespace_port_allocations to avoid breaking existing port blocks
|
||||||
|
-- Each namespace gets SFU + TURN ports on each node where those services run
|
||||||
|
CREATE TABLE IF NOT EXISTS webrtc_port_allocations (
|
||||||
|
id TEXT PRIMARY KEY, -- UUID
|
||||||
|
node_id TEXT NOT NULL, -- Physical node ID
|
||||||
|
namespace_cluster_id TEXT NOT NULL, -- FK to namespace_clusters
|
||||||
|
service_type TEXT NOT NULL, -- 'sfu' or 'turn'
|
||||||
|
|
||||||
|
-- SFU ports (when service_type = 'sfu')
|
||||||
|
sfu_signaling_port INTEGER, -- WebSocket signaling port
|
||||||
|
sfu_media_port_start INTEGER, -- Start of RTP media port range
|
||||||
|
sfu_media_port_end INTEGER, -- End of RTP media port range
|
||||||
|
|
||||||
|
-- TURN ports (when service_type = 'turn')
|
||||||
|
turn_listen_port INTEGER, -- TURN listener port (3478)
|
||||||
|
turn_tls_port INTEGER, -- TURN TLS port (443/UDP)
|
||||||
|
turn_relay_port_start INTEGER, -- Start of relay port range
|
||||||
|
turn_relay_port_end INTEGER, -- End of relay port range
|
||||||
|
|
||||||
|
allocated_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
|
||||||
|
-- Prevent overlapping allocations
|
||||||
|
UNIQUE(node_id, namespace_cluster_id, service_type),
|
||||||
|
FOREIGN KEY (namespace_cluster_id) REFERENCES namespace_clusters(id) ON DELETE CASCADE
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_webrtc_ports_node ON webrtc_port_allocations(node_id);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_webrtc_ports_cluster ON webrtc_port_allocations(namespace_cluster_id);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_webrtc_ports_type ON webrtc_port_allocations(service_type);
|
||||||
|
|
||||||
|
-- Mark migration as applied
|
||||||
|
INSERT OR IGNORE INTO schema_migrations(version) VALUES (18);
|
||||||
|
|
||||||
|
COMMIT;
|
||||||
@ -37,6 +37,30 @@ func HandleNamespaceCommand(args []string) {
|
|||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
handleNamespaceRepair(args[1])
|
handleNamespaceRepair(args[1])
|
||||||
|
case "enable":
|
||||||
|
if len(args) < 2 {
|
||||||
|
fmt.Fprintf(os.Stderr, "Usage: orama namespace enable <feature> --namespace <name>\n")
|
||||||
|
fmt.Fprintf(os.Stderr, "Features: webrtc\n")
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
handleNamespaceEnable(args[1:])
|
||||||
|
case "disable":
|
||||||
|
if len(args) < 2 {
|
||||||
|
fmt.Fprintf(os.Stderr, "Usage: orama namespace disable <feature> --namespace <name>\n")
|
||||||
|
fmt.Fprintf(os.Stderr, "Features: webrtc\n")
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
handleNamespaceDisable(args[1:])
|
||||||
|
case "webrtc-status":
|
||||||
|
var ns string
|
||||||
|
fs := flag.NewFlagSet("namespace webrtc-status", flag.ExitOnError)
|
||||||
|
fs.StringVar(&ns, "namespace", "", "Namespace name")
|
||||||
|
_ = fs.Parse(args[1:])
|
||||||
|
if ns == "" {
|
||||||
|
fmt.Fprintf(os.Stderr, "Usage: orama namespace webrtc-status --namespace <name>\n")
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
handleNamespaceWebRTCStatus(ns)
|
||||||
case "help":
|
case "help":
|
||||||
showNamespaceHelp()
|
showNamespaceHelp()
|
||||||
default:
|
default:
|
||||||
@ -50,17 +74,24 @@ func showNamespaceHelp() {
|
|||||||
fmt.Printf("Namespace Management Commands\n\n")
|
fmt.Printf("Namespace Management Commands\n\n")
|
||||||
fmt.Printf("Usage: orama namespace <subcommand>\n\n")
|
fmt.Printf("Usage: orama namespace <subcommand>\n\n")
|
||||||
fmt.Printf("Subcommands:\n")
|
fmt.Printf("Subcommands:\n")
|
||||||
fmt.Printf(" list - List namespaces owned by the current wallet\n")
|
fmt.Printf(" list - List namespaces owned by the current wallet\n")
|
||||||
fmt.Printf(" delete - Delete the current namespace and all its resources\n")
|
fmt.Printf(" delete - Delete the current namespace and all its resources\n")
|
||||||
fmt.Printf(" repair <namespace> - Repair an under-provisioned namespace cluster (add missing nodes)\n")
|
fmt.Printf(" repair <namespace> - Repair an under-provisioned namespace cluster\n")
|
||||||
fmt.Printf(" help - Show this help message\n\n")
|
fmt.Printf(" enable webrtc --namespace NS - Enable WebRTC (SFU + TURN) for a namespace\n")
|
||||||
|
fmt.Printf(" disable webrtc --namespace NS - Disable WebRTC for a namespace\n")
|
||||||
|
fmt.Printf(" webrtc-status --namespace NS - Show WebRTC service status\n")
|
||||||
|
fmt.Printf(" help - Show this help message\n\n")
|
||||||
fmt.Printf("Flags:\n")
|
fmt.Printf("Flags:\n")
|
||||||
fmt.Printf(" --force - Skip confirmation prompt (delete only)\n\n")
|
fmt.Printf(" --force - Skip confirmation prompt (delete only)\n")
|
||||||
|
fmt.Printf(" --namespace - Namespace name (enable/disable/webrtc-status)\n\n")
|
||||||
fmt.Printf("Examples:\n")
|
fmt.Printf("Examples:\n")
|
||||||
fmt.Printf(" orama namespace list\n")
|
fmt.Printf(" orama namespace list\n")
|
||||||
fmt.Printf(" orama namespace delete\n")
|
fmt.Printf(" orama namespace delete\n")
|
||||||
fmt.Printf(" orama namespace delete --force\n")
|
fmt.Printf(" orama namespace delete --force\n")
|
||||||
fmt.Printf(" orama namespace repair anchat\n")
|
fmt.Printf(" orama namespace repair anchat\n")
|
||||||
|
fmt.Printf(" orama namespace enable webrtc --namespace myapp\n")
|
||||||
|
fmt.Printf(" orama namespace disable webrtc --namespace myapp\n")
|
||||||
|
fmt.Printf(" orama namespace webrtc-status --namespace myapp\n")
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleNamespaceRepair(namespaceName string) {
|
func handleNamespaceRepair(namespaceName string) {
|
||||||
@ -193,6 +224,165 @@ func handleNamespaceDelete(force bool) {
|
|||||||
fmt.Printf("Run 'orama auth login' to create a new namespace.\n")
|
fmt.Printf("Run 'orama auth login' to create a new namespace.\n")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func handleNamespaceEnable(args []string) {
|
||||||
|
feature := args[0]
|
||||||
|
if feature != "webrtc" {
|
||||||
|
fmt.Fprintf(os.Stderr, "Unknown feature: %s\nSupported features: webrtc\n", feature)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
var ns string
|
||||||
|
fs := flag.NewFlagSet("namespace enable webrtc", flag.ExitOnError)
|
||||||
|
fs.StringVar(&ns, "namespace", "", "Namespace name")
|
||||||
|
_ = fs.Parse(args[1:])
|
||||||
|
|
||||||
|
if ns == "" {
|
||||||
|
fmt.Fprintf(os.Stderr, "Usage: orama namespace enable webrtc --namespace <name>\n")
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("Enabling WebRTC for namespace '%s'...\n", ns)
|
||||||
|
fmt.Printf("This will provision SFU (3 nodes) and TURN (2 nodes) services.\n")
|
||||||
|
|
||||||
|
url := fmt.Sprintf("http://localhost:%d/v1/internal/namespace/webrtc/enable?namespace=%s", constants.GatewayAPIPort, ns)
|
||||||
|
req, err := http.NewRequest(http.MethodPost, url, nil)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Failed to create request: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
req.Header.Set("X-Orama-Internal-Auth", "namespace-coordination")
|
||||||
|
|
||||||
|
client := &http.Client{}
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Failed to connect to local gateway (is the node running?): %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
var result map[string]interface{}
|
||||||
|
json.NewDecoder(resp.Body).Decode(&result)
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
errMsg := "unknown error"
|
||||||
|
if e, ok := result["error"].(string); ok {
|
||||||
|
errMsg = e
|
||||||
|
}
|
||||||
|
fmt.Fprintf(os.Stderr, "Failed to enable WebRTC: %s\n", errMsg)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("WebRTC enabled for namespace '%s'.\n", ns)
|
||||||
|
fmt.Printf(" SFU instances: 3 nodes (signaling via WireGuard)\n")
|
||||||
|
fmt.Printf(" TURN instances: 2 nodes (relay on public IPs)\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleNamespaceDisable(args []string) {
|
||||||
|
feature := args[0]
|
||||||
|
if feature != "webrtc" {
|
||||||
|
fmt.Fprintf(os.Stderr, "Unknown feature: %s\nSupported features: webrtc\n", feature)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
var ns string
|
||||||
|
fs := flag.NewFlagSet("namespace disable webrtc", flag.ExitOnError)
|
||||||
|
fs.StringVar(&ns, "namespace", "", "Namespace name")
|
||||||
|
_ = fs.Parse(args[1:])
|
||||||
|
|
||||||
|
if ns == "" {
|
||||||
|
fmt.Fprintf(os.Stderr, "Usage: orama namespace disable webrtc --namespace <name>\n")
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("Disabling WebRTC for namespace '%s'...\n", ns)
|
||||||
|
|
||||||
|
url := fmt.Sprintf("http://localhost:%d/v1/internal/namespace/webrtc/disable?namespace=%s", constants.GatewayAPIPort, ns)
|
||||||
|
req, err := http.NewRequest(http.MethodPost, url, nil)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Failed to create request: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
req.Header.Set("X-Orama-Internal-Auth", "namespace-coordination")
|
||||||
|
|
||||||
|
client := &http.Client{}
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Failed to connect to local gateway (is the node running?): %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
var result map[string]interface{}
|
||||||
|
json.NewDecoder(resp.Body).Decode(&result)
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
errMsg := "unknown error"
|
||||||
|
if e, ok := result["error"].(string); ok {
|
||||||
|
errMsg = e
|
||||||
|
}
|
||||||
|
fmt.Fprintf(os.Stderr, "Failed to disable WebRTC: %s\n", errMsg)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("WebRTC disabled for namespace '%s'.\n", ns)
|
||||||
|
fmt.Printf(" SFU and TURN services stopped, ports deallocated, DNS records removed.\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleNamespaceWebRTCStatus(ns string) {
|
||||||
|
url := fmt.Sprintf("http://localhost:%d/v1/internal/namespace/webrtc/status?namespace=%s", constants.GatewayAPIPort, ns)
|
||||||
|
req, err := http.NewRequest(http.MethodGet, url, nil)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Failed to create request: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
req.Header.Set("X-Orama-Internal-Auth", "namespace-coordination")
|
||||||
|
|
||||||
|
client := &http.Client{}
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Failed to connect to local gateway (is the node running?): %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
var result map[string]interface{}
|
||||||
|
json.NewDecoder(resp.Body).Decode(&result)
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
errMsg := "unknown error"
|
||||||
|
if e, ok := result["error"].(string); ok {
|
||||||
|
errMsg = e
|
||||||
|
}
|
||||||
|
fmt.Fprintf(os.Stderr, "Failed to get WebRTC status: %s\n", errMsg)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
enabled, _ := result["enabled"].(bool)
|
||||||
|
if !enabled {
|
||||||
|
fmt.Printf("WebRTC is not enabled for namespace '%s'.\n", ns)
|
||||||
|
fmt.Printf(" Enable with: orama namespace enable webrtc --namespace %s\n", ns)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("WebRTC Status for namespace '%s'\n\n", ns)
|
||||||
|
fmt.Printf(" Enabled: yes\n")
|
||||||
|
if sfuCount, ok := result["sfu_node_count"].(float64); ok {
|
||||||
|
fmt.Printf(" SFU nodes: %.0f\n", sfuCount)
|
||||||
|
}
|
||||||
|
if turnCount, ok := result["turn_node_count"].(float64); ok {
|
||||||
|
fmt.Printf(" TURN nodes: %.0f\n", turnCount)
|
||||||
|
}
|
||||||
|
if ttl, ok := result["turn_credential_ttl"].(float64); ok {
|
||||||
|
fmt.Printf(" TURN cred TTL: %.0fs\n", ttl)
|
||||||
|
}
|
||||||
|
if enabledBy, ok := result["enabled_by"].(string); ok {
|
||||||
|
fmt.Printf(" Enabled by: %s\n", enabledBy)
|
||||||
|
}
|
||||||
|
if enabledAt, ok := result["enabled_at"].(string); ok {
|
||||||
|
fmt.Printf(" Enabled at: %s\n", enabledAt)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func handleNamespaceList() {
|
func handleNamespaceList() {
|
||||||
// Load credentials
|
// Load credentials
|
||||||
store, err := auth.LoadEnhancedCredentials()
|
store, err := auth.LoadEnhancedCredentials()
|
||||||
|
|||||||
@ -543,6 +543,8 @@ func (o *Orchestrator) installNamespaceTemplates() error {
|
|||||||
"orama-namespace-rqlite@.service",
|
"orama-namespace-rqlite@.service",
|
||||||
"orama-namespace-olric@.service",
|
"orama-namespace-olric@.service",
|
||||||
"orama-namespace-gateway@.service",
|
"orama-namespace-gateway@.service",
|
||||||
|
"orama-namespace-sfu@.service",
|
||||||
|
"orama-namespace-turn@.service",
|
||||||
}
|
}
|
||||||
|
|
||||||
installedCount := 0
|
installedCount := 0
|
||||||
|
|||||||
@ -431,6 +431,8 @@ func (o *Orchestrator) installNamespaceTemplates() error {
|
|||||||
"orama-namespace-rqlite@.service",
|
"orama-namespace-rqlite@.service",
|
||||||
"orama-namespace-olric@.service",
|
"orama-namespace-olric@.service",
|
||||||
"orama-namespace-gateway@.service",
|
"orama-namespace-gateway@.service",
|
||||||
|
"orama-namespace-sfu@.service",
|
||||||
|
"orama-namespace-turn@.service",
|
||||||
}
|
}
|
||||||
|
|
||||||
installedCount := 0
|
installedCount := 0
|
||||||
|
|||||||
@ -184,7 +184,7 @@ func GetProductionServices() []string {
|
|||||||
namespacesDir := "/opt/orama/.orama/data/namespaces"
|
namespacesDir := "/opt/orama/.orama/data/namespaces"
|
||||||
nsEntries, err := os.ReadDir(namespacesDir)
|
nsEntries, err := os.ReadDir(namespacesDir)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
serviceTypes := []string{"rqlite", "olric", "gateway"}
|
serviceTypes := []string{"rqlite", "olric", "gateway", "sfu", "turn"}
|
||||||
for _, nsEntry := range nsEntries {
|
for _, nsEntry := range nsEntries {
|
||||||
if !nsEntry.IsDir() {
|
if !nsEntry.IsDir() {
|
||||||
continue
|
continue
|
||||||
@ -289,7 +289,8 @@ func identifyPortProcess(port int) string {
|
|||||||
|
|
||||||
// NamespaceServiceOrder defines the dependency order for namespace services.
|
// NamespaceServiceOrder defines the dependency order for namespace services.
|
||||||
// RQLite must start first (database), then Olric (cache), then Gateway (depends on both).
|
// RQLite must start first (database), then Olric (cache), then Gateway (depends on both).
|
||||||
var NamespaceServiceOrder = []string{"rqlite", "olric", "gateway"}
|
// TURN and SFU are optional WebRTC services that start after Gateway.
|
||||||
|
var NamespaceServiceOrder = []string{"rqlite", "olric", "gateway", "turn", "sfu"}
|
||||||
|
|
||||||
// StartServicesOrdered starts services respecting namespace dependency order.
|
// StartServicesOrdered starts services respecting namespace dependency order.
|
||||||
// Namespace services are started in order: rqlite → olric (+ wait) → gateway.
|
// Namespace services are started in order: rqlite → olric (+ wait) → gateway.
|
||||||
|
|||||||
@ -20,6 +20,17 @@ type HTTPGatewayConfig struct {
|
|||||||
IPFSAPIURL string `yaml:"ipfs_api_url"` // IPFS API URL
|
IPFSAPIURL string `yaml:"ipfs_api_url"` // IPFS API URL
|
||||||
IPFSTimeout time.Duration `yaml:"ipfs_timeout"` // Timeout for IPFS operations
|
IPFSTimeout time.Duration `yaml:"ipfs_timeout"` // Timeout for IPFS operations
|
||||||
BaseDomain string `yaml:"base_domain"` // Base domain for deployments (e.g., "dbrs.space"). Defaults to "dbrs.space"
|
BaseDomain string `yaml:"base_domain"` // Base domain for deployments (e.g., "dbrs.space"). Defaults to "dbrs.space"
|
||||||
|
|
||||||
|
// WebRTC configuration (optional, enabled per-namespace)
|
||||||
|
WebRTC WebRTCConfig `yaml:"webrtc"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// WebRTCConfig contains WebRTC-related gateway configuration
|
||||||
|
type WebRTCConfig struct {
|
||||||
|
Enabled bool `yaml:"enabled"` // Whether this gateway has WebRTC support active
|
||||||
|
SFUPort int `yaml:"sfu_port"` // Local SFU signaling port to proxy to
|
||||||
|
TURNDomain string `yaml:"turn_domain"` // TURN domain (e.g., "turn.ns-myapp.dbrs.space")
|
||||||
|
TURNSecret string `yaml:"turn_secret"` // HMAC-SHA1 shared secret for TURN credential generation
|
||||||
}
|
}
|
||||||
|
|
||||||
// HTTPSConfig contains HTTPS/TLS configuration for the gateway
|
// HTTPSConfig contains HTTPS/TLS configuration for the gateway
|
||||||
|
|||||||
@ -41,4 +41,10 @@ type Config struct {
|
|||||||
|
|
||||||
// WireGuard mesh configuration
|
// WireGuard mesh configuration
|
||||||
ClusterSecret string // Cluster secret for authenticating internal WireGuard peer exchange
|
ClusterSecret string // Cluster secret for authenticating internal WireGuard peer exchange
|
||||||
|
|
||||||
|
// WebRTC configuration (set when namespace has WebRTC enabled)
|
||||||
|
WebRTCEnabled bool // Whether WebRTC endpoints are active on this gateway
|
||||||
|
SFUPort int // Local SFU signaling port to proxy WebSocket connections to
|
||||||
|
TURNDomain string // TURN server domain for credential generation
|
||||||
|
TURNSecret string // HMAC-SHA1 shared secret for TURN credential generation
|
||||||
}
|
}
|
||||||
|
|||||||
@ -70,6 +70,16 @@ func (c *Config) ValidateConfig() []error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Validate WebRTC configuration
|
||||||
|
if c.WebRTCEnabled {
|
||||||
|
if c.SFUPort <= 0 || c.SFUPort > 65535 {
|
||||||
|
errs = append(errs, fmt.Errorf("gateway.sfu_port: must be between 1 and 65535 when webrtc is enabled"))
|
||||||
|
}
|
||||||
|
if c.TURNSecret == "" {
|
||||||
|
errs = append(errs, fmt.Errorf("gateway.turn_secret: must not be empty when webrtc is enabled"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Validate HTTPS configuration
|
// Validate HTTPS configuration
|
||||||
if c.EnableHTTPS {
|
if c.EnableHTTPS {
|
||||||
if c.DomainName == "" {
|
if c.DomainName == "" {
|
||||||
|
|||||||
@ -30,6 +30,7 @@ import (
|
|||||||
pubsubhandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/pubsub"
|
pubsubhandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/pubsub"
|
||||||
serverlesshandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/serverless"
|
serverlesshandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/serverless"
|
||||||
joinhandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/join"
|
joinhandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/join"
|
||||||
|
webrtchandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/webrtc"
|
||||||
wireguardhandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/wireguard"
|
wireguardhandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/wireguard"
|
||||||
sqlitehandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/sqlite"
|
sqlitehandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/sqlite"
|
||||||
"github.com/DeBrosOfficial/network/pkg/gateway/handlers/storage"
|
"github.com/DeBrosOfficial/network/pkg/gateway/handlers/storage"
|
||||||
@ -122,6 +123,9 @@ type Gateway struct {
|
|||||||
rateLimiter *RateLimiter
|
rateLimiter *RateLimiter
|
||||||
namespaceRateLimiter *NamespaceRateLimiter
|
namespaceRateLimiter *NamespaceRateLimiter
|
||||||
|
|
||||||
|
// WebRTC signaling and TURN credentials
|
||||||
|
webrtcHandlers *webrtchandlers.WebRTCHandlers
|
||||||
|
|
||||||
// WireGuard peer exchange
|
// WireGuard peer exchange
|
||||||
wireguardHandler *wireguardhandlers.Handler
|
wireguardHandler *wireguardhandlers.Handler
|
||||||
|
|
||||||
@ -149,6 +153,9 @@ type Gateway struct {
|
|||||||
// Node recovery handler (called when health monitor confirms a node dead or recovered)
|
// Node recovery handler (called when health monitor confirms a node dead or recovered)
|
||||||
nodeRecoverer authhandlers.NodeRecoverer
|
nodeRecoverer authhandlers.NodeRecoverer
|
||||||
|
|
||||||
|
// WebRTC manager for enable/disable operations
|
||||||
|
webrtcManager authhandlers.WebRTCManager
|
||||||
|
|
||||||
// Circuit breakers for proxy targets (per-target failure tracking)
|
// Circuit breakers for proxy targets (per-target failure tracking)
|
||||||
circuitBreakers *CircuitBreakerRegistry
|
circuitBreakers *CircuitBreakerRegistry
|
||||||
|
|
||||||
@ -323,6 +330,18 @@ func New(logger *logging.ColoredLogger, cfg *Config) (*Gateway, error) {
|
|||||||
// Initialize handler instances
|
// Initialize handler instances
|
||||||
gw.pubsubHandlers = pubsubhandlers.NewPubSubHandlers(deps.Client, logger)
|
gw.pubsubHandlers = pubsubhandlers.NewPubSubHandlers(deps.Client, logger)
|
||||||
|
|
||||||
|
if cfg.WebRTCEnabled && cfg.SFUPort > 0 {
|
||||||
|
gw.webrtcHandlers = webrtchandlers.NewWebRTCHandlers(
|
||||||
|
logger,
|
||||||
|
cfg.SFUPort,
|
||||||
|
cfg.TURNDomain,
|
||||||
|
cfg.TURNSecret,
|
||||||
|
gw.proxyWebSocket,
|
||||||
|
)
|
||||||
|
logger.ComponentInfo(logging.ComponentGeneral, "WebRTC handlers initialized",
|
||||||
|
zap.Int("sfu_port", cfg.SFUPort))
|
||||||
|
}
|
||||||
|
|
||||||
if deps.OlricClient != nil {
|
if deps.OlricClient != nil {
|
||||||
gw.cacheHandlers = cache.NewCacheHandlers(logger, deps.OlricClient)
|
gw.cacheHandlers = cache.NewCacheHandlers(logger, deps.OlricClient)
|
||||||
}
|
}
|
||||||
@ -633,6 +652,11 @@ func (g *Gateway) SetNodeRecoverer(nr authhandlers.NodeRecoverer) {
|
|||||||
g.nodeRecoverer = nr
|
g.nodeRecoverer = nr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetWebRTCManager sets the WebRTC lifecycle manager for enable/disable operations.
|
||||||
|
func (g *Gateway) SetWebRTCManager(wm authhandlers.WebRTCManager) {
|
||||||
|
g.webrtcManager = wm
|
||||||
|
}
|
||||||
|
|
||||||
// SetSpawnHandler sets the handler for internal namespace spawn/stop requests.
|
// SetSpawnHandler sets the handler for internal namespace spawn/stop requests.
|
||||||
func (g *Gateway) SetSpawnHandler(h http.Handler) {
|
func (g *Gateway) SetSpawnHandler(h http.Handler) {
|
||||||
g.spawnHandler = h
|
g.spawnHandler = h
|
||||||
@ -847,3 +871,121 @@ func (g *Gateway) namespaceClusterRepairHandler(w http.ResponseWriter, r *http.R
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// namespaceWebRTCEnableHandler handles POST /v1/internal/namespace/webrtc/enable?namespace={name}
|
||||||
|
// Internal-only: authenticated by X-Orama-Internal-Auth header + WireGuard subnet.
|
||||||
|
func (g *Gateway) namespaceWebRTCEnableHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodPost {
|
||||||
|
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.Header.Get("X-Orama-Internal-Auth") != "namespace-coordination" || !nodeauth.IsWireGuardPeer(r.RemoteAddr) {
|
||||||
|
writeError(w, http.StatusUnauthorized, "unauthorized")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
namespaceName := r.URL.Query().Get("namespace")
|
||||||
|
if namespaceName == "" {
|
||||||
|
writeError(w, http.StatusBadRequest, "namespace parameter required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if g.webrtcManager == nil {
|
||||||
|
writeError(w, http.StatusServiceUnavailable, "WebRTC management not enabled")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := g.webrtcManager.EnableWebRTC(r.Context(), namespaceName, "cli"); err != nil {
|
||||||
|
writeError(w, http.StatusInternalServerError, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||||
|
"status": "ok",
|
||||||
|
"namespace": namespaceName,
|
||||||
|
"message": "WebRTC enabled successfully",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// namespaceWebRTCDisableHandler handles POST /v1/internal/namespace/webrtc/disable?namespace={name}
|
||||||
|
// Internal-only: authenticated by X-Orama-Internal-Auth header + WireGuard subnet.
|
||||||
|
func (g *Gateway) namespaceWebRTCDisableHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodPost {
|
||||||
|
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.Header.Get("X-Orama-Internal-Auth") != "namespace-coordination" || !nodeauth.IsWireGuardPeer(r.RemoteAddr) {
|
||||||
|
writeError(w, http.StatusUnauthorized, "unauthorized")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
namespaceName := r.URL.Query().Get("namespace")
|
||||||
|
if namespaceName == "" {
|
||||||
|
writeError(w, http.StatusBadRequest, "namespace parameter required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if g.webrtcManager == nil {
|
||||||
|
writeError(w, http.StatusServiceUnavailable, "WebRTC management not enabled")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := g.webrtcManager.DisableWebRTC(r.Context(), namespaceName); err != nil {
|
||||||
|
writeError(w, http.StatusInternalServerError, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||||
|
"status": "ok",
|
||||||
|
"namespace": namespaceName,
|
||||||
|
"message": "WebRTC disabled successfully",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// namespaceWebRTCStatusHandler handles GET /v1/internal/namespace/webrtc/status?namespace={name}
|
||||||
|
// Internal-only: authenticated by X-Orama-Internal-Auth header + WireGuard subnet.
|
||||||
|
func (g *Gateway) namespaceWebRTCStatusHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodGet {
|
||||||
|
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.Header.Get("X-Orama-Internal-Auth") != "namespace-coordination" || !nodeauth.IsWireGuardPeer(r.RemoteAddr) {
|
||||||
|
writeError(w, http.StatusUnauthorized, "unauthorized")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
namespaceName := r.URL.Query().Get("namespace")
|
||||||
|
if namespaceName == "" {
|
||||||
|
writeError(w, http.StatusBadRequest, "namespace parameter required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if g.webrtcManager == nil {
|
||||||
|
writeError(w, http.StatusServiceUnavailable, "WebRTC management not enabled")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
config, err := g.webrtcManager.GetWebRTCStatus(r.Context(), namespaceName)
|
||||||
|
if err != nil {
|
||||||
|
writeError(w, http.StatusInternalServerError, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
if config == nil {
|
||||||
|
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||||
|
"namespace": namespaceName,
|
||||||
|
"enabled": false,
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
json.NewEncoder(w).Encode(config)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|||||||
@ -58,6 +58,14 @@ type NodeRecoverer interface {
|
|||||||
RepairCluster(ctx context.Context, namespaceName string) error
|
RepairCluster(ctx context.Context, namespaceName string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WebRTCManager handles enabling/disabling WebRTC services for namespaces.
|
||||||
|
type WebRTCManager interface {
|
||||||
|
EnableWebRTC(ctx context.Context, namespaceName, enabledBy string) error
|
||||||
|
DisableWebRTC(ctx context.Context, namespaceName string) error
|
||||||
|
// GetWebRTCStatus returns the WebRTC config for a namespace, or nil if not enabled.
|
||||||
|
GetWebRTCStatus(ctx context.Context, namespaceName string) (interface{}, error)
|
||||||
|
}
|
||||||
|
|
||||||
// Handlers holds dependencies for authentication HTTP handlers
|
// Handlers holds dependencies for authentication HTTP handlers
|
||||||
type Handlers struct {
|
type Handlers struct {
|
||||||
logger *logging.ColoredLogger
|
logger *logging.ColoredLogger
|
||||||
|
|||||||
@ -302,6 +302,8 @@ func (h *DeleteHandler) cleanupGlobalTables(ctx context.Context, ns string) {
|
|||||||
{"namespace_sqlite_databases", "namespace"},
|
{"namespace_sqlite_databases", "namespace"},
|
||||||
{"namespace_quotas", "namespace"},
|
{"namespace_quotas", "namespace"},
|
||||||
{"home_node_assignments", "namespace"},
|
{"home_node_assignments", "namespace"},
|
||||||
|
{"webrtc_rooms", "namespace_name"},
|
||||||
|
{"namespace_webrtc_config", "namespace_name"},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, t := range tables {
|
for _, t := range tables {
|
||||||
|
|||||||
@ -12,12 +12,13 @@ import (
|
|||||||
namespacepkg "github.com/DeBrosOfficial/network/pkg/namespace"
|
namespacepkg "github.com/DeBrosOfficial/network/pkg/namespace"
|
||||||
"github.com/DeBrosOfficial/network/pkg/olric"
|
"github.com/DeBrosOfficial/network/pkg/olric"
|
||||||
"github.com/DeBrosOfficial/network/pkg/rqlite"
|
"github.com/DeBrosOfficial/network/pkg/rqlite"
|
||||||
|
"github.com/DeBrosOfficial/network/pkg/sfu"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
)
|
)
|
||||||
|
|
||||||
// SpawnRequest represents a request to spawn or stop a namespace instance
|
// SpawnRequest represents a request to spawn or stop a namespace instance
|
||||||
type SpawnRequest struct {
|
type SpawnRequest struct {
|
||||||
Action string `json:"action"` // "spawn-rqlite", "spawn-olric", "spawn-gateway", "stop-rqlite", "stop-olric", "stop-gateway", "save-cluster-state", "delete-cluster-state"
|
Action string `json:"action"` // spawn-{rqlite,olric,gateway,sfu,turn}, stop-{rqlite,olric,gateway,sfu,turn}, save-cluster-state, delete-cluster-state
|
||||||
Namespace string `json:"namespace"`
|
Namespace string `json:"namespace"`
|
||||||
NodeID string `json:"node_id"`
|
NodeID string `json:"node_id"`
|
||||||
|
|
||||||
@ -48,6 +49,24 @@ type SpawnRequest struct {
|
|||||||
IPFSTimeout string `json:"ipfs_timeout,omitempty"`
|
IPFSTimeout string `json:"ipfs_timeout,omitempty"`
|
||||||
IPFSReplicationFactor int `json:"ipfs_replication_factor,omitempty"`
|
IPFSReplicationFactor int `json:"ipfs_replication_factor,omitempty"`
|
||||||
|
|
||||||
|
// SFU config (when action = "spawn-sfu")
|
||||||
|
SFUListenAddr string `json:"sfu_listen_addr,omitempty"`
|
||||||
|
SFUMediaStart int `json:"sfu_media_start,omitempty"`
|
||||||
|
SFUMediaEnd int `json:"sfu_media_end,omitempty"`
|
||||||
|
TURNServers []sfu.TURNServerConfig `json:"turn_servers,omitempty"`
|
||||||
|
TURNSecret string `json:"turn_secret,omitempty"`
|
||||||
|
TURNCredTTL int `json:"turn_cred_ttl,omitempty"`
|
||||||
|
RQLiteDSN string `json:"rqlite_dsn,omitempty"`
|
||||||
|
|
||||||
|
// TURN config (when action = "spawn-turn")
|
||||||
|
TURNListenAddr string `json:"turn_listen_addr,omitempty"`
|
||||||
|
TURNTLSAddr string `json:"turn_tls_addr,omitempty"`
|
||||||
|
TURNPublicIP string `json:"turn_public_ip,omitempty"`
|
||||||
|
TURNRealm string `json:"turn_realm,omitempty"`
|
||||||
|
TURNAuthSecret string `json:"turn_auth_secret,omitempty"`
|
||||||
|
TURNRelayStart int `json:"turn_relay_start,omitempty"`
|
||||||
|
TURNRelayEnd int `json:"turn_relay_end,omitempty"`
|
||||||
|
|
||||||
// Cluster state (when action = "save-cluster-state")
|
// Cluster state (when action = "save-cluster-state")
|
||||||
ClusterState json.RawMessage `json:"cluster_state,omitempty"`
|
ClusterState json.RawMessage `json:"cluster_state,omitempty"`
|
||||||
}
|
}
|
||||||
@ -242,6 +261,60 @@ func (h *SpawnHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
writeSpawnResponse(w, http.StatusOK, SpawnResponse{Success: true})
|
writeSpawnResponse(w, http.StatusOK, SpawnResponse{Success: true})
|
||||||
|
|
||||||
|
case "spawn-sfu":
|
||||||
|
cfg := namespacepkg.SFUInstanceConfig{
|
||||||
|
Namespace: req.Namespace,
|
||||||
|
NodeID: req.NodeID,
|
||||||
|
ListenAddr: req.SFUListenAddr,
|
||||||
|
MediaPortStart: req.SFUMediaStart,
|
||||||
|
MediaPortEnd: req.SFUMediaEnd,
|
||||||
|
TURNServers: req.TURNServers,
|
||||||
|
TURNSecret: req.TURNSecret,
|
||||||
|
TURNCredTTL: req.TURNCredTTL,
|
||||||
|
RQLiteDSN: req.RQLiteDSN,
|
||||||
|
}
|
||||||
|
if err := h.systemdSpawner.SpawnSFU(ctx, req.Namespace, req.NodeID, cfg); err != nil {
|
||||||
|
h.logger.Error("Failed to spawn SFU instance", zap.Error(err))
|
||||||
|
writeSpawnResponse(w, http.StatusInternalServerError, SpawnResponse{Error: err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
writeSpawnResponse(w, http.StatusOK, SpawnResponse{Success: true})
|
||||||
|
|
||||||
|
case "stop-sfu":
|
||||||
|
if err := h.systemdSpawner.StopSFU(ctx, req.Namespace, req.NodeID); err != nil {
|
||||||
|
h.logger.Error("Failed to stop SFU instance", zap.Error(err))
|
||||||
|
writeSpawnResponse(w, http.StatusInternalServerError, SpawnResponse{Error: err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
writeSpawnResponse(w, http.StatusOK, SpawnResponse{Success: true})
|
||||||
|
|
||||||
|
case "spawn-turn":
|
||||||
|
cfg := namespacepkg.TURNInstanceConfig{
|
||||||
|
Namespace: req.Namespace,
|
||||||
|
NodeID: req.NodeID,
|
||||||
|
ListenAddr: req.TURNListenAddr,
|
||||||
|
TLSListenAddr: req.TURNTLSAddr,
|
||||||
|
PublicIP: req.TURNPublicIP,
|
||||||
|
Realm: req.TURNRealm,
|
||||||
|
AuthSecret: req.TURNAuthSecret,
|
||||||
|
RelayPortStart: req.TURNRelayStart,
|
||||||
|
RelayPortEnd: req.TURNRelayEnd,
|
||||||
|
}
|
||||||
|
if err := h.systemdSpawner.SpawnTURN(ctx, req.Namespace, req.NodeID, cfg); err != nil {
|
||||||
|
h.logger.Error("Failed to spawn TURN instance", zap.Error(err))
|
||||||
|
writeSpawnResponse(w, http.StatusInternalServerError, SpawnResponse{Error: err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
writeSpawnResponse(w, http.StatusOK, SpawnResponse{Success: true})
|
||||||
|
|
||||||
|
case "stop-turn":
|
||||||
|
if err := h.systemdSpawner.StopTURN(ctx, req.Namespace, req.NodeID); err != nil {
|
||||||
|
h.logger.Error("Failed to stop TURN instance", zap.Error(err))
|
||||||
|
writeSpawnResponse(w, http.StatusInternalServerError, SpawnResponse{Error: err.Error()})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
writeSpawnResponse(w, http.StatusOK, SpawnResponse{Success: true})
|
||||||
|
|
||||||
default:
|
default:
|
||||||
writeSpawnResponse(w, http.StatusBadRequest, SpawnResponse{Error: fmt.Sprintf("unknown action: %s", req.Action)})
|
writeSpawnResponse(w, http.StatusBadRequest, SpawnResponse{Error: fmt.Sprintf("unknown action: %s", req.Action)})
|
||||||
}
|
}
|
||||||
|
|||||||
56
pkg/gateway/handlers/webrtc/credentials.go
Normal file
56
pkg/gateway/handlers/webrtc/credentials.go
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
package webrtc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/DeBrosOfficial/network/pkg/logging"
|
||||||
|
"github.com/DeBrosOfficial/network/pkg/turn"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
const turnCredentialTTL = 10 * time.Minute
|
||||||
|
|
||||||
|
// CredentialsHandler handles POST /v1/webrtc/turn/credentials
|
||||||
|
// Returns fresh TURN credentials scoped to the authenticated namespace.
|
||||||
|
func (h *WebRTCHandlers) CredentialsHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodPost {
|
||||||
|
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ns := resolveNamespaceFromRequest(r)
|
||||||
|
if ns == "" {
|
||||||
|
writeError(w, http.StatusForbidden, "namespace not resolved")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if h.turnSecret == "" {
|
||||||
|
writeError(w, http.StatusServiceUnavailable, "TURN not configured")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
username, password := turn.GenerateCredentials(h.turnSecret, ns, turnCredentialTTL)
|
||||||
|
|
||||||
|
// Build TURN URIs — use IPs to bypass DNS propagation delays
|
||||||
|
var uris []string
|
||||||
|
if h.turnDomain != "" {
|
||||||
|
uris = append(uris,
|
||||||
|
fmt.Sprintf("turn:%s:3478?transport=udp", h.turnDomain),
|
||||||
|
fmt.Sprintf("turn:%s:443?transport=udp", h.turnDomain),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
h.logger.ComponentInfo(logging.ComponentGeneral, "Issued TURN credentials",
|
||||||
|
zap.String("namespace", ns),
|
||||||
|
zap.String("username", username),
|
||||||
|
)
|
||||||
|
|
||||||
|
writeJSON(w, http.StatusOK, map[string]interface{}{
|
||||||
|
"username": username,
|
||||||
|
"password": password,
|
||||||
|
"ttl": int(turnCredentialTTL.Seconds()),
|
||||||
|
"uris": uris,
|
||||||
|
})
|
||||||
|
}
|
||||||
270
pkg/gateway/handlers/webrtc/handlers_test.go
Normal file
270
pkg/gateway/handlers/webrtc/handlers_test.go
Normal file
@ -0,0 +1,270 @@
|
|||||||
|
package webrtc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/DeBrosOfficial/network/pkg/gateway/ctxkeys"
|
||||||
|
"github.com/DeBrosOfficial/network/pkg/logging"
|
||||||
|
)
|
||||||
|
|
||||||
|
func testHandlers() *WebRTCHandlers {
|
||||||
|
logger, _ := logging.NewColoredLogger(logging.ComponentGeneral, false)
|
||||||
|
return NewWebRTCHandlers(
|
||||||
|
logger,
|
||||||
|
8443,
|
||||||
|
"turn.ns-test.dbrs.space",
|
||||||
|
"test-secret-key-32bytes-long!!!!",
|
||||||
|
nil, // No actual proxy in tests
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func requestWithNamespace(method, path, namespace string) *http.Request {
|
||||||
|
req := httptest.NewRequest(method, path, nil)
|
||||||
|
ctx := context.WithValue(req.Context(), ctxkeys.NamespaceOverride, namespace)
|
||||||
|
return req.WithContext(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Credentials handler tests ---
|
||||||
|
|
||||||
|
func TestCredentialsHandler_Success(t *testing.T) {
|
||||||
|
h := testHandlers()
|
||||||
|
req := requestWithNamespace("POST", "/v1/webrtc/turn/credentials", "test-ns")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
h.CredentialsHandler(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("status = %d, want %d", w.Code, http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result map[string]interface{}
|
||||||
|
if err := json.NewDecoder(w.Body).Decode(&result); err != nil {
|
||||||
|
t.Fatalf("failed to decode response: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result["username"] == nil || result["username"] == "" {
|
||||||
|
t.Error("expected non-empty username")
|
||||||
|
}
|
||||||
|
if result["password"] == nil || result["password"] == "" {
|
||||||
|
t.Error("expected non-empty password")
|
||||||
|
}
|
||||||
|
if result["ttl"] == nil {
|
||||||
|
t.Error("expected ttl field")
|
||||||
|
}
|
||||||
|
ttl, ok := result["ttl"].(float64)
|
||||||
|
if !ok || ttl != 600 {
|
||||||
|
t.Errorf("ttl = %v, want 600", result["ttl"])
|
||||||
|
}
|
||||||
|
uris, ok := result["uris"].([]interface{})
|
||||||
|
if !ok || len(uris) != 2 {
|
||||||
|
t.Errorf("uris count = %v, want 2", result["uris"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCredentialsHandler_MethodNotAllowed(t *testing.T) {
|
||||||
|
h := testHandlers()
|
||||||
|
req := requestWithNamespace("GET", "/v1/webrtc/turn/credentials", "test-ns")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
h.CredentialsHandler(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusMethodNotAllowed {
|
||||||
|
t.Errorf("status = %d, want %d", w.Code, http.StatusMethodNotAllowed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCredentialsHandler_NoNamespace(t *testing.T) {
|
||||||
|
h := testHandlers()
|
||||||
|
req := httptest.NewRequest("POST", "/v1/webrtc/turn/credentials", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
h.CredentialsHandler(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusForbidden {
|
||||||
|
t.Errorf("status = %d, want %d", w.Code, http.StatusForbidden)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCredentialsHandler_NoTURNSecret(t *testing.T) {
|
||||||
|
logger, _ := logging.NewColoredLogger(logging.ComponentGeneral, false)
|
||||||
|
h := NewWebRTCHandlers(logger, 8443, "turn.test.dbrs.space", "", nil)
|
||||||
|
|
||||||
|
req := requestWithNamespace("POST", "/v1/webrtc/turn/credentials", "test-ns")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
h.CredentialsHandler(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusServiceUnavailable {
|
||||||
|
t.Errorf("status = %d, want %d", w.Code, http.StatusServiceUnavailable)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Signal handler tests ---
|
||||||
|
|
||||||
|
func TestSignalHandler_NoNamespace(t *testing.T) {
|
||||||
|
h := testHandlers()
|
||||||
|
req := httptest.NewRequest("GET", "/v1/webrtc/signal", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
h.SignalHandler(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusForbidden {
|
||||||
|
t.Errorf("status = %d, want %d", w.Code, http.StatusForbidden)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSignalHandler_NoSFUPort(t *testing.T) {
|
||||||
|
logger, _ := logging.NewColoredLogger(logging.ComponentGeneral, false)
|
||||||
|
h := NewWebRTCHandlers(logger, 0, "", "secret", nil)
|
||||||
|
|
||||||
|
req := requestWithNamespace("GET", "/v1/webrtc/signal", "test-ns")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
h.SignalHandler(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusServiceUnavailable {
|
||||||
|
t.Errorf("status = %d, want %d", w.Code, http.StatusServiceUnavailable)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSignalHandler_NoProxyFunc(t *testing.T) {
|
||||||
|
h := testHandlers() // proxyWebSocket is nil
|
||||||
|
req := requestWithNamespace("GET", "/v1/webrtc/signal", "test-ns")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
h.SignalHandler(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusInternalServerError {
|
||||||
|
t.Errorf("status = %d, want %d", w.Code, http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Rooms handler tests ---
|
||||||
|
|
||||||
|
func TestRoomsHandler_MethodNotAllowed(t *testing.T) {
|
||||||
|
h := testHandlers()
|
||||||
|
req := requestWithNamespace("POST", "/v1/webrtc/rooms", "test-ns")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
h.RoomsHandler(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusMethodNotAllowed {
|
||||||
|
t.Errorf("status = %d, want %d", w.Code, http.StatusMethodNotAllowed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRoomsHandler_NoNamespace(t *testing.T) {
|
||||||
|
h := testHandlers()
|
||||||
|
req := httptest.NewRequest("GET", "/v1/webrtc/rooms", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
h.RoomsHandler(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusForbidden {
|
||||||
|
t.Errorf("status = %d, want %d", w.Code, http.StatusForbidden)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRoomsHandler_NoSFUPort(t *testing.T) {
|
||||||
|
logger, _ := logging.NewColoredLogger(logging.ComponentGeneral, false)
|
||||||
|
h := NewWebRTCHandlers(logger, 0, "", "secret", nil)
|
||||||
|
|
||||||
|
req := requestWithNamespace("GET", "/v1/webrtc/rooms", "test-ns")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
h.RoomsHandler(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusServiceUnavailable {
|
||||||
|
t.Errorf("status = %d, want %d", w.Code, http.StatusServiceUnavailable)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRoomsHandler_SFUProxySuccess(t *testing.T) {
|
||||||
|
// Start a mock SFU health endpoint
|
||||||
|
mockSFU := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte(`{"status":"ok","rooms":3}`))
|
||||||
|
}))
|
||||||
|
defer mockSFU.Close()
|
||||||
|
|
||||||
|
// Extract port from mock server
|
||||||
|
logger, _ := logging.NewColoredLogger(logging.ComponentGeneral, false)
|
||||||
|
// Parse port from mockSFU.URL (format: http://127.0.0.1:PORT)
|
||||||
|
var port int
|
||||||
|
for i := len(mockSFU.URL) - 1; i >= 0; i-- {
|
||||||
|
if mockSFU.URL[i] == ':' {
|
||||||
|
p := mockSFU.URL[i+1:]
|
||||||
|
for _, c := range p {
|
||||||
|
port = port*10 + int(c-'0')
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
h := NewWebRTCHandlers(logger, port, "", "secret", nil)
|
||||||
|
req := requestWithNamespace("GET", "/v1/webrtc/rooms", "test-ns")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
h.RoomsHandler(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("status = %d, want %d", w.Code, http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
body := w.Body.String()
|
||||||
|
if body != `{"status":"ok","rooms":3}` {
|
||||||
|
t.Errorf("body = %q, want %q", body, `{"status":"ok","rooms":3}`)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Helper tests ---
|
||||||
|
|
||||||
|
func TestResolveNamespaceFromRequest(t *testing.T) {
|
||||||
|
// With namespace
|
||||||
|
req := requestWithNamespace("GET", "/test", "my-namespace")
|
||||||
|
ns := resolveNamespaceFromRequest(req)
|
||||||
|
if ns != "my-namespace" {
|
||||||
|
t.Errorf("namespace = %q, want %q", ns, "my-namespace")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Without namespace
|
||||||
|
req = httptest.NewRequest("GET", "/test", nil)
|
||||||
|
ns = resolveNamespaceFromRequest(req)
|
||||||
|
if ns != "" {
|
||||||
|
t.Errorf("namespace = %q, want empty", ns)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWriteError(t *testing.T) {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
writeError(w, http.StatusBadRequest, "bad request")
|
||||||
|
|
||||||
|
if w.Code != http.StatusBadRequest {
|
||||||
|
t.Errorf("status = %d, want %d", w.Code, http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result map[string]string
|
||||||
|
if err := json.NewDecoder(w.Body).Decode(&result); err != nil {
|
||||||
|
t.Fatalf("failed to decode: %v", err)
|
||||||
|
}
|
||||||
|
if result["error"] != "bad request" {
|
||||||
|
t.Errorf("error = %q, want %q", result["error"], "bad request")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWriteJSON(t *testing.T) {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
writeJSON(w, http.StatusOK, map[string]string{"status": "ok"})
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("status = %d, want %d", w.Code, http.StatusOK)
|
||||||
|
}
|
||||||
|
if ct := w.Header().Get("Content-Type"); ct != "application/json" {
|
||||||
|
t.Errorf("Content-Type = %q, want %q", ct, "application/json")
|
||||||
|
}
|
||||||
|
}
|
||||||
51
pkg/gateway/handlers/webrtc/rooms.go
Normal file
51
pkg/gateway/handlers/webrtc/rooms.go
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
package webrtc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/DeBrosOfficial/network/pkg/logging"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RoomsHandler handles GET /v1/webrtc/rooms (list rooms)
|
||||||
|
// and GET /v1/webrtc/rooms?room_id=X (get specific room)
|
||||||
|
// Proxies to the local SFU's health endpoint for room data.
|
||||||
|
func (h *WebRTCHandlers) RoomsHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodGet {
|
||||||
|
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ns := resolveNamespaceFromRequest(r)
|
||||||
|
if ns == "" {
|
||||||
|
writeError(w, http.StatusForbidden, "namespace not resolved")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if h.sfuPort <= 0 {
|
||||||
|
writeError(w, http.StatusServiceUnavailable, "SFU not configured")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Proxy to SFU health endpoint which returns room count
|
||||||
|
targetURL := fmt.Sprintf("http://127.0.0.1:%d/health", h.sfuPort)
|
||||||
|
|
||||||
|
client := &http.Client{Timeout: 5 * time.Second}
|
||||||
|
resp, err := client.Get(targetURL)
|
||||||
|
if err != nil {
|
||||||
|
h.logger.ComponentWarn(logging.ComponentGeneral, "SFU health check failed",
|
||||||
|
zap.String("namespace", ns),
|
||||||
|
zap.Error(err),
|
||||||
|
)
|
||||||
|
writeError(w, http.StatusServiceUnavailable, "SFU unavailable")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(resp.StatusCode)
|
||||||
|
io.Copy(w, resp.Body)
|
||||||
|
}
|
||||||
52
pkg/gateway/handlers/webrtc/signal.go
Normal file
52
pkg/gateway/handlers/webrtc/signal.go
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
package webrtc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/DeBrosOfficial/network/pkg/logging"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SignalHandler handles WebSocket /v1/webrtc/signal
|
||||||
|
// Proxies the WebSocket connection to the local SFU's signaling endpoint.
|
||||||
|
func (h *WebRTCHandlers) SignalHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ns := resolveNamespaceFromRequest(r)
|
||||||
|
if ns == "" {
|
||||||
|
writeError(w, http.StatusForbidden, "namespace not resolved")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if h.sfuPort <= 0 {
|
||||||
|
writeError(w, http.StatusServiceUnavailable, "SFU not configured")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Proxy WebSocket to local SFU on WireGuard IP
|
||||||
|
// SFU binds to WireGuard IP, so we use 127.0.0.1 since we're on the same node
|
||||||
|
targetHost := fmt.Sprintf("127.0.0.1:%d", h.sfuPort)
|
||||||
|
|
||||||
|
h.logger.ComponentDebug(logging.ComponentGeneral, "Proxying WebRTC signal to SFU",
|
||||||
|
zap.String("namespace", ns),
|
||||||
|
zap.String("target", targetHost),
|
||||||
|
)
|
||||||
|
|
||||||
|
// Rewrite the URL path to match the SFU's expected endpoint
|
||||||
|
r.URL.Path = "/ws/signal"
|
||||||
|
r.URL.Scheme = "http"
|
||||||
|
r.URL.Host = targetHost
|
||||||
|
r.Host = targetHost
|
||||||
|
|
||||||
|
if h.proxyWebSocket == nil {
|
||||||
|
writeError(w, http.StatusInternalServerError, "WebSocket proxy not available")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !h.proxyWebSocket(w, r, targetHost) {
|
||||||
|
// proxyWebSocket already wrote the error response
|
||||||
|
h.logger.ComponentWarn(logging.ComponentGeneral, "SFU WebSocket proxy failed",
|
||||||
|
zap.String("namespace", ns),
|
||||||
|
zap.String("target", targetHost),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
58
pkg/gateway/handlers/webrtc/types.go
Normal file
58
pkg/gateway/handlers/webrtc/types.go
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
package webrtc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/DeBrosOfficial/network/pkg/gateway/ctxkeys"
|
||||||
|
"github.com/DeBrosOfficial/network/pkg/logging"
|
||||||
|
)
|
||||||
|
|
||||||
|
// WebRTCHandlers handles all WebRTC-related HTTP and WebSocket endpoints.
|
||||||
|
// These run on the namespace gateway and proxy signaling to the local SFU.
|
||||||
|
type WebRTCHandlers struct {
|
||||||
|
logger *logging.ColoredLogger
|
||||||
|
sfuPort int // Local SFU signaling port to proxy WebSocket connections to
|
||||||
|
turnDomain string // TURN server domain for building URIs
|
||||||
|
turnSecret string // HMAC-SHA1 shared secret for TURN credential generation
|
||||||
|
|
||||||
|
// proxyWebSocket is injected from the gateway to reuse its WebSocket proxy logic
|
||||||
|
proxyWebSocket func(w http.ResponseWriter, r *http.Request, targetHost string) bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewWebRTCHandlers creates a new WebRTCHandlers instance.
|
||||||
|
func NewWebRTCHandlers(
|
||||||
|
logger *logging.ColoredLogger,
|
||||||
|
sfuPort int,
|
||||||
|
turnDomain string,
|
||||||
|
turnSecret string,
|
||||||
|
proxyWS func(w http.ResponseWriter, r *http.Request, targetHost string) bool,
|
||||||
|
) *WebRTCHandlers {
|
||||||
|
return &WebRTCHandlers{
|
||||||
|
logger: logger,
|
||||||
|
sfuPort: sfuPort,
|
||||||
|
turnDomain: turnDomain,
|
||||||
|
turnSecret: turnSecret,
|
||||||
|
proxyWebSocket: proxyWS,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// resolveNamespaceFromRequest gets namespace from context set by auth middleware
|
||||||
|
func resolveNamespaceFromRequest(r *http.Request) string {
|
||||||
|
if v := r.Context().Value(ctxkeys.NamespaceOverride); v != nil {
|
||||||
|
if s, ok := v.(string); ok {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeJSON(w http.ResponseWriter, code int, v any) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(code)
|
||||||
|
json.NewEncoder(w).Encode(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeError(w http.ResponseWriter, code int, msg string) {
|
||||||
|
writeJSON(w, code, map[string]string{"error": msg})
|
||||||
|
}
|
||||||
@ -196,7 +196,7 @@ func (g *Gateway) securityHeadersMiddleware(next http.Handler) http.Handler {
|
|||||||
w.Header().Set("X-Frame-Options", "DENY")
|
w.Header().Set("X-Frame-Options", "DENY")
|
||||||
w.Header().Set("X-XSS-Protection", "0")
|
w.Header().Set("X-XSS-Protection", "0")
|
||||||
w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
|
w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
|
||||||
w.Header().Set("Permissions-Policy", "camera=(), microphone=(), geolocation=()")
|
w.Header().Set("Permissions-Policy", "camera=(self), microphone=(self), geolocation=()")
|
||||||
// HSTS only when behind TLS (Caddy)
|
// HSTS only when behind TLS (Caddy)
|
||||||
if r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https" {
|
if r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https" {
|
||||||
w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains")
|
w.Header().Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains")
|
||||||
@ -618,6 +618,9 @@ func requiresNamespaceOwnership(p string) bool {
|
|||||||
if strings.HasPrefix(p, "/v1/functions") {
|
if strings.HasPrefix(p, "/v1/functions") {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
if strings.HasPrefix(p, "/v1/webrtc/") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -382,7 +382,7 @@ func TestSecurityHeadersMiddleware(t *testing.T) {
|
|||||||
"X-Frame-Options": "DENY",
|
"X-Frame-Options": "DENY",
|
||||||
"X-Xss-Protection": "0",
|
"X-Xss-Protection": "0",
|
||||||
"Referrer-Policy": "strict-origin-when-cross-origin",
|
"Referrer-Policy": "strict-origin-when-cross-origin",
|
||||||
"Permissions-Policy": "camera=(), microphone=(), geolocation=()",
|
"Permissions-Policy": "camera=(self), microphone=(self), geolocation=()",
|
||||||
}
|
}
|
||||||
for header, want := range expected {
|
for header, want := range expected {
|
||||||
got := rr.Header().Get(header)
|
got := rr.Header().Get(header)
|
||||||
|
|||||||
@ -47,6 +47,11 @@ func (g *Gateway) Routes() http.Handler {
|
|||||||
// Namespace cluster repair (internal, handler does its own auth)
|
// Namespace cluster repair (internal, handler does its own auth)
|
||||||
mux.HandleFunc("/v1/internal/namespace/repair", g.namespaceClusterRepairHandler)
|
mux.HandleFunc("/v1/internal/namespace/repair", g.namespaceClusterRepairHandler)
|
||||||
|
|
||||||
|
// Namespace WebRTC enable/disable/status (internal, handler does its own auth)
|
||||||
|
mux.HandleFunc("/v1/internal/namespace/webrtc/enable", g.namespaceWebRTCEnableHandler)
|
||||||
|
mux.HandleFunc("/v1/internal/namespace/webrtc/disable", g.namespaceWebRTCDisableHandler)
|
||||||
|
mux.HandleFunc("/v1/internal/namespace/webrtc/status", g.namespaceWebRTCStatusHandler)
|
||||||
|
|
||||||
// auth endpoints
|
// auth endpoints
|
||||||
mux.HandleFunc("/v1/auth/jwks", g.authService.JWKSHandler)
|
mux.HandleFunc("/v1/auth/jwks", g.authService.JWKSHandler)
|
||||||
mux.HandleFunc("/.well-known/jwks.json", g.authService.JWKSHandler)
|
mux.HandleFunc("/.well-known/jwks.json", g.authService.JWKSHandler)
|
||||||
@ -104,6 +109,13 @@ func (g *Gateway) Routes() http.Handler {
|
|||||||
mux.HandleFunc("/v1/pubsub/presence", g.pubsubHandlers.PresenceHandler)
|
mux.HandleFunc("/v1/pubsub/presence", g.pubsubHandlers.PresenceHandler)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// webrtc
|
||||||
|
if g.webrtcHandlers != nil {
|
||||||
|
mux.HandleFunc("/v1/webrtc/turn/credentials", g.webrtcHandlers.CredentialsHandler)
|
||||||
|
mux.HandleFunc("/v1/webrtc/signal", g.webrtcHandlers.SignalHandler)
|
||||||
|
mux.HandleFunc("/v1/webrtc/rooms", g.webrtcHandlers.RoomsHandler)
|
||||||
|
}
|
||||||
|
|
||||||
// anon proxy (authenticated users only)
|
// anon proxy (authenticated users only)
|
||||||
mux.HandleFunc("/v1/proxy/anon", g.anonProxyHandler)
|
mux.HandleFunc("/v1/proxy/anon", g.anonProxyHandler)
|
||||||
|
|
||||||
|
|||||||
@ -55,6 +55,8 @@ const (
|
|||||||
ComponentGeneral Component = "GENERAL"
|
ComponentGeneral Component = "GENERAL"
|
||||||
ComponentAnyone Component = "ANYONE"
|
ComponentAnyone Component = "ANYONE"
|
||||||
ComponentGateway Component = "GATEWAY"
|
ComponentGateway Component = "GATEWAY"
|
||||||
|
ComponentSFU Component = "SFU"
|
||||||
|
ComponentTURN Component = "TURN"
|
||||||
)
|
)
|
||||||
|
|
||||||
// getComponentColor returns the color for a specific component
|
// getComponentColor returns the color for a specific component
|
||||||
@ -78,6 +80,10 @@ func getComponentColor(component Component) string {
|
|||||||
return Cyan
|
return Cyan
|
||||||
case ComponentGateway:
|
case ComponentGateway:
|
||||||
return BrightGreen
|
return BrightGreen
|
||||||
|
case ComponentSFU:
|
||||||
|
return BrightRed
|
||||||
|
case ComponentTURN:
|
||||||
|
return Magenta
|
||||||
default:
|
default:
|
||||||
return White
|
return White
|
||||||
}
|
}
|
||||||
|
|||||||
@ -18,6 +18,7 @@ import (
|
|||||||
"github.com/DeBrosOfficial/network/pkg/gateway"
|
"github.com/DeBrosOfficial/network/pkg/gateway"
|
||||||
"github.com/DeBrosOfficial/network/pkg/olric"
|
"github.com/DeBrosOfficial/network/pkg/olric"
|
||||||
"github.com/DeBrosOfficial/network/pkg/rqlite"
|
"github.com/DeBrosOfficial/network/pkg/rqlite"
|
||||||
|
"github.com/DeBrosOfficial/network/pkg/sfu"
|
||||||
"github.com/DeBrosOfficial/network/pkg/systemd"
|
"github.com/DeBrosOfficial/network/pkg/systemd"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
@ -37,12 +38,13 @@ type ClusterManagerConfig struct {
|
|||||||
|
|
||||||
// ClusterManager orchestrates namespace cluster provisioning and lifecycle
|
// ClusterManager orchestrates namespace cluster provisioning and lifecycle
|
||||||
type ClusterManager struct {
|
type ClusterManager struct {
|
||||||
db rqlite.Client
|
db rqlite.Client
|
||||||
portAllocator *NamespacePortAllocator
|
portAllocator *NamespacePortAllocator
|
||||||
nodeSelector *ClusterNodeSelector
|
webrtcPortAllocator *WebRTCPortAllocator
|
||||||
systemdSpawner *SystemdSpawner // NEW: Systemd-based spawner replaces old spawners
|
nodeSelector *ClusterNodeSelector
|
||||||
dnsManager *DNSRecordManager
|
systemdSpawner *SystemdSpawner // NEW: Systemd-based spawner replaces old spawners
|
||||||
logger *zap.Logger
|
dnsManager *DNSRecordManager
|
||||||
|
logger *zap.Logger
|
||||||
baseDomain string
|
baseDomain string
|
||||||
baseDataDir string
|
baseDataDir string
|
||||||
globalRQLiteDSN string // Global RQLite DSN for namespace gateway auth
|
globalRQLiteDSN string // Global RQLite DSN for namespace gateway auth
|
||||||
@ -69,6 +71,7 @@ func NewClusterManager(
|
|||||||
) *ClusterManager {
|
) *ClusterManager {
|
||||||
// Create internal components
|
// Create internal components
|
||||||
portAllocator := NewNamespacePortAllocator(db, logger)
|
portAllocator := NewNamespacePortAllocator(db, logger)
|
||||||
|
webrtcPortAllocator := NewWebRTCPortAllocator(db, logger)
|
||||||
nodeSelector := NewClusterNodeSelector(db, portAllocator, logger)
|
nodeSelector := NewClusterNodeSelector(db, portAllocator, logger)
|
||||||
systemdSpawner := NewSystemdSpawner(cfg.BaseDataDir, logger)
|
systemdSpawner := NewSystemdSpawner(cfg.BaseDataDir, logger)
|
||||||
dnsManager := NewDNSRecordManager(db, cfg.BaseDomain, logger)
|
dnsManager := NewDNSRecordManager(db, cfg.BaseDomain, logger)
|
||||||
@ -94,6 +97,7 @@ func NewClusterManager(
|
|||||||
return &ClusterManager{
|
return &ClusterManager{
|
||||||
db: db,
|
db: db,
|
||||||
portAllocator: portAllocator,
|
portAllocator: portAllocator,
|
||||||
|
webrtcPortAllocator: webrtcPortAllocator,
|
||||||
nodeSelector: nodeSelector,
|
nodeSelector: nodeSelector,
|
||||||
systemdSpawner: systemdSpawner,
|
systemdSpawner: systemdSpawner,
|
||||||
dnsManager: dnsManager,
|
dnsManager: dnsManager,
|
||||||
@ -139,6 +143,7 @@ func NewClusterManagerWithComponents(
|
|||||||
return &ClusterManager{
|
return &ClusterManager{
|
||||||
db: db,
|
db: db,
|
||||||
portAllocator: portAllocator,
|
portAllocator: portAllocator,
|
||||||
|
webrtcPortAllocator: NewWebRTCPortAllocator(db, logger),
|
||||||
nodeSelector: nodeSelector,
|
nodeSelector: nodeSelector,
|
||||||
systemdSpawner: systemdSpawner,
|
systemdSpawner: systemdSpawner,
|
||||||
dnsManager: NewDNSRecordManager(db, cfg.BaseDomain, logger),
|
dnsManager: NewDNSRecordManager(db, cfg.BaseDomain, logger),
|
||||||
@ -854,12 +859,21 @@ func (cm *ClusterManager) DeprovisionCluster(ctx context.Context, namespaceID in
|
|||||||
if err := cm.db.Query(ctx, &clusterNodes, nodeQuery, cluster.ID); err != nil {
|
if err := cm.db.Query(ctx, &clusterNodes, nodeQuery, cluster.ID); err != nil {
|
||||||
cm.logger.Warn("Failed to query cluster nodes for deprovisioning, falling back to local-only stop", zap.Error(err))
|
cm.logger.Warn("Failed to query cluster nodes for deprovisioning, falling back to local-only stop", zap.Error(err))
|
||||||
// Fall back to local-only stop (individual methods, NOT StopAll which uses dangerous glob)
|
// Fall back to local-only stop (individual methods, NOT StopAll which uses dangerous glob)
|
||||||
|
// Stop WebRTC services first (SFU → TURN), then core services (Gateway → Olric → RQLite)
|
||||||
|
cm.systemdSpawner.StopSFU(ctx, cluster.NamespaceName, cm.localNodeID)
|
||||||
|
cm.systemdSpawner.StopTURN(ctx, cluster.NamespaceName, cm.localNodeID)
|
||||||
cm.systemdSpawner.StopGateway(ctx, cluster.NamespaceName, cm.localNodeID)
|
cm.systemdSpawner.StopGateway(ctx, cluster.NamespaceName, cm.localNodeID)
|
||||||
cm.systemdSpawner.StopOlric(ctx, cluster.NamespaceName, cm.localNodeID)
|
cm.systemdSpawner.StopOlric(ctx, cluster.NamespaceName, cm.localNodeID)
|
||||||
cm.systemdSpawner.StopRQLite(ctx, cluster.NamespaceName, cm.localNodeID)
|
cm.systemdSpawner.StopRQLite(ctx, cluster.NamespaceName, cm.localNodeID)
|
||||||
cm.systemdSpawner.DeleteClusterState(cluster.NamespaceName)
|
cm.systemdSpawner.DeleteClusterState(cluster.NamespaceName)
|
||||||
} else {
|
} else {
|
||||||
// 2. Stop namespace infra on ALL nodes (reverse dependency order: Gateway → Olric → RQLite)
|
// 2. Stop WebRTC services first (SFU → TURN), then core infra (Gateway → Olric → RQLite)
|
||||||
|
for _, node := range clusterNodes {
|
||||||
|
cm.stopSFUOnNode(ctx, node.NodeID, node.InternalIP, cluster.NamespaceName)
|
||||||
|
}
|
||||||
|
for _, node := range clusterNodes {
|
||||||
|
cm.stopTURNOnNode(ctx, node.NodeID, node.InternalIP, cluster.NamespaceName)
|
||||||
|
}
|
||||||
for _, node := range clusterNodes {
|
for _, node := range clusterNodes {
|
||||||
cm.stopGatewayOnNode(ctx, node.NodeID, node.InternalIP, cluster.NamespaceName)
|
cm.stopGatewayOnNode(ctx, node.NodeID, node.InternalIP, cluster.NamespaceName)
|
||||||
}
|
}
|
||||||
@ -880,16 +894,21 @@ func (cm *ClusterManager) DeprovisionCluster(ctx context.Context, namespaceID in
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 4. Deallocate all ports
|
// 4. Deallocate all ports (core + WebRTC)
|
||||||
cm.portAllocator.DeallocateAllPortBlocks(ctx, cluster.ID)
|
cm.portAllocator.DeallocateAllPortBlocks(ctx, cluster.ID)
|
||||||
|
cm.webrtcPortAllocator.DeallocateAll(ctx, cluster.ID)
|
||||||
|
|
||||||
// 5. Delete namespace DNS records
|
// 5. Delete namespace DNS records (gateway + TURN)
|
||||||
cm.dnsManager.DeleteNamespaceRecords(ctx, cluster.NamespaceName)
|
cm.dnsManager.DeleteNamespaceRecords(ctx, cluster.NamespaceName)
|
||||||
|
cm.dnsManager.DeleteTURNRecords(ctx, cluster.NamespaceName)
|
||||||
|
|
||||||
// 6. Explicitly delete child tables (FK cascades disabled in rqlite)
|
// 6. Explicitly delete child tables (FK cascades disabled in rqlite)
|
||||||
cm.db.Exec(ctx, `DELETE FROM namespace_cluster_events WHERE namespace_cluster_id = ?`, cluster.ID)
|
cm.db.Exec(ctx, `DELETE FROM namespace_cluster_events WHERE namespace_cluster_id = ?`, cluster.ID)
|
||||||
cm.db.Exec(ctx, `DELETE FROM namespace_cluster_nodes WHERE namespace_cluster_id = ?`, cluster.ID)
|
cm.db.Exec(ctx, `DELETE FROM namespace_cluster_nodes WHERE namespace_cluster_id = ?`, cluster.ID)
|
||||||
cm.db.Exec(ctx, `DELETE FROM namespace_port_allocations WHERE namespace_cluster_id = ?`, cluster.ID)
|
cm.db.Exec(ctx, `DELETE FROM namespace_port_allocations WHERE namespace_cluster_id = ?`, cluster.ID)
|
||||||
|
cm.db.Exec(ctx, `DELETE FROM webrtc_port_allocations WHERE namespace_cluster_id = ?`, cluster.ID)
|
||||||
|
cm.db.Exec(ctx, `DELETE FROM webrtc_rooms WHERE namespace_cluster_id = ?`, cluster.ID)
|
||||||
|
cm.db.Exec(ctx, `DELETE FROM namespace_webrtc_config WHERE namespace_cluster_id = ?`, cluster.ID)
|
||||||
|
|
||||||
// 7. Delete cluster record
|
// 7. Delete cluster record
|
||||||
cm.db.Exec(ctx, `DELETE FROM namespace_clusters WHERE id = ?`, cluster.ID)
|
cm.db.Exec(ctx, `DELETE FROM namespace_clusters WHERE id = ?`, cluster.ID)
|
||||||
@ -1594,6 +1613,19 @@ type ClusterLocalState struct {
|
|||||||
HasGateway bool `json:"has_gateway"`
|
HasGateway bool `json:"has_gateway"`
|
||||||
BaseDomain string `json:"base_domain"`
|
BaseDomain string `json:"base_domain"`
|
||||||
SavedAt time.Time `json:"saved_at"`
|
SavedAt time.Time `json:"saved_at"`
|
||||||
|
|
||||||
|
// WebRTC fields (zero values when WebRTC not enabled — backward compatible)
|
||||||
|
HasSFU bool `json:"has_sfu,omitempty"`
|
||||||
|
HasTURN bool `json:"has_turn,omitempty"`
|
||||||
|
TURNSharedSecret string `json:"-"` // Never persisted to disk state file
|
||||||
|
TURNCredentialTTL int `json:"turn_credential_ttl,omitempty"`
|
||||||
|
SFUSignalingPort int `json:"sfu_signaling_port,omitempty"`
|
||||||
|
SFUMediaPortStart int `json:"sfu_media_port_start,omitempty"`
|
||||||
|
SFUMediaPortEnd int `json:"sfu_media_port_end,omitempty"`
|
||||||
|
TURNListenPort int `json:"turn_listen_port,omitempty"`
|
||||||
|
TURNTLSPort int `json:"turn_tls_port,omitempty"`
|
||||||
|
TURNRelayPortStart int `json:"turn_relay_port_start,omitempty"`
|
||||||
|
TURNRelayPortEnd int `json:"turn_relay_port_end,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ClusterLocalStatePorts struct {
|
type ClusterLocalStatePorts struct {
|
||||||
@ -1891,6 +1923,70 @@ func (cm *ClusterManager) restoreClusterFromState(ctx context.Context, state *Cl
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 4. Restore TURN (if enabled)
|
||||||
|
if state.HasTURN && state.TURNRelayPortStart > 0 {
|
||||||
|
turnRunning, _ := cm.systemdSpawner.systemdMgr.IsServiceActive(state.NamespaceName, systemd.ServiceTypeTURN)
|
||||||
|
if !turnRunning {
|
||||||
|
// TURN config needs the shared secret from DB — we can't persist it to disk state.
|
||||||
|
// If DB is available, fetch it; otherwise skip TURN restore (it will come back when DB is ready).
|
||||||
|
webrtcCfg, err := cm.GetWebRTCConfig(ctx, state.NamespaceName)
|
||||||
|
if err == nil && webrtcCfg != nil {
|
||||||
|
turnCfg := TURNInstanceConfig{
|
||||||
|
Namespace: state.NamespaceName,
|
||||||
|
NodeID: cm.localNodeID,
|
||||||
|
ListenAddr: fmt.Sprintf("0.0.0.0:%d", state.TURNListenPort),
|
||||||
|
TLSListenAddr: fmt.Sprintf("0.0.0.0:%d", state.TURNTLSPort),
|
||||||
|
PublicIP: "", // Will be resolved by spawner or from node info
|
||||||
|
Realm: cm.baseDomain,
|
||||||
|
AuthSecret: webrtcCfg.TURNSharedSecret,
|
||||||
|
RelayPortStart: state.TURNRelayPortStart,
|
||||||
|
RelayPortEnd: state.TURNRelayPortEnd,
|
||||||
|
}
|
||||||
|
if err := cm.systemdSpawner.SpawnTURN(ctx, state.NamespaceName, cm.localNodeID, turnCfg); err != nil {
|
||||||
|
cm.logger.Error("Failed to restore TURN from state", zap.String("namespace", state.NamespaceName), zap.Error(err))
|
||||||
|
} else {
|
||||||
|
cm.logger.Info("Restored TURN instance from state", zap.String("namespace", state.NamespaceName))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
cm.logger.Warn("Skipping TURN restore: WebRTC config not available from DB",
|
||||||
|
zap.String("namespace", state.NamespaceName))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 5. Restore SFU (if enabled)
|
||||||
|
if state.HasSFU && state.SFUSignalingPort > 0 {
|
||||||
|
sfuRunning, _ := cm.systemdSpawner.systemdMgr.IsServiceActive(state.NamespaceName, systemd.ServiceTypeSFU)
|
||||||
|
if !sfuRunning {
|
||||||
|
webrtcCfg, err := cm.GetWebRTCConfig(ctx, state.NamespaceName)
|
||||||
|
if err == nil && webrtcCfg != nil {
|
||||||
|
turnDomain := fmt.Sprintf("turn.ns-%s.%s", state.NamespaceName, cm.baseDomain)
|
||||||
|
sfuCfg := SFUInstanceConfig{
|
||||||
|
Namespace: state.NamespaceName,
|
||||||
|
NodeID: cm.localNodeID,
|
||||||
|
ListenAddr: fmt.Sprintf("%s:%d", localIP, state.SFUSignalingPort),
|
||||||
|
MediaPortStart: state.SFUMediaPortStart,
|
||||||
|
MediaPortEnd: state.SFUMediaPortEnd,
|
||||||
|
TURNServers: []sfu.TURNServerConfig{
|
||||||
|
{Host: turnDomain, Port: TURNDefaultPort},
|
||||||
|
{Host: turnDomain, Port: TURNTLSPort},
|
||||||
|
},
|
||||||
|
TURNSecret: webrtcCfg.TURNSharedSecret,
|
||||||
|
TURNCredTTL: webrtcCfg.TURNCredentialTTL,
|
||||||
|
RQLiteDSN: fmt.Sprintf("http://localhost:%d", pb.RQLiteHTTPPort),
|
||||||
|
}
|
||||||
|
if err := cm.systemdSpawner.SpawnSFU(ctx, state.NamespaceName, cm.localNodeID, sfuCfg); err != nil {
|
||||||
|
cm.logger.Error("Failed to restore SFU from state", zap.String("namespace", state.NamespaceName), zap.Error(err))
|
||||||
|
} else {
|
||||||
|
cm.logger.Info("Restored SFU instance from state", zap.String("namespace", state.NamespaceName))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
cm.logger.Warn("Skipping SFU restore: WebRTC config not available from DB",
|
||||||
|
zap.String("namespace", state.NamespaceName))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
616
pkg/namespace/cluster_manager_webrtc.go
Normal file
616
pkg/namespace/cluster_manager_webrtc.go
Normal file
@ -0,0 +1,616 @@
|
|||||||
|
package namespace
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/DeBrosOfficial/network/pkg/client"
|
||||||
|
"github.com/DeBrosOfficial/network/pkg/sfu"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
// EnableWebRTC enables WebRTC (SFU + TURN) for an existing namespace cluster.
|
||||||
|
// Allocates ports, spawns SFU on all 3 nodes and TURN on 2 nodes,
|
||||||
|
// creates TURN DNS records, and updates cluster state.
|
||||||
|
func (cm *ClusterManager) EnableWebRTC(ctx context.Context, namespaceName, enabledBy string) error {
|
||||||
|
internalCtx := client.WithInternalAuth(ctx)
|
||||||
|
|
||||||
|
// 1. Verify cluster exists and is ready
|
||||||
|
cluster, err := cm.GetClusterByNamespace(ctx, namespaceName)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get cluster: %w", err)
|
||||||
|
}
|
||||||
|
if cluster == nil {
|
||||||
|
return ErrClusterNotFound
|
||||||
|
}
|
||||||
|
if cluster.Status != ClusterStatusReady {
|
||||||
|
return &ClusterError{Message: fmt.Sprintf("cluster status is %q, must be %q to enable WebRTC", cluster.Status, ClusterStatusReady)}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Check if WebRTC is already enabled
|
||||||
|
var existingConfigs []WebRTCConfig
|
||||||
|
if err := cm.db.Query(internalCtx, &existingConfigs,
|
||||||
|
`SELECT * FROM namespace_webrtc_config WHERE namespace_cluster_id = ? AND enabled = 1`, cluster.ID); err == nil && len(existingConfigs) > 0 {
|
||||||
|
return ErrWebRTCAlreadyEnabled
|
||||||
|
}
|
||||||
|
|
||||||
|
cm.logger.Info("Enabling WebRTC for namespace",
|
||||||
|
zap.String("namespace", namespaceName),
|
||||||
|
zap.String("cluster_id", cluster.ID),
|
||||||
|
)
|
||||||
|
|
||||||
|
// 3. Generate TURN shared secret (32 bytes, crypto/rand)
|
||||||
|
secretBytes := make([]byte, 32)
|
||||||
|
if _, err := rand.Read(secretBytes); err != nil {
|
||||||
|
return fmt.Errorf("failed to generate TURN secret: %w", err)
|
||||||
|
}
|
||||||
|
turnSecret := base64.StdEncoding.EncodeToString(secretBytes)
|
||||||
|
|
||||||
|
// 4. Insert namespace_webrtc_config
|
||||||
|
webrtcConfigID := uuid.New().String()
|
||||||
|
_, err = cm.db.Exec(internalCtx,
|
||||||
|
`INSERT INTO namespace_webrtc_config (id, namespace_cluster_id, namespace_name, enabled, turn_shared_secret, turn_credential_ttl, sfu_node_count, turn_node_count, enabled_by, enabled_at)
|
||||||
|
VALUES (?, ?, ?, 1, ?, ?, ?, ?, ?, ?)`,
|
||||||
|
webrtcConfigID, cluster.ID, namespaceName,
|
||||||
|
turnSecret, DefaultTURNCredentialTTL,
|
||||||
|
DefaultSFUNodeCount, DefaultTURNNodeCount,
|
||||||
|
enabledBy, time.Now(),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to insert WebRTC config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 5. Get cluster nodes with IPs
|
||||||
|
clusterNodes, err := cm.getClusterNodesWithIPs(ctx, cluster.ID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get cluster nodes: %w", err)
|
||||||
|
}
|
||||||
|
if len(clusterNodes) < 3 {
|
||||||
|
return fmt.Errorf("cluster has %d nodes, need at least 3 for WebRTC", len(clusterNodes))
|
||||||
|
}
|
||||||
|
|
||||||
|
// 6. Allocate SFU ports on all nodes
|
||||||
|
sfuBlocks := make(map[string]*WebRTCPortBlock) // nodeID -> block
|
||||||
|
for _, node := range clusterNodes {
|
||||||
|
block, err := cm.webrtcPortAllocator.AllocateSFUPorts(ctx, node.NodeID, cluster.ID)
|
||||||
|
if err != nil {
|
||||||
|
cm.cleanupWebRTCOnError(ctx, cluster.ID, namespaceName, clusterNodes)
|
||||||
|
return fmt.Errorf("failed to allocate SFU ports on node %s: %w", node.NodeID, err)
|
||||||
|
}
|
||||||
|
sfuBlocks[node.NodeID] = block
|
||||||
|
}
|
||||||
|
|
||||||
|
// 7. Select TURN nodes (prefer nodes without existing TURN allocations)
|
||||||
|
turnNodes := cm.selectTURNNodes(ctx, clusterNodes, DefaultTURNNodeCount)
|
||||||
|
|
||||||
|
// 8. Allocate TURN ports on selected nodes
|
||||||
|
turnBlocks := make(map[string]*WebRTCPortBlock) // nodeID -> block
|
||||||
|
for _, node := range turnNodes {
|
||||||
|
block, err := cm.webrtcPortAllocator.AllocateTURNPorts(ctx, node.NodeID, cluster.ID)
|
||||||
|
if err != nil {
|
||||||
|
cm.cleanupWebRTCOnError(ctx, cluster.ID, namespaceName, clusterNodes)
|
||||||
|
return fmt.Errorf("failed to allocate TURN ports on node %s: %w", node.NodeID, err)
|
||||||
|
}
|
||||||
|
turnBlocks[node.NodeID] = block
|
||||||
|
}
|
||||||
|
|
||||||
|
// 9. Build TURN server list for SFU config
|
||||||
|
turnDomain := fmt.Sprintf("turn.ns-%s.%s", namespaceName, cm.baseDomain)
|
||||||
|
turnServers := []sfu.TURNServerConfig{
|
||||||
|
{Host: turnDomain, Port: TURNDefaultPort},
|
||||||
|
{Host: turnDomain, Port: TURNTLSPort},
|
||||||
|
}
|
||||||
|
|
||||||
|
// 10. Get port blocks for RQLite DSN
|
||||||
|
portBlocks, err := cm.portAllocator.GetAllPortBlocks(ctx, cluster.ID)
|
||||||
|
if err != nil {
|
||||||
|
cm.cleanupWebRTCOnError(ctx, cluster.ID, namespaceName, clusterNodes)
|
||||||
|
return fmt.Errorf("failed to get port blocks: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build nodeID -> PortBlock map
|
||||||
|
nodePortBlocks := make(map[string]*PortBlock)
|
||||||
|
for i := range portBlocks {
|
||||||
|
nodePortBlocks[portBlocks[i].NodeID] = &portBlocks[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
// 11. Spawn TURN on selected nodes
|
||||||
|
for _, node := range turnNodes {
|
||||||
|
turnBlock := turnBlocks[node.NodeID]
|
||||||
|
turnCfg := TURNInstanceConfig{
|
||||||
|
Namespace: namespaceName,
|
||||||
|
NodeID: node.NodeID,
|
||||||
|
ListenAddr: fmt.Sprintf("0.0.0.0:%d", turnBlock.TURNListenPort),
|
||||||
|
TLSListenAddr: fmt.Sprintf("0.0.0.0:%d", turnBlock.TURNTLSPort),
|
||||||
|
PublicIP: node.PublicIP,
|
||||||
|
Realm: cm.baseDomain,
|
||||||
|
AuthSecret: turnSecret,
|
||||||
|
RelayPortStart: turnBlock.TURNRelayPortStart,
|
||||||
|
RelayPortEnd: turnBlock.TURNRelayPortEnd,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := cm.spawnTURNOnNode(ctx, node, namespaceName, turnCfg); err != nil {
|
||||||
|
cm.logger.Error("Failed to spawn TURN",
|
||||||
|
zap.String("namespace", namespaceName),
|
||||||
|
zap.String("node_id", node.NodeID),
|
||||||
|
zap.Error(err))
|
||||||
|
cm.cleanupWebRTCOnError(ctx, cluster.ID, namespaceName, clusterNodes)
|
||||||
|
return fmt.Errorf("failed to spawn TURN on node %s: %w", node.NodeID, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cm.logEvent(ctx, cluster.ID, EventTURNStarted, node.NodeID,
|
||||||
|
fmt.Sprintf("TURN started on %s (relay ports %d-%d)", node.NodeID, turnBlock.TURNRelayPortStart, turnBlock.TURNRelayPortEnd), nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 12. Spawn SFU on all nodes
|
||||||
|
for _, node := range clusterNodes {
|
||||||
|
sfuBlock := sfuBlocks[node.NodeID]
|
||||||
|
pb := nodePortBlocks[node.NodeID]
|
||||||
|
rqliteDSN := fmt.Sprintf("http://localhost:%d", pb.RQLiteHTTPPort)
|
||||||
|
|
||||||
|
sfuCfg := SFUInstanceConfig{
|
||||||
|
Namespace: namespaceName,
|
||||||
|
NodeID: node.NodeID,
|
||||||
|
ListenAddr: fmt.Sprintf("%s:%d", node.InternalIP, sfuBlock.SFUSignalingPort),
|
||||||
|
MediaPortStart: sfuBlock.SFUMediaPortStart,
|
||||||
|
MediaPortEnd: sfuBlock.SFUMediaPortEnd,
|
||||||
|
TURNServers: turnServers,
|
||||||
|
TURNSecret: turnSecret,
|
||||||
|
TURNCredTTL: DefaultTURNCredentialTTL,
|
||||||
|
RQLiteDSN: rqliteDSN,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := cm.spawnSFUOnNode(ctx, node, namespaceName, sfuCfg); err != nil {
|
||||||
|
cm.logger.Error("Failed to spawn SFU",
|
||||||
|
zap.String("namespace", namespaceName),
|
||||||
|
zap.String("node_id", node.NodeID),
|
||||||
|
zap.Error(err))
|
||||||
|
cm.cleanupWebRTCOnError(ctx, cluster.ID, namespaceName, clusterNodes)
|
||||||
|
return fmt.Errorf("failed to spawn SFU on node %s: %w", node.NodeID, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cm.logEvent(ctx, cluster.ID, EventSFUStarted, node.NodeID,
|
||||||
|
fmt.Sprintf("SFU started on %s:%d", node.InternalIP, sfuBlock.SFUSignalingPort), nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 13. Create TURN DNS records
|
||||||
|
var turnIPs []string
|
||||||
|
for _, node := range turnNodes {
|
||||||
|
turnIPs = append(turnIPs, node.PublicIP)
|
||||||
|
}
|
||||||
|
if err := cm.dnsManager.CreateTURNRecords(ctx, namespaceName, turnIPs); err != nil {
|
||||||
|
cm.logger.Warn("Failed to create TURN DNS records",
|
||||||
|
zap.String("namespace", namespaceName),
|
||||||
|
zap.Error(err))
|
||||||
|
}
|
||||||
|
|
||||||
|
// 14. Update cluster-state.json on all nodes with WebRTC info
|
||||||
|
cm.updateClusterStateWithWebRTC(ctx, cluster, clusterNodes, sfuBlocks, turnBlocks)
|
||||||
|
|
||||||
|
cm.logEvent(ctx, cluster.ID, EventWebRTCEnabled, "",
|
||||||
|
fmt.Sprintf("WebRTC enabled: SFU on %d nodes, TURN on %d nodes", len(clusterNodes), len(turnNodes)), nil)
|
||||||
|
|
||||||
|
cm.logger.Info("WebRTC enabled successfully",
|
||||||
|
zap.String("namespace", namespaceName),
|
||||||
|
zap.String("cluster_id", cluster.ID),
|
||||||
|
zap.Int("sfu_nodes", len(clusterNodes)),
|
||||||
|
zap.Int("turn_nodes", len(turnNodes)),
|
||||||
|
)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DisableWebRTC disables WebRTC for a namespace cluster.
|
||||||
|
// Stops SFU/TURN services, deallocates ports, and cleans up DNS/DB.
|
||||||
|
func (cm *ClusterManager) DisableWebRTC(ctx context.Context, namespaceName string) error {
|
||||||
|
internalCtx := client.WithInternalAuth(ctx)
|
||||||
|
|
||||||
|
// 1. Verify cluster exists
|
||||||
|
cluster, err := cm.GetClusterByNamespace(ctx, namespaceName)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get cluster: %w", err)
|
||||||
|
}
|
||||||
|
if cluster == nil {
|
||||||
|
return ErrClusterNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Verify WebRTC is enabled
|
||||||
|
var configs []WebRTCConfig
|
||||||
|
if err := cm.db.Query(internalCtx, &configs,
|
||||||
|
`SELECT * FROM namespace_webrtc_config WHERE namespace_cluster_id = ? AND enabled = 1`, cluster.ID); err != nil || len(configs) == 0 {
|
||||||
|
return ErrWebRTCNotEnabled
|
||||||
|
}
|
||||||
|
|
||||||
|
cm.logger.Info("Disabling WebRTC for namespace",
|
||||||
|
zap.String("namespace", namespaceName),
|
||||||
|
zap.String("cluster_id", cluster.ID),
|
||||||
|
)
|
||||||
|
|
||||||
|
// 3. Get cluster nodes with IPs
|
||||||
|
clusterNodes, err := cm.getClusterNodesWithIPs(ctx, cluster.ID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get cluster nodes: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4. Stop SFU on all nodes
|
||||||
|
for _, node := range clusterNodes {
|
||||||
|
cm.stopSFUOnNode(ctx, node.NodeID, node.InternalIP, namespaceName)
|
||||||
|
cm.logEvent(ctx, cluster.ID, EventSFUStopped, node.NodeID, "SFU stopped", nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 5. Stop TURN on nodes that have TURN allocations
|
||||||
|
turnBlocks, _ := cm.getWebRTCBlocksByType(ctx, cluster.ID, "turn")
|
||||||
|
for _, block := range turnBlocks {
|
||||||
|
nodeIP := cm.getNodeIP(clusterNodes, block.NodeID)
|
||||||
|
cm.stopTURNOnNode(ctx, block.NodeID, nodeIP, namespaceName)
|
||||||
|
cm.logEvent(ctx, cluster.ID, EventTURNStopped, block.NodeID, "TURN stopped", nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 6. Deallocate all WebRTC ports
|
||||||
|
if err := cm.webrtcPortAllocator.DeallocateAll(ctx, cluster.ID); err != nil {
|
||||||
|
cm.logger.Warn("Failed to deallocate WebRTC ports", zap.Error(err))
|
||||||
|
}
|
||||||
|
|
||||||
|
// 7. Delete TURN DNS records
|
||||||
|
if err := cm.dnsManager.DeleteTURNRecords(ctx, namespaceName); err != nil {
|
||||||
|
cm.logger.Warn("Failed to delete TURN DNS records", zap.Error(err))
|
||||||
|
}
|
||||||
|
|
||||||
|
// 8. Clean up DB tables
|
||||||
|
cm.db.Exec(internalCtx, `DELETE FROM webrtc_rooms WHERE namespace_cluster_id = ?`, cluster.ID)
|
||||||
|
cm.db.Exec(internalCtx, `DELETE FROM namespace_webrtc_config WHERE namespace_cluster_id = ?`, cluster.ID)
|
||||||
|
|
||||||
|
// 9. Update cluster-state.json to remove WebRTC info
|
||||||
|
cm.updateClusterStateWithWebRTC(ctx, cluster, clusterNodes, nil, nil)
|
||||||
|
|
||||||
|
cm.logEvent(ctx, cluster.ID, EventWebRTCDisabled, "", "WebRTC disabled", nil)
|
||||||
|
|
||||||
|
cm.logger.Info("WebRTC disabled successfully",
|
||||||
|
zap.String("namespace", namespaceName),
|
||||||
|
zap.String("cluster_id", cluster.ID),
|
||||||
|
)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetWebRTCConfig returns the WebRTC configuration for a namespace.
|
||||||
|
func (cm *ClusterManager) GetWebRTCConfig(ctx context.Context, namespaceName string) (*WebRTCConfig, error) {
|
||||||
|
internalCtx := client.WithInternalAuth(ctx)
|
||||||
|
|
||||||
|
var configs []WebRTCConfig
|
||||||
|
err := cm.db.Query(internalCtx, &configs,
|
||||||
|
`SELECT * FROM namespace_webrtc_config WHERE namespace_name = ? AND enabled = 1`, namespaceName)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to query WebRTC config: %w", err)
|
||||||
|
}
|
||||||
|
if len(configs) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return &configs[0], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetWebRTCStatus returns the WebRTC config as an interface{} for the WebRTCManager interface.
|
||||||
|
func (cm *ClusterManager) GetWebRTCStatus(ctx context.Context, namespaceName string) (interface{}, error) {
|
||||||
|
cfg, err := cm.GetWebRTCConfig(ctx, namespaceName)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if cfg == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return cfg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Internal helpers ---
|
||||||
|
|
||||||
|
// clusterNodeInfo holds node info needed for WebRTC operations
|
||||||
|
type clusterNodeInfo struct {
|
||||||
|
NodeID string
|
||||||
|
InternalIP string // WireGuard IP
|
||||||
|
PublicIP string // Public IP for TURN
|
||||||
|
}
|
||||||
|
|
||||||
|
// getClusterNodesWithIPs returns cluster nodes with both internal and public IPs.
|
||||||
|
func (cm *ClusterManager) getClusterNodesWithIPs(ctx context.Context, clusterID string) ([]clusterNodeInfo, error) {
|
||||||
|
internalCtx := client.WithInternalAuth(ctx)
|
||||||
|
|
||||||
|
type nodeRow struct {
|
||||||
|
NodeID string `db:"node_id"`
|
||||||
|
InternalIP string `db:"internal_ip"`
|
||||||
|
PublicIP string `db:"public_ip"`
|
||||||
|
}
|
||||||
|
var rows []nodeRow
|
||||||
|
query := `
|
||||||
|
SELECT ncn.node_id,
|
||||||
|
COALESCE(dn.internal_ip, dn.ip_address) as internal_ip,
|
||||||
|
dn.ip_address as public_ip
|
||||||
|
FROM namespace_cluster_nodes ncn
|
||||||
|
JOIN dns_nodes dn ON ncn.node_id = dn.id
|
||||||
|
WHERE ncn.namespace_cluster_id = ?
|
||||||
|
`
|
||||||
|
if err := cm.db.Query(internalCtx, &rows, query, clusterID); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
nodes := make([]clusterNodeInfo, len(rows))
|
||||||
|
for i, r := range rows {
|
||||||
|
nodes[i] = clusterNodeInfo{
|
||||||
|
NodeID: r.NodeID,
|
||||||
|
InternalIP: r.InternalIP,
|
||||||
|
PublicIP: r.PublicIP,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nodes, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// selectTURNNodes selects the best N nodes for TURN, preferring nodes without existing TURN allocations.
|
||||||
|
func (cm *ClusterManager) selectTURNNodes(ctx context.Context, nodes []clusterNodeInfo, count int) []clusterNodeInfo {
|
||||||
|
if count >= len(nodes) {
|
||||||
|
return nodes
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prefer nodes without existing TURN allocations
|
||||||
|
var preferred, fallback []clusterNodeInfo
|
||||||
|
for _, node := range nodes {
|
||||||
|
hasTURN, err := cm.webrtcPortAllocator.NodeHasTURN(ctx, node.NodeID)
|
||||||
|
if err != nil || !hasTURN {
|
||||||
|
preferred = append(preferred, node)
|
||||||
|
} else {
|
||||||
|
fallback = append(fallback, node)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Take from preferred first, then fallback
|
||||||
|
result := make([]clusterNodeInfo, 0, count)
|
||||||
|
for _, node := range preferred {
|
||||||
|
if len(result) >= count {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
result = append(result, node)
|
||||||
|
}
|
||||||
|
for _, node := range fallback {
|
||||||
|
if len(result) >= count {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
result = append(result, node)
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// spawnSFUOnNode spawns SFU on a node (local or remote)
|
||||||
|
func (cm *ClusterManager) spawnSFUOnNode(ctx context.Context, node clusterNodeInfo, namespace string, cfg SFUInstanceConfig) error {
|
||||||
|
if node.NodeID == cm.localNodeID {
|
||||||
|
return cm.systemdSpawner.SpawnSFU(ctx, namespace, node.NodeID, cfg)
|
||||||
|
}
|
||||||
|
return cm.spawnSFURemote(ctx, node.InternalIP, cfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
// spawnTURNOnNode spawns TURN on a node (local or remote)
|
||||||
|
func (cm *ClusterManager) spawnTURNOnNode(ctx context.Context, node clusterNodeInfo, namespace string, cfg TURNInstanceConfig) error {
|
||||||
|
if node.NodeID == cm.localNodeID {
|
||||||
|
return cm.systemdSpawner.SpawnTURN(ctx, namespace, node.NodeID, cfg)
|
||||||
|
}
|
||||||
|
return cm.spawnTURNRemote(ctx, node.InternalIP, cfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
// stopSFUOnNode stops SFU on a node (local or remote)
|
||||||
|
func (cm *ClusterManager) stopSFUOnNode(ctx context.Context, nodeID, nodeIP, namespace string) {
|
||||||
|
if nodeID == cm.localNodeID {
|
||||||
|
cm.systemdSpawner.StopSFU(ctx, namespace, nodeID)
|
||||||
|
} else {
|
||||||
|
cm.sendStopRequest(ctx, nodeIP, "stop-sfu", namespace, nodeID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// stopTURNOnNode stops TURN on a node (local or remote)
|
||||||
|
func (cm *ClusterManager) stopTURNOnNode(ctx context.Context, nodeID, nodeIP, namespace string) {
|
||||||
|
if nodeID == cm.localNodeID {
|
||||||
|
cm.systemdSpawner.StopTURN(ctx, namespace, nodeID)
|
||||||
|
} else {
|
||||||
|
cm.sendStopRequest(ctx, nodeIP, "stop-turn", namespace, nodeID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// spawnSFURemote sends a spawn-sfu request to a remote node
|
||||||
|
func (cm *ClusterManager) spawnSFURemote(ctx context.Context, nodeIP string, cfg SFUInstanceConfig) error {
|
||||||
|
// Serialize TURN servers for transport
|
||||||
|
turnServers := make([]map[string]interface{}, len(cfg.TURNServers))
|
||||||
|
for i, ts := range cfg.TURNServers {
|
||||||
|
turnServers[i] = map[string]interface{}{
|
||||||
|
"host": ts.Host,
|
||||||
|
"port": ts.Port,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := cm.sendSpawnRequest(ctx, nodeIP, map[string]interface{}{
|
||||||
|
"action": "spawn-sfu",
|
||||||
|
"namespace": cfg.Namespace,
|
||||||
|
"node_id": cfg.NodeID,
|
||||||
|
"sfu_listen_addr": cfg.ListenAddr,
|
||||||
|
"sfu_media_start": cfg.MediaPortStart,
|
||||||
|
"sfu_media_end": cfg.MediaPortEnd,
|
||||||
|
"turn_servers": turnServers,
|
||||||
|
"turn_secret": cfg.TURNSecret,
|
||||||
|
"turn_cred_ttl": cfg.TURNCredTTL,
|
||||||
|
"rqlite_dsn": cfg.RQLiteDSN,
|
||||||
|
})
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// spawnTURNRemote sends a spawn-turn request to a remote node
|
||||||
|
func (cm *ClusterManager) spawnTURNRemote(ctx context.Context, nodeIP string, cfg TURNInstanceConfig) error {
|
||||||
|
_, err := cm.sendSpawnRequest(ctx, nodeIP, map[string]interface{}{
|
||||||
|
"action": "spawn-turn",
|
||||||
|
"namespace": cfg.Namespace,
|
||||||
|
"node_id": cfg.NodeID,
|
||||||
|
"turn_listen_addr": cfg.ListenAddr,
|
||||||
|
"turn_tls_addr": cfg.TLSListenAddr,
|
||||||
|
"turn_public_ip": cfg.PublicIP,
|
||||||
|
"turn_realm": cfg.Realm,
|
||||||
|
"turn_auth_secret": cfg.AuthSecret,
|
||||||
|
"turn_relay_start": cfg.RelayPortStart,
|
||||||
|
"turn_relay_end": cfg.RelayPortEnd,
|
||||||
|
})
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// getWebRTCBlocksByType returns all WebRTC port blocks of a given type for a cluster.
|
||||||
|
func (cm *ClusterManager) getWebRTCBlocksByType(ctx context.Context, clusterID, serviceType string) ([]WebRTCPortBlock, error) {
|
||||||
|
allBlocks, err := cm.webrtcPortAllocator.GetAllPorts(ctx, clusterID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var filtered []WebRTCPortBlock
|
||||||
|
for _, b := range allBlocks {
|
||||||
|
if b.ServiceType == serviceType {
|
||||||
|
filtered = append(filtered, b)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return filtered, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getNodeIP looks up the internal IP for a node ID from a list.
|
||||||
|
func (cm *ClusterManager) getNodeIP(nodes []clusterNodeInfo, nodeID string) string {
|
||||||
|
for _, n := range nodes {
|
||||||
|
if n.NodeID == nodeID {
|
||||||
|
return n.InternalIP
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// cleanupWebRTCOnError cleans up partial WebRTC allocations when EnableWebRTC fails mid-way.
|
||||||
|
func (cm *ClusterManager) cleanupWebRTCOnError(ctx context.Context, clusterID, namespaceName string, nodes []clusterNodeInfo) {
|
||||||
|
cm.logger.Warn("Cleaning up partial WebRTC enablement",
|
||||||
|
zap.String("namespace", namespaceName),
|
||||||
|
zap.String("cluster_id", clusterID))
|
||||||
|
|
||||||
|
internalCtx := client.WithInternalAuth(ctx)
|
||||||
|
|
||||||
|
// Stop any spawned SFU/TURN services
|
||||||
|
for _, node := range nodes {
|
||||||
|
cm.stopSFUOnNode(ctx, node.NodeID, node.InternalIP, namespaceName)
|
||||||
|
cm.stopTURNOnNode(ctx, node.NodeID, node.InternalIP, namespaceName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deallocate ports
|
||||||
|
cm.webrtcPortAllocator.DeallocateAll(ctx, clusterID)
|
||||||
|
|
||||||
|
// Remove config row
|
||||||
|
cm.db.Exec(internalCtx, `DELETE FROM namespace_webrtc_config WHERE namespace_cluster_id = ?`, clusterID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// updateClusterStateWithWebRTC updates the cluster-state.json on all nodes
|
||||||
|
// to include (or remove) WebRTC port information.
|
||||||
|
// Pass nil maps to clear WebRTC state (when disabling).
|
||||||
|
func (cm *ClusterManager) updateClusterStateWithWebRTC(
|
||||||
|
ctx context.Context,
|
||||||
|
cluster *NamespaceCluster,
|
||||||
|
nodes []clusterNodeInfo,
|
||||||
|
sfuBlocks map[string]*WebRTCPortBlock,
|
||||||
|
turnBlocks map[string]*WebRTCPortBlock,
|
||||||
|
) {
|
||||||
|
// Get existing port blocks for base state
|
||||||
|
portBlocks, err := cm.portAllocator.GetAllPortBlocks(ctx, cluster.ID)
|
||||||
|
if err != nil {
|
||||||
|
cm.logger.Warn("Failed to get port blocks for state update", zap.Error(err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build nodeID -> PortBlock map
|
||||||
|
nodePortMap := make(map[string]*PortBlock)
|
||||||
|
for i := range portBlocks {
|
||||||
|
nodePortMap[portBlocks[i].NodeID] = &portBlocks[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build AllNodes list
|
||||||
|
var allStateNodes []ClusterLocalStateNode
|
||||||
|
for _, node := range nodes {
|
||||||
|
pb := nodePortMap[node.NodeID]
|
||||||
|
if pb == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
allStateNodes = append(allStateNodes, ClusterLocalStateNode{
|
||||||
|
NodeID: node.NodeID,
|
||||||
|
InternalIP: node.InternalIP,
|
||||||
|
RQLiteHTTPPort: pb.RQLiteHTTPPort,
|
||||||
|
RQLiteRaftPort: pb.RQLiteRaftPort,
|
||||||
|
OlricHTTPPort: pb.OlricHTTPPort,
|
||||||
|
OlricMemberlistPort: pb.OlricMemberlistPort,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save state on each node
|
||||||
|
for _, node := range nodes {
|
||||||
|
pb := nodePortMap[node.NodeID]
|
||||||
|
if pb == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
state := &ClusterLocalState{
|
||||||
|
ClusterID: cluster.ID,
|
||||||
|
NamespaceName: cluster.NamespaceName,
|
||||||
|
LocalNodeID: node.NodeID,
|
||||||
|
LocalIP: node.InternalIP,
|
||||||
|
LocalPorts: ClusterLocalStatePorts{
|
||||||
|
RQLiteHTTPPort: pb.RQLiteHTTPPort,
|
||||||
|
RQLiteRaftPort: pb.RQLiteRaftPort,
|
||||||
|
OlricHTTPPort: pb.OlricHTTPPort,
|
||||||
|
OlricMemberlistPort: pb.OlricMemberlistPort,
|
||||||
|
GatewayHTTPPort: pb.GatewayHTTPPort,
|
||||||
|
},
|
||||||
|
AllNodes: allStateNodes,
|
||||||
|
HasGateway: true,
|
||||||
|
BaseDomain: cm.baseDomain,
|
||||||
|
SavedAt: time.Now(),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add WebRTC fields if enabling
|
||||||
|
if sfuBlocks != nil {
|
||||||
|
if sfuBlock, ok := sfuBlocks[node.NodeID]; ok {
|
||||||
|
state.HasSFU = true
|
||||||
|
state.SFUSignalingPort = sfuBlock.SFUSignalingPort
|
||||||
|
state.SFUMediaPortStart = sfuBlock.SFUMediaPortStart
|
||||||
|
state.SFUMediaPortEnd = sfuBlock.SFUMediaPortEnd
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if turnBlocks != nil {
|
||||||
|
if turnBlock, ok := turnBlocks[node.NodeID]; ok {
|
||||||
|
state.HasTURN = true
|
||||||
|
state.TURNListenPort = turnBlock.TURNListenPort
|
||||||
|
state.TURNTLSPort = turnBlock.TURNTLSPort
|
||||||
|
state.TURNRelayPortStart = turnBlock.TURNRelayPortStart
|
||||||
|
state.TURNRelayPortEnd = turnBlock.TURNRelayPortEnd
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if node.NodeID == cm.localNodeID {
|
||||||
|
if err := cm.saveLocalState(state); err != nil {
|
||||||
|
cm.logger.Warn("Failed to save local cluster state",
|
||||||
|
zap.String("namespace", cluster.NamespaceName),
|
||||||
|
zap.Error(err))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
cm.saveRemoteState(ctx, node.InternalIP, cluster.NamespaceName, state)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// saveRemoteState sends cluster state to a remote node for persistence.
|
||||||
|
func (cm *ClusterManager) saveRemoteState(ctx context.Context, nodeIP, namespace string, state *ClusterLocalState) {
|
||||||
|
_, err := cm.sendSpawnRequest(ctx, nodeIP, map[string]interface{}{
|
||||||
|
"action": "save-cluster-state",
|
||||||
|
"namespace": namespace,
|
||||||
|
"cluster_state": state,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
cm.logger.Warn("Failed to save cluster state on remote node",
|
||||||
|
zap.String("node_ip", nodeIP),
|
||||||
|
zap.Error(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -300,6 +300,78 @@ func (drm *DNSRecordManager) DisableNamespaceRecord(ctx context.Context, namespa
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CreateTURNRecords creates DNS A records for TURN servers.
|
||||||
|
// TURN records follow the pattern: turn.ns-{namespace}.{baseDomain} -> TURN node IPs
|
||||||
|
func (drm *DNSRecordManager) CreateTURNRecords(ctx context.Context, namespaceName string, turnIPs []string) error {
|
||||||
|
internalCtx := client.WithInternalAuth(ctx)
|
||||||
|
|
||||||
|
if len(turnIPs) == 0 {
|
||||||
|
return &ClusterError{Message: "no TURN IPs provided for DNS records"}
|
||||||
|
}
|
||||||
|
|
||||||
|
fqdn := fmt.Sprintf("turn.ns-%s.%s.", namespaceName, drm.baseDomain)
|
||||||
|
|
||||||
|
drm.logger.Info("Creating TURN DNS records",
|
||||||
|
zap.String("namespace", namespaceName),
|
||||||
|
zap.String("fqdn", fqdn),
|
||||||
|
zap.Strings("turn_ips", turnIPs),
|
||||||
|
)
|
||||||
|
|
||||||
|
// Delete existing TURN records for this namespace
|
||||||
|
deleteQuery := `DELETE FROM dns_records WHERE fqdn = ? AND namespace = ?`
|
||||||
|
_, _ = drm.db.Exec(internalCtx, deleteQuery, fqdn, "namespace-turn:"+namespaceName)
|
||||||
|
|
||||||
|
// Create A records for each TURN node IP
|
||||||
|
now := time.Now()
|
||||||
|
for _, ip := range turnIPs {
|
||||||
|
recordID := uuid.New().String()
|
||||||
|
insertQuery := `
|
||||||
|
INSERT INTO dns_records (
|
||||||
|
id, fqdn, record_type, value, ttl, namespace, created_by, is_active, created_at, updated_at
|
||||||
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
|
`
|
||||||
|
_, err := drm.db.Exec(internalCtx, insertQuery,
|
||||||
|
recordID, fqdn, "A", ip, 60,
|
||||||
|
"namespace-turn:"+namespaceName,
|
||||||
|
"cluster-manager",
|
||||||
|
true, now, now,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return &ClusterError{
|
||||||
|
Message: fmt.Sprintf("failed to create TURN DNS record %s -> %s", fqdn, ip),
|
||||||
|
Cause: err,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
drm.logger.Info("TURN DNS records created",
|
||||||
|
zap.String("namespace", namespaceName),
|
||||||
|
zap.Int("record_count", len(turnIPs)),
|
||||||
|
)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteTURNRecords deletes all TURN DNS records for a namespace.
|
||||||
|
func (drm *DNSRecordManager) DeleteTURNRecords(ctx context.Context, namespaceName string) error {
|
||||||
|
internalCtx := client.WithInternalAuth(ctx)
|
||||||
|
|
||||||
|
drm.logger.Info("Deleting TURN DNS records",
|
||||||
|
zap.String("namespace", namespaceName),
|
||||||
|
)
|
||||||
|
|
||||||
|
deleteQuery := `DELETE FROM dns_records WHERE namespace = ?`
|
||||||
|
_, err := drm.db.Exec(internalCtx, deleteQuery, "namespace-turn:"+namespaceName)
|
||||||
|
if err != nil {
|
||||||
|
return &ClusterError{
|
||||||
|
Message: "failed to delete TURN DNS records",
|
||||||
|
Cause: err,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// EnableNamespaceRecord marks a specific IP's record as active (for recovery)
|
// EnableNamespaceRecord marks a specific IP's record as active (for recovery)
|
||||||
func (drm *DNSRecordManager) EnableNamespaceRecord(ctx context.Context, namespaceName, ip string) error {
|
func (drm *DNSRecordManager) EnableNamespaceRecord(ctx context.Context, namespaceName, ip string) error {
|
||||||
internalCtx := client.WithInternalAuth(ctx)
|
internalCtx := client.WithInternalAuth(ctx)
|
||||||
|
|||||||
@ -10,7 +10,9 @@ import (
|
|||||||
"github.com/DeBrosOfficial/network/pkg/gateway"
|
"github.com/DeBrosOfficial/network/pkg/gateway"
|
||||||
"github.com/DeBrosOfficial/network/pkg/olric"
|
"github.com/DeBrosOfficial/network/pkg/olric"
|
||||||
"github.com/DeBrosOfficial/network/pkg/rqlite"
|
"github.com/DeBrosOfficial/network/pkg/rqlite"
|
||||||
|
"github.com/DeBrosOfficial/network/pkg/sfu"
|
||||||
"github.com/DeBrosOfficial/network/pkg/systemd"
|
"github.com/DeBrosOfficial/network/pkg/systemd"
|
||||||
|
"github.com/DeBrosOfficial/network/pkg/turn"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
"gopkg.in/yaml.v3"
|
"gopkg.in/yaml.v3"
|
||||||
)
|
)
|
||||||
@ -289,6 +291,185 @@ func (s *SystemdSpawner) StopGateway(ctx context.Context, namespace, nodeID stri
|
|||||||
return s.systemdMgr.StopService(namespace, systemd.ServiceTypeGateway)
|
return s.systemdMgr.StopService(namespace, systemd.ServiceTypeGateway)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SFUInstanceConfig holds configuration for spawning an SFU instance
|
||||||
|
type SFUInstanceConfig struct {
|
||||||
|
Namespace string
|
||||||
|
NodeID string
|
||||||
|
ListenAddr string // WireGuard IP:port (e.g., "10.0.0.1:30000")
|
||||||
|
MediaPortStart int // Start of RTP media port range
|
||||||
|
MediaPortEnd int // End of RTP media port range
|
||||||
|
TURNServers []sfu.TURNServerConfig // TURN servers to advertise to peers
|
||||||
|
TURNSecret string // HMAC-SHA1 shared secret
|
||||||
|
TURNCredTTL int // Credential TTL in seconds
|
||||||
|
RQLiteDSN string // Namespace-local RQLite DSN
|
||||||
|
}
|
||||||
|
|
||||||
|
// SpawnSFU starts an SFU instance using systemd
|
||||||
|
func (s *SystemdSpawner) SpawnSFU(ctx context.Context, namespace, nodeID string, cfg SFUInstanceConfig) error {
|
||||||
|
s.logger.Info("Spawning SFU via systemd",
|
||||||
|
zap.String("namespace", namespace),
|
||||||
|
zap.String("node_id", nodeID),
|
||||||
|
zap.String("listen_addr", cfg.ListenAddr))
|
||||||
|
|
||||||
|
// Create config directory
|
||||||
|
configDir := filepath.Join(s.namespaceBase, namespace, "configs")
|
||||||
|
if err := os.MkdirAll(configDir, 0755); err != nil {
|
||||||
|
return fmt.Errorf("failed to create config directory: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
configPath := filepath.Join(configDir, fmt.Sprintf("sfu-%s.yaml", nodeID))
|
||||||
|
|
||||||
|
// Build SFU YAML config
|
||||||
|
sfuConfig := sfu.Config{
|
||||||
|
ListenAddr: cfg.ListenAddr,
|
||||||
|
Namespace: cfg.Namespace,
|
||||||
|
MediaPortStart: cfg.MediaPortStart,
|
||||||
|
MediaPortEnd: cfg.MediaPortEnd,
|
||||||
|
TURNServers: cfg.TURNServers,
|
||||||
|
TURNSecret: cfg.TURNSecret,
|
||||||
|
TURNCredentialTTL: cfg.TURNCredTTL,
|
||||||
|
RQLiteDSN: cfg.RQLiteDSN,
|
||||||
|
}
|
||||||
|
|
||||||
|
configBytes, err := yaml.Marshal(sfuConfig)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to marshal SFU config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := os.WriteFile(configPath, configBytes, 0644); err != nil {
|
||||||
|
return fmt.Errorf("failed to write SFU config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.logger.Info("Created SFU config file",
|
||||||
|
zap.String("path", configPath),
|
||||||
|
zap.String("namespace", namespace),
|
||||||
|
zap.String("node_id", nodeID))
|
||||||
|
|
||||||
|
// Generate environment file pointing to config
|
||||||
|
envVars := map[string]string{
|
||||||
|
"SFU_CONFIG": configPath,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.systemdMgr.GenerateEnvFile(namespace, nodeID, systemd.ServiceTypeSFU, envVars); err != nil {
|
||||||
|
return fmt.Errorf("failed to generate SFU env file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start the systemd service
|
||||||
|
if err := s.systemdMgr.StartService(namespace, systemd.ServiceTypeSFU); err != nil {
|
||||||
|
return fmt.Errorf("failed to start SFU service: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for service to be active
|
||||||
|
if err := s.waitForService(namespace, systemd.ServiceTypeSFU, 30*time.Second); err != nil {
|
||||||
|
return fmt.Errorf("SFU service did not become active: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.logger.Info("SFU spawned successfully via systemd",
|
||||||
|
zap.String("namespace", namespace),
|
||||||
|
zap.String("node_id", nodeID))
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// StopSFU stops an SFU instance
|
||||||
|
func (s *SystemdSpawner) StopSFU(ctx context.Context, namespace, nodeID string) error {
|
||||||
|
s.logger.Info("Stopping SFU via systemd",
|
||||||
|
zap.String("namespace", namespace),
|
||||||
|
zap.String("node_id", nodeID))
|
||||||
|
|
||||||
|
return s.systemdMgr.StopService(namespace, systemd.ServiceTypeSFU)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TURNInstanceConfig holds configuration for spawning a TURN instance
|
||||||
|
type TURNInstanceConfig struct {
|
||||||
|
Namespace string
|
||||||
|
NodeID string
|
||||||
|
ListenAddr string // e.g., "0.0.0.0:3478"
|
||||||
|
TLSListenAddr string // e.g., "0.0.0.0:443" (UDP, no conflict with Caddy TCP)
|
||||||
|
PublicIP string // Public IP for TURN relay allocations
|
||||||
|
Realm string // TURN realm (typically base domain)
|
||||||
|
AuthSecret string // HMAC-SHA1 shared secret
|
||||||
|
RelayPortStart int // Start of relay port range
|
||||||
|
RelayPortEnd int // End of relay port range
|
||||||
|
}
|
||||||
|
|
||||||
|
// SpawnTURN starts a TURN instance using systemd
|
||||||
|
func (s *SystemdSpawner) SpawnTURN(ctx context.Context, namespace, nodeID string, cfg TURNInstanceConfig) error {
|
||||||
|
s.logger.Info("Spawning TURN via systemd",
|
||||||
|
zap.String("namespace", namespace),
|
||||||
|
zap.String("node_id", nodeID),
|
||||||
|
zap.String("listen_addr", cfg.ListenAddr),
|
||||||
|
zap.String("public_ip", cfg.PublicIP))
|
||||||
|
|
||||||
|
// Create config directory
|
||||||
|
configDir := filepath.Join(s.namespaceBase, namespace, "configs")
|
||||||
|
if err := os.MkdirAll(configDir, 0755); err != nil {
|
||||||
|
return fmt.Errorf("failed to create config directory: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
configPath := filepath.Join(configDir, fmt.Sprintf("turn-%s.yaml", nodeID))
|
||||||
|
|
||||||
|
// Build TURN YAML config
|
||||||
|
turnConfig := turn.Config{
|
||||||
|
ListenAddr: cfg.ListenAddr,
|
||||||
|
TLSListenAddr: cfg.TLSListenAddr,
|
||||||
|
PublicIP: cfg.PublicIP,
|
||||||
|
Realm: cfg.Realm,
|
||||||
|
AuthSecret: cfg.AuthSecret,
|
||||||
|
RelayPortStart: cfg.RelayPortStart,
|
||||||
|
RelayPortEnd: cfg.RelayPortEnd,
|
||||||
|
Namespace: cfg.Namespace,
|
||||||
|
}
|
||||||
|
|
||||||
|
configBytes, err := yaml.Marshal(turnConfig)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to marshal TURN config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := os.WriteFile(configPath, configBytes, 0644); err != nil {
|
||||||
|
return fmt.Errorf("failed to write TURN config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.logger.Info("Created TURN config file",
|
||||||
|
zap.String("path", configPath),
|
||||||
|
zap.String("namespace", namespace),
|
||||||
|
zap.String("node_id", nodeID))
|
||||||
|
|
||||||
|
// Generate environment file pointing to config
|
||||||
|
envVars := map[string]string{
|
||||||
|
"TURN_CONFIG": configPath,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.systemdMgr.GenerateEnvFile(namespace, nodeID, systemd.ServiceTypeTURN, envVars); err != nil {
|
||||||
|
return fmt.Errorf("failed to generate TURN env file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start the systemd service
|
||||||
|
if err := s.systemdMgr.StartService(namespace, systemd.ServiceTypeTURN); err != nil {
|
||||||
|
return fmt.Errorf("failed to start TURN service: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for service to be active
|
||||||
|
if err := s.waitForService(namespace, systemd.ServiceTypeTURN, 30*time.Second); err != nil {
|
||||||
|
return fmt.Errorf("TURN service did not become active: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.logger.Info("TURN spawned successfully via systemd",
|
||||||
|
zap.String("namespace", namespace),
|
||||||
|
zap.String("node_id", nodeID))
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// StopTURN stops a TURN instance
|
||||||
|
func (s *SystemdSpawner) StopTURN(ctx context.Context, namespace, nodeID string) error {
|
||||||
|
s.logger.Info("Stopping TURN via systemd",
|
||||||
|
zap.String("namespace", namespace),
|
||||||
|
zap.String("node_id", nodeID))
|
||||||
|
|
||||||
|
return s.systemdMgr.StopService(namespace, systemd.ServiceTypeTURN)
|
||||||
|
}
|
||||||
|
|
||||||
// SaveClusterState writes cluster state JSON to the namespace data directory.
|
// SaveClusterState writes cluster state JSON to the namespace data directory.
|
||||||
// Used by the spawn handler to persist state received from the coordinator node.
|
// Used by the spawn handler to persist state received from the coordinator node.
|
||||||
func (s *SystemdSpawner) SaveClusterState(namespace string, data []byte) error {
|
func (s *SystemdSpawner) SaveClusterState(namespace string, data []byte) error {
|
||||||
|
|||||||
@ -24,6 +24,8 @@ const (
|
|||||||
NodeRoleRQLiteFollower NodeRole = "rqlite_follower"
|
NodeRoleRQLiteFollower NodeRole = "rqlite_follower"
|
||||||
NodeRoleOlric NodeRole = "olric"
|
NodeRoleOlric NodeRole = "olric"
|
||||||
NodeRoleGateway NodeRole = "gateway"
|
NodeRoleGateway NodeRole = "gateway"
|
||||||
|
NodeRoleSFU NodeRole = "sfu"
|
||||||
|
NodeRoleTURN NodeRole = "turn"
|
||||||
)
|
)
|
||||||
|
|
||||||
// NodeStatus represents the status of a service on a node
|
// NodeStatus represents the status of a service on a node
|
||||||
@ -62,6 +64,12 @@ const (
|
|||||||
EventNodeReplaced EventType = "node_replaced"
|
EventNodeReplaced EventType = "node_replaced"
|
||||||
EventRecoveryComplete EventType = "recovery_complete"
|
EventRecoveryComplete EventType = "recovery_complete"
|
||||||
EventRecoveryFailed EventType = "recovery_failed"
|
EventRecoveryFailed EventType = "recovery_failed"
|
||||||
|
EventWebRTCEnabled EventType = "webrtc_enabled"
|
||||||
|
EventWebRTCDisabled EventType = "webrtc_disabled"
|
||||||
|
EventSFUStarted EventType = "sfu_started"
|
||||||
|
EventSFUStopped EventType = "sfu_stopped"
|
||||||
|
EventTURNStarted EventType = "turn_started"
|
||||||
|
EventTURNStopped EventType = "turn_stopped"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Port allocation constants
|
// Port allocation constants
|
||||||
@ -80,6 +88,39 @@ const (
|
|||||||
MaxNamespacesPerNode = (NamespacePortRangeEnd - NamespacePortRangeStart + 1) / PortsPerNamespace // 20
|
MaxNamespacesPerNode = (NamespacePortRangeEnd - NamespacePortRangeStart + 1) / PortsPerNamespace // 20
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// WebRTC port allocation constants
|
||||||
|
// These are separate from the core namespace port range (10000-10099)
|
||||||
|
// to avoid breaking existing port blocks.
|
||||||
|
const (
|
||||||
|
// SFU media port range: 20000-29999
|
||||||
|
// Each namespace gets a 500-port sub-range for RTP media
|
||||||
|
SFUMediaPortRangeStart = 20000
|
||||||
|
SFUMediaPortRangeEnd = 29999
|
||||||
|
SFUMediaPortsPerNamespace = 500
|
||||||
|
|
||||||
|
// SFU signaling ports: 30000-30099
|
||||||
|
// Each namespace gets 1 signaling port per node
|
||||||
|
SFUSignalingPortRangeStart = 30000
|
||||||
|
SFUSignalingPortRangeEnd = 30099
|
||||||
|
|
||||||
|
// TURN relay port range: 49152-65535
|
||||||
|
// Each namespace gets an 800-port sub-range for TURN relay
|
||||||
|
TURNRelayPortRangeStart = 49152
|
||||||
|
TURNRelayPortRangeEnd = 65535
|
||||||
|
TURNRelayPortsPerNamespace = 800
|
||||||
|
|
||||||
|
// TURN listen ports (standard)
|
||||||
|
TURNDefaultPort = 3478
|
||||||
|
TURNTLSPort = 443
|
||||||
|
|
||||||
|
// Default TURN credential TTL in seconds (10 minutes)
|
||||||
|
DefaultTURNCredentialTTL = 600
|
||||||
|
|
||||||
|
// Default service counts per namespace
|
||||||
|
DefaultSFUNodeCount = 3 // SFU on all 3 nodes
|
||||||
|
DefaultTURNNodeCount = 2 // TURN on 2 of 3 nodes for HA
|
||||||
|
)
|
||||||
|
|
||||||
// Default cluster sizes
|
// Default cluster sizes
|
||||||
const (
|
const (
|
||||||
DefaultRQLiteNodeCount = 3
|
DefaultRQLiteNodeCount = 3
|
||||||
@ -206,4 +247,58 @@ var (
|
|||||||
ErrNamespaceNotFound = &ClusterError{Message: "namespace not found"}
|
ErrNamespaceNotFound = &ClusterError{Message: "namespace not found"}
|
||||||
ErrInvalidClusterStatus = &ClusterError{Message: "invalid cluster status for operation"}
|
ErrInvalidClusterStatus = &ClusterError{Message: "invalid cluster status for operation"}
|
||||||
ErrRecoveryInProgress = &ClusterError{Message: "recovery already in progress for this cluster"}
|
ErrRecoveryInProgress = &ClusterError{Message: "recovery already in progress for this cluster"}
|
||||||
|
ErrWebRTCAlreadyEnabled = &ClusterError{Message: "WebRTC is already enabled for this namespace"}
|
||||||
|
ErrWebRTCNotEnabled = &ClusterError{Message: "WebRTC is not enabled for this namespace"}
|
||||||
|
ErrNoWebRTCPortsAvailable = &ClusterError{Message: "no WebRTC ports available on node"}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// WebRTCConfig represents the per-namespace WebRTC configuration stored in the database
|
||||||
|
type WebRTCConfig struct {
|
||||||
|
ID string `json:"id" db:"id"`
|
||||||
|
NamespaceClusterID string `json:"namespace_cluster_id" db:"namespace_cluster_id"`
|
||||||
|
NamespaceName string `json:"namespace_name" db:"namespace_name"`
|
||||||
|
Enabled bool `json:"enabled" db:"enabled"`
|
||||||
|
TURNSharedSecret string `json:"-" db:"turn_shared_secret"` // Never serialize secret to JSON
|
||||||
|
TURNCredentialTTL int `json:"turn_credential_ttl" db:"turn_credential_ttl"`
|
||||||
|
SFUNodeCount int `json:"sfu_node_count" db:"sfu_node_count"`
|
||||||
|
TURNNodeCount int `json:"turn_node_count" db:"turn_node_count"`
|
||||||
|
EnabledBy string `json:"enabled_by" db:"enabled_by"`
|
||||||
|
EnabledAt time.Time `json:"enabled_at" db:"enabled_at"`
|
||||||
|
DisabledAt *time.Time `json:"disabled_at,omitempty" db:"disabled_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// WebRTCRoom represents an active WebRTC room tracked in the database
|
||||||
|
type WebRTCRoom struct {
|
||||||
|
ID string `json:"id" db:"id"`
|
||||||
|
NamespaceClusterID string `json:"namespace_cluster_id" db:"namespace_cluster_id"`
|
||||||
|
NamespaceName string `json:"namespace_name" db:"namespace_name"`
|
||||||
|
RoomID string `json:"room_id" db:"room_id"`
|
||||||
|
SFUNodeID string `json:"sfu_node_id" db:"sfu_node_id"`
|
||||||
|
SFUInternalIP string `json:"sfu_internal_ip" db:"sfu_internal_ip"`
|
||||||
|
SFUSignalingPort int `json:"sfu_signaling_port" db:"sfu_signaling_port"`
|
||||||
|
ParticipantCount int `json:"participant_count" db:"participant_count"`
|
||||||
|
MaxParticipants int `json:"max_participants" db:"max_participants"`
|
||||||
|
CreatedAt time.Time `json:"created_at" db:"created_at"`
|
||||||
|
LastActivity time.Time `json:"last_activity" db:"last_activity"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// WebRTCPortBlock represents allocated WebRTC ports for a namespace on a node
|
||||||
|
type WebRTCPortBlock struct {
|
||||||
|
ID string `json:"id" db:"id"`
|
||||||
|
NodeID string `json:"node_id" db:"node_id"`
|
||||||
|
NamespaceClusterID string `json:"namespace_cluster_id" db:"namespace_cluster_id"`
|
||||||
|
ServiceType string `json:"service_type" db:"service_type"` // "sfu" or "turn"
|
||||||
|
|
||||||
|
// SFU ports
|
||||||
|
SFUSignalingPort int `json:"sfu_signaling_port,omitempty" db:"sfu_signaling_port"`
|
||||||
|
SFUMediaPortStart int `json:"sfu_media_port_start,omitempty" db:"sfu_media_port_start"`
|
||||||
|
SFUMediaPortEnd int `json:"sfu_media_port_end,omitempty" db:"sfu_media_port_end"`
|
||||||
|
|
||||||
|
// TURN ports
|
||||||
|
TURNListenPort int `json:"turn_listen_port,omitempty" db:"turn_listen_port"`
|
||||||
|
TURNTLSPort int `json:"turn_tls_port,omitempty" db:"turn_tls_port"`
|
||||||
|
TURNRelayPortStart int `json:"turn_relay_port_start,omitempty" db:"turn_relay_port_start"`
|
||||||
|
TURNRelayPortEnd int `json:"turn_relay_port_end,omitempty" db:"turn_relay_port_end"`
|
||||||
|
|
||||||
|
AllocatedAt time.Time `json:"allocated_at" db:"allocated_at"`
|
||||||
|
}
|
||||||
|
|||||||
519
pkg/namespace/webrtc_port_allocator.go
Normal file
519
pkg/namespace/webrtc_port_allocator.go
Normal file
@ -0,0 +1,519 @@
|
|||||||
|
package namespace
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/DeBrosOfficial/network/pkg/client"
|
||||||
|
"github.com/DeBrosOfficial/network/pkg/rqlite"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
// WebRTCPortAllocator manages port allocation for SFU and TURN services.
|
||||||
|
// Uses the webrtc_port_allocations table, separate from namespace_port_allocations,
|
||||||
|
// to avoid breaking existing port blocks.
|
||||||
|
type WebRTCPortAllocator struct {
|
||||||
|
db rqlite.Client
|
||||||
|
logger *zap.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewWebRTCPortAllocator creates a new WebRTC port allocator
|
||||||
|
func NewWebRTCPortAllocator(db rqlite.Client, logger *zap.Logger) *WebRTCPortAllocator {
|
||||||
|
return &WebRTCPortAllocator{
|
||||||
|
db: db,
|
||||||
|
logger: logger.With(zap.String("component", "webrtc-port-allocator")),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AllocateSFUPorts allocates SFU ports for a namespace on a node.
|
||||||
|
// Each namespace gets: 1 signaling port (30000-30099) + 500 media ports (20000-29999).
|
||||||
|
// Returns the existing allocation if one already exists (idempotent).
|
||||||
|
func (wpa *WebRTCPortAllocator) AllocateSFUPorts(ctx context.Context, nodeID, namespaceClusterID string) (*WebRTCPortBlock, error) {
|
||||||
|
internalCtx := client.WithInternalAuth(ctx)
|
||||||
|
|
||||||
|
// Check for existing allocation (idempotent)
|
||||||
|
existing, err := wpa.GetSFUPorts(ctx, namespaceClusterID, nodeID)
|
||||||
|
if err == nil && existing != nil {
|
||||||
|
wpa.logger.Debug("SFU ports already allocated",
|
||||||
|
zap.String("node_id", nodeID),
|
||||||
|
zap.String("namespace_cluster_id", namespaceClusterID),
|
||||||
|
zap.Int("signaling_port", existing.SFUSignalingPort),
|
||||||
|
)
|
||||||
|
return existing, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Retry logic for concurrent allocation conflicts
|
||||||
|
maxRetries := 10
|
||||||
|
retryDelay := 100 * time.Millisecond
|
||||||
|
|
||||||
|
for attempt := 0; attempt < maxRetries; attempt++ {
|
||||||
|
block, err := wpa.tryAllocateSFUPorts(internalCtx, nodeID, namespaceClusterID)
|
||||||
|
if err == nil {
|
||||||
|
wpa.logger.Info("SFU ports allocated",
|
||||||
|
zap.String("node_id", nodeID),
|
||||||
|
zap.String("namespace_cluster_id", namespaceClusterID),
|
||||||
|
zap.Int("signaling_port", block.SFUSignalingPort),
|
||||||
|
zap.Int("media_start", block.SFUMediaPortStart),
|
||||||
|
zap.Int("media_end", block.SFUMediaPortEnd),
|
||||||
|
zap.Int("attempt", attempt+1),
|
||||||
|
)
|
||||||
|
return block, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if isConflictError(err) {
|
||||||
|
wpa.logger.Debug("SFU port allocation conflict, retrying",
|
||||||
|
zap.String("node_id", nodeID),
|
||||||
|
zap.Int("attempt", attempt+1),
|
||||||
|
zap.Error(err),
|
||||||
|
)
|
||||||
|
time.Sleep(retryDelay)
|
||||||
|
retryDelay *= 2
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, &ClusterError{
|
||||||
|
Message: fmt.Sprintf("failed to allocate SFU ports after %d retries", maxRetries),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// tryAllocateSFUPorts performs a single attempt to allocate SFU ports.
|
||||||
|
func (wpa *WebRTCPortAllocator) tryAllocateSFUPorts(ctx context.Context, nodeID, namespaceClusterID string) (*WebRTCPortBlock, error) {
|
||||||
|
// Get node IPs sharing the same physical address (dev environment handling)
|
||||||
|
nodeIDs, err := wpa.getColocatedNodeIDs(ctx, nodeID)
|
||||||
|
if err != nil {
|
||||||
|
nodeIDs = []string{nodeID}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find next available SFU signaling port (30000-30099)
|
||||||
|
signalingPort, err := wpa.findAvailablePort(ctx, nodeIDs, "sfu", "sfu_signaling_port",
|
||||||
|
SFUSignalingPortRangeStart, SFUSignalingPortRangeEnd, 1)
|
||||||
|
if err != nil {
|
||||||
|
return nil, &ClusterError{
|
||||||
|
Message: "no SFU signaling port available on node",
|
||||||
|
Cause: err,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find next available SFU media port block (20000-29999, 500 per namespace)
|
||||||
|
mediaStart, err := wpa.findAvailablePortBlock(ctx, nodeIDs, "sfu", "sfu_media_port_start",
|
||||||
|
SFUMediaPortRangeStart, SFUMediaPortRangeEnd, SFUMediaPortsPerNamespace)
|
||||||
|
if err != nil {
|
||||||
|
return nil, &ClusterError{
|
||||||
|
Message: "no SFU media port range available on node",
|
||||||
|
Cause: err,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
block := &WebRTCPortBlock{
|
||||||
|
ID: uuid.New().String(),
|
||||||
|
NodeID: nodeID,
|
||||||
|
NamespaceClusterID: namespaceClusterID,
|
||||||
|
ServiceType: "sfu",
|
||||||
|
SFUSignalingPort: signalingPort,
|
||||||
|
SFUMediaPortStart: mediaStart,
|
||||||
|
SFUMediaPortEnd: mediaStart + SFUMediaPortsPerNamespace - 1,
|
||||||
|
AllocatedAt: time.Now(),
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := wpa.insertPortBlock(ctx, block); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return block, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AllocateTURNPorts allocates TURN ports for a namespace on a node.
|
||||||
|
// Each namespace gets: standard listen ports (3478/443) + 800 relay ports (49152-65535).
|
||||||
|
// Returns the existing allocation if one already exists (idempotent).
|
||||||
|
func (wpa *WebRTCPortAllocator) AllocateTURNPorts(ctx context.Context, nodeID, namespaceClusterID string) (*WebRTCPortBlock, error) {
|
||||||
|
internalCtx := client.WithInternalAuth(ctx)
|
||||||
|
|
||||||
|
// Check for existing allocation (idempotent)
|
||||||
|
existing, err := wpa.GetTURNPorts(ctx, namespaceClusterID, nodeID)
|
||||||
|
if err == nil && existing != nil {
|
||||||
|
wpa.logger.Debug("TURN ports already allocated",
|
||||||
|
zap.String("node_id", nodeID),
|
||||||
|
zap.String("namespace_cluster_id", namespaceClusterID),
|
||||||
|
)
|
||||||
|
return existing, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Retry logic for concurrent allocation conflicts
|
||||||
|
maxRetries := 10
|
||||||
|
retryDelay := 100 * time.Millisecond
|
||||||
|
|
||||||
|
for attempt := 0; attempt < maxRetries; attempt++ {
|
||||||
|
block, err := wpa.tryAllocateTURNPorts(internalCtx, nodeID, namespaceClusterID)
|
||||||
|
if err == nil {
|
||||||
|
wpa.logger.Info("TURN ports allocated",
|
||||||
|
zap.String("node_id", nodeID),
|
||||||
|
zap.String("namespace_cluster_id", namespaceClusterID),
|
||||||
|
zap.Int("relay_start", block.TURNRelayPortStart),
|
||||||
|
zap.Int("relay_end", block.TURNRelayPortEnd),
|
||||||
|
zap.Int("attempt", attempt+1),
|
||||||
|
)
|
||||||
|
return block, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if isConflictError(err) {
|
||||||
|
wpa.logger.Debug("TURN port allocation conflict, retrying",
|
||||||
|
zap.String("node_id", nodeID),
|
||||||
|
zap.Int("attempt", attempt+1),
|
||||||
|
zap.Error(err),
|
||||||
|
)
|
||||||
|
time.Sleep(retryDelay)
|
||||||
|
retryDelay *= 2
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, &ClusterError{
|
||||||
|
Message: fmt.Sprintf("failed to allocate TURN ports after %d retries", maxRetries),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// tryAllocateTURNPorts performs a single attempt to allocate TURN ports.
|
||||||
|
func (wpa *WebRTCPortAllocator) tryAllocateTURNPorts(ctx context.Context, nodeID, namespaceClusterID string) (*WebRTCPortBlock, error) {
|
||||||
|
// Get colocated node IDs (dev environment handling)
|
||||||
|
nodeIDs, err := wpa.getColocatedNodeIDs(ctx, nodeID)
|
||||||
|
if err != nil {
|
||||||
|
nodeIDs = []string{nodeID}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find next available TURN relay port block (49152-65535, 800 per namespace)
|
||||||
|
relayStart, err := wpa.findAvailablePortBlock(ctx, nodeIDs, "turn", "turn_relay_port_start",
|
||||||
|
TURNRelayPortRangeStart, TURNRelayPortRangeEnd, TURNRelayPortsPerNamespace)
|
||||||
|
if err != nil {
|
||||||
|
return nil, &ClusterError{
|
||||||
|
Message: "no TURN relay port range available on node",
|
||||||
|
Cause: err,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
block := &WebRTCPortBlock{
|
||||||
|
ID: uuid.New().String(),
|
||||||
|
NodeID: nodeID,
|
||||||
|
NamespaceClusterID: namespaceClusterID,
|
||||||
|
ServiceType: "turn",
|
||||||
|
TURNListenPort: TURNDefaultPort,
|
||||||
|
TURNTLSPort: TURNTLSPort,
|
||||||
|
TURNRelayPortStart: relayStart,
|
||||||
|
TURNRelayPortEnd: relayStart + TURNRelayPortsPerNamespace - 1,
|
||||||
|
AllocatedAt: time.Now(),
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := wpa.insertPortBlock(ctx, block); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return block, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeallocateAll releases all WebRTC port blocks for a namespace cluster.
|
||||||
|
func (wpa *WebRTCPortAllocator) DeallocateAll(ctx context.Context, namespaceClusterID string) error {
|
||||||
|
internalCtx := client.WithInternalAuth(ctx)
|
||||||
|
|
||||||
|
query := `DELETE FROM webrtc_port_allocations WHERE namespace_cluster_id = ?`
|
||||||
|
_, err := wpa.db.Exec(internalCtx, query, namespaceClusterID)
|
||||||
|
if err != nil {
|
||||||
|
return &ClusterError{
|
||||||
|
Message: "failed to deallocate WebRTC port blocks",
|
||||||
|
Cause: err,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
wpa.logger.Info("All WebRTC port blocks deallocated",
|
||||||
|
zap.String("namespace_cluster_id", namespaceClusterID),
|
||||||
|
)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeallocateByNode releases WebRTC port blocks for a specific node and service type.
|
||||||
|
func (wpa *WebRTCPortAllocator) DeallocateByNode(ctx context.Context, namespaceClusterID, nodeID, serviceType string) error {
|
||||||
|
internalCtx := client.WithInternalAuth(ctx)
|
||||||
|
|
||||||
|
query := `DELETE FROM webrtc_port_allocations WHERE namespace_cluster_id = ? AND node_id = ? AND service_type = ?`
|
||||||
|
_, err := wpa.db.Exec(internalCtx, query, namespaceClusterID, nodeID, serviceType)
|
||||||
|
if err != nil {
|
||||||
|
return &ClusterError{
|
||||||
|
Message: fmt.Sprintf("failed to deallocate %s port block on node %s", serviceType, nodeID),
|
||||||
|
Cause: err,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
wpa.logger.Info("WebRTC port block deallocated",
|
||||||
|
zap.String("namespace_cluster_id", namespaceClusterID),
|
||||||
|
zap.String("node_id", nodeID),
|
||||||
|
zap.String("service_type", serviceType),
|
||||||
|
)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSFUPorts retrieves the SFU port allocation for a namespace on a node.
|
||||||
|
func (wpa *WebRTCPortAllocator) GetSFUPorts(ctx context.Context, namespaceClusterID, nodeID string) (*WebRTCPortBlock, error) {
|
||||||
|
return wpa.getPortBlock(ctx, namespaceClusterID, nodeID, "sfu")
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetTURNPorts retrieves the TURN port allocation for a namespace on a node.
|
||||||
|
func (wpa *WebRTCPortAllocator) GetTURNPorts(ctx context.Context, namespaceClusterID, nodeID string) (*WebRTCPortBlock, error) {
|
||||||
|
return wpa.getPortBlock(ctx, namespaceClusterID, nodeID, "turn")
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAllPorts retrieves all WebRTC port blocks for a namespace cluster.
|
||||||
|
func (wpa *WebRTCPortAllocator) GetAllPorts(ctx context.Context, namespaceClusterID string) ([]WebRTCPortBlock, error) {
|
||||||
|
internalCtx := client.WithInternalAuth(ctx)
|
||||||
|
|
||||||
|
var blocks []WebRTCPortBlock
|
||||||
|
query := `
|
||||||
|
SELECT id, node_id, namespace_cluster_id, service_type,
|
||||||
|
sfu_signaling_port, sfu_media_port_start, sfu_media_port_end,
|
||||||
|
turn_listen_port, turn_tls_port, turn_relay_port_start, turn_relay_port_end,
|
||||||
|
allocated_at
|
||||||
|
FROM webrtc_port_allocations
|
||||||
|
WHERE namespace_cluster_id = ?
|
||||||
|
ORDER BY service_type, node_id
|
||||||
|
`
|
||||||
|
err := wpa.db.Query(internalCtx, &blocks, query, namespaceClusterID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, &ClusterError{
|
||||||
|
Message: "failed to query WebRTC port blocks",
|
||||||
|
Cause: err,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return blocks, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NodeHasTURN checks if a node already has a TURN allocation from any namespace.
|
||||||
|
// Used during node selection to avoid port conflicts on standard TURN ports (3478/443).
|
||||||
|
func (wpa *WebRTCPortAllocator) NodeHasTURN(ctx context.Context, nodeID string) (bool, error) {
|
||||||
|
internalCtx := client.WithInternalAuth(ctx)
|
||||||
|
|
||||||
|
type countResult struct {
|
||||||
|
Count int `db:"count"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var results []countResult
|
||||||
|
query := `SELECT COUNT(*) as count FROM webrtc_port_allocations WHERE node_id = ? AND service_type = 'turn'`
|
||||||
|
err := wpa.db.Query(internalCtx, &results, query, nodeID)
|
||||||
|
if err != nil {
|
||||||
|
return false, &ClusterError{
|
||||||
|
Message: "failed to check TURN allocation on node",
|
||||||
|
Cause: err,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(results) == 0 {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return results[0].Count > 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- internal helpers ---
|
||||||
|
|
||||||
|
// getPortBlock retrieves a specific port block by cluster, node, and service type.
|
||||||
|
func (wpa *WebRTCPortAllocator) getPortBlock(ctx context.Context, namespaceClusterID, nodeID, serviceType string) (*WebRTCPortBlock, error) {
|
||||||
|
internalCtx := client.WithInternalAuth(ctx)
|
||||||
|
|
||||||
|
var blocks []WebRTCPortBlock
|
||||||
|
query := `
|
||||||
|
SELECT id, node_id, namespace_cluster_id, service_type,
|
||||||
|
sfu_signaling_port, sfu_media_port_start, sfu_media_port_end,
|
||||||
|
turn_listen_port, turn_tls_port, turn_relay_port_start, turn_relay_port_end,
|
||||||
|
allocated_at
|
||||||
|
FROM webrtc_port_allocations
|
||||||
|
WHERE namespace_cluster_id = ? AND node_id = ? AND service_type = ?
|
||||||
|
LIMIT 1
|
||||||
|
`
|
||||||
|
err := wpa.db.Query(internalCtx, &blocks, query, namespaceClusterID, nodeID, serviceType)
|
||||||
|
if err != nil {
|
||||||
|
return nil, &ClusterError{
|
||||||
|
Message: fmt.Sprintf("failed to query %s port block", serviceType),
|
||||||
|
Cause: err,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(blocks) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return &blocks[0], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// insertPortBlock inserts a WebRTC port allocation record.
|
||||||
|
func (wpa *WebRTCPortAllocator) insertPortBlock(ctx context.Context, block *WebRTCPortBlock) error {
|
||||||
|
query := `
|
||||||
|
INSERT INTO webrtc_port_allocations (
|
||||||
|
id, node_id, namespace_cluster_id, service_type,
|
||||||
|
sfu_signaling_port, sfu_media_port_start, sfu_media_port_end,
|
||||||
|
turn_listen_port, turn_tls_port, turn_relay_port_start, turn_relay_port_end,
|
||||||
|
allocated_at
|
||||||
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
|
`
|
||||||
|
_, err := wpa.db.Exec(ctx, query,
|
||||||
|
block.ID,
|
||||||
|
block.NodeID,
|
||||||
|
block.NamespaceClusterID,
|
||||||
|
block.ServiceType,
|
||||||
|
block.SFUSignalingPort,
|
||||||
|
block.SFUMediaPortStart,
|
||||||
|
block.SFUMediaPortEnd,
|
||||||
|
block.TURNListenPort,
|
||||||
|
block.TURNTLSPort,
|
||||||
|
block.TURNRelayPortStart,
|
||||||
|
block.TURNRelayPortEnd,
|
||||||
|
block.AllocatedAt,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return &ClusterError{
|
||||||
|
Message: fmt.Sprintf("failed to insert %s port allocation", block.ServiceType),
|
||||||
|
Cause: err,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getColocatedNodeIDs returns all node IDs that share the same IP address as the given node.
|
||||||
|
// In dev environments, multiple logical nodes share one physical IP — port ranges must not overlap.
|
||||||
|
// In production (one node per IP), returns only the given nodeID.
|
||||||
|
func (wpa *WebRTCPortAllocator) getColocatedNodeIDs(ctx context.Context, nodeID string) ([]string, error) {
|
||||||
|
// Get this node's IP
|
||||||
|
type nodeInfo struct {
|
||||||
|
IPAddress string `db:"ip_address"`
|
||||||
|
}
|
||||||
|
var infos []nodeInfo
|
||||||
|
if err := wpa.db.Query(ctx, &infos, `SELECT ip_address FROM dns_nodes WHERE id = ? LIMIT 1`, nodeID); err != nil || len(infos) == 0 {
|
||||||
|
return []string{nodeID}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
ip := infos[0].IPAddress
|
||||||
|
if ip == "" {
|
||||||
|
return []string{nodeID}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if multiple nodes share this IP
|
||||||
|
type nodeIDRow struct {
|
||||||
|
ID string `db:"id"`
|
||||||
|
}
|
||||||
|
var colocated []nodeIDRow
|
||||||
|
if err := wpa.db.Query(ctx, &colocated, `SELECT id FROM dns_nodes WHERE ip_address = ?`, ip); err != nil || len(colocated) <= 1 {
|
||||||
|
return []string{nodeID}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
ids := make([]string, len(colocated))
|
||||||
|
for i, n := range colocated {
|
||||||
|
ids[i] = n.ID
|
||||||
|
}
|
||||||
|
|
||||||
|
wpa.logger.Debug("Multiple nodes share IP, allocating globally",
|
||||||
|
zap.String("ip_address", ip),
|
||||||
|
zap.Int("node_count", len(ids)),
|
||||||
|
)
|
||||||
|
|
||||||
|
return ids, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// findAvailablePort finds the next available single port in a range on the given nodes.
|
||||||
|
func (wpa *WebRTCPortAllocator) findAvailablePort(ctx context.Context, nodeIDs []string, serviceType, portColumn string, rangeStart, rangeEnd, step int) (int, error) {
|
||||||
|
allocated, err := wpa.getAllocatedValues(ctx, nodeIDs, serviceType, portColumn)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
allocatedSet := make(map[int]bool, len(allocated))
|
||||||
|
for _, v := range allocated {
|
||||||
|
allocatedSet[v] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
for port := rangeStart; port <= rangeEnd; port += step {
|
||||||
|
if !allocatedSet[port] {
|
||||||
|
return port, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0, ErrNoWebRTCPortsAvailable
|
||||||
|
}
|
||||||
|
|
||||||
|
// findAvailablePortBlock finds the next available contiguous port block in a range.
|
||||||
|
func (wpa *WebRTCPortAllocator) findAvailablePortBlock(ctx context.Context, nodeIDs []string, serviceType, portColumn string, rangeStart, rangeEnd, blockSize int) (int, error) {
|
||||||
|
allocated, err := wpa.getAllocatedValues(ctx, nodeIDs, serviceType, portColumn)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
allocatedSet := make(map[int]bool, len(allocated))
|
||||||
|
for _, v := range allocated {
|
||||||
|
allocatedSet[v] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
for start := rangeStart; start+blockSize-1 <= rangeEnd; start += blockSize {
|
||||||
|
if !allocatedSet[start] {
|
||||||
|
return start, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0, ErrNoWebRTCPortsAvailable
|
||||||
|
}
|
||||||
|
|
||||||
|
// getAllocatedValues queries the allocated port values for a given column across colocated nodes.
|
||||||
|
func (wpa *WebRTCPortAllocator) getAllocatedValues(ctx context.Context, nodeIDs []string, serviceType, portColumn string) ([]int, error) {
|
||||||
|
type portRow struct {
|
||||||
|
Port int `db:"port_val"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var rows []portRow
|
||||||
|
|
||||||
|
if len(nodeIDs) == 1 {
|
||||||
|
query := fmt.Sprintf(
|
||||||
|
`SELECT %s as port_val FROM webrtc_port_allocations WHERE node_id = ? AND service_type = ? AND %s > 0 ORDER BY %s ASC`,
|
||||||
|
portColumn, portColumn, portColumn,
|
||||||
|
)
|
||||||
|
if err := wpa.db.Query(ctx, &rows, query, nodeIDs[0], serviceType); err != nil {
|
||||||
|
return nil, &ClusterError{
|
||||||
|
Message: "failed to query allocated WebRTC ports",
|
||||||
|
Cause: err,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Multiple colocated nodes — query by joining with dns_nodes on IP
|
||||||
|
// Get the IP of the first node (they all share the same IP)
|
||||||
|
type nodeInfo struct {
|
||||||
|
IPAddress string `db:"ip_address"`
|
||||||
|
}
|
||||||
|
var infos []nodeInfo
|
||||||
|
if err := wpa.db.Query(ctx, &infos, `SELECT ip_address FROM dns_nodes WHERE id = ? LIMIT 1`, nodeIDs[0]); err != nil || len(infos) == 0 {
|
||||||
|
return nil, &ClusterError{Message: "failed to get node IP for colocated port query"}
|
||||||
|
}
|
||||||
|
|
||||||
|
query := fmt.Sprintf(
|
||||||
|
`SELECT wpa.%s as port_val FROM webrtc_port_allocations wpa
|
||||||
|
JOIN dns_nodes dn ON wpa.node_id = dn.id
|
||||||
|
WHERE dn.ip_address = ? AND wpa.service_type = ? AND wpa.%s > 0
|
||||||
|
ORDER BY wpa.%s ASC`,
|
||||||
|
portColumn, portColumn, portColumn,
|
||||||
|
)
|
||||||
|
if err := wpa.db.Query(ctx, &rows, query, infos[0].IPAddress, serviceType); err != nil {
|
||||||
|
return nil, &ClusterError{
|
||||||
|
Message: "failed to query allocated WebRTC ports (colocated)",
|
||||||
|
Cause: err,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
result := make([]int, len(rows))
|
||||||
|
for i, r := range rows {
|
||||||
|
result[i] = r.Port
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
337
pkg/namespace/webrtc_port_allocator_test.go
Normal file
337
pkg/namespace/webrtc_port_allocator_test.go
Normal file
@ -0,0 +1,337 @@
|
|||||||
|
package namespace
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestWebRTCPortConstants_NoOverlap(t *testing.T) {
|
||||||
|
// Verify WebRTC port ranges don't overlap with core namespace ports (10000-10099)
|
||||||
|
ranges := []struct {
|
||||||
|
name string
|
||||||
|
start int
|
||||||
|
end int
|
||||||
|
}{
|
||||||
|
{"core namespace", NamespacePortRangeStart, NamespacePortRangeEnd},
|
||||||
|
{"SFU media", SFUMediaPortRangeStart, SFUMediaPortRangeEnd},
|
||||||
|
{"SFU signaling", SFUSignalingPortRangeStart, SFUSignalingPortRangeEnd},
|
||||||
|
{"TURN relay", TURNRelayPortRangeStart, TURNRelayPortRangeEnd},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < len(ranges); i++ {
|
||||||
|
for j := i + 1; j < len(ranges); j++ {
|
||||||
|
a, b := ranges[i], ranges[j]
|
||||||
|
if a.start <= b.end && b.start <= a.end {
|
||||||
|
t.Errorf("Range overlap: %s (%d-%d) overlaps with %s (%d-%d)",
|
||||||
|
a.name, a.start, a.end, b.name, b.start, b.end)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWebRTCPortConstants_Capacity(t *testing.T) {
|
||||||
|
// SFU media: (29999-20000+1)/500 = 20 namespaces per node
|
||||||
|
sfuMediaCapacity := (SFUMediaPortRangeEnd - SFUMediaPortRangeStart + 1) / SFUMediaPortsPerNamespace
|
||||||
|
if sfuMediaCapacity < 20 {
|
||||||
|
t.Errorf("SFU media capacity = %d, want >= 20", sfuMediaCapacity)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SFU signaling: 30099-30000+1 = 100 ports → 100 namespaces per node
|
||||||
|
sfuSignalingCapacity := SFUSignalingPortRangeEnd - SFUSignalingPortRangeStart + 1
|
||||||
|
if sfuSignalingCapacity < 20 {
|
||||||
|
t.Errorf("SFU signaling capacity = %d, want >= 20", sfuSignalingCapacity)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TURN relay: (65535-49152+1)/800 = 20 namespaces per node
|
||||||
|
turnRelayCapacity := (TURNRelayPortRangeEnd - TURNRelayPortRangeStart + 1) / TURNRelayPortsPerNamespace
|
||||||
|
if turnRelayCapacity < 20 {
|
||||||
|
t.Errorf("TURN relay capacity = %d, want >= 20", turnRelayCapacity)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWebRTCPortConstants_Values(t *testing.T) {
|
||||||
|
if SFUMediaPortRangeStart != 20000 {
|
||||||
|
t.Errorf("SFUMediaPortRangeStart = %d, want 20000", SFUMediaPortRangeStart)
|
||||||
|
}
|
||||||
|
if SFUMediaPortRangeEnd != 29999 {
|
||||||
|
t.Errorf("SFUMediaPortRangeEnd = %d, want 29999", SFUMediaPortRangeEnd)
|
||||||
|
}
|
||||||
|
if SFUMediaPortsPerNamespace != 500 {
|
||||||
|
t.Errorf("SFUMediaPortsPerNamespace = %d, want 500", SFUMediaPortsPerNamespace)
|
||||||
|
}
|
||||||
|
if SFUSignalingPortRangeStart != 30000 {
|
||||||
|
t.Errorf("SFUSignalingPortRangeStart = %d, want 30000", SFUSignalingPortRangeStart)
|
||||||
|
}
|
||||||
|
if TURNRelayPortRangeStart != 49152 {
|
||||||
|
t.Errorf("TURNRelayPortRangeStart = %d, want 49152", TURNRelayPortRangeStart)
|
||||||
|
}
|
||||||
|
if TURNRelayPortsPerNamespace != 800 {
|
||||||
|
t.Errorf("TURNRelayPortsPerNamespace = %d, want 800", TURNRelayPortsPerNamespace)
|
||||||
|
}
|
||||||
|
if TURNDefaultPort != 3478 {
|
||||||
|
t.Errorf("TURNDefaultPort = %d, want 3478", TURNDefaultPort)
|
||||||
|
}
|
||||||
|
if DefaultSFUNodeCount != 3 {
|
||||||
|
t.Errorf("DefaultSFUNodeCount = %d, want 3", DefaultSFUNodeCount)
|
||||||
|
}
|
||||||
|
if DefaultTURNNodeCount != 2 {
|
||||||
|
t.Errorf("DefaultTURNNodeCount = %d, want 2", DefaultTURNNodeCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewWebRTCPortAllocator(t *testing.T) {
|
||||||
|
mockDB := newMockRQLiteClient()
|
||||||
|
allocator := NewWebRTCPortAllocator(mockDB, testLogger())
|
||||||
|
|
||||||
|
if allocator == nil {
|
||||||
|
t.Fatal("NewWebRTCPortAllocator returned nil")
|
||||||
|
}
|
||||||
|
if allocator.db != mockDB {
|
||||||
|
t.Error("allocator.db not set correctly")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWebRTCPortAllocator_AllocateSFUPorts(t *testing.T) {
|
||||||
|
mockDB := newMockRQLiteClient()
|
||||||
|
allocator := NewWebRTCPortAllocator(mockDB, testLogger())
|
||||||
|
|
||||||
|
block, err := allocator.AllocateSFUPorts(context.Background(), "node-1", "cluster-1")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("AllocateSFUPorts failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if block == nil {
|
||||||
|
t.Fatal("AllocateSFUPorts returned nil block")
|
||||||
|
}
|
||||||
|
|
||||||
|
if block.ServiceType != "sfu" {
|
||||||
|
t.Errorf("ServiceType = %q, want %q", block.ServiceType, "sfu")
|
||||||
|
}
|
||||||
|
if block.NodeID != "node-1" {
|
||||||
|
t.Errorf("NodeID = %q, want %q", block.NodeID, "node-1")
|
||||||
|
}
|
||||||
|
if block.NamespaceClusterID != "cluster-1" {
|
||||||
|
t.Errorf("NamespaceClusterID = %q, want %q", block.NamespaceClusterID, "cluster-1")
|
||||||
|
}
|
||||||
|
|
||||||
|
// First allocation should get the first port in each range
|
||||||
|
if block.SFUSignalingPort != SFUSignalingPortRangeStart {
|
||||||
|
t.Errorf("SFUSignalingPort = %d, want %d", block.SFUSignalingPort, SFUSignalingPortRangeStart)
|
||||||
|
}
|
||||||
|
if block.SFUMediaPortStart != SFUMediaPortRangeStart {
|
||||||
|
t.Errorf("SFUMediaPortStart = %d, want %d", block.SFUMediaPortStart, SFUMediaPortRangeStart)
|
||||||
|
}
|
||||||
|
if block.SFUMediaPortEnd != SFUMediaPortRangeStart+SFUMediaPortsPerNamespace-1 {
|
||||||
|
t.Errorf("SFUMediaPortEnd = %d, want %d", block.SFUMediaPortEnd, SFUMediaPortRangeStart+SFUMediaPortsPerNamespace-1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TURN fields should be zero for SFU allocation
|
||||||
|
if block.TURNListenPort != 0 {
|
||||||
|
t.Errorf("TURNListenPort = %d, want 0 for SFU allocation", block.TURNListenPort)
|
||||||
|
}
|
||||||
|
if block.TURNRelayPortStart != 0 {
|
||||||
|
t.Errorf("TURNRelayPortStart = %d, want 0 for SFU allocation", block.TURNRelayPortStart)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify INSERT was called
|
||||||
|
hasInsert := false
|
||||||
|
for _, call := range mockDB.execCalls {
|
||||||
|
if strings.Contains(call.Query, "INSERT INTO webrtc_port_allocations") {
|
||||||
|
hasInsert = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !hasInsert {
|
||||||
|
t.Error("expected INSERT INTO webrtc_port_allocations to be called")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWebRTCPortAllocator_AllocateTURNPorts(t *testing.T) {
|
||||||
|
mockDB := newMockRQLiteClient()
|
||||||
|
allocator := NewWebRTCPortAllocator(mockDB, testLogger())
|
||||||
|
|
||||||
|
block, err := allocator.AllocateTURNPorts(context.Background(), "node-1", "cluster-1")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("AllocateTURNPorts failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if block == nil {
|
||||||
|
t.Fatal("AllocateTURNPorts returned nil block")
|
||||||
|
}
|
||||||
|
|
||||||
|
if block.ServiceType != "turn" {
|
||||||
|
t.Errorf("ServiceType = %q, want %q", block.ServiceType, "turn")
|
||||||
|
}
|
||||||
|
if block.TURNListenPort != TURNDefaultPort {
|
||||||
|
t.Errorf("TURNListenPort = %d, want %d", block.TURNListenPort, TURNDefaultPort)
|
||||||
|
}
|
||||||
|
if block.TURNTLSPort != TURNTLSPort {
|
||||||
|
t.Errorf("TURNTLSPort = %d, want %d", block.TURNTLSPort, TURNTLSPort)
|
||||||
|
}
|
||||||
|
if block.TURNRelayPortStart != TURNRelayPortRangeStart {
|
||||||
|
t.Errorf("TURNRelayPortStart = %d, want %d", block.TURNRelayPortStart, TURNRelayPortRangeStart)
|
||||||
|
}
|
||||||
|
if block.TURNRelayPortEnd != TURNRelayPortRangeStart+TURNRelayPortsPerNamespace-1 {
|
||||||
|
t.Errorf("TURNRelayPortEnd = %d, want %d", block.TURNRelayPortEnd, TURNRelayPortRangeStart+TURNRelayPortsPerNamespace-1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SFU fields should be zero for TURN allocation
|
||||||
|
if block.SFUSignalingPort != 0 {
|
||||||
|
t.Errorf("SFUSignalingPort = %d, want 0 for TURN allocation", block.SFUSignalingPort)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWebRTCPortAllocator_DeallocateAll(t *testing.T) {
|
||||||
|
mockDB := newMockRQLiteClient()
|
||||||
|
allocator := NewWebRTCPortAllocator(mockDB, testLogger())
|
||||||
|
|
||||||
|
err := allocator.DeallocateAll(context.Background(), "cluster-1")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("DeallocateAll failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify DELETE was called with correct cluster ID
|
||||||
|
hasDelete := false
|
||||||
|
for _, call := range mockDB.execCalls {
|
||||||
|
if strings.Contains(call.Query, "DELETE FROM webrtc_port_allocations") &&
|
||||||
|
strings.Contains(call.Query, "namespace_cluster_id") {
|
||||||
|
hasDelete = true
|
||||||
|
if len(call.Args) < 1 || call.Args[0] != "cluster-1" {
|
||||||
|
t.Errorf("DELETE called with wrong cluster ID: %v", call.Args)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !hasDelete {
|
||||||
|
t.Error("expected DELETE FROM webrtc_port_allocations to be called")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWebRTCPortAllocator_DeallocateByNode(t *testing.T) {
|
||||||
|
mockDB := newMockRQLiteClient()
|
||||||
|
allocator := NewWebRTCPortAllocator(mockDB, testLogger())
|
||||||
|
|
||||||
|
err := allocator.DeallocateByNode(context.Background(), "cluster-1", "node-1", "sfu")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("DeallocateByNode failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify DELETE was called with correct parameters
|
||||||
|
hasDelete := false
|
||||||
|
for _, call := range mockDB.execCalls {
|
||||||
|
if strings.Contains(call.Query, "DELETE FROM webrtc_port_allocations") &&
|
||||||
|
strings.Contains(call.Query, "service_type") {
|
||||||
|
hasDelete = true
|
||||||
|
if len(call.Args) != 3 {
|
||||||
|
t.Fatalf("DELETE called with %d args, want 3", len(call.Args))
|
||||||
|
}
|
||||||
|
if call.Args[0] != "cluster-1" {
|
||||||
|
t.Errorf("arg[0] = %v, want cluster-1", call.Args[0])
|
||||||
|
}
|
||||||
|
if call.Args[1] != "node-1" {
|
||||||
|
t.Errorf("arg[1] = %v, want node-1", call.Args[1])
|
||||||
|
}
|
||||||
|
if call.Args[2] != "sfu" {
|
||||||
|
t.Errorf("arg[2] = %v, want sfu", call.Args[2])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !hasDelete {
|
||||||
|
t.Error("expected DELETE FROM webrtc_port_allocations to be called")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWebRTCPortAllocator_NodeHasTURN(t *testing.T) {
|
||||||
|
mockDB := newMockRQLiteClient()
|
||||||
|
allocator := NewWebRTCPortAllocator(mockDB, testLogger())
|
||||||
|
|
||||||
|
// Mock query returns empty results → no TURN on node
|
||||||
|
hasTURN, err := allocator.NodeHasTURN(context.Background(), "node-1")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NodeHasTURN failed: %v", err)
|
||||||
|
}
|
||||||
|
if hasTURN {
|
||||||
|
t.Error("expected NodeHasTURN = false for node with no allocations")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWebRTCPortAllocator_GetSFUPorts_NoAllocation(t *testing.T) {
|
||||||
|
mockDB := newMockRQLiteClient()
|
||||||
|
allocator := NewWebRTCPortAllocator(mockDB, testLogger())
|
||||||
|
|
||||||
|
block, err := allocator.GetSFUPorts(context.Background(), "cluster-1", "node-1")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetSFUPorts failed: %v", err)
|
||||||
|
}
|
||||||
|
if block != nil {
|
||||||
|
t.Error("expected nil block when no allocation exists")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWebRTCPortAllocator_GetTURNPorts_NoAllocation(t *testing.T) {
|
||||||
|
mockDB := newMockRQLiteClient()
|
||||||
|
allocator := NewWebRTCPortAllocator(mockDB, testLogger())
|
||||||
|
|
||||||
|
block, err := allocator.GetTURNPorts(context.Background(), "cluster-1", "node-1")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetTURNPorts failed: %v", err)
|
||||||
|
}
|
||||||
|
if block != nil {
|
||||||
|
t.Error("expected nil block when no allocation exists")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWebRTCPortAllocator_GetAllPorts_Empty(t *testing.T) {
|
||||||
|
mockDB := newMockRQLiteClient()
|
||||||
|
allocator := NewWebRTCPortAllocator(mockDB, testLogger())
|
||||||
|
|
||||||
|
blocks, err := allocator.GetAllPorts(context.Background(), "cluster-1")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetAllPorts failed: %v", err)
|
||||||
|
}
|
||||||
|
if len(blocks) != 0 {
|
||||||
|
t.Errorf("expected 0 blocks, got %d", len(blocks))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWebRTCPortBlock_SFUFields(t *testing.T) {
|
||||||
|
block := &WebRTCPortBlock{
|
||||||
|
ID: "test-id",
|
||||||
|
NodeID: "node-1",
|
||||||
|
NamespaceClusterID: "cluster-1",
|
||||||
|
ServiceType: "sfu",
|
||||||
|
SFUSignalingPort: 30000,
|
||||||
|
SFUMediaPortStart: 20000,
|
||||||
|
SFUMediaPortEnd: 20499,
|
||||||
|
}
|
||||||
|
|
||||||
|
mediaRange := block.SFUMediaPortEnd - block.SFUMediaPortStart + 1
|
||||||
|
if mediaRange != SFUMediaPortsPerNamespace {
|
||||||
|
t.Errorf("SFU media range = %d, want %d", mediaRange, SFUMediaPortsPerNamespace)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWebRTCPortBlock_TURNFields(t *testing.T) {
|
||||||
|
block := &WebRTCPortBlock{
|
||||||
|
ID: "test-id",
|
||||||
|
NodeID: "node-1",
|
||||||
|
NamespaceClusterID: "cluster-1",
|
||||||
|
ServiceType: "turn",
|
||||||
|
TURNListenPort: 3478,
|
||||||
|
TURNTLSPort: 443,
|
||||||
|
TURNRelayPortStart: 49152,
|
||||||
|
TURNRelayPortEnd: 49951,
|
||||||
|
}
|
||||||
|
|
||||||
|
relayRange := block.TURNRelayPortEnd - block.TURNRelayPortStart + 1
|
||||||
|
if relayRange != TURNRelayPortsPerNamespace {
|
||||||
|
t.Errorf("TURN relay range = %d, want %d", relayRange, TURNRelayPortsPerNamespace)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// testLogger returns a no-op logger for tests
|
||||||
|
func testLogger() *zap.Logger {
|
||||||
|
return zap.NewNop()
|
||||||
|
}
|
||||||
@ -57,7 +57,11 @@ func (n *Node) startHTTPGateway(ctx context.Context) error {
|
|||||||
IPFSTimeout: n.config.HTTPGateway.IPFSTimeout,
|
IPFSTimeout: n.config.HTTPGateway.IPFSTimeout,
|
||||||
BaseDomain: n.config.HTTPGateway.BaseDomain,
|
BaseDomain: n.config.HTTPGateway.BaseDomain,
|
||||||
DataDir: oramaDir,
|
DataDir: oramaDir,
|
||||||
ClusterSecret: clusterSecret,
|
ClusterSecret: clusterSecret,
|
||||||
|
WebRTCEnabled: n.config.HTTPGateway.WebRTC.Enabled,
|
||||||
|
SFUPort: n.config.HTTPGateway.WebRTC.SFUPort,
|
||||||
|
TURNDomain: n.config.HTTPGateway.WebRTC.TURNDomain,
|
||||||
|
TURNSecret: n.config.HTTPGateway.WebRTC.TURNSecret,
|
||||||
}
|
}
|
||||||
|
|
||||||
apiGateway, err := gateway.New(gatewayLogger, gwCfg)
|
apiGateway, err := gateway.New(gatewayLogger, gwCfg)
|
||||||
@ -82,6 +86,7 @@ func (n *Node) startHTTPGateway(ctx context.Context) error {
|
|||||||
clusterManager.SetLocalNodeID(gwCfg.NodePeerID)
|
clusterManager.SetLocalNodeID(gwCfg.NodePeerID)
|
||||||
apiGateway.SetClusterProvisioner(clusterManager)
|
apiGateway.SetClusterProvisioner(clusterManager)
|
||||||
apiGateway.SetNodeRecoverer(clusterManager)
|
apiGateway.SetNodeRecoverer(clusterManager)
|
||||||
|
apiGateway.SetWebRTCManager(clusterManager)
|
||||||
|
|
||||||
// Wire spawn handler for distributed namespace instance spawning
|
// Wire spawn handler for distributed namespace instance spawning
|
||||||
systemdSpawner := namespace.NewSystemdSpawner(baseDataDir, n.logger.Logger)
|
systemdSpawner := namespace.NewSystemdSpawner(baseDataDir, n.logger.Logger)
|
||||||
|
|||||||
80
pkg/sfu/config.go
Normal file
80
pkg/sfu/config.go
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
package sfu
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
|
// Config holds configuration for the SFU server
|
||||||
|
type Config struct {
|
||||||
|
// ListenAddr is the address to bind the signaling WebSocket server.
|
||||||
|
// Must be a WireGuard IP (10.0.0.x) — never 0.0.0.0.
|
||||||
|
ListenAddr string `yaml:"listen_addr"`
|
||||||
|
|
||||||
|
// Namespace this SFU instance belongs to
|
||||||
|
Namespace string `yaml:"namespace"`
|
||||||
|
|
||||||
|
// MediaPortRange defines the UDP port range for RTP media
|
||||||
|
MediaPortStart int `yaml:"media_port_start"`
|
||||||
|
MediaPortEnd int `yaml:"media_port_end"`
|
||||||
|
|
||||||
|
// TURN servers this SFU should advertise to peers
|
||||||
|
TURNServers []TURNServerConfig `yaml:"turn_servers"`
|
||||||
|
|
||||||
|
// TURNSecret is the shared HMAC-SHA1 secret for generating TURN credentials
|
||||||
|
TURNSecret string `yaml:"turn_secret"`
|
||||||
|
|
||||||
|
// TURNCredentialTTL is the lifetime of TURN credentials in seconds
|
||||||
|
TURNCredentialTTL int `yaml:"turn_credential_ttl"`
|
||||||
|
|
||||||
|
// RQLiteDSN is the namespace-local RQLite DSN for room state
|
||||||
|
RQLiteDSN string `yaml:"rqlite_dsn"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// TURNServerConfig represents a single TURN server endpoint
|
||||||
|
type TURNServerConfig struct {
|
||||||
|
Host string `yaml:"host"` // IP or hostname
|
||||||
|
Port int `yaml:"port"` // UDP port (3478 or 443)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate checks the SFU configuration for errors
|
||||||
|
func (c *Config) Validate() []error {
|
||||||
|
var errs []error
|
||||||
|
|
||||||
|
if c.ListenAddr == "" {
|
||||||
|
errs = append(errs, fmt.Errorf("sfu.listen_addr: must not be empty"))
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.Namespace == "" {
|
||||||
|
errs = append(errs, fmt.Errorf("sfu.namespace: must not be empty"))
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.MediaPortStart <= 0 || c.MediaPortEnd <= 0 {
|
||||||
|
errs = append(errs, fmt.Errorf("sfu.media_port_range: start and end must be positive"))
|
||||||
|
} else if c.MediaPortEnd <= c.MediaPortStart {
|
||||||
|
errs = append(errs, fmt.Errorf("sfu.media_port_range: end (%d) must be greater than start (%d)", c.MediaPortEnd, c.MediaPortStart))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(c.TURNServers) == 0 {
|
||||||
|
errs = append(errs, fmt.Errorf("sfu.turn_servers: at least one TURN server must be configured"))
|
||||||
|
}
|
||||||
|
for i, ts := range c.TURNServers {
|
||||||
|
if ts.Host == "" {
|
||||||
|
errs = append(errs, fmt.Errorf("sfu.turn_servers[%d].host: must not be empty", i))
|
||||||
|
}
|
||||||
|
if ts.Port <= 0 || ts.Port > 65535 {
|
||||||
|
errs = append(errs, fmt.Errorf("sfu.turn_servers[%d].port: must be between 1 and 65535", i))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.TURNSecret == "" {
|
||||||
|
errs = append(errs, fmt.Errorf("sfu.turn_secret: must not be empty"))
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.TURNCredentialTTL <= 0 {
|
||||||
|
errs = append(errs, fmt.Errorf("sfu.turn_credential_ttl: must be positive"))
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.RQLiteDSN == "" {
|
||||||
|
errs = append(errs, fmt.Errorf("sfu.rqlite_dsn: must not be empty"))
|
||||||
|
}
|
||||||
|
|
||||||
|
return errs
|
||||||
|
}
|
||||||
167
pkg/sfu/config_test.go
Normal file
167
pkg/sfu/config_test.go
Normal file
@ -0,0 +1,167 @@
|
|||||||
|
package sfu
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestConfigValidation(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
config Config
|
||||||
|
wantErrs int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid config",
|
||||||
|
config: Config{
|
||||||
|
ListenAddr: "10.0.0.1:8443",
|
||||||
|
Namespace: "test-ns",
|
||||||
|
MediaPortStart: 20000,
|
||||||
|
MediaPortEnd: 20500,
|
||||||
|
TURNServers: []TURNServerConfig{{Host: "1.2.3.4", Port: 3478}},
|
||||||
|
TURNSecret: "secret-key",
|
||||||
|
TURNCredentialTTL: 600,
|
||||||
|
RQLiteDSN: "http://10.0.0.1:4001",
|
||||||
|
},
|
||||||
|
wantErrs: 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid config with multiple TURN servers",
|
||||||
|
config: Config{
|
||||||
|
ListenAddr: "10.0.0.1:8443",
|
||||||
|
Namespace: "test-ns",
|
||||||
|
MediaPortStart: 20000,
|
||||||
|
MediaPortEnd: 20500,
|
||||||
|
TURNServers: []TURNServerConfig{
|
||||||
|
{Host: "1.2.3.4", Port: 3478},
|
||||||
|
{Host: "5.6.7.8", Port: 443},
|
||||||
|
},
|
||||||
|
TURNSecret: "secret-key",
|
||||||
|
TURNCredentialTTL: 600,
|
||||||
|
RQLiteDSN: "http://10.0.0.1:4001",
|
||||||
|
},
|
||||||
|
wantErrs: 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing all fields",
|
||||||
|
config: Config{},
|
||||||
|
wantErrs: 7, // listen_addr, namespace, media_port_range, turn_servers, turn_secret, turn_credential_ttl, rqlite_dsn
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing listen addr",
|
||||||
|
config: Config{
|
||||||
|
Namespace: "test-ns",
|
||||||
|
MediaPortStart: 20000,
|
||||||
|
MediaPortEnd: 20500,
|
||||||
|
TURNServers: []TURNServerConfig{{Host: "1.2.3.4", Port: 3478}},
|
||||||
|
TURNSecret: "secret",
|
||||||
|
TURNCredentialTTL: 600,
|
||||||
|
RQLiteDSN: "http://10.0.0.1:4001",
|
||||||
|
},
|
||||||
|
wantErrs: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing namespace",
|
||||||
|
config: Config{
|
||||||
|
ListenAddr: "10.0.0.1:8443",
|
||||||
|
MediaPortStart: 20000,
|
||||||
|
MediaPortEnd: 20500,
|
||||||
|
TURNServers: []TURNServerConfig{{Host: "1.2.3.4", Port: 3478}},
|
||||||
|
TURNSecret: "secret",
|
||||||
|
TURNCredentialTTL: 600,
|
||||||
|
RQLiteDSN: "http://10.0.0.1:4001",
|
||||||
|
},
|
||||||
|
wantErrs: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid media port range - inverted",
|
||||||
|
config: Config{
|
||||||
|
ListenAddr: "10.0.0.1:8443",
|
||||||
|
Namespace: "test-ns",
|
||||||
|
MediaPortStart: 20500,
|
||||||
|
MediaPortEnd: 20000,
|
||||||
|
TURNServers: []TURNServerConfig{{Host: "1.2.3.4", Port: 3478}},
|
||||||
|
TURNSecret: "secret",
|
||||||
|
TURNCredentialTTL: 600,
|
||||||
|
RQLiteDSN: "http://10.0.0.1:4001",
|
||||||
|
},
|
||||||
|
wantErrs: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid media port range - zero",
|
||||||
|
config: Config{
|
||||||
|
ListenAddr: "10.0.0.1:8443",
|
||||||
|
Namespace: "test-ns",
|
||||||
|
MediaPortStart: 0,
|
||||||
|
MediaPortEnd: 0,
|
||||||
|
TURNServers: []TURNServerConfig{{Host: "1.2.3.4", Port: 3478}},
|
||||||
|
TURNSecret: "secret",
|
||||||
|
TURNCredentialTTL: 600,
|
||||||
|
RQLiteDSN: "http://10.0.0.1:4001",
|
||||||
|
},
|
||||||
|
wantErrs: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no TURN servers",
|
||||||
|
config: Config{
|
||||||
|
ListenAddr: "10.0.0.1:8443",
|
||||||
|
Namespace: "test-ns",
|
||||||
|
MediaPortStart: 20000,
|
||||||
|
MediaPortEnd: 20500,
|
||||||
|
TURNServers: []TURNServerConfig{},
|
||||||
|
TURNSecret: "secret",
|
||||||
|
TURNCredentialTTL: 600,
|
||||||
|
RQLiteDSN: "http://10.0.0.1:4001",
|
||||||
|
},
|
||||||
|
wantErrs: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "TURN server with invalid port",
|
||||||
|
config: Config{
|
||||||
|
ListenAddr: "10.0.0.1:8443",
|
||||||
|
Namespace: "test-ns",
|
||||||
|
MediaPortStart: 20000,
|
||||||
|
MediaPortEnd: 20500,
|
||||||
|
TURNServers: []TURNServerConfig{{Host: "1.2.3.4", Port: 0}},
|
||||||
|
TURNSecret: "secret",
|
||||||
|
TURNCredentialTTL: 600,
|
||||||
|
RQLiteDSN: "http://10.0.0.1:4001",
|
||||||
|
},
|
||||||
|
wantErrs: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "TURN server with empty host",
|
||||||
|
config: Config{
|
||||||
|
ListenAddr: "10.0.0.1:8443",
|
||||||
|
Namespace: "test-ns",
|
||||||
|
MediaPortStart: 20000,
|
||||||
|
MediaPortEnd: 20500,
|
||||||
|
TURNServers: []TURNServerConfig{{Host: "", Port: 3478}},
|
||||||
|
TURNSecret: "secret",
|
||||||
|
TURNCredentialTTL: 600,
|
||||||
|
RQLiteDSN: "http://10.0.0.1:4001",
|
||||||
|
},
|
||||||
|
wantErrs: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "negative credential TTL",
|
||||||
|
config: Config{
|
||||||
|
ListenAddr: "10.0.0.1:8443",
|
||||||
|
Namespace: "test-ns",
|
||||||
|
MediaPortStart: 20000,
|
||||||
|
MediaPortEnd: 20500,
|
||||||
|
TURNServers: []TURNServerConfig{{Host: "1.2.3.4", Port: 3478}},
|
||||||
|
TURNSecret: "secret",
|
||||||
|
TURNCredentialTTL: -1,
|
||||||
|
RQLiteDSN: "http://10.0.0.1:4001",
|
||||||
|
},
|
||||||
|
wantErrs: 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
errs := tt.config.Validate()
|
||||||
|
if len(errs) != tt.wantErrs {
|
||||||
|
t.Errorf("Validate() returned %d errors, want %d: %v", len(errs), tt.wantErrs, errs)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
340
pkg/sfu/peer.go
Normal file
340
pkg/sfu/peer.go
Normal file
@ -0,0 +1,340 @@
|
|||||||
|
package sfu
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
"github.com/pion/rtcp"
|
||||||
|
"github.com/pion/webrtc/v4"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrPeerNotInitialized = errors.New("peer connection not initialized")
|
||||||
|
ErrPeerClosed = errors.New("peer is closed")
|
||||||
|
ErrWebSocketClosed = errors.New("websocket connection closed")
|
||||||
|
)
|
||||||
|
|
||||||
|
// Peer represents a participant in a room with a WebRTC PeerConnection.
|
||||||
|
type Peer struct {
|
||||||
|
ID string
|
||||||
|
UserID string
|
||||||
|
|
||||||
|
pc *webrtc.PeerConnection
|
||||||
|
conn *websocket.Conn
|
||||||
|
room *Room
|
||||||
|
|
||||||
|
// Negotiation state machine
|
||||||
|
negotiationPending bool
|
||||||
|
batchingTracks bool
|
||||||
|
negotiationMu sync.Mutex
|
||||||
|
|
||||||
|
// Connection state
|
||||||
|
closed bool
|
||||||
|
closedMu sync.RWMutex
|
||||||
|
connMu sync.Mutex
|
||||||
|
|
||||||
|
logger *zap.Logger
|
||||||
|
onClose func(*Peer)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewPeer creates a new peer
|
||||||
|
func NewPeer(userID string, conn *websocket.Conn, room *Room, logger *zap.Logger) *Peer {
|
||||||
|
return &Peer{
|
||||||
|
ID: uuid.New().String(),
|
||||||
|
UserID: userID,
|
||||||
|
conn: conn,
|
||||||
|
room: room,
|
||||||
|
logger: logger.With(zap.String("peer_id", "")), // Updated after ID assigned
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// InitPeerConnection creates and configures the WebRTC PeerConnection.
|
||||||
|
func (p *Peer) InitPeerConnection(api *webrtc.API, iceServers []webrtc.ICEServer) error {
|
||||||
|
pc, err := api.NewPeerConnection(webrtc.Configuration{
|
||||||
|
ICEServers: iceServers,
|
||||||
|
ICETransportPolicy: webrtc.ICETransportPolicyRelay, // Force TURN relay
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
p.pc = pc
|
||||||
|
p.logger = p.logger.With(zap.String("peer_id", p.ID))
|
||||||
|
|
||||||
|
// ICE connection state changes
|
||||||
|
pc.OnICEConnectionStateChange(func(state webrtc.ICEConnectionState) {
|
||||||
|
p.logger.Info("ICE state changed", zap.String("state", state.String()))
|
||||||
|
|
||||||
|
switch state {
|
||||||
|
case webrtc.ICEConnectionStateDisconnected:
|
||||||
|
// Give 15 seconds to reconnect before removing
|
||||||
|
go p.handleReconnectTimeout()
|
||||||
|
case webrtc.ICEConnectionStateFailed, webrtc.ICEConnectionStateClosed:
|
||||||
|
p.handleDisconnect()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// ICE candidate generation
|
||||||
|
pc.OnICECandidate(func(candidate *webrtc.ICECandidate) {
|
||||||
|
if candidate == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c := candidate.ToJSON()
|
||||||
|
data := &ICECandidateData{Candidate: c.Candidate}
|
||||||
|
if c.SDPMid != nil {
|
||||||
|
data.SDPMid = *c.SDPMid
|
||||||
|
}
|
||||||
|
if c.SDPMLineIndex != nil {
|
||||||
|
data.SDPMLineIndex = *c.SDPMLineIndex
|
||||||
|
}
|
||||||
|
if c.UsernameFragment != nil {
|
||||||
|
data.UsernameFragment = *c.UsernameFragment
|
||||||
|
}
|
||||||
|
p.SendMessage(NewServerMessage(MessageTypeICECandidate, data))
|
||||||
|
})
|
||||||
|
|
||||||
|
// Incoming tracks from the client
|
||||||
|
pc.OnTrack(func(track *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) {
|
||||||
|
p.logger.Info("Track received",
|
||||||
|
zap.String("track_id", track.ID()),
|
||||||
|
zap.String("kind", track.Kind().String()),
|
||||||
|
zap.String("codec", track.Codec().MimeType))
|
||||||
|
|
||||||
|
// Read RTCP feedback (PLI/NACK) in background
|
||||||
|
go p.readRTCP(receiver, track)
|
||||||
|
|
||||||
|
// Forward track to all other peers
|
||||||
|
p.room.BroadcastTrack(p.ID, track)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Negotiation needed — only when stable
|
||||||
|
pc.OnNegotiationNeeded(func() {
|
||||||
|
p.negotiationMu.Lock()
|
||||||
|
if p.batchingTracks {
|
||||||
|
p.negotiationPending = true
|
||||||
|
p.negotiationMu.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
p.negotiationMu.Unlock()
|
||||||
|
|
||||||
|
if pc.SignalingState() == webrtc.SignalingStateStable {
|
||||||
|
p.createAndSendOffer()
|
||||||
|
} else {
|
||||||
|
p.negotiationMu.Lock()
|
||||||
|
p.negotiationPending = true
|
||||||
|
p.negotiationMu.Unlock()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// When state returns to stable, fire pending negotiation
|
||||||
|
pc.OnSignalingStateChange(func(state webrtc.SignalingState) {
|
||||||
|
if state == webrtc.SignalingStateStable {
|
||||||
|
p.negotiationMu.Lock()
|
||||||
|
pending := p.negotiationPending
|
||||||
|
p.negotiationPending = false
|
||||||
|
p.negotiationMu.Unlock()
|
||||||
|
|
||||||
|
if pending {
|
||||||
|
p.createAndSendOffer()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Peer) createAndSendOffer() {
|
||||||
|
if p.pc == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if p.pc.SignalingState() != webrtc.SignalingStateStable {
|
||||||
|
p.negotiationMu.Lock()
|
||||||
|
p.negotiationPending = true
|
||||||
|
p.negotiationMu.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
offer, err := p.pc.CreateOffer(nil)
|
||||||
|
if err != nil {
|
||||||
|
p.logger.Error("Failed to create offer", zap.Error(err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := p.pc.SetLocalDescription(offer); err != nil {
|
||||||
|
p.logger.Error("Failed to set local description", zap.Error(err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
p.SendMessage(NewServerMessage(MessageTypeOffer, &OfferData{SDP: offer.SDP}))
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandleOffer processes an SDP offer from the client
|
||||||
|
func (p *Peer) HandleOffer(sdp string) error {
|
||||||
|
if p.pc == nil {
|
||||||
|
return ErrPeerNotInitialized
|
||||||
|
}
|
||||||
|
if err := p.pc.SetRemoteDescription(webrtc.SessionDescription{
|
||||||
|
Type: webrtc.SDPTypeOffer, SDP: sdp,
|
||||||
|
}); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
answer, err := p.pc.CreateAnswer(nil)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := p.pc.SetLocalDescription(answer); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
p.SendMessage(NewServerMessage(MessageTypeAnswer, &AnswerData{SDP: answer.SDP}))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandleAnswer processes an SDP answer from the client
|
||||||
|
func (p *Peer) HandleAnswer(sdp string) error {
|
||||||
|
if p.pc == nil {
|
||||||
|
return ErrPeerNotInitialized
|
||||||
|
}
|
||||||
|
return p.pc.SetRemoteDescription(webrtc.SessionDescription{
|
||||||
|
Type: webrtc.SDPTypeAnswer, SDP: sdp,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandleICECandidate adds a remote ICE candidate
|
||||||
|
func (p *Peer) HandleICECandidate(data *ICECandidateData) error {
|
||||||
|
if p.pc == nil {
|
||||||
|
return ErrPeerNotInitialized
|
||||||
|
}
|
||||||
|
return p.pc.AddICECandidate(data.ToWebRTCCandidate())
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddTrack adds a local track to send to this peer
|
||||||
|
func (p *Peer) AddTrack(track *webrtc.TrackLocalStaticRTP) (*webrtc.RTPSender, error) {
|
||||||
|
if p.pc == nil {
|
||||||
|
return nil, ErrPeerNotInitialized
|
||||||
|
}
|
||||||
|
return p.pc.AddTrack(track)
|
||||||
|
}
|
||||||
|
|
||||||
|
// StartTrackBatch suppresses renegotiation during bulk track additions
|
||||||
|
func (p *Peer) StartTrackBatch() {
|
||||||
|
p.negotiationMu.Lock()
|
||||||
|
p.batchingTracks = true
|
||||||
|
p.negotiationMu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// EndTrackBatch ends batching and fires deferred renegotiation
|
||||||
|
func (p *Peer) EndTrackBatch() {
|
||||||
|
p.negotiationMu.Lock()
|
||||||
|
p.batchingTracks = false
|
||||||
|
pending := p.negotiationPending
|
||||||
|
p.negotiationPending = false
|
||||||
|
p.negotiationMu.Unlock()
|
||||||
|
|
||||||
|
if pending && p.pc != nil && p.pc.SignalingState() == webrtc.SignalingStateStable {
|
||||||
|
p.createAndSendOffer()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendMessage sends a signaling message via WebSocket
|
||||||
|
func (p *Peer) SendMessage(msg *ServerMessage) error {
|
||||||
|
p.closedMu.RLock()
|
||||||
|
if p.closed {
|
||||||
|
p.closedMu.RUnlock()
|
||||||
|
return ErrPeerClosed
|
||||||
|
}
|
||||||
|
p.closedMu.RUnlock()
|
||||||
|
|
||||||
|
p.connMu.Lock()
|
||||||
|
defer p.connMu.Unlock()
|
||||||
|
if p.conn == nil {
|
||||||
|
return ErrWebSocketClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := json.Marshal(msg)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return p.conn.WriteMessage(websocket.TextMessage, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetInfo returns public info about this peer
|
||||||
|
func (p *Peer) GetInfo() ParticipantInfo {
|
||||||
|
return ParticipantInfo{PeerID: p.ID, UserID: p.UserID}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleReconnectTimeout waits 15 seconds for ICE reconnection before removing the peer.
|
||||||
|
func (p *Peer) handleReconnectTimeout() {
|
||||||
|
// Use a channel that closes when peer state changes
|
||||||
|
// Check after 15 seconds if still disconnected
|
||||||
|
<-timeAfter(reconnectTimeout)
|
||||||
|
|
||||||
|
if p.pc == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
state := p.pc.ICEConnectionState()
|
||||||
|
if state == webrtc.ICEConnectionStateDisconnected || state == webrtc.ICEConnectionStateFailed {
|
||||||
|
p.logger.Info("Peer did not reconnect within timeout, removing")
|
||||||
|
p.handleDisconnect()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Peer) handleDisconnect() {
|
||||||
|
p.closedMu.Lock()
|
||||||
|
if p.closed {
|
||||||
|
p.closedMu.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
p.closed = true
|
||||||
|
p.closedMu.Unlock()
|
||||||
|
|
||||||
|
if p.onClose != nil {
|
||||||
|
p.onClose(p)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the peer connection and WebSocket
|
||||||
|
func (p *Peer) Close() error {
|
||||||
|
p.closedMu.Lock()
|
||||||
|
if p.closed {
|
||||||
|
p.closedMu.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
p.closed = true
|
||||||
|
p.closedMu.Unlock()
|
||||||
|
|
||||||
|
p.connMu.Lock()
|
||||||
|
if p.conn != nil {
|
||||||
|
p.conn.Close()
|
||||||
|
p.conn = nil
|
||||||
|
}
|
||||||
|
p.connMu.Unlock()
|
||||||
|
|
||||||
|
if p.pc != nil {
|
||||||
|
return p.pc.Close()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnClose sets the disconnect callback
|
||||||
|
func (p *Peer) OnClose(fn func(*Peer)) {
|
||||||
|
p.onClose = fn
|
||||||
|
}
|
||||||
|
|
||||||
|
// readRTCP reads RTCP feedback and forwards PLI/FIR to the source peer
|
||||||
|
func (p *Peer) readRTCP(receiver *webrtc.RTPReceiver, track *webrtc.TrackRemote) {
|
||||||
|
localTrackID := track.Kind().String() + "-" + p.ID
|
||||||
|
|
||||||
|
for {
|
||||||
|
packets, _, err := receiver.ReadRTCP()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for _, pkt := range packets {
|
||||||
|
switch pkt.(type) {
|
||||||
|
case *rtcp.PictureLossIndication, *rtcp.FullIntraRequest:
|
||||||
|
p.room.RequestKeyframe(localTrackID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
555
pkg/sfu/room.go
Normal file
555
pkg/sfu/room.go
Normal file
@ -0,0 +1,555 @@
|
|||||||
|
package sfu
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/DeBrosOfficial/network/pkg/turn"
|
||||||
|
"github.com/pion/interceptor"
|
||||||
|
"github.com/pion/interceptor/pkg/intervalpli"
|
||||||
|
"github.com/pion/interceptor/pkg/nack"
|
||||||
|
"github.com/pion/rtcp"
|
||||||
|
"github.com/pion/webrtc/v4"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
// For testing: allows overriding time.After
|
||||||
|
var timeAfter = func(d time.Duration) <-chan time.Time { return time.After(d) }
|
||||||
|
|
||||||
|
const (
|
||||||
|
reconnectTimeout = 15 * time.Second
|
||||||
|
emptyRoomTTL = 60 * time.Second
|
||||||
|
rtpBufferSize = 8192
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrRoomFull = errors.New("room is full")
|
||||||
|
ErrRoomClosed = errors.New("room is closed")
|
||||||
|
ErrPeerNotFound = errors.New("peer not found")
|
||||||
|
)
|
||||||
|
|
||||||
|
// publishedTrack holds a local track being forwarded from a remote source.
|
||||||
|
type publishedTrack struct {
|
||||||
|
sourcePeerID string
|
||||||
|
localTrack *webrtc.TrackLocalStaticRTP
|
||||||
|
remoteTrackSSRC uint32
|
||||||
|
kind string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Room is a WebRTC room with multiple participants sharing media tracks.
|
||||||
|
type Room struct {
|
||||||
|
ID string
|
||||||
|
Namespace string
|
||||||
|
|
||||||
|
peers map[string]*Peer
|
||||||
|
peersMu sync.RWMutex
|
||||||
|
|
||||||
|
publishedTracks map[string]*publishedTrack // key: localTrack.ID()
|
||||||
|
publishedTracksMu sync.RWMutex
|
||||||
|
|
||||||
|
api *webrtc.API
|
||||||
|
config *Config
|
||||||
|
logger *zap.Logger
|
||||||
|
|
||||||
|
closed bool
|
||||||
|
closedMu sync.RWMutex
|
||||||
|
|
||||||
|
onEmpty func(*Room)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RoomManager manages the lifecycle of rooms.
|
||||||
|
type RoomManager struct {
|
||||||
|
rooms map[string]*Room // key: roomID
|
||||||
|
mu sync.RWMutex
|
||||||
|
config *Config
|
||||||
|
logger *zap.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRoomManager creates a new room manager.
|
||||||
|
func NewRoomManager(cfg *Config, logger *zap.Logger) *RoomManager {
|
||||||
|
return &RoomManager{
|
||||||
|
rooms: make(map[string]*Room),
|
||||||
|
config: cfg,
|
||||||
|
logger: logger.With(zap.String("component", "room-manager")),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOrCreateRoom returns an existing room or creates a new one.
|
||||||
|
func (rm *RoomManager) GetOrCreateRoom(roomID string) *Room {
|
||||||
|
rm.mu.Lock()
|
||||||
|
defer rm.mu.Unlock()
|
||||||
|
|
||||||
|
if room, ok := rm.rooms[roomID]; ok && !room.IsClosed() {
|
||||||
|
return room
|
||||||
|
}
|
||||||
|
|
||||||
|
api := newWebRTCAPI(rm.config)
|
||||||
|
room := &Room{
|
||||||
|
ID: roomID,
|
||||||
|
Namespace: rm.config.Namespace,
|
||||||
|
peers: make(map[string]*Peer),
|
||||||
|
publishedTracks: make(map[string]*publishedTrack),
|
||||||
|
api: api,
|
||||||
|
config: rm.config,
|
||||||
|
logger: rm.logger.With(zap.String("room_id", roomID)),
|
||||||
|
}
|
||||||
|
|
||||||
|
room.onEmpty = func(r *Room) {
|
||||||
|
// Start empty room cleanup timer
|
||||||
|
go func() {
|
||||||
|
<-timeAfter(emptyRoomTTL)
|
||||||
|
if r.GetParticipantCount() == 0 {
|
||||||
|
rm.mu.Lock()
|
||||||
|
delete(rm.rooms, r.ID)
|
||||||
|
rm.mu.Unlock()
|
||||||
|
r.Close()
|
||||||
|
rm.logger.Info("Empty room cleaned up", zap.String("room_id", r.ID))
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
rm.rooms[roomID] = room
|
||||||
|
rm.logger.Info("Room created", zap.String("room_id", roomID))
|
||||||
|
return room
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRoom returns a room by ID, or nil if not found.
|
||||||
|
func (rm *RoomManager) GetRoom(roomID string) *Room {
|
||||||
|
rm.mu.RLock()
|
||||||
|
defer rm.mu.RUnlock()
|
||||||
|
return rm.rooms[roomID]
|
||||||
|
}
|
||||||
|
|
||||||
|
// CloseAll closes all rooms (for graceful shutdown).
|
||||||
|
func (rm *RoomManager) CloseAll() {
|
||||||
|
rm.mu.Lock()
|
||||||
|
rooms := make([]*Room, 0, len(rm.rooms))
|
||||||
|
for _, r := range rm.rooms {
|
||||||
|
rooms = append(rooms, r)
|
||||||
|
}
|
||||||
|
rm.rooms = make(map[string]*Room)
|
||||||
|
rm.mu.Unlock()
|
||||||
|
|
||||||
|
for _, r := range rooms {
|
||||||
|
r.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RoomCount returns the number of active rooms.
|
||||||
|
func (rm *RoomManager) RoomCount() int {
|
||||||
|
rm.mu.RLock()
|
||||||
|
defer rm.mu.RUnlock()
|
||||||
|
return len(rm.rooms)
|
||||||
|
}
|
||||||
|
|
||||||
|
// newWebRTCAPI creates a Pion WebRTC API with codecs and interceptors.
|
||||||
|
func newWebRTCAPI(cfg *Config) *webrtc.API {
|
||||||
|
m := &webrtc.MediaEngine{}
|
||||||
|
|
||||||
|
// Audio: Opus
|
||||||
|
videoRTCPFeedback := []webrtc.RTCPFeedback{
|
||||||
|
{Type: "goog-remb", Parameter: ""},
|
||||||
|
{Type: "ccm", Parameter: "fir"},
|
||||||
|
{Type: "nack", Parameter: ""},
|
||||||
|
{Type: "nack", Parameter: "pli"},
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = m.RegisterCodec(webrtc.RTPCodecParameters{
|
||||||
|
RTPCodecCapability: webrtc.RTPCodecCapability{
|
||||||
|
MimeType: webrtc.MimeTypeOpus,
|
||||||
|
ClockRate: 48000,
|
||||||
|
Channels: 2,
|
||||||
|
SDPFmtpLine: "minptime=10;useinbandfec=1",
|
||||||
|
},
|
||||||
|
PayloadType: 111,
|
||||||
|
}, webrtc.RTPCodecTypeAudio)
|
||||||
|
|
||||||
|
// Video: VP8
|
||||||
|
_ = m.RegisterCodec(webrtc.RTPCodecParameters{
|
||||||
|
RTPCodecCapability: webrtc.RTPCodecCapability{
|
||||||
|
MimeType: webrtc.MimeTypeVP8,
|
||||||
|
ClockRate: 90000,
|
||||||
|
RTCPFeedback: videoRTCPFeedback,
|
||||||
|
},
|
||||||
|
PayloadType: 96,
|
||||||
|
}, webrtc.RTPCodecTypeVideo)
|
||||||
|
|
||||||
|
// Video: H264
|
||||||
|
_ = m.RegisterCodec(webrtc.RTPCodecParameters{
|
||||||
|
RTPCodecCapability: webrtc.RTPCodecCapability{
|
||||||
|
MimeType: webrtc.MimeTypeH264,
|
||||||
|
ClockRate: 90000,
|
||||||
|
SDPFmtpLine: "level-asymmetry-allowed=1;packetization-mode=1;profile-level-id=42001f",
|
||||||
|
RTCPFeedback: videoRTCPFeedback,
|
||||||
|
},
|
||||||
|
PayloadType: 125,
|
||||||
|
}, webrtc.RTPCodecTypeVideo)
|
||||||
|
|
||||||
|
// Interceptors: NACK + PLI
|
||||||
|
i := &interceptor.Registry{}
|
||||||
|
if f, err := nack.NewResponderInterceptor(); err == nil {
|
||||||
|
i.Add(f)
|
||||||
|
}
|
||||||
|
if f, err := nack.NewGeneratorInterceptor(); err == nil {
|
||||||
|
i.Add(f)
|
||||||
|
}
|
||||||
|
if f, err := intervalpli.NewReceiverInterceptor(); err == nil {
|
||||||
|
i.Add(f)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SettingEngine: restrict media ports
|
||||||
|
se := webrtc.SettingEngine{}
|
||||||
|
if cfg.MediaPortStart > 0 && cfg.MediaPortEnd > 0 {
|
||||||
|
se.SetEphemeralUDPPortRange(uint16(cfg.MediaPortStart), uint16(cfg.MediaPortEnd))
|
||||||
|
}
|
||||||
|
|
||||||
|
return webrtc.NewAPI(
|
||||||
|
webrtc.WithMediaEngine(m),
|
||||||
|
webrtc.WithInterceptorRegistry(i),
|
||||||
|
webrtc.WithSettingEngine(se),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Room methods ---
|
||||||
|
|
||||||
|
// AddPeer adds a peer to the room and notifies other participants.
|
||||||
|
func (r *Room) AddPeer(peer *Peer) error {
|
||||||
|
r.closedMu.RLock()
|
||||||
|
if r.closed {
|
||||||
|
r.closedMu.RUnlock()
|
||||||
|
return ErrRoomClosed
|
||||||
|
}
|
||||||
|
r.closedMu.RUnlock()
|
||||||
|
|
||||||
|
// Build ICE servers for TURN
|
||||||
|
iceServers := r.buildICEServers()
|
||||||
|
|
||||||
|
r.peersMu.Lock()
|
||||||
|
if len(r.peers) >= 100 { // Hard cap
|
||||||
|
r.peersMu.Unlock()
|
||||||
|
return ErrRoomFull
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := peer.InitPeerConnection(r.api, iceServers); err != nil {
|
||||||
|
r.peersMu.Unlock()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
peer.OnClose(func(p *Peer) { r.RemovePeer(p.ID) })
|
||||||
|
|
||||||
|
r.peers[peer.ID] = peer
|
||||||
|
info := peer.GetInfo()
|
||||||
|
total := len(r.peers)
|
||||||
|
r.peersMu.Unlock()
|
||||||
|
|
||||||
|
r.logger.Info("Peer joined", zap.String("peer_id", peer.ID), zap.Int("total", total))
|
||||||
|
|
||||||
|
// Notify others
|
||||||
|
r.broadcastMessage(peer.ID, NewServerMessage(MessageTypeParticipantJoined, &ParticipantJoinedData{
|
||||||
|
Participant: info,
|
||||||
|
}))
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemovePeer removes a peer and cleans up their published tracks.
|
||||||
|
func (r *Room) RemovePeer(peerID string) {
|
||||||
|
r.peersMu.Lock()
|
||||||
|
peer, ok := r.peers[peerID]
|
||||||
|
if !ok {
|
||||||
|
r.peersMu.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
delete(r.peers, peerID)
|
||||||
|
remaining := len(r.peers)
|
||||||
|
r.peersMu.Unlock()
|
||||||
|
|
||||||
|
// Remove published tracks from this peer
|
||||||
|
r.publishedTracksMu.Lock()
|
||||||
|
var removed []string
|
||||||
|
for trackID, pt := range r.publishedTracks {
|
||||||
|
if pt.sourcePeerID == peerID {
|
||||||
|
delete(r.publishedTracks, trackID)
|
||||||
|
removed = append(removed, trackID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
r.publishedTracksMu.Unlock()
|
||||||
|
|
||||||
|
// Remove RTPSenders for this peer's tracks from all other peers
|
||||||
|
if len(removed) > 0 {
|
||||||
|
r.removeTrackSendersFromPeers(removed)
|
||||||
|
}
|
||||||
|
|
||||||
|
peer.Close()
|
||||||
|
|
||||||
|
r.logger.Info("Peer left", zap.String("peer_id", peerID), zap.Int("remaining", remaining))
|
||||||
|
|
||||||
|
r.broadcastMessage(peerID, NewServerMessage(MessageTypeParticipantLeft, &ParticipantLeftData{
|
||||||
|
PeerID: peerID,
|
||||||
|
}))
|
||||||
|
|
||||||
|
// Notify about removed tracks
|
||||||
|
for _, trackID := range removed {
|
||||||
|
r.broadcastMessage(peerID, NewServerMessage(MessageTypeTrackRemoved, &TrackRemovedData{
|
||||||
|
PeerID: peerID,
|
||||||
|
TrackID: trackID,
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
if remaining == 0 && r.onEmpty != nil {
|
||||||
|
r.onEmpty(r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// removeTrackSendersFromPeers removes RTPSenders for the given track IDs from all peers.
|
||||||
|
// This fixes the ghost track bug from the original implementation.
|
||||||
|
func (r *Room) removeTrackSendersFromPeers(trackIDs []string) {
|
||||||
|
trackIDSet := make(map[string]bool, len(trackIDs))
|
||||||
|
for _, id := range trackIDs {
|
||||||
|
trackIDSet[id] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
r.peersMu.RLock()
|
||||||
|
defer r.peersMu.RUnlock()
|
||||||
|
|
||||||
|
for _, peer := range r.peers {
|
||||||
|
if peer.pc == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, sender := range peer.pc.GetSenders() {
|
||||||
|
if sender.Track() == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if trackIDSet[sender.Track().ID()] {
|
||||||
|
if err := peer.pc.RemoveTrack(sender); err != nil {
|
||||||
|
r.logger.Warn("Failed to remove track sender",
|
||||||
|
zap.String("peer_id", peer.ID),
|
||||||
|
zap.String("track_id", sender.Track().ID()),
|
||||||
|
zap.Error(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BroadcastTrack creates a local track from a remote track and forwards it to all other peers.
|
||||||
|
func (r *Room) BroadcastTrack(sourcePeerID string, track *webrtc.TrackRemote) {
|
||||||
|
codec := track.Codec()
|
||||||
|
|
||||||
|
localTrack, err := webrtc.NewTrackLocalStaticRTP(
|
||||||
|
codec.RTPCodecCapability,
|
||||||
|
track.Kind().String()+"-"+sourcePeerID,
|
||||||
|
sourcePeerID,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
r.logger.Error("Failed to create local track", zap.Error(err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store for future joiners
|
||||||
|
r.publishedTracksMu.Lock()
|
||||||
|
r.publishedTracks[localTrack.ID()] = &publishedTrack{
|
||||||
|
sourcePeerID: sourcePeerID,
|
||||||
|
localTrack: localTrack,
|
||||||
|
remoteTrackSSRC: uint32(track.SSRC()),
|
||||||
|
kind: track.Kind().String(),
|
||||||
|
}
|
||||||
|
r.publishedTracksMu.Unlock()
|
||||||
|
|
||||||
|
// RTP forwarding loop with proper buffer size
|
||||||
|
go func() {
|
||||||
|
buf := make([]byte, rtpBufferSize)
|
||||||
|
for {
|
||||||
|
n, _, err := track.Read(buf)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if _, err := localTrack.Write(buf[:n]); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Add to all current peers except the source
|
||||||
|
r.peersMu.RLock()
|
||||||
|
for peerID, peer := range r.peers {
|
||||||
|
if peerID == sourcePeerID {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, err := peer.AddTrack(localTrack); err != nil {
|
||||||
|
r.logger.Warn("Failed to add track to peer",
|
||||||
|
zap.String("peer_id", peerID), zap.Error(err))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
peer.SendMessage(NewServerMessage(MessageTypeTrackAdded, &TrackAddedData{
|
||||||
|
PeerID: sourcePeerID,
|
||||||
|
TrackID: localTrack.ID(),
|
||||||
|
StreamID: localTrack.StreamID(),
|
||||||
|
Kind: track.Kind().String(),
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
r.peersMu.RUnlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendExistingTracksTo sends all published tracks to a newly joined peer.
|
||||||
|
// Uses batch mode for a single renegotiation.
|
||||||
|
func (r *Room) SendExistingTracksTo(peer *Peer) {
|
||||||
|
r.publishedTracksMu.RLock()
|
||||||
|
var tracks []*publishedTrack
|
||||||
|
for _, pt := range r.publishedTracks {
|
||||||
|
if pt.sourcePeerID != peer.ID {
|
||||||
|
tracks = append(tracks, pt)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
r.publishedTracksMu.RUnlock()
|
||||||
|
|
||||||
|
if len(tracks) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
peer.StartTrackBatch()
|
||||||
|
for _, pt := range tracks {
|
||||||
|
if _, err := peer.AddTrack(pt.localTrack); err != nil {
|
||||||
|
r.logger.Warn("Failed to add existing track", zap.Error(err))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
peer.SendMessage(NewServerMessage(MessageTypeTrackAdded, &TrackAddedData{
|
||||||
|
PeerID: pt.sourcePeerID,
|
||||||
|
TrackID: pt.localTrack.ID(),
|
||||||
|
StreamID: pt.localTrack.StreamID(),
|
||||||
|
Kind: pt.kind,
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
peer.EndTrackBatch()
|
||||||
|
|
||||||
|
// Request keyframes for video tracks after negotiation settles
|
||||||
|
go func() {
|
||||||
|
<-timeAfter(300 * time.Millisecond)
|
||||||
|
r.RequestKeyframeForAllVideoTracks()
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequestKeyframe sends a PLI to the source peer for a video track.
|
||||||
|
func (r *Room) RequestKeyframe(trackID string) {
|
||||||
|
r.publishedTracksMu.RLock()
|
||||||
|
pt, ok := r.publishedTracks[trackID]
|
||||||
|
r.publishedTracksMu.RUnlock()
|
||||||
|
if !ok || pt.kind != "video" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
r.peersMu.RLock()
|
||||||
|
source, ok := r.peers[pt.sourcePeerID]
|
||||||
|
r.peersMu.RUnlock()
|
||||||
|
if !ok || source.pc == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
pli := &rtcp.PictureLossIndication{MediaSSRC: pt.remoteTrackSSRC}
|
||||||
|
if err := source.pc.WriteRTCP([]rtcp.Packet{pli}); err != nil {
|
||||||
|
r.logger.Debug("Failed to send PLI", zap.String("track_id", trackID), zap.Error(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequestKeyframeForAllVideoTracks sends PLIs for all video tracks.
|
||||||
|
func (r *Room) RequestKeyframeForAllVideoTracks() {
|
||||||
|
r.publishedTracksMu.RLock()
|
||||||
|
var ids []string
|
||||||
|
for id, pt := range r.publishedTracks {
|
||||||
|
if pt.kind == "video" {
|
||||||
|
ids = append(ids, id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
r.publishedTracksMu.RUnlock()
|
||||||
|
|
||||||
|
for _, id := range ids {
|
||||||
|
r.RequestKeyframe(id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetParticipants returns info about all participants.
|
||||||
|
func (r *Room) GetParticipants() []ParticipantInfo {
|
||||||
|
r.peersMu.RLock()
|
||||||
|
defer r.peersMu.RUnlock()
|
||||||
|
infos := make([]ParticipantInfo, 0, len(r.peers))
|
||||||
|
for _, p := range r.peers {
|
||||||
|
infos = append(infos, p.GetInfo())
|
||||||
|
}
|
||||||
|
return infos
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetParticipantCount returns the number of participants.
|
||||||
|
func (r *Room) GetParticipantCount() int {
|
||||||
|
r.peersMu.RLock()
|
||||||
|
defer r.peersMu.RUnlock()
|
||||||
|
return len(r.peers)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsClosed returns whether the room is closed.
|
||||||
|
func (r *Room) IsClosed() bool {
|
||||||
|
r.closedMu.RLock()
|
||||||
|
defer r.closedMu.RUnlock()
|
||||||
|
return r.closed
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the room and all peer connections.
|
||||||
|
func (r *Room) Close() error {
|
||||||
|
r.closedMu.Lock()
|
||||||
|
if r.closed {
|
||||||
|
r.closedMu.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
r.closed = true
|
||||||
|
r.closedMu.Unlock()
|
||||||
|
|
||||||
|
r.peersMu.Lock()
|
||||||
|
peers := make([]*Peer, 0, len(r.peers))
|
||||||
|
for _, p := range r.peers {
|
||||||
|
peers = append(peers, p)
|
||||||
|
}
|
||||||
|
r.peers = make(map[string]*Peer)
|
||||||
|
r.peersMu.Unlock()
|
||||||
|
|
||||||
|
for _, p := range peers {
|
||||||
|
p.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
r.logger.Info("Room closed")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Room) broadcastMessage(excludePeerID string, msg *ServerMessage) {
|
||||||
|
r.peersMu.RLock()
|
||||||
|
defer r.peersMu.RUnlock()
|
||||||
|
for id, peer := range r.peers {
|
||||||
|
if id == excludePeerID {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
peer.SendMessage(msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildICEServers constructs ICE server config from TURN settings.
|
||||||
|
func (r *Room) buildICEServers() []webrtc.ICEServer {
|
||||||
|
if len(r.config.TURNServers) == 0 || r.config.TURNSecret == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var urls []string
|
||||||
|
for _, ts := range r.config.TURNServers {
|
||||||
|
urls = append(urls, fmt.Sprintf("turn:%s:%d?transport=udp", ts.Host, ts.Port))
|
||||||
|
}
|
||||||
|
|
||||||
|
ttl := time.Duration(r.config.TURNCredentialTTL) * time.Second
|
||||||
|
username, password := turn.GenerateCredentials(r.config.TURNSecret, r.config.Namespace, ttl)
|
||||||
|
|
||||||
|
return []webrtc.ICEServer{
|
||||||
|
{
|
||||||
|
URLs: urls,
|
||||||
|
Username: username,
|
||||||
|
Credential: password,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
368
pkg/sfu/room_test.go
Normal file
368
pkg/sfu/room_test.go
Normal file
@ -0,0 +1,368 @@
|
|||||||
|
package sfu
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
func testConfig() *Config {
|
||||||
|
return &Config{
|
||||||
|
ListenAddr: "10.0.0.1:8443",
|
||||||
|
Namespace: "test-ns",
|
||||||
|
MediaPortStart: 20000,
|
||||||
|
MediaPortEnd: 20500,
|
||||||
|
TURNServers: []TURNServerConfig{{Host: "1.2.3.4", Port: 3478}},
|
||||||
|
TURNSecret: "test-secret-key-32bytes-long!!!!",
|
||||||
|
TURNCredentialTTL: 600,
|
||||||
|
RQLiteDSN: "http://10.0.0.1:4001",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func testLogger() *zap.Logger {
|
||||||
|
return zap.NewNop()
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- RoomManager tests ---
|
||||||
|
|
||||||
|
func TestNewRoomManager(t *testing.T) {
|
||||||
|
rm := NewRoomManager(testConfig(), testLogger())
|
||||||
|
if rm == nil {
|
||||||
|
t.Fatal("NewRoomManager returned nil")
|
||||||
|
}
|
||||||
|
if rm.RoomCount() != 0 {
|
||||||
|
t.Errorf("RoomCount = %d, want 0", rm.RoomCount())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRoomManagerGetOrCreateRoom(t *testing.T) {
|
||||||
|
rm := NewRoomManager(testConfig(), testLogger())
|
||||||
|
|
||||||
|
room1 := rm.GetOrCreateRoom("room-1")
|
||||||
|
if room1 == nil {
|
||||||
|
t.Fatal("GetOrCreateRoom returned nil")
|
||||||
|
}
|
||||||
|
if room1.ID != "room-1" {
|
||||||
|
t.Errorf("Room.ID = %q, want %q", room1.ID, "room-1")
|
||||||
|
}
|
||||||
|
if room1.Namespace != "test-ns" {
|
||||||
|
t.Errorf("Room.Namespace = %q, want %q", room1.Namespace, "test-ns")
|
||||||
|
}
|
||||||
|
if rm.RoomCount() != 1 {
|
||||||
|
t.Errorf("RoomCount = %d, want 1", rm.RoomCount())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Getting same room returns same instance
|
||||||
|
room1Again := rm.GetOrCreateRoom("room-1")
|
||||||
|
if room1 != room1Again {
|
||||||
|
t.Error("expected same room instance")
|
||||||
|
}
|
||||||
|
if rm.RoomCount() != 1 {
|
||||||
|
t.Errorf("RoomCount = %d, want 1 (same room)", rm.RoomCount())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Different room creates new instance
|
||||||
|
room2 := rm.GetOrCreateRoom("room-2")
|
||||||
|
if room2 == nil {
|
||||||
|
t.Fatal("second room is nil")
|
||||||
|
}
|
||||||
|
if room2.ID != "room-2" {
|
||||||
|
t.Errorf("Room.ID = %q, want %q", room2.ID, "room-2")
|
||||||
|
}
|
||||||
|
if rm.RoomCount() != 2 {
|
||||||
|
t.Errorf("RoomCount = %d, want 2", rm.RoomCount())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRoomManagerGetRoom(t *testing.T) {
|
||||||
|
rm := NewRoomManager(testConfig(), testLogger())
|
||||||
|
|
||||||
|
// Non-existent room returns nil
|
||||||
|
room := rm.GetRoom("nonexistent")
|
||||||
|
if room != nil {
|
||||||
|
t.Error("expected nil for non-existent room")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a room and retrieve it
|
||||||
|
rm.GetOrCreateRoom("room-1")
|
||||||
|
room = rm.GetRoom("room-1")
|
||||||
|
if room == nil {
|
||||||
|
t.Fatal("expected non-nil for existing room")
|
||||||
|
}
|
||||||
|
if room.ID != "room-1" {
|
||||||
|
t.Errorf("Room.ID = %q, want %q", room.ID, "room-1")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRoomManagerCloseAll(t *testing.T) {
|
||||||
|
rm := NewRoomManager(testConfig(), testLogger())
|
||||||
|
|
||||||
|
rm.GetOrCreateRoom("room-1")
|
||||||
|
rm.GetOrCreateRoom("room-2")
|
||||||
|
rm.GetOrCreateRoom("room-3")
|
||||||
|
if rm.RoomCount() != 3 {
|
||||||
|
t.Fatalf("RoomCount = %d, want 3", rm.RoomCount())
|
||||||
|
}
|
||||||
|
|
||||||
|
rm.CloseAll()
|
||||||
|
if rm.RoomCount() != 0 {
|
||||||
|
t.Errorf("RoomCount after CloseAll = %d, want 0", rm.RoomCount())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRoomManagerGetOrCreateRoomReplacesClosedRoom(t *testing.T) {
|
||||||
|
rm := NewRoomManager(testConfig(), testLogger())
|
||||||
|
|
||||||
|
room1 := rm.GetOrCreateRoom("room-1")
|
||||||
|
room1.Close()
|
||||||
|
|
||||||
|
// Getting the same room ID after close should create a new room
|
||||||
|
room1New := rm.GetOrCreateRoom("room-1")
|
||||||
|
if room1New == room1 {
|
||||||
|
t.Error("expected new room instance after close")
|
||||||
|
}
|
||||||
|
if room1New.IsClosed() {
|
||||||
|
t.Error("new room should not be closed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Room tests ---
|
||||||
|
|
||||||
|
func TestRoomIsClosed(t *testing.T) {
|
||||||
|
rm := NewRoomManager(testConfig(), testLogger())
|
||||||
|
room := rm.GetOrCreateRoom("room-1")
|
||||||
|
|
||||||
|
if room.IsClosed() {
|
||||||
|
t.Error("new room should not be closed")
|
||||||
|
}
|
||||||
|
|
||||||
|
room.Close()
|
||||||
|
if !room.IsClosed() {
|
||||||
|
t.Error("room should be closed after Close()")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRoomCloseIdempotent(t *testing.T) {
|
||||||
|
rm := NewRoomManager(testConfig(), testLogger())
|
||||||
|
room := rm.GetOrCreateRoom("room-1")
|
||||||
|
|
||||||
|
// Should not panic or error when called multiple times
|
||||||
|
if err := room.Close(); err != nil {
|
||||||
|
t.Errorf("first Close() returned error: %v", err)
|
||||||
|
}
|
||||||
|
if err := room.Close(); err != nil {
|
||||||
|
t.Errorf("second Close() returned error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRoomGetParticipantsEmpty(t *testing.T) {
|
||||||
|
rm := NewRoomManager(testConfig(), testLogger())
|
||||||
|
room := rm.GetOrCreateRoom("room-1")
|
||||||
|
|
||||||
|
participants := room.GetParticipants()
|
||||||
|
if len(participants) != 0 {
|
||||||
|
t.Errorf("Participants count = %d, want 0", len(participants))
|
||||||
|
}
|
||||||
|
if room.GetParticipantCount() != 0 {
|
||||||
|
t.Errorf("ParticipantCount = %d, want 0", room.GetParticipantCount())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRoomBuildICEServers(t *testing.T) {
|
||||||
|
rm := NewRoomManager(testConfig(), testLogger())
|
||||||
|
room := rm.GetOrCreateRoom("room-1")
|
||||||
|
|
||||||
|
servers := room.buildICEServers()
|
||||||
|
if len(servers) != 1 {
|
||||||
|
t.Fatalf("ICE servers count = %d, want 1", len(servers))
|
||||||
|
}
|
||||||
|
if len(servers[0].URLs) != 1 {
|
||||||
|
t.Fatalf("URLs count = %d, want 1", len(servers[0].URLs))
|
||||||
|
}
|
||||||
|
if servers[0].URLs[0] != "turn:1.2.3.4:3478?transport=udp" {
|
||||||
|
t.Errorf("URL = %q, want %q", servers[0].URLs[0], "turn:1.2.3.4:3478?transport=udp")
|
||||||
|
}
|
||||||
|
if servers[0].Username == "" {
|
||||||
|
t.Error("Username should not be empty")
|
||||||
|
}
|
||||||
|
if servers[0].Credential == "" {
|
||||||
|
t.Error("Credential should not be empty")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRoomBuildICEServersNoTURN(t *testing.T) {
|
||||||
|
cfg := testConfig()
|
||||||
|
cfg.TURNServers = nil
|
||||||
|
|
||||||
|
rm := NewRoomManager(cfg, testLogger())
|
||||||
|
room := rm.GetOrCreateRoom("room-1")
|
||||||
|
|
||||||
|
servers := room.buildICEServers()
|
||||||
|
if servers != nil {
|
||||||
|
t.Errorf("expected nil ICE servers when no TURN configured, got %v", servers)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRoomBuildICEServersNoSecret(t *testing.T) {
|
||||||
|
cfg := testConfig()
|
||||||
|
cfg.TURNSecret = ""
|
||||||
|
|
||||||
|
rm := NewRoomManager(cfg, testLogger())
|
||||||
|
room := rm.GetOrCreateRoom("room-1")
|
||||||
|
|
||||||
|
servers := room.buildICEServers()
|
||||||
|
if servers != nil {
|
||||||
|
t.Errorf("expected nil ICE servers when no secret, got %v", servers)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRoomBuildICEServersMultipleTURN(t *testing.T) {
|
||||||
|
cfg := testConfig()
|
||||||
|
cfg.TURNServers = []TURNServerConfig{
|
||||||
|
{Host: "1.2.3.4", Port: 3478},
|
||||||
|
{Host: "5.6.7.8", Port: 443},
|
||||||
|
}
|
||||||
|
|
||||||
|
rm := NewRoomManager(cfg, testLogger())
|
||||||
|
room := rm.GetOrCreateRoom("room-1")
|
||||||
|
|
||||||
|
servers := room.buildICEServers()
|
||||||
|
if len(servers) != 1 {
|
||||||
|
t.Fatalf("ICE servers count = %d, want 1", len(servers))
|
||||||
|
}
|
||||||
|
if len(servers[0].URLs) != 2 {
|
||||||
|
t.Fatalf("URLs count = %d, want 2", len(servers[0].URLs))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Empty room cleanup test ---
|
||||||
|
|
||||||
|
func TestEmptyRoomCleanup(t *testing.T) {
|
||||||
|
// Override timeAfter for instant timer
|
||||||
|
origTimeAfter := timeAfter
|
||||||
|
timeAfter = func(d time.Duration) <-chan time.Time {
|
||||||
|
ch := make(chan time.Time, 1)
|
||||||
|
ch <- time.Now()
|
||||||
|
return ch
|
||||||
|
}
|
||||||
|
defer func() { timeAfter = origTimeAfter }()
|
||||||
|
|
||||||
|
rm := NewRoomManager(testConfig(), testLogger())
|
||||||
|
room := rm.GetOrCreateRoom("room-1")
|
||||||
|
|
||||||
|
// Trigger the onEmpty callback (which starts cleanup timer)
|
||||||
|
room.onEmpty(room)
|
||||||
|
|
||||||
|
// Give the goroutine time to execute
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
|
||||||
|
if rm.RoomCount() != 0 {
|
||||||
|
t.Errorf("RoomCount = %d, want 0 (should have been cleaned up)", rm.RoomCount())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Server health tests ---
|
||||||
|
|
||||||
|
func TestHealthEndpointOK(t *testing.T) {
|
||||||
|
cfg := testConfig()
|
||||||
|
server, err := NewServer(cfg, testLogger())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewServer failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/health", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
server.handleHealth(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("status = %d, want %d", w.Code, http.StatusOK)
|
||||||
|
}
|
||||||
|
body := w.Body.String()
|
||||||
|
if body != `{"status":"ok","rooms":0}` {
|
||||||
|
t.Errorf("body = %q, want %q", body, `{"status":"ok","rooms":0}`)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHealthEndpointDraining(t *testing.T) {
|
||||||
|
cfg := testConfig()
|
||||||
|
server, err := NewServer(cfg, testLogger())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewServer failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set draining
|
||||||
|
server.drainingMu.Lock()
|
||||||
|
server.draining = true
|
||||||
|
server.drainingMu.Unlock()
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/health", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
server.handleHealth(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusServiceUnavailable {
|
||||||
|
t.Errorf("status = %d, want %d", w.Code, http.StatusServiceUnavailable)
|
||||||
|
}
|
||||||
|
body := w.Body.String()
|
||||||
|
if body != `{"status":"draining","rooms":0}` {
|
||||||
|
t.Errorf("body = %q, want %q", body, `{"status":"draining","rooms":0}`)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServerDrainSetsFlag(t *testing.T) {
|
||||||
|
// Override timeAfter for instant timer
|
||||||
|
origTimeAfter := timeAfter
|
||||||
|
timeAfter = func(d time.Duration) <-chan time.Time {
|
||||||
|
ch := make(chan time.Time, 1)
|
||||||
|
ch <- time.Now()
|
||||||
|
return ch
|
||||||
|
}
|
||||||
|
defer func() { timeAfter = origTimeAfter }()
|
||||||
|
|
||||||
|
cfg := testConfig()
|
||||||
|
server, err := NewServer(cfg, testLogger())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewServer failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
server.Drain(0)
|
||||||
|
|
||||||
|
server.drainingMu.RLock()
|
||||||
|
draining := server.draining
|
||||||
|
server.drainingMu.RUnlock()
|
||||||
|
|
||||||
|
if !draining {
|
||||||
|
t.Error("expected draining to be true after Drain()")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServerNewServerValidation(t *testing.T) {
|
||||||
|
// Invalid config should return error
|
||||||
|
cfg := &Config{} // Empty = invalid
|
||||||
|
_, err := NewServer(cfg, testLogger())
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error for invalid config")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServerSignalEndpointRejectsDraining(t *testing.T) {
|
||||||
|
cfg := testConfig()
|
||||||
|
server, err := NewServer(cfg, testLogger())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewServer failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
server.drainingMu.Lock()
|
||||||
|
server.draining = true
|
||||||
|
server.drainingMu.Unlock()
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/ws/signal", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
server.handleSignal(w, req)
|
||||||
|
|
||||||
|
if w.Code != http.StatusServiceUnavailable {
|
||||||
|
t.Errorf("status = %d, want %d", w.Code, http.StatusServiceUnavailable)
|
||||||
|
}
|
||||||
|
}
|
||||||
293
pkg/sfu/server.go
Normal file
293
pkg/sfu/server.go
Normal file
@ -0,0 +1,293 @@
|
|||||||
|
package sfu
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/DeBrosOfficial/network/pkg/turn"
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Server is the SFU HTTP server providing WebSocket signaling and a health endpoint.
|
||||||
|
// It binds only to a WireGuard IP — never exposed publicly.
|
||||||
|
type Server struct {
|
||||||
|
config *Config
|
||||||
|
roomManager *RoomManager
|
||||||
|
logger *zap.Logger
|
||||||
|
httpServer *http.Server
|
||||||
|
upgrader websocket.Upgrader
|
||||||
|
draining bool
|
||||||
|
drainingMu sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewServer creates a new SFU server.
|
||||||
|
func NewServer(cfg *Config, logger *zap.Logger) (*Server, error) {
|
||||||
|
if errs := cfg.Validate(); len(errs) > 0 {
|
||||||
|
return nil, fmt.Errorf("invalid SFU config: %v", errs[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
s := &Server{
|
||||||
|
config: cfg,
|
||||||
|
roomManager: NewRoomManager(cfg, logger),
|
||||||
|
logger: logger.With(zap.String("component", "sfu"), zap.String("namespace", cfg.Namespace)),
|
||||||
|
upgrader: websocket.Upgrader{
|
||||||
|
ReadBufferSize: 4096,
|
||||||
|
WriteBufferSize: 4096,
|
||||||
|
CheckOrigin: func(r *http.Request) bool { return true }, // Gateway handles auth
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("/ws/signal", s.handleSignal)
|
||||||
|
mux.HandleFunc("/health", s.handleHealth)
|
||||||
|
|
||||||
|
s.httpServer = &http.Server{
|
||||||
|
Addr: cfg.ListenAddr,
|
||||||
|
Handler: mux,
|
||||||
|
ReadHeaderTimeout: 10 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListenAndServe starts the HTTP server. Blocks until the server is stopped.
|
||||||
|
func (s *Server) ListenAndServe() error {
|
||||||
|
s.logger.Info("SFU server starting",
|
||||||
|
zap.String("addr", s.config.ListenAddr),
|
||||||
|
zap.String("namespace", s.config.Namespace))
|
||||||
|
return s.httpServer.ListenAndServe()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Drain initiates graceful drain: notifies all peers, waits, then closes.
|
||||||
|
func (s *Server) Drain(timeout time.Duration) {
|
||||||
|
s.drainingMu.Lock()
|
||||||
|
s.draining = true
|
||||||
|
s.drainingMu.Unlock()
|
||||||
|
|
||||||
|
s.logger.Info("SFU draining started", zap.Duration("timeout", timeout))
|
||||||
|
|
||||||
|
// Notify all peers
|
||||||
|
s.roomManager.mu.RLock()
|
||||||
|
for _, room := range s.roomManager.rooms {
|
||||||
|
room.broadcastMessage("", NewServerMessage(MessageTypeServerDraining, &ServerDrainingData{
|
||||||
|
Reason: "server shutting down",
|
||||||
|
TimeoutMs: int(timeout.Milliseconds()),
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
s.roomManager.mu.RUnlock()
|
||||||
|
|
||||||
|
// Wait for timeout, then force close
|
||||||
|
<-timeAfter(timeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close shuts down the SFU server.
|
||||||
|
func (s *Server) Close() error {
|
||||||
|
s.logger.Info("SFU server shutting down")
|
||||||
|
s.roomManager.CloseAll()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
return s.httpServer.Shutdown(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleHealth is a simple health check endpoint.
|
||||||
|
func (s *Server) handleHealth(w http.ResponseWriter, r *http.Request) {
|
||||||
|
s.drainingMu.RLock()
|
||||||
|
draining := s.draining
|
||||||
|
s.drainingMu.RUnlock()
|
||||||
|
|
||||||
|
if draining {
|
||||||
|
w.WriteHeader(http.StatusServiceUnavailable)
|
||||||
|
fmt.Fprintf(w, `{"status":"draining","rooms":%d}`, s.roomManager.RoomCount())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
fmt.Fprintf(w, `{"status":"ok","rooms":%d}`, s.roomManager.RoomCount())
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleSignal upgrades to WebSocket and runs the signaling loop for one peer.
|
||||||
|
func (s *Server) handleSignal(w http.ResponseWriter, r *http.Request) {
|
||||||
|
s.drainingMu.RLock()
|
||||||
|
if s.draining {
|
||||||
|
s.drainingMu.RUnlock()
|
||||||
|
http.Error(w, "server draining", http.StatusServiceUnavailable)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.drainingMu.RUnlock()
|
||||||
|
|
||||||
|
conn, err := s.upgrader.Upgrade(w, r, nil)
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Error("WebSocket upgrade failed", zap.Error(err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.logger.Debug("WebSocket connected", zap.String("remote", r.RemoteAddr))
|
||||||
|
|
||||||
|
// Read the first message — must be a join
|
||||||
|
conn.SetReadDeadline(time.Now().Add(10 * time.Second))
|
||||||
|
_, msgBytes, err := conn.ReadMessage()
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Warn("Failed to read join message", zap.Error(err))
|
||||||
|
conn.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
conn.SetReadDeadline(time.Time{}) // Clear deadline
|
||||||
|
|
||||||
|
var msg ClientMessage
|
||||||
|
if err := json.Unmarshal(msgBytes, &msg); err != nil {
|
||||||
|
conn.WriteMessage(websocket.TextMessage, mustMarshal(NewErrorMessage("invalid_message", "malformed JSON")))
|
||||||
|
conn.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if msg.Type != MessageTypeJoin {
|
||||||
|
conn.WriteMessage(websocket.TextMessage, mustMarshal(NewErrorMessage("invalid_message", "first message must be join")))
|
||||||
|
conn.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var joinData JoinData
|
||||||
|
if err := json.Unmarshal(msg.Data, &joinData); err != nil || joinData.RoomID == "" || joinData.UserID == "" {
|
||||||
|
conn.WriteMessage(websocket.TextMessage, mustMarshal(NewErrorMessage("invalid_join", "roomId and userId required")))
|
||||||
|
conn.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
room := s.roomManager.GetOrCreateRoom(joinData.RoomID)
|
||||||
|
peer := NewPeer(joinData.UserID, conn, room, s.logger)
|
||||||
|
|
||||||
|
if err := room.AddPeer(peer); err != nil {
|
||||||
|
conn.WriteMessage(websocket.TextMessage, mustMarshal(NewErrorMessage("join_failed", err.Error())))
|
||||||
|
conn.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send welcome with current participants
|
||||||
|
peer.SendMessage(NewServerMessage(MessageTypeWelcome, &WelcomeData{
|
||||||
|
PeerID: peer.ID,
|
||||||
|
RoomID: room.ID,
|
||||||
|
Participants: room.GetParticipants(),
|
||||||
|
}))
|
||||||
|
|
||||||
|
// Send TURN credentials
|
||||||
|
if s.config.TURNSecret != "" && len(s.config.TURNServers) > 0 {
|
||||||
|
s.sendTURNCredentials(peer)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send existing tracks from other peers
|
||||||
|
room.SendExistingTracksTo(peer)
|
||||||
|
|
||||||
|
// Start credential refresh goroutine
|
||||||
|
if s.config.TURNCredentialTTL > 0 {
|
||||||
|
go s.credentialRefreshLoop(peer)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Signaling read loop
|
||||||
|
s.signalingLoop(peer, room)
|
||||||
|
}
|
||||||
|
|
||||||
|
// signalingLoop reads signaling messages from the WebSocket until disconnect.
|
||||||
|
func (s *Server) signalingLoop(peer *Peer, room *Room) {
|
||||||
|
defer room.RemovePeer(peer.ID)
|
||||||
|
|
||||||
|
for {
|
||||||
|
_, msgBytes, err := peer.conn.ReadMessage()
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Debug("WebSocket read error", zap.String("peer_id", peer.ID), zap.Error(err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var msg ClientMessage
|
||||||
|
if err := json.Unmarshal(msgBytes, &msg); err != nil {
|
||||||
|
peer.SendMessage(NewErrorMessage("invalid_message", "malformed JSON"))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
switch msg.Type {
|
||||||
|
case MessageTypeOffer:
|
||||||
|
var data OfferData
|
||||||
|
if err := json.Unmarshal(msg.Data, &data); err != nil {
|
||||||
|
peer.SendMessage(NewErrorMessage("invalid_offer", err.Error()))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := peer.HandleOffer(data.SDP); err != nil {
|
||||||
|
s.logger.Error("Failed to handle offer", zap.String("peer_id", peer.ID), zap.Error(err))
|
||||||
|
peer.SendMessage(NewErrorMessage("offer_failed", err.Error()))
|
||||||
|
}
|
||||||
|
|
||||||
|
case MessageTypeAnswer:
|
||||||
|
var data AnswerData
|
||||||
|
if err := json.Unmarshal(msg.Data, &data); err != nil {
|
||||||
|
peer.SendMessage(NewErrorMessage("invalid_answer", err.Error()))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := peer.HandleAnswer(data.SDP); err != nil {
|
||||||
|
s.logger.Error("Failed to handle answer", zap.String("peer_id", peer.ID), zap.Error(err))
|
||||||
|
}
|
||||||
|
|
||||||
|
case MessageTypeICECandidate:
|
||||||
|
var data ICECandidateData
|
||||||
|
if err := json.Unmarshal(msg.Data, &data); err != nil {
|
||||||
|
peer.SendMessage(NewErrorMessage("invalid_candidate", err.Error()))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := peer.HandleICECandidate(&data); err != nil {
|
||||||
|
s.logger.Error("Failed to handle ICE candidate", zap.String("peer_id", peer.ID), zap.Error(err))
|
||||||
|
}
|
||||||
|
|
||||||
|
case MessageTypeLeave:
|
||||||
|
s.logger.Info("Peer leaving", zap.String("peer_id", peer.ID))
|
||||||
|
return
|
||||||
|
|
||||||
|
default:
|
||||||
|
peer.SendMessage(NewErrorMessage("unknown_message", fmt.Sprintf("unknown message type: %s", msg.Type)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// sendTURNCredentials sends TURN server credentials to a peer.
|
||||||
|
func (s *Server) sendTURNCredentials(peer *Peer) {
|
||||||
|
ttl := time.Duration(s.config.TURNCredentialTTL) * time.Second
|
||||||
|
username, password := turn.GenerateCredentials(s.config.TURNSecret, s.config.Namespace, ttl)
|
||||||
|
|
||||||
|
var uris []string
|
||||||
|
for _, ts := range s.config.TURNServers {
|
||||||
|
uris = append(uris, fmt.Sprintf("turn:%s:%d?transport=udp", ts.Host, ts.Port))
|
||||||
|
}
|
||||||
|
|
||||||
|
peer.SendMessage(NewServerMessage(MessageTypeTURNCredentials, &TURNCredentialsData{
|
||||||
|
Username: username,
|
||||||
|
Password: password,
|
||||||
|
TTL: s.config.TURNCredentialTTL,
|
||||||
|
URIs: uris,
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
// credentialRefreshLoop sends fresh TURN credentials at 80% of TTL.
|
||||||
|
func (s *Server) credentialRefreshLoop(peer *Peer) {
|
||||||
|
refreshInterval := time.Duration(float64(s.config.TURNCredentialTTL)*0.8) * time.Second
|
||||||
|
|
||||||
|
for {
|
||||||
|
<-timeAfter(refreshInterval)
|
||||||
|
|
||||||
|
peer.closedMu.RLock()
|
||||||
|
closed := peer.closed
|
||||||
|
peer.closedMu.RUnlock()
|
||||||
|
if closed {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.sendTURNCredentials(peer)
|
||||||
|
s.logger.Debug("Refreshed TURN credentials", zap.String("peer_id", peer.ID))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mustMarshal(v interface{}) []byte {
|
||||||
|
data, _ := json.Marshal(v)
|
||||||
|
return data
|
||||||
|
}
|
||||||
144
pkg/sfu/signaling.go
Normal file
144
pkg/sfu/signaling.go
Normal file
@ -0,0 +1,144 @@
|
|||||||
|
package sfu
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
|
||||||
|
"github.com/pion/webrtc/v4"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MessageType represents the type of signaling message
|
||||||
|
type MessageType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
// Client → Server
|
||||||
|
MessageTypeJoin MessageType = "join"
|
||||||
|
MessageTypeLeave MessageType = "leave"
|
||||||
|
MessageTypeOffer MessageType = "offer"
|
||||||
|
MessageTypeAnswer MessageType = "answer"
|
||||||
|
MessageTypeICECandidate MessageType = "ice-candidate"
|
||||||
|
|
||||||
|
// Server → Client
|
||||||
|
MessageTypeWelcome MessageType = "welcome"
|
||||||
|
MessageTypeParticipantJoined MessageType = "participant-joined"
|
||||||
|
MessageTypeParticipantLeft MessageType = "participant-left"
|
||||||
|
MessageTypeTrackAdded MessageType = "track-added"
|
||||||
|
MessageTypeTrackRemoved MessageType = "track-removed"
|
||||||
|
MessageTypeTURNCredentials MessageType = "turn-credentials"
|
||||||
|
MessageTypeRefreshCredentials MessageType = "refresh-credentials"
|
||||||
|
MessageTypeServerDraining MessageType = "server-draining"
|
||||||
|
MessageTypeError MessageType = "error"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ClientMessage is a message from client to server
|
||||||
|
type ClientMessage struct {
|
||||||
|
Type MessageType `json:"type"`
|
||||||
|
Data json.RawMessage `json:"data,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ServerMessage is a message from server to client
|
||||||
|
type ServerMessage struct {
|
||||||
|
Type MessageType `json:"type"`
|
||||||
|
Data interface{} `json:"data,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// JoinData is the payload for join messages
|
||||||
|
type JoinData struct {
|
||||||
|
RoomID string `json:"roomId"`
|
||||||
|
UserID string `json:"userId"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// OfferData is the payload for SDP offer messages
|
||||||
|
type OfferData struct {
|
||||||
|
SDP string `json:"sdp"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// AnswerData is the payload for SDP answer messages
|
||||||
|
type AnswerData struct {
|
||||||
|
SDP string `json:"sdp"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ICECandidateData is the payload for ICE candidate messages
|
||||||
|
type ICECandidateData struct {
|
||||||
|
Candidate string `json:"candidate"`
|
||||||
|
SDPMid string `json:"sdpMid,omitempty"`
|
||||||
|
SDPMLineIndex uint16 `json:"sdpMLineIndex,omitempty"`
|
||||||
|
UsernameFragment string `json:"usernameFragment,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToWebRTCCandidate converts to pion ICECandidateInit
|
||||||
|
func (c *ICECandidateData) ToWebRTCCandidate() webrtc.ICECandidateInit {
|
||||||
|
return webrtc.ICECandidateInit{
|
||||||
|
Candidate: c.Candidate,
|
||||||
|
SDPMid: &c.SDPMid,
|
||||||
|
SDPMLineIndex: &c.SDPMLineIndex,
|
||||||
|
UsernameFragment: &c.UsernameFragment,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WelcomeData is sent when a peer successfully joins a room
|
||||||
|
type WelcomeData struct {
|
||||||
|
PeerID string `json:"peerId"`
|
||||||
|
RoomID string `json:"roomId"`
|
||||||
|
Participants []ParticipantInfo `json:"participants"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParticipantInfo is public info about a room participant
|
||||||
|
type ParticipantInfo struct {
|
||||||
|
PeerID string `json:"peerId"`
|
||||||
|
UserID string `json:"userId"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParticipantJoinedData is sent when a new participant joins
|
||||||
|
type ParticipantJoinedData struct {
|
||||||
|
Participant ParticipantInfo `json:"participant"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParticipantLeftData is sent when a participant leaves
|
||||||
|
type ParticipantLeftData struct {
|
||||||
|
PeerID string `json:"peerId"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// TrackAddedData is sent when a new track is available
|
||||||
|
type TrackAddedData struct {
|
||||||
|
PeerID string `json:"peerId"`
|
||||||
|
TrackID string `json:"trackId"`
|
||||||
|
StreamID string `json:"streamId"`
|
||||||
|
Kind string `json:"kind"` // "audio" or "video"
|
||||||
|
}
|
||||||
|
|
||||||
|
// TrackRemovedData is sent when a track is removed
|
||||||
|
type TrackRemovedData struct {
|
||||||
|
PeerID string `json:"peerId"`
|
||||||
|
TrackID string `json:"trackId"`
|
||||||
|
Kind string `json:"kind"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// TURNCredentialsData provides TURN server credentials
|
||||||
|
type TURNCredentialsData struct {
|
||||||
|
Username string `json:"username"`
|
||||||
|
Password string `json:"password"`
|
||||||
|
TTL int `json:"ttl"`
|
||||||
|
URIs []string `json:"uris"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ServerDrainingData warns clients the server is shutting down
|
||||||
|
type ServerDrainingData struct {
|
||||||
|
Reason string `json:"reason"`
|
||||||
|
TimeoutMs int `json:"timeoutMs"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ErrorData is sent when an error occurs
|
||||||
|
type ErrorData struct {
|
||||||
|
Code string `json:"code"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewServerMessage creates a new server message
|
||||||
|
func NewServerMessage(msgType MessageType, data interface{}) *ServerMessage {
|
||||||
|
return &ServerMessage{Type: msgType, Data: data}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewErrorMessage creates a new error message
|
||||||
|
func NewErrorMessage(code, message string) *ServerMessage {
|
||||||
|
return NewServerMessage(MessageTypeError, &ErrorData{Code: code, Message: message})
|
||||||
|
}
|
||||||
257
pkg/sfu/signaling_test.go
Normal file
257
pkg/sfu/signaling_test.go
Normal file
@ -0,0 +1,257 @@
|
|||||||
|
package sfu
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestClientMessageDeserialization(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
wantType MessageType
|
||||||
|
wantData bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "join message",
|
||||||
|
input: `{"type":"join","data":{"roomId":"room-1","userId":"user-1"}}`,
|
||||||
|
wantType: MessageTypeJoin,
|
||||||
|
wantData: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "leave message",
|
||||||
|
input: `{"type":"leave"}`,
|
||||||
|
wantType: MessageTypeLeave,
|
||||||
|
wantData: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "offer message",
|
||||||
|
input: `{"type":"offer","data":{"sdp":"v=0..."}}`,
|
||||||
|
wantType: MessageTypeOffer,
|
||||||
|
wantData: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "answer message",
|
||||||
|
input: `{"type":"answer","data":{"sdp":"v=0..."}}`,
|
||||||
|
wantType: MessageTypeAnswer,
|
||||||
|
wantData: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ice-candidate message",
|
||||||
|
input: `{"type":"ice-candidate","data":{"candidate":"candidate:1234"}}`,
|
||||||
|
wantType: MessageTypeICECandidate,
|
||||||
|
wantData: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
var msg ClientMessage
|
||||||
|
if err := json.Unmarshal([]byte(tt.input), &msg); err != nil {
|
||||||
|
t.Fatalf("failed to unmarshal: %v", err)
|
||||||
|
}
|
||||||
|
if msg.Type != tt.wantType {
|
||||||
|
t.Errorf("Type = %q, want %q", msg.Type, tt.wantType)
|
||||||
|
}
|
||||||
|
if tt.wantData && msg.Data == nil {
|
||||||
|
t.Error("expected Data to be non-nil")
|
||||||
|
}
|
||||||
|
if !tt.wantData && msg.Data != nil {
|
||||||
|
t.Error("expected Data to be nil")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJoinDataDeserialization(t *testing.T) {
|
||||||
|
input := `{"roomId":"room-abc","userId":"user-xyz"}`
|
||||||
|
var data JoinData
|
||||||
|
if err := json.Unmarshal([]byte(input), &data); err != nil {
|
||||||
|
t.Fatalf("failed to unmarshal: %v", err)
|
||||||
|
}
|
||||||
|
if data.RoomID != "room-abc" {
|
||||||
|
t.Errorf("RoomID = %q, want %q", data.RoomID, "room-abc")
|
||||||
|
}
|
||||||
|
if data.UserID != "user-xyz" {
|
||||||
|
t.Errorf("UserID = %q, want %q", data.UserID, "user-xyz")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServerMessageSerialization(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
msg *ServerMessage
|
||||||
|
wantKey string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "welcome message",
|
||||||
|
msg: NewServerMessage(MessageTypeWelcome, &WelcomeData{PeerID: "p1", RoomID: "r1"}),
|
||||||
|
wantKey: "welcome",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "participant joined",
|
||||||
|
msg: NewServerMessage(MessageTypeParticipantJoined, &ParticipantJoinedData{Participant: ParticipantInfo{PeerID: "p2", UserID: "u2"}}),
|
||||||
|
wantKey: "participant-joined",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "participant left",
|
||||||
|
msg: NewServerMessage(MessageTypeParticipantLeft, &ParticipantLeftData{PeerID: "p2"}),
|
||||||
|
wantKey: "participant-left",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "track added",
|
||||||
|
msg: NewServerMessage(MessageTypeTrackAdded, &TrackAddedData{PeerID: "p1", TrackID: "t1", StreamID: "s1", Kind: "video"}),
|
||||||
|
wantKey: "track-added",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "track removed",
|
||||||
|
msg: NewServerMessage(MessageTypeTrackRemoved, &TrackRemovedData{PeerID: "p1", TrackID: "t1", Kind: "video"}),
|
||||||
|
wantKey: "track-removed",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "TURN credentials",
|
||||||
|
msg: NewServerMessage(MessageTypeTURNCredentials, &TURNCredentialsData{Username: "u", Password: "p", TTL: 600, URIs: []string{"turn:1.2.3.4:3478"}}),
|
||||||
|
wantKey: "turn-credentials",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "server draining",
|
||||||
|
msg: NewServerMessage(MessageTypeServerDraining, &ServerDrainingData{Reason: "shutdown", TimeoutMs: 30000}),
|
||||||
|
wantKey: "server-draining",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
data, err := json.Marshal(tt.msg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to marshal: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify it roundtrips correctly
|
||||||
|
var raw map[string]json.RawMessage
|
||||||
|
if err := json.Unmarshal(data, &raw); err != nil {
|
||||||
|
t.Fatalf("failed to unmarshal to raw: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var msgType string
|
||||||
|
if err := json.Unmarshal(raw["type"], &msgType); err != nil {
|
||||||
|
t.Fatalf("failed to unmarshal type: %v", err)
|
||||||
|
}
|
||||||
|
if msgType != tt.wantKey {
|
||||||
|
t.Errorf("type = %q, want %q", msgType, tt.wantKey)
|
||||||
|
}
|
||||||
|
if _, ok := raw["data"]; !ok {
|
||||||
|
t.Error("expected data field in output")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewErrorMessage(t *testing.T) {
|
||||||
|
msg := NewErrorMessage("invalid_offer", "bad SDP")
|
||||||
|
if msg.Type != MessageTypeError {
|
||||||
|
t.Errorf("Type = %q, want %q", msg.Type, MessageTypeError)
|
||||||
|
}
|
||||||
|
|
||||||
|
errData, ok := msg.Data.(*ErrorData)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("Data is not *ErrorData")
|
||||||
|
}
|
||||||
|
if errData.Code != "invalid_offer" {
|
||||||
|
t.Errorf("Code = %q, want %q", errData.Code, "invalid_offer")
|
||||||
|
}
|
||||||
|
if errData.Message != "bad SDP" {
|
||||||
|
t.Errorf("Message = %q, want %q", errData.Message, "bad SDP")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify serialization
|
||||||
|
data, err := json.Marshal(msg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to marshal: %v", err)
|
||||||
|
}
|
||||||
|
result := string(data)
|
||||||
|
if result == "" {
|
||||||
|
t.Error("expected non-empty serialized output")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestICECandidateDataToWebRTCCandidate(t *testing.T) {
|
||||||
|
data := &ICECandidateData{
|
||||||
|
Candidate: "candidate:842163049 1 udp 1677729535 203.0.113.1 3478 typ srflx",
|
||||||
|
SDPMid: "0",
|
||||||
|
SDPMLineIndex: 0,
|
||||||
|
UsernameFragment: "abc123",
|
||||||
|
}
|
||||||
|
|
||||||
|
candidate := data.ToWebRTCCandidate()
|
||||||
|
if candidate.Candidate != data.Candidate {
|
||||||
|
t.Errorf("Candidate = %q, want %q", candidate.Candidate, data.Candidate)
|
||||||
|
}
|
||||||
|
if candidate.SDPMid == nil || *candidate.SDPMid != "0" {
|
||||||
|
t.Error("SDPMid should be pointer to '0'")
|
||||||
|
}
|
||||||
|
if candidate.SDPMLineIndex == nil || *candidate.SDPMLineIndex != 0 {
|
||||||
|
t.Error("SDPMLineIndex should be pointer to 0")
|
||||||
|
}
|
||||||
|
if candidate.UsernameFragment == nil || *candidate.UsernameFragment != "abc123" {
|
||||||
|
t.Error("UsernameFragment should be pointer to 'abc123'")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWelcomeDataSerialization(t *testing.T) {
|
||||||
|
welcome := &WelcomeData{
|
||||||
|
PeerID: "peer-123",
|
||||||
|
RoomID: "room-456",
|
||||||
|
Participants: []ParticipantInfo{
|
||||||
|
{PeerID: "peer-001", UserID: "user-001"},
|
||||||
|
{PeerID: "peer-002", UserID: "user-002"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := json.Marshal(welcome)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to marshal: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result WelcomeData
|
||||||
|
if err := json.Unmarshal(data, &result); err != nil {
|
||||||
|
t.Fatalf("failed to unmarshal: %v", err)
|
||||||
|
}
|
||||||
|
if result.PeerID != "peer-123" {
|
||||||
|
t.Errorf("PeerID = %q, want %q", result.PeerID, "peer-123")
|
||||||
|
}
|
||||||
|
if result.RoomID != "room-456" {
|
||||||
|
t.Errorf("RoomID = %q, want %q", result.RoomID, "room-456")
|
||||||
|
}
|
||||||
|
if len(result.Participants) != 2 {
|
||||||
|
t.Errorf("Participants count = %d, want 2", len(result.Participants))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTURNCredentialsDataSerialization(t *testing.T) {
|
||||||
|
creds := &TURNCredentialsData{
|
||||||
|
Username: "1234567890:test-ns",
|
||||||
|
Password: "base64password==",
|
||||||
|
TTL: 600,
|
||||||
|
URIs: []string{"turn:1.2.3.4:3478?transport=udp", "turn:5.6.7.8:443?transport=udp"},
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := json.Marshal(creds)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to marshal: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var result TURNCredentialsData
|
||||||
|
if err := json.Unmarshal(data, &result); err != nil {
|
||||||
|
t.Fatalf("failed to unmarshal: %v", err)
|
||||||
|
}
|
||||||
|
if result.Username != creds.Username {
|
||||||
|
t.Errorf("Username = %q, want %q", result.Username, creds.Username)
|
||||||
|
}
|
||||||
|
if result.TTL != 600 {
|
||||||
|
t.Errorf("TTL = %d, want 600", result.TTL)
|
||||||
|
}
|
||||||
|
if len(result.URIs) != 2 {
|
||||||
|
t.Errorf("URIs count = %d, want 2", len(result.URIs))
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -17,6 +17,8 @@ const (
|
|||||||
ServiceTypeRQLite ServiceType = "rqlite"
|
ServiceTypeRQLite ServiceType = "rqlite"
|
||||||
ServiceTypeOlric ServiceType = "olric"
|
ServiceTypeOlric ServiceType = "olric"
|
||||||
ServiceTypeGateway ServiceType = "gateway"
|
ServiceTypeGateway ServiceType = "gateway"
|
||||||
|
ServiceTypeSFU ServiceType = "sfu"
|
||||||
|
ServiceTypeTURN ServiceType = "turn"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Manager manages systemd units for namespace services
|
// Manager manages systemd units for namespace services
|
||||||
@ -192,13 +194,33 @@ func (m *Manager) ReloadDaemon() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// serviceExists checks if a namespace service has an env file on disk,
|
||||||
|
// indicating the service was provisioned for this namespace.
|
||||||
|
func (m *Manager) serviceExists(namespace string, serviceType ServiceType) bool {
|
||||||
|
envFile := filepath.Join(m.namespaceBase, namespace, fmt.Sprintf("%s.env", serviceType))
|
||||||
|
_, err := os.Stat(envFile)
|
||||||
|
return err == nil
|
||||||
|
}
|
||||||
|
|
||||||
// StopAllNamespaceServices stops all namespace services for a given namespace
|
// StopAllNamespaceServices stops all namespace services for a given namespace
|
||||||
func (m *Manager) StopAllNamespaceServices(namespace string) error {
|
func (m *Manager) StopAllNamespaceServices(namespace string) error {
|
||||||
m.logger.Info("Stopping all namespace services", zap.String("namespace", namespace))
|
m.logger.Info("Stopping all namespace services", zap.String("namespace", namespace))
|
||||||
|
|
||||||
// Stop in reverse dependency order: Gateway → Olric → RQLite
|
// Stop in reverse dependency order: SFU → TURN → Gateway → Olric → RQLite
|
||||||
services := []ServiceType{ServiceTypeGateway, ServiceTypeOlric, ServiceTypeRQLite}
|
// SFU and TURN are conditional — only stop if they exist
|
||||||
for _, svcType := range services {
|
for _, svcType := range []ServiceType{ServiceTypeSFU, ServiceTypeTURN} {
|
||||||
|
if m.serviceExists(namespace, svcType) {
|
||||||
|
if err := m.StopService(namespace, svcType); err != nil {
|
||||||
|
m.logger.Warn("Failed to stop service",
|
||||||
|
zap.String("namespace", namespace),
|
||||||
|
zap.String("service_type", string(svcType)),
|
||||||
|
zap.Error(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Core services always exist
|
||||||
|
for _, svcType := range []ServiceType{ServiceTypeGateway, ServiceTypeOlric, ServiceTypeRQLite} {
|
||||||
if err := m.StopService(namespace, svcType); err != nil {
|
if err := m.StopService(namespace, svcType); err != nil {
|
||||||
m.logger.Warn("Failed to stop service",
|
m.logger.Warn("Failed to stop service",
|
||||||
zap.String("namespace", namespace),
|
zap.String("namespace", namespace),
|
||||||
@ -215,14 +237,22 @@ func (m *Manager) StopAllNamespaceServices(namespace string) error {
|
|||||||
func (m *Manager) StartAllNamespaceServices(namespace string) error {
|
func (m *Manager) StartAllNamespaceServices(namespace string) error {
|
||||||
m.logger.Info("Starting all namespace services", zap.String("namespace", namespace))
|
m.logger.Info("Starting all namespace services", zap.String("namespace", namespace))
|
||||||
|
|
||||||
// Start in dependency order: RQLite → Olric → Gateway
|
// Start core services in dependency order: RQLite → Olric → Gateway
|
||||||
services := []ServiceType{ServiceTypeRQLite, ServiceTypeOlric, ServiceTypeGateway}
|
for _, svcType := range []ServiceType{ServiceTypeRQLite, ServiceTypeOlric, ServiceTypeGateway} {
|
||||||
for _, svcType := range services {
|
|
||||||
if err := m.StartService(namespace, svcType); err != nil {
|
if err := m.StartService(namespace, svcType); err != nil {
|
||||||
return fmt.Errorf("failed to start %s service: %w", svcType, err)
|
return fmt.Errorf("failed to start %s service: %w", svcType, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Start WebRTC services if provisioned: TURN → SFU
|
||||||
|
for _, svcType := range []ServiceType{ServiceTypeTURN, ServiceTypeSFU} {
|
||||||
|
if m.serviceExists(namespace, svcType) {
|
||||||
|
if err := m.StartService(namespace, svcType); err != nil {
|
||||||
|
return fmt.Errorf("failed to start %s service: %w", svcType, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -419,6 +449,8 @@ func (m *Manager) InstallTemplateUnits(sourceDir string) error {
|
|||||||
"orama-namespace-rqlite@.service",
|
"orama-namespace-rqlite@.service",
|
||||||
"orama-namespace-olric@.service",
|
"orama-namespace-olric@.service",
|
||||||
"orama-namespace-gateway@.service",
|
"orama-namespace-gateway@.service",
|
||||||
|
"orama-namespace-sfu@.service",
|
||||||
|
"orama-namespace-turn@.service",
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, template := range templates {
|
for _, template := range templates {
|
||||||
|
|||||||
71
pkg/turn/config.go
Normal file
71
pkg/turn/config.go
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
package turn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Config holds configuration for the TURN server
|
||||||
|
type Config struct {
|
||||||
|
// ListenAddr is the address to bind the TURN listener (e.g., "0.0.0.0:3478")
|
||||||
|
ListenAddr string `yaml:"listen_addr"`
|
||||||
|
|
||||||
|
// TLSListenAddr is the address for TURN over TLS/DTLS (e.g., "0.0.0.0:443")
|
||||||
|
// Uses UDP 443 which does not conflict with Caddy's TCP 443
|
||||||
|
TLSListenAddr string `yaml:"tls_listen_addr"`
|
||||||
|
|
||||||
|
// PublicIP is the public IP address of this node, advertised in TURN allocations
|
||||||
|
PublicIP string `yaml:"public_ip"`
|
||||||
|
|
||||||
|
// Realm is the TURN realm (typically the base domain)
|
||||||
|
Realm string `yaml:"realm"`
|
||||||
|
|
||||||
|
// AuthSecret is the HMAC-SHA1 shared secret for credential validation
|
||||||
|
AuthSecret string `yaml:"auth_secret"`
|
||||||
|
|
||||||
|
// RelayPortStart is the beginning of the UDP relay port range
|
||||||
|
RelayPortStart int `yaml:"relay_port_start"`
|
||||||
|
|
||||||
|
// RelayPortEnd is the end of the UDP relay port range
|
||||||
|
RelayPortEnd int `yaml:"relay_port_end"`
|
||||||
|
|
||||||
|
// Namespace this TURN instance belongs to
|
||||||
|
Namespace string `yaml:"namespace"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate checks the TURN configuration for errors
|
||||||
|
func (c *Config) Validate() []error {
|
||||||
|
var errs []error
|
||||||
|
|
||||||
|
if c.ListenAddr == "" {
|
||||||
|
errs = append(errs, fmt.Errorf("turn.listen_addr: must not be empty"))
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.PublicIP == "" {
|
||||||
|
errs = append(errs, fmt.Errorf("turn.public_ip: must not be empty"))
|
||||||
|
} else if ip := net.ParseIP(c.PublicIP); ip == nil {
|
||||||
|
errs = append(errs, fmt.Errorf("turn.public_ip: %q is not a valid IP address", c.PublicIP))
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.Realm == "" {
|
||||||
|
errs = append(errs, fmt.Errorf("turn.realm: must not be empty"))
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.AuthSecret == "" {
|
||||||
|
errs = append(errs, fmt.Errorf("turn.auth_secret: must not be empty"))
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.RelayPortStart <= 0 || c.RelayPortEnd <= 0 {
|
||||||
|
errs = append(errs, fmt.Errorf("turn.relay_port_range: start and end must be positive"))
|
||||||
|
} else if c.RelayPortEnd <= c.RelayPortStart {
|
||||||
|
errs = append(errs, fmt.Errorf("turn.relay_port_range: end (%d) must be greater than start (%d)", c.RelayPortEnd, c.RelayPortStart))
|
||||||
|
} else if c.RelayPortEnd-c.RelayPortStart < 100 {
|
||||||
|
errs = append(errs, fmt.Errorf("turn.relay_port_range: range must be at least 100 ports (got %d)", c.RelayPortEnd-c.RelayPortStart))
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.Namespace == "" {
|
||||||
|
errs = append(errs, fmt.Errorf("turn.namespace: must not be empty"))
|
||||||
|
}
|
||||||
|
|
||||||
|
return errs
|
||||||
|
}
|
||||||
228
pkg/turn/server.go
Normal file
228
pkg/turn/server.go
Normal file
@ -0,0 +1,228 @@
|
|||||||
|
package turn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/hmac"
|
||||||
|
"crypto/sha1"
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
pionTurn "github.com/pion/turn/v4"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Server wraps a Pion TURN server with namespace-scoped HMAC-SHA1 authentication.
|
||||||
|
type Server struct {
|
||||||
|
config *Config
|
||||||
|
logger *zap.Logger
|
||||||
|
turnServer *pionTurn.Server
|
||||||
|
conn net.PacketConn // UDP listener on primary port (3478)
|
||||||
|
tlsConn net.PacketConn // UDP listener on TLS port (443)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewServer creates and starts a TURN server.
|
||||||
|
func NewServer(cfg *Config, logger *zap.Logger) (*Server, error) {
|
||||||
|
if errs := cfg.Validate(); len(errs) > 0 {
|
||||||
|
return nil, fmt.Errorf("invalid TURN config: %v", errs[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
relayIP := net.ParseIP(cfg.PublicIP)
|
||||||
|
if relayIP == nil {
|
||||||
|
return nil, fmt.Errorf("turn.public_ip: %q is not a valid IP address", cfg.PublicIP)
|
||||||
|
}
|
||||||
|
|
||||||
|
s := &Server{
|
||||||
|
config: cfg,
|
||||||
|
logger: logger.With(zap.String("component", "turn"), zap.String("namespace", cfg.Namespace)),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create primary UDP listener (port 3478)
|
||||||
|
conn, err := net.ListenPacket("udp4", cfg.ListenAddr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to listen on %s: %w", cfg.ListenAddr, err)
|
||||||
|
}
|
||||||
|
s.conn = conn
|
||||||
|
|
||||||
|
packetConfigs := []pionTurn.PacketConnConfig{
|
||||||
|
{
|
||||||
|
PacketConn: conn,
|
||||||
|
RelayAddressGenerator: &pionTurn.RelayAddressGeneratorPortRange{
|
||||||
|
RelayAddress: relayIP,
|
||||||
|
Address: "0.0.0.0",
|
||||||
|
MinPort: uint16(cfg.RelayPortStart),
|
||||||
|
MaxPort: uint16(cfg.RelayPortEnd),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create TLS UDP listener (port 443) if configured
|
||||||
|
// UDP 443 does not conflict with Caddy's TCP 443
|
||||||
|
if cfg.TLSListenAddr != "" {
|
||||||
|
tlsConn, err := net.ListenPacket("udp4", cfg.TLSListenAddr)
|
||||||
|
if err != nil {
|
||||||
|
conn.Close()
|
||||||
|
return nil, fmt.Errorf("failed to listen on %s: %w", cfg.TLSListenAddr, err)
|
||||||
|
}
|
||||||
|
s.tlsConn = tlsConn
|
||||||
|
|
||||||
|
packetConfigs = append(packetConfigs, pionTurn.PacketConnConfig{
|
||||||
|
PacketConn: tlsConn,
|
||||||
|
RelayAddressGenerator: &pionTurn.RelayAddressGeneratorPortRange{
|
||||||
|
RelayAddress: relayIP,
|
||||||
|
Address: "0.0.0.0",
|
||||||
|
MinPort: uint16(cfg.RelayPortStart),
|
||||||
|
MaxPort: uint16(cfg.RelayPortEnd),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create TURN server with HMAC-SHA1 auth
|
||||||
|
turnServer, err := pionTurn.NewServer(pionTurn.ServerConfig{
|
||||||
|
Realm: cfg.Realm,
|
||||||
|
AuthHandler: func(username, realm string, srcAddr net.Addr) ([]byte, bool) {
|
||||||
|
return s.authHandler(username, realm, srcAddr)
|
||||||
|
},
|
||||||
|
PacketConnConfigs: packetConfigs,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
s.closeListeners()
|
||||||
|
return nil, fmt.Errorf("failed to create TURN server: %w", err)
|
||||||
|
}
|
||||||
|
s.turnServer = turnServer
|
||||||
|
|
||||||
|
s.logger.Info("TURN server started",
|
||||||
|
zap.String("listen_addr", cfg.ListenAddr),
|
||||||
|
zap.String("tls_listen_addr", cfg.TLSListenAddr),
|
||||||
|
zap.String("public_ip", cfg.PublicIP),
|
||||||
|
zap.String("realm", cfg.Realm),
|
||||||
|
zap.Int("relay_port_start", cfg.RelayPortStart),
|
||||||
|
zap.Int("relay_port_end", cfg.RelayPortEnd),
|
||||||
|
)
|
||||||
|
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// authHandler validates HMAC-SHA1 credentials.
|
||||||
|
// Username format: {expiry_unix}:{namespace}
|
||||||
|
// Password: base64(HMAC-SHA1(shared_secret, username))
|
||||||
|
func (s *Server) authHandler(username, realm string, srcAddr net.Addr) ([]byte, bool) {
|
||||||
|
// Parse username: must be "{timestamp}:{namespace}"
|
||||||
|
parts := strings.SplitN(username, ":", 2)
|
||||||
|
if len(parts) != 2 {
|
||||||
|
s.logger.Debug("Malformed TURN username: expected timestamp:namespace",
|
||||||
|
zap.String("username", username),
|
||||||
|
zap.String("src_addr", srcAddr.String()))
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
timestamp, err := strconv.ParseInt(parts[0], 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Debug("Invalid timestamp in TURN username",
|
||||||
|
zap.String("username", username),
|
||||||
|
zap.String("src_addr", srcAddr.String()))
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
ns := parts[1]
|
||||||
|
|
||||||
|
// Verify namespace matches this TURN server's namespace
|
||||||
|
if ns != s.config.Namespace {
|
||||||
|
s.logger.Debug("TURN credential namespace mismatch",
|
||||||
|
zap.String("credential_namespace", ns),
|
||||||
|
zap.String("server_namespace", s.config.Namespace),
|
||||||
|
zap.String("src_addr", srcAddr.String()))
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check expiry — credential must not be expired
|
||||||
|
if timestamp <= time.Now().Unix() {
|
||||||
|
s.logger.Debug("TURN credential expired",
|
||||||
|
zap.String("username", username),
|
||||||
|
zap.Int64("expired_at", timestamp),
|
||||||
|
zap.String("src_addr", srcAddr.String()))
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate expected password and derive auth key
|
||||||
|
password := GeneratePassword(s.config.AuthSecret, username)
|
||||||
|
key := pionTurn.GenerateAuthKey(username, realm, password)
|
||||||
|
|
||||||
|
s.logger.Debug("TURN auth accepted",
|
||||||
|
zap.String("namespace", ns),
|
||||||
|
zap.String("src_addr", srcAddr.String()))
|
||||||
|
|
||||||
|
return key, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close gracefully shuts down the TURN server.
|
||||||
|
func (s *Server) Close() error {
|
||||||
|
s.logger.Info("Stopping TURN server")
|
||||||
|
|
||||||
|
if s.turnServer != nil {
|
||||||
|
if err := s.turnServer.Close(); err != nil {
|
||||||
|
s.logger.Warn("Error closing TURN server", zap.Error(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
s.closeListeners()
|
||||||
|
|
||||||
|
s.logger.Info("TURN server stopped")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) closeListeners() {
|
||||||
|
if s.conn != nil {
|
||||||
|
s.conn.Close()
|
||||||
|
s.conn = nil
|
||||||
|
}
|
||||||
|
if s.tlsConn != nil {
|
||||||
|
s.tlsConn.Close()
|
||||||
|
s.tlsConn = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateCredentials creates time-limited HMAC-SHA1 TURN credentials.
|
||||||
|
// Returns username and password suitable for WebRTC ICE server configuration.
|
||||||
|
func GenerateCredentials(secret, namespace string, ttl time.Duration) (username, password string) {
|
||||||
|
expiry := time.Now().Add(ttl).Unix()
|
||||||
|
username = fmt.Sprintf("%d:%s", expiry, namespace)
|
||||||
|
password = GeneratePassword(secret, username)
|
||||||
|
return username, password
|
||||||
|
}
|
||||||
|
|
||||||
|
// GeneratePassword computes the HMAC-SHA1 password for a TURN username.
|
||||||
|
func GeneratePassword(secret, username string) string {
|
||||||
|
h := hmac.New(sha1.New, []byte(secret))
|
||||||
|
h.Write([]byte(username))
|
||||||
|
return base64.StdEncoding.EncodeToString(h.Sum(nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateCredentials checks if TURN credentials are valid and not expired.
|
||||||
|
func ValidateCredentials(secret, username, password, expectedNamespace string) bool {
|
||||||
|
parts := strings.SplitN(username, ":", 2)
|
||||||
|
if len(parts) != 2 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
timestamp, err := strconv.ParseInt(parts[0], 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check namespace
|
||||||
|
if parts[1] != expectedNamespace {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check expiry
|
||||||
|
if timestamp <= time.Now().Unix() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check password
|
||||||
|
expected := GeneratePassword(secret, username)
|
||||||
|
return hmac.Equal([]byte(password), []byte(expected))
|
||||||
|
}
|
||||||
225
pkg/turn/server_test.go
Normal file
225
pkg/turn/server_test.go
Normal file
@ -0,0 +1,225 @@
|
|||||||
|
package turn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGenerateCredentials(t *testing.T) {
|
||||||
|
secret := "test-secret-key-32bytes-long!!!!"
|
||||||
|
namespace := "test-namespace"
|
||||||
|
ttl := 10 * time.Minute
|
||||||
|
|
||||||
|
username, password := GenerateCredentials(secret, namespace, ttl)
|
||||||
|
|
||||||
|
if username == "" {
|
||||||
|
t.Fatal("username should not be empty")
|
||||||
|
}
|
||||||
|
if password == "" {
|
||||||
|
t.Fatal("password should not be empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Username should be "{timestamp}:{namespace}"
|
||||||
|
var ts int64
|
||||||
|
var ns string
|
||||||
|
n, err := fmt.Sscanf(username, "%d:%s", &ts, &ns)
|
||||||
|
if err != nil || n != 2 {
|
||||||
|
t.Fatalf("username format should be timestamp:namespace, got %q", username)
|
||||||
|
}
|
||||||
|
|
||||||
|
if ns != namespace {
|
||||||
|
t.Fatalf("namespace in username should be %q, got %q", namespace, ns)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Timestamp should be ~10 minutes in the future
|
||||||
|
now := time.Now().Unix()
|
||||||
|
expectedExpiry := now + int64(ttl.Seconds())
|
||||||
|
if ts < expectedExpiry-2 || ts > expectedExpiry+2 {
|
||||||
|
t.Fatalf("expiry timestamp should be ~%d, got %d", expectedExpiry, ts)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGeneratePassword(t *testing.T) {
|
||||||
|
secret := "test-secret"
|
||||||
|
username := "1234567890:test-ns"
|
||||||
|
|
||||||
|
password1 := GeneratePassword(secret, username)
|
||||||
|
password2 := GeneratePassword(secret, username)
|
||||||
|
|
||||||
|
// Same inputs should produce same output
|
||||||
|
if password1 != password2 {
|
||||||
|
t.Fatal("GeneratePassword should be deterministic")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Different secret should produce different output
|
||||||
|
password3 := GeneratePassword("different-secret", username)
|
||||||
|
if password1 == password3 {
|
||||||
|
t.Fatal("different secrets should produce different passwords")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Different username should produce different output
|
||||||
|
password4 := GeneratePassword(secret, "9999999999:other-ns")
|
||||||
|
if password1 == password4 {
|
||||||
|
t.Fatal("different usernames should produce different passwords")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateCredentials(t *testing.T) {
|
||||||
|
secret := "test-secret-key"
|
||||||
|
namespace := "my-namespace"
|
||||||
|
ttl := 10 * time.Minute
|
||||||
|
|
||||||
|
// Generate valid credentials
|
||||||
|
username, password := GenerateCredentials(secret, namespace, ttl)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
secret string
|
||||||
|
username string
|
||||||
|
password string
|
||||||
|
namespace string
|
||||||
|
wantValid bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid credentials",
|
||||||
|
secret: secret,
|
||||||
|
username: username,
|
||||||
|
password: password,
|
||||||
|
namespace: namespace,
|
||||||
|
wantValid: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wrong secret",
|
||||||
|
secret: "wrong-secret",
|
||||||
|
username: username,
|
||||||
|
password: password,
|
||||||
|
namespace: namespace,
|
||||||
|
wantValid: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wrong password",
|
||||||
|
secret: secret,
|
||||||
|
username: username,
|
||||||
|
password: "wrongpassword",
|
||||||
|
namespace: namespace,
|
||||||
|
wantValid: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wrong namespace",
|
||||||
|
secret: secret,
|
||||||
|
username: username,
|
||||||
|
password: password,
|
||||||
|
namespace: "other-namespace",
|
||||||
|
wantValid: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "expired credentials",
|
||||||
|
secret: secret,
|
||||||
|
username: fmt.Sprintf("%d:%s", time.Now().Unix()-60, namespace),
|
||||||
|
password: GeneratePassword(secret, fmt.Sprintf("%d:%s", time.Now().Unix()-60, namespace)),
|
||||||
|
namespace: namespace,
|
||||||
|
wantValid: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "malformed username - no colon",
|
||||||
|
secret: secret,
|
||||||
|
username: "badusername",
|
||||||
|
password: "whatever",
|
||||||
|
namespace: namespace,
|
||||||
|
wantValid: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "malformed username - non-numeric timestamp",
|
||||||
|
secret: secret,
|
||||||
|
username: "notanumber:my-namespace",
|
||||||
|
password: "whatever",
|
||||||
|
namespace: namespace,
|
||||||
|
wantValid: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := ValidateCredentials(tt.secret, tt.username, tt.password, tt.namespace)
|
||||||
|
if got != tt.wantValid {
|
||||||
|
t.Errorf("ValidateCredentials() = %v, want %v", got, tt.wantValid)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfigValidation(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
config Config
|
||||||
|
wantErrs int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid config",
|
||||||
|
config: Config{
|
||||||
|
ListenAddr: "0.0.0.0:3478",
|
||||||
|
PublicIP: "1.2.3.4",
|
||||||
|
Realm: "dbrs.space",
|
||||||
|
AuthSecret: "secret123",
|
||||||
|
RelayPortStart: 49152,
|
||||||
|
RelayPortEnd: 50000,
|
||||||
|
Namespace: "test-ns",
|
||||||
|
},
|
||||||
|
wantErrs: 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing all fields",
|
||||||
|
config: Config{},
|
||||||
|
wantErrs: 6, // listen_addr, public_ip, realm, auth_secret, relay_port_range, namespace
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid public IP",
|
||||||
|
config: Config{
|
||||||
|
ListenAddr: "0.0.0.0:3478",
|
||||||
|
PublicIP: "not-an-ip",
|
||||||
|
Realm: "dbrs.space",
|
||||||
|
AuthSecret: "secret",
|
||||||
|
RelayPortStart: 49152,
|
||||||
|
RelayPortEnd: 50000,
|
||||||
|
Namespace: "test-ns",
|
||||||
|
},
|
||||||
|
wantErrs: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "relay range too small",
|
||||||
|
config: Config{
|
||||||
|
ListenAddr: "0.0.0.0:3478",
|
||||||
|
PublicIP: "1.2.3.4",
|
||||||
|
Realm: "dbrs.space",
|
||||||
|
AuthSecret: "secret",
|
||||||
|
RelayPortStart: 49152,
|
||||||
|
RelayPortEnd: 49200,
|
||||||
|
Namespace: "test-ns",
|
||||||
|
},
|
||||||
|
wantErrs: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "relay range inverted",
|
||||||
|
config: Config{
|
||||||
|
ListenAddr: "0.0.0.0:3478",
|
||||||
|
PublicIP: "1.2.3.4",
|
||||||
|
Realm: "dbrs.space",
|
||||||
|
AuthSecret: "secret",
|
||||||
|
RelayPortStart: 50000,
|
||||||
|
RelayPortEnd: 49152,
|
||||||
|
Namespace: "test-ns",
|
||||||
|
},
|
||||||
|
wantErrs: 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
errs := tt.config.Validate()
|
||||||
|
if len(errs) != tt.wantErrs {
|
||||||
|
t.Errorf("Validate() returned %d errors, want %d: %v", len(errs), tt.wantErrs, errs)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
32
systemd/orama-namespace-sfu@.service
Normal file
32
systemd/orama-namespace-sfu@.service
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
[Unit]
|
||||||
|
Description=Orama Namespace SFU (%i)
|
||||||
|
Documentation=https://github.com/DeBrosOfficial/network
|
||||||
|
After=network.target orama-namespace-olric@%i.service
|
||||||
|
Wants=orama-namespace-olric@%i.service
|
||||||
|
PartOf=orama-node.service
|
||||||
|
|
||||||
|
[Service]
|
||||||
|
Type=simple
|
||||||
|
WorkingDirectory=/opt/orama
|
||||||
|
|
||||||
|
EnvironmentFile=/opt/orama/.orama/data/namespaces/%i/sfu.env
|
||||||
|
|
||||||
|
ExecStart=/bin/sh -c 'exec /opt/orama/bin/sfu --config ${SFU_CONFIG}'
|
||||||
|
|
||||||
|
TimeoutStopSec=45s
|
||||||
|
KillMode=mixed
|
||||||
|
KillSignal=SIGTERM
|
||||||
|
|
||||||
|
Restart=on-failure
|
||||||
|
RestartSec=5s
|
||||||
|
|
||||||
|
StandardOutput=journal
|
||||||
|
StandardError=journal
|
||||||
|
SyslogIdentifier=orama-sfu-%i
|
||||||
|
|
||||||
|
PrivateTmp=yes
|
||||||
|
LimitNOFILE=65536
|
||||||
|
MemoryMax=2G
|
||||||
|
|
||||||
|
[Install]
|
||||||
|
WantedBy=multi-user.target
|
||||||
31
systemd/orama-namespace-turn@.service
Normal file
31
systemd/orama-namespace-turn@.service
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
[Unit]
|
||||||
|
Description=Orama Namespace TURN (%i)
|
||||||
|
Documentation=https://github.com/DeBrosOfficial/network
|
||||||
|
After=network.target
|
||||||
|
PartOf=orama-node.service
|
||||||
|
|
||||||
|
[Service]
|
||||||
|
Type=simple
|
||||||
|
WorkingDirectory=/opt/orama
|
||||||
|
|
||||||
|
EnvironmentFile=/opt/orama/.orama/data/namespaces/%i/turn.env
|
||||||
|
|
||||||
|
ExecStart=/bin/sh -c 'exec /opt/orama/bin/turn --config ${TURN_CONFIG}'
|
||||||
|
|
||||||
|
TimeoutStopSec=30s
|
||||||
|
KillMode=mixed
|
||||||
|
KillSignal=SIGTERM
|
||||||
|
|
||||||
|
Restart=on-failure
|
||||||
|
RestartSec=5s
|
||||||
|
|
||||||
|
StandardOutput=journal
|
||||||
|
StandardError=journal
|
||||||
|
SyslogIdentifier=orama-turn-%i
|
||||||
|
|
||||||
|
PrivateTmp=yes
|
||||||
|
LimitNOFILE=65536
|
||||||
|
MemoryMax=1G
|
||||||
|
|
||||||
|
[Install]
|
||||||
|
WantedBy=multi-user.target
|
||||||
Loading…
x
Reference in New Issue
Block a user