Skip to content

Commit

Permalink
Handle multiple messages in send and group_send
Browse files Browse the repository at this point in the history
  • Loading branch information
olzhasar committed Oct 8, 2024
1 parent 13cef45 commit 15b50ae
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 30 deletions.
93 changes: 63 additions & 30 deletions channels_redis/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,24 +178,37 @@ def _setup_encryption(self, symmetric_encryption_keys):

async def send(self, channel, message):
"""
Send a message onto a (general or specific) channel.
Send one or multiple messages onto a (general or specific) channel.
The `message` can be a single dict or an iterable of dicts.
"""
messages = self._parse_messages(message)

# Typecheck
assert isinstance(message, dict), "message is not a dict"
assert self.valid_channel_name(channel), "Channel name not valid"
# Make sure the message does not contain reserved keys
assert "__asgi_channel__" not in message

# If it's a process-local channel, strip off local part and stick full name in message
channel_non_local_name = channel
if "!" in channel:
message = dict(message.items())
message["__asgi_channel__"] = channel
process_local = "!" in channel
if process_local:
channel_non_local_name = self.non_local_name(channel)

now = time.time()
mapping = {}
for message in messages:
assert isinstance(message, dict), "message is not a dict"
# Make sure the message does not contain reserved keys
assert "__asgi_channel__" not in message
if process_local:
message = dict(message.items())
message["__asgi_channel__"] = channel

mapping[self.serialize(message)] = now

# Write out message into expiring key (avoids big items in list)
channel_key = self.prefix + channel_non_local_name
# Pick a connection to the right server - consistent for specific
# channels, random for general channels
if "!" in channel:
if process_local:
index = self.consistent_hash(channel)
else:
index = next(self._send_index_generator)
Expand All @@ -207,15 +220,23 @@ async def send(self, channel, message):

# Check the length of the list before send
# This can allow the list to leak slightly over capacity, but that's fine.
if await connection.zcount(channel_key, "-inf", "+inf") >= self.get_capacity(
channel
):
current_length = await connection.zcount(channel_key, "-inf", "+inf")

if current_length + len(messages) > self.get_capacity(channel):
raise ChannelFull()

# Push onto the list then set it to expire in case it's not consumed
await connection.zadd(channel_key, {self.serialize(message): time.time()})
await connection.zadd(channel_key, mapping)
await connection.expire(channel_key, int(self.expiry))

def _parse_messages(self, message):
"""
Convert a passed message arg to a tuple of messages.
"""
if not isinstance(message, dict) and hasattr(message, "__iter__"):
return tuple(message)
return (message,)

def _backup_channel_name(self, channel):
"""
Construct the key used as a backup queue for the given channel.
Expand Down Expand Up @@ -519,8 +540,11 @@ async def group_discard(self, group, channel):

async def group_send(self, group, message):
"""
Sends a message to the entire group.
Sends one or multiple messages to the entire group.
The `message` can be a single dict or an iterable of dicts.
"""
messages = self._parse_messages(message)

assert self.valid_group_name(group), "Group name not valid"
# Retrieve list of all channel names
key = self._group_key(group)
Expand All @@ -536,7 +560,7 @@ async def group_send(self, group, message):
connection_to_channel_keys,
channel_keys_to_message,
channel_keys_to_capacity,
) = self._map_channel_keys_to_connection(channel_names, message)
) = self._map_channel_keys_to_connection(channel_names, messages)

for connection_index, channel_redis_keys in connection_to_channel_keys.items():
# Discard old messages based on expiry
Expand All @@ -548,17 +572,23 @@ async def group_send(self, group, message):
await pipe.execute()

# Create a LUA script specific for this connection.
# Make sure to use the message specific to this channel, it is
# stored in channel_to_message dict and contains the
# Make sure to use the message list specific to this channel, it is
# stored in channel_to_message dict and each message contains the
# __asgi_channel__ key.

group_send_lua = """
local over_capacity = 0
local num_messages = tonumber(ARGV[#ARGV - 2])
local current_time = ARGV[#ARGV - 1]
local expiry = ARGV[#ARGV]
for i=1,#KEYS do
if redis.call('ZCOUNT', KEYS[i], '-inf', '+inf') < tonumber(ARGV[i + #KEYS]) then
redis.call('ZADD', KEYS[i], current_time, ARGV[i])
if redis.call('ZCOUNT', KEYS[i], '-inf', '+inf') < tonumber(ARGV[i + #KEYS * num_messages]) then
local messages = {}
for j=num_messages * (i - 1) + 1, num_messages * i do
table.insert(messages, current_time)
table.insert(messages, ARGV[j])
end
redis.call('ZADD', KEYS[i], unpack(messages))
redis.call('EXPIRE', KEYS[i], expiry)
else
over_capacity = over_capacity + 1
Expand All @@ -568,18 +598,18 @@ async def group_send(self, group, message):
"""

# We need to filter the messages to keep those related to the connection
args = [
channel_keys_to_message[channel_key]
for channel_key in channel_redis_keys
]
args = []

for channel_key in channel_redis_keys:
args += channel_keys_to_message[channel_key]

# We need to send the capacity for each channel
args += [
channel_keys_to_capacity[channel_key]
for channel_key in channel_redis_keys
]

args += [time.time(), self.expiry]
args += [len(messages), time.time(), self.expiry]

# channel_keys does not contain a single redis key more than once
connection = self.connection(connection_index)
Expand All @@ -594,7 +624,7 @@ async def group_send(self, group, message):
group,
)

def _map_channel_keys_to_connection(self, channel_names, message):
def _map_channel_keys_to_connection(self, channel_names, messages):
"""
For a list of channel names, GET
Expand All @@ -609,7 +639,7 @@ def _map_channel_keys_to_connection(self, channel_names, message):
# Connection dict keyed by index to list of redis keys mapped on that index
connection_to_channel_keys = collections.defaultdict(list)
# Message dict maps redis key to the message that needs to be send on that key
channel_key_to_message = dict()
channel_key_to_message = collections.defaultdict(list)
# Channel key mapped to its capacity
channel_key_to_capacity = dict()

Expand All @@ -623,20 +653,23 @@ def _map_channel_keys_to_connection(self, channel_names, message):
# Have we come across the same redis key?
if channel_key not in channel_key_to_message:
# If not, fill the corresponding dicts
message = dict(message.items())
message["__asgi_channel__"] = [channel]
channel_key_to_message[channel_key] = message
for message in messages:
message = dict(message.items())
message["__asgi_channel__"] = [channel]
channel_key_to_message[channel_key].append(message)
channel_key_to_capacity[channel_key] = self.get_capacity(channel)
idx = self.consistent_hash(channel_non_local_name)
connection_to_channel_keys[idx].append(channel_key)
else:
# Yes, Append the channel in message dict
channel_key_to_message[channel_key]["__asgi_channel__"].append(channel)
for message in channel_key_to_message[channel_key]:
message["__asgi_channel__"].append(channel)

# Now that we know what message needs to be send on a redis key we serialize it
for key, value in channel_key_to_message.items():
# Serialize the message stored for each redis key
channel_key_to_message[key] = self.serialize(value)
for idx, message in enumerate(value):
channel_key_to_message[key][idx] = self.serialize(message)

return (
connection_to_channel_keys,
Expand Down
54 changes: 54 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import collections
import random

import async_timeout
Expand Down Expand Up @@ -125,6 +126,25 @@ async def listen2():
async_to_sync(channel_layer.flush)()


@pytest.mark.asyncio
async def test_send_multiple(channel_layer):
messsages = [
{"type": "test.message.1"},
{"type": "test.message.2"},
{"type": "test.message.3"},
]

await channel_layer.send("test-channel-1", messsages)

expected = {"test.message.1", "test.message.2", "test.message.3"}
received = set()
for _ in range(3):
msg = await channel_layer.receive("test-channel-1")
received.add(msg["type"])

assert received == expected


@pytest.mark.asyncio
async def test_send_capacity(channel_layer):
"""
Expand Down Expand Up @@ -225,6 +245,40 @@ async def test_groups_basic(channel_layer):
await channel_layer.flush()


@pytest.mark.asyncio
async def test_groups_multiple(channel_layer):
"""
Tests basic group operation.
"""
channel_name1 = await channel_layer.new_channel(prefix="test-gr-chan-1")
channel_name2 = await channel_layer.new_channel(prefix="test-gr-chan-2")
channel_name3 = await channel_layer.new_channel(prefix="test-gr-chan-3")
await channel_layer.group_add("test-group", channel_name1)
await channel_layer.group_add("test-group", channel_name2)
await channel_layer.group_add("test-group", channel_name3)

messages = [
{"type": "message.1"},
{"type": "message.2"},
{"type": "message.3"},
]

expected = {msg["type"] for msg in messages}

await channel_layer.group_send("test-group", messages)

received = collections.defaultdict(set)

for channel_name in (channel_name1, channel_name2, channel_name3):
async with async_timeout.timeout(1):
for _ in range(len(messages)):
received[channel_name].add(
(await channel_layer.receive(channel_name))["type"]
)

assert received[channel_name] == expected


@pytest.mark.asyncio
async def test_groups_channel_full(channel_layer):
"""
Expand Down

0 comments on commit 15b50ae

Please sign in to comment.