Skip to content

Commit 5ee9b59

Browse files
authored
Add flat classifier #minor (#128)
1 parent 6f37990 commit 5ee9b59

File tree

7 files changed

+141
-3
lines changed

7 files changed

+141
-3
lines changed

CONTRIBUTING.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,18 @@ pre-commit install
2929
```
3030

3131
If black is not executed locally and there are formatting errors, the CI/CD pipeline will fail.
32+
33+
## Building the documentation locally
34+
35+
To build the documentation locally, you need to install another set of dependencies that are specific for the documentation. It is easier to create a separate conda environment and run the following command:
36+
37+
```
38+
pip install -r docs/requirements.txt
39+
```
40+
41+
To build the documentation you need to change to run the following commands:
42+
43+
```
44+
cd docs
45+
make html
46+
```

docs/source/api/classifiers.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,3 +49,12 @@ LocalClassifierPerParentNode
4949
:show-inheritance:
5050
:inherited-members:
5151
:special-members: __init__
52+
53+
Flat Classifier
54+
===============
55+
56+
FlatClassifier
57+
^^^^^^^^^^^^^^^^^^^^^^
58+
.. autoclass:: FlatClassifier.FlatClassifier
59+
:members:
60+
:special-members: __init__

hiclass/BinaryPolicy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from abc import ABC
44

5-
from scipy.sparse import vstack, csr_matrix
5+
from scipy.sparse import vstack, csr_matrix, csr_array
66
import networkx as nx
77
import numpy as np
88

@@ -160,7 +160,7 @@ def get_binary_examples(self, node) -> tuple:
160160
)
161161
y = np.zeros(len(X))
162162
y[: len(positive_x)] = 1
163-
elif isinstance(self.X, csr_matrix):
163+
elif isinstance(self.X, csr_matrix) or isinstance(self.X, csr_array):
164164
X = vstack([positive_x, negative_x])
165165
sample_weights = (
166166
vstack([positive_weights, negative_weights])

hiclass/FlatClassifier.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
"""
2+
Flat classifier approach, used for comparison purposes.
3+
4+
Implementation by @lpfgarcia
5+
"""
6+
7+
import numpy as np
8+
from sklearn.base import BaseEstimator
9+
from sklearn.linear_model import LogisticRegression
10+
from sklearn.utils.validation import check_is_fitted
11+
12+
13+
class FlatClassifier(BaseEstimator):
14+
"""
15+
A flat classifier utility that accepts as input a hierarchy and flattens it internally.
16+
17+
Examples
18+
--------
19+
>>> from hiclass import FlatClassifier
20+
>>> y = [['1', '1.1'], ['2', '2.1']]
21+
>>> X = [[1, 2], [3, 4]]
22+
>>> flat = FlatClassifier()
23+
>>> flat.fit(X, y)
24+
>>> flat.predict(X)
25+
array([['1', '1.1'],
26+
['2', '2.1']])
27+
"""
28+
29+
def __init__(
30+
self,
31+
local_classifier: BaseEstimator = LogisticRegression(),
32+
):
33+
"""
34+
Initialize a flat classifier.
35+
36+
Parameters
37+
----------
38+
local_classifier : BaseEstimator, default=LogisticRegression
39+
The scikit-learn model used for the flat classification. Needs to have fit, predict and clone methods.
40+
"""
41+
self.local_classifier = local_classifier
42+
43+
def fit(self, X, y, sample_weight=None):
44+
"""
45+
Fit a flat classifier.
46+
47+
Parameters
48+
----------
49+
X : {array-like, sparse matrix} of shape (n_samples, n_features)
50+
The training input samples. Internally, its dtype will be converted
51+
to ``dtype=np.float32``. If a sparse matrix is provided, it will be
52+
converted into a sparse ``csc_matrix``.
53+
y : array-like of shape (n_samples, n_levels)
54+
The target values, i.e., hierarchical class labels for classification.
55+
sample_weight : array-like of shape (n_samples,), default=None
56+
Array of weights that are assigned to individual samples.
57+
If not provided, then each sample is given unit weight.
58+
59+
Returns
60+
-------
61+
self : object
62+
Fitted estimator.
63+
"""
64+
# Convert from hierarchical labels to flat labels
65+
self.separator_ = "::HiClass::Separator::"
66+
y = [self.separator_.join(i) for i in y]
67+
68+
# Fit flat classifier
69+
self.local_classifier.fit(X, y, sample_weight=sample_weight)
70+
71+
# Return the classifier
72+
return self
73+
74+
def predict(self, X):
75+
"""
76+
Predict classes for the given data.
77+
78+
Hierarchical labels are returned.
79+
80+
Parameters
81+
----------
82+
X : {array-like, sparse matrix} of shape (n_samples, n_features)
83+
The input samples. Internally, its dtype will be converted
84+
to ``dtype=np.float32``. If a sparse matrix is provided, it will be
85+
converted into a sparse ``csr_matrix``.
86+
Returns
87+
-------
88+
y : ndarray of shape (n_samples,) or (n_samples, n_outputs)
89+
The predicted classes.
90+
"""
91+
# Check if fit has been called
92+
check_is_fitted(self)
93+
94+
# Predict and remove separator
95+
predictions = [
96+
i.split(self.separator_) for i in self.local_classifier.predict(X)
97+
]
98+
99+
return np.array(predictions)

hiclass/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from .LocalClassifierPerLevel import LocalClassifierPerLevel
66
from .LocalClassifierPerNode import LocalClassifierPerNode
77
from .LocalClassifierPerParentNode import LocalClassifierPerParentNode
8+
from .FlatClassifier import FlatClassifier
89
from .MultiLabelLocalClassifierPerNode import MultiLabelLocalClassifierPerNode
910
from .MultiLabelLocalClassifierPerParentNode import (
1011
MultiLabelLocalClassifierPerParentNode,
@@ -19,6 +20,7 @@
1920
"LocalClassifierPerNode",
2021
"LocalClassifierPerParentNode",
2122
"LocalClassifierPerLevel",
23+
"FlatClassifier",
2224
"Explainer",
2325
"MultiLabelLocalClassifierPerNode",
2426
"MultiLabelLocalClassifierPerParentNode",

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
KEYWORDS = ["hierarchical classification"]
2828
DACS_SOFTWARE = "https://gitlab.com/dacs-hpi"
2929
# What packages are required for this module to be executed?
30-
REQUIRED = ["networkx", "numpy", "scikit-learn", "scipy<1.13"]
30+
REQUIRED = ["networkx", "numpy", "scikit-learn<1.5", "scipy<1.13"]
3131

3232
# What packages are optional?
3333
# 'fancy feature': ['django'],}

tests/test_FlatClassifier.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import numpy as np
2+
from numpy.testing import assert_array_equal
3+
4+
from hiclass import FlatClassifier
5+
6+
7+
def test_fit_predict():
8+
flat = FlatClassifier()
9+
x = np.array([[1, 2], [3, 4]])
10+
y = np.array([["a", "b"], ["b", "c"]])
11+
flat.fit(x, y)
12+
predictions = flat.predict(x)
13+
assert_array_equal(y, predictions)

0 commit comments

Comments
 (0)