|
3 | 3 | from scipy.stats import norm, binom |
4 | 4 | from scipy.optimize import brentq, minimize |
5 | 5 | from statsmodels.regression.linear_model import OLS, WLS |
6 | | -from statsmodels.stats.weightstats import _zconfint_generic, _zstat_generic |
| 6 | +from statsmodels.stats.weightstats import ( |
| 7 | + _zconfint_generic, |
| 8 | + _zstat_generic, |
| 9 | + _zstat_generic2, |
| 10 | +) |
7 | 11 | from sklearn.linear_model import LogisticRegression, PoissonRegressor |
8 | 12 | import warnings |
9 | 13 |
|
@@ -962,6 +966,112 @@ def _logistic_get_stats( |
962 | 966 | return grads, grads_hat, grads_hat_unlabeled, inv_hessian |
963 | 967 |
|
964 | 968 |
|
| 969 | +def ppi_logistic_pval( |
| 970 | + X, |
| 971 | + Y, |
| 972 | + Yhat, |
| 973 | + X_unlabeled, |
| 974 | + Yhat_unlabeled, |
| 975 | + lam=None, |
| 976 | + coord=None, |
| 977 | + optimizer_options=None, |
| 978 | + w=None, |
| 979 | + w_unlabeled=None, |
| 980 | + alternative="two-sided", |
| 981 | +): |
| 982 | + """Computes the prediction-powered pvalues on the logistic regression coefficients for the null hypothesis that the coefficient is zero. |
| 983 | +
|
| 984 | + Args: |
| 985 | + X (ndarray): Covariates corresponding to the gold-standard labels. |
| 986 | + Y (ndarray): Gold-standard labels. |
| 987 | + Yhat (ndarray): Predictions corresponding to the gold-standard labels. |
| 988 | + X_unlabeled (ndarray): Covariates corresponding to the unlabeled data. |
| 989 | + Yhat_unlabeled (ndarray): Predictions corresponding to the unlabeled data. |
| 990 | + lam (float, optional): Power-tuning parameter (see `[ADZ23] <https://arxiv.org/abs/2311.01453>`__). The default value `None` will estimate the optimal value from data. Setting `lam=1` recovers PPI with no power tuning, and setting `lam=0` recovers the classical point estimate. |
| 991 | + coord (int, optional): Coordinate for which to optimize `lam`. If `None`, it optimizes the total variance over all coordinates. Must be in {1, ..., d} where d is the shape of the estimand. |
| 992 | + optimizer_options (dict, optional): Options to pass to the optimizer. See scipy.optimize.minimize for details. |
| 993 | + w (ndarray, optional): Sample weights for the labeled data set. |
| 994 | + w_unlabeled (ndarray, optional): Sample weights for the unlabeled data set. |
| 995 | + alternative (str, optional): Alternative hypothesis, either 'two-sided', 'larger' or 'smaller'. |
| 996 | +
|
| 997 | + Returns: |
| 998 | + ndarray: Prediction-powered point estimate of the logistic regression coefficients. |
| 999 | +
|
| 1000 | + Notes: |
| 1001 | + `[ADZ23] <https://arxiv.org/abs/2311.01453>`__ A. N. Angelopoulos, J. C. Duchi, and T. Zrnic. PPI++: Efficient Prediction Powered Inference. arxiv:2311.01453, 2023. |
| 1002 | + """ |
| 1003 | + n = Y.shape[0] |
| 1004 | + d = X.shape[1] |
| 1005 | + N = Yhat_unlabeled.shape[0] |
| 1006 | + w = np.ones(n) if w is None else w / w.sum() * n |
| 1007 | + w_unlabeled = ( |
| 1008 | + np.ones(N) |
| 1009 | + if w_unlabeled is None |
| 1010 | + else w_unlabeled / w_unlabeled.sum() * N |
| 1011 | + ) |
| 1012 | + use_unlabeled = lam != 0 |
| 1013 | + |
| 1014 | + ppi_pointest = ppi_logistic_pointestimate( |
| 1015 | + X, |
| 1016 | + Y, |
| 1017 | + Yhat, |
| 1018 | + X_unlabeled, |
| 1019 | + Yhat_unlabeled, |
| 1020 | + optimizer_options=optimizer_options, |
| 1021 | + lam=lam, |
| 1022 | + coord=coord, |
| 1023 | + w=w, |
| 1024 | + w_unlabeled=w_unlabeled, |
| 1025 | + ) |
| 1026 | + |
| 1027 | + grads, grads_hat, grads_hat_unlabeled, inv_hessian = _logistic_get_stats( |
| 1028 | + ppi_pointest, |
| 1029 | + X, |
| 1030 | + Y, |
| 1031 | + Yhat, |
| 1032 | + X_unlabeled, |
| 1033 | + Yhat_unlabeled, |
| 1034 | + w, |
| 1035 | + w_unlabeled, |
| 1036 | + use_unlabeled=use_unlabeled, |
| 1037 | + ) |
| 1038 | + |
| 1039 | + if lam is None: |
| 1040 | + lam = _calc_lam_glm( |
| 1041 | + grads, |
| 1042 | + grads_hat, |
| 1043 | + grads_hat_unlabeled, |
| 1044 | + inv_hessian, |
| 1045 | + clip=True, |
| 1046 | + ) |
| 1047 | + return ppi_logistic_pval( |
| 1048 | + X, |
| 1049 | + Y, |
| 1050 | + Yhat, |
| 1051 | + X_unlabeled, |
| 1052 | + Yhat_unlabeled, |
| 1053 | + optimizer_options=optimizer_options, |
| 1054 | + lam=lam, |
| 1055 | + coord=coord, |
| 1056 | + w=w, |
| 1057 | + w_unlabeled=w_unlabeled, |
| 1058 | + alternative=alternative, |
| 1059 | + ) |
| 1060 | + |
| 1061 | + var_unlabeled = np.cov(lam * grads_hat_unlabeled.T).reshape(d, d) |
| 1062 | + var = np.cov(grads.T - lam * grads_hat.T).reshape(d, d) |
| 1063 | + Sigma_hat = inv_hessian @ (n / N * var_unlabeled + var) @ inv_hessian |
| 1064 | + |
| 1065 | + var_diag = np.sqrt(np.diag(Sigma_hat) / n) |
| 1066 | + |
| 1067 | + pvals = _zstat_generic2( |
| 1068 | + ppi_pointest, |
| 1069 | + var_diag, |
| 1070 | + alternative=alternative, |
| 1071 | + )[1] |
| 1072 | + return pvals |
| 1073 | + |
| 1074 | + |
965 | 1075 | def ppi_logistic_ci( |
966 | 1076 | X, |
967 | 1077 | Y, |
|
0 commit comments