Skip to content

Commit

Permalink
fixes & tests
Browse files Browse the repository at this point in the history
  • Loading branch information
davidzhao committed Jan 14, 2024
1 parent 66e0c4b commit c6d7769
Show file tree
Hide file tree
Showing 4 changed files with 205 additions and 0 deletions.
3 changes: 3 additions & 0 deletions livekit-rtc/livekit/rtc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
)
from .video_source import VideoSource
from .video_stream import VideoStream
from .chat import ChatManager, ChatMessage

from .version import __version__

Expand Down Expand Up @@ -134,5 +135,7 @@
"VideoFrameBuffer",
"VideoSource",
"VideoStream",
"ChatManager",
"ChatMessage",
"__version__",
]
29 changes: 29 additions & 0 deletions livekit-rtc/livekit/rtc/_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,22 @@
# Copyright 2023 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
import logging
from collections import deque
import ctypes
import random
from typing import Callable, Generic, List, TypeVar

logger = logging.getLogger("livekit")
Expand Down Expand Up @@ -101,3 +116,17 @@ async def join(self) -> None:
subs = self._subscribers.copy()
for queue in subs:
await queue.join()


_base62_characters = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"


def generate_random_base62(length=12):
"""
Generate a random base62 encoded string of a specified length.
:param length: The desired length of the base62 encoded string.
:return: A base62 encoded string.
"""
global _base62_characters
return "".join(random.choice(_base62_characters) for _ in range(length))
142 changes: 142 additions & 0 deletions livekit-rtc/livekit/rtc/chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
# Copyright 2023 LiveKit, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass, field
from datetime import datetime
import json
import logging
from typing import Any, Callable, Dict, Optional
import uuid

Check failure on line 20 in livekit-rtc/livekit/rtc/chat.py

View workflow job for this annotation

GitHub Actions / build

Ruff (F401)

livekit-rtc/livekit/rtc/chat.py:20:8: F401 `uuid` imported but unused

from .room import Room, Participant, DataPacket
from ._proto.room_pb2 import DataPacketKind
from ._utils import generate_random_base62

_CHAT_TOPIC = "lk-chat-topic"
_CHAT_UPDATE_TOPIC = "lk-chat-update-topic"


class ChatManager:
"""A utility class that sends and receives chat messages in the active session.
It implements LiveKit Chat Protocol, and serializes data to/from JSON data packets.
"""

def __init__(
self, room: Room, *, on_message: Callable[["ChatMessage"], None] = None
):
self._lp = room.local_participant
self._room = room
self._callback: Callable[["ChatMessage"], None] = None

room.on("data_received", self._on_data_received)
if on_message:
self.on_message(on_message)

def close(self):
self._room.off("data_received", self._on_data_received)

async def send_message(self, message: str) -> "ChatMessage":
"""Send a chat message to the end user using LiveKit Chat Protocol.
Args:
message (str): the message to send
Returns:
ChatMessage: the message that was sent
"""
msg = ChatMessage(
message=message,
is_local=True,
participant=self._lp,
)
await self._lp.publish_data(
payload=json.dumps(msg.asjsondict()),
kind=DataPacketKind.KIND_RELIABLE,
topic=_CHAT_TOPIC,
)
return msg

async def update_message(self, message: "ChatMessage"):
"""Update a chat message that was previously sent.
If message.deleted is set to True, we'll signal to remote participants that the message
should be deleted.
"""
await self._lp.publish_data(
payload=json.dumps(message.asjsondict()),
kind=DataPacketKind.KIND_RELIABLE,
topic=_CHAT_UPDATE_TOPIC,
)

def on_message(self, callback: Callable[["ChatMessage"], None]):
"""Register a callback to be called when a chat message is received from the end user."""
self._callback = callback

def _on_data_received(self, dp: DataPacket):
# handle both new and updates the same way, as long as the ID is in there
# the user can decide how to replace the previous message
if dp.topic == _CHAT_TOPIC or dp.topic == _CHAT_UPDATE_TOPIC:
try:
parsed = json.loads(dp.data)
msg = ChatMessage.from_jsondict(parsed)
if dp.participant:
msg.participant = dp.participant
if self._callback:
self._callback(msg)
except Exception as e:
logging.warning(
"failed to parse chat message: %s", e, exc_info=e)


@dataclass
class ChatMessage:
message: str = None
id: str = field(default_factory=generate_random_base62)
timestamp: datetime = field(default_factory=datetime.now)
deleted: bool = field(default=False)

# These fields are not part of the wire protocol. They are here to provide
# context for the application.
participant: Optional[Participant] = None
is_local: bool = field(default=False)

@classmethod
def from_jsondict(cls, d: Dict[str, Any]) -> "ChatMessage":
# older version of the protocol didn't contain a message ID, so we'll create one
id = d.get("id") or generate_random_base62()
timestamp = datetime.now()
if d.get("timestamp"):
timestamp = datetime.fromtimestamp(d.get("timestamp") / 1000.0)
msg = cls(
id=id,
timestamp=timestamp,
)
msg.update_from_jsondict(d)
return msg

def update_from_jsondict(self, d: Dict[str, Any]) -> None:
self.message = d.get("message")
self.deleted = d.get("deleted", False)

def asjsondict(self):
"""Returns a JSON serializable dictionary representation of the message."""
d = {
"id": self.id,
"message": self.message,
"timestamp": int(self.timestamp.timestamp() * 1000),
}
if self.deleted:
d["deleted"] = True
return d
31 changes: 31 additions & 0 deletions livekit-rtc/tests/test_chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from datetime import datetime
import json

from livekit.rtc import ChatMessage


def test_message_basics():
msg = ChatMessage()
assert msg.id is not None, "message id should be set"
assert msg.timestamp is not None, "timestamp should be set"
assert msg.timestamp.day == datetime.now().day, "timestamp should be today"
assert len(msg.id) > 5, "message id should be long enough"


def test_message_serialization():
msg = ChatMessage(
message="hello",
)
data = msg.asjsondict()
msg2 = ChatMessage.from_jsondict(json.loads(json.dumps(data)))
assert msg2.message == msg.message, "message should be the same"
assert msg2.id == msg.id, "id should be the same"
assert int(msg2.timestamp.timestamp()/1000) == int(msg.timestamp.timestamp() /
1000), "timestamp should be the same"
assert not msg2.deleted, "not deleted"

# deletion is handled
msg.deleted = True
data = msg.asjsondict()
msg2 = ChatMessage.from_jsondict(json.loads(json.dumps(data)))
assert msg2.deleted, "should be deleted"

0 comments on commit c6d7769

Please sign in to comment.