33import  csv 
44import  json 
55import  os 
6+ import  pathlib 
7+ import  uuid 
8+ from  dataclasses  import  dataclass 
69from  pathlib  import  Path 
710from  typing  import  (
811    Any ,
912    Dict ,
1013    List ,
1114    Literal ,
1215    Optional ,
16+     Union ,
17+     cast ,
1318)
1419
1520import  pytest 
21+ import  xdist 
22+ import  xdist .dsession 
23+ import  xdist .workermanage 
1624from  constants  import  PACKAGE_NAME 
1725from  filelock  import  FileLock 
1826from  typing_extensions  import  NotRequired , TypedDict 
1927
2028from  django_utils_lib .logger  import  build_heading_block , pkg_logger 
21- from  django_utils_lib .testing .utils  import  PytestNodeID , validate_requirement_tagging 
29+ from  django_utils_lib .testing .utils  import  PytestNodeID , is_main_pytest_runner ,  validate_requirement_tagging 
2230
2331BASE_DIR  =  Path (__file__ ).resolve ().parent 
2432
25- # Due to the parallelized nature of xdist (we our library consumer might or might 
26- # not be using), we are going to use a file-based system for implementing both 
27- # a concurrency lock, as well as a way to easily share the metadata across 
28- # processes. 
29- temp_file_path  =  os .path .join (BASE_DIR , "test.temp.json" )
30- temp_file_lock_path  =  f"{ temp_file_path }  
31- file_lock  =  FileLock (temp_file_lock_path )
32- 
3333
3434TestStatus  =  Literal ["PASS" , "FAIL" , "" ]
3535
@@ -103,6 +103,36 @@ class PluginConfigurationItem(TypedDict):
103103}
104104
105105
106+ class  InternalSessionConfig (TypedDict ):
107+     global_session_id : str 
108+     temp_shared_session_dir_path : str 
109+ 
110+ 
111+ # Note: Redundant typing of InternalSessionConfig, but likely unavoidable 
112+ # due to lack of type-coercion features in Python types 
113+ @dataclass  
114+ class  InternalSessionConfigDataClass :
115+     global_session_id : str 
116+     temp_shared_session_dir_path : str 
117+ 
118+ 
119+ class  InternalWorkerConfig (InternalSessionConfig ):
120+     # These values are provided by xdist automatically 
121+     workerid : str 
122+     """ 
123+     Auto-generated worker ID (`gw0`, `gw1`, etc.) 
124+     """ 
125+     workercount : int 
126+     testrunuid : str 
127+     # Our own injected values 
128+     temp_worker_dir_path : str 
129+ 
130+ 
131+ @dataclass  
132+ class  WorkerConfigInstance :
133+     workerinput : InternalWorkerConfig 
134+ 
135+ 
106136class  CollectedTestMetadata (TypedDict ):
107137    """ 
108138    Metadata that is collected for each test "node" 
@@ -138,11 +168,26 @@ class CollectedTests:
138168    File-backed data-store for collected test info 
139169    """ 
140170
171+     def  __init__ (self , run_id : str ) ->  None :
172+         """ 
173+         Args: 
174+             run_id: This should be a global session ID, unless you want to isolate results by worker 
175+         """ 
176+         self .tmp_dir_path  =  os .path .join (BASE_DIR , ".pytest_run_cache" , run_id )
177+         os .makedirs (self .tmp_dir_path , exist_ok = True )
178+         # Due to the parallelized nature of xdist (we our library consumer might or might 
179+         # not be using), we are going to use a file-based system for implementing both 
180+         # a concurrency lock, as well as a way to easily share the metadata across 
181+         # processes. 
182+         self .temp_file_path  =  os .path .join (self .tmp_dir_path , "test.temp.json" )
183+         self .temp_file_lock_path  =  f"{ self .temp_file_path }  
184+         self .file_lock  =  FileLock (self .temp_file_lock_path )
185+ 
141186    def  _get_data (self ) ->  CollectedTestsMapping :
142-         with  file_lock :
143-             if  not  os .path .exists (temp_file_path ):
187+         with  self . file_lock :
188+             if  not  os .path .exists (self . temp_file_path ):
144189                return  {}
145-             with  open (temp_file_path , "r" ) as  f :
190+             with  open (self . temp_file_path , "r" ) as  f :
146191                return  json .load (f )
147192
148193    def  __getitem__ (self , node_id : PytestNodeID ) ->  CollectedTestMetadata :
@@ -151,21 +196,18 @@ def __getitem__(self, node_id: PytestNodeID) -> CollectedTestMetadata:
151196    def  __setitem__ (self , node_id : str , item : CollectedTestMetadata ):
152197        updated_data  =  self ._get_data ()
153198        updated_data [node_id ] =  item 
154-         with  file_lock :
155-             with  open (temp_file_path , "w" ) as  f :
199+         with  self . file_lock :
200+             with  open (self . temp_file_path , "w" ) as  f :
156201                json .dump (updated_data , f )
157202
158203    def  update_test_status (self , node_id : PytestNodeID , updated_status : TestStatus ):
159204        updated_data  =  self ._get_data ()
160205        updated_data [node_id ]["status" ] =  updated_status 
161-         with  file_lock :
162-             with  open (temp_file_path , "w" ) as  f :
206+         with  self . file_lock :
207+             with  open (self . temp_file_path , "w" ) as  f :
163208                json .dump (updated_data , f )
164209
165210
166- collected_tests  =  CollectedTests ()
167- 
168- 
169211@pytest .hookimpl () 
170212def  pytest_addoption (parser : pytest .Parser ):
171213    # Register all config key-pairs with INI parser 
@@ -175,58 +217,114 @@ def pytest_addoption(parser: pytest.Parser):
175217
176218@pytest .hookimpl () 
177219def  pytest_configure (config : pytest .Config ):
178-     if  hasattr (config ,  "workerinput" ):
220+     if  not   is_main_pytest_runner (config ):
179221        return 
180222
181223    # Register markers 
182224    config .addinivalue_line ("markers" , "requirements(requirements: List[str]): Attach requirements to test" )
183225
184-     # Register plugin 
185-     plugin  =  CustomPytestPlugin (config )
186-     config .pluginmanager .register (plugin )
226+ 
227+ @pytest .hookimpl () 
228+ def  pytest_sessionstart (session : pytest .Session ):
229+     if  is_main_pytest_runner (session ):
230+         # If we are on the main runner, this is either a non-xdist run, or 
231+         # this is the main xdist process, before nodes been distributed. 
232+         # Regardless, we should set up a shared temporary directory, which can 
233+         # be shared among all n{0,} nodes 
234+         global_session_id  =  uuid .uuid4 ().hex 
235+         temp_shared_session_dir_path  =  os .path .join (BASE_DIR , ".pytest_run_cache" , global_session_id )
236+         pathlib .Path (temp_shared_session_dir_path ).mkdir (parents = True , exist_ok = True )
237+         session_config  =  cast (InternalSessionConfigDataClass , session .config )
238+         session_config .global_session_id  =  global_session_id 
239+         session_config .temp_shared_session_dir_path  =  temp_shared_session_dir_path 
240+ 
241+     plugin  =  CustomPytestPlugin (session .config )
242+     session .config .pluginmanager .register (plugin )
187243    pkg_logger .debug (f"{ PACKAGE_NAME }  )
188244    plugin .auto_engage_debugger ()
189245
190246
247+ def  pytest_configure_node (node : xdist .workermanage .WorkerController ):
248+     """ 
249+     Special xdist-only hook, which is called as a node is configured, before instantiation & distribution 
250+ 
251+     This hook only runs on the main process (not workers), and is skipped entirely if xdist is not being used 
252+     """ 
253+     worker_id : str  =  node .workerinput ["workerid" ]
254+ 
255+     # Retrieve global shared session config 
256+     session_config  =  cast (InternalSessionConfigDataClass , node .config )
257+     temp_shared_session_dir_path  =  session_config .temp_shared_session_dir_path 
258+ 
259+     # Construct worker-scoped temp directory 
260+     temp_worker_dir_path  =  os .path .join (temp_shared_session_dir_path , worker_id )
261+     pathlib .Path (temp_worker_dir_path ).mkdir (parents = True , exist_ok = True )
262+ 
263+     # Copy worker-specific, as well as shared config values, into the node config 
264+     node .workerinput ["temp_worker_dir_path" ] =  temp_worker_dir_path 
265+     node .workerinput ["temp_shared_session_dir_path" ] =  temp_shared_session_dir_path 
266+     node .workerinput ["global_session_id" ] =  session_config .global_session_id 
267+ 
268+ 
191269class  CustomPytestPlugin :
192270    # Tell Pytest that this is not a test class 
193271    __test__  =  False 
194272
195273    def  __init__ (self , pytest_config : pytest .Config ) ->  None :
196274        self .pytest_config  =  pytest_config 
275+         self .collected_tests  =  CollectedTests (self .get_internal_shared_config (pytest_config )["global_session_id" ])
197276        self .debugger_listening  =  False 
198277        # We might or might not be running inside an xdist worker 
199-         self ._is_running_on_worker  =  False 
278+         self ._is_running_on_worker  =  not   is_main_pytest_runner ( pytest_config ) 
200279
201-     def  get_config_val (self , config_key : PluginConfigKey ):
280+     def  get_global_config_val (self , config_key : PluginConfigKey ):
202281        """ 
203282        Wrapper function just to add some extra type-safety around dynamic config keys 
204283        """ 
205284        return  self .pytest_config .getini (config_key )
206285
286+     def  get_internal_shared_config (
287+         self , pytest_obj : Union [pytest .Session , pytest .Config , pytest .FixtureRequest ]
288+     ) ->  InternalSessionConfig :
289+         """ 
290+         Utility function to get shared config values, because it can be a little tricky to know 
291+         where to retrieve them from (for main vs worker) 
292+         """ 
293+         config  =  pytest_obj  if  isinstance (pytest_obj , pytest .Config ) else  pytest_obj .config 
294+         # If we are on the main runner, we can just directly access 
295+         if  is_main_pytest_runner (config ):
296+             session_config  =  cast (InternalSessionConfigDataClass , config )
297+             return  {
298+                 "temp_shared_session_dir_path" : session_config .temp_shared_session_dir_path ,
299+                 "global_session_id" : session_config .global_session_id ,
300+             }
301+         # If we are on a worker, we can retrieve the shared config values via the `workerinput` property 
302+         worker_input  =  cast (WorkerConfigInstance , config ).workerinput 
303+         return  worker_input 
304+ 
207305    @property  
208306    def  auto_debug (self ) ->  bool :
209307        # Disable if CI is detected 
210308        if  os .getenv ("CI" , "" ).lower () ==  "true" :
211309            return  False 
212-         return  bool (self .get_config_val ("auto_debug" )) or  bool (os .getenv (f"{ PACKAGE_NAME }  , "" ))
310+         return  bool (self .get_global_config_val ("auto_debug" )) or  bool (os .getenv (f"{ PACKAGE_NAME }  , "" ))
213311
214312    @property  
215313    def  auto_debug_wait_for_connect (self ) ->  bool :
216-         return  bool (self .get_config_val ("auto_debug_wait_for_connect" ))
314+         return  bool (self .get_global_config_val ("auto_debug_wait_for_connect" ))
217315
218316    @property  
219317    def  mandate_requirement_markers (self ) ->  bool :
220-         return  bool (self .get_config_val ("mandate_requirement_markers" ))
318+         return  bool (self .get_global_config_val ("mandate_requirement_markers" ))
221319
222320    @property  
223321    def  reporting_config (self ) ->  Optional [PluginReportingConfiguration ]:
224-         csv_export_path  =  self .get_config_val ("reporting.csv_export_path" )
322+         csv_export_path  =  self .get_global_config_val ("reporting.csv_export_path" )
225323        if  not  isinstance (csv_export_path , str ):
226324            return  None 
227325        return  {
228326            "csv_export_path" : csv_export_path ,
229-             "omit_unexecuted_tests" : bool (self .get_config_val ("reporting.omit_unexecuted_tests" )),
327+             "omit_unexecuted_tests" : bool (self .get_global_config_val ("reporting.omit_unexecuted_tests" )),
230328        }
231329
232330    @property  
@@ -282,7 +380,7 @@ def pytest_collection_modifyitems(self, config: pytest.Config, items: List[pytes
282380                requirements  =  validation_results ["validated_requirements" ]
283381
284382            doc_string : str  =  item .obj .__doc__  or  ""   # type: ignore 
285-             collected_tests [item .nodeid ] =  {
383+             self . collected_tests [item .nodeid ] =  {
286384                "node_id" : item .nodeid ,
287385                "requirements" : requirements ,
288386                "doc_string" : doc_string .strip (),
@@ -294,10 +392,8 @@ def pytest_collection_modifyitems(self, config: pytest.Config, items: List[pytes
294392
295393    @pytest .hookimpl () 
296394    def  pytest_sessionstart (self , session : pytest .Session ):
297-         self ._is_running_on_worker  =  getattr (session .config , "workerinput" , None ) is  not None 
298- 
299-         if  self ._is_running_on_worker :
300-             # Nothing to do here at the moment 
395+         if  not  is_main_pytest_runner (session ):
396+             self ._is_running_on_worker  =  True 
301397            return 
302398
303399        # Init debugpy listener on main 
@@ -311,7 +407,7 @@ def pytest_collection_finish(self, session: pytest.Session):
311407    def  pytest_sessionfinish (self , session : pytest .Session , exitstatus ):
312408        if  not  self .reporting_config :
313409            return 
314-         collected_test_mappings  =  collected_tests ._get_data ()
410+         collected_test_mappings  =  self . collected_tests ._get_data ()
315411        with  open (self .reporting_config ["csv_export_path" ], "w" ) as  csv_file :
316412            # Use keys of first entry, since all entries should have same keys 
317413            fieldnames  =  collected_test_mappings [next (iter (collected_test_mappings ))].keys ()
@@ -327,4 +423,4 @@ def pytest_sessionfinish(self, session: pytest.Session, exitstatus):
327423    def  pytest_runtest_logreport (self , report : pytest .TestReport ):
328424        # Capture test outcomes and save to collection 
329425        if  report .when  ==  "call" :
330-             collected_tests .update_test_status (report .nodeid , "PASS"  if  report .passed  else  "FAIL" )
426+             self . collected_tests .update_test_status (report .nodeid , "PASS"  if  report .passed  else  "FAIL" )
0 commit comments