Skip to content

Commit 72a0760

Browse files
committed
Clean up resources in DHT/P2P, improve test robustness (#636)
* Clean up resources in DHT and P2P * Update the tests * Gracefully handle SIGTERM in run_server.py * Optimize tests, add another synchronization event in test_mpfuture_done_callback * Disable fail-fast for test matrix * Try temporary fix of test_client_anomaly_detection with DHT init * Acquire locks before mp.Value updates (cherry picked from commit 94c1bf4)
1 parent f76c070 commit 72a0760

File tree

13 files changed

+87
-29
lines changed

13 files changed

+87
-29
lines changed

.github/workflows/run-tests.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ jobs:
1212
strategy:
1313
matrix:
1414
python-version: [ '3.8', '3.9', '3.10', '3.11' ]
15+
fail-fast: false
1516
timeout-minutes: 15
1617
steps:
1718
- uses: actions/checkout@v3

hivemind/dht/dht.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def __init__(
7272
self.num_workers = num_workers
7373

7474
self._record_validator = CompositeValidator(record_validators)
75-
self._inner_pipe, self._outer_pipe = mp.Pipe(duplex=True)
75+
self._inner_pipe, self._outer_pipe = mp.Pipe(duplex=False)
7676
self.shutdown_timeout = shutdown_timeout
7777
self._ready = MPFuture()
7878
self.daemon = daemon
@@ -137,6 +137,7 @@ async def _run():
137137
break
138138

139139
loop.run_until_complete(_run())
140+
loop.close()
140141

141142
def run_in_background(self, await_ready: bool = True, timeout: Optional[float] = None) -> None:
142143
"""
@@ -154,6 +155,7 @@ def shutdown(self) -> None:
154155
"""Shut down a running dht process"""
155156
if self.is_alive():
156157
self._outer_pipe.send(("_shutdown", [], {}))
158+
self._outer_pipe.close()
157159
self.join(self.shutdown_timeout)
158160
if self.is_alive():
159161
logger.warning("DHT did not shut down within the grace period; terminating it the hard way")

hivemind/hivemind_cli/run_dht.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
import time
21
from argparse import ArgumentParser
32
from secrets import token_hex
3+
from signal import SIGINT, SIGTERM, signal, strsignal
4+
from threading import Event
45

56
from hivemind.dht import DHT, DHTNode
67
from hivemind.utils.logging import get_logger, use_hivemind_log_handler
@@ -84,12 +85,19 @@ def main():
8485
)
8586
log_visible_maddrs(dht.get_visible_maddrs(), only_p2p=args.use_ipfs)
8687

88+
exit_event = Event()
89+
90+
def signal_handler(signal_number: int, _) -> None:
91+
logger.info(f"Caught signal {signal_number} ({strsignal(signal_number)}), shutting down")
92+
exit_event.set()
93+
94+
signal(SIGTERM, signal_handler)
95+
signal(SIGINT, signal_handler)
96+
8797
try:
88-
while True:
98+
while not exit_event.is_set():
8999
dht.run_coroutine(report_status, return_future=False)
90-
time.sleep(args.refresh_period)
91-
except KeyboardInterrupt:
92-
logger.info("Caught KeyboardInterrupt, shutting down")
100+
exit_event.wait(args.refresh_period)
93101
finally:
94102
dht.shutdown()
95103

hivemind/hivemind_cli/run_server.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from functools import partial
22
from pathlib import Path
3+
from signal import SIGINT, SIGTERM, signal, strsignal
4+
from threading import Event
35

46
import configargparse
57
import torch
@@ -104,10 +106,20 @@ def main():
104106

105107
server = Server.create(**args, optim_cls=optim_cls, start=True, compression=compression)
106108

109+
exit_event = Event()
110+
111+
def signal_handler(signal_number: int, _) -> None:
112+
logger.info(f"Caught signal {signal_number} ({strsignal(signal_number)}), shutting down")
113+
exit_event.set()
114+
115+
signal(SIGTERM, signal_handler)
116+
signal(SIGINT, signal_handler)
117+
107118
try:
119+
exit_event.wait()
120+
finally:
121+
server.shutdown()
108122
server.join()
109-
except KeyboardInterrupt:
110-
logger.info("Caught KeyboardInterrupt, shutting down")
111123

112124

113125
if __name__ == "__main__":

hivemind/p2p/p2p_daemon_bindings/control.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,7 @@ async def identify(self) -> Tuple[PeerID, Tuple[Multiaddr, ...]]:
322322
resp = p2pd_pb.Response() # type: ignore
323323
await read_pbmsg_safe(reader, resp)
324324
writer.close()
325+
await writer.wait_closed()
325326

326327
raise_if_failed(resp)
327328
peer_id_bytes = resp.identify.id
@@ -343,6 +344,7 @@ async def connect(self, peer_id: PeerID, maddrs: Iterable[Multiaddr]) -> None:
343344
resp = p2pd_pb.Response() # type: ignore
344345
await read_pbmsg_safe(reader, resp)
345346
writer.close()
347+
await writer.wait_closed()
346348
raise_if_failed(resp)
347349

348350
async def list_peers(self) -> Tuple[PeerInfo, ...]:
@@ -352,6 +354,7 @@ async def list_peers(self) -> Tuple[PeerInfo, ...]:
352354
resp = p2pd_pb.Response() # type: ignore
353355
await read_pbmsg_safe(reader, resp)
354356
writer.close()
357+
await writer.wait_closed()
355358
raise_if_failed(resp)
356359

357360
peers = tuple(PeerInfo.from_protobuf(pinfo) for pinfo in resp.peers)
@@ -365,6 +368,7 @@ async def disconnect(self, peer_id: PeerID) -> None:
365368
resp = p2pd_pb.Response() # type: ignore
366369
await read_pbmsg_safe(reader, resp)
367370
writer.close()
371+
await writer.wait_closed()
368372
raise_if_failed(resp)
369373

370374
async def stream_open(
@@ -403,6 +407,7 @@ async def stream_handler(self, proto: str, handler_cb: StreamHandler, balanced:
403407
resp = p2pd_pb.Response() # type: ignore
404408
await read_pbmsg_safe(reader, resp)
405409
writer.close()
410+
await writer.wait_closed()
406411
raise_if_failed(resp)
407412

408413
async def remove_stream_handler(self, proto: str) -> None:
@@ -420,6 +425,7 @@ async def remove_stream_handler(self, proto: str) -> None:
420425
resp = p2pd_pb.Response() # type: ignore
421426
await read_pbmsg_safe(reader, resp)
422427
writer.close()
428+
await writer.wait_closed()
423429
raise_if_failed(resp)
424430

425431
del self.handlers[proto]

hivemind/p2p/p2p_daemon_bindings/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ async def write_unsigned_varint(stream: asyncio.StreamWriter, integer: int, max_
4646
value |= 0x80
4747
byte = value.to_bytes(1, "big")
4848
stream.write(byte)
49+
await stream.drain()
4950
if integer == 0:
5051
break
5152

@@ -77,6 +78,7 @@ async def write_pbmsg(stream: asyncio.StreamWriter, pbmsg: PBMessage) -> None:
7778
await write_unsigned_varint(stream, size)
7879
msg_bytes: bytes = pbmsg.SerializeToString()
7980
stream.write(msg_bytes)
81+
await stream.drain()
8082

8183

8284
async def read_pbmsg_safe(stream: asyncio.StreamReader, pbmsg: PBMessage) -> None:

tests/test_allreduce.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,9 @@ async def wait_synchronously():
108108
wall_time = time.perf_counter() - start_time
109109
# check that event loop had enough time to respond to incoming requests; this is over 50% most of the time
110110
# we set 33% threshold to ensure that the test will pass reliably. If we break prefetch, this drops to <10%
111-
assert time_in_waiting > wall_time / 3, f"Event loop could only run {time_in_waiting / wall_time :.5f} of the time"
111+
assert (
112+
time_in_waiting > wall_time / 3
113+
), f"Event loop could only run {time_in_waiting / wall_time * 100 :.5f}% of the time"
112114

113115

114116
@pytest.mark.parametrize("num_senders", [1, 2, 4, 10])

tests/test_cli_scripts.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from subprocess import PIPE, Popen
44
from time import sleep
55

6-
DHT_START_PATTERN = re.compile(r"Running a DHT instance. To connect other peers to this one, use (.+)$")
6+
_DHT_START_PATTERN = re.compile(r"Running a DHT instance. To connect other peers to this one, use (.+)$")
77

88

99
def test_dht_connection_successful():
@@ -23,32 +23,39 @@ def test_dht_connection_successful():
2323

2424
first_line = dht_proc.stderr.readline()
2525
second_line = dht_proc.stderr.readline()
26-
dht_pattern_match = DHT_START_PATTERN.search(first_line)
26+
dht_pattern_match = _DHT_START_PATTERN.search(first_line)
2727
assert dht_pattern_match is not None, first_line
2828
assert "Full list of visible multiaddresses:" in second_line, second_line
2929

3030
initial_peers = dht_pattern_match.group(1).split(" ")
3131

3232
dht_client_proc = Popen(
33-
["hivemind-dht", *initial_peers, "--host_maddrs", "/ip4/127.0.0.1/tcp/0"],
33+
[
34+
"hivemind-dht",
35+
*initial_peers,
36+
"--host_maddrs",
37+
"/ip4/127.0.0.1/tcp/0",
38+
"--refresh_period",
39+
str(dht_refresh_period),
40+
],
3441
stderr=PIPE,
3542
text=True,
3643
encoding="utf-8",
3744
env=cloned_env,
3845
)
3946

47+
# ensure we get the output of dht_proc after the start of dht_client_proc
48+
sleep(2 * dht_refresh_period)
49+
4050
# skip first two lines with connectivity info
4151
for _ in range(2):
4252
dht_client_proc.stderr.readline()
4353
first_report_msg = dht_client_proc.stderr.readline()
4454

4555
assert "2 DHT nodes (including this one) are in the local routing table" in first_report_msg, first_report_msg
4656

47-
# ensure we get the output of dht_proc after the start of dht_client_proc
48-
sleep(dht_refresh_period)
49-
5057
# expect that one of the next logging outputs from the first peer shows a new connection
51-
for _ in range(5):
58+
for _ in range(10):
5259
first_report_msg = dht_proc.stderr.readline()
5360
second_report_msg = dht_proc.stderr.readline()
5461

@@ -63,6 +70,9 @@ def test_dht_connection_successful():
6370
and "Local storage contains 0 keys" in second_report_msg
6471
)
6572

73+
dht_proc.stderr.close()
74+
dht_client_proc.stderr.close()
75+
6676
dht_proc.terminate()
6777
dht_client_proc.terminate()
6878

tests/test_moe.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,7 @@ def test_client_anomaly_detection():
282282
experts["expert.3"].module.ffn.weight.data[0, 0] = float("nan")
283283

284284
dht = DHT(start=True)
285+
dht.get_visible_maddrs(latest=True)
285286
server = Server(dht, experts, num_connection_handlers=1)
286287
server.start()
287288
try:
@@ -318,7 +319,8 @@ def test_client_anomaly_detection():
318319
def _measure_coro_running_time(n_coros, elapsed_fut, counter):
319320
async def coro():
320321
await asyncio.sleep(0.1)
321-
counter.value += 1
322+
with counter.get_lock():
323+
counter.value += 1
322324

323325
try:
324326
start_time = time.perf_counter()

tests/test_optimizer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -414,8 +414,8 @@ def run_trainer(batch_size: int, batch_time: float, client_mode: bool):
414414
loss.backward()
415415

416416
optimizer.step()
417-
418-
total_samples_accumulated.value += batch_size
417+
with total_samples_accumulated.get_lock():
418+
total_samples_accumulated.value += batch_size
419419

420420
if not reuse_grad_buffers:
421421
optimizer.zero_grad()

0 commit comments

Comments
 (0)