Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow sending multiple messages in bulk #400

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 145 additions & 34 deletions channels_redis/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,22 +180,40 @@ async def send(self, channel, message):
"""
Send a message onto a (general or specific) channel.
"""
await self.send_bulk(channel, (message,))

async def send_bulk(self, channel, messages):
"""
Send multiple messages in bulk onto a (general or specific) channel.
The `messages` argument should be an iterable of dicts.
"""

# Typecheck
assert isinstance(message, dict), "message is not a dict"
assert self.valid_channel_name(channel), "Channel name not valid"
# Make sure the message does not contain reserved keys
assert "__asgi_channel__" not in message

# If it's a process-local channel, strip off local part and stick full name in message
olzhasar marked this conversation as resolved.
Show resolved Hide resolved
channel_non_local_name = channel
if "!" in channel:
message = dict(message.items())
message["__asgi_channel__"] = channel
process_local = "!" in channel
if process_local:
channel_non_local_name = self.non_local_name(channel)

now = time.time()
mapping = {}
for message in messages:
assert isinstance(message, dict), "message is not a dict"
# Make sure the message does not contain reserved keys
assert "__asgi_channel__" not in message
if process_local:
message = dict(message.items())
message["__asgi_channel__"] = channel

mapping[self.serialize(message)] = now

# Write out message into expiring key (avoids big items in list)
channel_key = self.prefix + channel_non_local_name
# Pick a connection to the right server - consistent for specific
# channels, random for general channels
if "!" in channel:
if process_local:
index = self.consistent_hash(channel)
else:
index = next(self._send_index_generator)
Expand All @@ -207,13 +225,13 @@ async def send(self, channel, message):

# Check the length of the list before send
# This can allow the list to leak slightly over capacity, but that's fine.
if await connection.zcount(channel_key, "-inf", "+inf") >= self.get_capacity(
channel
):
current_length = await connection.zcount(channel_key, "-inf", "+inf")

if current_length + len(messages) > self.get_capacity(channel):
raise ChannelFull()

# Push onto the list then set it to expire in case it's not consumed
await connection.zadd(channel_key, {self.serialize(message): time.time()})
await connection.zadd(channel_key, mapping)
await connection.expire(channel_key, int(self.expiry))

def _backup_channel_name(self, channel):
Expand Down Expand Up @@ -517,10 +535,7 @@ async def group_discard(self, group, channel):
connection = self.connection(self.consistent_hash(group))
await connection.zrem(key, channel)

async def group_send(self, group, message):
"""
Sends a message to the entire group.
"""
async def _get_group_connection_and_channels(self, group):
assert self.valid_group_name(group), "Group name not valid"
# Retrieve list of all channel names
key = self._group_key(group)
Expand All @@ -532,11 +547,36 @@ async def group_send(self, group, message):

channel_names = [x.decode("utf8") for x in await connection.zrange(key, 0, -1)]

return connection, channel_names

async def _exec_group_lua_script(
self, conn_idx, group, channel_redis_keys, channel_names, script, args
):
# channel_keys does not contain a single redis key more than once
connection = self.connection(conn_idx)
channels_over_capacity = await connection.eval(
script, len(channel_redis_keys), *channel_redis_keys, *args
)
if channels_over_capacity > 0:
logger.info(
"%s of %s channels over capacity in group %s",
channels_over_capacity,
len(channel_names),
group,
)

async def group_send(self, group, message):
"""
Sends a message to the entire group.
"""

connection, channel_names = await self._get_group_connection_and_channels(group)

(
connection_to_channel_keys,
channel_keys_to_message,
channel_keys_to_capacity,
) = self._map_channel_keys_to_connection(channel_names, message)
) = self._map_channel_keys_to_connection(channel_names, (message,))

for connection_index, channel_redis_keys in connection_to_channel_keys.items():
# Discard old messages based on expiry
Expand Down Expand Up @@ -569,7 +609,7 @@ async def group_send(self, group, message):

# We need to filter the messages to keep those related to the connection
args = [
channel_keys_to_message[channel_key]
channel_keys_to_message[channel_key][0]
for channel_key in channel_redis_keys
]

Expand All @@ -581,20 +621,88 @@ async def group_send(self, group, message):

args += [time.time(), self.expiry]

# channel_keys does not contain a single redis key more than once
connection = self.connection(connection_index)
channels_over_capacity = await connection.eval(
group_send_lua, len(channel_redis_keys), *channel_redis_keys, *args
await self._exec_group_lua_script(
connection_index,
group,
channel_redis_keys,
channel_names,
group_send_lua,
args,
)
if channels_over_capacity > 0:
logger.info(
"%s of %s channels over capacity in group %s",
channels_over_capacity,
len(channel_names),
group,

async def group_send_bulk(self, group, messages):
"""
Sends multiple messages in bulk to the entire group.
The `messages` argument should be an iterable of dicts.
"""

connection, channel_names = await self._get_group_connection_and_channels(group)

(
connection_to_channel_keys,
channel_keys_to_message,
channel_keys_to_capacity,
) = self._map_channel_keys_to_connection(channel_names, messages)

for connection_index, channel_redis_keys in connection_to_channel_keys.items():
# Discard old messages based on expiry
pipe = connection.pipeline()
for key in channel_redis_keys:
pipe.zremrangebyscore(
key, min=0, max=int(time.time()) - int(self.expiry)
)
await pipe.execute()

# Create a LUA script specific for this connection.
# Make sure to use the message list specific to this channel, it is
# stored in channel_to_message dict and each message contains the
# __asgi_channel__ key.

group_send_lua = """
local over_capacity = 0
local num_messages = tonumber(ARGV[#ARGV - 2])
local current_time = ARGV[#ARGV - 1]
local expiry = ARGV[#ARGV]
for i=1,#KEYS do
if redis.call('ZCOUNT', KEYS[i], '-inf', '+inf') < tonumber(ARGV[i + #KEYS * num_messages]) then
local messages = {}
for j=num_messages * (i - 1) + 1, num_messages * i do
table.insert(messages, current_time)
table.insert(messages, ARGV[j])
end
redis.call('ZADD', KEYS[i], unpack(messages))
redis.call('EXPIRE', KEYS[i], expiry)
else
over_capacity = over_capacity + 1
end
end
return over_capacity
"""

# We need to filter the messages to keep those related to the connection
args = []

for channel_key in channel_redis_keys:
args += channel_keys_to_message[channel_key]

# We need to send the capacity for each channel
args += [
channel_keys_to_capacity[channel_key]
for channel_key in channel_redis_keys
]

def _map_channel_keys_to_connection(self, channel_names, message):
args += [len(messages), time.time(), self.expiry]

await self._exec_group_lua_script(
connection_index,
group,
channel_redis_keys,
channel_names,
group_send_lua,
args,
)

def _map_channel_keys_to_connection(self, channel_names, messages):
"""
For a list of channel names, GET
Expand All @@ -609,7 +717,7 @@ def _map_channel_keys_to_connection(self, channel_names, message):
# Connection dict keyed by index to list of redis keys mapped on that index
connection_to_channel_keys = collections.defaultdict(list)
# Message dict maps redis key to the message that needs to be send on that key
channel_key_to_message = dict()
channel_key_to_message = collections.defaultdict(list)
# Channel key mapped to its capacity
channel_key_to_capacity = dict()

Expand All @@ -623,20 +731,23 @@ def _map_channel_keys_to_connection(self, channel_names, message):
# Have we come across the same redis key?
if channel_key not in channel_key_to_message:
# If not, fill the corresponding dicts
message = dict(message.items())
message["__asgi_channel__"] = [channel]
channel_key_to_message[channel_key] = message
for message in messages:
message = dict(message.items())
message["__asgi_channel__"] = [channel]
channel_key_to_message[channel_key].append(message)
channel_key_to_capacity[channel_key] = self.get_capacity(channel)
idx = self.consistent_hash(channel_non_local_name)
connection_to_channel_keys[idx].append(channel_key)
else:
# Yes, Append the channel in message dict
channel_key_to_message[channel_key]["__asgi_channel__"].append(channel)
for message in channel_key_to_message[channel_key]:
message["__asgi_channel__"].append(channel)

# Now that we know what message needs to be send on a redis key we serialize it
for key, value in channel_key_to_message.items():
# Serialize the message stored for each redis key
channel_key_to_message[key] = self.serialize(value)
for idx, message in enumerate(value):
channel_key_to_message[key][idx] = self.serialize(message)

return (
connection_to_channel_keys,
Expand Down
54 changes: 54 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import collections
import random

import async_timeout
Expand Down Expand Up @@ -125,6 +126,25 @@ async def listen2():
async_to_sync(channel_layer.flush)()


@pytest.mark.asyncio
async def test_send_multiple(channel_layer):
messsages = [
{"type": "test.message.1"},
{"type": "test.message.2"},
{"type": "test.message.3"},
]

await channel_layer.send_bulk("test-channel-1", messsages)

expected = {"test.message.1", "test.message.2", "test.message.3"}
received = set()
for _ in range(3):
msg = await channel_layer.receive("test-channel-1")
received.add(msg["type"])

assert received == expected


@pytest.mark.asyncio
async def test_send_capacity(channel_layer):
"""
Expand Down Expand Up @@ -225,6 +245,40 @@ async def test_groups_basic(channel_layer):
await channel_layer.flush()


@pytest.mark.asyncio
async def test_groups_multiple(channel_layer):
"""
Tests basic group operation.
"""
channel_name1 = await channel_layer.new_channel(prefix="test-gr-chan-1")
channel_name2 = await channel_layer.new_channel(prefix="test-gr-chan-2")
channel_name3 = await channel_layer.new_channel(prefix="test-gr-chan-3")
await channel_layer.group_add("test-group", channel_name1)
await channel_layer.group_add("test-group", channel_name2)
await channel_layer.group_add("test-group", channel_name3)

messages = [
{"type": "message.1"},
{"type": "message.2"},
{"type": "message.3"},
]

expected = {msg["type"] for msg in messages}

await channel_layer.group_send_bulk("test-group", messages)

received = collections.defaultdict(set)

for channel_name in (channel_name1, channel_name2, channel_name3):
async with async_timeout.timeout(1):
for _ in range(len(messages)):
received[channel_name].add(
(await channel_layer.receive(channel_name))["type"]
)

assert received[channel_name] == expected


@pytest.mark.asyncio
async def test_groups_channel_full(channel_layer):
"""
Expand Down
Loading