1+ import numpy as np
2+ from ppi_py import *
3+
4+ def test_ppi_distribution_label_shift_ci_sorted ():
5+ """Test ppi_distribution_label_shift_ci with various sorting scenarios."""
6+ n = 1000
7+ N = 10000
8+ num_classes = 10
9+
10+ # Generate multi-class predictions (integer labels)
11+ Y = np .random .randint (0 , num_classes , n )
12+ Yhat = np .random .randint (0 , num_classes , n )
13+ Yhat_unlabeled = np .random .randint (0 , num_classes , N )
14+
15+ # Determine K and nu
16+ unique_Yhat = np .unique (Yhat )
17+ K = len (unique_Yhat ) # Should be equal to num_classes if all classes are present in Yhat
18+
19+ # Set nu (example: vector of ones)
20+ nu = np .ones (K )
21+
22+ # Test 1: Unsorted Yhat_unlabeled (implicit in the current implementation)
23+ ci_unsorted = ppi_distribution_label_shift_ci (Y , Yhat , Yhat_unlabeled , K , nu , alpha = 0.1 )
24+
25+ # Test 2: Explicitly sorted Yhat_unlabeled (to demonstrate no dependence on order)
26+ uq , uq_counts = np .unique (Yhat_unlabeled , return_counts = True )
27+ sort_indices = np .argsort (uq )
28+ Yhat_unlabeled_sorted = Yhat_unlabeled [np .argsort (np .argsort (Yhat_unlabeled ))] # Stable sort to mimic previous behavior.
29+ ci_sorted = ppi_distribution_label_shift_ci (Y , Yhat , Yhat_unlabeled_sorted , K , nu , alpha = 0.1 )
30+
31+
32+ # Test 3: Reverse sorted Yhat_unlabeled
33+ uq , uq_counts = np .unique (Yhat_unlabeled , return_counts = True )
34+ sort_indices = np .argsort (uq )[::- 1 ]
35+ Yhat_unlabeled_reverse_sorted = Yhat_unlabeled [np .argsort (np .argsort (Yhat_unlabeled ))[::- 1 ]]
36+ ci_reverse_sorted = ppi_distribution_label_shift_ci (Y , Yhat , Yhat_unlabeled_reverse_sorted , K , nu , alpha = 0.1 )
37+
38+
39+ # The confidence intervals should be identical regardless of the sorting of Yhat_unlabeled
40+ np .testing .assert_allclose (ci_unsorted , ci_sorted )
41+ np .testing .assert_allclose (ci_unsorted , ci_reverse_sorted )
0 commit comments