Skip to content

Commit 838c116

Browse files
committed
Adding multioutput support for KPCovC
1 parent bb7147c commit 838c116

File tree

6 files changed

+237
-148
lines changed

6 files changed

+237
-148
lines changed

src/skmatter/decomposition/_kernel_pcovc.py

Lines changed: 72 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
22

33
from sklearn import clone
4+
from sklearn.multioutput import MultiOutputClassifier
45
from sklearn.svm import LinearSVC
56
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
67
from sklearn.linear_model import (
@@ -24,7 +25,7 @@
2425
class KernelPCovC(LinearClassifierMixin, _BaseKPCov):
2526
r"""Kernel Principal Covariates Classification (KPCovC).
2627
27-
KPCovC is a modification on the PrincipalCovariates Classification
28+
KPCovC is a modification on the Principal Covariates Classification
2829
proposed in [Jorgensen2025]_. It determines a latent-space projection
2930
:math:`\mathbf{T}` which minimizes a combined loss in supervised and unsupervised
3031
tasks in the reproducing kernel Hilbert space (RKHS).
@@ -52,6 +53,9 @@ class KernelPCovC(LinearClassifierMixin, _BaseKPCov):
5253
5354
n_components == n_samples
5455
56+
n_outputs : int
57+
The number of outputs when ``fit`` is performed.
58+
5559
svd_solver : {'auto', 'full', 'arpack', 'randomized'}, default='auto'
5660
If auto :
5761
The solver is selected by a default policy based on `X.shape` and
@@ -78,13 +82,14 @@ class KernelPCovC(LinearClassifierMixin, _BaseKPCov):
7882
- ``sklearn.linear_model.LogisticRegressionCV()``
7983
- ``sklearn.svm.LinearSVC()``
8084
- ``sklearn.discriminant_analysis.LinearDiscriminantAnalysis()``
85+
- ``sklearn.multioutput.MultiOutputClassifier()``
8186
- ``sklearn.linear_model.RidgeClassifier()``
8287
- ``sklearn.linear_model.RidgeClassifierCV()``
8388
- ``sklearn.linear_model.Perceptron()``
8489
8590
If a pre-fitted classifier is provided, it is used to compute :math:`{\mathbf{Z}}`.
86-
If None, ``sklearn.linear_model.LogisticRegression()``
87-
is used as the classifier.
91+
If None and ``n_outputs < 2``, ``sklearn.linear_model.LogisticRegression()`` is used.
92+
If None and ``n_outputs == 2``, ``sklearn.multioutput.MultiOutputClassifier()`` is used.
8893
8994
kernel : {"linear", "poly", "rbf", "sigmoid", "precomputed"} or callable, default="linear"
9095
Kernel.
@@ -132,6 +137,9 @@ class KernelPCovC(LinearClassifierMixin, _BaseKPCov):
132137
133138
Attributes
134139
----------
140+
n_outputs : int
141+
The number of outputs when ``fit`` is performed.
142+
135143
classifier : estimator object
136144
The linear classifier passed for fitting. If pre-fitted, it is assummed
137145
to be fit on a precomputed kernel :math:`\mathbf{K}` and :math:`\mathbf{Y}`.
@@ -268,9 +276,11 @@ def fit(self, X, Y, W=None):
268276
self: object
269277
Returns the instance itself.
270278
"""
271-
X, Y = validate_data(self, X, Y, y_numeric=False)
279+
X, Y = validate_data(self, X, Y, multi_output=True, y_numeric=False)
280+
272281
check_classification_targets(Y)
273282
self.classes_ = np.unique(Y)
283+
self.n_outputs = 1 if Y.ndim == 1 else Y.shape[1]
274284

275285
super().fit(X)
276286

@@ -285,6 +295,7 @@ def fit(self, X, Y, W=None):
285295
LogisticRegressionCV,
286296
LinearSVC,
287297
LinearDiscriminantAnalysis,
298+
MultiOutputClassifier,
288299
RidgeClassifier,
289300
RidgeClassifierCV,
290301
SGDClassifier,
@@ -300,28 +311,37 @@ def fit(self, X, Y, W=None):
300311
", or `precomputed`"
301312
)
302313

303-
if self.classifier != "precomputed":
304-
if self.classifier is None:
305-
classifier = LogisticRegression()
306-
else:
307-
classifier = self.classifier
314+
multioutput = self.n_outputs != 1
315+
precomputed = self.classifier == "precomputed"
308316

309-
# for convergence warnings
310-
if hasattr(classifier, "max_iter") and (
311-
classifier.max_iter is None or classifier.max_iter < 500
312-
):
313-
classifier.max_iter = 500
317+
if self.classifier is None or precomputed:
318+
# used as the default classifier for subsequent computations
319+
classifier = (
320+
MultiOutputClassifier(LogisticRegression())
321+
if multioutput
322+
else LogisticRegression()
323+
)
324+
else:
325+
classifier = self.classifier
314326

315-
# Check if classifier is fitted; if not, fit with precomputed K
316-
self.z_classifier_ = check_cl_fit(classifier, K, Y)
317-
W = self.z_classifier_.coef_.T.reshape(K.shape[1], -1)
327+
if hasattr(classifier, "max_iter") and (
328+
classifier.max_iter is None or classifier.max_iter < 500
329+
):
330+
classifier.max_iter = 500
331+
332+
if precomputed and W is None:
333+
_ = clone(classifier).fit(K, Y)
334+
if multioutput:
335+
W = np.hstack([_.coef_.T for _ in _.estimators_])
336+
else:
337+
W = _.coef_.T
318338

319339
else:
320-
# If precomputed, use default classifier to predict Y from T
321-
classifier = LogisticRegression(max_iter=500)
322-
if W is None:
323-
W = LogisticRegression().fit(K, Y).coef_.T
324-
W = W.reshape(K.shape[1], -1)
340+
self.z_classifier_ = check_cl_fit(classifier, K, Y)
341+
if multioutput:
342+
W = np.hstack([est_.coef_.T for est_ in self.z_classifier_.estimators_])
343+
else:
344+
W = self.z_classifier_.coef_.T
325345

326346
Z = K @ W
327347

@@ -334,10 +354,16 @@ def fit(self, X, Y, W=None):
334354

335355
self.classifier_ = clone(classifier).fit(K @ self.pkt_, Y)
336356

337-
self.ptz_ = self.classifier_.coef_.T
338-
self.pkz_ = self.pkt_ @ self.ptz_
357+
if multioutput:
358+
self.ptz_ = np.hstack(
359+
[est_.coef_.T for est_ in self.classifier_.estimators_]
360+
)
361+
self.pkz_ = self.pkt_ @ self.ptz_
362+
else:
363+
self.ptz_ = self.classifier_.coef_.T
364+
self.pkz_ = self.pkt_ @ self.ptz_
339365

340-
if len(Y.shape) == 1 and type_of_target(Y) == "binary":
366+
if not multioutput and type_of_target(Y) == "binary":
341367
self.pkz_ = self.pkz_.reshape(
342368
K.shape[1],
343369
)
@@ -346,6 +372,7 @@ def fit(self, X, Y, W=None):
346372
)
347373

348374
self.components_ = self.pkt_.T # for sklearn compatibility
375+
349376
return self
350377

351378
def predict(self, X=None, T=None):
@@ -425,9 +452,12 @@ def decision_function(self, X=None, T=None):
425452
426453
Returns
427454
-------
428-
Z : numpy.ndarray, shape (n_samples,) or (n_samples, n_classes)
455+
Z : numpy.ndarray, shape (n_samples,) or (n_samples, n_classes), or a list of \
456+
n_outputs such arrays if n_outputs > 1
429457
Confidence scores. For binary classification, has shape `(n_samples,)`,
430-
for multiclass classification, has shape `(n_samples, n_classes)`
458+
for multiclass classification, has shape `(n_samples, n_classes)`.
459+
If n_outputs > 1, the list can contain arrays with differing shapes
460+
depending on the number of classes in each output of Y.
431461
"""
432462
check_is_fitted(self, attributes=["pkz_", "ptz_"])
433463

@@ -440,9 +470,21 @@ def decision_function(self, X=None, T=None):
440470
if self.center:
441471
K = self.centerer_.transform(K)
442472

443-
# Or self.classifier_.decision_function(K @ self.pxt_)
444-
return K @ self.pkz_ + self.classifier_.intercept_
473+
if self.n_outputs == 1:
474+
# Or self.classifier_.decision_function(K @ self.pkt_)
475+
return K @ self.pkz_ + self.classifier_.intercept_
476+
else:
477+
return [
478+
est_.decision_function(K @ self.pkt_)
479+
for est_ in self.classifier_.estimators_
480+
]
445481

446482
else:
447483
T = check_array(T)
448-
return T @ self.ptz_ + self.classifier_.intercept_
484+
485+
if self.n_outputs == 1:
486+
T @ self.ptz_ + self.classifier_.intercept_
487+
else:
488+
return [
489+
est_.decision_function(T) for est_ in self.classifier_.estimators_
490+
]

0 commit comments

Comments
 (0)