Skip to content

Commit

Permalink
Fix memory leak caused by disconnecting before connect finished (#386)
Browse files Browse the repository at this point in the history
* State locking for Room and RTC engine around critical spots

* Cancel connect job if invoking coroutine is cancelled

* cleanup

* Clean up test logs

* revert stress test changes to sample apps
  • Loading branch information
davidliu authored Feb 29, 2024
1 parent 760c536 commit bca9985
Show file tree
Hide file tree
Showing 5 changed files with 294 additions and 187 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ constructor(
restartingIce = true
}

if (this.peerConnection.signalingState() == SignalingState.HAVE_LOCAL_OFFER) {
if (peerConnection.signalingState() == SignalingState.HAVE_LOCAL_OFFER) {
// we're waiting for the peer to accept our offer, so we'll just wait
// the only exception to this is when ICE restart is needed
val curSd = peerConnection.remoteDescription
Expand Down Expand Up @@ -313,7 +313,7 @@ constructor(
}

@OptIn(ExperimentalContracts::class)
private suspend inline fun <T> launchRTCIfNotClosed(noinline action: suspend () -> T): T? {
private suspend inline fun <T> launchRTCIfNotClosed(noinline action: suspend CoroutineScope.() -> T): T? {
contract { callsInPlace(action, InvocationKind.AT_MOST_ONCE) }
if (isClosed()) {
return null
Expand Down
213 changes: 120 additions & 93 deletions livekit-android-sdk/src/main/java/io/livekit/android/room/RTCEngine.kt
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,17 @@ import io.livekit.android.util.FlowObservable
import io.livekit.android.util.LKLog
import io.livekit.android.util.flowDelegate
import io.livekit.android.util.nullSafe
import io.livekit.android.util.withCheckLock
import io.livekit.android.webrtc.RTCStatsGetter
import io.livekit.android.webrtc.copy
import io.livekit.android.webrtc.isConnected
import io.livekit.android.webrtc.isDisconnected
import io.livekit.android.webrtc.peerconnection.executeBlockingOnRTCThread
import io.livekit.android.webrtc.peerconnection.launchBlockingOnRTCThread
import io.livekit.android.webrtc.toProtoSessionDescription
import kotlinx.coroutines.*
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
import livekit.LivekitModels
import livekit.LivekitRtc
import livekit.LivekitRtc.JoinResponse
Expand Down Expand Up @@ -134,6 +138,12 @@ internal constructor(

private var coroutineScope = CloseableCoroutineScope(SupervisorJob() + ioDispatcher)

/**
* Note: If this lock is ever used in conjunction with the RTC thread,
* this must be grabbed on the RTC thread to prevent deadlocks.
*/
private var configurationLock = Mutex()

init {
client.listener = this
}
Expand All @@ -158,8 +168,10 @@ internal constructor(
token: String,
options: ConnectOptions,
roomOptions: RoomOptions,
): JoinResponse {
): JoinResponse = coroutineScope {
val joinResponse = client.join(url, token, options, roomOptions)
ensureActive()

listener?.onJoinResponse(joinResponse)
isClosed = false
listener?.onSignalConnected(false)
Expand All @@ -169,93 +181,103 @@ internal constructor(
configure(joinResponse, options)

// create offer
if (!this.isSubscriberPrimary) {
if (!isSubscriberPrimary) {
negotiatePublisher()
}
client.onReadyForResponses()
return joinResponse

return@coroutineScope joinResponse
}

private suspend fun configure(joinResponse: JoinResponse, connectOptions: ConnectOptions) {
if (publisher != null && subscriber != null) {
// already configured
return
}
launchBlockingOnRTCThread {
configurationLock.withCheckLock(
{
ensureActive()
if (publisher != null && subscriber != null) {
// already configured
return@launchBlockingOnRTCThread
}
},
) {
participantSid = if (joinResponse.hasParticipant()) {
joinResponse.participant.sid
} else {
null
}

participantSid = if (joinResponse.hasParticipant()) {
joinResponse.participant.sid
} else {
null
}
// Setup peer connections
val rtcConfig = makeRTCConfig(Either.Left(joinResponse), connectOptions)

// Setup peer connections
val rtcConfig = makeRTCConfig(Either.Left(joinResponse), connectOptions)
publisher?.close()
publisher = pctFactory.create(
rtcConfig,
publisherObserver,
publisherObserver,
)
subscriber?.close()
subscriber = pctFactory.create(
rtcConfig,
subscriberObserver,
null,
)

publisher?.close()
publisher = pctFactory.create(
rtcConfig,
publisherObserver,
publisherObserver,
)
subscriber?.close()
subscriber = pctFactory.create(
rtcConfig,
subscriberObserver,
null,
)
val connectionStateListener: (PeerConnection.PeerConnectionState) -> Unit = { newState ->
LKLog.v { "onIceConnection new state: $newState" }
if (newState.isConnected()) {
connectionState = ConnectionState.CONNECTED
} else if (newState.isDisconnected()) {
connectionState = ConnectionState.DISCONNECTED
}
}

val connectionStateListener: (PeerConnection.PeerConnectionState) -> Unit = { newState ->
LKLog.v { "onIceConnection new state: $newState" }
if (newState.isConnected()) {
connectionState = ConnectionState.CONNECTED
} else if (newState.isDisconnected()) {
connectionState = ConnectionState.DISCONNECTED
}
}
if (joinResponse.subscriberPrimary) {
// in subscriber primary mode, server side opens sub data channels.
subscriberObserver.dataChannelListener = onDataChannel@{ dataChannel: DataChannel ->
when (dataChannel.label()) {
RELIABLE_DATA_CHANNEL_LABEL -> reliableDataChannelSub = dataChannel
LOSSY_DATA_CHANNEL_LABEL -> lossyDataChannelSub = dataChannel
else -> return@onDataChannel
}
dataChannel.registerObserver(DataChannelObserver(dataChannel))
}

if (joinResponse.subscriberPrimary) {
// in subscriber primary mode, server side opens sub data channels.
subscriberObserver.dataChannelListener = onDataChannel@{ dataChannel: DataChannel ->
when (dataChannel.label()) {
RELIABLE_DATA_CHANNEL_LABEL -> reliableDataChannelSub = dataChannel
LOSSY_DATA_CHANNEL_LABEL -> lossyDataChannelSub = dataChannel
else -> return@onDataChannel
subscriberObserver.connectionChangeListener = connectionStateListener
// Also reconnect on publisher disconnect
publisherObserver.connectionChangeListener = { newState ->
if (newState.isDisconnected()) {
reconnect()
}
}
} else {
publisherObserver.connectionChangeListener = connectionStateListener
}
dataChannel.registerObserver(DataChannelObserver(dataChannel))
}

subscriberObserver.connectionChangeListener = connectionStateListener
// Also reconnect on publisher disconnect
publisherObserver.connectionChangeListener = { newState ->
if (newState.isDisconnected()) {
reconnect()
ensureActive()
// data channels
val reliableInit = DataChannel.Init()
reliableInit.ordered = true
reliableDataChannel = publisher?.withPeerConnection {
createDataChannel(
RELIABLE_DATA_CHANNEL_LABEL,
reliableInit,
).also { dataChannel ->
dataChannel.registerObserver(DataChannelObserver(dataChannel))
}
}
}
} else {
publisherObserver.connectionChangeListener = connectionStateListener
}

// data channels
val reliableInit = DataChannel.Init()
reliableInit.ordered = true
reliableDataChannel = publisher?.withPeerConnection {
createDataChannel(
RELIABLE_DATA_CHANNEL_LABEL,
reliableInit,
).also { dataChannel ->
dataChannel.registerObserver(DataChannelObserver(dataChannel))
}
}

val lossyInit = DataChannel.Init()
lossyInit.ordered = true
lossyInit.maxRetransmits = 0
lossyDataChannel = publisher?.withPeerConnection {
createDataChannel(
LOSSY_DATA_CHANNEL_LABEL,
lossyInit,
).also { dataChannel ->
dataChannel.registerObserver(DataChannelObserver(dataChannel))
ensureActive()
val lossyInit = DataChannel.Init()
lossyInit.ordered = true
lossyInit.maxRetransmits = 0
lossyDataChannel = publisher?.withPeerConnection {
createDataChannel(
LOSSY_DATA_CHANNEL_LABEL,
lossyInit,
).also { dataChannel ->
dataChannel.registerObserver(DataChannelObserver(dataChannel))
}
}
}
}
}
Expand Down Expand Up @@ -327,27 +349,32 @@ internal constructor(

private fun closeResources(reason: String) {
executeBlockingOnRTCThread {
publisherObserver.connectionChangeListener = null
subscriberObserver.connectionChangeListener = null
publisher?.closeBlocking()
publisher = null
subscriber?.closeBlocking()
subscriber = null

fun DataChannel?.completeDispose() {
this?.unregisterObserver()
this?.close()
this?.dispose()
runBlocking {
configurationLock.withLock {
publisherObserver.connectionChangeListener = null
subscriberObserver.connectionChangeListener = null
publisher?.closeBlocking()
publisher = null
subscriber?.closeBlocking()
subscriber = null

fun DataChannel?.completeDispose() {
this?.unregisterObserver()
this?.close()
this?.dispose()
}

reliableDataChannel?.completeDispose()
reliableDataChannel = null
reliableDataChannelSub?.completeDispose()
reliableDataChannelSub = null
lossyDataChannel?.completeDispose()
lossyDataChannel = null
lossyDataChannelSub?.completeDispose()
lossyDataChannelSub = null
isSubscriberPrimary = false
}
}
reliableDataChannel?.completeDispose()
reliableDataChannel = null
reliableDataChannelSub?.completeDispose()
reliableDataChannelSub = null
lossyDataChannel?.completeDispose()
lossyDataChannel = null
lossyDataChannelSub?.completeDispose()
lossyDataChannelSub = null
isSubscriberPrimary = false
}
client.close(reason = reason)
}
Expand Down
Loading

0 comments on commit bca9985

Please sign in to comment.