Skip to content

Commit da5bb83

Browse files
Feat: Add a transform_index() method to transform the dataset index (#41)
1 parent ce9ea3e commit da5bb83

File tree

2 files changed

+116
-2
lines changed

2 files changed

+116
-2
lines changed

data_validation_framework/task.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,15 @@ def read_dataset(self):
297297
"""
298298
return pd.read_csv(self.dataset_df, index_col=self.input_index_col)
299299

300+
def transform_index(self, df):
301+
"""Method executed after loading the dataset to transform its index.
302+
303+
.. note::
304+
305+
This transformation is applied to both the dataset and the input reports.
306+
"""
307+
return df
308+
300309
def pre_process(self, df, args, kwargs):
301310
"""Method executed before applying the external function."""
302311

@@ -440,6 +449,7 @@ def _get_dataset(self):
440449
if self.dataset_df is not None:
441450
L.info("Input dataset: %s", Path(self.dataset_df).resolve())
442451
new_df = self.read_dataset()
452+
new_df = self.transform_index(new_df)
443453
duplicated_index = new_df.index.duplicated()
444454
if duplicated_index.any():
445455
raise IndexError(
@@ -473,8 +483,10 @@ def _join_inputs(self, new_df):
473483
}
474484
L.debug("Importing the following reports: %s", all_report_paths)
475485
all_dfs = {
476-
task_obj: self._rename_cols(
477-
pd.read_csv(path, index_col=INDEX_LABEL).rename_axis(index="index")
486+
task_obj: self.transform_index(
487+
self._rename_cols(
488+
pd.read_csv(path, index_col=INDEX_LABEL).rename_axis(index="index")
489+
)
478490
)
479491
for task_obj, path in all_report_paths.items()
480492
}

tests/test_task.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1326,6 +1326,107 @@ def check_exception(failed_task, exception): # pylint: disable=unused-variable
13261326
)
13271327
]
13281328

1329+
def test_pre_process_change_index(self, tmpdir, TestTask):
1330+
"""Test that the process fails if the index is changed by the preprocess."""
1331+
dataset_df_path = str(tmpdir / "dataset.csv")
1332+
base_dataset_df = pd.DataFrame({"a": [1, 2, 3, 4], "b": [5, 6, 7, 8]}, index=[0, 1, 2, 3])
1333+
base_dataset_df.to_csv(dataset_df_path, index=True, index_label="index_col")
1334+
1335+
class TestTaskUpdateIndex(TestTask):
1336+
def pre_process(self, df, args, kwargs):
1337+
df.sort_index(ascending=False, inplace=True)
1338+
1339+
@staticmethod
1340+
def validation_function(df, output_path, *args, **kwargs):
1341+
pass
1342+
1343+
failed_tasks = []
1344+
exceptions = []
1345+
1346+
@TestTaskUpdateIndex.event_handler(luigi.Event.FAILURE)
1347+
def check_exception(failed_task, exception): # pylint: disable=unused-variable
1348+
failed_tasks.append(str(failed_task))
1349+
exceptions.append(str(exception))
1350+
1351+
failing_task = TestTaskUpdateIndex(
1352+
dataset_df=dataset_df_path,
1353+
input_index_col="index_col",
1354+
result_path=str(tmpdir / "out_preprocess_update_index"),
1355+
)
1356+
assert not luigi.build([failing_task], local_scheduler=True)
1357+
1358+
assert failed_tasks == [str(failing_task)]
1359+
assert exceptions == [
1360+
str(
1361+
IndexError(
1362+
"The index changed during the process. Please update your validation function "
1363+
"or your pre/post process functions to avoid this behaviour."
1364+
)
1365+
)
1366+
]
1367+
1368+
@pytest.mark.parametrize(
1369+
"task_type",
1370+
[int, str, object, float],
1371+
)
1372+
@pytest.mark.parametrize(
1373+
"workflow_type",
1374+
[int, str, object, float],
1375+
)
1376+
def test_read_dataset_change_index(
1377+
self, tmpdir, TestTask, dataset_df_path, task_type, workflow_type
1378+
):
1379+
"""Test that the process succeeds if the index is only changed by the preprocess."""
1380+
1381+
class TestTaskUpdateIndex(TestTask):
1382+
"""A simple Task."""
1383+
1384+
def transform_index(self, df):
1385+
df.index = df.index.astype(task_type)
1386+
return df
1387+
1388+
class TestWorkflow(task.ValidationWorkflow):
1389+
"""A validation workflow."""
1390+
1391+
def transform_index(self, df):
1392+
df.index = df.index.astype(workflow_type)
1393+
return df
1394+
1395+
def inputs(self):
1396+
return {
1397+
TestTaskUpdateIndex: {},
1398+
}
1399+
1400+
@staticmethod
1401+
def validation_function(df, output_path, *args, **kwargs):
1402+
if task_type == float and workflow_type == str:
1403+
assert len(df) == 0
1404+
else:
1405+
assert len(df) == 2
1406+
1407+
failed_tasks = []
1408+
exceptions = []
1409+
1410+
@TestWorkflow.event_handler(luigi.Event.FAILURE)
1411+
def check_exception(failed_task, exception): # pylint: disable=unused-variable
1412+
failed_tasks.append(str(failed_task))
1413+
exceptions.append(str(exception))
1414+
1415+
workflow_with_index_cast = TestWorkflow(
1416+
dataset_df=dataset_df_path,
1417+
result_path=str(tmpdir / "out_preprocess_update_index"),
1418+
)
1419+
assert luigi.build([workflow_with_index_cast], local_scheduler=True)
1420+
1421+
assert not failed_tasks
1422+
assert not exceptions
1423+
res = pd.read_csv(tmpdir / "out_preprocess_update_index" / "TestWorkflow" / "report.csv")
1424+
if task_type == float and workflow_type == str:
1425+
assert len(res) == 0
1426+
else:
1427+
assert len(res) == 2
1428+
assert res["is_valid"].all()
1429+
13291430
def test_missing_retcodes(self, tmpdir, dataset_df_path, TestTask):
13301431
"""Test invalid retcodes."""
13311432

@@ -1512,6 +1613,7 @@ def validation_function(df, output_path, *args, **kwargs):
15121613
res = pd.read_csv(tmpdir / "extra_requires" / "TestTaskB" / "report.csv")
15131614
assert (res["extra_path"] == str(tmpdir / "file.test")).all()
15141615
assert (res["extra_result"] == "result of TestTaskA").all()
1616+
assert Path(res.loc[0, "extra_path"]).exists()
15151617

15161618
def test_static_args_kwargs(self, dataset_df_path):
15171619
"""Test the args and kwargs feature."""

0 commit comments

Comments
 (0)