Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/changes/0.6.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,5 @@

Version 0.6 (in progress)
-------------------------

- :class:`skglm.solvers.LBFGS` now supports fitting an intercept with the `fit_intercept` parameter.
50 changes: 35 additions & 15 deletions skglm/solvers/lbfgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,21 @@ class LBFGS(BaseSolver):
tol : float, default 1e-4
Tolerance for convergence.

fit_intercept : bool, default False
Whether or not to fit an intercept.

verbose : bool, default False
Amount of verbosity. 0/False is silent.
"""

_datafit_required_attr = ("gradient",)
_penalty_required_attr = ("gradient",)

def __init__(self, max_iter=50, tol=1e-4, verbose=False):
def __init__(self, max_iter=50, tol=1e-4, fit_intercept=False, verbose=False):
self.max_iter = max_iter
self.tol = tol
self.fit_intercept = fit_intercept
self.warm_start = False
self.verbose = verbose

def _solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
Expand All @@ -46,25 +51,40 @@ def _solve(self, X, y, datafit, penalty, w_init=None, Xw_init=None):
datafit.initialize(X, y)

def objective(w):
Xw = X @ w
datafit_value = datafit.value(y, w, Xw)
penalty_value = penalty.value(w)

w_features = w[:n_features]
Xw = X @ w_features
if self.fit_intercept:
Xw += w[-1]
datafit_value = datafit.value(y, w_features, Xw)
penalty_value = penalty.value(w_features)
return datafit_value + penalty_value

def d_jac(w):
Xw = X @ w
w_features = w[:n_features]
Xw = X @ w_features
if self.fit_intercept:
Xw += w[-1]
datafit_grad = datafit.gradient(X, y, Xw)
penalty_grad = penalty.gradient(w)

return datafit_grad + penalty_grad
penalty_grad = penalty.gradient(w_features)
if self.fit_intercept:
intercept_grad = datafit.raw_grad(y, Xw).sum()
return np.concatenate([datafit_grad + penalty_grad, [intercept_grad]])
else:
return datafit_grad + penalty_grad

def s_jac(w):
Xw = X @ w
datafit_grad = datafit.gradient_sparse(X.data, X.indptr, X.indices, y, Xw)
penalty_grad = penalty.gradient(w)

return datafit_grad + penalty_grad
w_features = w[:n_features]
Xw = X @ w_features
if self.fit_intercept:
Xw += w[-1]
datafit_grad = datafit.gradient_sparse(
X.data, X.indptr, X.indices, y, Xw)
penalty_grad = penalty.gradient(w_features)
if self.fit_intercept:
intercept_grad = datafit.raw_grad(y, Xw).sum()
return np.concatenate([datafit_grad + penalty_grad, [intercept_grad]])
else:
return datafit_grad + penalty_grad

def callback_post_iter(w_k):
# save p_obj
Expand All @@ -81,7 +101,7 @@ def callback_post_iter(w_k):
)

n_features = X.shape[1]
w = np.zeros(n_features) if w_init is None else w_init
w = np.zeros(n_features + self.fit_intercept) if w_init is None else w_init
jac = s_jac if issparse(X) else d_jac
p_objs_out = []

Expand Down
13 changes: 9 additions & 4 deletions skglm/tests/test_lbfgs_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@


@pytest.mark.parametrize("X_sparse", [True, False])
def test_lbfgs_L2_logreg(X_sparse):
@pytest.mark.parametrize("fit_intercept", [True, False])
def test_lbfgs_L2_logreg(X_sparse, fit_intercept):
reg = 1.0
X_density = 1.0 if not X_sparse else 0.5
n_samples, n_features = 100, 50
Expand All @@ -28,17 +29,21 @@ def test_lbfgs_L2_logreg(X_sparse):
# fit L-BFGS
datafit = Logistic()
penalty = L2(reg)
w, *_ = LBFGS(tol=1e-12).solve(X, y, datafit, penalty)
w, *_ = LBFGS(tol=1e-12, fit_intercept=fit_intercept).solve(X, y, datafit, penalty)

# fit scikit learn
estimator = LogisticRegression(
penalty="l2",
C=1 / (n_samples * reg),
fit_intercept=False,
fit_intercept=fit_intercept,
tol=1e-12,
).fit(X, y)

np.testing.assert_allclose(w, estimator.coef_.flatten(), atol=1e-5)
if fit_intercept:
np.testing.assert_allclose(w[:-1], estimator.coef_.flatten(), atol=1e-5)
np.testing.assert_allclose(w[-1], estimator.intercept_[0], atol=1e-5)
else:
np.testing.assert_allclose(w, estimator.coef_.flatten(), atol=1e-5)


@pytest.mark.parametrize("use_efron", [True, False])
Expand Down