Skip to content

Commit 589cb2c

Browse files
authored
Allow RemoteExpertWorker run coroutines concurrently (#561)
Previously, `RemoteExpertWorker` ran one coroutine at a time, so hivemind.moe/Petals clients were very slow for concurrent calls.
1 parent 3164928 commit 589cb2c

File tree

4 files changed

+67
-31
lines changed

4 files changed

+67
-31
lines changed

.github/workflows/check-style.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,4 @@ jobs:
3232
- uses: codespell-project/actions-codespell@v1
3333
with:
3434
only_warn: 1
35+
ignore_words_list: ibrary,nd
Lines changed: 16 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1+
import asyncio
12
import os
23
from concurrent.futures import Future
3-
from queue import Queue
44
from threading import Thread
55
from typing import Awaitable, Optional
66

@@ -10,39 +10,27 @@
1010
class RemoteExpertWorker:
1111
"""Local thread for managing async tasks related to RemoteExpert"""
1212

13-
_task_queue: Queue = Queue()
14-
_event_thread: Optional[Thread] = None
15-
_pid: int = -1
13+
_event_thread = None
14+
_event_loop_fut = None
15+
_pid = None
1616

1717
@classmethod
18-
def _run(cls):
19-
loop = switch_to_uvloop()
20-
21-
async def receive_tasks():
22-
while True:
23-
cor, future = cls._task_queue.get()
24-
try:
25-
result = await cor
26-
except Exception as e:
27-
future.set_exception(e)
28-
continue
29-
if not future.cancelled():
30-
future.set_result(result)
31-
32-
loop.run_until_complete(receive_tasks())
18+
def _run_event_loop(cls):
19+
try:
20+
loop = switch_to_uvloop()
21+
cls._event_loop_fut.set_result(loop)
22+
except Exception as e:
23+
cls._event_loop_fut.set_exception(e)
24+
loop.run_forever()
3325

3426
@classmethod
3527
def run_coroutine(cls, coro: Awaitable, return_future: bool = False):
3628
if cls._event_thread is None or cls._pid != os.getpid():
3729
cls._pid = os.getpid()
38-
cls._event_thread = Thread(target=cls._run, daemon=True)
30+
cls._event_loop_fut = Future()
31+
cls._event_thread = Thread(target=cls._run_event_loop, daemon=True)
3932
cls._event_thread.start()
4033

41-
future = Future()
42-
cls._task_queue.put((coro, future))
43-
44-
if return_future:
45-
return future
46-
47-
result = future.result()
48-
return result
34+
loop = cls._event_loop_fut.result()
35+
future = asyncio.run_coroutine_threadsafe(coro, loop)
36+
return future if return_future else future.result()

hivemind/p2p/servicer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class RPCHandler:
1818

1919
class StubBase:
2020
"""
21-
Base class for P2P RPC stubs. The interface mimicks gRPC stubs.
21+
Base class for P2P RPC stubs. The interface mimics gRPC stubs.
2222
2323
Servicer derives stub classes for particular services (e.g. DHT, averager, etc.) from StubBase,
2424
adding the necessary rpc_* methods. Calls to these methods are translated to calls to the remote peer.
@@ -32,7 +32,7 @@ def __init__(self, p2p: P2P, peer: PeerID, namespace: Optional[str]):
3232

3333
class ServicerBase:
3434
"""
35-
Base class for P2P RPC servicers (e.g. DHT, averager, MoE server). The interface mimicks gRPC servicers.
35+
Base class for P2P RPC servicers (e.g. DHT, averager, MoE server). The interface mimics gRPC servicers.
3636
3737
- ``add_p2p_handlers(self, p2p)`` registers all rpc_* methods of the derived class as P2P handlers, allowing
3838
other peers to call them. It uses type annotations for the ``request`` parameter and the return value

tests/test_moe.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,23 @@
1+
import asyncio
2+
import ctypes
3+
import multiprocessing as mp
4+
import threading
5+
import time
6+
17
import numpy as np
28
import pytest
39
import torch
410

511
from hivemind.dht import DHT
612
from hivemind.moe.client.expert import RemoteExpert, create_remote_experts
713
from hivemind.moe.client.moe import DUMMY, RemoteMixtureOfExperts, _RemoteCallMany
14+
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
815
from hivemind.moe.client.switch_moe import RemoteSwitchMixtureOfExperts
916
from hivemind.moe.expert_uid import ExpertInfo
1017
from hivemind.moe.server import ModuleBackend, Server, background_server, declare_experts
1118
from hivemind.moe.server.layers import name_to_block
1219
from hivemind.p2p.p2p_daemon_bindings.control import P2PHandlerError
13-
from hivemind.utils import BatchTensorDescriptor, get_dht_time
20+
from hivemind.utils import BatchTensorDescriptor, MPFuture, get_dht_time
1421

1522

1623
@pytest.mark.forked
@@ -306,3 +313,43 @@ def test_client_anomaly_detection():
306313

307314
finally:
308315
server.shutdown()
316+
317+
318+
def _measure_coro_running_time(n_coros, elapsed_fut, counter):
319+
async def coro():
320+
await asyncio.sleep(0.1)
321+
counter.value += 1
322+
323+
try:
324+
start_time = time.perf_counter()
325+
326+
futures = [
327+
RemoteExpertWorker.run_coroutine(coro(), return_future=True) for _ in range(n_coros - 1)
328+
] # Non-blocking calls
329+
RemoteExpertWorker.run_coroutine(coro(), return_future=False) # A blocking call
330+
for fut in futures:
331+
fut.result()
332+
333+
elapsed_fut.set_result(time.perf_counter() - start_time)
334+
except Exception as e:
335+
elapsed_fut.set_exception(e)
336+
337+
338+
@pytest.mark.forked
339+
def test_remote_expert_worker_runs_coros_concurrently(n_processes=4, n_coros=10):
340+
processes = []
341+
counter = mp.Value(ctypes.c_int64)
342+
for i in range(n_processes):
343+
elapsed_fut = MPFuture()
344+
factory = threading.Thread if i % 2 == 0 else mp.Process # Test both threads and processes
345+
346+
proc = factory(target=_measure_coro_running_time, args=(n_coros, elapsed_fut, counter))
347+
proc.start()
348+
processes.append((proc, elapsed_fut))
349+
350+
for proc, elapsed_fut in processes:
351+
# Ensure that the coroutines were run concurrently, not sequentially
352+
assert elapsed_fut.result() < 0.2
353+
proc.join()
354+
355+
assert counter.value == n_processes * n_coros # Ensure all couroutines have finished

0 commit comments

Comments
 (0)