diff --git a/.changeset/nasty-kids-fail.md b/.changeset/nasty-kids-fail.md new file mode 100644 index 000000000..518b0671e --- /dev/null +++ b/.changeset/nasty-kids-fail.md @@ -0,0 +1,8 @@ +--- +"client-sdk-android": minor +--- + +End to end encryption for data channels option + +* Added EncryptionType fields to DataReceived events and StreamInfo objects to indicate the + encryption status. diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 143210cdb..fb47049cb 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -1,5 +1,5 @@ [versions] -webrtc = "137.7151.03" +webrtc = "137.7151.04" androidJainSipRi = "1.3.0-91" androidx-activity = "1.9.0" diff --git a/livekit-android-sdk/src/main/java/io/livekit/android/dagger/RTCModule.kt b/livekit-android-sdk/src/main/java/io/livekit/android/dagger/RTCModule.kt index c4a02ed61..9f6da9221 100644 --- a/livekit-android-sdk/src/main/java/io/livekit/android/dagger/RTCModule.kt +++ b/livekit-android-sdk/src/main/java/io/livekit/android/dagger/RTCModule.kt @@ -33,6 +33,8 @@ import io.livekit.android.audio.AudioRecordSamplesDispatcher import io.livekit.android.audio.CommunicationWorkaround import io.livekit.android.audio.JavaAudioRecordPrewarmer import io.livekit.android.audio.NoAudioRecordPrewarmer +import io.livekit.android.e2ee.DataPacketCryptorManager +import io.livekit.android.e2ee.DataPacketCryptorManagerImpl import io.livekit.android.memory.CloseableManager import io.livekit.android.util.LKLog import io.livekit.android.util.LoggingLevel @@ -373,6 +375,11 @@ internal object RTCModule { }!! } + @Provides + fun dataPacketCryptorManagerFactory(): DataPacketCryptorManager.Factory { + return DataPacketCryptorManagerImpl.Factory + } + @Provides @Singleton fun peerConnectionFactory( diff --git a/livekit-android-sdk/src/main/java/io/livekit/android/e2ee/DataPacketCryptorManager.kt b/livekit-android-sdk/src/main/java/io/livekit/android/e2ee/DataPacketCryptorManager.kt new file mode 100644 index 000000000..bb06b701d --- /dev/null +++ b/livekit-android-sdk/src/main/java/io/livekit/android/e2ee/DataPacketCryptorManager.kt @@ -0,0 +1,129 @@ +/* + * Copyright 2025 LiveKit, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.livekit.android.e2ee + +import io.livekit.android.room.participant.Participant +import io.livekit.android.util.LKLog +import livekit.LivekitModels +import livekit.org.webrtc.DataPacketCryptor +import livekit.org.webrtc.DataPacketCryptorFactory +import livekit.org.webrtc.FrameCryptorAlgorithm + +/** + * @suppress + */ +interface DataPacketCryptorManager { + fun encrypt(participantId: Participant.Identity, keyIndex: Int, payload: ByteArray): EncryptedPacket? + fun decrypt(participantId: Participant.Identity, packet: EncryptedPacket): ByteArray? + fun dispose() + + interface Factory { + fun create(keyProvider: KeyProvider): DataPacketCryptorManager + } +} + +/** + * @suppress + */ +class EncryptedPacket( + val payload: ByteArray, + val iv: ByteArray, + val keyIndex: Int, +) + +/** + * @suppress + */ +fun LivekitModels.EncryptedPacket.toSdkType() = + EncryptedPacket( + payload = this.encryptedValue.toByteArray(), + iv = this.iv.toByteArray(), + keyIndex = this.keyIndex, + ) + +internal class DataPacketCryptorManagerImpl( + keyProvider: KeyProvider, +) : DataPacketCryptorManager { + var isDisposed = false + private val dataPacketCryptor: DataPacketCryptor = DataPacketCryptorFactory.createDataPacketCryptor(FrameCryptorAlgorithm.AES_GCM, keyProvider.rtcKeyProvider) + + @Synchronized + override fun encrypt(participantId: Participant.Identity, keyIndex: Int, payload: ByteArray): EncryptedPacket? { + if (isDisposed) { + return null + } + val packet = dataPacketCryptor.encrypt( + participantId.value, + keyIndex, + payload, + ) + + if (packet == null) { + LKLog.i { "Error encrypting packet: null packet" } + return null + } + + val payload = packet.payload + val iv = packet.iv + val keyIndex = packet.keyIndex + + if (payload == null) { + LKLog.w { "Error encrypting packet: null payload" } + return null + } + if (iv == null) { + LKLog.i { "Error encrypting packet: null iv returned" } + return null + } + + return EncryptedPacket( + payload = payload, + iv = iv, + keyIndex = keyIndex, + ) + } + + @Synchronized + override fun decrypt(participantId: Participant.Identity, packet: EncryptedPacket): ByteArray? { + if (isDisposed) { + return null + } + return dataPacketCryptor.decrypt( + participantId.value, + DataPacketCryptor.EncryptedPacket( + packet.payload, + packet.iv, + packet.keyIndex, + ), + ) + } + + @Synchronized + override fun dispose() { + if (isDisposed) { + return + } + isDisposed = true + dataPacketCryptor.dispose() + } + + object Factory : DataPacketCryptorManager.Factory { + override fun create(keyProvider: KeyProvider): DataPacketCryptorManager { + return DataPacketCryptorManagerImpl(keyProvider) + } + } +} diff --git a/livekit-android-sdk/src/main/java/io/livekit/android/e2ee/E2EEManager.kt b/livekit-android-sdk/src/main/java/io/livekit/android/e2ee/E2EEManager.kt index fa5857d1f..cbbeb86eb 100644 --- a/livekit-android-sdk/src/main/java/io/livekit/android/e2ee/E2EEManager.kt +++ b/livekit-android-sdk/src/main/java/io/livekit/android/e2ee/E2EEManager.kt @@ -19,6 +19,7 @@ package io.livekit.android.e2ee import dagger.assisted.Assisted import dagger.assisted.AssistedFactory import dagger.assisted.AssistedInject +import io.livekit.android.annotations.Beta import io.livekit.android.events.RoomEvent import io.livekit.android.room.Room import io.livekit.android.room.participant.LocalParticipant @@ -42,45 +43,59 @@ import livekit.org.webrtc.RtpSender class E2EEManager @AssistedInject constructor( - @Assisted keyProvider: KeyProvider, - peerConnectionFactory: PeerConnectionFactory, + @Assisted val keyProvider: KeyProvider, + val peerConnectionFactory: PeerConnectionFactory, + dataPacketCryptorManagerFactory: DataPacketCryptorManager.Factory, ) { private var room: Room? = null - private var keyProvider: KeyProvider - private var peerConnectionFactory: PeerConnectionFactory private var frameCryptors = mutableMapOf, FrameCryptor>() private var algorithm: FrameCryptorAlgorithm = FrameCryptorAlgorithm.AES_GCM private lateinit var emitEvent: (roomEvent: RoomEvent) -> Unit? + + internal var dataPacketCryptorManager: DataPacketCryptorManager = dataPacketCryptorManagerFactory.create(keyProvider) + var enabled: Boolean = false + set(value) { + field = value + for (item in frameCryptors.entries) { + val frameCryptor = item.value + frameCryptor.isEnabled = enabled + } + } - init { - this.keyProvider = keyProvider - this.peerConnectionFactory = peerConnectionFactory + /** + * Enables data channel encryption. Decryption is always enabled for forward compatibility. + */ + @Beta + var dataChannelEncryptionEnabled = false + + fun isDataChannelEncryptionEnabled(): Boolean { + return enabled && dataChannelEncryptionEnabled } fun keyProvider(): KeyProvider { return this.keyProvider } - suspend fun setup(room: Room, emitEvent: (roomEvent: RoomEvent) -> Unit) { - if (this.room != room) { + fun setup(room: Room, emitEvent: (roomEvent: RoomEvent) -> Unit) { + if (this.room != room && this.room != null) { // E2EEManager already setup, clean up first - cleanUp() + cleanup() } this.enabled = true this.room = room this.emitEvent = emitEvent this.room?.localParticipant?.trackPublications?.forEach { item -> - var participant = this.room!!.localParticipant - var publication = item.value + val participant = this.room!!.localParticipant + val publication = item.value if (publication.track != null) { addPublishedTrack(publication.track!!, publication, participant, room) } } this.room?.remoteParticipants?.forEach { item -> - var participant = item.value + val participant = item.value participant.trackPublications.forEach { item -> - var publication = item.value + val publication = item.value if (publication.track != null) { addSubscribedTrack(publication.track!!, publication, participant, room) } @@ -89,14 +104,14 @@ constructor( } fun addSubscribedTrack(track: Track, publication: TrackPublication, participant: RemoteParticipant, room: Room) { - var rtpReceiver: RtpReceiver? = when (publication.track!!) { + val rtpReceiver: RtpReceiver? = when (publication.track!!) { is RemoteAudioTrack -> (publication.track!! as RemoteAudioTrack).receiver is RemoteVideoTrack -> (publication.track!! as RemoteVideoTrack).receiver else -> { throw IllegalArgumentException("unsupported track type") } } - var frameCryptor = addRtpReceiver(rtpReceiver!!, participant.identity!!, publication.sid, publication.track!!.kind.name.lowercase()) + val frameCryptor = addRtpReceiver(rtpReceiver!!, participant.identity!!, publication.sid, publication.track!!.kind.name.lowercase()) frameCryptor.setObserver { trackId, state -> LKLog.i { "Receiver::onFrameCryptionStateChanged: $trackId, state: $state" } emitEvent( @@ -112,9 +127,9 @@ constructor( } fun removeSubscribedTrack(track: Track, publication: TrackPublication, participant: RemoteParticipant, room: Room) { - var trackId = publication.sid - var participantId = participant.identity - var frameCryptor = frameCryptors.get(trackId to participantId) + val trackId = publication.sid + val participantId = participant.identity + val frameCryptor = frameCryptors.get(trackId to participantId) if (frameCryptor != null) { frameCryptor.isEnabled = false frameCryptor.dispose() @@ -123,7 +138,7 @@ constructor( } fun addPublishedTrack(track: Track, publication: TrackPublication, participant: LocalParticipant, room: Room) { - var rtpSender: RtpSender? = when (publication.track!!) { + val rtpSender: RtpSender? = when (publication.track!!) { is LocalAudioTrack -> (publication.track!! as LocalAudioTrack)?.sender is LocalVideoTrack -> (publication.track!! as LocalVideoTrack)?.sender else -> { @@ -131,12 +146,12 @@ constructor( } } ?: throw IllegalArgumentException("rtpSender is null") - var frameCryptor = addRtpSender(rtpSender!!, participant.identity!!, publication.sid, publication.track!!.kind.name.lowercase()) + val frameCryptor = addRtpSender(rtpSender!!, participant.identity!!, publication.sid, publication.track!!.kind.name.lowercase()) frameCryptor.setObserver { trackId, state -> LKLog.i { "Sender::onFrameCryptionStateChanged: $trackId, state: $state" } emitEvent( RoomEvent.TrackE2EEStateEvent( - room!!, + room, publication.track!!, publication, participant, @@ -147,9 +162,9 @@ constructor( } fun removePublishedTrack(track: Track, publication: TrackPublication, participant: LocalParticipant, room: Room) { - var trackId = publication.sid - var participantId = participant.identity - var frameCryptor = frameCryptors.get(trackId to participantId) + val trackId = publication.sid + val participantId = participant.identity + val frameCryptor = frameCryptors.get(trackId to participantId) if (frameCryptor != null) { frameCryptor.isEnabled = false frameCryptor.dispose() @@ -171,7 +186,7 @@ constructor( } private fun addRtpSender(sender: RtpSender, participantId: Participant.Identity, trackId: String, kind: String): FrameCryptor { - var frameCryptor = FrameCryptorFactory.createFrameCryptorForRtpSender( + val frameCryptor = FrameCryptorFactory.createFrameCryptorForRtpSender( peerConnectionFactory, sender, participantId.value, @@ -180,12 +195,13 @@ constructor( ) frameCryptors[trackId to participantId] = frameCryptor - frameCryptor.setEnabled(enabled) + frameCryptor.isEnabled = enabled + frameCryptor.keyIndex = keyProvider.getLatestKeyIndex(participantId.value) return frameCryptor } private fun addRtpReceiver(receiver: RtpReceiver, participantId: Participant.Identity, trackId: String, kind: String): FrameCryptor { - var frameCryptor = FrameCryptorFactory.createFrameCryptorForRtpReceiver( + val frameCryptor = FrameCryptorFactory.createFrameCryptorForRtpReceiver( peerConnectionFactory, receiver, participantId.value, @@ -194,7 +210,8 @@ constructor( ) frameCryptors[trackId to participantId] = frameCryptor - frameCryptor.setEnabled(enabled) + frameCryptor.isEnabled = enabled + frameCryptor.keyIndex = keyProvider.getLatestKeyIndex(participantId.value) return frameCryptor } @@ -204,27 +221,43 @@ constructor( */ fun enableE2EE(enabled: Boolean) { this.enabled = enabled - for (item in frameCryptors.entries) { - var frameCryptor = item.value - frameCryptor.setEnabled(enabled) - } } /** * Ratchet key for local participant */ fun ratchetKey() { - var newKey = keyProvider.ratchetSharedKey() + val newKey = keyProvider.ratchetSharedKey() LKLog.d { "ratchetSharedKey: newKey: $newKey" } } - fun cleanUp() { + internal fun cleanup() { for (frameCryptor in frameCryptors.values) { frameCryptor.dispose() } frameCryptors.clear() } + internal fun dispose() { + dataPacketCryptorManager.dispose() + } + + fun encrypt(byteArray: ByteArray): EncryptedPacket? { + val participantId = room?.localParticipant?.identity ?: Participant.Identity("") + return dataPacketCryptorManager.encrypt( + participantId, + keyIndex = keyProvider.getLatestKeyIndex(participantId.value), + payload = byteArray, + ) + } + + fun decrypt(participantId: Participant.Identity, packet: EncryptedPacket): ByteArray? { + return dataPacketCryptorManager.decrypt( + participantId = participantId, + packet = packet, + ) + } + @AssistedFactory interface Factory { fun create( diff --git a/livekit-android-sdk/src/main/java/io/livekit/android/e2ee/KeyProvider.kt b/livekit-android-sdk/src/main/java/io/livekit/android/e2ee/KeyProvider.kt index b57821c39..a30828f4a 100644 --- a/livekit-android-sdk/src/main/java/io/livekit/android/e2ee/KeyProvider.kt +++ b/livekit-android-sdk/src/main/java/io/livekit/android/e2ee/KeyProvider.kt @@ -20,7 +20,7 @@ import io.livekit.android.util.LKLog import livekit.org.webrtc.FrameCryptorFactory import livekit.org.webrtc.FrameCryptorKeyProvider -class KeyInfo(var participantId: String, var keyIndex: Int, var key: String) { +internal class KeyInfo(var participantId: String, var keyIndex: Int, var key: String) { override fun toString(): String { return "KeyInfo(participantId='$participantId', keyIndex=$keyIndex)" } @@ -34,6 +34,7 @@ interface KeyProvider { fun ratchetKey(participantId: String, keyIndex: Int? = 0): ByteArray fun exportKey(participantId: String, keyIndex: Int? = 0): ByteArray fun setSifTrailer(trailer: ByteArray) + fun getLatestKeyIndex(participantId: String): Int val rtcKeyProvider: FrameCryptorKeyProvider @@ -49,21 +50,19 @@ class BaseKeyProvider( keyRingSize: Int = defaultKeyRingSize, discardFrameWhenCryptorNotReady: Boolean = defaultDiscardFrameWhenCryptorNotReady, ) : KeyProvider { - override val rtcKeyProvider: FrameCryptorKeyProvider - - private var keys: MutableMap> = mutableMapOf() - - init { - this.rtcKeyProvider = FrameCryptorFactory.createFrameCryptorKeyProvider( - enableSharedKey, - ratchetSalt.toByteArray(), - ratchetWindowSize, - uncryptedMagicBytes.toByteArray(), - failureTolerance, - keyRingSize, - discardFrameWhenCryptorNotReady, - ) - } + + private val latestSetIndex = mutableMapOf() + + override val rtcKeyProvider: FrameCryptorKeyProvider = FrameCryptorFactory.createFrameCryptorKeyProvider( + enableSharedKey, + ratchetSalt.toByteArray(), + ratchetWindowSize, + uncryptedMagicBytes.toByteArray(), + failureTolerance, + keyRingSize, + discardFrameWhenCryptorNotReady, + ) + override fun setSharedKey(key: String, keyIndex: Int?): Boolean { return rtcKeyProvider.setSharedKey(keyIndex ?: 0, key.toByteArray()) } @@ -92,13 +91,10 @@ class BaseKeyProvider( return } - var keyInfo = KeyInfo(participantId, keyIndex ?: 0, key) + val keyIndex = keyIndex ?: 0 + latestSetIndex[participantId] = keyIndex - if (!keys.containsKey(keyInfo.participantId)) { - keys[keyInfo.participantId] = mutableMapOf() - } - keys[keyInfo.participantId]!![keyInfo.keyIndex] = keyInfo.key - rtcKeyProvider.setKey(participantId, keyInfo.keyIndex, key.toByteArray()) + rtcKeyProvider.setKey(participantId, keyIndex, key.toByteArray()) } override fun ratchetKey(participantId: String, keyIndex: Int?): ByteArray { @@ -112,4 +108,8 @@ class BaseKeyProvider( override fun setSifTrailer(trailer: ByteArray) { rtcKeyProvider.setSifTrailer(trailer) } + + override fun getLatestKeyIndex(participantId: String): Int { + return latestSetIndex[participantId] ?: 0 + } } diff --git a/livekit-android-sdk/src/main/java/io/livekit/android/events/ParticipantEvent.kt b/livekit-android-sdk/src/main/java/io/livekit/android/events/ParticipantEvent.kt index 2d2489399..d189cc9dc 100644 --- a/livekit-android-sdk/src/main/java/io/livekit/android/events/ParticipantEvent.kt +++ b/livekit-android-sdk/src/main/java/io/livekit/android/events/ParticipantEvent.kt @@ -26,6 +26,7 @@ import io.livekit.android.room.track.Track import io.livekit.android.room.track.TrackException import io.livekit.android.room.track.TrackPublication import io.livekit.android.room.types.TranscriptionSegment +import livekit.LivekitModels sealed class ParticipantEvent(open val participant: Participant) : Event() { // all participants @@ -155,6 +156,7 @@ sealed class ParticipantEvent(open val participant: Participant) : Event() { override val participant: RemoteParticipant, val data: ByteArray, val topic: String?, + val encryptionType: LivekitModels.Encryption.Type, ) : ParticipantEvent(participant) /** diff --git a/livekit-android-sdk/src/main/java/io/livekit/android/events/RoomEvent.kt b/livekit-android-sdk/src/main/java/io/livekit/android/events/RoomEvent.kt index 9f57e4617..0181b0cb6 100644 --- a/livekit-android-sdk/src/main/java/io/livekit/android/events/RoomEvent.kt +++ b/livekit-android-sdk/src/main/java/io/livekit/android/events/RoomEvent.kt @@ -213,7 +213,13 @@ sealed class RoomEvent(val room: Room) : Event() { * @param data the published data * @param participant the participant if available */ - class DataReceived(room: Room, val data: ByteArray, val participant: RemoteParticipant?, val topic: String?) : + class DataReceived( + room: Room, + val data: ByteArray, + val participant: RemoteParticipant?, + val topic: String?, + val encryptionType: LivekitModels.Encryption.Type, + ) : RoomEvent(room) /** diff --git a/livekit-android-sdk/src/main/java/io/livekit/android/room/RTCEngine.kt b/livekit-android-sdk/src/main/java/io/livekit/android/room/RTCEngine.kt index 3ca981080..1db8bbbeb 100644 --- a/livekit-android-sdk/src/main/java/io/livekit/android/room/RTCEngine.kt +++ b/livekit-android-sdk/src/main/java/io/livekit/android/room/RTCEngine.kt @@ -24,8 +24,12 @@ import com.vdurmont.semver4j.Semver import io.livekit.android.ConnectOptions import io.livekit.android.RoomOptions import io.livekit.android.dagger.InjectionNames +import io.livekit.android.e2ee.DataPacketCryptorManager +import io.livekit.android.e2ee.E2EEManager +import io.livekit.android.e2ee.EncryptedPacket import io.livekit.android.events.DisconnectReason import io.livekit.android.events.convert +import io.livekit.android.room.participant.Participant import io.livekit.android.room.participant.ParticipantTrackPermission import io.livekit.android.room.track.TrackException import io.livekit.android.room.util.MediaConstraintKeys @@ -109,6 +113,7 @@ internal constructor( @Named(InjectionNames.DISPATCHER_IO) private val ioDispatcher: CoroutineDispatcher, private val rtcThreadToken: RTCThreadToken, + private val dataPacketCryptorFactory: DataPacketCryptorManager.Factory, ) : SignalClient.Listener { internal var listener: Listener? = null @@ -190,6 +195,10 @@ internal constructor( private var coroutineScope = CloseableCoroutineScope(SupervisorJob() + ioDispatcher) + internal var e2EEManager: E2EEManager? = null + private val dataPacketCryptorManager: DataPacketCryptorManager? + get() = e2EEManager?.dataPacketCryptorManager + /** * Note: If this lock is ever used in conjunction with the RTC thread, * this must be grabbed on the RTC thread to prevent deadlocks. @@ -671,6 +680,28 @@ internal constructor( try { // Redeclare to make variable var dataPacket = dataPacket + + val e2EEManager = e2EEManager + val dataEncryptionEnabled = e2EEManager?.isDataChannelEncryptionEnabled() ?: false + if (dataEncryptionEnabled && e2EEManager != null) { + val encryptedPacketPayload = dataPacket.asEncryptedPacketPayload() + if (encryptedPacketPayload != null) { + val encryptedData = e2EEManager.encrypt(encryptedPacketPayload.toByteArray()) + if (encryptedData != null) { + dataPacket = with(dataPacket.toBuilder()) { + encryptedPacket = with(LivekitModels.EncryptedPacket.newBuilder()) { + encryptedValue = ByteString.copyFrom(encryptedData.payload) + iv = ByteString.copyFrom(encryptedData.iv) + keyIndex = encryptedData.keyIndex + encryptionType = LivekitModels.Encryption.Type.GCM + build() + } + build() + } + } + } + } + if (dataPacket.kind == LivekitModels.DataPacket.Kind.RELIABLE) { dataPacket = dataPacket.toBuilder() .setSequence(reliableDataSequence) @@ -900,7 +931,7 @@ internal constructor( fun onRoomUpdate(update: LivekitModels.Room) fun onConnectionQuality(updates: List) fun onSpeakersChanged(speakers: List) - fun onUserPacket(packet: LivekitModels.UserPacket, kind: LivekitModels.DataPacket.Kind) + fun onUserPacket(packet: LivekitModels.UserPacket, kind: LivekitModels.DataPacket.Kind, encryptionType: LivekitModels.Encryption.Type) fun onStreamStateUpdate(streamStates: List) fun onSubscribedQualityUpdate(subscribedQualityUpdate: LivekitRtc.SubscribedQualityUpdate) fun onSubscriptionPermissionUpdate(subscriptionPermissionUpdate: LivekitRtc.SubscriptionPermissionUpdate) @@ -911,7 +942,7 @@ internal constructor( fun onTranscriptionReceived(transcription: LivekitModels.Transcription) fun onLocalTrackSubscribed(trackSubscribed: LivekitRtc.TrackSubscribed) fun onRpcPacketReceived(dp: LivekitModels.DataPacket) - fun onDataStreamPacket(dp: LivekitModels.DataPacket) + fun onDataStreamPacket(dp: LivekitModels.DataPacket, encryptionType: LivekitModels.Encryption.Type) } companion object { @@ -1154,7 +1185,7 @@ internal constructor( if (buffer == null) { return } - val dp = LivekitModels.DataPacket.parseFrom(ByteString.copyFrom(buffer.data)) + var dp = LivekitModels.DataPacket.parseFrom(ByteString.copyFrom(buffer.data)) if (dp.sequence > 0 && dp.participantSid.isNotEmpty()) { synchronized(reliableStateLock) { @@ -1166,13 +1197,38 @@ internal constructor( this.reliableReceivedState[dp.participantSid] = dp.sequence } } + + // Always decrypt if able, to allow for backward compatibility. + val dataPacketCryptor = dataPacketCryptorManager + var encryptionType = LivekitModels.Encryption.Type.NONE + if (dp.hasEncryptedPacket() && dataPacketCryptor != null) { + val encryptedPacket = EncryptedPacket( + dp.encryptedPacket.encryptedValue.toByteArray(), + dp.encryptedPacket.iv.toByteArray(), + dp.encryptedPacket.keyIndex, + ) + encryptionType = dp.encryptedPacket.encryptionType + + val decryptedData = dataPacketCryptor.decrypt(Participant.Identity(dp.participantIdentity), encryptedPacket) + if (decryptedData == null) { + LKLog.i { "Failed to decrypt data packet." } + return + } + val payload = LivekitModels.EncryptedPacketPayload.parseFrom(decryptedData) + + dp = with(dp.toBuilder()) { + setFromEncryptedPayload(payload) + build() + } + } + when (dp.valueCase) { LivekitModels.DataPacket.ValueCase.SPEAKER -> { listener?.onActiveSpeakersUpdate(dp.speaker.speakersList) } LivekitModels.DataPacket.ValueCase.USER -> { - listener?.onUserPacket(dp.user, dp.kind) + listener?.onUserPacket(dp.user, dp.kind, encryptionType) } LivekitModels.DataPacket.ValueCase.SIP_DTMF -> { @@ -1198,17 +1254,21 @@ internal constructor( listener?.onRpcPacketReceived(dp) } - LivekitModels.DataPacket.ValueCase.VALUE_NOT_SET, - null, - -> { - LKLog.v { "invalid value for data packet" } - } - LivekitModels.DataPacket.ValueCase.STREAM_HEADER, LivekitModels.DataPacket.ValueCase.STREAM_CHUNK, LivekitModels.DataPacket.ValueCase.STREAM_TRAILER, -> { - listener?.onDataStreamPacket(dp) + listener?.onDataStreamPacket(dp, encryptionType) + } + + LivekitModels.DataPacket.ValueCase.ENCRYPTED_PACKET -> { + // should be handled above. + } + + LivekitModels.DataPacket.ValueCase.VALUE_NOT_SET, + null, + -> { + LKLog.v { "invalid value for data packet" } } } } @@ -1347,6 +1407,9 @@ enum class ReconnectType { FORCE_FULL_RECONNECT, } +/** + * @suppress + */ fun LivekitRtc.ICEServer.toWebrtc(): PeerConnection.IceServer = PeerConnection.IceServer.builder(urlsList) .setUsername(username ?: "") .setPassword(credential ?: "") @@ -1355,3 +1418,106 @@ fun LivekitRtc.ICEServer.toWebrtc(): PeerConnection.IceServer = PeerConnection.I .createIceServer() typealias PeerConnectionStateListener = (PeerConnectionState) -> Unit + +internal fun LivekitModels.DataPacket.asEncryptedPacketPayload(): LivekitModels.EncryptedPacketPayload? { + return when (valueCase) { + LivekitModels.DataPacket.ValueCase.USER -> { + LivekitModels.EncryptedPacketPayload.newBuilder() + .setUser(this.user) + .build() + } + + LivekitModels.DataPacket.ValueCase.RPC_REQUEST -> { + LivekitModels.EncryptedPacketPayload.newBuilder() + .setRpcRequest(this.rpcRequest) + .build() + } + + LivekitModels.DataPacket.ValueCase.RPC_ACK -> { + LivekitModels.EncryptedPacketPayload.newBuilder() + .setRpcAck(this.rpcAck) + .build() + } + + LivekitModels.DataPacket.ValueCase.RPC_RESPONSE -> { + LivekitModels.EncryptedPacketPayload.newBuilder() + .setRpcResponse(this.rpcResponse) + .build() + } + + LivekitModels.DataPacket.ValueCase.STREAM_HEADER -> { + LivekitModels.EncryptedPacketPayload.newBuilder() + .setStreamHeader(this.streamHeader) + .build() + } + + LivekitModels.DataPacket.ValueCase.STREAM_CHUNK -> { + LivekitModels.EncryptedPacketPayload.newBuilder() + .setStreamChunk(this.streamChunk) + .build() + } + + LivekitModels.DataPacket.ValueCase.STREAM_TRAILER -> { + LivekitModels.EncryptedPacketPayload.newBuilder() + .setStreamTrailer(this.streamTrailer) + .build() + } + + LivekitModels.DataPacket.ValueCase.CHAT_MESSAGE -> { + LivekitModels.EncryptedPacketPayload.newBuilder() + .setChatMessage(this.chatMessage) + .build() + } + + LivekitModels.DataPacket.ValueCase.METRICS, + LivekitModels.DataPacket.ValueCase.SIP_DTMF, + LivekitModels.DataPacket.ValueCase.SPEAKER, + LivekitModels.DataPacket.ValueCase.ENCRYPTED_PACKET, + LivekitModels.DataPacket.ValueCase.TRANSCRIPTION, + LivekitModels.DataPacket.ValueCase.VALUE_NOT_SET, + -> { + null + } + } +} + +internal fun LivekitModels.DataPacket.Builder.setFromEncryptedPayload(payload: LivekitModels.EncryptedPacketPayload) { + when (payload.valueCase) { + LivekitModels.EncryptedPacketPayload.ValueCase.USER -> { + this.user = payload.user + } + + LivekitModels.EncryptedPacketPayload.ValueCase.CHAT_MESSAGE -> { + this.chatMessage = payload.chatMessage + } + + LivekitModels.EncryptedPacketPayload.ValueCase.RPC_REQUEST -> { + this.rpcRequest = payload.rpcRequest + } + + LivekitModels.EncryptedPacketPayload.ValueCase.RPC_ACK -> { + this.rpcAck = payload.rpcAck + } + + LivekitModels.EncryptedPacketPayload.ValueCase.RPC_RESPONSE -> { + this.rpcResponse = payload.rpcResponse + } + + LivekitModels.EncryptedPacketPayload.ValueCase.STREAM_HEADER -> { + this.streamHeader = payload.streamHeader + } + + LivekitModels.EncryptedPacketPayload.ValueCase.STREAM_CHUNK -> { + this.streamChunk = payload.streamChunk + } + + LivekitModels.EncryptedPacketPayload.ValueCase.STREAM_TRAILER -> { + this.streamTrailer = payload.streamTrailer + } + + LivekitModels.EncryptedPacketPayload.ValueCase.VALUE_NOT_SET -> { + // decryption likely failed + LKLog.w { "Attempting to set from non-valid payload" } + } + } +} diff --git a/livekit-android-sdk/src/main/java/io/livekit/android/room/Room.kt b/livekit-android-sdk/src/main/java/io/livekit/android/room/Room.kt index 69b31f909..e33cb67b5 100644 --- a/livekit-android-sdk/src/main/java/io/livekit/android/room/Room.kt +++ b/livekit-android-sdk/src/main/java/io/livekit/android/room/Room.kt @@ -452,6 +452,7 @@ constructor( } } } + engine.e2EEManager = e2eeManager } } @@ -976,7 +977,7 @@ constructor( * Removes all participants and tracks from the room. */ private fun cleanupRoom() { - e2eeManager?.cleanUp() + e2eeManager?.dispose() e2eeManager = null localParticipant.cleanup() remoteParticipants.keys.toMutableSet() // copy keys to avoid concurrent modifications. @@ -1271,7 +1272,7 @@ constructor( /** * @suppress */ - override fun onUserPacket(packet: LivekitModels.UserPacket, kind: LivekitModels.DataPacket.Kind) { + override fun onUserPacket(packet: LivekitModels.UserPacket, kind: LivekitModels.DataPacket.Kind, encryptionType: LivekitModels.Encryption.Type) { val participant = getParticipantBySid(packet.participantSid) as? RemoteParticipant val data = packet.payload.toByteArray() val topic = if (packet.hasTopic()) { @@ -1280,25 +1281,26 @@ constructor( null } - eventBus.postEvent(RoomEvent.DataReceived(this, data, participant, topic), coroutineScope) - participant?.onDataReceived(data, topic) + val event = RoomEvent.DataReceived(this, data, participant, topic, encryptionType) + eventBus.postEvent(event, coroutineScope) + participant?.onDataReceived(event) } /** * @suppress */ - override fun onDataStreamPacket(dp: LivekitModels.DataPacket) { + override fun onDataStreamPacket(dp: LivekitModels.DataPacket, encryptionType: LivekitModels.Encryption.Type) { when (dp.valueCase) { LivekitModels.DataPacket.ValueCase.STREAM_HEADER -> { - incomingDataStreamManager.handleStreamHeader(dp.streamHeader, Participant.Identity(dp.participantIdentity)) + incomingDataStreamManager.handleStreamHeader(dp.streamHeader, Participant.Identity(dp.participantIdentity), encryptionType) } LivekitModels.DataPacket.ValueCase.STREAM_CHUNK -> { - incomingDataStreamManager.handleDataChunk(dp.streamChunk) + incomingDataStreamManager.handleDataChunk(dp.streamChunk, encryptionType) } LivekitModels.DataPacket.ValueCase.STREAM_TRAILER -> { - incomingDataStreamManager.handleStreamTrailer(dp.streamTrailer) + incomingDataStreamManager.handleStreamTrailer(dp.streamTrailer, encryptionType) } // Ignore other cases. diff --git a/livekit-android-sdk/src/main/java/io/livekit/android/room/SignalClient.kt b/livekit-android-sdk/src/main/java/io/livekit/android/room/SignalClient.kt index 9320848c5..fb6a48f0e 100644 --- a/livekit-android-sdk/src/main/java/io/livekit/android/room/SignalClient.kt +++ b/livekit-android-sdk/src/main/java/io/livekit/android/room/SignalClient.kt @@ -784,6 +784,10 @@ constructor( LivekitRtc.SignalResponse.MessageCase.ROOM_MOVED -> { // TODO } + + LivekitRtc.SignalResponse.MessageCase.MEDIA_SECTIONS_REQUIREMENT -> { + // TODO + } } } diff --git a/livekit-android-sdk/src/main/java/io/livekit/android/room/datastream/StreamException.kt b/livekit-android-sdk/src/main/java/io/livekit/android/room/datastream/StreamException.kt index ae8b8b583..9c5ccb57d 100644 --- a/livekit-android-sdk/src/main/java/io/livekit/android/room/datastream/StreamException.kt +++ b/livekit-android-sdk/src/main/java/io/livekit/android/room/datastream/StreamException.kt @@ -61,4 +61,9 @@ sealed class StreamException(message: String? = null) : Exception(message) { * Unable to read information about the file to send. */ class FileInfoUnavailableException : StreamException() + + /** + * + */ + class EncryptionTypeMismatch(message: String? = null) : StreamException(message) } diff --git a/livekit-android-sdk/src/main/java/io/livekit/android/room/datastream/StreamInfo.kt b/livekit-android-sdk/src/main/java/io/livekit/android/room/datastream/StreamInfo.kt index dbb7ebe24..bcd1abf5f 100644 --- a/livekit-android-sdk/src/main/java/io/livekit/android/room/datastream/StreamInfo.kt +++ b/livekit-android-sdk/src/main/java/io/livekit/android/room/datastream/StreamInfo.kt @@ -27,6 +27,7 @@ sealed class StreamInfo( open val timestampMs: Long, open val totalSize: Long?, open val attributes: Map, + open val encryptionType: LivekitModels.Encryption.Type, ) data class TextStreamInfo( @@ -40,8 +41,9 @@ data class TextStreamInfo( val replyToStreamId: String?, val attachedStreamIds: List, val generated: Boolean, -) : StreamInfo(id, topic, timestampMs, totalSize, attributes) { - constructor(header: Header, textHeader: TextHeader) : this( + override val encryptionType: LivekitModels.Encryption.Type, +) : StreamInfo(id, topic, timestampMs, totalSize, attributes, encryptionType) { + constructor(header: Header, textHeader: TextHeader, encryptionType: LivekitModels.Encryption.Type) : this( id = header.streamId, topic = header.topic, timestampMs = header.timestamp, @@ -60,6 +62,7 @@ data class TextStreamInfo( }, attachedStreamIds = textHeader.attachedStreamIdsList ?: emptyList(), generated = textHeader.generated, + encryptionType = encryptionType, ) enum class OperationType { @@ -105,8 +108,9 @@ data class ByteStreamInfo( override val attributes: Map, val mimeType: String, val name: String?, -) : StreamInfo(id, topic, timestampMs, totalSize, attributes) { - constructor(header: Header, byteHeader: ByteHeader) : this( + override val encryptionType: LivekitModels.Encryption.Type, +) : StreamInfo(id, topic, timestampMs, totalSize, attributes, encryptionType) { + constructor(header: Header, byteHeader: ByteHeader, encryptionType: LivekitModels.Encryption.Type) : this( id = header.streamId, topic = header.topic, timestampMs = header.timestamp, @@ -118,5 +122,6 @@ data class ByteStreamInfo( attributes = header.attributesMap.toMap(), mimeType = header.mimeType, name = byteHeader.name, + encryptionType = encryptionType, ) } diff --git a/livekit-android-sdk/src/main/java/io/livekit/android/room/datastream/StreamOptions.kt b/livekit-android-sdk/src/main/java/io/livekit/android/room/datastream/StreamOptions.kt index da1ea717a..0a987eecf 100644 --- a/livekit-android-sdk/src/main/java/io/livekit/android/room/datastream/StreamOptions.kt +++ b/livekit-android-sdk/src/main/java/io/livekit/android/room/datastream/StreamOptions.kt @@ -17,18 +17,8 @@ package io.livekit.android.room.datastream import io.livekit.android.room.participant.Participant -import livekit.LivekitModels import java.util.UUID -interface StreamOptions { - val topic: String? - val attributes: Map? - val totalLength: Long? - val mimeType: String? - val encryptionType: LivekitModels.Encryption.Type? - val destinationIdentities: List -} - data class StreamTextOptions( val topic: String = "", val attributes: Map = emptyMap(), diff --git a/livekit-android-sdk/src/main/java/io/livekit/android/room/datastream/incoming/IncomingDataStreamManager.kt b/livekit-android-sdk/src/main/java/io/livekit/android/room/datastream/incoming/IncomingDataStreamManager.kt index 8fab499ba..337c02acd 100644 --- a/livekit-android-sdk/src/main/java/io/livekit/android/room/datastream/incoming/IncomingDataStreamManager.kt +++ b/livekit-android-sdk/src/main/java/io/livekit/android/room/datastream/incoming/IncomingDataStreamManager.kt @@ -27,6 +27,7 @@ import io.livekit.android.util.LKLog import kotlinx.coroutines.ExperimentalCoroutinesApi import kotlinx.coroutines.channels.BufferOverflow import kotlinx.coroutines.channels.Channel +import livekit.LivekitModels import livekit.LivekitModels.DataStream import java.util.Collections import javax.inject.Inject @@ -66,17 +67,17 @@ interface IncomingDataStreamManager { /** * @suppress */ - fun handleStreamHeader(header: DataStream.Header, fromIdentity: Participant.Identity) + fun handleStreamHeader(header: DataStream.Header, fromIdentity: Participant.Identity, encryptionType: LivekitModels.Encryption.Type) /** * @suppress */ - fun handleDataChunk(chunk: DataStream.Chunk) + fun handleDataChunk(chunk: DataStream.Chunk, encryptionType: LivekitModels.Encryption.Type) /** * @suppress */ - fun handleStreamTrailer(trailer: DataStream.Trailer) + fun handleStreamTrailer(trailer: DataStream.Trailer, encryptionType: LivekitModels.Encryption.Type) /** * @suppress @@ -89,12 +90,20 @@ interface IncomingDataStreamManager { */ class IncomingDataStreamManagerImpl @Inject constructor() : IncomingDataStreamManager { + /** + * A stream descriptor for any open incoming streams. + */ private data class Descriptor( val streamInfo: StreamInfo, /** * Measured by SystemClock.elapsedRealtime() */ val openTime: Long, + /** + * The channel to pipe any incoming data into. + * + * Calling [Channel.close] will automatically call [closeStream] and remove the stream. + */ val channel: Channel, var readLength: Long = 0, ) @@ -154,8 +163,8 @@ class IncomingDataStreamManagerImpl @Inject constructor() : IncomingDataStreamMa /** * @suppress */ - override fun handleStreamHeader(header: DataStream.Header, fromIdentity: Participant.Identity) { - val info = streamInfoFromHeader(header) ?: return + override fun handleStreamHeader(header: DataStream.Header, fromIdentity: Participant.Identity, encryptionType: LivekitModels.Encryption.Type) { + val info = streamInfoFromHeader(header, encryptionType) ?: return openStream(info, fromIdentity) } @@ -189,9 +198,18 @@ class IncomingDataStreamManagerImpl @Inject constructor() : IncomingDataStreamMa /** * @suppress */ - override fun handleDataChunk(chunk: DataStream.Chunk) { + override fun handleDataChunk(chunk: DataStream.Chunk, encryptionType: LivekitModels.Encryption.Type) { val content = chunk.content ?: return val descriptor = openStreams[chunk.streamId] ?: return + + if (encryptionType != descriptor.streamInfo.encryptionType) { + descriptor.channel.close( + StreamException.EncryptionTypeMismatch( + "Encryption type mismatch for stream ${chunk.streamId}. Expected ${descriptor.streamInfo.encryptionType}, got $encryptionType", + ), + ) + } + val totalReadLength = descriptor.readLength + content.size() val totalLength = descriptor.streamInfo.totalSize @@ -208,13 +226,21 @@ class IncomingDataStreamManagerImpl @Inject constructor() : IncomingDataStreamMa /** * @suppress */ - override fun handleStreamTrailer(trailer: DataStream.Trailer) { + override fun handleStreamTrailer(trailer: DataStream.Trailer, encryptionType: LivekitModels.Encryption.Type) { val descriptor = openStreams[trailer.streamId] if (descriptor == null) { LKLog.w { "Received trailer for unknown stream: ${trailer.streamId}" } return } + if (encryptionType != descriptor.streamInfo.encryptionType) { + descriptor.channel.close( + StreamException.EncryptionTypeMismatch( + "Encryption type mismatch for stream ${trailer.streamId}. Expected ${descriptor.streamInfo.encryptionType}, got $encryptionType", + ), + ) + } + val totalLength = descriptor.streamInfo.totalSize if (totalLength != null) { if (descriptor.readLength != totalLength) { @@ -294,15 +320,15 @@ class IncomingDataStreamManagerImpl @Inject constructor() : IncomingDataStreamMa } } - private fun streamInfoFromHeader(header: DataStream.Header): StreamInfo? { + private fun streamInfoFromHeader(header: DataStream.Header, encryptionType: LivekitModels.Encryption.Type): StreamInfo? { try { return when (header.contentHeaderCase) { DataStream.Header.ContentHeaderCase.TEXT_HEADER -> { - TextStreamInfo(header, header.textHeader) + TextStreamInfo(header, header.textHeader, encryptionType) } DataStream.Header.ContentHeaderCase.BYTE_HEADER -> { - ByteStreamInfo(header, header.byteHeader) + ByteStreamInfo(header, header.byteHeader, encryptionType) } DataStream.Header.ContentHeaderCase.CONTENTHEADER_NOT_SET, diff --git a/livekit-android-sdk/src/main/java/io/livekit/android/room/datastream/outgoing/OutgoingDataStreamManager.kt b/livekit-android-sdk/src/main/java/io/livekit/android/room/datastream/outgoing/OutgoingDataStreamManager.kt index 7e903cc84..dcdf557b5 100644 --- a/livekit-android-sdk/src/main/java/io/livekit/android/room/datastream/outgoing/OutgoingDataStreamManager.kt +++ b/livekit-android-sdk/src/main/java/io/livekit/android/room/datastream/outgoing/OutgoingDataStreamManager.kt @@ -27,6 +27,7 @@ import io.livekit.android.room.datastream.StreamTextOptions import io.livekit.android.room.datastream.TextStreamInfo import io.livekit.android.room.participant.Participant import io.livekit.android.util.LKLog +import livekit.LivekitModels import livekit.LivekitModels.DataPacket import livekit.LivekitModels.DataStream import java.io.File @@ -242,6 +243,11 @@ constructor( replyToStreamId = options.replyToStreamId, attachedStreamIds = options.attachedStreamIds, generated = false, + encryptionType = if (engine.e2EEManager?.isDataChannelEncryptionEnabled() ?: false) { + LivekitModels.Encryption.Type.GCM + } else { + LivekitModels.Encryption.Type.NONE + }, ) val streamId = options.streamId @@ -267,6 +273,11 @@ constructor( attributes = options.attributes, mimeType = options.mimeType, name = options.name, + encryptionType = if (engine.e2EEManager?.isDataChannelEncryptionEnabled() ?: false) { + LivekitModels.Encryption.Type.GCM + } else { + LivekitModels.Encryption.Type.NONE + }, ) val streamId = options.streamId diff --git a/livekit-android-sdk/src/main/java/io/livekit/android/room/participant/RemoteParticipant.kt b/livekit-android-sdk/src/main/java/io/livekit/android/room/participant/RemoteParticipant.kt index 00e57ad37..b4dc649c7 100644 --- a/livekit-android-sdk/src/main/java/io/livekit/android/room/participant/RemoteParticipant.kt +++ b/livekit-android-sdk/src/main/java/io/livekit/android/room/participant/RemoteParticipant.kt @@ -21,6 +21,7 @@ import dagger.assisted.AssistedFactory import dagger.assisted.AssistedInject import io.livekit.android.dagger.InjectionNames import io.livekit.android.events.ParticipantEvent +import io.livekit.android.events.RoomEvent import io.livekit.android.room.SignalClient import io.livekit.android.room.track.KIND_AUDIO import io.livekit.android.room.track.KIND_VIDEO @@ -233,7 +234,7 @@ class RemoteParticipant( } // Internal methods just for posting events. - internal fun onDataReceived(data: ByteArray, topic: String?) { - eventBus.postEvent(ParticipantEvent.DataReceived(this, data, topic), scope) + internal fun onDataReceived(event: RoomEvent.DataReceived) { + eventBus.postEvent(ParticipantEvent.DataReceived(this, event.data, event.topic, event.encryptionType), scope) } } diff --git a/livekit-android-sdk/src/main/java/io/livekit/android/webrtc/peerconnection/RTCThreadUtils.kt b/livekit-android-sdk/src/main/java/io/livekit/android/webrtc/peerconnection/RTCThreadUtils.kt index 21657ab81..6e14c3471 100644 --- a/livekit-android-sdk/src/main/java/io/livekit/android/webrtc/peerconnection/RTCThreadUtils.kt +++ b/livekit-android-sdk/src/main/java/io/livekit/android/webrtc/peerconnection/RTCThreadUtils.kt @@ -145,6 +145,9 @@ interface RTCThreadToken { val isDisposed: Boolean } +/** + * This token should live throughout the lifecycle of the underlying PeerConnectionFactory. + */ internal class RTCThreadTokenImpl( private val peerConnectionFactoryManager: PeerConnectionFactoryManager, ) : RTCThreadToken { diff --git a/livekit-android-test/src/main/java/io/livekit/android/test/mock/dagger/TestRTCModule.kt b/livekit-android-test/src/main/java/io/livekit/android/test/mock/dagger/TestRTCModule.kt index e3d1e2720..0ed771c7e 100644 --- a/livekit-android-test/src/main/java/io/livekit/android/test/mock/dagger/TestRTCModule.kt +++ b/livekit-android-test/src/main/java/io/livekit/android/test/mock/dagger/TestRTCModule.kt @@ -28,9 +28,12 @@ import io.livekit.android.audio.AudioRecordSamplesDispatcher import io.livekit.android.audio.NoAudioRecordPrewarmer import io.livekit.android.dagger.CapabilitiesGetter import io.livekit.android.dagger.InjectionNames +import io.livekit.android.e2ee.DataPacketCryptorManager +import io.livekit.android.e2ee.KeyProvider import io.livekit.android.test.mock.MockAudioDeviceModule import io.livekit.android.test.mock.MockAudioProcessingController import io.livekit.android.test.mock.MockEglBase +import io.livekit.android.test.mock.e2ee.ReversingDataPacketCryptorManager import io.livekit.android.webrtc.PeerConnectionFactoryManager import io.livekit.android.webrtc.peerconnection.RTCThreadToken import livekit.org.webrtc.EglBase @@ -128,4 +131,11 @@ object TestRTCModule { @Provides fun sdpFactory() = SdpFactory.getInstance() + + @Provides + fun dataPacketCryptorManagerFactory(): DataPacketCryptorManager.Factory = object : DataPacketCryptorManager.Factory { + override fun create(keyProvider: KeyProvider): DataPacketCryptorManager { + return ReversingDataPacketCryptorManager() + } + } } diff --git a/livekit-android-test/src/main/java/io/livekit/android/test/mock/e2ee/NoopKeyProvider.kt b/livekit-android-test/src/main/java/io/livekit/android/test/mock/e2ee/NoopKeyProvider.kt new file mode 100644 index 000000000..eb7aa744b --- /dev/null +++ b/livekit-android-test/src/main/java/io/livekit/android/test/mock/e2ee/NoopKeyProvider.kt @@ -0,0 +1,53 @@ +/* + * Copyright 2025 LiveKit, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.livekit.android.test.mock.e2ee + +import io.livekit.android.e2ee.KeyProvider +import livekit.org.webrtc.FrameCryptorKeyProvider +import org.mockito.kotlin.mock + +class NoopKeyProvider(override val rtcKeyProvider: FrameCryptorKeyProvider = mock(), override var enableSharedKey: Boolean = true) : KeyProvider { + override fun setSharedKey(key: String, keyIndex: Int?): Boolean { + return true + } + + override fun ratchetSharedKey(keyIndex: Int?): ByteArray { + return ByteArray(0) + } + + override fun exportSharedKey(keyIndex: Int?): ByteArray { + return ByteArray(0) + } + + override fun setKey(key: String, participantId: String?, keyIndex: Int?) { + } + + override fun ratchetKey(participantId: String, keyIndex: Int?): ByteArray { + return ByteArray(0) + } + + override fun exportKey(participantId: String, keyIndex: Int?): ByteArray { + return ByteArray(0) + } + + override fun setSifTrailer(trailer: ByteArray) { + } + + override fun getLatestKeyIndex(participantId: String): Int { + return 0 + } +} diff --git a/livekit-android-test/src/main/java/io/livekit/android/test/mock/e2ee/ReversingDataPacketCryptorManager.kt b/livekit-android-test/src/main/java/io/livekit/android/test/mock/e2ee/ReversingDataPacketCryptorManager.kt new file mode 100644 index 000000000..672df7778 --- /dev/null +++ b/livekit-android-test/src/main/java/io/livekit/android/test/mock/e2ee/ReversingDataPacketCryptorManager.kt @@ -0,0 +1,42 @@ +/* + * Copyright 2025 LiveKit, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.livekit.android.test.mock.e2ee + +import io.livekit.android.e2ee.DataPacketCryptorManager +import io.livekit.android.e2ee.EncryptedPacket +import io.livekit.android.room.participant.Participant + +class ReversingDataPacketCryptorManager : DataPacketCryptorManager { + override fun encrypt( + participantId: Participant.Identity, + keyIndex: Int, + payload: ByteArray, + ): EncryptedPacket? { + return EncryptedPacket( + payload = payload.reversedArray(), + iv = "$participantId,$keyIndex".toByteArray(), + keyIndex = keyIndex, + ) + } + + override fun decrypt(participantId: Participant.Identity, packet: EncryptedPacket): ByteArray? { + return packet.payload.reversedArray() + } + + override fun dispose() { + } +} diff --git a/livekit-android-test/src/test/java/io/livekit/android/e2ee/DataPacketCryptorMockE2ETest.kt b/livekit-android-test/src/test/java/io/livekit/android/e2ee/DataPacketCryptorMockE2ETest.kt new file mode 100644 index 000000000..77563d166 --- /dev/null +++ b/livekit-android-test/src/test/java/io/livekit/android/e2ee/DataPacketCryptorMockE2ETest.kt @@ -0,0 +1,126 @@ +/* + * Copyright 2025 LiveKit, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.livekit.android.e2ee + +import com.google.protobuf.ByteString +import io.livekit.android.events.RoomEvent +import io.livekit.android.room.RTCEngine +import io.livekit.android.room.participant.Participant +import io.livekit.android.test.MockE2ETest +import io.livekit.android.test.events.EventCollector +import io.livekit.android.test.mock.MockDataChannel +import io.livekit.android.test.mock.MockPeerConnection +import io.livekit.android.test.mock.e2ee.NoopKeyProvider +import io.livekit.android.test.mock.e2ee.ReversingDataPacketCryptorManager +import kotlinx.coroutines.ExperimentalCoroutinesApi +import livekit.LivekitModels +import livekit.LivekitModels.DataPacket +import livekit.LivekitRtc +import livekit.org.webrtc.DataChannel +import org.junit.Assert.assertEquals +import org.junit.Assert.assertTrue +import org.junit.Test +import java.nio.ByteBuffer + +@OptIn(ExperimentalCoroutinesApi::class) +class DataPacketCryptorMockE2ETest : MockE2ETest() { + + lateinit var pubDataChannel: MockDataChannel + lateinit var subDataChannel: MockDataChannel + + override suspend fun connect(joinResponse: LivekitRtc.SignalResponse) { + super.connect(joinResponse) + + val pubPeerConnection = component.rtcEngine().getPublisherPeerConnection() as MockPeerConnection + pubDataChannel = pubPeerConnection.dataChannels[RTCEngine.RELIABLE_DATA_CHANNEL_LABEL] as MockDataChannel + + val subPeerConnection = component.rtcEngine().getSubscriberPeerConnection() as MockPeerConnection + subDataChannel = MockDataChannel(RTCEngine.RELIABLE_DATA_CHANNEL_LABEL) + subPeerConnection.observer?.onDataChannel(subDataChannel) + } + + @Test + fun sendsDataEncrypted() = runTest { + room.e2eeOptions = E2EEOptions( + keyProvider = NoopKeyProvider(), + ) + + connect() + + room.e2eeManager?.dataChannelEncryptionEnabled = true + + // The mock data cryptor used just reverses the bytes. + val data = "1234".toByteArray() + assertTrue(room.localParticipant.publishData(data).isSuccess) + + assertEquals(1, pubDataChannel.sentBuffers.size) + + val encryptedPacket = DataPacket.parseFrom(ByteString.copyFrom(pubDataChannel.sentBuffers[0].data)) + assertTrue(encryptedPacket.hasEncryptedPacket()) + + val dataCryptor = ReversingDataPacketCryptorManager() + val decryptedBytes = dataCryptor.decrypt(Participant.Identity(""), encryptedPacket.encryptedPacket.toSdkType()) + val payload = LivekitModels.EncryptedPacketPayload.parseFrom(decryptedBytes) + + assertTrue(data.contentEquals(payload.user.payload.toByteArray())) + } + + @Test + fun receivesDataDecrypted() = runTest { + room.e2eeOptions = E2EEOptions( + keyProvider = NoopKeyProvider(), + ) + + connect() + + // The mock data cryptor used just reverses the bytes. + val data = "1234".toByteArray() + assertTrue(room.localParticipant.publishData(data).isSuccess) + + val dataCryptor = ReversingDataPacketCryptorManager() + val encryptedPacketPayload = with(LivekitModels.EncryptedPacketPayload.newBuilder()) { + user = with(LivekitModels.UserPacket.newBuilder()) { + payload = ByteString.copyFrom(data) + build() + } + build() + } + val encrypted = dataCryptor.encrypt(Participant.Identity(""), 0, encryptedPacketPayload.toByteArray())!! + val dataPacket = with(DataPacket.newBuilder()) { + encryptedPacket = with(LivekitModels.EncryptedPacket.newBuilder()) { + encryptedValue = ByteString.copyFrom(encrypted.payload) + iv = ByteString.copyFrom(encrypted.iv) + keyIndex = encrypted.keyIndex + encryptionType = LivekitModels.Encryption.Type.CUSTOM + build() + } + build() + } + + val eventCollector = EventCollector(room.events, coroutineRule.scope) + + subDataChannel.simulateBufferReceived(DataChannel.Buffer(ByteBuffer.wrap(dataPacket.toByteArray()), true)) + + val events = eventCollector.stopCollecting() + + assertEquals(1, events.size) + val event = events[0] as RoomEvent.DataReceived + + assertTrue(data.contentEquals(event.data)) + assertEquals(LivekitModels.Encryption.Type.CUSTOM, event.encryptionType) + } +} diff --git a/livekit-android-test/src/test/java/io/livekit/android/room/datastream/StreamReaderTest.kt b/livekit-android-test/src/test/java/io/livekit/android/room/datastream/StreamReaderTest.kt index 044db9f2c..6b2fcff26 100644 --- a/livekit-android-test/src/test/java/io/livekit/android/room/datastream/StreamReaderTest.kt +++ b/livekit-android-test/src/test/java/io/livekit/android/room/datastream/StreamReaderTest.kt @@ -22,6 +22,7 @@ import io.livekit.android.test.BaseTest import kotlinx.coroutines.ExperimentalCoroutinesApi import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.runBlocking +import livekit.LivekitModels import org.junit.Assert.assertEquals import org.junit.Assert.assertTrue import org.junit.Before @@ -40,7 +41,16 @@ class StreamReaderTest : BaseTest() { channel.trySend(ByteArray(1) { 1 }) channel.trySend(ByteArray(1) { 2 }) channel.close() - val streamInfo = ByteStreamInfo(id = "id", topic = "topic", timestampMs = 3, totalSize = null, attributes = mapOf(), mimeType = "mime", name = null) + val streamInfo = ByteStreamInfo( + id = "id", + topic = "topic", + timestampMs = 3, + totalSize = null, + attributes = mapOf(), + mimeType = "mime", + name = null, + encryptionType = LivekitModels.Encryption.Type.NONE, + ) reader = ByteStreamReceiver(streamInfo, channel) } diff --git a/livekit-android-test/src/test/java/io/livekit/android/room/datastream/outgoing/ByteStreamSenderTest.kt b/livekit-android-test/src/test/java/io/livekit/android/room/datastream/outgoing/ByteStreamSenderTest.kt index 0c2a0415e..23c56bbcd 100644 --- a/livekit-android-test/src/test/java/io/livekit/android/room/datastream/outgoing/ByteStreamSenderTest.kt +++ b/livekit-android-test/src/test/java/io/livekit/android/room/datastream/outgoing/ByteStreamSenderTest.kt @@ -20,6 +20,7 @@ import io.livekit.android.room.datastream.ByteStreamInfo import io.livekit.android.test.BaseTest import io.livekit.android.test.mock.room.datastream.outgoing.MockStreamDestination import kotlinx.coroutines.ExperimentalCoroutinesApi +import livekit.LivekitModels import org.junit.Assert.assertEquals import org.junit.Assert.assertFalse import org.junit.Assert.assertTrue @@ -83,5 +84,14 @@ class ByteStreamSenderTest : BaseTest() { assertTrue(sender.write(ByteArray(100)).isFailure) } - fun createInfo(): ByteStreamInfo = ByteStreamInfo(id = "stream_id", topic = "topic", timestampMs = 0, totalSize = null, attributes = mapOf(), mimeType = "", name = null) + fun createInfo(): ByteStreamInfo = ByteStreamInfo( + id = "stream_id", + topic = "topic", + timestampMs = 0, + totalSize = null, + attributes = mapOf(), + mimeType = "", + name = null, + encryptionType = LivekitModels.Encryption.Type.NONE, + ) } diff --git a/livekit-android-test/src/test/java/io/livekit/android/room/datastream/outgoing/TextStreamSenderTest.kt b/livekit-android-test/src/test/java/io/livekit/android/room/datastream/outgoing/TextStreamSenderTest.kt index b0b9bae99..9ac0a726b 100644 --- a/livekit-android-test/src/test/java/io/livekit/android/room/datastream/outgoing/TextStreamSenderTest.kt +++ b/livekit-android-test/src/test/java/io/livekit/android/room/datastream/outgoing/TextStreamSenderTest.kt @@ -20,6 +20,7 @@ import io.livekit.android.room.datastream.TextStreamInfo import io.livekit.android.test.BaseTest import io.livekit.android.test.mock.room.datastream.outgoing.MockStreamDestination import kotlinx.coroutines.ExperimentalCoroutinesApi +import livekit.LivekitModels import org.junit.Assert.assertEquals import org.junit.Assert.assertFalse import org.junit.Assert.assertNotEquals @@ -113,5 +114,6 @@ class TextStreamSenderTest : BaseTest() { replyToStreamId = null, attachedStreamIds = listOf(), generated = false, + encryptionType = LivekitModels.Encryption.Type.NONE, ) } diff --git a/protocol b/protocol index 727624357..9bd0fc7c9 160000 --- a/protocol +++ b/protocol @@ -1 +1 @@ -Subproject commit 7276243574cd9bd8a621700d2f696cda5fd1b6bd +Subproject commit 9bd0fc7c95be67400c6086216d79f39cac791e5c diff --git a/sample-app-common/src/main/java/io/livekit/android/sample/CallViewModel.kt b/sample-app-common/src/main/java/io/livekit/android/sample/CallViewModel.kt index 5a7ef5f54..9f87c3de4 100644 --- a/sample-app-common/src/main/java/io/livekit/android/sample/CallViewModel.kt +++ b/sample-app-common/src/main/java/io/livekit/android/sample/CallViewModel.kt @@ -39,6 +39,8 @@ import io.livekit.android.e2ee.E2EEOptions import io.livekit.android.events.RoomEvent import io.livekit.android.events.collect import io.livekit.android.room.Room +import io.livekit.android.room.datastream.StreamTextOptions +import io.livekit.android.room.datastream.incoming.TextStreamReceiver import io.livekit.android.room.participant.LocalParticipant import io.livekit.android.room.participant.Participant import io.livekit.android.room.participant.RemoteParticipant @@ -153,6 +155,17 @@ class CallViewModel( } } + // Handling text streams + room.registerTextStreamHandler( + topic = "lk.chat", + handler = { receiver: TextStreamReceiver, identity: Participant.Identity -> + viewModelScope.launch { + val message = receiver.readAll().joinToString(separator = "") + mutableDataReceived.emit("$identity: $message") + } + }, + ) + viewModelScope.launch(Dispatchers.Default) { // Collect any errors. launch { @@ -177,6 +190,7 @@ class CallViewModel( when (it) { is RoomEvent.FailedToConnect -> mutableError.value = it.error is RoomEvent.DataReceived -> { + // Handling basic data packets. val identity = it.participant?.identity ?: "server" val message = it.data.toString(Charsets.UTF_8) mutableDataReceived.emit("$identity: $message") @@ -368,7 +382,7 @@ class CallViewModel( fun sendData(message: String) { viewModelScope.launch(Dispatchers.IO) { - room.localParticipant.publishData(message.toByteArray(Charsets.UTF_8)) + room.localParticipant.sendText(message, StreamTextOptions(topic = "lk.chat")) } }