Skip to content

Commit 3185f8c

Browse files
Julien RousselJulien Roussel
authored andcommitted
metrics tests reworked
1 parent 15510c6 commit 3185f8c

File tree

2 files changed

+15
-12
lines changed

2 files changed

+15
-12
lines changed

qolmat/benchmark/metrics.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def columnwise_metric(
6565
f"({df1.columns} != {df2.columns})"
6666
)
6767
if type_cols == "all":
68-
cols = df1.columns
68+
cols = df1.columns.tolist()
6969
elif type_cols == "numerical":
7070
cols = utils._get_numerical_features(df1)
7171
elif type_cols == "categorical":
@@ -74,6 +74,8 @@ def columnwise_metric(
7474
raise ValueError(
7575
f"Value {type_cols} is not valid for parameter `type_cols`!"
7676
)
77+
if cols == []:
78+
raise ValueError(f"No column found for the type {type_cols}!")
7779
values = {}
7880
for col in cols:
7981
df1_col = df1.loc[df_mask[col], col]
@@ -510,6 +512,8 @@ def mean_difference_correlation_matrix_numerical_features(
510512
_check_same_number_columns(df1, df2)
511513

512514
cols_numerical = utils._get_numerical_features(df1)
515+
if cols_numerical == []:
516+
raise Exception("No numerical feature found")
513517
df_corr1 = _get_correlation_pearson_matrix(
514518
df1[cols_numerical], use_p_value=use_p_value
515519
)
@@ -594,6 +598,8 @@ def mean_difference_correlation_matrix_categorical_features(
594598
_check_same_number_columns(df1, df2)
595599

596600
cols_categorical = utils._get_categorical_features(df1)
601+
if cols_categorical == []:
602+
raise Exception("No categorical feature found")
597603
df_corr1 = _get_correlation_chi2_matrix(
598604
df1[cols_categorical], use_p_value=use_p_value
599605
)
@@ -681,7 +687,11 @@ def mean_diff_corr_matrix_categorical_vs_numerical_features(
681687
_check_same_number_columns(df1, df2)
682688

683689
cols_categorical = utils._get_categorical_features(df1)
690+
if cols_categorical == []:
691+
raise Exception("No categorical feature found")
684692
cols_numerical = utils._get_numerical_features(df1)
693+
if cols_numerical == []:
694+
raise Exception("No numerical feature found")
685695
df_corr1 = _get_correlation_f_oneway_matrix(
686696
df1, cols_categorical, cols_numerical, use_p_value=use_p_value
687697
)

qolmat/utils/utils.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,11 @@ def _get_numerical_features(df1: pd.DataFrame) -> List[str]:
2727
Raises
2828
------
2929
Exception
30-
No numerical feature is found
30+
No numerical feature found
3131
3232
"""
3333
cols_numerical = df1.select_dtypes(include=np.number).columns.tolist()
34-
if len(cols_numerical) == 0:
35-
print(df1)
36-
raise Exception("No numerical feature is found.")
37-
else:
38-
return cols_numerical
34+
return cols_numerical
3935

4036

4137
def _get_categorical_features(df1: pd.DataFrame) -> List[str]:
@@ -54,17 +50,14 @@ def _get_categorical_features(df1: pd.DataFrame) -> List[str]:
5450
Raises
5551
------
5652
Exception
57-
No categorical feature is found
53+
No categorical feature found
5854
5955
"""
6056
cols_numerical = df1.select_dtypes(include=np.number).columns.tolist()
6157
cols_categorical = [
6258
col for col in df1.columns.to_list() if col not in cols_numerical
6359
]
64-
if len(cols_categorical) == 0:
65-
raise Exception("No categorical feature is found.")
66-
else:
67-
return cols_categorical
60+
return cols_categorical
6861

6962

7063
def _validate_input(X: NDArray) -> pd.DataFrame:

0 commit comments

Comments
 (0)