Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ The rules for CHANGELOG file:

0.3.0 (XXXX/XX/XX)
------------------
- Add ``_BasePCov`` class (#248)
- Add ``PCovC`` class that inherits shared functionality from ``_BasePCov`` (#248)
- Add ``PCovC`` testing suite and examples (#248)
- Modify ``PCovR`` to inherit shared functionality from ``_BasePCov_`` (#248)
- Update to sklearn >= 1.6.0 and scipy >= 1.15.0 (#239)
- Fixed moved function import from scipy and bump scipy dependency to 1.15.0 (#236)
- Fix rendering issues for `SparseKDE` and `QuickShift` (#236)
Expand Down
6 changes: 6 additions & 0 deletions docs/src/bibliography.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,9 @@ References
Michele Ceriotti, "Improving Sample and Feature Selection with Principal Covariates
Regression" 2021 Mach. Learn.: Sci. Technol. 2 035038.
https://iopscience.iop.org/article/10.1088/2632-2153/abfe7c.

.. [Jorgensen2025]
Christian Jorgensen, Arthur Y. Lin, and Rose K. Cersonsky,
"Interpretable Visualizations of Data Spaces for Classification Problems"
2025 arXiv. 2503.05861
https://doi.org/10.48550/arXiv.2503.05861.
9 changes: 8 additions & 1 deletion docs/src/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,14 @@
"sphinx_toggleprompt",
]

example_subdirs = ["pcovr", "selection", "regression", "reconstruction", "neighbors"]
example_subdirs = [
"pcovr",
"pcovc",
"selection",
"regression",
"reconstruction",
"neighbors",
]
sphinx_gallery_conf = {
"filename_pattern": "/*",
"examples_dirs": [f"../../examples/{p}" for p in example_subdirs],
Expand Down
24 changes: 22 additions & 2 deletions docs/src/references/decomposition.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
Principal Covariates Regression (PCovR)
=======================================
Hybrid Mapping Techniques (PCovR and PCovC)
===========================================

.. _PCovR-api:

Expand All @@ -20,6 +20,26 @@ PCovR
.. automethod:: inverse_transform
.. automethod:: score

.. _PCovC-api:

PCovC
-----

.. autoclass:: skmatter.decomposition.PCovC
:show-inheritance:
:special-members:

.. automethod:: fit

.. automethod:: _fit_feature_space
.. automethod:: _fit_sample_space

.. automethod:: transform
.. automethod:: predict
.. automethod:: inverse_transform
.. automethod:: decision_function
.. automethod:: score

.. _KPCovR-api:

Kernel PCovR
Expand Down
1 change: 1 addition & 0 deletions docs/src/tutorials.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
.. toctree::

examples/pcovr/index
examples/pcovc/index
examples/selection/index
examples/regression/index
examples/reconstruction/index
Expand Down
122 changes: 122 additions & 0 deletions examples/pcovc/PCovC-BreastCancerDataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
#!/usr/bin/env python
# coding: utf-8

"""
PCovC with the Breast Cancer Dataset
====================================
"""
# %%
#

import matplotlib.pyplot as plt
import numpy as np
from sklearn.datasets import load_breast_cancer
from sklearn.decomposition import PCA
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.linear_model import LogisticRegressionCV
from sklearn.preprocessing import StandardScaler

from skmatter.decomposition import PCovC


plt.rcParams["image.cmap"] = "tab10"
plt.rcParams["scatter.edgecolors"] = "k"

random_state = 0

# %%
#
# For this, we will use the :func:`sklearn.datasets.load_breast_cancer` dataset from
# ``sklearn``.

X, y = load_breast_cancer(return_X_y=True)
print(load_breast_cancer().DESCR)

# %%
#
# Scale Feature Data
# ------------------
#
# Below, we transform the Breast Cancer feature data to have a mean of zero
# and standard deviation of one, while preserving relative relationships
# between feature values.

scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)

# %%
#
# PCA
# ---
#
# We use Principal Component Analysis to reduce the Breast Cancer feature
# data to two features that retain as much information as possible
# about the original dataset.

pca = PCA(n_components=2)

pca.fit(X_scaled, y)
T_pca = pca.transform(X_scaled)

fig, axis = plt.subplots()
scatter = axis.scatter(T_pca[:, 0], T_pca[:, 1], c=y)
axis.set(xlabel="PC$_1$", ylabel="PC$_2$")
axis.legend(
scatter.legend_elements()[0][::-1],
load_breast_cancer().target_names[::-1],
loc="upper right",
title="Classes",
)

# %%
#
# LDA
# ---
#
# Here, we use Linear Discriminant Analysis to find a projection
# of the feature data that maximizes separability between
# the benign/malignant classes.

lda = LinearDiscriminantAnalysis(n_components=1)
lda.fit(X_scaled, y)

T_lda = lda.transform(X_scaled)

fig, axis = plt.subplots()
axis.scatter(-T_lda[:], np.zeros(len(T_lda[:])), c=y)

# %%
#
# PCA, PCovC, and LDA
# -------------------
#
# Below, we see a side-by-side comparison of PCA, PCovC (Logistic
# Regression classifier, :math:`\alpha` = 0.5), and LDA maps of the data.

mixing = 0.5
n_models = 3
fig, axes = plt.subplots(1, n_models, figsize=(6 * n_models, 5))

models = {
PCA(n_components=2): "PCA",
PCovC(
mixing=mixing,
n_components=2,
random_state=random_state,
classifier=LogisticRegressionCV(),
): "PCovC",
LinearDiscriminantAnalysis(n_components=1): "LDA",
}

for id in range(0, n_models):
model = list(models)[id]

model.fit(X_scaled, y)
T = model.transform(X_scaled)

if isinstance(model, LinearDiscriminantAnalysis):
axes[id].scatter(-T_lda[:], np.zeros(len(T_lda[:])), c=y)
else:
axes[id].scatter(T[:, 0], T[:, 1], c=y)

axes[id].set_title(models[model])
178 changes: 178 additions & 0 deletions examples/pcovc/PCovC-IrisDataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
#!/usr/bin/env python
# coding: utf-8

"""
PCovC with the Iris Dataset
===========================
"""
# %%
#

import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
from sklearn.datasets import load_iris
from sklearn.decomposition import PCA
from sklearn.inspection import DecisionBoundaryDisplay
from sklearn.linear_model import LogisticRegressionCV, Perceptron, RidgeClassifierCV
from sklearn.preprocessing import StandardScaler
from sklearn.svm import LinearSVC

from skmatter.decomposition import PCovC


plt.rcParams["image.cmap"] = "tab10"
plt.rcParams["scatter.edgecolors"] = "k"

random_state = 10
n_components = 2

# %%
#
# For this, we will use the :func:`sklearn.datasets.load_iris` dataset from
# ``sklearn``.

X, y = load_iris(return_X_y=True)
print(load_iris().DESCR)

# %%
#
# Scale Feature Data
# ------------------
#
# Below, we transform the Iris feature data to have a mean of zero and
# standard deviation of one, while preserving relative relationships
# between feature values.

scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)

# %%
#
# PCA
# ---
#
# We use Principal Component Analysis to reduce the Iris feature
# data to two features that retain as much information as possible
# about the original dataset.

pca = PCA(n_components=n_components)

pca.fit(X_scaled, y)
T_pca = pca.transform(X_scaled)

fig, axis = plt.subplots()
scatter = axis.scatter(T_pca[:, 0], T_pca[:, 1], c=y)
axis.set(xlabel="PC$_1$", ylabel="PC$_2$")
axis.legend(
scatter.legend_elements()[0],
load_iris().target_names,
loc="lower right",
title="Classes",
)

# %%
#
# Effect of Mixing Parameter :math:`\alpha` on PCovC Map
# ------------------------------------------------------
#
# Below, we see how different :math:`\alpha` values for our PCovC model
# result in varying class distinctions between setosa, versicolor,
# and virginica on the PCovC map.

n_mixing = 5
mixing_params = [0, 0.25, 0.50, 0.75, 1]

fig, axes = plt.subplots(1, n_mixing, figsize=(4 * n_mixing, 4), sharey="row")

for id in range(0, n_mixing):
mixing = mixing_params[id]

pcovc = PCovC(
mixing=mixing,
n_components=n_components,
random_state=random_state,
classifier=LogisticRegressionCV(),
)

pcovc.fit(X_scaled, y)
T = pcovc.transform(X_scaled)

axes[id].set_xticks([])
axes[id].set_yticks([])

axes[id].set_title(r"$\alpha=$" + str(mixing))
axes[id].set_xlabel("PCov$_1$")
axes[id].scatter(T[:, 0], T[:, 1], c=y)

axes[0].set_ylabel("PCov$_2$")

fig.subplots_adjust(wspace=0)

# %%
#
# Effect of PCovC Classifier on PCovC Map and Decision Boundaries
# ---------------------------------------------------------------
#
# Here, we see how a PCovC model (:math:`\alpha` = 0.5) fitted with
# different classifiers produces varying PCovC maps. In addition,
# we see the varying decision boundaries produced by the
# respective PCovC classifiers.

soft_dots = ["#ff3333", "#339933", "#3333ff"]
soft_fill = ["#f5bcbc", "#b7d4b7", "#bcbcf5"]

cmap_dots = LinearSegmentedColormap.from_list("SoftDots", soft_dots)
cmap_fill = LinearSegmentedColormap.from_list("SoftFill", soft_fill)

mixing = 0.5
n_models = 4
fig, axes = plt.subplots(1, n_models, figsize=(4 * n_models, 4))

models = {
RidgeClassifierCV(): "Ridge Classification",
LogisticRegressionCV(random_state=random_state): "Logistic Regression",
LinearSVC(random_state=random_state): "Support Vector Classification",
Perceptron(random_state=random_state): "Single-Layer Perceptron",
}

for id in range(0, n_models):
model = list(models)[id]

pcovc = PCovC(
mixing=mixing,
n_components=n_components,
random_state=random_state,
classifier=model,
)

pcovc.fit(X_scaled, y)
T = pcovc.transform(X_scaled)

graph = axes[id]
graph.set_title(models[model])

DecisionBoundaryDisplay.from_estimator(
estimator=pcovc.classifier_,
X=T,
ax=graph,
response_method="predict",
grid_resolution=1000,
cmap=cmap_fill,
)

scatter = graph.scatter(T[:, 0], T[:, 1], c=y, cmap=cmap_dots)

graph.set_xlabel("PCov$_1$")
graph.set_xticks([])
graph.set_yticks([])

axes[0].set_ylabel("PCov$_2$")
axes[0].legend(
scatter.legend_elements()[0],
load_iris().target_names,
loc="lower right",
title="Classes",
fontsize=8,
)

fig.subplots_adjust(wspace=0.04)
2 changes: 2 additions & 0 deletions examples/pcovc/README.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
PCovC
=====
Loading