|
14 | 14 | from sklearn.linear_model import ElasticNet as ElasticNet_sklearn |
15 | 15 | from sklearn.linear_model import LogisticRegression as LogReg_sklearn |
16 | 16 | from sklearn.linear_model import MultiTaskLasso as MultiTaskLasso_sklearn |
| 17 | +from sklearn.linear_model import PoissonRegressor, GammaRegressor |
17 | 18 | from sklearn.model_selection import GridSearchCV |
18 | 19 | from sklearn.svm import LinearSVC as LinearSVC_sklearn |
19 | 20 | from sklearn.utils.estimator_checks import check_estimator |
|
23 | 24 | from skglm.estimators import ( |
24 | 25 | GeneralizedLinearEstimator, Lasso, MultiTaskLasso, WeightedLasso, ElasticNet, |
25 | 26 | MCPRegression, SparseLogisticRegression, LinearSVC, GroupLasso, CoxEstimator) |
26 | | -from skglm.datafits import Logistic, Quadratic, QuadraticSVC, QuadraticMultiTask, Cox |
27 | | -from skglm.penalties import L1, IndicatorBox, L1_plus_L2, MCPenalty, WeightedL1, SLOPE |
| 27 | +from skglm.datafits import ( |
| 28 | + Logistic, Quadratic, QuadraticSVC, QuadraticMultiTask, Cox, Poisson, Gamma |
| 29 | +) |
| 30 | +from skglm.penalties import ( |
| 31 | + L1, IndicatorBox, L1_plus_L2, MCPenalty, WeightedL1, SLOPE |
| 32 | +) |
28 | 33 | from skglm.solvers import AndersonCD, FISTA, ProxNewton |
29 | 34 |
|
30 | 35 | n_samples = 50 |
@@ -629,5 +634,45 @@ def test_SLOPE_printing(): |
629 | 634 | assert isinstance(res, str) |
630 | 635 |
|
631 | 636 |
|
| 637 | +def test_poisson_predictions_match_sklearn(): |
| 638 | + """Test that skglm Poisson estimator predictions match sklearn PoissonRegressor.""" |
| 639 | + np.random.seed(42) |
| 640 | + X = np.random.randn(20, 5) |
| 641 | + y = np.random.poisson(np.exp(X.sum(axis=1) * 0.1)) |
| 642 | + |
| 643 | + # Fit sklearn PoissonRegressor (no regularization due to different alpha scaling) |
| 644 | + sklearn_pred = PoissonRegressor( |
| 645 | + alpha=0.0, max_iter=10_000, tol=1e-8).fit(X, y).predict(X) |
| 646 | + |
| 647 | + # Fit skglm equivalent (no regularization) |
| 648 | + skglm_pred = GeneralizedLinearEstimator( |
| 649 | + datafit=Poisson(), |
| 650 | + penalty=L1_plus_L2(0.0, l1_ratio=0.0), |
| 651 | + solver=ProxNewton(fit_intercept=True, max_iter=10_000, tol=1e-8) |
| 652 | + ).fit(X, y).predict(X) |
| 653 | + |
| 654 | + np.testing.assert_allclose(sklearn_pred, skglm_pred, rtol=1e-6, atol=1e-8) |
| 655 | + |
| 656 | + |
| 657 | +def test_gamma_predictions_match_sklearn(): |
| 658 | + """Test that skglm Gamma estimator predictions match sklearn GammaRegressor.""" |
| 659 | + np.random.seed(42) |
| 660 | + X = np.random.randn(20, 5) |
| 661 | + y = np.random.gamma(2.0, np.exp(X.sum(axis=1) * 0.1)) |
| 662 | + |
| 663 | + # Fit sklearn GammaRegressor (no regularization due to different alpha scaling) |
| 664 | + sklearn_pred = GammaRegressor( |
| 665 | + alpha=0.0, max_iter=10_000, tol=1e-8).fit(X, y).predict(X) |
| 666 | + |
| 667 | + # Fit skglm equivalent (no regularization) |
| 668 | + skglm_pred = GeneralizedLinearEstimator( |
| 669 | + datafit=Gamma(), |
| 670 | + penalty=L1_plus_L2(0.0, l1_ratio=0.0), |
| 671 | + solver=ProxNewton(fit_intercept=True, max_iter=10_000, tol=1e-8) |
| 672 | + ).fit(X, y).predict(X) |
| 673 | + |
| 674 | + np.testing.assert_allclose(sklearn_pred, skglm_pred, rtol=1e-6, atol=1e-8) |
| 675 | + |
| 676 | + |
632 | 677 | if __name__ == "__main__": |
633 | 678 | pass |
0 commit comments