Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
224 changes: 163 additions & 61 deletions channels/testing/live.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,34 @@
from functools import partial
import threading

from daphne.testing import DaphneProcess
from daphne.endpoints import build_endpoint_description_strings
from daphne.server import Server
from daphne.testing import _reinstall_reactor
from django.contrib.staticfiles.handlers import ASGIStaticFilesHandler
from django.core.exceptions import ImproperlyConfigured
from django.db import connections
from django.db.backends.base.creation import TEST_DATABASE_PREFIX
from django.test.testcases import TransactionTestCase
from django.test.utils import modify_settings
from django.utils.functional import classproperty
from django.utils.version import PY311

from channels.routing import get_default_application


def make_application(*, static_wrapper):
# Module-level function for pickle-ability
application = get_default_application()
if static_wrapper is not None:
application = static_wrapper(application)
return application
if not PY311:
# Backport of unittest.case._enter_context() from Python 3.11.
def _enter_context(cm, addcleanup):
# Look up the special methods on the type to match the with statement.
cls = type(cm)
try:
enter = cls.__enter__
exit = cls.__exit__
except AttributeError:
raise TypeError(
f"'{cls.__module__}.{cls.__qualname__}' object does not support the "
f"context manager protocol"
) from None
result = enter(cm)
addcleanup(exit, cm, None, None, None)
return result


def set_database_connection():
Expand All @@ -28,71 +40,161 @@ def set_database_connection():
settings.DATABASES["default"]["NAME"] = test_db_name


class ChannelsLiveServerThread(threading.Thread):
"""Thread for running a live ASGI server while the tests are running."""

server_class = Server

def __init__(self, host, static_handler, connections_override=None, port=0):
self.host = host
self.port = port
self.is_ready = threading.Event()
self.error = None
self.static_handler = static_handler
self.connections_override = connections_override
super().__init__()

def run(self):
"""
Set up the live server and databases, and then loop over handling
ASGI requests.
"""
if self.connections_override:
# Override this thread's database connections with the ones
# provided by the main thread.
for alias, conn in self.connections_override.items():
connections[alias] = conn
try:
# Reinstall the reactor for this thread (same as DaphneProcess)
_reinstall_reactor()

self.httpd = self._create_server(
connections_override=self.connections_override,
)

# Run database setup
set_database_connection()

# The server will call ready_callable when ready
self.httpd.run()
except Exception as e:
self.error = e
self.is_ready.set()
finally:
connections.close_all()

def _create_server(self, connections_override=None):
endpoints = build_endpoint_description_strings(host=self.host, port=self.port)
# Create the handler for serving static files
application = self.static_handler(get_default_application())
return self.server_class(
application=application,
endpoints=endpoints,
signal_handlers=False,
ready_callable=self._server_is_ready,
verbosity=0,
)

def _server_is_ready(self):
"""Called by Daphne when the server is ready and listening."""
# If binding to port zero, assign the port allocated by the OS.
if self.port == 0:
self.port = self.httpd.listening_addresses[0][1]
self.is_ready.set()

def terminate(self):
if hasattr(self, "httpd"):
# Stop the ASGI server
from twisted.internet import reactor

if reactor.running:
reactor.callFromThread(reactor.stop)
self.join(timeout=5)


class ChannelsLiveServerTestCase(TransactionTestCase):
"""
Does basically the same as TransactionTestCase but also launches a
live Daphne server in a separate process, so
that the tests may use another test framework, such as Selenium,
instead of the built-in dummy client.
Do basically the same as TransactionTestCase but also launch a live ASGI
server in a separate thread so that the tests may use another testing
framework, such as Selenium for example, instead of the built-in dummy
client.
It inherits from TransactionTestCase instead of TestCase because the
threads don't share the same transactions (unless if using in-memory
sqlite) and each thread needs to commit all their transactions so that the
other thread can see the changes.
"""

host = "localhost"
ProtocolServerProcess = DaphneProcess
static_wrapper = ASGIStaticFilesHandler
serve_static = True
port = 0
server_thread_class = ChannelsLiveServerThread
static_handler = ASGIStaticFilesHandler

if not PY311:
# Backport of unittest.TestCase.enterClassContext() from Python 3.11.
@classmethod
def enterClassContext(cls, cm):
return _enter_context(cm, cls.addClassCleanup)

@property
def live_server_url(self):
return "http://%s:%s" % (self.host, self._port)
@classproperty
def live_server_url(cls):
return "http://%s:%s" % (cls.host, cls.server_thread.port)

@property
def live_server_ws_url(self):
return "ws://%s:%s" % (self.host, self._port)
@classproperty
def live_server_ws_url(cls):
return "ws://%s:%s" % (cls.host, cls.server_thread.port)

@classproperty
def allowed_host(cls):
return cls.host

@classmethod
def setUpClass(cls):
for connection in connections.all():
if cls._is_in_memory_db(connection):
raise ImproperlyConfigured(
"ChannelLiveServerTestCase can not be used with in memory databases"
)
def _make_connections_override(cls):
connections_override = {}
for conn in connections.all():
# If using in-memory sqlite databases, pass the connections to
# the server thread.
if conn.vendor == "sqlite" and conn.is_in_memory_db():
connections_override[conn.alias] = conn
return connections_override

@classmethod
def setUpClass(cls):
super().setUpClass()

cls._live_server_modified_settings = modify_settings(
ALLOWED_HOSTS={"append": cls.host}
cls.enterClassContext(
modify_settings(ALLOWED_HOSTS={"append": cls.allowed_host})
)
cls._live_server_modified_settings.enable()
cls._start_server_thread()

get_application = partial(
make_application,
static_wrapper=cls.static_wrapper if cls.serve_static else None,
)
cls._server_process = cls.ProtocolServerProcess(
cls.host,
get_application,
setup=set_database_connection,
)
cls._server_process.start()
while True:
if not cls._server_process.ready.wait(timeout=1):
if cls._server_process.is_alive():
continue
raise RuntimeError("Server stopped") from None
break
cls._port = cls._server_process.port.value
@classmethod
def _start_server_thread(cls):
connections_override = cls._make_connections_override()
for conn in connections_override.values():
# Explicitly enable thread-shareability for this connection.
conn.inc_thread_sharing()

cls.server_thread = cls._create_server_thread(connections_override)
cls.server_thread.daemon = True
cls.server_thread.start()
cls.addClassCleanup(cls._terminate_thread)

# Wait for the live server to be ready
cls.server_thread.is_ready.wait()
if cls.server_thread.error:
raise cls.server_thread.error

@classmethod
def tearDownClass(cls):
cls._server_process.terminate()
cls._server_process.join()
cls._live_server_modified_settings.disable()
super().tearDownClass()
def _create_server_thread(cls, connections_override):
return cls.server_thread_class(
cls.host,
cls.static_handler,
connections_override=connections_override,
port=cls.port,
)

@classmethod
def _is_in_memory_db(cls, connection):
"""
Check if DatabaseWrapper holds in memory database.
"""
if connection.vendor == "sqlite":
return connection.is_in_memory_db()
def _terminate_thread(cls):
# Terminate the live server's thread.
cls.server_thread.terminate()
# Restore shared connections' non-shareability.
for conn in cls.server_thread.connections_override.values():
conn.dec_thread_sharing()
1 change: 0 additions & 1 deletion tests/sample_project/tests/test_selenium.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@


class TestSampleApp(SeleniumMixin, ChannelsLiveServerTestCase):
serve_static = True

def setUp(self):
super().setUp()
Expand Down