Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions examples/pcovc/KPCovC_Comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@
t_train = model.fit_transform(X_train_scaled, y_train)
t_test = model.transform(X_test_scaled)

ax.scatter(t_test[:, 0], t_test[:, 1], alpha=alpha_d, cmap=cm_bright, c=y_test)
ax.scatter(t_test[:, 0], t_test[:, 1], alpha=alpha_p, cmap=cm_bright, c=y_test)
ax.scatter(t_train[:, 0], t_train[:, 1], cmap=cm_bright, c=y_train)

ax.set_title(models[model])
Expand Down Expand Up @@ -197,20 +197,16 @@
models = {
LogisticRegressionCV(random_state=random_state): {
"kernel_params": {"kernel": "rbf", "gamma": 12},
"title": "Logistic Regression",
},
RidgeClassifierCV(): {
"kernel_params": {"kernel": "rbf", "gamma": 1},
"title": "Ridge Classifier",
"eps": 0.40,
},
LinearSVC(random_state=random_state): {
"kernel_params": {"kernel": "rbf", "gamma": 15},
"title": "Support Vector Classification",
},
SGDClassifier(random_state=random_state): {
"kernel_params": {"kernel": "rbf", "gamma": 15},
"title": "SGD Classifier",
"eps": 10,
},
}
Expand Down
11 changes: 5 additions & 6 deletions src/skmatter/decomposition/_kernel_pcovc.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
class KernelPCovC(LinearClassifierMixin, _BaseKPCov):
r"""Kernel Principal Covariates Classification (KPCovC).

KPCovC is a modification on the PrincipalCovariates Classification
proposed in [Jorgensen2025]_. It determines a latent-space projection
KPCovC is a modification on the Principal Covariates Classification
proposed in [Jorgensen2025]_. It determines a latent-space projection
:math:`\mathbf{T}` which minimizes a combined loss in supervised and unsupervised
tasks in the reproducing kernel Hilbert space (RKHS).

Expand Down Expand Up @@ -272,7 +272,7 @@ def fit(self, X, Y, W=None):
check_classification_targets(Y)
self.classes_ = np.unique(Y)

super().fit(X)
super()._set_fit_params(X)

K = self._get_kernel(X)

Expand Down Expand Up @@ -314,14 +314,13 @@ def fit(self, X, Y, W=None):

# Check if classifier is fitted; if not, fit with precomputed K
self.z_classifier_ = check_cl_fit(classifier, K, Y)
W = self.z_classifier_.coef_.T.reshape(K.shape[1], -1)
W = self.z_classifier_.coef_.T

else:
# If precomputed, use default classifier to predict Y from T
classifier = LogisticRegression(max_iter=500)
if W is None:
W = LogisticRegression().fit(K, Y).coef_.T
W = W.reshape(K.shape[1], -1)

Z = K @ W

Expand Down Expand Up @@ -440,7 +439,7 @@ def decision_function(self, X=None, T=None):
if self.center:
K = self.centerer_.transform(K)

# Or self.classifier_.decision_function(K @ self.pxt_)
# Or self.classifier_.decision_function(K @ self.pkt_)
return K @ self.pkz_ + self.classifier_.intercept_

else:
Expand Down
2 changes: 1 addition & 1 deletion src/skmatter/decomposition/_kernel_pcovr.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def fit(self, X, Y, W=None):
"""
X, Y = validate_data(self, X, Y, y_numeric=True, multi_output=True)

super().fit(X)
super()._set_fit_params(X)

K = self._get_kernel(X)

Expand Down
10 changes: 6 additions & 4 deletions src/skmatter/decomposition/_kpcov.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,12 @@ def _get_kernel(self, X, Y=None):
X, Y, metric=self.kernel, filter_params=True, n_jobs=self.n_jobs, **params
)

def fit(self, X):
"""Contains the common functionality for the KPCovR and KPCovC fit methods,
but leaves the rest of the functionality to the subclass.
"""
@abstractmethod
def fit(self, X, Y):
"""Fit the model with X and Y. Subclasses should implement this method."""

def _set_fit_params(self, X):
"""Initializes common fit parameters for PCovR and PCovC."""
self.X_fit_ = X.copy()

if self.n_components is None:
Expand Down
10 changes: 6 additions & 4 deletions src/skmatter/decomposition/_pcov.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,12 @@ def __init__(
self.random_state = random_state
self.whiten = whiten

def fit(self, X):
"""Contains the common functionality for the PCovR and PCovC fit methods,
but leaves the rest of the functionality to the subclass.
"""
@abstractmethod
def fit(self, X, Y):
"""Fit the model with X and Y. Subclasses should implement this method."""

def _set_fit_params(self, X):
"""Initializes common fit parameters for PCovR and PCovC."""
# saved for inverse transformations from the latent space,
# should be zero in the case that the features have been properly centered
self.mean_ = np.mean(X, axis=0)
Expand Down
19 changes: 9 additions & 10 deletions src/skmatter/decomposition/_pcovc.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,11 @@ class PCovC(LinearClassifierMixin, _BasePCov):
Must be of range [0.0, infinity).

space: {'feature', 'sample', 'auto'}, default='auto'
whether to compute the PCovC in `sample` or `feature` space
default=`sample` when :math:`{n_{samples} < n_{features}}` and
whether to compute the PCovC in `sample` or `feature` space.
Default = `sample` when :math:`{n_{samples} < n_{features}}` and
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks a bit weird for me. Maybe write a "real" sentence like

Suggested change
Default = `sample` when :math:`{n_{samples} < n_{features}}` and
The default value is `sample` when :math:`{n_{samples} < n_{features}}` and

or similar.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you push the changes. Somehow I don't see them here?

`feature` when :math:`{n_{features} < n_{samples}}`

classifier: `estimator object` or `precomputed`, default=None
classifier: `estimator object` or `precomputed`, default=None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you want to do a code highlighting use double ticks ``. Our documentation pages use rst where inline codes are highlighted with double ticks. Sorry for the confusion and this very subtle difference between markdown and rst.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, that makes sense. I think the reason I have single ticks is to match how PCovR/KPCovR docstrings handle `precomputed`, but I will keep that in mind

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
classifier: `estimator object` or `precomputed`, default=None
classifier: ``estimator object`` or ``precomputed``, default=None

classifier for computing :math:`{\mathbf{Z}}`. The classifier should be
one of the following:

Expand Down Expand Up @@ -144,9 +144,9 @@ class PCovC(LinearClassifierMixin, _BasePCov):
Must be of range [0.0, infinity).

space: {'feature', 'sample', 'auto'}, default='auto'
whether to compute the PCovC in `sample` or `feature` space
default=`sample` when :math:`{n_{samples} < n_{features}}` and
`feature` when :math:`{n_{features} < n_{samples}}`
whether to compute the PCovC in `sample` or `feature` space.
The default is = `sample` when :math:`{n_{samples} < n_{features}}`
and `feature` when :math:`{n_{features} < n_{samples}}`
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
whether to compute the PCovC in `sample` or `feature` space.
The default is = `sample` when :math:`{n_{samples} < n_{features}}`
and `feature` when :math:`{n_{features} < n_{samples}}`
whether to compute the PCovC in ``sample`` or ``feature`` space.
The default is equal to ``sample`` when :math:`{n_{samples} < n_{features}}`
and ``feature`` when :math:`{n_{features} < n_{samples}}`


n_components_ : int
The estimated number of components, which equals the parameter
Expand All @@ -160,7 +160,7 @@ class PCovC(LinearClassifierMixin, _BasePCov):
The linear classifier fit between :math:`\mathbf{X}` and :math:`\mathbf{Y}`.

classifier_ : estimator object
The linear classifier fit between :math:`\mathbf{T}` and :math:`\mathbf{Y}`.
The linear classifier fit between :math:`\mathbf{T}` and :math:`\mathbf{Y}`.

pxt_ : ndarray of size :math:`({n_{features}, n_{components}})`
the projector, or weights, from the input space :math:`\mathbf{X}`
Expand Down Expand Up @@ -262,7 +262,7 @@ def fit(self, X, Y, W=None):
check_classification_targets(Y)
self.classes_ = np.unique(Y)

super().fit(X)
super()._set_fit_params(X)

compatible_classifiers = (
LogisticRegression,
Expand Down Expand Up @@ -291,14 +291,13 @@ def fit(self, X, Y, W=None):
classifier = self.classifier

self.z_classifier_ = check_cl_fit(classifier, X, Y)
W = self.z_classifier_.coef_.T.reshape(X.shape[1], -1)
W = self.z_classifier_.coef_.T

else:
# If precomputed, use default classifier to predict Y from T
classifier = LogisticRegression()
if W is None:
W = LogisticRegression().fit(X, Y).coef_.T
W = W.reshape(X.shape[1], -1)

Z = X @ W

Expand Down
15 changes: 8 additions & 7 deletions src/skmatter/decomposition/_pcovr.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,9 @@ class PCovR(RegressorMixin, MultiOutputMixin, _BasePCov):
range [0.0, infinity).

space: {'feature', 'sample', 'auto'}, default='auto'
whether to compute the PCovR in `sample` or `feature` space default=`sample`
when :math:`{n_{samples} < n_{features}}` and `feature` when
:math:`{n_{features} < n_{samples}}`
whether to compute the PCovR in `sample` or `feature` space.
The default is = `sample` when :math:`{n_{samples} < n_{features}}`
and `feature` when :math:`{n_{features} < n_{samples}}`
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
The default is = `sample` when :math:`{n_{samples} < n_{features}}`
and `feature` when :math:`{n_{features} < n_{samples}}`
The default is equal to ``sample`` when :math:`{n_{samples} < n_{features}}`
and ``feature`` when :math:`{n_{features} < n_{samples}}`


regressor: {`Ridge`, `RidgeCV`, `LinearRegression`, `precomputed`}, default=None
regressor for computing approximated :math:`{\mathbf{\hat{Y}}}`. The regressor
Expand Down Expand Up @@ -126,7 +126,7 @@ class PCovR(RegressorMixin, MultiOutputMixin, _BasePCov):
Must be of range [0.0, infinity).

space: {'feature', 'sample', 'auto'}, default='auto'
whether to compute the PCovR in `sample` or `feature` space default=`sample`
whether to compute the PCovR in `sample` or `feature` space. Default = `sample`
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
whether to compute the PCovR in `sample` or `feature` space. Default = `sample`
whether to compute the PCovR in `sample` or `feature` space. Default is equal to ``sample``

when :math:`{n_{samples} < n_{features}}` and `feature` when
:math:`{n_{features} < n_{samples}}`

Expand Down Expand Up @@ -227,11 +227,12 @@ def fit(self, X, Y, W=None):
regressed form of the properties, :math:`{\mathbf{\hat{Y}}}`.

W : numpy.ndarray, shape (n_features, n_properties)
Regression weights, optional when regressor= `precomputed`. If not
Regression weights, optional when regressor = `precomputed`. If not
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Regression weights, optional when regressor = `precomputed`. If not
Regression weights, optional when regressor is ``precomputed``. If not

passed, it is assumed that `W = np.linalg.lstsq(X, Y, self.tol)[0]`
"""
X, Y = validate_data(self, X, Y, y_numeric=True, multi_output=True)
super().fit(X)

super()._set_fit_params(X)

compatible_regressors = (LinearRegression, Ridge, RidgeCV)

Expand Down Expand Up @@ -414,7 +415,7 @@ def score(self, X, y, T=None):
Negative sum of the loss in reconstructing X from the latent-space
projection T and the loss in predicting Y from the latent-space projection T
"""
X, y = validate_data(self, X, y, reset=False)
X, y = validate_data(self, X, y, multi_output=True, reset=False)

if T is None:
T = self.transform(X)
Expand Down
60 changes: 26 additions & 34 deletions tests/test_kernel_pcovc.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,22 +160,17 @@ def test_Z_shape(self):
kpcovc.fit(self.X, self.Y)

# Shape (n_samples, ) for binary classifcation
Z = kpcovc.decision_function(self.X)
Z_binary = kpcovc.decision_function(self.X)

self.assertTrue(Z.ndim == 1)
self.assertTrue(Z.shape[0] == self.X.shape[0])

# Modify Y so that it now contains three classes
Y_multiclass = self.Y.copy()
Y_multiclass[0] = 2
kpcovc.fit(self.X, Y_multiclass)
n_classes = len(np.unique(Y_multiclass))
self.assertEqual(Z_binary.ndim, 1)
self.assertEqual(Z_binary.shape[0], self.X.shape[0])

# Shape (n_samples, n_classes) for multiclass classification
Z = kpcovc.decision_function(self.X)
kpcovc.fit(self.X, np.random.randint(0, 3, size=self.X.shape[0]))
Z_multi = kpcovc.decision_function(self.X)

self.assertTrue(Z.ndim == 2)
self.assertTrue((Z.shape[0], Z.shape[1]) == (self.X.shape[0], n_classes))
self.assertEqual(Z_multi.ndim, 2)
self.assertEqual(Z_multi.shape, (self.X.shape[0], len(kpcovc.classes_)))

def test_decision_function(self):
"""Check that KPCovC's decision_function works when only T is
Expand Down Expand Up @@ -225,11 +220,11 @@ def test_prefit_classifier(self):
kpcovc = KernelPCovC(mixing=0.5, classifier=classifier, **kernel_params)
kpcovc.fit(self.X, self.Y)

Z_classifier = classifier.decision_function(K).reshape(K.shape[0], -1)
W_classifier = classifier.coef_.T.reshape(K.shape[1], -1)
Z_classifier = classifier.decision_function(K)
W_classifier = classifier.coef_.T

Z_kpcovc = kpcovc.z_classifier_.decision_function(K).reshape(K.shape[0], -1)
W_kpcovc = kpcovc.z_classifier_.coef_.T.reshape(K.shape[1], -1)
Z_kpcovc = kpcovc.z_classifier_.decision_function(K)
W_kpcovc = kpcovc.z_classifier_.coef_.T

self.assertTrue(np.allclose(Z_classifier, Z_kpcovc))
self.assertTrue(np.allclose(W_classifier, W_kpcovc))
Expand Down Expand Up @@ -273,40 +268,37 @@ def test_none_classifier(self):
self.assertTrue(kpcovc.classifier_ is not None)

def test_incompatible_coef_shape(self):
kernel_params = {"kernel": "rbf", "gamma": 0.1, "degree": 3, "coef0": 0}

K = pairwise_kernels(self.X, metric="rbf", filter_params=True, **kernel_params)

# Modify Y to be multiclass
Y_multiclass = self.Y.copy()
Y_multiclass[0] = 2
kernel_params = {"kernel": "sigmoid", "gamma": 0.1, "degree": 3, "coef0": 0}
K = pairwise_kernels(
self.X, metric="sigmoid", filter_params=True, **kernel_params
)

classifier1 = LinearSVC()
classifier1.fit(K, Y_multiclass)
kpcovc1 = self.model(mixing=0.5, classifier=classifier1, **kernel_params)
cl_multi = LinearSVC()
cl_multi.fit(K, np.random.randint(0, 3, size=self.X.shape[0]))
kpcovc_binary = self.model(mixing=0.5, classifier=cl_multi)

# Binary classification shape mismatch
with self.assertRaises(ValueError) as cm:
kpcovc1.fit(self.X, self.Y)
kpcovc_binary.fit(self.X, self.Y)
self.assertEqual(
str(cm.exception),
"For binary classification, expected classifier coefficients "
"to have shape (1, %d) but got shape %r"
% (K.shape[1], classifier1.coef_.shape),
% (K.shape[1], cl_multi.coef_.shape),
)

classifier2 = LinearSVC()
classifier2.fit(K, self.Y)
kpcovc2 = self.model(mixing=0.5, classifier=classifier2)
cl_binary = LinearSVC()
cl_binary.fit(K, self.Y)
kpcovc_multi = self.model(mixing=0.5, classifier=cl_binary)

# Multiclass classification shape mismatch
with self.assertRaises(ValueError) as cm:
kpcovc2.fit(self.X, Y_multiclass)
kpcovc_multi.fit(self.X, np.random.randint(0, 3, size=self.X.shape[0]))
self.assertEqual(
str(cm.exception),
"For multiclass classification, expected classifier coefficients "
"to have shape (%d, %d) but got shape %r"
% (len(np.unique(Y_multiclass)), K.shape[1], classifier2.coef_.shape),
% (len(kpcovc_multi.classes_), K.shape[1], cl_binary.coef_.shape),
)

def test_precomputed_classification(self):
Expand All @@ -316,7 +308,7 @@ def test_precomputed_classification(self):
classifier = LogisticRegression()
classifier.fit(K, self.Y)

W = classifier.coef_.T.reshape(K.shape[1], -1)
W = classifier.coef_.T
kpcovc1 = self.model(mixing=0.5, classifier="precomputed", **kernel_params)
kpcovc1.fit(self.X, self.Y, W)
t1 = kpcovc1.transform(self.X)
Expand Down
Loading
Loading