mirror of
https://github.com/DeBrosOfficial/orama.git
synced 2026-03-17 09:36:56 +00:00
initial
This commit is contained in:
parent
ade6241357
commit
ea48a21ae4
@ -73,6 +73,27 @@ func parseGatewayConfig(logger *logging.ColoredLogger) *gateway.Config {
|
||||
}
|
||||
|
||||
// Load YAML
|
||||
type yamlICEServer struct {
|
||||
URLs []string `yaml:"urls"`
|
||||
Username string `yaml:"username,omitempty"`
|
||||
Credential string `yaml:"credential,omitempty"`
|
||||
}
|
||||
|
||||
type yamlTURN struct {
|
||||
SharedSecret string `yaml:"shared_secret"`
|
||||
TTL string `yaml:"ttl"`
|
||||
ExternalHost string `yaml:"external_host"`
|
||||
STUNURLs []string `yaml:"stun_urls"`
|
||||
TURNURLs []string `yaml:"turn_urls"`
|
||||
}
|
||||
|
||||
type yamlSFU struct {
|
||||
Enabled bool `yaml:"enabled"`
|
||||
MaxParticipants int `yaml:"max_participants"`
|
||||
MediaTimeout string `yaml:"media_timeout"`
|
||||
ICEServers []yamlICEServer `yaml:"ice_servers"`
|
||||
}
|
||||
|
||||
type yamlCfg struct {
|
||||
ListenAddr string `yaml:"listen_addr"`
|
||||
ClientNamespace string `yaml:"client_namespace"`
|
||||
@ -87,6 +108,8 @@ func parseGatewayConfig(logger *logging.ColoredLogger) *gateway.Config {
|
||||
IPFSAPIURL string `yaml:"ipfs_api_url"`
|
||||
IPFSTimeout string `yaml:"ipfs_timeout"`
|
||||
IPFSReplicationFactor int `yaml:"ipfs_replication_factor"`
|
||||
TURN yamlTURN `yaml:"turn"`
|
||||
SFU yamlSFU `yaml:"sfu"`
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(configPath)
|
||||
@ -191,6 +214,64 @@ func parseGatewayConfig(logger *logging.ColoredLogger) *gateway.Config {
|
||||
cfg.IPFSReplicationFactor = y.IPFSReplicationFactor
|
||||
}
|
||||
|
||||
// TURN configuration
|
||||
if y.TURN.SharedSecret != "" || len(y.TURN.STUNURLs) > 0 || len(y.TURN.TURNURLs) > 0 {
|
||||
turnCfg := &config.TURNConfig{
|
||||
SharedSecret: y.TURN.SharedSecret,
|
||||
ExternalHost: y.TURN.ExternalHost,
|
||||
STUNURLs: y.TURN.STUNURLs,
|
||||
TURNURLs: y.TURN.TURNURLs,
|
||||
}
|
||||
// Check for environment variable overrides
|
||||
if envSecret := os.Getenv("TURN_SHARED_SECRET"); envSecret != "" {
|
||||
turnCfg.SharedSecret = envSecret
|
||||
}
|
||||
if envHost := os.Getenv("TURN_EXTERNAL_HOST"); envHost != "" {
|
||||
turnCfg.ExternalHost = envHost
|
||||
}
|
||||
if v := strings.TrimSpace(y.TURN.TTL); v != "" {
|
||||
if parsed, err := time.ParseDuration(v); err == nil {
|
||||
turnCfg.TTL = parsed
|
||||
} else {
|
||||
logger.ComponentWarn(logging.ComponentGeneral, "invalid turn.ttl, using default", zap.String("value", v), zap.Error(err))
|
||||
}
|
||||
}
|
||||
cfg.TURN = turnCfg
|
||||
logger.ComponentInfo(logging.ComponentGeneral, "TURN configuration loaded",
|
||||
zap.Int("stun_urls", len(turnCfg.STUNURLs)),
|
||||
zap.Int("turn_urls", len(turnCfg.TURNURLs)),
|
||||
zap.String("external_host", turnCfg.ExternalHost),
|
||||
)
|
||||
}
|
||||
|
||||
// SFU configuration
|
||||
if y.SFU.Enabled {
|
||||
sfuCfg := &config.SFUConfig{
|
||||
Enabled: true,
|
||||
MaxParticipants: y.SFU.MaxParticipants,
|
||||
}
|
||||
if v := strings.TrimSpace(y.SFU.MediaTimeout); v != "" {
|
||||
if parsed, err := time.ParseDuration(v); err == nil {
|
||||
sfuCfg.MediaTimeout = parsed
|
||||
} else {
|
||||
logger.ComponentWarn(logging.ComponentGeneral, "invalid sfu.media_timeout, using default", zap.String("value", v), zap.Error(err))
|
||||
}
|
||||
}
|
||||
// Parse ICE servers
|
||||
for _, iceServer := range y.SFU.ICEServers {
|
||||
sfuCfg.ICEServers = append(sfuCfg.ICEServers, config.ICEServerConfig{
|
||||
URLs: iceServer.URLs,
|
||||
Username: iceServer.Username,
|
||||
Credential: iceServer.Credential,
|
||||
})
|
||||
}
|
||||
cfg.SFU = sfuCfg
|
||||
logger.ComponentInfo(logging.ComponentGeneral, "SFU configuration loaded",
|
||||
zap.Int("max_participants", sfuCfg.MaxParticipants),
|
||||
zap.Int("ice_servers", len(sfuCfg.ICEServers)),
|
||||
)
|
||||
}
|
||||
|
||||
// Validate configuration
|
||||
if errs := cfg.ValidateConfig(); len(errs) > 0 {
|
||||
fmt.Fprintf(os.Stderr, "\nGateway configuration errors (%d):\n", len(errs))
|
||||
|
||||
8
go.mod
8
go.mod
@ -18,6 +18,10 @@ require (
|
||||
github.com/mattn/go-sqlite3 v1.14.32
|
||||
github.com/multiformats/go-multiaddr v0.15.0
|
||||
github.com/olric-data/olric v0.7.0
|
||||
github.com/pion/interceptor v0.1.37
|
||||
github.com/pion/rtcp v1.2.15
|
||||
github.com/pion/turn/v4 v4.0.0
|
||||
github.com/pion/webrtc/v4 v4.0.10
|
||||
github.com/rqlite/gorqlite v0.0.0-20250609141355-ac86a4a1c9a8
|
||||
github.com/tetratelabs/wazero v1.11.0
|
||||
go.uber.org/zap v1.27.0
|
||||
@ -113,11 +117,9 @@ require (
|
||||
github.com/pion/dtls/v2 v2.2.12 // indirect
|
||||
github.com/pion/dtls/v3 v3.0.4 // indirect
|
||||
github.com/pion/ice/v4 v4.0.8 // indirect
|
||||
github.com/pion/interceptor v0.1.37 // indirect
|
||||
github.com/pion/logging v0.2.3 // indirect
|
||||
github.com/pion/mdns/v2 v2.0.7 // indirect
|
||||
github.com/pion/randutil v0.1.0 // indirect
|
||||
github.com/pion/rtcp v1.2.15 // indirect
|
||||
github.com/pion/rtp v1.8.11 // indirect
|
||||
github.com/pion/sctp v1.8.37 // indirect
|
||||
github.com/pion/sdp/v3 v3.0.10 // indirect
|
||||
@ -126,8 +128,6 @@ require (
|
||||
github.com/pion/stun/v3 v3.0.0 // indirect
|
||||
github.com/pion/transport/v2 v2.2.10 // indirect
|
||||
github.com/pion/transport/v3 v3.0.7 // indirect
|
||||
github.com/pion/turn/v4 v4.0.0 // indirect
|
||||
github.com/pion/webrtc/v4 v4.0.10 // indirect
|
||||
github.com/pkg/errors v0.9.1 // indirect
|
||||
github.com/prometheus/client_golang v1.22.0 // indirect
|
||||
github.com/prometheus/client_model v0.6.2 // indirect
|
||||
|
||||
@ -3,7 +3,6 @@ package config
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/DeBrosOfficial/network/pkg/config/validate"
|
||||
"github.com/multiformats/go-multiaddr"
|
||||
)
|
||||
|
||||
@ -15,69 +14,248 @@ type Config struct {
|
||||
Security SecurityConfig `yaml:"security"`
|
||||
Logging LoggingConfig `yaml:"logging"`
|
||||
HTTPGateway HTTPGatewayConfig `yaml:"http_gateway"`
|
||||
TURNServer TURNServerConfig `yaml:"turn_server"` // Built-in TURN server
|
||||
}
|
||||
|
||||
// ValidationError represents a single validation error with context.
|
||||
// This is exported from the validate subpackage for backward compatibility.
|
||||
type ValidationError = validate.ValidationError
|
||||
|
||||
// ValidateSwarmKey validates that a swarm key is 64 hex characters.
|
||||
// This is exported from the validate subpackage for backward compatibility.
|
||||
func ValidateSwarmKey(key string) error {
|
||||
return validate.ValidateSwarmKey(key)
|
||||
// NodeConfig contains node-specific configuration
|
||||
type NodeConfig struct {
|
||||
ID string `yaml:"id"` // Auto-generated if empty
|
||||
ListenAddresses []string `yaml:"listen_addresses"` // LibP2P listen addresses
|
||||
DataDir string `yaml:"data_dir"` // Data directory
|
||||
MaxConnections int `yaml:"max_connections"` // Maximum peer connections
|
||||
Domain string `yaml:"domain"` // Domain for this node (e.g., node-1.orama.network)
|
||||
}
|
||||
|
||||
// Validate performs comprehensive validation of the entire config.
|
||||
// It aggregates all errors and returns them, allowing the caller to print all issues at once.
|
||||
func (c *Config) Validate() []error {
|
||||
var errs []error
|
||||
// DatabaseConfig contains database-related configuration
|
||||
type DatabaseConfig struct {
|
||||
DataDir string `yaml:"data_dir"`
|
||||
ReplicationFactor int `yaml:"replication_factor"`
|
||||
ShardCount int `yaml:"shard_count"`
|
||||
MaxDatabaseSize int64 `yaml:"max_database_size"` // In bytes
|
||||
BackupInterval time.Duration `yaml:"backup_interval"`
|
||||
|
||||
// Validate node config
|
||||
errs = append(errs, validate.ValidateNode(validate.NodeConfig{
|
||||
ID: c.Node.ID,
|
||||
ListenAddresses: c.Node.ListenAddresses,
|
||||
DataDir: c.Node.DataDir,
|
||||
MaxConnections: c.Node.MaxConnections,
|
||||
})...)
|
||||
// RQLite-specific configuration
|
||||
RQLitePort int `yaml:"rqlite_port"` // RQLite HTTP API port
|
||||
RQLiteRaftPort int `yaml:"rqlite_raft_port"` // RQLite Raft consensus port
|
||||
RQLiteJoinAddress string `yaml:"rqlite_join_address"` // Address to join RQLite cluster
|
||||
|
||||
// Validate database config
|
||||
errs = append(errs, validate.ValidateDatabase(validate.DatabaseConfig{
|
||||
DataDir: c.Database.DataDir,
|
||||
ReplicationFactor: c.Database.ReplicationFactor,
|
||||
ShardCount: c.Database.ShardCount,
|
||||
MaxDatabaseSize: c.Database.MaxDatabaseSize,
|
||||
RQLitePort: c.Database.RQLitePort,
|
||||
RQLiteRaftPort: c.Database.RQLiteRaftPort,
|
||||
RQLiteJoinAddress: c.Database.RQLiteJoinAddress,
|
||||
ClusterSyncInterval: c.Database.ClusterSyncInterval,
|
||||
PeerInactivityLimit: c.Database.PeerInactivityLimit,
|
||||
MinClusterSize: c.Database.MinClusterSize,
|
||||
})...)
|
||||
// RQLite node-to-node TLS encryption (for inter-node Raft communication)
|
||||
// See: https://rqlite.io/docs/guides/security/#encrypting-node-to-node-communication
|
||||
NodeCert string `yaml:"node_cert"` // Path to X.509 certificate for node-to-node communication
|
||||
NodeKey string `yaml:"node_key"` // Path to X.509 private key for node-to-node communication
|
||||
NodeCACert string `yaml:"node_ca_cert"` // Path to CA certificate (optional, uses system CA if not set)
|
||||
NodeNoVerify bool `yaml:"node_no_verify"` // Skip certificate verification (for testing/self-signed certs)
|
||||
|
||||
// Validate discovery config
|
||||
errs = append(errs, validate.ValidateDiscovery(validate.DiscoveryConfig{
|
||||
BootstrapPeers: c.Discovery.BootstrapPeers,
|
||||
DiscoveryInterval: c.Discovery.DiscoveryInterval,
|
||||
BootstrapPort: c.Discovery.BootstrapPort,
|
||||
HttpAdvAddress: c.Discovery.HttpAdvAddress,
|
||||
RaftAdvAddress: c.Discovery.RaftAdvAddress,
|
||||
})...)
|
||||
// Dynamic discovery configuration (always enabled)
|
||||
ClusterSyncInterval time.Duration `yaml:"cluster_sync_interval"` // default: 30s
|
||||
PeerInactivityLimit time.Duration `yaml:"peer_inactivity_limit"` // default: 24h
|
||||
MinClusterSize int `yaml:"min_cluster_size"` // default: 1
|
||||
|
||||
// Validate security config
|
||||
errs = append(errs, validate.ValidateSecurity(validate.SecurityConfig{
|
||||
EnableTLS: c.Security.EnableTLS,
|
||||
PrivateKeyFile: c.Security.PrivateKeyFile,
|
||||
CertificateFile: c.Security.CertificateFile,
|
||||
})...)
|
||||
// Olric cache configuration
|
||||
OlricHTTPPort int `yaml:"olric_http_port"` // Olric HTTP API port (default: 3320)
|
||||
OlricMemberlistPort int `yaml:"olric_memberlist_port"` // Olric memberlist port (default: 3322)
|
||||
|
||||
// Validate logging config
|
||||
errs = append(errs, validate.ValidateLogging(validate.LoggingConfig{
|
||||
Level: c.Logging.Level,
|
||||
Format: c.Logging.Format,
|
||||
OutputFile: c.Logging.OutputFile,
|
||||
})...)
|
||||
// IPFS storage configuration
|
||||
IPFS IPFSConfig `yaml:"ipfs"`
|
||||
}
|
||||
|
||||
return errs
|
||||
// IPFSConfig contains IPFS storage configuration
|
||||
type IPFSConfig struct {
|
||||
// ClusterAPIURL is the IPFS Cluster HTTP API URL (e.g., "http://localhost:9094")
|
||||
// If empty, IPFS storage is disabled for this node
|
||||
ClusterAPIURL string `yaml:"cluster_api_url"`
|
||||
|
||||
// APIURL is the IPFS HTTP API URL for content retrieval (e.g., "http://localhost:5001")
|
||||
// If empty, defaults to "http://localhost:5001"
|
||||
APIURL string `yaml:"api_url"`
|
||||
|
||||
// Timeout for IPFS operations
|
||||
// If zero, defaults to 60 seconds
|
||||
Timeout time.Duration `yaml:"timeout"`
|
||||
|
||||
// ReplicationFactor is the replication factor for pinned content
|
||||
// If zero, defaults to 3
|
||||
ReplicationFactor int `yaml:"replication_factor"`
|
||||
|
||||
// EnableEncryption enables client-side encryption before upload
|
||||
// Defaults to true
|
||||
EnableEncryption bool `yaml:"enable_encryption"`
|
||||
}
|
||||
|
||||
// DiscoveryConfig contains peer discovery configuration
|
||||
type DiscoveryConfig struct {
|
||||
BootstrapPeers []string `yaml:"bootstrap_peers"` // Peer addresses to connect to
|
||||
DiscoveryInterval time.Duration `yaml:"discovery_interval"` // Discovery announcement interval
|
||||
BootstrapPort int `yaml:"bootstrap_port"` // Default port for peer discovery
|
||||
HttpAdvAddress string `yaml:"http_adv_address"` // HTTP advertisement address
|
||||
RaftAdvAddress string `yaml:"raft_adv_address"` // Raft advertisement
|
||||
NodeNamespace string `yaml:"node_namespace"` // Namespace for node identifiers
|
||||
}
|
||||
|
||||
// SecurityConfig contains security-related configuration
|
||||
type SecurityConfig struct {
|
||||
EnableTLS bool `yaml:"enable_tls"`
|
||||
PrivateKeyFile string `yaml:"private_key_file"`
|
||||
CertificateFile string `yaml:"certificate_file"`
|
||||
}
|
||||
|
||||
// LoggingConfig contains logging configuration
|
||||
type LoggingConfig struct {
|
||||
Level string `yaml:"level"` // debug, info, warn, error
|
||||
Format string `yaml:"format"` // json, console
|
||||
OutputFile string `yaml:"output_file"` // Empty for stdout
|
||||
}
|
||||
|
||||
// HTTPGatewayConfig contains HTTP reverse proxy gateway configuration
|
||||
type HTTPGatewayConfig struct {
|
||||
Enabled bool `yaml:"enabled"` // Enable HTTP gateway
|
||||
ListenAddr string `yaml:"listen_addr"` // Address to listen on (e.g., ":8080")
|
||||
NodeName string `yaml:"node_name"` // Node name for routing
|
||||
Routes map[string]RouteConfig `yaml:"routes"` // Service routes
|
||||
HTTPS HTTPSConfig `yaml:"https"` // HTTPS/TLS configuration
|
||||
SNI SNIConfig `yaml:"sni"` // SNI-based TCP routing configuration
|
||||
|
||||
// Full gateway configuration (for API, auth, pubsub)
|
||||
ClientNamespace string `yaml:"client_namespace"` // Namespace for network client
|
||||
RQLiteDSN string `yaml:"rqlite_dsn"` // RQLite database DSN
|
||||
OlricServers []string `yaml:"olric_servers"` // List of Olric server addresses
|
||||
OlricTimeout time.Duration `yaml:"olric_timeout"` // Timeout for Olric operations
|
||||
IPFSClusterAPIURL string `yaml:"ipfs_cluster_api_url"` // IPFS Cluster API URL
|
||||
IPFSAPIURL string `yaml:"ipfs_api_url"` // IPFS API URL
|
||||
IPFSTimeout time.Duration `yaml:"ipfs_timeout"` // Timeout for IPFS operations
|
||||
|
||||
// WebRTC configuration for video/audio calls
|
||||
TURN *TURNConfig `yaml:"turn"` // TURN/STUN server configuration
|
||||
SFU *SFUConfig `yaml:"sfu"` // SFU (Selective Forwarding Unit) configuration
|
||||
}
|
||||
|
||||
// HTTPSConfig contains HTTPS/TLS configuration for the gateway
|
||||
type HTTPSConfig struct {
|
||||
Enabled bool `yaml:"enabled"` // Enable HTTPS (port 443)
|
||||
Domain string `yaml:"domain"` // Primary domain (e.g., node-123.orama.network)
|
||||
AutoCert bool `yaml:"auto_cert"` // Use Let's Encrypt for automatic certificate
|
||||
UseSelfSigned bool `yaml:"use_self_signed"` // Use self-signed certificates (pre-generated)
|
||||
CertFile string `yaml:"cert_file"` // Path to certificate file (if not using auto_cert)
|
||||
KeyFile string `yaml:"key_file"` // Path to key file (if not using auto_cert)
|
||||
CacheDir string `yaml:"cache_dir"` // Directory for Let's Encrypt certificate cache
|
||||
HTTPPort int `yaml:"http_port"` // HTTP port for ACME challenge (default: 80)
|
||||
HTTPSPort int `yaml:"https_port"` // HTTPS port (default: 443)
|
||||
Email string `yaml:"email"` // Email for Let's Encrypt account
|
||||
}
|
||||
|
||||
// SNIConfig contains SNI-based TCP routing configuration for port 7001
|
||||
type SNIConfig struct {
|
||||
Enabled bool `yaml:"enabled"` // Enable SNI-based TCP routing
|
||||
ListenAddr string `yaml:"listen_addr"` // Address to listen on (e.g., ":7001")
|
||||
Routes map[string]string `yaml:"routes"` // SNI hostname -> backend address mapping
|
||||
CertFile string `yaml:"cert_file"` // Path to certificate file
|
||||
KeyFile string `yaml:"key_file"` // Path to key file
|
||||
}
|
||||
|
||||
// RouteConfig defines a single reverse proxy route
|
||||
type RouteConfig struct {
|
||||
PathPrefix string `yaml:"path_prefix"` // URL path prefix (e.g., "/rqlite/http")
|
||||
BackendURL string `yaml:"backend_url"` // Backend service URL
|
||||
Timeout time.Duration `yaml:"timeout"` // Request timeout
|
||||
WebSocket bool `yaml:"websocket"` // Support WebSocket upgrades
|
||||
}
|
||||
|
||||
// ClientConfig represents configuration for network clients
|
||||
type ClientConfig struct {
|
||||
AppName string `yaml:"app_name"`
|
||||
DatabaseName string `yaml:"database_name"`
|
||||
BootstrapPeers []string `yaml:"bootstrap_peers"`
|
||||
ConnectTimeout time.Duration `yaml:"connect_timeout"`
|
||||
RetryAttempts int `yaml:"retry_attempts"`
|
||||
}
|
||||
|
||||
// TURNConfig contains TURN/STUN server credential configuration
|
||||
type TURNConfig struct {
|
||||
// SharedSecret is the shared secret for TURN credential generation (HMAC-SHA1)
|
||||
// Should be set via TURN_SHARED_SECRET environment variable
|
||||
SharedSecret string `yaml:"shared_secret"`
|
||||
|
||||
// TTL is the time-to-live for generated credentials
|
||||
// Default: 24 hours
|
||||
TTL time.Duration `yaml:"ttl"`
|
||||
|
||||
// ExternalHost is the external hostname or IP address for STUN/TURN URLs
|
||||
// - Production: Set to your public domain (e.g., "turn.example.com")
|
||||
// - Development: Leave empty for auto-detection of LAN IP
|
||||
// Can also be set via TURN_EXTERNAL_HOST environment variable
|
||||
ExternalHost string `yaml:"external_host"`
|
||||
|
||||
// STUNURLs are the STUN server URLs to return to clients
|
||||
// Use "::" as placeholder for ExternalHost (e.g., "stun:::3478" -> "stun:turn.example.com:3478")
|
||||
// e.g., ["stun:::3478"] or ["stun:gateway.orama.com:3478"]
|
||||
STUNURLs []string `yaml:"stun_urls"`
|
||||
|
||||
// TURNURLs are the TURN server URLs to return to clients
|
||||
// Use "::" as placeholder for ExternalHost (e.g., "turn:::3478" -> "turn:turn.example.com:3478")
|
||||
// e.g., ["turn:::3478?transport=udp"] or ["turn:gateway.orama.com:3478?transport=udp"]
|
||||
TURNURLs []string `yaml:"turn_urls"`
|
||||
|
||||
// TLSEnabled indicates whether TURNS (TURN over TLS) is available
|
||||
// When true, turns:// URLs will be included in the response
|
||||
TLSEnabled bool `yaml:"tls_enabled"`
|
||||
}
|
||||
|
||||
// SFUConfig contains WebRTC SFU (Selective Forwarding Unit) configuration
|
||||
type SFUConfig struct {
|
||||
// Enabled enables the SFU service
|
||||
Enabled bool `yaml:"enabled"`
|
||||
|
||||
// MaxParticipants is the maximum number of participants per room
|
||||
// Default: 10
|
||||
MaxParticipants int `yaml:"max_participants"`
|
||||
|
||||
// MediaTimeout is the timeout for media operations
|
||||
// Default: 30 seconds
|
||||
MediaTimeout time.Duration `yaml:"media_timeout"`
|
||||
|
||||
// ICEServers are additional ICE servers for WebRTC connections
|
||||
// These are used in addition to the TURN servers from TURNConfig
|
||||
ICEServers []ICEServerConfig `yaml:"ice_servers"`
|
||||
}
|
||||
|
||||
// ICEServerConfig represents a single ICE server configuration
|
||||
type ICEServerConfig struct {
|
||||
URLs []string `yaml:"urls"`
|
||||
Username string `yaml:"username,omitempty"`
|
||||
Credential string `yaml:"credential,omitempty"`
|
||||
}
|
||||
|
||||
// TURNServerConfig contains built-in TURN server configuration
|
||||
type TURNServerConfig struct {
|
||||
// Enabled enables the built-in TURN server
|
||||
Enabled bool `yaml:"enabled"`
|
||||
|
||||
// ListenAddr is the UDP address to listen on (e.g., "0.0.0.0:3478")
|
||||
ListenAddr string `yaml:"listen_addr"`
|
||||
|
||||
// PublicIP is the public IP address to advertise for relay
|
||||
// If empty, will try to auto-detect
|
||||
PublicIP string `yaml:"public_ip"`
|
||||
|
||||
// Realm is the TURN realm (e.g., "orama.network")
|
||||
Realm string `yaml:"realm"`
|
||||
|
||||
// MinPort and MaxPort define the relay port range
|
||||
MinPort uint16 `yaml:"min_port"`
|
||||
MaxPort uint16 `yaml:"max_port"`
|
||||
|
||||
// TLS Configuration for TURNS (TURN over TLS)
|
||||
// TLSEnabled enables TURNS listener
|
||||
TLSEnabled bool `yaml:"tls_enabled"`
|
||||
|
||||
// TLSListenAddr is the TCP/TLS address to listen on (e.g., "0.0.0.0:443")
|
||||
TLSListenAddr string `yaml:"tls_listen_addr"`
|
||||
|
||||
// TLSCertFile is the path to the TLS certificate file
|
||||
TLSCertFile string `yaml:"tls_cert_file"`
|
||||
|
||||
// TLSKeyFile is the path to the TLS private key file
|
||||
TLSKeyFile string `yaml:"tls_key_file"`
|
||||
}
|
||||
|
||||
// ParseMultiaddrs converts string addresses to multiaddr objects
|
||||
|
||||
@ -47,6 +47,15 @@ logging:
|
||||
level: "info"
|
||||
format: "console"
|
||||
|
||||
# Built-in TURN server for WebRTC NAT traversal
|
||||
turn_server:
|
||||
enabled: true
|
||||
listen_addr: "0.0.0.0:3478"
|
||||
public_ip: "" # Auto-detect if empty, or set to your public IP
|
||||
realm: "orama.network"
|
||||
min_port: 49152
|
||||
max_port: 65535
|
||||
|
||||
http_gateway:
|
||||
enabled: true
|
||||
listen_addr: "{{if .EnableHTTPS}}:{{.HTTPSPort}}{{else}}:{{.UnifiedGatewayPort}}{{end}}"
|
||||
@ -86,3 +95,19 @@ http_gateway:
|
||||
|
||||
# Routes for internal service reverse proxy (kept for backwards compatibility but not used by full gateway)
|
||||
routes: {}
|
||||
|
||||
# TURN/STUN URLs returned to clients (points to built-in TURN server)
|
||||
turn:
|
||||
shared_secret: "dev-secret-12345"
|
||||
ttl: "24h"
|
||||
stun_urls:
|
||||
- "stun:::3478"
|
||||
turn_urls:
|
||||
- "turn:::3478?transport=udp"
|
||||
|
||||
# SFU (Selective Forwarding Unit) configuration for WebRTC group calls
|
||||
sfu:
|
||||
enabled: true
|
||||
max_participants: 10
|
||||
media_timeout: "30s"
|
||||
ice_servers: [] # Additional ICE servers beyond TURN config (optional)
|
||||
|
||||
@ -1,31 +1,75 @@
|
||||
// Package gateway provides the main API Gateway for the Orama Network.
|
||||
// It orchestrates traffic between clients and various backend services including
|
||||
// distributed caching (Olric), decentralized storage (IPFS), and serverless
|
||||
// WebAssembly (WASM) execution. The gateway implements robust security through
|
||||
// wallet-based cryptographic authentication and JWT lifecycle management.
|
||||
package gateway
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"database/sql"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/DeBrosOfficial/network/pkg/client"
|
||||
"github.com/DeBrosOfficial/network/pkg/config"
|
||||
"github.com/DeBrosOfficial/network/pkg/gateway/auth"
|
||||
authhandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/auth"
|
||||
"github.com/DeBrosOfficial/network/pkg/gateway/handlers/cache"
|
||||
pubsubhandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/pubsub"
|
||||
serverlesshandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/serverless"
|
||||
"github.com/DeBrosOfficial/network/pkg/gateway/handlers/storage"
|
||||
"github.com/DeBrosOfficial/network/pkg/ipfs"
|
||||
"github.com/DeBrosOfficial/network/pkg/logging"
|
||||
"github.com/DeBrosOfficial/network/pkg/olric"
|
||||
"github.com/DeBrosOfficial/network/pkg/pubsub"
|
||||
"github.com/DeBrosOfficial/network/pkg/rqlite"
|
||||
"github.com/DeBrosOfficial/network/pkg/serverless"
|
||||
"github.com/multiformats/go-multiaddr"
|
||||
olriclib "github.com/olric-data/olric"
|
||||
"go.uber.org/zap"
|
||||
|
||||
_ "github.com/rqlite/gorqlite/stdlib"
|
||||
)
|
||||
|
||||
const (
|
||||
olricInitMaxAttempts = 5
|
||||
olricInitInitialBackoff = 500 * time.Millisecond
|
||||
olricInitMaxBackoff = 5 * time.Second
|
||||
)
|
||||
|
||||
// Config holds configuration for the gateway server
|
||||
type Config struct {
|
||||
ListenAddr string
|
||||
ClientNamespace string
|
||||
BootstrapPeers []string
|
||||
NodePeerID string // The node's actual peer ID from its identity file
|
||||
|
||||
// Optional DSN for rqlite database/sql driver, e.g. "http://localhost:4001"
|
||||
// If empty, defaults to "http://localhost:4001".
|
||||
RQLiteDSN string
|
||||
|
||||
// HTTPS configuration
|
||||
EnableHTTPS bool // Enable HTTPS with ACME (Let's Encrypt)
|
||||
DomainName string // Domain name for HTTPS certificate
|
||||
TLSCacheDir string // Directory to cache TLS certificates (default: ~/.orama/tls-cache)
|
||||
|
||||
// Olric cache configuration
|
||||
OlricServers []string // List of Olric server addresses (e.g., ["localhost:3320"]). If empty, defaults to ["localhost:3320"]
|
||||
OlricTimeout time.Duration // Timeout for Olric operations (default: 10s)
|
||||
|
||||
// IPFS Cluster configuration
|
||||
IPFSClusterAPIURL string // IPFS Cluster HTTP API URL (e.g., "http://localhost:9094"). If empty, gateway will discover from node configs
|
||||
IPFSAPIURL string // IPFS HTTP API URL for content retrieval (e.g., "http://localhost:5001"). If empty, gateway will discover from node configs
|
||||
IPFSTimeout time.Duration // Timeout for IPFS operations (default: 60s)
|
||||
IPFSReplicationFactor int // Replication factor for pins (default: 3)
|
||||
IPFSEnableEncryption bool // Enable client-side encryption before upload (default: true, discovered from node configs)
|
||||
|
||||
// TURN/STUN configuration for WebRTC
|
||||
TURN *config.TURNConfig
|
||||
|
||||
// SFU configuration for WebRTC group calls
|
||||
SFU *config.SFUConfig
|
||||
}
|
||||
|
||||
type Gateway struct {
|
||||
logger *logging.ColoredLogger
|
||||
@ -42,29 +86,28 @@ type Gateway struct {
|
||||
// Olric cache client
|
||||
olricClient *olric.Client
|
||||
olricMu sync.RWMutex
|
||||
cacheHandlers *cache.CacheHandlers
|
||||
|
||||
// IPFS storage client
|
||||
ipfsClient ipfs.IPFSClient
|
||||
storageHandlers *storage.Handlers
|
||||
ipfsClient ipfs.IPFSClient
|
||||
|
||||
// Local pub/sub bypass for same-gateway subscribers
|
||||
localSubscribers map[string][]*localSubscriber // topic+namespace -> subscribers
|
||||
presenceMembers map[string][]PresenceMember // topicKey -> members
|
||||
mu sync.RWMutex
|
||||
presenceMu sync.RWMutex
|
||||
pubsubHandlers *pubsubhandlers.PubSubHandlers
|
||||
|
||||
// Serverless function engine
|
||||
serverlessEngine *serverless.Engine
|
||||
serverlessRegistry *serverless.Registry
|
||||
serverlessInvoker *serverless.Invoker
|
||||
serverlessWSMgr *serverless.WSManager
|
||||
serverlessHandlers *serverlesshandlers.ServerlessHandlers
|
||||
serverlessHandlers *ServerlessHandlers
|
||||
|
||||
// Authentication service
|
||||
authService *auth.Service
|
||||
authHandlers *authhandlers.Handlers
|
||||
authService *auth.Service
|
||||
|
||||
// SFU manager for WebRTC group calls
|
||||
sfuManager *SFUManager
|
||||
}
|
||||
|
||||
// localSubscriber represents a WebSocket subscriber for local message delivery
|
||||
@ -81,113 +124,359 @@ type PresenceMember struct {
|
||||
ConnID string `json:"-"` // Internal: for tracking which connection
|
||||
}
|
||||
|
||||
// authClientAdapter adapts client.NetworkClient to authhandlers.NetworkClient
|
||||
type authClientAdapter struct {
|
||||
client client.NetworkClient
|
||||
}
|
||||
|
||||
func (a *authClientAdapter) Database() authhandlers.DatabaseClient {
|
||||
return &authDatabaseAdapter{db: a.client.Database()}
|
||||
}
|
||||
|
||||
// authDatabaseAdapter adapts client.DatabaseClient to authhandlers.DatabaseClient
|
||||
type authDatabaseAdapter struct {
|
||||
db client.DatabaseClient
|
||||
}
|
||||
|
||||
func (a *authDatabaseAdapter) Query(ctx context.Context, sql string, args ...interface{}) (*authhandlers.QueryResult, error) {
|
||||
result, err := a.db.Query(ctx, sql, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Convert client.QueryResult to authhandlers.QueryResult
|
||||
// The auth handlers expect []interface{} but client returns [][]interface{}
|
||||
convertedRows := make([]interface{}, len(result.Rows))
|
||||
for i, row := range result.Rows {
|
||||
convertedRows[i] = row
|
||||
}
|
||||
return &authhandlers.QueryResult{
|
||||
Count: int(result.Count),
|
||||
Rows: convertedRows,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// New creates and initializes a new Gateway instance.
|
||||
// It establishes all necessary service connections and dependencies.
|
||||
// New creates and initializes a new Gateway instance
|
||||
func New(logger *logging.ColoredLogger, cfg *Config) (*Gateway, error) {
|
||||
logger.ComponentInfo(logging.ComponentGeneral, "Creating gateway dependencies...")
|
||||
logger.ComponentInfo(logging.ComponentGeneral, "Building client config...")
|
||||
|
||||
// Initialize all dependencies (network client, database, cache, storage, serverless)
|
||||
deps, err := NewDependencies(logger, cfg)
|
||||
// Build client config from gateway cfg
|
||||
cliCfg := client.DefaultClientConfig(cfg.ClientNamespace)
|
||||
if len(cfg.BootstrapPeers) > 0 {
|
||||
cliCfg.BootstrapPeers = cfg.BootstrapPeers
|
||||
}
|
||||
|
||||
logger.ComponentInfo(logging.ComponentGeneral, "Creating network client...")
|
||||
c, err := client.NewClient(cliCfg)
|
||||
if err != nil {
|
||||
logger.ComponentError(logging.ComponentGeneral, "failed to create dependencies", zap.Error(err))
|
||||
logger.ComponentError(logging.ComponentClient, "failed to create network client", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
logger.ComponentInfo(logging.ComponentGeneral, "Connecting network client...")
|
||||
if err := c.Connect(); err != nil {
|
||||
logger.ComponentError(logging.ComponentClient, "failed to connect network client", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
logger.ComponentInfo(logging.ComponentClient, "Network client connected",
|
||||
zap.String("namespace", cliCfg.AppName),
|
||||
zap.Int("peer_count", len(cliCfg.BootstrapPeers)),
|
||||
)
|
||||
|
||||
logger.ComponentInfo(logging.ComponentGeneral, "Creating gateway instance...")
|
||||
gw := &Gateway{
|
||||
logger: logger,
|
||||
cfg: cfg,
|
||||
client: deps.Client,
|
||||
nodePeerID: cfg.NodePeerID,
|
||||
startedAt: time.Now(),
|
||||
sqlDB: deps.SQLDB,
|
||||
ormClient: deps.ORMClient,
|
||||
ormHTTP: deps.ORMHTTP,
|
||||
olricClient: deps.OlricClient,
|
||||
ipfsClient: deps.IPFSClient,
|
||||
serverlessEngine: deps.ServerlessEngine,
|
||||
serverlessRegistry: deps.ServerlessRegistry,
|
||||
serverlessInvoker: deps.ServerlessInvoker,
|
||||
serverlessWSMgr: deps.ServerlessWSMgr,
|
||||
serverlessHandlers: deps.ServerlessHandlers,
|
||||
authService: deps.AuthService,
|
||||
localSubscribers: make(map[string][]*localSubscriber),
|
||||
presenceMembers: make(map[string][]PresenceMember),
|
||||
logger: logger,
|
||||
cfg: cfg,
|
||||
client: c,
|
||||
nodePeerID: cfg.NodePeerID,
|
||||
startedAt: time.Now(),
|
||||
localSubscribers: make(map[string][]*localSubscriber),
|
||||
presenceMembers: make(map[string][]PresenceMember),
|
||||
}
|
||||
|
||||
// Initialize handler instances
|
||||
gw.pubsubHandlers = pubsubhandlers.NewPubSubHandlers(deps.Client, logger)
|
||||
|
||||
if deps.OlricClient != nil {
|
||||
gw.cacheHandlers = cache.NewCacheHandlers(logger, deps.OlricClient)
|
||||
logger.ComponentInfo(logging.ComponentGeneral, "Initializing RQLite ORM HTTP gateway...")
|
||||
dsn := cfg.RQLiteDSN
|
||||
if dsn == "" {
|
||||
dsn = "http://localhost:5001"
|
||||
}
|
||||
db, dbErr := sql.Open("rqlite", dsn)
|
||||
if dbErr != nil {
|
||||
logger.ComponentWarn(logging.ComponentGeneral, "failed to open rqlite sql db; http orm gateway disabled", zap.Error(dbErr))
|
||||
} else {
|
||||
// Configure connection pool with proper timeouts and limits
|
||||
db.SetMaxOpenConns(25) // Maximum number of open connections
|
||||
db.SetMaxIdleConns(5) // Maximum number of idle connections
|
||||
db.SetConnMaxLifetime(5 * time.Minute) // Maximum lifetime of a connection
|
||||
db.SetConnMaxIdleTime(2 * time.Minute) // Maximum idle time before closing
|
||||
|
||||
if deps.IPFSClient != nil {
|
||||
gw.storageHandlers = storage.New(deps.IPFSClient, logger, storage.Config{
|
||||
IPFSReplicationFactor: cfg.IPFSReplicationFactor,
|
||||
IPFSAPIURL: cfg.IPFSAPIURL,
|
||||
})
|
||||
}
|
||||
|
||||
if deps.AuthService != nil {
|
||||
// Create adapter for auth handlers to use the client
|
||||
authClientAdapter := &authClientAdapter{client: deps.Client}
|
||||
gw.authHandlers = authhandlers.NewHandlers(
|
||||
logger,
|
||||
deps.AuthService,
|
||||
authClientAdapter,
|
||||
cfg.ClientNamespace,
|
||||
gw.withInternalAuth,
|
||||
gw.sqlDB = db
|
||||
orm := rqlite.NewClient(db)
|
||||
gw.ormClient = orm
|
||||
gw.ormHTTP = rqlite.NewHTTPGateway(orm, "/v1/db")
|
||||
// Set a reasonable timeout for HTTP requests (30 seconds)
|
||||
gw.ormHTTP.Timeout = 30 * time.Second
|
||||
logger.ComponentInfo(logging.ComponentGeneral, "RQLite ORM HTTP gateway ready",
|
||||
zap.String("dsn", dsn),
|
||||
zap.String("base_path", "/v1/db"),
|
||||
zap.Duration("timeout", gw.ormHTTP.Timeout),
|
||||
)
|
||||
}
|
||||
|
||||
// Start background Olric reconnection if initial connection failed
|
||||
if deps.OlricClient == nil {
|
||||
olricCfg := olric.Config{
|
||||
Servers: cfg.OlricServers,
|
||||
Timeout: cfg.OlricTimeout,
|
||||
logger.ComponentInfo(logging.ComponentGeneral, "Initializing Olric cache client...")
|
||||
|
||||
// Discover Olric servers dynamically from LibP2P peers if not explicitly configured
|
||||
olricServers := cfg.OlricServers
|
||||
if len(olricServers) == 0 {
|
||||
logger.ComponentInfo(logging.ComponentGeneral, "Olric servers not configured, discovering from LibP2P peers...")
|
||||
discovered := discoverOlricServers(c, logger.Logger)
|
||||
if len(discovered) > 0 {
|
||||
olricServers = discovered
|
||||
logger.ComponentInfo(logging.ComponentGeneral, "Discovered Olric servers from LibP2P peers",
|
||||
zap.Strings("servers", olricServers))
|
||||
} else {
|
||||
// Fallback to localhost for local development
|
||||
olricServers = []string{"localhost:3320"}
|
||||
logger.ComponentInfo(logging.ComponentGeneral, "No Olric servers discovered, using localhost fallback")
|
||||
}
|
||||
if len(olricCfg.Servers) == 0 {
|
||||
olricCfg.Servers = []string{"localhost:3320"}
|
||||
}
|
||||
gw.startOlricReconnectLoop(olricCfg)
|
||||
} else {
|
||||
logger.ComponentInfo(logging.ComponentGeneral, "Using explicitly configured Olric servers",
|
||||
zap.Strings("servers", olricServers))
|
||||
}
|
||||
|
||||
logger.ComponentInfo(logging.ComponentGeneral, "Gateway creation completed")
|
||||
olricCfg := olric.Config{
|
||||
Servers: olricServers,
|
||||
Timeout: cfg.OlricTimeout,
|
||||
}
|
||||
olricClient, olricErr := initializeOlricClientWithRetry(olricCfg, logger)
|
||||
if olricErr != nil {
|
||||
logger.ComponentWarn(logging.ComponentGeneral, "failed to initialize Olric cache client; cache endpoints disabled", zap.Error(olricErr))
|
||||
gw.startOlricReconnectLoop(olricCfg)
|
||||
} else {
|
||||
gw.setOlricClient(olricClient)
|
||||
logger.ComponentInfo(logging.ComponentGeneral, "Olric cache client ready",
|
||||
zap.Strings("servers", olricCfg.Servers),
|
||||
zap.Duration("timeout", olricCfg.Timeout),
|
||||
)
|
||||
}
|
||||
|
||||
logger.ComponentInfo(logging.ComponentGeneral, "Initializing IPFS Cluster client...")
|
||||
|
||||
// Discover IPFS endpoints from node configs if not explicitly configured
|
||||
ipfsClusterURL := cfg.IPFSClusterAPIURL
|
||||
ipfsAPIURL := cfg.IPFSAPIURL
|
||||
ipfsTimeout := cfg.IPFSTimeout
|
||||
ipfsReplicationFactor := cfg.IPFSReplicationFactor
|
||||
ipfsEnableEncryption := cfg.IPFSEnableEncryption
|
||||
|
||||
if ipfsClusterURL == "" {
|
||||
logger.ComponentInfo(logging.ComponentGeneral, "IPFS Cluster URL not configured, discovering from node configs...")
|
||||
discovered := discoverIPFSFromNodeConfigs(logger.Logger)
|
||||
if discovered.clusterURL != "" {
|
||||
ipfsClusterURL = discovered.clusterURL
|
||||
ipfsAPIURL = discovered.apiURL
|
||||
if discovered.timeout > 0 {
|
||||
ipfsTimeout = discovered.timeout
|
||||
}
|
||||
if discovered.replicationFactor > 0 {
|
||||
ipfsReplicationFactor = discovered.replicationFactor
|
||||
}
|
||||
ipfsEnableEncryption = discovered.enableEncryption
|
||||
logger.ComponentInfo(logging.ComponentGeneral, "Discovered IPFS endpoints from node configs",
|
||||
zap.String("cluster_url", ipfsClusterURL),
|
||||
zap.String("api_url", ipfsAPIURL),
|
||||
zap.Bool("encryption_enabled", ipfsEnableEncryption))
|
||||
} else {
|
||||
// Fallback to localhost defaults
|
||||
ipfsClusterURL = "http://localhost:9094"
|
||||
ipfsAPIURL = "http://localhost:5001"
|
||||
ipfsEnableEncryption = true // Default to true
|
||||
logger.ComponentInfo(logging.ComponentGeneral, "No IPFS config found in node configs, using localhost defaults")
|
||||
}
|
||||
}
|
||||
|
||||
if ipfsAPIURL == "" {
|
||||
ipfsAPIURL = "http://localhost:5001"
|
||||
}
|
||||
if ipfsTimeout == 0 {
|
||||
ipfsTimeout = 60 * time.Second
|
||||
}
|
||||
if ipfsReplicationFactor == 0 {
|
||||
ipfsReplicationFactor = 3
|
||||
}
|
||||
if !cfg.IPFSEnableEncryption && !ipfsEnableEncryption {
|
||||
// Only disable if explicitly set to false in both places
|
||||
ipfsEnableEncryption = false
|
||||
} else {
|
||||
// Default to true if not explicitly disabled
|
||||
ipfsEnableEncryption = true
|
||||
}
|
||||
|
||||
ipfsCfg := ipfs.Config{
|
||||
ClusterAPIURL: ipfsClusterURL,
|
||||
Timeout: ipfsTimeout,
|
||||
}
|
||||
ipfsClient, ipfsErr := ipfs.NewClient(ipfsCfg, logger.Logger)
|
||||
if ipfsErr != nil {
|
||||
logger.ComponentWarn(logging.ComponentGeneral, "failed to initialize IPFS Cluster client; storage endpoints disabled", zap.Error(ipfsErr))
|
||||
} else {
|
||||
gw.ipfsClient = ipfsClient
|
||||
|
||||
// Check peer count and warn if insufficient (use background context to avoid blocking)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if peerCount, err := ipfsClient.GetPeerCount(ctx); err == nil {
|
||||
if peerCount < ipfsReplicationFactor {
|
||||
logger.ComponentWarn(logging.ComponentGeneral, "insufficient cluster peers for replication factor",
|
||||
zap.Int("peer_count", peerCount),
|
||||
zap.Int("replication_factor", ipfsReplicationFactor),
|
||||
zap.String("message", "Some pin operations may fail until more peers join the cluster"))
|
||||
} else {
|
||||
logger.ComponentInfo(logging.ComponentGeneral, "IPFS Cluster peer count sufficient",
|
||||
zap.Int("peer_count", peerCount),
|
||||
zap.Int("replication_factor", ipfsReplicationFactor))
|
||||
}
|
||||
} else {
|
||||
logger.ComponentWarn(logging.ComponentGeneral, "failed to get cluster peer count", zap.Error(err))
|
||||
}
|
||||
|
||||
logger.ComponentInfo(logging.ComponentGeneral, "IPFS Cluster client ready",
|
||||
zap.String("cluster_api_url", ipfsCfg.ClusterAPIURL),
|
||||
zap.String("ipfs_api_url", ipfsAPIURL),
|
||||
zap.Duration("timeout", ipfsCfg.Timeout),
|
||||
zap.Int("replication_factor", ipfsReplicationFactor),
|
||||
zap.Bool("encryption_enabled", ipfsEnableEncryption),
|
||||
)
|
||||
}
|
||||
// Store IPFS settings in gateway for use by handlers
|
||||
gw.cfg.IPFSAPIURL = ipfsAPIURL
|
||||
gw.cfg.IPFSReplicationFactor = ipfsReplicationFactor
|
||||
gw.cfg.IPFSEnableEncryption = ipfsEnableEncryption
|
||||
|
||||
// Initialize serverless function engine
|
||||
logger.ComponentInfo(logging.ComponentGeneral, "Initializing serverless function engine...")
|
||||
if gw.ormClient != nil && gw.ipfsClient != nil {
|
||||
// Create serverless registry (stores functions in RQLite + IPFS)
|
||||
registryCfg := serverless.RegistryConfig{
|
||||
IPFSAPIURL: ipfsAPIURL,
|
||||
}
|
||||
registry := serverless.NewRegistry(gw.ormClient, gw.ipfsClient, registryCfg, logger.Logger)
|
||||
gw.serverlessRegistry = registry
|
||||
|
||||
// Create WebSocket manager for function streaming
|
||||
gw.serverlessWSMgr = serverless.NewWSManager(logger.Logger)
|
||||
|
||||
// Get underlying Olric client if available
|
||||
var olricClient olriclib.Client
|
||||
if oc := gw.getOlricClient(); oc != nil {
|
||||
olricClient = oc.UnderlyingClient()
|
||||
}
|
||||
|
||||
// Create host functions provider (allows functions to call Orama services)
|
||||
// Get pubsub adapter from client for serverless functions
|
||||
var pubsubAdapter *pubsub.ClientAdapter
|
||||
if gw.client != nil {
|
||||
if concreteClient, ok := gw.client.(*client.Client); ok {
|
||||
pubsubAdapter = concreteClient.PubSubAdapter()
|
||||
if pubsubAdapter != nil {
|
||||
logger.ComponentInfo(logging.ComponentGeneral, "pubsub adapter available for serverless functions")
|
||||
} else {
|
||||
logger.ComponentWarn(logging.ComponentGeneral, "pubsub adapter is nil - serverless pubsub will be unavailable")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
hostFuncsCfg := serverless.HostFunctionsConfig{
|
||||
IPFSAPIURL: ipfsAPIURL,
|
||||
HTTPTimeout: 30 * time.Second,
|
||||
}
|
||||
hostFuncs := serverless.NewHostFunctions(
|
||||
gw.ormClient,
|
||||
olricClient,
|
||||
gw.ipfsClient,
|
||||
pubsubAdapter, // pubsub adapter for serverless functions
|
||||
gw.serverlessWSMgr,
|
||||
nil, // secrets manager - TODO: implement
|
||||
hostFuncsCfg,
|
||||
logger.Logger,
|
||||
)
|
||||
|
||||
// Create WASM engine configuration
|
||||
engineCfg := serverless.DefaultConfig()
|
||||
engineCfg.DefaultMemoryLimitMB = 128
|
||||
engineCfg.MaxMemoryLimitMB = 256
|
||||
engineCfg.DefaultTimeoutSeconds = 30
|
||||
engineCfg.MaxTimeoutSeconds = 60
|
||||
engineCfg.ModuleCacheSize = 100
|
||||
|
||||
// Create WASM engine
|
||||
engine, engineErr := serverless.NewEngine(engineCfg, registry, hostFuncs, logger.Logger, serverless.WithInvocationLogger(registry))
|
||||
if engineErr != nil {
|
||||
logger.ComponentWarn(logging.ComponentGeneral, "failed to initialize serverless engine; functions disabled", zap.Error(engineErr))
|
||||
} else {
|
||||
gw.serverlessEngine = engine
|
||||
|
||||
// Create invoker
|
||||
gw.serverlessInvoker = serverless.NewInvoker(engine, registry, hostFuncs, logger.Logger)
|
||||
|
||||
// Create trigger manager
|
||||
triggerManager := serverless.NewDBTriggerManager(gw.ormClient, logger.Logger)
|
||||
|
||||
// Create HTTP handlers
|
||||
gw.serverlessHandlers = NewServerlessHandlers(
|
||||
gw.serverlessInvoker,
|
||||
registry,
|
||||
gw.serverlessWSMgr,
|
||||
triggerManager,
|
||||
logger.Logger,
|
||||
)
|
||||
|
||||
// Initialize auth service
|
||||
// For now using ephemeral key, can be loaded from config later
|
||||
key, _ := rsa.GenerateKey(rand.Reader, 2048)
|
||||
keyPEM := pem.EncodeToMemory(&pem.Block{
|
||||
Type: "RSA PRIVATE KEY",
|
||||
Bytes: x509.MarshalPKCS1PrivateKey(key),
|
||||
})
|
||||
authService, err := auth.NewService(logger, c, string(keyPEM), cfg.ClientNamespace)
|
||||
if err != nil {
|
||||
logger.ComponentError(logging.ComponentGeneral, "failed to initialize auth service", zap.Error(err))
|
||||
} else {
|
||||
gw.authService = authService
|
||||
}
|
||||
|
||||
logger.ComponentInfo(logging.ComponentGeneral, "Serverless function engine ready",
|
||||
zap.Int("default_memory_mb", engineCfg.DefaultMemoryLimitMB),
|
||||
zap.Int("default_timeout_sec", engineCfg.DefaultTimeoutSeconds),
|
||||
zap.Int("module_cache_size", engineCfg.ModuleCacheSize),
|
||||
)
|
||||
}
|
||||
} else {
|
||||
logger.ComponentWarn(logging.ComponentGeneral, "serverless engine requires RQLite and IPFS; functions disabled")
|
||||
}
|
||||
|
||||
// Initialize SFU manager for WebRTC calls
|
||||
if err := gw.initializeSFUManager(); err != nil {
|
||||
logger.ComponentWarn(logging.ComponentGeneral, "failed to initialize SFU manager; WebRTC calls disabled", zap.Error(err))
|
||||
}
|
||||
|
||||
logger.ComponentInfo(logging.ComponentGeneral, "Gateway creation completed, returning...")
|
||||
return gw, nil
|
||||
}
|
||||
|
||||
// withInternalAuth creates a context for internal gateway operations that bypass authentication
|
||||
func (g *Gateway) withInternalAuth(ctx context.Context) context.Context {
|
||||
return client.WithInternalAuth(ctx)
|
||||
}
|
||||
|
||||
// Close disconnects the gateway client
|
||||
func (g *Gateway) Close() {
|
||||
// Close SFU manager first
|
||||
if g.sfuManager != nil {
|
||||
if err := g.sfuManager.Close(); err != nil {
|
||||
g.logger.ComponentWarn(logging.ComponentGeneral, "error during SFU manager close", zap.Error(err))
|
||||
}
|
||||
}
|
||||
// Close serverless engine
|
||||
if g.serverlessEngine != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
if err := g.serverlessEngine.Close(ctx); err != nil {
|
||||
g.logger.ComponentWarn(logging.ComponentGeneral, "error during serverless engine close", zap.Error(err))
|
||||
}
|
||||
cancel()
|
||||
}
|
||||
if g.client != nil {
|
||||
if err := g.client.Disconnect(); err != nil {
|
||||
g.logger.ComponentWarn(logging.ComponentClient, "error during client disconnect", zap.Error(err))
|
||||
}
|
||||
}
|
||||
if g.sqlDB != nil {
|
||||
_ = g.sqlDB.Close()
|
||||
}
|
||||
if client := g.getOlricClient(); client != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := client.Close(ctx); err != nil {
|
||||
g.logger.ComponentWarn(logging.ComponentGeneral, "error during Olric client close", zap.Error(err))
|
||||
}
|
||||
}
|
||||
if g.ipfsClient != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := g.ipfsClient.Close(ctx); err != nil {
|
||||
g.logger.ComponentWarn(logging.ComponentGeneral, "error during IPFS client close", zap.Error(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// getLocalSubscribers returns all local subscribers for a given topic and namespace
|
||||
func (g *Gateway) getLocalSubscribers(topic, namespace string) []*localSubscriber {
|
||||
topicKey := namespace + "." + topic
|
||||
@ -197,32 +486,23 @@ func (g *Gateway) getLocalSubscribers(topic, namespace string) []*localSubscribe
|
||||
return nil
|
||||
}
|
||||
|
||||
// setOlricClient atomically sets the Olric client and reinitializes cache handlers.
|
||||
func (g *Gateway) setOlricClient(client *olric.Client) {
|
||||
g.olricMu.Lock()
|
||||
defer g.olricMu.Unlock()
|
||||
g.olricClient = client
|
||||
if client != nil {
|
||||
g.cacheHandlers = cache.NewCacheHandlers(g.logger, client)
|
||||
}
|
||||
}
|
||||
|
||||
// getOlricClient atomically retrieves the current Olric client.
|
||||
func (g *Gateway) getOlricClient() *olric.Client {
|
||||
g.olricMu.RLock()
|
||||
defer g.olricMu.RUnlock()
|
||||
return g.olricClient
|
||||
}
|
||||
|
||||
// startOlricReconnectLoop starts a background goroutine that continuously attempts
|
||||
// to reconnect to the Olric cluster with exponential backoff.
|
||||
func (g *Gateway) startOlricReconnectLoop(cfg olric.Config) {
|
||||
go func() {
|
||||
retryDelay := 5 * time.Second
|
||||
maxBackoff := 30 * time.Second
|
||||
|
||||
for {
|
||||
client, err := olric.NewClient(cfg, g.logger.Logger)
|
||||
client, err := initializeOlricClientWithRetry(cfg, g.logger)
|
||||
if err == nil {
|
||||
g.setOlricClient(client)
|
||||
g.logger.ComponentInfo(logging.ComponentGeneral, "Olric cache client connected after background retries",
|
||||
@ -236,13 +516,211 @@ func (g *Gateway) startOlricReconnectLoop(cfg olric.Config) {
|
||||
zap.Error(err))
|
||||
|
||||
time.Sleep(retryDelay)
|
||||
if retryDelay < maxBackoff {
|
||||
if retryDelay < olricInitMaxBackoff {
|
||||
retryDelay *= 2
|
||||
if retryDelay > maxBackoff {
|
||||
retryDelay = maxBackoff
|
||||
if retryDelay > olricInitMaxBackoff {
|
||||
retryDelay = olricInitMaxBackoff
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func initializeOlricClientWithRetry(cfg olric.Config, logger *logging.ColoredLogger) (*olric.Client, error) {
|
||||
backoff := olricInitInitialBackoff
|
||||
|
||||
for attempt := 1; attempt <= olricInitMaxAttempts; attempt++ {
|
||||
client, err := olric.NewClient(cfg, logger.Logger)
|
||||
if err == nil {
|
||||
if attempt > 1 {
|
||||
logger.ComponentInfo(logging.ComponentGeneral, "Olric cache client initialized after retries",
|
||||
zap.Int("attempts", attempt))
|
||||
}
|
||||
return client, nil
|
||||
}
|
||||
|
||||
logger.ComponentWarn(logging.ComponentGeneral, "Olric cache client init attempt failed",
|
||||
zap.Int("attempt", attempt),
|
||||
zap.Duration("retry_in", backoff),
|
||||
zap.Error(err))
|
||||
|
||||
if attempt == olricInitMaxAttempts {
|
||||
return nil, fmt.Errorf("failed to initialize Olric cache client after %d attempts: %w", attempt, err)
|
||||
}
|
||||
|
||||
time.Sleep(backoff)
|
||||
backoff *= 2
|
||||
if backoff > olricInitMaxBackoff {
|
||||
backoff = olricInitMaxBackoff
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("failed to initialize Olric cache client")
|
||||
}
|
||||
|
||||
// discoverOlricServers discovers Olric server addresses from LibP2P peers
|
||||
// Returns a list of IP:port addresses where Olric servers are expected to run (port 3320)
|
||||
func discoverOlricServers(networkClient client.NetworkClient, logger *zap.Logger) []string {
|
||||
// Get network info to access peer information
|
||||
networkInfo := networkClient.Network()
|
||||
if networkInfo == nil {
|
||||
logger.Debug("Network info not available for Olric discovery")
|
||||
return nil
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
peers, err := networkInfo.GetPeers(ctx)
|
||||
if err != nil {
|
||||
logger.Debug("Failed to get peers for Olric discovery", zap.Error(err))
|
||||
return nil
|
||||
}
|
||||
|
||||
olricServers := make([]string, 0)
|
||||
seen := make(map[string]bool)
|
||||
|
||||
for _, peer := range peers {
|
||||
for _, addrStr := range peer.Addresses {
|
||||
// Parse multiaddr
|
||||
ma, err := multiaddr.NewMultiaddr(addrStr)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Extract IP address
|
||||
var ip string
|
||||
if ipv4, err := ma.ValueForProtocol(multiaddr.P_IP4); err == nil && ipv4 != "" {
|
||||
ip = ipv4
|
||||
} else if ipv6, err := ma.ValueForProtocol(multiaddr.P_IP6); err == nil && ipv6 != "" {
|
||||
ip = ipv6
|
||||
} else {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip localhost loopback addresses (we'll use localhost:3320 as fallback)
|
||||
if ip == "localhost" || ip == "::1" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Build Olric server address (standard port 3320)
|
||||
olricAddr := net.JoinHostPort(ip, "3320")
|
||||
if !seen[olricAddr] {
|
||||
olricServers = append(olricServers, olricAddr)
|
||||
seen[olricAddr] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Also check peers from config
|
||||
if cfg := networkClient.Config(); cfg != nil {
|
||||
for _, peerAddr := range cfg.BootstrapPeers {
|
||||
ma, err := multiaddr.NewMultiaddr(peerAddr)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
var ip string
|
||||
if ipv4, err := ma.ValueForProtocol(multiaddr.P_IP4); err == nil && ipv4 != "" {
|
||||
ip = ipv4
|
||||
} else if ipv6, err := ma.ValueForProtocol(multiaddr.P_IP6); err == nil && ipv6 != "" {
|
||||
ip = ipv6
|
||||
} else {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip localhost
|
||||
if ip == "localhost" || ip == "::1" {
|
||||
continue
|
||||
}
|
||||
|
||||
olricAddr := net.JoinHostPort(ip, "3320")
|
||||
if !seen[olricAddr] {
|
||||
olricServers = append(olricServers, olricAddr)
|
||||
seen[olricAddr] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If we found servers, log them
|
||||
if len(olricServers) > 0 {
|
||||
logger.Info("Discovered Olric servers from LibP2P network",
|
||||
zap.Strings("servers", olricServers))
|
||||
}
|
||||
|
||||
return olricServers
|
||||
}
|
||||
|
||||
// ipfsDiscoveryResult holds discovered IPFS configuration
|
||||
type ipfsDiscoveryResult struct {
|
||||
clusterURL string
|
||||
apiURL string
|
||||
timeout time.Duration
|
||||
replicationFactor int
|
||||
enableEncryption bool
|
||||
}
|
||||
|
||||
// discoverIPFSFromNodeConfigs discovers IPFS configuration from node.yaml files
|
||||
// Checks node-1.yaml through node-5.yaml for IPFS configuration
|
||||
func discoverIPFSFromNodeConfigs(logger *zap.Logger) ipfsDiscoveryResult {
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
logger.Debug("Failed to get home directory for IPFS discovery", zap.Error(err))
|
||||
return ipfsDiscoveryResult{}
|
||||
}
|
||||
|
||||
configDir := filepath.Join(homeDir, ".orama")
|
||||
|
||||
// Try all node config files for IPFS settings
|
||||
configFiles := []string{"node-1.yaml", "node-2.yaml", "node-3.yaml", "node-4.yaml", "node-5.yaml"}
|
||||
|
||||
for _, filename := range configFiles {
|
||||
configPath := filepath.Join(configDir, filename)
|
||||
data, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
var nodeCfg config.Config
|
||||
if err := config.DecodeStrict(strings.NewReader(string(data)), &nodeCfg); err != nil {
|
||||
logger.Debug("Failed to parse node config for IPFS discovery",
|
||||
zap.String("file", filename), zap.Error(err))
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if IPFS is configured
|
||||
if nodeCfg.Database.IPFS.ClusterAPIURL != "" {
|
||||
result := ipfsDiscoveryResult{
|
||||
clusterURL: nodeCfg.Database.IPFS.ClusterAPIURL,
|
||||
apiURL: nodeCfg.Database.IPFS.APIURL,
|
||||
timeout: nodeCfg.Database.IPFS.Timeout,
|
||||
replicationFactor: nodeCfg.Database.IPFS.ReplicationFactor,
|
||||
enableEncryption: nodeCfg.Database.IPFS.EnableEncryption,
|
||||
}
|
||||
|
||||
if result.apiURL == "" {
|
||||
result.apiURL = "http://localhost:5001"
|
||||
}
|
||||
if result.timeout == 0 {
|
||||
result.timeout = 60 * time.Second
|
||||
}
|
||||
if result.replicationFactor == 0 {
|
||||
result.replicationFactor = 3
|
||||
}
|
||||
// Default encryption to true if not set
|
||||
if !result.enableEncryption {
|
||||
result.enableEncryption = true
|
||||
}
|
||||
|
||||
logger.Info("Discovered IPFS config from node config",
|
||||
zap.String("file", filename),
|
||||
zap.String("cluster_url", result.clusterURL),
|
||||
zap.String("api_url", result.apiURL),
|
||||
zap.Bool("encryption_enabled", result.enableEncryption))
|
||||
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
||||
return ipfsDiscoveryResult{}
|
||||
}
|
||||
|
||||
@ -191,6 +191,16 @@ func isPublicPath(p string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// TURN credentials (public for development - requires secret for actual use)
|
||||
if strings.HasPrefix(p, "/v1/turn/") {
|
||||
return true
|
||||
}
|
||||
|
||||
// SFU endpoints (public for development)
|
||||
if strings.HasPrefix(p, "/v1/sfu/") {
|
||||
return true
|
||||
}
|
||||
|
||||
switch p {
|
||||
case "/health", "/v1/health", "/status", "/v1/status", "/v1/auth/jwks", "/.well-known/jwks.json", "/v1/version", "/v1/auth/login", "/v1/auth/challenge", "/v1/auth/verify", "/v1/auth/register", "/v1/auth/refresh", "/v1/auth/logout", "/v1/auth/api-key", "/v1/auth/simple-key", "/v1/network/status", "/v1/network/peers":
|
||||
return true
|
||||
|
||||
475
pkg/gateway/pubsub_handlers.go
Normal file
475
pkg/gateway/pubsub_handlers.go
Normal file
@ -0,0 +1,475 @@
|
||||
package gateway
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/DeBrosOfficial/network/pkg/client"
|
||||
"github.com/DeBrosOfficial/network/pkg/pubsub"
|
||||
"github.com/google/uuid"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
var wsUpgrader = websocket.Upgrader{
|
||||
ReadBufferSize: 1024,
|
||||
WriteBufferSize: 1024,
|
||||
// For early development we accept any origin; tighten later.
|
||||
CheckOrigin: func(r *http.Request) bool { return true },
|
||||
}
|
||||
|
||||
// pubsubWebsocketHandler upgrades to WS, subscribes to a namespaced topic, and
|
||||
// forwards received PubSub messages to the client. Messages sent by the client
|
||||
// are published to the same namespaced topic.
|
||||
func (g *Gateway) pubsubWebsocketHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if g.client == nil {
|
||||
g.logger.ComponentWarn("gateway", "pubsub ws: client not initialized")
|
||||
writeError(w, http.StatusServiceUnavailable, "client not initialized")
|
||||
return
|
||||
}
|
||||
if r.Method != http.MethodGet {
|
||||
g.logger.ComponentWarn("gateway", "pubsub ws: method not allowed")
|
||||
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
|
||||
return
|
||||
}
|
||||
|
||||
// Resolve namespace from auth context
|
||||
ns := resolveNamespaceFromRequest(r)
|
||||
if ns == "" {
|
||||
g.logger.ComponentWarn("gateway", "pubsub ws: namespace not resolved")
|
||||
writeError(w, http.StatusForbidden, "namespace not resolved")
|
||||
return
|
||||
}
|
||||
|
||||
topic := r.URL.Query().Get("topic")
|
||||
if topic == "" {
|
||||
g.logger.ComponentWarn("gateway", "pubsub ws: missing topic")
|
||||
writeError(w, http.StatusBadRequest, "missing 'topic'")
|
||||
return
|
||||
}
|
||||
|
||||
// Presence handling
|
||||
enablePresence := r.URL.Query().Get("presence") == "true"
|
||||
memberID := r.URL.Query().Get("member_id")
|
||||
memberMetaStr := r.URL.Query().Get("member_meta")
|
||||
var memberMeta map[string]interface{}
|
||||
if memberMetaStr != "" {
|
||||
_ = json.Unmarshal([]byte(memberMetaStr), &memberMeta)
|
||||
}
|
||||
|
||||
if enablePresence && memberID == "" {
|
||||
g.logger.ComponentWarn("gateway", "pubsub ws: presence enabled but missing member_id")
|
||||
writeError(w, http.StatusBadRequest, "missing 'member_id' for presence")
|
||||
return
|
||||
}
|
||||
|
||||
conn, err := wsUpgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
g.logger.ComponentWarn("gateway", "pubsub ws: upgrade failed")
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Channel to deliver PubSub messages to WS writer
|
||||
msgs := make(chan []byte, 128)
|
||||
|
||||
// NEW: Register as local subscriber for direct message delivery
|
||||
localSub := &localSubscriber{
|
||||
msgChan: msgs,
|
||||
namespace: ns,
|
||||
}
|
||||
topicKey := fmt.Sprintf("%s.%s", ns, topic)
|
||||
|
||||
g.mu.Lock()
|
||||
g.localSubscribers[topicKey] = append(g.localSubscribers[topicKey], localSub)
|
||||
subscriberCount := len(g.localSubscribers[topicKey])
|
||||
g.mu.Unlock()
|
||||
|
||||
connID := uuid.New().String()
|
||||
if enablePresence {
|
||||
member := PresenceMember{
|
||||
MemberID: memberID,
|
||||
JoinedAt: time.Now().Unix(),
|
||||
Meta: memberMeta,
|
||||
ConnID: connID,
|
||||
}
|
||||
|
||||
g.presenceMu.Lock()
|
||||
g.presenceMembers[topicKey] = append(g.presenceMembers[topicKey], member)
|
||||
g.presenceMu.Unlock()
|
||||
|
||||
// Broadcast join event
|
||||
joinEvent := map[string]interface{}{
|
||||
"type": "presence.join",
|
||||
"member_id": memberID,
|
||||
"meta": memberMeta,
|
||||
"timestamp": member.JoinedAt,
|
||||
}
|
||||
eventData, _ := json.Marshal(joinEvent)
|
||||
// Use a background context for the broadcast to ensure it finishes even if the connection closes immediately
|
||||
broadcastCtx := pubsub.WithNamespace(client.WithInternalAuth(context.Background()), ns)
|
||||
_ = g.client.PubSub().Publish(broadcastCtx, topic, eventData)
|
||||
|
||||
g.logger.ComponentInfo("gateway", "pubsub ws: member joined presence",
|
||||
zap.String("topic", topic),
|
||||
zap.String("member_id", memberID))
|
||||
}
|
||||
|
||||
g.logger.ComponentInfo("gateway", "pubsub ws: registered local subscriber",
|
||||
zap.String("topic", topic),
|
||||
zap.String("namespace", ns),
|
||||
zap.Int("total_subscribers", subscriberCount))
|
||||
|
||||
// Unregister on close
|
||||
defer func() {
|
||||
g.mu.Lock()
|
||||
subs := g.localSubscribers[topicKey]
|
||||
for i, sub := range subs {
|
||||
if sub == localSub {
|
||||
g.localSubscribers[topicKey] = append(subs[:i], subs[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
remainingCount := len(g.localSubscribers[topicKey])
|
||||
if remainingCount == 0 {
|
||||
delete(g.localSubscribers, topicKey)
|
||||
}
|
||||
g.mu.Unlock()
|
||||
|
||||
if enablePresence {
|
||||
g.presenceMu.Lock()
|
||||
members := g.presenceMembers[topicKey]
|
||||
for i, m := range members {
|
||||
if m.ConnID == connID {
|
||||
g.presenceMembers[topicKey] = append(members[:i], members[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
if len(g.presenceMembers[topicKey]) == 0 {
|
||||
delete(g.presenceMembers, topicKey)
|
||||
}
|
||||
g.presenceMu.Unlock()
|
||||
|
||||
// Broadcast leave event
|
||||
leaveEvent := map[string]interface{}{
|
||||
"type": "presence.leave",
|
||||
"member_id": memberID,
|
||||
"timestamp": time.Now().Unix(),
|
||||
}
|
||||
eventData, _ := json.Marshal(leaveEvent)
|
||||
broadcastCtx := pubsub.WithNamespace(client.WithInternalAuth(context.Background()), ns)
|
||||
_ = g.client.PubSub().Publish(broadcastCtx, topic, eventData)
|
||||
|
||||
g.logger.ComponentInfo("gateway", "pubsub ws: member left presence",
|
||||
zap.String("topic", topic),
|
||||
zap.String("member_id", memberID))
|
||||
}
|
||||
|
||||
g.logger.ComponentInfo("gateway", "pubsub ws: unregistered local subscriber",
|
||||
zap.String("topic", topic),
|
||||
zap.Int("remaining_subscribers", remainingCount))
|
||||
}()
|
||||
|
||||
// Use internal auth context when interacting with client to avoid circular auth requirements
|
||||
ctx := client.WithInternalAuth(r.Context())
|
||||
// Apply namespace isolation
|
||||
ctx = pubsub.WithNamespace(ctx, ns)
|
||||
|
||||
// Writer loop - START THIS FIRST before libp2p subscription
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
g.logger.ComponentInfo("gateway", "pubsub ws: writer goroutine started",
|
||||
zap.String("topic", topic))
|
||||
defer g.logger.ComponentInfo("gateway", "pubsub ws: writer goroutine exiting",
|
||||
zap.String("topic", topic))
|
||||
ticker := time.NewTicker(30 * time.Second)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case b, ok := <-msgs:
|
||||
if !ok {
|
||||
g.logger.ComponentWarn("gateway", "pubsub ws: message channel closed",
|
||||
zap.String("topic", topic))
|
||||
_ = conn.WriteControl(websocket.CloseMessage, []byte{}, time.Now().Add(5*time.Second))
|
||||
close(done)
|
||||
return
|
||||
}
|
||||
|
||||
g.logger.ComponentInfo("gateway", "pubsub ws: sending message to client",
|
||||
zap.String("topic", topic),
|
||||
zap.Int("data_len", len(b)))
|
||||
|
||||
// Format message as JSON envelope with data (base64 encoded), timestamp, and topic
|
||||
// This matches the SDK's Message interface: {data: string, timestamp: number, topic: string}
|
||||
envelope := map[string]interface{}{
|
||||
"data": base64.StdEncoding.EncodeToString(b),
|
||||
"timestamp": time.Now().UnixMilli(),
|
||||
"topic": topic,
|
||||
}
|
||||
envelopeJSON, err := json.Marshal(envelope)
|
||||
if err != nil {
|
||||
g.logger.ComponentWarn("gateway", "pubsub ws: failed to marshal envelope",
|
||||
zap.String("topic", topic),
|
||||
zap.Error(err))
|
||||
continue
|
||||
}
|
||||
|
||||
g.logger.ComponentDebug("gateway", "pubsub ws: envelope created",
|
||||
zap.String("topic", topic),
|
||||
zap.Int("envelope_len", len(envelopeJSON)))
|
||||
|
||||
conn.SetWriteDeadline(time.Now().Add(30 * time.Second))
|
||||
if err := conn.WriteMessage(websocket.TextMessage, envelopeJSON); err != nil {
|
||||
g.logger.ComponentWarn("gateway", "pubsub ws: failed to write to websocket",
|
||||
zap.String("topic", topic),
|
||||
zap.Error(err))
|
||||
close(done)
|
||||
return
|
||||
}
|
||||
|
||||
g.logger.ComponentInfo("gateway", "pubsub ws: message sent successfully",
|
||||
zap.String("topic", topic))
|
||||
case <-ticker.C:
|
||||
// Ping keepalive
|
||||
_ = conn.WriteControl(websocket.PingMessage, []byte("ping"), time.Now().Add(5*time.Second))
|
||||
case <-ctx.Done():
|
||||
close(done)
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Subscribe to libp2p for cross-node messages (in background, non-blocking)
|
||||
go func() {
|
||||
h := func(_ string, data []byte) error {
|
||||
g.logger.ComponentInfo("gateway", "pubsub ws: received message from libp2p",
|
||||
zap.String("topic", topic),
|
||||
zap.Int("data_len", len(data)))
|
||||
|
||||
select {
|
||||
case msgs <- data:
|
||||
g.logger.ComponentInfo("gateway", "pubsub ws: forwarded to client",
|
||||
zap.String("topic", topic),
|
||||
zap.String("source", "libp2p"))
|
||||
return nil
|
||||
default:
|
||||
// Drop if client is slow to avoid blocking network
|
||||
g.logger.ComponentWarn("gateway", "pubsub ws: client slow, dropping message",
|
||||
zap.String("topic", topic))
|
||||
return nil
|
||||
}
|
||||
}
|
||||
if err := g.client.PubSub().Subscribe(ctx, topic, h); err != nil {
|
||||
g.logger.ComponentWarn("gateway", "pubsub ws: libp2p subscribe failed (will use local-only)",
|
||||
zap.String("topic", topic),
|
||||
zap.Error(err))
|
||||
return
|
||||
}
|
||||
g.logger.ComponentInfo("gateway", "pubsub ws: libp2p subscription established",
|
||||
zap.String("topic", topic))
|
||||
|
||||
// Keep subscription alive until done
|
||||
<-done
|
||||
_ = g.client.PubSub().Unsubscribe(ctx, topic)
|
||||
g.logger.ComponentInfo("gateway", "pubsub ws: libp2p subscription closed",
|
||||
zap.String("topic", topic))
|
||||
}()
|
||||
|
||||
// Reader loop: treat any client message as publish to the same topic
|
||||
for {
|
||||
mt, data, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
if mt != websocket.TextMessage && mt != websocket.BinaryMessage {
|
||||
continue
|
||||
}
|
||||
|
||||
// Filter out WebSocket heartbeat messages
|
||||
// Don't publish them to the topic
|
||||
var msg map[string]interface{}
|
||||
if err := json.Unmarshal(data, &msg); err == nil {
|
||||
if msgType, ok := msg["type"].(string); ok && msgType == "ping" {
|
||||
g.logger.ComponentInfo("gateway", "pubsub ws: filtering out heartbeat ping")
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if err := g.client.PubSub().Publish(ctx, topic, data); err != nil {
|
||||
// Best-effort notify client
|
||||
_ = conn.WriteMessage(websocket.TextMessage, []byte("publish_error"))
|
||||
}
|
||||
}
|
||||
<-done
|
||||
}
|
||||
|
||||
// pubsubPublishHandler handles POST /v1/pubsub/publish {topic, data_base64}
|
||||
func (g *Gateway) pubsubPublishHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if g.client == nil {
|
||||
writeError(w, http.StatusServiceUnavailable, "client not initialized")
|
||||
return
|
||||
}
|
||||
if r.Method != http.MethodPost {
|
||||
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
|
||||
return
|
||||
}
|
||||
ns := resolveNamespaceFromRequest(r)
|
||||
if ns == "" {
|
||||
writeError(w, http.StatusForbidden, "namespace not resolved")
|
||||
return
|
||||
}
|
||||
var body struct {
|
||||
Topic string `json:"topic"`
|
||||
DataB64 string `json:"data_base64"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&body); err != nil || body.Topic == "" || body.DataB64 == "" {
|
||||
writeError(w, http.StatusBadRequest, "invalid body: expected {topic,data_base64}")
|
||||
return
|
||||
}
|
||||
data, err := base64.StdEncoding.DecodeString(body.DataB64)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid base64 data")
|
||||
return
|
||||
}
|
||||
|
||||
// NEW: Check for local websocket subscribers FIRST and deliver directly
|
||||
g.mu.RLock()
|
||||
localSubs := g.getLocalSubscribers(body.Topic, ns)
|
||||
g.mu.RUnlock()
|
||||
|
||||
localDeliveryCount := 0
|
||||
if len(localSubs) > 0 {
|
||||
for _, sub := range localSubs {
|
||||
select {
|
||||
case sub.msgChan <- data:
|
||||
localDeliveryCount++
|
||||
g.logger.ComponentDebug("gateway", "delivered to local subscriber",
|
||||
zap.String("topic", body.Topic))
|
||||
default:
|
||||
// Drop if buffer full
|
||||
g.logger.ComponentWarn("gateway", "local subscriber buffer full, dropping message",
|
||||
zap.String("topic", body.Topic))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
g.logger.ComponentInfo("gateway", "pubsub publish: processing message",
|
||||
zap.String("topic", body.Topic),
|
||||
zap.String("namespace", ns),
|
||||
zap.Int("data_len", len(data)),
|
||||
zap.Int("local_subscribers", len(localSubs)),
|
||||
zap.Int("local_delivered", localDeliveryCount))
|
||||
|
||||
// Publish to libp2p asynchronously for cross-node delivery
|
||||
// This prevents blocking the HTTP response if libp2p network is slow
|
||||
go func() {
|
||||
publishCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
ctx := pubsub.WithNamespace(client.WithInternalAuth(publishCtx), ns)
|
||||
if err := g.client.PubSub().Publish(ctx, body.Topic, data); err != nil {
|
||||
g.logger.ComponentWarn("gateway", "async libp2p publish failed",
|
||||
zap.String("topic", body.Topic),
|
||||
zap.Error(err))
|
||||
} else {
|
||||
g.logger.ComponentDebug("gateway", "async libp2p publish succeeded",
|
||||
zap.String("topic", body.Topic))
|
||||
}
|
||||
}()
|
||||
|
||||
// Return immediately after local delivery
|
||||
// Local WebSocket subscribers already received the message
|
||||
writeJSON(w, http.StatusOK, map[string]any{"status": "ok"})
|
||||
}
|
||||
|
||||
// pubsubTopicsHandler lists topics within the caller's namespace
|
||||
func (g *Gateway) pubsubTopicsHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if g.client == nil {
|
||||
writeError(w, http.StatusServiceUnavailable, "client not initialized")
|
||||
return
|
||||
}
|
||||
ns := resolveNamespaceFromRequest(r)
|
||||
if ns == "" {
|
||||
writeError(w, http.StatusForbidden, "namespace not resolved")
|
||||
return
|
||||
}
|
||||
// Apply namespace isolation
|
||||
ctx := pubsub.WithNamespace(client.WithInternalAuth(r.Context()), ns)
|
||||
all, err := g.client.PubSub().ListTopics(ctx)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, err.Error())
|
||||
return
|
||||
}
|
||||
// Client returns topics already trimmed to its namespace; return as-is
|
||||
writeJSON(w, http.StatusOK, map[string]any{"topics": all})
|
||||
}
|
||||
|
||||
// resolveNamespaceFromRequest gets namespace from context set by auth middleware
|
||||
// Falls back to query parameter "namespace" for development/testing
|
||||
func resolveNamespaceFromRequest(r *http.Request) string {
|
||||
if v := r.Context().Value(ctxKeyNamespaceOverride); v != nil {
|
||||
if s, ok := v.(string); ok {
|
||||
return s
|
||||
}
|
||||
}
|
||||
// Fallback: check query parameter for development
|
||||
if ns := strings.TrimSpace(r.URL.Query().Get("namespace")); ns != "" {
|
||||
return ns
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func namespacePrefix(ns string) string {
|
||||
return "ns::" + ns + "::"
|
||||
}
|
||||
|
||||
func namespacedTopic(ns, topic string) string {
|
||||
return namespacePrefix(ns) + topic
|
||||
}
|
||||
|
||||
// pubsubPresenceHandler handles GET /v1/pubsub/presence?topic=mytopic
|
||||
func (g *Gateway) pubsubPresenceHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
|
||||
return
|
||||
}
|
||||
|
||||
ns := resolveNamespaceFromRequest(r)
|
||||
if ns == "" {
|
||||
writeError(w, http.StatusForbidden, "namespace not resolved")
|
||||
return
|
||||
}
|
||||
|
||||
topic := r.URL.Query().Get("topic")
|
||||
if topic == "" {
|
||||
writeError(w, http.StatusBadRequest, "missing 'topic'")
|
||||
return
|
||||
}
|
||||
|
||||
topicKey := fmt.Sprintf("%s.%s", ns, topic)
|
||||
|
||||
g.presenceMu.RLock()
|
||||
members, ok := g.presenceMembers[topicKey]
|
||||
g.presenceMu.RUnlock()
|
||||
|
||||
if !ok {
|
||||
writeJSON(w, http.StatusOK, map[string]any{
|
||||
"topic": topic,
|
||||
"members": []PresenceMember{},
|
||||
"count": 0,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, map[string]any{
|
||||
"topic": topic,
|
||||
"members": members,
|
||||
"count": len(members),
|
||||
})
|
||||
}
|
||||
@ -79,5 +79,14 @@ func (g *Gateway) Routes() http.Handler {
|
||||
g.serverlessHandlers.RegisterRoutes(mux)
|
||||
}
|
||||
|
||||
// TURN credentials for WebRTC
|
||||
mux.HandleFunc("/v1/turn/credentials", g.turnCredentialsHandler)
|
||||
|
||||
// SFU endpoints for WebRTC group calls (if enabled)
|
||||
if g.sfuManager != nil {
|
||||
mux.HandleFunc("/v1/sfu/room", g.sfuCreateRoomHandler)
|
||||
mux.HandleFunc("/v1/sfu/room/", g.sfuRoomHandler) // Handles :roomId/* paths
|
||||
}
|
||||
|
||||
return g.withMiddleware(mux)
|
||||
}
|
||||
|
||||
933
pkg/gateway/serverless_handlers.go
Normal file
933
pkg/gateway/serverless_handlers.go
Normal file
@ -0,0 +1,933 @@
|
||||
package gateway
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/DeBrosOfficial/network/pkg/gateway/auth"
|
||||
"github.com/DeBrosOfficial/network/pkg/serverless"
|
||||
"github.com/google/uuid"
|
||||
"github.com/gorilla/websocket"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// ServerlessHandlers contains handlers for serverless function endpoints.
|
||||
// It's a separate struct to keep the Gateway struct clean.
|
||||
type ServerlessHandlers struct {
|
||||
invoker *serverless.Invoker
|
||||
registry serverless.FunctionRegistry
|
||||
wsManager *serverless.WSManager
|
||||
triggerManager serverless.TriggerManager
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewServerlessHandlers creates a new ServerlessHandlers instance.
|
||||
func NewServerlessHandlers(
|
||||
invoker *serverless.Invoker,
|
||||
registry serverless.FunctionRegistry,
|
||||
wsManager *serverless.WSManager,
|
||||
triggerManager serverless.TriggerManager,
|
||||
logger *zap.Logger,
|
||||
) *ServerlessHandlers {
|
||||
return &ServerlessHandlers{
|
||||
invoker: invoker,
|
||||
registry: registry,
|
||||
wsManager: wsManager,
|
||||
triggerManager: triggerManager,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterRoutes registers all serverless routes on the given mux.
|
||||
func (h *ServerlessHandlers) RegisterRoutes(mux *http.ServeMux) {
|
||||
// Function management
|
||||
mux.HandleFunc("/v1/functions", h.handleFunctions)
|
||||
mux.HandleFunc("/v1/functions/", h.handleFunctionByName)
|
||||
|
||||
// Direct invoke endpoint
|
||||
mux.HandleFunc("/v1/invoke/", h.handleInvoke)
|
||||
}
|
||||
|
||||
// handleFunctions handles GET /v1/functions (list) and POST /v1/functions (deploy)
|
||||
func (h *ServerlessHandlers) handleFunctions(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.Method {
|
||||
case http.MethodGet:
|
||||
h.listFunctions(w, r)
|
||||
case http.MethodPost:
|
||||
h.deployFunction(w, r)
|
||||
default:
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
}
|
||||
}
|
||||
|
||||
// handleFunctionByName handles operations on a specific function
|
||||
// Routes:
|
||||
// - GET /v1/functions/{name} - Get function info
|
||||
// - DELETE /v1/functions/{name} - Delete function
|
||||
// - POST /v1/functions/{name}/invoke - Invoke function
|
||||
// - GET /v1/functions/{name}/versions - List versions
|
||||
// - GET /v1/functions/{name}/logs - Get logs
|
||||
// - WS /v1/functions/{name}/ws - WebSocket invoke
|
||||
func (h *ServerlessHandlers) handleFunctionByName(w http.ResponseWriter, r *http.Request) {
|
||||
// Parse path: /v1/functions/{name}[/{action}]
|
||||
path := strings.TrimPrefix(r.URL.Path, "/v1/functions/")
|
||||
parts := strings.SplitN(path, "/", 2)
|
||||
|
||||
if len(parts) == 0 || parts[0] == "" {
|
||||
http.Error(w, "Function name required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
name := parts[0]
|
||||
action := ""
|
||||
if len(parts) > 1 {
|
||||
action = parts[1]
|
||||
}
|
||||
|
||||
// Parse version from name if present (e.g., "myfunction@2")
|
||||
version := 0
|
||||
if idx := strings.Index(name, "@"); idx > 0 {
|
||||
vStr := name[idx+1:]
|
||||
name = name[:idx]
|
||||
if v, err := strconv.Atoi(vStr); err == nil {
|
||||
version = v
|
||||
}
|
||||
}
|
||||
|
||||
switch action {
|
||||
case "invoke":
|
||||
h.invokeFunction(w, r, name, version)
|
||||
case "ws":
|
||||
h.handleWebSocket(w, r, name, version)
|
||||
case "versions":
|
||||
h.listVersions(w, r, name)
|
||||
case "logs":
|
||||
h.getFunctionLogs(w, r, name)
|
||||
case "triggers":
|
||||
h.handleFunctionTriggers(w, r, name)
|
||||
case "":
|
||||
switch r.Method {
|
||||
case http.MethodGet:
|
||||
h.getFunctionInfo(w, r, name, version)
|
||||
case http.MethodDelete:
|
||||
h.deleteFunction(w, r, name, version)
|
||||
default:
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
}
|
||||
default:
|
||||
http.Error(w, "Unknown action", http.StatusNotFound)
|
||||
}
|
||||
}
|
||||
|
||||
// handleInvoke handles POST /v1/invoke/{namespace}/{name}[@version]
|
||||
func (h *ServerlessHandlers) handleInvoke(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
// Parse path: /v1/invoke/{namespace}/{name}[@version]
|
||||
path := strings.TrimPrefix(r.URL.Path, "/v1/invoke/")
|
||||
parts := strings.SplitN(path, "/", 2)
|
||||
|
||||
if len(parts) < 2 {
|
||||
http.Error(w, "Path must be /v1/invoke/{namespace}/{name}", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
namespace := parts[0]
|
||||
name := parts[1]
|
||||
|
||||
// Parse version if present
|
||||
version := 0
|
||||
if idx := strings.Index(name, "@"); idx > 0 {
|
||||
vStr := name[idx+1:]
|
||||
name = name[:idx]
|
||||
if v, err := strconv.Atoi(vStr); err == nil {
|
||||
version = v
|
||||
}
|
||||
}
|
||||
|
||||
h.invokeFunction(w, r, namespace+"/"+name, version)
|
||||
}
|
||||
|
||||
// listFunctions handles GET /v1/functions
|
||||
func (h *ServerlessHandlers) listFunctions(w http.ResponseWriter, r *http.Request) {
|
||||
namespace := r.URL.Query().Get("namespace")
|
||||
if namespace == "" {
|
||||
// Get namespace from JWT if available
|
||||
namespace = h.getNamespaceFromRequest(r)
|
||||
}
|
||||
|
||||
if namespace == "" {
|
||||
writeError(w, http.StatusBadRequest, "namespace required")
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
functions, err := h.registry.List(ctx, namespace)
|
||||
if err != nil {
|
||||
h.logger.Error("Failed to list functions", zap.Error(err))
|
||||
writeError(w, http.StatusInternalServerError, "Failed to list functions")
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, map[string]interface{}{
|
||||
"functions": functions,
|
||||
"count": len(functions),
|
||||
})
|
||||
}
|
||||
|
||||
// deployFunction handles POST /v1/functions
|
||||
func (h *ServerlessHandlers) deployFunction(w http.ResponseWriter, r *http.Request) {
|
||||
// Parse multipart form (for WASM upload) or JSON
|
||||
contentType := r.Header.Get("Content-Type")
|
||||
|
||||
var def serverless.FunctionDefinition
|
||||
var wasmBytes []byte
|
||||
|
||||
if strings.HasPrefix(contentType, "multipart/form-data") {
|
||||
// Parse multipart form
|
||||
if err := r.ParseMultipartForm(32 << 20); err != nil { // 32MB max
|
||||
writeError(w, http.StatusBadRequest, "Failed to parse form: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Get metadata from form field
|
||||
metadataStr := r.FormValue("metadata")
|
||||
if metadataStr != "" {
|
||||
if err := json.Unmarshal([]byte(metadataStr), &def); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "Invalid metadata JSON: "+err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Get name from form if not in metadata
|
||||
if def.Name == "" {
|
||||
def.Name = r.FormValue("name")
|
||||
}
|
||||
|
||||
// Get namespace from form if not in metadata
|
||||
if def.Namespace == "" {
|
||||
def.Namespace = r.FormValue("namespace")
|
||||
}
|
||||
|
||||
// Get other configuration fields from form
|
||||
if v := r.FormValue("is_public"); v != "" {
|
||||
def.IsPublic, _ = strconv.ParseBool(v)
|
||||
}
|
||||
if v := r.FormValue("memory_limit_mb"); v != "" {
|
||||
def.MemoryLimitMB, _ = strconv.Atoi(v)
|
||||
}
|
||||
if v := r.FormValue("timeout_seconds"); v != "" {
|
||||
def.TimeoutSeconds, _ = strconv.Atoi(v)
|
||||
}
|
||||
if v := r.FormValue("retry_count"); v != "" {
|
||||
def.RetryCount, _ = strconv.Atoi(v)
|
||||
}
|
||||
if v := r.FormValue("retry_delay_seconds"); v != "" {
|
||||
def.RetryDelaySeconds, _ = strconv.Atoi(v)
|
||||
}
|
||||
|
||||
// Get WASM file
|
||||
file, _, err := r.FormFile("wasm")
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "WASM file required")
|
||||
return
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
wasmBytes, err = io.ReadAll(file)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "Failed to read WASM file: "+err.Error())
|
||||
return
|
||||
}
|
||||
} else {
|
||||
// JSON body with base64-encoded WASM
|
||||
var req struct {
|
||||
serverless.FunctionDefinition
|
||||
WASMBase64 string `json:"wasm_base64"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "Invalid JSON: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
def = req.FunctionDefinition
|
||||
|
||||
if req.WASMBase64 != "" {
|
||||
// Decode base64 WASM - for now, just reject this method
|
||||
writeError(w, http.StatusBadRequest, "Base64 WASM upload not supported, use multipart/form-data")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Get namespace from JWT if not provided
|
||||
if def.Namespace == "" {
|
||||
def.Namespace = h.getNamespaceFromRequest(r)
|
||||
}
|
||||
|
||||
if def.Name == "" {
|
||||
writeError(w, http.StatusBadRequest, "Function name required")
|
||||
return
|
||||
}
|
||||
if def.Namespace == "" {
|
||||
writeError(w, http.StatusBadRequest, "Namespace required")
|
||||
return
|
||||
}
|
||||
if len(wasmBytes) == 0 {
|
||||
writeError(w, http.StatusBadRequest, "WASM bytecode required")
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 60*time.Second)
|
||||
defer cancel()
|
||||
|
||||
oldFn, err := h.registry.Register(ctx, &def, wasmBytes)
|
||||
if err != nil {
|
||||
h.logger.Error("Failed to deploy function",
|
||||
zap.String("name", def.Name),
|
||||
zap.Error(err),
|
||||
)
|
||||
writeError(w, http.StatusInternalServerError, "Failed to deploy: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Invalidate cache for the old version to ensure the new one is loaded
|
||||
if oldFn != nil {
|
||||
h.invoker.InvalidateCache(oldFn.WASMCID)
|
||||
h.logger.Debug("Invalidated function cache",
|
||||
zap.String("name", def.Name),
|
||||
zap.String("old_wasm_cid", oldFn.WASMCID),
|
||||
)
|
||||
}
|
||||
|
||||
h.logger.Info("Function deployed",
|
||||
zap.String("name", def.Name),
|
||||
zap.String("namespace", def.Namespace),
|
||||
)
|
||||
|
||||
// Fetch the deployed function to return
|
||||
fn, err := h.registry.Get(ctx, def.Namespace, def.Name, def.Version)
|
||||
if err != nil {
|
||||
writeJSON(w, http.StatusCreated, map[string]interface{}{
|
||||
"message": "Function deployed successfully",
|
||||
"name": def.Name,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Register PubSub triggers if provided in metadata
|
||||
var triggersAdded []string
|
||||
if len(def.PubSubTopics) > 0 && h.triggerManager != nil {
|
||||
for _, topic := range def.PubSubTopics {
|
||||
if err := h.triggerManager.AddPubSubTrigger(ctx, fn.ID, topic); err != nil {
|
||||
// Log but don't fail deployment
|
||||
h.logger.Warn("Failed to add pubsub trigger during deployment",
|
||||
zap.String("function", def.Name),
|
||||
zap.String("topic", topic),
|
||||
zap.Error(err),
|
||||
)
|
||||
} else {
|
||||
triggersAdded = append(triggersAdded, topic)
|
||||
h.logger.Info("PubSub trigger added during deployment",
|
||||
zap.String("function", def.Name),
|
||||
zap.String("topic", topic),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
response := map[string]interface{}{
|
||||
"message": "Function deployed successfully",
|
||||
"function": fn,
|
||||
}
|
||||
if len(triggersAdded) > 0 {
|
||||
response["triggers_added"] = triggersAdded
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusCreated, response)
|
||||
}
|
||||
|
||||
// getFunctionInfo handles GET /v1/functions/{name}
|
||||
func (h *ServerlessHandlers) getFunctionInfo(w http.ResponseWriter, r *http.Request, name string, version int) {
|
||||
namespace := r.URL.Query().Get("namespace")
|
||||
if namespace == "" {
|
||||
namespace = h.getNamespaceFromRequest(r)
|
||||
}
|
||||
|
||||
if namespace == "" {
|
||||
writeError(w, http.StatusBadRequest, "namespace required")
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
fn, err := h.registry.Get(ctx, namespace, name, version)
|
||||
if err != nil {
|
||||
if serverless.IsNotFound(err) {
|
||||
writeError(w, http.StatusNotFound, "Function not found")
|
||||
} else {
|
||||
writeError(w, http.StatusInternalServerError, "Failed to get function")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, fn)
|
||||
}
|
||||
|
||||
// deleteFunction handles DELETE /v1/functions/{name}
|
||||
func (h *ServerlessHandlers) deleteFunction(w http.ResponseWriter, r *http.Request, name string, version int) {
|
||||
namespace := r.URL.Query().Get("namespace")
|
||||
if namespace == "" {
|
||||
namespace = h.getNamespaceFromRequest(r)
|
||||
}
|
||||
|
||||
if namespace == "" {
|
||||
writeError(w, http.StatusBadRequest, "namespace required")
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := h.registry.Delete(ctx, namespace, name, version); err != nil {
|
||||
if serverless.IsNotFound(err) {
|
||||
writeError(w, http.StatusNotFound, "Function not found")
|
||||
} else {
|
||||
writeError(w, http.StatusInternalServerError, "Failed to delete function")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, map[string]string{
|
||||
"message": "Function deleted successfully",
|
||||
})
|
||||
}
|
||||
|
||||
// invokeFunction handles POST /v1/functions/{name}/invoke
|
||||
func (h *ServerlessHandlers) invokeFunction(w http.ResponseWriter, r *http.Request, nameWithNS string, version int) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
// Parse namespace and name
|
||||
var namespace, name string
|
||||
if idx := strings.Index(nameWithNS, "/"); idx > 0 {
|
||||
namespace = nameWithNS[:idx]
|
||||
name = nameWithNS[idx+1:]
|
||||
} else {
|
||||
name = nameWithNS
|
||||
namespace = r.URL.Query().Get("namespace")
|
||||
if namespace == "" {
|
||||
namespace = h.getNamespaceFromRequest(r)
|
||||
}
|
||||
}
|
||||
|
||||
if namespace == "" {
|
||||
writeError(w, http.StatusBadRequest, "namespace required")
|
||||
return
|
||||
}
|
||||
|
||||
// Read input body
|
||||
input, err := io.ReadAll(io.LimitReader(r.Body, 1<<20)) // 1MB max
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "Failed to read request body")
|
||||
return
|
||||
}
|
||||
|
||||
// Get caller wallet from JWT
|
||||
callerWallet := h.getWalletFromRequest(r)
|
||||
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 60*time.Second)
|
||||
defer cancel()
|
||||
|
||||
req := &serverless.InvokeRequest{
|
||||
Namespace: namespace,
|
||||
FunctionName: name,
|
||||
Version: version,
|
||||
Input: input,
|
||||
TriggerType: serverless.TriggerTypeHTTP,
|
||||
CallerWallet: callerWallet,
|
||||
}
|
||||
|
||||
resp, err := h.invoker.Invoke(ctx, req)
|
||||
if err != nil {
|
||||
statusCode := http.StatusInternalServerError
|
||||
if serverless.IsNotFound(err) {
|
||||
statusCode = http.StatusNotFound
|
||||
} else if serverless.IsResourceExhausted(err) {
|
||||
statusCode = http.StatusTooManyRequests
|
||||
} else if serverless.IsUnauthorized(err) {
|
||||
statusCode = http.StatusUnauthorized
|
||||
}
|
||||
|
||||
writeJSON(w, statusCode, map[string]interface{}{
|
||||
"request_id": resp.RequestID,
|
||||
"status": resp.Status,
|
||||
"error": resp.Error,
|
||||
"duration_ms": resp.DurationMS,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Return the function's output directly if it's JSON
|
||||
w.Header().Set("X-Request-ID", resp.RequestID)
|
||||
w.Header().Set("X-Duration-Ms", strconv.FormatInt(resp.DurationMS, 10))
|
||||
|
||||
// Try to detect if output is JSON
|
||||
if len(resp.Output) > 0 && (resp.Output[0] == '{' || resp.Output[0] == '[') {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(resp.Output)
|
||||
} else {
|
||||
writeJSON(w, http.StatusOK, map[string]interface{}{
|
||||
"request_id": resp.RequestID,
|
||||
"output": string(resp.Output),
|
||||
"status": resp.Status,
|
||||
"duration_ms": resp.DurationMS,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// handleWebSocket handles WebSocket connections for function streaming
|
||||
func (h *ServerlessHandlers) handleWebSocket(w http.ResponseWriter, r *http.Request, name string, version int) {
|
||||
namespace := r.URL.Query().Get("namespace")
|
||||
if namespace == "" {
|
||||
namespace = h.getNamespaceFromRequest(r)
|
||||
}
|
||||
|
||||
if namespace == "" {
|
||||
http.Error(w, "namespace required", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Upgrade to WebSocket
|
||||
upgrader := websocket.Upgrader{
|
||||
CheckOrigin: func(r *http.Request) bool { return true },
|
||||
}
|
||||
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
h.logger.Error("WebSocket upgrade failed", zap.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
clientID := uuid.New().String()
|
||||
wsConn := &serverless.GorillaWSConn{Conn: conn}
|
||||
|
||||
// Register connection
|
||||
h.wsManager.Register(clientID, wsConn)
|
||||
defer h.wsManager.Unregister(clientID)
|
||||
|
||||
h.logger.Info("WebSocket connected",
|
||||
zap.String("client_id", clientID),
|
||||
zap.String("function", name),
|
||||
)
|
||||
|
||||
callerWallet := h.getWalletFromRequest(r)
|
||||
|
||||
// Message loop
|
||||
for {
|
||||
_, message, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
|
||||
h.logger.Warn("WebSocket error", zap.Error(err))
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
// Invoke function with WebSocket context
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
|
||||
req := &serverless.InvokeRequest{
|
||||
Namespace: namespace,
|
||||
FunctionName: name,
|
||||
Version: version,
|
||||
Input: message,
|
||||
TriggerType: serverless.TriggerTypeWebSocket,
|
||||
CallerWallet: callerWallet,
|
||||
WSClientID: clientID,
|
||||
}
|
||||
|
||||
resp, err := h.invoker.Invoke(ctx, req)
|
||||
cancel()
|
||||
|
||||
// Send response back
|
||||
response := map[string]interface{}{
|
||||
"request_id": resp.RequestID,
|
||||
"status": resp.Status,
|
||||
"duration_ms": resp.DurationMS,
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
response["error"] = resp.Error
|
||||
} else if len(resp.Output) > 0 {
|
||||
// Try to parse output as JSON
|
||||
var output interface{}
|
||||
if json.Unmarshal(resp.Output, &output) == nil {
|
||||
response["output"] = output
|
||||
} else {
|
||||
response["output"] = string(resp.Output)
|
||||
}
|
||||
}
|
||||
|
||||
respBytes, _ := json.Marshal(response)
|
||||
if err := conn.WriteMessage(websocket.TextMessage, respBytes); err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// listVersions handles GET /v1/functions/{name}/versions
|
||||
func (h *ServerlessHandlers) listVersions(w http.ResponseWriter, r *http.Request, name string) {
|
||||
namespace := r.URL.Query().Get("namespace")
|
||||
if namespace == "" {
|
||||
namespace = h.getNamespaceFromRequest(r)
|
||||
}
|
||||
|
||||
if namespace == "" {
|
||||
writeError(w, http.StatusBadRequest, "namespace required")
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Get registry with extended methods
|
||||
reg, ok := h.registry.(*serverless.Registry)
|
||||
if !ok {
|
||||
writeError(w, http.StatusNotImplemented, "Version listing not supported")
|
||||
return
|
||||
}
|
||||
|
||||
versions, err := reg.ListVersions(ctx, namespace, name)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "Failed to list versions")
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, map[string]interface{}{
|
||||
"versions": versions,
|
||||
"count": len(versions),
|
||||
})
|
||||
}
|
||||
|
||||
// getFunctionLogs handles GET /v1/functions/{name}/logs
|
||||
func (h *ServerlessHandlers) getFunctionLogs(w http.ResponseWriter, r *http.Request, name string) {
|
||||
namespace := r.URL.Query().Get("namespace")
|
||||
if namespace == "" {
|
||||
namespace = h.getNamespaceFromRequest(r)
|
||||
}
|
||||
|
||||
if namespace == "" {
|
||||
writeError(w, http.StatusBadRequest, "namespace required")
|
||||
return
|
||||
}
|
||||
|
||||
limit := 100
|
||||
if lStr := r.URL.Query().Get("limit"); lStr != "" {
|
||||
if l, err := strconv.Atoi(lStr); err == nil {
|
||||
limit = l
|
||||
}
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
logs, err := h.registry.GetLogs(ctx, namespace, name, limit)
|
||||
if err != nil {
|
||||
h.logger.Error("Failed to get function logs",
|
||||
zap.String("name", name),
|
||||
zap.String("namespace", namespace),
|
||||
zap.Error(err),
|
||||
)
|
||||
writeError(w, http.StatusInternalServerError, "Failed to get logs")
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, map[string]interface{}{
|
||||
"name": name,
|
||||
"namespace": namespace,
|
||||
"logs": logs,
|
||||
"count": len(logs),
|
||||
})
|
||||
}
|
||||
|
||||
// getNamespaceFromRequest extracts namespace from JWT or query param
|
||||
func (h *ServerlessHandlers) getNamespaceFromRequest(r *http.Request) string {
|
||||
// Try context first (set by auth middleware) - most secure
|
||||
if v := r.Context().Value(ctxKeyNamespaceOverride); v != nil {
|
||||
if ns, ok := v.(string); ok && ns != "" {
|
||||
return ns
|
||||
}
|
||||
}
|
||||
|
||||
// Try query param as fallback (e.g. for public access or admin)
|
||||
if ns := r.URL.Query().Get("namespace"); ns != "" {
|
||||
return ns
|
||||
}
|
||||
|
||||
// Try header as fallback
|
||||
if ns := r.Header.Get("X-Namespace"); ns != "" {
|
||||
return ns
|
||||
}
|
||||
|
||||
return "default"
|
||||
}
|
||||
|
||||
// getWalletFromRequest extracts wallet address from JWT
|
||||
func (h *ServerlessHandlers) getWalletFromRequest(r *http.Request) string {
|
||||
// 1. Try X-Wallet header (legacy/direct bypass)
|
||||
if wallet := r.Header.Get("X-Wallet"); wallet != "" {
|
||||
return wallet
|
||||
}
|
||||
|
||||
// 2. Try JWT claims from context
|
||||
if v := r.Context().Value(ctxKeyJWT); v != nil {
|
||||
if claims, ok := v.(*auth.JWTClaims); ok && claims != nil {
|
||||
subj := strings.TrimSpace(claims.Sub)
|
||||
// Ensure it's not an API key (standard Orama logic)
|
||||
if !strings.HasPrefix(strings.ToLower(subj), "ak_") && !strings.Contains(subj, ":") {
|
||||
return subj
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Fallback to API key identity (namespace)
|
||||
if v := r.Context().Value(ctxKeyNamespaceOverride); v != nil {
|
||||
if ns, ok := v.(string); ok && ns != "" {
|
||||
return ns
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// HealthStatus returns the health status of the serverless engine
|
||||
func (h *ServerlessHandlers) HealthStatus() map[string]interface{} {
|
||||
stats := h.wsManager.GetStats()
|
||||
return map[string]interface{}{
|
||||
"status": "ok",
|
||||
"connections": stats.ConnectionCount,
|
||||
"topics": stats.TopicCount,
|
||||
}
|
||||
}
|
||||
|
||||
// handleFunctionTriggers handles trigger operations for a function
|
||||
// Routes:
|
||||
// - GET /v1/functions/{name}/triggers - List all triggers
|
||||
// - POST /v1/functions/{name}/triggers/pubsub - Add pubsub trigger
|
||||
// - DELETE /v1/functions/{name}/triggers/{id} - Remove trigger
|
||||
func (h *ServerlessHandlers) handleFunctionTriggers(w http.ResponseWriter, r *http.Request, name string) {
|
||||
if h.triggerManager == nil {
|
||||
writeError(w, http.StatusServiceUnavailable, "Trigger management not available")
|
||||
return
|
||||
}
|
||||
|
||||
namespace := r.URL.Query().Get("namespace")
|
||||
if namespace == "" {
|
||||
namespace = h.getNamespaceFromRequest(r)
|
||||
}
|
||||
|
||||
if namespace == "" {
|
||||
writeError(w, http.StatusBadRequest, "namespace required")
|
||||
return
|
||||
}
|
||||
|
||||
// Parse sub-path for trigger type or ID
|
||||
// Path after "triggers" could be: "", "pubsub", or "{trigger_id}"
|
||||
fullPath := r.URL.Path
|
||||
triggersIdx := strings.Index(fullPath, "/triggers")
|
||||
subPath := ""
|
||||
if triggersIdx > 0 {
|
||||
subPath = strings.TrimPrefix(fullPath[triggersIdx:], "/triggers")
|
||||
subPath = strings.TrimPrefix(subPath, "/")
|
||||
}
|
||||
|
||||
switch r.Method {
|
||||
case http.MethodGet:
|
||||
h.listFunctionTriggers(w, r, namespace, name)
|
||||
case http.MethodPost:
|
||||
if subPath == "pubsub" {
|
||||
h.addPubSubTrigger(w, r, namespace, name)
|
||||
} else {
|
||||
writeError(w, http.StatusBadRequest, "Invalid trigger type. Use /triggers/pubsub")
|
||||
}
|
||||
case http.MethodDelete:
|
||||
if subPath == "" {
|
||||
writeError(w, http.StatusBadRequest, "Trigger ID required")
|
||||
return
|
||||
}
|
||||
h.removeFunctionTrigger(w, r, namespace, name, subPath)
|
||||
default:
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
}
|
||||
}
|
||||
|
||||
// listFunctionTriggers handles GET /v1/functions/{name}/triggers
|
||||
func (h *ServerlessHandlers) listFunctionTriggers(w http.ResponseWriter, r *http.Request, namespace, name string) {
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Get function to verify it exists and get its ID
|
||||
fn, err := h.registry.Get(ctx, namespace, name, 0)
|
||||
if err != nil {
|
||||
if serverless.IsNotFound(err) {
|
||||
writeError(w, http.StatusNotFound, "Function not found")
|
||||
return
|
||||
}
|
||||
h.logger.Error("Failed to get function",
|
||||
zap.String("name", name),
|
||||
zap.String("namespace", namespace),
|
||||
zap.Error(err),
|
||||
)
|
||||
writeError(w, http.StatusInternalServerError, "Failed to get function")
|
||||
return
|
||||
}
|
||||
|
||||
// Get pubsub triggers
|
||||
pubsubTriggers, err := h.triggerManager.ListPubSubTriggers(ctx, fn.ID)
|
||||
if err != nil {
|
||||
h.logger.Error("Failed to list triggers",
|
||||
zap.String("function_id", fn.ID),
|
||||
zap.Error(err),
|
||||
)
|
||||
writeError(w, http.StatusInternalServerError, "Failed to list triggers")
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, map[string]interface{}{
|
||||
"name": name,
|
||||
"namespace": namespace,
|
||||
"function_id": fn.ID,
|
||||
"pubsub_triggers": pubsubTriggers,
|
||||
"count": len(pubsubTriggers),
|
||||
})
|
||||
}
|
||||
|
||||
// addPubSubTrigger handles POST /v1/functions/{name}/triggers/pubsub
|
||||
func (h *ServerlessHandlers) addPubSubTrigger(w http.ResponseWriter, r *http.Request, namespace, name string) {
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Parse request body
|
||||
var req struct {
|
||||
Topic string `json:"topic"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "Invalid JSON: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if req.Topic == "" {
|
||||
writeError(w, http.StatusBadRequest, "topic is required")
|
||||
return
|
||||
}
|
||||
|
||||
// Get function to verify it exists and get its ID
|
||||
fn, err := h.registry.Get(ctx, namespace, name, 0)
|
||||
if err != nil {
|
||||
if serverless.IsNotFound(err) {
|
||||
writeError(w, http.StatusNotFound, "Function not found")
|
||||
return
|
||||
}
|
||||
h.logger.Error("Failed to get function",
|
||||
zap.String("name", name),
|
||||
zap.String("namespace", namespace),
|
||||
zap.Error(err),
|
||||
)
|
||||
writeError(w, http.StatusInternalServerError, "Failed to get function")
|
||||
return
|
||||
}
|
||||
|
||||
// Add the trigger
|
||||
err = h.triggerManager.AddPubSubTrigger(ctx, fn.ID, req.Topic)
|
||||
if err != nil {
|
||||
if serverless.IsValidationError(err) {
|
||||
writeError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
h.logger.Error("Failed to add pubsub trigger",
|
||||
zap.String("function_id", fn.ID),
|
||||
zap.String("topic", req.Topic),
|
||||
zap.Error(err),
|
||||
)
|
||||
writeError(w, http.StatusInternalServerError, "Failed to add trigger")
|
||||
return
|
||||
}
|
||||
|
||||
// Get the triggers to return the newly created one
|
||||
triggers, _ := h.triggerManager.ListPubSubTriggers(ctx, fn.ID)
|
||||
var newTrigger *serverless.PubSubTrigger
|
||||
for i := range triggers {
|
||||
if triggers[i].Topic == req.Topic {
|
||||
newTrigger = &triggers[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
h.logger.Info("PubSub trigger added",
|
||||
zap.String("function", name),
|
||||
zap.String("namespace", namespace),
|
||||
zap.String("topic", req.Topic),
|
||||
)
|
||||
|
||||
writeJSON(w, http.StatusCreated, map[string]interface{}{
|
||||
"message": "Trigger added successfully",
|
||||
"trigger": newTrigger,
|
||||
})
|
||||
}
|
||||
|
||||
// removeFunctionTrigger handles DELETE /v1/functions/{name}/triggers/{id}
|
||||
func (h *ServerlessHandlers) removeFunctionTrigger(w http.ResponseWriter, r *http.Request, namespace, name, triggerID string) {
|
||||
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Verify function exists
|
||||
fn, err := h.registry.Get(ctx, namespace, name, 0)
|
||||
if err != nil {
|
||||
if serverless.IsNotFound(err) {
|
||||
writeError(w, http.StatusNotFound, "Function not found")
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusInternalServerError, "Failed to get function")
|
||||
return
|
||||
}
|
||||
|
||||
// Remove the trigger
|
||||
err = h.triggerManager.RemoveTrigger(ctx, triggerID)
|
||||
if err != nil {
|
||||
if err == serverless.ErrTriggerNotFound {
|
||||
writeError(w, http.StatusNotFound, "Trigger not found")
|
||||
return
|
||||
}
|
||||
h.logger.Error("Failed to remove trigger",
|
||||
zap.String("trigger_id", triggerID),
|
||||
zap.Error(err),
|
||||
)
|
||||
writeError(w, http.StatusInternalServerError, "Failed to remove trigger")
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Info("Trigger removed",
|
||||
zap.String("function", name),
|
||||
zap.String("function_id", fn.ID),
|
||||
zap.String("trigger_id", triggerID),
|
||||
)
|
||||
|
||||
writeJSON(w, http.StatusOK, map[string]interface{}{
|
||||
"message": "Trigger removed successfully",
|
||||
"trigger_id": triggerID,
|
||||
})
|
||||
}
|
||||
@ -8,7 +8,6 @@ import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
serverlesshandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/serverless"
|
||||
"github.com/DeBrosOfficial/network/pkg/serverless"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
@ -50,12 +49,12 @@ func TestServerlessHandlers_ListFunctions(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
h := serverlesshandlers.NewServerlessHandlers(nil, registry, nil, logger)
|
||||
h := NewServerlessHandlers(nil, registry, nil, nil, logger)
|
||||
|
||||
req, _ := http.NewRequest("GET", "/v1/functions?namespace=ns1", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
h.ListFunctions(rr, req)
|
||||
h.handleFunctions(rr, req)
|
||||
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d", rr.Code)
|
||||
@ -73,7 +72,7 @@ func TestServerlessHandlers_DeployFunction(t *testing.T) {
|
||||
logger := zap.NewNop()
|
||||
registry := &mockFunctionRegistry{}
|
||||
|
||||
h := serverlesshandlers.NewServerlessHandlers(nil, registry, nil, logger)
|
||||
h := NewServerlessHandlers(nil, registry, nil, nil, logger)
|
||||
|
||||
// Test JSON deploy (which is partially supported according to code)
|
||||
// Should be 400 because WASM is missing or base64 not supported
|
||||
@ -81,7 +80,7 @@ func TestServerlessHandlers_DeployFunction(t *testing.T) {
|
||||
req, _ := http.NewRequest("POST", "/v1/functions", bytes.NewBufferString(`{"name": "test"}`))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.DeployFunction(writer, req)
|
||||
h.handleFunctions(writer, req)
|
||||
|
||||
if writer.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected status 400, got %d", writer.Code)
|
||||
|
||||
181
pkg/gateway/sfu/config.go
Normal file
181
pkg/gateway/sfu/config.go
Normal file
@ -0,0 +1,181 @@
|
||||
package sfu
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/pion/interceptor"
|
||||
"github.com/pion/interceptor/pkg/intervalpli"
|
||||
"github.com/pion/interceptor/pkg/nack"
|
||||
"github.com/pion/webrtc/v4"
|
||||
)
|
||||
|
||||
// Config holds SFU configuration
|
||||
type Config struct {
|
||||
// MaxParticipants is the maximum number of participants per room
|
||||
MaxParticipants int
|
||||
|
||||
// MediaTimeout is the timeout for media operations
|
||||
MediaTimeout time.Duration
|
||||
|
||||
// ICEServers are the ICE servers for WebRTC connections
|
||||
ICEServers []webrtc.ICEServer
|
||||
}
|
||||
|
||||
// DefaultConfig returns a default SFU configuration
|
||||
func DefaultConfig() *Config {
|
||||
return &Config{
|
||||
MaxParticipants: 10,
|
||||
MediaTimeout: 30 * time.Second,
|
||||
ICEServers: []webrtc.ICEServer{
|
||||
{
|
||||
URLs: []string{"stun:stun.l.google.com:19302"},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// NewMediaEngine creates a MediaEngine with supported codecs for the SFU
|
||||
func NewMediaEngine() (*webrtc.MediaEngine, error) {
|
||||
m := &webrtc.MediaEngine{}
|
||||
|
||||
// RTCP feedback for video codecs - enables NACK, PLI, FIR
|
||||
videoRTCPFeedback := []webrtc.RTCPFeedback{
|
||||
{Type: "goog-remb", Parameter: ""}, // Bandwidth estimation
|
||||
{Type: "ccm", Parameter: "fir"}, // Full Intra Request
|
||||
{Type: "nack", Parameter: ""}, // Generic NACK
|
||||
{Type: "nack", Parameter: "pli"}, // Picture Loss Indication
|
||||
}
|
||||
|
||||
// Register Opus codec for audio with NACK support
|
||||
if err := m.RegisterCodec(webrtc.RTPCodecParameters{
|
||||
RTPCodecCapability: webrtc.RTPCodecCapability{
|
||||
MimeType: webrtc.MimeTypeOpus,
|
||||
ClockRate: 48000,
|
||||
Channels: 2,
|
||||
SDPFmtpLine: "minptime=10;useinbandfec=1",
|
||||
},
|
||||
PayloadType: 111,
|
||||
}, webrtc.RTPCodecTypeAudio); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Register VP8 codec for video with full RTCP feedback
|
||||
if err := m.RegisterCodec(webrtc.RTPCodecParameters{
|
||||
RTPCodecCapability: webrtc.RTPCodecCapability{
|
||||
MimeType: webrtc.MimeTypeVP8,
|
||||
ClockRate: 90000,
|
||||
RTCPFeedback: videoRTCPFeedback,
|
||||
},
|
||||
PayloadType: 96,
|
||||
}, webrtc.RTPCodecTypeVideo); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Register RTX for VP8 (retransmission)
|
||||
if err := m.RegisterCodec(webrtc.RTPCodecParameters{
|
||||
RTPCodecCapability: webrtc.RTPCodecCapability{
|
||||
MimeType: "video/rtx",
|
||||
ClockRate: 90000,
|
||||
SDPFmtpLine: "apt=96",
|
||||
},
|
||||
PayloadType: 97,
|
||||
}, webrtc.RTPCodecTypeVideo); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Register H264 codec for video with full RTCP feedback (fallback)
|
||||
if err := m.RegisterCodec(webrtc.RTPCodecParameters{
|
||||
RTPCodecCapability: webrtc.RTPCodecCapability{
|
||||
MimeType: webrtc.MimeTypeH264,
|
||||
ClockRate: 90000,
|
||||
SDPFmtpLine: "level-asymmetry-allowed=1;packetization-mode=1;profile-level-id=42e01f",
|
||||
RTCPFeedback: videoRTCPFeedback,
|
||||
},
|
||||
PayloadType: 102,
|
||||
}, webrtc.RTPCodecTypeVideo); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Register RTX for H264 (retransmission)
|
||||
if err := m.RegisterCodec(webrtc.RTPCodecParameters{
|
||||
RTPCodecCapability: webrtc.RTPCodecCapability{
|
||||
MimeType: "video/rtx",
|
||||
ClockRate: 90000,
|
||||
SDPFmtpLine: "apt=102",
|
||||
},
|
||||
PayloadType: 103,
|
||||
}, webrtc.RTPCodecTypeVideo); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// NewWebRTCAPI creates a new WebRTC API with the configured MediaEngine
|
||||
func NewWebRTCAPI() (*webrtc.API, error) {
|
||||
mediaEngine, err := NewMediaEngine()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create interceptor registry for RTCP feedback
|
||||
i := &interceptor.Registry{}
|
||||
|
||||
// Register NACK responder - handles retransmission requests from receivers
|
||||
// This is critical for video quality: when a receiver loses a packet,
|
||||
// it sends NACK and the sender retransmits the lost packet
|
||||
nackResponderFactory, err := nack.NewResponderInterceptor()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
i.Add(nackResponderFactory)
|
||||
|
||||
// Register NACK generator - sends NACK when we detect packet loss as receiver
|
||||
nackGeneratorFactory, err := nack.NewGeneratorInterceptor()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
i.Add(nackGeneratorFactory)
|
||||
|
||||
// Register interval PLI - automatically sends PLI periodically for video tracks
|
||||
// This helps new receivers get keyframes faster
|
||||
intervalPLIFactory, err := intervalpli.NewReceiverInterceptor()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
i.Add(intervalPLIFactory)
|
||||
|
||||
// Configure settings for better media performance
|
||||
settingEngine := webrtc.SettingEngine{}
|
||||
|
||||
return webrtc.NewAPI(
|
||||
webrtc.WithMediaEngine(mediaEngine),
|
||||
webrtc.WithInterceptorRegistry(i),
|
||||
webrtc.WithSettingEngine(settingEngine),
|
||||
), nil
|
||||
}
|
||||
|
||||
// GetRTPCapabilities returns the RTP capabilities of the SFU
|
||||
// This is used by clients to negotiate codecs
|
||||
func GetRTPCapabilities() map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"codecs": []map[string]interface{}{
|
||||
{
|
||||
"kind": "audio",
|
||||
"mimeType": "audio/opus",
|
||||
"clockRate": 48000,
|
||||
"channels": 2,
|
||||
},
|
||||
{
|
||||
"kind": "video",
|
||||
"mimeType": "video/VP8",
|
||||
"clockRate": 90000,
|
||||
},
|
||||
{
|
||||
"kind": "video",
|
||||
"mimeType": "video/H264",
|
||||
"clockRate": 90000,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
228
pkg/gateway/sfu/manager.go
Normal file
228
pkg/gateway/sfu/manager.go
Normal file
@ -0,0 +1,228 @@
|
||||
package sfu
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
|
||||
"github.com/pion/webrtc/v4"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrRoomNotFound = errors.New("room not found")
|
||||
)
|
||||
|
||||
// RoomManager manages all rooms in the SFU
|
||||
type RoomManager struct {
|
||||
rooms map[string]*Room // key: namespace:roomId
|
||||
roomsMu sync.RWMutex
|
||||
|
||||
api *webrtc.API
|
||||
config *Config
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewRoomManager creates a new room manager
|
||||
func NewRoomManager(config *Config, logger *zap.Logger) (*RoomManager, error) {
|
||||
if config == nil {
|
||||
config = DefaultConfig()
|
||||
}
|
||||
|
||||
api, err := NewWebRTCAPI()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &RoomManager{
|
||||
rooms: make(map[string]*Room),
|
||||
api: api,
|
||||
config: config,
|
||||
logger: logger.With(zap.String("component", "sfu_manager")),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// roomKey creates a unique key for namespace:roomId
|
||||
func roomKey(namespace, roomID string) string {
|
||||
return namespace + ":" + roomID
|
||||
}
|
||||
|
||||
// GetOrCreateRoom gets an existing room or creates a new one
|
||||
func (m *RoomManager) GetOrCreateRoom(namespace, roomID string) (*Room, bool) {
|
||||
key := roomKey(namespace, roomID)
|
||||
|
||||
m.roomsMu.Lock()
|
||||
defer m.roomsMu.Unlock()
|
||||
|
||||
// Check if room already exists
|
||||
if room, ok := m.rooms[key]; ok {
|
||||
if !room.IsClosed() {
|
||||
return room, false
|
||||
}
|
||||
// Room is closed, remove it and create a new one
|
||||
delete(m.rooms, key)
|
||||
}
|
||||
|
||||
// Create new room
|
||||
room := NewRoom(roomID, namespace, m.api, m.config, m.logger)
|
||||
|
||||
// Set up empty room handler
|
||||
room.OnEmpty(func(r *Room) {
|
||||
m.logger.Info("Room is empty, will be garbage collected",
|
||||
zap.String("room_id", r.ID),
|
||||
zap.String("namespace", r.Namespace),
|
||||
)
|
||||
// Optionally close the room immediately
|
||||
// m.CloseRoom(r.Namespace, r.ID)
|
||||
})
|
||||
|
||||
m.rooms[key] = room
|
||||
|
||||
m.logger.Info("Room created",
|
||||
zap.String("room_id", roomID),
|
||||
zap.String("namespace", namespace),
|
||||
)
|
||||
|
||||
return room, true
|
||||
}
|
||||
|
||||
// GetRoom gets an existing room
|
||||
func (m *RoomManager) GetRoom(namespace, roomID string) (*Room, error) {
|
||||
key := roomKey(namespace, roomID)
|
||||
|
||||
m.roomsMu.RLock()
|
||||
defer m.roomsMu.RUnlock()
|
||||
|
||||
room, ok := m.rooms[key]
|
||||
if !ok {
|
||||
return nil, ErrRoomNotFound
|
||||
}
|
||||
|
||||
if room.IsClosed() {
|
||||
return nil, ErrRoomClosed
|
||||
}
|
||||
|
||||
return room, nil
|
||||
}
|
||||
|
||||
// CloseRoom closes and removes a room
|
||||
func (m *RoomManager) CloseRoom(namespace, roomID string) error {
|
||||
key := roomKey(namespace, roomID)
|
||||
|
||||
m.roomsMu.Lock()
|
||||
room, ok := m.rooms[key]
|
||||
if !ok {
|
||||
m.roomsMu.Unlock()
|
||||
return ErrRoomNotFound
|
||||
}
|
||||
delete(m.rooms, key)
|
||||
m.roomsMu.Unlock()
|
||||
|
||||
if err := room.Close(); err != nil {
|
||||
m.logger.Warn("Error closing room",
|
||||
zap.String("room_id", roomID),
|
||||
zap.Error(err),
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
m.logger.Info("Room closed",
|
||||
zap.String("room_id", roomID),
|
||||
zap.String("namespace", namespace),
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListRooms returns all rooms for a namespace
|
||||
func (m *RoomManager) ListRooms(namespace string) []*Room {
|
||||
m.roomsMu.RLock()
|
||||
defer m.roomsMu.RUnlock()
|
||||
|
||||
prefix := namespace + ":"
|
||||
rooms := make([]*Room, 0)
|
||||
|
||||
for key, room := range m.rooms {
|
||||
if len(key) > len(prefix) && key[:len(prefix)] == prefix && !room.IsClosed() {
|
||||
rooms = append(rooms, room)
|
||||
}
|
||||
}
|
||||
|
||||
return rooms
|
||||
}
|
||||
|
||||
// RoomInfo contains public information about a room
|
||||
type RoomInfo struct {
|
||||
ID string `json:"id"`
|
||||
Namespace string `json:"namespace"`
|
||||
ParticipantCount int `json:"participantCount"`
|
||||
Participants []ParticipantInfo `json:"participants"`
|
||||
}
|
||||
|
||||
// GetRoomInfo returns public information about a room
|
||||
func (m *RoomManager) GetRoomInfo(namespace, roomID string) (*RoomInfo, error) {
|
||||
room, err := m.GetRoom(namespace, roomID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &RoomInfo{
|
||||
ID: room.ID,
|
||||
Namespace: room.Namespace,
|
||||
ParticipantCount: room.GetParticipantCount(),
|
||||
Participants: room.GetParticipants(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetStats returns statistics about the room manager
|
||||
func (m *RoomManager) GetStats() map[string]interface{} {
|
||||
m.roomsMu.RLock()
|
||||
defer m.roomsMu.RUnlock()
|
||||
|
||||
totalRooms := 0
|
||||
totalParticipants := 0
|
||||
|
||||
for _, room := range m.rooms {
|
||||
if !room.IsClosed() {
|
||||
totalRooms++
|
||||
totalParticipants += room.GetParticipantCount()
|
||||
}
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"totalRooms": totalRooms,
|
||||
"totalParticipants": totalParticipants,
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes all rooms and cleans up resources
|
||||
func (m *RoomManager) Close() error {
|
||||
m.roomsMu.Lock()
|
||||
rooms := make([]*Room, 0, len(m.rooms))
|
||||
for _, room := range m.rooms {
|
||||
rooms = append(rooms, room)
|
||||
}
|
||||
m.rooms = make(map[string]*Room)
|
||||
m.roomsMu.Unlock()
|
||||
|
||||
for _, room := range rooms {
|
||||
if err := room.Close(); err != nil {
|
||||
m.logger.Warn("Error closing room during shutdown",
|
||||
zap.String("room_id", room.ID),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
m.logger.Info("Room manager closed", zap.Int("rooms_closed", len(rooms)))
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetConfig returns the SFU configuration
|
||||
func (m *RoomManager) GetConfig() *Config {
|
||||
return m.config
|
||||
}
|
||||
|
||||
// UpdateICEServers updates the ICE servers configuration
|
||||
func (m *RoomManager) UpdateICEServers(servers []webrtc.ICEServer) {
|
||||
m.config.ICEServers = servers
|
||||
}
|
||||
505
pkg/gateway/sfu/peer.go
Normal file
505
pkg/gateway/sfu/peer.go
Normal file
@ -0,0 +1,505 @@
|
||||
package sfu
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"sync"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/pion/rtcp"
|
||||
"github.com/pion/webrtc/v4"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Peer represents a participant in a room
|
||||
type Peer struct {
|
||||
ID string
|
||||
UserID string
|
||||
DisplayName string
|
||||
|
||||
// WebRTC connection
|
||||
pc *webrtc.PeerConnection
|
||||
|
||||
// Tracks published by this peer (local tracks that others receive)
|
||||
localTracks map[string]*webrtc.TrackLocalStaticRTP
|
||||
localTracksMu sync.RWMutex
|
||||
|
||||
// Track receivers for consuming other peers' tracks
|
||||
trackReceivers map[string]*webrtc.RTPReceiver
|
||||
trackReceiversMu sync.RWMutex
|
||||
|
||||
// WebSocket connection for signaling
|
||||
conn *websocket.Conn
|
||||
connMu sync.Mutex
|
||||
|
||||
// State
|
||||
audioMuted bool
|
||||
videoMuted bool
|
||||
closed bool
|
||||
closedMu sync.RWMutex
|
||||
negotiationPending bool
|
||||
negotiationPendingMu sync.Mutex
|
||||
initialOfferHandled bool
|
||||
initialOfferMu sync.Mutex
|
||||
batchingTracks bool // When true, suppress automatic negotiation
|
||||
batchingTracksMu sync.Mutex
|
||||
|
||||
// Room reference
|
||||
room *Room
|
||||
logger *zap.Logger
|
||||
|
||||
// Callbacks
|
||||
onClose func(*Peer)
|
||||
}
|
||||
|
||||
// NewPeer creates a new peer
|
||||
func NewPeer(userID, displayName string, conn *websocket.Conn, room *Room, logger *zap.Logger) *Peer {
|
||||
return &Peer{
|
||||
ID: uuid.New().String(),
|
||||
UserID: userID,
|
||||
DisplayName: displayName,
|
||||
localTracks: make(map[string]*webrtc.TrackLocalStaticRTP),
|
||||
trackReceivers: make(map[string]*webrtc.RTPReceiver),
|
||||
conn: conn,
|
||||
room: room,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// InitPeerConnection initializes the WebRTC peer connection
|
||||
func (p *Peer) InitPeerConnection(api *webrtc.API, config webrtc.Configuration) error {
|
||||
pc, err := api.NewPeerConnection(config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p.pc = pc
|
||||
|
||||
// Handle ICE connection state changes
|
||||
pc.OnICEConnectionStateChange(func(state webrtc.ICEConnectionState) {
|
||||
p.logger.Info("ICE connection state changed",
|
||||
zap.String("peer_id", p.ID),
|
||||
zap.String("state", state.String()),
|
||||
)
|
||||
|
||||
if state == webrtc.ICEConnectionStateFailed ||
|
||||
state == webrtc.ICEConnectionStateDisconnected ||
|
||||
state == webrtc.ICEConnectionStateClosed {
|
||||
p.handleDisconnect()
|
||||
}
|
||||
})
|
||||
|
||||
// Handle ICE candidates
|
||||
pc.OnICECandidate(func(candidate *webrtc.ICECandidate) {
|
||||
if candidate == nil {
|
||||
return
|
||||
}
|
||||
|
||||
p.logger.Debug("ICE candidate generated",
|
||||
zap.String("peer_id", p.ID),
|
||||
zap.String("candidate", candidate.String()),
|
||||
)
|
||||
|
||||
candidateJSON := candidate.ToJSON()
|
||||
data := &ICECandidateData{
|
||||
Candidate: candidateJSON.Candidate,
|
||||
}
|
||||
if candidateJSON.SDPMid != nil {
|
||||
data.SDPMid = *candidateJSON.SDPMid
|
||||
}
|
||||
if candidateJSON.SDPMLineIndex != nil {
|
||||
data.SDPMLineIndex = *candidateJSON.SDPMLineIndex
|
||||
}
|
||||
if candidateJSON.UsernameFragment != nil {
|
||||
data.UsernameFragment = *candidateJSON.UsernameFragment
|
||||
}
|
||||
|
||||
p.SendMessage(NewServerMessage(MessageTypeICECandidate, data))
|
||||
})
|
||||
|
||||
// Handle incoming tracks from remote peers
|
||||
pc.OnTrack(func(track *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) {
|
||||
codec := track.Codec()
|
||||
p.logger.Info("Track received from client",
|
||||
zap.String("peer_id", p.ID),
|
||||
zap.String("user_id", p.UserID),
|
||||
zap.String("track_id", track.ID()),
|
||||
zap.String("stream_id", track.StreamID()),
|
||||
zap.String("kind", track.Kind().String()),
|
||||
zap.String("codec_mime", codec.MimeType),
|
||||
zap.Uint32("codec_clock_rate", codec.ClockRate),
|
||||
zap.Uint8("codec_payload_type", uint8(codec.PayloadType)),
|
||||
)
|
||||
|
||||
p.trackReceiversMu.Lock()
|
||||
p.trackReceivers[track.ID()] = receiver
|
||||
p.trackReceiversMu.Unlock()
|
||||
|
||||
// Start RTCP reader to monitor for packet loss (NACK) and PLI requests
|
||||
go p.readRTCP(receiver, track)
|
||||
|
||||
// Forward track to other peers in the room
|
||||
p.room.BroadcastTrack(p.ID, track)
|
||||
})
|
||||
|
||||
// Handle negotiation needed - only trigger when in stable state
|
||||
pc.OnNegotiationNeeded(func() {
|
||||
p.logger.Debug("Negotiation needed",
|
||||
zap.String("peer_id", p.ID),
|
||||
zap.String("signaling_state", pc.SignalingState().String()),
|
||||
)
|
||||
|
||||
// Check if we're batching tracks - if so, just mark as pending
|
||||
p.batchingTracksMu.Lock()
|
||||
batching := p.batchingTracks
|
||||
p.batchingTracksMu.Unlock()
|
||||
|
||||
if batching {
|
||||
p.negotiationPendingMu.Lock()
|
||||
p.negotiationPending = true
|
||||
p.negotiationPendingMu.Unlock()
|
||||
p.logger.Debug("Negotiation deferred - batching tracks",
|
||||
zap.String("peer_id", p.ID),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
// Only create offer if we're in stable state
|
||||
// Otherwise, mark negotiation as pending
|
||||
if pc.SignalingState() == webrtc.SignalingStateStable {
|
||||
p.createAndSendOffer()
|
||||
} else {
|
||||
p.negotiationPendingMu.Lock()
|
||||
p.negotiationPending = true
|
||||
p.negotiationPendingMu.Unlock()
|
||||
p.logger.Debug("Negotiation queued - not in stable state",
|
||||
zap.String("peer_id", p.ID),
|
||||
zap.String("signaling_state", pc.SignalingState().String()),
|
||||
)
|
||||
}
|
||||
})
|
||||
|
||||
// Handle signaling state changes to process pending negotiations
|
||||
pc.OnSignalingStateChange(func(state webrtc.SignalingState) {
|
||||
p.logger.Debug("Signaling state changed",
|
||||
zap.String("peer_id", p.ID),
|
||||
zap.String("state", state.String()),
|
||||
)
|
||||
|
||||
// When we return to stable state, check if negotiation was pending
|
||||
if state == webrtc.SignalingStateStable {
|
||||
p.negotiationPendingMu.Lock()
|
||||
pending := p.negotiationPending
|
||||
p.negotiationPending = false
|
||||
p.negotiationPendingMu.Unlock()
|
||||
|
||||
if pending {
|
||||
p.logger.Debug("Processing pending negotiation", zap.String("peer_id", p.ID))
|
||||
p.createAndSendOffer()
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// createAndSendOffer creates an SDP offer and sends it to the peer
|
||||
func (p *Peer) createAndSendOffer() {
|
||||
if p.pc == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Double-check signaling state before creating offer
|
||||
if p.pc.SignalingState() != webrtc.SignalingStateStable {
|
||||
p.logger.Debug("Skipping offer - not in stable state",
|
||||
zap.String("peer_id", p.ID),
|
||||
zap.String("signaling_state", p.pc.SignalingState().String()),
|
||||
)
|
||||
// Mark as pending so it will be retried when state becomes stable
|
||||
p.negotiationPendingMu.Lock()
|
||||
p.negotiationPending = true
|
||||
p.negotiationPendingMu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
p.logger.Info("Creating SDP offer", zap.String("peer_id", p.ID))
|
||||
|
||||
offer, err := p.pc.CreateOffer(nil)
|
||||
if err != nil {
|
||||
p.logger.Error("Failed to create offer",
|
||||
zap.String("peer_id", p.ID),
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
if err := p.pc.SetLocalDescription(offer); err != nil {
|
||||
p.logger.Error("Failed to set local description",
|
||||
zap.String("peer_id", p.ID),
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
p.logger.Info("Sending SDP offer", zap.String("peer_id", p.ID))
|
||||
p.SendMessage(NewServerMessage(MessageTypeOffer, &OfferData{
|
||||
SDP: offer.SDP,
|
||||
}))
|
||||
}
|
||||
|
||||
// HandleOffer processes an SDP offer from the client
|
||||
func (p *Peer) HandleOffer(sdp string) error {
|
||||
if p.pc == nil {
|
||||
return ErrPeerNotInitialized
|
||||
}
|
||||
|
||||
offer := webrtc.SessionDescription{
|
||||
Type: webrtc.SDPTypeOffer,
|
||||
SDP: sdp,
|
||||
}
|
||||
|
||||
if err := p.pc.SetRemoteDescription(offer); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
answer, err := p.pc.CreateAnswer(nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := p.pc.SetLocalDescription(answer); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p.SendMessage(NewServerMessage(MessageTypeAnswer, &AnswerData{
|
||||
SDP: answer.SDP,
|
||||
}))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// HandleAnswer processes an SDP answer from the client
|
||||
func (p *Peer) HandleAnswer(sdp string) error {
|
||||
if p.pc == nil {
|
||||
return ErrPeerNotInitialized
|
||||
}
|
||||
|
||||
answer := webrtc.SessionDescription{
|
||||
Type: webrtc.SDPTypeAnswer,
|
||||
SDP: sdp,
|
||||
}
|
||||
|
||||
return p.pc.SetRemoteDescription(answer)
|
||||
}
|
||||
|
||||
// HandleICECandidate processes an ICE candidate from the client
|
||||
func (p *Peer) HandleICECandidate(data *ICECandidateData) error {
|
||||
if p.pc == nil {
|
||||
return ErrPeerNotInitialized
|
||||
}
|
||||
|
||||
return p.pc.AddICECandidate(data.ToWebRTCCandidate())
|
||||
}
|
||||
|
||||
// AddTrack adds a track to send to this peer (from another peer)
|
||||
func (p *Peer) AddTrack(track *webrtc.TrackLocalStaticRTP) (*webrtc.RTPSender, error) {
|
||||
if p.pc == nil {
|
||||
return nil, ErrPeerNotInitialized
|
||||
}
|
||||
|
||||
return p.pc.AddTrack(track)
|
||||
}
|
||||
|
||||
// StartTrackBatch starts batching track additions.
|
||||
// Call EndTrackBatch when done to trigger a single renegotiation.
|
||||
func (p *Peer) StartTrackBatch() {
|
||||
p.batchingTracksMu.Lock()
|
||||
p.batchingTracks = true
|
||||
p.batchingTracksMu.Unlock()
|
||||
p.logger.Debug("Started track batching", zap.String("peer_id", p.ID))
|
||||
}
|
||||
|
||||
// EndTrackBatch ends track batching and triggers renegotiation if needed.
|
||||
func (p *Peer) EndTrackBatch() {
|
||||
p.batchingTracksMu.Lock()
|
||||
p.batchingTracks = false
|
||||
p.batchingTracksMu.Unlock()
|
||||
|
||||
// Check if negotiation was pending during batching
|
||||
p.negotiationPendingMu.Lock()
|
||||
pending := p.negotiationPending
|
||||
p.negotiationPending = false
|
||||
p.negotiationPendingMu.Unlock()
|
||||
|
||||
if pending && p.pc != nil && p.pc.SignalingState() == webrtc.SignalingStateStable {
|
||||
p.logger.Debug("Processing batched negotiation", zap.String("peer_id", p.ID))
|
||||
p.createAndSendOffer()
|
||||
}
|
||||
}
|
||||
|
||||
// SendMessage sends a signaling message to the peer via WebSocket
|
||||
func (p *Peer) SendMessage(msg *ServerMessage) error {
|
||||
p.closedMu.RLock()
|
||||
if p.closed {
|
||||
p.closedMu.RUnlock()
|
||||
return ErrPeerClosed
|
||||
}
|
||||
p.closedMu.RUnlock()
|
||||
|
||||
p.connMu.Lock()
|
||||
defer p.connMu.Unlock()
|
||||
|
||||
if p.conn == nil {
|
||||
return ErrWebSocketClosed
|
||||
}
|
||||
|
||||
data, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return p.conn.WriteMessage(websocket.TextMessage, data)
|
||||
}
|
||||
|
||||
// GetInfo returns public information about this peer
|
||||
func (p *Peer) GetInfo() ParticipantInfo {
|
||||
p.localTracksMu.RLock()
|
||||
hasAudio := false
|
||||
hasVideo := false
|
||||
for _, track := range p.localTracks {
|
||||
if track.Kind() == webrtc.RTPCodecTypeAudio {
|
||||
hasAudio = true
|
||||
} else if track.Kind() == webrtc.RTPCodecTypeVideo {
|
||||
hasVideo = true
|
||||
}
|
||||
}
|
||||
p.localTracksMu.RUnlock()
|
||||
|
||||
return ParticipantInfo{
|
||||
ID: p.ID,
|
||||
UserID: p.UserID,
|
||||
DisplayName: p.DisplayName,
|
||||
HasAudio: hasAudio,
|
||||
HasVideo: hasVideo,
|
||||
AudioMuted: p.audioMuted,
|
||||
VideoMuted: p.videoMuted,
|
||||
}
|
||||
}
|
||||
|
||||
// SetAudioMuted sets the audio mute state
|
||||
func (p *Peer) SetAudioMuted(muted bool) {
|
||||
p.audioMuted = muted
|
||||
}
|
||||
|
||||
// SetVideoMuted sets the video mute state
|
||||
func (p *Peer) SetVideoMuted(muted bool) {
|
||||
p.videoMuted = muted
|
||||
}
|
||||
|
||||
// MarkInitialOfferHandled marks that the initial offer has been processed.
|
||||
// Returns true if this is the first time it's being marked (i.e., this was the first offer).
|
||||
func (p *Peer) MarkInitialOfferHandled() bool {
|
||||
p.initialOfferMu.Lock()
|
||||
defer p.initialOfferMu.Unlock()
|
||||
|
||||
if p.initialOfferHandled {
|
||||
return false
|
||||
}
|
||||
p.initialOfferHandled = true
|
||||
return true
|
||||
}
|
||||
|
||||
// handleDisconnect handles peer disconnection
|
||||
func (p *Peer) handleDisconnect() {
|
||||
p.closedMu.Lock()
|
||||
if p.closed {
|
||||
p.closedMu.Unlock()
|
||||
return
|
||||
}
|
||||
p.closed = true
|
||||
p.closedMu.Unlock()
|
||||
|
||||
if p.onClose != nil {
|
||||
p.onClose(p)
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes the peer connection and cleans up resources
|
||||
func (p *Peer) Close() error {
|
||||
p.closedMu.Lock()
|
||||
if p.closed {
|
||||
p.closedMu.Unlock()
|
||||
return nil
|
||||
}
|
||||
p.closed = true
|
||||
p.closedMu.Unlock()
|
||||
|
||||
p.logger.Info("Closing peer", zap.String("peer_id", p.ID))
|
||||
|
||||
// Close WebSocket
|
||||
p.connMu.Lock()
|
||||
if p.conn != nil {
|
||||
p.conn.Close()
|
||||
p.conn = nil
|
||||
}
|
||||
p.connMu.Unlock()
|
||||
|
||||
// Close peer connection
|
||||
if p.pc != nil {
|
||||
return p.pc.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// OnClose sets a callback for when the peer is closed
|
||||
func (p *Peer) OnClose(fn func(*Peer)) {
|
||||
p.onClose = fn
|
||||
}
|
||||
|
||||
// readRTCP reads RTCP packets from a receiver to monitor feedback
|
||||
// This helps detect packet loss (via NACK) for adaptive quality adjustments
|
||||
func (p *Peer) readRTCP(receiver *webrtc.RTPReceiver, track *webrtc.TrackRemote) {
|
||||
localTrackID := track.Kind().String() + "-" + p.ID
|
||||
|
||||
for {
|
||||
packets, _, err := receiver.ReadRTCP()
|
||||
if err != nil {
|
||||
// Connection closed, exit gracefully
|
||||
return
|
||||
}
|
||||
|
||||
for _, pkt := range packets {
|
||||
switch rtcpPkt := pkt.(type) {
|
||||
case *rtcp.TransportLayerNack:
|
||||
// NACK received - indicates packet loss from receivers
|
||||
// Increment the NACK counter for adaptive keyframe logic
|
||||
p.room.IncrementNackCount(localTrackID)
|
||||
|
||||
p.logger.Debug("NACK received",
|
||||
zap.String("peer_id", p.ID),
|
||||
zap.String("track_id", localTrackID),
|
||||
zap.Uint32("sender_ssrc", rtcpPkt.SenderSSRC),
|
||||
zap.Int("nack_pairs", len(rtcpPkt.Nacks)),
|
||||
)
|
||||
|
||||
case *rtcp.PictureLossIndication:
|
||||
// PLI received - receiver needs a keyframe
|
||||
p.logger.Debug("PLI received from receiver",
|
||||
zap.String("peer_id", p.ID),
|
||||
zap.String("track_id", localTrackID),
|
||||
)
|
||||
// Request keyframe from source
|
||||
p.room.RequestKeyframe(localTrackID)
|
||||
|
||||
case *rtcp.FullIntraRequest:
|
||||
// FIR received - receiver needs a full keyframe
|
||||
p.logger.Debug("FIR received from receiver",
|
||||
zap.String("peer_id", p.ID),
|
||||
zap.String("track_id", localTrackID),
|
||||
)
|
||||
// Request keyframe from source
|
||||
p.room.RequestKeyframe(localTrackID)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
775
pkg/gateway/sfu/room.go
Normal file
775
pkg/gateway/sfu/room.go
Normal file
@ -0,0 +1,775 @@
|
||||
package sfu
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/pion/rtcp"
|
||||
"github.com/pion/webrtc/v4"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Common errors
|
||||
var (
|
||||
ErrRoomFull = errors.New("room is full")
|
||||
ErrRoomClosed = errors.New("room is closed")
|
||||
ErrPeerNotFound = errors.New("peer not found")
|
||||
ErrPeerNotInitialized = errors.New("peer not initialized")
|
||||
ErrPeerClosed = errors.New("peer is closed")
|
||||
ErrWebSocketClosed = errors.New("websocket connection closed")
|
||||
)
|
||||
|
||||
// publishedTrack holds information about a track published to the room
|
||||
type publishedTrack struct {
|
||||
sourcePeerID string
|
||||
sourceUserID string
|
||||
localTrack *webrtc.TrackLocalStaticRTP
|
||||
remoteTrackSSRC uint32 // SSRC of the remote track (for PLI requests)
|
||||
remoteTrack *webrtc.TrackRemote // Reference to the remote track
|
||||
kind string
|
||||
forwarderActive bool // tracks if the RTP forwarder goroutine is still running
|
||||
packetCount int // number of packets forwarded (for debugging)
|
||||
|
||||
// Packet loss tracking for adaptive keyframe requests
|
||||
lastKeyframeRequest time.Time
|
||||
nackCount atomic.Int64 // Number of NACK packets received (indicates packet loss)
|
||||
}
|
||||
|
||||
// Room represents a WebRTC room with multiple participants
|
||||
type Room struct {
|
||||
ID string
|
||||
Namespace string
|
||||
|
||||
// Participants in the room
|
||||
peers map[string]*Peer
|
||||
peersMu sync.RWMutex
|
||||
|
||||
// Published tracks in the room (for sending to new joiners)
|
||||
publishedTracks map[string]*publishedTrack // key: trackID
|
||||
publishedTracksMu sync.RWMutex
|
||||
|
||||
// WebRTC API for creating peer connections
|
||||
api *webrtc.API
|
||||
|
||||
// Configuration
|
||||
config *Config
|
||||
logger *zap.Logger
|
||||
|
||||
// State
|
||||
closed bool
|
||||
closedMu sync.RWMutex
|
||||
|
||||
// Callbacks
|
||||
onEmpty func(*Room)
|
||||
}
|
||||
|
||||
// NewRoom creates a new room
|
||||
func NewRoom(id, namespace string, api *webrtc.API, config *Config, logger *zap.Logger) *Room {
|
||||
return &Room{
|
||||
ID: id,
|
||||
Namespace: namespace,
|
||||
peers: make(map[string]*Peer),
|
||||
publishedTracks: make(map[string]*publishedTrack),
|
||||
api: api,
|
||||
config: config,
|
||||
logger: logger.With(zap.String("room_id", id)),
|
||||
}
|
||||
}
|
||||
|
||||
// AddPeer adds a new peer to the room
|
||||
func (r *Room) AddPeer(peer *Peer) error {
|
||||
r.closedMu.RLock()
|
||||
if r.closed {
|
||||
r.closedMu.RUnlock()
|
||||
return ErrRoomClosed
|
||||
}
|
||||
r.closedMu.RUnlock()
|
||||
|
||||
r.peersMu.Lock()
|
||||
|
||||
// Check max participants
|
||||
if r.config.MaxParticipants > 0 && len(r.peers) >= r.config.MaxParticipants {
|
||||
r.peersMu.Unlock()
|
||||
return ErrRoomFull
|
||||
}
|
||||
|
||||
// Initialize peer connection
|
||||
pcConfig := webrtc.Configuration{
|
||||
ICEServers: r.config.ICEServers,
|
||||
}
|
||||
|
||||
if err := peer.InitPeerConnection(r.api, pcConfig); err != nil {
|
||||
r.peersMu.Unlock()
|
||||
return err
|
||||
}
|
||||
|
||||
// Set up peer close handler
|
||||
peer.OnClose(func(p *Peer) {
|
||||
r.RemovePeer(p.ID)
|
||||
})
|
||||
|
||||
r.peers[peer.ID] = peer
|
||||
peerInfo := peer.GetInfo() // Get info while holding lock
|
||||
totalPeers := len(r.peers)
|
||||
|
||||
// Release lock BEFORE broadcasting to avoid deadlock
|
||||
// (broadcastMessage also acquires the lock)
|
||||
r.peersMu.Unlock()
|
||||
|
||||
r.logger.Info("Peer added to room",
|
||||
zap.String("peer_id", peer.ID),
|
||||
zap.String("user_id", peer.UserID),
|
||||
zap.Int("total_peers", totalPeers),
|
||||
)
|
||||
|
||||
// Notify other peers (now safe since we released the lock)
|
||||
r.broadcastMessage(peer.ID, NewServerMessage(MessageTypeParticipantJoined, &ParticipantJoinedData{
|
||||
Participant: peerInfo,
|
||||
}))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemovePeer removes a peer from the room
|
||||
func (r *Room) RemovePeer(peerID string) error {
|
||||
r.peersMu.Lock()
|
||||
peer, ok := r.peers[peerID]
|
||||
if !ok {
|
||||
r.peersMu.Unlock()
|
||||
return ErrPeerNotFound
|
||||
}
|
||||
|
||||
delete(r.peers, peerID)
|
||||
remainingPeers := len(r.peers)
|
||||
r.peersMu.Unlock()
|
||||
|
||||
// Remove tracks published by this peer
|
||||
r.publishedTracksMu.Lock()
|
||||
removedTracks := make([]string, 0)
|
||||
for trackID, track := range r.publishedTracks {
|
||||
if track.sourcePeerID == peerID {
|
||||
delete(r.publishedTracks, trackID)
|
||||
removedTracks = append(removedTracks, trackID)
|
||||
}
|
||||
}
|
||||
r.publishedTracksMu.Unlock()
|
||||
|
||||
if len(removedTracks) > 0 {
|
||||
r.logger.Info("Removed tracks for departing peer",
|
||||
zap.String("peer_id", peerID),
|
||||
zap.Strings("track_ids", removedTracks),
|
||||
)
|
||||
}
|
||||
|
||||
// Close the peer
|
||||
if err := peer.Close(); err != nil {
|
||||
r.logger.Warn("Error closing peer",
|
||||
zap.String("peer_id", peerID),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
|
||||
r.logger.Info("Peer removed from room",
|
||||
zap.String("peer_id", peerID),
|
||||
zap.Int("remaining_peers", remainingPeers),
|
||||
)
|
||||
|
||||
// Notify other peers
|
||||
r.broadcastMessage(peerID, NewServerMessage(MessageTypeParticipantLeft, &ParticipantLeftData{
|
||||
ParticipantID: peerID,
|
||||
}))
|
||||
|
||||
// Check if room is empty
|
||||
if remainingPeers == 0 && r.onEmpty != nil {
|
||||
r.onEmpty(r)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetPeer returns a peer by ID
|
||||
func (r *Room) GetPeer(peerID string) (*Peer, error) {
|
||||
r.peersMu.RLock()
|
||||
defer r.peersMu.RUnlock()
|
||||
|
||||
peer, ok := r.peers[peerID]
|
||||
if !ok {
|
||||
return nil, ErrPeerNotFound
|
||||
}
|
||||
|
||||
return peer, nil
|
||||
}
|
||||
|
||||
// GetPeers returns all peers in the room
|
||||
func (r *Room) GetPeers() []*Peer {
|
||||
r.peersMu.RLock()
|
||||
defer r.peersMu.RUnlock()
|
||||
|
||||
peers := make([]*Peer, 0, len(r.peers))
|
||||
for _, peer := range r.peers {
|
||||
peers = append(peers, peer)
|
||||
}
|
||||
return peers
|
||||
}
|
||||
|
||||
// GetParticipants returns info about all participants
|
||||
func (r *Room) GetParticipants() []ParticipantInfo {
|
||||
r.peersMu.RLock()
|
||||
defer r.peersMu.RUnlock()
|
||||
|
||||
participants := make([]ParticipantInfo, 0, len(r.peers))
|
||||
for _, peer := range r.peers {
|
||||
participants = append(participants, peer.GetInfo())
|
||||
}
|
||||
return participants
|
||||
}
|
||||
|
||||
// GetParticipantCount returns the number of participants
|
||||
func (r *Room) GetParticipantCount() int {
|
||||
r.peersMu.RLock()
|
||||
defer r.peersMu.RUnlock()
|
||||
return len(r.peers)
|
||||
}
|
||||
|
||||
// SendExistingTracksTo sends all existing tracks from other participants to the specified peer.
|
||||
// This should be called AFTER the welcome message is sent to ensure the client is ready.
|
||||
// Uses batch mode to send all tracks with a single renegotiation for faster joins.
|
||||
func (r *Room) SendExistingTracksTo(peer *Peer) {
|
||||
r.publishedTracksMu.RLock()
|
||||
existingTracks := make([]*publishedTrack, 0, len(r.publishedTracks))
|
||||
for _, track := range r.publishedTracks {
|
||||
// Don't send peer's own tracks back to them
|
||||
if track.sourcePeerID != peer.ID {
|
||||
existingTracks = append(existingTracks, track)
|
||||
}
|
||||
}
|
||||
r.publishedTracksMu.RUnlock()
|
||||
|
||||
if len(existingTracks) == 0 {
|
||||
r.logger.Info("No existing tracks to send to new peer",
|
||||
zap.String("peer_id", peer.ID),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
r.logger.Info("Sending existing tracks to new peer (batch mode)",
|
||||
zap.String("peer_id", peer.ID),
|
||||
zap.Int("track_count", len(existingTracks)),
|
||||
)
|
||||
|
||||
videoTrackIDs := make([]string, 0)
|
||||
|
||||
// Start batch mode - suppresses individual renegotiations
|
||||
peer.StartTrackBatch()
|
||||
|
||||
for _, track := range existingTracks {
|
||||
// Log forwarder status to help diagnose video issues
|
||||
r.logger.Info("Adding existing track to new peer",
|
||||
zap.String("new_peer_id", peer.ID),
|
||||
zap.String("source_peer_id", track.sourcePeerID),
|
||||
zap.String("source_user_id", track.sourceUserID),
|
||||
zap.String("track_id", track.localTrack.ID()),
|
||||
zap.String("kind", track.kind),
|
||||
zap.Bool("forwarder_active", track.forwarderActive),
|
||||
zap.Int("packets_forwarded", track.packetCount),
|
||||
)
|
||||
|
||||
// Warn if forwarder is no longer active
|
||||
if !track.forwarderActive {
|
||||
r.logger.Warn("WARNING: Track forwarder is NOT active - track may not receive data",
|
||||
zap.String("track_id", track.localTrack.ID()),
|
||||
zap.String("kind", track.kind),
|
||||
zap.Int("final_packet_count", track.packetCount),
|
||||
)
|
||||
}
|
||||
|
||||
if _, err := peer.AddTrack(track.localTrack); err != nil {
|
||||
r.logger.Warn("Failed to add existing track to new peer",
|
||||
zap.String("peer_id", peer.ID),
|
||||
zap.String("track_id", track.localTrack.ID()),
|
||||
zap.Error(err),
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
// Track video tracks for keyframe requests
|
||||
if track.kind == "video" {
|
||||
videoTrackIDs = append(videoTrackIDs, track.localTrack.ID())
|
||||
}
|
||||
|
||||
// Notify new peer about the existing track
|
||||
peer.SendMessage(NewServerMessage(MessageTypeTrackAdded, &TrackAddedData{
|
||||
ParticipantID: track.sourcePeerID,
|
||||
UserID: track.sourceUserID,
|
||||
TrackID: track.localTrack.ID(),
|
||||
StreamID: track.localTrack.StreamID(),
|
||||
Kind: track.kind,
|
||||
}))
|
||||
}
|
||||
|
||||
// End batch mode - triggers single renegotiation for all tracks
|
||||
peer.EndTrackBatch()
|
||||
|
||||
r.logger.Info("Batch track addition complete - single renegotiation triggered",
|
||||
zap.String("peer_id", peer.ID),
|
||||
zap.Int("total_tracks", len(existingTracks)),
|
||||
)
|
||||
|
||||
// Request keyframes for video tracks after a short delay
|
||||
// This ensures the receiver has time to set up the track before receiving the keyframe
|
||||
if len(videoTrackIDs) > 0 {
|
||||
go func() {
|
||||
// Wait for negotiation to complete
|
||||
time.Sleep(300 * time.Millisecond)
|
||||
r.logger.Info("Requesting keyframes for new peer",
|
||||
zap.String("peer_id", peer.ID),
|
||||
zap.Int("video_track_count", len(videoTrackIDs)),
|
||||
)
|
||||
for _, trackID := range videoTrackIDs {
|
||||
r.RequestKeyframe(trackID)
|
||||
}
|
||||
// Request again after 500ms in case the first was too early
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
for _, trackID := range videoTrackIDs {
|
||||
r.RequestKeyframe(trackID)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
// BroadcastTrack broadcasts a track from one peer to all other peers
|
||||
func (r *Room) BroadcastTrack(sourcePeerID string, track *webrtc.TrackRemote) {
|
||||
r.peersMu.RLock()
|
||||
defer r.peersMu.RUnlock()
|
||||
|
||||
// Get source peer's user ID for the track-added message
|
||||
sourceUserID := ""
|
||||
if sourcePeer, ok := r.peers[sourcePeerID]; ok {
|
||||
sourceUserID = sourcePeer.UserID
|
||||
}
|
||||
|
||||
// Create a local track from the remote track
|
||||
// Use participant ID as the stream ID so clients can identify the source
|
||||
// Format: trackId stays the same, streamId = sourcePeerID (or sourceUserID if available)
|
||||
streamID := sourcePeerID
|
||||
if sourceUserID != "" {
|
||||
streamID = sourceUserID // Use userID for easier client-side mapping
|
||||
}
|
||||
|
||||
r.logger.Info("Creating local track for broadcast",
|
||||
zap.String("source_peer_id", sourcePeerID),
|
||||
zap.String("source_user_id", sourceUserID),
|
||||
zap.String("original_track_id", track.ID()),
|
||||
zap.String("original_stream_id", track.StreamID()),
|
||||
zap.String("new_stream_id", streamID),
|
||||
)
|
||||
|
||||
// Log codec information for debugging
|
||||
codec := track.Codec()
|
||||
r.logger.Info("Track codec info",
|
||||
zap.String("source_peer_id", sourcePeerID),
|
||||
zap.String("track_kind", track.Kind().String()),
|
||||
zap.String("mime_type", codec.MimeType),
|
||||
zap.Uint32("clock_rate", codec.ClockRate),
|
||||
zap.Uint16("channels", codec.Channels),
|
||||
zap.String("sdp_fmtp_line", codec.SDPFmtpLine),
|
||||
)
|
||||
|
||||
localTrack, err := webrtc.NewTrackLocalStaticRTP(
|
||||
codec.RTPCodecCapability,
|
||||
track.Kind().String()+"-"+sourcePeerID, // Include peer ID in track ID
|
||||
streamID, // Use participant/user ID as stream ID
|
||||
)
|
||||
if err != nil {
|
||||
r.logger.Error("Failed to create local track",
|
||||
zap.String("source_peer", sourcePeerID),
|
||||
zap.String("track_id", track.ID()),
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
// Store the track for new joiners
|
||||
pubTrack := &publishedTrack{
|
||||
sourcePeerID: sourcePeerID,
|
||||
sourceUserID: sourceUserID,
|
||||
localTrack: localTrack,
|
||||
remoteTrackSSRC: uint32(track.SSRC()),
|
||||
remoteTrack: track,
|
||||
kind: track.Kind().String(),
|
||||
forwarderActive: true,
|
||||
packetCount: 0,
|
||||
lastKeyframeRequest: time.Now(),
|
||||
}
|
||||
r.publishedTracksMu.Lock()
|
||||
r.publishedTracks[localTrack.ID()] = pubTrack
|
||||
r.publishedTracksMu.Unlock()
|
||||
|
||||
r.logger.Info("Track stored for new joiners",
|
||||
zap.String("track_id", localTrack.ID()),
|
||||
zap.String("source_peer_id", sourcePeerID),
|
||||
zap.Int("total_published_tracks", len(r.publishedTracks)),
|
||||
)
|
||||
|
||||
// Forward RTP packets from remote track to local track
|
||||
go func() {
|
||||
trackID := track.ID()
|
||||
localTrackID := localTrack.ID()
|
||||
trackKind := track.Kind().String()
|
||||
buf := make([]byte, 1600) // Slightly larger than MTU to handle RTP extensions
|
||||
packetCount := 0
|
||||
byteCount := 0
|
||||
startTime := time.Now()
|
||||
firstPacketReceived := false
|
||||
|
||||
r.logger.Info("RTP forwarder started",
|
||||
zap.String("track_id", trackID),
|
||||
zap.String("local_track_id", localTrackID),
|
||||
zap.String("kind", trackKind),
|
||||
zap.String("source_peer_id", sourcePeerID),
|
||||
)
|
||||
|
||||
// Start a goroutine to log warning if no packets received after 5 seconds
|
||||
go func() {
|
||||
time.Sleep(5 * time.Second)
|
||||
if !firstPacketReceived {
|
||||
r.logger.Warn("RTP forwarder WARNING: No packets received after 5 seconds - host may not be sending",
|
||||
zap.String("track_id", trackID),
|
||||
zap.String("local_track_id", localTrackID),
|
||||
zap.String("kind", trackKind),
|
||||
zap.String("source_peer_id", sourcePeerID),
|
||||
)
|
||||
}
|
||||
}()
|
||||
|
||||
// For video tracks, use adaptive keyframe requests based on packet loss
|
||||
if trackKind == "video" {
|
||||
go func() {
|
||||
// Use a faster ticker for checking, but only send keyframes when needed
|
||||
ticker := time.NewTicker(500 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
var lastNackCount int64
|
||||
var consecutiveLossDetections int
|
||||
baseInterval := 3 * time.Second
|
||||
minInterval := 500 * time.Millisecond // Minimum interval between keyframes
|
||||
lastKeyframeTime := time.Now()
|
||||
|
||||
for range ticker.C {
|
||||
// Check if room is closed
|
||||
if r.IsClosed() {
|
||||
return
|
||||
}
|
||||
|
||||
// Check if forwarder is still active
|
||||
r.publishedTracksMu.RLock()
|
||||
pt, ok := r.publishedTracks[localTrackID]
|
||||
if !ok || !pt.forwarderActive {
|
||||
r.publishedTracksMu.RUnlock()
|
||||
return
|
||||
}
|
||||
|
||||
// Get current NACK count to detect packet loss
|
||||
currentNackCount := pt.nackCount.Load()
|
||||
r.publishedTracksMu.RUnlock()
|
||||
|
||||
timeSinceLastKeyframe := time.Since(lastKeyframeTime)
|
||||
|
||||
// Detect if packet loss is happening (NACKs increasing)
|
||||
if currentNackCount > lastNackCount {
|
||||
consecutiveLossDetections++
|
||||
lastNackCount = currentNackCount
|
||||
|
||||
// If we detect packet loss and haven't requested a keyframe recently,
|
||||
// request one immediately to help receivers recover
|
||||
if timeSinceLastKeyframe >= minInterval {
|
||||
r.logger.Debug("Adaptive keyframe request due to packet loss",
|
||||
zap.String("track_id", localTrackID),
|
||||
zap.Int64("nack_count", currentNackCount),
|
||||
zap.Int("consecutive_loss_detections", consecutiveLossDetections),
|
||||
)
|
||||
r.RequestKeyframe(localTrackID)
|
||||
lastKeyframeTime = time.Now()
|
||||
}
|
||||
} else {
|
||||
// Reset consecutive loss counter when no new NACKs
|
||||
if consecutiveLossDetections > 0 {
|
||||
consecutiveLossDetections--
|
||||
}
|
||||
}
|
||||
|
||||
// Regular keyframe request at base interval (regardless of loss)
|
||||
if timeSinceLastKeyframe >= baseInterval {
|
||||
r.RequestKeyframe(localTrackID)
|
||||
lastKeyframeTime = time.Now()
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Mark forwarder as stopped when we exit
|
||||
defer func() {
|
||||
r.publishedTracksMu.Lock()
|
||||
if pt, ok := r.publishedTracks[localTrackID]; ok {
|
||||
pt.forwarderActive = false
|
||||
pt.packetCount = packetCount
|
||||
}
|
||||
r.publishedTracksMu.Unlock()
|
||||
|
||||
r.logger.Info("RTP forwarder exiting",
|
||||
zap.String("track_id", trackID),
|
||||
zap.String("local_track_id", localTrackID),
|
||||
zap.String("kind", trackKind),
|
||||
zap.Duration("lifetime", time.Since(startTime)),
|
||||
zap.Int("total_packets", packetCount),
|
||||
zap.Int("total_bytes", byteCount),
|
||||
)
|
||||
}()
|
||||
|
||||
for {
|
||||
n, _, readErr := track.Read(buf)
|
||||
if readErr != nil {
|
||||
r.logger.Info("RTP forwarder stopped - read error",
|
||||
zap.String("track_id", trackID),
|
||||
zap.String("local_track_id", localTrackID),
|
||||
zap.String("kind", trackKind),
|
||||
zap.Int("packets_forwarded", packetCount),
|
||||
zap.Int("bytes_forwarded", byteCount),
|
||||
zap.Error(readErr),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
// Log first packet to confirm data is flowing
|
||||
if packetCount == 0 {
|
||||
firstPacketReceived = true
|
||||
r.logger.Info("RTP forwarder received FIRST packet",
|
||||
zap.String("track_id", trackID),
|
||||
zap.String("local_track_id", localTrackID),
|
||||
zap.String("kind", trackKind),
|
||||
zap.Int("packet_size", n),
|
||||
zap.Duration("time_to_first_packet", time.Since(startTime)),
|
||||
)
|
||||
}
|
||||
|
||||
if _, writeErr := localTrack.Write(buf[:n]); writeErr != nil {
|
||||
r.logger.Info("RTP forwarder stopped - write error",
|
||||
zap.String("track_id", trackID),
|
||||
zap.String("local_track_id", localTrackID),
|
||||
zap.String("kind", trackKind),
|
||||
zap.Int("packets_forwarded", packetCount),
|
||||
zap.Int("bytes_forwarded", byteCount),
|
||||
zap.Error(writeErr),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
packetCount++
|
||||
byteCount += n
|
||||
|
||||
// Update packet count periodically (not every packet to reduce lock contention)
|
||||
if packetCount%50 == 0 {
|
||||
r.publishedTracksMu.Lock()
|
||||
if pt, ok := r.publishedTracks[localTrackID]; ok {
|
||||
pt.packetCount = packetCount
|
||||
}
|
||||
r.publishedTracksMu.Unlock()
|
||||
}
|
||||
|
||||
// Log progress every 100 packets for video, 500 for audio
|
||||
logInterval := 500
|
||||
if trackKind == "video" {
|
||||
logInterval = 100
|
||||
}
|
||||
if packetCount%logInterval == 0 {
|
||||
r.logger.Info("RTP forwarder progress",
|
||||
zap.String("track_id", trackID),
|
||||
zap.String("kind", trackKind),
|
||||
zap.Int("packets", packetCount),
|
||||
zap.Int("bytes", byteCount),
|
||||
)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Add track to all other peers
|
||||
for peerID, peer := range r.peers {
|
||||
if peerID == sourcePeerID {
|
||||
continue
|
||||
}
|
||||
|
||||
r.logger.Info("Adding track to peer",
|
||||
zap.String("target_peer", peerID),
|
||||
zap.String("source_peer", sourcePeerID),
|
||||
zap.String("source_user", sourceUserID),
|
||||
zap.String("track_id", localTrack.ID()),
|
||||
zap.String("stream_id", localTrack.StreamID()),
|
||||
)
|
||||
|
||||
if _, err := peer.AddTrack(localTrack); err != nil {
|
||||
r.logger.Warn("Failed to add track to peer",
|
||||
zap.String("target_peer", peerID),
|
||||
zap.String("track_id", localTrack.ID()),
|
||||
zap.Error(err),
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
// Notify peer about new track
|
||||
// Use consistent IDs: participantId=sourcePeerID, streamId matches the track
|
||||
peer.SendMessage(NewServerMessage(MessageTypeTrackAdded, &TrackAddedData{
|
||||
ParticipantID: sourcePeerID,
|
||||
UserID: sourceUserID,
|
||||
TrackID: localTrack.ID(), // Format: "{kind}-{participantId}"
|
||||
StreamID: localTrack.StreamID(), // Same as userId for easy matching
|
||||
Kind: track.Kind().String(),
|
||||
}))
|
||||
}
|
||||
|
||||
r.logger.Info("Track broadcast to room",
|
||||
zap.String("source_peer", sourcePeerID),
|
||||
zap.String("source_user", sourceUserID),
|
||||
zap.String("track_id", track.ID()),
|
||||
zap.String("kind", track.Kind().String()),
|
||||
)
|
||||
}
|
||||
|
||||
// broadcastMessage sends a message to all peers except the specified one
|
||||
func (r *Room) broadcastMessage(excludePeerID string, msg *ServerMessage) {
|
||||
r.peersMu.RLock()
|
||||
defer r.peersMu.RUnlock()
|
||||
|
||||
for peerID, peer := range r.peers {
|
||||
if peerID == excludePeerID {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := peer.SendMessage(msg); err != nil {
|
||||
r.logger.Warn("Failed to send message to peer",
|
||||
zap.String("peer_id", peerID),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes the room and all peer connections
|
||||
func (r *Room) Close() error {
|
||||
r.closedMu.Lock()
|
||||
if r.closed {
|
||||
r.closedMu.Unlock()
|
||||
return nil
|
||||
}
|
||||
r.closed = true
|
||||
r.closedMu.Unlock()
|
||||
|
||||
r.logger.Info("Closing room")
|
||||
|
||||
r.peersMu.Lock()
|
||||
peers := make([]*Peer, 0, len(r.peers))
|
||||
for _, peer := range r.peers {
|
||||
peers = append(peers, peer)
|
||||
}
|
||||
r.peers = make(map[string]*Peer)
|
||||
r.peersMu.Unlock()
|
||||
|
||||
// Close all peers
|
||||
for _, peer := range peers {
|
||||
if err := peer.Close(); err != nil {
|
||||
r.logger.Warn("Error closing peer",
|
||||
zap.String("peer_id", peer.ID),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// OnEmpty sets a callback for when the room becomes empty
|
||||
func (r *Room) OnEmpty(fn func(*Room)) {
|
||||
r.onEmpty = fn
|
||||
}
|
||||
|
||||
// IsClosed returns whether the room is closed
|
||||
func (r *Room) IsClosed() bool {
|
||||
r.closedMu.RLock()
|
||||
defer r.closedMu.RUnlock()
|
||||
return r.closed
|
||||
}
|
||||
|
||||
// RequestKeyframe sends a PLI (Picture Loss Indication) to the source peer for a video track.
|
||||
// This causes the source to send a keyframe, which is needed for new receivers to start decoding.
|
||||
func (r *Room) RequestKeyframe(trackID string) {
|
||||
r.publishedTracksMu.RLock()
|
||||
track, ok := r.publishedTracks[trackID]
|
||||
r.publishedTracksMu.RUnlock()
|
||||
|
||||
if !ok || track.kind != "video" {
|
||||
return
|
||||
}
|
||||
|
||||
r.peersMu.RLock()
|
||||
sourcePeer, ok := r.peers[track.sourcePeerID]
|
||||
r.peersMu.RUnlock()
|
||||
|
||||
if !ok || sourcePeer.pc == nil {
|
||||
r.logger.Debug("Cannot request keyframe - source peer not found",
|
||||
zap.String("track_id", trackID),
|
||||
zap.String("source_peer_id", track.sourcePeerID),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
// Create a PLI packet
|
||||
pli := &rtcp.PictureLossIndication{
|
||||
MediaSSRC: track.remoteTrackSSRC,
|
||||
}
|
||||
|
||||
// Send the PLI to the source peer
|
||||
if err := sourcePeer.pc.WriteRTCP([]rtcp.Packet{pli}); err != nil {
|
||||
r.logger.Debug("Failed to send PLI",
|
||||
zap.String("track_id", trackID),
|
||||
zap.String("source_peer_id", track.sourcePeerID),
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
r.logger.Debug("PLI keyframe request sent",
|
||||
zap.String("track_id", trackID),
|
||||
zap.String("source_peer_id", track.sourcePeerID),
|
||||
zap.Uint32("ssrc", track.remoteTrackSSRC),
|
||||
)
|
||||
}
|
||||
|
||||
// RequestKeyframeForAllVideoTracks sends PLI requests for all video tracks in the room.
|
||||
// This is useful when a new peer joins to ensure they get keyframes quickly.
|
||||
func (r *Room) RequestKeyframeForAllVideoTracks() {
|
||||
r.publishedTracksMu.RLock()
|
||||
videoTrackIDs := make([]string, 0)
|
||||
for trackID, track := range r.publishedTracks {
|
||||
if track.kind == "video" {
|
||||
videoTrackIDs = append(videoTrackIDs, trackID)
|
||||
}
|
||||
}
|
||||
r.publishedTracksMu.RUnlock()
|
||||
|
||||
for _, trackID := range videoTrackIDs {
|
||||
r.RequestKeyframe(trackID)
|
||||
}
|
||||
}
|
||||
|
||||
// IncrementNackCount increments the NACK counter for a track.
|
||||
// This is called when we receive NACK feedback indicating packet loss.
|
||||
func (r *Room) IncrementNackCount(trackID string) {
|
||||
r.publishedTracksMu.RLock()
|
||||
track, ok := r.publishedTracks[trackID]
|
||||
r.publishedTracksMu.RUnlock()
|
||||
|
||||
if ok {
|
||||
track.nackCount.Add(1)
|
||||
}
|
||||
}
|
||||
145
pkg/gateway/sfu/signaling.go
Normal file
145
pkg/gateway/sfu/signaling.go
Normal file
@ -0,0 +1,145 @@
|
||||
package sfu
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
"github.com/pion/webrtc/v4"
|
||||
)
|
||||
|
||||
// MessageType represents the type of signaling message
|
||||
type MessageType string
|
||||
|
||||
const (
|
||||
// Client -> Server message types
|
||||
MessageTypeJoin MessageType = "join"
|
||||
MessageTypeLeave MessageType = "leave"
|
||||
MessageTypeOffer MessageType = "offer"
|
||||
MessageTypeAnswer MessageType = "answer"
|
||||
MessageTypeICECandidate MessageType = "ice-candidate"
|
||||
MessageTypeMute MessageType = "mute"
|
||||
MessageTypeUnmute MessageType = "unmute"
|
||||
MessageTypeStartVideo MessageType = "start-video"
|
||||
MessageTypeStopVideo MessageType = "stop-video"
|
||||
|
||||
// Server -> Client message types
|
||||
MessageTypeWelcome MessageType = "welcome"
|
||||
MessageTypeParticipantJoined MessageType = "participant-joined"
|
||||
MessageTypeParticipantLeft MessageType = "participant-left"
|
||||
MessageTypeTrackAdded MessageType = "track-added"
|
||||
MessageTypeTrackRemoved MessageType = "track-removed"
|
||||
MessageTypeError MessageType = "error"
|
||||
)
|
||||
|
||||
// ClientMessage represents a message from client to server
|
||||
type ClientMessage struct {
|
||||
Type MessageType `json:"type"`
|
||||
Data json.RawMessage `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
// ServerMessage represents a message from server to client
|
||||
type ServerMessage struct {
|
||||
Type MessageType `json:"type"`
|
||||
Data interface{} `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
// JoinData is the payload for join messages
|
||||
type JoinData struct {
|
||||
DisplayName string `json:"displayName"`
|
||||
AudioOnly bool `json:"audioOnly,omitempty"`
|
||||
}
|
||||
|
||||
// OfferData is the payload for offer messages
|
||||
type OfferData struct {
|
||||
SDP string `json:"sdp"`
|
||||
}
|
||||
|
||||
// AnswerData is the payload for answer messages
|
||||
type AnswerData struct {
|
||||
SDP string `json:"sdp"`
|
||||
}
|
||||
|
||||
// ICECandidateData is the payload for ICE candidate messages
|
||||
type ICECandidateData struct {
|
||||
Candidate string `json:"candidate"`
|
||||
SDPMid string `json:"sdpMid,omitempty"`
|
||||
SDPMLineIndex uint16 `json:"sdpMLineIndex,omitempty"`
|
||||
UsernameFragment string `json:"usernameFragment,omitempty"`
|
||||
}
|
||||
|
||||
// ToWebRTCCandidate converts ICECandidateData to webrtc.ICECandidateInit
|
||||
func (c *ICECandidateData) ToWebRTCCandidate() webrtc.ICECandidateInit {
|
||||
return webrtc.ICECandidateInit{
|
||||
Candidate: c.Candidate,
|
||||
SDPMid: &c.SDPMid,
|
||||
SDPMLineIndex: &c.SDPMLineIndex,
|
||||
UsernameFragment: &c.UsernameFragment,
|
||||
}
|
||||
}
|
||||
|
||||
// WelcomeData is sent to a client when they successfully join
|
||||
type WelcomeData struct {
|
||||
ParticipantID string `json:"participantId"`
|
||||
RoomID string `json:"roomId"`
|
||||
Participants []ParticipantInfo `json:"participants"`
|
||||
}
|
||||
|
||||
// ParticipantInfo contains public information about a participant
|
||||
type ParticipantInfo struct {
|
||||
ID string `json:"id"`
|
||||
UserID string `json:"userId"`
|
||||
DisplayName string `json:"displayName"`
|
||||
HasAudio bool `json:"hasAudio"`
|
||||
HasVideo bool `json:"hasVideo"`
|
||||
AudioMuted bool `json:"audioMuted"`
|
||||
VideoMuted bool `json:"videoMuted"`
|
||||
}
|
||||
|
||||
// ParticipantJoinedData is sent when a new participant joins
|
||||
type ParticipantJoinedData struct {
|
||||
Participant ParticipantInfo `json:"participant"`
|
||||
}
|
||||
|
||||
// ParticipantLeftData is sent when a participant leaves
|
||||
type ParticipantLeftData struct {
|
||||
ParticipantID string `json:"participantId"`
|
||||
}
|
||||
|
||||
// TrackAddedData is sent when a new track is available
|
||||
type TrackAddedData struct {
|
||||
ParticipantID string `json:"participantId"` // Internal SFU peer ID
|
||||
UserID string `json:"userId"` // The user's actual ID for easier mapping
|
||||
TrackID string `json:"trackId"` // Format: "{kind}-{participantId}"
|
||||
StreamID string `json:"streamId"` // Same as userId (for WebRTC stream matching)
|
||||
Kind string `json:"kind"` // "audio" or "video"
|
||||
}
|
||||
|
||||
// TrackRemovedData is sent when a track is removed
|
||||
type TrackRemovedData struct {
|
||||
ParticipantID string `json:"participantId"`
|
||||
UserID string `json:"userId"` // The user's actual ID for easier mapping
|
||||
TrackID string `json:"trackId"`
|
||||
StreamID string `json:"streamId"`
|
||||
Kind string `json:"kind"`
|
||||
}
|
||||
|
||||
// ErrorData is sent when an error occurs
|
||||
type ErrorData struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// NewServerMessage creates a new server message
|
||||
func NewServerMessage(msgType MessageType, data interface{}) *ServerMessage {
|
||||
return &ServerMessage{
|
||||
Type: msgType,
|
||||
Data: data,
|
||||
}
|
||||
}
|
||||
|
||||
// NewErrorMessage creates a new error message
|
||||
func NewErrorMessage(code, message string) *ServerMessage {
|
||||
return NewServerMessage(MessageTypeError, &ErrorData{
|
||||
Code: code,
|
||||
Message: message,
|
||||
})
|
||||
}
|
||||
573
pkg/gateway/sfu_handlers.go
Normal file
573
pkg/gateway/sfu_handlers.go
Normal file
@ -0,0 +1,573 @@
|
||||
package gateway
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/DeBrosOfficial/network/pkg/gateway/sfu"
|
||||
"github.com/DeBrosOfficial/network/pkg/logging"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/pion/webrtc/v4"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// SFUManager is a type alias for the SFU room manager
|
||||
type SFUManager = sfu.RoomManager
|
||||
|
||||
// CreateRoomRequest is the request body for creating a room
|
||||
type CreateRoomRequest struct {
|
||||
RoomID string `json:"roomId"`
|
||||
}
|
||||
|
||||
// CreateRoomResponse is the response for creating/joining a room
|
||||
type CreateRoomResponse struct {
|
||||
RoomID string `json:"roomId"`
|
||||
Created bool `json:"created"`
|
||||
RTPCapabilities map[string]interface{} `json:"rtpCapabilities"`
|
||||
}
|
||||
|
||||
// JoinRoomRequest is the request body for joining a room
|
||||
type JoinRoomRequest struct {
|
||||
DisplayName string `json:"displayName"`
|
||||
}
|
||||
|
||||
// JoinRoomResponse is the response for joining a room
|
||||
type JoinRoomResponse struct {
|
||||
ParticipantID string `json:"participantId"`
|
||||
Participants []sfu.ParticipantInfo `json:"participants"`
|
||||
}
|
||||
|
||||
// sfuCreateRoomHandler handles POST /v1/sfu/room
|
||||
func (g *Gateway) sfuCreateRoomHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
|
||||
return
|
||||
}
|
||||
|
||||
// Check if SFU is enabled
|
||||
if g.sfuManager == nil {
|
||||
writeError(w, http.StatusServiceUnavailable, "SFU service not enabled")
|
||||
return
|
||||
}
|
||||
|
||||
// Get namespace from auth context
|
||||
ns := resolveNamespaceFromRequest(r)
|
||||
if ns == "" {
|
||||
writeError(w, http.StatusForbidden, "namespace not resolved")
|
||||
return
|
||||
}
|
||||
|
||||
// Parse request
|
||||
var req CreateRoomRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid request body")
|
||||
return
|
||||
}
|
||||
|
||||
if req.RoomID == "" {
|
||||
writeError(w, http.StatusBadRequest, "roomId is required")
|
||||
return
|
||||
}
|
||||
|
||||
// Create or get room
|
||||
room, created := g.sfuManager.GetOrCreateRoom(ns, req.RoomID)
|
||||
|
||||
g.logger.ComponentInfo(logging.ComponentGeneral, "SFU room request",
|
||||
zap.String("room_id", req.RoomID),
|
||||
zap.String("namespace", ns),
|
||||
zap.Bool("created", created),
|
||||
zap.Int("participants", room.GetParticipantCount()),
|
||||
)
|
||||
|
||||
writeJSON(w, http.StatusOK, &CreateRoomResponse{
|
||||
RoomID: req.RoomID,
|
||||
Created: created,
|
||||
RTPCapabilities: sfu.GetRTPCapabilities(),
|
||||
})
|
||||
}
|
||||
|
||||
// sfuRoomHandler handles all /v1/sfu/room/:roomId/* endpoints
|
||||
func (g *Gateway) sfuRoomHandler(w http.ResponseWriter, r *http.Request) {
|
||||
// Check if SFU is enabled
|
||||
if g.sfuManager == nil {
|
||||
writeError(w, http.StatusServiceUnavailable, "SFU service not enabled")
|
||||
return
|
||||
}
|
||||
|
||||
// Get namespace from auth context
|
||||
ns := resolveNamespaceFromRequest(r)
|
||||
if ns == "" {
|
||||
writeError(w, http.StatusForbidden, "namespace not resolved")
|
||||
return
|
||||
}
|
||||
|
||||
// Parse room ID and action from path
|
||||
// Path format: /v1/sfu/room/{roomId}/{action}
|
||||
path := strings.TrimPrefix(r.URL.Path, "/v1/sfu/room/")
|
||||
parts := strings.SplitN(path, "/", 2)
|
||||
|
||||
if len(parts) == 0 || parts[0] == "" {
|
||||
writeError(w, http.StatusBadRequest, "room ID required")
|
||||
return
|
||||
}
|
||||
|
||||
roomID := parts[0]
|
||||
action := ""
|
||||
if len(parts) > 1 {
|
||||
action = parts[1]
|
||||
}
|
||||
|
||||
// Route to appropriate handler
|
||||
switch action {
|
||||
case "":
|
||||
// GET /v1/sfu/room/:roomId - Get room info
|
||||
g.sfuGetRoomHandler(w, r, ns, roomID)
|
||||
case "join":
|
||||
// POST /v1/sfu/room/:roomId/join - Join room
|
||||
g.sfuJoinRoomHandler(w, r, ns, roomID)
|
||||
case "leave":
|
||||
// POST /v1/sfu/room/:roomId/leave - Leave room
|
||||
g.sfuLeaveRoomHandler(w, r, ns, roomID)
|
||||
case "participants":
|
||||
// GET /v1/sfu/room/:roomId/participants - List participants
|
||||
g.sfuParticipantsHandler(w, r, ns, roomID)
|
||||
case "ws":
|
||||
// GET /v1/sfu/room/:roomId/ws - WebSocket signaling
|
||||
g.sfuWebSocketHandler(w, r, ns, roomID)
|
||||
default:
|
||||
writeError(w, http.StatusNotFound, "unknown action")
|
||||
}
|
||||
}
|
||||
|
||||
// sfuGetRoomHandler handles GET /v1/sfu/room/:roomId
|
||||
func (g *Gateway) sfuGetRoomHandler(w http.ResponseWriter, r *http.Request, ns, roomID string) {
|
||||
if r.Method != http.MethodGet {
|
||||
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
|
||||
return
|
||||
}
|
||||
|
||||
info, err := g.sfuManager.GetRoomInfo(ns, roomID)
|
||||
if err != nil {
|
||||
if err == sfu.ErrRoomNotFound {
|
||||
writeError(w, http.StatusNotFound, "room not found")
|
||||
} else {
|
||||
writeError(w, http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, info)
|
||||
}
|
||||
|
||||
// sfuJoinRoomHandler handles POST /v1/sfu/room/:roomId/join
|
||||
func (g *Gateway) sfuJoinRoomHandler(w http.ResponseWriter, r *http.Request, ns, roomID string) {
|
||||
if r.Method != http.MethodPost {
|
||||
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
|
||||
return
|
||||
}
|
||||
|
||||
// Parse request
|
||||
var req JoinRoomRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
// Default display name if not provided
|
||||
req.DisplayName = "Anonymous"
|
||||
}
|
||||
|
||||
// Get user ID from query param or auth context
|
||||
userID := r.URL.Query().Get("userId")
|
||||
if userID == "" {
|
||||
userID = g.extractUserID(r)
|
||||
}
|
||||
if userID == "" {
|
||||
userID = "anonymous"
|
||||
}
|
||||
|
||||
// Get or create room
|
||||
room, _ := g.sfuManager.GetOrCreateRoom(ns, roomID)
|
||||
|
||||
// This endpoint just returns room info
|
||||
// The actual peer connection is established via WebSocket
|
||||
writeJSON(w, http.StatusOK, &JoinRoomResponse{
|
||||
ParticipantID: "", // Will be assigned when WebSocket connects
|
||||
Participants: room.GetParticipants(),
|
||||
})
|
||||
}
|
||||
|
||||
// sfuLeaveRoomHandler handles POST /v1/sfu/room/:roomId/leave
|
||||
func (g *Gateway) sfuLeaveRoomHandler(w http.ResponseWriter, r *http.Request, ns, roomID string) {
|
||||
if r.Method != http.MethodPost {
|
||||
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
|
||||
return
|
||||
}
|
||||
|
||||
// Parse participant ID from request body
|
||||
var req struct {
|
||||
ParticipantID string `json:"participantId"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil || req.ParticipantID == "" {
|
||||
writeError(w, http.StatusBadRequest, "participantId required")
|
||||
return
|
||||
}
|
||||
|
||||
room, err := g.sfuManager.GetRoom(ns, roomID)
|
||||
if err != nil {
|
||||
if err == sfu.ErrRoomNotFound {
|
||||
writeError(w, http.StatusNotFound, "room not found")
|
||||
} else {
|
||||
writeError(w, http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err := room.RemovePeer(req.ParticipantID); err != nil {
|
||||
if err == sfu.ErrPeerNotFound {
|
||||
writeError(w, http.StatusNotFound, "participant not found")
|
||||
} else {
|
||||
writeError(w, http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, map[string]interface{}{
|
||||
"success": true,
|
||||
})
|
||||
}
|
||||
|
||||
// sfuParticipantsHandler handles GET /v1/sfu/room/:roomId/participants
|
||||
func (g *Gateway) sfuParticipantsHandler(w http.ResponseWriter, r *http.Request, ns, roomID string) {
|
||||
if r.Method != http.MethodGet {
|
||||
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
|
||||
return
|
||||
}
|
||||
|
||||
room, err := g.sfuManager.GetRoom(ns, roomID)
|
||||
if err != nil {
|
||||
if err == sfu.ErrRoomNotFound {
|
||||
writeError(w, http.StatusNotFound, "room not found")
|
||||
} else {
|
||||
writeError(w, http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, map[string]interface{}{
|
||||
"participants": room.GetParticipants(),
|
||||
})
|
||||
}
|
||||
|
||||
// sfuWebSocketHandler handles WebSocket signaling for a room
|
||||
func (g *Gateway) sfuWebSocketHandler(w http.ResponseWriter, r *http.Request, ns, roomID string) {
|
||||
if r.Method != http.MethodGet {
|
||||
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
|
||||
return
|
||||
}
|
||||
|
||||
// Get user ID and display name from query parameters
|
||||
// Priority: 1) userId query param, 2) JWT/API key auth, 3) "anonymous"
|
||||
userID := r.URL.Query().Get("userId")
|
||||
if userID == "" {
|
||||
// Fall back to authentication-based user ID
|
||||
userID = g.extractUserID(r)
|
||||
}
|
||||
if userID == "" {
|
||||
userID = "anonymous"
|
||||
}
|
||||
|
||||
displayName := r.URL.Query().Get("displayName")
|
||||
if displayName == "" {
|
||||
displayName = "Anonymous"
|
||||
}
|
||||
|
||||
// Upgrade to WebSocket
|
||||
conn, err := wsUpgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
g.logger.ComponentWarn(logging.ComponentGeneral, "SFU WebSocket upgrade failed",
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
// Recover from panics to avoid silent crashes
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
g.logger.ComponentError(logging.ComponentGeneral, "SFU WebSocket handler panic",
|
||||
zap.Any("panic", r),
|
||||
zap.String("room_id", roomID),
|
||||
zap.String("user_id", userID),
|
||||
)
|
||||
conn.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
g.logger.ComponentInfo(logging.ComponentGeneral, "SFU WebSocket connected",
|
||||
zap.String("room_id", roomID),
|
||||
zap.String("namespace", ns),
|
||||
zap.String("user_id", userID),
|
||||
zap.String("display_name", displayName),
|
||||
)
|
||||
|
||||
// Get or create room
|
||||
g.logger.ComponentDebug(logging.ComponentGeneral, "SFU getting/creating room",
|
||||
zap.String("room_id", roomID),
|
||||
zap.String("namespace", ns),
|
||||
)
|
||||
room, created := g.sfuManager.GetOrCreateRoom(ns, roomID)
|
||||
g.logger.ComponentInfo(logging.ComponentGeneral, "SFU room ready",
|
||||
zap.String("room_id", roomID),
|
||||
zap.Bool("created", created),
|
||||
)
|
||||
|
||||
// Create peer
|
||||
peer := sfu.NewPeer(userID, displayName, conn, room, g.logger.Logger)
|
||||
g.logger.ComponentInfo(logging.ComponentGeneral, "SFU peer created",
|
||||
zap.String("peer_id", peer.ID),
|
||||
)
|
||||
|
||||
// Add peer to room
|
||||
g.logger.ComponentDebug(logging.ComponentGeneral, "SFU adding peer to room (will init peer connection)",
|
||||
zap.String("peer_id", peer.ID),
|
||||
zap.String("room_id", roomID),
|
||||
)
|
||||
if err := room.AddPeer(peer); err != nil {
|
||||
g.logger.ComponentError(logging.ComponentGeneral, "Failed to add peer to room",
|
||||
zap.String("room_id", roomID),
|
||||
zap.Error(err),
|
||||
)
|
||||
peer.SendMessage(sfu.NewErrorMessage("join_failed", err.Error()))
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
g.logger.ComponentInfo(logging.ComponentGeneral, "SFU peer added to room successfully",
|
||||
zap.String("peer_id", peer.ID),
|
||||
zap.String("room_id", roomID),
|
||||
)
|
||||
|
||||
// Send welcome message
|
||||
g.logger.ComponentDebug(logging.ComponentGeneral, "SFU preparing welcome message",
|
||||
zap.String("peer_id", peer.ID),
|
||||
zap.Int("num_participants", len(room.GetParticipants())),
|
||||
)
|
||||
welcomeMsg := sfu.NewServerMessage(sfu.MessageTypeWelcome, &sfu.WelcomeData{
|
||||
ParticipantID: peer.ID,
|
||||
RoomID: roomID,
|
||||
Participants: room.GetParticipants(),
|
||||
})
|
||||
g.logger.ComponentDebug(logging.ComponentGeneral, "SFU sending welcome message via SendMessage",
|
||||
zap.String("peer_id", peer.ID),
|
||||
)
|
||||
if err := peer.SendMessage(welcomeMsg); err != nil {
|
||||
g.logger.ComponentError(logging.ComponentGeneral, "Failed to send welcome message",
|
||||
zap.String("peer_id", peer.ID),
|
||||
zap.Error(err),
|
||||
)
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
g.logger.ComponentInfo(logging.ComponentGeneral, "SFU welcome message sent successfully",
|
||||
zap.String("peer_id", peer.ID),
|
||||
zap.String("room_id", roomID),
|
||||
)
|
||||
|
||||
// Handle signaling messages
|
||||
// Note: existing tracks are sent AFTER the first offer/answer exchange completes
|
||||
g.handleSFUSignaling(conn, peer, room)
|
||||
}
|
||||
|
||||
// handleSFUSignaling handles WebSocket signaling messages for a peer
|
||||
func (g *Gateway) handleSFUSignaling(conn *websocket.Conn, peer *sfu.Peer, room *sfu.Room) {
|
||||
defer func() {
|
||||
room.RemovePeer(peer.ID)
|
||||
conn.Close()
|
||||
g.logger.ComponentInfo(logging.ComponentGeneral, "SFU WebSocket disconnected",
|
||||
zap.String("peer_id", peer.ID),
|
||||
zap.String("room_id", room.ID),
|
||||
)
|
||||
}()
|
||||
|
||||
// Set up ping/pong for keepalive
|
||||
conn.SetReadDeadline(time.Now().Add(60 * time.Second))
|
||||
conn.SetPongHandler(func(string) error {
|
||||
conn.SetReadDeadline(time.Now().Add(60 * time.Second))
|
||||
return nil
|
||||
})
|
||||
|
||||
// Start ping ticker
|
||||
ticker := time.NewTicker(30 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
go func() {
|
||||
for range ticker.C {
|
||||
if err := conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(5*time.Second)); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Read messages
|
||||
for {
|
||||
_, data, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
|
||||
g.logger.ComponentWarn(logging.ComponentGeneral, "SFU WebSocket read error",
|
||||
zap.String("peer_id", peer.ID),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Reset read deadline
|
||||
conn.SetReadDeadline(time.Now().Add(60 * time.Second))
|
||||
|
||||
// Parse message
|
||||
var msg sfu.ClientMessage
|
||||
if err := json.Unmarshal(data, &msg); err != nil {
|
||||
peer.SendMessage(sfu.NewErrorMessage("invalid_message", "failed to parse message"))
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle message
|
||||
g.handleSFUMessage(peer, room, &msg)
|
||||
}
|
||||
}
|
||||
|
||||
// handleSFUMessage handles a single signaling message
|
||||
func (g *Gateway) handleSFUMessage(peer *sfu.Peer, room *sfu.Room, msg *sfu.ClientMessage) {
|
||||
switch msg.Type {
|
||||
case sfu.MessageTypeOffer:
|
||||
var data sfu.OfferData
|
||||
if err := json.Unmarshal(msg.Data, &data); err != nil {
|
||||
peer.SendMessage(sfu.NewErrorMessage("invalid_offer", err.Error()))
|
||||
return
|
||||
}
|
||||
if err := peer.HandleOffer(data.SDP); err != nil {
|
||||
peer.SendMessage(sfu.NewErrorMessage("offer_failed", err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
// After successfully handling the FIRST offer, send existing tracks
|
||||
// This ensures the WebRTC connection is established before adding more tracks
|
||||
if peer.MarkInitialOfferHandled() {
|
||||
g.logger.ComponentInfo(logging.ComponentGeneral, "First offer handled, sending existing tracks",
|
||||
zap.String("peer_id", peer.ID),
|
||||
)
|
||||
room.SendExistingTracksTo(peer)
|
||||
}
|
||||
|
||||
case sfu.MessageTypeAnswer:
|
||||
var data sfu.AnswerData
|
||||
if err := json.Unmarshal(msg.Data, &data); err != nil {
|
||||
peer.SendMessage(sfu.NewErrorMessage("invalid_answer", err.Error()))
|
||||
return
|
||||
}
|
||||
if err := peer.HandleAnswer(data.SDP); err != nil {
|
||||
peer.SendMessage(sfu.NewErrorMessage("answer_failed", err.Error()))
|
||||
}
|
||||
|
||||
case sfu.MessageTypeICECandidate:
|
||||
var data sfu.ICECandidateData
|
||||
if err := json.Unmarshal(msg.Data, &data); err != nil {
|
||||
peer.SendMessage(sfu.NewErrorMessage("invalid_candidate", err.Error()))
|
||||
return
|
||||
}
|
||||
if err := peer.HandleICECandidate(&data); err != nil {
|
||||
peer.SendMessage(sfu.NewErrorMessage("candidate_failed", err.Error()))
|
||||
}
|
||||
|
||||
case sfu.MessageTypeMute:
|
||||
peer.SetAudioMuted(true)
|
||||
g.logger.ComponentDebug(logging.ComponentGeneral, "Peer muted audio", zap.String("peer_id", peer.ID))
|
||||
|
||||
case sfu.MessageTypeUnmute:
|
||||
peer.SetAudioMuted(false)
|
||||
g.logger.ComponentDebug(logging.ComponentGeneral, "Peer unmuted audio", zap.String("peer_id", peer.ID))
|
||||
|
||||
case sfu.MessageTypeStartVideo:
|
||||
peer.SetVideoMuted(false)
|
||||
g.logger.ComponentDebug(logging.ComponentGeneral, "Peer started video", zap.String("peer_id", peer.ID))
|
||||
|
||||
case sfu.MessageTypeStopVideo:
|
||||
peer.SetVideoMuted(true)
|
||||
g.logger.ComponentDebug(logging.ComponentGeneral, "Peer stopped video", zap.String("peer_id", peer.ID))
|
||||
|
||||
case sfu.MessageTypeLeave:
|
||||
// Will be handled by deferred cleanup
|
||||
g.logger.ComponentInfo(logging.ComponentGeneral, "Peer leaving room", zap.String("peer_id", peer.ID))
|
||||
|
||||
default:
|
||||
peer.SendMessage(sfu.NewErrorMessage("unknown_type", "unknown message type"))
|
||||
}
|
||||
}
|
||||
|
||||
// initializeSFUManager initializes the SFU manager with the gateway's TURN config
|
||||
func (g *Gateway) initializeSFUManager() error {
|
||||
if g.cfg.SFU == nil || !g.cfg.SFU.Enabled {
|
||||
g.logger.ComponentInfo(logging.ComponentGeneral, "SFU service disabled")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Build ICE servers from config
|
||||
iceServers := make([]webrtc.ICEServer, 0)
|
||||
|
||||
// Add configured ICE servers
|
||||
if g.cfg.SFU.ICEServers != nil {
|
||||
for _, server := range g.cfg.SFU.ICEServers {
|
||||
iceServers = append(iceServers, webrtc.ICEServer{
|
||||
URLs: server.URLs,
|
||||
Username: server.Username,
|
||||
Credential: server.Credential,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Add TURN servers if configured (credentials will be generated dynamically)
|
||||
if g.cfg.TURN != nil {
|
||||
// Determine hostname for ICE server URLs
|
||||
// Use configured domain, or fallback to localhost for local development
|
||||
iceHost := g.cfg.DomainName
|
||||
if iceHost == "" {
|
||||
iceHost = "localhost"
|
||||
}
|
||||
|
||||
if len(g.cfg.TURN.STUNURLs) > 0 {
|
||||
// Process URLs to replace empty hostnames (e.g., "stun::3478" -> "stun:localhost:3478")
|
||||
processedURLs := processURLsWithHost(g.cfg.TURN.STUNURLs, iceHost)
|
||||
iceServers = append(iceServers, webrtc.ICEServer{
|
||||
URLs: processedURLs,
|
||||
})
|
||||
}
|
||||
// Note: TURN credentials are time-limited, so clients should fetch them
|
||||
// via the /v1/turn/credentials endpoint before joining a room
|
||||
}
|
||||
|
||||
// Create SFU config
|
||||
sfuConfig := &sfu.Config{
|
||||
MaxParticipants: g.cfg.SFU.MaxParticipants,
|
||||
MediaTimeout: g.cfg.SFU.MediaTimeout,
|
||||
ICEServers: iceServers,
|
||||
}
|
||||
|
||||
if sfuConfig.MaxParticipants == 0 {
|
||||
sfuConfig.MaxParticipants = 10
|
||||
}
|
||||
if sfuConfig.MediaTimeout == 0 {
|
||||
sfuConfig.MediaTimeout = 30 * time.Second
|
||||
}
|
||||
|
||||
manager, err := sfu.NewRoomManager(sfuConfig, g.logger.Logger)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
g.sfuManager = manager
|
||||
|
||||
g.logger.ComponentInfo(logging.ComponentGeneral, "SFU manager initialized",
|
||||
zap.Int("max_participants", sfuConfig.MaxParticipants),
|
||||
zap.Duration("media_timeout", sfuConfig.MediaTimeout),
|
||||
zap.Int("ice_servers", len(iceServers)),
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
192
pkg/gateway/turn_handlers.go
Normal file
192
pkg/gateway/turn_handlers.go
Normal file
@ -0,0 +1,192 @@
|
||||
package gateway
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha1"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/DeBrosOfficial/network/pkg/gateway/auth"
|
||||
"github.com/DeBrosOfficial/network/pkg/logging"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// TURNCredentialsResponse is the response for TURN credential requests
|
||||
type TURNCredentialsResponse struct {
|
||||
Username string `json:"username"` // Format: "timestamp:userId"
|
||||
Credential string `json:"credential"` // HMAC-SHA1(username, shared_secret) base64 encoded
|
||||
TTL int64 `json:"ttl"` // Time-to-live in seconds
|
||||
STUNURLs []string `json:"stun_urls"` // STUN server URLs
|
||||
TURNURLs []string `json:"turn_urls"` // TURN server URLs
|
||||
}
|
||||
|
||||
// turnCredentialsHandler handles POST /v1/turn/credentials
|
||||
// Returns time-limited TURN credentials for WebRTC connections
|
||||
func (g *Gateway) turnCredentialsHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost && r.Method != http.MethodGet {
|
||||
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
|
||||
return
|
||||
}
|
||||
|
||||
// Check if TURN is configured
|
||||
if g.cfg.TURN == nil || g.cfg.TURN.SharedSecret == "" {
|
||||
g.logger.ComponentWarn(logging.ComponentGeneral, "TURN credentials requested but not configured")
|
||||
writeError(w, http.StatusServiceUnavailable, "TURN service not configured")
|
||||
return
|
||||
}
|
||||
|
||||
// Get user ID from JWT claims or API key
|
||||
userID := g.extractUserID(r)
|
||||
if userID == "" {
|
||||
userID = "anonymous"
|
||||
}
|
||||
|
||||
// Get gateway hostname from request
|
||||
gatewayHost := r.Host
|
||||
if idx := strings.Index(gatewayHost, ":"); idx != -1 {
|
||||
gatewayHost = gatewayHost[:idx] // Remove port
|
||||
}
|
||||
if gatewayHost == "" {
|
||||
gatewayHost = "localhost"
|
||||
}
|
||||
|
||||
// Generate credentials
|
||||
credentials := g.generateTURNCredentials(userID, gatewayHost)
|
||||
|
||||
g.logger.ComponentInfo(logging.ComponentGeneral, "TURN credentials generated",
|
||||
zap.String("user_id", userID),
|
||||
zap.String("gateway_host", gatewayHost),
|
||||
zap.Int64("ttl", credentials.TTL),
|
||||
)
|
||||
|
||||
writeJSON(w, http.StatusOK, credentials)
|
||||
}
|
||||
|
||||
// generateTURNCredentials creates time-limited TURN credentials using HMAC-SHA1
|
||||
func (g *Gateway) generateTURNCredentials(userID, gatewayHost string) *TURNCredentialsResponse {
|
||||
cfg := g.cfg.TURN
|
||||
|
||||
// Default TTL to 24 hours if not configured
|
||||
ttl := cfg.TTL
|
||||
if ttl == 0 {
|
||||
ttl = 24 * time.Hour
|
||||
}
|
||||
|
||||
// Calculate expiry timestamp
|
||||
timestamp := time.Now().Unix() + int64(ttl.Seconds())
|
||||
|
||||
// Format: "timestamp:userId" (coturn format)
|
||||
username := fmt.Sprintf("%d:%s", timestamp, userID)
|
||||
|
||||
// Generate HMAC-SHA1 credential
|
||||
h := hmac.New(sha1.New, []byte(cfg.SharedSecret))
|
||||
h.Write([]byte(username))
|
||||
credential := base64.StdEncoding.EncodeToString(h.Sum(nil))
|
||||
|
||||
// Determine the host to use for STUN/TURN URLs
|
||||
// Priority: 1) ExternalHost from config, 2) Auto-detect LAN IP, 3) Gateway host from request
|
||||
host := cfg.ExternalHost
|
||||
if host == "" {
|
||||
// Auto-detect LAN IP for development
|
||||
host = detectLANIP()
|
||||
if host == "" {
|
||||
// Fallback to gateway host from request (may be localhost)
|
||||
host = gatewayHost
|
||||
}
|
||||
}
|
||||
|
||||
// Process URLs - replace empty hostnames (::) with determined host
|
||||
stunURLs := processURLsWithHost(cfg.STUNURLs, host)
|
||||
turnURLs := processURLsWithHost(cfg.TURNURLs, host)
|
||||
|
||||
// If TLS is enabled, ensure we have turns:// URLs
|
||||
if cfg.TLSEnabled {
|
||||
hasTurns := false
|
||||
for _, url := range turnURLs {
|
||||
if strings.HasPrefix(url, "turns:") {
|
||||
hasTurns = true
|
||||
break
|
||||
}
|
||||
}
|
||||
// Auto-add turns:// URL if not already configured
|
||||
if !hasTurns {
|
||||
turnsURL := fmt.Sprintf("turns:%s:443?transport=tcp", host)
|
||||
turnURLs = append(turnURLs, turnsURL)
|
||||
}
|
||||
}
|
||||
|
||||
return &TURNCredentialsResponse{
|
||||
Username: username,
|
||||
Credential: credential,
|
||||
TTL: int64(ttl.Seconds()),
|
||||
STUNURLs: stunURLs,
|
||||
TURNURLs: turnURLs,
|
||||
}
|
||||
}
|
||||
|
||||
// detectLANIP returns the first non-loopback IPv4 address found on the system.
|
||||
// Returns empty string if no suitable address is found.
|
||||
func detectLANIP() string {
|
||||
addrs, err := net.InterfaceAddrs()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
for _, addr := range addrs {
|
||||
if ipNet, ok := addr.(*net.IPNet); ok && !ipNet.IP.IsLoopback() {
|
||||
if ipNet.IP.To4() != nil {
|
||||
return ipNet.IP.String()
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// processURLsWithHost replaces empty hostnames in URLs with the given host
|
||||
// Supports two patterns:
|
||||
// - "stun:::3478" (triple colon) -> "stun:host:3478"
|
||||
// - "stun::3478" (double colon) -> "stun:host:3478"
|
||||
func processURLsWithHost(urls []string, host string) []string {
|
||||
result := make([]string, 0, len(urls))
|
||||
for _, url := range urls {
|
||||
// Check for triple colon pattern first (e.g., "stun:::3478")
|
||||
// This is the preferred format: protocol:::port
|
||||
if strings.Contains(url, ":::") {
|
||||
url = strings.Replace(url, ":::", ":"+host+":", 1)
|
||||
} else if strings.Contains(url, "::") {
|
||||
// Fallback for double colon pattern (e.g., "stun::3478")
|
||||
url = strings.Replace(url, "::", ":"+host+":", 1)
|
||||
}
|
||||
result = append(result, url)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// extractUserID extracts the user ID from the request context
|
||||
func (g *Gateway) extractUserID(r *http.Request) string {
|
||||
ctx := r.Context()
|
||||
|
||||
// Try JWT claims first
|
||||
if v := ctx.Value(ctxKeyJWT); v != nil {
|
||||
if claims, ok := v.(*auth.JWTClaims); ok && claims != nil {
|
||||
if claims.Sub != "" {
|
||||
return claims.Sub
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to API key
|
||||
if v := ctx.Value(ctxKeyAPIKey); v != nil {
|
||||
if key, ok := v.(string); ok && key != "" {
|
||||
// Use a hash of the API key as the user ID for privacy
|
||||
return fmt.Sprintf("ak_%s", key[:min(8, len(key))])
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
@ -33,19 +33,21 @@ func (n *Node) startHTTPGateway(ctx context.Context) error {
|
||||
}
|
||||
|
||||
gwCfg := &gateway.Config{
|
||||
ListenAddr: n.config.HTTPGateway.ListenAddr,
|
||||
ClientNamespace: n.config.HTTPGateway.ClientNamespace,
|
||||
BootstrapPeers: n.config.Discovery.BootstrapPeers,
|
||||
NodePeerID: loadNodePeerIDFromIdentity(n.config.Node.DataDir),
|
||||
RQLiteDSN: n.config.HTTPGateway.RQLiteDSN,
|
||||
OlricServers: n.config.HTTPGateway.OlricServers,
|
||||
OlricTimeout: n.config.HTTPGateway.OlricTimeout,
|
||||
ListenAddr: n.config.HTTPGateway.ListenAddr,
|
||||
ClientNamespace: n.config.HTTPGateway.ClientNamespace,
|
||||
BootstrapPeers: n.config.Discovery.BootstrapPeers,
|
||||
NodePeerID: loadNodePeerIDFromIdentity(n.config.Node.DataDir),
|
||||
RQLiteDSN: n.config.HTTPGateway.RQLiteDSN,
|
||||
OlricServers: n.config.HTTPGateway.OlricServers,
|
||||
OlricTimeout: n.config.HTTPGateway.OlricTimeout,
|
||||
IPFSClusterAPIURL: n.config.HTTPGateway.IPFSClusterAPIURL,
|
||||
IPFSAPIURL: n.config.HTTPGateway.IPFSAPIURL,
|
||||
IPFSTimeout: n.config.HTTPGateway.IPFSTimeout,
|
||||
EnableHTTPS: n.config.HTTPGateway.HTTPS.Enabled,
|
||||
DomainName: n.config.HTTPGateway.HTTPS.Domain,
|
||||
TLSCacheDir: n.config.HTTPGateway.HTTPS.CacheDir,
|
||||
IPFSAPIURL: n.config.HTTPGateway.IPFSAPIURL,
|
||||
IPFSTimeout: n.config.HTTPGateway.IPFSTimeout,
|
||||
EnableHTTPS: n.config.HTTPGateway.HTTPS.Enabled,
|
||||
DomainName: n.config.HTTPGateway.HTTPS.Domain,
|
||||
TLSCacheDir: n.config.HTTPGateway.HTTPS.CacheDir,
|
||||
TURN: n.config.HTTPGateway.TURN,
|
||||
SFU: n.config.HTTPGateway.SFU,
|
||||
}
|
||||
|
||||
apiGateway, err := gateway.New(gatewayLogger, gwCfg)
|
||||
|
||||
@ -16,6 +16,7 @@ import (
|
||||
"github.com/DeBrosOfficial/network/pkg/logging"
|
||||
"github.com/DeBrosOfficial/network/pkg/pubsub"
|
||||
database "github.com/DeBrosOfficial/network/pkg/rqlite"
|
||||
"github.com/DeBrosOfficial/network/pkg/turn"
|
||||
"github.com/libp2p/go-libp2p/core/host"
|
||||
"go.uber.org/zap"
|
||||
"golang.org/x/crypto/acme/autocert"
|
||||
@ -55,6 +56,9 @@ type Node struct {
|
||||
|
||||
// Certificate ready signal - closed when TLS certificates are extracted and ready for use
|
||||
certReady chan struct{}
|
||||
|
||||
// Built-in TURN server for WebRTC NAT traversal
|
||||
turnServer *turn.Server
|
||||
}
|
||||
|
||||
// NewNode creates a new network node
|
||||
@ -96,6 +100,11 @@ func (n *Node) Start(ctx context.Context) error {
|
||||
n.logger.ComponentWarn(logging.ComponentNode, "Failed to start HTTP Gateway", zap.Error(err))
|
||||
}
|
||||
|
||||
// Start built-in TURN server if enabled
|
||||
if err := n.startTURNServer(); err != nil {
|
||||
n.logger.ComponentWarn(logging.ComponentNode, "Failed to start TURN server", zap.Error(err))
|
||||
}
|
||||
|
||||
// Start LibP2P host first (needed for cluster discovery)
|
||||
if err := n.startLibP2P(); err != nil {
|
||||
return fmt.Errorf("failed to start LibP2P: %w", err)
|
||||
@ -135,6 +144,11 @@ func (n *Node) Start(ctx context.Context) error {
|
||||
func (n *Node) Stop() error {
|
||||
n.logger.ComponentInfo(logging.ComponentNode, "Stopping network node")
|
||||
|
||||
// Stop TURN server
|
||||
if n.turnServer != nil {
|
||||
_ = n.turnServer.Stop()
|
||||
}
|
||||
|
||||
// Stop HTTP Gateway server
|
||||
if n.apiGatewayServer != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
|
||||
90
pkg/node/turn.go
Normal file
90
pkg/node/turn.go
Normal file
@ -0,0 +1,90 @@
|
||||
package node
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"github.com/DeBrosOfficial/network/pkg/logging"
|
||||
"github.com/DeBrosOfficial/network/pkg/turn"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// startTURNServer initializes and starts the built-in TURN server
|
||||
func (n *Node) startTURNServer() error {
|
||||
if !n.config.TURNServer.Enabled {
|
||||
n.logger.ComponentInfo(logging.ComponentNode, "Built-in TURN server disabled")
|
||||
return nil
|
||||
}
|
||||
|
||||
n.logger.ComponentInfo(logging.ComponentNode, "Starting built-in TURN server")
|
||||
|
||||
// Get shared secret - env var takes priority over config file (for production)
|
||||
sharedSecret := os.Getenv("TURN_SHARED_SECRET")
|
||||
if sharedSecret == "" && n.config.HTTPGateway.TURN != nil && n.config.HTTPGateway.TURN.SharedSecret != "" {
|
||||
sharedSecret = n.config.HTTPGateway.TURN.SharedSecret
|
||||
}
|
||||
|
||||
if sharedSecret == "" {
|
||||
n.logger.ComponentWarn(logging.ComponentNode, "TURN server enabled but no shared_secret configured in http_gateway.turn")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get public IP - env var takes priority over config file (for production)
|
||||
publicIP := os.Getenv("TURN_PUBLIC_IP")
|
||||
if publicIP == "" {
|
||||
publicIP = n.config.TURNServer.PublicIP
|
||||
}
|
||||
|
||||
// Build TURN server config
|
||||
turnCfg := &turn.Config{
|
||||
Enabled: true,
|
||||
ListenAddr: n.config.TURNServer.ListenAddr,
|
||||
PublicIP: publicIP,
|
||||
Realm: n.config.TURNServer.Realm,
|
||||
SharedSecret: sharedSecret,
|
||||
CredentialTTL: 24 * 60 * 60, // 24 hours in seconds (will be converted)
|
||||
MinPort: n.config.TURNServer.MinPort,
|
||||
MaxPort: n.config.TURNServer.MaxPort,
|
||||
// TLS configuration for TURNS
|
||||
TLSEnabled: n.config.TURNServer.TLSEnabled,
|
||||
TLSListenAddr: n.config.TURNServer.TLSListenAddr,
|
||||
TLSCertFile: n.config.TURNServer.TLSCertFile,
|
||||
TLSKeyFile: n.config.TURNServer.TLSKeyFile,
|
||||
}
|
||||
|
||||
// Apply defaults
|
||||
if turnCfg.ListenAddr == "" {
|
||||
turnCfg.ListenAddr = "0.0.0.0:3478"
|
||||
}
|
||||
if turnCfg.Realm == "" {
|
||||
turnCfg.Realm = "orama.network"
|
||||
}
|
||||
if turnCfg.MinPort == 0 {
|
||||
turnCfg.MinPort = 49152
|
||||
}
|
||||
if turnCfg.MaxPort == 0 {
|
||||
turnCfg.MaxPort = 65535
|
||||
}
|
||||
if turnCfg.TLSListenAddr == "" && turnCfg.TLSEnabled {
|
||||
turnCfg.TLSListenAddr = "0.0.0.0:443"
|
||||
}
|
||||
|
||||
// Create and start TURN server
|
||||
server, err := turn.NewServer(turnCfg, n.logger.Logger)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := server.Start(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
n.turnServer = server
|
||||
|
||||
n.logger.ComponentInfo(logging.ComponentNode, "Built-in TURN server started",
|
||||
zap.String("listen_addr", turnCfg.ListenAddr),
|
||||
zap.String("realm", turnCfg.Realm),
|
||||
zap.Bool("turns_enabled", turnCfg.TLSEnabled),
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
@ -214,3 +214,9 @@ func IsServiceUnavailable(err error) bool {
|
||||
errors.Is(err, ErrDatabaseUnavailable) ||
|
||||
errors.Is(err, ErrCacheUnavailable)
|
||||
}
|
||||
|
||||
// IsValidationError checks if an error is a validation error.
|
||||
func IsValidationError(err error) bool {
|
||||
var validationErr *ValidationError
|
||||
return errors.As(err, &validationErr)
|
||||
}
|
||||
|
||||
239
pkg/serverless/triggers.go
Normal file
239
pkg/serverless/triggers.go
Normal file
@ -0,0 +1,239 @@
|
||||
package serverless
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/DeBrosOfficial/network/pkg/rqlite"
|
||||
"github.com/google/uuid"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Ensure DBTriggerManager implements TriggerManager interface.
|
||||
var _ TriggerManager = (*DBTriggerManager)(nil)
|
||||
|
||||
// DBTriggerManager manages function triggers using RQLite database.
|
||||
type DBTriggerManager struct {
|
||||
db rqlite.Client
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewDBTriggerManager creates a new trigger manager.
|
||||
func NewDBTriggerManager(db rqlite.Client, logger *zap.Logger) *DBTriggerManager {
|
||||
return &DBTriggerManager{
|
||||
db: db,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// pubsubTriggerRow represents a row from the function_pubsub_triggers table.
|
||||
type pubsubTriggerRow struct {
|
||||
ID string `db:"id"`
|
||||
FunctionID string `db:"function_id"`
|
||||
Topic string `db:"topic"`
|
||||
Enabled bool `db:"enabled"`
|
||||
CreatedAt time.Time `db:"created_at"`
|
||||
}
|
||||
|
||||
// AddPubSubTrigger adds a pubsub trigger to a function.
|
||||
func (m *DBTriggerManager) AddPubSubTrigger(ctx context.Context, functionID, topic string) error {
|
||||
functionID = strings.TrimSpace(functionID)
|
||||
topic = strings.TrimSpace(topic)
|
||||
|
||||
if functionID == "" {
|
||||
return &ValidationError{Field: "function_id", Message: "cannot be empty"}
|
||||
}
|
||||
if topic == "" {
|
||||
return &ValidationError{Field: "topic", Message: "cannot be empty"}
|
||||
}
|
||||
|
||||
// Check if trigger already exists for this function and topic
|
||||
var existing []pubsubTriggerRow
|
||||
checkQuery := `SELECT id FROM function_pubsub_triggers WHERE function_id = ? AND topic = ? AND enabled = TRUE`
|
||||
if err := m.db.Query(ctx, &existing, checkQuery, functionID, topic); err != nil {
|
||||
return fmt.Errorf("failed to check existing trigger: %w", err)
|
||||
}
|
||||
if len(existing) > 0 {
|
||||
return &ValidationError{Field: "topic", Message: "trigger already exists for this topic"}
|
||||
}
|
||||
|
||||
// Generate trigger ID
|
||||
triggerID := "trig_" + uuid.New().String()[:8]
|
||||
|
||||
// Insert trigger
|
||||
query := `INSERT INTO function_pubsub_triggers (id, function_id, topic, enabled, created_at) VALUES (?, ?, ?, TRUE, ?)`
|
||||
_, err := m.db.Exec(ctx, query, triggerID, functionID, topic, time.Now())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create pubsub trigger: %w", err)
|
||||
}
|
||||
|
||||
m.logger.Info("PubSub trigger created",
|
||||
zap.String("trigger_id", triggerID),
|
||||
zap.String("function_id", functionID),
|
||||
zap.String("topic", topic),
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddPubSubTriggerWithID adds a pubsub trigger and returns the trigger ID.
|
||||
func (m *DBTriggerManager) AddPubSubTriggerWithID(ctx context.Context, functionID, topic string) (string, error) {
|
||||
functionID = strings.TrimSpace(functionID)
|
||||
topic = strings.TrimSpace(topic)
|
||||
|
||||
if functionID == "" {
|
||||
return "", &ValidationError{Field: "function_id", Message: "cannot be empty"}
|
||||
}
|
||||
if topic == "" {
|
||||
return "", &ValidationError{Field: "topic", Message: "cannot be empty"}
|
||||
}
|
||||
|
||||
// Check if trigger already exists for this function and topic
|
||||
var existing []pubsubTriggerRow
|
||||
checkQuery := `SELECT id FROM function_pubsub_triggers WHERE function_id = ? AND topic = ? AND enabled = TRUE`
|
||||
if err := m.db.Query(ctx, &existing, checkQuery, functionID, topic); err != nil {
|
||||
return "", fmt.Errorf("failed to check existing trigger: %w", err)
|
||||
}
|
||||
if len(existing) > 0 {
|
||||
return "", &ValidationError{Field: "topic", Message: "trigger already exists for this topic"}
|
||||
}
|
||||
|
||||
// Generate trigger ID
|
||||
triggerID := "trig_" + uuid.New().String()[:8]
|
||||
|
||||
// Insert trigger
|
||||
query := `INSERT INTO function_pubsub_triggers (id, function_id, topic, enabled, created_at) VALUES (?, ?, ?, TRUE, ?)`
|
||||
_, err := m.db.Exec(ctx, query, triggerID, functionID, topic, time.Now())
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create pubsub trigger: %w", err)
|
||||
}
|
||||
|
||||
m.logger.Info("PubSub trigger created",
|
||||
zap.String("trigger_id", triggerID),
|
||||
zap.String("function_id", functionID),
|
||||
zap.String("topic", topic),
|
||||
)
|
||||
|
||||
return triggerID, nil
|
||||
}
|
||||
|
||||
// AddCronTrigger adds a cron-based trigger to a function.
|
||||
func (m *DBTriggerManager) AddCronTrigger(ctx context.Context, functionID, cronExpr string) error {
|
||||
// TODO: Implement cron trigger support
|
||||
return fmt.Errorf("cron triggers not yet implemented")
|
||||
}
|
||||
|
||||
// AddDBTrigger adds a database trigger to a function.
|
||||
func (m *DBTriggerManager) AddDBTrigger(ctx context.Context, functionID, tableName string, operation DBOperation, condition string) error {
|
||||
// TODO: Implement database trigger support
|
||||
return fmt.Errorf("database triggers not yet implemented")
|
||||
}
|
||||
|
||||
// ScheduleOnce schedules a one-time execution.
|
||||
func (m *DBTriggerManager) ScheduleOnce(ctx context.Context, functionID string, runAt time.Time, payload []byte) (string, error) {
|
||||
// TODO: Implement one-time timer support
|
||||
return "", fmt.Errorf("one-time timers not yet implemented")
|
||||
}
|
||||
|
||||
// RemoveTrigger removes a trigger by ID.
|
||||
func (m *DBTriggerManager) RemoveTrigger(ctx context.Context, triggerID string) error {
|
||||
triggerID = strings.TrimSpace(triggerID)
|
||||
if triggerID == "" {
|
||||
return &ValidationError{Field: "trigger_id", Message: "cannot be empty"}
|
||||
}
|
||||
|
||||
// Try to delete from pubsub triggers first
|
||||
query := `DELETE FROM function_pubsub_triggers WHERE id = ?`
|
||||
result, err := m.db.Exec(ctx, query, triggerID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to remove trigger: %w", err)
|
||||
}
|
||||
|
||||
rowsAffected, _ := result.RowsAffected()
|
||||
if rowsAffected == 0 {
|
||||
return ErrTriggerNotFound
|
||||
}
|
||||
|
||||
m.logger.Info("Trigger removed", zap.String("trigger_id", triggerID))
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListPubSubTriggers returns all pubsub triggers for a function.
|
||||
func (m *DBTriggerManager) ListPubSubTriggers(ctx context.Context, functionID string) ([]PubSubTrigger, error) {
|
||||
functionID = strings.TrimSpace(functionID)
|
||||
if functionID == "" {
|
||||
return nil, &ValidationError{Field: "function_id", Message: "cannot be empty"}
|
||||
}
|
||||
|
||||
query := `SELECT id, function_id, topic, enabled FROM function_pubsub_triggers WHERE function_id = ? AND enabled = TRUE`
|
||||
var rows []pubsubTriggerRow
|
||||
if err := m.db.Query(ctx, &rows, query, functionID); err != nil {
|
||||
return nil, fmt.Errorf("failed to list pubsub triggers: %w", err)
|
||||
}
|
||||
|
||||
triggers := make([]PubSubTrigger, len(rows))
|
||||
for i, row := range rows {
|
||||
triggers[i] = PubSubTrigger{
|
||||
ID: row.ID,
|
||||
FunctionID: row.FunctionID,
|
||||
Topic: row.Topic,
|
||||
Enabled: row.Enabled,
|
||||
}
|
||||
}
|
||||
|
||||
return triggers, nil
|
||||
}
|
||||
|
||||
// GetTriggersByTopic returns all enabled triggers for a specific topic.
|
||||
func (m *DBTriggerManager) GetTriggersByTopic(ctx context.Context, topic string) ([]PubSubTrigger, error) {
|
||||
topic = strings.TrimSpace(topic)
|
||||
if topic == "" {
|
||||
return nil, &ValidationError{Field: "topic", Message: "cannot be empty"}
|
||||
}
|
||||
|
||||
query := `SELECT id, function_id, topic, enabled FROM function_pubsub_triggers WHERE topic = ? AND enabled = TRUE`
|
||||
var rows []pubsubTriggerRow
|
||||
if err := m.db.Query(ctx, &rows, query, topic); err != nil {
|
||||
return nil, fmt.Errorf("failed to get triggers by topic: %w", err)
|
||||
}
|
||||
|
||||
triggers := make([]PubSubTrigger, len(rows))
|
||||
for i, row := range rows {
|
||||
triggers[i] = PubSubTrigger{
|
||||
ID: row.ID,
|
||||
FunctionID: row.FunctionID,
|
||||
Topic: row.Topic,
|
||||
Enabled: row.Enabled,
|
||||
}
|
||||
}
|
||||
|
||||
return triggers, nil
|
||||
}
|
||||
|
||||
// GetPubSubTrigger returns a specific pubsub trigger by ID.
|
||||
func (m *DBTriggerManager) GetPubSubTrigger(ctx context.Context, triggerID string) (*PubSubTrigger, error) {
|
||||
triggerID = strings.TrimSpace(triggerID)
|
||||
if triggerID == "" {
|
||||
return nil, &ValidationError{Field: "trigger_id", Message: "cannot be empty"}
|
||||
}
|
||||
|
||||
query := `SELECT id, function_id, topic, enabled FROM function_pubsub_triggers WHERE id = ?`
|
||||
var rows []pubsubTriggerRow
|
||||
if err := m.db.Query(ctx, &rows, query, triggerID); err != nil {
|
||||
return nil, fmt.Errorf("failed to get trigger: %w", err)
|
||||
}
|
||||
|
||||
if len(rows) == 0 {
|
||||
return nil, ErrTriggerNotFound
|
||||
}
|
||||
|
||||
row := rows[0]
|
||||
return &PubSubTrigger{
|
||||
ID: row.ID,
|
||||
FunctionID: row.FunctionID,
|
||||
Topic: row.Topic,
|
||||
Enabled: row.Enabled,
|
||||
}, nil
|
||||
}
|
||||
@ -131,6 +131,12 @@ type TriggerManager interface {
|
||||
|
||||
// RemoveTrigger removes a trigger by ID.
|
||||
RemoveTrigger(ctx context.Context, triggerID string) error
|
||||
|
||||
// ListPubSubTriggers returns all pubsub triggers for a function.
|
||||
ListPubSubTriggers(ctx context.Context, functionID string) ([]PubSubTrigger, error)
|
||||
|
||||
// GetTriggersByTopic returns all enabled triggers for a specific topic.
|
||||
GetTriggersByTopic(ctx context.Context, topic string) ([]PubSubTrigger, error)
|
||||
}
|
||||
|
||||
// JobManager manages background job execution.
|
||||
|
||||
343
pkg/turn/server.go
Normal file
343
pkg/turn/server.go
Normal file
@ -0,0 +1,343 @@
|
||||
// Package turn provides a built-in TURN/STUN server using Pion.
|
||||
package turn
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha1"
|
||||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/pion/turn/v4"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Config contains TURN server configuration
|
||||
type Config struct {
|
||||
// Enabled enables the built-in TURN server
|
||||
Enabled bool `yaml:"enabled"`
|
||||
|
||||
// ListenAddr is the UDP address to listen on (e.g., "0.0.0.0:3478")
|
||||
ListenAddr string `yaml:"listen_addr"`
|
||||
|
||||
// PublicIP is the public IP address to advertise for relay
|
||||
// If empty, will try to auto-detect
|
||||
PublicIP string `yaml:"public_ip"`
|
||||
|
||||
// Realm is the TURN realm (e.g., "orama.network")
|
||||
Realm string `yaml:"realm"`
|
||||
|
||||
// SharedSecret is the secret for HMAC-SHA1 credential generation
|
||||
// Should match the gateway's TURN_SHARED_SECRET
|
||||
SharedSecret string `yaml:"shared_secret"`
|
||||
|
||||
// CredentialTTL is the lifetime of generated credentials
|
||||
CredentialTTL time.Duration `yaml:"credential_ttl"`
|
||||
|
||||
// MinPort and MaxPort define the relay port range
|
||||
MinPort uint16 `yaml:"min_port"`
|
||||
MaxPort uint16 `yaml:"max_port"`
|
||||
|
||||
// TLS Configuration for TURNS (TURN over TLS)
|
||||
// TLSEnabled enables TURNS listener on TLSListenAddr
|
||||
TLSEnabled bool `yaml:"tls_enabled"`
|
||||
|
||||
// TLSListenAddr is the TCP/TLS address to listen on (e.g., "0.0.0.0:443")
|
||||
TLSListenAddr string `yaml:"tls_listen_addr"`
|
||||
|
||||
// TLSCertFile is the path to the TLS certificate file
|
||||
TLSCertFile string `yaml:"tls_cert_file"`
|
||||
|
||||
// TLSKeyFile is the path to the TLS private key file
|
||||
TLSKeyFile string `yaml:"tls_key_file"`
|
||||
}
|
||||
|
||||
// DefaultConfig returns a default TURN server configuration
|
||||
func DefaultConfig() *Config {
|
||||
return &Config{
|
||||
Enabled: false,
|
||||
ListenAddr: "0.0.0.0:3478",
|
||||
Realm: "orama.network",
|
||||
CredentialTTL: 24 * time.Hour,
|
||||
MinPort: 49152,
|
||||
MaxPort: 65535,
|
||||
}
|
||||
}
|
||||
|
||||
// Server is a built-in TURN/STUN server
|
||||
type Server struct {
|
||||
config *Config
|
||||
logger *zap.Logger
|
||||
turnServer *turn.Server
|
||||
conn net.PacketConn // UDP listener
|
||||
tlsListener net.Listener // TLS listener for TURNS
|
||||
mu sync.RWMutex
|
||||
running bool
|
||||
}
|
||||
|
||||
// NewServer creates a new TURN server
|
||||
func NewServer(config *Config, logger *zap.Logger) (*Server, error) {
|
||||
if config == nil {
|
||||
config = DefaultConfig()
|
||||
}
|
||||
|
||||
return &Server{
|
||||
config: config,
|
||||
logger: logger,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Start starts the TURN server
|
||||
func (s *Server) Start() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.running {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !s.config.Enabled {
|
||||
s.logger.Info("TURN server disabled")
|
||||
return nil
|
||||
}
|
||||
|
||||
if s.config.SharedSecret == "" {
|
||||
return fmt.Errorf("TURN shared secret is required")
|
||||
}
|
||||
|
||||
// Create UDP listener
|
||||
conn, err := net.ListenPacket("udp4", s.config.ListenAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to listen on %s: %w", s.config.ListenAddr, err)
|
||||
}
|
||||
s.conn = conn
|
||||
|
||||
// Determine public IP
|
||||
publicIP := s.config.PublicIP
|
||||
if publicIP == "" {
|
||||
// Try to auto-detect
|
||||
publicIP, err = getPublicIP()
|
||||
if err != nil {
|
||||
s.logger.Warn("Failed to auto-detect public IP, using listener address", zap.Error(err))
|
||||
host, _, _ := net.SplitHostPort(s.config.ListenAddr)
|
||||
if host == "0.0.0.0" || host == "" {
|
||||
host = "127.0.0.1"
|
||||
}
|
||||
publicIP = host
|
||||
}
|
||||
}
|
||||
|
||||
relayIP := net.ParseIP(publicIP)
|
||||
if relayIP == nil {
|
||||
return fmt.Errorf("invalid public IP: %s", publicIP)
|
||||
}
|
||||
|
||||
s.logger.Info("Starting TURN server",
|
||||
zap.String("listen_addr", s.config.ListenAddr),
|
||||
zap.String("public_ip", publicIP),
|
||||
zap.String("realm", s.config.Realm),
|
||||
zap.Uint16("min_port", s.config.MinPort),
|
||||
zap.Uint16("max_port", s.config.MaxPort),
|
||||
zap.Bool("tls_enabled", s.config.TLSEnabled),
|
||||
)
|
||||
|
||||
// Prepare listener configs for TLS (TURNS)
|
||||
var listenerConfigs []turn.ListenerConfig
|
||||
|
||||
if s.config.TLSEnabled && s.config.TLSCertFile != "" && s.config.TLSKeyFile != "" {
|
||||
// Load TLS certificate
|
||||
cert, err := tls.LoadX509KeyPair(s.config.TLSCertFile, s.config.TLSKeyFile)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return fmt.Errorf("failed to load TLS certificate: %w", err)
|
||||
}
|
||||
|
||||
tlsConfig := &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
MinVersion: tls.VersionTLS12,
|
||||
}
|
||||
|
||||
// Determine TLS listen address
|
||||
tlsListenAddr := s.config.TLSListenAddr
|
||||
if tlsListenAddr == "" {
|
||||
tlsListenAddr = "0.0.0.0:443"
|
||||
}
|
||||
|
||||
// Create TLS listener
|
||||
tlsListener, err := tls.Listen("tcp", tlsListenAddr, tlsConfig)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return fmt.Errorf("failed to start TLS listener on %s: %w", tlsListenAddr, err)
|
||||
}
|
||||
s.tlsListener = tlsListener
|
||||
|
||||
listenerConfigs = append(listenerConfigs, turn.ListenerConfig{
|
||||
Listener: tlsListener,
|
||||
RelayAddressGenerator: &turn.RelayAddressGeneratorPortRange{
|
||||
RelayAddress: relayIP,
|
||||
Address: "0.0.0.0",
|
||||
MinPort: s.config.MinPort,
|
||||
MaxPort: s.config.MaxPort,
|
||||
},
|
||||
})
|
||||
|
||||
s.logger.Info("TURNS (TLS) listener started",
|
||||
zap.String("tls_addr", tlsListenAddr),
|
||||
)
|
||||
}
|
||||
|
||||
// Create TURN server with HMAC-SHA1 auth
|
||||
turnServer, err := turn.NewServer(turn.ServerConfig{
|
||||
Realm: s.config.Realm,
|
||||
AuthHandler: func(username, realm string, srcAddr net.Addr) ([]byte, bool) {
|
||||
return s.authHandler(username, realm, srcAddr)
|
||||
},
|
||||
PacketConnConfigs: []turn.PacketConnConfig{
|
||||
{
|
||||
PacketConn: conn,
|
||||
RelayAddressGenerator: &turn.RelayAddressGeneratorPortRange{
|
||||
RelayAddress: relayIP,
|
||||
Address: "0.0.0.0",
|
||||
MinPort: s.config.MinPort,
|
||||
MaxPort: s.config.MaxPort,
|
||||
},
|
||||
},
|
||||
},
|
||||
ListenerConfigs: listenerConfigs,
|
||||
})
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
if s.tlsListener != nil {
|
||||
s.tlsListener.Close()
|
||||
}
|
||||
return fmt.Errorf("failed to create TURN server: %w", err)
|
||||
}
|
||||
|
||||
s.turnServer = turnServer
|
||||
s.running = true
|
||||
|
||||
s.logger.Info("TURN server started successfully",
|
||||
zap.String("addr", s.config.ListenAddr),
|
||||
zap.String("realm", s.config.Realm),
|
||||
zap.Bool("turns_enabled", s.config.TLSEnabled),
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// authHandler validates HMAC-SHA1 credentials (coturn-compatible format)
|
||||
// Username format: timestamp:userID (e.g., "1234567890:user123")
|
||||
func (s *Server) authHandler(username, realm string, srcAddr net.Addr) ([]byte, bool) {
|
||||
// Parse timestamp from username
|
||||
// Format: timestamp:userID
|
||||
var timestamp int64
|
||||
for i, c := range username {
|
||||
if c == ':' {
|
||||
ts, err := strconv.ParseInt(username[:i], 10, 64)
|
||||
if err != nil {
|
||||
s.logger.Debug("Invalid timestamp in username", zap.String("username", username))
|
||||
return nil, false
|
||||
}
|
||||
timestamp = ts
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Check if credential has expired
|
||||
now := time.Now().Unix()
|
||||
if timestamp > 0 && timestamp < now {
|
||||
s.logger.Debug("Credential expired",
|
||||
zap.String("username", username),
|
||||
zap.Int64("expired_at", timestamp),
|
||||
zap.Int64("now", now),
|
||||
)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// Generate expected password using HMAC-SHA1
|
||||
// This matches the gateway's generateTURNCredentials function
|
||||
h := hmac.New(sha1.New, []byte(s.config.SharedSecret))
|
||||
h.Write([]byte(username))
|
||||
password := base64.StdEncoding.EncodeToString(h.Sum(nil))
|
||||
|
||||
s.logger.Debug("TURN auth request",
|
||||
zap.String("username", username),
|
||||
zap.String("realm", realm),
|
||||
zap.String("src_addr", srcAddr.String()),
|
||||
)
|
||||
|
||||
return turn.GenerateAuthKey(username, realm, password), true
|
||||
}
|
||||
|
||||
// Stop stops the TURN server
|
||||
func (s *Server) Stop() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if !s.running {
|
||||
return nil
|
||||
}
|
||||
|
||||
s.logger.Info("Stopping TURN server")
|
||||
|
||||
if s.turnServer != nil {
|
||||
if err := s.turnServer.Close(); err != nil {
|
||||
s.logger.Warn("Error closing TURN server", zap.Error(err))
|
||||
}
|
||||
s.turnServer = nil
|
||||
}
|
||||
|
||||
if s.conn != nil {
|
||||
s.conn.Close()
|
||||
s.conn = nil
|
||||
}
|
||||
|
||||
// Close TLS listener
|
||||
if s.tlsListener != nil {
|
||||
s.tlsListener.Close()
|
||||
s.tlsListener = nil
|
||||
}
|
||||
|
||||
s.running = false
|
||||
s.logger.Info("TURN server stopped")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsRunning returns whether the server is running
|
||||
func (s *Server) IsRunning() bool {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.running
|
||||
}
|
||||
|
||||
// GetListenAddr returns the listen address
|
||||
func (s *Server) GetListenAddr() string {
|
||||
return s.config.ListenAddr
|
||||
}
|
||||
|
||||
// GetPublicAddr returns the public address for clients
|
||||
func (s *Server) GetPublicAddr() string {
|
||||
if s.config.PublicIP != "" {
|
||||
_, port, _ := net.SplitHostPort(s.config.ListenAddr)
|
||||
return net.JoinHostPort(s.config.PublicIP, port)
|
||||
}
|
||||
return s.config.ListenAddr
|
||||
}
|
||||
|
||||
// getPublicIP tries to determine the public IP address
|
||||
func getPublicIP() (string, error) {
|
||||
// Try to get outbound IP by connecting to a public address
|
||||
conn, err := net.Dial("udp4", "8.8.8.8:80")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
localAddr := conn.LocalAddr().(*net.UDPAddr)
|
||||
return localAddr.IP.String(), nil
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user