Skip to content

Commit 2479119

Browse files
add inverse link function
1 parent 9df9d80 commit 2479119

File tree

3 files changed

+15
-3
lines changed

3 files changed

+15
-3
lines changed

skglm/datafits/group.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,10 @@ class PoissonGroup(Poisson):
184184
def __init__(self, grp_ptr, grp_indices):
185185
self.grp_ptr, self.grp_indices = grp_ptr, grp_indices
186186

187+
@staticmethod
188+
def inverse_link(x):
189+
return np.exp(x)
190+
187191
def get_spec(self):
188192
return (
189193
('grp_ptr', int32[:]),

skglm/datafits/single_task.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -590,6 +590,10 @@ class Poisson(BaseDatafit):
590590
def __init__(self):
591591
pass
592592

593+
@staticmethod
594+
def inverse_link(x):
595+
return np.exp(x)
596+
593597
def get_spec(self):
594598
pass
595599

@@ -664,6 +668,10 @@ class Gamma(BaseDatafit):
664668
def __init__(self):
665669
pass
666670

671+
@staticmethod
672+
def inverse_link(x):
673+
return np.exp(x)
674+
667675
def get_spec(self):
668676
pass
669677

skglm/estimators.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from skglm.solvers import AndersonCD, MultiTaskBCD, GroupBCD, ProxNewton, LBFGS
2121
from skglm.datafits import (
22-
Cox, Quadratic, Logistic, Poisson, PoissonGroup, Gamma, QuadraticSVC,
22+
Cox, Quadratic, Logistic, QuadraticSVC,
2323
QuadraticMultiTask, QuadraticGroup,)
2424
from skglm.penalties import (L1, WeightedL1, L1_plus_L2, L2, WeightedGroupL2,
2525
MCPenalty, WeightedMCPenalty, IndicatorBox, L2_1)
@@ -266,8 +266,8 @@ def predict(self, X):
266266
else:
267267
indices = scores.argmax(axis=1)
268268
return self.classes_[indices]
269-
elif isinstance(self.datafit, (Poisson, PoissonGroup, Gamma)):
270-
return np.exp(self._decision_function(X))
269+
elif hasattr(self.datafit, "inverse_link"):
270+
return self.datafit.inverse_link(self._decision_function(X))
271271
else:
272272
return self._decision_function(X)
273273

0 commit comments

Comments
 (0)