Skip to content

Clean up resources in DHT/P2P, improve test robustness #636

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Nov 4, 2024
1 change: 1 addition & 0 deletions .github/workflows/run-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ jobs:
strategy:
matrix:
python-version: [ '3.8', '3.9', '3.10', '3.11' ]
fail-fast: false
timeout-minutes: 15
steps:
- uses: actions/checkout@v3
Expand Down
4 changes: 3 additions & 1 deletion hivemind/dht/dht.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def __init__(
self.num_workers = num_workers

self._record_validator = CompositeValidator(record_validators)
self._inner_pipe, self._outer_pipe = mp.Pipe(duplex=True)
self._inner_pipe, self._outer_pipe = mp.Pipe(duplex=False)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Everything works with duplex=False according to tests, best to use pipes since duplex relies on socketpair, which is slower [1] [2]

self.shutdown_timeout = shutdown_timeout
self._ready = MPFuture()
self.daemon = daemon
Expand Down Expand Up @@ -137,6 +137,7 @@ async def _run():
break

loop.run_until_complete(_run())
loop.close()

def run_in_background(self, await_ready: bool = True, timeout: Optional[float] = None) -> None:
"""
Expand All @@ -154,6 +155,7 @@ def shutdown(self) -> None:
"""Shut down a running dht process"""
if self.is_alive():
self._outer_pipe.send(("_shutdown", [], {}))
self._outer_pipe.close()
self.join(self.shutdown_timeout)
if self.is_alive():
logger.warning("DHT did not shut down within the grace period; terminating it the hard way")
Expand Down
18 changes: 13 additions & 5 deletions hivemind/hivemind_cli/run_dht.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import time
from argparse import ArgumentParser
from secrets import token_hex
from signal import SIGINT, SIGTERM, signal, strsignal
from threading import Event

from hivemind.dht import DHT, DHTNode
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
Expand Down Expand Up @@ -84,12 +85,19 @@ def main():
)
log_visible_maddrs(dht.get_visible_maddrs(), only_p2p=args.use_ipfs)

exit_event = Event()

def signal_handler(signal_number: int, _) -> None:
logger.info(f"Caught signal {signal_number} ({strsignal(signal_number)}), shutting down")
exit_event.set()

signal(SIGTERM, signal_handler)
signal(SIGINT, signal_handler)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We terminate hivemind-server and hivemind-dht in tests, so a signal handler is necessary to ensure child processes are terminated correctly


try:
while True:
while not exit_event.is_set():
dht.run_coroutine(report_status, return_future=False)
time.sleep(args.refresh_period)
except KeyboardInterrupt:
logger.info("Caught KeyboardInterrupt, shutting down")
exit_event.wait(args.refresh_period)
finally:
dht.shutdown()

Expand Down
16 changes: 14 additions & 2 deletions hivemind/hivemind_cli/run_server.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from functools import partial
from pathlib import Path
from signal import SIGINT, SIGTERM, signal, strsignal
from threading import Event

import configargparse
import torch
Expand Down Expand Up @@ -104,10 +106,20 @@ def main():

server = Server.create(**args, optim_cls=optim_cls, start=True, compression=compression)

exit_event = Event()

def signal_handler(signal_number: int, _) -> None:
logger.info(f"Caught signal {signal_number} ({strsignal(signal_number)}), shutting down")
exit_event.set()

signal(SIGTERM, signal_handler)
signal(SIGINT, signal_handler)

try:
exit_event.wait()
finally:
server.shutdown()
server.join()
except KeyboardInterrupt:
logger.info("Caught KeyboardInterrupt, shutting down")


if __name__ == "__main__":
Expand Down
6 changes: 6 additions & 0 deletions hivemind/p2p/p2p_daemon_bindings/control.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,7 @@ async def identify(self) -> Tuple[PeerID, Tuple[Multiaddr, ...]]:
resp = p2pd_pb.Response() # type: ignore
await read_pbmsg_safe(reader, resp)
writer.close()
await writer.wait_closed()
Comment on lines 324 to +325
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait_closed is recommended after close()


raise_if_failed(resp)
peer_id_bytes = resp.identify.id
Expand All @@ -343,6 +344,7 @@ async def connect(self, peer_id: PeerID, maddrs: Iterable[Multiaddr]) -> None:
resp = p2pd_pb.Response() # type: ignore
await read_pbmsg_safe(reader, resp)
writer.close()
await writer.wait_closed()
raise_if_failed(resp)

async def list_peers(self) -> Tuple[PeerInfo, ...]:
Expand All @@ -352,6 +354,7 @@ async def list_peers(self) -> Tuple[PeerInfo, ...]:
resp = p2pd_pb.Response() # type: ignore
await read_pbmsg_safe(reader, resp)
writer.close()
await writer.wait_closed()
raise_if_failed(resp)

peers = tuple(PeerInfo.from_protobuf(pinfo) for pinfo in resp.peers)
Expand All @@ -365,6 +368,7 @@ async def disconnect(self, peer_id: PeerID) -> None:
resp = p2pd_pb.Response() # type: ignore
await read_pbmsg_safe(reader, resp)
writer.close()
await writer.wait_closed()
raise_if_failed(resp)

async def stream_open(
Expand Down Expand Up @@ -403,6 +407,7 @@ async def stream_handler(self, proto: str, handler_cb: StreamHandler, balanced:
resp = p2pd_pb.Response() # type: ignore
await read_pbmsg_safe(reader, resp)
writer.close()
await writer.wait_closed()
raise_if_failed(resp)

async def remove_stream_handler(self, proto: str) -> None:
Expand All @@ -420,6 +425,7 @@ async def remove_stream_handler(self, proto: str) -> None:
resp = p2pd_pb.Response() # type: ignore
await read_pbmsg_safe(reader, resp)
writer.close()
await writer.wait_closed()
raise_if_failed(resp)

del self.handlers[proto]
2 changes: 2 additions & 0 deletions hivemind/p2p/p2p_daemon_bindings/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ async def write_unsigned_varint(stream: asyncio.StreamWriter, integer: int, max_
value |= 0x80
byte = value.to_bytes(1, "big")
stream.write(byte)
await stream.drain()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

drain() is recommended after write()

if integer == 0:
break

Expand Down Expand Up @@ -77,6 +78,7 @@ async def write_pbmsg(stream: asyncio.StreamWriter, pbmsg: PBMessage) -> None:
await write_unsigned_varint(stream, size)
msg_bytes: bytes = pbmsg.SerializeToString()
stream.write(msg_bytes)
await stream.drain()


async def read_pbmsg_safe(stream: asyncio.StreamReader, pbmsg: PBMessage) -> None:
Expand Down
4 changes: 3 additions & 1 deletion tests/test_allreduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,9 @@ async def wait_synchronously():
wall_time = time.perf_counter() - start_time
# check that event loop had enough time to respond to incoming requests; this is over 50% most of the time
# we set 33% threshold to ensure that the test will pass reliably. If we break prefetch, this drops to <10%
assert time_in_waiting > wall_time / 3, f"Event loop could only run {time_in_waiting / wall_time :.5f} of the time"
assert (
time_in_waiting > wall_time / 3
), f"Event loop could only run {time_in_waiting / wall_time * 100 :.5f}% of the time"


@pytest.mark.parametrize("num_senders", [1, 2, 4, 10])
Expand Down
24 changes: 17 additions & 7 deletions tests/test_cli_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from subprocess import PIPE, Popen
from time import sleep

DHT_START_PATTERN = re.compile(r"Running a DHT instance. To connect other peers to this one, use (.+)$")
_DHT_START_PATTERN = re.compile(r"Running a DHT instance. To connect other peers to this one, use (.+)$")


def test_dht_connection_successful():
Expand All @@ -23,32 +23,39 @@ def test_dht_connection_successful():

first_line = dht_proc.stderr.readline()
second_line = dht_proc.stderr.readline()
dht_pattern_match = DHT_START_PATTERN.search(first_line)
dht_pattern_match = _DHT_START_PATTERN.search(first_line)
assert dht_pattern_match is not None, first_line
assert "Full list of visible multiaddresses:" in second_line, second_line

initial_peers = dht_pattern_match.group(1).split(" ")

dht_client_proc = Popen(
["hivemind-dht", *initial_peers, "--host_maddrs", "/ip4/127.0.0.1/tcp/0"],
[
"hivemind-dht",
*initial_peers,
"--host_maddrs",
"/ip4/127.0.0.1/tcp/0",
"--refresh_period",
str(dht_refresh_period),
],
stderr=PIPE,
text=True,
encoding="utf-8",
env=cloned_env,
)

# ensure we get the output of dht_proc after the start of dht_client_proc
sleep(2 * dht_refresh_period)

# skip first two lines with connectivity info
for _ in range(2):
dht_client_proc.stderr.readline()
first_report_msg = dht_client_proc.stderr.readline()

assert "2 DHT nodes (including this one) are in the local routing table" in first_report_msg, first_report_msg

# ensure we get the output of dht_proc after the start of dht_client_proc
sleep(dht_refresh_period)

# expect that one of the next logging outputs from the first peer shows a new connection
for _ in range(5):
for _ in range(10):
Comment on lines -51 to +58
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For some reason, we might not be getting the updated routing table in DHT processes after 5 iterations. Increasing the limit seems to help, although the solution is still quite hacky

first_report_msg = dht_proc.stderr.readline()
second_report_msg = dht_proc.stderr.readline()

Expand All @@ -63,6 +70,9 @@ def test_dht_connection_successful():
and "Local storage contains 0 keys" in second_report_msg
)

dht_proc.stderr.close()
dht_client_proc.stderr.close()

dht_proc.terminate()
dht_client_proc.terminate()

Expand Down
4 changes: 3 additions & 1 deletion tests/test_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ def test_client_anomaly_detection():
experts["expert.3"].module.ffn.weight.data[0, 0] = float("nan")

dht = DHT(start=True)
dht.get_visible_maddrs(latest=True)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm using the solution suggested in #635, otherwise replicating P2P seems to hang sometimes

server = Server(dht, experts, num_connection_handlers=1)
server.start()
try:
Expand Down Expand Up @@ -318,7 +319,8 @@ def test_client_anomaly_detection():
def _measure_coro_running_time(n_coros, elapsed_fut, counter):
async def coro():
await asyncio.sleep(0.1)
counter.value += 1
with counter.get_lock():
counter.value += 1
Comment on lines -321 to +323
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inplace updates of mp.Value are not atomic, so we need to acquire the lock to avoid race conditions


try:
start_time = time.perf_counter()
Expand Down
4 changes: 2 additions & 2 deletions tests/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,8 +414,8 @@ def run_trainer(batch_size: int, batch_time: float, client_mode: bool):
loss.backward()

optimizer.step()

total_samples_accumulated.value += batch_size
with total_samples_accumulated.get_lock():
total_samples_accumulated.value += batch_size

if not reuse_grad_buffers:
optimizer.zero_grad()
Expand Down
3 changes: 2 additions & 1 deletion tests/test_p2p_daemon_bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ async def readexactly(self, n):


class MockWriter(io.BytesIO):
pass
async def drain(self):
pass


class MockReaderWriter(MockReader, MockWriter):
Expand Down
15 changes: 12 additions & 3 deletions tests/test_start_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,18 @@ def test_cli_run_server_identity_path():
with TemporaryDirectory() as tempdir:
id_path = os.path.join(tempdir, "id")

cloned_env = os.environ.copy()
# overriding the loglevel to prevent debug print statements
cloned_env["HIVEMIND_LOGLEVEL"] = "INFO"

common_server_args = ["--hidden_dim", "4", "--num_handlers", "1"]

server_1_proc = Popen(
["hivemind-server", "--num_experts", "1", "--identity_path", id_path],
["hivemind-server", "--num_experts", "1", "--identity_path", id_path] + common_server_args,
stderr=PIPE,
text=True,
encoding="utf-8",
env=cloned_env,
)

line = server_1_proc.stderr.readline()
Expand All @@ -46,10 +53,11 @@ def test_cli_run_server_identity_path():
assert len(ids_1) == 1

server_2_proc = Popen(
["hivemind-server", "--num_experts", "1", "--identity_path", id_path],
["hivemind-server", "--num_experts", "1", "--identity_path", id_path] + common_server_args,
stderr=PIPE,
text=True,
encoding="utf-8",
env=cloned_env,
)

line = server_2_proc.stderr.readline()
Expand All @@ -61,10 +69,11 @@ def test_cli_run_server_identity_path():
assert len(ids_2) == 1

server_3_proc = Popen(
["hivemind-server", "--num_experts", "1"],
["hivemind-server", "--num_experts", "1"] + common_server_args,
stderr=PIPE,
text=True,
encoding="utf-8",
env=cloned_env,
)

line = server_3_proc.stderr.readline()
Expand Down
15 changes: 9 additions & 6 deletions tests/test_util_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def _future_creator():
@pytest.mark.forked
def test_mpfuture_done_callback():
receiver, sender = mp.Pipe(duplex=False)
events = [mp.Event() for _ in range(6)]
events = [mp.Event() for _ in range(7)]

def _future_creator():
future1, future2, future3 = hivemind.MPFuture(), hivemind.MPFuture(), hivemind.MPFuture()
Expand All @@ -250,7 +250,7 @@ def _check_result_and_set(future):

sender.send((future1, future2))
future2.cancel() # trigger future2 callback from the same process

events[6].set()
events[0].wait()
future1.add_done_callback(
lambda future: events[4].set()
Expand All @@ -262,6 +262,7 @@ def _check_result_and_set(future):

future1, future2 = receiver.recv()
future1.set_result(123)
events[6].wait()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without this event, we could assert future2.done() in the main process before it's cancelled in the child process


with pytest.raises(RuntimeError):
future1.add_done_callback(lambda future: (1, 2, 3))
Expand Down Expand Up @@ -514,21 +515,23 @@ async def test_async_context_flooding():

Here's how the test below works: suppose that the thread pool has at most N workers;
If at least N + 1 coroutines await lock1 concurrently, N of them occupy workers and the rest are awaiting workers;
When the first of N workers acquires lock1, it lets coroutine A inside lock1 and into await sleep(1e-2);
When the first of N workers acquires lock1, it lets coroutine A inside lock1 and into await sleep();
During that sleep, one of the worker-less coroutines will take up the worker freed by coroutine A.
Finally, coroutine A finishes sleeping and immediately gets stuck at lock2, because there are no free workers.
Thus, every single coroutine is either awaiting an already acquired lock, or awaiting for free workers in executor.

"""
total_sleep_time = 1
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't quite understand the purpose of this test since there are no assert statements, but decided to optimize it a bit (especially on machines with >=128 cores) by setting the total sleep time across all coroutines

lock1, lock2 = mp.Lock(), mp.Lock()

num_coros = max(33, mp.cpu_count() * 5 + 1)

async def coro():
async with enter_asynchronously(lock1):
await asyncio.sleep(1e-2)
await asyncio.sleep(total_sleep_time / (num_coros * 2))
async with enter_asynchronously(lock2):
await asyncio.sleep(1e-2)
await asyncio.sleep(total_sleep_time / (num_coros * 2))

num_coros = max(33, mp.cpu_count() * 5 + 1)
await asyncio.wait({asyncio.create_task(coro()) for _ in range(num_coros)})


Expand Down
Loading