This commit is contained in:
JohnySigma 2026-02-20 18:24:32 +02:00
parent ade6241357
commit ea48a21ae4
24 changed files with 5678 additions and 191 deletions

View File

@ -73,6 +73,27 @@ func parseGatewayConfig(logger *logging.ColoredLogger) *gateway.Config {
}
// Load YAML
type yamlICEServer struct {
URLs []string `yaml:"urls"`
Username string `yaml:"username,omitempty"`
Credential string `yaml:"credential,omitempty"`
}
type yamlTURN struct {
SharedSecret string `yaml:"shared_secret"`
TTL string `yaml:"ttl"`
ExternalHost string `yaml:"external_host"`
STUNURLs []string `yaml:"stun_urls"`
TURNURLs []string `yaml:"turn_urls"`
}
type yamlSFU struct {
Enabled bool `yaml:"enabled"`
MaxParticipants int `yaml:"max_participants"`
MediaTimeout string `yaml:"media_timeout"`
ICEServers []yamlICEServer `yaml:"ice_servers"`
}
type yamlCfg struct {
ListenAddr string `yaml:"listen_addr"`
ClientNamespace string `yaml:"client_namespace"`
@ -87,6 +108,8 @@ func parseGatewayConfig(logger *logging.ColoredLogger) *gateway.Config {
IPFSAPIURL string `yaml:"ipfs_api_url"`
IPFSTimeout string `yaml:"ipfs_timeout"`
IPFSReplicationFactor int `yaml:"ipfs_replication_factor"`
TURN yamlTURN `yaml:"turn"`
SFU yamlSFU `yaml:"sfu"`
}
data, err := os.ReadFile(configPath)
@ -191,6 +214,64 @@ func parseGatewayConfig(logger *logging.ColoredLogger) *gateway.Config {
cfg.IPFSReplicationFactor = y.IPFSReplicationFactor
}
// TURN configuration
if y.TURN.SharedSecret != "" || len(y.TURN.STUNURLs) > 0 || len(y.TURN.TURNURLs) > 0 {
turnCfg := &config.TURNConfig{
SharedSecret: y.TURN.SharedSecret,
ExternalHost: y.TURN.ExternalHost,
STUNURLs: y.TURN.STUNURLs,
TURNURLs: y.TURN.TURNURLs,
}
// Check for environment variable overrides
if envSecret := os.Getenv("TURN_SHARED_SECRET"); envSecret != "" {
turnCfg.SharedSecret = envSecret
}
if envHost := os.Getenv("TURN_EXTERNAL_HOST"); envHost != "" {
turnCfg.ExternalHost = envHost
}
if v := strings.TrimSpace(y.TURN.TTL); v != "" {
if parsed, err := time.ParseDuration(v); err == nil {
turnCfg.TTL = parsed
} else {
logger.ComponentWarn(logging.ComponentGeneral, "invalid turn.ttl, using default", zap.String("value", v), zap.Error(err))
}
}
cfg.TURN = turnCfg
logger.ComponentInfo(logging.ComponentGeneral, "TURN configuration loaded",
zap.Int("stun_urls", len(turnCfg.STUNURLs)),
zap.Int("turn_urls", len(turnCfg.TURNURLs)),
zap.String("external_host", turnCfg.ExternalHost),
)
}
// SFU configuration
if y.SFU.Enabled {
sfuCfg := &config.SFUConfig{
Enabled: true,
MaxParticipants: y.SFU.MaxParticipants,
}
if v := strings.TrimSpace(y.SFU.MediaTimeout); v != "" {
if parsed, err := time.ParseDuration(v); err == nil {
sfuCfg.MediaTimeout = parsed
} else {
logger.ComponentWarn(logging.ComponentGeneral, "invalid sfu.media_timeout, using default", zap.String("value", v), zap.Error(err))
}
}
// Parse ICE servers
for _, iceServer := range y.SFU.ICEServers {
sfuCfg.ICEServers = append(sfuCfg.ICEServers, config.ICEServerConfig{
URLs: iceServer.URLs,
Username: iceServer.Username,
Credential: iceServer.Credential,
})
}
cfg.SFU = sfuCfg
logger.ComponentInfo(logging.ComponentGeneral, "SFU configuration loaded",
zap.Int("max_participants", sfuCfg.MaxParticipants),
zap.Int("ice_servers", len(sfuCfg.ICEServers)),
)
}
// Validate configuration
if errs := cfg.ValidateConfig(); len(errs) > 0 {
fmt.Fprintf(os.Stderr, "\nGateway configuration errors (%d):\n", len(errs))

8
go.mod
View File

@ -18,6 +18,10 @@ require (
github.com/mattn/go-sqlite3 v1.14.32
github.com/multiformats/go-multiaddr v0.15.0
github.com/olric-data/olric v0.7.0
github.com/pion/interceptor v0.1.37
github.com/pion/rtcp v1.2.15
github.com/pion/turn/v4 v4.0.0
github.com/pion/webrtc/v4 v4.0.10
github.com/rqlite/gorqlite v0.0.0-20250609141355-ac86a4a1c9a8
github.com/tetratelabs/wazero v1.11.0
go.uber.org/zap v1.27.0
@ -113,11 +117,9 @@ require (
github.com/pion/dtls/v2 v2.2.12 // indirect
github.com/pion/dtls/v3 v3.0.4 // indirect
github.com/pion/ice/v4 v4.0.8 // indirect
github.com/pion/interceptor v0.1.37 // indirect
github.com/pion/logging v0.2.3 // indirect
github.com/pion/mdns/v2 v2.0.7 // indirect
github.com/pion/randutil v0.1.0 // indirect
github.com/pion/rtcp v1.2.15 // indirect
github.com/pion/rtp v1.8.11 // indirect
github.com/pion/sctp v1.8.37 // indirect
github.com/pion/sdp/v3 v3.0.10 // indirect
@ -126,8 +128,6 @@ require (
github.com/pion/stun/v3 v3.0.0 // indirect
github.com/pion/transport/v2 v2.2.10 // indirect
github.com/pion/transport/v3 v3.0.7 // indirect
github.com/pion/turn/v4 v4.0.0 // indirect
github.com/pion/webrtc/v4 v4.0.10 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/prometheus/client_golang v1.22.0 // indirect
github.com/prometheus/client_model v0.6.2 // indirect

View File

@ -3,7 +3,6 @@ package config
import (
"time"
"github.com/DeBrosOfficial/network/pkg/config/validate"
"github.com/multiformats/go-multiaddr"
)
@ -15,69 +14,248 @@ type Config struct {
Security SecurityConfig `yaml:"security"`
Logging LoggingConfig `yaml:"logging"`
HTTPGateway HTTPGatewayConfig `yaml:"http_gateway"`
TURNServer TURNServerConfig `yaml:"turn_server"` // Built-in TURN server
}
// ValidationError represents a single validation error with context.
// This is exported from the validate subpackage for backward compatibility.
type ValidationError = validate.ValidationError
// ValidateSwarmKey validates that a swarm key is 64 hex characters.
// This is exported from the validate subpackage for backward compatibility.
func ValidateSwarmKey(key string) error {
return validate.ValidateSwarmKey(key)
// NodeConfig contains node-specific configuration
type NodeConfig struct {
ID string `yaml:"id"` // Auto-generated if empty
ListenAddresses []string `yaml:"listen_addresses"` // LibP2P listen addresses
DataDir string `yaml:"data_dir"` // Data directory
MaxConnections int `yaml:"max_connections"` // Maximum peer connections
Domain string `yaml:"domain"` // Domain for this node (e.g., node-1.orama.network)
}
// Validate performs comprehensive validation of the entire config.
// It aggregates all errors and returns them, allowing the caller to print all issues at once.
func (c *Config) Validate() []error {
var errs []error
// DatabaseConfig contains database-related configuration
type DatabaseConfig struct {
DataDir string `yaml:"data_dir"`
ReplicationFactor int `yaml:"replication_factor"`
ShardCount int `yaml:"shard_count"`
MaxDatabaseSize int64 `yaml:"max_database_size"` // In bytes
BackupInterval time.Duration `yaml:"backup_interval"`
// Validate node config
errs = append(errs, validate.ValidateNode(validate.NodeConfig{
ID: c.Node.ID,
ListenAddresses: c.Node.ListenAddresses,
DataDir: c.Node.DataDir,
MaxConnections: c.Node.MaxConnections,
})...)
// RQLite-specific configuration
RQLitePort int `yaml:"rqlite_port"` // RQLite HTTP API port
RQLiteRaftPort int `yaml:"rqlite_raft_port"` // RQLite Raft consensus port
RQLiteJoinAddress string `yaml:"rqlite_join_address"` // Address to join RQLite cluster
// Validate database config
errs = append(errs, validate.ValidateDatabase(validate.DatabaseConfig{
DataDir: c.Database.DataDir,
ReplicationFactor: c.Database.ReplicationFactor,
ShardCount: c.Database.ShardCount,
MaxDatabaseSize: c.Database.MaxDatabaseSize,
RQLitePort: c.Database.RQLitePort,
RQLiteRaftPort: c.Database.RQLiteRaftPort,
RQLiteJoinAddress: c.Database.RQLiteJoinAddress,
ClusterSyncInterval: c.Database.ClusterSyncInterval,
PeerInactivityLimit: c.Database.PeerInactivityLimit,
MinClusterSize: c.Database.MinClusterSize,
})...)
// RQLite node-to-node TLS encryption (for inter-node Raft communication)
// See: https://rqlite.io/docs/guides/security/#encrypting-node-to-node-communication
NodeCert string `yaml:"node_cert"` // Path to X.509 certificate for node-to-node communication
NodeKey string `yaml:"node_key"` // Path to X.509 private key for node-to-node communication
NodeCACert string `yaml:"node_ca_cert"` // Path to CA certificate (optional, uses system CA if not set)
NodeNoVerify bool `yaml:"node_no_verify"` // Skip certificate verification (for testing/self-signed certs)
// Validate discovery config
errs = append(errs, validate.ValidateDiscovery(validate.DiscoveryConfig{
BootstrapPeers: c.Discovery.BootstrapPeers,
DiscoveryInterval: c.Discovery.DiscoveryInterval,
BootstrapPort: c.Discovery.BootstrapPort,
HttpAdvAddress: c.Discovery.HttpAdvAddress,
RaftAdvAddress: c.Discovery.RaftAdvAddress,
})...)
// Dynamic discovery configuration (always enabled)
ClusterSyncInterval time.Duration `yaml:"cluster_sync_interval"` // default: 30s
PeerInactivityLimit time.Duration `yaml:"peer_inactivity_limit"` // default: 24h
MinClusterSize int `yaml:"min_cluster_size"` // default: 1
// Validate security config
errs = append(errs, validate.ValidateSecurity(validate.SecurityConfig{
EnableTLS: c.Security.EnableTLS,
PrivateKeyFile: c.Security.PrivateKeyFile,
CertificateFile: c.Security.CertificateFile,
})...)
// Olric cache configuration
OlricHTTPPort int `yaml:"olric_http_port"` // Olric HTTP API port (default: 3320)
OlricMemberlistPort int `yaml:"olric_memberlist_port"` // Olric memberlist port (default: 3322)
// Validate logging config
errs = append(errs, validate.ValidateLogging(validate.LoggingConfig{
Level: c.Logging.Level,
Format: c.Logging.Format,
OutputFile: c.Logging.OutputFile,
})...)
// IPFS storage configuration
IPFS IPFSConfig `yaml:"ipfs"`
}
return errs
// IPFSConfig contains IPFS storage configuration
type IPFSConfig struct {
// ClusterAPIURL is the IPFS Cluster HTTP API URL (e.g., "http://localhost:9094")
// If empty, IPFS storage is disabled for this node
ClusterAPIURL string `yaml:"cluster_api_url"`
// APIURL is the IPFS HTTP API URL for content retrieval (e.g., "http://localhost:5001")
// If empty, defaults to "http://localhost:5001"
APIURL string `yaml:"api_url"`
// Timeout for IPFS operations
// If zero, defaults to 60 seconds
Timeout time.Duration `yaml:"timeout"`
// ReplicationFactor is the replication factor for pinned content
// If zero, defaults to 3
ReplicationFactor int `yaml:"replication_factor"`
// EnableEncryption enables client-side encryption before upload
// Defaults to true
EnableEncryption bool `yaml:"enable_encryption"`
}
// DiscoveryConfig contains peer discovery configuration
type DiscoveryConfig struct {
BootstrapPeers []string `yaml:"bootstrap_peers"` // Peer addresses to connect to
DiscoveryInterval time.Duration `yaml:"discovery_interval"` // Discovery announcement interval
BootstrapPort int `yaml:"bootstrap_port"` // Default port for peer discovery
HttpAdvAddress string `yaml:"http_adv_address"` // HTTP advertisement address
RaftAdvAddress string `yaml:"raft_adv_address"` // Raft advertisement
NodeNamespace string `yaml:"node_namespace"` // Namespace for node identifiers
}
// SecurityConfig contains security-related configuration
type SecurityConfig struct {
EnableTLS bool `yaml:"enable_tls"`
PrivateKeyFile string `yaml:"private_key_file"`
CertificateFile string `yaml:"certificate_file"`
}
// LoggingConfig contains logging configuration
type LoggingConfig struct {
Level string `yaml:"level"` // debug, info, warn, error
Format string `yaml:"format"` // json, console
OutputFile string `yaml:"output_file"` // Empty for stdout
}
// HTTPGatewayConfig contains HTTP reverse proxy gateway configuration
type HTTPGatewayConfig struct {
Enabled bool `yaml:"enabled"` // Enable HTTP gateway
ListenAddr string `yaml:"listen_addr"` // Address to listen on (e.g., ":8080")
NodeName string `yaml:"node_name"` // Node name for routing
Routes map[string]RouteConfig `yaml:"routes"` // Service routes
HTTPS HTTPSConfig `yaml:"https"` // HTTPS/TLS configuration
SNI SNIConfig `yaml:"sni"` // SNI-based TCP routing configuration
// Full gateway configuration (for API, auth, pubsub)
ClientNamespace string `yaml:"client_namespace"` // Namespace for network client
RQLiteDSN string `yaml:"rqlite_dsn"` // RQLite database DSN
OlricServers []string `yaml:"olric_servers"` // List of Olric server addresses
OlricTimeout time.Duration `yaml:"olric_timeout"` // Timeout for Olric operations
IPFSClusterAPIURL string `yaml:"ipfs_cluster_api_url"` // IPFS Cluster API URL
IPFSAPIURL string `yaml:"ipfs_api_url"` // IPFS API URL
IPFSTimeout time.Duration `yaml:"ipfs_timeout"` // Timeout for IPFS operations
// WebRTC configuration for video/audio calls
TURN *TURNConfig `yaml:"turn"` // TURN/STUN server configuration
SFU *SFUConfig `yaml:"sfu"` // SFU (Selective Forwarding Unit) configuration
}
// HTTPSConfig contains HTTPS/TLS configuration for the gateway
type HTTPSConfig struct {
Enabled bool `yaml:"enabled"` // Enable HTTPS (port 443)
Domain string `yaml:"domain"` // Primary domain (e.g., node-123.orama.network)
AutoCert bool `yaml:"auto_cert"` // Use Let's Encrypt for automatic certificate
UseSelfSigned bool `yaml:"use_self_signed"` // Use self-signed certificates (pre-generated)
CertFile string `yaml:"cert_file"` // Path to certificate file (if not using auto_cert)
KeyFile string `yaml:"key_file"` // Path to key file (if not using auto_cert)
CacheDir string `yaml:"cache_dir"` // Directory for Let's Encrypt certificate cache
HTTPPort int `yaml:"http_port"` // HTTP port for ACME challenge (default: 80)
HTTPSPort int `yaml:"https_port"` // HTTPS port (default: 443)
Email string `yaml:"email"` // Email for Let's Encrypt account
}
// SNIConfig contains SNI-based TCP routing configuration for port 7001
type SNIConfig struct {
Enabled bool `yaml:"enabled"` // Enable SNI-based TCP routing
ListenAddr string `yaml:"listen_addr"` // Address to listen on (e.g., ":7001")
Routes map[string]string `yaml:"routes"` // SNI hostname -> backend address mapping
CertFile string `yaml:"cert_file"` // Path to certificate file
KeyFile string `yaml:"key_file"` // Path to key file
}
// RouteConfig defines a single reverse proxy route
type RouteConfig struct {
PathPrefix string `yaml:"path_prefix"` // URL path prefix (e.g., "/rqlite/http")
BackendURL string `yaml:"backend_url"` // Backend service URL
Timeout time.Duration `yaml:"timeout"` // Request timeout
WebSocket bool `yaml:"websocket"` // Support WebSocket upgrades
}
// ClientConfig represents configuration for network clients
type ClientConfig struct {
AppName string `yaml:"app_name"`
DatabaseName string `yaml:"database_name"`
BootstrapPeers []string `yaml:"bootstrap_peers"`
ConnectTimeout time.Duration `yaml:"connect_timeout"`
RetryAttempts int `yaml:"retry_attempts"`
}
// TURNConfig contains TURN/STUN server credential configuration
type TURNConfig struct {
// SharedSecret is the shared secret for TURN credential generation (HMAC-SHA1)
// Should be set via TURN_SHARED_SECRET environment variable
SharedSecret string `yaml:"shared_secret"`
// TTL is the time-to-live for generated credentials
// Default: 24 hours
TTL time.Duration `yaml:"ttl"`
// ExternalHost is the external hostname or IP address for STUN/TURN URLs
// - Production: Set to your public domain (e.g., "turn.example.com")
// - Development: Leave empty for auto-detection of LAN IP
// Can also be set via TURN_EXTERNAL_HOST environment variable
ExternalHost string `yaml:"external_host"`
// STUNURLs are the STUN server URLs to return to clients
// Use "::" as placeholder for ExternalHost (e.g., "stun:::3478" -> "stun:turn.example.com:3478")
// e.g., ["stun:::3478"] or ["stun:gateway.orama.com:3478"]
STUNURLs []string `yaml:"stun_urls"`
// TURNURLs are the TURN server URLs to return to clients
// Use "::" as placeholder for ExternalHost (e.g., "turn:::3478" -> "turn:turn.example.com:3478")
// e.g., ["turn:::3478?transport=udp"] or ["turn:gateway.orama.com:3478?transport=udp"]
TURNURLs []string `yaml:"turn_urls"`
// TLSEnabled indicates whether TURNS (TURN over TLS) is available
// When true, turns:// URLs will be included in the response
TLSEnabled bool `yaml:"tls_enabled"`
}
// SFUConfig contains WebRTC SFU (Selective Forwarding Unit) configuration
type SFUConfig struct {
// Enabled enables the SFU service
Enabled bool `yaml:"enabled"`
// MaxParticipants is the maximum number of participants per room
// Default: 10
MaxParticipants int `yaml:"max_participants"`
// MediaTimeout is the timeout for media operations
// Default: 30 seconds
MediaTimeout time.Duration `yaml:"media_timeout"`
// ICEServers are additional ICE servers for WebRTC connections
// These are used in addition to the TURN servers from TURNConfig
ICEServers []ICEServerConfig `yaml:"ice_servers"`
}
// ICEServerConfig represents a single ICE server configuration
type ICEServerConfig struct {
URLs []string `yaml:"urls"`
Username string `yaml:"username,omitempty"`
Credential string `yaml:"credential,omitempty"`
}
// TURNServerConfig contains built-in TURN server configuration
type TURNServerConfig struct {
// Enabled enables the built-in TURN server
Enabled bool `yaml:"enabled"`
// ListenAddr is the UDP address to listen on (e.g., "0.0.0.0:3478")
ListenAddr string `yaml:"listen_addr"`
// PublicIP is the public IP address to advertise for relay
// If empty, will try to auto-detect
PublicIP string `yaml:"public_ip"`
// Realm is the TURN realm (e.g., "orama.network")
Realm string `yaml:"realm"`
// MinPort and MaxPort define the relay port range
MinPort uint16 `yaml:"min_port"`
MaxPort uint16 `yaml:"max_port"`
// TLS Configuration for TURNS (TURN over TLS)
// TLSEnabled enables TURNS listener
TLSEnabled bool `yaml:"tls_enabled"`
// TLSListenAddr is the TCP/TLS address to listen on (e.g., "0.0.0.0:443")
TLSListenAddr string `yaml:"tls_listen_addr"`
// TLSCertFile is the path to the TLS certificate file
TLSCertFile string `yaml:"tls_cert_file"`
// TLSKeyFile is the path to the TLS private key file
TLSKeyFile string `yaml:"tls_key_file"`
}
// ParseMultiaddrs converts string addresses to multiaddr objects

View File

@ -47,6 +47,15 @@ logging:
level: "info"
format: "console"
# Built-in TURN server for WebRTC NAT traversal
turn_server:
enabled: true
listen_addr: "0.0.0.0:3478"
public_ip: "" # Auto-detect if empty, or set to your public IP
realm: "orama.network"
min_port: 49152
max_port: 65535
http_gateway:
enabled: true
listen_addr: "{{if .EnableHTTPS}}:{{.HTTPSPort}}{{else}}:{{.UnifiedGatewayPort}}{{end}}"
@ -86,3 +95,19 @@ http_gateway:
# Routes for internal service reverse proxy (kept for backwards compatibility but not used by full gateway)
routes: {}
# TURN/STUN URLs returned to clients (points to built-in TURN server)
turn:
shared_secret: "dev-secret-12345"
ttl: "24h"
stun_urls:
- "stun:::3478"
turn_urls:
- "turn:::3478?transport=udp"
# SFU (Selective Forwarding Unit) configuration for WebRTC group calls
sfu:
enabled: true
max_participants: 10
media_timeout: "30s"
ice_servers: [] # Additional ICE servers beyond TURN config (optional)

View File

@ -1,31 +1,75 @@
// Package gateway provides the main API Gateway for the Orama Network.
// It orchestrates traffic between clients and various backend services including
// distributed caching (Olric), decentralized storage (IPFS), and serverless
// WebAssembly (WASM) execution. The gateway implements robust security through
// wallet-based cryptographic authentication and JWT lifecycle management.
package gateway
import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"database/sql"
"encoding/pem"
"fmt"
"net"
"os"
"path/filepath"
"strings"
"sync"
"time"
"github.com/DeBrosOfficial/network/pkg/client"
"github.com/DeBrosOfficial/network/pkg/config"
"github.com/DeBrosOfficial/network/pkg/gateway/auth"
authhandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/auth"
"github.com/DeBrosOfficial/network/pkg/gateway/handlers/cache"
pubsubhandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/pubsub"
serverlesshandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/serverless"
"github.com/DeBrosOfficial/network/pkg/gateway/handlers/storage"
"github.com/DeBrosOfficial/network/pkg/ipfs"
"github.com/DeBrosOfficial/network/pkg/logging"
"github.com/DeBrosOfficial/network/pkg/olric"
"github.com/DeBrosOfficial/network/pkg/pubsub"
"github.com/DeBrosOfficial/network/pkg/rqlite"
"github.com/DeBrosOfficial/network/pkg/serverless"
"github.com/multiformats/go-multiaddr"
olriclib "github.com/olric-data/olric"
"go.uber.org/zap"
_ "github.com/rqlite/gorqlite/stdlib"
)
const (
olricInitMaxAttempts = 5
olricInitInitialBackoff = 500 * time.Millisecond
olricInitMaxBackoff = 5 * time.Second
)
// Config holds configuration for the gateway server
type Config struct {
ListenAddr string
ClientNamespace string
BootstrapPeers []string
NodePeerID string // The node's actual peer ID from its identity file
// Optional DSN for rqlite database/sql driver, e.g. "http://localhost:4001"
// If empty, defaults to "http://localhost:4001".
RQLiteDSN string
// HTTPS configuration
EnableHTTPS bool // Enable HTTPS with ACME (Let's Encrypt)
DomainName string // Domain name for HTTPS certificate
TLSCacheDir string // Directory to cache TLS certificates (default: ~/.orama/tls-cache)
// Olric cache configuration
OlricServers []string // List of Olric server addresses (e.g., ["localhost:3320"]). If empty, defaults to ["localhost:3320"]
OlricTimeout time.Duration // Timeout for Olric operations (default: 10s)
// IPFS Cluster configuration
IPFSClusterAPIURL string // IPFS Cluster HTTP API URL (e.g., "http://localhost:9094"). If empty, gateway will discover from node configs
IPFSAPIURL string // IPFS HTTP API URL for content retrieval (e.g., "http://localhost:5001"). If empty, gateway will discover from node configs
IPFSTimeout time.Duration // Timeout for IPFS operations (default: 60s)
IPFSReplicationFactor int // Replication factor for pins (default: 3)
IPFSEnableEncryption bool // Enable client-side encryption before upload (default: true, discovered from node configs)
// TURN/STUN configuration for WebRTC
TURN *config.TURNConfig
// SFU configuration for WebRTC group calls
SFU *config.SFUConfig
}
type Gateway struct {
logger *logging.ColoredLogger
@ -42,29 +86,28 @@ type Gateway struct {
// Olric cache client
olricClient *olric.Client
olricMu sync.RWMutex
cacheHandlers *cache.CacheHandlers
// IPFS storage client
ipfsClient ipfs.IPFSClient
storageHandlers *storage.Handlers
ipfsClient ipfs.IPFSClient
// Local pub/sub bypass for same-gateway subscribers
localSubscribers map[string][]*localSubscriber // topic+namespace -> subscribers
presenceMembers map[string][]PresenceMember // topicKey -> members
mu sync.RWMutex
presenceMu sync.RWMutex
pubsubHandlers *pubsubhandlers.PubSubHandlers
// Serverless function engine
serverlessEngine *serverless.Engine
serverlessRegistry *serverless.Registry
serverlessInvoker *serverless.Invoker
serverlessWSMgr *serverless.WSManager
serverlessHandlers *serverlesshandlers.ServerlessHandlers
serverlessHandlers *ServerlessHandlers
// Authentication service
authService *auth.Service
authHandlers *authhandlers.Handlers
authService *auth.Service
// SFU manager for WebRTC group calls
sfuManager *SFUManager
}
// localSubscriber represents a WebSocket subscriber for local message delivery
@ -81,113 +124,359 @@ type PresenceMember struct {
ConnID string `json:"-"` // Internal: for tracking which connection
}
// authClientAdapter adapts client.NetworkClient to authhandlers.NetworkClient
type authClientAdapter struct {
client client.NetworkClient
}
func (a *authClientAdapter) Database() authhandlers.DatabaseClient {
return &authDatabaseAdapter{db: a.client.Database()}
}
// authDatabaseAdapter adapts client.DatabaseClient to authhandlers.DatabaseClient
type authDatabaseAdapter struct {
db client.DatabaseClient
}
func (a *authDatabaseAdapter) Query(ctx context.Context, sql string, args ...interface{}) (*authhandlers.QueryResult, error) {
result, err := a.db.Query(ctx, sql, args...)
if err != nil {
return nil, err
}
// Convert client.QueryResult to authhandlers.QueryResult
// The auth handlers expect []interface{} but client returns [][]interface{}
convertedRows := make([]interface{}, len(result.Rows))
for i, row := range result.Rows {
convertedRows[i] = row
}
return &authhandlers.QueryResult{
Count: int(result.Count),
Rows: convertedRows,
}, nil
}
// New creates and initializes a new Gateway instance.
// It establishes all necessary service connections and dependencies.
// New creates and initializes a new Gateway instance
func New(logger *logging.ColoredLogger, cfg *Config) (*Gateway, error) {
logger.ComponentInfo(logging.ComponentGeneral, "Creating gateway dependencies...")
logger.ComponentInfo(logging.ComponentGeneral, "Building client config...")
// Initialize all dependencies (network client, database, cache, storage, serverless)
deps, err := NewDependencies(logger, cfg)
// Build client config from gateway cfg
cliCfg := client.DefaultClientConfig(cfg.ClientNamespace)
if len(cfg.BootstrapPeers) > 0 {
cliCfg.BootstrapPeers = cfg.BootstrapPeers
}
logger.ComponentInfo(logging.ComponentGeneral, "Creating network client...")
c, err := client.NewClient(cliCfg)
if err != nil {
logger.ComponentError(logging.ComponentGeneral, "failed to create dependencies", zap.Error(err))
logger.ComponentError(logging.ComponentClient, "failed to create network client", zap.Error(err))
return nil, err
}
logger.ComponentInfo(logging.ComponentGeneral, "Connecting network client...")
if err := c.Connect(); err != nil {
logger.ComponentError(logging.ComponentClient, "failed to connect network client", zap.Error(err))
return nil, err
}
logger.ComponentInfo(logging.ComponentClient, "Network client connected",
zap.String("namespace", cliCfg.AppName),
zap.Int("peer_count", len(cliCfg.BootstrapPeers)),
)
logger.ComponentInfo(logging.ComponentGeneral, "Creating gateway instance...")
gw := &Gateway{
logger: logger,
cfg: cfg,
client: deps.Client,
nodePeerID: cfg.NodePeerID,
startedAt: time.Now(),
sqlDB: deps.SQLDB,
ormClient: deps.ORMClient,
ormHTTP: deps.ORMHTTP,
olricClient: deps.OlricClient,
ipfsClient: deps.IPFSClient,
serverlessEngine: deps.ServerlessEngine,
serverlessRegistry: deps.ServerlessRegistry,
serverlessInvoker: deps.ServerlessInvoker,
serverlessWSMgr: deps.ServerlessWSMgr,
serverlessHandlers: deps.ServerlessHandlers,
authService: deps.AuthService,
localSubscribers: make(map[string][]*localSubscriber),
presenceMembers: make(map[string][]PresenceMember),
logger: logger,
cfg: cfg,
client: c,
nodePeerID: cfg.NodePeerID,
startedAt: time.Now(),
localSubscribers: make(map[string][]*localSubscriber),
presenceMembers: make(map[string][]PresenceMember),
}
// Initialize handler instances
gw.pubsubHandlers = pubsubhandlers.NewPubSubHandlers(deps.Client, logger)
if deps.OlricClient != nil {
gw.cacheHandlers = cache.NewCacheHandlers(logger, deps.OlricClient)
logger.ComponentInfo(logging.ComponentGeneral, "Initializing RQLite ORM HTTP gateway...")
dsn := cfg.RQLiteDSN
if dsn == "" {
dsn = "http://localhost:5001"
}
db, dbErr := sql.Open("rqlite", dsn)
if dbErr != nil {
logger.ComponentWarn(logging.ComponentGeneral, "failed to open rqlite sql db; http orm gateway disabled", zap.Error(dbErr))
} else {
// Configure connection pool with proper timeouts and limits
db.SetMaxOpenConns(25) // Maximum number of open connections
db.SetMaxIdleConns(5) // Maximum number of idle connections
db.SetConnMaxLifetime(5 * time.Minute) // Maximum lifetime of a connection
db.SetConnMaxIdleTime(2 * time.Minute) // Maximum idle time before closing
if deps.IPFSClient != nil {
gw.storageHandlers = storage.New(deps.IPFSClient, logger, storage.Config{
IPFSReplicationFactor: cfg.IPFSReplicationFactor,
IPFSAPIURL: cfg.IPFSAPIURL,
})
}
if deps.AuthService != nil {
// Create adapter for auth handlers to use the client
authClientAdapter := &authClientAdapter{client: deps.Client}
gw.authHandlers = authhandlers.NewHandlers(
logger,
deps.AuthService,
authClientAdapter,
cfg.ClientNamespace,
gw.withInternalAuth,
gw.sqlDB = db
orm := rqlite.NewClient(db)
gw.ormClient = orm
gw.ormHTTP = rqlite.NewHTTPGateway(orm, "/v1/db")
// Set a reasonable timeout for HTTP requests (30 seconds)
gw.ormHTTP.Timeout = 30 * time.Second
logger.ComponentInfo(logging.ComponentGeneral, "RQLite ORM HTTP gateway ready",
zap.String("dsn", dsn),
zap.String("base_path", "/v1/db"),
zap.Duration("timeout", gw.ormHTTP.Timeout),
)
}
// Start background Olric reconnection if initial connection failed
if deps.OlricClient == nil {
olricCfg := olric.Config{
Servers: cfg.OlricServers,
Timeout: cfg.OlricTimeout,
logger.ComponentInfo(logging.ComponentGeneral, "Initializing Olric cache client...")
// Discover Olric servers dynamically from LibP2P peers if not explicitly configured
olricServers := cfg.OlricServers
if len(olricServers) == 0 {
logger.ComponentInfo(logging.ComponentGeneral, "Olric servers not configured, discovering from LibP2P peers...")
discovered := discoverOlricServers(c, logger.Logger)
if len(discovered) > 0 {
olricServers = discovered
logger.ComponentInfo(logging.ComponentGeneral, "Discovered Olric servers from LibP2P peers",
zap.Strings("servers", olricServers))
} else {
// Fallback to localhost for local development
olricServers = []string{"localhost:3320"}
logger.ComponentInfo(logging.ComponentGeneral, "No Olric servers discovered, using localhost fallback")
}
if len(olricCfg.Servers) == 0 {
olricCfg.Servers = []string{"localhost:3320"}
}
gw.startOlricReconnectLoop(olricCfg)
} else {
logger.ComponentInfo(logging.ComponentGeneral, "Using explicitly configured Olric servers",
zap.Strings("servers", olricServers))
}
logger.ComponentInfo(logging.ComponentGeneral, "Gateway creation completed")
olricCfg := olric.Config{
Servers: olricServers,
Timeout: cfg.OlricTimeout,
}
olricClient, olricErr := initializeOlricClientWithRetry(olricCfg, logger)
if olricErr != nil {
logger.ComponentWarn(logging.ComponentGeneral, "failed to initialize Olric cache client; cache endpoints disabled", zap.Error(olricErr))
gw.startOlricReconnectLoop(olricCfg)
} else {
gw.setOlricClient(olricClient)
logger.ComponentInfo(logging.ComponentGeneral, "Olric cache client ready",
zap.Strings("servers", olricCfg.Servers),
zap.Duration("timeout", olricCfg.Timeout),
)
}
logger.ComponentInfo(logging.ComponentGeneral, "Initializing IPFS Cluster client...")
// Discover IPFS endpoints from node configs if not explicitly configured
ipfsClusterURL := cfg.IPFSClusterAPIURL
ipfsAPIURL := cfg.IPFSAPIURL
ipfsTimeout := cfg.IPFSTimeout
ipfsReplicationFactor := cfg.IPFSReplicationFactor
ipfsEnableEncryption := cfg.IPFSEnableEncryption
if ipfsClusterURL == "" {
logger.ComponentInfo(logging.ComponentGeneral, "IPFS Cluster URL not configured, discovering from node configs...")
discovered := discoverIPFSFromNodeConfigs(logger.Logger)
if discovered.clusterURL != "" {
ipfsClusterURL = discovered.clusterURL
ipfsAPIURL = discovered.apiURL
if discovered.timeout > 0 {
ipfsTimeout = discovered.timeout
}
if discovered.replicationFactor > 0 {
ipfsReplicationFactor = discovered.replicationFactor
}
ipfsEnableEncryption = discovered.enableEncryption
logger.ComponentInfo(logging.ComponentGeneral, "Discovered IPFS endpoints from node configs",
zap.String("cluster_url", ipfsClusterURL),
zap.String("api_url", ipfsAPIURL),
zap.Bool("encryption_enabled", ipfsEnableEncryption))
} else {
// Fallback to localhost defaults
ipfsClusterURL = "http://localhost:9094"
ipfsAPIURL = "http://localhost:5001"
ipfsEnableEncryption = true // Default to true
logger.ComponentInfo(logging.ComponentGeneral, "No IPFS config found in node configs, using localhost defaults")
}
}
if ipfsAPIURL == "" {
ipfsAPIURL = "http://localhost:5001"
}
if ipfsTimeout == 0 {
ipfsTimeout = 60 * time.Second
}
if ipfsReplicationFactor == 0 {
ipfsReplicationFactor = 3
}
if !cfg.IPFSEnableEncryption && !ipfsEnableEncryption {
// Only disable if explicitly set to false in both places
ipfsEnableEncryption = false
} else {
// Default to true if not explicitly disabled
ipfsEnableEncryption = true
}
ipfsCfg := ipfs.Config{
ClusterAPIURL: ipfsClusterURL,
Timeout: ipfsTimeout,
}
ipfsClient, ipfsErr := ipfs.NewClient(ipfsCfg, logger.Logger)
if ipfsErr != nil {
logger.ComponentWarn(logging.ComponentGeneral, "failed to initialize IPFS Cluster client; storage endpoints disabled", zap.Error(ipfsErr))
} else {
gw.ipfsClient = ipfsClient
// Check peer count and warn if insufficient (use background context to avoid blocking)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if peerCount, err := ipfsClient.GetPeerCount(ctx); err == nil {
if peerCount < ipfsReplicationFactor {
logger.ComponentWarn(logging.ComponentGeneral, "insufficient cluster peers for replication factor",
zap.Int("peer_count", peerCount),
zap.Int("replication_factor", ipfsReplicationFactor),
zap.String("message", "Some pin operations may fail until more peers join the cluster"))
} else {
logger.ComponentInfo(logging.ComponentGeneral, "IPFS Cluster peer count sufficient",
zap.Int("peer_count", peerCount),
zap.Int("replication_factor", ipfsReplicationFactor))
}
} else {
logger.ComponentWarn(logging.ComponentGeneral, "failed to get cluster peer count", zap.Error(err))
}
logger.ComponentInfo(logging.ComponentGeneral, "IPFS Cluster client ready",
zap.String("cluster_api_url", ipfsCfg.ClusterAPIURL),
zap.String("ipfs_api_url", ipfsAPIURL),
zap.Duration("timeout", ipfsCfg.Timeout),
zap.Int("replication_factor", ipfsReplicationFactor),
zap.Bool("encryption_enabled", ipfsEnableEncryption),
)
}
// Store IPFS settings in gateway for use by handlers
gw.cfg.IPFSAPIURL = ipfsAPIURL
gw.cfg.IPFSReplicationFactor = ipfsReplicationFactor
gw.cfg.IPFSEnableEncryption = ipfsEnableEncryption
// Initialize serverless function engine
logger.ComponentInfo(logging.ComponentGeneral, "Initializing serverless function engine...")
if gw.ormClient != nil && gw.ipfsClient != nil {
// Create serverless registry (stores functions in RQLite + IPFS)
registryCfg := serverless.RegistryConfig{
IPFSAPIURL: ipfsAPIURL,
}
registry := serverless.NewRegistry(gw.ormClient, gw.ipfsClient, registryCfg, logger.Logger)
gw.serverlessRegistry = registry
// Create WebSocket manager for function streaming
gw.serverlessWSMgr = serverless.NewWSManager(logger.Logger)
// Get underlying Olric client if available
var olricClient olriclib.Client
if oc := gw.getOlricClient(); oc != nil {
olricClient = oc.UnderlyingClient()
}
// Create host functions provider (allows functions to call Orama services)
// Get pubsub adapter from client for serverless functions
var pubsubAdapter *pubsub.ClientAdapter
if gw.client != nil {
if concreteClient, ok := gw.client.(*client.Client); ok {
pubsubAdapter = concreteClient.PubSubAdapter()
if pubsubAdapter != nil {
logger.ComponentInfo(logging.ComponentGeneral, "pubsub adapter available for serverless functions")
} else {
logger.ComponentWarn(logging.ComponentGeneral, "pubsub adapter is nil - serverless pubsub will be unavailable")
}
}
}
hostFuncsCfg := serverless.HostFunctionsConfig{
IPFSAPIURL: ipfsAPIURL,
HTTPTimeout: 30 * time.Second,
}
hostFuncs := serverless.NewHostFunctions(
gw.ormClient,
olricClient,
gw.ipfsClient,
pubsubAdapter, // pubsub adapter for serverless functions
gw.serverlessWSMgr,
nil, // secrets manager - TODO: implement
hostFuncsCfg,
logger.Logger,
)
// Create WASM engine configuration
engineCfg := serverless.DefaultConfig()
engineCfg.DefaultMemoryLimitMB = 128
engineCfg.MaxMemoryLimitMB = 256
engineCfg.DefaultTimeoutSeconds = 30
engineCfg.MaxTimeoutSeconds = 60
engineCfg.ModuleCacheSize = 100
// Create WASM engine
engine, engineErr := serverless.NewEngine(engineCfg, registry, hostFuncs, logger.Logger, serverless.WithInvocationLogger(registry))
if engineErr != nil {
logger.ComponentWarn(logging.ComponentGeneral, "failed to initialize serverless engine; functions disabled", zap.Error(engineErr))
} else {
gw.serverlessEngine = engine
// Create invoker
gw.serverlessInvoker = serverless.NewInvoker(engine, registry, hostFuncs, logger.Logger)
// Create trigger manager
triggerManager := serverless.NewDBTriggerManager(gw.ormClient, logger.Logger)
// Create HTTP handlers
gw.serverlessHandlers = NewServerlessHandlers(
gw.serverlessInvoker,
registry,
gw.serverlessWSMgr,
triggerManager,
logger.Logger,
)
// Initialize auth service
// For now using ephemeral key, can be loaded from config later
key, _ := rsa.GenerateKey(rand.Reader, 2048)
keyPEM := pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(key),
})
authService, err := auth.NewService(logger, c, string(keyPEM), cfg.ClientNamespace)
if err != nil {
logger.ComponentError(logging.ComponentGeneral, "failed to initialize auth service", zap.Error(err))
} else {
gw.authService = authService
}
logger.ComponentInfo(logging.ComponentGeneral, "Serverless function engine ready",
zap.Int("default_memory_mb", engineCfg.DefaultMemoryLimitMB),
zap.Int("default_timeout_sec", engineCfg.DefaultTimeoutSeconds),
zap.Int("module_cache_size", engineCfg.ModuleCacheSize),
)
}
} else {
logger.ComponentWarn(logging.ComponentGeneral, "serverless engine requires RQLite and IPFS; functions disabled")
}
// Initialize SFU manager for WebRTC calls
if err := gw.initializeSFUManager(); err != nil {
logger.ComponentWarn(logging.ComponentGeneral, "failed to initialize SFU manager; WebRTC calls disabled", zap.Error(err))
}
logger.ComponentInfo(logging.ComponentGeneral, "Gateway creation completed, returning...")
return gw, nil
}
// withInternalAuth creates a context for internal gateway operations that bypass authentication
func (g *Gateway) withInternalAuth(ctx context.Context) context.Context {
return client.WithInternalAuth(ctx)
}
// Close disconnects the gateway client
func (g *Gateway) Close() {
// Close SFU manager first
if g.sfuManager != nil {
if err := g.sfuManager.Close(); err != nil {
g.logger.ComponentWarn(logging.ComponentGeneral, "error during SFU manager close", zap.Error(err))
}
}
// Close serverless engine
if g.serverlessEngine != nil {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
if err := g.serverlessEngine.Close(ctx); err != nil {
g.logger.ComponentWarn(logging.ComponentGeneral, "error during serverless engine close", zap.Error(err))
}
cancel()
}
if g.client != nil {
if err := g.client.Disconnect(); err != nil {
g.logger.ComponentWarn(logging.ComponentClient, "error during client disconnect", zap.Error(err))
}
}
if g.sqlDB != nil {
_ = g.sqlDB.Close()
}
if client := g.getOlricClient(); client != nil {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := client.Close(ctx); err != nil {
g.logger.ComponentWarn(logging.ComponentGeneral, "error during Olric client close", zap.Error(err))
}
}
if g.ipfsClient != nil {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := g.ipfsClient.Close(ctx); err != nil {
g.logger.ComponentWarn(logging.ComponentGeneral, "error during IPFS client close", zap.Error(err))
}
}
}
// getLocalSubscribers returns all local subscribers for a given topic and namespace
func (g *Gateway) getLocalSubscribers(topic, namespace string) []*localSubscriber {
topicKey := namespace + "." + topic
@ -197,32 +486,23 @@ func (g *Gateway) getLocalSubscribers(topic, namespace string) []*localSubscribe
return nil
}
// setOlricClient atomically sets the Olric client and reinitializes cache handlers.
func (g *Gateway) setOlricClient(client *olric.Client) {
g.olricMu.Lock()
defer g.olricMu.Unlock()
g.olricClient = client
if client != nil {
g.cacheHandlers = cache.NewCacheHandlers(g.logger, client)
}
}
// getOlricClient atomically retrieves the current Olric client.
func (g *Gateway) getOlricClient() *olric.Client {
g.olricMu.RLock()
defer g.olricMu.RUnlock()
return g.olricClient
}
// startOlricReconnectLoop starts a background goroutine that continuously attempts
// to reconnect to the Olric cluster with exponential backoff.
func (g *Gateway) startOlricReconnectLoop(cfg olric.Config) {
go func() {
retryDelay := 5 * time.Second
maxBackoff := 30 * time.Second
for {
client, err := olric.NewClient(cfg, g.logger.Logger)
client, err := initializeOlricClientWithRetry(cfg, g.logger)
if err == nil {
g.setOlricClient(client)
g.logger.ComponentInfo(logging.ComponentGeneral, "Olric cache client connected after background retries",
@ -236,13 +516,211 @@ func (g *Gateway) startOlricReconnectLoop(cfg olric.Config) {
zap.Error(err))
time.Sleep(retryDelay)
if retryDelay < maxBackoff {
if retryDelay < olricInitMaxBackoff {
retryDelay *= 2
if retryDelay > maxBackoff {
retryDelay = maxBackoff
if retryDelay > olricInitMaxBackoff {
retryDelay = olricInitMaxBackoff
}
}
}
}()
}
func initializeOlricClientWithRetry(cfg olric.Config, logger *logging.ColoredLogger) (*olric.Client, error) {
backoff := olricInitInitialBackoff
for attempt := 1; attempt <= olricInitMaxAttempts; attempt++ {
client, err := olric.NewClient(cfg, logger.Logger)
if err == nil {
if attempt > 1 {
logger.ComponentInfo(logging.ComponentGeneral, "Olric cache client initialized after retries",
zap.Int("attempts", attempt))
}
return client, nil
}
logger.ComponentWarn(logging.ComponentGeneral, "Olric cache client init attempt failed",
zap.Int("attempt", attempt),
zap.Duration("retry_in", backoff),
zap.Error(err))
if attempt == olricInitMaxAttempts {
return nil, fmt.Errorf("failed to initialize Olric cache client after %d attempts: %w", attempt, err)
}
time.Sleep(backoff)
backoff *= 2
if backoff > olricInitMaxBackoff {
backoff = olricInitMaxBackoff
}
}
return nil, fmt.Errorf("failed to initialize Olric cache client")
}
// discoverOlricServers discovers Olric server addresses from LibP2P peers
// Returns a list of IP:port addresses where Olric servers are expected to run (port 3320)
func discoverOlricServers(networkClient client.NetworkClient, logger *zap.Logger) []string {
// Get network info to access peer information
networkInfo := networkClient.Network()
if networkInfo == nil {
logger.Debug("Network info not available for Olric discovery")
return nil
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
peers, err := networkInfo.GetPeers(ctx)
if err != nil {
logger.Debug("Failed to get peers for Olric discovery", zap.Error(err))
return nil
}
olricServers := make([]string, 0)
seen := make(map[string]bool)
for _, peer := range peers {
for _, addrStr := range peer.Addresses {
// Parse multiaddr
ma, err := multiaddr.NewMultiaddr(addrStr)
if err != nil {
continue
}
// Extract IP address
var ip string
if ipv4, err := ma.ValueForProtocol(multiaddr.P_IP4); err == nil && ipv4 != "" {
ip = ipv4
} else if ipv6, err := ma.ValueForProtocol(multiaddr.P_IP6); err == nil && ipv6 != "" {
ip = ipv6
} else {
continue
}
// Skip localhost loopback addresses (we'll use localhost:3320 as fallback)
if ip == "localhost" || ip == "::1" {
continue
}
// Build Olric server address (standard port 3320)
olricAddr := net.JoinHostPort(ip, "3320")
if !seen[olricAddr] {
olricServers = append(olricServers, olricAddr)
seen[olricAddr] = true
}
}
}
// Also check peers from config
if cfg := networkClient.Config(); cfg != nil {
for _, peerAddr := range cfg.BootstrapPeers {
ma, err := multiaddr.NewMultiaddr(peerAddr)
if err != nil {
continue
}
var ip string
if ipv4, err := ma.ValueForProtocol(multiaddr.P_IP4); err == nil && ipv4 != "" {
ip = ipv4
} else if ipv6, err := ma.ValueForProtocol(multiaddr.P_IP6); err == nil && ipv6 != "" {
ip = ipv6
} else {
continue
}
// Skip localhost
if ip == "localhost" || ip == "::1" {
continue
}
olricAddr := net.JoinHostPort(ip, "3320")
if !seen[olricAddr] {
olricServers = append(olricServers, olricAddr)
seen[olricAddr] = true
}
}
}
// If we found servers, log them
if len(olricServers) > 0 {
logger.Info("Discovered Olric servers from LibP2P network",
zap.Strings("servers", olricServers))
}
return olricServers
}
// ipfsDiscoveryResult holds discovered IPFS configuration
type ipfsDiscoveryResult struct {
clusterURL string
apiURL string
timeout time.Duration
replicationFactor int
enableEncryption bool
}
// discoverIPFSFromNodeConfigs discovers IPFS configuration from node.yaml files
// Checks node-1.yaml through node-5.yaml for IPFS configuration
func discoverIPFSFromNodeConfigs(logger *zap.Logger) ipfsDiscoveryResult {
homeDir, err := os.UserHomeDir()
if err != nil {
logger.Debug("Failed to get home directory for IPFS discovery", zap.Error(err))
return ipfsDiscoveryResult{}
}
configDir := filepath.Join(homeDir, ".orama")
// Try all node config files for IPFS settings
configFiles := []string{"node-1.yaml", "node-2.yaml", "node-3.yaml", "node-4.yaml", "node-5.yaml"}
for _, filename := range configFiles {
configPath := filepath.Join(configDir, filename)
data, err := os.ReadFile(configPath)
if err != nil {
continue
}
var nodeCfg config.Config
if err := config.DecodeStrict(strings.NewReader(string(data)), &nodeCfg); err != nil {
logger.Debug("Failed to parse node config for IPFS discovery",
zap.String("file", filename), zap.Error(err))
continue
}
// Check if IPFS is configured
if nodeCfg.Database.IPFS.ClusterAPIURL != "" {
result := ipfsDiscoveryResult{
clusterURL: nodeCfg.Database.IPFS.ClusterAPIURL,
apiURL: nodeCfg.Database.IPFS.APIURL,
timeout: nodeCfg.Database.IPFS.Timeout,
replicationFactor: nodeCfg.Database.IPFS.ReplicationFactor,
enableEncryption: nodeCfg.Database.IPFS.EnableEncryption,
}
if result.apiURL == "" {
result.apiURL = "http://localhost:5001"
}
if result.timeout == 0 {
result.timeout = 60 * time.Second
}
if result.replicationFactor == 0 {
result.replicationFactor = 3
}
// Default encryption to true if not set
if !result.enableEncryption {
result.enableEncryption = true
}
logger.Info("Discovered IPFS config from node config",
zap.String("file", filename),
zap.String("cluster_url", result.clusterURL),
zap.String("api_url", result.apiURL),
zap.Bool("encryption_enabled", result.enableEncryption))
return result
}
}
return ipfsDiscoveryResult{}
}

View File

@ -191,6 +191,16 @@ func isPublicPath(p string) bool {
return true
}
// TURN credentials (public for development - requires secret for actual use)
if strings.HasPrefix(p, "/v1/turn/") {
return true
}
// SFU endpoints (public for development)
if strings.HasPrefix(p, "/v1/sfu/") {
return true
}
switch p {
case "/health", "/v1/health", "/status", "/v1/status", "/v1/auth/jwks", "/.well-known/jwks.json", "/v1/version", "/v1/auth/login", "/v1/auth/challenge", "/v1/auth/verify", "/v1/auth/register", "/v1/auth/refresh", "/v1/auth/logout", "/v1/auth/api-key", "/v1/auth/simple-key", "/v1/network/status", "/v1/network/peers":
return true

View File

@ -0,0 +1,475 @@
package gateway
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"net/http"
"strings"
"time"
"github.com/DeBrosOfficial/network/pkg/client"
"github.com/DeBrosOfficial/network/pkg/pubsub"
"github.com/google/uuid"
"go.uber.org/zap"
"github.com/gorilla/websocket"
)
var wsUpgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
// For early development we accept any origin; tighten later.
CheckOrigin: func(r *http.Request) bool { return true },
}
// pubsubWebsocketHandler upgrades to WS, subscribes to a namespaced topic, and
// forwards received PubSub messages to the client. Messages sent by the client
// are published to the same namespaced topic.
func (g *Gateway) pubsubWebsocketHandler(w http.ResponseWriter, r *http.Request) {
if g.client == nil {
g.logger.ComponentWarn("gateway", "pubsub ws: client not initialized")
writeError(w, http.StatusServiceUnavailable, "client not initialized")
return
}
if r.Method != http.MethodGet {
g.logger.ComponentWarn("gateway", "pubsub ws: method not allowed")
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
return
}
// Resolve namespace from auth context
ns := resolveNamespaceFromRequest(r)
if ns == "" {
g.logger.ComponentWarn("gateway", "pubsub ws: namespace not resolved")
writeError(w, http.StatusForbidden, "namespace not resolved")
return
}
topic := r.URL.Query().Get("topic")
if topic == "" {
g.logger.ComponentWarn("gateway", "pubsub ws: missing topic")
writeError(w, http.StatusBadRequest, "missing 'topic'")
return
}
// Presence handling
enablePresence := r.URL.Query().Get("presence") == "true"
memberID := r.URL.Query().Get("member_id")
memberMetaStr := r.URL.Query().Get("member_meta")
var memberMeta map[string]interface{}
if memberMetaStr != "" {
_ = json.Unmarshal([]byte(memberMetaStr), &memberMeta)
}
if enablePresence && memberID == "" {
g.logger.ComponentWarn("gateway", "pubsub ws: presence enabled but missing member_id")
writeError(w, http.StatusBadRequest, "missing 'member_id' for presence")
return
}
conn, err := wsUpgrader.Upgrade(w, r, nil)
if err != nil {
g.logger.ComponentWarn("gateway", "pubsub ws: upgrade failed")
return
}
defer conn.Close()
// Channel to deliver PubSub messages to WS writer
msgs := make(chan []byte, 128)
// NEW: Register as local subscriber for direct message delivery
localSub := &localSubscriber{
msgChan: msgs,
namespace: ns,
}
topicKey := fmt.Sprintf("%s.%s", ns, topic)
g.mu.Lock()
g.localSubscribers[topicKey] = append(g.localSubscribers[topicKey], localSub)
subscriberCount := len(g.localSubscribers[topicKey])
g.mu.Unlock()
connID := uuid.New().String()
if enablePresence {
member := PresenceMember{
MemberID: memberID,
JoinedAt: time.Now().Unix(),
Meta: memberMeta,
ConnID: connID,
}
g.presenceMu.Lock()
g.presenceMembers[topicKey] = append(g.presenceMembers[topicKey], member)
g.presenceMu.Unlock()
// Broadcast join event
joinEvent := map[string]interface{}{
"type": "presence.join",
"member_id": memberID,
"meta": memberMeta,
"timestamp": member.JoinedAt,
}
eventData, _ := json.Marshal(joinEvent)
// Use a background context for the broadcast to ensure it finishes even if the connection closes immediately
broadcastCtx := pubsub.WithNamespace(client.WithInternalAuth(context.Background()), ns)
_ = g.client.PubSub().Publish(broadcastCtx, topic, eventData)
g.logger.ComponentInfo("gateway", "pubsub ws: member joined presence",
zap.String("topic", topic),
zap.String("member_id", memberID))
}
g.logger.ComponentInfo("gateway", "pubsub ws: registered local subscriber",
zap.String("topic", topic),
zap.String("namespace", ns),
zap.Int("total_subscribers", subscriberCount))
// Unregister on close
defer func() {
g.mu.Lock()
subs := g.localSubscribers[topicKey]
for i, sub := range subs {
if sub == localSub {
g.localSubscribers[topicKey] = append(subs[:i], subs[i+1:]...)
break
}
}
remainingCount := len(g.localSubscribers[topicKey])
if remainingCount == 0 {
delete(g.localSubscribers, topicKey)
}
g.mu.Unlock()
if enablePresence {
g.presenceMu.Lock()
members := g.presenceMembers[topicKey]
for i, m := range members {
if m.ConnID == connID {
g.presenceMembers[topicKey] = append(members[:i], members[i+1:]...)
break
}
}
if len(g.presenceMembers[topicKey]) == 0 {
delete(g.presenceMembers, topicKey)
}
g.presenceMu.Unlock()
// Broadcast leave event
leaveEvent := map[string]interface{}{
"type": "presence.leave",
"member_id": memberID,
"timestamp": time.Now().Unix(),
}
eventData, _ := json.Marshal(leaveEvent)
broadcastCtx := pubsub.WithNamespace(client.WithInternalAuth(context.Background()), ns)
_ = g.client.PubSub().Publish(broadcastCtx, topic, eventData)
g.logger.ComponentInfo("gateway", "pubsub ws: member left presence",
zap.String("topic", topic),
zap.String("member_id", memberID))
}
g.logger.ComponentInfo("gateway", "pubsub ws: unregistered local subscriber",
zap.String("topic", topic),
zap.Int("remaining_subscribers", remainingCount))
}()
// Use internal auth context when interacting with client to avoid circular auth requirements
ctx := client.WithInternalAuth(r.Context())
// Apply namespace isolation
ctx = pubsub.WithNamespace(ctx, ns)
// Writer loop - START THIS FIRST before libp2p subscription
done := make(chan struct{})
go func() {
g.logger.ComponentInfo("gateway", "pubsub ws: writer goroutine started",
zap.String("topic", topic))
defer g.logger.ComponentInfo("gateway", "pubsub ws: writer goroutine exiting",
zap.String("topic", topic))
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for {
select {
case b, ok := <-msgs:
if !ok {
g.logger.ComponentWarn("gateway", "pubsub ws: message channel closed",
zap.String("topic", topic))
_ = conn.WriteControl(websocket.CloseMessage, []byte{}, time.Now().Add(5*time.Second))
close(done)
return
}
g.logger.ComponentInfo("gateway", "pubsub ws: sending message to client",
zap.String("topic", topic),
zap.Int("data_len", len(b)))
// Format message as JSON envelope with data (base64 encoded), timestamp, and topic
// This matches the SDK's Message interface: {data: string, timestamp: number, topic: string}
envelope := map[string]interface{}{
"data": base64.StdEncoding.EncodeToString(b),
"timestamp": time.Now().UnixMilli(),
"topic": topic,
}
envelopeJSON, err := json.Marshal(envelope)
if err != nil {
g.logger.ComponentWarn("gateway", "pubsub ws: failed to marshal envelope",
zap.String("topic", topic),
zap.Error(err))
continue
}
g.logger.ComponentDebug("gateway", "pubsub ws: envelope created",
zap.String("topic", topic),
zap.Int("envelope_len", len(envelopeJSON)))
conn.SetWriteDeadline(time.Now().Add(30 * time.Second))
if err := conn.WriteMessage(websocket.TextMessage, envelopeJSON); err != nil {
g.logger.ComponentWarn("gateway", "pubsub ws: failed to write to websocket",
zap.String("topic", topic),
zap.Error(err))
close(done)
return
}
g.logger.ComponentInfo("gateway", "pubsub ws: message sent successfully",
zap.String("topic", topic))
case <-ticker.C:
// Ping keepalive
_ = conn.WriteControl(websocket.PingMessage, []byte("ping"), time.Now().Add(5*time.Second))
case <-ctx.Done():
close(done)
return
}
}
}()
// Subscribe to libp2p for cross-node messages (in background, non-blocking)
go func() {
h := func(_ string, data []byte) error {
g.logger.ComponentInfo("gateway", "pubsub ws: received message from libp2p",
zap.String("topic", topic),
zap.Int("data_len", len(data)))
select {
case msgs <- data:
g.logger.ComponentInfo("gateway", "pubsub ws: forwarded to client",
zap.String("topic", topic),
zap.String("source", "libp2p"))
return nil
default:
// Drop if client is slow to avoid blocking network
g.logger.ComponentWarn("gateway", "pubsub ws: client slow, dropping message",
zap.String("topic", topic))
return nil
}
}
if err := g.client.PubSub().Subscribe(ctx, topic, h); err != nil {
g.logger.ComponentWarn("gateway", "pubsub ws: libp2p subscribe failed (will use local-only)",
zap.String("topic", topic),
zap.Error(err))
return
}
g.logger.ComponentInfo("gateway", "pubsub ws: libp2p subscription established",
zap.String("topic", topic))
// Keep subscription alive until done
<-done
_ = g.client.PubSub().Unsubscribe(ctx, topic)
g.logger.ComponentInfo("gateway", "pubsub ws: libp2p subscription closed",
zap.String("topic", topic))
}()
// Reader loop: treat any client message as publish to the same topic
for {
mt, data, err := conn.ReadMessage()
if err != nil {
break
}
if mt != websocket.TextMessage && mt != websocket.BinaryMessage {
continue
}
// Filter out WebSocket heartbeat messages
// Don't publish them to the topic
var msg map[string]interface{}
if err := json.Unmarshal(data, &msg); err == nil {
if msgType, ok := msg["type"].(string); ok && msgType == "ping" {
g.logger.ComponentInfo("gateway", "pubsub ws: filtering out heartbeat ping")
continue
}
}
if err := g.client.PubSub().Publish(ctx, topic, data); err != nil {
// Best-effort notify client
_ = conn.WriteMessage(websocket.TextMessage, []byte("publish_error"))
}
}
<-done
}
// pubsubPublishHandler handles POST /v1/pubsub/publish {topic, data_base64}
func (g *Gateway) pubsubPublishHandler(w http.ResponseWriter, r *http.Request) {
if g.client == nil {
writeError(w, http.StatusServiceUnavailable, "client not initialized")
return
}
if r.Method != http.MethodPost {
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
return
}
ns := resolveNamespaceFromRequest(r)
if ns == "" {
writeError(w, http.StatusForbidden, "namespace not resolved")
return
}
var body struct {
Topic string `json:"topic"`
DataB64 string `json:"data_base64"`
}
if err := json.NewDecoder(r.Body).Decode(&body); err != nil || body.Topic == "" || body.DataB64 == "" {
writeError(w, http.StatusBadRequest, "invalid body: expected {topic,data_base64}")
return
}
data, err := base64.StdEncoding.DecodeString(body.DataB64)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid base64 data")
return
}
// NEW: Check for local websocket subscribers FIRST and deliver directly
g.mu.RLock()
localSubs := g.getLocalSubscribers(body.Topic, ns)
g.mu.RUnlock()
localDeliveryCount := 0
if len(localSubs) > 0 {
for _, sub := range localSubs {
select {
case sub.msgChan <- data:
localDeliveryCount++
g.logger.ComponentDebug("gateway", "delivered to local subscriber",
zap.String("topic", body.Topic))
default:
// Drop if buffer full
g.logger.ComponentWarn("gateway", "local subscriber buffer full, dropping message",
zap.String("topic", body.Topic))
}
}
}
g.logger.ComponentInfo("gateway", "pubsub publish: processing message",
zap.String("topic", body.Topic),
zap.String("namespace", ns),
zap.Int("data_len", len(data)),
zap.Int("local_subscribers", len(localSubs)),
zap.Int("local_delivered", localDeliveryCount))
// Publish to libp2p asynchronously for cross-node delivery
// This prevents blocking the HTTP response if libp2p network is slow
go func() {
publishCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
ctx := pubsub.WithNamespace(client.WithInternalAuth(publishCtx), ns)
if err := g.client.PubSub().Publish(ctx, body.Topic, data); err != nil {
g.logger.ComponentWarn("gateway", "async libp2p publish failed",
zap.String("topic", body.Topic),
zap.Error(err))
} else {
g.logger.ComponentDebug("gateway", "async libp2p publish succeeded",
zap.String("topic", body.Topic))
}
}()
// Return immediately after local delivery
// Local WebSocket subscribers already received the message
writeJSON(w, http.StatusOK, map[string]any{"status": "ok"})
}
// pubsubTopicsHandler lists topics within the caller's namespace
func (g *Gateway) pubsubTopicsHandler(w http.ResponseWriter, r *http.Request) {
if g.client == nil {
writeError(w, http.StatusServiceUnavailable, "client not initialized")
return
}
ns := resolveNamespaceFromRequest(r)
if ns == "" {
writeError(w, http.StatusForbidden, "namespace not resolved")
return
}
// Apply namespace isolation
ctx := pubsub.WithNamespace(client.WithInternalAuth(r.Context()), ns)
all, err := g.client.PubSub().ListTopics(ctx)
if err != nil {
writeError(w, http.StatusInternalServerError, err.Error())
return
}
// Client returns topics already trimmed to its namespace; return as-is
writeJSON(w, http.StatusOK, map[string]any{"topics": all})
}
// resolveNamespaceFromRequest gets namespace from context set by auth middleware
// Falls back to query parameter "namespace" for development/testing
func resolveNamespaceFromRequest(r *http.Request) string {
if v := r.Context().Value(ctxKeyNamespaceOverride); v != nil {
if s, ok := v.(string); ok {
return s
}
}
// Fallback: check query parameter for development
if ns := strings.TrimSpace(r.URL.Query().Get("namespace")); ns != "" {
return ns
}
return ""
}
func namespacePrefix(ns string) string {
return "ns::" + ns + "::"
}
func namespacedTopic(ns, topic string) string {
return namespacePrefix(ns) + topic
}
// pubsubPresenceHandler handles GET /v1/pubsub/presence?topic=mytopic
func (g *Gateway) pubsubPresenceHandler(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
return
}
ns := resolveNamespaceFromRequest(r)
if ns == "" {
writeError(w, http.StatusForbidden, "namespace not resolved")
return
}
topic := r.URL.Query().Get("topic")
if topic == "" {
writeError(w, http.StatusBadRequest, "missing 'topic'")
return
}
topicKey := fmt.Sprintf("%s.%s", ns, topic)
g.presenceMu.RLock()
members, ok := g.presenceMembers[topicKey]
g.presenceMu.RUnlock()
if !ok {
writeJSON(w, http.StatusOK, map[string]any{
"topic": topic,
"members": []PresenceMember{},
"count": 0,
})
return
}
writeJSON(w, http.StatusOK, map[string]any{
"topic": topic,
"members": members,
"count": len(members),
})
}

View File

@ -79,5 +79,14 @@ func (g *Gateway) Routes() http.Handler {
g.serverlessHandlers.RegisterRoutes(mux)
}
// TURN credentials for WebRTC
mux.HandleFunc("/v1/turn/credentials", g.turnCredentialsHandler)
// SFU endpoints for WebRTC group calls (if enabled)
if g.sfuManager != nil {
mux.HandleFunc("/v1/sfu/room", g.sfuCreateRoomHandler)
mux.HandleFunc("/v1/sfu/room/", g.sfuRoomHandler) // Handles :roomId/* paths
}
return g.withMiddleware(mux)
}

View File

@ -0,0 +1,933 @@
package gateway
import (
"context"
"encoding/json"
"io"
"net/http"
"strconv"
"strings"
"time"
"github.com/DeBrosOfficial/network/pkg/gateway/auth"
"github.com/DeBrosOfficial/network/pkg/serverless"
"github.com/google/uuid"
"github.com/gorilla/websocket"
"go.uber.org/zap"
)
// ServerlessHandlers contains handlers for serverless function endpoints.
// It's a separate struct to keep the Gateway struct clean.
type ServerlessHandlers struct {
invoker *serverless.Invoker
registry serverless.FunctionRegistry
wsManager *serverless.WSManager
triggerManager serverless.TriggerManager
logger *zap.Logger
}
// NewServerlessHandlers creates a new ServerlessHandlers instance.
func NewServerlessHandlers(
invoker *serverless.Invoker,
registry serverless.FunctionRegistry,
wsManager *serverless.WSManager,
triggerManager serverless.TriggerManager,
logger *zap.Logger,
) *ServerlessHandlers {
return &ServerlessHandlers{
invoker: invoker,
registry: registry,
wsManager: wsManager,
triggerManager: triggerManager,
logger: logger,
}
}
// RegisterRoutes registers all serverless routes on the given mux.
func (h *ServerlessHandlers) RegisterRoutes(mux *http.ServeMux) {
// Function management
mux.HandleFunc("/v1/functions", h.handleFunctions)
mux.HandleFunc("/v1/functions/", h.handleFunctionByName)
// Direct invoke endpoint
mux.HandleFunc("/v1/invoke/", h.handleInvoke)
}
// handleFunctions handles GET /v1/functions (list) and POST /v1/functions (deploy)
func (h *ServerlessHandlers) handleFunctions(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodGet:
h.listFunctions(w, r)
case http.MethodPost:
h.deployFunction(w, r)
default:
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
}
// handleFunctionByName handles operations on a specific function
// Routes:
// - GET /v1/functions/{name} - Get function info
// - DELETE /v1/functions/{name} - Delete function
// - POST /v1/functions/{name}/invoke - Invoke function
// - GET /v1/functions/{name}/versions - List versions
// - GET /v1/functions/{name}/logs - Get logs
// - WS /v1/functions/{name}/ws - WebSocket invoke
func (h *ServerlessHandlers) handleFunctionByName(w http.ResponseWriter, r *http.Request) {
// Parse path: /v1/functions/{name}[/{action}]
path := strings.TrimPrefix(r.URL.Path, "/v1/functions/")
parts := strings.SplitN(path, "/", 2)
if len(parts) == 0 || parts[0] == "" {
http.Error(w, "Function name required", http.StatusBadRequest)
return
}
name := parts[0]
action := ""
if len(parts) > 1 {
action = parts[1]
}
// Parse version from name if present (e.g., "myfunction@2")
version := 0
if idx := strings.Index(name, "@"); idx > 0 {
vStr := name[idx+1:]
name = name[:idx]
if v, err := strconv.Atoi(vStr); err == nil {
version = v
}
}
switch action {
case "invoke":
h.invokeFunction(w, r, name, version)
case "ws":
h.handleWebSocket(w, r, name, version)
case "versions":
h.listVersions(w, r, name)
case "logs":
h.getFunctionLogs(w, r, name)
case "triggers":
h.handleFunctionTriggers(w, r, name)
case "":
switch r.Method {
case http.MethodGet:
h.getFunctionInfo(w, r, name, version)
case http.MethodDelete:
h.deleteFunction(w, r, name, version)
default:
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
default:
http.Error(w, "Unknown action", http.StatusNotFound)
}
}
// handleInvoke handles POST /v1/invoke/{namespace}/{name}[@version]
func (h *ServerlessHandlers) handleInvoke(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// Parse path: /v1/invoke/{namespace}/{name}[@version]
path := strings.TrimPrefix(r.URL.Path, "/v1/invoke/")
parts := strings.SplitN(path, "/", 2)
if len(parts) < 2 {
http.Error(w, "Path must be /v1/invoke/{namespace}/{name}", http.StatusBadRequest)
return
}
namespace := parts[0]
name := parts[1]
// Parse version if present
version := 0
if idx := strings.Index(name, "@"); idx > 0 {
vStr := name[idx+1:]
name = name[:idx]
if v, err := strconv.Atoi(vStr); err == nil {
version = v
}
}
h.invokeFunction(w, r, namespace+"/"+name, version)
}
// listFunctions handles GET /v1/functions
func (h *ServerlessHandlers) listFunctions(w http.ResponseWriter, r *http.Request) {
namespace := r.URL.Query().Get("namespace")
if namespace == "" {
// Get namespace from JWT if available
namespace = h.getNamespaceFromRequest(r)
}
if namespace == "" {
writeError(w, http.StatusBadRequest, "namespace required")
return
}
ctx, cancel := context.WithTimeout(r.Context(), 30*time.Second)
defer cancel()
functions, err := h.registry.List(ctx, namespace)
if err != nil {
h.logger.Error("Failed to list functions", zap.Error(err))
writeError(w, http.StatusInternalServerError, "Failed to list functions")
return
}
writeJSON(w, http.StatusOK, map[string]interface{}{
"functions": functions,
"count": len(functions),
})
}
// deployFunction handles POST /v1/functions
func (h *ServerlessHandlers) deployFunction(w http.ResponseWriter, r *http.Request) {
// Parse multipart form (for WASM upload) or JSON
contentType := r.Header.Get("Content-Type")
var def serverless.FunctionDefinition
var wasmBytes []byte
if strings.HasPrefix(contentType, "multipart/form-data") {
// Parse multipart form
if err := r.ParseMultipartForm(32 << 20); err != nil { // 32MB max
writeError(w, http.StatusBadRequest, "Failed to parse form: "+err.Error())
return
}
// Get metadata from form field
metadataStr := r.FormValue("metadata")
if metadataStr != "" {
if err := json.Unmarshal([]byte(metadataStr), &def); err != nil {
writeError(w, http.StatusBadRequest, "Invalid metadata JSON: "+err.Error())
return
}
}
// Get name from form if not in metadata
if def.Name == "" {
def.Name = r.FormValue("name")
}
// Get namespace from form if not in metadata
if def.Namespace == "" {
def.Namespace = r.FormValue("namespace")
}
// Get other configuration fields from form
if v := r.FormValue("is_public"); v != "" {
def.IsPublic, _ = strconv.ParseBool(v)
}
if v := r.FormValue("memory_limit_mb"); v != "" {
def.MemoryLimitMB, _ = strconv.Atoi(v)
}
if v := r.FormValue("timeout_seconds"); v != "" {
def.TimeoutSeconds, _ = strconv.Atoi(v)
}
if v := r.FormValue("retry_count"); v != "" {
def.RetryCount, _ = strconv.Atoi(v)
}
if v := r.FormValue("retry_delay_seconds"); v != "" {
def.RetryDelaySeconds, _ = strconv.Atoi(v)
}
// Get WASM file
file, _, err := r.FormFile("wasm")
if err != nil {
writeError(w, http.StatusBadRequest, "WASM file required")
return
}
defer file.Close()
wasmBytes, err = io.ReadAll(file)
if err != nil {
writeError(w, http.StatusBadRequest, "Failed to read WASM file: "+err.Error())
return
}
} else {
// JSON body with base64-encoded WASM
var req struct {
serverless.FunctionDefinition
WASMBase64 string `json:"wasm_base64"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeError(w, http.StatusBadRequest, "Invalid JSON: "+err.Error())
return
}
def = req.FunctionDefinition
if req.WASMBase64 != "" {
// Decode base64 WASM - for now, just reject this method
writeError(w, http.StatusBadRequest, "Base64 WASM upload not supported, use multipart/form-data")
return
}
}
// Get namespace from JWT if not provided
if def.Namespace == "" {
def.Namespace = h.getNamespaceFromRequest(r)
}
if def.Name == "" {
writeError(w, http.StatusBadRequest, "Function name required")
return
}
if def.Namespace == "" {
writeError(w, http.StatusBadRequest, "Namespace required")
return
}
if len(wasmBytes) == 0 {
writeError(w, http.StatusBadRequest, "WASM bytecode required")
return
}
ctx, cancel := context.WithTimeout(r.Context(), 60*time.Second)
defer cancel()
oldFn, err := h.registry.Register(ctx, &def, wasmBytes)
if err != nil {
h.logger.Error("Failed to deploy function",
zap.String("name", def.Name),
zap.Error(err),
)
writeError(w, http.StatusInternalServerError, "Failed to deploy: "+err.Error())
return
}
// Invalidate cache for the old version to ensure the new one is loaded
if oldFn != nil {
h.invoker.InvalidateCache(oldFn.WASMCID)
h.logger.Debug("Invalidated function cache",
zap.String("name", def.Name),
zap.String("old_wasm_cid", oldFn.WASMCID),
)
}
h.logger.Info("Function deployed",
zap.String("name", def.Name),
zap.String("namespace", def.Namespace),
)
// Fetch the deployed function to return
fn, err := h.registry.Get(ctx, def.Namespace, def.Name, def.Version)
if err != nil {
writeJSON(w, http.StatusCreated, map[string]interface{}{
"message": "Function deployed successfully",
"name": def.Name,
})
return
}
// Register PubSub triggers if provided in metadata
var triggersAdded []string
if len(def.PubSubTopics) > 0 && h.triggerManager != nil {
for _, topic := range def.PubSubTopics {
if err := h.triggerManager.AddPubSubTrigger(ctx, fn.ID, topic); err != nil {
// Log but don't fail deployment
h.logger.Warn("Failed to add pubsub trigger during deployment",
zap.String("function", def.Name),
zap.String("topic", topic),
zap.Error(err),
)
} else {
triggersAdded = append(triggersAdded, topic)
h.logger.Info("PubSub trigger added during deployment",
zap.String("function", def.Name),
zap.String("topic", topic),
)
}
}
}
response := map[string]interface{}{
"message": "Function deployed successfully",
"function": fn,
}
if len(triggersAdded) > 0 {
response["triggers_added"] = triggersAdded
}
writeJSON(w, http.StatusCreated, response)
}
// getFunctionInfo handles GET /v1/functions/{name}
func (h *ServerlessHandlers) getFunctionInfo(w http.ResponseWriter, r *http.Request, name string, version int) {
namespace := r.URL.Query().Get("namespace")
if namespace == "" {
namespace = h.getNamespaceFromRequest(r)
}
if namespace == "" {
writeError(w, http.StatusBadRequest, "namespace required")
return
}
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
defer cancel()
fn, err := h.registry.Get(ctx, namespace, name, version)
if err != nil {
if serverless.IsNotFound(err) {
writeError(w, http.StatusNotFound, "Function not found")
} else {
writeError(w, http.StatusInternalServerError, "Failed to get function")
}
return
}
writeJSON(w, http.StatusOK, fn)
}
// deleteFunction handles DELETE /v1/functions/{name}
func (h *ServerlessHandlers) deleteFunction(w http.ResponseWriter, r *http.Request, name string, version int) {
namespace := r.URL.Query().Get("namespace")
if namespace == "" {
namespace = h.getNamespaceFromRequest(r)
}
if namespace == "" {
writeError(w, http.StatusBadRequest, "namespace required")
return
}
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
defer cancel()
if err := h.registry.Delete(ctx, namespace, name, version); err != nil {
if serverless.IsNotFound(err) {
writeError(w, http.StatusNotFound, "Function not found")
} else {
writeError(w, http.StatusInternalServerError, "Failed to delete function")
}
return
}
writeJSON(w, http.StatusOK, map[string]string{
"message": "Function deleted successfully",
})
}
// invokeFunction handles POST /v1/functions/{name}/invoke
func (h *ServerlessHandlers) invokeFunction(w http.ResponseWriter, r *http.Request, nameWithNS string, version int) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// Parse namespace and name
var namespace, name string
if idx := strings.Index(nameWithNS, "/"); idx > 0 {
namespace = nameWithNS[:idx]
name = nameWithNS[idx+1:]
} else {
name = nameWithNS
namespace = r.URL.Query().Get("namespace")
if namespace == "" {
namespace = h.getNamespaceFromRequest(r)
}
}
if namespace == "" {
writeError(w, http.StatusBadRequest, "namespace required")
return
}
// Read input body
input, err := io.ReadAll(io.LimitReader(r.Body, 1<<20)) // 1MB max
if err != nil {
writeError(w, http.StatusBadRequest, "Failed to read request body")
return
}
// Get caller wallet from JWT
callerWallet := h.getWalletFromRequest(r)
ctx, cancel := context.WithTimeout(r.Context(), 60*time.Second)
defer cancel()
req := &serverless.InvokeRequest{
Namespace: namespace,
FunctionName: name,
Version: version,
Input: input,
TriggerType: serverless.TriggerTypeHTTP,
CallerWallet: callerWallet,
}
resp, err := h.invoker.Invoke(ctx, req)
if err != nil {
statusCode := http.StatusInternalServerError
if serverless.IsNotFound(err) {
statusCode = http.StatusNotFound
} else if serverless.IsResourceExhausted(err) {
statusCode = http.StatusTooManyRequests
} else if serverless.IsUnauthorized(err) {
statusCode = http.StatusUnauthorized
}
writeJSON(w, statusCode, map[string]interface{}{
"request_id": resp.RequestID,
"status": resp.Status,
"error": resp.Error,
"duration_ms": resp.DurationMS,
})
return
}
// Return the function's output directly if it's JSON
w.Header().Set("X-Request-ID", resp.RequestID)
w.Header().Set("X-Duration-Ms", strconv.FormatInt(resp.DurationMS, 10))
// Try to detect if output is JSON
if len(resp.Output) > 0 && (resp.Output[0] == '{' || resp.Output[0] == '[') {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write(resp.Output)
} else {
writeJSON(w, http.StatusOK, map[string]interface{}{
"request_id": resp.RequestID,
"output": string(resp.Output),
"status": resp.Status,
"duration_ms": resp.DurationMS,
})
}
}
// handleWebSocket handles WebSocket connections for function streaming
func (h *ServerlessHandlers) handleWebSocket(w http.ResponseWriter, r *http.Request, name string, version int) {
namespace := r.URL.Query().Get("namespace")
if namespace == "" {
namespace = h.getNamespaceFromRequest(r)
}
if namespace == "" {
http.Error(w, "namespace required", http.StatusBadRequest)
return
}
// Upgrade to WebSocket
upgrader := websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool { return true },
}
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
h.logger.Error("WebSocket upgrade failed", zap.Error(err))
return
}
clientID := uuid.New().String()
wsConn := &serverless.GorillaWSConn{Conn: conn}
// Register connection
h.wsManager.Register(clientID, wsConn)
defer h.wsManager.Unregister(clientID)
h.logger.Info("WebSocket connected",
zap.String("client_id", clientID),
zap.String("function", name),
)
callerWallet := h.getWalletFromRequest(r)
// Message loop
for {
_, message, err := conn.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
h.logger.Warn("WebSocket error", zap.Error(err))
}
break
}
// Invoke function with WebSocket context
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
req := &serverless.InvokeRequest{
Namespace: namespace,
FunctionName: name,
Version: version,
Input: message,
TriggerType: serverless.TriggerTypeWebSocket,
CallerWallet: callerWallet,
WSClientID: clientID,
}
resp, err := h.invoker.Invoke(ctx, req)
cancel()
// Send response back
response := map[string]interface{}{
"request_id": resp.RequestID,
"status": resp.Status,
"duration_ms": resp.DurationMS,
}
if err != nil {
response["error"] = resp.Error
} else if len(resp.Output) > 0 {
// Try to parse output as JSON
var output interface{}
if json.Unmarshal(resp.Output, &output) == nil {
response["output"] = output
} else {
response["output"] = string(resp.Output)
}
}
respBytes, _ := json.Marshal(response)
if err := conn.WriteMessage(websocket.TextMessage, respBytes); err != nil {
break
}
}
}
// listVersions handles GET /v1/functions/{name}/versions
func (h *ServerlessHandlers) listVersions(w http.ResponseWriter, r *http.Request, name string) {
namespace := r.URL.Query().Get("namespace")
if namespace == "" {
namespace = h.getNamespaceFromRequest(r)
}
if namespace == "" {
writeError(w, http.StatusBadRequest, "namespace required")
return
}
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
defer cancel()
// Get registry with extended methods
reg, ok := h.registry.(*serverless.Registry)
if !ok {
writeError(w, http.StatusNotImplemented, "Version listing not supported")
return
}
versions, err := reg.ListVersions(ctx, namespace, name)
if err != nil {
writeError(w, http.StatusInternalServerError, "Failed to list versions")
return
}
writeJSON(w, http.StatusOK, map[string]interface{}{
"versions": versions,
"count": len(versions),
})
}
// getFunctionLogs handles GET /v1/functions/{name}/logs
func (h *ServerlessHandlers) getFunctionLogs(w http.ResponseWriter, r *http.Request, name string) {
namespace := r.URL.Query().Get("namespace")
if namespace == "" {
namespace = h.getNamespaceFromRequest(r)
}
if namespace == "" {
writeError(w, http.StatusBadRequest, "namespace required")
return
}
limit := 100
if lStr := r.URL.Query().Get("limit"); lStr != "" {
if l, err := strconv.Atoi(lStr); err == nil {
limit = l
}
}
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
defer cancel()
logs, err := h.registry.GetLogs(ctx, namespace, name, limit)
if err != nil {
h.logger.Error("Failed to get function logs",
zap.String("name", name),
zap.String("namespace", namespace),
zap.Error(err),
)
writeError(w, http.StatusInternalServerError, "Failed to get logs")
return
}
writeJSON(w, http.StatusOK, map[string]interface{}{
"name": name,
"namespace": namespace,
"logs": logs,
"count": len(logs),
})
}
// getNamespaceFromRequest extracts namespace from JWT or query param
func (h *ServerlessHandlers) getNamespaceFromRequest(r *http.Request) string {
// Try context first (set by auth middleware) - most secure
if v := r.Context().Value(ctxKeyNamespaceOverride); v != nil {
if ns, ok := v.(string); ok && ns != "" {
return ns
}
}
// Try query param as fallback (e.g. for public access or admin)
if ns := r.URL.Query().Get("namespace"); ns != "" {
return ns
}
// Try header as fallback
if ns := r.Header.Get("X-Namespace"); ns != "" {
return ns
}
return "default"
}
// getWalletFromRequest extracts wallet address from JWT
func (h *ServerlessHandlers) getWalletFromRequest(r *http.Request) string {
// 1. Try X-Wallet header (legacy/direct bypass)
if wallet := r.Header.Get("X-Wallet"); wallet != "" {
return wallet
}
// 2. Try JWT claims from context
if v := r.Context().Value(ctxKeyJWT); v != nil {
if claims, ok := v.(*auth.JWTClaims); ok && claims != nil {
subj := strings.TrimSpace(claims.Sub)
// Ensure it's not an API key (standard Orama logic)
if !strings.HasPrefix(strings.ToLower(subj), "ak_") && !strings.Contains(subj, ":") {
return subj
}
}
}
// 3. Fallback to API key identity (namespace)
if v := r.Context().Value(ctxKeyNamespaceOverride); v != nil {
if ns, ok := v.(string); ok && ns != "" {
return ns
}
}
return ""
}
// HealthStatus returns the health status of the serverless engine
func (h *ServerlessHandlers) HealthStatus() map[string]interface{} {
stats := h.wsManager.GetStats()
return map[string]interface{}{
"status": "ok",
"connections": stats.ConnectionCount,
"topics": stats.TopicCount,
}
}
// handleFunctionTriggers handles trigger operations for a function
// Routes:
// - GET /v1/functions/{name}/triggers - List all triggers
// - POST /v1/functions/{name}/triggers/pubsub - Add pubsub trigger
// - DELETE /v1/functions/{name}/triggers/{id} - Remove trigger
func (h *ServerlessHandlers) handleFunctionTriggers(w http.ResponseWriter, r *http.Request, name string) {
if h.triggerManager == nil {
writeError(w, http.StatusServiceUnavailable, "Trigger management not available")
return
}
namespace := r.URL.Query().Get("namespace")
if namespace == "" {
namespace = h.getNamespaceFromRequest(r)
}
if namespace == "" {
writeError(w, http.StatusBadRequest, "namespace required")
return
}
// Parse sub-path for trigger type or ID
// Path after "triggers" could be: "", "pubsub", or "{trigger_id}"
fullPath := r.URL.Path
triggersIdx := strings.Index(fullPath, "/triggers")
subPath := ""
if triggersIdx > 0 {
subPath = strings.TrimPrefix(fullPath[triggersIdx:], "/triggers")
subPath = strings.TrimPrefix(subPath, "/")
}
switch r.Method {
case http.MethodGet:
h.listFunctionTriggers(w, r, namespace, name)
case http.MethodPost:
if subPath == "pubsub" {
h.addPubSubTrigger(w, r, namespace, name)
} else {
writeError(w, http.StatusBadRequest, "Invalid trigger type. Use /triggers/pubsub")
}
case http.MethodDelete:
if subPath == "" {
writeError(w, http.StatusBadRequest, "Trigger ID required")
return
}
h.removeFunctionTrigger(w, r, namespace, name, subPath)
default:
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
}
// listFunctionTriggers handles GET /v1/functions/{name}/triggers
func (h *ServerlessHandlers) listFunctionTriggers(w http.ResponseWriter, r *http.Request, namespace, name string) {
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
defer cancel()
// Get function to verify it exists and get its ID
fn, err := h.registry.Get(ctx, namespace, name, 0)
if err != nil {
if serverless.IsNotFound(err) {
writeError(w, http.StatusNotFound, "Function not found")
return
}
h.logger.Error("Failed to get function",
zap.String("name", name),
zap.String("namespace", namespace),
zap.Error(err),
)
writeError(w, http.StatusInternalServerError, "Failed to get function")
return
}
// Get pubsub triggers
pubsubTriggers, err := h.triggerManager.ListPubSubTriggers(ctx, fn.ID)
if err != nil {
h.logger.Error("Failed to list triggers",
zap.String("function_id", fn.ID),
zap.Error(err),
)
writeError(w, http.StatusInternalServerError, "Failed to list triggers")
return
}
writeJSON(w, http.StatusOK, map[string]interface{}{
"name": name,
"namespace": namespace,
"function_id": fn.ID,
"pubsub_triggers": pubsubTriggers,
"count": len(pubsubTriggers),
})
}
// addPubSubTrigger handles POST /v1/functions/{name}/triggers/pubsub
func (h *ServerlessHandlers) addPubSubTrigger(w http.ResponseWriter, r *http.Request, namespace, name string) {
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
defer cancel()
// Parse request body
var req struct {
Topic string `json:"topic"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeError(w, http.StatusBadRequest, "Invalid JSON: "+err.Error())
return
}
if req.Topic == "" {
writeError(w, http.StatusBadRequest, "topic is required")
return
}
// Get function to verify it exists and get its ID
fn, err := h.registry.Get(ctx, namespace, name, 0)
if err != nil {
if serverless.IsNotFound(err) {
writeError(w, http.StatusNotFound, "Function not found")
return
}
h.logger.Error("Failed to get function",
zap.String("name", name),
zap.String("namespace", namespace),
zap.Error(err),
)
writeError(w, http.StatusInternalServerError, "Failed to get function")
return
}
// Add the trigger
err = h.triggerManager.AddPubSubTrigger(ctx, fn.ID, req.Topic)
if err != nil {
if serverless.IsValidationError(err) {
writeError(w, http.StatusBadRequest, err.Error())
return
}
h.logger.Error("Failed to add pubsub trigger",
zap.String("function_id", fn.ID),
zap.String("topic", req.Topic),
zap.Error(err),
)
writeError(w, http.StatusInternalServerError, "Failed to add trigger")
return
}
// Get the triggers to return the newly created one
triggers, _ := h.triggerManager.ListPubSubTriggers(ctx, fn.ID)
var newTrigger *serverless.PubSubTrigger
for i := range triggers {
if triggers[i].Topic == req.Topic {
newTrigger = &triggers[i]
break
}
}
h.logger.Info("PubSub trigger added",
zap.String("function", name),
zap.String("namespace", namespace),
zap.String("topic", req.Topic),
)
writeJSON(w, http.StatusCreated, map[string]interface{}{
"message": "Trigger added successfully",
"trigger": newTrigger,
})
}
// removeFunctionTrigger handles DELETE /v1/functions/{name}/triggers/{id}
func (h *ServerlessHandlers) removeFunctionTrigger(w http.ResponseWriter, r *http.Request, namespace, name, triggerID string) {
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
defer cancel()
// Verify function exists
fn, err := h.registry.Get(ctx, namespace, name, 0)
if err != nil {
if serverless.IsNotFound(err) {
writeError(w, http.StatusNotFound, "Function not found")
return
}
writeError(w, http.StatusInternalServerError, "Failed to get function")
return
}
// Remove the trigger
err = h.triggerManager.RemoveTrigger(ctx, triggerID)
if err != nil {
if err == serverless.ErrTriggerNotFound {
writeError(w, http.StatusNotFound, "Trigger not found")
return
}
h.logger.Error("Failed to remove trigger",
zap.String("trigger_id", triggerID),
zap.Error(err),
)
writeError(w, http.StatusInternalServerError, "Failed to remove trigger")
return
}
h.logger.Info("Trigger removed",
zap.String("function", name),
zap.String("function_id", fn.ID),
zap.String("trigger_id", triggerID),
)
writeJSON(w, http.StatusOK, map[string]interface{}{
"message": "Trigger removed successfully",
"trigger_id": triggerID,
})
}

View File

@ -8,7 +8,6 @@ import (
"net/http/httptest"
"testing"
serverlesshandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/serverless"
"github.com/DeBrosOfficial/network/pkg/serverless"
"go.uber.org/zap"
)
@ -50,12 +49,12 @@ func TestServerlessHandlers_ListFunctions(t *testing.T) {
},
}
h := serverlesshandlers.NewServerlessHandlers(nil, registry, nil, logger)
h := NewServerlessHandlers(nil, registry, nil, nil, logger)
req, _ := http.NewRequest("GET", "/v1/functions?namespace=ns1", nil)
rr := httptest.NewRecorder()
h.ListFunctions(rr, req)
h.handleFunctions(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", rr.Code)
@ -73,7 +72,7 @@ func TestServerlessHandlers_DeployFunction(t *testing.T) {
logger := zap.NewNop()
registry := &mockFunctionRegistry{}
h := serverlesshandlers.NewServerlessHandlers(nil, registry, nil, logger)
h := NewServerlessHandlers(nil, registry, nil, nil, logger)
// Test JSON deploy (which is partially supported according to code)
// Should be 400 because WASM is missing or base64 not supported
@ -81,7 +80,7 @@ func TestServerlessHandlers_DeployFunction(t *testing.T) {
req, _ := http.NewRequest("POST", "/v1/functions", bytes.NewBufferString(`{"name": "test"}`))
req.Header.Set("Content-Type", "application/json")
h.DeployFunction(writer, req)
h.handleFunctions(writer, req)
if writer.Code != http.StatusBadRequest {
t.Errorf("expected status 400, got %d", writer.Code)

181
pkg/gateway/sfu/config.go Normal file
View File

@ -0,0 +1,181 @@
package sfu
import (
"time"
"github.com/pion/interceptor"
"github.com/pion/interceptor/pkg/intervalpli"
"github.com/pion/interceptor/pkg/nack"
"github.com/pion/webrtc/v4"
)
// Config holds SFU configuration
type Config struct {
// MaxParticipants is the maximum number of participants per room
MaxParticipants int
// MediaTimeout is the timeout for media operations
MediaTimeout time.Duration
// ICEServers are the ICE servers for WebRTC connections
ICEServers []webrtc.ICEServer
}
// DefaultConfig returns a default SFU configuration
func DefaultConfig() *Config {
return &Config{
MaxParticipants: 10,
MediaTimeout: 30 * time.Second,
ICEServers: []webrtc.ICEServer{
{
URLs: []string{"stun:stun.l.google.com:19302"},
},
},
}
}
// NewMediaEngine creates a MediaEngine with supported codecs for the SFU
func NewMediaEngine() (*webrtc.MediaEngine, error) {
m := &webrtc.MediaEngine{}
// RTCP feedback for video codecs - enables NACK, PLI, FIR
videoRTCPFeedback := []webrtc.RTCPFeedback{
{Type: "goog-remb", Parameter: ""}, // Bandwidth estimation
{Type: "ccm", Parameter: "fir"}, // Full Intra Request
{Type: "nack", Parameter: ""}, // Generic NACK
{Type: "nack", Parameter: "pli"}, // Picture Loss Indication
}
// Register Opus codec for audio with NACK support
if err := m.RegisterCodec(webrtc.RTPCodecParameters{
RTPCodecCapability: webrtc.RTPCodecCapability{
MimeType: webrtc.MimeTypeOpus,
ClockRate: 48000,
Channels: 2,
SDPFmtpLine: "minptime=10;useinbandfec=1",
},
PayloadType: 111,
}, webrtc.RTPCodecTypeAudio); err != nil {
return nil, err
}
// Register VP8 codec for video with full RTCP feedback
if err := m.RegisterCodec(webrtc.RTPCodecParameters{
RTPCodecCapability: webrtc.RTPCodecCapability{
MimeType: webrtc.MimeTypeVP8,
ClockRate: 90000,
RTCPFeedback: videoRTCPFeedback,
},
PayloadType: 96,
}, webrtc.RTPCodecTypeVideo); err != nil {
return nil, err
}
// Register RTX for VP8 (retransmission)
if err := m.RegisterCodec(webrtc.RTPCodecParameters{
RTPCodecCapability: webrtc.RTPCodecCapability{
MimeType: "video/rtx",
ClockRate: 90000,
SDPFmtpLine: "apt=96",
},
PayloadType: 97,
}, webrtc.RTPCodecTypeVideo); err != nil {
return nil, err
}
// Register H264 codec for video with full RTCP feedback (fallback)
if err := m.RegisterCodec(webrtc.RTPCodecParameters{
RTPCodecCapability: webrtc.RTPCodecCapability{
MimeType: webrtc.MimeTypeH264,
ClockRate: 90000,
SDPFmtpLine: "level-asymmetry-allowed=1;packetization-mode=1;profile-level-id=42e01f",
RTCPFeedback: videoRTCPFeedback,
},
PayloadType: 102,
}, webrtc.RTPCodecTypeVideo); err != nil {
return nil, err
}
// Register RTX for H264 (retransmission)
if err := m.RegisterCodec(webrtc.RTPCodecParameters{
RTPCodecCapability: webrtc.RTPCodecCapability{
MimeType: "video/rtx",
ClockRate: 90000,
SDPFmtpLine: "apt=102",
},
PayloadType: 103,
}, webrtc.RTPCodecTypeVideo); err != nil {
return nil, err
}
return m, nil
}
// NewWebRTCAPI creates a new WebRTC API with the configured MediaEngine
func NewWebRTCAPI() (*webrtc.API, error) {
mediaEngine, err := NewMediaEngine()
if err != nil {
return nil, err
}
// Create interceptor registry for RTCP feedback
i := &interceptor.Registry{}
// Register NACK responder - handles retransmission requests from receivers
// This is critical for video quality: when a receiver loses a packet,
// it sends NACK and the sender retransmits the lost packet
nackResponderFactory, err := nack.NewResponderInterceptor()
if err != nil {
return nil, err
}
i.Add(nackResponderFactory)
// Register NACK generator - sends NACK when we detect packet loss as receiver
nackGeneratorFactory, err := nack.NewGeneratorInterceptor()
if err != nil {
return nil, err
}
i.Add(nackGeneratorFactory)
// Register interval PLI - automatically sends PLI periodically for video tracks
// This helps new receivers get keyframes faster
intervalPLIFactory, err := intervalpli.NewReceiverInterceptor()
if err != nil {
return nil, err
}
i.Add(intervalPLIFactory)
// Configure settings for better media performance
settingEngine := webrtc.SettingEngine{}
return webrtc.NewAPI(
webrtc.WithMediaEngine(mediaEngine),
webrtc.WithInterceptorRegistry(i),
webrtc.WithSettingEngine(settingEngine),
), nil
}
// GetRTPCapabilities returns the RTP capabilities of the SFU
// This is used by clients to negotiate codecs
func GetRTPCapabilities() map[string]interface{} {
return map[string]interface{}{
"codecs": []map[string]interface{}{
{
"kind": "audio",
"mimeType": "audio/opus",
"clockRate": 48000,
"channels": 2,
},
{
"kind": "video",
"mimeType": "video/VP8",
"clockRate": 90000,
},
{
"kind": "video",
"mimeType": "video/H264",
"clockRate": 90000,
},
},
}
}

228
pkg/gateway/sfu/manager.go Normal file
View File

@ -0,0 +1,228 @@
package sfu
import (
"errors"
"sync"
"github.com/pion/webrtc/v4"
"go.uber.org/zap"
)
var (
ErrRoomNotFound = errors.New("room not found")
)
// RoomManager manages all rooms in the SFU
type RoomManager struct {
rooms map[string]*Room // key: namespace:roomId
roomsMu sync.RWMutex
api *webrtc.API
config *Config
logger *zap.Logger
}
// NewRoomManager creates a new room manager
func NewRoomManager(config *Config, logger *zap.Logger) (*RoomManager, error) {
if config == nil {
config = DefaultConfig()
}
api, err := NewWebRTCAPI()
if err != nil {
return nil, err
}
return &RoomManager{
rooms: make(map[string]*Room),
api: api,
config: config,
logger: logger.With(zap.String("component", "sfu_manager")),
}, nil
}
// roomKey creates a unique key for namespace:roomId
func roomKey(namespace, roomID string) string {
return namespace + ":" + roomID
}
// GetOrCreateRoom gets an existing room or creates a new one
func (m *RoomManager) GetOrCreateRoom(namespace, roomID string) (*Room, bool) {
key := roomKey(namespace, roomID)
m.roomsMu.Lock()
defer m.roomsMu.Unlock()
// Check if room already exists
if room, ok := m.rooms[key]; ok {
if !room.IsClosed() {
return room, false
}
// Room is closed, remove it and create a new one
delete(m.rooms, key)
}
// Create new room
room := NewRoom(roomID, namespace, m.api, m.config, m.logger)
// Set up empty room handler
room.OnEmpty(func(r *Room) {
m.logger.Info("Room is empty, will be garbage collected",
zap.String("room_id", r.ID),
zap.String("namespace", r.Namespace),
)
// Optionally close the room immediately
// m.CloseRoom(r.Namespace, r.ID)
})
m.rooms[key] = room
m.logger.Info("Room created",
zap.String("room_id", roomID),
zap.String("namespace", namespace),
)
return room, true
}
// GetRoom gets an existing room
func (m *RoomManager) GetRoom(namespace, roomID string) (*Room, error) {
key := roomKey(namespace, roomID)
m.roomsMu.RLock()
defer m.roomsMu.RUnlock()
room, ok := m.rooms[key]
if !ok {
return nil, ErrRoomNotFound
}
if room.IsClosed() {
return nil, ErrRoomClosed
}
return room, nil
}
// CloseRoom closes and removes a room
func (m *RoomManager) CloseRoom(namespace, roomID string) error {
key := roomKey(namespace, roomID)
m.roomsMu.Lock()
room, ok := m.rooms[key]
if !ok {
m.roomsMu.Unlock()
return ErrRoomNotFound
}
delete(m.rooms, key)
m.roomsMu.Unlock()
if err := room.Close(); err != nil {
m.logger.Warn("Error closing room",
zap.String("room_id", roomID),
zap.Error(err),
)
return err
}
m.logger.Info("Room closed",
zap.String("room_id", roomID),
zap.String("namespace", namespace),
)
return nil
}
// ListRooms returns all rooms for a namespace
func (m *RoomManager) ListRooms(namespace string) []*Room {
m.roomsMu.RLock()
defer m.roomsMu.RUnlock()
prefix := namespace + ":"
rooms := make([]*Room, 0)
for key, room := range m.rooms {
if len(key) > len(prefix) && key[:len(prefix)] == prefix && !room.IsClosed() {
rooms = append(rooms, room)
}
}
return rooms
}
// RoomInfo contains public information about a room
type RoomInfo struct {
ID string `json:"id"`
Namespace string `json:"namespace"`
ParticipantCount int `json:"participantCount"`
Participants []ParticipantInfo `json:"participants"`
}
// GetRoomInfo returns public information about a room
func (m *RoomManager) GetRoomInfo(namespace, roomID string) (*RoomInfo, error) {
room, err := m.GetRoom(namespace, roomID)
if err != nil {
return nil, err
}
return &RoomInfo{
ID: room.ID,
Namespace: room.Namespace,
ParticipantCount: room.GetParticipantCount(),
Participants: room.GetParticipants(),
}, nil
}
// GetStats returns statistics about the room manager
func (m *RoomManager) GetStats() map[string]interface{} {
m.roomsMu.RLock()
defer m.roomsMu.RUnlock()
totalRooms := 0
totalParticipants := 0
for _, room := range m.rooms {
if !room.IsClosed() {
totalRooms++
totalParticipants += room.GetParticipantCount()
}
}
return map[string]interface{}{
"totalRooms": totalRooms,
"totalParticipants": totalParticipants,
}
}
// Close closes all rooms and cleans up resources
func (m *RoomManager) Close() error {
m.roomsMu.Lock()
rooms := make([]*Room, 0, len(m.rooms))
for _, room := range m.rooms {
rooms = append(rooms, room)
}
m.rooms = make(map[string]*Room)
m.roomsMu.Unlock()
for _, room := range rooms {
if err := room.Close(); err != nil {
m.logger.Warn("Error closing room during shutdown",
zap.String("room_id", room.ID),
zap.Error(err),
)
}
}
m.logger.Info("Room manager closed", zap.Int("rooms_closed", len(rooms)))
return nil
}
// GetConfig returns the SFU configuration
func (m *RoomManager) GetConfig() *Config {
return m.config
}
// UpdateICEServers updates the ICE servers configuration
func (m *RoomManager) UpdateICEServers(servers []webrtc.ICEServer) {
m.config.ICEServers = servers
}

505
pkg/gateway/sfu/peer.go Normal file
View File

@ -0,0 +1,505 @@
package sfu
import (
"encoding/json"
"sync"
"github.com/google/uuid"
"github.com/gorilla/websocket"
"github.com/pion/rtcp"
"github.com/pion/webrtc/v4"
"go.uber.org/zap"
)
// Peer represents a participant in a room
type Peer struct {
ID string
UserID string
DisplayName string
// WebRTC connection
pc *webrtc.PeerConnection
// Tracks published by this peer (local tracks that others receive)
localTracks map[string]*webrtc.TrackLocalStaticRTP
localTracksMu sync.RWMutex
// Track receivers for consuming other peers' tracks
trackReceivers map[string]*webrtc.RTPReceiver
trackReceiversMu sync.RWMutex
// WebSocket connection for signaling
conn *websocket.Conn
connMu sync.Mutex
// State
audioMuted bool
videoMuted bool
closed bool
closedMu sync.RWMutex
negotiationPending bool
negotiationPendingMu sync.Mutex
initialOfferHandled bool
initialOfferMu sync.Mutex
batchingTracks bool // When true, suppress automatic negotiation
batchingTracksMu sync.Mutex
// Room reference
room *Room
logger *zap.Logger
// Callbacks
onClose func(*Peer)
}
// NewPeer creates a new peer
func NewPeer(userID, displayName string, conn *websocket.Conn, room *Room, logger *zap.Logger) *Peer {
return &Peer{
ID: uuid.New().String(),
UserID: userID,
DisplayName: displayName,
localTracks: make(map[string]*webrtc.TrackLocalStaticRTP),
trackReceivers: make(map[string]*webrtc.RTPReceiver),
conn: conn,
room: room,
logger: logger,
}
}
// InitPeerConnection initializes the WebRTC peer connection
func (p *Peer) InitPeerConnection(api *webrtc.API, config webrtc.Configuration) error {
pc, err := api.NewPeerConnection(config)
if err != nil {
return err
}
p.pc = pc
// Handle ICE connection state changes
pc.OnICEConnectionStateChange(func(state webrtc.ICEConnectionState) {
p.logger.Info("ICE connection state changed",
zap.String("peer_id", p.ID),
zap.String("state", state.String()),
)
if state == webrtc.ICEConnectionStateFailed ||
state == webrtc.ICEConnectionStateDisconnected ||
state == webrtc.ICEConnectionStateClosed {
p.handleDisconnect()
}
})
// Handle ICE candidates
pc.OnICECandidate(func(candidate *webrtc.ICECandidate) {
if candidate == nil {
return
}
p.logger.Debug("ICE candidate generated",
zap.String("peer_id", p.ID),
zap.String("candidate", candidate.String()),
)
candidateJSON := candidate.ToJSON()
data := &ICECandidateData{
Candidate: candidateJSON.Candidate,
}
if candidateJSON.SDPMid != nil {
data.SDPMid = *candidateJSON.SDPMid
}
if candidateJSON.SDPMLineIndex != nil {
data.SDPMLineIndex = *candidateJSON.SDPMLineIndex
}
if candidateJSON.UsernameFragment != nil {
data.UsernameFragment = *candidateJSON.UsernameFragment
}
p.SendMessage(NewServerMessage(MessageTypeICECandidate, data))
})
// Handle incoming tracks from remote peers
pc.OnTrack(func(track *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) {
codec := track.Codec()
p.logger.Info("Track received from client",
zap.String("peer_id", p.ID),
zap.String("user_id", p.UserID),
zap.String("track_id", track.ID()),
zap.String("stream_id", track.StreamID()),
zap.String("kind", track.Kind().String()),
zap.String("codec_mime", codec.MimeType),
zap.Uint32("codec_clock_rate", codec.ClockRate),
zap.Uint8("codec_payload_type", uint8(codec.PayloadType)),
)
p.trackReceiversMu.Lock()
p.trackReceivers[track.ID()] = receiver
p.trackReceiversMu.Unlock()
// Start RTCP reader to monitor for packet loss (NACK) and PLI requests
go p.readRTCP(receiver, track)
// Forward track to other peers in the room
p.room.BroadcastTrack(p.ID, track)
})
// Handle negotiation needed - only trigger when in stable state
pc.OnNegotiationNeeded(func() {
p.logger.Debug("Negotiation needed",
zap.String("peer_id", p.ID),
zap.String("signaling_state", pc.SignalingState().String()),
)
// Check if we're batching tracks - if so, just mark as pending
p.batchingTracksMu.Lock()
batching := p.batchingTracks
p.batchingTracksMu.Unlock()
if batching {
p.negotiationPendingMu.Lock()
p.negotiationPending = true
p.negotiationPendingMu.Unlock()
p.logger.Debug("Negotiation deferred - batching tracks",
zap.String("peer_id", p.ID),
)
return
}
// Only create offer if we're in stable state
// Otherwise, mark negotiation as pending
if pc.SignalingState() == webrtc.SignalingStateStable {
p.createAndSendOffer()
} else {
p.negotiationPendingMu.Lock()
p.negotiationPending = true
p.negotiationPendingMu.Unlock()
p.logger.Debug("Negotiation queued - not in stable state",
zap.String("peer_id", p.ID),
zap.String("signaling_state", pc.SignalingState().String()),
)
}
})
// Handle signaling state changes to process pending negotiations
pc.OnSignalingStateChange(func(state webrtc.SignalingState) {
p.logger.Debug("Signaling state changed",
zap.String("peer_id", p.ID),
zap.String("state", state.String()),
)
// When we return to stable state, check if negotiation was pending
if state == webrtc.SignalingStateStable {
p.negotiationPendingMu.Lock()
pending := p.negotiationPending
p.negotiationPending = false
p.negotiationPendingMu.Unlock()
if pending {
p.logger.Debug("Processing pending negotiation", zap.String("peer_id", p.ID))
p.createAndSendOffer()
}
}
})
return nil
}
// createAndSendOffer creates an SDP offer and sends it to the peer
func (p *Peer) createAndSendOffer() {
if p.pc == nil {
return
}
// Double-check signaling state before creating offer
if p.pc.SignalingState() != webrtc.SignalingStateStable {
p.logger.Debug("Skipping offer - not in stable state",
zap.String("peer_id", p.ID),
zap.String("signaling_state", p.pc.SignalingState().String()),
)
// Mark as pending so it will be retried when state becomes stable
p.negotiationPendingMu.Lock()
p.negotiationPending = true
p.negotiationPendingMu.Unlock()
return
}
p.logger.Info("Creating SDP offer", zap.String("peer_id", p.ID))
offer, err := p.pc.CreateOffer(nil)
if err != nil {
p.logger.Error("Failed to create offer",
zap.String("peer_id", p.ID),
zap.Error(err),
)
return
}
if err := p.pc.SetLocalDescription(offer); err != nil {
p.logger.Error("Failed to set local description",
zap.String("peer_id", p.ID),
zap.Error(err),
)
return
}
p.logger.Info("Sending SDP offer", zap.String("peer_id", p.ID))
p.SendMessage(NewServerMessage(MessageTypeOffer, &OfferData{
SDP: offer.SDP,
}))
}
// HandleOffer processes an SDP offer from the client
func (p *Peer) HandleOffer(sdp string) error {
if p.pc == nil {
return ErrPeerNotInitialized
}
offer := webrtc.SessionDescription{
Type: webrtc.SDPTypeOffer,
SDP: sdp,
}
if err := p.pc.SetRemoteDescription(offer); err != nil {
return err
}
answer, err := p.pc.CreateAnswer(nil)
if err != nil {
return err
}
if err := p.pc.SetLocalDescription(answer); err != nil {
return err
}
p.SendMessage(NewServerMessage(MessageTypeAnswer, &AnswerData{
SDP: answer.SDP,
}))
return nil
}
// HandleAnswer processes an SDP answer from the client
func (p *Peer) HandleAnswer(sdp string) error {
if p.pc == nil {
return ErrPeerNotInitialized
}
answer := webrtc.SessionDescription{
Type: webrtc.SDPTypeAnswer,
SDP: sdp,
}
return p.pc.SetRemoteDescription(answer)
}
// HandleICECandidate processes an ICE candidate from the client
func (p *Peer) HandleICECandidate(data *ICECandidateData) error {
if p.pc == nil {
return ErrPeerNotInitialized
}
return p.pc.AddICECandidate(data.ToWebRTCCandidate())
}
// AddTrack adds a track to send to this peer (from another peer)
func (p *Peer) AddTrack(track *webrtc.TrackLocalStaticRTP) (*webrtc.RTPSender, error) {
if p.pc == nil {
return nil, ErrPeerNotInitialized
}
return p.pc.AddTrack(track)
}
// StartTrackBatch starts batching track additions.
// Call EndTrackBatch when done to trigger a single renegotiation.
func (p *Peer) StartTrackBatch() {
p.batchingTracksMu.Lock()
p.batchingTracks = true
p.batchingTracksMu.Unlock()
p.logger.Debug("Started track batching", zap.String("peer_id", p.ID))
}
// EndTrackBatch ends track batching and triggers renegotiation if needed.
func (p *Peer) EndTrackBatch() {
p.batchingTracksMu.Lock()
p.batchingTracks = false
p.batchingTracksMu.Unlock()
// Check if negotiation was pending during batching
p.negotiationPendingMu.Lock()
pending := p.negotiationPending
p.negotiationPending = false
p.negotiationPendingMu.Unlock()
if pending && p.pc != nil && p.pc.SignalingState() == webrtc.SignalingStateStable {
p.logger.Debug("Processing batched negotiation", zap.String("peer_id", p.ID))
p.createAndSendOffer()
}
}
// SendMessage sends a signaling message to the peer via WebSocket
func (p *Peer) SendMessage(msg *ServerMessage) error {
p.closedMu.RLock()
if p.closed {
p.closedMu.RUnlock()
return ErrPeerClosed
}
p.closedMu.RUnlock()
p.connMu.Lock()
defer p.connMu.Unlock()
if p.conn == nil {
return ErrWebSocketClosed
}
data, err := json.Marshal(msg)
if err != nil {
return err
}
return p.conn.WriteMessage(websocket.TextMessage, data)
}
// GetInfo returns public information about this peer
func (p *Peer) GetInfo() ParticipantInfo {
p.localTracksMu.RLock()
hasAudio := false
hasVideo := false
for _, track := range p.localTracks {
if track.Kind() == webrtc.RTPCodecTypeAudio {
hasAudio = true
} else if track.Kind() == webrtc.RTPCodecTypeVideo {
hasVideo = true
}
}
p.localTracksMu.RUnlock()
return ParticipantInfo{
ID: p.ID,
UserID: p.UserID,
DisplayName: p.DisplayName,
HasAudio: hasAudio,
HasVideo: hasVideo,
AudioMuted: p.audioMuted,
VideoMuted: p.videoMuted,
}
}
// SetAudioMuted sets the audio mute state
func (p *Peer) SetAudioMuted(muted bool) {
p.audioMuted = muted
}
// SetVideoMuted sets the video mute state
func (p *Peer) SetVideoMuted(muted bool) {
p.videoMuted = muted
}
// MarkInitialOfferHandled marks that the initial offer has been processed.
// Returns true if this is the first time it's being marked (i.e., this was the first offer).
func (p *Peer) MarkInitialOfferHandled() bool {
p.initialOfferMu.Lock()
defer p.initialOfferMu.Unlock()
if p.initialOfferHandled {
return false
}
p.initialOfferHandled = true
return true
}
// handleDisconnect handles peer disconnection
func (p *Peer) handleDisconnect() {
p.closedMu.Lock()
if p.closed {
p.closedMu.Unlock()
return
}
p.closed = true
p.closedMu.Unlock()
if p.onClose != nil {
p.onClose(p)
}
}
// Close closes the peer connection and cleans up resources
func (p *Peer) Close() error {
p.closedMu.Lock()
if p.closed {
p.closedMu.Unlock()
return nil
}
p.closed = true
p.closedMu.Unlock()
p.logger.Info("Closing peer", zap.String("peer_id", p.ID))
// Close WebSocket
p.connMu.Lock()
if p.conn != nil {
p.conn.Close()
p.conn = nil
}
p.connMu.Unlock()
// Close peer connection
if p.pc != nil {
return p.pc.Close()
}
return nil
}
// OnClose sets a callback for when the peer is closed
func (p *Peer) OnClose(fn func(*Peer)) {
p.onClose = fn
}
// readRTCP reads RTCP packets from a receiver to monitor feedback
// This helps detect packet loss (via NACK) for adaptive quality adjustments
func (p *Peer) readRTCP(receiver *webrtc.RTPReceiver, track *webrtc.TrackRemote) {
localTrackID := track.Kind().String() + "-" + p.ID
for {
packets, _, err := receiver.ReadRTCP()
if err != nil {
// Connection closed, exit gracefully
return
}
for _, pkt := range packets {
switch rtcpPkt := pkt.(type) {
case *rtcp.TransportLayerNack:
// NACK received - indicates packet loss from receivers
// Increment the NACK counter for adaptive keyframe logic
p.room.IncrementNackCount(localTrackID)
p.logger.Debug("NACK received",
zap.String("peer_id", p.ID),
zap.String("track_id", localTrackID),
zap.Uint32("sender_ssrc", rtcpPkt.SenderSSRC),
zap.Int("nack_pairs", len(rtcpPkt.Nacks)),
)
case *rtcp.PictureLossIndication:
// PLI received - receiver needs a keyframe
p.logger.Debug("PLI received from receiver",
zap.String("peer_id", p.ID),
zap.String("track_id", localTrackID),
)
// Request keyframe from source
p.room.RequestKeyframe(localTrackID)
case *rtcp.FullIntraRequest:
// FIR received - receiver needs a full keyframe
p.logger.Debug("FIR received from receiver",
zap.String("peer_id", p.ID),
zap.String("track_id", localTrackID),
)
// Request keyframe from source
p.room.RequestKeyframe(localTrackID)
}
}
}
}

775
pkg/gateway/sfu/room.go Normal file
View File

@ -0,0 +1,775 @@
package sfu
import (
"errors"
"sync"
"sync/atomic"
"time"
"github.com/pion/rtcp"
"github.com/pion/webrtc/v4"
"go.uber.org/zap"
)
// Common errors
var (
ErrRoomFull = errors.New("room is full")
ErrRoomClosed = errors.New("room is closed")
ErrPeerNotFound = errors.New("peer not found")
ErrPeerNotInitialized = errors.New("peer not initialized")
ErrPeerClosed = errors.New("peer is closed")
ErrWebSocketClosed = errors.New("websocket connection closed")
)
// publishedTrack holds information about a track published to the room
type publishedTrack struct {
sourcePeerID string
sourceUserID string
localTrack *webrtc.TrackLocalStaticRTP
remoteTrackSSRC uint32 // SSRC of the remote track (for PLI requests)
remoteTrack *webrtc.TrackRemote // Reference to the remote track
kind string
forwarderActive bool // tracks if the RTP forwarder goroutine is still running
packetCount int // number of packets forwarded (for debugging)
// Packet loss tracking for adaptive keyframe requests
lastKeyframeRequest time.Time
nackCount atomic.Int64 // Number of NACK packets received (indicates packet loss)
}
// Room represents a WebRTC room with multiple participants
type Room struct {
ID string
Namespace string
// Participants in the room
peers map[string]*Peer
peersMu sync.RWMutex
// Published tracks in the room (for sending to new joiners)
publishedTracks map[string]*publishedTrack // key: trackID
publishedTracksMu sync.RWMutex
// WebRTC API for creating peer connections
api *webrtc.API
// Configuration
config *Config
logger *zap.Logger
// State
closed bool
closedMu sync.RWMutex
// Callbacks
onEmpty func(*Room)
}
// NewRoom creates a new room
func NewRoom(id, namespace string, api *webrtc.API, config *Config, logger *zap.Logger) *Room {
return &Room{
ID: id,
Namespace: namespace,
peers: make(map[string]*Peer),
publishedTracks: make(map[string]*publishedTrack),
api: api,
config: config,
logger: logger.With(zap.String("room_id", id)),
}
}
// AddPeer adds a new peer to the room
func (r *Room) AddPeer(peer *Peer) error {
r.closedMu.RLock()
if r.closed {
r.closedMu.RUnlock()
return ErrRoomClosed
}
r.closedMu.RUnlock()
r.peersMu.Lock()
// Check max participants
if r.config.MaxParticipants > 0 && len(r.peers) >= r.config.MaxParticipants {
r.peersMu.Unlock()
return ErrRoomFull
}
// Initialize peer connection
pcConfig := webrtc.Configuration{
ICEServers: r.config.ICEServers,
}
if err := peer.InitPeerConnection(r.api, pcConfig); err != nil {
r.peersMu.Unlock()
return err
}
// Set up peer close handler
peer.OnClose(func(p *Peer) {
r.RemovePeer(p.ID)
})
r.peers[peer.ID] = peer
peerInfo := peer.GetInfo() // Get info while holding lock
totalPeers := len(r.peers)
// Release lock BEFORE broadcasting to avoid deadlock
// (broadcastMessage also acquires the lock)
r.peersMu.Unlock()
r.logger.Info("Peer added to room",
zap.String("peer_id", peer.ID),
zap.String("user_id", peer.UserID),
zap.Int("total_peers", totalPeers),
)
// Notify other peers (now safe since we released the lock)
r.broadcastMessage(peer.ID, NewServerMessage(MessageTypeParticipantJoined, &ParticipantJoinedData{
Participant: peerInfo,
}))
return nil
}
// RemovePeer removes a peer from the room
func (r *Room) RemovePeer(peerID string) error {
r.peersMu.Lock()
peer, ok := r.peers[peerID]
if !ok {
r.peersMu.Unlock()
return ErrPeerNotFound
}
delete(r.peers, peerID)
remainingPeers := len(r.peers)
r.peersMu.Unlock()
// Remove tracks published by this peer
r.publishedTracksMu.Lock()
removedTracks := make([]string, 0)
for trackID, track := range r.publishedTracks {
if track.sourcePeerID == peerID {
delete(r.publishedTracks, trackID)
removedTracks = append(removedTracks, trackID)
}
}
r.publishedTracksMu.Unlock()
if len(removedTracks) > 0 {
r.logger.Info("Removed tracks for departing peer",
zap.String("peer_id", peerID),
zap.Strings("track_ids", removedTracks),
)
}
// Close the peer
if err := peer.Close(); err != nil {
r.logger.Warn("Error closing peer",
zap.String("peer_id", peerID),
zap.Error(err),
)
}
r.logger.Info("Peer removed from room",
zap.String("peer_id", peerID),
zap.Int("remaining_peers", remainingPeers),
)
// Notify other peers
r.broadcastMessage(peerID, NewServerMessage(MessageTypeParticipantLeft, &ParticipantLeftData{
ParticipantID: peerID,
}))
// Check if room is empty
if remainingPeers == 0 && r.onEmpty != nil {
r.onEmpty(r)
}
return nil
}
// GetPeer returns a peer by ID
func (r *Room) GetPeer(peerID string) (*Peer, error) {
r.peersMu.RLock()
defer r.peersMu.RUnlock()
peer, ok := r.peers[peerID]
if !ok {
return nil, ErrPeerNotFound
}
return peer, nil
}
// GetPeers returns all peers in the room
func (r *Room) GetPeers() []*Peer {
r.peersMu.RLock()
defer r.peersMu.RUnlock()
peers := make([]*Peer, 0, len(r.peers))
for _, peer := range r.peers {
peers = append(peers, peer)
}
return peers
}
// GetParticipants returns info about all participants
func (r *Room) GetParticipants() []ParticipantInfo {
r.peersMu.RLock()
defer r.peersMu.RUnlock()
participants := make([]ParticipantInfo, 0, len(r.peers))
for _, peer := range r.peers {
participants = append(participants, peer.GetInfo())
}
return participants
}
// GetParticipantCount returns the number of participants
func (r *Room) GetParticipantCount() int {
r.peersMu.RLock()
defer r.peersMu.RUnlock()
return len(r.peers)
}
// SendExistingTracksTo sends all existing tracks from other participants to the specified peer.
// This should be called AFTER the welcome message is sent to ensure the client is ready.
// Uses batch mode to send all tracks with a single renegotiation for faster joins.
func (r *Room) SendExistingTracksTo(peer *Peer) {
r.publishedTracksMu.RLock()
existingTracks := make([]*publishedTrack, 0, len(r.publishedTracks))
for _, track := range r.publishedTracks {
// Don't send peer's own tracks back to them
if track.sourcePeerID != peer.ID {
existingTracks = append(existingTracks, track)
}
}
r.publishedTracksMu.RUnlock()
if len(existingTracks) == 0 {
r.logger.Info("No existing tracks to send to new peer",
zap.String("peer_id", peer.ID),
)
return
}
r.logger.Info("Sending existing tracks to new peer (batch mode)",
zap.String("peer_id", peer.ID),
zap.Int("track_count", len(existingTracks)),
)
videoTrackIDs := make([]string, 0)
// Start batch mode - suppresses individual renegotiations
peer.StartTrackBatch()
for _, track := range existingTracks {
// Log forwarder status to help diagnose video issues
r.logger.Info("Adding existing track to new peer",
zap.String("new_peer_id", peer.ID),
zap.String("source_peer_id", track.sourcePeerID),
zap.String("source_user_id", track.sourceUserID),
zap.String("track_id", track.localTrack.ID()),
zap.String("kind", track.kind),
zap.Bool("forwarder_active", track.forwarderActive),
zap.Int("packets_forwarded", track.packetCount),
)
// Warn if forwarder is no longer active
if !track.forwarderActive {
r.logger.Warn("WARNING: Track forwarder is NOT active - track may not receive data",
zap.String("track_id", track.localTrack.ID()),
zap.String("kind", track.kind),
zap.Int("final_packet_count", track.packetCount),
)
}
if _, err := peer.AddTrack(track.localTrack); err != nil {
r.logger.Warn("Failed to add existing track to new peer",
zap.String("peer_id", peer.ID),
zap.String("track_id", track.localTrack.ID()),
zap.Error(err),
)
continue
}
// Track video tracks for keyframe requests
if track.kind == "video" {
videoTrackIDs = append(videoTrackIDs, track.localTrack.ID())
}
// Notify new peer about the existing track
peer.SendMessage(NewServerMessage(MessageTypeTrackAdded, &TrackAddedData{
ParticipantID: track.sourcePeerID,
UserID: track.sourceUserID,
TrackID: track.localTrack.ID(),
StreamID: track.localTrack.StreamID(),
Kind: track.kind,
}))
}
// End batch mode - triggers single renegotiation for all tracks
peer.EndTrackBatch()
r.logger.Info("Batch track addition complete - single renegotiation triggered",
zap.String("peer_id", peer.ID),
zap.Int("total_tracks", len(existingTracks)),
)
// Request keyframes for video tracks after a short delay
// This ensures the receiver has time to set up the track before receiving the keyframe
if len(videoTrackIDs) > 0 {
go func() {
// Wait for negotiation to complete
time.Sleep(300 * time.Millisecond)
r.logger.Info("Requesting keyframes for new peer",
zap.String("peer_id", peer.ID),
zap.Int("video_track_count", len(videoTrackIDs)),
)
for _, trackID := range videoTrackIDs {
r.RequestKeyframe(trackID)
}
// Request again after 500ms in case the first was too early
time.Sleep(500 * time.Millisecond)
for _, trackID := range videoTrackIDs {
r.RequestKeyframe(trackID)
}
}()
}
}
// BroadcastTrack broadcasts a track from one peer to all other peers
func (r *Room) BroadcastTrack(sourcePeerID string, track *webrtc.TrackRemote) {
r.peersMu.RLock()
defer r.peersMu.RUnlock()
// Get source peer's user ID for the track-added message
sourceUserID := ""
if sourcePeer, ok := r.peers[sourcePeerID]; ok {
sourceUserID = sourcePeer.UserID
}
// Create a local track from the remote track
// Use participant ID as the stream ID so clients can identify the source
// Format: trackId stays the same, streamId = sourcePeerID (or sourceUserID if available)
streamID := sourcePeerID
if sourceUserID != "" {
streamID = sourceUserID // Use userID for easier client-side mapping
}
r.logger.Info("Creating local track for broadcast",
zap.String("source_peer_id", sourcePeerID),
zap.String("source_user_id", sourceUserID),
zap.String("original_track_id", track.ID()),
zap.String("original_stream_id", track.StreamID()),
zap.String("new_stream_id", streamID),
)
// Log codec information for debugging
codec := track.Codec()
r.logger.Info("Track codec info",
zap.String("source_peer_id", sourcePeerID),
zap.String("track_kind", track.Kind().String()),
zap.String("mime_type", codec.MimeType),
zap.Uint32("clock_rate", codec.ClockRate),
zap.Uint16("channels", codec.Channels),
zap.String("sdp_fmtp_line", codec.SDPFmtpLine),
)
localTrack, err := webrtc.NewTrackLocalStaticRTP(
codec.RTPCodecCapability,
track.Kind().String()+"-"+sourcePeerID, // Include peer ID in track ID
streamID, // Use participant/user ID as stream ID
)
if err != nil {
r.logger.Error("Failed to create local track",
zap.String("source_peer", sourcePeerID),
zap.String("track_id", track.ID()),
zap.Error(err),
)
return
}
// Store the track for new joiners
pubTrack := &publishedTrack{
sourcePeerID: sourcePeerID,
sourceUserID: sourceUserID,
localTrack: localTrack,
remoteTrackSSRC: uint32(track.SSRC()),
remoteTrack: track,
kind: track.Kind().String(),
forwarderActive: true,
packetCount: 0,
lastKeyframeRequest: time.Now(),
}
r.publishedTracksMu.Lock()
r.publishedTracks[localTrack.ID()] = pubTrack
r.publishedTracksMu.Unlock()
r.logger.Info("Track stored for new joiners",
zap.String("track_id", localTrack.ID()),
zap.String("source_peer_id", sourcePeerID),
zap.Int("total_published_tracks", len(r.publishedTracks)),
)
// Forward RTP packets from remote track to local track
go func() {
trackID := track.ID()
localTrackID := localTrack.ID()
trackKind := track.Kind().String()
buf := make([]byte, 1600) // Slightly larger than MTU to handle RTP extensions
packetCount := 0
byteCount := 0
startTime := time.Now()
firstPacketReceived := false
r.logger.Info("RTP forwarder started",
zap.String("track_id", trackID),
zap.String("local_track_id", localTrackID),
zap.String("kind", trackKind),
zap.String("source_peer_id", sourcePeerID),
)
// Start a goroutine to log warning if no packets received after 5 seconds
go func() {
time.Sleep(5 * time.Second)
if !firstPacketReceived {
r.logger.Warn("RTP forwarder WARNING: No packets received after 5 seconds - host may not be sending",
zap.String("track_id", trackID),
zap.String("local_track_id", localTrackID),
zap.String("kind", trackKind),
zap.String("source_peer_id", sourcePeerID),
)
}
}()
// For video tracks, use adaptive keyframe requests based on packet loss
if trackKind == "video" {
go func() {
// Use a faster ticker for checking, but only send keyframes when needed
ticker := time.NewTicker(500 * time.Millisecond)
defer ticker.Stop()
var lastNackCount int64
var consecutiveLossDetections int
baseInterval := 3 * time.Second
minInterval := 500 * time.Millisecond // Minimum interval between keyframes
lastKeyframeTime := time.Now()
for range ticker.C {
// Check if room is closed
if r.IsClosed() {
return
}
// Check if forwarder is still active
r.publishedTracksMu.RLock()
pt, ok := r.publishedTracks[localTrackID]
if !ok || !pt.forwarderActive {
r.publishedTracksMu.RUnlock()
return
}
// Get current NACK count to detect packet loss
currentNackCount := pt.nackCount.Load()
r.publishedTracksMu.RUnlock()
timeSinceLastKeyframe := time.Since(lastKeyframeTime)
// Detect if packet loss is happening (NACKs increasing)
if currentNackCount > lastNackCount {
consecutiveLossDetections++
lastNackCount = currentNackCount
// If we detect packet loss and haven't requested a keyframe recently,
// request one immediately to help receivers recover
if timeSinceLastKeyframe >= minInterval {
r.logger.Debug("Adaptive keyframe request due to packet loss",
zap.String("track_id", localTrackID),
zap.Int64("nack_count", currentNackCount),
zap.Int("consecutive_loss_detections", consecutiveLossDetections),
)
r.RequestKeyframe(localTrackID)
lastKeyframeTime = time.Now()
}
} else {
// Reset consecutive loss counter when no new NACKs
if consecutiveLossDetections > 0 {
consecutiveLossDetections--
}
}
// Regular keyframe request at base interval (regardless of loss)
if timeSinceLastKeyframe >= baseInterval {
r.RequestKeyframe(localTrackID)
lastKeyframeTime = time.Now()
}
}
}()
}
// Mark forwarder as stopped when we exit
defer func() {
r.publishedTracksMu.Lock()
if pt, ok := r.publishedTracks[localTrackID]; ok {
pt.forwarderActive = false
pt.packetCount = packetCount
}
r.publishedTracksMu.Unlock()
r.logger.Info("RTP forwarder exiting",
zap.String("track_id", trackID),
zap.String("local_track_id", localTrackID),
zap.String("kind", trackKind),
zap.Duration("lifetime", time.Since(startTime)),
zap.Int("total_packets", packetCount),
zap.Int("total_bytes", byteCount),
)
}()
for {
n, _, readErr := track.Read(buf)
if readErr != nil {
r.logger.Info("RTP forwarder stopped - read error",
zap.String("track_id", trackID),
zap.String("local_track_id", localTrackID),
zap.String("kind", trackKind),
zap.Int("packets_forwarded", packetCount),
zap.Int("bytes_forwarded", byteCount),
zap.Error(readErr),
)
return
}
// Log first packet to confirm data is flowing
if packetCount == 0 {
firstPacketReceived = true
r.logger.Info("RTP forwarder received FIRST packet",
zap.String("track_id", trackID),
zap.String("local_track_id", localTrackID),
zap.String("kind", trackKind),
zap.Int("packet_size", n),
zap.Duration("time_to_first_packet", time.Since(startTime)),
)
}
if _, writeErr := localTrack.Write(buf[:n]); writeErr != nil {
r.logger.Info("RTP forwarder stopped - write error",
zap.String("track_id", trackID),
zap.String("local_track_id", localTrackID),
zap.String("kind", trackKind),
zap.Int("packets_forwarded", packetCount),
zap.Int("bytes_forwarded", byteCount),
zap.Error(writeErr),
)
return
}
packetCount++
byteCount += n
// Update packet count periodically (not every packet to reduce lock contention)
if packetCount%50 == 0 {
r.publishedTracksMu.Lock()
if pt, ok := r.publishedTracks[localTrackID]; ok {
pt.packetCount = packetCount
}
r.publishedTracksMu.Unlock()
}
// Log progress every 100 packets for video, 500 for audio
logInterval := 500
if trackKind == "video" {
logInterval = 100
}
if packetCount%logInterval == 0 {
r.logger.Info("RTP forwarder progress",
zap.String("track_id", trackID),
zap.String("kind", trackKind),
zap.Int("packets", packetCount),
zap.Int("bytes", byteCount),
)
}
}
}()
// Add track to all other peers
for peerID, peer := range r.peers {
if peerID == sourcePeerID {
continue
}
r.logger.Info("Adding track to peer",
zap.String("target_peer", peerID),
zap.String("source_peer", sourcePeerID),
zap.String("source_user", sourceUserID),
zap.String("track_id", localTrack.ID()),
zap.String("stream_id", localTrack.StreamID()),
)
if _, err := peer.AddTrack(localTrack); err != nil {
r.logger.Warn("Failed to add track to peer",
zap.String("target_peer", peerID),
zap.String("track_id", localTrack.ID()),
zap.Error(err),
)
continue
}
// Notify peer about new track
// Use consistent IDs: participantId=sourcePeerID, streamId matches the track
peer.SendMessage(NewServerMessage(MessageTypeTrackAdded, &TrackAddedData{
ParticipantID: sourcePeerID,
UserID: sourceUserID,
TrackID: localTrack.ID(), // Format: "{kind}-{participantId}"
StreamID: localTrack.StreamID(), // Same as userId for easy matching
Kind: track.Kind().String(),
}))
}
r.logger.Info("Track broadcast to room",
zap.String("source_peer", sourcePeerID),
zap.String("source_user", sourceUserID),
zap.String("track_id", track.ID()),
zap.String("kind", track.Kind().String()),
)
}
// broadcastMessage sends a message to all peers except the specified one
func (r *Room) broadcastMessage(excludePeerID string, msg *ServerMessage) {
r.peersMu.RLock()
defer r.peersMu.RUnlock()
for peerID, peer := range r.peers {
if peerID == excludePeerID {
continue
}
if err := peer.SendMessage(msg); err != nil {
r.logger.Warn("Failed to send message to peer",
zap.String("peer_id", peerID),
zap.Error(err),
)
}
}
}
// Close closes the room and all peer connections
func (r *Room) Close() error {
r.closedMu.Lock()
if r.closed {
r.closedMu.Unlock()
return nil
}
r.closed = true
r.closedMu.Unlock()
r.logger.Info("Closing room")
r.peersMu.Lock()
peers := make([]*Peer, 0, len(r.peers))
for _, peer := range r.peers {
peers = append(peers, peer)
}
r.peers = make(map[string]*Peer)
r.peersMu.Unlock()
// Close all peers
for _, peer := range peers {
if err := peer.Close(); err != nil {
r.logger.Warn("Error closing peer",
zap.String("peer_id", peer.ID),
zap.Error(err),
)
}
}
return nil
}
// OnEmpty sets a callback for when the room becomes empty
func (r *Room) OnEmpty(fn func(*Room)) {
r.onEmpty = fn
}
// IsClosed returns whether the room is closed
func (r *Room) IsClosed() bool {
r.closedMu.RLock()
defer r.closedMu.RUnlock()
return r.closed
}
// RequestKeyframe sends a PLI (Picture Loss Indication) to the source peer for a video track.
// This causes the source to send a keyframe, which is needed for new receivers to start decoding.
func (r *Room) RequestKeyframe(trackID string) {
r.publishedTracksMu.RLock()
track, ok := r.publishedTracks[trackID]
r.publishedTracksMu.RUnlock()
if !ok || track.kind != "video" {
return
}
r.peersMu.RLock()
sourcePeer, ok := r.peers[track.sourcePeerID]
r.peersMu.RUnlock()
if !ok || sourcePeer.pc == nil {
r.logger.Debug("Cannot request keyframe - source peer not found",
zap.String("track_id", trackID),
zap.String("source_peer_id", track.sourcePeerID),
)
return
}
// Create a PLI packet
pli := &rtcp.PictureLossIndication{
MediaSSRC: track.remoteTrackSSRC,
}
// Send the PLI to the source peer
if err := sourcePeer.pc.WriteRTCP([]rtcp.Packet{pli}); err != nil {
r.logger.Debug("Failed to send PLI",
zap.String("track_id", trackID),
zap.String("source_peer_id", track.sourcePeerID),
zap.Error(err),
)
return
}
r.logger.Debug("PLI keyframe request sent",
zap.String("track_id", trackID),
zap.String("source_peer_id", track.sourcePeerID),
zap.Uint32("ssrc", track.remoteTrackSSRC),
)
}
// RequestKeyframeForAllVideoTracks sends PLI requests for all video tracks in the room.
// This is useful when a new peer joins to ensure they get keyframes quickly.
func (r *Room) RequestKeyframeForAllVideoTracks() {
r.publishedTracksMu.RLock()
videoTrackIDs := make([]string, 0)
for trackID, track := range r.publishedTracks {
if track.kind == "video" {
videoTrackIDs = append(videoTrackIDs, trackID)
}
}
r.publishedTracksMu.RUnlock()
for _, trackID := range videoTrackIDs {
r.RequestKeyframe(trackID)
}
}
// IncrementNackCount increments the NACK counter for a track.
// This is called when we receive NACK feedback indicating packet loss.
func (r *Room) IncrementNackCount(trackID string) {
r.publishedTracksMu.RLock()
track, ok := r.publishedTracks[trackID]
r.publishedTracksMu.RUnlock()
if ok {
track.nackCount.Add(1)
}
}

View File

@ -0,0 +1,145 @@
package sfu
import (
"encoding/json"
"github.com/pion/webrtc/v4"
)
// MessageType represents the type of signaling message
type MessageType string
const (
// Client -> Server message types
MessageTypeJoin MessageType = "join"
MessageTypeLeave MessageType = "leave"
MessageTypeOffer MessageType = "offer"
MessageTypeAnswer MessageType = "answer"
MessageTypeICECandidate MessageType = "ice-candidate"
MessageTypeMute MessageType = "mute"
MessageTypeUnmute MessageType = "unmute"
MessageTypeStartVideo MessageType = "start-video"
MessageTypeStopVideo MessageType = "stop-video"
// Server -> Client message types
MessageTypeWelcome MessageType = "welcome"
MessageTypeParticipantJoined MessageType = "participant-joined"
MessageTypeParticipantLeft MessageType = "participant-left"
MessageTypeTrackAdded MessageType = "track-added"
MessageTypeTrackRemoved MessageType = "track-removed"
MessageTypeError MessageType = "error"
)
// ClientMessage represents a message from client to server
type ClientMessage struct {
Type MessageType `json:"type"`
Data json.RawMessage `json:"data,omitempty"`
}
// ServerMessage represents a message from server to client
type ServerMessage struct {
Type MessageType `json:"type"`
Data interface{} `json:"data,omitempty"`
}
// JoinData is the payload for join messages
type JoinData struct {
DisplayName string `json:"displayName"`
AudioOnly bool `json:"audioOnly,omitempty"`
}
// OfferData is the payload for offer messages
type OfferData struct {
SDP string `json:"sdp"`
}
// AnswerData is the payload for answer messages
type AnswerData struct {
SDP string `json:"sdp"`
}
// ICECandidateData is the payload for ICE candidate messages
type ICECandidateData struct {
Candidate string `json:"candidate"`
SDPMid string `json:"sdpMid,omitempty"`
SDPMLineIndex uint16 `json:"sdpMLineIndex,omitempty"`
UsernameFragment string `json:"usernameFragment,omitempty"`
}
// ToWebRTCCandidate converts ICECandidateData to webrtc.ICECandidateInit
func (c *ICECandidateData) ToWebRTCCandidate() webrtc.ICECandidateInit {
return webrtc.ICECandidateInit{
Candidate: c.Candidate,
SDPMid: &c.SDPMid,
SDPMLineIndex: &c.SDPMLineIndex,
UsernameFragment: &c.UsernameFragment,
}
}
// WelcomeData is sent to a client when they successfully join
type WelcomeData struct {
ParticipantID string `json:"participantId"`
RoomID string `json:"roomId"`
Participants []ParticipantInfo `json:"participants"`
}
// ParticipantInfo contains public information about a participant
type ParticipantInfo struct {
ID string `json:"id"`
UserID string `json:"userId"`
DisplayName string `json:"displayName"`
HasAudio bool `json:"hasAudio"`
HasVideo bool `json:"hasVideo"`
AudioMuted bool `json:"audioMuted"`
VideoMuted bool `json:"videoMuted"`
}
// ParticipantJoinedData is sent when a new participant joins
type ParticipantJoinedData struct {
Participant ParticipantInfo `json:"participant"`
}
// ParticipantLeftData is sent when a participant leaves
type ParticipantLeftData struct {
ParticipantID string `json:"participantId"`
}
// TrackAddedData is sent when a new track is available
type TrackAddedData struct {
ParticipantID string `json:"participantId"` // Internal SFU peer ID
UserID string `json:"userId"` // The user's actual ID for easier mapping
TrackID string `json:"trackId"` // Format: "{kind}-{participantId}"
StreamID string `json:"streamId"` // Same as userId (for WebRTC stream matching)
Kind string `json:"kind"` // "audio" or "video"
}
// TrackRemovedData is sent when a track is removed
type TrackRemovedData struct {
ParticipantID string `json:"participantId"`
UserID string `json:"userId"` // The user's actual ID for easier mapping
TrackID string `json:"trackId"`
StreamID string `json:"streamId"`
Kind string `json:"kind"`
}
// ErrorData is sent when an error occurs
type ErrorData struct {
Code string `json:"code"`
Message string `json:"message"`
}
// NewServerMessage creates a new server message
func NewServerMessage(msgType MessageType, data interface{}) *ServerMessage {
return &ServerMessage{
Type: msgType,
Data: data,
}
}
// NewErrorMessage creates a new error message
func NewErrorMessage(code, message string) *ServerMessage {
return NewServerMessage(MessageTypeError, &ErrorData{
Code: code,
Message: message,
})
}

573
pkg/gateway/sfu_handlers.go Normal file
View File

@ -0,0 +1,573 @@
package gateway
import (
"encoding/json"
"net/http"
"strings"
"time"
"github.com/DeBrosOfficial/network/pkg/gateway/sfu"
"github.com/DeBrosOfficial/network/pkg/logging"
"github.com/gorilla/websocket"
"github.com/pion/webrtc/v4"
"go.uber.org/zap"
)
// SFUManager is a type alias for the SFU room manager
type SFUManager = sfu.RoomManager
// CreateRoomRequest is the request body for creating a room
type CreateRoomRequest struct {
RoomID string `json:"roomId"`
}
// CreateRoomResponse is the response for creating/joining a room
type CreateRoomResponse struct {
RoomID string `json:"roomId"`
Created bool `json:"created"`
RTPCapabilities map[string]interface{} `json:"rtpCapabilities"`
}
// JoinRoomRequest is the request body for joining a room
type JoinRoomRequest struct {
DisplayName string `json:"displayName"`
}
// JoinRoomResponse is the response for joining a room
type JoinRoomResponse struct {
ParticipantID string `json:"participantId"`
Participants []sfu.ParticipantInfo `json:"participants"`
}
// sfuCreateRoomHandler handles POST /v1/sfu/room
func (g *Gateway) sfuCreateRoomHandler(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
return
}
// Check if SFU is enabled
if g.sfuManager == nil {
writeError(w, http.StatusServiceUnavailable, "SFU service not enabled")
return
}
// Get namespace from auth context
ns := resolveNamespaceFromRequest(r)
if ns == "" {
writeError(w, http.StatusForbidden, "namespace not resolved")
return
}
// Parse request
var req CreateRoomRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeError(w, http.StatusBadRequest, "invalid request body")
return
}
if req.RoomID == "" {
writeError(w, http.StatusBadRequest, "roomId is required")
return
}
// Create or get room
room, created := g.sfuManager.GetOrCreateRoom(ns, req.RoomID)
g.logger.ComponentInfo(logging.ComponentGeneral, "SFU room request",
zap.String("room_id", req.RoomID),
zap.String("namespace", ns),
zap.Bool("created", created),
zap.Int("participants", room.GetParticipantCount()),
)
writeJSON(w, http.StatusOK, &CreateRoomResponse{
RoomID: req.RoomID,
Created: created,
RTPCapabilities: sfu.GetRTPCapabilities(),
})
}
// sfuRoomHandler handles all /v1/sfu/room/:roomId/* endpoints
func (g *Gateway) sfuRoomHandler(w http.ResponseWriter, r *http.Request) {
// Check if SFU is enabled
if g.sfuManager == nil {
writeError(w, http.StatusServiceUnavailable, "SFU service not enabled")
return
}
// Get namespace from auth context
ns := resolveNamespaceFromRequest(r)
if ns == "" {
writeError(w, http.StatusForbidden, "namespace not resolved")
return
}
// Parse room ID and action from path
// Path format: /v1/sfu/room/{roomId}/{action}
path := strings.TrimPrefix(r.URL.Path, "/v1/sfu/room/")
parts := strings.SplitN(path, "/", 2)
if len(parts) == 0 || parts[0] == "" {
writeError(w, http.StatusBadRequest, "room ID required")
return
}
roomID := parts[0]
action := ""
if len(parts) > 1 {
action = parts[1]
}
// Route to appropriate handler
switch action {
case "":
// GET /v1/sfu/room/:roomId - Get room info
g.sfuGetRoomHandler(w, r, ns, roomID)
case "join":
// POST /v1/sfu/room/:roomId/join - Join room
g.sfuJoinRoomHandler(w, r, ns, roomID)
case "leave":
// POST /v1/sfu/room/:roomId/leave - Leave room
g.sfuLeaveRoomHandler(w, r, ns, roomID)
case "participants":
// GET /v1/sfu/room/:roomId/participants - List participants
g.sfuParticipantsHandler(w, r, ns, roomID)
case "ws":
// GET /v1/sfu/room/:roomId/ws - WebSocket signaling
g.sfuWebSocketHandler(w, r, ns, roomID)
default:
writeError(w, http.StatusNotFound, "unknown action")
}
}
// sfuGetRoomHandler handles GET /v1/sfu/room/:roomId
func (g *Gateway) sfuGetRoomHandler(w http.ResponseWriter, r *http.Request, ns, roomID string) {
if r.Method != http.MethodGet {
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
return
}
info, err := g.sfuManager.GetRoomInfo(ns, roomID)
if err != nil {
if err == sfu.ErrRoomNotFound {
writeError(w, http.StatusNotFound, "room not found")
} else {
writeError(w, http.StatusInternalServerError, err.Error())
}
return
}
writeJSON(w, http.StatusOK, info)
}
// sfuJoinRoomHandler handles POST /v1/sfu/room/:roomId/join
func (g *Gateway) sfuJoinRoomHandler(w http.ResponseWriter, r *http.Request, ns, roomID string) {
if r.Method != http.MethodPost {
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
return
}
// Parse request
var req JoinRoomRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
// Default display name if not provided
req.DisplayName = "Anonymous"
}
// Get user ID from query param or auth context
userID := r.URL.Query().Get("userId")
if userID == "" {
userID = g.extractUserID(r)
}
if userID == "" {
userID = "anonymous"
}
// Get or create room
room, _ := g.sfuManager.GetOrCreateRoom(ns, roomID)
// This endpoint just returns room info
// The actual peer connection is established via WebSocket
writeJSON(w, http.StatusOK, &JoinRoomResponse{
ParticipantID: "", // Will be assigned when WebSocket connects
Participants: room.GetParticipants(),
})
}
// sfuLeaveRoomHandler handles POST /v1/sfu/room/:roomId/leave
func (g *Gateway) sfuLeaveRoomHandler(w http.ResponseWriter, r *http.Request, ns, roomID string) {
if r.Method != http.MethodPost {
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
return
}
// Parse participant ID from request body
var req struct {
ParticipantID string `json:"participantId"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil || req.ParticipantID == "" {
writeError(w, http.StatusBadRequest, "participantId required")
return
}
room, err := g.sfuManager.GetRoom(ns, roomID)
if err != nil {
if err == sfu.ErrRoomNotFound {
writeError(w, http.StatusNotFound, "room not found")
} else {
writeError(w, http.StatusInternalServerError, err.Error())
}
return
}
if err := room.RemovePeer(req.ParticipantID); err != nil {
if err == sfu.ErrPeerNotFound {
writeError(w, http.StatusNotFound, "participant not found")
} else {
writeError(w, http.StatusInternalServerError, err.Error())
}
return
}
writeJSON(w, http.StatusOK, map[string]interface{}{
"success": true,
})
}
// sfuParticipantsHandler handles GET /v1/sfu/room/:roomId/participants
func (g *Gateway) sfuParticipantsHandler(w http.ResponseWriter, r *http.Request, ns, roomID string) {
if r.Method != http.MethodGet {
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
return
}
room, err := g.sfuManager.GetRoom(ns, roomID)
if err != nil {
if err == sfu.ErrRoomNotFound {
writeError(w, http.StatusNotFound, "room not found")
} else {
writeError(w, http.StatusInternalServerError, err.Error())
}
return
}
writeJSON(w, http.StatusOK, map[string]interface{}{
"participants": room.GetParticipants(),
})
}
// sfuWebSocketHandler handles WebSocket signaling for a room
func (g *Gateway) sfuWebSocketHandler(w http.ResponseWriter, r *http.Request, ns, roomID string) {
if r.Method != http.MethodGet {
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
return
}
// Get user ID and display name from query parameters
// Priority: 1) userId query param, 2) JWT/API key auth, 3) "anonymous"
userID := r.URL.Query().Get("userId")
if userID == "" {
// Fall back to authentication-based user ID
userID = g.extractUserID(r)
}
if userID == "" {
userID = "anonymous"
}
displayName := r.URL.Query().Get("displayName")
if displayName == "" {
displayName = "Anonymous"
}
// Upgrade to WebSocket
conn, err := wsUpgrader.Upgrade(w, r, nil)
if err != nil {
g.logger.ComponentWarn(logging.ComponentGeneral, "SFU WebSocket upgrade failed",
zap.Error(err),
)
return
}
// Recover from panics to avoid silent crashes
defer func() {
if r := recover(); r != nil {
g.logger.ComponentError(logging.ComponentGeneral, "SFU WebSocket handler panic",
zap.Any("panic", r),
zap.String("room_id", roomID),
zap.String("user_id", userID),
)
conn.Close()
}
}()
g.logger.ComponentInfo(logging.ComponentGeneral, "SFU WebSocket connected",
zap.String("room_id", roomID),
zap.String("namespace", ns),
zap.String("user_id", userID),
zap.String("display_name", displayName),
)
// Get or create room
g.logger.ComponentDebug(logging.ComponentGeneral, "SFU getting/creating room",
zap.String("room_id", roomID),
zap.String("namespace", ns),
)
room, created := g.sfuManager.GetOrCreateRoom(ns, roomID)
g.logger.ComponentInfo(logging.ComponentGeneral, "SFU room ready",
zap.String("room_id", roomID),
zap.Bool("created", created),
)
// Create peer
peer := sfu.NewPeer(userID, displayName, conn, room, g.logger.Logger)
g.logger.ComponentInfo(logging.ComponentGeneral, "SFU peer created",
zap.String("peer_id", peer.ID),
)
// Add peer to room
g.logger.ComponentDebug(logging.ComponentGeneral, "SFU adding peer to room (will init peer connection)",
zap.String("peer_id", peer.ID),
zap.String("room_id", roomID),
)
if err := room.AddPeer(peer); err != nil {
g.logger.ComponentError(logging.ComponentGeneral, "Failed to add peer to room",
zap.String("room_id", roomID),
zap.Error(err),
)
peer.SendMessage(sfu.NewErrorMessage("join_failed", err.Error()))
conn.Close()
return
}
g.logger.ComponentInfo(logging.ComponentGeneral, "SFU peer added to room successfully",
zap.String("peer_id", peer.ID),
zap.String("room_id", roomID),
)
// Send welcome message
g.logger.ComponentDebug(logging.ComponentGeneral, "SFU preparing welcome message",
zap.String("peer_id", peer.ID),
zap.Int("num_participants", len(room.GetParticipants())),
)
welcomeMsg := sfu.NewServerMessage(sfu.MessageTypeWelcome, &sfu.WelcomeData{
ParticipantID: peer.ID,
RoomID: roomID,
Participants: room.GetParticipants(),
})
g.logger.ComponentDebug(logging.ComponentGeneral, "SFU sending welcome message via SendMessage",
zap.String("peer_id", peer.ID),
)
if err := peer.SendMessage(welcomeMsg); err != nil {
g.logger.ComponentError(logging.ComponentGeneral, "Failed to send welcome message",
zap.String("peer_id", peer.ID),
zap.Error(err),
)
conn.Close()
return
}
g.logger.ComponentInfo(logging.ComponentGeneral, "SFU welcome message sent successfully",
zap.String("peer_id", peer.ID),
zap.String("room_id", roomID),
)
// Handle signaling messages
// Note: existing tracks are sent AFTER the first offer/answer exchange completes
g.handleSFUSignaling(conn, peer, room)
}
// handleSFUSignaling handles WebSocket signaling messages for a peer
func (g *Gateway) handleSFUSignaling(conn *websocket.Conn, peer *sfu.Peer, room *sfu.Room) {
defer func() {
room.RemovePeer(peer.ID)
conn.Close()
g.logger.ComponentInfo(logging.ComponentGeneral, "SFU WebSocket disconnected",
zap.String("peer_id", peer.ID),
zap.String("room_id", room.ID),
)
}()
// Set up ping/pong for keepalive
conn.SetReadDeadline(time.Now().Add(60 * time.Second))
conn.SetPongHandler(func(string) error {
conn.SetReadDeadline(time.Now().Add(60 * time.Second))
return nil
})
// Start ping ticker
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
go func() {
for range ticker.C {
if err := conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(5*time.Second)); err != nil {
return
}
}
}()
// Read messages
for {
_, data, err := conn.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
g.logger.ComponentWarn(logging.ComponentGeneral, "SFU WebSocket read error",
zap.String("peer_id", peer.ID),
zap.Error(err),
)
}
return
}
// Reset read deadline
conn.SetReadDeadline(time.Now().Add(60 * time.Second))
// Parse message
var msg sfu.ClientMessage
if err := json.Unmarshal(data, &msg); err != nil {
peer.SendMessage(sfu.NewErrorMessage("invalid_message", "failed to parse message"))
continue
}
// Handle message
g.handleSFUMessage(peer, room, &msg)
}
}
// handleSFUMessage handles a single signaling message
func (g *Gateway) handleSFUMessage(peer *sfu.Peer, room *sfu.Room, msg *sfu.ClientMessage) {
switch msg.Type {
case sfu.MessageTypeOffer:
var data sfu.OfferData
if err := json.Unmarshal(msg.Data, &data); err != nil {
peer.SendMessage(sfu.NewErrorMessage("invalid_offer", err.Error()))
return
}
if err := peer.HandleOffer(data.SDP); err != nil {
peer.SendMessage(sfu.NewErrorMessage("offer_failed", err.Error()))
return
}
// After successfully handling the FIRST offer, send existing tracks
// This ensures the WebRTC connection is established before adding more tracks
if peer.MarkInitialOfferHandled() {
g.logger.ComponentInfo(logging.ComponentGeneral, "First offer handled, sending existing tracks",
zap.String("peer_id", peer.ID),
)
room.SendExistingTracksTo(peer)
}
case sfu.MessageTypeAnswer:
var data sfu.AnswerData
if err := json.Unmarshal(msg.Data, &data); err != nil {
peer.SendMessage(sfu.NewErrorMessage("invalid_answer", err.Error()))
return
}
if err := peer.HandleAnswer(data.SDP); err != nil {
peer.SendMessage(sfu.NewErrorMessage("answer_failed", err.Error()))
}
case sfu.MessageTypeICECandidate:
var data sfu.ICECandidateData
if err := json.Unmarshal(msg.Data, &data); err != nil {
peer.SendMessage(sfu.NewErrorMessage("invalid_candidate", err.Error()))
return
}
if err := peer.HandleICECandidate(&data); err != nil {
peer.SendMessage(sfu.NewErrorMessage("candidate_failed", err.Error()))
}
case sfu.MessageTypeMute:
peer.SetAudioMuted(true)
g.logger.ComponentDebug(logging.ComponentGeneral, "Peer muted audio", zap.String("peer_id", peer.ID))
case sfu.MessageTypeUnmute:
peer.SetAudioMuted(false)
g.logger.ComponentDebug(logging.ComponentGeneral, "Peer unmuted audio", zap.String("peer_id", peer.ID))
case sfu.MessageTypeStartVideo:
peer.SetVideoMuted(false)
g.logger.ComponentDebug(logging.ComponentGeneral, "Peer started video", zap.String("peer_id", peer.ID))
case sfu.MessageTypeStopVideo:
peer.SetVideoMuted(true)
g.logger.ComponentDebug(logging.ComponentGeneral, "Peer stopped video", zap.String("peer_id", peer.ID))
case sfu.MessageTypeLeave:
// Will be handled by deferred cleanup
g.logger.ComponentInfo(logging.ComponentGeneral, "Peer leaving room", zap.String("peer_id", peer.ID))
default:
peer.SendMessage(sfu.NewErrorMessage("unknown_type", "unknown message type"))
}
}
// initializeSFUManager initializes the SFU manager with the gateway's TURN config
func (g *Gateway) initializeSFUManager() error {
if g.cfg.SFU == nil || !g.cfg.SFU.Enabled {
g.logger.ComponentInfo(logging.ComponentGeneral, "SFU service disabled")
return nil
}
// Build ICE servers from config
iceServers := make([]webrtc.ICEServer, 0)
// Add configured ICE servers
if g.cfg.SFU.ICEServers != nil {
for _, server := range g.cfg.SFU.ICEServers {
iceServers = append(iceServers, webrtc.ICEServer{
URLs: server.URLs,
Username: server.Username,
Credential: server.Credential,
})
}
}
// Add TURN servers if configured (credentials will be generated dynamically)
if g.cfg.TURN != nil {
// Determine hostname for ICE server URLs
// Use configured domain, or fallback to localhost for local development
iceHost := g.cfg.DomainName
if iceHost == "" {
iceHost = "localhost"
}
if len(g.cfg.TURN.STUNURLs) > 0 {
// Process URLs to replace empty hostnames (e.g., "stun::3478" -> "stun:localhost:3478")
processedURLs := processURLsWithHost(g.cfg.TURN.STUNURLs, iceHost)
iceServers = append(iceServers, webrtc.ICEServer{
URLs: processedURLs,
})
}
// Note: TURN credentials are time-limited, so clients should fetch them
// via the /v1/turn/credentials endpoint before joining a room
}
// Create SFU config
sfuConfig := &sfu.Config{
MaxParticipants: g.cfg.SFU.MaxParticipants,
MediaTimeout: g.cfg.SFU.MediaTimeout,
ICEServers: iceServers,
}
if sfuConfig.MaxParticipants == 0 {
sfuConfig.MaxParticipants = 10
}
if sfuConfig.MediaTimeout == 0 {
sfuConfig.MediaTimeout = 30 * time.Second
}
manager, err := sfu.NewRoomManager(sfuConfig, g.logger.Logger)
if err != nil {
return err
}
g.sfuManager = manager
g.logger.ComponentInfo(logging.ComponentGeneral, "SFU manager initialized",
zap.Int("max_participants", sfuConfig.MaxParticipants),
zap.Duration("media_timeout", sfuConfig.MediaTimeout),
zap.Int("ice_servers", len(iceServers)),
)
return nil
}

View File

@ -0,0 +1,192 @@
package gateway
import (
"crypto/hmac"
"crypto/sha1"
"encoding/base64"
"fmt"
"net"
"net/http"
"strings"
"time"
"github.com/DeBrosOfficial/network/pkg/gateway/auth"
"github.com/DeBrosOfficial/network/pkg/logging"
"go.uber.org/zap"
)
// TURNCredentialsResponse is the response for TURN credential requests
type TURNCredentialsResponse struct {
Username string `json:"username"` // Format: "timestamp:userId"
Credential string `json:"credential"` // HMAC-SHA1(username, shared_secret) base64 encoded
TTL int64 `json:"ttl"` // Time-to-live in seconds
STUNURLs []string `json:"stun_urls"` // STUN server URLs
TURNURLs []string `json:"turn_urls"` // TURN server URLs
}
// turnCredentialsHandler handles POST /v1/turn/credentials
// Returns time-limited TURN credentials for WebRTC connections
func (g *Gateway) turnCredentialsHandler(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost && r.Method != http.MethodGet {
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
return
}
// Check if TURN is configured
if g.cfg.TURN == nil || g.cfg.TURN.SharedSecret == "" {
g.logger.ComponentWarn(logging.ComponentGeneral, "TURN credentials requested but not configured")
writeError(w, http.StatusServiceUnavailable, "TURN service not configured")
return
}
// Get user ID from JWT claims or API key
userID := g.extractUserID(r)
if userID == "" {
userID = "anonymous"
}
// Get gateway hostname from request
gatewayHost := r.Host
if idx := strings.Index(gatewayHost, ":"); idx != -1 {
gatewayHost = gatewayHost[:idx] // Remove port
}
if gatewayHost == "" {
gatewayHost = "localhost"
}
// Generate credentials
credentials := g.generateTURNCredentials(userID, gatewayHost)
g.logger.ComponentInfo(logging.ComponentGeneral, "TURN credentials generated",
zap.String("user_id", userID),
zap.String("gateway_host", gatewayHost),
zap.Int64("ttl", credentials.TTL),
)
writeJSON(w, http.StatusOK, credentials)
}
// generateTURNCredentials creates time-limited TURN credentials using HMAC-SHA1
func (g *Gateway) generateTURNCredentials(userID, gatewayHost string) *TURNCredentialsResponse {
cfg := g.cfg.TURN
// Default TTL to 24 hours if not configured
ttl := cfg.TTL
if ttl == 0 {
ttl = 24 * time.Hour
}
// Calculate expiry timestamp
timestamp := time.Now().Unix() + int64(ttl.Seconds())
// Format: "timestamp:userId" (coturn format)
username := fmt.Sprintf("%d:%s", timestamp, userID)
// Generate HMAC-SHA1 credential
h := hmac.New(sha1.New, []byte(cfg.SharedSecret))
h.Write([]byte(username))
credential := base64.StdEncoding.EncodeToString(h.Sum(nil))
// Determine the host to use for STUN/TURN URLs
// Priority: 1) ExternalHost from config, 2) Auto-detect LAN IP, 3) Gateway host from request
host := cfg.ExternalHost
if host == "" {
// Auto-detect LAN IP for development
host = detectLANIP()
if host == "" {
// Fallback to gateway host from request (may be localhost)
host = gatewayHost
}
}
// Process URLs - replace empty hostnames (::) with determined host
stunURLs := processURLsWithHost(cfg.STUNURLs, host)
turnURLs := processURLsWithHost(cfg.TURNURLs, host)
// If TLS is enabled, ensure we have turns:// URLs
if cfg.TLSEnabled {
hasTurns := false
for _, url := range turnURLs {
if strings.HasPrefix(url, "turns:") {
hasTurns = true
break
}
}
// Auto-add turns:// URL if not already configured
if !hasTurns {
turnsURL := fmt.Sprintf("turns:%s:443?transport=tcp", host)
turnURLs = append(turnURLs, turnsURL)
}
}
return &TURNCredentialsResponse{
Username: username,
Credential: credential,
TTL: int64(ttl.Seconds()),
STUNURLs: stunURLs,
TURNURLs: turnURLs,
}
}
// detectLANIP returns the first non-loopback IPv4 address found on the system.
// Returns empty string if no suitable address is found.
func detectLANIP() string {
addrs, err := net.InterfaceAddrs()
if err != nil {
return ""
}
for _, addr := range addrs {
if ipNet, ok := addr.(*net.IPNet); ok && !ipNet.IP.IsLoopback() {
if ipNet.IP.To4() != nil {
return ipNet.IP.String()
}
}
}
return ""
}
// processURLsWithHost replaces empty hostnames in URLs with the given host
// Supports two patterns:
// - "stun:::3478" (triple colon) -> "stun:host:3478"
// - "stun::3478" (double colon) -> "stun:host:3478"
func processURLsWithHost(urls []string, host string) []string {
result := make([]string, 0, len(urls))
for _, url := range urls {
// Check for triple colon pattern first (e.g., "stun:::3478")
// This is the preferred format: protocol:::port
if strings.Contains(url, ":::") {
url = strings.Replace(url, ":::", ":"+host+":", 1)
} else if strings.Contains(url, "::") {
// Fallback for double colon pattern (e.g., "stun::3478")
url = strings.Replace(url, "::", ":"+host+":", 1)
}
result = append(result, url)
}
return result
}
// extractUserID extracts the user ID from the request context
func (g *Gateway) extractUserID(r *http.Request) string {
ctx := r.Context()
// Try JWT claims first
if v := ctx.Value(ctxKeyJWT); v != nil {
if claims, ok := v.(*auth.JWTClaims); ok && claims != nil {
if claims.Sub != "" {
return claims.Sub
}
}
}
// Fallback to API key
if v := ctx.Value(ctxKeyAPIKey); v != nil {
if key, ok := v.(string); ok && key != "" {
// Use a hash of the API key as the user ID for privacy
return fmt.Sprintf("ak_%s", key[:min(8, len(key))])
}
}
return ""
}

View File

@ -33,19 +33,21 @@ func (n *Node) startHTTPGateway(ctx context.Context) error {
}
gwCfg := &gateway.Config{
ListenAddr: n.config.HTTPGateway.ListenAddr,
ClientNamespace: n.config.HTTPGateway.ClientNamespace,
BootstrapPeers: n.config.Discovery.BootstrapPeers,
NodePeerID: loadNodePeerIDFromIdentity(n.config.Node.DataDir),
RQLiteDSN: n.config.HTTPGateway.RQLiteDSN,
OlricServers: n.config.HTTPGateway.OlricServers,
OlricTimeout: n.config.HTTPGateway.OlricTimeout,
ListenAddr: n.config.HTTPGateway.ListenAddr,
ClientNamespace: n.config.HTTPGateway.ClientNamespace,
BootstrapPeers: n.config.Discovery.BootstrapPeers,
NodePeerID: loadNodePeerIDFromIdentity(n.config.Node.DataDir),
RQLiteDSN: n.config.HTTPGateway.RQLiteDSN,
OlricServers: n.config.HTTPGateway.OlricServers,
OlricTimeout: n.config.HTTPGateway.OlricTimeout,
IPFSClusterAPIURL: n.config.HTTPGateway.IPFSClusterAPIURL,
IPFSAPIURL: n.config.HTTPGateway.IPFSAPIURL,
IPFSTimeout: n.config.HTTPGateway.IPFSTimeout,
EnableHTTPS: n.config.HTTPGateway.HTTPS.Enabled,
DomainName: n.config.HTTPGateway.HTTPS.Domain,
TLSCacheDir: n.config.HTTPGateway.HTTPS.CacheDir,
IPFSAPIURL: n.config.HTTPGateway.IPFSAPIURL,
IPFSTimeout: n.config.HTTPGateway.IPFSTimeout,
EnableHTTPS: n.config.HTTPGateway.HTTPS.Enabled,
DomainName: n.config.HTTPGateway.HTTPS.Domain,
TLSCacheDir: n.config.HTTPGateway.HTTPS.CacheDir,
TURN: n.config.HTTPGateway.TURN,
SFU: n.config.HTTPGateway.SFU,
}
apiGateway, err := gateway.New(gatewayLogger, gwCfg)

View File

@ -16,6 +16,7 @@ import (
"github.com/DeBrosOfficial/network/pkg/logging"
"github.com/DeBrosOfficial/network/pkg/pubsub"
database "github.com/DeBrosOfficial/network/pkg/rqlite"
"github.com/DeBrosOfficial/network/pkg/turn"
"github.com/libp2p/go-libp2p/core/host"
"go.uber.org/zap"
"golang.org/x/crypto/acme/autocert"
@ -55,6 +56,9 @@ type Node struct {
// Certificate ready signal - closed when TLS certificates are extracted and ready for use
certReady chan struct{}
// Built-in TURN server for WebRTC NAT traversal
turnServer *turn.Server
}
// NewNode creates a new network node
@ -96,6 +100,11 @@ func (n *Node) Start(ctx context.Context) error {
n.logger.ComponentWarn(logging.ComponentNode, "Failed to start HTTP Gateway", zap.Error(err))
}
// Start built-in TURN server if enabled
if err := n.startTURNServer(); err != nil {
n.logger.ComponentWarn(logging.ComponentNode, "Failed to start TURN server", zap.Error(err))
}
// Start LibP2P host first (needed for cluster discovery)
if err := n.startLibP2P(); err != nil {
return fmt.Errorf("failed to start LibP2P: %w", err)
@ -135,6 +144,11 @@ func (n *Node) Start(ctx context.Context) error {
func (n *Node) Stop() error {
n.logger.ComponentInfo(logging.ComponentNode, "Stopping network node")
// Stop TURN server
if n.turnServer != nil {
_ = n.turnServer.Stop()
}
// Stop HTTP Gateway server
if n.apiGatewayServer != nil {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)

90
pkg/node/turn.go Normal file
View File

@ -0,0 +1,90 @@
package node
import (
"os"
"github.com/DeBrosOfficial/network/pkg/logging"
"github.com/DeBrosOfficial/network/pkg/turn"
"go.uber.org/zap"
)
// startTURNServer initializes and starts the built-in TURN server
func (n *Node) startTURNServer() error {
if !n.config.TURNServer.Enabled {
n.logger.ComponentInfo(logging.ComponentNode, "Built-in TURN server disabled")
return nil
}
n.logger.ComponentInfo(logging.ComponentNode, "Starting built-in TURN server")
// Get shared secret - env var takes priority over config file (for production)
sharedSecret := os.Getenv("TURN_SHARED_SECRET")
if sharedSecret == "" && n.config.HTTPGateway.TURN != nil && n.config.HTTPGateway.TURN.SharedSecret != "" {
sharedSecret = n.config.HTTPGateway.TURN.SharedSecret
}
if sharedSecret == "" {
n.logger.ComponentWarn(logging.ComponentNode, "TURN server enabled but no shared_secret configured in http_gateway.turn")
return nil
}
// Get public IP - env var takes priority over config file (for production)
publicIP := os.Getenv("TURN_PUBLIC_IP")
if publicIP == "" {
publicIP = n.config.TURNServer.PublicIP
}
// Build TURN server config
turnCfg := &turn.Config{
Enabled: true,
ListenAddr: n.config.TURNServer.ListenAddr,
PublicIP: publicIP,
Realm: n.config.TURNServer.Realm,
SharedSecret: sharedSecret,
CredentialTTL: 24 * 60 * 60, // 24 hours in seconds (will be converted)
MinPort: n.config.TURNServer.MinPort,
MaxPort: n.config.TURNServer.MaxPort,
// TLS configuration for TURNS
TLSEnabled: n.config.TURNServer.TLSEnabled,
TLSListenAddr: n.config.TURNServer.TLSListenAddr,
TLSCertFile: n.config.TURNServer.TLSCertFile,
TLSKeyFile: n.config.TURNServer.TLSKeyFile,
}
// Apply defaults
if turnCfg.ListenAddr == "" {
turnCfg.ListenAddr = "0.0.0.0:3478"
}
if turnCfg.Realm == "" {
turnCfg.Realm = "orama.network"
}
if turnCfg.MinPort == 0 {
turnCfg.MinPort = 49152
}
if turnCfg.MaxPort == 0 {
turnCfg.MaxPort = 65535
}
if turnCfg.TLSListenAddr == "" && turnCfg.TLSEnabled {
turnCfg.TLSListenAddr = "0.0.0.0:443"
}
// Create and start TURN server
server, err := turn.NewServer(turnCfg, n.logger.Logger)
if err != nil {
return err
}
if err := server.Start(); err != nil {
return err
}
n.turnServer = server
n.logger.ComponentInfo(logging.ComponentNode, "Built-in TURN server started",
zap.String("listen_addr", turnCfg.ListenAddr),
zap.String("realm", turnCfg.Realm),
zap.Bool("turns_enabled", turnCfg.TLSEnabled),
)
return nil
}

View File

@ -214,3 +214,9 @@ func IsServiceUnavailable(err error) bool {
errors.Is(err, ErrDatabaseUnavailable) ||
errors.Is(err, ErrCacheUnavailable)
}
// IsValidationError checks if an error is a validation error.
func IsValidationError(err error) bool {
var validationErr *ValidationError
return errors.As(err, &validationErr)
}

239
pkg/serverless/triggers.go Normal file
View File

@ -0,0 +1,239 @@
package serverless
import (
"context"
"fmt"
"strings"
"time"
"github.com/DeBrosOfficial/network/pkg/rqlite"
"github.com/google/uuid"
"go.uber.org/zap"
)
// Ensure DBTriggerManager implements TriggerManager interface.
var _ TriggerManager = (*DBTriggerManager)(nil)
// DBTriggerManager manages function triggers using RQLite database.
type DBTriggerManager struct {
db rqlite.Client
logger *zap.Logger
}
// NewDBTriggerManager creates a new trigger manager.
func NewDBTriggerManager(db rqlite.Client, logger *zap.Logger) *DBTriggerManager {
return &DBTriggerManager{
db: db,
logger: logger,
}
}
// pubsubTriggerRow represents a row from the function_pubsub_triggers table.
type pubsubTriggerRow struct {
ID string `db:"id"`
FunctionID string `db:"function_id"`
Topic string `db:"topic"`
Enabled bool `db:"enabled"`
CreatedAt time.Time `db:"created_at"`
}
// AddPubSubTrigger adds a pubsub trigger to a function.
func (m *DBTriggerManager) AddPubSubTrigger(ctx context.Context, functionID, topic string) error {
functionID = strings.TrimSpace(functionID)
topic = strings.TrimSpace(topic)
if functionID == "" {
return &ValidationError{Field: "function_id", Message: "cannot be empty"}
}
if topic == "" {
return &ValidationError{Field: "topic", Message: "cannot be empty"}
}
// Check if trigger already exists for this function and topic
var existing []pubsubTriggerRow
checkQuery := `SELECT id FROM function_pubsub_triggers WHERE function_id = ? AND topic = ? AND enabled = TRUE`
if err := m.db.Query(ctx, &existing, checkQuery, functionID, topic); err != nil {
return fmt.Errorf("failed to check existing trigger: %w", err)
}
if len(existing) > 0 {
return &ValidationError{Field: "topic", Message: "trigger already exists for this topic"}
}
// Generate trigger ID
triggerID := "trig_" + uuid.New().String()[:8]
// Insert trigger
query := `INSERT INTO function_pubsub_triggers (id, function_id, topic, enabled, created_at) VALUES (?, ?, ?, TRUE, ?)`
_, err := m.db.Exec(ctx, query, triggerID, functionID, topic, time.Now())
if err != nil {
return fmt.Errorf("failed to create pubsub trigger: %w", err)
}
m.logger.Info("PubSub trigger created",
zap.String("trigger_id", triggerID),
zap.String("function_id", functionID),
zap.String("topic", topic),
)
return nil
}
// AddPubSubTriggerWithID adds a pubsub trigger and returns the trigger ID.
func (m *DBTriggerManager) AddPubSubTriggerWithID(ctx context.Context, functionID, topic string) (string, error) {
functionID = strings.TrimSpace(functionID)
topic = strings.TrimSpace(topic)
if functionID == "" {
return "", &ValidationError{Field: "function_id", Message: "cannot be empty"}
}
if topic == "" {
return "", &ValidationError{Field: "topic", Message: "cannot be empty"}
}
// Check if trigger already exists for this function and topic
var existing []pubsubTriggerRow
checkQuery := `SELECT id FROM function_pubsub_triggers WHERE function_id = ? AND topic = ? AND enabled = TRUE`
if err := m.db.Query(ctx, &existing, checkQuery, functionID, topic); err != nil {
return "", fmt.Errorf("failed to check existing trigger: %w", err)
}
if len(existing) > 0 {
return "", &ValidationError{Field: "topic", Message: "trigger already exists for this topic"}
}
// Generate trigger ID
triggerID := "trig_" + uuid.New().String()[:8]
// Insert trigger
query := `INSERT INTO function_pubsub_triggers (id, function_id, topic, enabled, created_at) VALUES (?, ?, ?, TRUE, ?)`
_, err := m.db.Exec(ctx, query, triggerID, functionID, topic, time.Now())
if err != nil {
return "", fmt.Errorf("failed to create pubsub trigger: %w", err)
}
m.logger.Info("PubSub trigger created",
zap.String("trigger_id", triggerID),
zap.String("function_id", functionID),
zap.String("topic", topic),
)
return triggerID, nil
}
// AddCronTrigger adds a cron-based trigger to a function.
func (m *DBTriggerManager) AddCronTrigger(ctx context.Context, functionID, cronExpr string) error {
// TODO: Implement cron trigger support
return fmt.Errorf("cron triggers not yet implemented")
}
// AddDBTrigger adds a database trigger to a function.
func (m *DBTriggerManager) AddDBTrigger(ctx context.Context, functionID, tableName string, operation DBOperation, condition string) error {
// TODO: Implement database trigger support
return fmt.Errorf("database triggers not yet implemented")
}
// ScheduleOnce schedules a one-time execution.
func (m *DBTriggerManager) ScheduleOnce(ctx context.Context, functionID string, runAt time.Time, payload []byte) (string, error) {
// TODO: Implement one-time timer support
return "", fmt.Errorf("one-time timers not yet implemented")
}
// RemoveTrigger removes a trigger by ID.
func (m *DBTriggerManager) RemoveTrigger(ctx context.Context, triggerID string) error {
triggerID = strings.TrimSpace(triggerID)
if triggerID == "" {
return &ValidationError{Field: "trigger_id", Message: "cannot be empty"}
}
// Try to delete from pubsub triggers first
query := `DELETE FROM function_pubsub_triggers WHERE id = ?`
result, err := m.db.Exec(ctx, query, triggerID)
if err != nil {
return fmt.Errorf("failed to remove trigger: %w", err)
}
rowsAffected, _ := result.RowsAffected()
if rowsAffected == 0 {
return ErrTriggerNotFound
}
m.logger.Info("Trigger removed", zap.String("trigger_id", triggerID))
return nil
}
// ListPubSubTriggers returns all pubsub triggers for a function.
func (m *DBTriggerManager) ListPubSubTriggers(ctx context.Context, functionID string) ([]PubSubTrigger, error) {
functionID = strings.TrimSpace(functionID)
if functionID == "" {
return nil, &ValidationError{Field: "function_id", Message: "cannot be empty"}
}
query := `SELECT id, function_id, topic, enabled FROM function_pubsub_triggers WHERE function_id = ? AND enabled = TRUE`
var rows []pubsubTriggerRow
if err := m.db.Query(ctx, &rows, query, functionID); err != nil {
return nil, fmt.Errorf("failed to list pubsub triggers: %w", err)
}
triggers := make([]PubSubTrigger, len(rows))
for i, row := range rows {
triggers[i] = PubSubTrigger{
ID: row.ID,
FunctionID: row.FunctionID,
Topic: row.Topic,
Enabled: row.Enabled,
}
}
return triggers, nil
}
// GetTriggersByTopic returns all enabled triggers for a specific topic.
func (m *DBTriggerManager) GetTriggersByTopic(ctx context.Context, topic string) ([]PubSubTrigger, error) {
topic = strings.TrimSpace(topic)
if topic == "" {
return nil, &ValidationError{Field: "topic", Message: "cannot be empty"}
}
query := `SELECT id, function_id, topic, enabled FROM function_pubsub_triggers WHERE topic = ? AND enabled = TRUE`
var rows []pubsubTriggerRow
if err := m.db.Query(ctx, &rows, query, topic); err != nil {
return nil, fmt.Errorf("failed to get triggers by topic: %w", err)
}
triggers := make([]PubSubTrigger, len(rows))
for i, row := range rows {
triggers[i] = PubSubTrigger{
ID: row.ID,
FunctionID: row.FunctionID,
Topic: row.Topic,
Enabled: row.Enabled,
}
}
return triggers, nil
}
// GetPubSubTrigger returns a specific pubsub trigger by ID.
func (m *DBTriggerManager) GetPubSubTrigger(ctx context.Context, triggerID string) (*PubSubTrigger, error) {
triggerID = strings.TrimSpace(triggerID)
if triggerID == "" {
return nil, &ValidationError{Field: "trigger_id", Message: "cannot be empty"}
}
query := `SELECT id, function_id, topic, enabled FROM function_pubsub_triggers WHERE id = ?`
var rows []pubsubTriggerRow
if err := m.db.Query(ctx, &rows, query, triggerID); err != nil {
return nil, fmt.Errorf("failed to get trigger: %w", err)
}
if len(rows) == 0 {
return nil, ErrTriggerNotFound
}
row := rows[0]
return &PubSubTrigger{
ID: row.ID,
FunctionID: row.FunctionID,
Topic: row.Topic,
Enabled: row.Enabled,
}, nil
}

View File

@ -131,6 +131,12 @@ type TriggerManager interface {
// RemoveTrigger removes a trigger by ID.
RemoveTrigger(ctx context.Context, triggerID string) error
// ListPubSubTriggers returns all pubsub triggers for a function.
ListPubSubTriggers(ctx context.Context, functionID string) ([]PubSubTrigger, error)
// GetTriggersByTopic returns all enabled triggers for a specific topic.
GetTriggersByTopic(ctx context.Context, topic string) ([]PubSubTrigger, error)
}
// JobManager manages background job execution.

343
pkg/turn/server.go Normal file
View File

@ -0,0 +1,343 @@
// Package turn provides a built-in TURN/STUN server using Pion.
package turn
import (
"crypto/hmac"
"crypto/sha1"
"crypto/tls"
"encoding/base64"
"fmt"
"net"
"strconv"
"sync"
"time"
"github.com/pion/turn/v4"
"go.uber.org/zap"
)
// Config contains TURN server configuration
type Config struct {
// Enabled enables the built-in TURN server
Enabled bool `yaml:"enabled"`
// ListenAddr is the UDP address to listen on (e.g., "0.0.0.0:3478")
ListenAddr string `yaml:"listen_addr"`
// PublicIP is the public IP address to advertise for relay
// If empty, will try to auto-detect
PublicIP string `yaml:"public_ip"`
// Realm is the TURN realm (e.g., "orama.network")
Realm string `yaml:"realm"`
// SharedSecret is the secret for HMAC-SHA1 credential generation
// Should match the gateway's TURN_SHARED_SECRET
SharedSecret string `yaml:"shared_secret"`
// CredentialTTL is the lifetime of generated credentials
CredentialTTL time.Duration `yaml:"credential_ttl"`
// MinPort and MaxPort define the relay port range
MinPort uint16 `yaml:"min_port"`
MaxPort uint16 `yaml:"max_port"`
// TLS Configuration for TURNS (TURN over TLS)
// TLSEnabled enables TURNS listener on TLSListenAddr
TLSEnabled bool `yaml:"tls_enabled"`
// TLSListenAddr is the TCP/TLS address to listen on (e.g., "0.0.0.0:443")
TLSListenAddr string `yaml:"tls_listen_addr"`
// TLSCertFile is the path to the TLS certificate file
TLSCertFile string `yaml:"tls_cert_file"`
// TLSKeyFile is the path to the TLS private key file
TLSKeyFile string `yaml:"tls_key_file"`
}
// DefaultConfig returns a default TURN server configuration
func DefaultConfig() *Config {
return &Config{
Enabled: false,
ListenAddr: "0.0.0.0:3478",
Realm: "orama.network",
CredentialTTL: 24 * time.Hour,
MinPort: 49152,
MaxPort: 65535,
}
}
// Server is a built-in TURN/STUN server
type Server struct {
config *Config
logger *zap.Logger
turnServer *turn.Server
conn net.PacketConn // UDP listener
tlsListener net.Listener // TLS listener for TURNS
mu sync.RWMutex
running bool
}
// NewServer creates a new TURN server
func NewServer(config *Config, logger *zap.Logger) (*Server, error) {
if config == nil {
config = DefaultConfig()
}
return &Server{
config: config,
logger: logger,
}, nil
}
// Start starts the TURN server
func (s *Server) Start() error {
s.mu.Lock()
defer s.mu.Unlock()
if s.running {
return nil
}
if !s.config.Enabled {
s.logger.Info("TURN server disabled")
return nil
}
if s.config.SharedSecret == "" {
return fmt.Errorf("TURN shared secret is required")
}
// Create UDP listener
conn, err := net.ListenPacket("udp4", s.config.ListenAddr)
if err != nil {
return fmt.Errorf("failed to listen on %s: %w", s.config.ListenAddr, err)
}
s.conn = conn
// Determine public IP
publicIP := s.config.PublicIP
if publicIP == "" {
// Try to auto-detect
publicIP, err = getPublicIP()
if err != nil {
s.logger.Warn("Failed to auto-detect public IP, using listener address", zap.Error(err))
host, _, _ := net.SplitHostPort(s.config.ListenAddr)
if host == "0.0.0.0" || host == "" {
host = "127.0.0.1"
}
publicIP = host
}
}
relayIP := net.ParseIP(publicIP)
if relayIP == nil {
return fmt.Errorf("invalid public IP: %s", publicIP)
}
s.logger.Info("Starting TURN server",
zap.String("listen_addr", s.config.ListenAddr),
zap.String("public_ip", publicIP),
zap.String("realm", s.config.Realm),
zap.Uint16("min_port", s.config.MinPort),
zap.Uint16("max_port", s.config.MaxPort),
zap.Bool("tls_enabled", s.config.TLSEnabled),
)
// Prepare listener configs for TLS (TURNS)
var listenerConfigs []turn.ListenerConfig
if s.config.TLSEnabled && s.config.TLSCertFile != "" && s.config.TLSKeyFile != "" {
// Load TLS certificate
cert, err := tls.LoadX509KeyPair(s.config.TLSCertFile, s.config.TLSKeyFile)
if err != nil {
conn.Close()
return fmt.Errorf("failed to load TLS certificate: %w", err)
}
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{cert},
MinVersion: tls.VersionTLS12,
}
// Determine TLS listen address
tlsListenAddr := s.config.TLSListenAddr
if tlsListenAddr == "" {
tlsListenAddr = "0.0.0.0:443"
}
// Create TLS listener
tlsListener, err := tls.Listen("tcp", tlsListenAddr, tlsConfig)
if err != nil {
conn.Close()
return fmt.Errorf("failed to start TLS listener on %s: %w", tlsListenAddr, err)
}
s.tlsListener = tlsListener
listenerConfigs = append(listenerConfigs, turn.ListenerConfig{
Listener: tlsListener,
RelayAddressGenerator: &turn.RelayAddressGeneratorPortRange{
RelayAddress: relayIP,
Address: "0.0.0.0",
MinPort: s.config.MinPort,
MaxPort: s.config.MaxPort,
},
})
s.logger.Info("TURNS (TLS) listener started",
zap.String("tls_addr", tlsListenAddr),
)
}
// Create TURN server with HMAC-SHA1 auth
turnServer, err := turn.NewServer(turn.ServerConfig{
Realm: s.config.Realm,
AuthHandler: func(username, realm string, srcAddr net.Addr) ([]byte, bool) {
return s.authHandler(username, realm, srcAddr)
},
PacketConnConfigs: []turn.PacketConnConfig{
{
PacketConn: conn,
RelayAddressGenerator: &turn.RelayAddressGeneratorPortRange{
RelayAddress: relayIP,
Address: "0.0.0.0",
MinPort: s.config.MinPort,
MaxPort: s.config.MaxPort,
},
},
},
ListenerConfigs: listenerConfigs,
})
if err != nil {
conn.Close()
if s.tlsListener != nil {
s.tlsListener.Close()
}
return fmt.Errorf("failed to create TURN server: %w", err)
}
s.turnServer = turnServer
s.running = true
s.logger.Info("TURN server started successfully",
zap.String("addr", s.config.ListenAddr),
zap.String("realm", s.config.Realm),
zap.Bool("turns_enabled", s.config.TLSEnabled),
)
return nil
}
// authHandler validates HMAC-SHA1 credentials (coturn-compatible format)
// Username format: timestamp:userID (e.g., "1234567890:user123")
func (s *Server) authHandler(username, realm string, srcAddr net.Addr) ([]byte, bool) {
// Parse timestamp from username
// Format: timestamp:userID
var timestamp int64
for i, c := range username {
if c == ':' {
ts, err := strconv.ParseInt(username[:i], 10, 64)
if err != nil {
s.logger.Debug("Invalid timestamp in username", zap.String("username", username))
return nil, false
}
timestamp = ts
break
}
}
// Check if credential has expired
now := time.Now().Unix()
if timestamp > 0 && timestamp < now {
s.logger.Debug("Credential expired",
zap.String("username", username),
zap.Int64("expired_at", timestamp),
zap.Int64("now", now),
)
return nil, false
}
// Generate expected password using HMAC-SHA1
// This matches the gateway's generateTURNCredentials function
h := hmac.New(sha1.New, []byte(s.config.SharedSecret))
h.Write([]byte(username))
password := base64.StdEncoding.EncodeToString(h.Sum(nil))
s.logger.Debug("TURN auth request",
zap.String("username", username),
zap.String("realm", realm),
zap.String("src_addr", srcAddr.String()),
)
return turn.GenerateAuthKey(username, realm, password), true
}
// Stop stops the TURN server
func (s *Server) Stop() error {
s.mu.Lock()
defer s.mu.Unlock()
if !s.running {
return nil
}
s.logger.Info("Stopping TURN server")
if s.turnServer != nil {
if err := s.turnServer.Close(); err != nil {
s.logger.Warn("Error closing TURN server", zap.Error(err))
}
s.turnServer = nil
}
if s.conn != nil {
s.conn.Close()
s.conn = nil
}
// Close TLS listener
if s.tlsListener != nil {
s.tlsListener.Close()
s.tlsListener = nil
}
s.running = false
s.logger.Info("TURN server stopped")
return nil
}
// IsRunning returns whether the server is running
func (s *Server) IsRunning() bool {
s.mu.RLock()
defer s.mu.RUnlock()
return s.running
}
// GetListenAddr returns the listen address
func (s *Server) GetListenAddr() string {
return s.config.ListenAddr
}
// GetPublicAddr returns the public address for clients
func (s *Server) GetPublicAddr() string {
if s.config.PublicIP != "" {
_, port, _ := net.SplitHostPort(s.config.ListenAddr)
return net.JoinHostPort(s.config.PublicIP, port)
}
return s.config.ListenAddr
}
// getPublicIP tries to determine the public IP address
func getPublicIP() (string, error) {
// Try to get outbound IP by connecting to a public address
conn, err := net.Dial("udp4", "8.8.8.8:80")
if err != nil {
return "", err
}
defer conn.Close()
localAddr := conn.LocalAddr().(*net.UDPAddr)
return localAddr.IP.String(), nil
}