diff --git a/examples/pcovc/KPCovC_Comparison.py b/examples/pcovc/KPCovC_Comparison.py index 0dd1277e5..5028a9b5c 100644 --- a/examples/pcovc/KPCovC_Comparison.py +++ b/examples/pcovc/KPCovC_Comparison.py @@ -105,7 +105,7 @@ t_train = model.fit_transform(X_train_scaled, y_train) t_test = model.transform(X_test_scaled) - ax.scatter(t_test[:, 0], t_test[:, 1], alpha=alpha_d, cmap=cm_bright, c=y_test) + ax.scatter(t_test[:, 0], t_test[:, 1], alpha=alpha_p, cmap=cm_bright, c=y_test) ax.scatter(t_train[:, 0], t_train[:, 1], cmap=cm_bright, c=y_train) ax.set_title(models[model]) @@ -197,20 +197,16 @@ models = { LogisticRegressionCV(random_state=random_state): { "kernel_params": {"kernel": "rbf", "gamma": 12}, - "title": "Logistic Regression", }, RidgeClassifierCV(): { "kernel_params": {"kernel": "rbf", "gamma": 1}, - "title": "Ridge Classifier", "eps": 0.40, }, LinearSVC(random_state=random_state): { "kernel_params": {"kernel": "rbf", "gamma": 15}, - "title": "Support Vector Classification", }, SGDClassifier(random_state=random_state): { "kernel_params": {"kernel": "rbf", "gamma": 15}, - "title": "SGD Classifier", "eps": 10, }, } diff --git a/src/skmatter/decomposition/_kernel_pcovc.py b/src/skmatter/decomposition/_kernel_pcovc.py index 63cc5f9fb..e8965a223 100644 --- a/src/skmatter/decomposition/_kernel_pcovc.py +++ b/src/skmatter/decomposition/_kernel_pcovc.py @@ -24,8 +24,8 @@ class KernelPCovC(LinearClassifierMixin, _BaseKPCov): r"""Kernel Principal Covariates Classification (KPCovC). - KPCovC is a modification on the PrincipalCovariates Classification - proposed in [Jorgensen2025]_. It determines a latent-space projection + KPCovC is a modification on the Principal Covariates Classification + proposed in [Jorgensen2025]_. It determines a latent-space projection :math:`\mathbf{T}` which minimizes a combined loss in supervised and unsupervised tasks in the reproducing kernel Hilbert space (RKHS). @@ -272,7 +272,7 @@ def fit(self, X, Y, W=None): check_classification_targets(Y) self.classes_ = np.unique(Y) - super().fit(X) + super()._set_fit_params(X) K = self._get_kernel(X) @@ -314,14 +314,13 @@ def fit(self, X, Y, W=None): # Check if classifier is fitted; if not, fit with precomputed K self.z_classifier_ = check_cl_fit(classifier, K, Y) - W = self.z_classifier_.coef_.T.reshape(K.shape[1], -1) + W = self.z_classifier_.coef_.T else: # If precomputed, use default classifier to predict Y from T classifier = LogisticRegression(max_iter=500) if W is None: W = LogisticRegression().fit(K, Y).coef_.T - W = W.reshape(K.shape[1], -1) Z = K @ W @@ -440,7 +439,7 @@ def decision_function(self, X=None, T=None): if self.center: K = self.centerer_.transform(K) - # Or self.classifier_.decision_function(K @ self.pxt_) + # Or self.classifier_.decision_function(K @ self.pkt_) return K @ self.pkz_ + self.classifier_.intercept_ else: diff --git a/src/skmatter/decomposition/_kernel_pcovr.py b/src/skmatter/decomposition/_kernel_pcovr.py index ed276739a..05fb4d436 100644 --- a/src/skmatter/decomposition/_kernel_pcovr.py +++ b/src/skmatter/decomposition/_kernel_pcovr.py @@ -242,7 +242,7 @@ def fit(self, X, Y, W=None): """ X, Y = validate_data(self, X, Y, y_numeric=True, multi_output=True) - super().fit(X) + super()._set_fit_params(X) K = self._get_kernel(X) diff --git a/src/skmatter/decomposition/_kpcov.py b/src/skmatter/decomposition/_kpcov.py index 62a5614da..5436d3ef4 100644 --- a/src/skmatter/decomposition/_kpcov.py +++ b/src/skmatter/decomposition/_kpcov.py @@ -74,10 +74,12 @@ def _get_kernel(self, X, Y=None): X, Y, metric=self.kernel, filter_params=True, n_jobs=self.n_jobs, **params ) - def fit(self, X): - """Contains the common functionality for the KPCovR and KPCovC fit methods, - but leaves the rest of the functionality to the subclass. - """ + @abstractmethod + def fit(self, X, Y): + """Fit the model with X and Y. Subclasses should implement this method.""" + + def _set_fit_params(self, X): + """Initializes common fit parameters for PCovR and PCovC.""" self.X_fit_ = X.copy() if self.n_components is None: diff --git a/src/skmatter/decomposition/_pcov.py b/src/skmatter/decomposition/_pcov.py index 04dc93b4e..e490b774d 100644 --- a/src/skmatter/decomposition/_pcov.py +++ b/src/skmatter/decomposition/_pcov.py @@ -48,10 +48,12 @@ def __init__( self.random_state = random_state self.whiten = whiten - def fit(self, X): - """Contains the common functionality for the PCovR and PCovC fit methods, - but leaves the rest of the functionality to the subclass. - """ + @abstractmethod + def fit(self, X, Y): + """Fit the model with X and Y. Subclasses should implement this method.""" + + def _set_fit_params(self, X): + """Initializes common fit parameters for PCovR and PCovC.""" # saved for inverse transformations from the latent space, # should be zero in the case that the features have been properly centered self.mean_ = np.mean(X, axis=0) diff --git a/src/skmatter/decomposition/_pcovc.py b/src/skmatter/decomposition/_pcovc.py index ec8ce3202..e0cee034e 100644 --- a/src/skmatter/decomposition/_pcovc.py +++ b/src/skmatter/decomposition/_pcovc.py @@ -97,11 +97,11 @@ class PCovC(LinearClassifierMixin, _BasePCov): Must be of range [0.0, infinity). space: {'feature', 'sample', 'auto'}, default='auto' - whether to compute the PCovC in `sample` or `feature` space - default=`sample` when :math:`{n_{samples} < n_{features}}` and - `feature` when :math:`{n_{features} < n_{samples}}` + whether to compute the PCovC in ``sample`` or ``feature`` space. + The default is equal to ``sample`` when :math:`{n_{samples} < n_{features}}` + and ``feature`` when :math:`{n_{features} < n_{samples}}` - classifier: `estimator object` or `precomputed`, default=None + classifier: ``estimator object`` or ``precomputed``, default=None classifier for computing :math:`{\mathbf{Z}}`. The classifier should be one of the following: @@ -144,9 +144,9 @@ class PCovC(LinearClassifierMixin, _BasePCov): Must be of range [0.0, infinity). space: {'feature', 'sample', 'auto'}, default='auto' - whether to compute the PCovC in `sample` or `feature` space - default=`sample` when :math:`{n_{samples} < n_{features}}` and - `feature` when :math:`{n_{features} < n_{samples}}` + whether to compute the PCovC in ``sample`` or ``feature`` space. + The default is equal to ``sample`` when :math:`{n_{samples} < n_{features}}` + and ``feature`` when :math:`{n_{features} < n_{samples}}` n_components_ : int The estimated number of components, which equals the parameter @@ -160,7 +160,7 @@ class PCovC(LinearClassifierMixin, _BasePCov): The linear classifier fit between :math:`\mathbf{X}` and :math:`\mathbf{Y}`. classifier_ : estimator object - The linear classifier fit between :math:`\mathbf{T}` and :math:`\mathbf{Y}`. + The linear classifier fit between :math:`\mathbf{T}` and :math:`\mathbf{Y}`. pxt_ : ndarray of size :math:`({n_{features}, n_{components}})` the projector, or weights, from the input space :math:`\mathbf{X}` @@ -254,7 +254,7 @@ def fit(self, X, Y, W=None): Training data, where n_samples is the number of samples. W : numpy.ndarray, shape (n_features, n_classes) - Classification weights, optional when classifier = `precomputed`. If + Classification weights, optional when classifier is ``precomputed``. If not passed, it is assumed that the weights will be taken from a linear classifier fit between :math:`\mathbf{X}` and :math:`\mathbf{Y}` """ @@ -262,7 +262,7 @@ def fit(self, X, Y, W=None): check_classification_targets(Y) self.classes_ = np.unique(Y) - super().fit(X) + super()._set_fit_params(X) compatible_classifiers = ( LogisticRegression, @@ -291,14 +291,13 @@ def fit(self, X, Y, W=None): classifier = self.classifier self.z_classifier_ = check_cl_fit(classifier, X, Y) - W = self.z_classifier_.coef_.T.reshape(X.shape[1], -1) + W = self.z_classifier_.coef_.T else: # If precomputed, use default classifier to predict Y from T classifier = LogisticRegression() if W is None: W = LogisticRegression().fit(X, Y).coef_.T - W = W.reshape(X.shape[1], -1) Z = X @ W diff --git a/src/skmatter/decomposition/_pcovr.py b/src/skmatter/decomposition/_pcovr.py index 9a038c6ea..4bcc03c7e 100644 --- a/src/skmatter/decomposition/_pcovr.py +++ b/src/skmatter/decomposition/_pcovr.py @@ -88,9 +88,9 @@ class PCovR(RegressorMixin, MultiOutputMixin, _BasePCov): range [0.0, infinity). space: {'feature', 'sample', 'auto'}, default='auto' - whether to compute the PCovR in `sample` or `feature` space default=`sample` - when :math:`{n_{samples} < n_{features}}` and `feature` when - :math:`{n_{features} < n_{samples}}` + whether to compute the PCovC in ``sample`` or ``feature`` space. + The default is equal to ``sample`` when :math:`{n_{samples} < n_{features}}` + and ``feature`` when :math:`{n_{features} < n_{samples}}` regressor: {`Ridge`, `RidgeCV`, `LinearRegression`, `precomputed`}, default=None regressor for computing approximated :math:`{\mathbf{\hat{Y}}}`. The regressor @@ -126,9 +126,9 @@ class PCovR(RegressorMixin, MultiOutputMixin, _BasePCov): Must be of range [0.0, infinity). space: {'feature', 'sample', 'auto'}, default='auto' - whether to compute the PCovR in `sample` or `feature` space default=`sample` - when :math:`{n_{samples} < n_{features}}` and `feature` when - :math:`{n_{features} < n_{samples}}` + whether to compute the PCovR in ``sample`` or ``feature`` space. + The default is equal to ``sample`` when :math:`{n_{samples} < n_{features}}` + and ``feature`` when :math:`{n_{features} < n_{samples}}` n_components_ : int The estimated number of components, which equals the parameter n_components, or @@ -227,11 +227,12 @@ def fit(self, X, Y, W=None): regressed form of the properties, :math:`{\mathbf{\hat{Y}}}`. W : numpy.ndarray, shape (n_features, n_properties) - Regression weights, optional when regressor= `precomputed`. If not + Regression weights, optional when regressor is ``precomputed``. If not passed, it is assumed that `W = np.linalg.lstsq(X, Y, self.tol)[0]` """ X, Y = validate_data(self, X, Y, y_numeric=True, multi_output=True) - super().fit(X) + + super()._set_fit_params(X) compatible_regressors = (LinearRegression, Ridge, RidgeCV) @@ -414,7 +415,7 @@ def score(self, X, y, T=None): Negative sum of the loss in reconstructing X from the latent-space projection T and the loss in predicting Y from the latent-space projection T """ - X, y = validate_data(self, X, y, reset=False) + X, y = validate_data(self, X, y, multi_output=True, reset=False) if T is None: T = self.transform(X) diff --git a/tests/test_kernel_pcovc.py b/tests/test_kernel_pcovc.py index 10ef589af..9b29b8437 100644 --- a/tests/test_kernel_pcovc.py +++ b/tests/test_kernel_pcovc.py @@ -160,22 +160,17 @@ def test_Z_shape(self): kpcovc.fit(self.X, self.Y) # Shape (n_samples, ) for binary classifcation - Z = kpcovc.decision_function(self.X) + Z_binary = kpcovc.decision_function(self.X) - self.assertTrue(Z.ndim == 1) - self.assertTrue(Z.shape[0] == self.X.shape[0]) - - # Modify Y so that it now contains three classes - Y_multiclass = self.Y.copy() - Y_multiclass[0] = 2 - kpcovc.fit(self.X, Y_multiclass) - n_classes = len(np.unique(Y_multiclass)) + self.assertEqual(Z_binary.ndim, 1) + self.assertEqual(Z_binary.shape[0], self.X.shape[0]) # Shape (n_samples, n_classes) for multiclass classification - Z = kpcovc.decision_function(self.X) + kpcovc.fit(self.X, np.random.randint(0, 3, size=self.X.shape[0])) + Z_multi = kpcovc.decision_function(self.X) - self.assertTrue(Z.ndim == 2) - self.assertTrue((Z.shape[0], Z.shape[1]) == (self.X.shape[0], n_classes)) + self.assertEqual(Z_multi.ndim, 2) + self.assertEqual(Z_multi.shape, (self.X.shape[0], len(kpcovc.classes_))) def test_decision_function(self): """Check that KPCovC's decision_function works when only T is @@ -225,11 +220,11 @@ def test_prefit_classifier(self): kpcovc = KernelPCovC(mixing=0.5, classifier=classifier, **kernel_params) kpcovc.fit(self.X, self.Y) - Z_classifier = classifier.decision_function(K).reshape(K.shape[0], -1) - W_classifier = classifier.coef_.T.reshape(K.shape[1], -1) + Z_classifier = classifier.decision_function(K) + W_classifier = classifier.coef_.T - Z_kpcovc = kpcovc.z_classifier_.decision_function(K).reshape(K.shape[0], -1) - W_kpcovc = kpcovc.z_classifier_.coef_.T.reshape(K.shape[1], -1) + Z_kpcovc = kpcovc.z_classifier_.decision_function(K) + W_kpcovc = kpcovc.z_classifier_.coef_.T self.assertTrue(np.allclose(Z_classifier, Z_kpcovc)) self.assertTrue(np.allclose(W_classifier, W_kpcovc)) @@ -273,40 +268,37 @@ def test_none_classifier(self): self.assertTrue(kpcovc.classifier_ is not None) def test_incompatible_coef_shape(self): - kernel_params = {"kernel": "rbf", "gamma": 0.1, "degree": 3, "coef0": 0} - - K = pairwise_kernels(self.X, metric="rbf", filter_params=True, **kernel_params) - - # Modify Y to be multiclass - Y_multiclass = self.Y.copy() - Y_multiclass[0] = 2 + kernel_params = {"kernel": "sigmoid", "gamma": 0.1, "degree": 3, "coef0": 0} + K = pairwise_kernels( + self.X, metric="sigmoid", filter_params=True, **kernel_params + ) - classifier1 = LinearSVC() - classifier1.fit(K, Y_multiclass) - kpcovc1 = self.model(mixing=0.5, classifier=classifier1, **kernel_params) + cl_multi = LinearSVC() + cl_multi.fit(K, np.random.randint(0, 3, size=self.X.shape[0])) + kpcovc_binary = self.model(mixing=0.5, classifier=cl_multi) # Binary classification shape mismatch with self.assertRaises(ValueError) as cm: - kpcovc1.fit(self.X, self.Y) + kpcovc_binary.fit(self.X, self.Y) self.assertEqual( str(cm.exception), "For binary classification, expected classifier coefficients " "to have shape (1, %d) but got shape %r" - % (K.shape[1], classifier1.coef_.shape), + % (K.shape[1], cl_multi.coef_.shape), ) - classifier2 = LinearSVC() - classifier2.fit(K, self.Y) - kpcovc2 = self.model(mixing=0.5, classifier=classifier2) + cl_binary = LinearSVC() + cl_binary.fit(K, self.Y) + kpcovc_multi = self.model(mixing=0.5, classifier=cl_binary) # Multiclass classification shape mismatch with self.assertRaises(ValueError) as cm: - kpcovc2.fit(self.X, Y_multiclass) + kpcovc_multi.fit(self.X, np.random.randint(0, 3, size=self.X.shape[0])) self.assertEqual( str(cm.exception), "For multiclass classification, expected classifier coefficients " "to have shape (%d, %d) but got shape %r" - % (len(np.unique(Y_multiclass)), K.shape[1], classifier2.coef_.shape), + % (len(kpcovc_multi.classes_), K.shape[1], cl_binary.coef_.shape), ) def test_precomputed_classification(self): @@ -316,7 +308,7 @@ def test_precomputed_classification(self): classifier = LogisticRegression() classifier.fit(K, self.Y) - W = classifier.coef_.T.reshape(K.shape[1], -1) + W = classifier.coef_.T kpcovc1 = self.model(mixing=0.5, classifier="precomputed", **kernel_params) kpcovc1.fit(self.X, self.Y, W) t1 = kpcovc1.transform(self.X) diff --git a/tests/test_pcovc.py b/tests/test_pcovc.py index 5746c610f..8607a2e2a 100644 --- a/tests/test_pcovc.py +++ b/tests/test_pcovc.py @@ -4,9 +4,9 @@ import numpy as np from sklearn import exceptions from sklearn.calibration import LinearSVC -from sklearn.datasets import load_breast_cancer as get_dataset +from sklearn.datasets import load_iris as get_dataset from sklearn.decomposition import PCA -from sklearn.linear_model import LogisticRegression +from sklearn.linear_model import LogisticRegression, RidgeClassifier from sklearn.naive_bayes import GaussianNB from sklearn.preprocessing import StandardScaler from sklearn.utils.validation import check_X_y @@ -25,9 +25,13 @@ def __init__(self, *args, **kwargs): ) self.error_tol = 1e-5 - self.X, self.Y = get_dataset(return_X_y=True) + # n_samples > 500 to ensure our svd_solver tests catch all cases + X_stacked = np.tile(self.X, (4, 1)) + Y_stacked = np.tile(self.Y, 4) + self.X, self.Y = X_stacked, Y_stacked + scaler = StandardScaler() self.X = scaler.fit_transform(self.X) @@ -75,11 +79,16 @@ def test_simple_reconstruction(self): def test_simple_prediction(self): """ Check that PCovC with a full eigendecomposition at mixing=0 - can fully reconstruct the input properties. + can reproduce a linear classification result. """ for space in ["feature", "sample", "auto"]: with self.subTest(space=space): - pcovc = self.model(mixing=0.0, n_components=2, space=space) + pcovc = self.model( + mixing=0.0, + classifier=RidgeClassifier(), + n_components=2, + space=space, + ) pcovc.classifier.fit(self.X, self.Y) Yhat = pcovc.classifier.predict(self.X) @@ -170,10 +179,11 @@ def test_select_sample_space(self): Check that PCovC implements the sample space version when :math:`n_{features} > n_{samples}``. """ - pcovc = self.model(n_components=2, tol=1e-12) + pcovc = self.model(n_components=1, tol=1e-12, svd_solver="arpack") + n_samples = 2 - n_samples = self.X.shape[1] - 1 - pcovc.fit(self.X[:n_samples], self.Y[:n_samples]) + # select range where there are at least 2 classes in Y + pcovc.fit(self.X[49 : 49 + n_samples], self.Y[49 : 49 + n_samples]) self.assertTrue(pcovc.space_ == "sample") @@ -289,7 +299,8 @@ def test_bad_n_components(self): pcovc = self.model( n_components="mle", classifier=LinearSVC(), svd_solver="full" ) - pcovc.fit(self.X[:20], self.Y[:20]) + # select range where there are at least 2 classes in Y + pcovc.fit(self.X[49:51], self.Y[49:51]) self.assertEqual( str(cm.exception), "n_components='mle' is only supported if n_samples >= n_features", @@ -395,7 +406,7 @@ def test_T_shape(self): """Check that PCovC returns a latent space projection consistent with the shape of the input matrix. """ - n_components = 5 + n_components = 4 pcovc = self.model(n_components=n_components, tol=1e-12) pcovc.fit(self.X, self.Y) T = pcovc.transform(self.X) @@ -414,27 +425,21 @@ def test_Z_shape(self): """Check that PCovC returns an evidence matrix consistent with the number of samples and the number of classes. """ - n_components = 5 + n_components = 2 pcovc = self.model(n_components=n_components, tol=1e-12) - pcovc.fit(self.X, self.Y) + pcovc.fit(self.X, np.random.randint(0, 2, size=self.X.shape[0])) # Shape (n_samples, ) for binary classifcation - Z = pcovc.decision_function(self.X) - - self.assertTrue(Z.ndim == 1) - self.assertTrue(Z.shape[0] == self.X.shape[0]) - - # Modify Y so that it now contains three classes - Y_multiclass = self.Y.copy() - Y_multiclass[0] = 2 - pcovc.fit(self.X, Y_multiclass) - n_classes = len(np.unique(Y_multiclass)) + Z_binary = pcovc.decision_function(self.X) + self.assertEqual(Z_binary.ndim, 1) + self.assertEqual(Z_binary.shape[0], self.X.shape[0]) # Shape (n_samples, n_classes) for multiclass classification - Z = pcovc.decision_function(self.X) + pcovc.fit(self.X, self.Y) + Z_multi = pcovc.decision_function(self.X) - self.assertTrue(Z.ndim == 2) - self.assertTrue((Z.shape[0], Z.shape[1]) == (self.X.shape[0], n_classes)) + self.assertEqual(Z_multi.ndim, 2) + self.assertEqual(Z_multi.shape, (self.X.shape[0], len(pcovc.classes_))) def test_decision_function(self): """Check that PCovC's decision_function works when only T is @@ -464,13 +469,11 @@ def test_prefit_classifier(self): pcovc = self.model(mixing=0.5, classifier=classifier) pcovc.fit(self.X, self.Y) - Z_classifier = classifier.decision_function(self.X).reshape(self.X.shape[0], -1) - W_classifier = classifier.coef_.T.reshape(self.X.shape[1], -1) + Z_classifier = classifier.decision_function(self.X) + W_classifier = classifier.coef_.T - Z_pcovc = pcovc.z_classifier_.decision_function(self.X).reshape( - self.X.shape[0], -1 - ) - W_pcovc = pcovc.z_classifier_.coef_.T.reshape(self.X.shape[1], -1) + Z_pcovc = pcovc.z_classifier_.decision_function(self.X) + W_pcovc = pcovc.z_classifier_.coef_.T self.assertTrue(np.allclose(Z_classifier, Z_pcovc)) self.assertTrue(np.allclose(W_classifier, W_pcovc)) @@ -479,7 +482,7 @@ def test_precomputed_classification(self): classifier = LogisticRegression() classifier.fit(self.X, self.Y) - W = classifier.coef_.T.reshape(self.X.shape[1], -1) + W = classifier.coef_.T pcovc1 = self.model(mixing=0.5, classifier="precomputed", n_components=1) pcovc1.fit(self.X, self.Y, W) t1 = pcovc1.transform(self.X) @@ -544,37 +547,32 @@ def test_none_classifier(self): self.assertTrue(pcovc.classifier_ is not None) def test_incompatible_coef_shape(self): - classifier1 = LogisticRegression() - - # Modify Y to be multiclass - Y_multiclass = self.Y.copy() - Y_multiclass[0] = 2 - - classifier1.fit(self.X, Y_multiclass) - pcovc1 = self.model(mixing=0.5, classifier=classifier1) + cl_multi = LogisticRegression() + cl_multi.fit(self.X, self.Y) + pcovc_binary = self.model(mixing=0.5, classifier=cl_multi) # Binary classification shape mismatch with self.assertRaises(ValueError) as cm: - pcovc1.fit(self.X, self.Y) + pcovc_binary.fit(self.X, np.random.randint(0, 2, size=self.X.shape[0])) self.assertEqual( str(cm.exception), "For binary classification, expected classifier coefficients " "to have shape (1, %d) but got shape %r" - % (self.X.shape[1], classifier1.coef_.shape), + % (self.X.shape[1], cl_multi.coef_.shape), ) - classifier2 = LogisticRegression() - classifier2.fit(self.X, self.Y) - pcovc2 = self.model(mixing=0.5, classifier=classifier2) + cl_binary = LogisticRegression() + cl_binary.fit(self.X, np.random.randint(0, 2, size=self.X.shape[0])) + pcovc_multi = self.model(mixing=0.5, classifier=cl_binary) # Multiclass classification shape mismatch with self.assertRaises(ValueError) as cm: - pcovc2.fit(self.X, Y_multiclass) + pcovc_multi.fit(self.X, self.Y) self.assertEqual( str(cm.exception), "For multiclass classification, expected classifier coefficients " "to have shape (%d, %d) but got shape %r" - % (len(np.unique(Y_multiclass)), self.X.shape[1], classifier2.coef_.shape), + % (len(pcovc_multi.classes_), self.X.shape[1], cl_binary.coef_.shape), ) diff --git a/tests/test_pcovr.py b/tests/test_pcovr.py index 597dcc2ba..0b5dfcb1d 100644 --- a/tests/test_pcovr.py +++ b/tests/test_pcovr.py @@ -401,7 +401,7 @@ def test_default_ncomponents(self): self.assertEqual(pcovr.n_components_, min(self.X.shape)) - def test_Y_Shape(self): + def test_Y_shape(self): pcovr = self.model() self.Y = np.vstack(self.Y) pcovr.fit(self.X, self.Y)