Skip to content

Commit 9df9d80

Browse files
follow up to PR 321, add unit test for Poisson and Gamma match with sklearn
1 parent 296adf0 commit 9df9d80

File tree

2 files changed

+49
-4
lines changed

2 files changed

+49
-4
lines changed

skglm/estimators.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from skglm.solvers import AndersonCD, MultiTaskBCD, GroupBCD, ProxNewton, LBFGS
2121
from skglm.datafits import (
22-
Cox, Quadratic, Logistic, Poisson, PoissonGroup, QuadraticSVC,
22+
Cox, Quadratic, Logistic, Poisson, PoissonGroup, Gamma, QuadraticSVC,
2323
QuadraticMultiTask, QuadraticGroup,)
2424
from skglm.penalties import (L1, WeightedL1, L1_plus_L2, L2, WeightedGroupL2,
2525
MCPenalty, WeightedMCPenalty, IndicatorBox, L2_1)
@@ -266,7 +266,7 @@ def predict(self, X):
266266
else:
267267
indices = scores.argmax(axis=1)
268268
return self.classes_[indices]
269-
elif isinstance(self.datafit, (Poisson, PoissonGroup)):
269+
elif isinstance(self.datafit, (Poisson, PoissonGroup, Gamma)):
270270
return np.exp(self._decision_function(X))
271271
else:
272272
return self._decision_function(X)

skglm/tests/test_estimators.py

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from sklearn.linear_model import ElasticNet as ElasticNet_sklearn
1515
from sklearn.linear_model import LogisticRegression as LogReg_sklearn
1616
from sklearn.linear_model import MultiTaskLasso as MultiTaskLasso_sklearn
17+
from sklearn.linear_model import PoissonRegressor, GammaRegressor
1718
from sklearn.model_selection import GridSearchCV
1819
from sklearn.svm import LinearSVC as LinearSVC_sklearn
1920
from sklearn.utils.estimator_checks import check_estimator
@@ -23,8 +24,12 @@
2324
from skglm.estimators import (
2425
GeneralizedLinearEstimator, Lasso, MultiTaskLasso, WeightedLasso, ElasticNet,
2526
MCPRegression, SparseLogisticRegression, LinearSVC, GroupLasso, CoxEstimator)
26-
from skglm.datafits import Logistic, Quadratic, QuadraticSVC, QuadraticMultiTask, Cox
27-
from skglm.penalties import L1, IndicatorBox, L1_plus_L2, MCPenalty, WeightedL1, SLOPE
27+
from skglm.datafits import (
28+
Logistic, Quadratic, QuadraticSVC, QuadraticMultiTask, Cox, Poisson, Gamma
29+
)
30+
from skglm.penalties import (
31+
L1, IndicatorBox, L1_plus_L2, MCPenalty, WeightedL1, SLOPE
32+
)
2833
from skglm.solvers import AndersonCD, FISTA, ProxNewton
2934

3035
n_samples = 50
@@ -629,5 +634,45 @@ def test_SLOPE_printing():
629634
assert isinstance(res, str)
630635

631636

637+
def test_poisson_predictions_match_sklearn():
638+
"""Test that skglm Poisson estimator predictions match sklearn PoissonRegressor."""
639+
np.random.seed(42)
640+
X = np.random.randn(20, 5)
641+
y = np.random.poisson(np.exp(X.sum(axis=1) * 0.1))
642+
643+
# Fit sklearn PoissonRegressor (no regularization due to different alpha scaling)
644+
sklearn_pred = PoissonRegressor(
645+
alpha=0.0, max_iter=10_000, tol=1e-8).fit(X, y).predict(X)
646+
647+
# Fit skglm equivalent (no regularization)
648+
skglm_pred = GeneralizedLinearEstimator(
649+
datafit=Poisson(),
650+
penalty=L1_plus_L2(0.0, l1_ratio=0.0),
651+
solver=ProxNewton(fit_intercept=True, max_iter=10_000, tol=1e-8)
652+
).fit(X, y).predict(X)
653+
654+
np.testing.assert_allclose(sklearn_pred, skglm_pred, rtol=1e-6, atol=1e-8)
655+
656+
657+
def test_gamma_predictions_match_sklearn():
658+
"""Test that skglm Gamma estimator predictions match sklearn GammaRegressor."""
659+
np.random.seed(42)
660+
X = np.random.randn(20, 5)
661+
y = np.random.gamma(2.0, np.exp(X.sum(axis=1) * 0.1))
662+
663+
# Fit sklearn GammaRegressor (no regularization due to different alpha scaling)
664+
sklearn_pred = GammaRegressor(
665+
alpha=0.0, max_iter=10_000, tol=1e-8).fit(X, y).predict(X)
666+
667+
# Fit skglm equivalent (no regularization)
668+
skglm_pred = GeneralizedLinearEstimator(
669+
datafit=Gamma(),
670+
penalty=L1_plus_L2(0.0, l1_ratio=0.0),
671+
solver=ProxNewton(fit_intercept=True, max_iter=10_000, tol=1e-8)
672+
).fit(X, y).predict(X)
673+
674+
np.testing.assert_allclose(sklearn_pred, skglm_pred, rtol=1e-6, atol=1e-8)
675+
676+
632677
if __name__ == "__main__":
633678
pass

0 commit comments

Comments
 (0)