Skip to content

Commit fb29cbf

Browse files
committed
Fix up xdist concurrent handling logic
- Large refactor of plugin code to better isolate per session vs per worker, and fix test collection logic
1 parent e30651b commit fb29cbf

File tree

3 files changed

+159
-39
lines changed

3 files changed

+159
-39
lines changed

Taskfile.yml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,11 @@ tasks:
6262
- task: lint:types
6363
test:
6464
deps: [_verify_python_venv, install]
65-
cmd: poetry run pytest -n auto
65+
cmd: |
66+
# Clear cache files
67+
rm -rf $PACKAGE_DIR/testing/.pytest_run_cache
68+
# Run pytest
69+
poetry run pytest -n auto
6670
#============================================================#
6771
#================= SECTION_HEADING ==========================#
6872
#============================================================#

django_utils_lib/testing/pytest_plugin.py

Lines changed: 133 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -3,33 +3,33 @@
33
import csv
44
import json
55
import os
6+
import pathlib
7+
import uuid
8+
from dataclasses import dataclass
69
from pathlib import Path
710
from typing import (
811
Any,
912
Dict,
1013
List,
1114
Literal,
1215
Optional,
16+
Union,
17+
cast,
1318
)
1419

1520
import pytest
21+
import xdist
22+
import xdist.dsession
23+
import xdist.workermanage
1624
from constants import PACKAGE_NAME
1725
from filelock import FileLock
1826
from typing_extensions import NotRequired, TypedDict
1927

2028
from django_utils_lib.logger import build_heading_block, pkg_logger
21-
from django_utils_lib.testing.utils import PytestNodeID, validate_requirement_tagging
29+
from django_utils_lib.testing.utils import PytestNodeID, is_main_pytest_runner, validate_requirement_tagging
2230

2331
BASE_DIR = Path(__file__).resolve().parent
2432

25-
# Due to the parallelized nature of xdist (we our library consumer might or might
26-
# not be using), we are going to use a file-based system for implementing both
27-
# a concurrency lock, as well as a way to easily share the metadata across
28-
# processes.
29-
temp_file_path = os.path.join(BASE_DIR, "test.temp.json")
30-
temp_file_lock_path = f"{temp_file_path}.lock"
31-
file_lock = FileLock(temp_file_lock_path)
32-
3333

3434
TestStatus = Literal["PASS", "FAIL", ""]
3535

@@ -103,6 +103,36 @@ class PluginConfigurationItem(TypedDict):
103103
}
104104

105105

106+
class InternalSessionConfig(TypedDict):
107+
global_session_id: str
108+
temp_shared_session_dir_path: str
109+
110+
111+
# Note: Redundant typing of InternalSessionConfig, but likely unavoidable
112+
# due to lack of type-coercion features in Python types
113+
@dataclass
114+
class InternalSessionConfigDataClass:
115+
global_session_id: str
116+
temp_shared_session_dir_path: str
117+
118+
119+
class InternalWorkerConfig(InternalSessionConfig):
120+
# These values are provided by xdist automatically
121+
workerid: str
122+
"""
123+
Auto-generated worker ID (`gw0`, `gw1`, etc.)
124+
"""
125+
workercount: int
126+
testrunuid: str
127+
# Our own injected values
128+
temp_worker_dir_path: str
129+
130+
131+
@dataclass
132+
class WorkerConfigInstance:
133+
workerinput: InternalWorkerConfig
134+
135+
106136
class CollectedTestMetadata(TypedDict):
107137
"""
108138
Metadata that is collected for each test "node"
@@ -138,11 +168,26 @@ class CollectedTests:
138168
File-backed data-store for collected test info
139169
"""
140170

171+
def __init__(self, run_id: str) -> None:
172+
"""
173+
Args:
174+
run_id: This should be a global session ID, unless you want to isolate results by worker
175+
"""
176+
self.tmp_dir_path = os.path.join(BASE_DIR, ".pytest_run_cache", run_id)
177+
os.makedirs(self.tmp_dir_path, exist_ok=True)
178+
# Due to the parallelized nature of xdist (we our library consumer might or might
179+
# not be using), we are going to use a file-based system for implementing both
180+
# a concurrency lock, as well as a way to easily share the metadata across
181+
# processes.
182+
self.temp_file_path = os.path.join(self.tmp_dir_path, "test.temp.json")
183+
self.temp_file_lock_path = f"{self.temp_file_path}.lock"
184+
self.file_lock = FileLock(self.temp_file_lock_path)
185+
141186
def _get_data(self) -> CollectedTestsMapping:
142-
with file_lock:
143-
if not os.path.exists(temp_file_path):
187+
with self.file_lock:
188+
if not os.path.exists(self.temp_file_path):
144189
return {}
145-
with open(temp_file_path, "r") as f:
190+
with open(self.temp_file_path, "r") as f:
146191
return json.load(f)
147192

148193
def __getitem__(self, node_id: PytestNodeID) -> CollectedTestMetadata:
@@ -151,21 +196,18 @@ def __getitem__(self, node_id: PytestNodeID) -> CollectedTestMetadata:
151196
def __setitem__(self, node_id: str, item: CollectedTestMetadata):
152197
updated_data = self._get_data()
153198
updated_data[node_id] = item
154-
with file_lock:
155-
with open(temp_file_path, "w") as f:
199+
with self.file_lock:
200+
with open(self.temp_file_path, "w") as f:
156201
json.dump(updated_data, f)
157202

158203
def update_test_status(self, node_id: PytestNodeID, updated_status: TestStatus):
159204
updated_data = self._get_data()
160205
updated_data[node_id]["status"] = updated_status
161-
with file_lock:
162-
with open(temp_file_path, "w") as f:
206+
with self.file_lock:
207+
with open(self.temp_file_path, "w") as f:
163208
json.dump(updated_data, f)
164209

165210

166-
collected_tests = CollectedTests()
167-
168-
169211
@pytest.hookimpl()
170212
def pytest_addoption(parser: pytest.Parser):
171213
# Register all config key-pairs with INI parser
@@ -175,58 +217,114 @@ def pytest_addoption(parser: pytest.Parser):
175217

176218
@pytest.hookimpl()
177219
def pytest_configure(config: pytest.Config):
178-
if hasattr(config, "workerinput"):
220+
if not is_main_pytest_runner(config):
179221
return
180222

181223
# Register markers
182224
config.addinivalue_line("markers", "requirements(requirements: List[str]): Attach requirements to test")
183225

184-
# Register plugin
185-
plugin = CustomPytestPlugin(config)
186-
config.pluginmanager.register(plugin)
226+
227+
@pytest.hookimpl()
228+
def pytest_sessionstart(session: pytest.Session):
229+
if is_main_pytest_runner(session):
230+
# If we are on the main runner, this is either a non-xdist run, or
231+
# this is the main xdist process, before nodes been distributed.
232+
# Regardless, we should set up a shared temporary directory, which can
233+
# be shared among all n{0,} nodes
234+
global_session_id = uuid.uuid4().hex
235+
temp_shared_session_dir_path = os.path.join(BASE_DIR, ".pytest_run_cache", global_session_id)
236+
pathlib.Path(temp_shared_session_dir_path).mkdir(parents=True, exist_ok=True)
237+
session_config = cast(InternalSessionConfigDataClass, session.config)
238+
session_config.global_session_id = global_session_id
239+
session_config.temp_shared_session_dir_path = temp_shared_session_dir_path
240+
241+
plugin = CustomPytestPlugin(session.config)
242+
session.config.pluginmanager.register(plugin)
187243
pkg_logger.debug(f"{PACKAGE_NAME} plugin registered")
188244
plugin.auto_engage_debugger()
189245

190246

247+
def pytest_configure_node(node: xdist.workermanage.WorkerController):
248+
"""
249+
Special xdist-only hook, which is called as a node is configured, before instantiation & distribution
250+
251+
This hook only runs on the main process (not workers), and is skipped entirely if xdist is not being used
252+
"""
253+
worker_id: str = node.workerinput["workerid"]
254+
255+
# Retrieve global shared session config
256+
session_config = cast(InternalSessionConfigDataClass, node.config)
257+
temp_shared_session_dir_path = session_config.temp_shared_session_dir_path
258+
259+
# Construct worker-scoped temp directory
260+
temp_worker_dir_path = os.path.join(temp_shared_session_dir_path, worker_id)
261+
pathlib.Path(temp_worker_dir_path).mkdir(parents=True, exist_ok=True)
262+
263+
# Copy worker-specific, as well as shared config values, into the node config
264+
node.workerinput["temp_worker_dir_path"] = temp_worker_dir_path
265+
node.workerinput["temp_shared_session_dir_path"] = temp_shared_session_dir_path
266+
node.workerinput["global_session_id"] = session_config.global_session_id
267+
268+
191269
class CustomPytestPlugin:
192270
# Tell Pytest that this is not a test class
193271
__test__ = False
194272

195273
def __init__(self, pytest_config: pytest.Config) -> None:
196274
self.pytest_config = pytest_config
275+
self.collected_tests = CollectedTests(self.get_internal_shared_config(pytest_config)["global_session_id"])
197276
self.debugger_listening = False
198277
# We might or might not be running inside an xdist worker
199-
self._is_running_on_worker = False
278+
self._is_running_on_worker = not is_main_pytest_runner(pytest_config)
200279

201-
def get_config_val(self, config_key: PluginConfigKey):
280+
def get_global_config_val(self, config_key: PluginConfigKey):
202281
"""
203282
Wrapper function just to add some extra type-safety around dynamic config keys
204283
"""
205284
return self.pytest_config.getini(config_key)
206285

286+
def get_internal_shared_config(
287+
self, pytest_obj: Union[pytest.Session, pytest.Config, pytest.FixtureRequest]
288+
) -> InternalSessionConfig:
289+
"""
290+
Utility function to get shared config values, because it can be a little tricky to know
291+
where to retrieve them from (for main vs worker)
292+
"""
293+
config = pytest_obj if isinstance(pytest_obj, pytest.Config) else pytest_obj.config
294+
# If we are on the main runner, we can just directly access
295+
if is_main_pytest_runner(config):
296+
session_config = cast(InternalSessionConfigDataClass, config)
297+
return {
298+
"temp_shared_session_dir_path": session_config.temp_shared_session_dir_path,
299+
"global_session_id": session_config.global_session_id,
300+
}
301+
# If we are on a worker, we can retrieve the shared config values via the `workerinput` property
302+
worker_input = cast(WorkerConfigInstance, config).workerinput
303+
return worker_input
304+
207305
@property
208306
def auto_debug(self) -> bool:
209307
# Disable if CI is detected
210308
if os.getenv("CI", "").lower() == "true":
211309
return False
212-
return bool(self.get_config_val("auto_debug")) or bool(os.getenv(f"{PACKAGE_NAME}_AUTO_DEBUG", ""))
310+
return bool(self.get_global_config_val("auto_debug")) or bool(os.getenv(f"{PACKAGE_NAME}_AUTO_DEBUG", ""))
213311

214312
@property
215313
def auto_debug_wait_for_connect(self) -> bool:
216-
return bool(self.get_config_val("auto_debug_wait_for_connect"))
314+
return bool(self.get_global_config_val("auto_debug_wait_for_connect"))
217315

218316
@property
219317
def mandate_requirement_markers(self) -> bool:
220-
return bool(self.get_config_val("mandate_requirement_markers"))
318+
return bool(self.get_global_config_val("mandate_requirement_markers"))
221319

222320
@property
223321
def reporting_config(self) -> Optional[PluginReportingConfiguration]:
224-
csv_export_path = self.get_config_val("reporting.csv_export_path")
322+
csv_export_path = self.get_global_config_val("reporting.csv_export_path")
225323
if not isinstance(csv_export_path, str):
226324
return None
227325
return {
228326
"csv_export_path": csv_export_path,
229-
"omit_unexecuted_tests": bool(self.get_config_val("reporting.omit_unexecuted_tests")),
327+
"omit_unexecuted_tests": bool(self.get_global_config_val("reporting.omit_unexecuted_tests")),
230328
}
231329

232330
@property
@@ -282,7 +380,7 @@ def pytest_collection_modifyitems(self, config: pytest.Config, items: List[pytes
282380
requirements = validation_results["validated_requirements"]
283381

284382
doc_string: str = item.obj.__doc__ or "" # type: ignore
285-
collected_tests[item.nodeid] = {
383+
self.collected_tests[item.nodeid] = {
286384
"node_id": item.nodeid,
287385
"requirements": requirements,
288386
"doc_string": doc_string.strip(),
@@ -294,10 +392,8 @@ def pytest_collection_modifyitems(self, config: pytest.Config, items: List[pytes
294392

295393
@pytest.hookimpl()
296394
def pytest_sessionstart(self, session: pytest.Session):
297-
self._is_running_on_worker = getattr(session.config, "workerinput", None) is not None
298-
299-
if self._is_running_on_worker:
300-
# Nothing to do here at the moment
395+
if not is_main_pytest_runner(session):
396+
self._is_running_on_worker = True
301397
return
302398

303399
# Init debugpy listener on main
@@ -311,7 +407,7 @@ def pytest_collection_finish(self, session: pytest.Session):
311407
def pytest_sessionfinish(self, session: pytest.Session, exitstatus):
312408
if not self.reporting_config:
313409
return
314-
collected_test_mappings = collected_tests._get_data()
410+
collected_test_mappings = self.collected_tests._get_data()
315411
with open(self.reporting_config["csv_export_path"], "w") as csv_file:
316412
# Use keys of first entry, since all entries should have same keys
317413
fieldnames = collected_test_mappings[next(iter(collected_test_mappings))].keys()
@@ -327,4 +423,4 @@ def pytest_sessionfinish(self, session: pytest.Session, exitstatus):
327423
def pytest_runtest_logreport(self, report: pytest.TestReport):
328424
# Capture test outcomes and save to collection
329425
if report.when == "call":
330-
collected_tests.update_test_status(report.nodeid, "PASS" if report.passed else "FAIL")
426+
self.collected_tests.update_test_status(report.nodeid, "PASS" if report.passed else "FAIL")

django_utils_lib/testing/utils.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,37 @@
11
from __future__ import annotations
22

33
import re
4-
from typing import List, Tuple, cast
4+
from typing import List, Tuple, Union, cast
55

66
import pytest
77
from typing_extensions import TypedDict
8+
from xdist import is_xdist_worker
89

910
PytestNodeID = str
1011
"""
1112
A pytest node ID follows the format of `file_path::test_name`
1213
"""
1314

1415

16+
def is_main_pytest_runner(pytest_obj: Union[pytest.Config, pytest.FixtureRequest, pytest.Session]):
17+
"""
18+
Utility function that returns true only if we are in the main runner (not an xdist worker)
19+
20+
This should work in both xdist and non-xdist modes of operation.
21+
"""
22+
# Pytest config or worker node
23+
if isinstance(pytest_obj, pytest.Config) or hasattr(pytest_obj, "workerinput"):
24+
# The presence of "workerinput", on either a config or distributed node,
25+
# indicates we are on a worker
26+
return getattr(pytest_obj, "workerinput", None) is None
27+
28+
# Pytest session objects or requests
29+
if hasattr(pytest_obj, "config"):
30+
return is_xdist_worker(pytest_obj) is False
31+
32+
return False
33+
34+
1535
class RequirementValidationResults(TypedDict):
1636
valid: bool
1737
errors: List[str]

0 commit comments

Comments
 (0)