Skip to content

Commit 18e4c8e

Browse files
committed
Minor reorganizing and typos
1 parent aee34e6 commit 18e4c8e

File tree

9 files changed

+44
-38
lines changed

9 files changed

+44
-38
lines changed

examples/pcovc/KPCovC_Comparison.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@
105105
t_train = model.fit_transform(X_train_scaled, y_train)
106106
t_test = model.transform(X_test_scaled)
107107

108-
ax.scatter(t_test[:, 0], t_test[:, 1], alpha=alpha_d, cmap=cm_bright, c=y_test)
108+
ax.scatter(t_test[:, 0], t_test[:, 1], alpha=alpha_p, cmap=cm_bright, c=y_test)
109109
ax.scatter(t_train[:, 0], t_train[:, 1], cmap=cm_bright, c=y_train)
110110

111111
ax.set_title(models[model])
@@ -197,20 +197,16 @@
197197
models = {
198198
LogisticRegressionCV(random_state=random_state): {
199199
"kernel_params": {"kernel": "rbf", "gamma": 12},
200-
"title": "Logistic Regression",
201200
},
202201
RidgeClassifierCV(): {
203202
"kernel_params": {"kernel": "rbf", "gamma": 1},
204-
"title": "Ridge Classifier",
205203
"eps": 0.40,
206204
},
207205
LinearSVC(random_state=random_state): {
208206
"kernel_params": {"kernel": "rbf", "gamma": 15},
209-
"title": "Support Vector Classification",
210207
},
211208
SGDClassifier(random_state=random_state): {
212209
"kernel_params": {"kernel": "rbf", "gamma": 15},
213-
"title": "SGD Classifier",
214210
"eps": 10,
215211
},
216212
}

src/skmatter/decomposition/_kernel_pcovc.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ def fit(self, X, Y, W=None):
272272
check_classification_targets(Y)
273273
self.classes_ = np.unique(Y)
274274

275-
super().fit(X)
275+
super()._set_fit_params(X)
276276

277277
K = self._get_kernel(X)
278278

@@ -314,14 +314,13 @@ def fit(self, X, Y, W=None):
314314

315315
# Check if classifier is fitted; if not, fit with precomputed K
316316
self.z_classifier_ = check_cl_fit(classifier, K, Y)
317-
W = self.z_classifier_.coef_.T.reshape(K.shape[1], -1)
317+
W = self.z_classifier_.coef_.T
318318

319319
else:
320320
# If precomputed, use default classifier to predict Y from T
321321
classifier = LogisticRegression(max_iter=500)
322322
if W is None:
323323
W = LogisticRegression().fit(K, Y).coef_.T
324-
W = W.reshape(K.shape[1], -1)
325324

326325
Z = K @ W
327326

@@ -440,7 +439,7 @@ def decision_function(self, X=None, T=None):
440439
if self.center:
441440
K = self.centerer_.transform(K)
442441

443-
# Or self.classifier_.decision_function(K @ self.pxt_)
442+
# Or self.classifier_.decision_function(K @ self.pkt_)
444443
return K @ self.pkz_ + self.classifier_.intercept_
445444

446445
else:

src/skmatter/decomposition/_kernel_pcovr.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ def fit(self, X, Y, W=None):
242242
"""
243243
X, Y = validate_data(self, X, Y, y_numeric=True, multi_output=True)
244244

245-
super().fit(X)
245+
super()._set_fit_params(X)
246246

247247
K = self._get_kernel(X)
248248

src/skmatter/decomposition/_kpcov.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,13 @@ def _get_kernel(self, X, Y=None):
7474
X, Y, metric=self.kernel, filter_params=True, n_jobs=self.n_jobs, **params
7575
)
7676

77-
def fit(self, X):
78-
"""Contains the common functionality for the KPCovR and KPCovC fit methods,
79-
but leaves the rest of the functionality to the subclass.
80-
"""
77+
@abstractmethod
78+
def fit(self, X, Y):
79+
"""Fit the model with X and Y. Subclasses should implement this method."""
80+
pass
81+
82+
def _set_fit_params(self, X):
83+
"""Initializes common fit parameters for PCovR and PCovC."""
8184
self.X_fit_ = X.copy()
8285

8386
if self.n_components is None:

src/skmatter/decomposition/_pcov.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,13 @@ def __init__(
4848
self.random_state = random_state
4949
self.whiten = whiten
5050

51-
def fit(self, X):
52-
"""Contains the common functionality for the PCovR and PCovC fit methods,
53-
but leaves the rest of the functionality to the subclass.
54-
"""
51+
@abstractmethod
52+
def fit(self, X, Y):
53+
"""Fit the model with X and Y. Subclasses should implement this method."""
54+
pass
55+
56+
def _set_fit_params(self, X):
57+
"""Initializes common fit parameters for PCovR and PCovC."""
5558
# saved for inverse transformations from the latent space,
5659
# should be zero in the case that the features have been properly centered
5760
self.mean_ = np.mean(X, axis=0)

src/skmatter/decomposition/_pcovc.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -97,11 +97,11 @@ class PCovC(LinearClassifierMixin, _BasePCov):
9797
Must be of range [0.0, infinity).
9898
9999
space: {'feature', 'sample', 'auto'}, default='auto'
100-
whether to compute the PCovC in `sample` or `feature` space
101-
default=`sample` when :math:`{n_{samples} < n_{features}}` and
100+
whether to compute the PCovC in `sample` or `feature` space.
101+
Default = `sample` when :math:`{n_{samples} < n_{features}}` and
102102
`feature` when :math:`{n_{features} < n_{samples}}`
103103
104-
classifier: `estimator object` or `precomputed`, default=None
104+
classifier: `estimator object` or `precomputed`, default=None
105105
classifier for computing :math:`{\mathbf{Z}}`. The classifier should be
106106
one of the following:
107107
@@ -144,8 +144,8 @@ class PCovC(LinearClassifierMixin, _BasePCov):
144144
Must be of range [0.0, infinity).
145145
146146
space: {'feature', 'sample', 'auto'}, default='auto'
147-
whether to compute the PCovC in `sample` or `feature` space
148-
default=`sample` when :math:`{n_{samples} < n_{features}}` and
147+
whether to compute the PCovC in `sample` or `feature` space.
148+
Default = `sample` when :math:`{n_{samples} < n_{features}}` and
149149
`feature` when :math:`{n_{features} < n_{samples}}`
150150
151151
n_components_ : int
@@ -262,7 +262,7 @@ def fit(self, X, Y, W=None):
262262
check_classification_targets(Y)
263263
self.classes_ = np.unique(Y)
264264

265-
super().fit(X)
265+
super()._set_fit_params(X)
266266

267267
compatible_classifiers = (
268268
LogisticRegression,
@@ -291,14 +291,13 @@ def fit(self, X, Y, W=None):
291291
classifier = self.classifier
292292

293293
self.z_classifier_ = check_cl_fit(classifier, X, Y)
294-
W = self.z_classifier_.coef_.T.reshape(X.shape[1], -1)
294+
W = self.z_classifier_.coef_.T
295295

296296
else:
297297
# If precomputed, use default classifier to predict Y from T
298298
classifier = LogisticRegression()
299299
if W is None:
300300
W = LogisticRegression().fit(X, Y).coef_.T
301-
W = W.reshape(X.shape[1], -1)
302301

303302
Z = X @ W
304303

src/skmatter/decomposition/_pcovr.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ class PCovR(RegressorMixin, MultiOutputMixin, _BasePCov):
8888
range [0.0, infinity).
8989
9090
space: {'feature', 'sample', 'auto'}, default='auto'
91-
whether to compute the PCovR in `sample` or `feature` space default=`sample`
91+
whether to compute the PCovR in `sample` or `feature` space. Default = `sample`
9292
when :math:`{n_{samples} < n_{features}}` and `feature` when
9393
:math:`{n_{features} < n_{samples}}`
9494
@@ -126,7 +126,7 @@ class PCovR(RegressorMixin, MultiOutputMixin, _BasePCov):
126126
Must be of range [0.0, infinity).
127127
128128
space: {'feature', 'sample', 'auto'}, default='auto'
129-
whether to compute the PCovR in `sample` or `feature` space default=`sample`
129+
whether to compute the PCovR in `sample` or `feature` space. Default = `sample`
130130
when :math:`{n_{samples} < n_{features}}` and `feature` when
131131
:math:`{n_{features} < n_{samples}}`
132132
@@ -227,11 +227,12 @@ def fit(self, X, Y, W=None):
227227
regressed form of the properties, :math:`{\mathbf{\hat{Y}}}`.
228228
229229
W : numpy.ndarray, shape (n_features, n_properties)
230-
Regression weights, optional when regressor= `precomputed`. If not
230+
Regression weights, optional when regressor = `precomputed`. If not
231231
passed, it is assumed that `W = np.linalg.lstsq(X, Y, self.tol)[0]`
232232
"""
233233
X, Y = validate_data(self, X, Y, y_numeric=True, multi_output=True)
234-
super().fit(X)
234+
235+
super()._set_fit_params(X)
235236

236237
compatible_regressors = (LinearRegression, Ridge, RidgeCV)
237238

@@ -414,7 +415,7 @@ def score(self, X, y, T=None):
414415
Negative sum of the loss in reconstructing X from the latent-space
415416
projection T and the loss in predicting Y from the latent-space projection T
416417
"""
417-
X, y = validate_data(self, X, y, reset=False)
418+
X, y = validate_data(self, X, y, multi_output=True, reset=False)
418419

419420
if T is None:
420421
T = self.transform(X)

tests/test_kernel_pcovc.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,17 @@ def __init__(self, *args, **kwargs):
3030
scaler = StandardScaler()
3131
self.X = scaler.fit_transform(self.X)
3232

33-
self.model = lambda mixing=0.5, classifier=LogisticRegression(), n_components=4, **kwargs: KernelPCovC(
34-
mixing=mixing,
35-
classifier=classifier,
36-
n_components=n_components,
37-
svd_solver=kwargs.pop("svd_solver", "full"),
38-
**kwargs,
33+
self.model = (
34+
lambda mixing=0.5,
35+
classifier=LogisticRegression(),
36+
n_components=4,
37+
**kwargs: KernelPCovC(
38+
mixing=mixing,
39+
classifier=classifier,
40+
n_components=n_components,
41+
svd_solver=kwargs.pop("svd_solver", "full"),
42+
**kwargs,
43+
)
3944
)
4045

4146
def setUp(self):

tests/test_pcovr.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -401,7 +401,7 @@ def test_default_ncomponents(self):
401401

402402
self.assertEqual(pcovr.n_components_, min(self.X.shape))
403403

404-
def test_Y_Shape(self):
404+
def test_Y_shape(self):
405405
pcovr = self.model()
406406
self.Y = np.vstack(self.Y)
407407
pcovr.fit(self.X, self.Y)

0 commit comments

Comments
 (0)