Skip to content

Commit 674fa40

Browse files
committed
Auto-format with Black
1 parent cb245e0 commit 674fa40

File tree

1 file changed

+20
-9
lines changed

1 file changed

+20
-9
lines changed

tests/test_labelshift.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
22
from ppi_py import *
33

4+
45
def test_ppi_distribution_label_shift_ci_sorted():
56
"""Test ppi_distribution_label_shift_ci with various sorting scenarios."""
67
n = 1000
@@ -14,28 +15,38 @@ def test_ppi_distribution_label_shift_ci_sorted():
1415

1516
# Determine K and nu
1617
unique_Yhat = np.unique(Yhat)
17-
K = len(unique_Yhat) # Should be equal to num_classes if all classes are present in Yhat
18+
K = len(
19+
unique_Yhat
20+
) # Should be equal to num_classes if all classes are present in Yhat
1821

1922
# Set nu (example: vector of ones)
2023
nu = np.ones(K)
2124

2225
# Test 1: Unsorted Yhat_unlabeled (implicit in the current implementation)
23-
ci_unsorted = ppi_distribution_label_shift_ci(Y, Yhat, Yhat_unlabeled, K, nu, alpha=0.1)
26+
ci_unsorted = ppi_distribution_label_shift_ci(
27+
Y, Yhat, Yhat_unlabeled, K, nu, alpha=0.1
28+
)
2429

2530
# Test 2: Explicitly sorted Yhat_unlabeled (to demonstrate no dependence on order)
2631
uq, uq_counts = np.unique(Yhat_unlabeled, return_counts=True)
2732
sort_indices = np.argsort(uq)
28-
Yhat_unlabeled_sorted = Yhat_unlabeled[np.argsort(np.argsort(Yhat_unlabeled))] # Stable sort to mimic previous behavior.
29-
ci_sorted = ppi_distribution_label_shift_ci(Y, Yhat, Yhat_unlabeled_sorted, K, nu, alpha=0.1)
30-
33+
Yhat_unlabeled_sorted = Yhat_unlabeled[
34+
np.argsort(np.argsort(Yhat_unlabeled))
35+
] # Stable sort to mimic previous behavior.
36+
ci_sorted = ppi_distribution_label_shift_ci(
37+
Y, Yhat, Yhat_unlabeled_sorted, K, nu, alpha=0.1
38+
)
3139

3240
# Test 3: Reverse sorted Yhat_unlabeled
3341
uq, uq_counts = np.unique(Yhat_unlabeled, return_counts=True)
3442
sort_indices = np.argsort(uq)[::-1]
35-
Yhat_unlabeled_reverse_sorted = Yhat_unlabeled[np.argsort(np.argsort(Yhat_unlabeled))[::-1]]
36-
ci_reverse_sorted = ppi_distribution_label_shift_ci(Y, Yhat, Yhat_unlabeled_reverse_sorted, K, nu, alpha=0.1)
37-
43+
Yhat_unlabeled_reverse_sorted = Yhat_unlabeled[
44+
np.argsort(np.argsort(Yhat_unlabeled))[::-1]
45+
]
46+
ci_reverse_sorted = ppi_distribution_label_shift_ci(
47+
Y, Yhat, Yhat_unlabeled_reverse_sorted, K, nu, alpha=0.1
48+
)
3849

3950
# The confidence intervals should be identical regardless of the sorting of Yhat_unlabeled
4051
np.testing.assert_allclose(ci_unsorted, ci_sorted)
41-
np.testing.assert_allclose(ci_unsorted, ci_reverse_sorted)
52+
np.testing.assert_allclose(ci_unsorted, ci_reverse_sorted)

0 commit comments

Comments
 (0)