Skip to content

Commit 1ff4fed

Browse files
author
Christian Jorgensen
committed
Fixing ptz and pxz for multioutput
1 parent c20262c commit 1ff4fed

File tree

4 files changed

+43
-40
lines changed

4 files changed

+43
-40
lines changed

examples/pcovc/PCovC_multioutput.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@
2222
plt.rcParams["image.cmap"] = "tab10"
2323
plt.rcParams["scatter.edgecolors"] = "k"
2424
# %%
25-
#
26-
#
25+
# For this, we will use the `sklearn.datasets.load_digits` dataset.
26+
# This dataset contains 8x8 images of handwritten digits (0-9).
2727
X, y = load_digits(return_X_y=True)
2828
x_scaler = StandardScaler()
2929
X_scaled = StandardScaler().fit_transform(X)
@@ -127,3 +127,5 @@
127127
axs[1, 0].set_ylabel("PCovC")
128128
fig.colorbar(scat_pca, ax=axs, orientation="horizontal")
129129
fig.suptitle("Multiclass-Multilabel PCovC")
130+
131+
# %%

src/skmatter/decomposition/_kernel_pcovc.py

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@ class KernelPCovC(LinearClassifierMixin, _BaseKPCov):
3939
4040
where :math:`\alpha` is a mixing parameter,
4141
:math:`\mathbf{K}` is the input kernel of shape :math:`(n_{samples}, n_{samples})`
42-
and :math:`\mathbf{Z}` is a matrix of class confidence scores of shape
43-
:math:`(n_{samples}, n_{classes})`
42+
and :math:`\mathbf{Z}` is a tensor of class confidence scores of shape
43+
:math:`(n_{samples}, n_{classes}, n_{labels})`
4444
4545
Parameters
4646
----------
@@ -82,10 +82,10 @@ class KernelPCovC(LinearClassifierMixin, _BaseKPCov):
8282
- ``sklearn.linear_model.LogisticRegressionCV()``
8383
- ``sklearn.svm.LinearSVC()``
8484
- ``sklearn.discriminant_analysis.LinearDiscriminantAnalysis()``
85-
- ``sklearn.multioutput.MultiOutputClassifier()``
85+
- ``sklearn.linear_model.Perceptron()``
8686
- ``sklearn.linear_model.RidgeClassifier()``
8787
- ``sklearn.linear_model.RidgeClassifierCV()``
88-
- ``sklearn.linear_model.Perceptron()``
88+
- ``sklearn.multioutput.MultiOutputClassifier()``
8989
9090
If a pre-fitted classifier
9191
is provided, it is used to compute :math:`{\mathbf{Z}}`.
@@ -167,13 +167,15 @@ class KernelPCovC(LinearClassifierMixin, _BaseKPCov):
167167
the projector, or weights, from the input kernel :math:`\mathbf{K}`
168168
to the latent-space projection :math:`\mathbf{T}`
169169
170-
pkz_: numpy.ndarray of size :math:`({n_{samples}, })` or :math:`({n_{samples}, n_{classes}})`
171-
the projector, or weights, from the input kernel :math:`\mathbf{K}`
172-
to the class confidence scores :math:`\mathbf{Z}`
170+
pkz_ : ndarray of size :math:`({n_{features}, {n_{classes}}})`, or list of
171+
ndarrays of size :math:`({n_{features}, {n_{classes_i}}})` for a dataset
172+
with :math: `i` labels.
173+
the projector, or weights, from the input space :math:`\mathbf{X}`
174+
to the class confidence scores :math:`\mathbf{Z}`.
173175
174-
ptz_: numpy.ndarray of size :math:`({n_{components}, })` or :math:`({n_{components}, n_{classes}})`
175-
the projector, or weights, from the latent-space projection
176-
:math:`\mathbf{T}` to the class confidence scores :math:`\mathbf{Z}`
176+
ptz_ : ndarray of size :math:`({n_{components}, {n_{classes}}})`, or list of
177+
ndarrays of size :math:`({n_{components}, {n_{classes_i}}})` for a dataset
178+
with :math: `i` labels.
177179
178180
ptx_: numpy.ndarray of size :math:`({n_{components}, n_{features}})`
179181
the projector, or weights, from the latent-space projection
@@ -271,13 +273,16 @@ def fit(self, X, Y, W=None):
271273
scaled to have unit variance, otherwise :math:`\mathbf{X}` should
272274
be scaled so that each feature has a variance of 1 / n_features.
273275
274-
Y : numpy.ndarray, shape (n_samples,)
275-
Training data, where n_samples is the number of samples.
276+
Y : numpy.ndarray, shape (n_samples,) or (n_samples, n_outputs)
277+
Training data, where n_samples is the number of samples and
278+
n_outputs is the number of outputs.
276279
277-
W : numpy.ndarray, shape (n_features, n_classes)
280+
W : numpy.ndarray, shape (n_features, n_classes) or (n_features, )
278281
Classification weights, optional when classifier = `precomputed`. If
279282
not passed, it is assumed that the weights will be taken from a
280-
linear classifier fit between K and Y.
283+
linear classifier fit between :math:`\mathbf{X}` and :math:`\mathbf{Y}`.
284+
In the multioutput case, use
285+
`` W = np.hstack([est_.coef_.T for est_ in classifier.estimators_])``.
281286
282287
Returns
283288
-------
@@ -355,7 +360,7 @@ def fit(self, X, Y, W=None):
355360
else:
356361
W = _.coef_.T
357362

358-
else:
363+
elif W is None:
359364
self.z_classifier_ = check_cl_fit(classifier, K, Y)
360365
if multioutput:
361366
W = np.hstack([est_.coef_.T for est_ in self.z_classifier_.estimators_])
@@ -374,10 +379,8 @@ def fit(self, X, Y, W=None):
374379
self.classifier_ = clone(classifier).fit(K @ self.pkt_, Y)
375380

376381
if multioutput:
377-
self.ptz_ = np.hstack(
378-
[est_.coef_.T for est_ in self.classifier_.estimators_]
379-
)
380-
self.pkz_ = self.pkt_ @ self.ptz_
382+
self.ptz_ = [est_.coef_.T for est_ in self.classifier_.estimators_]
383+
self.pkz_ = [self.pkt_ @ ptz for ptz in self.ptz_]
381384
else:
382385
self.ptz_ = self.classifier_.coef_.T
383386
self.pkz_ = self.pkt_ @ self.ptz_

src/skmatter/decomposition/_pcovc.py

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
)
1212
from sklearn.linear_model._base import LinearClassifierMixin
1313

14-
from sklearn.base import MultiOutputMixin
1514
from sklearn.multioutput import MultiOutputClassifier
1615
from sklearn.svm import LinearSVC
1716
from sklearn.utils import check_array
@@ -36,8 +35,8 @@ class PCovC(LinearClassifierMixin, _BasePCov):
3635
(1 - \alpha) \mathbf{Z}\mathbf{Z}^T
3736
3837
where :math:`\alpha` is a mixing parameter, :math:`\mathbf{X}` is an input matrix of shape
39-
:math:`(n_{samples}, n_{features})`, and :math:`\mathbf{Z}` is a matrix of class confidence scores
40-
of shape :math:`(n_{samples}, n_{classes})`. For :math:`(n_{samples} < n_{features})`,
38+
:math:`(n_{samples}, n_{features})`, and :math:`\mathbf{Z}` is a tensor of class confidence scores
39+
of shape :math:`(n_{samples}, n_{classes}, n_{labels})`. For :math:`(n_{samples} < n_{features})`,
4140
this can be more efficiently computed using the eigendecomposition of a modified covariance matrix
4241
:math:`\mathbf{\tilde{C}}`
4342
@@ -112,10 +111,10 @@ class PCovC(LinearClassifierMixin, _BasePCov):
112111
- ``sklearn.linear_model.LogisticRegressionCV()``
113112
- ``sklearn.svm.LinearSVC()``
114113
- ``sklearn.discriminant_analysis.LinearDiscriminantAnalysis()``
115-
- ``sklearn.multioutput.MultiOutputClassifier()``
114+
- ``sklearn.linear_model.Perceptron()``
116115
- ``sklearn.linear_model.RidgeClassifier()``
117116
- ``sklearn.linear_model.RidgeClassifierCV()``
118-
- ``sklearn.linear_model.Perceptron()``
117+
- ``sklearn.multioutput.MultiOutputClassifier()``
119118
120119
If a pre-fitted classifier
121120
is provided, it is used to compute :math:`{\mathbf{Z}}`.
@@ -175,11 +174,15 @@ class PCovC(LinearClassifierMixin, _BasePCov):
175174
the projector, or weights, from the input space :math:`\mathbf{X}`
176175
to the latent-space projection :math:`\mathbf{T}`
177176
178-
pxz_ : ndarray of size :math:`({n_{features}, })`, :math:`({n_{features}, n_{classes}})`
177+
pxz_ : ndarray of size :math:`({n_{features}, {n_{classes}}})`, or list of
178+
ndarrays of size :math:`({n_{features}, {n_{classes_i}}})` for a dataset
179+
with :math: `i` labels.
179180
the projector, or weights, from the input space :math:`\mathbf{X}`
180181
to the class confidence scores :math:`\mathbf{Z}`.
181182
182-
ptz_ : ndarray of size :math:`({n_{components}, })`, :math:`({n_{components}, n_{classes}})`
183+
ptz_ : ndarray of size :math:`({n_{components}, {n_{classes}}})`, or list of
184+
ndarrays of size :math:`({n_{components}, {n_{classes_i}}})` for a dataset
185+
with :math: `i` labels.
183186
the projector, or weights, from from the latent-space projection
184187
:math:`\mathbf{T}` to the class confidence scores :math:`\mathbf{Z}`.
185188
@@ -267,7 +270,7 @@ def fit(self, X, Y, W=None):
267270
Classification weights, optional when classifier is ``precomputed``. If
268271
not passed, it is assumed that the weights will be taken from a
269272
linear classifier fit between :math:`\mathbf{X}` and :math:`\mathbf{Y}`.
270-
In the multioutput case,
273+
In the multioutput case, use
271274
`` W = np.hstack([est_.coef_.T for est_ in classifier.estimators_])``.
272275
"""
273276
X, Y = validate_data(self, X, Y, multi_output=True, y_numeric=False)
@@ -329,15 +332,15 @@ def fit(self, X, Y, W=None):
329332
W = np.hstack([_.coef_.T for _ in _.estimators_])
330333
else:
331334
W = _.coef_.T
332-
else:
335+
elif W is None:
333336
self.z_classifier_ = check_cl_fit(classifier, X, Y)
334337
if multioutput:
335338
W = np.hstack([est_.coef_.T for est_ in self.z_classifier_.estimators_])
336339
else:
337340
W = self.z_classifier_.coef_.T
338341

339342
Z = X @ W
340-
343+
341344
if self.space_ == "feature":
342345
self._fit_feature_space(X, Y, Z)
343346
else:
@@ -348,19 +351,12 @@ def fit(self, X, Y, W=None):
348351
self.classifier_ = clone(classifier).fit(X @ self.pxt_, Y)
349352

350353
if multioutput:
351-
self.ptz_ = np.hstack(
352-
[est_.coef_.T for est_ in self.classifier_.estimators_]
353-
)
354-
# print(f"pxt {self.pxt_.shape}")
355-
# print(f"ptz {self.ptz_.shape}")
356-
self.pxz_ = self.pxt_ @ self.ptz_
357-
# print(f"pxz {self.pxz_.shape}")
354+
self.ptz_ = [est_.coef_.T for est_ in self.classifier_.estimators_]
355+
self.pxz_ = [self.pxt_ @ ptz for ptz in self.ptz_]
358356
else:
359357
self.ptz_ = self.classifier_.coef_.T
360-
# print(self.ptz_.shape)
361358
self.pxz_ = self.pxt_ @ self.ptz_
362359

363-
# print(self.ptz_.shape)
364360
if not multioutput and type_of_target(Y) == "binary":
365361
self.pxz_ = self.pxz_.reshape(
366362
X.shape[1],
@@ -531,3 +527,4 @@ def score(self, X, y, sample_weight=None):
531527

532528
# Inherit the docstring from scikit-learn
533529
score.__doc__ = LinearClassifierMixin.score.__doc__
530+

tests/test_pcovc.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -606,6 +606,7 @@ def test_precomputed_multioutput(self):
606606

607607
classifier.fit(self.X, Y_double)
608608
W = np.hstack([est_.coef_.T for est_ in classifier.estimators_])
609+
print(W.shape)
609610
pcovc1 = self.model(mixing=0.5, classifier="precomputed", n_components=1)
610611
pcovc1.fit(self.X, Y_double, W)
611612
t1 = pcovc1.transform(self.X)

0 commit comments

Comments
 (0)