diff --git a/.github/workflows/run-tests.yml b/.github/workflows/run-tests.yml index 74d778bd0..11792a3c9 100644 --- a/.github/workflows/run-tests.yml +++ b/.github/workflows/run-tests.yml @@ -12,6 +12,7 @@ jobs: strategy: matrix: python-version: [ '3.8', '3.9', '3.10', '3.11' ] + fail-fast: false timeout-minutes: 15 steps: - uses: actions/checkout@v3 diff --git a/hivemind/dht/dht.py b/hivemind/dht/dht.py index 85b371d1c..957c8d3df 100644 --- a/hivemind/dht/dht.py +++ b/hivemind/dht/dht.py @@ -72,7 +72,7 @@ def __init__( self.num_workers = num_workers self._record_validator = CompositeValidator(record_validators) - self._inner_pipe, self._outer_pipe = mp.Pipe(duplex=True) + self._inner_pipe, self._outer_pipe = mp.Pipe(duplex=False) self.shutdown_timeout = shutdown_timeout self._ready = MPFuture() self.daemon = daemon @@ -137,6 +137,7 @@ async def _run(): break loop.run_until_complete(_run()) + loop.close() def run_in_background(self, await_ready: bool = True, timeout: Optional[float] = None) -> None: """ @@ -154,6 +155,7 @@ def shutdown(self) -> None: """Shut down a running dht process""" if self.is_alive(): self._outer_pipe.send(("_shutdown", [], {})) + self._outer_pipe.close() self.join(self.shutdown_timeout) if self.is_alive(): logger.warning("DHT did not shut down within the grace period; terminating it the hard way") diff --git a/hivemind/hivemind_cli/run_dht.py b/hivemind/hivemind_cli/run_dht.py index d72dbd22b..64f831e15 100644 --- a/hivemind/hivemind_cli/run_dht.py +++ b/hivemind/hivemind_cli/run_dht.py @@ -1,6 +1,7 @@ -import time from argparse import ArgumentParser from secrets import token_hex +from signal import SIGINT, SIGTERM, signal, strsignal +from threading import Event from hivemind.dht import DHT, DHTNode from hivemind.utils.logging import get_logger, use_hivemind_log_handler @@ -84,12 +85,19 @@ def main(): ) log_visible_maddrs(dht.get_visible_maddrs(), only_p2p=args.use_ipfs) + exit_event = Event() + + def signal_handler(signal_number: int, _) -> None: + logger.info(f"Caught signal {signal_number} ({strsignal(signal_number)}), shutting down") + exit_event.set() + + signal(SIGTERM, signal_handler) + signal(SIGINT, signal_handler) + try: - while True: + while not exit_event.is_set(): dht.run_coroutine(report_status, return_future=False) - time.sleep(args.refresh_period) - except KeyboardInterrupt: - logger.info("Caught KeyboardInterrupt, shutting down") + exit_event.wait(args.refresh_period) finally: dht.shutdown() diff --git a/hivemind/hivemind_cli/run_server.py b/hivemind/hivemind_cli/run_server.py index 1c6bc9a09..b5abd529d 100644 --- a/hivemind/hivemind_cli/run_server.py +++ b/hivemind/hivemind_cli/run_server.py @@ -1,5 +1,7 @@ from functools import partial from pathlib import Path +from signal import SIGINT, SIGTERM, signal, strsignal +from threading import Event import configargparse import torch @@ -104,10 +106,20 @@ def main(): server = Server.create(**args, optim_cls=optim_cls, start=True, compression=compression) + exit_event = Event() + + def signal_handler(signal_number: int, _) -> None: + logger.info(f"Caught signal {signal_number} ({strsignal(signal_number)}), shutting down") + exit_event.set() + + signal(SIGTERM, signal_handler) + signal(SIGINT, signal_handler) + try: + exit_event.wait() + finally: + server.shutdown() server.join() - except KeyboardInterrupt: - logger.info("Caught KeyboardInterrupt, shutting down") if __name__ == "__main__": diff --git a/hivemind/p2p/p2p_daemon_bindings/control.py b/hivemind/p2p/p2p_daemon_bindings/control.py index 4f229bbdb..a8de5d74e 100644 --- a/hivemind/p2p/p2p_daemon_bindings/control.py +++ b/hivemind/p2p/p2p_daemon_bindings/control.py @@ -322,6 +322,7 @@ async def identify(self) -> Tuple[PeerID, Tuple[Multiaddr, ...]]: resp = p2pd_pb.Response() # type: ignore await read_pbmsg_safe(reader, resp) writer.close() + await writer.wait_closed() raise_if_failed(resp) peer_id_bytes = resp.identify.id @@ -343,6 +344,7 @@ async def connect(self, peer_id: PeerID, maddrs: Iterable[Multiaddr]) -> None: resp = p2pd_pb.Response() # type: ignore await read_pbmsg_safe(reader, resp) writer.close() + await writer.wait_closed() raise_if_failed(resp) async def list_peers(self) -> Tuple[PeerInfo, ...]: @@ -352,6 +354,7 @@ async def list_peers(self) -> Tuple[PeerInfo, ...]: resp = p2pd_pb.Response() # type: ignore await read_pbmsg_safe(reader, resp) writer.close() + await writer.wait_closed() raise_if_failed(resp) peers = tuple(PeerInfo.from_protobuf(pinfo) for pinfo in resp.peers) @@ -365,6 +368,7 @@ async def disconnect(self, peer_id: PeerID) -> None: resp = p2pd_pb.Response() # type: ignore await read_pbmsg_safe(reader, resp) writer.close() + await writer.wait_closed() raise_if_failed(resp) async def stream_open( @@ -403,6 +407,7 @@ async def stream_handler(self, proto: str, handler_cb: StreamHandler, balanced: resp = p2pd_pb.Response() # type: ignore await read_pbmsg_safe(reader, resp) writer.close() + await writer.wait_closed() raise_if_failed(resp) async def remove_stream_handler(self, proto: str) -> None: @@ -420,6 +425,7 @@ async def remove_stream_handler(self, proto: str) -> None: resp = p2pd_pb.Response() # type: ignore await read_pbmsg_safe(reader, resp) writer.close() + await writer.wait_closed() raise_if_failed(resp) del self.handlers[proto] diff --git a/hivemind/p2p/p2p_daemon_bindings/utils.py b/hivemind/p2p/p2p_daemon_bindings/utils.py index c8ca87901..4a1f106c6 100644 --- a/hivemind/p2p/p2p_daemon_bindings/utils.py +++ b/hivemind/p2p/p2p_daemon_bindings/utils.py @@ -46,6 +46,7 @@ async def write_unsigned_varint(stream: asyncio.StreamWriter, integer: int, max_ value |= 0x80 byte = value.to_bytes(1, "big") stream.write(byte) + await stream.drain() if integer == 0: break @@ -77,6 +78,7 @@ async def write_pbmsg(stream: asyncio.StreamWriter, pbmsg: PBMessage) -> None: await write_unsigned_varint(stream, size) msg_bytes: bytes = pbmsg.SerializeToString() stream.write(msg_bytes) + await stream.drain() async def read_pbmsg_safe(stream: asyncio.StreamReader, pbmsg: PBMessage) -> None: diff --git a/tests/test_allreduce.py b/tests/test_allreduce.py index 43a1fdc5a..fb86951be 100644 --- a/tests/test_allreduce.py +++ b/tests/test_allreduce.py @@ -108,7 +108,9 @@ async def wait_synchronously(): wall_time = time.perf_counter() - start_time # check that event loop had enough time to respond to incoming requests; this is over 50% most of the time # we set 33% threshold to ensure that the test will pass reliably. If we break prefetch, this drops to <10% - assert time_in_waiting > wall_time / 3, f"Event loop could only run {time_in_waiting / wall_time :.5f} of the time" + assert ( + time_in_waiting > wall_time / 3 + ), f"Event loop could only run {time_in_waiting / wall_time * 100 :.5f}% of the time" @pytest.mark.parametrize("num_senders", [1, 2, 4, 10]) diff --git a/tests/test_cli_scripts.py b/tests/test_cli_scripts.py index 97c674000..f9e044947 100644 --- a/tests/test_cli_scripts.py +++ b/tests/test_cli_scripts.py @@ -3,7 +3,7 @@ from subprocess import PIPE, Popen from time import sleep -DHT_START_PATTERN = re.compile(r"Running a DHT instance. To connect other peers to this one, use (.+)$") +_DHT_START_PATTERN = re.compile(r"Running a DHT instance. To connect other peers to this one, use (.+)$") def test_dht_connection_successful(): @@ -23,20 +23,30 @@ def test_dht_connection_successful(): first_line = dht_proc.stderr.readline() second_line = dht_proc.stderr.readline() - dht_pattern_match = DHT_START_PATTERN.search(first_line) + dht_pattern_match = _DHT_START_PATTERN.search(first_line) assert dht_pattern_match is not None, first_line assert "Full list of visible multiaddresses:" in second_line, second_line initial_peers = dht_pattern_match.group(1).split(" ") dht_client_proc = Popen( - ["hivemind-dht", *initial_peers, "--host_maddrs", "/ip4/127.0.0.1/tcp/0"], + [ + "hivemind-dht", + *initial_peers, + "--host_maddrs", + "/ip4/127.0.0.1/tcp/0", + "--refresh_period", + str(dht_refresh_period), + ], stderr=PIPE, text=True, encoding="utf-8", env=cloned_env, ) + # ensure we get the output of dht_proc after the start of dht_client_proc + sleep(2 * dht_refresh_period) + # skip first two lines with connectivity info for _ in range(2): dht_client_proc.stderr.readline() @@ -44,11 +54,8 @@ def test_dht_connection_successful(): assert "2 DHT nodes (including this one) are in the local routing table" in first_report_msg, first_report_msg - # ensure we get the output of dht_proc after the start of dht_client_proc - sleep(dht_refresh_period) - # expect that one of the next logging outputs from the first peer shows a new connection - for _ in range(5): + for _ in range(10): first_report_msg = dht_proc.stderr.readline() second_report_msg = dht_proc.stderr.readline() @@ -63,6 +70,9 @@ def test_dht_connection_successful(): and "Local storage contains 0 keys" in second_report_msg ) + dht_proc.stderr.close() + dht_client_proc.stderr.close() + dht_proc.terminate() dht_client_proc.terminate() diff --git a/tests/test_moe.py b/tests/test_moe.py index f62c2159d..d788cb0dc 100644 --- a/tests/test_moe.py +++ b/tests/test_moe.py @@ -282,6 +282,7 @@ def test_client_anomaly_detection(): experts["expert.3"].module.ffn.weight.data[0, 0] = float("nan") dht = DHT(start=True) + dht.get_visible_maddrs(latest=True) server = Server(dht, experts, num_connection_handlers=1) server.start() try: @@ -318,7 +319,8 @@ def test_client_anomaly_detection(): def _measure_coro_running_time(n_coros, elapsed_fut, counter): async def coro(): await asyncio.sleep(0.1) - counter.value += 1 + with counter.get_lock(): + counter.value += 1 try: start_time = time.perf_counter() diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index c859e3879..16fb7f2f3 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -414,8 +414,8 @@ def run_trainer(batch_size: int, batch_time: float, client_mode: bool): loss.backward() optimizer.step() - - total_samples_accumulated.value += batch_size + with total_samples_accumulated.get_lock(): + total_samples_accumulated.value += batch_size if not reuse_grad_buffers: optimizer.zero_grad() diff --git a/tests/test_p2p_daemon_bindings.py b/tests/test_p2p_daemon_bindings.py index 71658f2ef..d9160e173 100644 --- a/tests/test_p2p_daemon_bindings.py +++ b/tests/test_p2p_daemon_bindings.py @@ -71,7 +71,8 @@ async def readexactly(self, n): class MockWriter(io.BytesIO): - pass + async def drain(self): + pass class MockReaderWriter(MockReader, MockWriter): diff --git a/tests/test_start_server.py b/tests/test_start_server.py index b84dd5407..b85507c1a 100644 --- a/tests/test_start_server.py +++ b/tests/test_start_server.py @@ -27,11 +27,18 @@ def test_cli_run_server_identity_path(): with TemporaryDirectory() as tempdir: id_path = os.path.join(tempdir, "id") + cloned_env = os.environ.copy() + # overriding the loglevel to prevent debug print statements + cloned_env["HIVEMIND_LOGLEVEL"] = "INFO" + + common_server_args = ["--hidden_dim", "4", "--num_handlers", "1"] + server_1_proc = Popen( - ["hivemind-server", "--num_experts", "1", "--identity_path", id_path], + ["hivemind-server", "--num_experts", "1", "--identity_path", id_path] + common_server_args, stderr=PIPE, text=True, encoding="utf-8", + env=cloned_env, ) line = server_1_proc.stderr.readline() @@ -46,10 +53,11 @@ def test_cli_run_server_identity_path(): assert len(ids_1) == 1 server_2_proc = Popen( - ["hivemind-server", "--num_experts", "1", "--identity_path", id_path], + ["hivemind-server", "--num_experts", "1", "--identity_path", id_path] + common_server_args, stderr=PIPE, text=True, encoding="utf-8", + env=cloned_env, ) line = server_2_proc.stderr.readline() @@ -61,10 +69,11 @@ def test_cli_run_server_identity_path(): assert len(ids_2) == 1 server_3_proc = Popen( - ["hivemind-server", "--num_experts", "1"], + ["hivemind-server", "--num_experts", "1"] + common_server_args, stderr=PIPE, text=True, encoding="utf-8", + env=cloned_env, ) line = server_3_proc.stderr.readline() diff --git a/tests/test_util_modules.py b/tests/test_util_modules.py index f245b777e..33b4597fe 100644 --- a/tests/test_util_modules.py +++ b/tests/test_util_modules.py @@ -233,7 +233,7 @@ def _future_creator(): @pytest.mark.forked def test_mpfuture_done_callback(): receiver, sender = mp.Pipe(duplex=False) - events = [mp.Event() for _ in range(6)] + events = [mp.Event() for _ in range(7)] def _future_creator(): future1, future2, future3 = hivemind.MPFuture(), hivemind.MPFuture(), hivemind.MPFuture() @@ -250,7 +250,7 @@ def _check_result_and_set(future): sender.send((future1, future2)) future2.cancel() # trigger future2 callback from the same process - + events[6].set() events[0].wait() future1.add_done_callback( lambda future: events[4].set() @@ -262,6 +262,7 @@ def _check_result_and_set(future): future1, future2 = receiver.recv() future1.set_result(123) + events[6].wait() with pytest.raises(RuntimeError): future1.add_done_callback(lambda future: (1, 2, 3)) @@ -514,21 +515,23 @@ async def test_async_context_flooding(): Here's how the test below works: suppose that the thread pool has at most N workers; If at least N + 1 coroutines await lock1 concurrently, N of them occupy workers and the rest are awaiting workers; - When the first of N workers acquires lock1, it lets coroutine A inside lock1 and into await sleep(1e-2); + When the first of N workers acquires lock1, it lets coroutine A inside lock1 and into await sleep(); During that sleep, one of the worker-less coroutines will take up the worker freed by coroutine A. Finally, coroutine A finishes sleeping and immediately gets stuck at lock2, because there are no free workers. Thus, every single coroutine is either awaiting an already acquired lock, or awaiting for free workers in executor. """ + total_sleep_time = 1 lock1, lock2 = mp.Lock(), mp.Lock() + num_coros = max(33, mp.cpu_count() * 5 + 1) + async def coro(): async with enter_asynchronously(lock1): - await asyncio.sleep(1e-2) + await asyncio.sleep(total_sleep_time / (num_coros * 2)) async with enter_asynchronously(lock2): - await asyncio.sleep(1e-2) + await asyncio.sleep(total_sleep_time / (num_coros * 2)) - num_coros = max(33, mp.cpu_count() * 5 + 1) await asyncio.wait({asyncio.create_task(coro()) for _ in range(num_coros)})