Skip to content

Commit

Permalink
Transcription events feature (#440)
Browse files Browse the repository at this point in the history
davidliu authored Jun 24, 2024
1 parent 67d6912 commit c0c04d0
Showing 12 changed files with 276 additions and 4 deletions.
2 changes: 1 addition & 1 deletion livekit-android-sdk/build.gradle
Original file line number Diff line number Diff line change
@@ -60,7 +60,7 @@ android {
buildConfig = true
}
kotlinOptions {
freeCompilerArgs = ["-Xinline-classes", "-opt-in=kotlin.RequiresOptIn"]
freeCompilerArgs = ["-Xinline-classes", "-opt-in=kotlin.RequiresOptIn", "-opt-in=io.livekit.android.annotations.Beta"]
jvmTarget = java_version
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Copyright 2024 LiveKit, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.livekit.android.annotations

@Retention(AnnotationRetention.BINARY)
@RequiresOptIn
annotation class Experimental

@Retention(AnnotationRetention.BINARY)
@RequiresOptIn
annotation class Alpha

@Retention(AnnotationRetention.BINARY)
@RequiresOptIn
annotation class Beta
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2023 LiveKit, Inc.
* Copyright 2023-2024 LiveKit, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -24,6 +24,7 @@ import io.livekit.android.room.track.LocalTrackPublication
import io.livekit.android.room.track.RemoteTrackPublication
import io.livekit.android.room.track.Track
import io.livekit.android.room.track.TrackPublication
import io.livekit.android.room.types.TranscriptionSegment

sealed class ParticipantEvent(open val participant: Participant) : Event() {
// all participants
@@ -152,4 +153,16 @@ sealed class ParticipantEvent(open val participant: Participant) : Event() {
val newPermissions: ParticipantPermission?,
val oldPermissions: ParticipantPermission?,
) : ParticipantEvent(participant)

class TranscriptionReceived(
override val participant: Participant,
/**
* The transcription segments.
*/
val transcriptions: List<TranscriptionSegment>,
/**
* The applicable track publication these transcriptions apply to.
*/
val publication: TrackPublication?,
) : ParticipantEvent(participant)
}
Original file line number Diff line number Diff line change
@@ -16,6 +16,7 @@

package io.livekit.android.events

import io.livekit.android.annotations.Beta
import io.livekit.android.e2ee.E2EEState
import io.livekit.android.room.Room
import io.livekit.android.room.participant.ConnectionQuality
@@ -27,6 +28,7 @@ import io.livekit.android.room.track.LocalTrackPublication
import io.livekit.android.room.track.RemoteTrackPublication
import io.livekit.android.room.track.Track
import io.livekit.android.room.track.TrackPublication
import io.livekit.android.room.types.TranscriptionSegment
import livekit.LivekitModels

sealed class RoomEvent(val room: Room) : Event() {
@@ -219,6 +221,23 @@ sealed class RoomEvent(val room: Room) : Event() {
val participant: Participant,
var state: E2EEState,
) : RoomEvent(room)

@Beta
class TranscriptionReceived(
room: Room,
/**
* The transcription segments.
*/
val transcriptionSegments: List<TranscriptionSegment>,
/**
* The applicable participant these transcriptions apply to.
*/
val participant: Participant?,
/**
* The applicable track publication these transcriptions apply to.
*/
val publication: TrackPublication?,
) : RoomEvent(room)
}

enum class DisconnectReason {
Original file line number Diff line number Diff line change
@@ -758,6 +758,7 @@ internal constructor(
fun onFullReconnecting()
suspend fun onPostReconnect(isFullReconnect: Boolean)
fun onLocalTrackUnpublished(trackUnpublished: LivekitRtc.TrackUnpublishedResponse)
fun onTranscriptionReceived(transcription: LivekitModels.Transcription)
}

companion object {
@@ -981,7 +982,7 @@ internal constructor(
}

LivekitModels.DataPacket.ValueCase.TRANSCRIPTION -> {
// TODO
listener?.onTranscriptionReceived(dp.transcription)
}

LivekitModels.DataPacket.ValueCase.VALUE_NOT_SET,
21 changes: 21 additions & 0 deletions livekit-android-sdk/src/main/java/io/livekit/android/room/Room.kt
Original file line number Diff line number Diff line change
@@ -42,6 +42,7 @@ import io.livekit.android.room.network.NetworkCallbackManagerFactory
import io.livekit.android.room.participant.*
import io.livekit.android.room.provisions.LKObjects
import io.livekit.android.room.track.*
import io.livekit.android.room.types.toSDKType
import io.livekit.android.util.FlowObservable
import io.livekit.android.util.LKLog
import io.livekit.android.util.flow
@@ -1004,6 +1005,26 @@ constructor(
participant?.onDataReceived(data, topic)
}

/**
* @suppress
*/
override fun onTranscriptionReceived(transcription: LivekitModels.Transcription) {
val participant = getParticipantByIdentity(transcription.transcribedParticipantIdentity)
val publication = participant?.trackPublications?.get(transcription.trackId)
val segments = transcription.segmentsList
.map { it.toSDKType() }

val event = RoomEvent.TranscriptionReceived(
room = this,
transcriptionSegments = segments,
participant = participant,
publication = publication,
)
eventBus.tryPostEvent(event)
participant?.onTranscriptionReceived(event)
// TODO: Emit for publication
}

/**
* @suppress
*/
Original file line number Diff line number Diff line change
@@ -20,6 +20,7 @@ import androidx.annotation.VisibleForTesting
import io.livekit.android.dagger.InjectionNames
import io.livekit.android.events.BroadcastEventBus
import io.livekit.android.events.ParticipantEvent
import io.livekit.android.events.RoomEvent
import io.livekit.android.events.TrackEvent
import io.livekit.android.room.track.LocalTrackPublication
import io.livekit.android.room.track.RemoteTrackPublication
@@ -366,6 +367,20 @@ open class Participant(
)
}

internal fun onTranscriptionReceived(transcription: RoomEvent.TranscriptionReceived) {
if (transcription.participant != this) {
return
}
eventBus.postEvent(
ParticipantEvent.TranscriptionReceived(
this,
transcriptions = transcription.transcriptionSegments,
publication = transcription.publication,
),
scope,
)
}

internal fun reinitialize() {
if (!scope.isActive) {
scope = createScope()
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*
* Copyright 2024 LiveKit, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.livekit.android.room.types

import io.livekit.android.util.LKLog
import livekit.LivekitModels

data class TranscriptionSegment(
val id: String,
val text: String,
val language: String,
val startTime: Long,
val endTime: Long,
val final: Boolean,
) {
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (javaClass != other?.javaClass) return false

other as TranscriptionSegment

return id == other.id
}

override fun hashCode(): Int {
return id.hashCode()
}
}

/**
* Merges new segments into the map. The key should correspond to the segment id.
*/
fun MutableMap<String, TranscriptionSegment>.mergeNewSegments(newSegments: Collection<TranscriptionSegment>) {
for (segment in newSegments) {
val existingSegment = get(segment.id)
if (existingSegment?.final == true) {
LKLog.d { "new segment for ${segment.id} overwriting final segment?" }
}
put(segment.id, segment)
}
}

/**
* @suppress
*/
fun LivekitModels.TranscriptionSegment.toSDKType() =
TranscriptionSegment(
id = id,
text = text,
language = language,
startTime = startTime,
endTime = endTime,
final = final,
)
2 changes: 1 addition & 1 deletion livekit-android-test/build.gradle
Original file line number Diff line number Diff line change
@@ -31,7 +31,7 @@ android {
targetCompatibility java_version
}
kotlinOptions {
freeCompilerArgs = ["-Xinline-classes", "-opt-in=kotlin.RequiresOptIn"]
freeCompilerArgs = ["-Xinline-classes", "-opt-in=kotlin.RequiresOptIn", "-opt-in=io.livekit.android.annotations.Beta"]
jvmTarget = java_version
}
testOptions {
Original file line number Diff line number Diff line change
@@ -302,4 +302,25 @@ object TestData {
}
build()
}

// Data packets

val DATA_PACKET_TRANSCRIPTION = with(LivekitModels.DataPacket.newBuilder()) {
transcription = with(LivekitModels.Transcription.newBuilder()) {
transcribedParticipantIdentity = JOIN.join.participant.identity // Local participant's identity
addSegments(
with(LivekitModels.TranscriptionSegment.newBuilder()) {
id = "id"
language = "enUS"
text = "This is a transcription."
startTime = 1
endTime = 10
final = true
build()
},
)
build()
}
build()
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/*
* Copyright 2024 LiveKit, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.livekit.android.test.util

import com.google.protobuf.MessageLite
import livekit.org.webrtc.DataChannel
import java.nio.ByteBuffer

fun MessageLite.toDataChannelBuffer() =
DataChannel.Buffer(
ByteBuffer.wrap(toByteArray()),
true,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/*
* Copyright 2023-2024 LiveKit, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.livekit.android.room

import io.livekit.android.events.RoomEvent
import io.livekit.android.test.MockE2ETest
import io.livekit.android.test.assert.assertIsClass
import io.livekit.android.test.events.EventCollector
import io.livekit.android.test.mock.MockDataChannel
import io.livekit.android.test.mock.MockPeerConnection
import io.livekit.android.test.mock.TestData
import io.livekit.android.test.util.toDataChannelBuffer
import kotlinx.coroutines.ExperimentalCoroutinesApi
import org.junit.Assert.assertEquals
import org.junit.Test

@OptIn(ExperimentalCoroutinesApi::class)
class RoomTranscriptionMockE2ETest : MockE2ETest() {
@Test
fun transcriptionReceived() = runTest {
connect()
val subPeerConnection = component.rtcEngine().getSubscriberPeerConnection() as MockPeerConnection
val subDataChannel = MockDataChannel(RTCEngine.RELIABLE_DATA_CHANNEL_LABEL)
subPeerConnection.observer?.onDataChannel(subDataChannel)

val collector = EventCollector(room.events, coroutineRule.scope)
val dataBuffer = TestData.DATA_PACKET_TRANSCRIPTION.toDataChannelBuffer()

subDataChannel.observer?.onMessage(dataBuffer)
val events = collector.stopCollecting()

assertEquals(1, events.size)
assertIsClass(RoomEvent.TranscriptionReceived::class.java, events[0])

val event = events.first() as RoomEvent.TranscriptionReceived
assertEquals(room, event.room)
assertEquals(room.localParticipant, event.participant)

val expectedSegment = TestData.DATA_PACKET_TRANSCRIPTION.transcription.getSegments(0)
val receivedSegment = event.transcriptionSegments.first()
assertEquals(expectedSegment.id, receivedSegment.id)
assertEquals(expectedSegment.text, receivedSegment.text)
}
}

0 comments on commit c0c04d0

Please sign in to comment.