orama/pkg/sfu/peer.go
anonpenguin23 8ee606bfb1 feat: implement SFU and TURN server functionality
- Add signaling package with message types and structures for SFU communication.
- Implement client and server message serialization/deserialization tests.
- Enhance systemd manager to handle SFU and TURN services, including start/stop logic.
- Create TURN server configuration and main server logic with HMAC-SHA1 authentication.
- Add tests for TURN server credential generation and validation.
- Define systemd service files for SFU and TURN services.
2026-02-21 11:17:13 +02:00

341 lines
8.0 KiB
Go

package sfu
import (
"encoding/json"
"errors"
"sync"
"github.com/google/uuid"
"github.com/gorilla/websocket"
"github.com/pion/rtcp"
"github.com/pion/webrtc/v4"
"go.uber.org/zap"
)
var (
ErrPeerNotInitialized = errors.New("peer connection not initialized")
ErrPeerClosed = errors.New("peer is closed")
ErrWebSocketClosed = errors.New("websocket connection closed")
)
// Peer represents a participant in a room with a WebRTC PeerConnection.
type Peer struct {
ID string
UserID string
pc *webrtc.PeerConnection
conn *websocket.Conn
room *Room
// Negotiation state machine
negotiationPending bool
batchingTracks bool
negotiationMu sync.Mutex
// Connection state
closed bool
closedMu sync.RWMutex
connMu sync.Mutex
logger *zap.Logger
onClose func(*Peer)
}
// NewPeer creates a new peer
func NewPeer(userID string, conn *websocket.Conn, room *Room, logger *zap.Logger) *Peer {
return &Peer{
ID: uuid.New().String(),
UserID: userID,
conn: conn,
room: room,
logger: logger.With(zap.String("peer_id", "")), // Updated after ID assigned
}
}
// InitPeerConnection creates and configures the WebRTC PeerConnection.
func (p *Peer) InitPeerConnection(api *webrtc.API, iceServers []webrtc.ICEServer) error {
pc, err := api.NewPeerConnection(webrtc.Configuration{
ICEServers: iceServers,
ICETransportPolicy: webrtc.ICETransportPolicyRelay, // Force TURN relay
})
if err != nil {
return err
}
p.pc = pc
p.logger = p.logger.With(zap.String("peer_id", p.ID))
// ICE connection state changes
pc.OnICEConnectionStateChange(func(state webrtc.ICEConnectionState) {
p.logger.Info("ICE state changed", zap.String("state", state.String()))
switch state {
case webrtc.ICEConnectionStateDisconnected:
// Give 15 seconds to reconnect before removing
go p.handleReconnectTimeout()
case webrtc.ICEConnectionStateFailed, webrtc.ICEConnectionStateClosed:
p.handleDisconnect()
}
})
// ICE candidate generation
pc.OnICECandidate(func(candidate *webrtc.ICECandidate) {
if candidate == nil {
return
}
c := candidate.ToJSON()
data := &ICECandidateData{Candidate: c.Candidate}
if c.SDPMid != nil {
data.SDPMid = *c.SDPMid
}
if c.SDPMLineIndex != nil {
data.SDPMLineIndex = *c.SDPMLineIndex
}
if c.UsernameFragment != nil {
data.UsernameFragment = *c.UsernameFragment
}
p.SendMessage(NewServerMessage(MessageTypeICECandidate, data))
})
// Incoming tracks from the client
pc.OnTrack(func(track *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) {
p.logger.Info("Track received",
zap.String("track_id", track.ID()),
zap.String("kind", track.Kind().String()),
zap.String("codec", track.Codec().MimeType))
// Read RTCP feedback (PLI/NACK) in background
go p.readRTCP(receiver, track)
// Forward track to all other peers
p.room.BroadcastTrack(p.ID, track)
})
// Negotiation needed — only when stable
pc.OnNegotiationNeeded(func() {
p.negotiationMu.Lock()
if p.batchingTracks {
p.negotiationPending = true
p.negotiationMu.Unlock()
return
}
p.negotiationMu.Unlock()
if pc.SignalingState() == webrtc.SignalingStateStable {
p.createAndSendOffer()
} else {
p.negotiationMu.Lock()
p.negotiationPending = true
p.negotiationMu.Unlock()
}
})
// When state returns to stable, fire pending negotiation
pc.OnSignalingStateChange(func(state webrtc.SignalingState) {
if state == webrtc.SignalingStateStable {
p.negotiationMu.Lock()
pending := p.negotiationPending
p.negotiationPending = false
p.negotiationMu.Unlock()
if pending {
p.createAndSendOffer()
}
}
})
return nil
}
func (p *Peer) createAndSendOffer() {
if p.pc == nil {
return
}
if p.pc.SignalingState() != webrtc.SignalingStateStable {
p.negotiationMu.Lock()
p.negotiationPending = true
p.negotiationMu.Unlock()
return
}
offer, err := p.pc.CreateOffer(nil)
if err != nil {
p.logger.Error("Failed to create offer", zap.Error(err))
return
}
if err := p.pc.SetLocalDescription(offer); err != nil {
p.logger.Error("Failed to set local description", zap.Error(err))
return
}
p.SendMessage(NewServerMessage(MessageTypeOffer, &OfferData{SDP: offer.SDP}))
}
// HandleOffer processes an SDP offer from the client
func (p *Peer) HandleOffer(sdp string) error {
if p.pc == nil {
return ErrPeerNotInitialized
}
if err := p.pc.SetRemoteDescription(webrtc.SessionDescription{
Type: webrtc.SDPTypeOffer, SDP: sdp,
}); err != nil {
return err
}
answer, err := p.pc.CreateAnswer(nil)
if err != nil {
return err
}
if err := p.pc.SetLocalDescription(answer); err != nil {
return err
}
p.SendMessage(NewServerMessage(MessageTypeAnswer, &AnswerData{SDP: answer.SDP}))
return nil
}
// HandleAnswer processes an SDP answer from the client
func (p *Peer) HandleAnswer(sdp string) error {
if p.pc == nil {
return ErrPeerNotInitialized
}
return p.pc.SetRemoteDescription(webrtc.SessionDescription{
Type: webrtc.SDPTypeAnswer, SDP: sdp,
})
}
// HandleICECandidate adds a remote ICE candidate
func (p *Peer) HandleICECandidate(data *ICECandidateData) error {
if p.pc == nil {
return ErrPeerNotInitialized
}
return p.pc.AddICECandidate(data.ToWebRTCCandidate())
}
// AddTrack adds a local track to send to this peer
func (p *Peer) AddTrack(track *webrtc.TrackLocalStaticRTP) (*webrtc.RTPSender, error) {
if p.pc == nil {
return nil, ErrPeerNotInitialized
}
return p.pc.AddTrack(track)
}
// StartTrackBatch suppresses renegotiation during bulk track additions
func (p *Peer) StartTrackBatch() {
p.negotiationMu.Lock()
p.batchingTracks = true
p.negotiationMu.Unlock()
}
// EndTrackBatch ends batching and fires deferred renegotiation
func (p *Peer) EndTrackBatch() {
p.negotiationMu.Lock()
p.batchingTracks = false
pending := p.negotiationPending
p.negotiationPending = false
p.negotiationMu.Unlock()
if pending && p.pc != nil && p.pc.SignalingState() == webrtc.SignalingStateStable {
p.createAndSendOffer()
}
}
// SendMessage sends a signaling message via WebSocket
func (p *Peer) SendMessage(msg *ServerMessage) error {
p.closedMu.RLock()
if p.closed {
p.closedMu.RUnlock()
return ErrPeerClosed
}
p.closedMu.RUnlock()
p.connMu.Lock()
defer p.connMu.Unlock()
if p.conn == nil {
return ErrWebSocketClosed
}
data, err := json.Marshal(msg)
if err != nil {
return err
}
return p.conn.WriteMessage(websocket.TextMessage, data)
}
// GetInfo returns public info about this peer
func (p *Peer) GetInfo() ParticipantInfo {
return ParticipantInfo{PeerID: p.ID, UserID: p.UserID}
}
// handleReconnectTimeout waits 15 seconds for ICE reconnection before removing the peer.
func (p *Peer) handleReconnectTimeout() {
// Use a channel that closes when peer state changes
// Check after 15 seconds if still disconnected
<-timeAfter(reconnectTimeout)
if p.pc == nil {
return
}
state := p.pc.ICEConnectionState()
if state == webrtc.ICEConnectionStateDisconnected || state == webrtc.ICEConnectionStateFailed {
p.logger.Info("Peer did not reconnect within timeout, removing")
p.handleDisconnect()
}
}
func (p *Peer) handleDisconnect() {
p.closedMu.Lock()
if p.closed {
p.closedMu.Unlock()
return
}
p.closed = true
p.closedMu.Unlock()
if p.onClose != nil {
p.onClose(p)
}
}
// Close closes the peer connection and WebSocket
func (p *Peer) Close() error {
p.closedMu.Lock()
if p.closed {
p.closedMu.Unlock()
return nil
}
p.closed = true
p.closedMu.Unlock()
p.connMu.Lock()
if p.conn != nil {
p.conn.Close()
p.conn = nil
}
p.connMu.Unlock()
if p.pc != nil {
return p.pc.Close()
}
return nil
}
// OnClose sets the disconnect callback
func (p *Peer) OnClose(fn func(*Peer)) {
p.onClose = fn
}
// readRTCP reads RTCP feedback and forwards PLI/FIR to the source peer
func (p *Peer) readRTCP(receiver *webrtc.RTPReceiver, track *webrtc.TrackRemote) {
localTrackID := track.Kind().String() + "-" + p.ID
for {
packets, _, err := receiver.ReadRTCP()
if err != nil {
return
}
for _, pkt := range packets {
switch pkt.(type) {
case *rtcp.PictureLossIndication, *rtcp.FullIntraRequest:
p.room.RequestKeyframe(localTrackID)
}
}
}
}