|
4 | 4 | import networkx as nx
|
5 | 5 | import numpy as np
|
6 | 6 | import pytest
|
7 |
| -from numpy.testing import assert_array_equal, assert_array_almost_equal |
| 7 | +from bert_sklearn import BertClassifier |
| 8 | +from numpy.testing import assert_array_almost_equal, assert_array_equal |
8 | 9 | from scipy.sparse import csr_matrix
|
9 | 10 | from sklearn.exceptions import NotFittedError
|
10 | 11 | from sklearn.linear_model import LogisticRegression
|
11 | 12 | from sklearn.utils.estimator_checks import parametrize_with_checks
|
12 | 13 | from sklearn.utils.validation import check_is_fitted
|
| 14 | + |
13 | 15 | from hiclass import LocalClassifierPerParentNode
|
14 | 16 | from hiclass._calibration.Calibrator import _Calibrator
|
15 | 17 | from hiclass.HierarchicalClassifier import make_leveled
|
@@ -393,3 +395,37 @@ def test_fit_calibrate_predict_predict_proba_bert():
|
393 | 395 | classifier.calibrate(x, y)
|
394 | 396 | classifier.predict(x)
|
395 | 397 | classifier.predict_proba(x)
|
| 398 | + |
| 399 | + |
| 400 | +# Note: bert only works with the local classifier per parent node |
| 401 | +# It does not have the attribute classes_, which are necessary |
| 402 | +# for the local classifiers per level and per node |
| 403 | +def test_fit_bert(): |
| 404 | + bert = BertClassifier() |
| 405 | + clf = LocalClassifierPerParentNode( |
| 406 | + local_classifier=bert, |
| 407 | + bert=True, |
| 408 | + ) |
| 409 | + x = ["Batman", "rorschach"] |
| 410 | + y = [ |
| 411 | + ["Action", "The Dark Night"], |
| 412 | + ["Action", "Watchmen"], |
| 413 | + ] |
| 414 | + clf.fit(x, y) |
| 415 | + check_is_fitted(clf) |
| 416 | + predictions = clf.predict(x) |
| 417 | + assert_array_equal(y, predictions) |
| 418 | + |
| 419 | + |
| 420 | +def test_bert_unleveled(): |
| 421 | + clf = LocalClassifierPerParentNode( |
| 422 | + local_classifier=BertClassifier(), |
| 423 | + bert=True, |
| 424 | + ) |
| 425 | + x = ["Batman", "Jaws"] |
| 426 | + y = [["Action", "The Dark Night"], ["Thriller"]] |
| 427 | + ground_truth = [["Action", "The Dark Night"], ["Action", "The Dark Night"]] |
| 428 | + clf.fit(x, y) |
| 429 | + check_is_fitted(clf) |
| 430 | + predictions = clf.predict(x) |
| 431 | + assert_array_equal(ground_truth, predictions) |
0 commit comments