mirror of
https://github.com/DeBrosOfficial/orama.git
synced 2026-03-17 09:36:56 +00:00
initial
This commit is contained in:
parent
ade6241357
commit
ea48a21ae4
@ -73,6 +73,27 @@ func parseGatewayConfig(logger *logging.ColoredLogger) *gateway.Config {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Load YAML
|
// Load YAML
|
||||||
|
type yamlICEServer struct {
|
||||||
|
URLs []string `yaml:"urls"`
|
||||||
|
Username string `yaml:"username,omitempty"`
|
||||||
|
Credential string `yaml:"credential,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type yamlTURN struct {
|
||||||
|
SharedSecret string `yaml:"shared_secret"`
|
||||||
|
TTL string `yaml:"ttl"`
|
||||||
|
ExternalHost string `yaml:"external_host"`
|
||||||
|
STUNURLs []string `yaml:"stun_urls"`
|
||||||
|
TURNURLs []string `yaml:"turn_urls"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type yamlSFU struct {
|
||||||
|
Enabled bool `yaml:"enabled"`
|
||||||
|
MaxParticipants int `yaml:"max_participants"`
|
||||||
|
MediaTimeout string `yaml:"media_timeout"`
|
||||||
|
ICEServers []yamlICEServer `yaml:"ice_servers"`
|
||||||
|
}
|
||||||
|
|
||||||
type yamlCfg struct {
|
type yamlCfg struct {
|
||||||
ListenAddr string `yaml:"listen_addr"`
|
ListenAddr string `yaml:"listen_addr"`
|
||||||
ClientNamespace string `yaml:"client_namespace"`
|
ClientNamespace string `yaml:"client_namespace"`
|
||||||
@ -87,6 +108,8 @@ func parseGatewayConfig(logger *logging.ColoredLogger) *gateway.Config {
|
|||||||
IPFSAPIURL string `yaml:"ipfs_api_url"`
|
IPFSAPIURL string `yaml:"ipfs_api_url"`
|
||||||
IPFSTimeout string `yaml:"ipfs_timeout"`
|
IPFSTimeout string `yaml:"ipfs_timeout"`
|
||||||
IPFSReplicationFactor int `yaml:"ipfs_replication_factor"`
|
IPFSReplicationFactor int `yaml:"ipfs_replication_factor"`
|
||||||
|
TURN yamlTURN `yaml:"turn"`
|
||||||
|
SFU yamlSFU `yaml:"sfu"`
|
||||||
}
|
}
|
||||||
|
|
||||||
data, err := os.ReadFile(configPath)
|
data, err := os.ReadFile(configPath)
|
||||||
@ -191,6 +214,64 @@ func parseGatewayConfig(logger *logging.ColoredLogger) *gateway.Config {
|
|||||||
cfg.IPFSReplicationFactor = y.IPFSReplicationFactor
|
cfg.IPFSReplicationFactor = y.IPFSReplicationFactor
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TURN configuration
|
||||||
|
if y.TURN.SharedSecret != "" || len(y.TURN.STUNURLs) > 0 || len(y.TURN.TURNURLs) > 0 {
|
||||||
|
turnCfg := &config.TURNConfig{
|
||||||
|
SharedSecret: y.TURN.SharedSecret,
|
||||||
|
ExternalHost: y.TURN.ExternalHost,
|
||||||
|
STUNURLs: y.TURN.STUNURLs,
|
||||||
|
TURNURLs: y.TURN.TURNURLs,
|
||||||
|
}
|
||||||
|
// Check for environment variable overrides
|
||||||
|
if envSecret := os.Getenv("TURN_SHARED_SECRET"); envSecret != "" {
|
||||||
|
turnCfg.SharedSecret = envSecret
|
||||||
|
}
|
||||||
|
if envHost := os.Getenv("TURN_EXTERNAL_HOST"); envHost != "" {
|
||||||
|
turnCfg.ExternalHost = envHost
|
||||||
|
}
|
||||||
|
if v := strings.TrimSpace(y.TURN.TTL); v != "" {
|
||||||
|
if parsed, err := time.ParseDuration(v); err == nil {
|
||||||
|
turnCfg.TTL = parsed
|
||||||
|
} else {
|
||||||
|
logger.ComponentWarn(logging.ComponentGeneral, "invalid turn.ttl, using default", zap.String("value", v), zap.Error(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cfg.TURN = turnCfg
|
||||||
|
logger.ComponentInfo(logging.ComponentGeneral, "TURN configuration loaded",
|
||||||
|
zap.Int("stun_urls", len(turnCfg.STUNURLs)),
|
||||||
|
zap.Int("turn_urls", len(turnCfg.TURNURLs)),
|
||||||
|
zap.String("external_host", turnCfg.ExternalHost),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SFU configuration
|
||||||
|
if y.SFU.Enabled {
|
||||||
|
sfuCfg := &config.SFUConfig{
|
||||||
|
Enabled: true,
|
||||||
|
MaxParticipants: y.SFU.MaxParticipants,
|
||||||
|
}
|
||||||
|
if v := strings.TrimSpace(y.SFU.MediaTimeout); v != "" {
|
||||||
|
if parsed, err := time.ParseDuration(v); err == nil {
|
||||||
|
sfuCfg.MediaTimeout = parsed
|
||||||
|
} else {
|
||||||
|
logger.ComponentWarn(logging.ComponentGeneral, "invalid sfu.media_timeout, using default", zap.String("value", v), zap.Error(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Parse ICE servers
|
||||||
|
for _, iceServer := range y.SFU.ICEServers {
|
||||||
|
sfuCfg.ICEServers = append(sfuCfg.ICEServers, config.ICEServerConfig{
|
||||||
|
URLs: iceServer.URLs,
|
||||||
|
Username: iceServer.Username,
|
||||||
|
Credential: iceServer.Credential,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
cfg.SFU = sfuCfg
|
||||||
|
logger.ComponentInfo(logging.ComponentGeneral, "SFU configuration loaded",
|
||||||
|
zap.Int("max_participants", sfuCfg.MaxParticipants),
|
||||||
|
zap.Int("ice_servers", len(sfuCfg.ICEServers)),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
// Validate configuration
|
// Validate configuration
|
||||||
if errs := cfg.ValidateConfig(); len(errs) > 0 {
|
if errs := cfg.ValidateConfig(); len(errs) > 0 {
|
||||||
fmt.Fprintf(os.Stderr, "\nGateway configuration errors (%d):\n", len(errs))
|
fmt.Fprintf(os.Stderr, "\nGateway configuration errors (%d):\n", len(errs))
|
||||||
|
|||||||
8
go.mod
8
go.mod
@ -18,6 +18,10 @@ require (
|
|||||||
github.com/mattn/go-sqlite3 v1.14.32
|
github.com/mattn/go-sqlite3 v1.14.32
|
||||||
github.com/multiformats/go-multiaddr v0.15.0
|
github.com/multiformats/go-multiaddr v0.15.0
|
||||||
github.com/olric-data/olric v0.7.0
|
github.com/olric-data/olric v0.7.0
|
||||||
|
github.com/pion/interceptor v0.1.37
|
||||||
|
github.com/pion/rtcp v1.2.15
|
||||||
|
github.com/pion/turn/v4 v4.0.0
|
||||||
|
github.com/pion/webrtc/v4 v4.0.10
|
||||||
github.com/rqlite/gorqlite v0.0.0-20250609141355-ac86a4a1c9a8
|
github.com/rqlite/gorqlite v0.0.0-20250609141355-ac86a4a1c9a8
|
||||||
github.com/tetratelabs/wazero v1.11.0
|
github.com/tetratelabs/wazero v1.11.0
|
||||||
go.uber.org/zap v1.27.0
|
go.uber.org/zap v1.27.0
|
||||||
@ -113,11 +117,9 @@ require (
|
|||||||
github.com/pion/dtls/v2 v2.2.12 // indirect
|
github.com/pion/dtls/v2 v2.2.12 // indirect
|
||||||
github.com/pion/dtls/v3 v3.0.4 // indirect
|
github.com/pion/dtls/v3 v3.0.4 // indirect
|
||||||
github.com/pion/ice/v4 v4.0.8 // indirect
|
github.com/pion/ice/v4 v4.0.8 // indirect
|
||||||
github.com/pion/interceptor v0.1.37 // indirect
|
|
||||||
github.com/pion/logging v0.2.3 // indirect
|
github.com/pion/logging v0.2.3 // indirect
|
||||||
github.com/pion/mdns/v2 v2.0.7 // indirect
|
github.com/pion/mdns/v2 v2.0.7 // indirect
|
||||||
github.com/pion/randutil v0.1.0 // indirect
|
github.com/pion/randutil v0.1.0 // indirect
|
||||||
github.com/pion/rtcp v1.2.15 // indirect
|
|
||||||
github.com/pion/rtp v1.8.11 // indirect
|
github.com/pion/rtp v1.8.11 // indirect
|
||||||
github.com/pion/sctp v1.8.37 // indirect
|
github.com/pion/sctp v1.8.37 // indirect
|
||||||
github.com/pion/sdp/v3 v3.0.10 // indirect
|
github.com/pion/sdp/v3 v3.0.10 // indirect
|
||||||
@ -126,8 +128,6 @@ require (
|
|||||||
github.com/pion/stun/v3 v3.0.0 // indirect
|
github.com/pion/stun/v3 v3.0.0 // indirect
|
||||||
github.com/pion/transport/v2 v2.2.10 // indirect
|
github.com/pion/transport/v2 v2.2.10 // indirect
|
||||||
github.com/pion/transport/v3 v3.0.7 // indirect
|
github.com/pion/transport/v3 v3.0.7 // indirect
|
||||||
github.com/pion/turn/v4 v4.0.0 // indirect
|
|
||||||
github.com/pion/webrtc/v4 v4.0.10 // indirect
|
|
||||||
github.com/pkg/errors v0.9.1 // indirect
|
github.com/pkg/errors v0.9.1 // indirect
|
||||||
github.com/prometheus/client_golang v1.22.0 // indirect
|
github.com/prometheus/client_golang v1.22.0 // indirect
|
||||||
github.com/prometheus/client_model v0.6.2 // indirect
|
github.com/prometheus/client_model v0.6.2 // indirect
|
||||||
|
|||||||
@ -3,7 +3,6 @@ package config
|
|||||||
import (
|
import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/DeBrosOfficial/network/pkg/config/validate"
|
|
||||||
"github.com/multiformats/go-multiaddr"
|
"github.com/multiformats/go-multiaddr"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -15,69 +14,248 @@ type Config struct {
|
|||||||
Security SecurityConfig `yaml:"security"`
|
Security SecurityConfig `yaml:"security"`
|
||||||
Logging LoggingConfig `yaml:"logging"`
|
Logging LoggingConfig `yaml:"logging"`
|
||||||
HTTPGateway HTTPGatewayConfig `yaml:"http_gateway"`
|
HTTPGateway HTTPGatewayConfig `yaml:"http_gateway"`
|
||||||
|
TURNServer TURNServerConfig `yaml:"turn_server"` // Built-in TURN server
|
||||||
}
|
}
|
||||||
|
|
||||||
// ValidationError represents a single validation error with context.
|
// NodeConfig contains node-specific configuration
|
||||||
// This is exported from the validate subpackage for backward compatibility.
|
type NodeConfig struct {
|
||||||
type ValidationError = validate.ValidationError
|
ID string `yaml:"id"` // Auto-generated if empty
|
||||||
|
ListenAddresses []string `yaml:"listen_addresses"` // LibP2P listen addresses
|
||||||
// ValidateSwarmKey validates that a swarm key is 64 hex characters.
|
DataDir string `yaml:"data_dir"` // Data directory
|
||||||
// This is exported from the validate subpackage for backward compatibility.
|
MaxConnections int `yaml:"max_connections"` // Maximum peer connections
|
||||||
func ValidateSwarmKey(key string) error {
|
Domain string `yaml:"domain"` // Domain for this node (e.g., node-1.orama.network)
|
||||||
return validate.ValidateSwarmKey(key)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate performs comprehensive validation of the entire config.
|
// DatabaseConfig contains database-related configuration
|
||||||
// It aggregates all errors and returns them, allowing the caller to print all issues at once.
|
type DatabaseConfig struct {
|
||||||
func (c *Config) Validate() []error {
|
DataDir string `yaml:"data_dir"`
|
||||||
var errs []error
|
ReplicationFactor int `yaml:"replication_factor"`
|
||||||
|
ShardCount int `yaml:"shard_count"`
|
||||||
|
MaxDatabaseSize int64 `yaml:"max_database_size"` // In bytes
|
||||||
|
BackupInterval time.Duration `yaml:"backup_interval"`
|
||||||
|
|
||||||
// Validate node config
|
// RQLite-specific configuration
|
||||||
errs = append(errs, validate.ValidateNode(validate.NodeConfig{
|
RQLitePort int `yaml:"rqlite_port"` // RQLite HTTP API port
|
||||||
ID: c.Node.ID,
|
RQLiteRaftPort int `yaml:"rqlite_raft_port"` // RQLite Raft consensus port
|
||||||
ListenAddresses: c.Node.ListenAddresses,
|
RQLiteJoinAddress string `yaml:"rqlite_join_address"` // Address to join RQLite cluster
|
||||||
DataDir: c.Node.DataDir,
|
|
||||||
MaxConnections: c.Node.MaxConnections,
|
|
||||||
})...)
|
|
||||||
|
|
||||||
// Validate database config
|
// RQLite node-to-node TLS encryption (for inter-node Raft communication)
|
||||||
errs = append(errs, validate.ValidateDatabase(validate.DatabaseConfig{
|
// See: https://rqlite.io/docs/guides/security/#encrypting-node-to-node-communication
|
||||||
DataDir: c.Database.DataDir,
|
NodeCert string `yaml:"node_cert"` // Path to X.509 certificate for node-to-node communication
|
||||||
ReplicationFactor: c.Database.ReplicationFactor,
|
NodeKey string `yaml:"node_key"` // Path to X.509 private key for node-to-node communication
|
||||||
ShardCount: c.Database.ShardCount,
|
NodeCACert string `yaml:"node_ca_cert"` // Path to CA certificate (optional, uses system CA if not set)
|
||||||
MaxDatabaseSize: c.Database.MaxDatabaseSize,
|
NodeNoVerify bool `yaml:"node_no_verify"` // Skip certificate verification (for testing/self-signed certs)
|
||||||
RQLitePort: c.Database.RQLitePort,
|
|
||||||
RQLiteRaftPort: c.Database.RQLiteRaftPort,
|
|
||||||
RQLiteJoinAddress: c.Database.RQLiteJoinAddress,
|
|
||||||
ClusterSyncInterval: c.Database.ClusterSyncInterval,
|
|
||||||
PeerInactivityLimit: c.Database.PeerInactivityLimit,
|
|
||||||
MinClusterSize: c.Database.MinClusterSize,
|
|
||||||
})...)
|
|
||||||
|
|
||||||
// Validate discovery config
|
// Dynamic discovery configuration (always enabled)
|
||||||
errs = append(errs, validate.ValidateDiscovery(validate.DiscoveryConfig{
|
ClusterSyncInterval time.Duration `yaml:"cluster_sync_interval"` // default: 30s
|
||||||
BootstrapPeers: c.Discovery.BootstrapPeers,
|
PeerInactivityLimit time.Duration `yaml:"peer_inactivity_limit"` // default: 24h
|
||||||
DiscoveryInterval: c.Discovery.DiscoveryInterval,
|
MinClusterSize int `yaml:"min_cluster_size"` // default: 1
|
||||||
BootstrapPort: c.Discovery.BootstrapPort,
|
|
||||||
HttpAdvAddress: c.Discovery.HttpAdvAddress,
|
|
||||||
RaftAdvAddress: c.Discovery.RaftAdvAddress,
|
|
||||||
})...)
|
|
||||||
|
|
||||||
// Validate security config
|
// Olric cache configuration
|
||||||
errs = append(errs, validate.ValidateSecurity(validate.SecurityConfig{
|
OlricHTTPPort int `yaml:"olric_http_port"` // Olric HTTP API port (default: 3320)
|
||||||
EnableTLS: c.Security.EnableTLS,
|
OlricMemberlistPort int `yaml:"olric_memberlist_port"` // Olric memberlist port (default: 3322)
|
||||||
PrivateKeyFile: c.Security.PrivateKeyFile,
|
|
||||||
CertificateFile: c.Security.CertificateFile,
|
|
||||||
})...)
|
|
||||||
|
|
||||||
// Validate logging config
|
// IPFS storage configuration
|
||||||
errs = append(errs, validate.ValidateLogging(validate.LoggingConfig{
|
IPFS IPFSConfig `yaml:"ipfs"`
|
||||||
Level: c.Logging.Level,
|
}
|
||||||
Format: c.Logging.Format,
|
|
||||||
OutputFile: c.Logging.OutputFile,
|
|
||||||
})...)
|
|
||||||
|
|
||||||
return errs
|
// IPFSConfig contains IPFS storage configuration
|
||||||
|
type IPFSConfig struct {
|
||||||
|
// ClusterAPIURL is the IPFS Cluster HTTP API URL (e.g., "http://localhost:9094")
|
||||||
|
// If empty, IPFS storage is disabled for this node
|
||||||
|
ClusterAPIURL string `yaml:"cluster_api_url"`
|
||||||
|
|
||||||
|
// APIURL is the IPFS HTTP API URL for content retrieval (e.g., "http://localhost:5001")
|
||||||
|
// If empty, defaults to "http://localhost:5001"
|
||||||
|
APIURL string `yaml:"api_url"`
|
||||||
|
|
||||||
|
// Timeout for IPFS operations
|
||||||
|
// If zero, defaults to 60 seconds
|
||||||
|
Timeout time.Duration `yaml:"timeout"`
|
||||||
|
|
||||||
|
// ReplicationFactor is the replication factor for pinned content
|
||||||
|
// If zero, defaults to 3
|
||||||
|
ReplicationFactor int `yaml:"replication_factor"`
|
||||||
|
|
||||||
|
// EnableEncryption enables client-side encryption before upload
|
||||||
|
// Defaults to true
|
||||||
|
EnableEncryption bool `yaml:"enable_encryption"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// DiscoveryConfig contains peer discovery configuration
|
||||||
|
type DiscoveryConfig struct {
|
||||||
|
BootstrapPeers []string `yaml:"bootstrap_peers"` // Peer addresses to connect to
|
||||||
|
DiscoveryInterval time.Duration `yaml:"discovery_interval"` // Discovery announcement interval
|
||||||
|
BootstrapPort int `yaml:"bootstrap_port"` // Default port for peer discovery
|
||||||
|
HttpAdvAddress string `yaml:"http_adv_address"` // HTTP advertisement address
|
||||||
|
RaftAdvAddress string `yaml:"raft_adv_address"` // Raft advertisement
|
||||||
|
NodeNamespace string `yaml:"node_namespace"` // Namespace for node identifiers
|
||||||
|
}
|
||||||
|
|
||||||
|
// SecurityConfig contains security-related configuration
|
||||||
|
type SecurityConfig struct {
|
||||||
|
EnableTLS bool `yaml:"enable_tls"`
|
||||||
|
PrivateKeyFile string `yaml:"private_key_file"`
|
||||||
|
CertificateFile string `yaml:"certificate_file"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoggingConfig contains logging configuration
|
||||||
|
type LoggingConfig struct {
|
||||||
|
Level string `yaml:"level"` // debug, info, warn, error
|
||||||
|
Format string `yaml:"format"` // json, console
|
||||||
|
OutputFile string `yaml:"output_file"` // Empty for stdout
|
||||||
|
}
|
||||||
|
|
||||||
|
// HTTPGatewayConfig contains HTTP reverse proxy gateway configuration
|
||||||
|
type HTTPGatewayConfig struct {
|
||||||
|
Enabled bool `yaml:"enabled"` // Enable HTTP gateway
|
||||||
|
ListenAddr string `yaml:"listen_addr"` // Address to listen on (e.g., ":8080")
|
||||||
|
NodeName string `yaml:"node_name"` // Node name for routing
|
||||||
|
Routes map[string]RouteConfig `yaml:"routes"` // Service routes
|
||||||
|
HTTPS HTTPSConfig `yaml:"https"` // HTTPS/TLS configuration
|
||||||
|
SNI SNIConfig `yaml:"sni"` // SNI-based TCP routing configuration
|
||||||
|
|
||||||
|
// Full gateway configuration (for API, auth, pubsub)
|
||||||
|
ClientNamespace string `yaml:"client_namespace"` // Namespace for network client
|
||||||
|
RQLiteDSN string `yaml:"rqlite_dsn"` // RQLite database DSN
|
||||||
|
OlricServers []string `yaml:"olric_servers"` // List of Olric server addresses
|
||||||
|
OlricTimeout time.Duration `yaml:"olric_timeout"` // Timeout for Olric operations
|
||||||
|
IPFSClusterAPIURL string `yaml:"ipfs_cluster_api_url"` // IPFS Cluster API URL
|
||||||
|
IPFSAPIURL string `yaml:"ipfs_api_url"` // IPFS API URL
|
||||||
|
IPFSTimeout time.Duration `yaml:"ipfs_timeout"` // Timeout for IPFS operations
|
||||||
|
|
||||||
|
// WebRTC configuration for video/audio calls
|
||||||
|
TURN *TURNConfig `yaml:"turn"` // TURN/STUN server configuration
|
||||||
|
SFU *SFUConfig `yaml:"sfu"` // SFU (Selective Forwarding Unit) configuration
|
||||||
|
}
|
||||||
|
|
||||||
|
// HTTPSConfig contains HTTPS/TLS configuration for the gateway
|
||||||
|
type HTTPSConfig struct {
|
||||||
|
Enabled bool `yaml:"enabled"` // Enable HTTPS (port 443)
|
||||||
|
Domain string `yaml:"domain"` // Primary domain (e.g., node-123.orama.network)
|
||||||
|
AutoCert bool `yaml:"auto_cert"` // Use Let's Encrypt for automatic certificate
|
||||||
|
UseSelfSigned bool `yaml:"use_self_signed"` // Use self-signed certificates (pre-generated)
|
||||||
|
CertFile string `yaml:"cert_file"` // Path to certificate file (if not using auto_cert)
|
||||||
|
KeyFile string `yaml:"key_file"` // Path to key file (if not using auto_cert)
|
||||||
|
CacheDir string `yaml:"cache_dir"` // Directory for Let's Encrypt certificate cache
|
||||||
|
HTTPPort int `yaml:"http_port"` // HTTP port for ACME challenge (default: 80)
|
||||||
|
HTTPSPort int `yaml:"https_port"` // HTTPS port (default: 443)
|
||||||
|
Email string `yaml:"email"` // Email for Let's Encrypt account
|
||||||
|
}
|
||||||
|
|
||||||
|
// SNIConfig contains SNI-based TCP routing configuration for port 7001
|
||||||
|
type SNIConfig struct {
|
||||||
|
Enabled bool `yaml:"enabled"` // Enable SNI-based TCP routing
|
||||||
|
ListenAddr string `yaml:"listen_addr"` // Address to listen on (e.g., ":7001")
|
||||||
|
Routes map[string]string `yaml:"routes"` // SNI hostname -> backend address mapping
|
||||||
|
CertFile string `yaml:"cert_file"` // Path to certificate file
|
||||||
|
KeyFile string `yaml:"key_file"` // Path to key file
|
||||||
|
}
|
||||||
|
|
||||||
|
// RouteConfig defines a single reverse proxy route
|
||||||
|
type RouteConfig struct {
|
||||||
|
PathPrefix string `yaml:"path_prefix"` // URL path prefix (e.g., "/rqlite/http")
|
||||||
|
BackendURL string `yaml:"backend_url"` // Backend service URL
|
||||||
|
Timeout time.Duration `yaml:"timeout"` // Request timeout
|
||||||
|
WebSocket bool `yaml:"websocket"` // Support WebSocket upgrades
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClientConfig represents configuration for network clients
|
||||||
|
type ClientConfig struct {
|
||||||
|
AppName string `yaml:"app_name"`
|
||||||
|
DatabaseName string `yaml:"database_name"`
|
||||||
|
BootstrapPeers []string `yaml:"bootstrap_peers"`
|
||||||
|
ConnectTimeout time.Duration `yaml:"connect_timeout"`
|
||||||
|
RetryAttempts int `yaml:"retry_attempts"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// TURNConfig contains TURN/STUN server credential configuration
|
||||||
|
type TURNConfig struct {
|
||||||
|
// SharedSecret is the shared secret for TURN credential generation (HMAC-SHA1)
|
||||||
|
// Should be set via TURN_SHARED_SECRET environment variable
|
||||||
|
SharedSecret string `yaml:"shared_secret"`
|
||||||
|
|
||||||
|
// TTL is the time-to-live for generated credentials
|
||||||
|
// Default: 24 hours
|
||||||
|
TTL time.Duration `yaml:"ttl"`
|
||||||
|
|
||||||
|
// ExternalHost is the external hostname or IP address for STUN/TURN URLs
|
||||||
|
// - Production: Set to your public domain (e.g., "turn.example.com")
|
||||||
|
// - Development: Leave empty for auto-detection of LAN IP
|
||||||
|
// Can also be set via TURN_EXTERNAL_HOST environment variable
|
||||||
|
ExternalHost string `yaml:"external_host"`
|
||||||
|
|
||||||
|
// STUNURLs are the STUN server URLs to return to clients
|
||||||
|
// Use "::" as placeholder for ExternalHost (e.g., "stun:::3478" -> "stun:turn.example.com:3478")
|
||||||
|
// e.g., ["stun:::3478"] or ["stun:gateway.orama.com:3478"]
|
||||||
|
STUNURLs []string `yaml:"stun_urls"`
|
||||||
|
|
||||||
|
// TURNURLs are the TURN server URLs to return to clients
|
||||||
|
// Use "::" as placeholder for ExternalHost (e.g., "turn:::3478" -> "turn:turn.example.com:3478")
|
||||||
|
// e.g., ["turn:::3478?transport=udp"] or ["turn:gateway.orama.com:3478?transport=udp"]
|
||||||
|
TURNURLs []string `yaml:"turn_urls"`
|
||||||
|
|
||||||
|
// TLSEnabled indicates whether TURNS (TURN over TLS) is available
|
||||||
|
// When true, turns:// URLs will be included in the response
|
||||||
|
TLSEnabled bool `yaml:"tls_enabled"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SFUConfig contains WebRTC SFU (Selective Forwarding Unit) configuration
|
||||||
|
type SFUConfig struct {
|
||||||
|
// Enabled enables the SFU service
|
||||||
|
Enabled bool `yaml:"enabled"`
|
||||||
|
|
||||||
|
// MaxParticipants is the maximum number of participants per room
|
||||||
|
// Default: 10
|
||||||
|
MaxParticipants int `yaml:"max_participants"`
|
||||||
|
|
||||||
|
// MediaTimeout is the timeout for media operations
|
||||||
|
// Default: 30 seconds
|
||||||
|
MediaTimeout time.Duration `yaml:"media_timeout"`
|
||||||
|
|
||||||
|
// ICEServers are additional ICE servers for WebRTC connections
|
||||||
|
// These are used in addition to the TURN servers from TURNConfig
|
||||||
|
ICEServers []ICEServerConfig `yaml:"ice_servers"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ICEServerConfig represents a single ICE server configuration
|
||||||
|
type ICEServerConfig struct {
|
||||||
|
URLs []string `yaml:"urls"`
|
||||||
|
Username string `yaml:"username,omitempty"`
|
||||||
|
Credential string `yaml:"credential,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// TURNServerConfig contains built-in TURN server configuration
|
||||||
|
type TURNServerConfig struct {
|
||||||
|
// Enabled enables the built-in TURN server
|
||||||
|
Enabled bool `yaml:"enabled"`
|
||||||
|
|
||||||
|
// ListenAddr is the UDP address to listen on (e.g., "0.0.0.0:3478")
|
||||||
|
ListenAddr string `yaml:"listen_addr"`
|
||||||
|
|
||||||
|
// PublicIP is the public IP address to advertise for relay
|
||||||
|
// If empty, will try to auto-detect
|
||||||
|
PublicIP string `yaml:"public_ip"`
|
||||||
|
|
||||||
|
// Realm is the TURN realm (e.g., "orama.network")
|
||||||
|
Realm string `yaml:"realm"`
|
||||||
|
|
||||||
|
// MinPort and MaxPort define the relay port range
|
||||||
|
MinPort uint16 `yaml:"min_port"`
|
||||||
|
MaxPort uint16 `yaml:"max_port"`
|
||||||
|
|
||||||
|
// TLS Configuration for TURNS (TURN over TLS)
|
||||||
|
// TLSEnabled enables TURNS listener
|
||||||
|
TLSEnabled bool `yaml:"tls_enabled"`
|
||||||
|
|
||||||
|
// TLSListenAddr is the TCP/TLS address to listen on (e.g., "0.0.0.0:443")
|
||||||
|
TLSListenAddr string `yaml:"tls_listen_addr"`
|
||||||
|
|
||||||
|
// TLSCertFile is the path to the TLS certificate file
|
||||||
|
TLSCertFile string `yaml:"tls_cert_file"`
|
||||||
|
|
||||||
|
// TLSKeyFile is the path to the TLS private key file
|
||||||
|
TLSKeyFile string `yaml:"tls_key_file"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ParseMultiaddrs converts string addresses to multiaddr objects
|
// ParseMultiaddrs converts string addresses to multiaddr objects
|
||||||
|
|||||||
@ -47,6 +47,15 @@ logging:
|
|||||||
level: "info"
|
level: "info"
|
||||||
format: "console"
|
format: "console"
|
||||||
|
|
||||||
|
# Built-in TURN server for WebRTC NAT traversal
|
||||||
|
turn_server:
|
||||||
|
enabled: true
|
||||||
|
listen_addr: "0.0.0.0:3478"
|
||||||
|
public_ip: "" # Auto-detect if empty, or set to your public IP
|
||||||
|
realm: "orama.network"
|
||||||
|
min_port: 49152
|
||||||
|
max_port: 65535
|
||||||
|
|
||||||
http_gateway:
|
http_gateway:
|
||||||
enabled: true
|
enabled: true
|
||||||
listen_addr: "{{if .EnableHTTPS}}:{{.HTTPSPort}}{{else}}:{{.UnifiedGatewayPort}}{{end}}"
|
listen_addr: "{{if .EnableHTTPS}}:{{.HTTPSPort}}{{else}}:{{.UnifiedGatewayPort}}{{end}}"
|
||||||
@ -86,3 +95,19 @@ http_gateway:
|
|||||||
|
|
||||||
# Routes for internal service reverse proxy (kept for backwards compatibility but not used by full gateway)
|
# Routes for internal service reverse proxy (kept for backwards compatibility but not used by full gateway)
|
||||||
routes: {}
|
routes: {}
|
||||||
|
|
||||||
|
# TURN/STUN URLs returned to clients (points to built-in TURN server)
|
||||||
|
turn:
|
||||||
|
shared_secret: "dev-secret-12345"
|
||||||
|
ttl: "24h"
|
||||||
|
stun_urls:
|
||||||
|
- "stun:::3478"
|
||||||
|
turn_urls:
|
||||||
|
- "turn:::3478?transport=udp"
|
||||||
|
|
||||||
|
# SFU (Selective Forwarding Unit) configuration for WebRTC group calls
|
||||||
|
sfu:
|
||||||
|
enabled: true
|
||||||
|
max_participants: 10
|
||||||
|
media_timeout: "30s"
|
||||||
|
ice_servers: [] # Additional ICE servers beyond TURN config (optional)
|
||||||
|
|||||||
@ -1,31 +1,75 @@
|
|||||||
// Package gateway provides the main API Gateway for the Orama Network.
|
|
||||||
// It orchestrates traffic between clients and various backend services including
|
|
||||||
// distributed caching (Olric), decentralized storage (IPFS), and serverless
|
|
||||||
// WebAssembly (WASM) execution. The gateway implements robust security through
|
|
||||||
// wallet-based cryptographic authentication and JWT lifecycle management.
|
|
||||||
package gateway
|
package gateway
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
|
"crypto/x509"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"encoding/pem"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/DeBrosOfficial/network/pkg/client"
|
"github.com/DeBrosOfficial/network/pkg/client"
|
||||||
|
"github.com/DeBrosOfficial/network/pkg/config"
|
||||||
"github.com/DeBrosOfficial/network/pkg/gateway/auth"
|
"github.com/DeBrosOfficial/network/pkg/gateway/auth"
|
||||||
authhandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/auth"
|
|
||||||
"github.com/DeBrosOfficial/network/pkg/gateway/handlers/cache"
|
|
||||||
pubsubhandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/pubsub"
|
|
||||||
serverlesshandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/serverless"
|
|
||||||
"github.com/DeBrosOfficial/network/pkg/gateway/handlers/storage"
|
|
||||||
"github.com/DeBrosOfficial/network/pkg/ipfs"
|
"github.com/DeBrosOfficial/network/pkg/ipfs"
|
||||||
"github.com/DeBrosOfficial/network/pkg/logging"
|
"github.com/DeBrosOfficial/network/pkg/logging"
|
||||||
"github.com/DeBrosOfficial/network/pkg/olric"
|
"github.com/DeBrosOfficial/network/pkg/olric"
|
||||||
|
"github.com/DeBrosOfficial/network/pkg/pubsub"
|
||||||
"github.com/DeBrosOfficial/network/pkg/rqlite"
|
"github.com/DeBrosOfficial/network/pkg/rqlite"
|
||||||
"github.com/DeBrosOfficial/network/pkg/serverless"
|
"github.com/DeBrosOfficial/network/pkg/serverless"
|
||||||
|
"github.com/multiformats/go-multiaddr"
|
||||||
|
olriclib "github.com/olric-data/olric"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
|
|
||||||
|
_ "github.com/rqlite/gorqlite/stdlib"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
olricInitMaxAttempts = 5
|
||||||
|
olricInitInitialBackoff = 500 * time.Millisecond
|
||||||
|
olricInitMaxBackoff = 5 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
// Config holds configuration for the gateway server
|
||||||
|
type Config struct {
|
||||||
|
ListenAddr string
|
||||||
|
ClientNamespace string
|
||||||
|
BootstrapPeers []string
|
||||||
|
NodePeerID string // The node's actual peer ID from its identity file
|
||||||
|
|
||||||
|
// Optional DSN for rqlite database/sql driver, e.g. "http://localhost:4001"
|
||||||
|
// If empty, defaults to "http://localhost:4001".
|
||||||
|
RQLiteDSN string
|
||||||
|
|
||||||
|
// HTTPS configuration
|
||||||
|
EnableHTTPS bool // Enable HTTPS with ACME (Let's Encrypt)
|
||||||
|
DomainName string // Domain name for HTTPS certificate
|
||||||
|
TLSCacheDir string // Directory to cache TLS certificates (default: ~/.orama/tls-cache)
|
||||||
|
|
||||||
|
// Olric cache configuration
|
||||||
|
OlricServers []string // List of Olric server addresses (e.g., ["localhost:3320"]). If empty, defaults to ["localhost:3320"]
|
||||||
|
OlricTimeout time.Duration // Timeout for Olric operations (default: 10s)
|
||||||
|
|
||||||
|
// IPFS Cluster configuration
|
||||||
|
IPFSClusterAPIURL string // IPFS Cluster HTTP API URL (e.g., "http://localhost:9094"). If empty, gateway will discover from node configs
|
||||||
|
IPFSAPIURL string // IPFS HTTP API URL for content retrieval (e.g., "http://localhost:5001"). If empty, gateway will discover from node configs
|
||||||
|
IPFSTimeout time.Duration // Timeout for IPFS operations (default: 60s)
|
||||||
|
IPFSReplicationFactor int // Replication factor for pins (default: 3)
|
||||||
|
IPFSEnableEncryption bool // Enable client-side encryption before upload (default: true, discovered from node configs)
|
||||||
|
|
||||||
|
// TURN/STUN configuration for WebRTC
|
||||||
|
TURN *config.TURNConfig
|
||||||
|
|
||||||
|
// SFU configuration for WebRTC group calls
|
||||||
|
SFU *config.SFUConfig
|
||||||
|
}
|
||||||
|
|
||||||
type Gateway struct {
|
type Gateway struct {
|
||||||
logger *logging.ColoredLogger
|
logger *logging.ColoredLogger
|
||||||
@ -42,29 +86,28 @@ type Gateway struct {
|
|||||||
// Olric cache client
|
// Olric cache client
|
||||||
olricClient *olric.Client
|
olricClient *olric.Client
|
||||||
olricMu sync.RWMutex
|
olricMu sync.RWMutex
|
||||||
cacheHandlers *cache.CacheHandlers
|
|
||||||
|
|
||||||
// IPFS storage client
|
// IPFS storage client
|
||||||
ipfsClient ipfs.IPFSClient
|
ipfsClient ipfs.IPFSClient
|
||||||
storageHandlers *storage.Handlers
|
|
||||||
|
|
||||||
// Local pub/sub bypass for same-gateway subscribers
|
// Local pub/sub bypass for same-gateway subscribers
|
||||||
localSubscribers map[string][]*localSubscriber // topic+namespace -> subscribers
|
localSubscribers map[string][]*localSubscriber // topic+namespace -> subscribers
|
||||||
presenceMembers map[string][]PresenceMember // topicKey -> members
|
presenceMembers map[string][]PresenceMember // topicKey -> members
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
presenceMu sync.RWMutex
|
presenceMu sync.RWMutex
|
||||||
pubsubHandlers *pubsubhandlers.PubSubHandlers
|
|
||||||
|
|
||||||
// Serverless function engine
|
// Serverless function engine
|
||||||
serverlessEngine *serverless.Engine
|
serverlessEngine *serverless.Engine
|
||||||
serverlessRegistry *serverless.Registry
|
serverlessRegistry *serverless.Registry
|
||||||
serverlessInvoker *serverless.Invoker
|
serverlessInvoker *serverless.Invoker
|
||||||
serverlessWSMgr *serverless.WSManager
|
serverlessWSMgr *serverless.WSManager
|
||||||
serverlessHandlers *serverlesshandlers.ServerlessHandlers
|
serverlessHandlers *ServerlessHandlers
|
||||||
|
|
||||||
// Authentication service
|
// Authentication service
|
||||||
authService *auth.Service
|
authService *auth.Service
|
||||||
authHandlers *authhandlers.Handlers
|
|
||||||
|
// SFU manager for WebRTC group calls
|
||||||
|
sfuManager *SFUManager
|
||||||
}
|
}
|
||||||
|
|
||||||
// localSubscriber represents a WebSocket subscriber for local message delivery
|
// localSubscriber represents a WebSocket subscriber for local message delivery
|
||||||
@ -81,113 +124,359 @@ type PresenceMember struct {
|
|||||||
ConnID string `json:"-"` // Internal: for tracking which connection
|
ConnID string `json:"-"` // Internal: for tracking which connection
|
||||||
}
|
}
|
||||||
|
|
||||||
// authClientAdapter adapts client.NetworkClient to authhandlers.NetworkClient
|
// New creates and initializes a new Gateway instance
|
||||||
type authClientAdapter struct {
|
|
||||||
client client.NetworkClient
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *authClientAdapter) Database() authhandlers.DatabaseClient {
|
|
||||||
return &authDatabaseAdapter{db: a.client.Database()}
|
|
||||||
}
|
|
||||||
|
|
||||||
// authDatabaseAdapter adapts client.DatabaseClient to authhandlers.DatabaseClient
|
|
||||||
type authDatabaseAdapter struct {
|
|
||||||
db client.DatabaseClient
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *authDatabaseAdapter) Query(ctx context.Context, sql string, args ...interface{}) (*authhandlers.QueryResult, error) {
|
|
||||||
result, err := a.db.Query(ctx, sql, args...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
// Convert client.QueryResult to authhandlers.QueryResult
|
|
||||||
// The auth handlers expect []interface{} but client returns [][]interface{}
|
|
||||||
convertedRows := make([]interface{}, len(result.Rows))
|
|
||||||
for i, row := range result.Rows {
|
|
||||||
convertedRows[i] = row
|
|
||||||
}
|
|
||||||
return &authhandlers.QueryResult{
|
|
||||||
Count: int(result.Count),
|
|
||||||
Rows: convertedRows,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// New creates and initializes a new Gateway instance.
|
|
||||||
// It establishes all necessary service connections and dependencies.
|
|
||||||
func New(logger *logging.ColoredLogger, cfg *Config) (*Gateway, error) {
|
func New(logger *logging.ColoredLogger, cfg *Config) (*Gateway, error) {
|
||||||
logger.ComponentInfo(logging.ComponentGeneral, "Creating gateway dependencies...")
|
logger.ComponentInfo(logging.ComponentGeneral, "Building client config...")
|
||||||
|
|
||||||
// Initialize all dependencies (network client, database, cache, storage, serverless)
|
// Build client config from gateway cfg
|
||||||
deps, err := NewDependencies(logger, cfg)
|
cliCfg := client.DefaultClientConfig(cfg.ClientNamespace)
|
||||||
|
if len(cfg.BootstrapPeers) > 0 {
|
||||||
|
cliCfg.BootstrapPeers = cfg.BootstrapPeers
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.ComponentInfo(logging.ComponentGeneral, "Creating network client...")
|
||||||
|
c, err := client.NewClient(cliCfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.ComponentError(logging.ComponentGeneral, "failed to create dependencies", zap.Error(err))
|
logger.ComponentError(logging.ComponentClient, "failed to create network client", zap.Error(err))
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
logger.ComponentInfo(logging.ComponentGeneral, "Connecting network client...")
|
||||||
|
if err := c.Connect(); err != nil {
|
||||||
|
logger.ComponentError(logging.ComponentClient, "failed to connect network client", zap.Error(err))
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.ComponentInfo(logging.ComponentClient, "Network client connected",
|
||||||
|
zap.String("namespace", cliCfg.AppName),
|
||||||
|
zap.Int("peer_count", len(cliCfg.BootstrapPeers)),
|
||||||
|
)
|
||||||
|
|
||||||
logger.ComponentInfo(logging.ComponentGeneral, "Creating gateway instance...")
|
logger.ComponentInfo(logging.ComponentGeneral, "Creating gateway instance...")
|
||||||
gw := &Gateway{
|
gw := &Gateway{
|
||||||
logger: logger,
|
logger: logger,
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
client: deps.Client,
|
client: c,
|
||||||
nodePeerID: cfg.NodePeerID,
|
nodePeerID: cfg.NodePeerID,
|
||||||
startedAt: time.Now(),
|
startedAt: time.Now(),
|
||||||
sqlDB: deps.SQLDB,
|
localSubscribers: make(map[string][]*localSubscriber),
|
||||||
ormClient: deps.ORMClient,
|
presenceMembers: make(map[string][]PresenceMember),
|
||||||
ormHTTP: deps.ORMHTTP,
|
|
||||||
olricClient: deps.OlricClient,
|
|
||||||
ipfsClient: deps.IPFSClient,
|
|
||||||
serverlessEngine: deps.ServerlessEngine,
|
|
||||||
serverlessRegistry: deps.ServerlessRegistry,
|
|
||||||
serverlessInvoker: deps.ServerlessInvoker,
|
|
||||||
serverlessWSMgr: deps.ServerlessWSMgr,
|
|
||||||
serverlessHandlers: deps.ServerlessHandlers,
|
|
||||||
authService: deps.AuthService,
|
|
||||||
localSubscribers: make(map[string][]*localSubscriber),
|
|
||||||
presenceMembers: make(map[string][]PresenceMember),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize handler instances
|
logger.ComponentInfo(logging.ComponentGeneral, "Initializing RQLite ORM HTTP gateway...")
|
||||||
gw.pubsubHandlers = pubsubhandlers.NewPubSubHandlers(deps.Client, logger)
|
dsn := cfg.RQLiteDSN
|
||||||
|
if dsn == "" {
|
||||||
if deps.OlricClient != nil {
|
dsn = "http://localhost:5001"
|
||||||
gw.cacheHandlers = cache.NewCacheHandlers(logger, deps.OlricClient)
|
|
||||||
}
|
}
|
||||||
|
db, dbErr := sql.Open("rqlite", dsn)
|
||||||
|
if dbErr != nil {
|
||||||
|
logger.ComponentWarn(logging.ComponentGeneral, "failed to open rqlite sql db; http orm gateway disabled", zap.Error(dbErr))
|
||||||
|
} else {
|
||||||
|
// Configure connection pool with proper timeouts and limits
|
||||||
|
db.SetMaxOpenConns(25) // Maximum number of open connections
|
||||||
|
db.SetMaxIdleConns(5) // Maximum number of idle connections
|
||||||
|
db.SetConnMaxLifetime(5 * time.Minute) // Maximum lifetime of a connection
|
||||||
|
db.SetConnMaxIdleTime(2 * time.Minute) // Maximum idle time before closing
|
||||||
|
|
||||||
if deps.IPFSClient != nil {
|
gw.sqlDB = db
|
||||||
gw.storageHandlers = storage.New(deps.IPFSClient, logger, storage.Config{
|
orm := rqlite.NewClient(db)
|
||||||
IPFSReplicationFactor: cfg.IPFSReplicationFactor,
|
gw.ormClient = orm
|
||||||
IPFSAPIURL: cfg.IPFSAPIURL,
|
gw.ormHTTP = rqlite.NewHTTPGateway(orm, "/v1/db")
|
||||||
})
|
// Set a reasonable timeout for HTTP requests (30 seconds)
|
||||||
}
|
gw.ormHTTP.Timeout = 30 * time.Second
|
||||||
|
logger.ComponentInfo(logging.ComponentGeneral, "RQLite ORM HTTP gateway ready",
|
||||||
if deps.AuthService != nil {
|
zap.String("dsn", dsn),
|
||||||
// Create adapter for auth handlers to use the client
|
zap.String("base_path", "/v1/db"),
|
||||||
authClientAdapter := &authClientAdapter{client: deps.Client}
|
zap.Duration("timeout", gw.ormHTTP.Timeout),
|
||||||
gw.authHandlers = authhandlers.NewHandlers(
|
|
||||||
logger,
|
|
||||||
deps.AuthService,
|
|
||||||
authClientAdapter,
|
|
||||||
cfg.ClientNamespace,
|
|
||||||
gw.withInternalAuth,
|
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start background Olric reconnection if initial connection failed
|
logger.ComponentInfo(logging.ComponentGeneral, "Initializing Olric cache client...")
|
||||||
if deps.OlricClient == nil {
|
|
||||||
olricCfg := olric.Config{
|
// Discover Olric servers dynamically from LibP2P peers if not explicitly configured
|
||||||
Servers: cfg.OlricServers,
|
olricServers := cfg.OlricServers
|
||||||
Timeout: cfg.OlricTimeout,
|
if len(olricServers) == 0 {
|
||||||
|
logger.ComponentInfo(logging.ComponentGeneral, "Olric servers not configured, discovering from LibP2P peers...")
|
||||||
|
discovered := discoverOlricServers(c, logger.Logger)
|
||||||
|
if len(discovered) > 0 {
|
||||||
|
olricServers = discovered
|
||||||
|
logger.ComponentInfo(logging.ComponentGeneral, "Discovered Olric servers from LibP2P peers",
|
||||||
|
zap.Strings("servers", olricServers))
|
||||||
|
} else {
|
||||||
|
// Fallback to localhost for local development
|
||||||
|
olricServers = []string{"localhost:3320"}
|
||||||
|
logger.ComponentInfo(logging.ComponentGeneral, "No Olric servers discovered, using localhost fallback")
|
||||||
}
|
}
|
||||||
if len(olricCfg.Servers) == 0 {
|
} else {
|
||||||
olricCfg.Servers = []string{"localhost:3320"}
|
logger.ComponentInfo(logging.ComponentGeneral, "Using explicitly configured Olric servers",
|
||||||
}
|
zap.Strings("servers", olricServers))
|
||||||
gw.startOlricReconnectLoop(olricCfg)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.ComponentInfo(logging.ComponentGeneral, "Gateway creation completed")
|
olricCfg := olric.Config{
|
||||||
|
Servers: olricServers,
|
||||||
|
Timeout: cfg.OlricTimeout,
|
||||||
|
}
|
||||||
|
olricClient, olricErr := initializeOlricClientWithRetry(olricCfg, logger)
|
||||||
|
if olricErr != nil {
|
||||||
|
logger.ComponentWarn(logging.ComponentGeneral, "failed to initialize Olric cache client; cache endpoints disabled", zap.Error(olricErr))
|
||||||
|
gw.startOlricReconnectLoop(olricCfg)
|
||||||
|
} else {
|
||||||
|
gw.setOlricClient(olricClient)
|
||||||
|
logger.ComponentInfo(logging.ComponentGeneral, "Olric cache client ready",
|
||||||
|
zap.Strings("servers", olricCfg.Servers),
|
||||||
|
zap.Duration("timeout", olricCfg.Timeout),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.ComponentInfo(logging.ComponentGeneral, "Initializing IPFS Cluster client...")
|
||||||
|
|
||||||
|
// Discover IPFS endpoints from node configs if not explicitly configured
|
||||||
|
ipfsClusterURL := cfg.IPFSClusterAPIURL
|
||||||
|
ipfsAPIURL := cfg.IPFSAPIURL
|
||||||
|
ipfsTimeout := cfg.IPFSTimeout
|
||||||
|
ipfsReplicationFactor := cfg.IPFSReplicationFactor
|
||||||
|
ipfsEnableEncryption := cfg.IPFSEnableEncryption
|
||||||
|
|
||||||
|
if ipfsClusterURL == "" {
|
||||||
|
logger.ComponentInfo(logging.ComponentGeneral, "IPFS Cluster URL not configured, discovering from node configs...")
|
||||||
|
discovered := discoverIPFSFromNodeConfigs(logger.Logger)
|
||||||
|
if discovered.clusterURL != "" {
|
||||||
|
ipfsClusterURL = discovered.clusterURL
|
||||||
|
ipfsAPIURL = discovered.apiURL
|
||||||
|
if discovered.timeout > 0 {
|
||||||
|
ipfsTimeout = discovered.timeout
|
||||||
|
}
|
||||||
|
if discovered.replicationFactor > 0 {
|
||||||
|
ipfsReplicationFactor = discovered.replicationFactor
|
||||||
|
}
|
||||||
|
ipfsEnableEncryption = discovered.enableEncryption
|
||||||
|
logger.ComponentInfo(logging.ComponentGeneral, "Discovered IPFS endpoints from node configs",
|
||||||
|
zap.String("cluster_url", ipfsClusterURL),
|
||||||
|
zap.String("api_url", ipfsAPIURL),
|
||||||
|
zap.Bool("encryption_enabled", ipfsEnableEncryption))
|
||||||
|
} else {
|
||||||
|
// Fallback to localhost defaults
|
||||||
|
ipfsClusterURL = "http://localhost:9094"
|
||||||
|
ipfsAPIURL = "http://localhost:5001"
|
||||||
|
ipfsEnableEncryption = true // Default to true
|
||||||
|
logger.ComponentInfo(logging.ComponentGeneral, "No IPFS config found in node configs, using localhost defaults")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if ipfsAPIURL == "" {
|
||||||
|
ipfsAPIURL = "http://localhost:5001"
|
||||||
|
}
|
||||||
|
if ipfsTimeout == 0 {
|
||||||
|
ipfsTimeout = 60 * time.Second
|
||||||
|
}
|
||||||
|
if ipfsReplicationFactor == 0 {
|
||||||
|
ipfsReplicationFactor = 3
|
||||||
|
}
|
||||||
|
if !cfg.IPFSEnableEncryption && !ipfsEnableEncryption {
|
||||||
|
// Only disable if explicitly set to false in both places
|
||||||
|
ipfsEnableEncryption = false
|
||||||
|
} else {
|
||||||
|
// Default to true if not explicitly disabled
|
||||||
|
ipfsEnableEncryption = true
|
||||||
|
}
|
||||||
|
|
||||||
|
ipfsCfg := ipfs.Config{
|
||||||
|
ClusterAPIURL: ipfsClusterURL,
|
||||||
|
Timeout: ipfsTimeout,
|
||||||
|
}
|
||||||
|
ipfsClient, ipfsErr := ipfs.NewClient(ipfsCfg, logger.Logger)
|
||||||
|
if ipfsErr != nil {
|
||||||
|
logger.ComponentWarn(logging.ComponentGeneral, "failed to initialize IPFS Cluster client; storage endpoints disabled", zap.Error(ipfsErr))
|
||||||
|
} else {
|
||||||
|
gw.ipfsClient = ipfsClient
|
||||||
|
|
||||||
|
// Check peer count and warn if insufficient (use background context to avoid blocking)
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
if peerCount, err := ipfsClient.GetPeerCount(ctx); err == nil {
|
||||||
|
if peerCount < ipfsReplicationFactor {
|
||||||
|
logger.ComponentWarn(logging.ComponentGeneral, "insufficient cluster peers for replication factor",
|
||||||
|
zap.Int("peer_count", peerCount),
|
||||||
|
zap.Int("replication_factor", ipfsReplicationFactor),
|
||||||
|
zap.String("message", "Some pin operations may fail until more peers join the cluster"))
|
||||||
|
} else {
|
||||||
|
logger.ComponentInfo(logging.ComponentGeneral, "IPFS Cluster peer count sufficient",
|
||||||
|
zap.Int("peer_count", peerCount),
|
||||||
|
zap.Int("replication_factor", ipfsReplicationFactor))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
logger.ComponentWarn(logging.ComponentGeneral, "failed to get cluster peer count", zap.Error(err))
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.ComponentInfo(logging.ComponentGeneral, "IPFS Cluster client ready",
|
||||||
|
zap.String("cluster_api_url", ipfsCfg.ClusterAPIURL),
|
||||||
|
zap.String("ipfs_api_url", ipfsAPIURL),
|
||||||
|
zap.Duration("timeout", ipfsCfg.Timeout),
|
||||||
|
zap.Int("replication_factor", ipfsReplicationFactor),
|
||||||
|
zap.Bool("encryption_enabled", ipfsEnableEncryption),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
// Store IPFS settings in gateway for use by handlers
|
||||||
|
gw.cfg.IPFSAPIURL = ipfsAPIURL
|
||||||
|
gw.cfg.IPFSReplicationFactor = ipfsReplicationFactor
|
||||||
|
gw.cfg.IPFSEnableEncryption = ipfsEnableEncryption
|
||||||
|
|
||||||
|
// Initialize serverless function engine
|
||||||
|
logger.ComponentInfo(logging.ComponentGeneral, "Initializing serverless function engine...")
|
||||||
|
if gw.ormClient != nil && gw.ipfsClient != nil {
|
||||||
|
// Create serverless registry (stores functions in RQLite + IPFS)
|
||||||
|
registryCfg := serverless.RegistryConfig{
|
||||||
|
IPFSAPIURL: ipfsAPIURL,
|
||||||
|
}
|
||||||
|
registry := serverless.NewRegistry(gw.ormClient, gw.ipfsClient, registryCfg, logger.Logger)
|
||||||
|
gw.serverlessRegistry = registry
|
||||||
|
|
||||||
|
// Create WebSocket manager for function streaming
|
||||||
|
gw.serverlessWSMgr = serverless.NewWSManager(logger.Logger)
|
||||||
|
|
||||||
|
// Get underlying Olric client if available
|
||||||
|
var olricClient olriclib.Client
|
||||||
|
if oc := gw.getOlricClient(); oc != nil {
|
||||||
|
olricClient = oc.UnderlyingClient()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create host functions provider (allows functions to call Orama services)
|
||||||
|
// Get pubsub adapter from client for serverless functions
|
||||||
|
var pubsubAdapter *pubsub.ClientAdapter
|
||||||
|
if gw.client != nil {
|
||||||
|
if concreteClient, ok := gw.client.(*client.Client); ok {
|
||||||
|
pubsubAdapter = concreteClient.PubSubAdapter()
|
||||||
|
if pubsubAdapter != nil {
|
||||||
|
logger.ComponentInfo(logging.ComponentGeneral, "pubsub adapter available for serverless functions")
|
||||||
|
} else {
|
||||||
|
logger.ComponentWarn(logging.ComponentGeneral, "pubsub adapter is nil - serverless pubsub will be unavailable")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
hostFuncsCfg := serverless.HostFunctionsConfig{
|
||||||
|
IPFSAPIURL: ipfsAPIURL,
|
||||||
|
HTTPTimeout: 30 * time.Second,
|
||||||
|
}
|
||||||
|
hostFuncs := serverless.NewHostFunctions(
|
||||||
|
gw.ormClient,
|
||||||
|
olricClient,
|
||||||
|
gw.ipfsClient,
|
||||||
|
pubsubAdapter, // pubsub adapter for serverless functions
|
||||||
|
gw.serverlessWSMgr,
|
||||||
|
nil, // secrets manager - TODO: implement
|
||||||
|
hostFuncsCfg,
|
||||||
|
logger.Logger,
|
||||||
|
)
|
||||||
|
|
||||||
|
// Create WASM engine configuration
|
||||||
|
engineCfg := serverless.DefaultConfig()
|
||||||
|
engineCfg.DefaultMemoryLimitMB = 128
|
||||||
|
engineCfg.MaxMemoryLimitMB = 256
|
||||||
|
engineCfg.DefaultTimeoutSeconds = 30
|
||||||
|
engineCfg.MaxTimeoutSeconds = 60
|
||||||
|
engineCfg.ModuleCacheSize = 100
|
||||||
|
|
||||||
|
// Create WASM engine
|
||||||
|
engine, engineErr := serverless.NewEngine(engineCfg, registry, hostFuncs, logger.Logger, serverless.WithInvocationLogger(registry))
|
||||||
|
if engineErr != nil {
|
||||||
|
logger.ComponentWarn(logging.ComponentGeneral, "failed to initialize serverless engine; functions disabled", zap.Error(engineErr))
|
||||||
|
} else {
|
||||||
|
gw.serverlessEngine = engine
|
||||||
|
|
||||||
|
// Create invoker
|
||||||
|
gw.serverlessInvoker = serverless.NewInvoker(engine, registry, hostFuncs, logger.Logger)
|
||||||
|
|
||||||
|
// Create trigger manager
|
||||||
|
triggerManager := serverless.NewDBTriggerManager(gw.ormClient, logger.Logger)
|
||||||
|
|
||||||
|
// Create HTTP handlers
|
||||||
|
gw.serverlessHandlers = NewServerlessHandlers(
|
||||||
|
gw.serverlessInvoker,
|
||||||
|
registry,
|
||||||
|
gw.serverlessWSMgr,
|
||||||
|
triggerManager,
|
||||||
|
logger.Logger,
|
||||||
|
)
|
||||||
|
|
||||||
|
// Initialize auth service
|
||||||
|
// For now using ephemeral key, can be loaded from config later
|
||||||
|
key, _ := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
keyPEM := pem.EncodeToMemory(&pem.Block{
|
||||||
|
Type: "RSA PRIVATE KEY",
|
||||||
|
Bytes: x509.MarshalPKCS1PrivateKey(key),
|
||||||
|
})
|
||||||
|
authService, err := auth.NewService(logger, c, string(keyPEM), cfg.ClientNamespace)
|
||||||
|
if err != nil {
|
||||||
|
logger.ComponentError(logging.ComponentGeneral, "failed to initialize auth service", zap.Error(err))
|
||||||
|
} else {
|
||||||
|
gw.authService = authService
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.ComponentInfo(logging.ComponentGeneral, "Serverless function engine ready",
|
||||||
|
zap.Int("default_memory_mb", engineCfg.DefaultMemoryLimitMB),
|
||||||
|
zap.Int("default_timeout_sec", engineCfg.DefaultTimeoutSeconds),
|
||||||
|
zap.Int("module_cache_size", engineCfg.ModuleCacheSize),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
logger.ComponentWarn(logging.ComponentGeneral, "serverless engine requires RQLite and IPFS; functions disabled")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize SFU manager for WebRTC calls
|
||||||
|
if err := gw.initializeSFUManager(); err != nil {
|
||||||
|
logger.ComponentWarn(logging.ComponentGeneral, "failed to initialize SFU manager; WebRTC calls disabled", zap.Error(err))
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.ComponentInfo(logging.ComponentGeneral, "Gateway creation completed, returning...")
|
||||||
return gw, nil
|
return gw, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// withInternalAuth creates a context for internal gateway operations that bypass authentication
|
||||||
|
func (g *Gateway) withInternalAuth(ctx context.Context) context.Context {
|
||||||
|
return client.WithInternalAuth(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close disconnects the gateway client
|
||||||
|
func (g *Gateway) Close() {
|
||||||
|
// Close SFU manager first
|
||||||
|
if g.sfuManager != nil {
|
||||||
|
if err := g.sfuManager.Close(); err != nil {
|
||||||
|
g.logger.ComponentWarn(logging.ComponentGeneral, "error during SFU manager close", zap.Error(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Close serverless engine
|
||||||
|
if g.serverlessEngine != nil {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
if err := g.serverlessEngine.Close(ctx); err != nil {
|
||||||
|
g.logger.ComponentWarn(logging.ComponentGeneral, "error during serverless engine close", zap.Error(err))
|
||||||
|
}
|
||||||
|
cancel()
|
||||||
|
}
|
||||||
|
if g.client != nil {
|
||||||
|
if err := g.client.Disconnect(); err != nil {
|
||||||
|
g.logger.ComponentWarn(logging.ComponentClient, "error during client disconnect", zap.Error(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if g.sqlDB != nil {
|
||||||
|
_ = g.sqlDB.Close()
|
||||||
|
}
|
||||||
|
if client := g.getOlricClient(); client != nil {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
if err := client.Close(ctx); err != nil {
|
||||||
|
g.logger.ComponentWarn(logging.ComponentGeneral, "error during Olric client close", zap.Error(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if g.ipfsClient != nil {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
if err := g.ipfsClient.Close(ctx); err != nil {
|
||||||
|
g.logger.ComponentWarn(logging.ComponentGeneral, "error during IPFS client close", zap.Error(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// getLocalSubscribers returns all local subscribers for a given topic and namespace
|
// getLocalSubscribers returns all local subscribers for a given topic and namespace
|
||||||
func (g *Gateway) getLocalSubscribers(topic, namespace string) []*localSubscriber {
|
func (g *Gateway) getLocalSubscribers(topic, namespace string) []*localSubscriber {
|
||||||
topicKey := namespace + "." + topic
|
topicKey := namespace + "." + topic
|
||||||
@ -197,32 +486,23 @@ func (g *Gateway) getLocalSubscribers(topic, namespace string) []*localSubscribe
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// setOlricClient atomically sets the Olric client and reinitializes cache handlers.
|
|
||||||
func (g *Gateway) setOlricClient(client *olric.Client) {
|
func (g *Gateway) setOlricClient(client *olric.Client) {
|
||||||
g.olricMu.Lock()
|
g.olricMu.Lock()
|
||||||
defer g.olricMu.Unlock()
|
defer g.olricMu.Unlock()
|
||||||
g.olricClient = client
|
g.olricClient = client
|
||||||
if client != nil {
|
|
||||||
g.cacheHandlers = cache.NewCacheHandlers(g.logger, client)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// getOlricClient atomically retrieves the current Olric client.
|
|
||||||
func (g *Gateway) getOlricClient() *olric.Client {
|
func (g *Gateway) getOlricClient() *olric.Client {
|
||||||
g.olricMu.RLock()
|
g.olricMu.RLock()
|
||||||
defer g.olricMu.RUnlock()
|
defer g.olricMu.RUnlock()
|
||||||
return g.olricClient
|
return g.olricClient
|
||||||
}
|
}
|
||||||
|
|
||||||
// startOlricReconnectLoop starts a background goroutine that continuously attempts
|
|
||||||
// to reconnect to the Olric cluster with exponential backoff.
|
|
||||||
func (g *Gateway) startOlricReconnectLoop(cfg olric.Config) {
|
func (g *Gateway) startOlricReconnectLoop(cfg olric.Config) {
|
||||||
go func() {
|
go func() {
|
||||||
retryDelay := 5 * time.Second
|
retryDelay := 5 * time.Second
|
||||||
maxBackoff := 30 * time.Second
|
|
||||||
|
|
||||||
for {
|
for {
|
||||||
client, err := olric.NewClient(cfg, g.logger.Logger)
|
client, err := initializeOlricClientWithRetry(cfg, g.logger)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
g.setOlricClient(client)
|
g.setOlricClient(client)
|
||||||
g.logger.ComponentInfo(logging.ComponentGeneral, "Olric cache client connected after background retries",
|
g.logger.ComponentInfo(logging.ComponentGeneral, "Olric cache client connected after background retries",
|
||||||
@ -236,13 +516,211 @@ func (g *Gateway) startOlricReconnectLoop(cfg olric.Config) {
|
|||||||
zap.Error(err))
|
zap.Error(err))
|
||||||
|
|
||||||
time.Sleep(retryDelay)
|
time.Sleep(retryDelay)
|
||||||
if retryDelay < maxBackoff {
|
if retryDelay < olricInitMaxBackoff {
|
||||||
retryDelay *= 2
|
retryDelay *= 2
|
||||||
if retryDelay > maxBackoff {
|
if retryDelay > olricInitMaxBackoff {
|
||||||
retryDelay = maxBackoff
|
retryDelay = olricInitMaxBackoff
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func initializeOlricClientWithRetry(cfg olric.Config, logger *logging.ColoredLogger) (*olric.Client, error) {
|
||||||
|
backoff := olricInitInitialBackoff
|
||||||
|
|
||||||
|
for attempt := 1; attempt <= olricInitMaxAttempts; attempt++ {
|
||||||
|
client, err := olric.NewClient(cfg, logger.Logger)
|
||||||
|
if err == nil {
|
||||||
|
if attempt > 1 {
|
||||||
|
logger.ComponentInfo(logging.ComponentGeneral, "Olric cache client initialized after retries",
|
||||||
|
zap.Int("attempts", attempt))
|
||||||
|
}
|
||||||
|
return client, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.ComponentWarn(logging.ComponentGeneral, "Olric cache client init attempt failed",
|
||||||
|
zap.Int("attempt", attempt),
|
||||||
|
zap.Duration("retry_in", backoff),
|
||||||
|
zap.Error(err))
|
||||||
|
|
||||||
|
if attempt == olricInitMaxAttempts {
|
||||||
|
return nil, fmt.Errorf("failed to initialize Olric cache client after %d attempts: %w", attempt, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(backoff)
|
||||||
|
backoff *= 2
|
||||||
|
if backoff > olricInitMaxBackoff {
|
||||||
|
backoff = olricInitMaxBackoff
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("failed to initialize Olric cache client")
|
||||||
|
}
|
||||||
|
|
||||||
|
// discoverOlricServers discovers Olric server addresses from LibP2P peers
|
||||||
|
// Returns a list of IP:port addresses where Olric servers are expected to run (port 3320)
|
||||||
|
func discoverOlricServers(networkClient client.NetworkClient, logger *zap.Logger) []string {
|
||||||
|
// Get network info to access peer information
|
||||||
|
networkInfo := networkClient.Network()
|
||||||
|
if networkInfo == nil {
|
||||||
|
logger.Debug("Network info not available for Olric discovery")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
peers, err := networkInfo.GetPeers(ctx)
|
||||||
|
if err != nil {
|
||||||
|
logger.Debug("Failed to get peers for Olric discovery", zap.Error(err))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
olricServers := make([]string, 0)
|
||||||
|
seen := make(map[string]bool)
|
||||||
|
|
||||||
|
for _, peer := range peers {
|
||||||
|
for _, addrStr := range peer.Addresses {
|
||||||
|
// Parse multiaddr
|
||||||
|
ma, err := multiaddr.NewMultiaddr(addrStr)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract IP address
|
||||||
|
var ip string
|
||||||
|
if ipv4, err := ma.ValueForProtocol(multiaddr.P_IP4); err == nil && ipv4 != "" {
|
||||||
|
ip = ipv4
|
||||||
|
} else if ipv6, err := ma.ValueForProtocol(multiaddr.P_IP6); err == nil && ipv6 != "" {
|
||||||
|
ip = ipv6
|
||||||
|
} else {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip localhost loopback addresses (we'll use localhost:3320 as fallback)
|
||||||
|
if ip == "localhost" || ip == "::1" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build Olric server address (standard port 3320)
|
||||||
|
olricAddr := net.JoinHostPort(ip, "3320")
|
||||||
|
if !seen[olricAddr] {
|
||||||
|
olricServers = append(olricServers, olricAddr)
|
||||||
|
seen[olricAddr] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Also check peers from config
|
||||||
|
if cfg := networkClient.Config(); cfg != nil {
|
||||||
|
for _, peerAddr := range cfg.BootstrapPeers {
|
||||||
|
ma, err := multiaddr.NewMultiaddr(peerAddr)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
var ip string
|
||||||
|
if ipv4, err := ma.ValueForProtocol(multiaddr.P_IP4); err == nil && ipv4 != "" {
|
||||||
|
ip = ipv4
|
||||||
|
} else if ipv6, err := ma.ValueForProtocol(multiaddr.P_IP6); err == nil && ipv6 != "" {
|
||||||
|
ip = ipv6
|
||||||
|
} else {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip localhost
|
||||||
|
if ip == "localhost" || ip == "::1" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
olricAddr := net.JoinHostPort(ip, "3320")
|
||||||
|
if !seen[olricAddr] {
|
||||||
|
olricServers = append(olricServers, olricAddr)
|
||||||
|
seen[olricAddr] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we found servers, log them
|
||||||
|
if len(olricServers) > 0 {
|
||||||
|
logger.Info("Discovered Olric servers from LibP2P network",
|
||||||
|
zap.Strings("servers", olricServers))
|
||||||
|
}
|
||||||
|
|
||||||
|
return olricServers
|
||||||
|
}
|
||||||
|
|
||||||
|
// ipfsDiscoveryResult holds discovered IPFS configuration
|
||||||
|
type ipfsDiscoveryResult struct {
|
||||||
|
clusterURL string
|
||||||
|
apiURL string
|
||||||
|
timeout time.Duration
|
||||||
|
replicationFactor int
|
||||||
|
enableEncryption bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// discoverIPFSFromNodeConfigs discovers IPFS configuration from node.yaml files
|
||||||
|
// Checks node-1.yaml through node-5.yaml for IPFS configuration
|
||||||
|
func discoverIPFSFromNodeConfigs(logger *zap.Logger) ipfsDiscoveryResult {
|
||||||
|
homeDir, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
logger.Debug("Failed to get home directory for IPFS discovery", zap.Error(err))
|
||||||
|
return ipfsDiscoveryResult{}
|
||||||
|
}
|
||||||
|
|
||||||
|
configDir := filepath.Join(homeDir, ".orama")
|
||||||
|
|
||||||
|
// Try all node config files for IPFS settings
|
||||||
|
configFiles := []string{"node-1.yaml", "node-2.yaml", "node-3.yaml", "node-4.yaml", "node-5.yaml"}
|
||||||
|
|
||||||
|
for _, filename := range configFiles {
|
||||||
|
configPath := filepath.Join(configDir, filename)
|
||||||
|
data, err := os.ReadFile(configPath)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
var nodeCfg config.Config
|
||||||
|
if err := config.DecodeStrict(strings.NewReader(string(data)), &nodeCfg); err != nil {
|
||||||
|
logger.Debug("Failed to parse node config for IPFS discovery",
|
||||||
|
zap.String("file", filename), zap.Error(err))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if IPFS is configured
|
||||||
|
if nodeCfg.Database.IPFS.ClusterAPIURL != "" {
|
||||||
|
result := ipfsDiscoveryResult{
|
||||||
|
clusterURL: nodeCfg.Database.IPFS.ClusterAPIURL,
|
||||||
|
apiURL: nodeCfg.Database.IPFS.APIURL,
|
||||||
|
timeout: nodeCfg.Database.IPFS.Timeout,
|
||||||
|
replicationFactor: nodeCfg.Database.IPFS.ReplicationFactor,
|
||||||
|
enableEncryption: nodeCfg.Database.IPFS.EnableEncryption,
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.apiURL == "" {
|
||||||
|
result.apiURL = "http://localhost:5001"
|
||||||
|
}
|
||||||
|
if result.timeout == 0 {
|
||||||
|
result.timeout = 60 * time.Second
|
||||||
|
}
|
||||||
|
if result.replicationFactor == 0 {
|
||||||
|
result.replicationFactor = 3
|
||||||
|
}
|
||||||
|
// Default encryption to true if not set
|
||||||
|
if !result.enableEncryption {
|
||||||
|
result.enableEncryption = true
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Discovered IPFS config from node config",
|
||||||
|
zap.String("file", filename),
|
||||||
|
zap.String("cluster_url", result.clusterURL),
|
||||||
|
zap.String("api_url", result.apiURL),
|
||||||
|
zap.Bool("encryption_enabled", result.enableEncryption))
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ipfsDiscoveryResult{}
|
||||||
|
}
|
||||||
|
|||||||
@ -191,6 +191,16 @@ func isPublicPath(p string) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TURN credentials (public for development - requires secret for actual use)
|
||||||
|
if strings.HasPrefix(p, "/v1/turn/") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// SFU endpoints (public for development)
|
||||||
|
if strings.HasPrefix(p, "/v1/sfu/") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
switch p {
|
switch p {
|
||||||
case "/health", "/v1/health", "/status", "/v1/status", "/v1/auth/jwks", "/.well-known/jwks.json", "/v1/version", "/v1/auth/login", "/v1/auth/challenge", "/v1/auth/verify", "/v1/auth/register", "/v1/auth/refresh", "/v1/auth/logout", "/v1/auth/api-key", "/v1/auth/simple-key", "/v1/network/status", "/v1/network/peers":
|
case "/health", "/v1/health", "/status", "/v1/status", "/v1/auth/jwks", "/.well-known/jwks.json", "/v1/version", "/v1/auth/login", "/v1/auth/challenge", "/v1/auth/verify", "/v1/auth/register", "/v1/auth/refresh", "/v1/auth/logout", "/v1/auth/api-key", "/v1/auth/simple-key", "/v1/network/status", "/v1/network/peers":
|
||||||
return true
|
return true
|
||||||
|
|||||||
475
pkg/gateway/pubsub_handlers.go
Normal file
475
pkg/gateway/pubsub_handlers.go
Normal file
@ -0,0 +1,475 @@
|
|||||||
|
package gateway
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/DeBrosOfficial/network/pkg/client"
|
||||||
|
"github.com/DeBrosOfficial/network/pkg/pubsub"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
)
|
||||||
|
|
||||||
|
var wsUpgrader = websocket.Upgrader{
|
||||||
|
ReadBufferSize: 1024,
|
||||||
|
WriteBufferSize: 1024,
|
||||||
|
// For early development we accept any origin; tighten later.
|
||||||
|
CheckOrigin: func(r *http.Request) bool { return true },
|
||||||
|
}
|
||||||
|
|
||||||
|
// pubsubWebsocketHandler upgrades to WS, subscribes to a namespaced topic, and
|
||||||
|
// forwards received PubSub messages to the client. Messages sent by the client
|
||||||
|
// are published to the same namespaced topic.
|
||||||
|
func (g *Gateway) pubsubWebsocketHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if g.client == nil {
|
||||||
|
g.logger.ComponentWarn("gateway", "pubsub ws: client not initialized")
|
||||||
|
writeError(w, http.StatusServiceUnavailable, "client not initialized")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if r.Method != http.MethodGet {
|
||||||
|
g.logger.ComponentWarn("gateway", "pubsub ws: method not allowed")
|
||||||
|
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Resolve namespace from auth context
|
||||||
|
ns := resolveNamespaceFromRequest(r)
|
||||||
|
if ns == "" {
|
||||||
|
g.logger.ComponentWarn("gateway", "pubsub ws: namespace not resolved")
|
||||||
|
writeError(w, http.StatusForbidden, "namespace not resolved")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
topic := r.URL.Query().Get("topic")
|
||||||
|
if topic == "" {
|
||||||
|
g.logger.ComponentWarn("gateway", "pubsub ws: missing topic")
|
||||||
|
writeError(w, http.StatusBadRequest, "missing 'topic'")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Presence handling
|
||||||
|
enablePresence := r.URL.Query().Get("presence") == "true"
|
||||||
|
memberID := r.URL.Query().Get("member_id")
|
||||||
|
memberMetaStr := r.URL.Query().Get("member_meta")
|
||||||
|
var memberMeta map[string]interface{}
|
||||||
|
if memberMetaStr != "" {
|
||||||
|
_ = json.Unmarshal([]byte(memberMetaStr), &memberMeta)
|
||||||
|
}
|
||||||
|
|
||||||
|
if enablePresence && memberID == "" {
|
||||||
|
g.logger.ComponentWarn("gateway", "pubsub ws: presence enabled but missing member_id")
|
||||||
|
writeError(w, http.StatusBadRequest, "missing 'member_id' for presence")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := wsUpgrader.Upgrade(w, r, nil)
|
||||||
|
if err != nil {
|
||||||
|
g.logger.ComponentWarn("gateway", "pubsub ws: upgrade failed")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
// Channel to deliver PubSub messages to WS writer
|
||||||
|
msgs := make(chan []byte, 128)
|
||||||
|
|
||||||
|
// NEW: Register as local subscriber for direct message delivery
|
||||||
|
localSub := &localSubscriber{
|
||||||
|
msgChan: msgs,
|
||||||
|
namespace: ns,
|
||||||
|
}
|
||||||
|
topicKey := fmt.Sprintf("%s.%s", ns, topic)
|
||||||
|
|
||||||
|
g.mu.Lock()
|
||||||
|
g.localSubscribers[topicKey] = append(g.localSubscribers[topicKey], localSub)
|
||||||
|
subscriberCount := len(g.localSubscribers[topicKey])
|
||||||
|
g.mu.Unlock()
|
||||||
|
|
||||||
|
connID := uuid.New().String()
|
||||||
|
if enablePresence {
|
||||||
|
member := PresenceMember{
|
||||||
|
MemberID: memberID,
|
||||||
|
JoinedAt: time.Now().Unix(),
|
||||||
|
Meta: memberMeta,
|
||||||
|
ConnID: connID,
|
||||||
|
}
|
||||||
|
|
||||||
|
g.presenceMu.Lock()
|
||||||
|
g.presenceMembers[topicKey] = append(g.presenceMembers[topicKey], member)
|
||||||
|
g.presenceMu.Unlock()
|
||||||
|
|
||||||
|
// Broadcast join event
|
||||||
|
joinEvent := map[string]interface{}{
|
||||||
|
"type": "presence.join",
|
||||||
|
"member_id": memberID,
|
||||||
|
"meta": memberMeta,
|
||||||
|
"timestamp": member.JoinedAt,
|
||||||
|
}
|
||||||
|
eventData, _ := json.Marshal(joinEvent)
|
||||||
|
// Use a background context for the broadcast to ensure it finishes even if the connection closes immediately
|
||||||
|
broadcastCtx := pubsub.WithNamespace(client.WithInternalAuth(context.Background()), ns)
|
||||||
|
_ = g.client.PubSub().Publish(broadcastCtx, topic, eventData)
|
||||||
|
|
||||||
|
g.logger.ComponentInfo("gateway", "pubsub ws: member joined presence",
|
||||||
|
zap.String("topic", topic),
|
||||||
|
zap.String("member_id", memberID))
|
||||||
|
}
|
||||||
|
|
||||||
|
g.logger.ComponentInfo("gateway", "pubsub ws: registered local subscriber",
|
||||||
|
zap.String("topic", topic),
|
||||||
|
zap.String("namespace", ns),
|
||||||
|
zap.Int("total_subscribers", subscriberCount))
|
||||||
|
|
||||||
|
// Unregister on close
|
||||||
|
defer func() {
|
||||||
|
g.mu.Lock()
|
||||||
|
subs := g.localSubscribers[topicKey]
|
||||||
|
for i, sub := range subs {
|
||||||
|
if sub == localSub {
|
||||||
|
g.localSubscribers[topicKey] = append(subs[:i], subs[i+1:]...)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
remainingCount := len(g.localSubscribers[topicKey])
|
||||||
|
if remainingCount == 0 {
|
||||||
|
delete(g.localSubscribers, topicKey)
|
||||||
|
}
|
||||||
|
g.mu.Unlock()
|
||||||
|
|
||||||
|
if enablePresence {
|
||||||
|
g.presenceMu.Lock()
|
||||||
|
members := g.presenceMembers[topicKey]
|
||||||
|
for i, m := range members {
|
||||||
|
if m.ConnID == connID {
|
||||||
|
g.presenceMembers[topicKey] = append(members[:i], members[i+1:]...)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(g.presenceMembers[topicKey]) == 0 {
|
||||||
|
delete(g.presenceMembers, topicKey)
|
||||||
|
}
|
||||||
|
g.presenceMu.Unlock()
|
||||||
|
|
||||||
|
// Broadcast leave event
|
||||||
|
leaveEvent := map[string]interface{}{
|
||||||
|
"type": "presence.leave",
|
||||||
|
"member_id": memberID,
|
||||||
|
"timestamp": time.Now().Unix(),
|
||||||
|
}
|
||||||
|
eventData, _ := json.Marshal(leaveEvent)
|
||||||
|
broadcastCtx := pubsub.WithNamespace(client.WithInternalAuth(context.Background()), ns)
|
||||||
|
_ = g.client.PubSub().Publish(broadcastCtx, topic, eventData)
|
||||||
|
|
||||||
|
g.logger.ComponentInfo("gateway", "pubsub ws: member left presence",
|
||||||
|
zap.String("topic", topic),
|
||||||
|
zap.String("member_id", memberID))
|
||||||
|
}
|
||||||
|
|
||||||
|
g.logger.ComponentInfo("gateway", "pubsub ws: unregistered local subscriber",
|
||||||
|
zap.String("topic", topic),
|
||||||
|
zap.Int("remaining_subscribers", remainingCount))
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Use internal auth context when interacting with client to avoid circular auth requirements
|
||||||
|
ctx := client.WithInternalAuth(r.Context())
|
||||||
|
// Apply namespace isolation
|
||||||
|
ctx = pubsub.WithNamespace(ctx, ns)
|
||||||
|
|
||||||
|
// Writer loop - START THIS FIRST before libp2p subscription
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
g.logger.ComponentInfo("gateway", "pubsub ws: writer goroutine started",
|
||||||
|
zap.String("topic", topic))
|
||||||
|
defer g.logger.ComponentInfo("gateway", "pubsub ws: writer goroutine exiting",
|
||||||
|
zap.String("topic", topic))
|
||||||
|
ticker := time.NewTicker(30 * time.Second)
|
||||||
|
defer ticker.Stop()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case b, ok := <-msgs:
|
||||||
|
if !ok {
|
||||||
|
g.logger.ComponentWarn("gateway", "pubsub ws: message channel closed",
|
||||||
|
zap.String("topic", topic))
|
||||||
|
_ = conn.WriteControl(websocket.CloseMessage, []byte{}, time.Now().Add(5*time.Second))
|
||||||
|
close(done)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
g.logger.ComponentInfo("gateway", "pubsub ws: sending message to client",
|
||||||
|
zap.String("topic", topic),
|
||||||
|
zap.Int("data_len", len(b)))
|
||||||
|
|
||||||
|
// Format message as JSON envelope with data (base64 encoded), timestamp, and topic
|
||||||
|
// This matches the SDK's Message interface: {data: string, timestamp: number, topic: string}
|
||||||
|
envelope := map[string]interface{}{
|
||||||
|
"data": base64.StdEncoding.EncodeToString(b),
|
||||||
|
"timestamp": time.Now().UnixMilli(),
|
||||||
|
"topic": topic,
|
||||||
|
}
|
||||||
|
envelopeJSON, err := json.Marshal(envelope)
|
||||||
|
if err != nil {
|
||||||
|
g.logger.ComponentWarn("gateway", "pubsub ws: failed to marshal envelope",
|
||||||
|
zap.String("topic", topic),
|
||||||
|
zap.Error(err))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
g.logger.ComponentDebug("gateway", "pubsub ws: envelope created",
|
||||||
|
zap.String("topic", topic),
|
||||||
|
zap.Int("envelope_len", len(envelopeJSON)))
|
||||||
|
|
||||||
|
conn.SetWriteDeadline(time.Now().Add(30 * time.Second))
|
||||||
|
if err := conn.WriteMessage(websocket.TextMessage, envelopeJSON); err != nil {
|
||||||
|
g.logger.ComponentWarn("gateway", "pubsub ws: failed to write to websocket",
|
||||||
|
zap.String("topic", topic),
|
||||||
|
zap.Error(err))
|
||||||
|
close(done)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
g.logger.ComponentInfo("gateway", "pubsub ws: message sent successfully",
|
||||||
|
zap.String("topic", topic))
|
||||||
|
case <-ticker.C:
|
||||||
|
// Ping keepalive
|
||||||
|
_ = conn.WriteControl(websocket.PingMessage, []byte("ping"), time.Now().Add(5*time.Second))
|
||||||
|
case <-ctx.Done():
|
||||||
|
close(done)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Subscribe to libp2p for cross-node messages (in background, non-blocking)
|
||||||
|
go func() {
|
||||||
|
h := func(_ string, data []byte) error {
|
||||||
|
g.logger.ComponentInfo("gateway", "pubsub ws: received message from libp2p",
|
||||||
|
zap.String("topic", topic),
|
||||||
|
zap.Int("data_len", len(data)))
|
||||||
|
|
||||||
|
select {
|
||||||
|
case msgs <- data:
|
||||||
|
g.logger.ComponentInfo("gateway", "pubsub ws: forwarded to client",
|
||||||
|
zap.String("topic", topic),
|
||||||
|
zap.String("source", "libp2p"))
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
// Drop if client is slow to avoid blocking network
|
||||||
|
g.logger.ComponentWarn("gateway", "pubsub ws: client slow, dropping message",
|
||||||
|
zap.String("topic", topic))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := g.client.PubSub().Subscribe(ctx, topic, h); err != nil {
|
||||||
|
g.logger.ComponentWarn("gateway", "pubsub ws: libp2p subscribe failed (will use local-only)",
|
||||||
|
zap.String("topic", topic),
|
||||||
|
zap.Error(err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
g.logger.ComponentInfo("gateway", "pubsub ws: libp2p subscription established",
|
||||||
|
zap.String("topic", topic))
|
||||||
|
|
||||||
|
// Keep subscription alive until done
|
||||||
|
<-done
|
||||||
|
_ = g.client.PubSub().Unsubscribe(ctx, topic)
|
||||||
|
g.logger.ComponentInfo("gateway", "pubsub ws: libp2p subscription closed",
|
||||||
|
zap.String("topic", topic))
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Reader loop: treat any client message as publish to the same topic
|
||||||
|
for {
|
||||||
|
mt, data, err := conn.ReadMessage()
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if mt != websocket.TextMessage && mt != websocket.BinaryMessage {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Filter out WebSocket heartbeat messages
|
||||||
|
// Don't publish them to the topic
|
||||||
|
var msg map[string]interface{}
|
||||||
|
if err := json.Unmarshal(data, &msg); err == nil {
|
||||||
|
if msgType, ok := msg["type"].(string); ok && msgType == "ping" {
|
||||||
|
g.logger.ComponentInfo("gateway", "pubsub ws: filtering out heartbeat ping")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := g.client.PubSub().Publish(ctx, topic, data); err != nil {
|
||||||
|
// Best-effort notify client
|
||||||
|
_ = conn.WriteMessage(websocket.TextMessage, []byte("publish_error"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
<-done
|
||||||
|
}
|
||||||
|
|
||||||
|
// pubsubPublishHandler handles POST /v1/pubsub/publish {topic, data_base64}
|
||||||
|
func (g *Gateway) pubsubPublishHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if g.client == nil {
|
||||||
|
writeError(w, http.StatusServiceUnavailable, "client not initialized")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if r.Method != http.MethodPost {
|
||||||
|
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ns := resolveNamespaceFromRequest(r)
|
||||||
|
if ns == "" {
|
||||||
|
writeError(w, http.StatusForbidden, "namespace not resolved")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var body struct {
|
||||||
|
Topic string `json:"topic"`
|
||||||
|
DataB64 string `json:"data_base64"`
|
||||||
|
}
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&body); err != nil || body.Topic == "" || body.DataB64 == "" {
|
||||||
|
writeError(w, http.StatusBadRequest, "invalid body: expected {topic,data_base64}")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data, err := base64.StdEncoding.DecodeString(body.DataB64)
|
||||||
|
if err != nil {
|
||||||
|
writeError(w, http.StatusBadRequest, "invalid base64 data")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// NEW: Check for local websocket subscribers FIRST and deliver directly
|
||||||
|
g.mu.RLock()
|
||||||
|
localSubs := g.getLocalSubscribers(body.Topic, ns)
|
||||||
|
g.mu.RUnlock()
|
||||||
|
|
||||||
|
localDeliveryCount := 0
|
||||||
|
if len(localSubs) > 0 {
|
||||||
|
for _, sub := range localSubs {
|
||||||
|
select {
|
||||||
|
case sub.msgChan <- data:
|
||||||
|
localDeliveryCount++
|
||||||
|
g.logger.ComponentDebug("gateway", "delivered to local subscriber",
|
||||||
|
zap.String("topic", body.Topic))
|
||||||
|
default:
|
||||||
|
// Drop if buffer full
|
||||||
|
g.logger.ComponentWarn("gateway", "local subscriber buffer full, dropping message",
|
||||||
|
zap.String("topic", body.Topic))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
g.logger.ComponentInfo("gateway", "pubsub publish: processing message",
|
||||||
|
zap.String("topic", body.Topic),
|
||||||
|
zap.String("namespace", ns),
|
||||||
|
zap.Int("data_len", len(data)),
|
||||||
|
zap.Int("local_subscribers", len(localSubs)),
|
||||||
|
zap.Int("local_delivered", localDeliveryCount))
|
||||||
|
|
||||||
|
// Publish to libp2p asynchronously for cross-node delivery
|
||||||
|
// This prevents blocking the HTTP response if libp2p network is slow
|
||||||
|
go func() {
|
||||||
|
publishCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
ctx := pubsub.WithNamespace(client.WithInternalAuth(publishCtx), ns)
|
||||||
|
if err := g.client.PubSub().Publish(ctx, body.Topic, data); err != nil {
|
||||||
|
g.logger.ComponentWarn("gateway", "async libp2p publish failed",
|
||||||
|
zap.String("topic", body.Topic),
|
||||||
|
zap.Error(err))
|
||||||
|
} else {
|
||||||
|
g.logger.ComponentDebug("gateway", "async libp2p publish succeeded",
|
||||||
|
zap.String("topic", body.Topic))
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Return immediately after local delivery
|
||||||
|
// Local WebSocket subscribers already received the message
|
||||||
|
writeJSON(w, http.StatusOK, map[string]any{"status": "ok"})
|
||||||
|
}
|
||||||
|
|
||||||
|
// pubsubTopicsHandler lists topics within the caller's namespace
|
||||||
|
func (g *Gateway) pubsubTopicsHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if g.client == nil {
|
||||||
|
writeError(w, http.StatusServiceUnavailable, "client not initialized")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ns := resolveNamespaceFromRequest(r)
|
||||||
|
if ns == "" {
|
||||||
|
writeError(w, http.StatusForbidden, "namespace not resolved")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Apply namespace isolation
|
||||||
|
ctx := pubsub.WithNamespace(client.WithInternalAuth(r.Context()), ns)
|
||||||
|
all, err := g.client.PubSub().ListTopics(ctx)
|
||||||
|
if err != nil {
|
||||||
|
writeError(w, http.StatusInternalServerError, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Client returns topics already trimmed to its namespace; return as-is
|
||||||
|
writeJSON(w, http.StatusOK, map[string]any{"topics": all})
|
||||||
|
}
|
||||||
|
|
||||||
|
// resolveNamespaceFromRequest gets namespace from context set by auth middleware
|
||||||
|
// Falls back to query parameter "namespace" for development/testing
|
||||||
|
func resolveNamespaceFromRequest(r *http.Request) string {
|
||||||
|
if v := r.Context().Value(ctxKeyNamespaceOverride); v != nil {
|
||||||
|
if s, ok := v.(string); ok {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Fallback: check query parameter for development
|
||||||
|
if ns := strings.TrimSpace(r.URL.Query().Get("namespace")); ns != "" {
|
||||||
|
return ns
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func namespacePrefix(ns string) string {
|
||||||
|
return "ns::" + ns + "::"
|
||||||
|
}
|
||||||
|
|
||||||
|
func namespacedTopic(ns, topic string) string {
|
||||||
|
return namespacePrefix(ns) + topic
|
||||||
|
}
|
||||||
|
|
||||||
|
// pubsubPresenceHandler handles GET /v1/pubsub/presence?topic=mytopic
|
||||||
|
func (g *Gateway) pubsubPresenceHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodGet {
|
||||||
|
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ns := resolveNamespaceFromRequest(r)
|
||||||
|
if ns == "" {
|
||||||
|
writeError(w, http.StatusForbidden, "namespace not resolved")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
topic := r.URL.Query().Get("topic")
|
||||||
|
if topic == "" {
|
||||||
|
writeError(w, http.StatusBadRequest, "missing 'topic'")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
topicKey := fmt.Sprintf("%s.%s", ns, topic)
|
||||||
|
|
||||||
|
g.presenceMu.RLock()
|
||||||
|
members, ok := g.presenceMembers[topicKey]
|
||||||
|
g.presenceMu.RUnlock()
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
writeJSON(w, http.StatusOK, map[string]any{
|
||||||
|
"topic": topic,
|
||||||
|
"members": []PresenceMember{},
|
||||||
|
"count": 0,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
writeJSON(w, http.StatusOK, map[string]any{
|
||||||
|
"topic": topic,
|
||||||
|
"members": members,
|
||||||
|
"count": len(members),
|
||||||
|
})
|
||||||
|
}
|
||||||
@ -79,5 +79,14 @@ func (g *Gateway) Routes() http.Handler {
|
|||||||
g.serverlessHandlers.RegisterRoutes(mux)
|
g.serverlessHandlers.RegisterRoutes(mux)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TURN credentials for WebRTC
|
||||||
|
mux.HandleFunc("/v1/turn/credentials", g.turnCredentialsHandler)
|
||||||
|
|
||||||
|
// SFU endpoints for WebRTC group calls (if enabled)
|
||||||
|
if g.sfuManager != nil {
|
||||||
|
mux.HandleFunc("/v1/sfu/room", g.sfuCreateRoomHandler)
|
||||||
|
mux.HandleFunc("/v1/sfu/room/", g.sfuRoomHandler) // Handles :roomId/* paths
|
||||||
|
}
|
||||||
|
|
||||||
return g.withMiddleware(mux)
|
return g.withMiddleware(mux)
|
||||||
}
|
}
|
||||||
|
|||||||
933
pkg/gateway/serverless_handlers.go
Normal file
933
pkg/gateway/serverless_handlers.go
Normal file
@ -0,0 +1,933 @@
|
|||||||
|
package gateway
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/DeBrosOfficial/network/pkg/gateway/auth"
|
||||||
|
"github.com/DeBrosOfficial/network/pkg/serverless"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ServerlessHandlers contains handlers for serverless function endpoints.
|
||||||
|
// It's a separate struct to keep the Gateway struct clean.
|
||||||
|
type ServerlessHandlers struct {
|
||||||
|
invoker *serverless.Invoker
|
||||||
|
registry serverless.FunctionRegistry
|
||||||
|
wsManager *serverless.WSManager
|
||||||
|
triggerManager serverless.TriggerManager
|
||||||
|
logger *zap.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewServerlessHandlers creates a new ServerlessHandlers instance.
|
||||||
|
func NewServerlessHandlers(
|
||||||
|
invoker *serverless.Invoker,
|
||||||
|
registry serverless.FunctionRegistry,
|
||||||
|
wsManager *serverless.WSManager,
|
||||||
|
triggerManager serverless.TriggerManager,
|
||||||
|
logger *zap.Logger,
|
||||||
|
) *ServerlessHandlers {
|
||||||
|
return &ServerlessHandlers{
|
||||||
|
invoker: invoker,
|
||||||
|
registry: registry,
|
||||||
|
wsManager: wsManager,
|
||||||
|
triggerManager: triggerManager,
|
||||||
|
logger: logger,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterRoutes registers all serverless routes on the given mux.
|
||||||
|
func (h *ServerlessHandlers) RegisterRoutes(mux *http.ServeMux) {
|
||||||
|
// Function management
|
||||||
|
mux.HandleFunc("/v1/functions", h.handleFunctions)
|
||||||
|
mux.HandleFunc("/v1/functions/", h.handleFunctionByName)
|
||||||
|
|
||||||
|
// Direct invoke endpoint
|
||||||
|
mux.HandleFunc("/v1/invoke/", h.handleInvoke)
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleFunctions handles GET /v1/functions (list) and POST /v1/functions (deploy)
|
||||||
|
func (h *ServerlessHandlers) handleFunctions(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.Method {
|
||||||
|
case http.MethodGet:
|
||||||
|
h.listFunctions(w, r)
|
||||||
|
case http.MethodPost:
|
||||||
|
h.deployFunction(w, r)
|
||||||
|
default:
|
||||||
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleFunctionByName handles operations on a specific function
|
||||||
|
// Routes:
|
||||||
|
// - GET /v1/functions/{name} - Get function info
|
||||||
|
// - DELETE /v1/functions/{name} - Delete function
|
||||||
|
// - POST /v1/functions/{name}/invoke - Invoke function
|
||||||
|
// - GET /v1/functions/{name}/versions - List versions
|
||||||
|
// - GET /v1/functions/{name}/logs - Get logs
|
||||||
|
// - WS /v1/functions/{name}/ws - WebSocket invoke
|
||||||
|
func (h *ServerlessHandlers) handleFunctionByName(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Parse path: /v1/functions/{name}[/{action}]
|
||||||
|
path := strings.TrimPrefix(r.URL.Path, "/v1/functions/")
|
||||||
|
parts := strings.SplitN(path, "/", 2)
|
||||||
|
|
||||||
|
if len(parts) == 0 || parts[0] == "" {
|
||||||
|
http.Error(w, "Function name required", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
name := parts[0]
|
||||||
|
action := ""
|
||||||
|
if len(parts) > 1 {
|
||||||
|
action = parts[1]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse version from name if present (e.g., "myfunction@2")
|
||||||
|
version := 0
|
||||||
|
if idx := strings.Index(name, "@"); idx > 0 {
|
||||||
|
vStr := name[idx+1:]
|
||||||
|
name = name[:idx]
|
||||||
|
if v, err := strconv.Atoi(vStr); err == nil {
|
||||||
|
version = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
switch action {
|
||||||
|
case "invoke":
|
||||||
|
h.invokeFunction(w, r, name, version)
|
||||||
|
case "ws":
|
||||||
|
h.handleWebSocket(w, r, name, version)
|
||||||
|
case "versions":
|
||||||
|
h.listVersions(w, r, name)
|
||||||
|
case "logs":
|
||||||
|
h.getFunctionLogs(w, r, name)
|
||||||
|
case "triggers":
|
||||||
|
h.handleFunctionTriggers(w, r, name)
|
||||||
|
case "":
|
||||||
|
switch r.Method {
|
||||||
|
case http.MethodGet:
|
||||||
|
h.getFunctionInfo(w, r, name, version)
|
||||||
|
case http.MethodDelete:
|
||||||
|
h.deleteFunction(w, r, name, version)
|
||||||
|
default:
|
||||||
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
http.Error(w, "Unknown action", http.StatusNotFound)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleInvoke handles POST /v1/invoke/{namespace}/{name}[@version]
|
||||||
|
func (h *ServerlessHandlers) handleInvoke(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodPost {
|
||||||
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse path: /v1/invoke/{namespace}/{name}[@version]
|
||||||
|
path := strings.TrimPrefix(r.URL.Path, "/v1/invoke/")
|
||||||
|
parts := strings.SplitN(path, "/", 2)
|
||||||
|
|
||||||
|
if len(parts) < 2 {
|
||||||
|
http.Error(w, "Path must be /v1/invoke/{namespace}/{name}", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace := parts[0]
|
||||||
|
name := parts[1]
|
||||||
|
|
||||||
|
// Parse version if present
|
||||||
|
version := 0
|
||||||
|
if idx := strings.Index(name, "@"); idx > 0 {
|
||||||
|
vStr := name[idx+1:]
|
||||||
|
name = name[:idx]
|
||||||
|
if v, err := strconv.Atoi(vStr); err == nil {
|
||||||
|
version = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
h.invokeFunction(w, r, namespace+"/"+name, version)
|
||||||
|
}
|
||||||
|
|
||||||
|
// listFunctions handles GET /v1/functions
|
||||||
|
func (h *ServerlessHandlers) listFunctions(w http.ResponseWriter, r *http.Request) {
|
||||||
|
namespace := r.URL.Query().Get("namespace")
|
||||||
|
if namespace == "" {
|
||||||
|
// Get namespace from JWT if available
|
||||||
|
namespace = h.getNamespaceFromRequest(r)
|
||||||
|
}
|
||||||
|
|
||||||
|
if namespace == "" {
|
||||||
|
writeError(w, http.StatusBadRequest, "namespace required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(r.Context(), 30*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
functions, err := h.registry.List(ctx, namespace)
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Error("Failed to list functions", zap.Error(err))
|
||||||
|
writeError(w, http.StatusInternalServerError, "Failed to list functions")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
writeJSON(w, http.StatusOK, map[string]interface{}{
|
||||||
|
"functions": functions,
|
||||||
|
"count": len(functions),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// deployFunction handles POST /v1/functions
|
||||||
|
func (h *ServerlessHandlers) deployFunction(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Parse multipart form (for WASM upload) or JSON
|
||||||
|
contentType := r.Header.Get("Content-Type")
|
||||||
|
|
||||||
|
var def serverless.FunctionDefinition
|
||||||
|
var wasmBytes []byte
|
||||||
|
|
||||||
|
if strings.HasPrefix(contentType, "multipart/form-data") {
|
||||||
|
// Parse multipart form
|
||||||
|
if err := r.ParseMultipartForm(32 << 20); err != nil { // 32MB max
|
||||||
|
writeError(w, http.StatusBadRequest, "Failed to parse form: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get metadata from form field
|
||||||
|
metadataStr := r.FormValue("metadata")
|
||||||
|
if metadataStr != "" {
|
||||||
|
if err := json.Unmarshal([]byte(metadataStr), &def); err != nil {
|
||||||
|
writeError(w, http.StatusBadRequest, "Invalid metadata JSON: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get name from form if not in metadata
|
||||||
|
if def.Name == "" {
|
||||||
|
def.Name = r.FormValue("name")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get namespace from form if not in metadata
|
||||||
|
if def.Namespace == "" {
|
||||||
|
def.Namespace = r.FormValue("namespace")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get other configuration fields from form
|
||||||
|
if v := r.FormValue("is_public"); v != "" {
|
||||||
|
def.IsPublic, _ = strconv.ParseBool(v)
|
||||||
|
}
|
||||||
|
if v := r.FormValue("memory_limit_mb"); v != "" {
|
||||||
|
def.MemoryLimitMB, _ = strconv.Atoi(v)
|
||||||
|
}
|
||||||
|
if v := r.FormValue("timeout_seconds"); v != "" {
|
||||||
|
def.TimeoutSeconds, _ = strconv.Atoi(v)
|
||||||
|
}
|
||||||
|
if v := r.FormValue("retry_count"); v != "" {
|
||||||
|
def.RetryCount, _ = strconv.Atoi(v)
|
||||||
|
}
|
||||||
|
if v := r.FormValue("retry_delay_seconds"); v != "" {
|
||||||
|
def.RetryDelaySeconds, _ = strconv.Atoi(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get WASM file
|
||||||
|
file, _, err := r.FormFile("wasm")
|
||||||
|
if err != nil {
|
||||||
|
writeError(w, http.StatusBadRequest, "WASM file required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer file.Close()
|
||||||
|
|
||||||
|
wasmBytes, err = io.ReadAll(file)
|
||||||
|
if err != nil {
|
||||||
|
writeError(w, http.StatusBadRequest, "Failed to read WASM file: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// JSON body with base64-encoded WASM
|
||||||
|
var req struct {
|
||||||
|
serverless.FunctionDefinition
|
||||||
|
WASMBase64 string `json:"wasm_base64"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
|
writeError(w, http.StatusBadRequest, "Invalid JSON: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
def = req.FunctionDefinition
|
||||||
|
|
||||||
|
if req.WASMBase64 != "" {
|
||||||
|
// Decode base64 WASM - for now, just reject this method
|
||||||
|
writeError(w, http.StatusBadRequest, "Base64 WASM upload not supported, use multipart/form-data")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get namespace from JWT if not provided
|
||||||
|
if def.Namespace == "" {
|
||||||
|
def.Namespace = h.getNamespaceFromRequest(r)
|
||||||
|
}
|
||||||
|
|
||||||
|
if def.Name == "" {
|
||||||
|
writeError(w, http.StatusBadRequest, "Function name required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if def.Namespace == "" {
|
||||||
|
writeError(w, http.StatusBadRequest, "Namespace required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(wasmBytes) == 0 {
|
||||||
|
writeError(w, http.StatusBadRequest, "WASM bytecode required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(r.Context(), 60*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
oldFn, err := h.registry.Register(ctx, &def, wasmBytes)
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Error("Failed to deploy function",
|
||||||
|
zap.String("name", def.Name),
|
||||||
|
zap.Error(err),
|
||||||
|
)
|
||||||
|
writeError(w, http.StatusInternalServerError, "Failed to deploy: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Invalidate cache for the old version to ensure the new one is loaded
|
||||||
|
if oldFn != nil {
|
||||||
|
h.invoker.InvalidateCache(oldFn.WASMCID)
|
||||||
|
h.logger.Debug("Invalidated function cache",
|
||||||
|
zap.String("name", def.Name),
|
||||||
|
zap.String("old_wasm_cid", oldFn.WASMCID),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
h.logger.Info("Function deployed",
|
||||||
|
zap.String("name", def.Name),
|
||||||
|
zap.String("namespace", def.Namespace),
|
||||||
|
)
|
||||||
|
|
||||||
|
// Fetch the deployed function to return
|
||||||
|
fn, err := h.registry.Get(ctx, def.Namespace, def.Name, def.Version)
|
||||||
|
if err != nil {
|
||||||
|
writeJSON(w, http.StatusCreated, map[string]interface{}{
|
||||||
|
"message": "Function deployed successfully",
|
||||||
|
"name": def.Name,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register PubSub triggers if provided in metadata
|
||||||
|
var triggersAdded []string
|
||||||
|
if len(def.PubSubTopics) > 0 && h.triggerManager != nil {
|
||||||
|
for _, topic := range def.PubSubTopics {
|
||||||
|
if err := h.triggerManager.AddPubSubTrigger(ctx, fn.ID, topic); err != nil {
|
||||||
|
// Log but don't fail deployment
|
||||||
|
h.logger.Warn("Failed to add pubsub trigger during deployment",
|
||||||
|
zap.String("function", def.Name),
|
||||||
|
zap.String("topic", topic),
|
||||||
|
zap.Error(err),
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
triggersAdded = append(triggersAdded, topic)
|
||||||
|
h.logger.Info("PubSub trigger added during deployment",
|
||||||
|
zap.String("function", def.Name),
|
||||||
|
zap.String("topic", topic),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
response := map[string]interface{}{
|
||||||
|
"message": "Function deployed successfully",
|
||||||
|
"function": fn,
|
||||||
|
}
|
||||||
|
if len(triggersAdded) > 0 {
|
||||||
|
response["triggers_added"] = triggersAdded
|
||||||
|
}
|
||||||
|
|
||||||
|
writeJSON(w, http.StatusCreated, response)
|
||||||
|
}
|
||||||
|
|
||||||
|
// getFunctionInfo handles GET /v1/functions/{name}
|
||||||
|
func (h *ServerlessHandlers) getFunctionInfo(w http.ResponseWriter, r *http.Request, name string, version int) {
|
||||||
|
namespace := r.URL.Query().Get("namespace")
|
||||||
|
if namespace == "" {
|
||||||
|
namespace = h.getNamespaceFromRequest(r)
|
||||||
|
}
|
||||||
|
|
||||||
|
if namespace == "" {
|
||||||
|
writeError(w, http.StatusBadRequest, "namespace required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
fn, err := h.registry.Get(ctx, namespace, name, version)
|
||||||
|
if err != nil {
|
||||||
|
if serverless.IsNotFound(err) {
|
||||||
|
writeError(w, http.StatusNotFound, "Function not found")
|
||||||
|
} else {
|
||||||
|
writeError(w, http.StatusInternalServerError, "Failed to get function")
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
writeJSON(w, http.StatusOK, fn)
|
||||||
|
}
|
||||||
|
|
||||||
|
// deleteFunction handles DELETE /v1/functions/{name}
|
||||||
|
func (h *ServerlessHandlers) deleteFunction(w http.ResponseWriter, r *http.Request, name string, version int) {
|
||||||
|
namespace := r.URL.Query().Get("namespace")
|
||||||
|
if namespace == "" {
|
||||||
|
namespace = h.getNamespaceFromRequest(r)
|
||||||
|
}
|
||||||
|
|
||||||
|
if namespace == "" {
|
||||||
|
writeError(w, http.StatusBadRequest, "namespace required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
if err := h.registry.Delete(ctx, namespace, name, version); err != nil {
|
||||||
|
if serverless.IsNotFound(err) {
|
||||||
|
writeError(w, http.StatusNotFound, "Function not found")
|
||||||
|
} else {
|
||||||
|
writeError(w, http.StatusInternalServerError, "Failed to delete function")
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
writeJSON(w, http.StatusOK, map[string]string{
|
||||||
|
"message": "Function deleted successfully",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// invokeFunction handles POST /v1/functions/{name}/invoke
|
||||||
|
func (h *ServerlessHandlers) invokeFunction(w http.ResponseWriter, r *http.Request, nameWithNS string, version int) {
|
||||||
|
if r.Method != http.MethodPost {
|
||||||
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse namespace and name
|
||||||
|
var namespace, name string
|
||||||
|
if idx := strings.Index(nameWithNS, "/"); idx > 0 {
|
||||||
|
namespace = nameWithNS[:idx]
|
||||||
|
name = nameWithNS[idx+1:]
|
||||||
|
} else {
|
||||||
|
name = nameWithNS
|
||||||
|
namespace = r.URL.Query().Get("namespace")
|
||||||
|
if namespace == "" {
|
||||||
|
namespace = h.getNamespaceFromRequest(r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if namespace == "" {
|
||||||
|
writeError(w, http.StatusBadRequest, "namespace required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read input body
|
||||||
|
input, err := io.ReadAll(io.LimitReader(r.Body, 1<<20)) // 1MB max
|
||||||
|
if err != nil {
|
||||||
|
writeError(w, http.StatusBadRequest, "Failed to read request body")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get caller wallet from JWT
|
||||||
|
callerWallet := h.getWalletFromRequest(r)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(r.Context(), 60*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
req := &serverless.InvokeRequest{
|
||||||
|
Namespace: namespace,
|
||||||
|
FunctionName: name,
|
||||||
|
Version: version,
|
||||||
|
Input: input,
|
||||||
|
TriggerType: serverless.TriggerTypeHTTP,
|
||||||
|
CallerWallet: callerWallet,
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := h.invoker.Invoke(ctx, req)
|
||||||
|
if err != nil {
|
||||||
|
statusCode := http.StatusInternalServerError
|
||||||
|
if serverless.IsNotFound(err) {
|
||||||
|
statusCode = http.StatusNotFound
|
||||||
|
} else if serverless.IsResourceExhausted(err) {
|
||||||
|
statusCode = http.StatusTooManyRequests
|
||||||
|
} else if serverless.IsUnauthorized(err) {
|
||||||
|
statusCode = http.StatusUnauthorized
|
||||||
|
}
|
||||||
|
|
||||||
|
writeJSON(w, statusCode, map[string]interface{}{
|
||||||
|
"request_id": resp.RequestID,
|
||||||
|
"status": resp.Status,
|
||||||
|
"error": resp.Error,
|
||||||
|
"duration_ms": resp.DurationMS,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return the function's output directly if it's JSON
|
||||||
|
w.Header().Set("X-Request-ID", resp.RequestID)
|
||||||
|
w.Header().Set("X-Duration-Ms", strconv.FormatInt(resp.DurationMS, 10))
|
||||||
|
|
||||||
|
// Try to detect if output is JSON
|
||||||
|
if len(resp.Output) > 0 && (resp.Output[0] == '{' || resp.Output[0] == '[') {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write(resp.Output)
|
||||||
|
} else {
|
||||||
|
writeJSON(w, http.StatusOK, map[string]interface{}{
|
||||||
|
"request_id": resp.RequestID,
|
||||||
|
"output": string(resp.Output),
|
||||||
|
"status": resp.Status,
|
||||||
|
"duration_ms": resp.DurationMS,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleWebSocket handles WebSocket connections for function streaming
|
||||||
|
func (h *ServerlessHandlers) handleWebSocket(w http.ResponseWriter, r *http.Request, name string, version int) {
|
||||||
|
namespace := r.URL.Query().Get("namespace")
|
||||||
|
if namespace == "" {
|
||||||
|
namespace = h.getNamespaceFromRequest(r)
|
||||||
|
}
|
||||||
|
|
||||||
|
if namespace == "" {
|
||||||
|
http.Error(w, "namespace required", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Upgrade to WebSocket
|
||||||
|
upgrader := websocket.Upgrader{
|
||||||
|
CheckOrigin: func(r *http.Request) bool { return true },
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := upgrader.Upgrade(w, r, nil)
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Error("WebSocket upgrade failed", zap.Error(err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
clientID := uuid.New().String()
|
||||||
|
wsConn := &serverless.GorillaWSConn{Conn: conn}
|
||||||
|
|
||||||
|
// Register connection
|
||||||
|
h.wsManager.Register(clientID, wsConn)
|
||||||
|
defer h.wsManager.Unregister(clientID)
|
||||||
|
|
||||||
|
h.logger.Info("WebSocket connected",
|
||||||
|
zap.String("client_id", clientID),
|
||||||
|
zap.String("function", name),
|
||||||
|
)
|
||||||
|
|
||||||
|
callerWallet := h.getWalletFromRequest(r)
|
||||||
|
|
||||||
|
// Message loop
|
||||||
|
for {
|
||||||
|
_, message, err := conn.ReadMessage()
|
||||||
|
if err != nil {
|
||||||
|
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
|
||||||
|
h.logger.Warn("WebSocket error", zap.Error(err))
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
// Invoke function with WebSocket context
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||||
|
|
||||||
|
req := &serverless.InvokeRequest{
|
||||||
|
Namespace: namespace,
|
||||||
|
FunctionName: name,
|
||||||
|
Version: version,
|
||||||
|
Input: message,
|
||||||
|
TriggerType: serverless.TriggerTypeWebSocket,
|
||||||
|
CallerWallet: callerWallet,
|
||||||
|
WSClientID: clientID,
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := h.invoker.Invoke(ctx, req)
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
// Send response back
|
||||||
|
response := map[string]interface{}{
|
||||||
|
"request_id": resp.RequestID,
|
||||||
|
"status": resp.Status,
|
||||||
|
"duration_ms": resp.DurationMS,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
response["error"] = resp.Error
|
||||||
|
} else if len(resp.Output) > 0 {
|
||||||
|
// Try to parse output as JSON
|
||||||
|
var output interface{}
|
||||||
|
if json.Unmarshal(resp.Output, &output) == nil {
|
||||||
|
response["output"] = output
|
||||||
|
} else {
|
||||||
|
response["output"] = string(resp.Output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
respBytes, _ := json.Marshal(response)
|
||||||
|
if err := conn.WriteMessage(websocket.TextMessage, respBytes); err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// listVersions handles GET /v1/functions/{name}/versions
|
||||||
|
func (h *ServerlessHandlers) listVersions(w http.ResponseWriter, r *http.Request, name string) {
|
||||||
|
namespace := r.URL.Query().Get("namespace")
|
||||||
|
if namespace == "" {
|
||||||
|
namespace = h.getNamespaceFromRequest(r)
|
||||||
|
}
|
||||||
|
|
||||||
|
if namespace == "" {
|
||||||
|
writeError(w, http.StatusBadRequest, "namespace required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// Get registry with extended methods
|
||||||
|
reg, ok := h.registry.(*serverless.Registry)
|
||||||
|
if !ok {
|
||||||
|
writeError(w, http.StatusNotImplemented, "Version listing not supported")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
versions, err := reg.ListVersions(ctx, namespace, name)
|
||||||
|
if err != nil {
|
||||||
|
writeError(w, http.StatusInternalServerError, "Failed to list versions")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
writeJSON(w, http.StatusOK, map[string]interface{}{
|
||||||
|
"versions": versions,
|
||||||
|
"count": len(versions),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// getFunctionLogs handles GET /v1/functions/{name}/logs
|
||||||
|
func (h *ServerlessHandlers) getFunctionLogs(w http.ResponseWriter, r *http.Request, name string) {
|
||||||
|
namespace := r.URL.Query().Get("namespace")
|
||||||
|
if namespace == "" {
|
||||||
|
namespace = h.getNamespaceFromRequest(r)
|
||||||
|
}
|
||||||
|
|
||||||
|
if namespace == "" {
|
||||||
|
writeError(w, http.StatusBadRequest, "namespace required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
limit := 100
|
||||||
|
if lStr := r.URL.Query().Get("limit"); lStr != "" {
|
||||||
|
if l, err := strconv.Atoi(lStr); err == nil {
|
||||||
|
limit = l
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
logs, err := h.registry.GetLogs(ctx, namespace, name, limit)
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Error("Failed to get function logs",
|
||||||
|
zap.String("name", name),
|
||||||
|
zap.String("namespace", namespace),
|
||||||
|
zap.Error(err),
|
||||||
|
)
|
||||||
|
writeError(w, http.StatusInternalServerError, "Failed to get logs")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
writeJSON(w, http.StatusOK, map[string]interface{}{
|
||||||
|
"name": name,
|
||||||
|
"namespace": namespace,
|
||||||
|
"logs": logs,
|
||||||
|
"count": len(logs),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// getNamespaceFromRequest extracts namespace from JWT or query param
|
||||||
|
func (h *ServerlessHandlers) getNamespaceFromRequest(r *http.Request) string {
|
||||||
|
// Try context first (set by auth middleware) - most secure
|
||||||
|
if v := r.Context().Value(ctxKeyNamespaceOverride); v != nil {
|
||||||
|
if ns, ok := v.(string); ok && ns != "" {
|
||||||
|
return ns
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try query param as fallback (e.g. for public access or admin)
|
||||||
|
if ns := r.URL.Query().Get("namespace"); ns != "" {
|
||||||
|
return ns
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try header as fallback
|
||||||
|
if ns := r.Header.Get("X-Namespace"); ns != "" {
|
||||||
|
return ns
|
||||||
|
}
|
||||||
|
|
||||||
|
return "default"
|
||||||
|
}
|
||||||
|
|
||||||
|
// getWalletFromRequest extracts wallet address from JWT
|
||||||
|
func (h *ServerlessHandlers) getWalletFromRequest(r *http.Request) string {
|
||||||
|
// 1. Try X-Wallet header (legacy/direct bypass)
|
||||||
|
if wallet := r.Header.Get("X-Wallet"); wallet != "" {
|
||||||
|
return wallet
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Try JWT claims from context
|
||||||
|
if v := r.Context().Value(ctxKeyJWT); v != nil {
|
||||||
|
if claims, ok := v.(*auth.JWTClaims); ok && claims != nil {
|
||||||
|
subj := strings.TrimSpace(claims.Sub)
|
||||||
|
// Ensure it's not an API key (standard Orama logic)
|
||||||
|
if !strings.HasPrefix(strings.ToLower(subj), "ak_") && !strings.Contains(subj, ":") {
|
||||||
|
return subj
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. Fallback to API key identity (namespace)
|
||||||
|
if v := r.Context().Value(ctxKeyNamespaceOverride); v != nil {
|
||||||
|
if ns, ok := v.(string); ok && ns != "" {
|
||||||
|
return ns
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// HealthStatus returns the health status of the serverless engine
|
||||||
|
func (h *ServerlessHandlers) HealthStatus() map[string]interface{} {
|
||||||
|
stats := h.wsManager.GetStats()
|
||||||
|
return map[string]interface{}{
|
||||||
|
"status": "ok",
|
||||||
|
"connections": stats.ConnectionCount,
|
||||||
|
"topics": stats.TopicCount,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleFunctionTriggers handles trigger operations for a function
|
||||||
|
// Routes:
|
||||||
|
// - GET /v1/functions/{name}/triggers - List all triggers
|
||||||
|
// - POST /v1/functions/{name}/triggers/pubsub - Add pubsub trigger
|
||||||
|
// - DELETE /v1/functions/{name}/triggers/{id} - Remove trigger
|
||||||
|
func (h *ServerlessHandlers) handleFunctionTriggers(w http.ResponseWriter, r *http.Request, name string) {
|
||||||
|
if h.triggerManager == nil {
|
||||||
|
writeError(w, http.StatusServiceUnavailable, "Trigger management not available")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace := r.URL.Query().Get("namespace")
|
||||||
|
if namespace == "" {
|
||||||
|
namespace = h.getNamespaceFromRequest(r)
|
||||||
|
}
|
||||||
|
|
||||||
|
if namespace == "" {
|
||||||
|
writeError(w, http.StatusBadRequest, "namespace required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse sub-path for trigger type or ID
|
||||||
|
// Path after "triggers" could be: "", "pubsub", or "{trigger_id}"
|
||||||
|
fullPath := r.URL.Path
|
||||||
|
triggersIdx := strings.Index(fullPath, "/triggers")
|
||||||
|
subPath := ""
|
||||||
|
if triggersIdx > 0 {
|
||||||
|
subPath = strings.TrimPrefix(fullPath[triggersIdx:], "/triggers")
|
||||||
|
subPath = strings.TrimPrefix(subPath, "/")
|
||||||
|
}
|
||||||
|
|
||||||
|
switch r.Method {
|
||||||
|
case http.MethodGet:
|
||||||
|
h.listFunctionTriggers(w, r, namespace, name)
|
||||||
|
case http.MethodPost:
|
||||||
|
if subPath == "pubsub" {
|
||||||
|
h.addPubSubTrigger(w, r, namespace, name)
|
||||||
|
} else {
|
||||||
|
writeError(w, http.StatusBadRequest, "Invalid trigger type. Use /triggers/pubsub")
|
||||||
|
}
|
||||||
|
case http.MethodDelete:
|
||||||
|
if subPath == "" {
|
||||||
|
writeError(w, http.StatusBadRequest, "Trigger ID required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.removeFunctionTrigger(w, r, namespace, name, subPath)
|
||||||
|
default:
|
||||||
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// listFunctionTriggers handles GET /v1/functions/{name}/triggers
|
||||||
|
func (h *ServerlessHandlers) listFunctionTriggers(w http.ResponseWriter, r *http.Request, namespace, name string) {
|
||||||
|
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// Get function to verify it exists and get its ID
|
||||||
|
fn, err := h.registry.Get(ctx, namespace, name, 0)
|
||||||
|
if err != nil {
|
||||||
|
if serverless.IsNotFound(err) {
|
||||||
|
writeError(w, http.StatusNotFound, "Function not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.logger.Error("Failed to get function",
|
||||||
|
zap.String("name", name),
|
||||||
|
zap.String("namespace", namespace),
|
||||||
|
zap.Error(err),
|
||||||
|
)
|
||||||
|
writeError(w, http.StatusInternalServerError, "Failed to get function")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get pubsub triggers
|
||||||
|
pubsubTriggers, err := h.triggerManager.ListPubSubTriggers(ctx, fn.ID)
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Error("Failed to list triggers",
|
||||||
|
zap.String("function_id", fn.ID),
|
||||||
|
zap.Error(err),
|
||||||
|
)
|
||||||
|
writeError(w, http.StatusInternalServerError, "Failed to list triggers")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
writeJSON(w, http.StatusOK, map[string]interface{}{
|
||||||
|
"name": name,
|
||||||
|
"namespace": namespace,
|
||||||
|
"function_id": fn.ID,
|
||||||
|
"pubsub_triggers": pubsubTriggers,
|
||||||
|
"count": len(pubsubTriggers),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// addPubSubTrigger handles POST /v1/functions/{name}/triggers/pubsub
|
||||||
|
func (h *ServerlessHandlers) addPubSubTrigger(w http.ResponseWriter, r *http.Request, namespace, name string) {
|
||||||
|
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// Parse request body
|
||||||
|
var req struct {
|
||||||
|
Topic string `json:"topic"`
|
||||||
|
}
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
|
writeError(w, http.StatusBadRequest, "Invalid JSON: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Topic == "" {
|
||||||
|
writeError(w, http.StatusBadRequest, "topic is required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get function to verify it exists and get its ID
|
||||||
|
fn, err := h.registry.Get(ctx, namespace, name, 0)
|
||||||
|
if err != nil {
|
||||||
|
if serverless.IsNotFound(err) {
|
||||||
|
writeError(w, http.StatusNotFound, "Function not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.logger.Error("Failed to get function",
|
||||||
|
zap.String("name", name),
|
||||||
|
zap.String("namespace", namespace),
|
||||||
|
zap.Error(err),
|
||||||
|
)
|
||||||
|
writeError(w, http.StatusInternalServerError, "Failed to get function")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add the trigger
|
||||||
|
err = h.triggerManager.AddPubSubTrigger(ctx, fn.ID, req.Topic)
|
||||||
|
if err != nil {
|
||||||
|
if serverless.IsValidationError(err) {
|
||||||
|
writeError(w, http.StatusBadRequest, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.logger.Error("Failed to add pubsub trigger",
|
||||||
|
zap.String("function_id", fn.ID),
|
||||||
|
zap.String("topic", req.Topic),
|
||||||
|
zap.Error(err),
|
||||||
|
)
|
||||||
|
writeError(w, http.StatusInternalServerError, "Failed to add trigger")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the triggers to return the newly created one
|
||||||
|
triggers, _ := h.triggerManager.ListPubSubTriggers(ctx, fn.ID)
|
||||||
|
var newTrigger *serverless.PubSubTrigger
|
||||||
|
for i := range triggers {
|
||||||
|
if triggers[i].Topic == req.Topic {
|
||||||
|
newTrigger = &triggers[i]
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
h.logger.Info("PubSub trigger added",
|
||||||
|
zap.String("function", name),
|
||||||
|
zap.String("namespace", namespace),
|
||||||
|
zap.String("topic", req.Topic),
|
||||||
|
)
|
||||||
|
|
||||||
|
writeJSON(w, http.StatusCreated, map[string]interface{}{
|
||||||
|
"message": "Trigger added successfully",
|
||||||
|
"trigger": newTrigger,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// removeFunctionTrigger handles DELETE /v1/functions/{name}/triggers/{id}
|
||||||
|
func (h *ServerlessHandlers) removeFunctionTrigger(w http.ResponseWriter, r *http.Request, namespace, name, triggerID string) {
|
||||||
|
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// Verify function exists
|
||||||
|
fn, err := h.registry.Get(ctx, namespace, name, 0)
|
||||||
|
if err != nil {
|
||||||
|
if serverless.IsNotFound(err) {
|
||||||
|
writeError(w, http.StatusNotFound, "Function not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
writeError(w, http.StatusInternalServerError, "Failed to get function")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove the trigger
|
||||||
|
err = h.triggerManager.RemoveTrigger(ctx, triggerID)
|
||||||
|
if err != nil {
|
||||||
|
if err == serverless.ErrTriggerNotFound {
|
||||||
|
writeError(w, http.StatusNotFound, "Trigger not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.logger.Error("Failed to remove trigger",
|
||||||
|
zap.String("trigger_id", triggerID),
|
||||||
|
zap.Error(err),
|
||||||
|
)
|
||||||
|
writeError(w, http.StatusInternalServerError, "Failed to remove trigger")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
h.logger.Info("Trigger removed",
|
||||||
|
zap.String("function", name),
|
||||||
|
zap.String("function_id", fn.ID),
|
||||||
|
zap.String("trigger_id", triggerID),
|
||||||
|
)
|
||||||
|
|
||||||
|
writeJSON(w, http.StatusOK, map[string]interface{}{
|
||||||
|
"message": "Trigger removed successfully",
|
||||||
|
"trigger_id": triggerID,
|
||||||
|
})
|
||||||
|
}
|
||||||
@ -8,7 +8,6 @@ import (
|
|||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
serverlesshandlers "github.com/DeBrosOfficial/network/pkg/gateway/handlers/serverless"
|
|
||||||
"github.com/DeBrosOfficial/network/pkg/serverless"
|
"github.com/DeBrosOfficial/network/pkg/serverless"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
)
|
)
|
||||||
@ -50,12 +49,12 @@ func TestServerlessHandlers_ListFunctions(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
h := serverlesshandlers.NewServerlessHandlers(nil, registry, nil, logger)
|
h := NewServerlessHandlers(nil, registry, nil, nil, logger)
|
||||||
|
|
||||||
req, _ := http.NewRequest("GET", "/v1/functions?namespace=ns1", nil)
|
req, _ := http.NewRequest("GET", "/v1/functions?namespace=ns1", nil)
|
||||||
rr := httptest.NewRecorder()
|
rr := httptest.NewRecorder()
|
||||||
|
|
||||||
h.ListFunctions(rr, req)
|
h.handleFunctions(rr, req)
|
||||||
|
|
||||||
if rr.Code != http.StatusOK {
|
if rr.Code != http.StatusOK {
|
||||||
t.Errorf("expected status 200, got %d", rr.Code)
|
t.Errorf("expected status 200, got %d", rr.Code)
|
||||||
@ -73,7 +72,7 @@ func TestServerlessHandlers_DeployFunction(t *testing.T) {
|
|||||||
logger := zap.NewNop()
|
logger := zap.NewNop()
|
||||||
registry := &mockFunctionRegistry{}
|
registry := &mockFunctionRegistry{}
|
||||||
|
|
||||||
h := serverlesshandlers.NewServerlessHandlers(nil, registry, nil, logger)
|
h := NewServerlessHandlers(nil, registry, nil, nil, logger)
|
||||||
|
|
||||||
// Test JSON deploy (which is partially supported according to code)
|
// Test JSON deploy (which is partially supported according to code)
|
||||||
// Should be 400 because WASM is missing or base64 not supported
|
// Should be 400 because WASM is missing or base64 not supported
|
||||||
@ -81,7 +80,7 @@ func TestServerlessHandlers_DeployFunction(t *testing.T) {
|
|||||||
req, _ := http.NewRequest("POST", "/v1/functions", bytes.NewBufferString(`{"name": "test"}`))
|
req, _ := http.NewRequest("POST", "/v1/functions", bytes.NewBufferString(`{"name": "test"}`))
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
h.DeployFunction(writer, req)
|
h.handleFunctions(writer, req)
|
||||||
|
|
||||||
if writer.Code != http.StatusBadRequest {
|
if writer.Code != http.StatusBadRequest {
|
||||||
t.Errorf("expected status 400, got %d", writer.Code)
|
t.Errorf("expected status 400, got %d", writer.Code)
|
||||||
|
|||||||
181
pkg/gateway/sfu/config.go
Normal file
181
pkg/gateway/sfu/config.go
Normal file
@ -0,0 +1,181 @@
|
|||||||
|
package sfu
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pion/interceptor"
|
||||||
|
"github.com/pion/interceptor/pkg/intervalpli"
|
||||||
|
"github.com/pion/interceptor/pkg/nack"
|
||||||
|
"github.com/pion/webrtc/v4"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Config holds SFU configuration
|
||||||
|
type Config struct {
|
||||||
|
// MaxParticipants is the maximum number of participants per room
|
||||||
|
MaxParticipants int
|
||||||
|
|
||||||
|
// MediaTimeout is the timeout for media operations
|
||||||
|
MediaTimeout time.Duration
|
||||||
|
|
||||||
|
// ICEServers are the ICE servers for WebRTC connections
|
||||||
|
ICEServers []webrtc.ICEServer
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultConfig returns a default SFU configuration
|
||||||
|
func DefaultConfig() *Config {
|
||||||
|
return &Config{
|
||||||
|
MaxParticipants: 10,
|
||||||
|
MediaTimeout: 30 * time.Second,
|
||||||
|
ICEServers: []webrtc.ICEServer{
|
||||||
|
{
|
||||||
|
URLs: []string{"stun:stun.l.google.com:19302"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMediaEngine creates a MediaEngine with supported codecs for the SFU
|
||||||
|
func NewMediaEngine() (*webrtc.MediaEngine, error) {
|
||||||
|
m := &webrtc.MediaEngine{}
|
||||||
|
|
||||||
|
// RTCP feedback for video codecs - enables NACK, PLI, FIR
|
||||||
|
videoRTCPFeedback := []webrtc.RTCPFeedback{
|
||||||
|
{Type: "goog-remb", Parameter: ""}, // Bandwidth estimation
|
||||||
|
{Type: "ccm", Parameter: "fir"}, // Full Intra Request
|
||||||
|
{Type: "nack", Parameter: ""}, // Generic NACK
|
||||||
|
{Type: "nack", Parameter: "pli"}, // Picture Loss Indication
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register Opus codec for audio with NACK support
|
||||||
|
if err := m.RegisterCodec(webrtc.RTPCodecParameters{
|
||||||
|
RTPCodecCapability: webrtc.RTPCodecCapability{
|
||||||
|
MimeType: webrtc.MimeTypeOpus,
|
||||||
|
ClockRate: 48000,
|
||||||
|
Channels: 2,
|
||||||
|
SDPFmtpLine: "minptime=10;useinbandfec=1",
|
||||||
|
},
|
||||||
|
PayloadType: 111,
|
||||||
|
}, webrtc.RTPCodecTypeAudio); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register VP8 codec for video with full RTCP feedback
|
||||||
|
if err := m.RegisterCodec(webrtc.RTPCodecParameters{
|
||||||
|
RTPCodecCapability: webrtc.RTPCodecCapability{
|
||||||
|
MimeType: webrtc.MimeTypeVP8,
|
||||||
|
ClockRate: 90000,
|
||||||
|
RTCPFeedback: videoRTCPFeedback,
|
||||||
|
},
|
||||||
|
PayloadType: 96,
|
||||||
|
}, webrtc.RTPCodecTypeVideo); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register RTX for VP8 (retransmission)
|
||||||
|
if err := m.RegisterCodec(webrtc.RTPCodecParameters{
|
||||||
|
RTPCodecCapability: webrtc.RTPCodecCapability{
|
||||||
|
MimeType: "video/rtx",
|
||||||
|
ClockRate: 90000,
|
||||||
|
SDPFmtpLine: "apt=96",
|
||||||
|
},
|
||||||
|
PayloadType: 97,
|
||||||
|
}, webrtc.RTPCodecTypeVideo); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register H264 codec for video with full RTCP feedback (fallback)
|
||||||
|
if err := m.RegisterCodec(webrtc.RTPCodecParameters{
|
||||||
|
RTPCodecCapability: webrtc.RTPCodecCapability{
|
||||||
|
MimeType: webrtc.MimeTypeH264,
|
||||||
|
ClockRate: 90000,
|
||||||
|
SDPFmtpLine: "level-asymmetry-allowed=1;packetization-mode=1;profile-level-id=42e01f",
|
||||||
|
RTCPFeedback: videoRTCPFeedback,
|
||||||
|
},
|
||||||
|
PayloadType: 102,
|
||||||
|
}, webrtc.RTPCodecTypeVideo); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register RTX for H264 (retransmission)
|
||||||
|
if err := m.RegisterCodec(webrtc.RTPCodecParameters{
|
||||||
|
RTPCodecCapability: webrtc.RTPCodecCapability{
|
||||||
|
MimeType: "video/rtx",
|
||||||
|
ClockRate: 90000,
|
||||||
|
SDPFmtpLine: "apt=102",
|
||||||
|
},
|
||||||
|
PayloadType: 103,
|
||||||
|
}, webrtc.RTPCodecTypeVideo); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewWebRTCAPI creates a new WebRTC API with the configured MediaEngine
|
||||||
|
func NewWebRTCAPI() (*webrtc.API, error) {
|
||||||
|
mediaEngine, err := NewMediaEngine()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create interceptor registry for RTCP feedback
|
||||||
|
i := &interceptor.Registry{}
|
||||||
|
|
||||||
|
// Register NACK responder - handles retransmission requests from receivers
|
||||||
|
// This is critical for video quality: when a receiver loses a packet,
|
||||||
|
// it sends NACK and the sender retransmits the lost packet
|
||||||
|
nackResponderFactory, err := nack.NewResponderInterceptor()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
i.Add(nackResponderFactory)
|
||||||
|
|
||||||
|
// Register NACK generator - sends NACK when we detect packet loss as receiver
|
||||||
|
nackGeneratorFactory, err := nack.NewGeneratorInterceptor()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
i.Add(nackGeneratorFactory)
|
||||||
|
|
||||||
|
// Register interval PLI - automatically sends PLI periodically for video tracks
|
||||||
|
// This helps new receivers get keyframes faster
|
||||||
|
intervalPLIFactory, err := intervalpli.NewReceiverInterceptor()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
i.Add(intervalPLIFactory)
|
||||||
|
|
||||||
|
// Configure settings for better media performance
|
||||||
|
settingEngine := webrtc.SettingEngine{}
|
||||||
|
|
||||||
|
return webrtc.NewAPI(
|
||||||
|
webrtc.WithMediaEngine(mediaEngine),
|
||||||
|
webrtc.WithInterceptorRegistry(i),
|
||||||
|
webrtc.WithSettingEngine(settingEngine),
|
||||||
|
), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRTPCapabilities returns the RTP capabilities of the SFU
|
||||||
|
// This is used by clients to negotiate codecs
|
||||||
|
func GetRTPCapabilities() map[string]interface{} {
|
||||||
|
return map[string]interface{}{
|
||||||
|
"codecs": []map[string]interface{}{
|
||||||
|
{
|
||||||
|
"kind": "audio",
|
||||||
|
"mimeType": "audio/opus",
|
||||||
|
"clockRate": 48000,
|
||||||
|
"channels": 2,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"kind": "video",
|
||||||
|
"mimeType": "video/VP8",
|
||||||
|
"clockRate": 90000,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"kind": "video",
|
||||||
|
"mimeType": "video/H264",
|
||||||
|
"clockRate": 90000,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
228
pkg/gateway/sfu/manager.go
Normal file
228
pkg/gateway/sfu/manager.go
Normal file
@ -0,0 +1,228 @@
|
|||||||
|
package sfu
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/pion/webrtc/v4"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrRoomNotFound = errors.New("room not found")
|
||||||
|
)
|
||||||
|
|
||||||
|
// RoomManager manages all rooms in the SFU
|
||||||
|
type RoomManager struct {
|
||||||
|
rooms map[string]*Room // key: namespace:roomId
|
||||||
|
roomsMu sync.RWMutex
|
||||||
|
|
||||||
|
api *webrtc.API
|
||||||
|
config *Config
|
||||||
|
logger *zap.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRoomManager creates a new room manager
|
||||||
|
func NewRoomManager(config *Config, logger *zap.Logger) (*RoomManager, error) {
|
||||||
|
if config == nil {
|
||||||
|
config = DefaultConfig()
|
||||||
|
}
|
||||||
|
|
||||||
|
api, err := NewWebRTCAPI()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &RoomManager{
|
||||||
|
rooms: make(map[string]*Room),
|
||||||
|
api: api,
|
||||||
|
config: config,
|
||||||
|
logger: logger.With(zap.String("component", "sfu_manager")),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// roomKey creates a unique key for namespace:roomId
|
||||||
|
func roomKey(namespace, roomID string) string {
|
||||||
|
return namespace + ":" + roomID
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOrCreateRoom gets an existing room or creates a new one
|
||||||
|
func (m *RoomManager) GetOrCreateRoom(namespace, roomID string) (*Room, bool) {
|
||||||
|
key := roomKey(namespace, roomID)
|
||||||
|
|
||||||
|
m.roomsMu.Lock()
|
||||||
|
defer m.roomsMu.Unlock()
|
||||||
|
|
||||||
|
// Check if room already exists
|
||||||
|
if room, ok := m.rooms[key]; ok {
|
||||||
|
if !room.IsClosed() {
|
||||||
|
return room, false
|
||||||
|
}
|
||||||
|
// Room is closed, remove it and create a new one
|
||||||
|
delete(m.rooms, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create new room
|
||||||
|
room := NewRoom(roomID, namespace, m.api, m.config, m.logger)
|
||||||
|
|
||||||
|
// Set up empty room handler
|
||||||
|
room.OnEmpty(func(r *Room) {
|
||||||
|
m.logger.Info("Room is empty, will be garbage collected",
|
||||||
|
zap.String("room_id", r.ID),
|
||||||
|
zap.String("namespace", r.Namespace),
|
||||||
|
)
|
||||||
|
// Optionally close the room immediately
|
||||||
|
// m.CloseRoom(r.Namespace, r.ID)
|
||||||
|
})
|
||||||
|
|
||||||
|
m.rooms[key] = room
|
||||||
|
|
||||||
|
m.logger.Info("Room created",
|
||||||
|
zap.String("room_id", roomID),
|
||||||
|
zap.String("namespace", namespace),
|
||||||
|
)
|
||||||
|
|
||||||
|
return room, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRoom gets an existing room
|
||||||
|
func (m *RoomManager) GetRoom(namespace, roomID string) (*Room, error) {
|
||||||
|
key := roomKey(namespace, roomID)
|
||||||
|
|
||||||
|
m.roomsMu.RLock()
|
||||||
|
defer m.roomsMu.RUnlock()
|
||||||
|
|
||||||
|
room, ok := m.rooms[key]
|
||||||
|
if !ok {
|
||||||
|
return nil, ErrRoomNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
if room.IsClosed() {
|
||||||
|
return nil, ErrRoomClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
return room, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CloseRoom closes and removes a room
|
||||||
|
func (m *RoomManager) CloseRoom(namespace, roomID string) error {
|
||||||
|
key := roomKey(namespace, roomID)
|
||||||
|
|
||||||
|
m.roomsMu.Lock()
|
||||||
|
room, ok := m.rooms[key]
|
||||||
|
if !ok {
|
||||||
|
m.roomsMu.Unlock()
|
||||||
|
return ErrRoomNotFound
|
||||||
|
}
|
||||||
|
delete(m.rooms, key)
|
||||||
|
m.roomsMu.Unlock()
|
||||||
|
|
||||||
|
if err := room.Close(); err != nil {
|
||||||
|
m.logger.Warn("Error closing room",
|
||||||
|
zap.String("room_id", roomID),
|
||||||
|
zap.Error(err),
|
||||||
|
)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
m.logger.Info("Room closed",
|
||||||
|
zap.String("room_id", roomID),
|
||||||
|
zap.String("namespace", namespace),
|
||||||
|
)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListRooms returns all rooms for a namespace
|
||||||
|
func (m *RoomManager) ListRooms(namespace string) []*Room {
|
||||||
|
m.roomsMu.RLock()
|
||||||
|
defer m.roomsMu.RUnlock()
|
||||||
|
|
||||||
|
prefix := namespace + ":"
|
||||||
|
rooms := make([]*Room, 0)
|
||||||
|
|
||||||
|
for key, room := range m.rooms {
|
||||||
|
if len(key) > len(prefix) && key[:len(prefix)] == prefix && !room.IsClosed() {
|
||||||
|
rooms = append(rooms, room)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return rooms
|
||||||
|
}
|
||||||
|
|
||||||
|
// RoomInfo contains public information about a room
|
||||||
|
type RoomInfo struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Namespace string `json:"namespace"`
|
||||||
|
ParticipantCount int `json:"participantCount"`
|
||||||
|
Participants []ParticipantInfo `json:"participants"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRoomInfo returns public information about a room
|
||||||
|
func (m *RoomManager) GetRoomInfo(namespace, roomID string) (*RoomInfo, error) {
|
||||||
|
room, err := m.GetRoom(namespace, roomID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &RoomInfo{
|
||||||
|
ID: room.ID,
|
||||||
|
Namespace: room.Namespace,
|
||||||
|
ParticipantCount: room.GetParticipantCount(),
|
||||||
|
Participants: room.GetParticipants(),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetStats returns statistics about the room manager
|
||||||
|
func (m *RoomManager) GetStats() map[string]interface{} {
|
||||||
|
m.roomsMu.RLock()
|
||||||
|
defer m.roomsMu.RUnlock()
|
||||||
|
|
||||||
|
totalRooms := 0
|
||||||
|
totalParticipants := 0
|
||||||
|
|
||||||
|
for _, room := range m.rooms {
|
||||||
|
if !room.IsClosed() {
|
||||||
|
totalRooms++
|
||||||
|
totalParticipants += room.GetParticipantCount()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return map[string]interface{}{
|
||||||
|
"totalRooms": totalRooms,
|
||||||
|
"totalParticipants": totalParticipants,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes all rooms and cleans up resources
|
||||||
|
func (m *RoomManager) Close() error {
|
||||||
|
m.roomsMu.Lock()
|
||||||
|
rooms := make([]*Room, 0, len(m.rooms))
|
||||||
|
for _, room := range m.rooms {
|
||||||
|
rooms = append(rooms, room)
|
||||||
|
}
|
||||||
|
m.rooms = make(map[string]*Room)
|
||||||
|
m.roomsMu.Unlock()
|
||||||
|
|
||||||
|
for _, room := range rooms {
|
||||||
|
if err := room.Close(); err != nil {
|
||||||
|
m.logger.Warn("Error closing room during shutdown",
|
||||||
|
zap.String("room_id", room.ID),
|
||||||
|
zap.Error(err),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
m.logger.Info("Room manager closed", zap.Int("rooms_closed", len(rooms)))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetConfig returns the SFU configuration
|
||||||
|
func (m *RoomManager) GetConfig() *Config {
|
||||||
|
return m.config
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateICEServers updates the ICE servers configuration
|
||||||
|
func (m *RoomManager) UpdateICEServers(servers []webrtc.ICEServer) {
|
||||||
|
m.config.ICEServers = servers
|
||||||
|
}
|
||||||
505
pkg/gateway/sfu/peer.go
Normal file
505
pkg/gateway/sfu/peer.go
Normal file
@ -0,0 +1,505 @@
|
|||||||
|
package sfu
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
"github.com/pion/rtcp"
|
||||||
|
"github.com/pion/webrtc/v4"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Peer represents a participant in a room
|
||||||
|
type Peer struct {
|
||||||
|
ID string
|
||||||
|
UserID string
|
||||||
|
DisplayName string
|
||||||
|
|
||||||
|
// WebRTC connection
|
||||||
|
pc *webrtc.PeerConnection
|
||||||
|
|
||||||
|
// Tracks published by this peer (local tracks that others receive)
|
||||||
|
localTracks map[string]*webrtc.TrackLocalStaticRTP
|
||||||
|
localTracksMu sync.RWMutex
|
||||||
|
|
||||||
|
// Track receivers for consuming other peers' tracks
|
||||||
|
trackReceivers map[string]*webrtc.RTPReceiver
|
||||||
|
trackReceiversMu sync.RWMutex
|
||||||
|
|
||||||
|
// WebSocket connection for signaling
|
||||||
|
conn *websocket.Conn
|
||||||
|
connMu sync.Mutex
|
||||||
|
|
||||||
|
// State
|
||||||
|
audioMuted bool
|
||||||
|
videoMuted bool
|
||||||
|
closed bool
|
||||||
|
closedMu sync.RWMutex
|
||||||
|
negotiationPending bool
|
||||||
|
negotiationPendingMu sync.Mutex
|
||||||
|
initialOfferHandled bool
|
||||||
|
initialOfferMu sync.Mutex
|
||||||
|
batchingTracks bool // When true, suppress automatic negotiation
|
||||||
|
batchingTracksMu sync.Mutex
|
||||||
|
|
||||||
|
// Room reference
|
||||||
|
room *Room
|
||||||
|
logger *zap.Logger
|
||||||
|
|
||||||
|
// Callbacks
|
||||||
|
onClose func(*Peer)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewPeer creates a new peer
|
||||||
|
func NewPeer(userID, displayName string, conn *websocket.Conn, room *Room, logger *zap.Logger) *Peer {
|
||||||
|
return &Peer{
|
||||||
|
ID: uuid.New().String(),
|
||||||
|
UserID: userID,
|
||||||
|
DisplayName: displayName,
|
||||||
|
localTracks: make(map[string]*webrtc.TrackLocalStaticRTP),
|
||||||
|
trackReceivers: make(map[string]*webrtc.RTPReceiver),
|
||||||
|
conn: conn,
|
||||||
|
room: room,
|
||||||
|
logger: logger,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// InitPeerConnection initializes the WebRTC peer connection
|
||||||
|
func (p *Peer) InitPeerConnection(api *webrtc.API, config webrtc.Configuration) error {
|
||||||
|
pc, err := api.NewPeerConnection(config)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
p.pc = pc
|
||||||
|
|
||||||
|
// Handle ICE connection state changes
|
||||||
|
pc.OnICEConnectionStateChange(func(state webrtc.ICEConnectionState) {
|
||||||
|
p.logger.Info("ICE connection state changed",
|
||||||
|
zap.String("peer_id", p.ID),
|
||||||
|
zap.String("state", state.String()),
|
||||||
|
)
|
||||||
|
|
||||||
|
if state == webrtc.ICEConnectionStateFailed ||
|
||||||
|
state == webrtc.ICEConnectionStateDisconnected ||
|
||||||
|
state == webrtc.ICEConnectionStateClosed {
|
||||||
|
p.handleDisconnect()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Handle ICE candidates
|
||||||
|
pc.OnICECandidate(func(candidate *webrtc.ICECandidate) {
|
||||||
|
if candidate == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
p.logger.Debug("ICE candidate generated",
|
||||||
|
zap.String("peer_id", p.ID),
|
||||||
|
zap.String("candidate", candidate.String()),
|
||||||
|
)
|
||||||
|
|
||||||
|
candidateJSON := candidate.ToJSON()
|
||||||
|
data := &ICECandidateData{
|
||||||
|
Candidate: candidateJSON.Candidate,
|
||||||
|
}
|
||||||
|
if candidateJSON.SDPMid != nil {
|
||||||
|
data.SDPMid = *candidateJSON.SDPMid
|
||||||
|
}
|
||||||
|
if candidateJSON.SDPMLineIndex != nil {
|
||||||
|
data.SDPMLineIndex = *candidateJSON.SDPMLineIndex
|
||||||
|
}
|
||||||
|
if candidateJSON.UsernameFragment != nil {
|
||||||
|
data.UsernameFragment = *candidateJSON.UsernameFragment
|
||||||
|
}
|
||||||
|
|
||||||
|
p.SendMessage(NewServerMessage(MessageTypeICECandidate, data))
|
||||||
|
})
|
||||||
|
|
||||||
|
// Handle incoming tracks from remote peers
|
||||||
|
pc.OnTrack(func(track *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) {
|
||||||
|
codec := track.Codec()
|
||||||
|
p.logger.Info("Track received from client",
|
||||||
|
zap.String("peer_id", p.ID),
|
||||||
|
zap.String("user_id", p.UserID),
|
||||||
|
zap.String("track_id", track.ID()),
|
||||||
|
zap.String("stream_id", track.StreamID()),
|
||||||
|
zap.String("kind", track.Kind().String()),
|
||||||
|
zap.String("codec_mime", codec.MimeType),
|
||||||
|
zap.Uint32("codec_clock_rate", codec.ClockRate),
|
||||||
|
zap.Uint8("codec_payload_type", uint8(codec.PayloadType)),
|
||||||
|
)
|
||||||
|
|
||||||
|
p.trackReceiversMu.Lock()
|
||||||
|
p.trackReceivers[track.ID()] = receiver
|
||||||
|
p.trackReceiversMu.Unlock()
|
||||||
|
|
||||||
|
// Start RTCP reader to monitor for packet loss (NACK) and PLI requests
|
||||||
|
go p.readRTCP(receiver, track)
|
||||||
|
|
||||||
|
// Forward track to other peers in the room
|
||||||
|
p.room.BroadcastTrack(p.ID, track)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Handle negotiation needed - only trigger when in stable state
|
||||||
|
pc.OnNegotiationNeeded(func() {
|
||||||
|
p.logger.Debug("Negotiation needed",
|
||||||
|
zap.String("peer_id", p.ID),
|
||||||
|
zap.String("signaling_state", pc.SignalingState().String()),
|
||||||
|
)
|
||||||
|
|
||||||
|
// Check if we're batching tracks - if so, just mark as pending
|
||||||
|
p.batchingTracksMu.Lock()
|
||||||
|
batching := p.batchingTracks
|
||||||
|
p.batchingTracksMu.Unlock()
|
||||||
|
|
||||||
|
if batching {
|
||||||
|
p.negotiationPendingMu.Lock()
|
||||||
|
p.negotiationPending = true
|
||||||
|
p.negotiationPendingMu.Unlock()
|
||||||
|
p.logger.Debug("Negotiation deferred - batching tracks",
|
||||||
|
zap.String("peer_id", p.ID),
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only create offer if we're in stable state
|
||||||
|
// Otherwise, mark negotiation as pending
|
||||||
|
if pc.SignalingState() == webrtc.SignalingStateStable {
|
||||||
|
p.createAndSendOffer()
|
||||||
|
} else {
|
||||||
|
p.negotiationPendingMu.Lock()
|
||||||
|
p.negotiationPending = true
|
||||||
|
p.negotiationPendingMu.Unlock()
|
||||||
|
p.logger.Debug("Negotiation queued - not in stable state",
|
||||||
|
zap.String("peer_id", p.ID),
|
||||||
|
zap.String("signaling_state", pc.SignalingState().String()),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Handle signaling state changes to process pending negotiations
|
||||||
|
pc.OnSignalingStateChange(func(state webrtc.SignalingState) {
|
||||||
|
p.logger.Debug("Signaling state changed",
|
||||||
|
zap.String("peer_id", p.ID),
|
||||||
|
zap.String("state", state.String()),
|
||||||
|
)
|
||||||
|
|
||||||
|
// When we return to stable state, check if negotiation was pending
|
||||||
|
if state == webrtc.SignalingStateStable {
|
||||||
|
p.negotiationPendingMu.Lock()
|
||||||
|
pending := p.negotiationPending
|
||||||
|
p.negotiationPending = false
|
||||||
|
p.negotiationPendingMu.Unlock()
|
||||||
|
|
||||||
|
if pending {
|
||||||
|
p.logger.Debug("Processing pending negotiation", zap.String("peer_id", p.ID))
|
||||||
|
p.createAndSendOffer()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// createAndSendOffer creates an SDP offer and sends it to the peer
|
||||||
|
func (p *Peer) createAndSendOffer() {
|
||||||
|
if p.pc == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Double-check signaling state before creating offer
|
||||||
|
if p.pc.SignalingState() != webrtc.SignalingStateStable {
|
||||||
|
p.logger.Debug("Skipping offer - not in stable state",
|
||||||
|
zap.String("peer_id", p.ID),
|
||||||
|
zap.String("signaling_state", p.pc.SignalingState().String()),
|
||||||
|
)
|
||||||
|
// Mark as pending so it will be retried when state becomes stable
|
||||||
|
p.negotiationPendingMu.Lock()
|
||||||
|
p.negotiationPending = true
|
||||||
|
p.negotiationPendingMu.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
p.logger.Info("Creating SDP offer", zap.String("peer_id", p.ID))
|
||||||
|
|
||||||
|
offer, err := p.pc.CreateOffer(nil)
|
||||||
|
if err != nil {
|
||||||
|
p.logger.Error("Failed to create offer",
|
||||||
|
zap.String("peer_id", p.ID),
|
||||||
|
zap.Error(err),
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := p.pc.SetLocalDescription(offer); err != nil {
|
||||||
|
p.logger.Error("Failed to set local description",
|
||||||
|
zap.String("peer_id", p.ID),
|
||||||
|
zap.Error(err),
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
p.logger.Info("Sending SDP offer", zap.String("peer_id", p.ID))
|
||||||
|
p.SendMessage(NewServerMessage(MessageTypeOffer, &OfferData{
|
||||||
|
SDP: offer.SDP,
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandleOffer processes an SDP offer from the client
|
||||||
|
func (p *Peer) HandleOffer(sdp string) error {
|
||||||
|
if p.pc == nil {
|
||||||
|
return ErrPeerNotInitialized
|
||||||
|
}
|
||||||
|
|
||||||
|
offer := webrtc.SessionDescription{
|
||||||
|
Type: webrtc.SDPTypeOffer,
|
||||||
|
SDP: sdp,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := p.pc.SetRemoteDescription(offer); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
answer, err := p.pc.CreateAnswer(nil)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := p.pc.SetLocalDescription(answer); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
p.SendMessage(NewServerMessage(MessageTypeAnswer, &AnswerData{
|
||||||
|
SDP: answer.SDP,
|
||||||
|
}))
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandleAnswer processes an SDP answer from the client
|
||||||
|
func (p *Peer) HandleAnswer(sdp string) error {
|
||||||
|
if p.pc == nil {
|
||||||
|
return ErrPeerNotInitialized
|
||||||
|
}
|
||||||
|
|
||||||
|
answer := webrtc.SessionDescription{
|
||||||
|
Type: webrtc.SDPTypeAnswer,
|
||||||
|
SDP: sdp,
|
||||||
|
}
|
||||||
|
|
||||||
|
return p.pc.SetRemoteDescription(answer)
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandleICECandidate processes an ICE candidate from the client
|
||||||
|
func (p *Peer) HandleICECandidate(data *ICECandidateData) error {
|
||||||
|
if p.pc == nil {
|
||||||
|
return ErrPeerNotInitialized
|
||||||
|
}
|
||||||
|
|
||||||
|
return p.pc.AddICECandidate(data.ToWebRTCCandidate())
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddTrack adds a track to send to this peer (from another peer)
|
||||||
|
func (p *Peer) AddTrack(track *webrtc.TrackLocalStaticRTP) (*webrtc.RTPSender, error) {
|
||||||
|
if p.pc == nil {
|
||||||
|
return nil, ErrPeerNotInitialized
|
||||||
|
}
|
||||||
|
|
||||||
|
return p.pc.AddTrack(track)
|
||||||
|
}
|
||||||
|
|
||||||
|
// StartTrackBatch starts batching track additions.
|
||||||
|
// Call EndTrackBatch when done to trigger a single renegotiation.
|
||||||
|
func (p *Peer) StartTrackBatch() {
|
||||||
|
p.batchingTracksMu.Lock()
|
||||||
|
p.batchingTracks = true
|
||||||
|
p.batchingTracksMu.Unlock()
|
||||||
|
p.logger.Debug("Started track batching", zap.String("peer_id", p.ID))
|
||||||
|
}
|
||||||
|
|
||||||
|
// EndTrackBatch ends track batching and triggers renegotiation if needed.
|
||||||
|
func (p *Peer) EndTrackBatch() {
|
||||||
|
p.batchingTracksMu.Lock()
|
||||||
|
p.batchingTracks = false
|
||||||
|
p.batchingTracksMu.Unlock()
|
||||||
|
|
||||||
|
// Check if negotiation was pending during batching
|
||||||
|
p.negotiationPendingMu.Lock()
|
||||||
|
pending := p.negotiationPending
|
||||||
|
p.negotiationPending = false
|
||||||
|
p.negotiationPendingMu.Unlock()
|
||||||
|
|
||||||
|
if pending && p.pc != nil && p.pc.SignalingState() == webrtc.SignalingStateStable {
|
||||||
|
p.logger.Debug("Processing batched negotiation", zap.String("peer_id", p.ID))
|
||||||
|
p.createAndSendOffer()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendMessage sends a signaling message to the peer via WebSocket
|
||||||
|
func (p *Peer) SendMessage(msg *ServerMessage) error {
|
||||||
|
p.closedMu.RLock()
|
||||||
|
if p.closed {
|
||||||
|
p.closedMu.RUnlock()
|
||||||
|
return ErrPeerClosed
|
||||||
|
}
|
||||||
|
p.closedMu.RUnlock()
|
||||||
|
|
||||||
|
p.connMu.Lock()
|
||||||
|
defer p.connMu.Unlock()
|
||||||
|
|
||||||
|
if p.conn == nil {
|
||||||
|
return ErrWebSocketClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := json.Marshal(msg)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return p.conn.WriteMessage(websocket.TextMessage, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetInfo returns public information about this peer
|
||||||
|
func (p *Peer) GetInfo() ParticipantInfo {
|
||||||
|
p.localTracksMu.RLock()
|
||||||
|
hasAudio := false
|
||||||
|
hasVideo := false
|
||||||
|
for _, track := range p.localTracks {
|
||||||
|
if track.Kind() == webrtc.RTPCodecTypeAudio {
|
||||||
|
hasAudio = true
|
||||||
|
} else if track.Kind() == webrtc.RTPCodecTypeVideo {
|
||||||
|
hasVideo = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
p.localTracksMu.RUnlock()
|
||||||
|
|
||||||
|
return ParticipantInfo{
|
||||||
|
ID: p.ID,
|
||||||
|
UserID: p.UserID,
|
||||||
|
DisplayName: p.DisplayName,
|
||||||
|
HasAudio: hasAudio,
|
||||||
|
HasVideo: hasVideo,
|
||||||
|
AudioMuted: p.audioMuted,
|
||||||
|
VideoMuted: p.videoMuted,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetAudioMuted sets the audio mute state
|
||||||
|
func (p *Peer) SetAudioMuted(muted bool) {
|
||||||
|
p.audioMuted = muted
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetVideoMuted sets the video mute state
|
||||||
|
func (p *Peer) SetVideoMuted(muted bool) {
|
||||||
|
p.videoMuted = muted
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarkInitialOfferHandled marks that the initial offer has been processed.
|
||||||
|
// Returns true if this is the first time it's being marked (i.e., this was the first offer).
|
||||||
|
func (p *Peer) MarkInitialOfferHandled() bool {
|
||||||
|
p.initialOfferMu.Lock()
|
||||||
|
defer p.initialOfferMu.Unlock()
|
||||||
|
|
||||||
|
if p.initialOfferHandled {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
p.initialOfferHandled = true
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleDisconnect handles peer disconnection
|
||||||
|
func (p *Peer) handleDisconnect() {
|
||||||
|
p.closedMu.Lock()
|
||||||
|
if p.closed {
|
||||||
|
p.closedMu.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
p.closed = true
|
||||||
|
p.closedMu.Unlock()
|
||||||
|
|
||||||
|
if p.onClose != nil {
|
||||||
|
p.onClose(p)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the peer connection and cleans up resources
|
||||||
|
func (p *Peer) Close() error {
|
||||||
|
p.closedMu.Lock()
|
||||||
|
if p.closed {
|
||||||
|
p.closedMu.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
p.closed = true
|
||||||
|
p.closedMu.Unlock()
|
||||||
|
|
||||||
|
p.logger.Info("Closing peer", zap.String("peer_id", p.ID))
|
||||||
|
|
||||||
|
// Close WebSocket
|
||||||
|
p.connMu.Lock()
|
||||||
|
if p.conn != nil {
|
||||||
|
p.conn.Close()
|
||||||
|
p.conn = nil
|
||||||
|
}
|
||||||
|
p.connMu.Unlock()
|
||||||
|
|
||||||
|
// Close peer connection
|
||||||
|
if p.pc != nil {
|
||||||
|
return p.pc.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnClose sets a callback for when the peer is closed
|
||||||
|
func (p *Peer) OnClose(fn func(*Peer)) {
|
||||||
|
p.onClose = fn
|
||||||
|
}
|
||||||
|
|
||||||
|
// readRTCP reads RTCP packets from a receiver to monitor feedback
|
||||||
|
// This helps detect packet loss (via NACK) for adaptive quality adjustments
|
||||||
|
func (p *Peer) readRTCP(receiver *webrtc.RTPReceiver, track *webrtc.TrackRemote) {
|
||||||
|
localTrackID := track.Kind().String() + "-" + p.ID
|
||||||
|
|
||||||
|
for {
|
||||||
|
packets, _, err := receiver.ReadRTCP()
|
||||||
|
if err != nil {
|
||||||
|
// Connection closed, exit gracefully
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, pkt := range packets {
|
||||||
|
switch rtcpPkt := pkt.(type) {
|
||||||
|
case *rtcp.TransportLayerNack:
|
||||||
|
// NACK received - indicates packet loss from receivers
|
||||||
|
// Increment the NACK counter for adaptive keyframe logic
|
||||||
|
p.room.IncrementNackCount(localTrackID)
|
||||||
|
|
||||||
|
p.logger.Debug("NACK received",
|
||||||
|
zap.String("peer_id", p.ID),
|
||||||
|
zap.String("track_id", localTrackID),
|
||||||
|
zap.Uint32("sender_ssrc", rtcpPkt.SenderSSRC),
|
||||||
|
zap.Int("nack_pairs", len(rtcpPkt.Nacks)),
|
||||||
|
)
|
||||||
|
|
||||||
|
case *rtcp.PictureLossIndication:
|
||||||
|
// PLI received - receiver needs a keyframe
|
||||||
|
p.logger.Debug("PLI received from receiver",
|
||||||
|
zap.String("peer_id", p.ID),
|
||||||
|
zap.String("track_id", localTrackID),
|
||||||
|
)
|
||||||
|
// Request keyframe from source
|
||||||
|
p.room.RequestKeyframe(localTrackID)
|
||||||
|
|
||||||
|
case *rtcp.FullIntraRequest:
|
||||||
|
// FIR received - receiver needs a full keyframe
|
||||||
|
p.logger.Debug("FIR received from receiver",
|
||||||
|
zap.String("peer_id", p.ID),
|
||||||
|
zap.String("track_id", localTrackID),
|
||||||
|
)
|
||||||
|
// Request keyframe from source
|
||||||
|
p.room.RequestKeyframe(localTrackID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
775
pkg/gateway/sfu/room.go
Normal file
775
pkg/gateway/sfu/room.go
Normal file
@ -0,0 +1,775 @@
|
|||||||
|
package sfu
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pion/rtcp"
|
||||||
|
"github.com/pion/webrtc/v4"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Common errors
|
||||||
|
var (
|
||||||
|
ErrRoomFull = errors.New("room is full")
|
||||||
|
ErrRoomClosed = errors.New("room is closed")
|
||||||
|
ErrPeerNotFound = errors.New("peer not found")
|
||||||
|
ErrPeerNotInitialized = errors.New("peer not initialized")
|
||||||
|
ErrPeerClosed = errors.New("peer is closed")
|
||||||
|
ErrWebSocketClosed = errors.New("websocket connection closed")
|
||||||
|
)
|
||||||
|
|
||||||
|
// publishedTrack holds information about a track published to the room
|
||||||
|
type publishedTrack struct {
|
||||||
|
sourcePeerID string
|
||||||
|
sourceUserID string
|
||||||
|
localTrack *webrtc.TrackLocalStaticRTP
|
||||||
|
remoteTrackSSRC uint32 // SSRC of the remote track (for PLI requests)
|
||||||
|
remoteTrack *webrtc.TrackRemote // Reference to the remote track
|
||||||
|
kind string
|
||||||
|
forwarderActive bool // tracks if the RTP forwarder goroutine is still running
|
||||||
|
packetCount int // number of packets forwarded (for debugging)
|
||||||
|
|
||||||
|
// Packet loss tracking for adaptive keyframe requests
|
||||||
|
lastKeyframeRequest time.Time
|
||||||
|
nackCount atomic.Int64 // Number of NACK packets received (indicates packet loss)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Room represents a WebRTC room with multiple participants
|
||||||
|
type Room struct {
|
||||||
|
ID string
|
||||||
|
Namespace string
|
||||||
|
|
||||||
|
// Participants in the room
|
||||||
|
peers map[string]*Peer
|
||||||
|
peersMu sync.RWMutex
|
||||||
|
|
||||||
|
// Published tracks in the room (for sending to new joiners)
|
||||||
|
publishedTracks map[string]*publishedTrack // key: trackID
|
||||||
|
publishedTracksMu sync.RWMutex
|
||||||
|
|
||||||
|
// WebRTC API for creating peer connections
|
||||||
|
api *webrtc.API
|
||||||
|
|
||||||
|
// Configuration
|
||||||
|
config *Config
|
||||||
|
logger *zap.Logger
|
||||||
|
|
||||||
|
// State
|
||||||
|
closed bool
|
||||||
|
closedMu sync.RWMutex
|
||||||
|
|
||||||
|
// Callbacks
|
||||||
|
onEmpty func(*Room)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRoom creates a new room
|
||||||
|
func NewRoom(id, namespace string, api *webrtc.API, config *Config, logger *zap.Logger) *Room {
|
||||||
|
return &Room{
|
||||||
|
ID: id,
|
||||||
|
Namespace: namespace,
|
||||||
|
peers: make(map[string]*Peer),
|
||||||
|
publishedTracks: make(map[string]*publishedTrack),
|
||||||
|
api: api,
|
||||||
|
config: config,
|
||||||
|
logger: logger.With(zap.String("room_id", id)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddPeer adds a new peer to the room
|
||||||
|
func (r *Room) AddPeer(peer *Peer) error {
|
||||||
|
r.closedMu.RLock()
|
||||||
|
if r.closed {
|
||||||
|
r.closedMu.RUnlock()
|
||||||
|
return ErrRoomClosed
|
||||||
|
}
|
||||||
|
r.closedMu.RUnlock()
|
||||||
|
|
||||||
|
r.peersMu.Lock()
|
||||||
|
|
||||||
|
// Check max participants
|
||||||
|
if r.config.MaxParticipants > 0 && len(r.peers) >= r.config.MaxParticipants {
|
||||||
|
r.peersMu.Unlock()
|
||||||
|
return ErrRoomFull
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize peer connection
|
||||||
|
pcConfig := webrtc.Configuration{
|
||||||
|
ICEServers: r.config.ICEServers,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := peer.InitPeerConnection(r.api, pcConfig); err != nil {
|
||||||
|
r.peersMu.Unlock()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set up peer close handler
|
||||||
|
peer.OnClose(func(p *Peer) {
|
||||||
|
r.RemovePeer(p.ID)
|
||||||
|
})
|
||||||
|
|
||||||
|
r.peers[peer.ID] = peer
|
||||||
|
peerInfo := peer.GetInfo() // Get info while holding lock
|
||||||
|
totalPeers := len(r.peers)
|
||||||
|
|
||||||
|
// Release lock BEFORE broadcasting to avoid deadlock
|
||||||
|
// (broadcastMessage also acquires the lock)
|
||||||
|
r.peersMu.Unlock()
|
||||||
|
|
||||||
|
r.logger.Info("Peer added to room",
|
||||||
|
zap.String("peer_id", peer.ID),
|
||||||
|
zap.String("user_id", peer.UserID),
|
||||||
|
zap.Int("total_peers", totalPeers),
|
||||||
|
)
|
||||||
|
|
||||||
|
// Notify other peers (now safe since we released the lock)
|
||||||
|
r.broadcastMessage(peer.ID, NewServerMessage(MessageTypeParticipantJoined, &ParticipantJoinedData{
|
||||||
|
Participant: peerInfo,
|
||||||
|
}))
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemovePeer removes a peer from the room
|
||||||
|
func (r *Room) RemovePeer(peerID string) error {
|
||||||
|
r.peersMu.Lock()
|
||||||
|
peer, ok := r.peers[peerID]
|
||||||
|
if !ok {
|
||||||
|
r.peersMu.Unlock()
|
||||||
|
return ErrPeerNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
delete(r.peers, peerID)
|
||||||
|
remainingPeers := len(r.peers)
|
||||||
|
r.peersMu.Unlock()
|
||||||
|
|
||||||
|
// Remove tracks published by this peer
|
||||||
|
r.publishedTracksMu.Lock()
|
||||||
|
removedTracks := make([]string, 0)
|
||||||
|
for trackID, track := range r.publishedTracks {
|
||||||
|
if track.sourcePeerID == peerID {
|
||||||
|
delete(r.publishedTracks, trackID)
|
||||||
|
removedTracks = append(removedTracks, trackID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
r.publishedTracksMu.Unlock()
|
||||||
|
|
||||||
|
if len(removedTracks) > 0 {
|
||||||
|
r.logger.Info("Removed tracks for departing peer",
|
||||||
|
zap.String("peer_id", peerID),
|
||||||
|
zap.Strings("track_ids", removedTracks),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close the peer
|
||||||
|
if err := peer.Close(); err != nil {
|
||||||
|
r.logger.Warn("Error closing peer",
|
||||||
|
zap.String("peer_id", peerID),
|
||||||
|
zap.Error(err),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
r.logger.Info("Peer removed from room",
|
||||||
|
zap.String("peer_id", peerID),
|
||||||
|
zap.Int("remaining_peers", remainingPeers),
|
||||||
|
)
|
||||||
|
|
||||||
|
// Notify other peers
|
||||||
|
r.broadcastMessage(peerID, NewServerMessage(MessageTypeParticipantLeft, &ParticipantLeftData{
|
||||||
|
ParticipantID: peerID,
|
||||||
|
}))
|
||||||
|
|
||||||
|
// Check if room is empty
|
||||||
|
if remainingPeers == 0 && r.onEmpty != nil {
|
||||||
|
r.onEmpty(r)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPeer returns a peer by ID
|
||||||
|
func (r *Room) GetPeer(peerID string) (*Peer, error) {
|
||||||
|
r.peersMu.RLock()
|
||||||
|
defer r.peersMu.RUnlock()
|
||||||
|
|
||||||
|
peer, ok := r.peers[peerID]
|
||||||
|
if !ok {
|
||||||
|
return nil, ErrPeerNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
return peer, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPeers returns all peers in the room
|
||||||
|
func (r *Room) GetPeers() []*Peer {
|
||||||
|
r.peersMu.RLock()
|
||||||
|
defer r.peersMu.RUnlock()
|
||||||
|
|
||||||
|
peers := make([]*Peer, 0, len(r.peers))
|
||||||
|
for _, peer := range r.peers {
|
||||||
|
peers = append(peers, peer)
|
||||||
|
}
|
||||||
|
return peers
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetParticipants returns info about all participants
|
||||||
|
func (r *Room) GetParticipants() []ParticipantInfo {
|
||||||
|
r.peersMu.RLock()
|
||||||
|
defer r.peersMu.RUnlock()
|
||||||
|
|
||||||
|
participants := make([]ParticipantInfo, 0, len(r.peers))
|
||||||
|
for _, peer := range r.peers {
|
||||||
|
participants = append(participants, peer.GetInfo())
|
||||||
|
}
|
||||||
|
return participants
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetParticipantCount returns the number of participants
|
||||||
|
func (r *Room) GetParticipantCount() int {
|
||||||
|
r.peersMu.RLock()
|
||||||
|
defer r.peersMu.RUnlock()
|
||||||
|
return len(r.peers)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendExistingTracksTo sends all existing tracks from other participants to the specified peer.
|
||||||
|
// This should be called AFTER the welcome message is sent to ensure the client is ready.
|
||||||
|
// Uses batch mode to send all tracks with a single renegotiation for faster joins.
|
||||||
|
func (r *Room) SendExistingTracksTo(peer *Peer) {
|
||||||
|
r.publishedTracksMu.RLock()
|
||||||
|
existingTracks := make([]*publishedTrack, 0, len(r.publishedTracks))
|
||||||
|
for _, track := range r.publishedTracks {
|
||||||
|
// Don't send peer's own tracks back to them
|
||||||
|
if track.sourcePeerID != peer.ID {
|
||||||
|
existingTracks = append(existingTracks, track)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
r.publishedTracksMu.RUnlock()
|
||||||
|
|
||||||
|
if len(existingTracks) == 0 {
|
||||||
|
r.logger.Info("No existing tracks to send to new peer",
|
||||||
|
zap.String("peer_id", peer.ID),
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
r.logger.Info("Sending existing tracks to new peer (batch mode)",
|
||||||
|
zap.String("peer_id", peer.ID),
|
||||||
|
zap.Int("track_count", len(existingTracks)),
|
||||||
|
)
|
||||||
|
|
||||||
|
videoTrackIDs := make([]string, 0)
|
||||||
|
|
||||||
|
// Start batch mode - suppresses individual renegotiations
|
||||||
|
peer.StartTrackBatch()
|
||||||
|
|
||||||
|
for _, track := range existingTracks {
|
||||||
|
// Log forwarder status to help diagnose video issues
|
||||||
|
r.logger.Info("Adding existing track to new peer",
|
||||||
|
zap.String("new_peer_id", peer.ID),
|
||||||
|
zap.String("source_peer_id", track.sourcePeerID),
|
||||||
|
zap.String("source_user_id", track.sourceUserID),
|
||||||
|
zap.String("track_id", track.localTrack.ID()),
|
||||||
|
zap.String("kind", track.kind),
|
||||||
|
zap.Bool("forwarder_active", track.forwarderActive),
|
||||||
|
zap.Int("packets_forwarded", track.packetCount),
|
||||||
|
)
|
||||||
|
|
||||||
|
// Warn if forwarder is no longer active
|
||||||
|
if !track.forwarderActive {
|
||||||
|
r.logger.Warn("WARNING: Track forwarder is NOT active - track may not receive data",
|
||||||
|
zap.String("track_id", track.localTrack.ID()),
|
||||||
|
zap.String("kind", track.kind),
|
||||||
|
zap.Int("final_packet_count", track.packetCount),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := peer.AddTrack(track.localTrack); err != nil {
|
||||||
|
r.logger.Warn("Failed to add existing track to new peer",
|
||||||
|
zap.String("peer_id", peer.ID),
|
||||||
|
zap.String("track_id", track.localTrack.ID()),
|
||||||
|
zap.Error(err),
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Track video tracks for keyframe requests
|
||||||
|
if track.kind == "video" {
|
||||||
|
videoTrackIDs = append(videoTrackIDs, track.localTrack.ID())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Notify new peer about the existing track
|
||||||
|
peer.SendMessage(NewServerMessage(MessageTypeTrackAdded, &TrackAddedData{
|
||||||
|
ParticipantID: track.sourcePeerID,
|
||||||
|
UserID: track.sourceUserID,
|
||||||
|
TrackID: track.localTrack.ID(),
|
||||||
|
StreamID: track.localTrack.StreamID(),
|
||||||
|
Kind: track.kind,
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
// End batch mode - triggers single renegotiation for all tracks
|
||||||
|
peer.EndTrackBatch()
|
||||||
|
|
||||||
|
r.logger.Info("Batch track addition complete - single renegotiation triggered",
|
||||||
|
zap.String("peer_id", peer.ID),
|
||||||
|
zap.Int("total_tracks", len(existingTracks)),
|
||||||
|
)
|
||||||
|
|
||||||
|
// Request keyframes for video tracks after a short delay
|
||||||
|
// This ensures the receiver has time to set up the track before receiving the keyframe
|
||||||
|
if len(videoTrackIDs) > 0 {
|
||||||
|
go func() {
|
||||||
|
// Wait for negotiation to complete
|
||||||
|
time.Sleep(300 * time.Millisecond)
|
||||||
|
r.logger.Info("Requesting keyframes for new peer",
|
||||||
|
zap.String("peer_id", peer.ID),
|
||||||
|
zap.Int("video_track_count", len(videoTrackIDs)),
|
||||||
|
)
|
||||||
|
for _, trackID := range videoTrackIDs {
|
||||||
|
r.RequestKeyframe(trackID)
|
||||||
|
}
|
||||||
|
// Request again after 500ms in case the first was too early
|
||||||
|
time.Sleep(500 * time.Millisecond)
|
||||||
|
for _, trackID := range videoTrackIDs {
|
||||||
|
r.RequestKeyframe(trackID)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BroadcastTrack broadcasts a track from one peer to all other peers
|
||||||
|
func (r *Room) BroadcastTrack(sourcePeerID string, track *webrtc.TrackRemote) {
|
||||||
|
r.peersMu.RLock()
|
||||||
|
defer r.peersMu.RUnlock()
|
||||||
|
|
||||||
|
// Get source peer's user ID for the track-added message
|
||||||
|
sourceUserID := ""
|
||||||
|
if sourcePeer, ok := r.peers[sourcePeerID]; ok {
|
||||||
|
sourceUserID = sourcePeer.UserID
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a local track from the remote track
|
||||||
|
// Use participant ID as the stream ID so clients can identify the source
|
||||||
|
// Format: trackId stays the same, streamId = sourcePeerID (or sourceUserID if available)
|
||||||
|
streamID := sourcePeerID
|
||||||
|
if sourceUserID != "" {
|
||||||
|
streamID = sourceUserID // Use userID for easier client-side mapping
|
||||||
|
}
|
||||||
|
|
||||||
|
r.logger.Info("Creating local track for broadcast",
|
||||||
|
zap.String("source_peer_id", sourcePeerID),
|
||||||
|
zap.String("source_user_id", sourceUserID),
|
||||||
|
zap.String("original_track_id", track.ID()),
|
||||||
|
zap.String("original_stream_id", track.StreamID()),
|
||||||
|
zap.String("new_stream_id", streamID),
|
||||||
|
)
|
||||||
|
|
||||||
|
// Log codec information for debugging
|
||||||
|
codec := track.Codec()
|
||||||
|
r.logger.Info("Track codec info",
|
||||||
|
zap.String("source_peer_id", sourcePeerID),
|
||||||
|
zap.String("track_kind", track.Kind().String()),
|
||||||
|
zap.String("mime_type", codec.MimeType),
|
||||||
|
zap.Uint32("clock_rate", codec.ClockRate),
|
||||||
|
zap.Uint16("channels", codec.Channels),
|
||||||
|
zap.String("sdp_fmtp_line", codec.SDPFmtpLine),
|
||||||
|
)
|
||||||
|
|
||||||
|
localTrack, err := webrtc.NewTrackLocalStaticRTP(
|
||||||
|
codec.RTPCodecCapability,
|
||||||
|
track.Kind().String()+"-"+sourcePeerID, // Include peer ID in track ID
|
||||||
|
streamID, // Use participant/user ID as stream ID
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
r.logger.Error("Failed to create local track",
|
||||||
|
zap.String("source_peer", sourcePeerID),
|
||||||
|
zap.String("track_id", track.ID()),
|
||||||
|
zap.Error(err),
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store the track for new joiners
|
||||||
|
pubTrack := &publishedTrack{
|
||||||
|
sourcePeerID: sourcePeerID,
|
||||||
|
sourceUserID: sourceUserID,
|
||||||
|
localTrack: localTrack,
|
||||||
|
remoteTrackSSRC: uint32(track.SSRC()),
|
||||||
|
remoteTrack: track,
|
||||||
|
kind: track.Kind().String(),
|
||||||
|
forwarderActive: true,
|
||||||
|
packetCount: 0,
|
||||||
|
lastKeyframeRequest: time.Now(),
|
||||||
|
}
|
||||||
|
r.publishedTracksMu.Lock()
|
||||||
|
r.publishedTracks[localTrack.ID()] = pubTrack
|
||||||
|
r.publishedTracksMu.Unlock()
|
||||||
|
|
||||||
|
r.logger.Info("Track stored for new joiners",
|
||||||
|
zap.String("track_id", localTrack.ID()),
|
||||||
|
zap.String("source_peer_id", sourcePeerID),
|
||||||
|
zap.Int("total_published_tracks", len(r.publishedTracks)),
|
||||||
|
)
|
||||||
|
|
||||||
|
// Forward RTP packets from remote track to local track
|
||||||
|
go func() {
|
||||||
|
trackID := track.ID()
|
||||||
|
localTrackID := localTrack.ID()
|
||||||
|
trackKind := track.Kind().String()
|
||||||
|
buf := make([]byte, 1600) // Slightly larger than MTU to handle RTP extensions
|
||||||
|
packetCount := 0
|
||||||
|
byteCount := 0
|
||||||
|
startTime := time.Now()
|
||||||
|
firstPacketReceived := false
|
||||||
|
|
||||||
|
r.logger.Info("RTP forwarder started",
|
||||||
|
zap.String("track_id", trackID),
|
||||||
|
zap.String("local_track_id", localTrackID),
|
||||||
|
zap.String("kind", trackKind),
|
||||||
|
zap.String("source_peer_id", sourcePeerID),
|
||||||
|
)
|
||||||
|
|
||||||
|
// Start a goroutine to log warning if no packets received after 5 seconds
|
||||||
|
go func() {
|
||||||
|
time.Sleep(5 * time.Second)
|
||||||
|
if !firstPacketReceived {
|
||||||
|
r.logger.Warn("RTP forwarder WARNING: No packets received after 5 seconds - host may not be sending",
|
||||||
|
zap.String("track_id", trackID),
|
||||||
|
zap.String("local_track_id", localTrackID),
|
||||||
|
zap.String("kind", trackKind),
|
||||||
|
zap.String("source_peer_id", sourcePeerID),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// For video tracks, use adaptive keyframe requests based on packet loss
|
||||||
|
if trackKind == "video" {
|
||||||
|
go func() {
|
||||||
|
// Use a faster ticker for checking, but only send keyframes when needed
|
||||||
|
ticker := time.NewTicker(500 * time.Millisecond)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
var lastNackCount int64
|
||||||
|
var consecutiveLossDetections int
|
||||||
|
baseInterval := 3 * time.Second
|
||||||
|
minInterval := 500 * time.Millisecond // Minimum interval between keyframes
|
||||||
|
lastKeyframeTime := time.Now()
|
||||||
|
|
||||||
|
for range ticker.C {
|
||||||
|
// Check if room is closed
|
||||||
|
if r.IsClosed() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if forwarder is still active
|
||||||
|
r.publishedTracksMu.RLock()
|
||||||
|
pt, ok := r.publishedTracks[localTrackID]
|
||||||
|
if !ok || !pt.forwarderActive {
|
||||||
|
r.publishedTracksMu.RUnlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get current NACK count to detect packet loss
|
||||||
|
currentNackCount := pt.nackCount.Load()
|
||||||
|
r.publishedTracksMu.RUnlock()
|
||||||
|
|
||||||
|
timeSinceLastKeyframe := time.Since(lastKeyframeTime)
|
||||||
|
|
||||||
|
// Detect if packet loss is happening (NACKs increasing)
|
||||||
|
if currentNackCount > lastNackCount {
|
||||||
|
consecutiveLossDetections++
|
||||||
|
lastNackCount = currentNackCount
|
||||||
|
|
||||||
|
// If we detect packet loss and haven't requested a keyframe recently,
|
||||||
|
// request one immediately to help receivers recover
|
||||||
|
if timeSinceLastKeyframe >= minInterval {
|
||||||
|
r.logger.Debug("Adaptive keyframe request due to packet loss",
|
||||||
|
zap.String("track_id", localTrackID),
|
||||||
|
zap.Int64("nack_count", currentNackCount),
|
||||||
|
zap.Int("consecutive_loss_detections", consecutiveLossDetections),
|
||||||
|
)
|
||||||
|
r.RequestKeyframe(localTrackID)
|
||||||
|
lastKeyframeTime = time.Now()
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Reset consecutive loss counter when no new NACKs
|
||||||
|
if consecutiveLossDetections > 0 {
|
||||||
|
consecutiveLossDetections--
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Regular keyframe request at base interval (regardless of loss)
|
||||||
|
if timeSinceLastKeyframe >= baseInterval {
|
||||||
|
r.RequestKeyframe(localTrackID)
|
||||||
|
lastKeyframeTime = time.Now()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mark forwarder as stopped when we exit
|
||||||
|
defer func() {
|
||||||
|
r.publishedTracksMu.Lock()
|
||||||
|
if pt, ok := r.publishedTracks[localTrackID]; ok {
|
||||||
|
pt.forwarderActive = false
|
||||||
|
pt.packetCount = packetCount
|
||||||
|
}
|
||||||
|
r.publishedTracksMu.Unlock()
|
||||||
|
|
||||||
|
r.logger.Info("RTP forwarder exiting",
|
||||||
|
zap.String("track_id", trackID),
|
||||||
|
zap.String("local_track_id", localTrackID),
|
||||||
|
zap.String("kind", trackKind),
|
||||||
|
zap.Duration("lifetime", time.Since(startTime)),
|
||||||
|
zap.Int("total_packets", packetCount),
|
||||||
|
zap.Int("total_bytes", byteCount),
|
||||||
|
)
|
||||||
|
}()
|
||||||
|
|
||||||
|
for {
|
||||||
|
n, _, readErr := track.Read(buf)
|
||||||
|
if readErr != nil {
|
||||||
|
r.logger.Info("RTP forwarder stopped - read error",
|
||||||
|
zap.String("track_id", trackID),
|
||||||
|
zap.String("local_track_id", localTrackID),
|
||||||
|
zap.String("kind", trackKind),
|
||||||
|
zap.Int("packets_forwarded", packetCount),
|
||||||
|
zap.Int("bytes_forwarded", byteCount),
|
||||||
|
zap.Error(readErr),
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Log first packet to confirm data is flowing
|
||||||
|
if packetCount == 0 {
|
||||||
|
firstPacketReceived = true
|
||||||
|
r.logger.Info("RTP forwarder received FIRST packet",
|
||||||
|
zap.String("track_id", trackID),
|
||||||
|
zap.String("local_track_id", localTrackID),
|
||||||
|
zap.String("kind", trackKind),
|
||||||
|
zap.Int("packet_size", n),
|
||||||
|
zap.Duration("time_to_first_packet", time.Since(startTime)),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, writeErr := localTrack.Write(buf[:n]); writeErr != nil {
|
||||||
|
r.logger.Info("RTP forwarder stopped - write error",
|
||||||
|
zap.String("track_id", trackID),
|
||||||
|
zap.String("local_track_id", localTrackID),
|
||||||
|
zap.String("kind", trackKind),
|
||||||
|
zap.Int("packets_forwarded", packetCount),
|
||||||
|
zap.Int("bytes_forwarded", byteCount),
|
||||||
|
zap.Error(writeErr),
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
packetCount++
|
||||||
|
byteCount += n
|
||||||
|
|
||||||
|
// Update packet count periodically (not every packet to reduce lock contention)
|
||||||
|
if packetCount%50 == 0 {
|
||||||
|
r.publishedTracksMu.Lock()
|
||||||
|
if pt, ok := r.publishedTracks[localTrackID]; ok {
|
||||||
|
pt.packetCount = packetCount
|
||||||
|
}
|
||||||
|
r.publishedTracksMu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Log progress every 100 packets for video, 500 for audio
|
||||||
|
logInterval := 500
|
||||||
|
if trackKind == "video" {
|
||||||
|
logInterval = 100
|
||||||
|
}
|
||||||
|
if packetCount%logInterval == 0 {
|
||||||
|
r.logger.Info("RTP forwarder progress",
|
||||||
|
zap.String("track_id", trackID),
|
||||||
|
zap.String("kind", trackKind),
|
||||||
|
zap.Int("packets", packetCount),
|
||||||
|
zap.Int("bytes", byteCount),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Add track to all other peers
|
||||||
|
for peerID, peer := range r.peers {
|
||||||
|
if peerID == sourcePeerID {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
r.logger.Info("Adding track to peer",
|
||||||
|
zap.String("target_peer", peerID),
|
||||||
|
zap.String("source_peer", sourcePeerID),
|
||||||
|
zap.String("source_user", sourceUserID),
|
||||||
|
zap.String("track_id", localTrack.ID()),
|
||||||
|
zap.String("stream_id", localTrack.StreamID()),
|
||||||
|
)
|
||||||
|
|
||||||
|
if _, err := peer.AddTrack(localTrack); err != nil {
|
||||||
|
r.logger.Warn("Failed to add track to peer",
|
||||||
|
zap.String("target_peer", peerID),
|
||||||
|
zap.String("track_id", localTrack.ID()),
|
||||||
|
zap.Error(err),
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Notify peer about new track
|
||||||
|
// Use consistent IDs: participantId=sourcePeerID, streamId matches the track
|
||||||
|
peer.SendMessage(NewServerMessage(MessageTypeTrackAdded, &TrackAddedData{
|
||||||
|
ParticipantID: sourcePeerID,
|
||||||
|
UserID: sourceUserID,
|
||||||
|
TrackID: localTrack.ID(), // Format: "{kind}-{participantId}"
|
||||||
|
StreamID: localTrack.StreamID(), // Same as userId for easy matching
|
||||||
|
Kind: track.Kind().String(),
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
r.logger.Info("Track broadcast to room",
|
||||||
|
zap.String("source_peer", sourcePeerID),
|
||||||
|
zap.String("source_user", sourceUserID),
|
||||||
|
zap.String("track_id", track.ID()),
|
||||||
|
zap.String("kind", track.Kind().String()),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// broadcastMessage sends a message to all peers except the specified one
|
||||||
|
func (r *Room) broadcastMessage(excludePeerID string, msg *ServerMessage) {
|
||||||
|
r.peersMu.RLock()
|
||||||
|
defer r.peersMu.RUnlock()
|
||||||
|
|
||||||
|
for peerID, peer := range r.peers {
|
||||||
|
if peerID == excludePeerID {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := peer.SendMessage(msg); err != nil {
|
||||||
|
r.logger.Warn("Failed to send message to peer",
|
||||||
|
zap.String("peer_id", peerID),
|
||||||
|
zap.Error(err),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the room and all peer connections
|
||||||
|
func (r *Room) Close() error {
|
||||||
|
r.closedMu.Lock()
|
||||||
|
if r.closed {
|
||||||
|
r.closedMu.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
r.closed = true
|
||||||
|
r.closedMu.Unlock()
|
||||||
|
|
||||||
|
r.logger.Info("Closing room")
|
||||||
|
|
||||||
|
r.peersMu.Lock()
|
||||||
|
peers := make([]*Peer, 0, len(r.peers))
|
||||||
|
for _, peer := range r.peers {
|
||||||
|
peers = append(peers, peer)
|
||||||
|
}
|
||||||
|
r.peers = make(map[string]*Peer)
|
||||||
|
r.peersMu.Unlock()
|
||||||
|
|
||||||
|
// Close all peers
|
||||||
|
for _, peer := range peers {
|
||||||
|
if err := peer.Close(); err != nil {
|
||||||
|
r.logger.Warn("Error closing peer",
|
||||||
|
zap.String("peer_id", peer.ID),
|
||||||
|
zap.Error(err),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnEmpty sets a callback for when the room becomes empty
|
||||||
|
func (r *Room) OnEmpty(fn func(*Room)) {
|
||||||
|
r.onEmpty = fn
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsClosed returns whether the room is closed
|
||||||
|
func (r *Room) IsClosed() bool {
|
||||||
|
r.closedMu.RLock()
|
||||||
|
defer r.closedMu.RUnlock()
|
||||||
|
return r.closed
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequestKeyframe sends a PLI (Picture Loss Indication) to the source peer for a video track.
|
||||||
|
// This causes the source to send a keyframe, which is needed for new receivers to start decoding.
|
||||||
|
func (r *Room) RequestKeyframe(trackID string) {
|
||||||
|
r.publishedTracksMu.RLock()
|
||||||
|
track, ok := r.publishedTracks[trackID]
|
||||||
|
r.publishedTracksMu.RUnlock()
|
||||||
|
|
||||||
|
if !ok || track.kind != "video" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
r.peersMu.RLock()
|
||||||
|
sourcePeer, ok := r.peers[track.sourcePeerID]
|
||||||
|
r.peersMu.RUnlock()
|
||||||
|
|
||||||
|
if !ok || sourcePeer.pc == nil {
|
||||||
|
r.logger.Debug("Cannot request keyframe - source peer not found",
|
||||||
|
zap.String("track_id", trackID),
|
||||||
|
zap.String("source_peer_id", track.sourcePeerID),
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a PLI packet
|
||||||
|
pli := &rtcp.PictureLossIndication{
|
||||||
|
MediaSSRC: track.remoteTrackSSRC,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send the PLI to the source peer
|
||||||
|
if err := sourcePeer.pc.WriteRTCP([]rtcp.Packet{pli}); err != nil {
|
||||||
|
r.logger.Debug("Failed to send PLI",
|
||||||
|
zap.String("track_id", trackID),
|
||||||
|
zap.String("source_peer_id", track.sourcePeerID),
|
||||||
|
zap.Error(err),
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
r.logger.Debug("PLI keyframe request sent",
|
||||||
|
zap.String("track_id", trackID),
|
||||||
|
zap.String("source_peer_id", track.sourcePeerID),
|
||||||
|
zap.Uint32("ssrc", track.remoteTrackSSRC),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequestKeyframeForAllVideoTracks sends PLI requests for all video tracks in the room.
|
||||||
|
// This is useful when a new peer joins to ensure they get keyframes quickly.
|
||||||
|
func (r *Room) RequestKeyframeForAllVideoTracks() {
|
||||||
|
r.publishedTracksMu.RLock()
|
||||||
|
videoTrackIDs := make([]string, 0)
|
||||||
|
for trackID, track := range r.publishedTracks {
|
||||||
|
if track.kind == "video" {
|
||||||
|
videoTrackIDs = append(videoTrackIDs, trackID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
r.publishedTracksMu.RUnlock()
|
||||||
|
|
||||||
|
for _, trackID := range videoTrackIDs {
|
||||||
|
r.RequestKeyframe(trackID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// IncrementNackCount increments the NACK counter for a track.
|
||||||
|
// This is called when we receive NACK feedback indicating packet loss.
|
||||||
|
func (r *Room) IncrementNackCount(trackID string) {
|
||||||
|
r.publishedTracksMu.RLock()
|
||||||
|
track, ok := r.publishedTracks[trackID]
|
||||||
|
r.publishedTracksMu.RUnlock()
|
||||||
|
|
||||||
|
if ok {
|
||||||
|
track.nackCount.Add(1)
|
||||||
|
}
|
||||||
|
}
|
||||||
145
pkg/gateway/sfu/signaling.go
Normal file
145
pkg/gateway/sfu/signaling.go
Normal file
@ -0,0 +1,145 @@
|
|||||||
|
package sfu
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
|
||||||
|
"github.com/pion/webrtc/v4"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MessageType represents the type of signaling message
|
||||||
|
type MessageType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
// Client -> Server message types
|
||||||
|
MessageTypeJoin MessageType = "join"
|
||||||
|
MessageTypeLeave MessageType = "leave"
|
||||||
|
MessageTypeOffer MessageType = "offer"
|
||||||
|
MessageTypeAnswer MessageType = "answer"
|
||||||
|
MessageTypeICECandidate MessageType = "ice-candidate"
|
||||||
|
MessageTypeMute MessageType = "mute"
|
||||||
|
MessageTypeUnmute MessageType = "unmute"
|
||||||
|
MessageTypeStartVideo MessageType = "start-video"
|
||||||
|
MessageTypeStopVideo MessageType = "stop-video"
|
||||||
|
|
||||||
|
// Server -> Client message types
|
||||||
|
MessageTypeWelcome MessageType = "welcome"
|
||||||
|
MessageTypeParticipantJoined MessageType = "participant-joined"
|
||||||
|
MessageTypeParticipantLeft MessageType = "participant-left"
|
||||||
|
MessageTypeTrackAdded MessageType = "track-added"
|
||||||
|
MessageTypeTrackRemoved MessageType = "track-removed"
|
||||||
|
MessageTypeError MessageType = "error"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ClientMessage represents a message from client to server
|
||||||
|
type ClientMessage struct {
|
||||||
|
Type MessageType `json:"type"`
|
||||||
|
Data json.RawMessage `json:"data,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ServerMessage represents a message from server to client
|
||||||
|
type ServerMessage struct {
|
||||||
|
Type MessageType `json:"type"`
|
||||||
|
Data interface{} `json:"data,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// JoinData is the payload for join messages
|
||||||
|
type JoinData struct {
|
||||||
|
DisplayName string `json:"displayName"`
|
||||||
|
AudioOnly bool `json:"audioOnly,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// OfferData is the payload for offer messages
|
||||||
|
type OfferData struct {
|
||||||
|
SDP string `json:"sdp"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// AnswerData is the payload for answer messages
|
||||||
|
type AnswerData struct {
|
||||||
|
SDP string `json:"sdp"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ICECandidateData is the payload for ICE candidate messages
|
||||||
|
type ICECandidateData struct {
|
||||||
|
Candidate string `json:"candidate"`
|
||||||
|
SDPMid string `json:"sdpMid,omitempty"`
|
||||||
|
SDPMLineIndex uint16 `json:"sdpMLineIndex,omitempty"`
|
||||||
|
UsernameFragment string `json:"usernameFragment,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToWebRTCCandidate converts ICECandidateData to webrtc.ICECandidateInit
|
||||||
|
func (c *ICECandidateData) ToWebRTCCandidate() webrtc.ICECandidateInit {
|
||||||
|
return webrtc.ICECandidateInit{
|
||||||
|
Candidate: c.Candidate,
|
||||||
|
SDPMid: &c.SDPMid,
|
||||||
|
SDPMLineIndex: &c.SDPMLineIndex,
|
||||||
|
UsernameFragment: &c.UsernameFragment,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WelcomeData is sent to a client when they successfully join
|
||||||
|
type WelcomeData struct {
|
||||||
|
ParticipantID string `json:"participantId"`
|
||||||
|
RoomID string `json:"roomId"`
|
||||||
|
Participants []ParticipantInfo `json:"participants"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParticipantInfo contains public information about a participant
|
||||||
|
type ParticipantInfo struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
UserID string `json:"userId"`
|
||||||
|
DisplayName string `json:"displayName"`
|
||||||
|
HasAudio bool `json:"hasAudio"`
|
||||||
|
HasVideo bool `json:"hasVideo"`
|
||||||
|
AudioMuted bool `json:"audioMuted"`
|
||||||
|
VideoMuted bool `json:"videoMuted"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParticipantJoinedData is sent when a new participant joins
|
||||||
|
type ParticipantJoinedData struct {
|
||||||
|
Participant ParticipantInfo `json:"participant"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParticipantLeftData is sent when a participant leaves
|
||||||
|
type ParticipantLeftData struct {
|
||||||
|
ParticipantID string `json:"participantId"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// TrackAddedData is sent when a new track is available
|
||||||
|
type TrackAddedData struct {
|
||||||
|
ParticipantID string `json:"participantId"` // Internal SFU peer ID
|
||||||
|
UserID string `json:"userId"` // The user's actual ID for easier mapping
|
||||||
|
TrackID string `json:"trackId"` // Format: "{kind}-{participantId}"
|
||||||
|
StreamID string `json:"streamId"` // Same as userId (for WebRTC stream matching)
|
||||||
|
Kind string `json:"kind"` // "audio" or "video"
|
||||||
|
}
|
||||||
|
|
||||||
|
// TrackRemovedData is sent when a track is removed
|
||||||
|
type TrackRemovedData struct {
|
||||||
|
ParticipantID string `json:"participantId"`
|
||||||
|
UserID string `json:"userId"` // The user's actual ID for easier mapping
|
||||||
|
TrackID string `json:"trackId"`
|
||||||
|
StreamID string `json:"streamId"`
|
||||||
|
Kind string `json:"kind"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ErrorData is sent when an error occurs
|
||||||
|
type ErrorData struct {
|
||||||
|
Code string `json:"code"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewServerMessage creates a new server message
|
||||||
|
func NewServerMessage(msgType MessageType, data interface{}) *ServerMessage {
|
||||||
|
return &ServerMessage{
|
||||||
|
Type: msgType,
|
||||||
|
Data: data,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewErrorMessage creates a new error message
|
||||||
|
func NewErrorMessage(code, message string) *ServerMessage {
|
||||||
|
return NewServerMessage(MessageTypeError, &ErrorData{
|
||||||
|
Code: code,
|
||||||
|
Message: message,
|
||||||
|
})
|
||||||
|
}
|
||||||
573
pkg/gateway/sfu_handlers.go
Normal file
573
pkg/gateway/sfu_handlers.go
Normal file
@ -0,0 +1,573 @@
|
|||||||
|
package gateway
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/DeBrosOfficial/network/pkg/gateway/sfu"
|
||||||
|
"github.com/DeBrosOfficial/network/pkg/logging"
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
"github.com/pion/webrtc/v4"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SFUManager is a type alias for the SFU room manager
|
||||||
|
type SFUManager = sfu.RoomManager
|
||||||
|
|
||||||
|
// CreateRoomRequest is the request body for creating a room
|
||||||
|
type CreateRoomRequest struct {
|
||||||
|
RoomID string `json:"roomId"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateRoomResponse is the response for creating/joining a room
|
||||||
|
type CreateRoomResponse struct {
|
||||||
|
RoomID string `json:"roomId"`
|
||||||
|
Created bool `json:"created"`
|
||||||
|
RTPCapabilities map[string]interface{} `json:"rtpCapabilities"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// JoinRoomRequest is the request body for joining a room
|
||||||
|
type JoinRoomRequest struct {
|
||||||
|
DisplayName string `json:"displayName"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// JoinRoomResponse is the response for joining a room
|
||||||
|
type JoinRoomResponse struct {
|
||||||
|
ParticipantID string `json:"participantId"`
|
||||||
|
Participants []sfu.ParticipantInfo `json:"participants"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// sfuCreateRoomHandler handles POST /v1/sfu/room
|
||||||
|
func (g *Gateway) sfuCreateRoomHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodPost {
|
||||||
|
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if SFU is enabled
|
||||||
|
if g.sfuManager == nil {
|
||||||
|
writeError(w, http.StatusServiceUnavailable, "SFU service not enabled")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get namespace from auth context
|
||||||
|
ns := resolveNamespaceFromRequest(r)
|
||||||
|
if ns == "" {
|
||||||
|
writeError(w, http.StatusForbidden, "namespace not resolved")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse request
|
||||||
|
var req CreateRoomRequest
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
|
writeError(w, http.StatusBadRequest, "invalid request body")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.RoomID == "" {
|
||||||
|
writeError(w, http.StatusBadRequest, "roomId is required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create or get room
|
||||||
|
room, created := g.sfuManager.GetOrCreateRoom(ns, req.RoomID)
|
||||||
|
|
||||||
|
g.logger.ComponentInfo(logging.ComponentGeneral, "SFU room request",
|
||||||
|
zap.String("room_id", req.RoomID),
|
||||||
|
zap.String("namespace", ns),
|
||||||
|
zap.Bool("created", created),
|
||||||
|
zap.Int("participants", room.GetParticipantCount()),
|
||||||
|
)
|
||||||
|
|
||||||
|
writeJSON(w, http.StatusOK, &CreateRoomResponse{
|
||||||
|
RoomID: req.RoomID,
|
||||||
|
Created: created,
|
||||||
|
RTPCapabilities: sfu.GetRTPCapabilities(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// sfuRoomHandler handles all /v1/sfu/room/:roomId/* endpoints
|
||||||
|
func (g *Gateway) sfuRoomHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Check if SFU is enabled
|
||||||
|
if g.sfuManager == nil {
|
||||||
|
writeError(w, http.StatusServiceUnavailable, "SFU service not enabled")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get namespace from auth context
|
||||||
|
ns := resolveNamespaceFromRequest(r)
|
||||||
|
if ns == "" {
|
||||||
|
writeError(w, http.StatusForbidden, "namespace not resolved")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse room ID and action from path
|
||||||
|
// Path format: /v1/sfu/room/{roomId}/{action}
|
||||||
|
path := strings.TrimPrefix(r.URL.Path, "/v1/sfu/room/")
|
||||||
|
parts := strings.SplitN(path, "/", 2)
|
||||||
|
|
||||||
|
if len(parts) == 0 || parts[0] == "" {
|
||||||
|
writeError(w, http.StatusBadRequest, "room ID required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
roomID := parts[0]
|
||||||
|
action := ""
|
||||||
|
if len(parts) > 1 {
|
||||||
|
action = parts[1]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Route to appropriate handler
|
||||||
|
switch action {
|
||||||
|
case "":
|
||||||
|
// GET /v1/sfu/room/:roomId - Get room info
|
||||||
|
g.sfuGetRoomHandler(w, r, ns, roomID)
|
||||||
|
case "join":
|
||||||
|
// POST /v1/sfu/room/:roomId/join - Join room
|
||||||
|
g.sfuJoinRoomHandler(w, r, ns, roomID)
|
||||||
|
case "leave":
|
||||||
|
// POST /v1/sfu/room/:roomId/leave - Leave room
|
||||||
|
g.sfuLeaveRoomHandler(w, r, ns, roomID)
|
||||||
|
case "participants":
|
||||||
|
// GET /v1/sfu/room/:roomId/participants - List participants
|
||||||
|
g.sfuParticipantsHandler(w, r, ns, roomID)
|
||||||
|
case "ws":
|
||||||
|
// GET /v1/sfu/room/:roomId/ws - WebSocket signaling
|
||||||
|
g.sfuWebSocketHandler(w, r, ns, roomID)
|
||||||
|
default:
|
||||||
|
writeError(w, http.StatusNotFound, "unknown action")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// sfuGetRoomHandler handles GET /v1/sfu/room/:roomId
|
||||||
|
func (g *Gateway) sfuGetRoomHandler(w http.ResponseWriter, r *http.Request, ns, roomID string) {
|
||||||
|
if r.Method != http.MethodGet {
|
||||||
|
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
info, err := g.sfuManager.GetRoomInfo(ns, roomID)
|
||||||
|
if err != nil {
|
||||||
|
if err == sfu.ErrRoomNotFound {
|
||||||
|
writeError(w, http.StatusNotFound, "room not found")
|
||||||
|
} else {
|
||||||
|
writeError(w, http.StatusInternalServerError, err.Error())
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
writeJSON(w, http.StatusOK, info)
|
||||||
|
}
|
||||||
|
|
||||||
|
// sfuJoinRoomHandler handles POST /v1/sfu/room/:roomId/join
|
||||||
|
func (g *Gateway) sfuJoinRoomHandler(w http.ResponseWriter, r *http.Request, ns, roomID string) {
|
||||||
|
if r.Method != http.MethodPost {
|
||||||
|
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse request
|
||||||
|
var req JoinRoomRequest
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
|
// Default display name if not provided
|
||||||
|
req.DisplayName = "Anonymous"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get user ID from query param or auth context
|
||||||
|
userID := r.URL.Query().Get("userId")
|
||||||
|
if userID == "" {
|
||||||
|
userID = g.extractUserID(r)
|
||||||
|
}
|
||||||
|
if userID == "" {
|
||||||
|
userID = "anonymous"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get or create room
|
||||||
|
room, _ := g.sfuManager.GetOrCreateRoom(ns, roomID)
|
||||||
|
|
||||||
|
// This endpoint just returns room info
|
||||||
|
// The actual peer connection is established via WebSocket
|
||||||
|
writeJSON(w, http.StatusOK, &JoinRoomResponse{
|
||||||
|
ParticipantID: "", // Will be assigned when WebSocket connects
|
||||||
|
Participants: room.GetParticipants(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// sfuLeaveRoomHandler handles POST /v1/sfu/room/:roomId/leave
|
||||||
|
func (g *Gateway) sfuLeaveRoomHandler(w http.ResponseWriter, r *http.Request, ns, roomID string) {
|
||||||
|
if r.Method != http.MethodPost {
|
||||||
|
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse participant ID from request body
|
||||||
|
var req struct {
|
||||||
|
ParticipantID string `json:"participantId"`
|
||||||
|
}
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil || req.ParticipantID == "" {
|
||||||
|
writeError(w, http.StatusBadRequest, "participantId required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
room, err := g.sfuManager.GetRoom(ns, roomID)
|
||||||
|
if err != nil {
|
||||||
|
if err == sfu.ErrRoomNotFound {
|
||||||
|
writeError(w, http.StatusNotFound, "room not found")
|
||||||
|
} else {
|
||||||
|
writeError(w, http.StatusInternalServerError, err.Error())
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := room.RemovePeer(req.ParticipantID); err != nil {
|
||||||
|
if err == sfu.ErrPeerNotFound {
|
||||||
|
writeError(w, http.StatusNotFound, "participant not found")
|
||||||
|
} else {
|
||||||
|
writeError(w, http.StatusInternalServerError, err.Error())
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
writeJSON(w, http.StatusOK, map[string]interface{}{
|
||||||
|
"success": true,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// sfuParticipantsHandler handles GET /v1/sfu/room/:roomId/participants
|
||||||
|
func (g *Gateway) sfuParticipantsHandler(w http.ResponseWriter, r *http.Request, ns, roomID string) {
|
||||||
|
if r.Method != http.MethodGet {
|
||||||
|
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
room, err := g.sfuManager.GetRoom(ns, roomID)
|
||||||
|
if err != nil {
|
||||||
|
if err == sfu.ErrRoomNotFound {
|
||||||
|
writeError(w, http.StatusNotFound, "room not found")
|
||||||
|
} else {
|
||||||
|
writeError(w, http.StatusInternalServerError, err.Error())
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
writeJSON(w, http.StatusOK, map[string]interface{}{
|
||||||
|
"participants": room.GetParticipants(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// sfuWebSocketHandler handles WebSocket signaling for a room
|
||||||
|
func (g *Gateway) sfuWebSocketHandler(w http.ResponseWriter, r *http.Request, ns, roomID string) {
|
||||||
|
if r.Method != http.MethodGet {
|
||||||
|
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get user ID and display name from query parameters
|
||||||
|
// Priority: 1) userId query param, 2) JWT/API key auth, 3) "anonymous"
|
||||||
|
userID := r.URL.Query().Get("userId")
|
||||||
|
if userID == "" {
|
||||||
|
// Fall back to authentication-based user ID
|
||||||
|
userID = g.extractUserID(r)
|
||||||
|
}
|
||||||
|
if userID == "" {
|
||||||
|
userID = "anonymous"
|
||||||
|
}
|
||||||
|
|
||||||
|
displayName := r.URL.Query().Get("displayName")
|
||||||
|
if displayName == "" {
|
||||||
|
displayName = "Anonymous"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Upgrade to WebSocket
|
||||||
|
conn, err := wsUpgrader.Upgrade(w, r, nil)
|
||||||
|
if err != nil {
|
||||||
|
g.logger.ComponentWarn(logging.ComponentGeneral, "SFU WebSocket upgrade failed",
|
||||||
|
zap.Error(err),
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Recover from panics to avoid silent crashes
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
g.logger.ComponentError(logging.ComponentGeneral, "SFU WebSocket handler panic",
|
||||||
|
zap.Any("panic", r),
|
||||||
|
zap.String("room_id", roomID),
|
||||||
|
zap.String("user_id", userID),
|
||||||
|
)
|
||||||
|
conn.Close()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
g.logger.ComponentInfo(logging.ComponentGeneral, "SFU WebSocket connected",
|
||||||
|
zap.String("room_id", roomID),
|
||||||
|
zap.String("namespace", ns),
|
||||||
|
zap.String("user_id", userID),
|
||||||
|
zap.String("display_name", displayName),
|
||||||
|
)
|
||||||
|
|
||||||
|
// Get or create room
|
||||||
|
g.logger.ComponentDebug(logging.ComponentGeneral, "SFU getting/creating room",
|
||||||
|
zap.String("room_id", roomID),
|
||||||
|
zap.String("namespace", ns),
|
||||||
|
)
|
||||||
|
room, created := g.sfuManager.GetOrCreateRoom(ns, roomID)
|
||||||
|
g.logger.ComponentInfo(logging.ComponentGeneral, "SFU room ready",
|
||||||
|
zap.String("room_id", roomID),
|
||||||
|
zap.Bool("created", created),
|
||||||
|
)
|
||||||
|
|
||||||
|
// Create peer
|
||||||
|
peer := sfu.NewPeer(userID, displayName, conn, room, g.logger.Logger)
|
||||||
|
g.logger.ComponentInfo(logging.ComponentGeneral, "SFU peer created",
|
||||||
|
zap.String("peer_id", peer.ID),
|
||||||
|
)
|
||||||
|
|
||||||
|
// Add peer to room
|
||||||
|
g.logger.ComponentDebug(logging.ComponentGeneral, "SFU adding peer to room (will init peer connection)",
|
||||||
|
zap.String("peer_id", peer.ID),
|
||||||
|
zap.String("room_id", roomID),
|
||||||
|
)
|
||||||
|
if err := room.AddPeer(peer); err != nil {
|
||||||
|
g.logger.ComponentError(logging.ComponentGeneral, "Failed to add peer to room",
|
||||||
|
zap.String("room_id", roomID),
|
||||||
|
zap.Error(err),
|
||||||
|
)
|
||||||
|
peer.SendMessage(sfu.NewErrorMessage("join_failed", err.Error()))
|
||||||
|
conn.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
g.logger.ComponentInfo(logging.ComponentGeneral, "SFU peer added to room successfully",
|
||||||
|
zap.String("peer_id", peer.ID),
|
||||||
|
zap.String("room_id", roomID),
|
||||||
|
)
|
||||||
|
|
||||||
|
// Send welcome message
|
||||||
|
g.logger.ComponentDebug(logging.ComponentGeneral, "SFU preparing welcome message",
|
||||||
|
zap.String("peer_id", peer.ID),
|
||||||
|
zap.Int("num_participants", len(room.GetParticipants())),
|
||||||
|
)
|
||||||
|
welcomeMsg := sfu.NewServerMessage(sfu.MessageTypeWelcome, &sfu.WelcomeData{
|
||||||
|
ParticipantID: peer.ID,
|
||||||
|
RoomID: roomID,
|
||||||
|
Participants: room.GetParticipants(),
|
||||||
|
})
|
||||||
|
g.logger.ComponentDebug(logging.ComponentGeneral, "SFU sending welcome message via SendMessage",
|
||||||
|
zap.String("peer_id", peer.ID),
|
||||||
|
)
|
||||||
|
if err := peer.SendMessage(welcomeMsg); err != nil {
|
||||||
|
g.logger.ComponentError(logging.ComponentGeneral, "Failed to send welcome message",
|
||||||
|
zap.String("peer_id", peer.ID),
|
||||||
|
zap.Error(err),
|
||||||
|
)
|
||||||
|
conn.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
g.logger.ComponentInfo(logging.ComponentGeneral, "SFU welcome message sent successfully",
|
||||||
|
zap.String("peer_id", peer.ID),
|
||||||
|
zap.String("room_id", roomID),
|
||||||
|
)
|
||||||
|
|
||||||
|
// Handle signaling messages
|
||||||
|
// Note: existing tracks are sent AFTER the first offer/answer exchange completes
|
||||||
|
g.handleSFUSignaling(conn, peer, room)
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleSFUSignaling handles WebSocket signaling messages for a peer
|
||||||
|
func (g *Gateway) handleSFUSignaling(conn *websocket.Conn, peer *sfu.Peer, room *sfu.Room) {
|
||||||
|
defer func() {
|
||||||
|
room.RemovePeer(peer.ID)
|
||||||
|
conn.Close()
|
||||||
|
g.logger.ComponentInfo(logging.ComponentGeneral, "SFU WebSocket disconnected",
|
||||||
|
zap.String("peer_id", peer.ID),
|
||||||
|
zap.String("room_id", room.ID),
|
||||||
|
)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Set up ping/pong for keepalive
|
||||||
|
conn.SetReadDeadline(time.Now().Add(60 * time.Second))
|
||||||
|
conn.SetPongHandler(func(string) error {
|
||||||
|
conn.SetReadDeadline(time.Now().Add(60 * time.Second))
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
// Start ping ticker
|
||||||
|
ticker := time.NewTicker(30 * time.Second)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
for range ticker.C {
|
||||||
|
if err := conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(5*time.Second)); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Read messages
|
||||||
|
for {
|
||||||
|
_, data, err := conn.ReadMessage()
|
||||||
|
if err != nil {
|
||||||
|
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
|
||||||
|
g.logger.ComponentWarn(logging.ComponentGeneral, "SFU WebSocket read error",
|
||||||
|
zap.String("peer_id", peer.ID),
|
||||||
|
zap.Error(err),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset read deadline
|
||||||
|
conn.SetReadDeadline(time.Now().Add(60 * time.Second))
|
||||||
|
|
||||||
|
// Parse message
|
||||||
|
var msg sfu.ClientMessage
|
||||||
|
if err := json.Unmarshal(data, &msg); err != nil {
|
||||||
|
peer.SendMessage(sfu.NewErrorMessage("invalid_message", "failed to parse message"))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle message
|
||||||
|
g.handleSFUMessage(peer, room, &msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleSFUMessage handles a single signaling message
|
||||||
|
func (g *Gateway) handleSFUMessage(peer *sfu.Peer, room *sfu.Room, msg *sfu.ClientMessage) {
|
||||||
|
switch msg.Type {
|
||||||
|
case sfu.MessageTypeOffer:
|
||||||
|
var data sfu.OfferData
|
||||||
|
if err := json.Unmarshal(msg.Data, &data); err != nil {
|
||||||
|
peer.SendMessage(sfu.NewErrorMessage("invalid_offer", err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := peer.HandleOffer(data.SDP); err != nil {
|
||||||
|
peer.SendMessage(sfu.NewErrorMessage("offer_failed", err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// After successfully handling the FIRST offer, send existing tracks
|
||||||
|
// This ensures the WebRTC connection is established before adding more tracks
|
||||||
|
if peer.MarkInitialOfferHandled() {
|
||||||
|
g.logger.ComponentInfo(logging.ComponentGeneral, "First offer handled, sending existing tracks",
|
||||||
|
zap.String("peer_id", peer.ID),
|
||||||
|
)
|
||||||
|
room.SendExistingTracksTo(peer)
|
||||||
|
}
|
||||||
|
|
||||||
|
case sfu.MessageTypeAnswer:
|
||||||
|
var data sfu.AnswerData
|
||||||
|
if err := json.Unmarshal(msg.Data, &data); err != nil {
|
||||||
|
peer.SendMessage(sfu.NewErrorMessage("invalid_answer", err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := peer.HandleAnswer(data.SDP); err != nil {
|
||||||
|
peer.SendMessage(sfu.NewErrorMessage("answer_failed", err.Error()))
|
||||||
|
}
|
||||||
|
|
||||||
|
case sfu.MessageTypeICECandidate:
|
||||||
|
var data sfu.ICECandidateData
|
||||||
|
if err := json.Unmarshal(msg.Data, &data); err != nil {
|
||||||
|
peer.SendMessage(sfu.NewErrorMessage("invalid_candidate", err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := peer.HandleICECandidate(&data); err != nil {
|
||||||
|
peer.SendMessage(sfu.NewErrorMessage("candidate_failed", err.Error()))
|
||||||
|
}
|
||||||
|
|
||||||
|
case sfu.MessageTypeMute:
|
||||||
|
peer.SetAudioMuted(true)
|
||||||
|
g.logger.ComponentDebug(logging.ComponentGeneral, "Peer muted audio", zap.String("peer_id", peer.ID))
|
||||||
|
|
||||||
|
case sfu.MessageTypeUnmute:
|
||||||
|
peer.SetAudioMuted(false)
|
||||||
|
g.logger.ComponentDebug(logging.ComponentGeneral, "Peer unmuted audio", zap.String("peer_id", peer.ID))
|
||||||
|
|
||||||
|
case sfu.MessageTypeStartVideo:
|
||||||
|
peer.SetVideoMuted(false)
|
||||||
|
g.logger.ComponentDebug(logging.ComponentGeneral, "Peer started video", zap.String("peer_id", peer.ID))
|
||||||
|
|
||||||
|
case sfu.MessageTypeStopVideo:
|
||||||
|
peer.SetVideoMuted(true)
|
||||||
|
g.logger.ComponentDebug(logging.ComponentGeneral, "Peer stopped video", zap.String("peer_id", peer.ID))
|
||||||
|
|
||||||
|
case sfu.MessageTypeLeave:
|
||||||
|
// Will be handled by deferred cleanup
|
||||||
|
g.logger.ComponentInfo(logging.ComponentGeneral, "Peer leaving room", zap.String("peer_id", peer.ID))
|
||||||
|
|
||||||
|
default:
|
||||||
|
peer.SendMessage(sfu.NewErrorMessage("unknown_type", "unknown message type"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// initializeSFUManager initializes the SFU manager with the gateway's TURN config
|
||||||
|
func (g *Gateway) initializeSFUManager() error {
|
||||||
|
if g.cfg.SFU == nil || !g.cfg.SFU.Enabled {
|
||||||
|
g.logger.ComponentInfo(logging.ComponentGeneral, "SFU service disabled")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build ICE servers from config
|
||||||
|
iceServers := make([]webrtc.ICEServer, 0)
|
||||||
|
|
||||||
|
// Add configured ICE servers
|
||||||
|
if g.cfg.SFU.ICEServers != nil {
|
||||||
|
for _, server := range g.cfg.SFU.ICEServers {
|
||||||
|
iceServers = append(iceServers, webrtc.ICEServer{
|
||||||
|
URLs: server.URLs,
|
||||||
|
Username: server.Username,
|
||||||
|
Credential: server.Credential,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add TURN servers if configured (credentials will be generated dynamically)
|
||||||
|
if g.cfg.TURN != nil {
|
||||||
|
// Determine hostname for ICE server URLs
|
||||||
|
// Use configured domain, or fallback to localhost for local development
|
||||||
|
iceHost := g.cfg.DomainName
|
||||||
|
if iceHost == "" {
|
||||||
|
iceHost = "localhost"
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(g.cfg.TURN.STUNURLs) > 0 {
|
||||||
|
// Process URLs to replace empty hostnames (e.g., "stun::3478" -> "stun:localhost:3478")
|
||||||
|
processedURLs := processURLsWithHost(g.cfg.TURN.STUNURLs, iceHost)
|
||||||
|
iceServers = append(iceServers, webrtc.ICEServer{
|
||||||
|
URLs: processedURLs,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
// Note: TURN credentials are time-limited, so clients should fetch them
|
||||||
|
// via the /v1/turn/credentials endpoint before joining a room
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create SFU config
|
||||||
|
sfuConfig := &sfu.Config{
|
||||||
|
MaxParticipants: g.cfg.SFU.MaxParticipants,
|
||||||
|
MediaTimeout: g.cfg.SFU.MediaTimeout,
|
||||||
|
ICEServers: iceServers,
|
||||||
|
}
|
||||||
|
|
||||||
|
if sfuConfig.MaxParticipants == 0 {
|
||||||
|
sfuConfig.MaxParticipants = 10
|
||||||
|
}
|
||||||
|
if sfuConfig.MediaTimeout == 0 {
|
||||||
|
sfuConfig.MediaTimeout = 30 * time.Second
|
||||||
|
}
|
||||||
|
|
||||||
|
manager, err := sfu.NewRoomManager(sfuConfig, g.logger.Logger)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
g.sfuManager = manager
|
||||||
|
|
||||||
|
g.logger.ComponentInfo(logging.ComponentGeneral, "SFU manager initialized",
|
||||||
|
zap.Int("max_participants", sfuConfig.MaxParticipants),
|
||||||
|
zap.Duration("media_timeout", sfuConfig.MediaTimeout),
|
||||||
|
zap.Int("ice_servers", len(iceServers)),
|
||||||
|
)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
192
pkg/gateway/turn_handlers.go
Normal file
192
pkg/gateway/turn_handlers.go
Normal file
@ -0,0 +1,192 @@
|
|||||||
|
package gateway
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/hmac"
|
||||||
|
"crypto/sha1"
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/DeBrosOfficial/network/pkg/gateway/auth"
|
||||||
|
"github.com/DeBrosOfficial/network/pkg/logging"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TURNCredentialsResponse is the response for TURN credential requests
|
||||||
|
type TURNCredentialsResponse struct {
|
||||||
|
Username string `json:"username"` // Format: "timestamp:userId"
|
||||||
|
Credential string `json:"credential"` // HMAC-SHA1(username, shared_secret) base64 encoded
|
||||||
|
TTL int64 `json:"ttl"` // Time-to-live in seconds
|
||||||
|
STUNURLs []string `json:"stun_urls"` // STUN server URLs
|
||||||
|
TURNURLs []string `json:"turn_urls"` // TURN server URLs
|
||||||
|
}
|
||||||
|
|
||||||
|
// turnCredentialsHandler handles POST /v1/turn/credentials
|
||||||
|
// Returns time-limited TURN credentials for WebRTC connections
|
||||||
|
func (g *Gateway) turnCredentialsHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodPost && r.Method != http.MethodGet {
|
||||||
|
writeError(w, http.StatusMethodNotAllowed, "method not allowed")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if TURN is configured
|
||||||
|
if g.cfg.TURN == nil || g.cfg.TURN.SharedSecret == "" {
|
||||||
|
g.logger.ComponentWarn(logging.ComponentGeneral, "TURN credentials requested but not configured")
|
||||||
|
writeError(w, http.StatusServiceUnavailable, "TURN service not configured")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get user ID from JWT claims or API key
|
||||||
|
userID := g.extractUserID(r)
|
||||||
|
if userID == "" {
|
||||||
|
userID = "anonymous"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get gateway hostname from request
|
||||||
|
gatewayHost := r.Host
|
||||||
|
if idx := strings.Index(gatewayHost, ":"); idx != -1 {
|
||||||
|
gatewayHost = gatewayHost[:idx] // Remove port
|
||||||
|
}
|
||||||
|
if gatewayHost == "" {
|
||||||
|
gatewayHost = "localhost"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate credentials
|
||||||
|
credentials := g.generateTURNCredentials(userID, gatewayHost)
|
||||||
|
|
||||||
|
g.logger.ComponentInfo(logging.ComponentGeneral, "TURN credentials generated",
|
||||||
|
zap.String("user_id", userID),
|
||||||
|
zap.String("gateway_host", gatewayHost),
|
||||||
|
zap.Int64("ttl", credentials.TTL),
|
||||||
|
)
|
||||||
|
|
||||||
|
writeJSON(w, http.StatusOK, credentials)
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateTURNCredentials creates time-limited TURN credentials using HMAC-SHA1
|
||||||
|
func (g *Gateway) generateTURNCredentials(userID, gatewayHost string) *TURNCredentialsResponse {
|
||||||
|
cfg := g.cfg.TURN
|
||||||
|
|
||||||
|
// Default TTL to 24 hours if not configured
|
||||||
|
ttl := cfg.TTL
|
||||||
|
if ttl == 0 {
|
||||||
|
ttl = 24 * time.Hour
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate expiry timestamp
|
||||||
|
timestamp := time.Now().Unix() + int64(ttl.Seconds())
|
||||||
|
|
||||||
|
// Format: "timestamp:userId" (coturn format)
|
||||||
|
username := fmt.Sprintf("%d:%s", timestamp, userID)
|
||||||
|
|
||||||
|
// Generate HMAC-SHA1 credential
|
||||||
|
h := hmac.New(sha1.New, []byte(cfg.SharedSecret))
|
||||||
|
h.Write([]byte(username))
|
||||||
|
credential := base64.StdEncoding.EncodeToString(h.Sum(nil))
|
||||||
|
|
||||||
|
// Determine the host to use for STUN/TURN URLs
|
||||||
|
// Priority: 1) ExternalHost from config, 2) Auto-detect LAN IP, 3) Gateway host from request
|
||||||
|
host := cfg.ExternalHost
|
||||||
|
if host == "" {
|
||||||
|
// Auto-detect LAN IP for development
|
||||||
|
host = detectLANIP()
|
||||||
|
if host == "" {
|
||||||
|
// Fallback to gateway host from request (may be localhost)
|
||||||
|
host = gatewayHost
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process URLs - replace empty hostnames (::) with determined host
|
||||||
|
stunURLs := processURLsWithHost(cfg.STUNURLs, host)
|
||||||
|
turnURLs := processURLsWithHost(cfg.TURNURLs, host)
|
||||||
|
|
||||||
|
// If TLS is enabled, ensure we have turns:// URLs
|
||||||
|
if cfg.TLSEnabled {
|
||||||
|
hasTurns := false
|
||||||
|
for _, url := range turnURLs {
|
||||||
|
if strings.HasPrefix(url, "turns:") {
|
||||||
|
hasTurns = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Auto-add turns:// URL if not already configured
|
||||||
|
if !hasTurns {
|
||||||
|
turnsURL := fmt.Sprintf("turns:%s:443?transport=tcp", host)
|
||||||
|
turnURLs = append(turnURLs, turnsURL)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &TURNCredentialsResponse{
|
||||||
|
Username: username,
|
||||||
|
Credential: credential,
|
||||||
|
TTL: int64(ttl.Seconds()),
|
||||||
|
STUNURLs: stunURLs,
|
||||||
|
TURNURLs: turnURLs,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// detectLANIP returns the first non-loopback IPv4 address found on the system.
|
||||||
|
// Returns empty string if no suitable address is found.
|
||||||
|
func detectLANIP() string {
|
||||||
|
addrs, err := net.InterfaceAddrs()
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, addr := range addrs {
|
||||||
|
if ipNet, ok := addr.(*net.IPNet); ok && !ipNet.IP.IsLoopback() {
|
||||||
|
if ipNet.IP.To4() != nil {
|
||||||
|
return ipNet.IP.String()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// processURLsWithHost replaces empty hostnames in URLs with the given host
|
||||||
|
// Supports two patterns:
|
||||||
|
// - "stun:::3478" (triple colon) -> "stun:host:3478"
|
||||||
|
// - "stun::3478" (double colon) -> "stun:host:3478"
|
||||||
|
func processURLsWithHost(urls []string, host string) []string {
|
||||||
|
result := make([]string, 0, len(urls))
|
||||||
|
for _, url := range urls {
|
||||||
|
// Check for triple colon pattern first (e.g., "stun:::3478")
|
||||||
|
// This is the preferred format: protocol:::port
|
||||||
|
if strings.Contains(url, ":::") {
|
||||||
|
url = strings.Replace(url, ":::", ":"+host+":", 1)
|
||||||
|
} else if strings.Contains(url, "::") {
|
||||||
|
// Fallback for double colon pattern (e.g., "stun::3478")
|
||||||
|
url = strings.Replace(url, "::", ":"+host+":", 1)
|
||||||
|
}
|
||||||
|
result = append(result, url)
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractUserID extracts the user ID from the request context
|
||||||
|
func (g *Gateway) extractUserID(r *http.Request) string {
|
||||||
|
ctx := r.Context()
|
||||||
|
|
||||||
|
// Try JWT claims first
|
||||||
|
if v := ctx.Value(ctxKeyJWT); v != nil {
|
||||||
|
if claims, ok := v.(*auth.JWTClaims); ok && claims != nil {
|
||||||
|
if claims.Sub != "" {
|
||||||
|
return claims.Sub
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback to API key
|
||||||
|
if v := ctx.Value(ctxKeyAPIKey); v != nil {
|
||||||
|
if key, ok := v.(string); ok && key != "" {
|
||||||
|
// Use a hash of the API key as the user ID for privacy
|
||||||
|
return fmt.Sprintf("ak_%s", key[:min(8, len(key))])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
@ -33,19 +33,21 @@ func (n *Node) startHTTPGateway(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
gwCfg := &gateway.Config{
|
gwCfg := &gateway.Config{
|
||||||
ListenAddr: n.config.HTTPGateway.ListenAddr,
|
ListenAddr: n.config.HTTPGateway.ListenAddr,
|
||||||
ClientNamespace: n.config.HTTPGateway.ClientNamespace,
|
ClientNamespace: n.config.HTTPGateway.ClientNamespace,
|
||||||
BootstrapPeers: n.config.Discovery.BootstrapPeers,
|
BootstrapPeers: n.config.Discovery.BootstrapPeers,
|
||||||
NodePeerID: loadNodePeerIDFromIdentity(n.config.Node.DataDir),
|
NodePeerID: loadNodePeerIDFromIdentity(n.config.Node.DataDir),
|
||||||
RQLiteDSN: n.config.HTTPGateway.RQLiteDSN,
|
RQLiteDSN: n.config.HTTPGateway.RQLiteDSN,
|
||||||
OlricServers: n.config.HTTPGateway.OlricServers,
|
OlricServers: n.config.HTTPGateway.OlricServers,
|
||||||
OlricTimeout: n.config.HTTPGateway.OlricTimeout,
|
OlricTimeout: n.config.HTTPGateway.OlricTimeout,
|
||||||
IPFSClusterAPIURL: n.config.HTTPGateway.IPFSClusterAPIURL,
|
IPFSClusterAPIURL: n.config.HTTPGateway.IPFSClusterAPIURL,
|
||||||
IPFSAPIURL: n.config.HTTPGateway.IPFSAPIURL,
|
IPFSAPIURL: n.config.HTTPGateway.IPFSAPIURL,
|
||||||
IPFSTimeout: n.config.HTTPGateway.IPFSTimeout,
|
IPFSTimeout: n.config.HTTPGateway.IPFSTimeout,
|
||||||
EnableHTTPS: n.config.HTTPGateway.HTTPS.Enabled,
|
EnableHTTPS: n.config.HTTPGateway.HTTPS.Enabled,
|
||||||
DomainName: n.config.HTTPGateway.HTTPS.Domain,
|
DomainName: n.config.HTTPGateway.HTTPS.Domain,
|
||||||
TLSCacheDir: n.config.HTTPGateway.HTTPS.CacheDir,
|
TLSCacheDir: n.config.HTTPGateway.HTTPS.CacheDir,
|
||||||
|
TURN: n.config.HTTPGateway.TURN,
|
||||||
|
SFU: n.config.HTTPGateway.SFU,
|
||||||
}
|
}
|
||||||
|
|
||||||
apiGateway, err := gateway.New(gatewayLogger, gwCfg)
|
apiGateway, err := gateway.New(gatewayLogger, gwCfg)
|
||||||
|
|||||||
@ -16,6 +16,7 @@ import (
|
|||||||
"github.com/DeBrosOfficial/network/pkg/logging"
|
"github.com/DeBrosOfficial/network/pkg/logging"
|
||||||
"github.com/DeBrosOfficial/network/pkg/pubsub"
|
"github.com/DeBrosOfficial/network/pkg/pubsub"
|
||||||
database "github.com/DeBrosOfficial/network/pkg/rqlite"
|
database "github.com/DeBrosOfficial/network/pkg/rqlite"
|
||||||
|
"github.com/DeBrosOfficial/network/pkg/turn"
|
||||||
"github.com/libp2p/go-libp2p/core/host"
|
"github.com/libp2p/go-libp2p/core/host"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
"golang.org/x/crypto/acme/autocert"
|
"golang.org/x/crypto/acme/autocert"
|
||||||
@ -55,6 +56,9 @@ type Node struct {
|
|||||||
|
|
||||||
// Certificate ready signal - closed when TLS certificates are extracted and ready for use
|
// Certificate ready signal - closed when TLS certificates are extracted and ready for use
|
||||||
certReady chan struct{}
|
certReady chan struct{}
|
||||||
|
|
||||||
|
// Built-in TURN server for WebRTC NAT traversal
|
||||||
|
turnServer *turn.Server
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewNode creates a new network node
|
// NewNode creates a new network node
|
||||||
@ -96,6 +100,11 @@ func (n *Node) Start(ctx context.Context) error {
|
|||||||
n.logger.ComponentWarn(logging.ComponentNode, "Failed to start HTTP Gateway", zap.Error(err))
|
n.logger.ComponentWarn(logging.ComponentNode, "Failed to start HTTP Gateway", zap.Error(err))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Start built-in TURN server if enabled
|
||||||
|
if err := n.startTURNServer(); err != nil {
|
||||||
|
n.logger.ComponentWarn(logging.ComponentNode, "Failed to start TURN server", zap.Error(err))
|
||||||
|
}
|
||||||
|
|
||||||
// Start LibP2P host first (needed for cluster discovery)
|
// Start LibP2P host first (needed for cluster discovery)
|
||||||
if err := n.startLibP2P(); err != nil {
|
if err := n.startLibP2P(); err != nil {
|
||||||
return fmt.Errorf("failed to start LibP2P: %w", err)
|
return fmt.Errorf("failed to start LibP2P: %w", err)
|
||||||
@ -135,6 +144,11 @@ func (n *Node) Start(ctx context.Context) error {
|
|||||||
func (n *Node) Stop() error {
|
func (n *Node) Stop() error {
|
||||||
n.logger.ComponentInfo(logging.ComponentNode, "Stopping network node")
|
n.logger.ComponentInfo(logging.ComponentNode, "Stopping network node")
|
||||||
|
|
||||||
|
// Stop TURN server
|
||||||
|
if n.turnServer != nil {
|
||||||
|
_ = n.turnServer.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
// Stop HTTP Gateway server
|
// Stop HTTP Gateway server
|
||||||
if n.apiGatewayServer != nil {
|
if n.apiGatewayServer != nil {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
|||||||
90
pkg/node/turn.go
Normal file
90
pkg/node/turn.go
Normal file
@ -0,0 +1,90 @@
|
|||||||
|
package node
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/DeBrosOfficial/network/pkg/logging"
|
||||||
|
"github.com/DeBrosOfficial/network/pkg/turn"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
// startTURNServer initializes and starts the built-in TURN server
|
||||||
|
func (n *Node) startTURNServer() error {
|
||||||
|
if !n.config.TURNServer.Enabled {
|
||||||
|
n.logger.ComponentInfo(logging.ComponentNode, "Built-in TURN server disabled")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
n.logger.ComponentInfo(logging.ComponentNode, "Starting built-in TURN server")
|
||||||
|
|
||||||
|
// Get shared secret - env var takes priority over config file (for production)
|
||||||
|
sharedSecret := os.Getenv("TURN_SHARED_SECRET")
|
||||||
|
if sharedSecret == "" && n.config.HTTPGateway.TURN != nil && n.config.HTTPGateway.TURN.SharedSecret != "" {
|
||||||
|
sharedSecret = n.config.HTTPGateway.TURN.SharedSecret
|
||||||
|
}
|
||||||
|
|
||||||
|
if sharedSecret == "" {
|
||||||
|
n.logger.ComponentWarn(logging.ComponentNode, "TURN server enabled but no shared_secret configured in http_gateway.turn")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get public IP - env var takes priority over config file (for production)
|
||||||
|
publicIP := os.Getenv("TURN_PUBLIC_IP")
|
||||||
|
if publicIP == "" {
|
||||||
|
publicIP = n.config.TURNServer.PublicIP
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build TURN server config
|
||||||
|
turnCfg := &turn.Config{
|
||||||
|
Enabled: true,
|
||||||
|
ListenAddr: n.config.TURNServer.ListenAddr,
|
||||||
|
PublicIP: publicIP,
|
||||||
|
Realm: n.config.TURNServer.Realm,
|
||||||
|
SharedSecret: sharedSecret,
|
||||||
|
CredentialTTL: 24 * 60 * 60, // 24 hours in seconds (will be converted)
|
||||||
|
MinPort: n.config.TURNServer.MinPort,
|
||||||
|
MaxPort: n.config.TURNServer.MaxPort,
|
||||||
|
// TLS configuration for TURNS
|
||||||
|
TLSEnabled: n.config.TURNServer.TLSEnabled,
|
||||||
|
TLSListenAddr: n.config.TURNServer.TLSListenAddr,
|
||||||
|
TLSCertFile: n.config.TURNServer.TLSCertFile,
|
||||||
|
TLSKeyFile: n.config.TURNServer.TLSKeyFile,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply defaults
|
||||||
|
if turnCfg.ListenAddr == "" {
|
||||||
|
turnCfg.ListenAddr = "0.0.0.0:3478"
|
||||||
|
}
|
||||||
|
if turnCfg.Realm == "" {
|
||||||
|
turnCfg.Realm = "orama.network"
|
||||||
|
}
|
||||||
|
if turnCfg.MinPort == 0 {
|
||||||
|
turnCfg.MinPort = 49152
|
||||||
|
}
|
||||||
|
if turnCfg.MaxPort == 0 {
|
||||||
|
turnCfg.MaxPort = 65535
|
||||||
|
}
|
||||||
|
if turnCfg.TLSListenAddr == "" && turnCfg.TLSEnabled {
|
||||||
|
turnCfg.TLSListenAddr = "0.0.0.0:443"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create and start TURN server
|
||||||
|
server, err := turn.NewServer(turnCfg, n.logger.Logger)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := server.Start(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
n.turnServer = server
|
||||||
|
|
||||||
|
n.logger.ComponentInfo(logging.ComponentNode, "Built-in TURN server started",
|
||||||
|
zap.String("listen_addr", turnCfg.ListenAddr),
|
||||||
|
zap.String("realm", turnCfg.Realm),
|
||||||
|
zap.Bool("turns_enabled", turnCfg.TLSEnabled),
|
||||||
|
)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@ -214,3 +214,9 @@ func IsServiceUnavailable(err error) bool {
|
|||||||
errors.Is(err, ErrDatabaseUnavailable) ||
|
errors.Is(err, ErrDatabaseUnavailable) ||
|
||||||
errors.Is(err, ErrCacheUnavailable)
|
errors.Is(err, ErrCacheUnavailable)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsValidationError checks if an error is a validation error.
|
||||||
|
func IsValidationError(err error) bool {
|
||||||
|
var validationErr *ValidationError
|
||||||
|
return errors.As(err, &validationErr)
|
||||||
|
}
|
||||||
|
|||||||
239
pkg/serverless/triggers.go
Normal file
239
pkg/serverless/triggers.go
Normal file
@ -0,0 +1,239 @@
|
|||||||
|
package serverless
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/DeBrosOfficial/network/pkg/rqlite"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Ensure DBTriggerManager implements TriggerManager interface.
|
||||||
|
var _ TriggerManager = (*DBTriggerManager)(nil)
|
||||||
|
|
||||||
|
// DBTriggerManager manages function triggers using RQLite database.
|
||||||
|
type DBTriggerManager struct {
|
||||||
|
db rqlite.Client
|
||||||
|
logger *zap.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDBTriggerManager creates a new trigger manager.
|
||||||
|
func NewDBTriggerManager(db rqlite.Client, logger *zap.Logger) *DBTriggerManager {
|
||||||
|
return &DBTriggerManager{
|
||||||
|
db: db,
|
||||||
|
logger: logger,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// pubsubTriggerRow represents a row from the function_pubsub_triggers table.
|
||||||
|
type pubsubTriggerRow struct {
|
||||||
|
ID string `db:"id"`
|
||||||
|
FunctionID string `db:"function_id"`
|
||||||
|
Topic string `db:"topic"`
|
||||||
|
Enabled bool `db:"enabled"`
|
||||||
|
CreatedAt time.Time `db:"created_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddPubSubTrigger adds a pubsub trigger to a function.
|
||||||
|
func (m *DBTriggerManager) AddPubSubTrigger(ctx context.Context, functionID, topic string) error {
|
||||||
|
functionID = strings.TrimSpace(functionID)
|
||||||
|
topic = strings.TrimSpace(topic)
|
||||||
|
|
||||||
|
if functionID == "" {
|
||||||
|
return &ValidationError{Field: "function_id", Message: "cannot be empty"}
|
||||||
|
}
|
||||||
|
if topic == "" {
|
||||||
|
return &ValidationError{Field: "topic", Message: "cannot be empty"}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if trigger already exists for this function and topic
|
||||||
|
var existing []pubsubTriggerRow
|
||||||
|
checkQuery := `SELECT id FROM function_pubsub_triggers WHERE function_id = ? AND topic = ? AND enabled = TRUE`
|
||||||
|
if err := m.db.Query(ctx, &existing, checkQuery, functionID, topic); err != nil {
|
||||||
|
return fmt.Errorf("failed to check existing trigger: %w", err)
|
||||||
|
}
|
||||||
|
if len(existing) > 0 {
|
||||||
|
return &ValidationError{Field: "topic", Message: "trigger already exists for this topic"}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate trigger ID
|
||||||
|
triggerID := "trig_" + uuid.New().String()[:8]
|
||||||
|
|
||||||
|
// Insert trigger
|
||||||
|
query := `INSERT INTO function_pubsub_triggers (id, function_id, topic, enabled, created_at) VALUES (?, ?, ?, TRUE, ?)`
|
||||||
|
_, err := m.db.Exec(ctx, query, triggerID, functionID, topic, time.Now())
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create pubsub trigger: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.logger.Info("PubSub trigger created",
|
||||||
|
zap.String("trigger_id", triggerID),
|
||||||
|
zap.String("function_id", functionID),
|
||||||
|
zap.String("topic", topic),
|
||||||
|
)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddPubSubTriggerWithID adds a pubsub trigger and returns the trigger ID.
|
||||||
|
func (m *DBTriggerManager) AddPubSubTriggerWithID(ctx context.Context, functionID, topic string) (string, error) {
|
||||||
|
functionID = strings.TrimSpace(functionID)
|
||||||
|
topic = strings.TrimSpace(topic)
|
||||||
|
|
||||||
|
if functionID == "" {
|
||||||
|
return "", &ValidationError{Field: "function_id", Message: "cannot be empty"}
|
||||||
|
}
|
||||||
|
if topic == "" {
|
||||||
|
return "", &ValidationError{Field: "topic", Message: "cannot be empty"}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if trigger already exists for this function and topic
|
||||||
|
var existing []pubsubTriggerRow
|
||||||
|
checkQuery := `SELECT id FROM function_pubsub_triggers WHERE function_id = ? AND topic = ? AND enabled = TRUE`
|
||||||
|
if err := m.db.Query(ctx, &existing, checkQuery, functionID, topic); err != nil {
|
||||||
|
return "", fmt.Errorf("failed to check existing trigger: %w", err)
|
||||||
|
}
|
||||||
|
if len(existing) > 0 {
|
||||||
|
return "", &ValidationError{Field: "topic", Message: "trigger already exists for this topic"}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate trigger ID
|
||||||
|
triggerID := "trig_" + uuid.New().String()[:8]
|
||||||
|
|
||||||
|
// Insert trigger
|
||||||
|
query := `INSERT INTO function_pubsub_triggers (id, function_id, topic, enabled, created_at) VALUES (?, ?, ?, TRUE, ?)`
|
||||||
|
_, err := m.db.Exec(ctx, query, triggerID, functionID, topic, time.Now())
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to create pubsub trigger: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.logger.Info("PubSub trigger created",
|
||||||
|
zap.String("trigger_id", triggerID),
|
||||||
|
zap.String("function_id", functionID),
|
||||||
|
zap.String("topic", topic),
|
||||||
|
)
|
||||||
|
|
||||||
|
return triggerID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddCronTrigger adds a cron-based trigger to a function.
|
||||||
|
func (m *DBTriggerManager) AddCronTrigger(ctx context.Context, functionID, cronExpr string) error {
|
||||||
|
// TODO: Implement cron trigger support
|
||||||
|
return fmt.Errorf("cron triggers not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddDBTrigger adds a database trigger to a function.
|
||||||
|
func (m *DBTriggerManager) AddDBTrigger(ctx context.Context, functionID, tableName string, operation DBOperation, condition string) error {
|
||||||
|
// TODO: Implement database trigger support
|
||||||
|
return fmt.Errorf("database triggers not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
// ScheduleOnce schedules a one-time execution.
|
||||||
|
func (m *DBTriggerManager) ScheduleOnce(ctx context.Context, functionID string, runAt time.Time, payload []byte) (string, error) {
|
||||||
|
// TODO: Implement one-time timer support
|
||||||
|
return "", fmt.Errorf("one-time timers not yet implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveTrigger removes a trigger by ID.
|
||||||
|
func (m *DBTriggerManager) RemoveTrigger(ctx context.Context, triggerID string) error {
|
||||||
|
triggerID = strings.TrimSpace(triggerID)
|
||||||
|
if triggerID == "" {
|
||||||
|
return &ValidationError{Field: "trigger_id", Message: "cannot be empty"}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to delete from pubsub triggers first
|
||||||
|
query := `DELETE FROM function_pubsub_triggers WHERE id = ?`
|
||||||
|
result, err := m.db.Exec(ctx, query, triggerID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to remove trigger: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rowsAffected, _ := result.RowsAffected()
|
||||||
|
if rowsAffected == 0 {
|
||||||
|
return ErrTriggerNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
m.logger.Info("Trigger removed", zap.String("trigger_id", triggerID))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListPubSubTriggers returns all pubsub triggers for a function.
|
||||||
|
func (m *DBTriggerManager) ListPubSubTriggers(ctx context.Context, functionID string) ([]PubSubTrigger, error) {
|
||||||
|
functionID = strings.TrimSpace(functionID)
|
||||||
|
if functionID == "" {
|
||||||
|
return nil, &ValidationError{Field: "function_id", Message: "cannot be empty"}
|
||||||
|
}
|
||||||
|
|
||||||
|
query := `SELECT id, function_id, topic, enabled FROM function_pubsub_triggers WHERE function_id = ? AND enabled = TRUE`
|
||||||
|
var rows []pubsubTriggerRow
|
||||||
|
if err := m.db.Query(ctx, &rows, query, functionID); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to list pubsub triggers: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
triggers := make([]PubSubTrigger, len(rows))
|
||||||
|
for i, row := range rows {
|
||||||
|
triggers[i] = PubSubTrigger{
|
||||||
|
ID: row.ID,
|
||||||
|
FunctionID: row.FunctionID,
|
||||||
|
Topic: row.Topic,
|
||||||
|
Enabled: row.Enabled,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return triggers, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetTriggersByTopic returns all enabled triggers for a specific topic.
|
||||||
|
func (m *DBTriggerManager) GetTriggersByTopic(ctx context.Context, topic string) ([]PubSubTrigger, error) {
|
||||||
|
topic = strings.TrimSpace(topic)
|
||||||
|
if topic == "" {
|
||||||
|
return nil, &ValidationError{Field: "topic", Message: "cannot be empty"}
|
||||||
|
}
|
||||||
|
|
||||||
|
query := `SELECT id, function_id, topic, enabled FROM function_pubsub_triggers WHERE topic = ? AND enabled = TRUE`
|
||||||
|
var rows []pubsubTriggerRow
|
||||||
|
if err := m.db.Query(ctx, &rows, query, topic); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get triggers by topic: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
triggers := make([]PubSubTrigger, len(rows))
|
||||||
|
for i, row := range rows {
|
||||||
|
triggers[i] = PubSubTrigger{
|
||||||
|
ID: row.ID,
|
||||||
|
FunctionID: row.FunctionID,
|
||||||
|
Topic: row.Topic,
|
||||||
|
Enabled: row.Enabled,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return triggers, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPubSubTrigger returns a specific pubsub trigger by ID.
|
||||||
|
func (m *DBTriggerManager) GetPubSubTrigger(ctx context.Context, triggerID string) (*PubSubTrigger, error) {
|
||||||
|
triggerID = strings.TrimSpace(triggerID)
|
||||||
|
if triggerID == "" {
|
||||||
|
return nil, &ValidationError{Field: "trigger_id", Message: "cannot be empty"}
|
||||||
|
}
|
||||||
|
|
||||||
|
query := `SELECT id, function_id, topic, enabled FROM function_pubsub_triggers WHERE id = ?`
|
||||||
|
var rows []pubsubTriggerRow
|
||||||
|
if err := m.db.Query(ctx, &rows, query, triggerID); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get trigger: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(rows) == 0 {
|
||||||
|
return nil, ErrTriggerNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
row := rows[0]
|
||||||
|
return &PubSubTrigger{
|
||||||
|
ID: row.ID,
|
||||||
|
FunctionID: row.FunctionID,
|
||||||
|
Topic: row.Topic,
|
||||||
|
Enabled: row.Enabled,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
@ -131,6 +131,12 @@ type TriggerManager interface {
|
|||||||
|
|
||||||
// RemoveTrigger removes a trigger by ID.
|
// RemoveTrigger removes a trigger by ID.
|
||||||
RemoveTrigger(ctx context.Context, triggerID string) error
|
RemoveTrigger(ctx context.Context, triggerID string) error
|
||||||
|
|
||||||
|
// ListPubSubTriggers returns all pubsub triggers for a function.
|
||||||
|
ListPubSubTriggers(ctx context.Context, functionID string) ([]PubSubTrigger, error)
|
||||||
|
|
||||||
|
// GetTriggersByTopic returns all enabled triggers for a specific topic.
|
||||||
|
GetTriggersByTopic(ctx context.Context, topic string) ([]PubSubTrigger, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// JobManager manages background job execution.
|
// JobManager manages background job execution.
|
||||||
|
|||||||
343
pkg/turn/server.go
Normal file
343
pkg/turn/server.go
Normal file
@ -0,0 +1,343 @@
|
|||||||
|
// Package turn provides a built-in TURN/STUN server using Pion.
|
||||||
|
package turn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/hmac"
|
||||||
|
"crypto/sha1"
|
||||||
|
"crypto/tls"
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"strconv"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pion/turn/v4"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Config contains TURN server configuration
|
||||||
|
type Config struct {
|
||||||
|
// Enabled enables the built-in TURN server
|
||||||
|
Enabled bool `yaml:"enabled"`
|
||||||
|
|
||||||
|
// ListenAddr is the UDP address to listen on (e.g., "0.0.0.0:3478")
|
||||||
|
ListenAddr string `yaml:"listen_addr"`
|
||||||
|
|
||||||
|
// PublicIP is the public IP address to advertise for relay
|
||||||
|
// If empty, will try to auto-detect
|
||||||
|
PublicIP string `yaml:"public_ip"`
|
||||||
|
|
||||||
|
// Realm is the TURN realm (e.g., "orama.network")
|
||||||
|
Realm string `yaml:"realm"`
|
||||||
|
|
||||||
|
// SharedSecret is the secret for HMAC-SHA1 credential generation
|
||||||
|
// Should match the gateway's TURN_SHARED_SECRET
|
||||||
|
SharedSecret string `yaml:"shared_secret"`
|
||||||
|
|
||||||
|
// CredentialTTL is the lifetime of generated credentials
|
||||||
|
CredentialTTL time.Duration `yaml:"credential_ttl"`
|
||||||
|
|
||||||
|
// MinPort and MaxPort define the relay port range
|
||||||
|
MinPort uint16 `yaml:"min_port"`
|
||||||
|
MaxPort uint16 `yaml:"max_port"`
|
||||||
|
|
||||||
|
// TLS Configuration for TURNS (TURN over TLS)
|
||||||
|
// TLSEnabled enables TURNS listener on TLSListenAddr
|
||||||
|
TLSEnabled bool `yaml:"tls_enabled"`
|
||||||
|
|
||||||
|
// TLSListenAddr is the TCP/TLS address to listen on (e.g., "0.0.0.0:443")
|
||||||
|
TLSListenAddr string `yaml:"tls_listen_addr"`
|
||||||
|
|
||||||
|
// TLSCertFile is the path to the TLS certificate file
|
||||||
|
TLSCertFile string `yaml:"tls_cert_file"`
|
||||||
|
|
||||||
|
// TLSKeyFile is the path to the TLS private key file
|
||||||
|
TLSKeyFile string `yaml:"tls_key_file"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultConfig returns a default TURN server configuration
|
||||||
|
func DefaultConfig() *Config {
|
||||||
|
return &Config{
|
||||||
|
Enabled: false,
|
||||||
|
ListenAddr: "0.0.0.0:3478",
|
||||||
|
Realm: "orama.network",
|
||||||
|
CredentialTTL: 24 * time.Hour,
|
||||||
|
MinPort: 49152,
|
||||||
|
MaxPort: 65535,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Server is a built-in TURN/STUN server
|
||||||
|
type Server struct {
|
||||||
|
config *Config
|
||||||
|
logger *zap.Logger
|
||||||
|
turnServer *turn.Server
|
||||||
|
conn net.PacketConn // UDP listener
|
||||||
|
tlsListener net.Listener // TLS listener for TURNS
|
||||||
|
mu sync.RWMutex
|
||||||
|
running bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewServer creates a new TURN server
|
||||||
|
func NewServer(config *Config, logger *zap.Logger) (*Server, error) {
|
||||||
|
if config == nil {
|
||||||
|
config = DefaultConfig()
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Server{
|
||||||
|
config: config,
|
||||||
|
logger: logger,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start starts the TURN server
|
||||||
|
func (s *Server) Start() error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
if s.running {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if !s.config.Enabled {
|
||||||
|
s.logger.Info("TURN server disabled")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.config.SharedSecret == "" {
|
||||||
|
return fmt.Errorf("TURN shared secret is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create UDP listener
|
||||||
|
conn, err := net.ListenPacket("udp4", s.config.ListenAddr)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to listen on %s: %w", s.config.ListenAddr, err)
|
||||||
|
}
|
||||||
|
s.conn = conn
|
||||||
|
|
||||||
|
// Determine public IP
|
||||||
|
publicIP := s.config.PublicIP
|
||||||
|
if publicIP == "" {
|
||||||
|
// Try to auto-detect
|
||||||
|
publicIP, err = getPublicIP()
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Warn("Failed to auto-detect public IP, using listener address", zap.Error(err))
|
||||||
|
host, _, _ := net.SplitHostPort(s.config.ListenAddr)
|
||||||
|
if host == "0.0.0.0" || host == "" {
|
||||||
|
host = "127.0.0.1"
|
||||||
|
}
|
||||||
|
publicIP = host
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
relayIP := net.ParseIP(publicIP)
|
||||||
|
if relayIP == nil {
|
||||||
|
return fmt.Errorf("invalid public IP: %s", publicIP)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.logger.Info("Starting TURN server",
|
||||||
|
zap.String("listen_addr", s.config.ListenAddr),
|
||||||
|
zap.String("public_ip", publicIP),
|
||||||
|
zap.String("realm", s.config.Realm),
|
||||||
|
zap.Uint16("min_port", s.config.MinPort),
|
||||||
|
zap.Uint16("max_port", s.config.MaxPort),
|
||||||
|
zap.Bool("tls_enabled", s.config.TLSEnabled),
|
||||||
|
)
|
||||||
|
|
||||||
|
// Prepare listener configs for TLS (TURNS)
|
||||||
|
var listenerConfigs []turn.ListenerConfig
|
||||||
|
|
||||||
|
if s.config.TLSEnabled && s.config.TLSCertFile != "" && s.config.TLSKeyFile != "" {
|
||||||
|
// Load TLS certificate
|
||||||
|
cert, err := tls.LoadX509KeyPair(s.config.TLSCertFile, s.config.TLSKeyFile)
|
||||||
|
if err != nil {
|
||||||
|
conn.Close()
|
||||||
|
return fmt.Errorf("failed to load TLS certificate: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tlsConfig := &tls.Config{
|
||||||
|
Certificates: []tls.Certificate{cert},
|
||||||
|
MinVersion: tls.VersionTLS12,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Determine TLS listen address
|
||||||
|
tlsListenAddr := s.config.TLSListenAddr
|
||||||
|
if tlsListenAddr == "" {
|
||||||
|
tlsListenAddr = "0.0.0.0:443"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create TLS listener
|
||||||
|
tlsListener, err := tls.Listen("tcp", tlsListenAddr, tlsConfig)
|
||||||
|
if err != nil {
|
||||||
|
conn.Close()
|
||||||
|
return fmt.Errorf("failed to start TLS listener on %s: %w", tlsListenAddr, err)
|
||||||
|
}
|
||||||
|
s.tlsListener = tlsListener
|
||||||
|
|
||||||
|
listenerConfigs = append(listenerConfigs, turn.ListenerConfig{
|
||||||
|
Listener: tlsListener,
|
||||||
|
RelayAddressGenerator: &turn.RelayAddressGeneratorPortRange{
|
||||||
|
RelayAddress: relayIP,
|
||||||
|
Address: "0.0.0.0",
|
||||||
|
MinPort: s.config.MinPort,
|
||||||
|
MaxPort: s.config.MaxPort,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
s.logger.Info("TURNS (TLS) listener started",
|
||||||
|
zap.String("tls_addr", tlsListenAddr),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create TURN server with HMAC-SHA1 auth
|
||||||
|
turnServer, err := turn.NewServer(turn.ServerConfig{
|
||||||
|
Realm: s.config.Realm,
|
||||||
|
AuthHandler: func(username, realm string, srcAddr net.Addr) ([]byte, bool) {
|
||||||
|
return s.authHandler(username, realm, srcAddr)
|
||||||
|
},
|
||||||
|
PacketConnConfigs: []turn.PacketConnConfig{
|
||||||
|
{
|
||||||
|
PacketConn: conn,
|
||||||
|
RelayAddressGenerator: &turn.RelayAddressGeneratorPortRange{
|
||||||
|
RelayAddress: relayIP,
|
||||||
|
Address: "0.0.0.0",
|
||||||
|
MinPort: s.config.MinPort,
|
||||||
|
MaxPort: s.config.MaxPort,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
ListenerConfigs: listenerConfigs,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
conn.Close()
|
||||||
|
if s.tlsListener != nil {
|
||||||
|
s.tlsListener.Close()
|
||||||
|
}
|
||||||
|
return fmt.Errorf("failed to create TURN server: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.turnServer = turnServer
|
||||||
|
s.running = true
|
||||||
|
|
||||||
|
s.logger.Info("TURN server started successfully",
|
||||||
|
zap.String("addr", s.config.ListenAddr),
|
||||||
|
zap.String("realm", s.config.Realm),
|
||||||
|
zap.Bool("turns_enabled", s.config.TLSEnabled),
|
||||||
|
)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// authHandler validates HMAC-SHA1 credentials (coturn-compatible format)
|
||||||
|
// Username format: timestamp:userID (e.g., "1234567890:user123")
|
||||||
|
func (s *Server) authHandler(username, realm string, srcAddr net.Addr) ([]byte, bool) {
|
||||||
|
// Parse timestamp from username
|
||||||
|
// Format: timestamp:userID
|
||||||
|
var timestamp int64
|
||||||
|
for i, c := range username {
|
||||||
|
if c == ':' {
|
||||||
|
ts, err := strconv.ParseInt(username[:i], 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Debug("Invalid timestamp in username", zap.String("username", username))
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
timestamp = ts
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if credential has expired
|
||||||
|
now := time.Now().Unix()
|
||||||
|
if timestamp > 0 && timestamp < now {
|
||||||
|
s.logger.Debug("Credential expired",
|
||||||
|
zap.String("username", username),
|
||||||
|
zap.Int64("expired_at", timestamp),
|
||||||
|
zap.Int64("now", now),
|
||||||
|
)
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate expected password using HMAC-SHA1
|
||||||
|
// This matches the gateway's generateTURNCredentials function
|
||||||
|
h := hmac.New(sha1.New, []byte(s.config.SharedSecret))
|
||||||
|
h.Write([]byte(username))
|
||||||
|
password := base64.StdEncoding.EncodeToString(h.Sum(nil))
|
||||||
|
|
||||||
|
s.logger.Debug("TURN auth request",
|
||||||
|
zap.String("username", username),
|
||||||
|
zap.String("realm", realm),
|
||||||
|
zap.String("src_addr", srcAddr.String()),
|
||||||
|
)
|
||||||
|
|
||||||
|
return turn.GenerateAuthKey(username, realm, password), true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop stops the TURN server
|
||||||
|
func (s *Server) Stop() error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
if !s.running {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
s.logger.Info("Stopping TURN server")
|
||||||
|
|
||||||
|
if s.turnServer != nil {
|
||||||
|
if err := s.turnServer.Close(); err != nil {
|
||||||
|
s.logger.Warn("Error closing TURN server", zap.Error(err))
|
||||||
|
}
|
||||||
|
s.turnServer = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.conn != nil {
|
||||||
|
s.conn.Close()
|
||||||
|
s.conn = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close TLS listener
|
||||||
|
if s.tlsListener != nil {
|
||||||
|
s.tlsListener.Close()
|
||||||
|
s.tlsListener = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
s.running = false
|
||||||
|
s.logger.Info("TURN server stopped")
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsRunning returns whether the server is running
|
||||||
|
func (s *Server) IsRunning() bool {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
return s.running
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetListenAddr returns the listen address
|
||||||
|
func (s *Server) GetListenAddr() string {
|
||||||
|
return s.config.ListenAddr
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPublicAddr returns the public address for clients
|
||||||
|
func (s *Server) GetPublicAddr() string {
|
||||||
|
if s.config.PublicIP != "" {
|
||||||
|
_, port, _ := net.SplitHostPort(s.config.ListenAddr)
|
||||||
|
return net.JoinHostPort(s.config.PublicIP, port)
|
||||||
|
}
|
||||||
|
return s.config.ListenAddr
|
||||||
|
}
|
||||||
|
|
||||||
|
// getPublicIP tries to determine the public IP address
|
||||||
|
func getPublicIP() (string, error) {
|
||||||
|
// Try to get outbound IP by connecting to a public address
|
||||||
|
conn, err := net.Dial("udp4", "8.8.8.8:80")
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
localAddr := conn.LocalAddr().(*net.UDPAddr)
|
||||||
|
return localAddr.IP.String(), nil
|
||||||
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user