Skip to content

Commit

Permalink
Handle multiple messages in send and group_send
Browse files Browse the repository at this point in the history
  • Loading branch information
olzhasar committed Oct 29, 2024
1 parent 13cef45 commit 13debf7
Show file tree
Hide file tree
Showing 2 changed files with 199 additions and 34 deletions.
179 changes: 145 additions & 34 deletions channels_redis/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,22 +180,40 @@ async def send(self, channel, message):
"""
Send a message onto a (general or specific) channel.
"""
await self.send_bulk(channel, (message,))

async def send_bulk(self, channel, messages):
"""
Send multiple messages in bulk onto a (general or specific) channel.
The `messages` argument should be an iterable of dicts.
"""

# Typecheck
assert isinstance(message, dict), "message is not a dict"
assert self.valid_channel_name(channel), "Channel name not valid"
# Make sure the message does not contain reserved keys
assert "__asgi_channel__" not in message

# If it's a process-local channel, strip off local part and stick full name in message
channel_non_local_name = channel
if "!" in channel:
message = dict(message.items())
message["__asgi_channel__"] = channel
process_local = "!" in channel
if process_local:
channel_non_local_name = self.non_local_name(channel)

now = time.time()
mapping = {}
for message in messages:
assert isinstance(message, dict), "message is not a dict"
# Make sure the message does not contain reserved keys
assert "__asgi_channel__" not in message
if process_local:
message = dict(message.items())
message["__asgi_channel__"] = channel

mapping[self.serialize(message)] = now

# Write out message into expiring key (avoids big items in list)
channel_key = self.prefix + channel_non_local_name
# Pick a connection to the right server - consistent for specific
# channels, random for general channels
if "!" in channel:
if process_local:
index = self.consistent_hash(channel)
else:
index = next(self._send_index_generator)
Expand All @@ -207,13 +225,13 @@ async def send(self, channel, message):

# Check the length of the list before send
# This can allow the list to leak slightly over capacity, but that's fine.
if await connection.zcount(channel_key, "-inf", "+inf") >= self.get_capacity(
channel
):
current_length = await connection.zcount(channel_key, "-inf", "+inf")

if current_length + len(messages) > self.get_capacity(channel):
raise ChannelFull()

# Push onto the list then set it to expire in case it's not consumed
await connection.zadd(channel_key, {self.serialize(message): time.time()})
await connection.zadd(channel_key, mapping)
await connection.expire(channel_key, int(self.expiry))

def _backup_channel_name(self, channel):
Expand Down Expand Up @@ -517,10 +535,7 @@ async def group_discard(self, group, channel):
connection = self.connection(self.consistent_hash(group))
await connection.zrem(key, channel)

async def group_send(self, group, message):
"""
Sends a message to the entire group.
"""
async def _get_group_connection_and_channels(self, group):
assert self.valid_group_name(group), "Group name not valid"
# Retrieve list of all channel names
key = self._group_key(group)
Expand All @@ -532,11 +547,36 @@ async def group_send(self, group, message):

channel_names = [x.decode("utf8") for x in await connection.zrange(key, 0, -1)]

return connection, channel_names

async def _exec_group_lua_script(
self, conn_idx, group, channel_redis_keys, channel_names, script, args
):
# channel_keys does not contain a single redis key more than once
connection = self.connection(conn_idx)
channels_over_capacity = await connection.eval(
script, len(channel_redis_keys), *channel_redis_keys, *args
)
if channels_over_capacity > 0:
logger.info(
"%s of %s channels over capacity in group %s",
channels_over_capacity,
len(channel_names),
group,
)

async def group_send(self, group, message):
"""
Sends a message to the entire group.
"""

connection, channel_names = await self._get_group_connection_and_channels(group)

(
connection_to_channel_keys,
channel_keys_to_message,
channel_keys_to_capacity,
) = self._map_channel_keys_to_connection(channel_names, message)
) = self._map_channel_keys_to_connection(channel_names, (message,))

for connection_index, channel_redis_keys in connection_to_channel_keys.items():
# Discard old messages based on expiry
Expand Down Expand Up @@ -569,7 +609,7 @@ async def group_send(self, group, message):

# We need to filter the messages to keep those related to the connection
args = [
channel_keys_to_message[channel_key]
channel_keys_to_message[channel_key][0]
for channel_key in channel_redis_keys
]

Expand All @@ -581,20 +621,88 @@ async def group_send(self, group, message):

args += [time.time(), self.expiry]

# channel_keys does not contain a single redis key more than once
connection = self.connection(connection_index)
channels_over_capacity = await connection.eval(
group_send_lua, len(channel_redis_keys), *channel_redis_keys, *args
await self._exec_group_lua_script(
connection_index,
group,
channel_redis_keys,
channel_names,
group_send_lua,
args,
)
if channels_over_capacity > 0:
logger.info(
"%s of %s channels over capacity in group %s",
channels_over_capacity,
len(channel_names),
group,

async def group_send_bulk(self, group, messages):
"""
Sends multiple messages in bulk to the entire group.
The `messages` argument should be an iterable of dicts.
"""

connection, channel_names = await self._get_group_connection_and_channels(group)

(
connection_to_channel_keys,
channel_keys_to_message,
channel_keys_to_capacity,
) = self._map_channel_keys_to_connection(channel_names, messages)

for connection_index, channel_redis_keys in connection_to_channel_keys.items():
# Discard old messages based on expiry
pipe = connection.pipeline()
for key in channel_redis_keys:
pipe.zremrangebyscore(
key, min=0, max=int(time.time()) - int(self.expiry)
)
await pipe.execute()

# Create a LUA script specific for this connection.
# Make sure to use the message list specific to this channel, it is
# stored in channel_to_message dict and each message contains the
# __asgi_channel__ key.

group_send_lua = """
local over_capacity = 0
local num_messages = tonumber(ARGV[#ARGV - 2])
local current_time = ARGV[#ARGV - 1]
local expiry = ARGV[#ARGV]
for i=1,#KEYS do
if redis.call('ZCOUNT', KEYS[i], '-inf', '+inf') < tonumber(ARGV[i + #KEYS * num_messages]) then
local messages = {}
for j=num_messages * (i - 1) + 1, num_messages * i do
table.insert(messages, current_time)
table.insert(messages, ARGV[j])
end
redis.call('ZADD', KEYS[i], unpack(messages))
redis.call('EXPIRE', KEYS[i], expiry)
else
over_capacity = over_capacity + 1
end
end
return over_capacity
"""

# We need to filter the messages to keep those related to the connection
args = []

for channel_key in channel_redis_keys:
args += channel_keys_to_message[channel_key]

# We need to send the capacity for each channel
args += [
channel_keys_to_capacity[channel_key]
for channel_key in channel_redis_keys
]

def _map_channel_keys_to_connection(self, channel_names, message):
args += [len(messages), time.time(), self.expiry]

await self._exec_group_lua_script(
connection_index,
group,
channel_redis_keys,
channel_names,
group_send_lua,
args,
)

def _map_channel_keys_to_connection(self, channel_names, messages):
"""
For a list of channel names, GET
Expand All @@ -609,7 +717,7 @@ def _map_channel_keys_to_connection(self, channel_names, message):
# Connection dict keyed by index to list of redis keys mapped on that index
connection_to_channel_keys = collections.defaultdict(list)
# Message dict maps redis key to the message that needs to be send on that key
channel_key_to_message = dict()
channel_key_to_message = collections.defaultdict(list)
# Channel key mapped to its capacity
channel_key_to_capacity = dict()

Expand All @@ -623,20 +731,23 @@ def _map_channel_keys_to_connection(self, channel_names, message):
# Have we come across the same redis key?
if channel_key not in channel_key_to_message:
# If not, fill the corresponding dicts
message = dict(message.items())
message["__asgi_channel__"] = [channel]
channel_key_to_message[channel_key] = message
for message in messages:
message = dict(message.items())
message["__asgi_channel__"] = [channel]
channel_key_to_message[channel_key].append(message)
channel_key_to_capacity[channel_key] = self.get_capacity(channel)
idx = self.consistent_hash(channel_non_local_name)
connection_to_channel_keys[idx].append(channel_key)
else:
# Yes, Append the channel in message dict
channel_key_to_message[channel_key]["__asgi_channel__"].append(channel)
for message in channel_key_to_message[channel_key]:
message["__asgi_channel__"].append(channel)

# Now that we know what message needs to be send on a redis key we serialize it
for key, value in channel_key_to_message.items():
# Serialize the message stored for each redis key
channel_key_to_message[key] = self.serialize(value)
for idx, message in enumerate(value):
channel_key_to_message[key][idx] = self.serialize(message)

return (
connection_to_channel_keys,
Expand Down
54 changes: 54 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import collections
import random

import async_timeout
Expand Down Expand Up @@ -125,6 +126,25 @@ async def listen2():
async_to_sync(channel_layer.flush)()


@pytest.mark.asyncio
async def test_send_multiple(channel_layer):
messsages = [
{"type": "test.message.1"},
{"type": "test.message.2"},
{"type": "test.message.3"},
]

await channel_layer.send_bulk("test-channel-1", messsages)

expected = {"test.message.1", "test.message.2", "test.message.3"}
received = set()
for _ in range(3):
msg = await channel_layer.receive("test-channel-1")
received.add(msg["type"])

assert received == expected


@pytest.mark.asyncio
async def test_send_capacity(channel_layer):
"""
Expand Down Expand Up @@ -225,6 +245,40 @@ async def test_groups_basic(channel_layer):
await channel_layer.flush()


@pytest.mark.asyncio
async def test_groups_multiple(channel_layer):
"""
Tests basic group operation.
"""
channel_name1 = await channel_layer.new_channel(prefix="test-gr-chan-1")
channel_name2 = await channel_layer.new_channel(prefix="test-gr-chan-2")
channel_name3 = await channel_layer.new_channel(prefix="test-gr-chan-3")
await channel_layer.group_add("test-group", channel_name1)
await channel_layer.group_add("test-group", channel_name2)
await channel_layer.group_add("test-group", channel_name3)

messages = [
{"type": "message.1"},
{"type": "message.2"},
{"type": "message.3"},
]

expected = {msg["type"] for msg in messages}

await channel_layer.group_send_bulk("test-group", messages)

received = collections.defaultdict(set)

for channel_name in (channel_name1, channel_name2, channel_name3):
async with async_timeout.timeout(1):
for _ in range(len(messages)):
received[channel_name].add(
(await channel_layer.receive(channel_name))["type"]
)

assert received[channel_name] == expected


@pytest.mark.asyncio
async def test_groups_channel_full(channel_layer):
"""
Expand Down

0 comments on commit 13debf7

Please sign in to comment.