Skip to content

Commit

Permalink
add extensions as runtime-checkable protocols, make BaseChannelLayer
Browse files Browse the repository at this point in the history
abstract
  • Loading branch information
devkral committed Jan 29, 2024
1 parent 8244d70 commit 39835c2
Showing 1 changed file with 48 additions and 28 deletions.
76 changes: 48 additions & 28 deletions channels/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,18 @@
import re
import string
import time
from abc import ABC, abstractmethod
from copy import deepcopy
from typing import Dict, Iterable, List, Optional, Tuple
from typing import (
Dict,
Iterable,
List,
NoReturn,
Optional,
Protocol,
Tuple,
runtime_checkable,
)

from django.conf import settings
from django.core.signals import setting_changed
Expand Down Expand Up @@ -97,7 +107,39 @@ def set(self, key: str, layer: BaseChannelLayer):
return old


class BaseChannelLayer:
@runtime_checkable
class WithFlushExtension(Protocol):
async def flush(self) -> NoReturn:
"""
Clears messages and if available groups
"""

async def close(self) -> NoReturn:
"""
Close connection to the layer. Called before stopping layer.
Unusable after.
"""


@runtime_checkable
class WithGroupsExtension(Protocol):
async def group_add(self, group: str, channel: str):
"""
Adds the channel name to a group.
"""

async def group_discard(self, group: str, channel: str) -> NoReturn:
"""
Removes the channel name from a group when it exists.
"""

async def group_send(self, group: str, message: dict) -> NoReturn:
"""
Sends message to group
"""


class BaseChannelLayer(ABC):
"""
Base channel layer class that others can inherit from, with useful
common functionality.
Expand Down Expand Up @@ -199,51 +241,29 @@ def non_local_name(self, name: str) -> str:
else:
return name

@abstractmethod
async def send(self, channel: str, message: dict):
"""
Send a message onto a (general or specific) channel.
"""
raise NotImplementedError()

@abstractmethod
async def receive(self, channel: str) -> dict:
"""
Receive the first message that arrives on the channel.
If more than one coroutine waits on the same channel, a random one
of the waiting coroutines will get the result.
"""
raise NotImplementedError()

@abstractmethod
async def new_channel(self, prefix: str = "specific.") -> str:
"""
Returns a new channel name that can be used by something in our
process as a specific channel.
"""
raise NotImplementedError()

# Flush extension

async def flush(self):
raise NotImplementedError()

async def close(self):
raise NotImplementedError()

# Groups extension

async def group_add(self, group: str, channel: str):
"""
Adds the channel name to a group.
"""
raise NotImplementedError()

async def group_discard(self, group: str, channel: str):
raise NotImplementedError()

async def group_send(self, group: str, message: dict):
raise NotImplementedError()


class InMemoryChannelLayer(BaseChannelLayer):
class InMemoryChannelLayer(WithFlushExtension, WithGroupsExtension, BaseChannelLayer):
"""
In-memory channel layer implementation
"""
Expand Down

0 comments on commit 39835c2

Please sign in to comment.