Skip to content

Commit 25a96f6

Browse files
committed
Changes to KPCovC tests
1 parent 46d2fca commit 25a96f6

File tree

1 file changed

+32
-45
lines changed

1 file changed

+32
-45
lines changed

tests/test_kernel_pcovc.py

Lines changed: 32 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -30,17 +30,12 @@ def __init__(self, *args, **kwargs):
3030
scaler = StandardScaler()
3131
self.X = scaler.fit_transform(self.X)
3232

33-
self.model = (
34-
lambda mixing=0.5,
35-
classifier=LogisticRegression(),
36-
n_components=4,
37-
**kwargs: KernelPCovC(
38-
mixing=mixing,
39-
classifier=classifier,
40-
n_components=n_components,
41-
svd_solver=kwargs.pop("svd_solver", "full"),
42-
**kwargs,
43-
)
33+
self.model = lambda mixing=0.5, classifier=LogisticRegression(), n_components=4, **kwargs: KernelPCovC(
34+
mixing=mixing,
35+
classifier=classifier,
36+
n_components=n_components,
37+
svd_solver=kwargs.pop("svd_solver", "full"),
38+
**kwargs,
4439
)
4540

4641
def setUp(self):
@@ -160,22 +155,17 @@ def test_Z_shape(self):
160155
kpcovc.fit(self.X, self.Y)
161156

162157
# Shape (n_samples, ) for binary classifcation
163-
Z = kpcovc.decision_function(self.X)
164-
165-
self.assertTrue(Z.ndim == 1)
166-
self.assertTrue(Z.shape[0] == self.X.shape[0])
158+
Z_binary = kpcovc.decision_function(self.X)
167159

168-
# Modify Y so that it now contains three classes
169-
Y_multiclass = self.Y.copy()
170-
Y_multiclass[0] = 2
171-
kpcovc.fit(self.X, Y_multiclass)
172-
n_classes = len(np.unique(Y_multiclass))
160+
self.assertEqual(Z_binary.ndim, 1)
161+
self.assertEqual(Z_binary.shape[0], self.X.shape[0])
173162

174163
# Shape (n_samples, n_classes) for multiclass classification
175-
Z = kpcovc.decision_function(self.X)
164+
kpcovc.fit(self.X, np.random.randint(0, 3, size=self.X.shape[0]))
165+
Z_multi = kpcovc.decision_function(self.X)
176166

177-
self.assertTrue(Z.ndim == 2)
178-
self.assertTrue((Z.shape[0], Z.shape[1]) == (self.X.shape[0], n_classes))
167+
self.assertEqual(Z_multi.ndim, 2)
168+
self.assertEqual(Z_multi.shape, (self.X.shape[0], len(kpcovc.classes_)))
179169

180170
def test_decision_function(self):
181171
"""Check that KPCovC's decision_function works when only T is
@@ -225,11 +215,11 @@ def test_prefit_classifier(self):
225215
kpcovc = KernelPCovC(mixing=0.5, classifier=classifier, **kernel_params)
226216
kpcovc.fit(self.X, self.Y)
227217

228-
Z_classifier = classifier.decision_function(K).reshape(K.shape[0], -1)
229-
W_classifier = classifier.coef_.T.reshape(K.shape[1], -1)
218+
Z_classifier = classifier.decision_function(K)
219+
W_classifier = classifier.coef_.T
230220

231-
Z_kpcovc = kpcovc.z_classifier_.decision_function(K).reshape(K.shape[0], -1)
232-
W_kpcovc = kpcovc.z_classifier_.coef_.T.reshape(K.shape[1], -1)
221+
Z_kpcovc = kpcovc.z_classifier_.decision_function(K)
222+
W_kpcovc = kpcovc.z_classifier_.coef_.T
233223

234224
self.assertTrue(np.allclose(Z_classifier, Z_kpcovc))
235225
self.assertTrue(np.allclose(W_classifier, W_kpcovc))
@@ -273,40 +263,37 @@ def test_none_classifier(self):
273263
self.assertTrue(kpcovc.classifier_ is not None)
274264

275265
def test_incompatible_coef_shape(self):
276-
kernel_params = {"kernel": "rbf", "gamma": 0.1, "degree": 3, "coef0": 0}
277-
278-
K = pairwise_kernels(self.X, metric="rbf", filter_params=True, **kernel_params)
279-
280-
# Modify Y to be multiclass
281-
Y_multiclass = self.Y.copy()
282-
Y_multiclass[0] = 2
266+
kernel_params = {"kernel": "sigmoid", "gamma": 0.1, "degree": 3, "coef0": 0}
267+
K = pairwise_kernels(
268+
self.X, metric="sigmoid", filter_params=True, **kernel_params
269+
)
283270

284-
classifier1 = LinearSVC()
285-
classifier1.fit(K, Y_multiclass)
286-
kpcovc1 = self.model(mixing=0.5, classifier=classifier1, **kernel_params)
271+
cl_multi = LinearSVC()
272+
cl_multi.fit(K, np.random.randint(0, 3, size=self.X.shape[0]))
273+
kpcovc_binary = self.model(mixing=0.5, classifier=cl_multi)
287274

288275
# Binary classification shape mismatch
289276
with self.assertRaises(ValueError) as cm:
290-
kpcovc1.fit(self.X, self.Y)
277+
kpcovc_binary.fit(self.X, self.Y)
291278
self.assertEqual(
292279
str(cm.exception),
293280
"For binary classification, expected classifier coefficients "
294281
"to have shape (1, %d) but got shape %r"
295-
% (K.shape[1], classifier1.coef_.shape),
282+
% (K.shape[1], cl_multi.coef_.shape),
296283
)
297284

298-
classifier2 = LinearSVC()
299-
classifier2.fit(K, self.Y)
300-
kpcovc2 = self.model(mixing=0.5, classifier=classifier2)
285+
cl_binary = LinearSVC()
286+
cl_binary.fit(K, self.Y)
287+
kpcovc_multi = self.model(mixing=0.5, classifier=cl_binary)
301288

302289
# Multiclass classification shape mismatch
303290
with self.assertRaises(ValueError) as cm:
304-
kpcovc2.fit(self.X, Y_multiclass)
291+
kpcovc_multi.fit(self.X, np.random.randint(0, 3, size=self.X.shape[0]))
305292
self.assertEqual(
306293
str(cm.exception),
307294
"For multiclass classification, expected classifier coefficients "
308295
"to have shape (%d, %d) but got shape %r"
309-
% (len(np.unique(Y_multiclass)), K.shape[1], classifier2.coef_.shape),
296+
% (len(kpcovc_multi.classes_), K.shape[1], cl_binary.coef_.shape),
310297
)
311298

312299
def test_precomputed_classification(self):
@@ -316,7 +303,7 @@ def test_precomputed_classification(self):
316303
classifier = LogisticRegression()
317304
classifier.fit(K, self.Y)
318305

319-
W = classifier.coef_.T.reshape(K.shape[1], -1)
306+
W = classifier.coef_.T
320307
kpcovc1 = self.model(mixing=0.5, classifier="precomputed", **kernel_params)
321308
kpcovc1.fit(self.X, self.Y, W)
322309
t1 = kpcovc1.transform(self.X)

0 commit comments

Comments
 (0)