Skip to content

Commit 5eb8a00

Browse files
committed
more fixes for tests
1 parent 6625136 commit 5eb8a00

File tree

4 files changed

+228
-139
lines changed

4 files changed

+228
-139
lines changed

skrules/datasets/credit_data.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,15 @@
1818

1919
import pandas as pd
2020
import numpy as np
21-
<<<<<<< HEAD
22-
from sklearn.datasets.base import get_data_home, Bunch
23-
=======
2421

2522
try:
26-
from sklearn.datasets.base import get_data_home, Bunch
27-
except ModuleNotFoundError:
23+
from sklearn.datasets.base import get_data_home, Bunch, _fetch_remote, RemoteFileMetadata
24+
except (ModuleNotFoundError, ImportError):
25+
# For scikit-learn >= 0.24 compatibility
2826
from sklearn.datasets import get_data_home
2927
from sklearn.utils import Bunch
28+
from sklearn.datasets._base import _fetch_remote, RemoteFileMetadata
3029

31-
>>>>>>> aa9588c... fix typo
32-
from sklearn.datasets.base import _fetch_remote, RemoteFileMetadata
3330
from os.path import exists, join
3431

3532

skrules/tests/test_common.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,28 @@
11
from sklearn.utils.estimator_checks import check_estimator
22
from skrules import SkopeRules
33
from skrules.datasets import load_credit_data
4+
import sklearn
45

56

67
def test_classifier():
7-
check_estimator(SkopeRules)
8+
try:
9+
check_estimator(SkopeRules)
10+
except TypeError:
11+
# For sklearn >= 0.24.0 compatibility
12+
from sklearn.utils._testing import SkipTest
13+
from sklearn.utils.estimator_checks import check_sample_weights_invariance
14+
15+
checks = check_estimator(SkopeRules(), generate_only=True)
16+
for estimator, check in checks:
17+
# Here we ignore this particular estimator check because
18+
# sample weights are treated differently in skope-rules
19+
if check.func != check_sample_weights_invariance:
20+
try:
21+
check(estimator)
22+
except SkipTest as exception:
23+
# SkipTest is thrown when pandas can't be imported, or by checks
24+
# that are in the xfail_checks tag
25+
warnings.warn(str(exception), SkipTestWarning)
826

927

1028
def test_load_credit_data():

skrules/tests/test_rule.py

Lines changed: 20 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,65 +1,48 @@
1-
from sklearn.utils.testing import assert_equal, assert_not_equal
2-
31
from skrules import Rule, replace_feature_name
42

53

64
def test_rule():
7-
assert_equal(Rule('a <= 10 and a <= 12'),
8-
Rule('a <= 10'))
9-
assert_equal(Rule('a <= 10 and a <= 12 and a > 3'),
10-
Rule('a > 3 and a <= 10'))
5+
assert Rule("a <= 10 and a <= 12") == Rule("a <= 10")
6+
assert Rule("a <= 10 and a <= 12 and a > 3") == Rule("a > 3 and a <= 10")
117

12-
assert_equal(Rule('a <= 10 and a <= 10 and a > 3'),
13-
Rule('a > 3 and a <= 10'))
8+
assert Rule("a <= 10 and a <= 10 and a > 3") == Rule("a > 3 and a <= 10")
149

15-
assert_equal(Rule('a <= 10 and a <= 12 and b > 3 and b > 6'),
16-
Rule('a <= 10 and b > 6'))
10+
assert Rule("a <= 10 and a <= 12 and b > 3 and b > 6") == Rule("a <= 10 and b > 6")
1711

18-
assert_equal(len({Rule('a <= 2 and a <= 3'),
19-
Rule('a <= 2')
20-
}), 1)
12+
assert len({Rule("a <= 2 and a <= 3"), Rule("a <= 2")}) == 1
2113

22-
assert_equal(len({Rule('a > 2 and a > 3 and b <= 2 and b <= 3'),
23-
Rule('a > 3 and b <= 2')
24-
}), 1)
14+
assert (
15+
len({Rule("a > 2 and a > 3 and b <= 2 and b <= 3"), Rule("a > 3 and b <= 2")})
16+
== 1
17+
)
2518

26-
assert_equal(len({Rule('a <= 3 and b <= 2'),
27-
Rule('b <= 2 and a <= 3')
28-
}), 1)
19+
assert len({Rule("a <= 3 and b <= 2"), Rule("b <= 2 and a <= 3")}) == 1
2920

3021

3122
def test_hash_rule():
32-
assert_equal(len({
33-
Rule('a <= 2 and a <= 3'),
34-
Rule('a <= 2')
35-
}), 1)
36-
assert_not_equal(len({
37-
Rule('a <= 4 and a <= 3'),
38-
Rule('a <= 2')
39-
}), 1)
23+
assert len({Rule("a <= 2 and a <= 3"), Rule("a <= 2")}) == 1
24+
assert len({Rule("a <= 4 and a <= 3"), Rule("a <= 2")}) != 1
4025

4126

4227
def test_str_rule():
43-
rule = 'a <= 10.0 and b > 3.0'
44-
assert_equal(rule, str(Rule(rule)))
28+
rule = "a <= 10.0 and b > 3.0"
29+
assert rule == str(Rule(rule))
4530

4631

4732
def test_equals_rule():
4833
rule = "a == a"
49-
assert_equal(rule, str(Rule(rule)))
34+
assert rule == str(Rule(rule))
5035

5136
rule2 = "a == a and a == a"
52-
assert_equal(rule, str(Rule(rule2)))
37+
assert rule == str(Rule(rule2))
5338

5439
rule3 = "a < 3.0 and a == a"
55-
assert_equal(rule3, str(Rule(rule3)))
40+
assert rule3 == str(Rule(rule3))
5641

5742

5843
def test_replace_feature_name():
5944
rule = "__C__0 <= 3 and __C__1 > 4"
6045
real_rule = "$b <= 3 and c(4) > 4"
61-
replace_dict = {
62-
"__C__0": "$b",
63-
"__C__1": "c(4)"
64-
}
65-
assert_equal(replace_feature_name(rule, replace_dict=replace_dict), real_rule)
46+
replace_dict = {"__C__0": "$b", "__C__1": "c(4)"}
47+
assert replace_feature_name(rule, replace_dict=replace_dict) == real_rule
48+

0 commit comments

Comments
 (0)