-
Notifications
You must be signed in to change notification settings - Fork 186
Clean up resources in DHT/P2P, improve test robustness #636
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
cac9298
bed05d7
80e3ec3
30ad882
c8dee55
0b1264d
85c86ad
e93416c
17d2a82
aa9caf4
0590954
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
import time | ||
from argparse import ArgumentParser | ||
from secrets import token_hex | ||
from signal import SIGINT, SIGTERM, signal, strsignal | ||
from threading import Event | ||
|
||
from hivemind.dht import DHT, DHTNode | ||
from hivemind.utils.logging import get_logger, use_hivemind_log_handler | ||
|
@@ -84,12 +85,19 @@ def main(): | |
) | ||
log_visible_maddrs(dht.get_visible_maddrs(), only_p2p=args.use_ipfs) | ||
|
||
exit_event = Event() | ||
|
||
def signal_handler(signal_number: int, _) -> None: | ||
logger.info(f"Caught signal {signal_number} ({strsignal(signal_number)}), shutting down") | ||
exit_event.set() | ||
|
||
signal(SIGTERM, signal_handler) | ||
signal(SIGINT, signal_handler) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We terminate hivemind-server and hivemind-dht in tests, so a signal handler is necessary to ensure child processes are terminated correctly |
||
|
||
try: | ||
while True: | ||
while not exit_event.is_set(): | ||
dht.run_coroutine(report_status, return_future=False) | ||
time.sleep(args.refresh_period) | ||
except KeyboardInterrupt: | ||
logger.info("Caught KeyboardInterrupt, shutting down") | ||
exit_event.wait(args.refresh_period) | ||
finally: | ||
dht.shutdown() | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -322,6 +322,7 @@ async def identify(self) -> Tuple[PeerID, Tuple[Multiaddr, ...]]: | |
resp = p2pd_pb.Response() # type: ignore | ||
await read_pbmsg_safe(reader, resp) | ||
writer.close() | ||
await writer.wait_closed() | ||
Comment on lines
324
to
+325
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. wait_closed is recommended after close() |
||
|
||
raise_if_failed(resp) | ||
peer_id_bytes = resp.identify.id | ||
|
@@ -343,6 +344,7 @@ async def connect(self, peer_id: PeerID, maddrs: Iterable[Multiaddr]) -> None: | |
resp = p2pd_pb.Response() # type: ignore | ||
await read_pbmsg_safe(reader, resp) | ||
writer.close() | ||
await writer.wait_closed() | ||
raise_if_failed(resp) | ||
|
||
async def list_peers(self) -> Tuple[PeerInfo, ...]: | ||
|
@@ -352,6 +354,7 @@ async def list_peers(self) -> Tuple[PeerInfo, ...]: | |
resp = p2pd_pb.Response() # type: ignore | ||
await read_pbmsg_safe(reader, resp) | ||
writer.close() | ||
await writer.wait_closed() | ||
raise_if_failed(resp) | ||
|
||
peers = tuple(PeerInfo.from_protobuf(pinfo) for pinfo in resp.peers) | ||
|
@@ -365,6 +368,7 @@ async def disconnect(self, peer_id: PeerID) -> None: | |
resp = p2pd_pb.Response() # type: ignore | ||
await read_pbmsg_safe(reader, resp) | ||
writer.close() | ||
await writer.wait_closed() | ||
raise_if_failed(resp) | ||
|
||
async def stream_open( | ||
|
@@ -403,6 +407,7 @@ async def stream_handler(self, proto: str, handler_cb: StreamHandler, balanced: | |
resp = p2pd_pb.Response() # type: ignore | ||
await read_pbmsg_safe(reader, resp) | ||
writer.close() | ||
await writer.wait_closed() | ||
raise_if_failed(resp) | ||
|
||
async def remove_stream_handler(self, proto: str) -> None: | ||
|
@@ -420,6 +425,7 @@ async def remove_stream_handler(self, proto: str) -> None: | |
resp = p2pd_pb.Response() # type: ignore | ||
await read_pbmsg_safe(reader, resp) | ||
writer.close() | ||
await writer.wait_closed() | ||
raise_if_failed(resp) | ||
|
||
del self.handlers[proto] |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -46,6 +46,7 @@ async def write_unsigned_varint(stream: asyncio.StreamWriter, integer: int, max_ | |
value |= 0x80 | ||
byte = value.to_bytes(1, "big") | ||
stream.write(byte) | ||
await stream.drain() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. drain() is recommended after write() |
||
if integer == 0: | ||
break | ||
|
||
|
@@ -77,6 +78,7 @@ async def write_pbmsg(stream: asyncio.StreamWriter, pbmsg: PBMessage) -> None: | |
await write_unsigned_varint(stream, size) | ||
msg_bytes: bytes = pbmsg.SerializeToString() | ||
stream.write(msg_bytes) | ||
await stream.drain() | ||
|
||
|
||
async def read_pbmsg_safe(stream: asyncio.StreamReader, pbmsg: PBMessage) -> None: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,7 +3,7 @@ | |
from subprocess import PIPE, Popen | ||
from time import sleep | ||
|
||
DHT_START_PATTERN = re.compile(r"Running a DHT instance. To connect other peers to this one, use (.+)$") | ||
_DHT_START_PATTERN = re.compile(r"Running a DHT instance. To connect other peers to this one, use (.+)$") | ||
|
||
|
||
def test_dht_connection_successful(): | ||
|
@@ -23,32 +23,39 @@ def test_dht_connection_successful(): | |
|
||
first_line = dht_proc.stderr.readline() | ||
second_line = dht_proc.stderr.readline() | ||
dht_pattern_match = DHT_START_PATTERN.search(first_line) | ||
dht_pattern_match = _DHT_START_PATTERN.search(first_line) | ||
assert dht_pattern_match is not None, first_line | ||
assert "Full list of visible multiaddresses:" in second_line, second_line | ||
|
||
initial_peers = dht_pattern_match.group(1).split(" ") | ||
|
||
dht_client_proc = Popen( | ||
["hivemind-dht", *initial_peers, "--host_maddrs", "/ip4/127.0.0.1/tcp/0"], | ||
[ | ||
"hivemind-dht", | ||
*initial_peers, | ||
"--host_maddrs", | ||
"/ip4/127.0.0.1/tcp/0", | ||
"--refresh_period", | ||
str(dht_refresh_period), | ||
], | ||
stderr=PIPE, | ||
text=True, | ||
encoding="utf-8", | ||
env=cloned_env, | ||
) | ||
|
||
# ensure we get the output of dht_proc after the start of dht_client_proc | ||
sleep(2 * dht_refresh_period) | ||
|
||
# skip first two lines with connectivity info | ||
for _ in range(2): | ||
dht_client_proc.stderr.readline() | ||
first_report_msg = dht_client_proc.stderr.readline() | ||
|
||
assert "2 DHT nodes (including this one) are in the local routing table" in first_report_msg, first_report_msg | ||
|
||
# ensure we get the output of dht_proc after the start of dht_client_proc | ||
sleep(dht_refresh_period) | ||
|
||
# expect that one of the next logging outputs from the first peer shows a new connection | ||
for _ in range(5): | ||
for _ in range(10): | ||
Comment on lines
-51
to
+58
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For some reason, we might not be getting the updated routing table in DHT processes after 5 iterations. Increasing the limit seems to help, although the solution is still quite hacky |
||
first_report_msg = dht_proc.stderr.readline() | ||
second_report_msg = dht_proc.stderr.readline() | ||
|
||
|
@@ -63,6 +70,9 @@ def test_dht_connection_successful(): | |
and "Local storage contains 0 keys" in second_report_msg | ||
) | ||
|
||
dht_proc.stderr.close() | ||
dht_client_proc.stderr.close() | ||
|
||
dht_proc.terminate() | ||
dht_client_proc.terminate() | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -282,6 +282,7 @@ def test_client_anomaly_detection(): | |
experts["expert.3"].module.ffn.weight.data[0, 0] = float("nan") | ||
|
||
dht = DHT(start=True) | ||
dht.get_visible_maddrs(latest=True) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm using the solution suggested in #635, otherwise replicating P2P seems to hang sometimes |
||
server = Server(dht, experts, num_connection_handlers=1) | ||
server.start() | ||
try: | ||
|
@@ -318,7 +319,8 @@ def test_client_anomaly_detection(): | |
def _measure_coro_running_time(n_coros, elapsed_fut, counter): | ||
async def coro(): | ||
await asyncio.sleep(0.1) | ||
counter.value += 1 | ||
with counter.get_lock(): | ||
counter.value += 1 | ||
Comment on lines
-321
to
+323
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Inplace updates of mp.Value are not atomic, so we need to acquire the lock to avoid race conditions |
||
|
||
try: | ||
start_time = time.perf_counter() | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -233,7 +233,7 @@ def _future_creator(): | |
@pytest.mark.forked | ||
def test_mpfuture_done_callback(): | ||
receiver, sender = mp.Pipe(duplex=False) | ||
events = [mp.Event() for _ in range(6)] | ||
events = [mp.Event() for _ in range(7)] | ||
|
||
def _future_creator(): | ||
future1, future2, future3 = hivemind.MPFuture(), hivemind.MPFuture(), hivemind.MPFuture() | ||
|
@@ -250,7 +250,7 @@ def _check_result_and_set(future): | |
|
||
sender.send((future1, future2)) | ||
future2.cancel() # trigger future2 callback from the same process | ||
|
||
events[6].set() | ||
events[0].wait() | ||
future1.add_done_callback( | ||
lambda future: events[4].set() | ||
|
@@ -262,6 +262,7 @@ def _check_result_and_set(future): | |
|
||
future1, future2 = receiver.recv() | ||
future1.set_result(123) | ||
events[6].wait() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Without this event, we could |
||
|
||
with pytest.raises(RuntimeError): | ||
future1.add_done_callback(lambda future: (1, 2, 3)) | ||
|
@@ -514,21 +515,23 @@ async def test_async_context_flooding(): | |
|
||
Here's how the test below works: suppose that the thread pool has at most N workers; | ||
If at least N + 1 coroutines await lock1 concurrently, N of them occupy workers and the rest are awaiting workers; | ||
When the first of N workers acquires lock1, it lets coroutine A inside lock1 and into await sleep(1e-2); | ||
When the first of N workers acquires lock1, it lets coroutine A inside lock1 and into await sleep(); | ||
During that sleep, one of the worker-less coroutines will take up the worker freed by coroutine A. | ||
Finally, coroutine A finishes sleeping and immediately gets stuck at lock2, because there are no free workers. | ||
Thus, every single coroutine is either awaiting an already acquired lock, or awaiting for free workers in executor. | ||
|
||
""" | ||
total_sleep_time = 1 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't quite understand the purpose of this test since there are no assert statements, but decided to optimize it a bit (especially on machines with >=128 cores) by setting the total sleep time across all coroutines |
||
lock1, lock2 = mp.Lock(), mp.Lock() | ||
|
||
num_coros = max(33, mp.cpu_count() * 5 + 1) | ||
|
||
async def coro(): | ||
async with enter_asynchronously(lock1): | ||
await asyncio.sleep(1e-2) | ||
await asyncio.sleep(total_sleep_time / (num_coros * 2)) | ||
async with enter_asynchronously(lock2): | ||
await asyncio.sleep(1e-2) | ||
await asyncio.sleep(total_sleep_time / (num_coros * 2)) | ||
|
||
num_coros = max(33, mp.cpu_count() * 5 + 1) | ||
await asyncio.wait({asyncio.create_task(coro()) for _ in range(num_coros)}) | ||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Everything works with duplex=False according to tests, best to use pipes since duplex relies on socketpair, which is slower [1] [2]