Skip to content

Commit cb245e0

Browse files
Merge pull request #15 from justinkay/bugfix/sorted_highlow
Remove sorted_highlow from form_discrete_distribution call in ppi_distribution_label_shift_ci.
2 parents d5522d5 + da5970a commit cb245e0

File tree

2 files changed

+42
-1
lines changed

2 files changed

+42
-1
lines changed

ppi_py/ppi.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1759,7 +1759,7 @@ def ppi_distribution_label_shift_ci(
17591759

17601760
# Invert Ahat
17611761
Ahatinv = np.linalg.inv(Ahat)
1762-
qfhat = form_discrete_distribution(Yhat_unlabeled, sorted_highlow=True)
1762+
qfhat = form_discrete_distribution(Yhat_unlabeled)
17631763

17641764
# Calculate the bound
17651765
point_estimate = nu @ Ahatinv @ qfhat

tests/test_labelshift.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import numpy as np
2+
from ppi_py import *
3+
4+
def test_ppi_distribution_label_shift_ci_sorted():
5+
"""Test ppi_distribution_label_shift_ci with various sorting scenarios."""
6+
n = 1000
7+
N = 10000
8+
num_classes = 10
9+
10+
# Generate multi-class predictions (integer labels)
11+
Y = np.random.randint(0, num_classes, n)
12+
Yhat = np.random.randint(0, num_classes, n)
13+
Yhat_unlabeled = np.random.randint(0, num_classes, N)
14+
15+
# Determine K and nu
16+
unique_Yhat = np.unique(Yhat)
17+
K = len(unique_Yhat) # Should be equal to num_classes if all classes are present in Yhat
18+
19+
# Set nu (example: vector of ones)
20+
nu = np.ones(K)
21+
22+
# Test 1: Unsorted Yhat_unlabeled (implicit in the current implementation)
23+
ci_unsorted = ppi_distribution_label_shift_ci(Y, Yhat, Yhat_unlabeled, K, nu, alpha=0.1)
24+
25+
# Test 2: Explicitly sorted Yhat_unlabeled (to demonstrate no dependence on order)
26+
uq, uq_counts = np.unique(Yhat_unlabeled, return_counts=True)
27+
sort_indices = np.argsort(uq)
28+
Yhat_unlabeled_sorted = Yhat_unlabeled[np.argsort(np.argsort(Yhat_unlabeled))] # Stable sort to mimic previous behavior.
29+
ci_sorted = ppi_distribution_label_shift_ci(Y, Yhat, Yhat_unlabeled_sorted, K, nu, alpha=0.1)
30+
31+
32+
# Test 3: Reverse sorted Yhat_unlabeled
33+
uq, uq_counts = np.unique(Yhat_unlabeled, return_counts=True)
34+
sort_indices = np.argsort(uq)[::-1]
35+
Yhat_unlabeled_reverse_sorted = Yhat_unlabeled[np.argsort(np.argsort(Yhat_unlabeled))[::-1]]
36+
ci_reverse_sorted = ppi_distribution_label_shift_ci(Y, Yhat, Yhat_unlabeled_reverse_sorted, K, nu, alpha=0.1)
37+
38+
39+
# The confidence intervals should be identical regardless of the sorting of Yhat_unlabeled
40+
np.testing.assert_allclose(ci_unsorted, ci_sorted)
41+
np.testing.assert_allclose(ci_unsorted, ci_reverse_sorted)

0 commit comments

Comments
 (0)