@@ -1326,6 +1326,107 @@ def check_exception(failed_task, exception): # pylint: disable=unused-variable
13261326 )
13271327 ]
13281328
1329+ def test_pre_process_change_index (self , tmpdir , TestTask ):
1330+ """Test that the process fails if the index is changed by the preprocess."""
1331+ dataset_df_path = str (tmpdir / "dataset.csv" )
1332+ base_dataset_df = pd .DataFrame ({"a" : [1 , 2 , 3 , 4 ], "b" : [5 , 6 , 7 , 8 ]}, index = [0 , 1 , 2 , 3 ])
1333+ base_dataset_df .to_csv (dataset_df_path , index = True , index_label = "index_col" )
1334+
1335+ class TestTaskUpdateIndex (TestTask ):
1336+ def pre_process (self , df , args , kwargs ):
1337+ df .sort_index (ascending = False , inplace = True )
1338+
1339+ @staticmethod
1340+ def validation_function (df , output_path , * args , ** kwargs ):
1341+ pass
1342+
1343+ failed_tasks = []
1344+ exceptions = []
1345+
1346+ @TestTaskUpdateIndex .event_handler (luigi .Event .FAILURE )
1347+ def check_exception (failed_task , exception ): # pylint: disable=unused-variable
1348+ failed_tasks .append (str (failed_task ))
1349+ exceptions .append (str (exception ))
1350+
1351+ failing_task = TestTaskUpdateIndex (
1352+ dataset_df = dataset_df_path ,
1353+ input_index_col = "index_col" ,
1354+ result_path = str (tmpdir / "out_preprocess_update_index" ),
1355+ )
1356+ assert not luigi .build ([failing_task ], local_scheduler = True )
1357+
1358+ assert failed_tasks == [str (failing_task )]
1359+ assert exceptions == [
1360+ str (
1361+ IndexError (
1362+ "The index changed during the process. Please update your validation function "
1363+ "or your pre/post process functions to avoid this behaviour."
1364+ )
1365+ )
1366+ ]
1367+
1368+ @pytest .mark .parametrize (
1369+ "task_type" ,
1370+ [int , str , object , float ],
1371+ )
1372+ @pytest .mark .parametrize (
1373+ "workflow_type" ,
1374+ [int , str , object , float ],
1375+ )
1376+ def test_read_dataset_change_index (
1377+ self , tmpdir , TestTask , dataset_df_path , task_type , workflow_type
1378+ ):
1379+ """Test that the process succeeds if the index is only changed by the preprocess."""
1380+
1381+ class TestTaskUpdateIndex (TestTask ):
1382+ """A simple Task."""
1383+
1384+ def transform_index (self , df ):
1385+ df .index = df .index .astype (task_type )
1386+ return df
1387+
1388+ class TestWorkflow (task .ValidationWorkflow ):
1389+ """A validation workflow."""
1390+
1391+ def transform_index (self , df ):
1392+ df .index = df .index .astype (workflow_type )
1393+ return df
1394+
1395+ def inputs (self ):
1396+ return {
1397+ TestTaskUpdateIndex : {},
1398+ }
1399+
1400+ @staticmethod
1401+ def validation_function (df , output_path , * args , ** kwargs ):
1402+ if task_type == float and workflow_type == str :
1403+ assert len (df ) == 0
1404+ else :
1405+ assert len (df ) == 2
1406+
1407+ failed_tasks = []
1408+ exceptions = []
1409+
1410+ @TestWorkflow .event_handler (luigi .Event .FAILURE )
1411+ def check_exception (failed_task , exception ): # pylint: disable=unused-variable
1412+ failed_tasks .append (str (failed_task ))
1413+ exceptions .append (str (exception ))
1414+
1415+ workflow_with_index_cast = TestWorkflow (
1416+ dataset_df = dataset_df_path ,
1417+ result_path = str (tmpdir / "out_preprocess_update_index" ),
1418+ )
1419+ assert luigi .build ([workflow_with_index_cast ], local_scheduler = True )
1420+
1421+ assert not failed_tasks
1422+ assert not exceptions
1423+ res = pd .read_csv (tmpdir / "out_preprocess_update_index" / "TestWorkflow" / "report.csv" )
1424+ if task_type == float and workflow_type == str :
1425+ assert len (res ) == 0
1426+ else :
1427+ assert len (res ) == 2
1428+ assert res ["is_valid" ].all ()
1429+
13291430 def test_missing_retcodes (self , tmpdir , dataset_df_path , TestTask ):
13301431 """Test invalid retcodes."""
13311432
@@ -1512,6 +1613,7 @@ def validation_function(df, output_path, *args, **kwargs):
15121613 res = pd .read_csv (tmpdir / "extra_requires" / "TestTaskB" / "report.csv" )
15131614 assert (res ["extra_path" ] == str (tmpdir / "file.test" )).all ()
15141615 assert (res ["extra_result" ] == "result of TestTaskA" ).all ()
1616+ assert Path (res .loc [0 , "extra_path" ]).exists ()
15151617
15161618 def test_static_args_kwargs (self , dataset_df_path ):
15171619 """Test the args and kwargs feature."""
0 commit comments