Skip to content

Commit 4d5c414

Browse files
authored
Upgrade pytest-asyncio to 0.26.0 (#652)
* Upgrade pytest-asyncio to 0.26.0 * Add newline to pyproject.toml * Remove the fixture override * Cancel tasks for ControlClient only if the event loop is running * Use cancel_task_if_running more broadly
1 parent 5353328 commit 4d5c414

File tree

10 files changed

+32
-30
lines changed

10 files changed

+32
-30
lines changed

hivemind/p2p/p2p_daemon.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from hivemind.p2p.p2p_daemon_bindings.utils import ControlFailure
2121
from hivemind.proto import crypto_pb2
2222
from hivemind.proto.p2pd_pb2 import RPCError
23-
from hivemind.utils.asyncio import as_aiter, asingle
23+
from hivemind.utils.asyncio import as_aiter, asingle, cancel_task_if_running
2424
from hivemind.utils.crypto import RSAPrivateKey
2525
from hivemind.utils.logging import get_logger, golog_level_to_python, loglevel, python_level_to_golog
2626
from hivemind.utils.multiaddr import Multiaddr
@@ -647,9 +647,9 @@ def _terminate(self) -> None:
647647
if self._client is not None:
648648
self._client.close()
649649
if self._listen_task is not None:
650-
self._listen_task.cancel()
650+
cancel_task_if_running(self._listen_task)
651651
if self._reader_task is not None:
652-
self._reader_task.cancel()
652+
cancel_task_if_running(self._reader_task)
653653

654654
self._alive = False
655655
if self._child is not None and self._child.returncode is None:

hivemind/p2p/p2p_daemon_bindings/control.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
write_pbmsg,
2020
)
2121
from hivemind.proto import p2pd_pb2 as p2pd_pb
22+
from hivemind.utils.asyncio import cancel_task_if_running
2223
from hivemind.utils.logging import get_logger
2324
from hivemind.utils.multiaddr import Multiaddr, protocols
2425

@@ -134,10 +135,8 @@ async def create(
134135
return control
135136

136137
def close(self) -> None:
137-
if self._read_task is not None:
138-
self._read_task.cancel()
139-
if self._write_task is not None:
140-
self._write_task.cancel()
138+
cancel_task_if_running(self._read_task)
139+
cancel_task_if_running(self._write_task)
141140

142141
def __del__(self):
143142
self.close()
@@ -194,7 +193,7 @@ async def _read_from_persistent_conn(self, reader: asyncio.StreamReader):
194193
self._handler_tasks[call_id] = handler_task
195194

196195
elif call_id in self._handler_tasks and resp.HasField("cancel"):
197-
self._handler_tasks[call_id].cancel()
196+
cancel_task_if_running(self._handler_tasks[call_id])
198197

199198
elif call_id in self._pending_calls and resp.HasField("daemonError"):
200199
daemon_exc = P2PDaemonError(resp.daemonError.message)

hivemind/utils/asyncio.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,3 +195,16 @@ async def enter_asynchronously(context: AbstractContextManager):
195195
"""Wrap a non-async context so that it can be entered asynchronously"""
196196
async with _AsyncContextWrapper(context) as ret_value:
197197
yield ret_value
198+
199+
200+
def cancel_task_if_running(task: Optional[asyncio.Task]) -> None:
201+
"""Safely cancel a task if it's still running and the event loop is available."""
202+
if task is not None and not task.done():
203+
try:
204+
loop = asyncio.get_event_loop()
205+
if loop.is_running():
206+
task.cancel()
207+
except RuntimeError as e:
208+
# Only ignore event loop closure errors
209+
if "Event loop is closed" not in str(e):
210+
raise

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,6 @@ dummy-variable-rgx = "^_$"
2121

2222
[tool.ruff.lint.isort]
2323
known-local-folder = ["arguments", "test_utils", "tests", "utils"]
24+
25+
[tool.pytest.ini_options]
26+
asyncio_default_fixture_loop_scope = "function"

requirements-dev.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
pytest==8.3.5
22
pytest-forked
3-
pytest-asyncio==0.16.0
3+
pytest-asyncio==0.26.0
44
pytest-cov
55
pytest-timeout
66
coverage

tests/conftest.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import asyncio
21
import gc
32

43
import psutil
@@ -12,22 +11,6 @@
1211
logger = get_logger(__name__)
1312

1413

15-
@pytest.fixture
16-
def event_loop():
17-
"""
18-
This overrides the ``event_loop`` fixture from pytest-asyncio
19-
(e.g. to make it compatible with ``asyncio.subprocess``).
20-
21-
This fixture is identical to the original one but does not call ``loop.close()`` in the end.
22-
Indeed, at this point, the loop is already stopped (i.e. next tests are free to create new loops).
23-
However, finalizers of objects created in the current test may reference the current loop and fail if it is closed.
24-
For example, this happens while using ``asyncio.subprocess`` (the ``asyncio.subprocess.Process`` finalizer
25-
fails if the loop is closed, but works if the loop is only stopped).
26-
"""
27-
28-
yield asyncio.get_event_loop()
29-
30-
3114
@pytest.fixture(autouse=True, scope="session")
3215
def cleanup_children():
3316
yield

tests/test_connection_handler.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import Any, Dict
66

77
import pytest
8+
import pytest_asyncio
89
import torch
910

1011
from hivemind.compression import deserialize_tensor_stream, deserialize_torch_tensor, serialize_torch_tensor
@@ -20,7 +21,7 @@
2021
from hivemind.utils.tensor_descr import BatchTensorDescriptor
2122

2223

23-
@pytest.fixture
24+
@pytest_asyncio.fixture
2425
async def client_stub():
2526
handler_dht = DHT(start=True)
2627
module_backends = {"expert1": DummyModuleBackend("expert1", k=1), "expert2": DummyModuleBackend("expert2", k=2)}

tests/test_dht_schema.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import Dict
33

44
import pytest
5+
import pytest_asyncio
56
from pydantic.v1 import BaseModel, StrictInt, conint
67

78
import hivemind
@@ -17,7 +18,7 @@ class SampleSchema(BaseModel):
1718
signed_data: Dict[BytesWithPublicKey, bytes]
1819

1920

20-
@pytest.fixture
21+
@pytest_asyncio.fixture
2122
async def dht_nodes_with_schema():
2223
validator = SchemaValidator(SampleSchema)
2324

tests/test_p2p_daemon_bindings.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from contextlib import AsyncExitStack
44

55
import pytest
6+
import pytest_asyncio
67
from google.protobuf.message import EncodeError
78

89
from hivemind.p2p.p2p_daemon_bindings.control import ControlClient, DaemonConnector, parse_conn_protocol
@@ -372,7 +373,7 @@ async def test_write_pbmsg_missing_fields(pb_msg):
372373
await write_pbmsg(MockReaderWriter(), pb_msg)
373374

374375

375-
@pytest.fixture
376+
@pytest_asyncio.fixture
376377
async def p2pcs():
377378
# TODO: Change back to gather style
378379
async with AsyncExitStack() as stack:

tests/test_p2p_servicer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,14 @@
22
from typing import AsyncIterator
33

44
import pytest
5+
import pytest_asyncio
56

67
from hivemind.p2p import P2P, P2PContext, P2PDaemonError, ServicerBase
78
from hivemind.proto import test_pb2
89
from hivemind.utils.asyncio import anext
910

1011

11-
@pytest.fixture
12+
@pytest_asyncio.fixture
1213
async def server_client():
1314
server = await P2P.create()
1415
client = await P2P.create(initial_peers=await server.get_visible_maddrs())

0 commit comments

Comments
 (0)