Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature: improve channel typings, add method stubs #2022

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 101 additions & 32 deletions channels/layers.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,23 @@
from __future__ import annotations

import asyncio
import fnmatch
import random
import re
import string
import time
from abc import ABC, abstractmethod
from copy import deepcopy
from typing import (
Dict,
Iterable,
List,
NoReturn,
Optional,
Protocol,
Tuple,
runtime_checkable,
)

from django.conf import settings
from django.core.signals import setting_changed
Expand All @@ -20,6 +33,8 @@ class ChannelLayerManager:
Takes a settings dictionary of backends and initialises them on request.
"""

backends: Dict[str, BaseChannelLayer]

def __init__(self):
self.backends = {}
setting_changed.connect(self._reset_backends)
Expand All @@ -36,14 +51,14 @@ def configs(self):
# Lazy load settings so we can be imported
return getattr(settings, "CHANNEL_LAYERS", {})

def make_backend(self, name):
def make_backend(self, name) -> BaseChannelLayer:
"""
Instantiate channel layer.
"""
config = self.configs[name].get("CONFIG", {})
return self._make_backend(name, config)

def make_test_backend(self, name):
def make_test_backend(self, name) -> BaseChannelLayer:
"""
Instantiate channel layer using its test config.
"""
Expand All @@ -53,7 +68,7 @@ def make_test_backend(self, name):
raise InvalidChannelLayerError("No TEST_CONFIG specified for %s" % name)
return self._make_backend(name, config)

def _make_backend(self, name, config):
def _make_backend(self, name, config) -> BaseChannelLayer:
# Check for old format config
if "ROUTING" in self.configs[name]:
raise InvalidChannelLayerError(
Expand Down Expand Up @@ -81,7 +96,7 @@ def __getitem__(self, key):
def __contains__(self, key):
return key in self.configs

def set(self, key, layer):
def set(self, key: str, layer: BaseChannelLayer):
"""
Sets an alias to point to a new ChannelLayerWrapper instance, and
returns the old one that it replaced. Useful for swapping out the
Expand All @@ -92,20 +107,63 @@ def set(self, key, layer):
return old


class BaseChannelLayer:
@runtime_checkable
class WithFlushExtension(Protocol):
async def flush(self) -> NoReturn:
"""
Clears messages and if available groups
"""

async def close(self) -> NoReturn:
"""
Close connection to the layer. Called before stopping layer.
Unusable after.
"""


@runtime_checkable
class WithGroupsExtension(Protocol):
async def group_add(self, group: str, channel: str):
"""
Adds the channel name to a group.
"""

async def group_discard(self, group: str, channel: str) -> NoReturn:
"""
Removes the channel name from a group when it exists.
"""

async def group_send(self, group: str, message: dict) -> NoReturn:
"""
Sends message to group
"""


class BaseChannelLayer(ABC):
"""
Base channel layer class that others can inherit from, with useful
common functionality.
"""

MAX_NAME_LENGTH = 100
extensions: Iterable[str] = ()
expiry: int
capacity: int
channel_capacity: Dict[str, int]

def __init__(self, expiry=60, capacity=100, channel_capacity=None):
def __init__(
self,
expiry: int = 60,
capacity: Optional[int] = 100,
channel_capacity: Optional[int] = None,
):
self.expiry = expiry
self.capacity = capacity
self.channel_capacity = channel_capacity or {}

def compile_capacities(self, channel_capacity):
def compile_capacities(
self, channel_capacity
) -> List[Tuple[re.Pattern, Optional[int]]]:
"""
Takes an input channel_capacity dict and returns the compiled list
of regexes that get_capacity will look for as self.channel_capacity
Expand All @@ -120,7 +178,7 @@ def compile_capacities(self, channel_capacity):
result.append((re.compile(fnmatch.translate(pattern)), value))
return result

def get_capacity(self, channel):
def get_capacity(self, channel: str) -> Optional[int]:
"""
Gets the correct capacity for the given channel; either the default,
or a matching result from channel_capacity. Returns the first matching
Expand All @@ -132,7 +190,7 @@ def get_capacity(self, channel):
return capacity
return self.capacity

def match_type_and_length(self, name):
def match_type_and_length(self, name) -> bool:
if isinstance(name, str) and (len(name) < self.MAX_NAME_LENGTH):
return True
return False
Expand All @@ -148,7 +206,7 @@ def match_type_and_length(self, name):
+ "not {}"
)

def valid_channel_name(self, name, receive=False):
def valid_channel_name(self, name: str, receive=False) -> bool:
if self.match_type_and_length(name):
if bool(self.channel_name_regex.match(name)):
# Check cases for special channels
Expand All @@ -159,13 +217,13 @@ def valid_channel_name(self, name, receive=False):
return True
raise TypeError(self.invalid_name_error.format("Channel", name))

def valid_group_name(self, name):
def valid_group_name(self, name: str) -> bool:
if self.match_type_and_length(name):
if bool(self.group_name_regex.match(name)):
return True
raise TypeError(self.invalid_name_error.format("Group", name))

def valid_channel_names(self, names, receive=False):
def valid_channel_names(self, names: List[str], receive=False) -> bool:
_non_empty_list = True if names else False
_names_type = isinstance(names, list)
assert _non_empty_list and _names_type, "names must be a non-empty list"
Expand All @@ -175,7 +233,7 @@ def valid_channel_names(self, names, receive=False):
)
return True

def non_local_name(self, name):
def non_local_name(self, name: str) -> str:
"""
Given a channel name, returns the "non-local" part. If the channel name
is a process-specific channel (contains !) this means the part up to
Expand All @@ -186,8 +244,34 @@ def non_local_name(self, name):
else:
return name

@abstractmethod
async def send(self, channel: str, message: dict):
"""
Send a message onto a (general or specific) channel.
"""

class InMemoryChannelLayer(BaseChannelLayer):
@abstractmethod
async def receive(self, channel: str) -> dict:
"""
Receive the first message that arrives on the channel.
If more than one coroutine waits on the same channel, a random one
of the waiting coroutines will get the result.
"""

@abstractmethod
async def new_channel(self, prefix: str = "specific.") -> str:
"""
Returns a new channel name that can be used by something in our
process as a specific channel.
"""


# WARNING: Protocols must be last
class InMemoryChannelLayer(
BaseChannelLayer,
WithFlushExtension,
WithGroupsExtension,
):
"""
In-memory channel layer implementation
"""
Expand All @@ -198,13 +282,13 @@ def __init__(
group_expiry=86400,
capacity=100,
channel_capacity=None,
**kwargs
**kwargs,
):
super().__init__(
expiry=expiry,
capacity=capacity,
channel_capacity=channel_capacity,
**kwargs
**kwargs,
)
self.channels = {}
self.groups = {}
Expand All @@ -215,9 +299,6 @@ def __init__(
extensions = ["groups", "flush"]

async def send(self, channel, message):
"""
Send a message onto a (general or specific) channel.
"""
# Typecheck
assert isinstance(message, dict), "message is not a dict"
assert self.valid_channel_name(channel), "Channel name not valid"
Expand All @@ -234,11 +315,6 @@ async def send(self, channel, message):
await queue.put((time.time() + self.expiry, deepcopy(message)))

async def receive(self, channel):
"""
Receive the first message that arrives on the channel.
If more than one coroutine waits on the same channel, a random one
of the waiting coroutines will get the result.
"""
assert self.valid_channel_name(channel)
self._clean_expired()

Expand All @@ -254,10 +330,6 @@ async def receive(self, channel):
return message

async def new_channel(self, prefix="specific."):
"""
Returns a new channel name that can be used by something in our
process as a specific channel.
"""
return "%s.inmemory!%s" % (
prefix,
"".join(random.choice(string.ascii_letters) for i in range(12)),
Expand Down Expand Up @@ -314,9 +386,6 @@ def _remove_from_groups(self, channel):
# Groups extension

async def group_add(self, group, channel):
"""
Adds the channel name to a group.
"""
# Check the inputs
assert self.valid_group_name(group), "Group name not valid"
assert self.valid_channel_name(channel), "Channel name not valid"
Expand Down Expand Up @@ -349,7 +418,7 @@ async def group_send(self, group, message):
pass


def get_channel_layer(alias=DEFAULT_CHANNEL_LAYER):
def get_channel_layer(alias=DEFAULT_CHANNEL_LAYER) -> Optional[BaseChannelLayer]:
"""
Returns a channel layer by alias, or None if it is not configured.
"""
Expand Down
14 changes: 13 additions & 1 deletion tests/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,18 @@
)


# when starting with Test it would be tried to collect by pytest
class StubChannelLayer(BaseChannelLayer):
async def send(self, channel: str, message: dict):
raise NotImplementedError()

async def receive(self, channel: str) -> dict:
raise NotImplementedError()

async def new_channel(self, prefix: str = "specific.") -> str:
raise NotImplementedError()


class TestChannelLayerManager(unittest.TestCase):
@override_settings(
CHANNEL_LAYERS={"default": {"BACKEND": "channels.layers.InMemoryChannelLayer"}}
Expand Down Expand Up @@ -72,7 +84,7 @@ async def test_send_receive():

@pytest.mark.parametrize(
"method",
[BaseChannelLayer().valid_channel_name, BaseChannelLayer().valid_group_name],
[StubChannelLayer().valid_channel_name, StubChannelLayer().valid_group_name],
)
@pytest.mark.parametrize(
"channel_name,expected_valid",
Expand Down