1919from python_prtree import PRTree2D , PRTree3D , PRTree4D
2020
2121
22- # Module-level function for multiprocessing (must be picklable)
22+ # Module-level functions for multiprocessing (must be picklable)
2323def _process_query_helper (query_data ):
2424 """Helper function for multiprocessing tests."""
2525 tree_class , idx_data , boxes_data , query_box = query_data
@@ -28,13 +28,41 @@ def _process_query_helper(query_data):
2828 return tree .query (query_box )
2929
3030
31+ def _concurrent_query_worker (proc_id , tree_class , dim ):
32+ """Worker function for concurrent multiprocessing tests."""
33+ try :
34+ np .random .seed (proc_id )
35+ n = 500
36+ idx = np .arange (n )
37+ boxes = np .random .rand (n , 2 * dim ) * 100
38+ for i in range (dim ):
39+ boxes [:, i + dim ] += boxes [:, i ] + 1
40+
41+ # Each process creates its own tree
42+ tree = tree_class (idx , boxes )
43+
44+ # Do queries
45+ results = []
46+ for i in range (50 ):
47+ query_box = np .random .rand (2 * dim ) * 100
48+ for d in range (dim ):
49+ query_box [d + dim ] += query_box [d ] + 1
50+
51+ result = tree .query (query_box )
52+ results .append (len (result ))
53+
54+ return sum (results )
55+ except Exception as e :
56+ return f"ERROR: { e } "
57+
58+
3159class TestPythonThreading :
3260 """Test Python threading safety."""
3361
3462 @pytest .mark .parametrize ("PRTree, dim" , [(PRTree2D , 2 ), (PRTree3D , 3 ), (PRTree4D , 4 )])
3563 @pytest .mark .parametrize ("num_threads" , [2 , 4 , 8 ])
3664 def test_concurrent_queries_multiple_threads (self , PRTree , dim , num_threads ):
37- """Verify safe concurrent queries from multiple Python threadsVerify that ."""
65+ """Verify safe concurrent queries from multiple Python threads ."""
3866 np .random .seed (42 )
3967 n = 1000
4068 idx = np .arange (n )
@@ -78,7 +106,7 @@ def query_worker(thread_id):
78106 @pytest .mark .parametrize ("PRTree, dim" , [(PRTree2D , 2 ), (PRTree3D , 3 ), (PRTree4D , 4 )])
79107 @pytest .mark .parametrize ("num_threads" , [2 , 4 ])
80108 def test_concurrent_batch_queries_multiple_threads (self , PRTree , dim , num_threads ):
81- """Verify safe concurrent batch_query from multiple Python threadsVerify that. """
109+ """Verify safe concurrent batch_query from multiple Python threads """
82110 np .random .seed (42 )
83111 n = 1000
84112 idx = np .arange (n )
@@ -118,7 +146,7 @@ def batch_query_worker(thread_id):
118146
119147 @pytest .mark .parametrize ("PRTree, dim" , [(PRTree2D , 2 ), (PRTree3D , 3 ), (PRTree4D , 4 )])
120148 def test_read_only_concurrent_access (self , PRTree , dim ):
121- """Verify that read-only concurrent access is safeVerify that. """
149+ """Verify that read-only concurrent access is safe """
122150 np .random .seed (42 )
123151 n = 500
124152 idx = np .arange (n )
@@ -154,56 +182,30 @@ class TestPythonMultiprocessing:
154182 @pytest .mark .parametrize ("PRTree, dim" , [(PRTree2D , 2 ), (PRTree3D , 3 )])
155183 @pytest .mark .parametrize ("num_processes" , [2 , 4 ])
156184 def test_concurrent_queries_multiple_processes (self , PRTree , dim , num_processes ):
157- """Verify safe concurrent queries from multiple Python processesVerify that."""
158-
159- def query_worker (proc_id , return_dict ):
160- try :
161- np .random .seed (proc_id )
162- n = 500
163- idx = np .arange (n )
164- boxes = np .random .rand (n , 2 * dim ) * 100
165- for i in range (dim ):
166- boxes [:, i + dim ] += boxes [:, i ] + 1
167-
168- # Each process creates its own tree
169- tree = PRTree (idx , boxes )
170-
171- # Do queries
172- results = []
173- for i in range (50 ):
174- query_box = np .random .rand (2 * dim ) * 100
175- for d in range (dim ):
176- query_box [d + dim ] += query_box [d ] + 1
177-
178- result = tree .query (query_box )
179- results .append (len (result ))
180-
181- return_dict [proc_id ] = sum (results )
182- except Exception as e :
183- return_dict [proc_id ] = f"ERROR: { e } "
184-
185- manager = mp .Manager ()
186- return_dict = manager .dict ()
187- processes = []
188-
189- for i in range (num_processes ):
190- p = mp .Process (target = query_worker , args = (i , return_dict ))
191- processes .append (p )
192- p .start ()
193-
194- for p in processes :
195- p .join (timeout = 30 )
196- if p .is_alive ():
197- p .terminate ()
198- pytest .fail ("Process timed out" )
185+ """Verify safe concurrent queries from multiple Python processes"""
186+ # Use ProcessPoolExecutor with module-level function for Windows compatibility
187+ with concurrent .futures .ProcessPoolExecutor (max_workers = num_processes ) as executor :
188+ # Submit tasks for each process
189+ futures = [executor .submit (_concurrent_query_worker , i , PRTree , dim )
190+ for i in range (num_processes )]
191+
192+ # Collect results with timeout
193+ results = []
194+ for future in concurrent .futures .as_completed (futures , timeout = 30 ):
195+ result = future .result ()
196+ # Check for errors
197+ assert not isinstance (result , str ) or not result .startswith ("ERROR" ), f"Process failed: { result } "
198+ results .append (result )
199199
200- assert len (return_dict ) == num_processes
201- for proc_id , result in return_dict .items ():
202- assert not isinstance (result , str ) or not result .startswith ("ERROR" ), f"Process { proc_id } failed: { result } "
200+ # Verify all processes completed
201+ assert len (results ) == num_processes
202+ # Verify each process got some query results
203+ for result in results :
204+ assert isinstance (result , int ) and result > 0
203205
204206 @pytest .mark .parametrize ("PRTree, dim" , [(PRTree2D , 2 )])
205207 def test_process_pool_queries (self , PRTree , dim ):
206- """Verify that queries with ProcessPoolExecutor are safeVerify that. """
208+ """Verify that queries with ProcessPoolExecutor are safe """
207209 np .random .seed (42 )
208210 n = 500
209211 idx = np .arange (n )
@@ -310,7 +312,7 @@ class TestThreadPoolExecutor:
310312 @pytest .mark .parametrize ("PRTree, dim" , [(PRTree2D , 2 ), (PRTree3D , 3 ), (PRTree4D , 4 )])
311313 @pytest .mark .parametrize ("max_workers" , [2 , 4 , 8 ])
312314 def test_thread_pool_queries (self , PRTree , dim , max_workers ):
313- """Verify that queries with ThreadPoolExecutor are safeVerify that. """
315+ """Verify that queries with ThreadPoolExecutor are safe """
314316 np .random .seed (42 )
315317 n = 1000
316318 idx = np .arange (n )
@@ -341,7 +343,7 @@ def query_task(query_box):
341343 @pytest .mark .parametrize ("PRTree, dim" , [(PRTree2D , 2 ), (PRTree3D , 3 )])
342344 @pytest .mark .parametrize ("max_workers" , [2 , 4 ])
343345 def test_thread_pool_batch_queries (self , PRTree , dim , max_workers ):
344- """Verify that batch_query with ThreadPoolExecutor is safeVerify that. """
346+ """Verify that batch_query with ThreadPoolExecutor is safe """
345347 np .random .seed (42 )
346348 n = 1000
347349 idx = np .arange (n )
@@ -372,7 +374,7 @@ class TestConcurrentModification:
372374
373375 @pytest .mark .parametrize ("PRTree, dim" , [(PRTree2D , 2 ), (PRTree3D , 3 )])
374376 def test_insert_from_multiple_threads_sequential (self , PRTree , dim ):
375- """Verify safe sequential insert from multiple threadsVerify that. """
377+ """Verify safe sequential insert from multiple threads """
376378 tree = PRTree ()
377379 lock = threading .Lock ()
378380 errors = []
@@ -403,7 +405,7 @@ def insert_worker(thread_id):
403405
404406 @pytest .mark .parametrize ("PRTree, dim" , [(PRTree2D , 2 )])
405407 def test_query_during_save_load (self , PRTree , dim , tmp_path ):
406- """Verify that queries during save/load are safeVerify that. """
408+ """Verify that queries during save/load are safe """
407409 np .random .seed (42 )
408410 n = 500
409411 idx = np .arange (n )
0 commit comments