Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion sdmetrics/column_pairs/statistical/contingency_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,9 @@ def compute_breakdown(
contingency_synthetic = synthetic.groupby(list(columns), dropna=False).size() / len(
synthetic
)
combined_index = contingency_real.index.union(contingency_synthetic.index, sort=False)
combined_index = contingency_real.index.union(
contingency_synthetic.index, sort=False
).drop_duplicates()
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixes a test error on minimum versions: union on older versions of pandas doesn't treat NaNs as a single value so if a column has a null category, duplicate indices get created. Adding the drop_duplicates call removes the duplicated indices.

contingency_synthetic = contingency_synthetic.reindex(combined_index, fill_value=0)
contingency_real = contingency_real.reindex(combined_index, fill_value=0)
diff = abs(contingency_real - contingency_synthetic).fillna(0)
Expand Down
27 changes: 17 additions & 10 deletions sdmetrics/reports/multi_table/_properties/inter_table_trends.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
ColumnPairTrends as SingleTableColumnPairTrends,
)
from sdmetrics.reports.utils import PlotConfig
from sdmetrics.utils import _cast_to_iterable


class InterTableTrends(BaseMultiTableProperty):
Expand Down Expand Up @@ -50,16 +51,16 @@ def _denormalize_tables(self, real_data, synthetic_data, relationship):
"""
parent = relationship['parent_table_name']
child = relationship['child_table_name']
foreign_key = relationship['child_foreign_key']
primary_key = relationship['parent_primary_key']
foreign_key = _cast_to_iterable(relationship['child_foreign_key'])
primary_key = _cast_to_iterable(relationship['parent_primary_key'])

real_parent = real_data[parent].add_prefix(f'{parent}.')
real_child = real_data[child].add_prefix(f'{child}.')
synthetic_parent = synthetic_data[parent].add_prefix(f'{parent}.')
synthetic_child = synthetic_data[child].add_prefix(f'{child}.')

child_index = f'{child}.{foreign_key}'
parent_index = f'{parent}.{primary_key}'
child_index = [f'{child}.{key_col}' for key_col in foreign_key]
parent_index = [f'{parent}.{key_col}' for key_col in primary_key]

denormalized_real = real_child.merge(
real_parent, left_on=child_index, right_on=parent_index
Expand Down Expand Up @@ -101,7 +102,12 @@ def _merge_metadata(self, metadata, parent_table, child_table):
merged_metadata['columns'] = {**child_cols, **parent_cols}
if 'primary_key' in merged_metadata:
primary_key = merged_metadata['primary_key']
merged_metadata['primary_key'] = f'{child_table}.{primary_key}'
if isinstance(primary_key, list):
merged_metadata['primary_key'] = [
f'{child_table}.{pk_col}' for pk_col in primary_key
]
else:
merged_metadata['primary_key'] = f'{child_table}.{primary_key}'

return merged_metadata, list(parent_cols.keys()), list(child_cols.keys())

Expand All @@ -123,6 +129,7 @@ def _generate_details(self, real_data, synthetic_data, metadata, progress_bar=No
parent = relationship['parent_table_name']
child = relationship['child_table_name']
foreign_key = relationship['child_foreign_key']
fk_tuple = tuple(foreign_key) if isinstance(foreign_key, list) else foreign_key

denormalized_real, denormalized_synthetic = self._denormalize_tables(
real_data, synthetic_data, relationship
Expand All @@ -132,14 +139,14 @@ def _generate_details(self, real_data, synthetic_data, metadata, progress_bar=No

parent_child_pairs = itertools.product(parent_cols, child_cols)

self._properties[(parent, child, foreign_key)] = SingleTableColumnPairTrends()
self._properties[(parent, child, fk_tuple)] = SingleTableColumnPairTrends()
self._properties[
(parent, child, foreign_key)
(parent, child, fk_tuple)
].real_correlation_threshold = self.real_correlation_threshold
self._properties[
(parent, child, foreign_key)
(parent, child, fk_tuple)
].real_association_threshold = self.real_association_threshold
details = self._properties[(parent, child, foreign_key)]._generate_details(
details = self._properties[(parent, child, fk_tuple)]._generate_details(
denormalized_real,
denormalized_synthetic,
merged_metadata,
Expand All @@ -149,7 +156,7 @@ def _generate_details(self, real_data, synthetic_data, metadata, progress_bar=No

details['Parent Table'] = parent
details['Child Table'] = child
details['Foreign Key'] = foreign_key
details['Foreign Key'] = str(foreign_key)
if not details.empty:
details['Column 1'] = details['Column 1'].str.replace(
f'{parent}.', '', n=1, regex=False
Expand Down
39 changes: 27 additions & 12 deletions sdmetrics/reports/multi_table/base_multi_table_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pandas as pd

from sdmetrics.reports.base_report import BaseReport
from sdmetrics.utils import _cast_to_iterable
from sdmetrics.visualization import set_plotly_config


Expand Down Expand Up @@ -43,22 +44,36 @@ def _validate_data_format(self, real_data, synthetic_data):
def _validate_relationships(self, real_data, synthetic_data, metadata):
"""Validate that the relationships are valid."""
for rel in metadata.get('relationships', []):
parent_dtype = real_data[rel['parent_table_name']][rel['parent_primary_key']].dtype
child_dtype = real_data[rel['child_table_name']][rel['child_foreign_key']].dtype
if (parent_dtype == 'object' and child_dtype != 'object') or (
parent_dtype != 'object' and child_dtype == 'object'
):
parent = rel['parent_table_name']
parent_key = rel['parent_primary_key']
child = rel['child_table_name']
child_key = rel['child_foreign_key']
parent = rel['parent_table_name']
parent_key = rel['parent_primary_key']
child = rel['child_table_name']
child_key = rel['child_foreign_key']
parent_key_str = f"'{parent_key}'" if isinstance(parent_key, str) else str(parent_key)
child_key_str = f"'{child_key}'" if isinstance(child_key, str) else str(child_key)
parent_primary_key = _cast_to_iterable(parent_key)
child_foreign_key = _cast_to_iterable(child_key)

if len(parent_primary_key) != len(child_foreign_key):
error_msg = (
f"The '{parent}' table and '{child}' table cannot be merged "
'for computing the cardinality. Please make sure the primary key'
f" in '{parent}' ('{parent_key}') and the foreign key in '{child}'"
f" ('{child_key}') have the same data type."
'for computing the cardinality. Please make sure the number of columns '
f'in the primary key ({parent_key_str}) matches the number of '
f'columns in the foreign key ({child_key_str}).'
)
raise ValueError(error_msg)
parent_dtypes = real_data[rel['parent_table_name']][parent_primary_key].dtypes
child_dtypes = real_data[rel['child_table_name']][child_foreign_key].dtypes
for parent_dtype, child_dtype in zip(parent_dtypes, child_dtypes):
if (parent_dtype == 'object' and child_dtype != 'object') or (
parent_dtype != 'object' and child_dtype == 'object'
):
error_msg = (
f"The '{parent}' table and '{child}' table cannot be merged "
'for computing the cardinality. Please make sure the primary key'
f" in '{parent}' ({parent_key_str}) and the foreign key in '{child}'"
f' ({child_key_str}) have the same data types.'
)
raise ValueError(error_msg)

def _validate_metadata_matches_data(self, real_data, synthetic_data, metadata):
"""Validate that the metadata matches the data."""
Expand Down
10 changes: 8 additions & 2 deletions sdmetrics/reports/single_table/_properties/data_validity.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,19 @@ def _generate_details(self, real_data, synthetic_data, metadata, progress_bar=No
The progress bar to use. Defaults to None.
"""
column_names, metric_names, scores = [], [], []
column_sdtypes = [(col, metadata['columns'][col]['sdtype']) for col in metadata['columns']]
error_messages = []
primary_key = metadata.get('primary_key')
if isinstance(primary_key, list):
if len(primary_key) > 1:
column_sdtypes = [(primary_key, None)] + column_sdtypes
else:
primary_key = primary_key[0]

alternate_keys = metadata.get('alternate_keys', [])
sequence_index = metadata.get('sequence_index')

for column_name in metadata['columns']:
sdtype = metadata['columns'][column_name]['sdtype']
for column_name, sdtype in column_sdtypes:
primary_key_match = column_name == primary_key
alternate_key_match = column_name in alternate_keys
is_unique = primary_key_match or alternate_key_match
Expand Down
8 changes: 8 additions & 0 deletions sdmetrics/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,3 +321,11 @@ def strip_characters(list_character, a_string):
result = result.replace(character, '')

return result


def _cast_to_iterable(value):
"""Return a ``list`` if the input object is not a ``list`` or ``tuple``."""
if isinstance(value, (list, tuple)):
return value

return [value]
48 changes: 48 additions & 0 deletions tests/integration/reports/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import pytest

from sdmetrics.demos import load_demo


@pytest.fixture(scope='module')
def composite_keys_single_table_demo():
real_data, synthetic_data, metadata = load_demo(modality='single_table')
metadata['primary_key'] = ['student_id', 'degree_type']
return real_data, synthetic_data, metadata


@pytest.fixture(scope='module')
def composite_keys_multi_table_demo():
real_data, synthetic_data, metadata = load_demo(modality='multi_table')
metadata['tables']['users']['columns']['user_type'] = {'sdtype': 'categorical'}
metadata['tables']['users']['primary_key'] = ['user_id', 'user_type']
metadata['tables']['sessions']['columns']['user_type'] = {'sdtype': 'categorical'}
metadata['tables']['sessions']['columns']['user_type'] = {'sdtype': 'categorical'}
metadata['tables']['sessions']['primary_key'] = ['session_id', 'device']
metadata['tables']['transactions']['columns']['device'] = {'sdtype': 'categorical'}

metadata['relationships'][0]['parent_primary_key'] = ['user_id', 'user_type']
metadata['relationships'][0]['child_foreign_key'] = ['user_id', 'user_type']
metadata['relationships'][1]['parent_primary_key'] = ['session_id', 'device']
metadata['relationships'][1]['child_foreign_key'] = ['session_id', 'device']

real_data['users']['user_type'] = ['PREMIUM'] * 5 + [None] * 5
synthetic_data['users']['user_type'] = ['PREMIUM'] * 5 + [None] * 5
for data in [real_data, synthetic_data]:
data['sessions']['user_type'] = (
data['users']
.set_index('user_id')
.loc[data['sessions']['user_id']]['user_type']
.to_numpy()
)
data['transactions']['device'] = (
data['sessions']
.set_index('session_id')
.loc[data['transactions']['session_id']]['device']
.to_numpy()
)
premium_mask = data['users']['user_type'] == 'PREMIUM'
data['users'].loc[premium_mask, 'user_id'] = range(5)
data['users'].loc[~premium_mask, 'user_id'] = range(5)
data['sessions'].loc[data['sessions']['user_type'].isna(), 'user_id'] -= 5

return real_data, synthetic_data, metadata
12 changes: 12 additions & 0 deletions tests/integration/reports/multi_table/test_diagnostic_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,18 @@ def test_end_to_end(self):
# Assert
assert results == 1.0

def test_end_to_end_composite_keys(self, composite_keys_multi_table_demo):
"""Test the end-to-end functionality of the ``DiagnosticReport`` report."""
real_data, synthetic_data, metadata = composite_keys_multi_table_demo
report = DiagnosticReport()

# Run
report.generate(real_data, synthetic_data, metadata, verbose=False)
results = report.get_score()

# Assert
assert results == 1.0

def test_end_to_end_with_object_datetimes(self):
"""Test the ``DiagnosticReport`` report with object datetimes."""
real_data, synthetic_data, metadata = load_demo(modality='multi_table')
Expand Down
19 changes: 15 additions & 4 deletions tests/integration/reports/multi_table/test_quality_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import numpy as np
import pandas as pd
import pytest
from packaging import version

from sdmetrics.demos import load_demo
Expand Down Expand Up @@ -230,10 +231,14 @@ def test_multi_table_quality_report():
assert report_info['generation_time'] <= generate_end_time - generate_start_time


def test_quality_report_end_to_end():
@pytest.mark.parametrize('key_type', ['single', 'composite'])
def test_quality_report_end_to_end(key_type, composite_keys_multi_table_demo):
"""Test the multi table QualityReport end to end."""
# Setup
real_data, synthetic_data, metadata = load_demo(modality='multi_table')
if key_type == 'single':
real_data, synthetic_data, metadata = load_demo(modality='multi_table')
else:
real_data, synthetic_data, metadata = composite_keys_multi_table_demo
report = QualityReport()
_set_thresholds_zero(report)

Expand All @@ -244,11 +249,17 @@ def test_quality_report_end_to_end():
info = report.get_info()

# Assert
expected_single_scores = [0.7978174603174604, 0.45654629583521095, 0.95, 0.4416666666666666]
expected_composite_scores = [0.82568543, 0.53305494, 0.95, 0.5375]
expected_properties = pd.DataFrame({
'Property': ['Column Shapes', 'Column Pair Trends', 'Cardinality', 'Intertable Trends'],
'Score': [0.7978174603174604, 0.45654629583521095, 0.95, 0.4416666666666666],
'Score': expected_single_scores if key_type == 'single' else expected_composite_scores,
})
assert score == 0.6615076057048344
if key_type == 'single':
assert score == 0.6615076057048344
else:
assert score == 0.7115600909354644

pd.testing.assert_frame_equal(properties, expected_properties)
expected_info_keys = {
'report_type',
Expand Down
81 changes: 81 additions & 0 deletions tests/integration/reports/single_table/test_diagnostic_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,88 @@ def test_end_to_end(self):
1.0,
],
})
expected_details_data_structure = pd.DataFrame({
'Metric': ['TableStructure'],
'Score': [1.0],
})

pd.testing.assert_frame_equal(
report.get_details('Data Validity'), expected_details_data_validity
)

pd.testing.assert_frame_equal(
report.get_details('Data Structure'), expected_details_data_structure
)

def test_end_to_end_composite_keys(self, composite_keys_single_table_demo):
"""Test the end-to-end functionality of the diagnostic report."""
# Setup
real_data, synthetic_data, metadata = composite_keys_single_table_demo
report = DiagnosticReport()

# Run
report.generate(real_data, synthetic_data, metadata)

# Assert
expected_details_data_validity = pd.DataFrame({
'Column': [
['student_id', 'degree_type'],
'start_date',
'end_date',
'salary',
'duration',
'high_perc',
'high_spec',
'mba_spec',
'second_perc',
'gender',
'degree_perc',
'placed',
'experience_years',
'employability_perc',
'mba_perc',
'work_experience',
'degree_type',
],
'Metric': [
'KeyUniqueness',
'BoundaryAdherence',
'BoundaryAdherence',
'BoundaryAdherence',
'BoundaryAdherence',
'BoundaryAdherence',
'CategoryAdherence',
'CategoryAdherence',
'BoundaryAdherence',
'CategoryAdherence',
'BoundaryAdherence',
'CategoryAdherence',
'BoundaryAdherence',
'BoundaryAdherence',
'BoundaryAdherence',
'CategoryAdherence',
'CategoryAdherence',
],
'Score': [
1.0,
1.0,
1.0,
1.0,
1.0,
1.0,
1.0,
1.0,
1.0,
1.0,
1.0,
1.0,
1.0,
1.0,
1.0,
1.0,
1.0,
],
})
expected_details_data_structure = pd.DataFrame({
'Metric': ['TableStructure'],
'Score': [1.0],
Expand Down
Loading