orama/pkg/gateway/sfu/room.go
2026-02-20 18:24:32 +02:00

776 lines
21 KiB
Go

package sfu
import (
"errors"
"sync"
"sync/atomic"
"time"
"github.com/pion/rtcp"
"github.com/pion/webrtc/v4"
"go.uber.org/zap"
)
// Common errors
var (
ErrRoomFull = errors.New("room is full")
ErrRoomClosed = errors.New("room is closed")
ErrPeerNotFound = errors.New("peer not found")
ErrPeerNotInitialized = errors.New("peer not initialized")
ErrPeerClosed = errors.New("peer is closed")
ErrWebSocketClosed = errors.New("websocket connection closed")
)
// publishedTrack holds information about a track published to the room
type publishedTrack struct {
sourcePeerID string
sourceUserID string
localTrack *webrtc.TrackLocalStaticRTP
remoteTrackSSRC uint32 // SSRC of the remote track (for PLI requests)
remoteTrack *webrtc.TrackRemote // Reference to the remote track
kind string
forwarderActive bool // tracks if the RTP forwarder goroutine is still running
packetCount int // number of packets forwarded (for debugging)
// Packet loss tracking for adaptive keyframe requests
lastKeyframeRequest time.Time
nackCount atomic.Int64 // Number of NACK packets received (indicates packet loss)
}
// Room represents a WebRTC room with multiple participants
type Room struct {
ID string
Namespace string
// Participants in the room
peers map[string]*Peer
peersMu sync.RWMutex
// Published tracks in the room (for sending to new joiners)
publishedTracks map[string]*publishedTrack // key: trackID
publishedTracksMu sync.RWMutex
// WebRTC API for creating peer connections
api *webrtc.API
// Configuration
config *Config
logger *zap.Logger
// State
closed bool
closedMu sync.RWMutex
// Callbacks
onEmpty func(*Room)
}
// NewRoom creates a new room
func NewRoom(id, namespace string, api *webrtc.API, config *Config, logger *zap.Logger) *Room {
return &Room{
ID: id,
Namespace: namespace,
peers: make(map[string]*Peer),
publishedTracks: make(map[string]*publishedTrack),
api: api,
config: config,
logger: logger.With(zap.String("room_id", id)),
}
}
// AddPeer adds a new peer to the room
func (r *Room) AddPeer(peer *Peer) error {
r.closedMu.RLock()
if r.closed {
r.closedMu.RUnlock()
return ErrRoomClosed
}
r.closedMu.RUnlock()
r.peersMu.Lock()
// Check max participants
if r.config.MaxParticipants > 0 && len(r.peers) >= r.config.MaxParticipants {
r.peersMu.Unlock()
return ErrRoomFull
}
// Initialize peer connection
pcConfig := webrtc.Configuration{
ICEServers: r.config.ICEServers,
}
if err := peer.InitPeerConnection(r.api, pcConfig); err != nil {
r.peersMu.Unlock()
return err
}
// Set up peer close handler
peer.OnClose(func(p *Peer) {
r.RemovePeer(p.ID)
})
r.peers[peer.ID] = peer
peerInfo := peer.GetInfo() // Get info while holding lock
totalPeers := len(r.peers)
// Release lock BEFORE broadcasting to avoid deadlock
// (broadcastMessage also acquires the lock)
r.peersMu.Unlock()
r.logger.Info("Peer added to room",
zap.String("peer_id", peer.ID),
zap.String("user_id", peer.UserID),
zap.Int("total_peers", totalPeers),
)
// Notify other peers (now safe since we released the lock)
r.broadcastMessage(peer.ID, NewServerMessage(MessageTypeParticipantJoined, &ParticipantJoinedData{
Participant: peerInfo,
}))
return nil
}
// RemovePeer removes a peer from the room
func (r *Room) RemovePeer(peerID string) error {
r.peersMu.Lock()
peer, ok := r.peers[peerID]
if !ok {
r.peersMu.Unlock()
return ErrPeerNotFound
}
delete(r.peers, peerID)
remainingPeers := len(r.peers)
r.peersMu.Unlock()
// Remove tracks published by this peer
r.publishedTracksMu.Lock()
removedTracks := make([]string, 0)
for trackID, track := range r.publishedTracks {
if track.sourcePeerID == peerID {
delete(r.publishedTracks, trackID)
removedTracks = append(removedTracks, trackID)
}
}
r.publishedTracksMu.Unlock()
if len(removedTracks) > 0 {
r.logger.Info("Removed tracks for departing peer",
zap.String("peer_id", peerID),
zap.Strings("track_ids", removedTracks),
)
}
// Close the peer
if err := peer.Close(); err != nil {
r.logger.Warn("Error closing peer",
zap.String("peer_id", peerID),
zap.Error(err),
)
}
r.logger.Info("Peer removed from room",
zap.String("peer_id", peerID),
zap.Int("remaining_peers", remainingPeers),
)
// Notify other peers
r.broadcastMessage(peerID, NewServerMessage(MessageTypeParticipantLeft, &ParticipantLeftData{
ParticipantID: peerID,
}))
// Check if room is empty
if remainingPeers == 0 && r.onEmpty != nil {
r.onEmpty(r)
}
return nil
}
// GetPeer returns a peer by ID
func (r *Room) GetPeer(peerID string) (*Peer, error) {
r.peersMu.RLock()
defer r.peersMu.RUnlock()
peer, ok := r.peers[peerID]
if !ok {
return nil, ErrPeerNotFound
}
return peer, nil
}
// GetPeers returns all peers in the room
func (r *Room) GetPeers() []*Peer {
r.peersMu.RLock()
defer r.peersMu.RUnlock()
peers := make([]*Peer, 0, len(r.peers))
for _, peer := range r.peers {
peers = append(peers, peer)
}
return peers
}
// GetParticipants returns info about all participants
func (r *Room) GetParticipants() []ParticipantInfo {
r.peersMu.RLock()
defer r.peersMu.RUnlock()
participants := make([]ParticipantInfo, 0, len(r.peers))
for _, peer := range r.peers {
participants = append(participants, peer.GetInfo())
}
return participants
}
// GetParticipantCount returns the number of participants
func (r *Room) GetParticipantCount() int {
r.peersMu.RLock()
defer r.peersMu.RUnlock()
return len(r.peers)
}
// SendExistingTracksTo sends all existing tracks from other participants to the specified peer.
// This should be called AFTER the welcome message is sent to ensure the client is ready.
// Uses batch mode to send all tracks with a single renegotiation for faster joins.
func (r *Room) SendExistingTracksTo(peer *Peer) {
r.publishedTracksMu.RLock()
existingTracks := make([]*publishedTrack, 0, len(r.publishedTracks))
for _, track := range r.publishedTracks {
// Don't send peer's own tracks back to them
if track.sourcePeerID != peer.ID {
existingTracks = append(existingTracks, track)
}
}
r.publishedTracksMu.RUnlock()
if len(existingTracks) == 0 {
r.logger.Info("No existing tracks to send to new peer",
zap.String("peer_id", peer.ID),
)
return
}
r.logger.Info("Sending existing tracks to new peer (batch mode)",
zap.String("peer_id", peer.ID),
zap.Int("track_count", len(existingTracks)),
)
videoTrackIDs := make([]string, 0)
// Start batch mode - suppresses individual renegotiations
peer.StartTrackBatch()
for _, track := range existingTracks {
// Log forwarder status to help diagnose video issues
r.logger.Info("Adding existing track to new peer",
zap.String("new_peer_id", peer.ID),
zap.String("source_peer_id", track.sourcePeerID),
zap.String("source_user_id", track.sourceUserID),
zap.String("track_id", track.localTrack.ID()),
zap.String("kind", track.kind),
zap.Bool("forwarder_active", track.forwarderActive),
zap.Int("packets_forwarded", track.packetCount),
)
// Warn if forwarder is no longer active
if !track.forwarderActive {
r.logger.Warn("WARNING: Track forwarder is NOT active - track may not receive data",
zap.String("track_id", track.localTrack.ID()),
zap.String("kind", track.kind),
zap.Int("final_packet_count", track.packetCount),
)
}
if _, err := peer.AddTrack(track.localTrack); err != nil {
r.logger.Warn("Failed to add existing track to new peer",
zap.String("peer_id", peer.ID),
zap.String("track_id", track.localTrack.ID()),
zap.Error(err),
)
continue
}
// Track video tracks for keyframe requests
if track.kind == "video" {
videoTrackIDs = append(videoTrackIDs, track.localTrack.ID())
}
// Notify new peer about the existing track
peer.SendMessage(NewServerMessage(MessageTypeTrackAdded, &TrackAddedData{
ParticipantID: track.sourcePeerID,
UserID: track.sourceUserID,
TrackID: track.localTrack.ID(),
StreamID: track.localTrack.StreamID(),
Kind: track.kind,
}))
}
// End batch mode - triggers single renegotiation for all tracks
peer.EndTrackBatch()
r.logger.Info("Batch track addition complete - single renegotiation triggered",
zap.String("peer_id", peer.ID),
zap.Int("total_tracks", len(existingTracks)),
)
// Request keyframes for video tracks after a short delay
// This ensures the receiver has time to set up the track before receiving the keyframe
if len(videoTrackIDs) > 0 {
go func() {
// Wait for negotiation to complete
time.Sleep(300 * time.Millisecond)
r.logger.Info("Requesting keyframes for new peer",
zap.String("peer_id", peer.ID),
zap.Int("video_track_count", len(videoTrackIDs)),
)
for _, trackID := range videoTrackIDs {
r.RequestKeyframe(trackID)
}
// Request again after 500ms in case the first was too early
time.Sleep(500 * time.Millisecond)
for _, trackID := range videoTrackIDs {
r.RequestKeyframe(trackID)
}
}()
}
}
// BroadcastTrack broadcasts a track from one peer to all other peers
func (r *Room) BroadcastTrack(sourcePeerID string, track *webrtc.TrackRemote) {
r.peersMu.RLock()
defer r.peersMu.RUnlock()
// Get source peer's user ID for the track-added message
sourceUserID := ""
if sourcePeer, ok := r.peers[sourcePeerID]; ok {
sourceUserID = sourcePeer.UserID
}
// Create a local track from the remote track
// Use participant ID as the stream ID so clients can identify the source
// Format: trackId stays the same, streamId = sourcePeerID (or sourceUserID if available)
streamID := sourcePeerID
if sourceUserID != "" {
streamID = sourceUserID // Use userID for easier client-side mapping
}
r.logger.Info("Creating local track for broadcast",
zap.String("source_peer_id", sourcePeerID),
zap.String("source_user_id", sourceUserID),
zap.String("original_track_id", track.ID()),
zap.String("original_stream_id", track.StreamID()),
zap.String("new_stream_id", streamID),
)
// Log codec information for debugging
codec := track.Codec()
r.logger.Info("Track codec info",
zap.String("source_peer_id", sourcePeerID),
zap.String("track_kind", track.Kind().String()),
zap.String("mime_type", codec.MimeType),
zap.Uint32("clock_rate", codec.ClockRate),
zap.Uint16("channels", codec.Channels),
zap.String("sdp_fmtp_line", codec.SDPFmtpLine),
)
localTrack, err := webrtc.NewTrackLocalStaticRTP(
codec.RTPCodecCapability,
track.Kind().String()+"-"+sourcePeerID, // Include peer ID in track ID
streamID, // Use participant/user ID as stream ID
)
if err != nil {
r.logger.Error("Failed to create local track",
zap.String("source_peer", sourcePeerID),
zap.String("track_id", track.ID()),
zap.Error(err),
)
return
}
// Store the track for new joiners
pubTrack := &publishedTrack{
sourcePeerID: sourcePeerID,
sourceUserID: sourceUserID,
localTrack: localTrack,
remoteTrackSSRC: uint32(track.SSRC()),
remoteTrack: track,
kind: track.Kind().String(),
forwarderActive: true,
packetCount: 0,
lastKeyframeRequest: time.Now(),
}
r.publishedTracksMu.Lock()
r.publishedTracks[localTrack.ID()] = pubTrack
r.publishedTracksMu.Unlock()
r.logger.Info("Track stored for new joiners",
zap.String("track_id", localTrack.ID()),
zap.String("source_peer_id", sourcePeerID),
zap.Int("total_published_tracks", len(r.publishedTracks)),
)
// Forward RTP packets from remote track to local track
go func() {
trackID := track.ID()
localTrackID := localTrack.ID()
trackKind := track.Kind().String()
buf := make([]byte, 1600) // Slightly larger than MTU to handle RTP extensions
packetCount := 0
byteCount := 0
startTime := time.Now()
firstPacketReceived := false
r.logger.Info("RTP forwarder started",
zap.String("track_id", trackID),
zap.String("local_track_id", localTrackID),
zap.String("kind", trackKind),
zap.String("source_peer_id", sourcePeerID),
)
// Start a goroutine to log warning if no packets received after 5 seconds
go func() {
time.Sleep(5 * time.Second)
if !firstPacketReceived {
r.logger.Warn("RTP forwarder WARNING: No packets received after 5 seconds - host may not be sending",
zap.String("track_id", trackID),
zap.String("local_track_id", localTrackID),
zap.String("kind", trackKind),
zap.String("source_peer_id", sourcePeerID),
)
}
}()
// For video tracks, use adaptive keyframe requests based on packet loss
if trackKind == "video" {
go func() {
// Use a faster ticker for checking, but only send keyframes when needed
ticker := time.NewTicker(500 * time.Millisecond)
defer ticker.Stop()
var lastNackCount int64
var consecutiveLossDetections int
baseInterval := 3 * time.Second
minInterval := 500 * time.Millisecond // Minimum interval between keyframes
lastKeyframeTime := time.Now()
for range ticker.C {
// Check if room is closed
if r.IsClosed() {
return
}
// Check if forwarder is still active
r.publishedTracksMu.RLock()
pt, ok := r.publishedTracks[localTrackID]
if !ok || !pt.forwarderActive {
r.publishedTracksMu.RUnlock()
return
}
// Get current NACK count to detect packet loss
currentNackCount := pt.nackCount.Load()
r.publishedTracksMu.RUnlock()
timeSinceLastKeyframe := time.Since(lastKeyframeTime)
// Detect if packet loss is happening (NACKs increasing)
if currentNackCount > lastNackCount {
consecutiveLossDetections++
lastNackCount = currentNackCount
// If we detect packet loss and haven't requested a keyframe recently,
// request one immediately to help receivers recover
if timeSinceLastKeyframe >= minInterval {
r.logger.Debug("Adaptive keyframe request due to packet loss",
zap.String("track_id", localTrackID),
zap.Int64("nack_count", currentNackCount),
zap.Int("consecutive_loss_detections", consecutiveLossDetections),
)
r.RequestKeyframe(localTrackID)
lastKeyframeTime = time.Now()
}
} else {
// Reset consecutive loss counter when no new NACKs
if consecutiveLossDetections > 0 {
consecutiveLossDetections--
}
}
// Regular keyframe request at base interval (regardless of loss)
if timeSinceLastKeyframe >= baseInterval {
r.RequestKeyframe(localTrackID)
lastKeyframeTime = time.Now()
}
}
}()
}
// Mark forwarder as stopped when we exit
defer func() {
r.publishedTracksMu.Lock()
if pt, ok := r.publishedTracks[localTrackID]; ok {
pt.forwarderActive = false
pt.packetCount = packetCount
}
r.publishedTracksMu.Unlock()
r.logger.Info("RTP forwarder exiting",
zap.String("track_id", trackID),
zap.String("local_track_id", localTrackID),
zap.String("kind", trackKind),
zap.Duration("lifetime", time.Since(startTime)),
zap.Int("total_packets", packetCount),
zap.Int("total_bytes", byteCount),
)
}()
for {
n, _, readErr := track.Read(buf)
if readErr != nil {
r.logger.Info("RTP forwarder stopped - read error",
zap.String("track_id", trackID),
zap.String("local_track_id", localTrackID),
zap.String("kind", trackKind),
zap.Int("packets_forwarded", packetCount),
zap.Int("bytes_forwarded", byteCount),
zap.Error(readErr),
)
return
}
// Log first packet to confirm data is flowing
if packetCount == 0 {
firstPacketReceived = true
r.logger.Info("RTP forwarder received FIRST packet",
zap.String("track_id", trackID),
zap.String("local_track_id", localTrackID),
zap.String("kind", trackKind),
zap.Int("packet_size", n),
zap.Duration("time_to_first_packet", time.Since(startTime)),
)
}
if _, writeErr := localTrack.Write(buf[:n]); writeErr != nil {
r.logger.Info("RTP forwarder stopped - write error",
zap.String("track_id", trackID),
zap.String("local_track_id", localTrackID),
zap.String("kind", trackKind),
zap.Int("packets_forwarded", packetCount),
zap.Int("bytes_forwarded", byteCount),
zap.Error(writeErr),
)
return
}
packetCount++
byteCount += n
// Update packet count periodically (not every packet to reduce lock contention)
if packetCount%50 == 0 {
r.publishedTracksMu.Lock()
if pt, ok := r.publishedTracks[localTrackID]; ok {
pt.packetCount = packetCount
}
r.publishedTracksMu.Unlock()
}
// Log progress every 100 packets for video, 500 for audio
logInterval := 500
if trackKind == "video" {
logInterval = 100
}
if packetCount%logInterval == 0 {
r.logger.Info("RTP forwarder progress",
zap.String("track_id", trackID),
zap.String("kind", trackKind),
zap.Int("packets", packetCount),
zap.Int("bytes", byteCount),
)
}
}
}()
// Add track to all other peers
for peerID, peer := range r.peers {
if peerID == sourcePeerID {
continue
}
r.logger.Info("Adding track to peer",
zap.String("target_peer", peerID),
zap.String("source_peer", sourcePeerID),
zap.String("source_user", sourceUserID),
zap.String("track_id", localTrack.ID()),
zap.String("stream_id", localTrack.StreamID()),
)
if _, err := peer.AddTrack(localTrack); err != nil {
r.logger.Warn("Failed to add track to peer",
zap.String("target_peer", peerID),
zap.String("track_id", localTrack.ID()),
zap.Error(err),
)
continue
}
// Notify peer about new track
// Use consistent IDs: participantId=sourcePeerID, streamId matches the track
peer.SendMessage(NewServerMessage(MessageTypeTrackAdded, &TrackAddedData{
ParticipantID: sourcePeerID,
UserID: sourceUserID,
TrackID: localTrack.ID(), // Format: "{kind}-{participantId}"
StreamID: localTrack.StreamID(), // Same as userId for easy matching
Kind: track.Kind().String(),
}))
}
r.logger.Info("Track broadcast to room",
zap.String("source_peer", sourcePeerID),
zap.String("source_user", sourceUserID),
zap.String("track_id", track.ID()),
zap.String("kind", track.Kind().String()),
)
}
// broadcastMessage sends a message to all peers except the specified one
func (r *Room) broadcastMessage(excludePeerID string, msg *ServerMessage) {
r.peersMu.RLock()
defer r.peersMu.RUnlock()
for peerID, peer := range r.peers {
if peerID == excludePeerID {
continue
}
if err := peer.SendMessage(msg); err != nil {
r.logger.Warn("Failed to send message to peer",
zap.String("peer_id", peerID),
zap.Error(err),
)
}
}
}
// Close closes the room and all peer connections
func (r *Room) Close() error {
r.closedMu.Lock()
if r.closed {
r.closedMu.Unlock()
return nil
}
r.closed = true
r.closedMu.Unlock()
r.logger.Info("Closing room")
r.peersMu.Lock()
peers := make([]*Peer, 0, len(r.peers))
for _, peer := range r.peers {
peers = append(peers, peer)
}
r.peers = make(map[string]*Peer)
r.peersMu.Unlock()
// Close all peers
for _, peer := range peers {
if err := peer.Close(); err != nil {
r.logger.Warn("Error closing peer",
zap.String("peer_id", peer.ID),
zap.Error(err),
)
}
}
return nil
}
// OnEmpty sets a callback for when the room becomes empty
func (r *Room) OnEmpty(fn func(*Room)) {
r.onEmpty = fn
}
// IsClosed returns whether the room is closed
func (r *Room) IsClosed() bool {
r.closedMu.RLock()
defer r.closedMu.RUnlock()
return r.closed
}
// RequestKeyframe sends a PLI (Picture Loss Indication) to the source peer for a video track.
// This causes the source to send a keyframe, which is needed for new receivers to start decoding.
func (r *Room) RequestKeyframe(trackID string) {
r.publishedTracksMu.RLock()
track, ok := r.publishedTracks[trackID]
r.publishedTracksMu.RUnlock()
if !ok || track.kind != "video" {
return
}
r.peersMu.RLock()
sourcePeer, ok := r.peers[track.sourcePeerID]
r.peersMu.RUnlock()
if !ok || sourcePeer.pc == nil {
r.logger.Debug("Cannot request keyframe - source peer not found",
zap.String("track_id", trackID),
zap.String("source_peer_id", track.sourcePeerID),
)
return
}
// Create a PLI packet
pli := &rtcp.PictureLossIndication{
MediaSSRC: track.remoteTrackSSRC,
}
// Send the PLI to the source peer
if err := sourcePeer.pc.WriteRTCP([]rtcp.Packet{pli}); err != nil {
r.logger.Debug("Failed to send PLI",
zap.String("track_id", trackID),
zap.String("source_peer_id", track.sourcePeerID),
zap.Error(err),
)
return
}
r.logger.Debug("PLI keyframe request sent",
zap.String("track_id", trackID),
zap.String("source_peer_id", track.sourcePeerID),
zap.Uint32("ssrc", track.remoteTrackSSRC),
)
}
// RequestKeyframeForAllVideoTracks sends PLI requests for all video tracks in the room.
// This is useful when a new peer joins to ensure they get keyframes quickly.
func (r *Room) RequestKeyframeForAllVideoTracks() {
r.publishedTracksMu.RLock()
videoTrackIDs := make([]string, 0)
for trackID, track := range r.publishedTracks {
if track.kind == "video" {
videoTrackIDs = append(videoTrackIDs, trackID)
}
}
r.publishedTracksMu.RUnlock()
for _, trackID := range videoTrackIDs {
r.RequestKeyframe(trackID)
}
}
// IncrementNackCount increments the NACK counter for a track.
// This is called when we receive NACK feedback indicating packet loss.
func (r *Room) IncrementNackCount(trackID string) {
r.publishedTracksMu.RLock()
track, ok := r.publishedTracks[trackID]
r.publishedTracksMu.RUnlock()
if ok {
track.nackCount.Add(1)
}
}