Skip to content

Commit 0bf1ccf

Browse files
Merge pull request #21 from PierreBoyeau/logistic_pvals
Pvalues for logistic regression
2 parents 674fa40 + caef9a9 commit 0bf1ccf

File tree

2 files changed

+146
-1
lines changed

2 files changed

+146
-1
lines changed

ppi_py/ppi.py

Lines changed: 111 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,11 @@
33
from scipy.stats import norm, binom
44
from scipy.optimize import brentq, minimize
55
from statsmodels.regression.linear_model import OLS, WLS
6-
from statsmodels.stats.weightstats import _zconfint_generic, _zstat_generic
6+
from statsmodels.stats.weightstats import (
7+
_zconfint_generic,
8+
_zstat_generic,
9+
_zstat_generic2,
10+
)
711
from sklearn.linear_model import LogisticRegression, PoissonRegressor
812
import warnings
913

@@ -962,6 +966,112 @@ def _logistic_get_stats(
962966
return grads, grads_hat, grads_hat_unlabeled, inv_hessian
963967

964968

969+
def ppi_logistic_pval(
970+
X,
971+
Y,
972+
Yhat,
973+
X_unlabeled,
974+
Yhat_unlabeled,
975+
lam=None,
976+
coord=None,
977+
optimizer_options=None,
978+
w=None,
979+
w_unlabeled=None,
980+
alternative="two-sided",
981+
):
982+
"""Computes the prediction-powered pvalues on the logistic regression coefficients for the null hypothesis that the coefficient is zero.
983+
984+
Args:
985+
X (ndarray): Covariates corresponding to the gold-standard labels.
986+
Y (ndarray): Gold-standard labels.
987+
Yhat (ndarray): Predictions corresponding to the gold-standard labels.
988+
X_unlabeled (ndarray): Covariates corresponding to the unlabeled data.
989+
Yhat_unlabeled (ndarray): Predictions corresponding to the unlabeled data.
990+
lam (float, optional): Power-tuning parameter (see `[ADZ23] <https://arxiv.org/abs/2311.01453>`__). The default value `None` will estimate the optimal value from data. Setting `lam=1` recovers PPI with no power tuning, and setting `lam=0` recovers the classical point estimate.
991+
coord (int, optional): Coordinate for which to optimize `lam`. If `None`, it optimizes the total variance over all coordinates. Must be in {1, ..., d} where d is the shape of the estimand.
992+
optimizer_options (dict, optional): Options to pass to the optimizer. See scipy.optimize.minimize for details.
993+
w (ndarray, optional): Sample weights for the labeled data set.
994+
w_unlabeled (ndarray, optional): Sample weights for the unlabeled data set.
995+
alternative (str, optional): Alternative hypothesis, either 'two-sided', 'larger' or 'smaller'.
996+
997+
Returns:
998+
ndarray: Prediction-powered point estimate of the logistic regression coefficients.
999+
1000+
Notes:
1001+
`[ADZ23] <https://arxiv.org/abs/2311.01453>`__ A. N. Angelopoulos, J. C. Duchi, and T. Zrnic. PPI++: Efficient Prediction Powered Inference. arxiv:2311.01453, 2023.
1002+
"""
1003+
n = Y.shape[0]
1004+
d = X.shape[1]
1005+
N = Yhat_unlabeled.shape[0]
1006+
w = np.ones(n) if w is None else w / w.sum() * n
1007+
w_unlabeled = (
1008+
np.ones(N)
1009+
if w_unlabeled is None
1010+
else w_unlabeled / w_unlabeled.sum() * N
1011+
)
1012+
use_unlabeled = lam != 0
1013+
1014+
ppi_pointest = ppi_logistic_pointestimate(
1015+
X,
1016+
Y,
1017+
Yhat,
1018+
X_unlabeled,
1019+
Yhat_unlabeled,
1020+
optimizer_options=optimizer_options,
1021+
lam=lam,
1022+
coord=coord,
1023+
w=w,
1024+
w_unlabeled=w_unlabeled,
1025+
)
1026+
1027+
grads, grads_hat, grads_hat_unlabeled, inv_hessian = _logistic_get_stats(
1028+
ppi_pointest,
1029+
X,
1030+
Y,
1031+
Yhat,
1032+
X_unlabeled,
1033+
Yhat_unlabeled,
1034+
w,
1035+
w_unlabeled,
1036+
use_unlabeled=use_unlabeled,
1037+
)
1038+
1039+
if lam is None:
1040+
lam = _calc_lam_glm(
1041+
grads,
1042+
grads_hat,
1043+
grads_hat_unlabeled,
1044+
inv_hessian,
1045+
clip=True,
1046+
)
1047+
return ppi_logistic_pval(
1048+
X,
1049+
Y,
1050+
Yhat,
1051+
X_unlabeled,
1052+
Yhat_unlabeled,
1053+
optimizer_options=optimizer_options,
1054+
lam=lam,
1055+
coord=coord,
1056+
w=w,
1057+
w_unlabeled=w_unlabeled,
1058+
alternative=alternative,
1059+
)
1060+
1061+
var_unlabeled = np.cov(lam * grads_hat_unlabeled.T).reshape(d, d)
1062+
var = np.cov(grads.T - lam * grads_hat.T).reshape(d, d)
1063+
Sigma_hat = inv_hessian @ (n / N * var_unlabeled + var) @ inv_hessian
1064+
1065+
var_diag = np.sqrt(np.diag(Sigma_hat) / n)
1066+
1067+
pvals = _zstat_generic2(
1068+
ppi_pointest,
1069+
var_diag,
1070+
alternative=alternative,
1071+
)[1]
1072+
return pvals
1073+
1074+
9651075
def ppi_logistic_ci(
9661076
X,
9671077
Y,

tests/test_logistic.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,41 @@ def test_ppi_logistic_pointestimate_recovers():
6464
assert np.linalg.norm(beta_ppi_pointestimate - beta) < 0.2
6565

6666

67+
def test_ppi_logistic_pval_makesense():
68+
# Make a synthetic regression problem
69+
n = 10000
70+
N = 100000
71+
d = 3
72+
X = np.random.randn(n, d)
73+
beta = np.array([0, 0, 1.0])
74+
75+
Y = np.random.binomial(1, expit(X.dot(beta)))
76+
Yhat = expit(X.dot(beta))
77+
X_unlabeled = np.random.randn(N, d)
78+
Yhat_unlabeled = expit(X_unlabeled.dot(beta))
79+
beta_ppi_pval = ppi_logistic_pval(
80+
X,
81+
Y,
82+
Yhat,
83+
X_unlabeled,
84+
Yhat_unlabeled,
85+
lam=0.5,
86+
optimizer_options={"gtol": 1e-3},
87+
)
88+
assert beta_ppi_pval[-1] < 0.1
89+
90+
beta_ppi_pval = ppi_logistic_pval(
91+
X,
92+
Y,
93+
Yhat,
94+
X_unlabeled,
95+
Yhat_unlabeled,
96+
lam=None,
97+
optimizer_options={"gtol": 1e-3},
98+
)
99+
assert beta_ppi_pval[-1] < 0.1
100+
101+
67102
def ppi_logistic_ci_subtest(i, alphas, n=1000, N=10000, d=1, epsilon=0.02):
68103
includeds = np.zeros(len(alphas))
69104
# Make a synthetic regression problem

0 commit comments

Comments
 (0)