Skip to content

Commit 46066a4

Browse files
authored
Merge pull request #4 from floriangardin/master
Any variable name can be used in "feature_names"
2 parents 645b189 + 63d3528 commit 46066a4

File tree

4 files changed

+47
-9
lines changed

4 files changed

+47
-9
lines changed

skrules/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
from .skope_rules import SkopeRules
2-
from .rule import Rule
2+
from .rule import Rule, replace_feature_name
33

44
__all__ = ['SkopeRules', 'Rule']

skrules/rule.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,13 @@
1+
import re
2+
3+
def replace_feature_name(rule, replace_dict):
4+
def replace(match):
5+
return replace_dict[match.group(0)]
6+
7+
rule = re.sub('|'.join(r'\b%s\b' % re.escape(s) for s in replace_dict),
8+
replace, rule)
9+
return rule
10+
111
class Rule:
212
""" An object modelizing a logical rule and add factorization methods.
313
It is used to simplify rules and deduplicate them.
@@ -56,3 +66,4 @@ def __repr__(self):
5666
[feature, symbol, str(self.agg_dict[(feature, symbol)])])
5767
for feature, symbol in sorted(self.agg_dict.keys())
5868
])
69+

skrules/skope_rules.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@
1212
from sklearn.externals import six
1313
from sklearn.tree import _tree
1414

15-
from .rule import Rule
15+
from .rule import Rule, replace_feature_name
1616

1717
INTEGER_TYPES = (numbers.Integral, np.integer)
18-
18+
BASE_FEATURE_NAME = "__C__"
1919

2020
class SkopeRules(BaseEstimator):
2121
""" An easy-interpretable classifier optimizing simple logical rules.
@@ -249,11 +249,17 @@ def fit(self, X, y, sample_weight=None):
249249
self.estimators_samples_ = []
250250
self.estimators_features_ = []
251251

252-
# default columns names of the form ['c0', 'c1', ...]:
253-
feature_names_ = (self.feature_names if self.feature_names is not None
254-
else ['c' + x for x in
255-
np.arange(X.shape[1]).astype(str)])
252+
# default columns names :
253+
feature_names_ = [BASE_FEATURE_NAME + x for x in
254+
np.arange(X.shape[1]).astype(str)]
255+
if self.feature_names is not None:
256+
self.feature_dict_ = {BASE_FEATURE_NAME + str(i): feat
257+
for i, feat in enumerate(self.feature_names)}
258+
else:
259+
self.feature_dict_ = {BASE_FEATURE_NAME + str(i): feat
260+
for i, feat in enumerate(feature_names_)}
256261
self.feature_names_ = feature_names_
262+
257263
clfs = []
258264
regs = []
259265

@@ -356,6 +362,10 @@ def fit(self, X, y, sample_weight=None):
356362
for rule in
357363
[Rule(r, args=args) for r, args in rules_]]
358364

365+
366+
367+
368+
359369
# keep only rules verifying precision_min and recall_min:
360370
for rule, score in rules_:
361371
if score[0] >= self.precision_min and score[1] >= self.recall_min:
@@ -377,7 +387,14 @@ def fit(self, X, y, sample_weight=None):
377387
# Deduplicate the rule using semantic tree
378388
if self.max_depth_duplication is not None:
379389
self.rules_ = self.deduplicate(self.rules_)
390+
380391
self.rules_ = sorted(self.rules_, key=lambda x: - self.f1_score(x))
392+
self.rules_without_feature_names_ = self.rules_
393+
394+
# Replace generic feature names by real feature names
395+
self.rules_ = [(replace_feature_name(rule, self.feature_dict_), perf)
396+
for rule, perf in self.rules_]
397+
381398
return self
382399

383400
def predict(self, X):
@@ -432,7 +449,7 @@ def decision_function(self, X):
432449
% (X.shape[1], self.n_features_))
433450

434451
df = pandas.DataFrame(X, columns=self.feature_names_)
435-
selected_rules = self.rules_
452+
selected_rules = self.rules_without_feature_names_
436453

437454
scores = np.zeros(X.shape[0])
438455
for (r, w) in selected_rules:

skrules/tests/test_rule.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from sklearn.utils.testing import assert_equal, assert_not_equal
22

3-
from skrules import Rule
3+
from skrules import Rule, replace_feature_name
44

55

66
def test_rule():
@@ -53,3 +53,13 @@ def test_equals_rule():
5353

5454
rule3 = "a < 3.0 and a == a"
5555
assert_equal(rule3, str(Rule(rule3)))
56+
57+
58+
def test_replace_feature_name():
59+
rule = "__C__0 <= 3 and __C__1 > 4"
60+
real_rule = "$b <= 3 and c(4) > 4"
61+
replace_dict = {
62+
"__C__0": "$b",
63+
"__C__1": "c(4)"
64+
}
65+
assert_equal(replace_feature_name(rule, replace_dict=replace_dict), real_rule)

0 commit comments

Comments
 (0)