diff --git a/sdmetrics/column_pairs/statistical/contingency_similarity.py b/sdmetrics/column_pairs/statistical/contingency_similarity.py index 348b0896..61ace34e 100644 --- a/sdmetrics/column_pairs/statistical/contingency_similarity.py +++ b/sdmetrics/column_pairs/statistical/contingency_similarity.py @@ -117,7 +117,9 @@ def compute_breakdown( contingency_synthetic = synthetic.groupby(list(columns), dropna=False).size() / len( synthetic ) - combined_index = contingency_real.index.union(contingency_synthetic.index, sort=False) + combined_index = contingency_real.index.union( + contingency_synthetic.index, sort=False + ).drop_duplicates() contingency_synthetic = contingency_synthetic.reindex(combined_index, fill_value=0) contingency_real = contingency_real.reindex(combined_index, fill_value=0) diff = abs(contingency_real - contingency_synthetic).fillna(0) diff --git a/sdmetrics/reports/multi_table/_properties/inter_table_trends.py b/sdmetrics/reports/multi_table/_properties/inter_table_trends.py index 04fae41e..1a33caf2 100644 --- a/sdmetrics/reports/multi_table/_properties/inter_table_trends.py +++ b/sdmetrics/reports/multi_table/_properties/inter_table_trends.py @@ -11,6 +11,7 @@ ColumnPairTrends as SingleTableColumnPairTrends, ) from sdmetrics.reports.utils import PlotConfig +from sdmetrics.utils import _cast_to_iterable class InterTableTrends(BaseMultiTableProperty): @@ -50,16 +51,16 @@ def _denormalize_tables(self, real_data, synthetic_data, relationship): """ parent = relationship['parent_table_name'] child = relationship['child_table_name'] - foreign_key = relationship['child_foreign_key'] - primary_key = relationship['parent_primary_key'] + foreign_key = _cast_to_iterable(relationship['child_foreign_key']) + primary_key = _cast_to_iterable(relationship['parent_primary_key']) real_parent = real_data[parent].add_prefix(f'{parent}.') real_child = real_data[child].add_prefix(f'{child}.') synthetic_parent = synthetic_data[parent].add_prefix(f'{parent}.') synthetic_child = synthetic_data[child].add_prefix(f'{child}.') - child_index = f'{child}.{foreign_key}' - parent_index = f'{parent}.{primary_key}' + child_index = [f'{child}.{key_col}' for key_col in foreign_key] + parent_index = [f'{parent}.{key_col}' for key_col in primary_key] denormalized_real = real_child.merge( real_parent, left_on=child_index, right_on=parent_index @@ -101,7 +102,12 @@ def _merge_metadata(self, metadata, parent_table, child_table): merged_metadata['columns'] = {**child_cols, **parent_cols} if 'primary_key' in merged_metadata: primary_key = merged_metadata['primary_key'] - merged_metadata['primary_key'] = f'{child_table}.{primary_key}' + if isinstance(primary_key, list): + merged_metadata['primary_key'] = [ + f'{child_table}.{pk_col}' for pk_col in primary_key + ] + else: + merged_metadata['primary_key'] = f'{child_table}.{primary_key}' return merged_metadata, list(parent_cols.keys()), list(child_cols.keys()) @@ -123,6 +129,7 @@ def _generate_details(self, real_data, synthetic_data, metadata, progress_bar=No parent = relationship['parent_table_name'] child = relationship['child_table_name'] foreign_key = relationship['child_foreign_key'] + fk_tuple = tuple(foreign_key) if isinstance(foreign_key, list) else foreign_key denormalized_real, denormalized_synthetic = self._denormalize_tables( real_data, synthetic_data, relationship @@ -132,14 +139,14 @@ def _generate_details(self, real_data, synthetic_data, metadata, progress_bar=No parent_child_pairs = itertools.product(parent_cols, child_cols) - self._properties[(parent, child, foreign_key)] = SingleTableColumnPairTrends() + self._properties[(parent, child, fk_tuple)] = SingleTableColumnPairTrends() self._properties[ - (parent, child, foreign_key) + (parent, child, fk_tuple) ].real_correlation_threshold = self.real_correlation_threshold self._properties[ - (parent, child, foreign_key) + (parent, child, fk_tuple) ].real_association_threshold = self.real_association_threshold - details = self._properties[(parent, child, foreign_key)]._generate_details( + details = self._properties[(parent, child, fk_tuple)]._generate_details( denormalized_real, denormalized_synthetic, merged_metadata, @@ -149,7 +156,7 @@ def _generate_details(self, real_data, synthetic_data, metadata, progress_bar=No details['Parent Table'] = parent details['Child Table'] = child - details['Foreign Key'] = foreign_key + details['Foreign Key'] = str(foreign_key) if not details.empty: details['Column 1'] = details['Column 1'].str.replace( f'{parent}.', '', n=1, regex=False diff --git a/sdmetrics/reports/multi_table/base_multi_table_report.py b/sdmetrics/reports/multi_table/base_multi_table_report.py index 319c182f..6379d570 100644 --- a/sdmetrics/reports/multi_table/base_multi_table_report.py +++ b/sdmetrics/reports/multi_table/base_multi_table_report.py @@ -3,6 +3,7 @@ import pandas as pd from sdmetrics.reports.base_report import BaseReport +from sdmetrics.utils import _cast_to_iterable from sdmetrics.visualization import set_plotly_config @@ -43,22 +44,36 @@ def _validate_data_format(self, real_data, synthetic_data): def _validate_relationships(self, real_data, synthetic_data, metadata): """Validate that the relationships are valid.""" for rel in metadata.get('relationships', []): - parent_dtype = real_data[rel['parent_table_name']][rel['parent_primary_key']].dtype - child_dtype = real_data[rel['child_table_name']][rel['child_foreign_key']].dtype - if (parent_dtype == 'object' and child_dtype != 'object') or ( - parent_dtype != 'object' and child_dtype == 'object' - ): - parent = rel['parent_table_name'] - parent_key = rel['parent_primary_key'] - child = rel['child_table_name'] - child_key = rel['child_foreign_key'] + parent = rel['parent_table_name'] + parent_key = rel['parent_primary_key'] + child = rel['child_table_name'] + child_key = rel['child_foreign_key'] + parent_key_str = f"'{parent_key}'" if isinstance(parent_key, str) else str(parent_key) + child_key_str = f"'{child_key}'" if isinstance(child_key, str) else str(child_key) + parent_primary_key = _cast_to_iterable(parent_key) + child_foreign_key = _cast_to_iterable(child_key) + + if len(parent_primary_key) != len(child_foreign_key): error_msg = ( f"The '{parent}' table and '{child}' table cannot be merged " - 'for computing the cardinality. Please make sure the primary key' - f" in '{parent}' ('{parent_key}') and the foreign key in '{child}'" - f" ('{child_key}') have the same data type." + 'for computing the cardinality. Please make sure the number of columns ' + f'in the primary key ({parent_key_str}) matches the number of ' + f'columns in the foreign key ({child_key_str}).' ) raise ValueError(error_msg) + parent_dtypes = real_data[rel['parent_table_name']][parent_primary_key].dtypes + child_dtypes = real_data[rel['child_table_name']][child_foreign_key].dtypes + for parent_dtype, child_dtype in zip(parent_dtypes, child_dtypes): + if (parent_dtype == 'object' and child_dtype != 'object') or ( + parent_dtype != 'object' and child_dtype == 'object' + ): + error_msg = ( + f"The '{parent}' table and '{child}' table cannot be merged " + 'for computing the cardinality. Please make sure the primary key' + f" in '{parent}' ({parent_key_str}) and the foreign key in '{child}'" + f' ({child_key_str}) have the same data types.' + ) + raise ValueError(error_msg) def _validate_metadata_matches_data(self, real_data, synthetic_data, metadata): """Validate that the metadata matches the data.""" diff --git a/sdmetrics/reports/single_table/_properties/data_validity.py b/sdmetrics/reports/single_table/_properties/data_validity.py index 7a7eea9d..6c661560 100644 --- a/sdmetrics/reports/single_table/_properties/data_validity.py +++ b/sdmetrics/reports/single_table/_properties/data_validity.py @@ -39,13 +39,19 @@ def _generate_details(self, real_data, synthetic_data, metadata, progress_bar=No The progress bar to use. Defaults to None. """ column_names, metric_names, scores = [], [], [] + column_sdtypes = [(col, metadata['columns'][col]['sdtype']) for col in metadata['columns']] error_messages = [] primary_key = metadata.get('primary_key') + if isinstance(primary_key, list): + if len(primary_key) > 1: + column_sdtypes = [(primary_key, None)] + column_sdtypes + else: + primary_key = primary_key[0] + alternate_keys = metadata.get('alternate_keys', []) sequence_index = metadata.get('sequence_index') - for column_name in metadata['columns']: - sdtype = metadata['columns'][column_name]['sdtype'] + for column_name, sdtype in column_sdtypes: primary_key_match = column_name == primary_key alternate_key_match = column_name in alternate_keys is_unique = primary_key_match or alternate_key_match diff --git a/sdmetrics/utils.py b/sdmetrics/utils.py index db1267e3..c516fbbb 100644 --- a/sdmetrics/utils.py +++ b/sdmetrics/utils.py @@ -321,3 +321,11 @@ def strip_characters(list_character, a_string): result = result.replace(character, '') return result + + +def _cast_to_iterable(value): + """Return a ``list`` if the input object is not a ``list`` or ``tuple``.""" + if isinstance(value, (list, tuple)): + return value + + return [value] diff --git a/tests/integration/reports/conftest.py b/tests/integration/reports/conftest.py new file mode 100644 index 00000000..d8a0014b --- /dev/null +++ b/tests/integration/reports/conftest.py @@ -0,0 +1,48 @@ +import pytest + +from sdmetrics.demos import load_demo + + +@pytest.fixture(scope='module') +def composite_keys_single_table_demo(): + real_data, synthetic_data, metadata = load_demo(modality='single_table') + metadata['primary_key'] = ['student_id', 'degree_type'] + return real_data, synthetic_data, metadata + + +@pytest.fixture(scope='module') +def composite_keys_multi_table_demo(): + real_data, synthetic_data, metadata = load_demo(modality='multi_table') + metadata['tables']['users']['columns']['user_type'] = {'sdtype': 'categorical'} + metadata['tables']['users']['primary_key'] = ['user_id', 'user_type'] + metadata['tables']['sessions']['columns']['user_type'] = {'sdtype': 'categorical'} + metadata['tables']['sessions']['columns']['user_type'] = {'sdtype': 'categorical'} + metadata['tables']['sessions']['primary_key'] = ['session_id', 'device'] + metadata['tables']['transactions']['columns']['device'] = {'sdtype': 'categorical'} + + metadata['relationships'][0]['parent_primary_key'] = ['user_id', 'user_type'] + metadata['relationships'][0]['child_foreign_key'] = ['user_id', 'user_type'] + metadata['relationships'][1]['parent_primary_key'] = ['session_id', 'device'] + metadata['relationships'][1]['child_foreign_key'] = ['session_id', 'device'] + + real_data['users']['user_type'] = ['PREMIUM'] * 5 + [None] * 5 + synthetic_data['users']['user_type'] = ['PREMIUM'] * 5 + [None] * 5 + for data in [real_data, synthetic_data]: + data['sessions']['user_type'] = ( + data['users'] + .set_index('user_id') + .loc[data['sessions']['user_id']]['user_type'] + .to_numpy() + ) + data['transactions']['device'] = ( + data['sessions'] + .set_index('session_id') + .loc[data['transactions']['session_id']]['device'] + .to_numpy() + ) + premium_mask = data['users']['user_type'] == 'PREMIUM' + data['users'].loc[premium_mask, 'user_id'] = range(5) + data['users'].loc[~premium_mask, 'user_id'] = range(5) + data['sessions'].loc[data['sessions']['user_type'].isna(), 'user_id'] -= 5 + + return real_data, synthetic_data, metadata diff --git a/tests/integration/reports/multi_table/test_diagnostic_report.py b/tests/integration/reports/multi_table/test_diagnostic_report.py index f006915c..c983f055 100644 --- a/tests/integration/reports/multi_table/test_diagnostic_report.py +++ b/tests/integration/reports/multi_table/test_diagnostic_report.py @@ -19,6 +19,18 @@ def test_end_to_end(self): # Assert assert results == 1.0 + def test_end_to_end_composite_keys(self, composite_keys_multi_table_demo): + """Test the end-to-end functionality of the ``DiagnosticReport`` report.""" + real_data, synthetic_data, metadata = composite_keys_multi_table_demo + report = DiagnosticReport() + + # Run + report.generate(real_data, synthetic_data, metadata, verbose=False) + results = report.get_score() + + # Assert + assert results == 1.0 + def test_end_to_end_with_object_datetimes(self): """Test the ``DiagnosticReport`` report with object datetimes.""" real_data, synthetic_data, metadata = load_demo(modality='multi_table') diff --git a/tests/integration/reports/multi_table/test_quality_report.py b/tests/integration/reports/multi_table/test_quality_report.py index 8e96415a..48527ead 100644 --- a/tests/integration/reports/multi_table/test_quality_report.py +++ b/tests/integration/reports/multi_table/test_quality_report.py @@ -3,6 +3,7 @@ import numpy as np import pandas as pd +import pytest from packaging import version from sdmetrics.demos import load_demo @@ -230,10 +231,14 @@ def test_multi_table_quality_report(): assert report_info['generation_time'] <= generate_end_time - generate_start_time -def test_quality_report_end_to_end(): +@pytest.mark.parametrize('key_type', ['single', 'composite']) +def test_quality_report_end_to_end(key_type, composite_keys_multi_table_demo): """Test the multi table QualityReport end to end.""" # Setup - real_data, synthetic_data, metadata = load_demo(modality='multi_table') + if key_type == 'single': + real_data, synthetic_data, metadata = load_demo(modality='multi_table') + else: + real_data, synthetic_data, metadata = composite_keys_multi_table_demo report = QualityReport() _set_thresholds_zero(report) @@ -244,11 +249,17 @@ def test_quality_report_end_to_end(): info = report.get_info() # Assert + expected_single_scores = [0.7978174603174604, 0.45654629583521095, 0.95, 0.4416666666666666] + expected_composite_scores = [0.82568543, 0.53305494, 0.95, 0.5375] expected_properties = pd.DataFrame({ 'Property': ['Column Shapes', 'Column Pair Trends', 'Cardinality', 'Intertable Trends'], - 'Score': [0.7978174603174604, 0.45654629583521095, 0.95, 0.4416666666666666], + 'Score': expected_single_scores if key_type == 'single' else expected_composite_scores, }) - assert score == 0.6615076057048344 + if key_type == 'single': + assert score == 0.6615076057048344 + else: + assert score == 0.7115600909354644 + pd.testing.assert_frame_equal(properties, expected_properties) expected_info_keys = { 'report_type', diff --git a/tests/integration/reports/single_table/test_diagnostic_report.py b/tests/integration/reports/single_table/test_diagnostic_report.py index e5226c17..517439da 100644 --- a/tests/integration/reports/single_table/test_diagnostic_report.py +++ b/tests/integration/reports/single_table/test_diagnostic_report.py @@ -119,7 +119,88 @@ def test_end_to_end(self): 1.0, ], }) + expected_details_data_structure = pd.DataFrame({ + 'Metric': ['TableStructure'], + 'Score': [1.0], + }) + + pd.testing.assert_frame_equal( + report.get_details('Data Validity'), expected_details_data_validity + ) + + pd.testing.assert_frame_equal( + report.get_details('Data Structure'), expected_details_data_structure + ) + + def test_end_to_end_composite_keys(self, composite_keys_single_table_demo): + """Test the end-to-end functionality of the diagnostic report.""" + # Setup + real_data, synthetic_data, metadata = composite_keys_single_table_demo + report = DiagnosticReport() + # Run + report.generate(real_data, synthetic_data, metadata) + + # Assert + expected_details_data_validity = pd.DataFrame({ + 'Column': [ + ['student_id', 'degree_type'], + 'start_date', + 'end_date', + 'salary', + 'duration', + 'high_perc', + 'high_spec', + 'mba_spec', + 'second_perc', + 'gender', + 'degree_perc', + 'placed', + 'experience_years', + 'employability_perc', + 'mba_perc', + 'work_experience', + 'degree_type', + ], + 'Metric': [ + 'KeyUniqueness', + 'BoundaryAdherence', + 'BoundaryAdherence', + 'BoundaryAdherence', + 'BoundaryAdherence', + 'BoundaryAdherence', + 'CategoryAdherence', + 'CategoryAdherence', + 'BoundaryAdherence', + 'CategoryAdherence', + 'BoundaryAdherence', + 'CategoryAdherence', + 'BoundaryAdherence', + 'BoundaryAdherence', + 'BoundaryAdherence', + 'CategoryAdherence', + 'CategoryAdherence', + ], + 'Score': [ + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + ], + }) expected_details_data_structure = pd.DataFrame({ 'Metric': ['TableStructure'], 'Score': [1.0], diff --git a/tests/integration/reports/single_table/test_quality_report.py b/tests/integration/reports/single_table/test_quality_report.py index 998334ab..066dd4bc 100644 --- a/tests/integration/reports/single_table/test_quality_report.py +++ b/tests/integration/reports/single_table/test_quality_report.py @@ -3,6 +3,7 @@ import numpy as np import pandas as pd +import pytest from sdmetrics.demos import load_demo from sdmetrics.reports.single_table import QualityReport @@ -70,14 +71,18 @@ def test__get_properties(self): }), ) - def test_report_end_to_end(self): + @pytest.mark.parametrize('key_type', ['single', 'composite']) + def test_report_end_to_end(self, key_type, composite_keys_single_table_demo): """Test the quality report end to end. The report must compute each property and the overall quality score. """ # Setup column_names = ['student_id', 'degree_type', 'start_date', 'second_perc', 'work_experience'] - real_data, synthetic_data, metadata = load_demo(modality='single_table') + if key_type == 'single': + real_data, synthetic_data, metadata = load_demo(modality='single_table') + else: + real_data, synthetic_data, metadata = composite_keys_single_table_demo metadata['columns'] = { key: val for key, val in metadata['columns'].items() if key in column_names diff --git a/tests/unit/reports/multi_table/test_base_multi_table_report.py b/tests/unit/reports/multi_table/test_base_multi_table_report.py index 02cafb11..b26b7038 100644 --- a/tests/unit/reports/multi_table/test_base_multi_table_report.py +++ b/tests/unit/reports/multi_table/test_base_multi_table_report.py @@ -87,11 +87,55 @@ def test__validate_relationships(self): expected_error_message = re.escape( "The 'Table_1' table and 'Table_2' table cannot be merged for computing" " the cardinality. Please make sure the primary key in 'Table_1' ('col1')" - " and the foreign key in 'Table_2' ('col2') have the same data type." + " and the foreign key in 'Table_2' ('col2') have the same data types." ) with pytest.raises(ValueError, match=expected_error_message): report._validate_metadata_matches_data(real_data_bad, synthetic_data, metadata) + def test__validate_relationships_num_key_cols_mismatch(self): + """Test the ``_validate_relationships`` method.""" + # Setup + real_data = { + 'Table_1': pd.DataFrame({'col1': [1, 2, 3]}), + 'Table_2': pd.DataFrame({'col2': [1, 2, 3]}), + } + synthetic_data = { + 'Table_1': pd.DataFrame({'col1': [1, 2, 3]}), + 'Table_2': pd.DataFrame({'col2': [1, 2, 3]}), + } + metadata = { + 'tables': { + 'Table_1': { + 'columns': { + 'col1': {}, + }, + }, + 'Table_2': { + 'columns': {'col2': {}}, + }, + }, + 'relationships': [ + { + 'parent_table_name': 'Table_1', + 'parent_primary_key': ['col1', 'col2'], + 'child_table_name': 'Table_2', + 'child_foreign_key': 'col2', + }, + ], + } + + report = BaseMultiTableReport() + + # Run and Assert + expected_error_message = re.escape( + "The 'Table_1' table and 'Table_2' table cannot be merged for computing" + ' the cardinality. Please make sure the number of columns' + " in the primary key (['col1', 'col2']) matches the number of" + " columns in the foreign key ('col2')." + ) + with pytest.raises(ValueError, match=expected_error_message): + report._validate_relationships(real_data, synthetic_data, metadata) + @patch('sdmetrics.reports.base_report.BaseReport._validate_metadata_matches_data') def test__validate_metadata_matches_data(self, mock__validate_metadata_matches_data): """Test the ``_validate_metadata_matches_data`` method."""