From 9d4d5766ef5377f01b917bba5acd3471755c74f3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miroslav=20=C5=A0ediv=C3=BD?= Date: Mon, 27 Sep 2021 00:50:49 +0200 Subject: [PATCH] webrtc refactor peer track. --- internal/capture/stream.go | 156 ++++++++++++++++++-------------- internal/types/capture.go | 8 +- internal/webrtc/manager.go | 167 ++++++----------------------------- internal/webrtc/peertrack.go | 97 ++++++++++++++++++++ 4 files changed, 219 insertions(+), 209 deletions(-) create mode 100644 internal/webrtc/peertrack.go diff --git a/internal/capture/stream.go b/internal/capture/stream.go index f8af877c..3e810a4e 100644 --- a/internal/capture/stream.go +++ b/internal/capture/stream.go @@ -1,6 +1,7 @@ package capture import ( + "errors" "reflect" "sync" @@ -13,18 +14,22 @@ import ( ) type StreamManagerCtx struct { - logger zerolog.Logger - mu sync.Mutex - wg sync.WaitGroup - codec codec.RTPCodec - pipelineStr func() string + logger zerolog.Logger + mu sync.Mutex + wg sync.WaitGroup + codec codec.RTPCodec + pipeline *gst.Pipeline - sample chan types.Sample - listeners map[uintptr]*func(sample types.Sample) - emitMu sync.Mutex - emitUpdate chan bool - emitStop chan bool - started bool + pipelineMu sync.Mutex + pipelineStr func() string + + sample chan types.Sample + sampleStop chan interface{} + sampleUpdate chan interface{} + + listeners map[uintptr]*func(sample types.Sample) + listenersMu sync.Mutex + listenersCount uint32 } func streamNew(codec codec.RTPCodec, pipelineStr func() string, video_id string) *StreamManagerCtx { @@ -34,13 +39,12 @@ func streamNew(codec codec.RTPCodec, pipelineStr func() string, video_id string) Str("video_id", video_id).Logger() manager := &StreamManagerCtx{ - logger: logger, - codec: codec, - pipelineStr: pipelineStr, - listeners: map[uintptr]*func(sample types.Sample){}, - emitUpdate: make(chan bool), - emitStop: make(chan bool), - started: false, + logger: logger, + codec: codec, + pipelineStr: pipelineStr, + sampleStop: make(chan interface{}), + sampleUpdate: make(chan interface{}), + listeners: map[uintptr]*func(sample types.Sample){}, } manager.wg.Add(1) @@ -51,17 +55,17 @@ func streamNew(codec codec.RTPCodec, pipelineStr func() string, video_id string) for { select { - case <-manager.emitStop: + case <-manager.sampleStop: manager.logger.Debug().Msg("stopped emitting samples") return - case <-manager.emitUpdate: + case <-manager.sampleUpdate: manager.logger.Debug().Msg("update emitting samples") case sample := <-manager.sample: - manager.emitMu.Lock() + manager.listenersMu.Lock() for _, emit := range manager.listeners { (*emit)(sample) } - manager.emitMu.Unlock() + manager.listenersMu.Unlock() } } }() @@ -72,15 +76,15 @@ func streamNew(codec codec.RTPCodec, pipelineStr func() string, video_id string) func (manager *StreamManagerCtx) shutdown() { manager.logger.Info().Msgf("shutdown") - manager.emitMu.Lock() + manager.listenersMu.Lock() for key := range manager.listeners { delete(manager.listeners, key) } - manager.emitMu.Unlock() + manager.listenersMu.Unlock() manager.destroyPipeline() - manager.emitStop <- true + close(manager.sampleStop) manager.wg.Wait() } @@ -88,63 +92,78 @@ func (manager *StreamManagerCtx) Codec() codec.RTPCodec { return manager.codec } -func (manager *StreamManagerCtx) AddListener(listener *func(sample types.Sample)) { - manager.emitMu.Lock() - defer manager.emitMu.Unlock() - - if listener != nil { - ptr := reflect.ValueOf(listener).Pointer() - manager.listeners[ptr] = listener - manager.logger.Debug().Interface("ptr", ptr).Msgf("adding listener") +func (manager *StreamManagerCtx) NewListener(listener *func(sample types.Sample)) (addListener func(), err error) { + if listener == nil { + return addListener, errors.New("listener cannot be nil") } + + manager.mu.Lock() + defer manager.mu.Unlock() + + if manager.listenersCount == 0 { + err := manager.createPipeline() + if err != nil && !errors.Is(err, types.ErrCapturePipelineAlreadyExists) { + return addListener, err + } + + manager.listenersCount++ + manager.logger.Info().Msgf("first listener, starting") + } + + return func() { + ptr := reflect.ValueOf(listener).Pointer() + + manager.listenersMu.Lock() + manager.listeners[ptr] = listener + manager.listenersMu.Unlock() + + manager.logger.Debug().Interface("ptr", ptr).Msgf("adding listener") + }, nil } func (manager *StreamManagerCtx) RemoveListener(listener *func(sample types.Sample)) { - manager.emitMu.Lock() - defer manager.emitMu.Unlock() - - if listener != nil { - ptr := reflect.ValueOf(listener).Pointer() - delete(manager.listeners, ptr) - manager.logger.Debug().Interface("ptr", ptr).Msgf("removing listener") + if listener == nil { + return } + + ptr := reflect.ValueOf(listener).Pointer() + + manager.listenersMu.Lock() + delete(manager.listeners, ptr) + manager.listenersMu.Unlock() + + manager.logger.Debug().Interface("ptr", ptr).Msgf("removing listener") + + go func() { + manager.mu.Lock() + defer manager.mu.Unlock() + + if manager.listenersCount == 1 { + manager.destroyPipeline() + manager.listenersCount = 0 + manager.logger.Info().Msgf("last listener, stopping") + } + }() } func (manager *StreamManagerCtx) ListenersCount() int { - manager.emitMu.Lock() - defer manager.emitMu.Unlock() + manager.listenersMu.Lock() + defer manager.listenersMu.Unlock() return len(manager.listeners) } -func (manager *StreamManagerCtx) Start() error { - manager.mu.Lock() - defer manager.mu.Unlock() - - err := manager.createPipeline() - if err != nil { - return err - } - - manager.logger.Info().Msgf("start") - manager.started = true - return nil -} - -func (manager *StreamManagerCtx) Stop() { - manager.mu.Lock() - defer manager.mu.Unlock() - - manager.logger.Info().Msgf("stop") - manager.started = false - manager.destroyPipeline() -} - func (manager *StreamManagerCtx) Started() bool { - return manager.started + manager.mu.Lock() + defer manager.mu.Unlock() + + return manager.listenersCount > 0 } func (manager *StreamManagerCtx) createPipeline() error { + manager.pipelineMu.Lock() + defer manager.pipelineMu.Unlock() + if manager.pipeline != nil { return types.ErrCapturePipelineAlreadyExists } @@ -166,11 +185,14 @@ func (manager *StreamManagerCtx) createPipeline() error { manager.pipeline.Start() manager.sample = manager.pipeline.Sample - manager.emitUpdate <- true + manager.sampleUpdate <- struct{}{} return nil } func (manager *StreamManagerCtx) destroyPipeline() { + manager.pipelineMu.Lock() + defer manager.pipelineMu.Unlock() + if manager.pipeline == nil { return } diff --git a/internal/types/capture.go b/internal/types/capture.go index 8901bdf6..84c35672 100644 --- a/internal/types/capture.go +++ b/internal/types/capture.go @@ -35,12 +35,12 @@ type ScreencastManager interface { type StreamManager interface { Codec() codec.RTPCodec - AddListener(listener *func(sample Sample)) + // starts pipeline if was not running before and returns register function + NewListener(listener *func(sample Sample)) (addListener func(), err error) + // stops pipeline if it was last listener RemoveListener(listener *func(sample Sample)) - ListenersCount() int - Start() error - Stop() + ListenersCount() int Started() bool } diff --git a/internal/webrtc/manager.go b/internal/webrtc/manager.go index 55355e12..04e9b7a7 100644 --- a/internal/webrtc/manager.go +++ b/internal/webrtc/manager.go @@ -1,15 +1,11 @@ package webrtc import ( - "errors" "fmt" - "io" "strings" - "sync" "time" "github.com/pion/webrtc/v3" - "github.com/pion/webrtc/v3/pkg/media" "github.com/rs/zerolog" "github.com/rs/zerolog/log" @@ -39,13 +35,10 @@ func New(desktop types.DesktopManager, capture types.CaptureManager, config *con capture: capture, curImage: cursor.NewImage(desktop), curPosition: cursor.NewPosition(desktop), - - participants: 0, } } type WebRTCManagerCtx struct { - mu sync.Mutex logger zerolog.Logger config *config.WebRTC @@ -53,33 +46,9 @@ type WebRTCManagerCtx struct { capture types.CaptureManager curImage *cursor.ImageCtx curPosition *cursor.PositionCtx - - audioTrack *webrtc.TrackLocalStaticSample - audioListener func(sample types.Sample) - participants uint32 } func (manager *WebRTCManagerCtx) Start() { - var err error - - // create audio track - audio := manager.capture.Audio() - manager.audioTrack, err = webrtc.NewTrackLocalStaticSample(audio.Codec().Capability, "audio", "stream") - if err != nil { - manager.logger.Panic().Err(err).Msg("unable to create audio track") - } - - manager.audioListener = func(sample types.Sample) { - if err := manager.audioTrack.WriteSample(media.Sample(sample)); err != nil { - if errors.Is(err, io.ErrClosedPipe) { - // The peerConnection has been closed. - return - } - manager.logger.Warn().Err(err).Msg("audio pipeline failed to write") - } - } - audio.AddListener(&manager.audioListener) - manager.curImage.Start() manager.logger.Info(). @@ -97,9 +66,6 @@ func (manager *WebRTCManagerCtx) Shutdown() error { manager.curImage.Shutdown() manager.curPosition.Shutdown() - audio := manager.capture.Audio() - audio.RemoveListener(&manager.audioListener) - return nil } @@ -112,6 +78,9 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, videoID strin logger := manager.logger.With().Str("session_id", session.ID()).Logger() logger.Info().Msg("creating webrtc peer") + // all audios must have the same codec + audioStream := manager.capture.Audio() + // all videos must have the same codec videoStream, ok := manager.capture.Video(videoID) if !ok { @@ -119,8 +88,8 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, videoID strin } connection, err := manager.newPeerConnection([]codec.RTPCodec{ + audioStream.Codec(), videoStream.Codec(), - manager.capture.Audio().Codec(), }, logger) if err != nil { return nil, err @@ -142,79 +111,32 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, videoID strin }) } - // create video track - videoTrack, err := webrtc.NewTrackLocalStaticSample(videoStream.Codec().Capability, "video", "stream") + // audio track + + audioTrack, err := manager.newPeerTrack(audioStream, logger) if err != nil { return nil, err } - videoListener := func(sample types.Sample) { - if err := videoTrack.WriteSample(media.Sample(sample)); err != nil { - if errors.Is(err, io.ErrClosedPipe) { - // The peerConnection has been closed. - return - } - logger.Warn().Err(err).Msg("video pipeline failed to write") - } - } - - manager.mu.Lock() - - // should be stream started - if videoStream.ListenersCount() == 0 { - if err := videoStream.Start(); err != nil { - return nil, err - } - } - - videoStream.AddListener(&videoListener) - - // start audio, when first participant connects - if !manager.capture.Audio().Started() { - if err := manager.capture.Audio().Start(); err != nil { - manager.logger.Panic().Err(err).Msg("unable to start audio stream") - } - } - - manager.participants = manager.participants + 1 - manager.mu.Unlock() - - changeVideo := func(videoID string) error { - newVideoStream, ok := manager.capture.Video(videoID) - if !ok { - return types.ErrWebRTCVideoNotFound - } - - // should be new stream started - if newVideoStream.ListenersCount() == 0 { - if err := newVideoStream.Start(); err != nil { - return err - } - } - - // switch videoListeners - videoStream.RemoveListener(&videoListener) - newVideoStream.AddListener(&videoListener) - - // should be old stream stopped - if videoStream.ListenersCount() == 0 { - videoStream.Stop() - } - - videoStream = newVideoStream - return nil - } - - rtpAudio, err := connection.AddTrack(manager.audioTrack) + audioTrack.AddToConnection(connection) if err != nil { return nil, err } - rtpVideo, err := connection.AddTrack(videoTrack) + // video track + + videoTrack, err := manager.newPeerTrack(videoStream, logger) if err != nil { return nil, err } + videoTrack.AddToConnection(connection) + if err != nil { + return nil, err + } + + // data channel + dataChannel, err := connection.CreateDataChannel("data", nil) if err != nil { return nil, err @@ -224,8 +146,15 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, videoID strin logger: logger, connection: connection, dataChannel: dataChannel, - changeVideo: changeVideo, - iceTrickle: manager.config.ICETrickle, + changeVideo: func(videoID string) error { + videoStream, ok := manager.capture.Video(videoID) + if !ok { + return types.ErrWebRTCVideoNotFound + } + + return videoTrack.SetStream(videoStream) + }, + iceTrickle: manager.config.ICETrickle, } cursorImage := func(entry *cursor.ImageEntry) { @@ -252,29 +181,9 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, videoID strin webrtc.PeerConnectionStateFailed: connection.Close() case webrtc.PeerConnectionStateClosed: - manager.mu.Lock() - session.SetWebRTCConnected(peer, false) - videoStream.RemoveListener(&videoListener) - - // should be stream stopped - if videoStream.ListenersCount() == 0 { - videoStream.Stop() - } - - // decrease participants - manager.participants = manager.participants - 1 - - // stop audio, if last participant disonnects - if manager.participants <= 0 { - manager.participants = 0 - - if manager.capture.Audio().Started() { - manager.capture.Audio().Stop() - } - } - - manager.mu.Unlock() + videoTrack.RemoveStream() + audioTrack.RemoveStream() } }) @@ -310,24 +219,6 @@ func (manager *WebRTCManagerCtx) CreatePeer(session types.Session, videoID strin } }) - go func() { - rtcpBuf := make([]byte, 1500) - for { - if _, _, err := rtpAudio.Read(rtcpBuf); err != nil { - return - } - } - }() - - go func() { - rtcpBuf := make([]byte, 1500) - for { - if _, _, err := rtpVideo.Read(rtcpBuf); err != nil { - return - } - } - }() - session.SetWebRTCPeer(peer) return peer.CreateOffer(false) } diff --git a/internal/webrtc/peertrack.go b/internal/webrtc/peertrack.go new file mode 100644 index 00000000..ef7cf8c6 --- /dev/null +++ b/internal/webrtc/peertrack.go @@ -0,0 +1,97 @@ +package webrtc + +import ( + "demodesk/neko/internal/types" + "errors" + "io" + "sync" + + "github.com/pion/webrtc/v3" + "github.com/pion/webrtc/v3/pkg/media" + "github.com/rs/zerolog" +) + +func (manager *WebRTCManagerCtx) newPeerTrack(stream types.StreamManager, logger zerolog.Logger) (*PeerTrack, error) { + codec := stream.Codec() + + id := codec.Type.String() + track, err := webrtc.NewTrackLocalStaticSample(codec.Capability, id, "stream") + if err != nil { + return nil, err + } + + logger = logger.With().Str("id", id).Logger() + + peer := &PeerTrack{ + logger: logger, + track: track, + listener: func(sample types.Sample) { + err := track.WriteSample(media.Sample(sample)) + if err != nil && errors.Is(err, io.ErrClosedPipe) { + logger.Warn().Err(err).Msg("pipeline failed to write") + } + }, + } + + peer.SetStream(stream) + return peer, nil + +} + +type PeerTrack struct { + logger zerolog.Logger + track *webrtc.TrackLocalStaticSample + listener func(sample types.Sample) + + streamMu sync.Mutex + stream types.StreamManager +} + +func (peer *PeerTrack) SetStream(stream types.StreamManager) error { + peer.streamMu.Lock() + defer peer.streamMu.Unlock() + + // prepare new listener + addListener, err := stream.NewListener(&peer.listener) + if err != nil { + return err + } + + // remove previous listener (in case it existed) + if peer.stream != nil { + peer.stream.RemoveListener(&peer.listener) + } + + // add new listener + addListener() + + peer.stream = stream + return nil +} + +func (peer *PeerTrack) RemoveStream() { + peer.streamMu.Lock() + defer peer.streamMu.Unlock() + + if peer.stream != nil { + peer.stream.RemoveListener(&peer.listener) + } +} + +func (peer *PeerTrack) AddToConnection(connection *webrtc.PeerConnection) error { + sender, err := connection.AddTrack(peer.track) + if err != nil { + return err + } + + go func() { + rtcpBuf := make([]byte, 1500) + for { + if _, _, err := sender.Read(rtcpBuf); err != nil { + return + } + } + }() + + return nil +}