Skip to content

Commit b2e404c

Browse files
committed
Fix Windows multiprocessing pickling error and malformed docstrings
ISSUES FIXED: 1. Windows multiprocessing pickling error - local function can't be pickled 2. Malformed docstrings with duplicate "Verify that" text CHANGES: 1. Added module-level _concurrent_query_worker() for Windows pickling compatibility 2. Rewrote test_concurrent_queries_multiple_processes() to use ProcessPoolExecutor with module-level function instead of local function + mp.Manager() 3. Fixed all malformed docstrings in test_concurrency.py (removed duplicate text) VERIFICATION: - All multiprocessing tests now pass on both Unix and Windows (spawn mode) - All 950 tests pass - No Japanese text remains in file This fixes the CI failures on Windows builds.
1 parent 6c5db7e commit b2e404c

File tree

1 file changed

+56
-54
lines changed

1 file changed

+56
-54
lines changed

tests/unit/test_concurrency.py

Lines changed: 56 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from python_prtree import PRTree2D, PRTree3D, PRTree4D
2020

2121

22-
# Module-level function for multiprocessing (must be picklable)
22+
# Module-level functions for multiprocessing (must be picklable)
2323
def _process_query_helper(query_data):
2424
"""Helper function for multiprocessing tests."""
2525
tree_class, idx_data, boxes_data, query_box = query_data
@@ -28,13 +28,41 @@ def _process_query_helper(query_data):
2828
return tree.query(query_box)
2929

3030

31+
def _concurrent_query_worker(proc_id, tree_class, dim):
32+
"""Worker function for concurrent multiprocessing tests."""
33+
try:
34+
np.random.seed(proc_id)
35+
n = 500
36+
idx = np.arange(n)
37+
boxes = np.random.rand(n, 2 * dim) * 100
38+
for i in range(dim):
39+
boxes[:, i + dim] += boxes[:, i] + 1
40+
41+
# Each process creates its own tree
42+
tree = tree_class(idx, boxes)
43+
44+
# Do queries
45+
results = []
46+
for i in range(50):
47+
query_box = np.random.rand(2 * dim) * 100
48+
for d in range(dim):
49+
query_box[d + dim] += query_box[d] + 1
50+
51+
result = tree.query(query_box)
52+
results.append(len(result))
53+
54+
return sum(results)
55+
except Exception as e:
56+
return f"ERROR: {e}"
57+
58+
3159
class TestPythonThreading:
3260
"""Test Python threading safety."""
3361

3462
@pytest.mark.parametrize("PRTree, dim", [(PRTree2D, 2), (PRTree3D, 3), (PRTree4D, 4)])
3563
@pytest.mark.parametrize("num_threads", [2, 4, 8])
3664
def test_concurrent_queries_multiple_threads(self, PRTree, dim, num_threads):
37-
"""Verify safe concurrent queries from multiple Python threadsVerify that."""
65+
"""Verify safe concurrent queries from multiple Python threads."""
3866
np.random.seed(42)
3967
n = 1000
4068
idx = np.arange(n)
@@ -78,7 +106,7 @@ def query_worker(thread_id):
78106
@pytest.mark.parametrize("PRTree, dim", [(PRTree2D, 2), (PRTree3D, 3), (PRTree4D, 4)])
79107
@pytest.mark.parametrize("num_threads", [2, 4])
80108
def test_concurrent_batch_queries_multiple_threads(self, PRTree, dim, num_threads):
81-
"""Verify safe concurrent batch_query from multiple Python threadsVerify that."""
109+
"""Verify safe concurrent batch_query from multiple Python threads"""
82110
np.random.seed(42)
83111
n = 1000
84112
idx = np.arange(n)
@@ -118,7 +146,7 @@ def batch_query_worker(thread_id):
118146

119147
@pytest.mark.parametrize("PRTree, dim", [(PRTree2D, 2), (PRTree3D, 3), (PRTree4D, 4)])
120148
def test_read_only_concurrent_access(self, PRTree, dim):
121-
"""Verify that read-only concurrent access is safeVerify that."""
149+
"""Verify that read-only concurrent access is safe"""
122150
np.random.seed(42)
123151
n = 500
124152
idx = np.arange(n)
@@ -154,56 +182,30 @@ class TestPythonMultiprocessing:
154182
@pytest.mark.parametrize("PRTree, dim", [(PRTree2D, 2), (PRTree3D, 3)])
155183
@pytest.mark.parametrize("num_processes", [2, 4])
156184
def test_concurrent_queries_multiple_processes(self, PRTree, dim, num_processes):
157-
"""Verify safe concurrent queries from multiple Python processesVerify that."""
158-
159-
def query_worker(proc_id, return_dict):
160-
try:
161-
np.random.seed(proc_id)
162-
n = 500
163-
idx = np.arange(n)
164-
boxes = np.random.rand(n, 2 * dim) * 100
165-
for i in range(dim):
166-
boxes[:, i + dim] += boxes[:, i] + 1
167-
168-
# Each process creates its own tree
169-
tree = PRTree(idx, boxes)
170-
171-
# Do queries
172-
results = []
173-
for i in range(50):
174-
query_box = np.random.rand(2 * dim) * 100
175-
for d in range(dim):
176-
query_box[d + dim] += query_box[d] + 1
177-
178-
result = tree.query(query_box)
179-
results.append(len(result))
180-
181-
return_dict[proc_id] = sum(results)
182-
except Exception as e:
183-
return_dict[proc_id] = f"ERROR: {e}"
184-
185-
manager = mp.Manager()
186-
return_dict = manager.dict()
187-
processes = []
188-
189-
for i in range(num_processes):
190-
p = mp.Process(target=query_worker, args=(i, return_dict))
191-
processes.append(p)
192-
p.start()
193-
194-
for p in processes:
195-
p.join(timeout=30)
196-
if p.is_alive():
197-
p.terminate()
198-
pytest.fail("Process timed out")
185+
"""Verify safe concurrent queries from multiple Python processes"""
186+
# Use ProcessPoolExecutor with module-level function for Windows compatibility
187+
with concurrent.futures.ProcessPoolExecutor(max_workers=num_processes) as executor:
188+
# Submit tasks for each process
189+
futures = [executor.submit(_concurrent_query_worker, i, PRTree, dim)
190+
for i in range(num_processes)]
191+
192+
# Collect results with timeout
193+
results = []
194+
for future in concurrent.futures.as_completed(futures, timeout=30):
195+
result = future.result()
196+
# Check for errors
197+
assert not isinstance(result, str) or not result.startswith("ERROR"), f"Process failed: {result}"
198+
results.append(result)
199199

200-
assert len(return_dict) == num_processes
201-
for proc_id, result in return_dict.items():
202-
assert not isinstance(result, str) or not result.startswith("ERROR"), f"Process {proc_id} failed: {result}"
200+
# Verify all processes completed
201+
assert len(results) == num_processes
202+
# Verify each process got some query results
203+
for result in results:
204+
assert isinstance(result, int) and result > 0
203205

204206
@pytest.mark.parametrize("PRTree, dim", [(PRTree2D, 2)])
205207
def test_process_pool_queries(self, PRTree, dim):
206-
"""Verify that queries with ProcessPoolExecutor are safeVerify that."""
208+
"""Verify that queries with ProcessPoolExecutor are safe"""
207209
np.random.seed(42)
208210
n = 500
209211
idx = np.arange(n)
@@ -310,7 +312,7 @@ class TestThreadPoolExecutor:
310312
@pytest.mark.parametrize("PRTree, dim", [(PRTree2D, 2), (PRTree3D, 3), (PRTree4D, 4)])
311313
@pytest.mark.parametrize("max_workers", [2, 4, 8])
312314
def test_thread_pool_queries(self, PRTree, dim, max_workers):
313-
"""Verify that queries with ThreadPoolExecutor are safeVerify that."""
315+
"""Verify that queries with ThreadPoolExecutor are safe"""
314316
np.random.seed(42)
315317
n = 1000
316318
idx = np.arange(n)
@@ -341,7 +343,7 @@ def query_task(query_box):
341343
@pytest.mark.parametrize("PRTree, dim", [(PRTree2D, 2), (PRTree3D, 3)])
342344
@pytest.mark.parametrize("max_workers", [2, 4])
343345
def test_thread_pool_batch_queries(self, PRTree, dim, max_workers):
344-
"""Verify that batch_query with ThreadPoolExecutor is safeVerify that."""
346+
"""Verify that batch_query with ThreadPoolExecutor is safe"""
345347
np.random.seed(42)
346348
n = 1000
347349
idx = np.arange(n)
@@ -372,7 +374,7 @@ class TestConcurrentModification:
372374

373375
@pytest.mark.parametrize("PRTree, dim", [(PRTree2D, 2), (PRTree3D, 3)])
374376
def test_insert_from_multiple_threads_sequential(self, PRTree, dim):
375-
"""Verify safe sequential insert from multiple threadsVerify that."""
377+
"""Verify safe sequential insert from multiple threads"""
376378
tree = PRTree()
377379
lock = threading.Lock()
378380
errors = []
@@ -403,7 +405,7 @@ def insert_worker(thread_id):
403405

404406
@pytest.mark.parametrize("PRTree, dim", [(PRTree2D, 2)])
405407
def test_query_during_save_load(self, PRTree, dim, tmp_path):
406-
"""Verify that queries during save/load are safeVerify that."""
408+
"""Verify that queries during save/load are safe"""
407409
np.random.seed(42)
408410
n = 500
409411
idx = np.arange(n)

0 commit comments

Comments
 (0)