Skip to content

Commit 72a0760

Browse files
committed
Clean up resources in DHT/P2P, improve test robustness (#636)
* Clean up resources in DHT and P2P * Update the tests * Gracefully handle SIGTERM in run_server.py * Optimize tests, add another synchronization event in test_mpfuture_done_callback * Disable fail-fast for test matrix * Try temporary fix of test_client_anomaly_detection with DHT init * Acquire locks before mp.Value updates (cherry picked from commit 94c1bf4)
1 parent f76c070 commit 72a0760

13 files changed

+87
-29
lines changed

.github/workflows/run-tests.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ jobs:
1212
strategy:
1313
matrix:
1414
python-version: [ '3.8', '3.9', '3.10', '3.11' ]
15+
fail-fast: false
1516
timeout-minutes: 15
1617
steps:
1718
- uses: actions/checkout@v3

hivemind/dht/dht.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def __init__(
7272
self.num_workers = num_workers
7373

7474
self._record_validator = CompositeValidator(record_validators)
75-
self._inner_pipe, self._outer_pipe = mp.Pipe(duplex=True)
75+
self._inner_pipe, self._outer_pipe = mp.Pipe(duplex=False)
7676
self.shutdown_timeout = shutdown_timeout
7777
self._ready = MPFuture()
7878
self.daemon = daemon
@@ -137,6 +137,7 @@ async def _run():
137137
break
138138

139139
loop.run_until_complete(_run())
140+
loop.close()
140141

141142
def run_in_background(self, await_ready: bool = True, timeout: Optional[float] = None) -> None:
142143
"""
@@ -154,6 +155,7 @@ def shutdown(self) -> None:
154155
"""Shut down a running dht process"""
155156
if self.is_alive():
156157
self._outer_pipe.send(("_shutdown", [], {}))
158+
self._outer_pipe.close()
157159
self.join(self.shutdown_timeout)
158160
if self.is_alive():
159161
logger.warning("DHT did not shut down within the grace period; terminating it the hard way")

hivemind/hivemind_cli/run_dht.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
import time
21
from argparse import ArgumentParser
32
from secrets import token_hex
3+
from signal import SIGINT, SIGTERM, signal, strsignal
4+
from threading import Event
45

56
from hivemind.dht import DHT, DHTNode
67
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
@@ -84,12 +85,19 @@ def main():
8485
)
8586
log_visible_maddrs(dht.get_visible_maddrs(), only_p2p=args.use_ipfs)
8687

88+
exit_event = Event()
89+
90+
def signal_handler(signal_number: int, _) -> None:
91+
logger.info(f"Caught signal {signal_number} ({strsignal(signal_number)}), shutting down")
92+
exit_event.set()
93+
94+
signal(SIGTERM, signal_handler)
95+
signal(SIGINT, signal_handler)
96+
8797
try:
88-
while True:
98+
while not exit_event.is_set():
8999
dht.run_coroutine(report_status, return_future=False)
90-
time.sleep(args.refresh_period)
91-
except KeyboardInterrupt:
92-
logger.info("Caught KeyboardInterrupt, shutting down")
100+
exit_event.wait(args.refresh_period)
93101
finally:
94102
dht.shutdown()
95103

hivemind/hivemind_cli/run_server.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from functools import partial
22
from pathlib import Path
3+
from signal import SIGINT, SIGTERM, signal, strsignal
4+
from threading import Event
35

46
import configargparse
57
import torch
@@ -104,10 +106,20 @@ def main():
104106

105107
server = Server.create(**args, optim_cls=optim_cls, start=True, compression=compression)
106108

109+
exit_event = Event()
110+
111+
def signal_handler(signal_number: int, _) -> None:
112+
logger.info(f"Caught signal {signal_number} ({strsignal(signal_number)}), shutting down")
113+
exit_event.set()
114+
115+
signal(SIGTERM, signal_handler)
116+
signal(SIGINT, signal_handler)
117+
107118
try:
119+
exit_event.wait()
120+
finally:
121+
server.shutdown()
108122
server.join()
109-
except KeyboardInterrupt:
110-
logger.info("Caught KeyboardInterrupt, shutting down")
111123

112124

113125
if __name__ == "__main__":

hivemind/p2p/p2p_daemon_bindings/control.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,7 @@ async def identify(self) -> Tuple[PeerID, Tuple[Multiaddr, ...]]:
322322
resp = p2pd_pb.Response() # type: ignore
323323
await read_pbmsg_safe(reader, resp)
324324
writer.close()
325+
await writer.wait_closed()
325326

326327
raise_if_failed(resp)
327328
peer_id_bytes = resp.identify.id
@@ -343,6 +344,7 @@ async def connect(self, peer_id: PeerID, maddrs: Iterable[Multiaddr]) -> None:
343344
resp = p2pd_pb.Response() # type: ignore
344345
await read_pbmsg_safe(reader, resp)
345346
writer.close()
347+
await writer.wait_closed()
346348
raise_if_failed(resp)
347349

348350
async def list_peers(self) -> Tuple[PeerInfo, ...]:
@@ -352,6 +354,7 @@ async def list_peers(self) -> Tuple[PeerInfo, ...]:
352354
resp = p2pd_pb.Response() # type: ignore
353355
await read_pbmsg_safe(reader, resp)
354356
writer.close()
357+
await writer.wait_closed()
355358
raise_if_failed(resp)
356359

357360
peers = tuple(PeerInfo.from_protobuf(pinfo) for pinfo in resp.peers)
@@ -365,6 +368,7 @@ async def disconnect(self, peer_id: PeerID) -> None:
365368
resp = p2pd_pb.Response() # type: ignore
366369
await read_pbmsg_safe(reader, resp)
367370
writer.close()
371+
await writer.wait_closed()
368372
raise_if_failed(resp)
369373

370374
async def stream_open(
@@ -403,6 +407,7 @@ async def stream_handler(self, proto: str, handler_cb: StreamHandler, balanced:
403407
resp = p2pd_pb.Response() # type: ignore
404408
await read_pbmsg_safe(reader, resp)
405409
writer.close()
410+
await writer.wait_closed()
406411
raise_if_failed(resp)
407412

408413
async def remove_stream_handler(self, proto: str) -> None:
@@ -420,6 +425,7 @@ async def remove_stream_handler(self, proto: str) -> None:
420425
resp = p2pd_pb.Response() # type: ignore
421426
await read_pbmsg_safe(reader, resp)
422427
writer.close()
428+
await writer.wait_closed()
423429
raise_if_failed(resp)
424430

425431
del self.handlers[proto]

hivemind/p2p/p2p_daemon_bindings/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ async def write_unsigned_varint(stream: asyncio.StreamWriter, integer: int, max_
4646
value |= 0x80
4747
byte = value.to_bytes(1, "big")
4848
stream.write(byte)
49+
await stream.drain()
4950
if integer == 0:
5051
break
5152

@@ -77,6 +78,7 @@ async def write_pbmsg(stream: asyncio.StreamWriter, pbmsg: PBMessage) -> None:
7778
await write_unsigned_varint(stream, size)
7879
msg_bytes: bytes = pbmsg.SerializeToString()
7980
stream.write(msg_bytes)
81+
await stream.drain()
8082

8183

8284
async def read_pbmsg_safe(stream: asyncio.StreamReader, pbmsg: PBMessage) -> None:

tests/test_allreduce.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,9 @@ async def wait_synchronously():
108108
wall_time = time.perf_counter() - start_time
109109
# check that event loop had enough time to respond to incoming requests; this is over 50% most of the time
110110
# we set 33% threshold to ensure that the test will pass reliably. If we break prefetch, this drops to <10%
111-
assert time_in_waiting > wall_time / 3, f"Event loop could only run {time_in_waiting / wall_time :.5f} of the time"
111+
assert (
112+
time_in_waiting > wall_time / 3
113+
), f"Event loop could only run {time_in_waiting / wall_time * 100 :.5f}% of the time"
112114

113115

114116
@pytest.mark.parametrize("num_senders", [1, 2, 4, 10])

tests/test_cli_scripts.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from subprocess import PIPE, Popen
44
from time import sleep
55

6-
DHT_START_PATTERN = re.compile(r"Running a DHT instance. To connect other peers to this one, use (.+)$")
6+
_DHT_START_PATTERN = re.compile(r"Running a DHT instance. To connect other peers to this one, use (.+)$")
77

88

99
def test_dht_connection_successful():
@@ -23,32 +23,39 @@ def test_dht_connection_successful():
2323

2424
first_line = dht_proc.stderr.readline()
2525
second_line = dht_proc.stderr.readline()
26-
dht_pattern_match = DHT_START_PATTERN.search(first_line)
26+
dht_pattern_match = _DHT_START_PATTERN.search(first_line)
2727
assert dht_pattern_match is not None, first_line
2828
assert "Full list of visible multiaddresses:" in second_line, second_line
2929

3030
initial_peers = dht_pattern_match.group(1).split(" ")
3131

3232
dht_client_proc = Popen(
33-
["hivemind-dht", *initial_peers, "--host_maddrs", "/ip4/127.0.0.1/tcp/0"],
33+
[
34+
"hivemind-dht",
35+
*initial_peers,
36+
"--host_maddrs",
37+
"/ip4/127.0.0.1/tcp/0",
38+
"--refresh_period",
39+
str(dht_refresh_period),
40+
],
3441
stderr=PIPE,
3542
text=True,
3643
encoding="utf-8",
3744
env=cloned_env,
3845
)
3946

47+
# ensure we get the output of dht_proc after the start of dht_client_proc
48+
sleep(2 * dht_refresh_period)
49+
4050
# skip first two lines with connectivity info
4151
for _ in range(2):
4252
dht_client_proc.stderr.readline()
4353
first_report_msg = dht_client_proc.stderr.readline()
4454

4555
assert "2 DHT nodes (including this one) are in the local routing table" in first_report_msg, first_report_msg
4656

47-
# ensure we get the output of dht_proc after the start of dht_client_proc
48-
sleep(dht_refresh_period)
49-
5057
# expect that one of the next logging outputs from the first peer shows a new connection
51-
for _ in range(5):
58+
for _ in range(10):
5259
first_report_msg = dht_proc.stderr.readline()
5360
second_report_msg = dht_proc.stderr.readline()
5461

@@ -63,6 +70,9 @@ def test_dht_connection_successful():
6370
and "Local storage contains 0 keys" in second_report_msg
6471
)
6572

73+
dht_proc.stderr.close()
74+
dht_client_proc.stderr.close()
75+
6676
dht_proc.terminate()
6777
dht_client_proc.terminate()
6878

tests/test_moe.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,7 @@ def test_client_anomaly_detection():
282282
experts["expert.3"].module.ffn.weight.data[0, 0] = float("nan")
283283

284284
dht = DHT(start=True)
285+
dht.get_visible_maddrs(latest=True)
285286
server = Server(dht, experts, num_connection_handlers=1)
286287
server.start()
287288
try:
@@ -318,7 +319,8 @@ def test_client_anomaly_detection():
318319
def _measure_coro_running_time(n_coros, elapsed_fut, counter):
319320
async def coro():
320321
await asyncio.sleep(0.1)
321-
counter.value += 1
322+
with counter.get_lock():
323+
counter.value += 1
322324

323325
try:
324326
start_time = time.perf_counter()

tests/test_optimizer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -414,8 +414,8 @@ def run_trainer(batch_size: int, batch_time: float, client_mode: bool):
414414
loss.backward()
415415

416416
optimizer.step()
417-
418-
total_samples_accumulated.value += batch_size
417+
with total_samples_accumulated.get_lock():
418+
total_samples_accumulated.value += batch_size
419419

420420
if not reuse_grad_buffers:
421421
optimizer.zero_grad()

tests/test_p2p_daemon_bindings.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,8 @@ async def readexactly(self, n):
7171

7272

7373
class MockWriter(io.BytesIO):
74-
pass
74+
async def drain(self):
75+
pass
7576

7677

7778
class MockReaderWriter(MockReader, MockWriter):

tests/test_start_server.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,18 @@ def test_cli_run_server_identity_path():
2727
with TemporaryDirectory() as tempdir:
2828
id_path = os.path.join(tempdir, "id")
2929

30+
cloned_env = os.environ.copy()
31+
# overriding the loglevel to prevent debug print statements
32+
cloned_env["HIVEMIND_LOGLEVEL"] = "INFO"
33+
34+
common_server_args = ["--hidden_dim", "4", "--num_handlers", "1"]
35+
3036
server_1_proc = Popen(
31-
["hivemind-server", "--num_experts", "1", "--identity_path", id_path],
37+
["hivemind-server", "--num_experts", "1", "--identity_path", id_path] + common_server_args,
3238
stderr=PIPE,
3339
text=True,
3440
encoding="utf-8",
41+
env=cloned_env,
3542
)
3643

3744
line = server_1_proc.stderr.readline()
@@ -46,10 +53,11 @@ def test_cli_run_server_identity_path():
4653
assert len(ids_1) == 1
4754

4855
server_2_proc = Popen(
49-
["hivemind-server", "--num_experts", "1", "--identity_path", id_path],
56+
["hivemind-server", "--num_experts", "1", "--identity_path", id_path] + common_server_args,
5057
stderr=PIPE,
5158
text=True,
5259
encoding="utf-8",
60+
env=cloned_env,
5361
)
5462

5563
line = server_2_proc.stderr.readline()
@@ -61,10 +69,11 @@ def test_cli_run_server_identity_path():
6169
assert len(ids_2) == 1
6270

6371
server_3_proc = Popen(
64-
["hivemind-server", "--num_experts", "1"],
72+
["hivemind-server", "--num_experts", "1"] + common_server_args,
6573
stderr=PIPE,
6674
text=True,
6775
encoding="utf-8",
76+
env=cloned_env,
6877
)
6978

7079
line = server_3_proc.stderr.readline()

tests/test_util_modules.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ def _future_creator():
233233
@pytest.mark.forked
234234
def test_mpfuture_done_callback():
235235
receiver, sender = mp.Pipe(duplex=False)
236-
events = [mp.Event() for _ in range(6)]
236+
events = [mp.Event() for _ in range(7)]
237237

238238
def _future_creator():
239239
future1, future2, future3 = hivemind.MPFuture(), hivemind.MPFuture(), hivemind.MPFuture()
@@ -250,7 +250,7 @@ def _check_result_and_set(future):
250250

251251
sender.send((future1, future2))
252252
future2.cancel() # trigger future2 callback from the same process
253-
253+
events[6].set()
254254
events[0].wait()
255255
future1.add_done_callback(
256256
lambda future: events[4].set()
@@ -262,6 +262,7 @@ def _check_result_and_set(future):
262262

263263
future1, future2 = receiver.recv()
264264
future1.set_result(123)
265+
events[6].wait()
265266

266267
with pytest.raises(RuntimeError):
267268
future1.add_done_callback(lambda future: (1, 2, 3))
@@ -514,21 +515,23 @@ async def test_async_context_flooding():
514515
515516
Here's how the test below works: suppose that the thread pool has at most N workers;
516517
If at least N + 1 coroutines await lock1 concurrently, N of them occupy workers and the rest are awaiting workers;
517-
When the first of N workers acquires lock1, it lets coroutine A inside lock1 and into await sleep(1e-2);
518+
When the first of N workers acquires lock1, it lets coroutine A inside lock1 and into await sleep();
518519
During that sleep, one of the worker-less coroutines will take up the worker freed by coroutine A.
519520
Finally, coroutine A finishes sleeping and immediately gets stuck at lock2, because there are no free workers.
520521
Thus, every single coroutine is either awaiting an already acquired lock, or awaiting for free workers in executor.
521522
522523
"""
524+
total_sleep_time = 1
523525
lock1, lock2 = mp.Lock(), mp.Lock()
524526

527+
num_coros = max(33, mp.cpu_count() * 5 + 1)
528+
525529
async def coro():
526530
async with enter_asynchronously(lock1):
527-
await asyncio.sleep(1e-2)
531+
await asyncio.sleep(total_sleep_time / (num_coros * 2))
528532
async with enter_asynchronously(lock2):
529-
await asyncio.sleep(1e-2)
533+
await asyncio.sleep(total_sleep_time / (num_coros * 2))
530534

531-
num_coros = max(33, mp.cpu_count() * 5 + 1)
532535
await asyncio.wait({asyncio.create_task(coro()) for _ in range(num_coros)})
533536

534537

0 commit comments

Comments
 (0)