11import numpy as np
22from ppi_py import *
33
4+
45def test_ppi_distribution_label_shift_ci_sorted ():
56 """Test ppi_distribution_label_shift_ci with various sorting scenarios."""
67 n = 1000
@@ -14,28 +15,38 @@ def test_ppi_distribution_label_shift_ci_sorted():
1415
1516 # Determine K and nu
1617 unique_Yhat = np .unique (Yhat )
17- K = len (unique_Yhat ) # Should be equal to num_classes if all classes are present in Yhat
18+ K = len (
19+ unique_Yhat
20+ ) # Should be equal to num_classes if all classes are present in Yhat
1821
1922 # Set nu (example: vector of ones)
2023 nu = np .ones (K )
2124
2225 # Test 1: Unsorted Yhat_unlabeled (implicit in the current implementation)
23- ci_unsorted = ppi_distribution_label_shift_ci (Y , Yhat , Yhat_unlabeled , K , nu , alpha = 0.1 )
26+ ci_unsorted = ppi_distribution_label_shift_ci (
27+ Y , Yhat , Yhat_unlabeled , K , nu , alpha = 0.1
28+ )
2429
2530 # Test 2: Explicitly sorted Yhat_unlabeled (to demonstrate no dependence on order)
2631 uq , uq_counts = np .unique (Yhat_unlabeled , return_counts = True )
2732 sort_indices = np .argsort (uq )
28- Yhat_unlabeled_sorted = Yhat_unlabeled [np .argsort (np .argsort (Yhat_unlabeled ))] # Stable sort to mimic previous behavior.
29- ci_sorted = ppi_distribution_label_shift_ci (Y , Yhat , Yhat_unlabeled_sorted , K , nu , alpha = 0.1 )
30-
33+ Yhat_unlabeled_sorted = Yhat_unlabeled [
34+ np .argsort (np .argsort (Yhat_unlabeled ))
35+ ] # Stable sort to mimic previous behavior.
36+ ci_sorted = ppi_distribution_label_shift_ci (
37+ Y , Yhat , Yhat_unlabeled_sorted , K , nu , alpha = 0.1
38+ )
3139
3240 # Test 3: Reverse sorted Yhat_unlabeled
3341 uq , uq_counts = np .unique (Yhat_unlabeled , return_counts = True )
3442 sort_indices = np .argsort (uq )[::- 1 ]
35- Yhat_unlabeled_reverse_sorted = Yhat_unlabeled [np .argsort (np .argsort (Yhat_unlabeled ))[::- 1 ]]
36- ci_reverse_sorted = ppi_distribution_label_shift_ci (Y , Yhat , Yhat_unlabeled_reverse_sorted , K , nu , alpha = 0.1 )
37-
43+ Yhat_unlabeled_reverse_sorted = Yhat_unlabeled [
44+ np .argsort (np .argsort (Yhat_unlabeled ))[::- 1 ]
45+ ]
46+ ci_reverse_sorted = ppi_distribution_label_shift_ci (
47+ Y , Yhat , Yhat_unlabeled_reverse_sorted , K , nu , alpha = 0.1
48+ )
3849
3950 # The confidence intervals should be identical regardless of the sorting of Yhat_unlabeled
4051 np .testing .assert_allclose (ci_unsorted , ci_sorted )
41- np .testing .assert_allclose (ci_unsorted , ci_reverse_sorted )
52+ np .testing .assert_allclose (ci_unsorted , ci_reverse_sorted )
0 commit comments