diff --git a/pulpcore/app/models/task.py b/pulpcore/app/models/task.py index 80d6ef7650..b058e3d9c3 100644 --- a/pulpcore/app/models/task.py +++ b/pulpcore/app/models/task.py @@ -71,7 +71,7 @@ class Task(BaseModel, AutoAddObjPermsMixin): The transitions to CANCELING (marked with *) are the only ones allowed to happen without holding the tasks advisory lock. Canceling is meant to be initiated asyncronously by a sparate - process before signalling the worker via Postgres LISTEN. + process before signalling the worker via a pubsub notification (e.g, Postgres LISTEN). Fields: diff --git a/pulpcore/constants.py b/pulpcore/constants.py index fb128573b9..b6b2441cba 100644 --- a/pulpcore/constants.py +++ b/pulpcore/constants.py @@ -16,6 +16,13 @@ TASK_WAKEUP_UNBLOCK = "unblock" TASK_WAKEUP_HANDLE = "handle" +#: All valid tasking pubsub channels +TASK_PUBSUB = SimpleNamespace( + WAKEUP_WORKER="pulp_worker_wakeup", + CANCEL_TASK="pulp_worker_cancel", + WORKER_METRICS="pulp_worker_metrics_heartbeat", +) + #: All valid task states. TASK_STATES = SimpleNamespace( WAITING="waiting", diff --git a/pulpcore/tasking/pubsub.py b/pulpcore/tasking/pubsub.py new file mode 100644 index 0000000000..6a7b7d3518 --- /dev/null +++ b/pulpcore/tasking/pubsub.py @@ -0,0 +1,151 @@ +from typing import NamedTuple +from pulpcore.constants import TASK_PUBSUB +import os +import logging +import select +from django.db import connection +from contextlib import suppress + +logger = logging.getLogger(__name__) + + +class BasePubSubBackend: + # Utils + @classmethod + def wakeup_worker(cls, reason="unknown"): + cls.publish(TASK_PUBSUB.WAKEUP_WORKER, reason) + + @classmethod + def cancel_task(cls, task_pk): + cls.publish(TASK_PUBSUB.CANCEL_TASK, str(task_pk)) + + @classmethod + def record_worker_metrics(cls, now): + cls.publish(TASK_PUBSUB.WORKER_METRICS, str(now)) + + # Interface + def subscribe(self, channel): + raise NotImplementedError() + + def unsubscribe(self, channel): + raise NotImplementedError() + + def get_subscriptions(self): + raise NotImplementedError() + + @classmethod + def publish(cls, channel, payload=None): + raise NotImplementedError() + + def fileno(self): + """Add support for being used in select loop.""" + raise NotImplementedError() + + def fetch(self): + """Fetch messages new message, if required.""" + raise NotImplementedError() + + def close(self): + raise NotImplementedError() + + +class PubsubMessage(NamedTuple): + channel: str + payload: str + + +def drain_non_blocking_fd(fd): + with suppress(BlockingIOError): + while True: + os.read(fd, 256) + + +class PostgresPubSub(BasePubSubBackend): + PID = os.getpid() + + def __init__(self): + self._subscriptions = set() + self.message_buffer = [] + # ensures a connection is initialized + with connection.cursor() as cursor: + cursor.execute("select 1") + self.backend_pid = connection.connection.info.backend_pid + self.sentinel_r, self.sentinel_w = os.pipe() + os.set_blocking(self.sentinel_r, False) + os.set_blocking(self.sentinel_w, False) + connection.connection.add_notify_handler(self._store_messages) + + @classmethod + def _debug(cls, message): + logger.debug(f"[{cls.PID}] {message}") + + def _store_messages(self, notification): + self.message_buffer.append( + PubsubMessage(channel=notification.channel, payload=notification.payload) + ) + if notification.pid == self.backend_pid: + os.write(self.sentinel_w, b"1") + self._debug(f"Received message: {notification}") + + @classmethod + def publish(cls, channel, payload=""): + query = ( + (f"NOTIFY {channel}",) + if not payload + else ("SELECT pg_notify(%s, %s)", (channel, str(payload))) + ) + + with connection.cursor() as cursor: + cursor.execute(*query) + cls._debug(f"Sent message: ({channel}, {str(payload)})") + + def subscribe(self, channel): + self._subscriptions.add(channel) + with connection.cursor() as cursor: + cursor.execute(f"LISTEN {channel}") + + def unsubscribe(self, channel): + self._subscriptions.remove(channel) + for i in range(0, len(self.message_buffer), -1): + if self.message_buffer[i].channel == channel: + self.message_buffer.pop(i) + with connection.cursor() as cursor: + cursor.execute(f"UNLISTEN {channel}") + + def get_subscriptions(self): + return self._subscriptions.copy() + + def fileno(self) -> int: + # when pub/sub clients are the same, the notification callback may be called + # asynchronously, making select on connection miss new notifications + ready, _, _ = select.select([self.sentinel_r], [], [], 0) + if self.sentinel_r in ready: + return self.sentinel_r + return connection.connection.fileno() + + def fetch(self) -> list[PubsubMessage]: + with connection.cursor() as cursor: + cursor.execute("SELECT 1").fetchone() + result = self.message_buffer.copy() + self.message_buffer.clear() + drain_non_blocking_fd(self.sentinel_r) + self._debug(f"Fetched messages: {result}") + return result + + def close(self): + self.message_buffer.clear() + connection.connection.remove_notify_handler(self._store_messages) + drain_non_blocking_fd(self.sentinel_r) + os.close(self.sentinel_r) + os.close(self.sentinel_w) + for channel in self.get_subscriptions(): + self.unsubscribe(channel) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.close() + + +backend = PostgresPubSub diff --git a/pulpcore/tasking/tasks.py b/pulpcore/tasking/tasks.py index e15b0d1cec..023ba3a196 100644 --- a/pulpcore/tasking/tasks.py +++ b/pulpcore/tasking/tasks.py @@ -26,10 +26,10 @@ TASK_INCOMPLETE_STATES, TASK_STATES, IMMEDIATE_TIMEOUT, - TASK_WAKEUP_HANDLE, TASK_WAKEUP_UNBLOCK, ) from pulpcore.middleware import x_task_diagnostics_var +from pulpcore.tasking import pubsub from pulpcore.tasking.kafka import send_task_notification _logger = logging.getLogger(__name__) @@ -50,12 +50,6 @@ def _validate_and_get_resources(resources): return list(resource_set) -def wakeup_worker(reason): - # Notify workers - with connection.connection.cursor() as cursor: - cursor.execute("SELECT pg_notify('pulp_worker_wakeup', %s)", (reason,)) - - def execute_task(task): # This extra stack is needed to isolate the current_task ContextVar contextvars.copy_context().run(_execute_task, task) @@ -257,7 +251,8 @@ def dispatch( task.set_canceling() task.set_canceled(TASK_STATES.CANCELED, "Resources temporarily unavailable.") if send_wakeup_signal: - wakeup_worker(TASK_WAKEUP_UNBLOCK) + with pubsub.PostgresPubSub(connection) as pubsub_client: + pubsub_client.wakeup_worker(reason=TASK_WAKEUP_UNBLOCK) return task @@ -297,7 +292,8 @@ async def adispatch( task.set_canceling() task.set_canceled(TASK_STATES.CANCELED, "Resources temporarily unavailable.") if send_wakeup_signal: - await sync_to_async(wakeup_worker)(TASK_WAKEUP_UNBLOCK) + with pubsub.PostgresPubSub(connection) as pubsub_client: + pubsub_client.wakeup_worker(reason=TASK_WAKEUP_UNBLOCK) return task @@ -429,12 +425,9 @@ def cancel_task(task_id): # This is the only valid transition without holding the task lock. task.set_canceling() - # Notify the worker that might be running that task. - with connection.cursor() as cursor: - if task.app_lock is None: - wakeup_worker(TASK_WAKEUP_HANDLE) - else: - cursor.execute("SELECT pg_notify('pulp_worker_cancel', %s)", (str(task.pk),)) + # Notify the worker that might be running that task and other workers to clean up + pubsub.backend.cancel_task(task_pk=task.pk) + pubsub.backend.wakeup_worker() return task diff --git a/pulpcore/tasking/worker.py b/pulpcore/tasking/worker.py index d6236959d8..956c41ff7b 100644 --- a/pulpcore/tasking/worker.py +++ b/pulpcore/tasking/worker.py @@ -7,13 +7,14 @@ import signal import socket import contextlib +import collections from datetime import datetime, timedelta from multiprocessing import Process from tempfile import TemporaryDirectory from packaging.version import parse as parse_version from django.conf import settings -from django.db import connection, DatabaseError, IntegrityError +from django.db import DatabaseError, IntegrityError from django.db.models import Case, Count, F, Max, Value, When from django.utils import timezone @@ -25,6 +26,7 @@ TASK_METRICS_HEARTBEAT_LOCK, TASK_WAKEUP_UNBLOCK, TASK_WAKEUP_HANDLE, + TASK_PUBSUB, ) from pulpcore.metrics import init_otel_meter from pulpcore.app.apps import pulp_plugin_configs @@ -32,6 +34,7 @@ from pulpcore.app.util import PGAdvisoryLock from pulpcore.exceptions import AdvisoryLockError +from pulpcore.tasking import pubsub from pulpcore.tasking.storage import WorkerDirectory from pulpcore.tasking._util import ( delete_incomplete_resources, @@ -76,7 +79,6 @@ def __init__(self, auxiliary=False): self.heartbeat_period = timedelta(seconds=settings.WORKER_TTL / 3) self.last_metric_heartbeat = timezone.now() self.versions = {app.label: app.version for app in pulp_plugin_configs()} - self.cursor = connection.cursor() self.app_status = AppStatus.objects.create( name=self.name, app_type="worker", versions=self.versions ) @@ -88,6 +90,9 @@ def __init__(self, auxiliary=False): self.worker_cleanup_countdown = random.randint( int(WORKER_CLEANUP_INTERVAL / 10), WORKER_CLEANUP_INTERVAL ) + # Pubsub handling + self.pubsub_client = pubsub.backend() + self.pubsub_channel_callback = {} # Add a file descriptor to trigger select on signals self.sentinel, sentinel_w = os.pipe() @@ -134,29 +139,6 @@ def _signal_handler(self, thesignal, frame): ) self.shutdown_requested = True - def _pg_notify_handler(self, notification): - if notification.channel == "pulp_worker_wakeup": - if notification.payload == TASK_WAKEUP_UNBLOCK: - # Auxiliary workers don't do this. - self.wakeup_unblock = not self.auxiliary - elif notification.payload == TASK_WAKEUP_HANDLE: - self.wakeup_handle = True - else: - _logger.warning("Unknown wakeup call recieved. Reason: '%s'", notification.payload) - # We cannot be sure so assume everything happened. - self.wakeup_unblock = not self.auxiliary - self.wakeup_handle = True - - elif notification.channel == "pulp_worker_metrics_heartbeat": - self.last_metric_heartbeat = datetime.fromisoformat(notification.payload) - elif self.task and notification.channel == "pulp_worker_cancel": - if notification.payload == str(self.task.pk): - self.cancel_task = True - - def shutdown(self): - self.app_status.delete() - _logger.info(_("Worker %s was shut down."), self.name) - def handle_worker_heartbeat(self): """ Update worker heartbeat records. @@ -217,9 +199,6 @@ def beat(self): # to be able to report on a congested tasking system to produce reliable results. self.record_unblocked_waiting_tasks_metric() - def notify_workers(self, reason): - self.cursor.execute("SELECT pg_notify('pulp_worker_wakeup', %s)", (reason,)) - def cancel_abandoned_task(self, task, final_state, reason=None): """Cancel and clean up an abandoned task. @@ -247,7 +226,8 @@ def cancel_abandoned_task(self, task, final_state, reason=None): delete_incomplete_resources(task) task.set_canceled(final_state=final_state, reason=reason) if task.reserved_resources_record: - self.notify_workers(TASK_WAKEUP_UNBLOCK) + self.pubsub_client.wakeup_worker(reason=TASK_WAKEUP_UNBLOCK) + return True def is_compatible(self, task): unmatched_versions = [ @@ -366,14 +346,11 @@ def sleep(self): _logger.debug(_("Worker %s entering sleep state."), self.name) while not self.shutdown_requested and not self.wakeup_handle: r, w, x = select.select( - [self.sentinel, connection.connection], - [], - [], - 0 if self.wakeup_unblock else self.heartbeat_period.seconds, + [self.sentinel, self.pubsub_client], [], [], self.heartbeat_period.seconds ) self.beat() - if connection.connection in r: - connection.connection.execute("SELECT 1") + if self.pubsub_client in r: + self.pubsub_handle_messages(self.pubsub_client.fetch()) if self.wakeup_unblock: self.unblock_tasks() if self.sentinel in r: @@ -408,15 +385,24 @@ def supervise_task(self, task): ) os.kill(task_process.pid, signal.SIGUSR1) + if self.cancel_task: + _logger.info( + _("Received signal to cancel current task %s in domain: %s."), + task.pk, + domain.name, + ) + cancel_state = TASK_STATES.CANCELED + self.cancel_task = False + r, w, x = select.select( - [self.sentinel, connection.connection, task_process.sentinel], + [self.sentinel, self.pubsub_client, task_process.sentinel], [], [], - 0 if self.wakeup_unblock or self.cancel_task else self.heartbeat_period.seconds, + self.heartbeat_period.seconds, ) self.beat() - if connection.connection in r: - connection.connection.execute("SELECT 1") + if self.pubsub_client in r: + self.pubsub_handle_messages(self.pubsub_client.fetch()) if self.cancel_task: _logger.info( _("Received signal to cancel current task %s in domain: %s."), @@ -468,7 +454,7 @@ def supervise_task(self, task): if cancel_state: self.cancel_abandoned_task(task, cancel_state, cancel_reason) if task.reserved_resources_record: - self.notify_workers(TASK_WAKEUP_UNBLOCK) + self.unblock_tasks() self.task = None def fetch_task(self): @@ -579,18 +565,62 @@ def _record_unblocked_waiting_tasks_metric(self): unblocked_tasks_stats["longest_unblocked_waiting_time"].seconds ) - self.cursor.execute(f"NOTIFY pulp_worker_metrics_heartbeat, '{str(now)}'") + self.pubsub_client.record_worker_metrics(now) + + def pubsub_handle_messages(self): + messages = self.pubsub_client.fetch() + by_channel = collections.defaultdict(list) + for message in messages: + by_channel[message.channel].append(message.payload) + for channel, channel_messages in by_channel.items(): + callback = self.pubsub_channel_callback[channel] + callback(channel_messages) + + def pubsub_setup(self): + def cancellation_callback(messages): + for message in messages: + if self.task and message == str(self.task.pk): + self.cancel_task = True + + def wakeup_callback(messages): + if len(messages) != 1: + message = "unknown" + else: + message = messages[0] + + if message == TASK_WAKEUP_UNBLOCK: + # Auxiliary workers don't do this. + self.wakeup_unblock = not self.auxiliary + elif message == TASK_WAKEUP_HANDLE: + self.wakeup_handle = True + else: + _logger.warn("Unknown wakeup call recieved. Reason: '%s'", message) + # We cannot be sure so assume everything happened. + self.wakeup_unblock = not self.auxiliary + self.wakeup_handle = True + + def metric_callback(messages): + message = messages[0] + self.last_metric_heartbeat = datetime.fromisoformat(message) + + self.pubsub_client.subscribe(TASK_PUBSUB.WAKEUP_WORKER) + self.pubsub_channel_callback[TASK_PUBSUB.WAKEUP_WORKER] = wakeup_callback + self.pubsub_client.subscribe(TASK_PUBSUB.CANCEL_TASK) + self.pubsub_channel_callback[TASK_PUBSUB.CANCEL_TASK] = cancellation_callback + self.pubsub_client.subscribe(TASK_PUBSUB.WORKER_METRICS) + self.pubsub_channel_callback[TASK_PUBSUB.WORKER_METRICS] = metric_callback + + def pubsub_teardown(self): + self.pubsub_client.close() def run(self, burst=False): with WorkerDirectory(self.name): signal.signal(signal.SIGINT, self._signal_handler) signal.signal(signal.SIGTERM, self._signal_handler) signal.signal(signal.SIGHUP, self._signal_handler) - # Subscribe to pgsql channels - connection.connection.add_notify_handler(self._pg_notify_handler) - self.cursor.execute("LISTEN pulp_worker_cancel") - self.cursor.execute("LISTEN pulp_worker_metrics_heartbeat") + self.pubsub_setup() if burst: + self.pubsub_client.unsubscribe(self.pubsub_client.WORKER_WAKEUP) if not self.auxiliary: # Attempt to flush the task queue completely. # Stop iteration if no new tasks were found to unblock. @@ -598,7 +628,6 @@ def run(self, burst=False): self.handle_unblocked_tasks() self.handle_unblocked_tasks() else: - self.cursor.execute("LISTEN pulp_worker_wakeup") while not self.shutdown_requested: # do work if self.shutdown_requested: @@ -608,7 +637,5 @@ def run(self, burst=False): break # rest until notified to wakeup self.sleep() - self.cursor.execute("UNLISTEN pulp_worker_wakeup") - self.cursor.execute("UNLISTEN pulp_worker_metrics_heartbeat") - self.cursor.execute("UNLISTEN pulp_worker_cancel") + self.pubsub_teardown() self.shutdown() diff --git a/pulpcore/tests/functional/test_pubsub.py b/pulpcore/tests/functional/test_pubsub.py new file mode 100644 index 0000000000..96c93ad7c7 --- /dev/null +++ b/pulpcore/tests/functional/test_pubsub.py @@ -0,0 +1,304 @@ +from types import SimpleNamespace +from datetime import datetime +import select +import pytest +from pulpcore.tasking import pubsub +from pulpcore.tests.functional.utils import IpcUtil + + +@pytest.fixture(autouse=True) +def django_connection_reset(request): + # django_db_blocker is from pytest-django. We don't want it to try to safeguard + # us from using our functional Pulp instance. + # https://pytest-django.readthedocs.io/en/latest/database.html#django-db-blocker + pytest_django_installed = False + try: + django_db_blocker = request.getfixturevalue("django_db_blocker") + django_db_blocker.unblock() + pytest_django_installed = True + except pytest.FixtureLookupError: + pass + + # If we dont' reset the connections we'll get interference between tests, + # as listen/notify is connection based. + from django.db import connections + + connections.close_all() + yield + if pytest_django_installed: + django_db_blocker.block() + + +class TestPostgresSpecifics: + def test_listen_notify_in_same_process(self): + """Testing postgres low-level implementation.""" + from django.db import connection + + state = SimpleNamespace() + state.got_message = False + with connection.cursor() as cursor: + assert connection.connection is cursor.connection + conn = cursor.connection + # Listen and Notify + conn.execute("LISTEN abc") + conn.add_notify_handler(lambda notification: setattr(state, "got_message", True)) + cursor.execute("NOTIFY abc, 'foo'") + assert state.got_message is True + conn.execute("SELECT 1") + assert state.got_message is True + + # Reset and retry + state.got_message = False + conn.execute("UNLISTEN abc") + cursor.execute("NOTIFY abc, 'foo'") + assert state.got_message is False + + def test_low_level_assumptions_on_multiprocess(self): + """Asserts that we are really testing two different connections. + + From psycopg, the backend_id is: + "The process ID (PID) of the backend process handling this connection." + """ + from django.db import connection + + def host_act(host_turn, log): + with host_turn(): # 1 + assert connection.connection is None + with connection.cursor() as cursor: + cursor.execute("select 1") + assert connection.connection is not None + log.put(connection.connection.info.backend_pid) + + def child_act(child_turn, log): + with child_turn(): # 2 + assert connection.connection is None + with connection.cursor() as cursor: + cursor.execute("select 1") + assert connection.connection is not None + log.put(connection.connection.info.backend_pid) + + log = IpcUtil.run(host_act, child_act) + assert len(log) == 2 + host_connection_pid, child_connection_pid = log + assert host_connection_pid != child_connection_pid + + +M = pubsub.PubsubMessage +PUBSUB_BACKENDS = [ + pubsub.PostgresPubSub, +] + + +def unsubscribe_all(channels, subscriber): + for channel in channels: + subscriber.unsubscribe(channel) + + +def subscribe_all(channels, subscriber): + for channel in channels: + subscriber.subscribe(channel) + + +def publish_all(messages, publisher): + for channel, payload in messages: + publisher.publish(channel, payload=payload) + + +@pytest.mark.parametrize("pubsub_backend", PUBSUB_BACKENDS) +@pytest.mark.parametrize( + "payload", + ( + pytest.param(None, id="none"), + pytest.param("", id="empty-string"), + pytest.param("payload", id="non-empty-string"), + pytest.param(123, id="int"), + pytest.param(datetime.now(), id="datetime"), + pytest.param(True, id="bool"), + ), +) +class TestPublish: + def test_with_payload_as(self, pubsub_backend: pubsub.BasePubSubBackend, payload): + pubsub_backend.publish("channel", payload=payload) + + +@pytest.mark.parametrize("pubsub_backend", PUBSUB_BACKENDS) +@pytest.mark.parametrize( + "messages", + ( + pytest.param([M("a", "A1")], id="single-message"), + pytest.param([M("a", "A1"), M("a", "A2")], id="two-messages-in-same-channel"), + pytest.param( + [M("a", "A1"), M("a", "A2"), M("b", "B1"), M("c", "C1")], + id="tree-msgs-in-different-channels", + ), + ), +) +class TestNoIpcSubscribeFetch: + def test_with( + self, pubsub_backend: pubsub.BasePubSubBackend, messages: list[pubsub.PubsubMessage] + ): + channels = {m.channel for m in messages} + publisher = pubsub_backend + with pubsub_backend() as subscriber: + subscribe_all(channels, subscriber) + publish_all(messages, publisher) + assert subscriber.fetch() == messages + + unsubscribe_all(channels, subscriber) + assert subscriber.fetch() == [] + + def test_select_readiness_with( + self, pubsub_backend: pubsub.BasePubSubBackend, messages: list[pubsub.PubsubMessage] + ): + TIMEOUT = 0.1 + CHANNELS = {m.channel for m in messages} + publisher = pubsub_backend + with pubsub_backend() as subscriber: + subscribe_all(CHANNELS, subscriber) + assert subscriber.get_subscriptions() == CHANNELS + + ready, _, _ = select.select([subscriber], [], [], TIMEOUT) + assert subscriber not in ready + assert subscriber.fetch() == [] + + publish_all(messages, publisher) + ready, _, _ = select.select([subscriber], [], [], TIMEOUT) + assert subscriber in ready + + ready, _, _ = select.select([subscriber], [], [], TIMEOUT) + assert subscriber in ready + assert subscriber.fetch() == messages + + ready, _, _ = select.select([subscriber], [], [], TIMEOUT) + assert subscriber not in ready + assert subscriber.fetch() == [] + + unsubscribe_all(CHANNELS, subscriber) + publish_all(messages, publisher) + ready, _, _ = select.select([subscriber], [], [], TIMEOUT) + assert subscriber not in ready + assert subscriber.fetch() == [] + + +@pytest.mark.parametrize("pubsub_backend", PUBSUB_BACKENDS) +@pytest.mark.parametrize( + "messages", + ( + pytest.param([M("a", "A1")], id="single-message"), + pytest.param([M("a", "A1")], id="test-if-leaking"), + pytest.param([M("b", "B1"), M("b", "B2")], id="two-messages-in-same-channel"), + pytest.param( + [M("c", "C1"), M("c", "C2"), M("d", "D1"), M("d", "D1")], + id="four-msgs-in-different-channels", + ), + ), +) +class TestIpcSubscribeFetch: + + def test_with( + self, pubsub_backend: pubsub.BasePubSubBackend, messages: list[pubsub.PubsubMessage] + ): + CHANNELS = {m.channel for m in messages} + EXPECTED_LOG = [ + "subscribe", + "publish", + "fetch", + "publish", + "fetch+unsubscribe", + "publish", + "fetch-empty", + ] + + # host + def subscriber_act(subscriber_turn, log): + with pubsub_backend() as subscriber: + with subscriber_turn(): # 1 + log.put("subscribe") + subscribe_all(CHANNELS, subscriber) + + with subscriber_turn(): # 3 + log.put("fetch") + assert subscriber.get_subscriptions() == CHANNELS + assert subscriber.fetch() == messages + assert subscriber.fetch() == [] + + with subscriber_turn(): # 5 + log.put("fetch+unsubscribe") + assert subscriber.fetch() == messages + assert subscriber.fetch() == [] + unsubscribe_all(CHANNELS, subscriber) + + with subscriber_turn(done=True): # 7 + log.put("fetch-empty") + assert subscriber.fetch() == [] + + # child + def publisher_act(publisher_turn, log): + publisher = pubsub_backend + with publisher_turn(): # 2 + log.put("publish") + publish_all(messages, publisher) + + with publisher_turn(): # 4 + log.put("publish") + publish_all(messages, publisher) + + with publisher_turn(): # 6 + log.put("publish") + publish_all(messages, publisher) + + log = IpcUtil.run(subscriber_act, publisher_act) + assert log == EXPECTED_LOG + + def test_select_readiness_with( + self, pubsub_backend: pubsub.BasePubSubBackend, messages: list[pubsub.PubsubMessage] + ): + TIMEOUT = 0.1 + CHANNELS = {m.channel for m in messages} + EXPECTED_LOG = [ + "subscribe/select-empty", + "publish", + "fetch/select-ready/unsubscribe", + "publish", + "fetch/select-empty", + ] + + def subscriber_act(subscriber_turn, log): + with pubsub_backend() as subscriber: + with subscriber_turn(): # 1 + log.put("subscribe/select-empty") + subscribe_all(CHANNELS, subscriber) + assert subscriber.get_subscriptions() == CHANNELS + ready, _, _ = select.select([subscriber], [], [], TIMEOUT) + assert subscriber not in ready + assert subscriber.fetch() == [] + + with subscriber_turn(): # 3 + log.put("fetch/select-ready/unsubscribe") + ready, _, _ = select.select([subscriber], [], [], TIMEOUT) + assert subscriber in ready + + ready, _, _ = select.select([subscriber], [], [], TIMEOUT) + assert subscriber in ready + assert subscriber.fetch() == messages + assert subscriber.fetch() == [] + unsubscribe_all(CHANNELS, subscriber) + + with subscriber_turn(done=True): # 5 + log.put("fetch/select-empty") + ready, _, _ = select.select([subscriber], [], [], TIMEOUT) + assert subscriber not in ready + assert subscriber.fetch() == [] + + def publisher_act(publisher_turn, log): + publisher = pubsub_backend + with publisher_turn(): # 2 + log.put("publish") + publish_all(messages, publisher) + + with publisher_turn(): # 4 + log.put("publish") + publish_all(messages, publisher) + + log = IpcUtil.run(subscriber_act, publisher_act) + assert log == EXPECTED_LOG diff --git a/pulpcore/tests/functional/test_utils.py b/pulpcore/tests/functional/test_utils.py new file mode 100644 index 0000000000..37efe69c14 --- /dev/null +++ b/pulpcore/tests/functional/test_utils.py @@ -0,0 +1,56 @@ +import pytest +from pulpcore.tests.functional.utils import IpcUtil + + +class TestIpcUtil: + + def test_catch_subprocess_errors(self): + + def host_act(host_turn, log): + with host_turn(): + log.put(0) + + def child_act(child_turn, log): + with child_turn(): + log.put(1) + assert 1 == 0 + + error_msg = "AssertionError: assert 1 == 0" + with pytest.raises(Exception, match=error_msg): + IpcUtil.run(host_act, child_act) + + def test_turns_are_deterministic(self): + RUNS = 1000 + errors = 0 + + def host_act(host_turn, log): + with host_turn(): + log.put(0) + + with host_turn(): + log.put(2) + + with host_turn(): + log.put(4) + + def child_act(child_turn, log): + with child_turn(): + log.put(1) + + with child_turn(): + log.put(3) + + with child_turn(): + log.put(5) + + def run(): + log = IpcUtil.run(host_act, child_act) + if log != [0, 1, 2, 3, 4, 5]: + return 1 + return 0 + + for _ in range(RUNS): + errors += run() + + error_rate = errors / RUNS + assert error_rate == 0 diff --git a/pulpcore/tests/functional/utils.py b/pulpcore/tests/functional/utils.py index aeab1328ca..9b5bf789ee 100644 --- a/pulpcore/tests/functional/utils.py +++ b/pulpcore/tests/functional/utils.py @@ -5,7 +5,14 @@ import hashlib import os import random - +import traceback +import sys +import multiprocessing as mp + +from multiprocessing.connection import Connection +from functools import partial +from contextlib import contextmanager +from typing import NamedTuple from aiohttp import web from dataclasses import dataclass from multidict import CIMultiDict @@ -156,3 +163,99 @@ async def _get_from_url(url, auth=None, headers=None): async with aiohttp.ClientSession(auth=auth) as session: async with session.get(url, ssl=False, headers=headers) as response: return response + + +class ProcessErrorData(NamedTuple): + error: Exception + stacktrace: str + + +class IpcUtil: + TIMEOUT_ERROR_MESSAGE = ( + "Tip: make sure the last 'with turn()' (in execution order) " + "is called with 'actor_turn(done=True)', otherwise it may hang." + ) + SUBPROCESS_ERROR_HEADER_TEMPLATE = "Error from sub-process (pid={pid}) on test using IpcUtil" + TURN_WAIT_TIMEOUT = 1 + + @staticmethod + def run(host_act, child_act) -> list: + """Run two processes in synchronous alternate turns. + + The act are functions with the signature (act_turn, log), where act_turn is + a context manager where each step of the act takes place, and log is a + queue where each actor can put messages in using Q.put(item). + + Args: + host_act: The function of the act that start the communication + child_act: The function of the act that follows host_act + Returns: + A list with the items collected via log. + """ + conn_1, conn_2 = mp.Pipe() + log = mp.SimpleQueue() + lock = mp.Lock() + turn_1 = partial(IpcUtil._actor_turn, conn_1, starts=True, log=log, lock=lock) + turn_2 = partial(IpcUtil._actor_turn, conn_2, starts=False, log=log, lock=lock) + proc_1 = mp.Process(target=host_act, args=(turn_1, log)) + proc_2 = mp.Process(target=child_act, args=(turn_2, log)) + proc_1.start() + proc_2.start() + try: + proc_1.join() + finally: + conn_1.send("1") + try: + proc_2.join() + finally: + conn_2.send("1") + conn_1.close() + conn_2.close() + result = IpcUtil.read_log(log) + log.close() + if proc_1.exitcode != 0 or proc_2.exitcode != 0: + error = Exception("General exception") + stacktrace = "No stacktrace" + for item in result: + if isinstance(item, ProcessErrorData): + error, stacktrace = item + break + raise Exception(stacktrace) from error + return result + + @staticmethod + @contextmanager + def _actor_turn(conn: Connection, starts: bool, log, lock: mp.Lock, done: bool = False): + def flush_conn(conn: Connection): + if not conn.poll(IpcUtil.TURN_WAIT_TIMEOUT): + raise TimeoutError(IpcUtil.TIMEOUT_ERROR_MESSAGE) + conn.recv() + + try: + if starts: + with lock: + conn.send("done") + yield + if not done: + flush_conn(conn) + else: + flush_conn(conn) + with lock: + yield + conn.send("done") + except Exception as e: + traceback.print_exc(file=sys.stderr) + error_header = IpcUtil.SUBPROCESS_ERROR_HEADER_TEMPLATE.format(pid=os.getpid()) + traceback_str = f"{error_header}\n\n{traceback.format_exc()}" + error = ProcessErrorData(e, traceback_str) + log.put(error) + exit(1) + + @staticmethod + def read_log(log: mp.SimpleQueue) -> list: + result = [] + while not log.empty(): + result.append(log.get()) + for item in result: + log.put(item) + return result