Skip to content

Commit

Permalink
Implement LocalAudioTrack.addSink (#516)
Browse files Browse the repository at this point in the history
* Implement LocalAudioTrack.addSink

* test fix

* fix tests

* fix tests
  • Loading branch information
davidliu authored Oct 2, 2024
1 parent 285ba5e commit 4d97868
Show file tree
Hide file tree
Showing 15 changed files with 223 additions and 113 deletions.
5 changes: 5 additions & 0 deletions .changeset/bright-pillows-pay.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"client-sdk-android": minor
---

Implement LocalAudioTrack.addSink to receive audio data from local mic
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/*
* Copyright 2024 LiveKit, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.livekit.android.audio

import android.media.AudioFormat
import android.os.SystemClock
import livekit.org.webrtc.AudioTrackSink
import livekit.org.webrtc.audio.JavaAudioDeviceModule
import livekit.org.webrtc.audio.JavaAudioDeviceModule.SamplesReadyCallback
import java.nio.ByteBuffer

class AudioRecordSamplesDispatcher : SamplesReadyCallback {

private val sinks = mutableSetOf<AudioTrackSink>()

@Synchronized
fun registerSink(sink: AudioTrackSink) {
sinks.add(sink)
}

@Synchronized
fun unregisterSink(sink: AudioTrackSink) {
sinks.remove(sink)
}

// Reference from Android code, AudioFormat.getBytesPerSample. BitPerSample / 8
// Default audio data format is PCM 16 bits per sample.
// Guaranteed to be supported by all devices
private fun getBytesPerSample(audioFormat: Int): Int {
return when (audioFormat) {
AudioFormat.ENCODING_PCM_8BIT -> 1
AudioFormat.ENCODING_PCM_16BIT, AudioFormat.ENCODING_IEC61937, AudioFormat.ENCODING_DEFAULT -> 2
AudioFormat.ENCODING_PCM_FLOAT -> 4
AudioFormat.ENCODING_INVALID -> throw IllegalArgumentException("Bad audio format $audioFormat")
else -> throw IllegalArgumentException("Bad audio format $audioFormat")
}
}

@Synchronized
override fun onWebRtcAudioRecordSamplesReady(samples: JavaAudioDeviceModule.AudioSamples) {
val bitsPerSample = getBytesPerSample(samples.audioFormat) * 8
val numFrames = samples.sampleRate / 100 // 10ms worth of samples.
val timestamp = SystemClock.elapsedRealtime()
for (sink in sinks) {
val byteBuffer = ByteBuffer.wrap(samples.data)
sink.onData(
byteBuffer,
bitsPerSample,
samples.sampleRate,
samples.channelCount,
numFrames,
timestamp,
)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ object InjectionNames {

const val LIB_WEBRTC_INITIALIZATION = "lib_webrtc_initialization"

const val LOCAL_AUDIO_RECORD_SAMPLES_DISPATCHER = "local_audio_record_samples_dispatcher"

// Overrides
const val OVERRIDE_OKHTTP = "override_okhttp"
const val OVERRIDE_AUDIO_DEVICE_MODULE = "override_audio_device_module"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import dagger.Provides
import io.livekit.android.LiveKit
import io.livekit.android.audio.AudioProcessingController
import io.livekit.android.audio.AudioProcessorOptions
import io.livekit.android.audio.AudioRecordSamplesDispatcher
import io.livekit.android.audio.CommunicationWorkaround
import io.livekit.android.memory.CloseableManager
import io.livekit.android.util.LKLog
Expand Down Expand Up @@ -127,6 +128,13 @@ internal object RTCModule {
return LibWebrtcInitialization
}

@Provides
@Named(InjectionNames.LOCAL_AUDIO_RECORD_SAMPLES_DISPATCHER)
@Singleton
fun localAudioSamplesDispatcher(): AudioRecordSamplesDispatcher {
return AudioRecordSamplesDispatcher()
}

@Provides
@Singleton
@JvmSuppressWildcards
Expand All @@ -141,6 +149,8 @@ internal object RTCModule {
appContext: Context,
closeableManager: CloseableManager,
communicationWorkaround: CommunicationWorkaround,
@Named(InjectionNames.LOCAL_AUDIO_RECORD_SAMPLES_DISPATCHER)
audioRecordSamplesDispatcher: AudioRecordSamplesDispatcher,
): AudioDeviceModule {
if (audioDeviceModuleOverride != null) {
return audioDeviceModuleOverride
Expand Down Expand Up @@ -215,6 +225,7 @@ internal object RTCModule {
.setAudioTrackErrorCallback(audioTrackErrorCallback)
.setAudioRecordStateCallback(audioRecordStateCallback)
.setAudioTrackStateCallback(audioTrackStateCallback)
.setSamplesReadyCallback(audioRecordSamplesDispatcher)
// VOICE_COMMUNICATION needs to be used for echo cancelling.
.setAudioSource(MediaRecorder.AudioSource.VOICE_COMMUNICATION)
.setAudioAttributes(audioOutputAttributes)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package io.livekit.android.room.track

import livekit.org.webrtc.AudioTrack
import livekit.org.webrtc.AudioTrackSink

/**
* A class representing an audio track.
Expand All @@ -27,4 +28,22 @@ abstract class AudioTrack(
* The underlying WebRTC audio track.
*/
override val rtcTrack: AudioTrack,
) : Track(name, Kind.AUDIO, rtcTrack)
) : Track(name, Kind.AUDIO, rtcTrack) {

/**
* Adds a sink that receives the audio bytes and related information
* for this audio track. Repeated calls using the same sink will
* only add the sink once.
*
* Implementations should copy the audio data into a local copy if they wish
* to use the data after the [AudioTrackSink.onData] callback returns.
* Long running processing of the received audio data should be done in a separate
* thread, as doing so inline may block the audio thread.
*/
abstract fun addSink(sink: AudioTrackSink)

/**
* Removes a previously added sink.
*/
abstract fun removeSink(sink: AudioTrackSink)
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import dagger.assisted.Assisted
import dagger.assisted.AssistedFactory
import dagger.assisted.AssistedInject
import io.livekit.android.audio.AudioProcessingController
import io.livekit.android.audio.AudioRecordSamplesDispatcher
import io.livekit.android.dagger.InjectionNames
import io.livekit.android.room.participant.LocalParticipant
import io.livekit.android.util.FlowObservable
Expand All @@ -37,10 +38,13 @@ import kotlinx.coroutines.flow.combine
import kotlinx.coroutines.flow.map
import kotlinx.coroutines.flow.stateIn
import livekit.LivekitModels.AudioTrackFeature
import livekit.org.webrtc.AudioTrackSink
import livekit.org.webrtc.MediaConstraints
import livekit.org.webrtc.PeerConnectionFactory
import livekit.org.webrtc.RtpSender
import livekit.org.webrtc.RtpTransceiver
import livekit.org.webrtc.audio.AudioDeviceModule
import livekit.org.webrtc.audio.JavaAudioDeviceModule
import java.util.UUID
import javax.inject.Named

Expand All @@ -58,6 +62,8 @@ constructor(
private val audioProcessingController: AudioProcessingController,
@Named(InjectionNames.DISPATCHER_DEFAULT)
private val dispatcher: CoroutineDispatcher,
@Named(InjectionNames.LOCAL_AUDIO_RECORD_SAMPLES_DISPATCHER)
private val audioRecordSamplesDispatcher: AudioRecordSamplesDispatcher,
) : AudioTrack(name, mediaTrack) {
/**
* To only be used for flow delegate scoping, and should not be cancelled.
Expand All @@ -68,6 +74,30 @@ constructor(
internal val sender: RtpSender?
get() = transceiver?.sender

private val trackSinks = mutableSetOf<AudioTrackSink>()

/**
* Note: This function relies on us setting
* [JavaAudioDeviceModule.Builder.setSamplesReadyCallback].
* If you provide your own [AudioDeviceModule], or set your own
* callback, your sink will not receive any audio data.
*
* @see AudioTrack.addSink
*/
override fun addSink(sink: AudioTrackSink) {
synchronized(trackSinks) {
trackSinks.add(sink)
audioRecordSamplesDispatcher.registerSink(sink)
}
}

override fun removeSink(sink: AudioTrackSink) {
synchronized(trackSinks) {
trackSinks.remove(sink)
audioRecordSamplesDispatcher.unregisterSink(sink)
}
}

/**
* Changes can be observed by using [io.livekit.android.util.flow]
*/
Expand Down Expand Up @@ -107,6 +137,16 @@ constructor(
return features
}

override fun dispose() {
synchronized(trackSinks) {
for (sink in trackSinks) {
trackSinks.remove(sink)
audioRecordSamplesDispatcher.unregisterSink(sink)
}
}
super.dispose()
}

companion object {
internal fun createTrack(
context: Context,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,24 +29,13 @@ class RemoteAudioTrack(
internal val receiver: RtpReceiver,
) : io.livekit.android.room.track.AudioTrack(name, rtcTrack) {

/**
* Adds a sink that receives the audio bytes and related information
* for this audio track. Repeated calls using the same sink will
* only add the sink once.
*
* Implementations should copy the audio data into a local copy if they wish
* to use the data after this function returns.
*/
fun addSink(sink: AudioTrackSink) {
override fun addSink(sink: AudioTrackSink) {
withRTCTrack {
rtcTrack.addSink(sink)
}
}

/**
* Removes a previously added sink.
*/
fun removeSink(sink: AudioTrackSink) {
override fun removeSink(sink: AudioTrackSink) {
withRTCTrack {
rtcTrack.removeSink(sink)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import android.javax.sdp.SdpFactory
import dagger.Module
import dagger.Provides
import io.livekit.android.audio.AudioProcessingController
import io.livekit.android.audio.AudioRecordSamplesDispatcher
import io.livekit.android.dagger.CapabilitiesGetter
import io.livekit.android.dagger.InjectionNames
import io.livekit.android.test.mock.MockAudioDeviceModule
Expand Down Expand Up @@ -58,6 +59,13 @@ object TestRTCModule {
return MockAudioDeviceModule()
}

@Provides
@Named(InjectionNames.LOCAL_AUDIO_RECORD_SAMPLES_DISPATCHER)
@Singleton
fun localAudioSamplesDispatcher(): AudioRecordSamplesDispatcher {
return AudioRecordSamplesDispatcher()
}

@Provides
@Singleton
fun peerConnectionFactory(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Copyright 2024 LiveKit, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.livekit.android.test.mock.room.track

import io.livekit.android.audio.AudioProcessingController
import io.livekit.android.audio.AudioRecordSamplesDispatcher
import io.livekit.android.room.track.LocalAudioTrack
import io.livekit.android.room.track.LocalAudioTrackOptions
import io.livekit.android.test.MockE2ETest
import io.livekit.android.test.mock.MockAudioProcessingController
import io.livekit.android.test.mock.MockAudioStreamTrack
import io.livekit.android.test.mock.TestData
import kotlinx.coroutines.CoroutineDispatcher
import kotlinx.coroutines.ExperimentalCoroutinesApi
import livekit.org.webrtc.AudioTrack

@OptIn(ExperimentalCoroutinesApi::class)
fun MockE2ETest.createMockLocalAudioTrack(
name: String = "",
mediaTrack: AudioTrack = MockAudioStreamTrack(id = TestData.LOCAL_TRACK_PUBLISHED.trackPublished.cid),
options: LocalAudioTrackOptions = LocalAudioTrackOptions(),
audioProcessingController: AudioProcessingController = MockAudioProcessingController(),
dispatcher: CoroutineDispatcher = coroutineRule.dispatcher,
audioRecordSamplesDispatcher: AudioRecordSamplesDispatcher = AudioRecordSamplesDispatcher(),
): LocalAudioTrack {
return LocalAudioTrack(
name = name,
mediaTrack = mediaTrack,
options = options,
audioProcessingController = audioProcessingController,
dispatcher = dispatcher,
audioRecordSamplesDispatcher = audioRecordSamplesDispatcher,
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,17 @@ import io.livekit.android.events.ParticipantEvent
import io.livekit.android.events.RoomEvent
import io.livekit.android.events.convert
import io.livekit.android.room.participant.ConnectionQuality
import io.livekit.android.room.track.LocalAudioTrack
import io.livekit.android.room.track.LocalAudioTrackOptions
import io.livekit.android.room.track.Track
import io.livekit.android.test.MockE2ETest
import io.livekit.android.test.assert.assertIsClassList
import io.livekit.android.test.events.EventCollector
import io.livekit.android.test.events.FlowCollector
import io.livekit.android.test.mock.MockAudioProcessingController
import io.livekit.android.test.mock.MockAudioStreamTrack
import io.livekit.android.test.mock.MockMediaStream
import io.livekit.android.test.mock.MockRtpReceiver
import io.livekit.android.test.mock.TestData
import io.livekit.android.test.mock.createMediaStreamId
import io.livekit.android.test.mock.room.track.createMockLocalAudioTrack
import io.livekit.android.util.flow
import io.livekit.android.util.toOkioByteString
import junit.framework.Assert.assertEquals
Expand Down Expand Up @@ -378,13 +376,7 @@ class RoomMockE2ETest : MockE2ETest() {
connect()

room.localParticipant.publishAudioTrack(
LocalAudioTrack(
name = "",
mediaTrack = MockAudioStreamTrack(id = TestData.LOCAL_TRACK_PUBLISHED.trackPublished.cid),
options = LocalAudioTrackOptions(),
audioProcessingController = MockAudioProcessingController(),
dispatcher = coroutineRule.dispatcher,
),
track = createMockLocalAudioTrack(),
)

val eventCollector = EventCollector(room.events, coroutineRule.scope)
Expand Down Expand Up @@ -427,13 +419,7 @@ class RoomMockE2ETest : MockE2ETest() {
return@registerSignalRequestHandler false
}
room.localParticipant.publishAudioTrack(
LocalAudioTrack(
name = "",
mediaTrack = MockAudioStreamTrack(id = TestData.LOCAL_TRACK_PUBLISHED.trackPublished.cid),
options = LocalAudioTrackOptions(),
audioProcessingController = MockAudioProcessingController(),
dispatcher = coroutineRule.dispatcher,
),
track = createMockLocalAudioTrack(),
)

val eventCollector = EventCollector(room.events, coroutineRule.scope)
Expand Down
Loading

0 comments on commit 4d97868

Please sign in to comment.