@@ -45,13 +45,13 @@ class BorutaPy(BaseEstimator, SelectorMixin):
4545 crucial parameter. For more info, please read about the perc parameter.
4646 - Automatic tree number:
4747 Setting the n_estimator to 'auto' will calculate the number of trees
48- in each itartion based on the number of features under investigation.
48+ in each iteration based on the number of features under investigation.
4949 This way more trees are used when the training data has many features
5050 and less when most of the features have been rejected.
5151 - Ranking of features:
5252 After fitting BorutaPy it provides the user with ranking of features.
5353 Confirmed ones are 1, Tentatives are 2, and the rejected are ranked
54- starting from 3, based on their feautre importance history through
54+ starting from 3, based on their feature importance history through
5555 the iterations.
5656
5757 We highly recommend using pruned trees with a depth between 3-7.
@@ -140,7 +140,7 @@ class BorutaPy(BaseEstimator, SelectorMixin):
140140 support_weak_ : array of shape [n_features]
141141
142142 The mask of selected tentative features, which haven't gained enough
143- support during the max_iter number of iterations..
143+ support during the max_iter number of iterations.
144144
145145 ranking_ : array of shape [n_features]
146146
@@ -328,7 +328,7 @@ def _fit(self, X, y):
328328
329329 # set n_estimators
330330 if self .n_estimators != 'auto' :
331- self .estimator . set_params ( n_estimators = self .n_estimators )
331+ self ._set_n_estimators ( self .n_estimators )
332332
333333 # main feature selection loop
334334 while np .any (dec_reg == 0 ) and _iter < self .max_iter :
@@ -337,7 +337,7 @@ def _fit(self, X, y):
337337 # number of features that aren't rejected
338338 not_rejected = np .where (dec_reg >= 0 )[0 ].shape [0 ]
339339 n_tree = self ._get_tree_num (not_rejected )
340- self .estimator . set_params (n_estimators = n_tree )
340+ self ._set_n_estimators (n_estimators = n_tree )
341341
342342 # make sure we start with a new tree in each iteration
343343 if self ._is_lightgbm :
@@ -358,13 +358,15 @@ def _fit(self, X, y):
358358 # register which feature is more imp than the max of shadows
359359 hit_reg = self ._assign_hits (hit_reg , cur_imp , imp_sha_max )
360360
361- # based on hit_reg we check if a feature is doing better than
362- # expected by chance
363- dec_reg = self ._do_tests (dec_reg , hit_reg , _iter )
361+ # Only test after the 5th round.
362+ if _iter > 4 :
363+ # based on hit_reg we check if a feature is doing better than
364+ # expected by chance
365+ dec_reg = self ._do_tests (dec_reg , hit_reg , _iter )
364366
365- # print out confirmed features
366- if self .verbose > 0 and _iter < self .max_iter :
367- self ._print_results (dec_reg , _iter , 0 )
367+ # print out confirmed features
368+ if self .verbose > 0 and _iter < self .max_iter :
369+ self ._print_results (dec_reg , _iter , 0 )
368370 if _iter < self .max_iter :
369371 _iter += 1
370372
@@ -454,6 +456,17 @@ def _transform(self, X, weak=False, return_df=False):
454456 X = X [:, indices ]
455457 return X
456458
459+ def _set_n_estimators (self , n_estimators ):
460+ try :
461+ self .estimator .set_params (n_estimators = n_estimators )
462+ except ValueError :
463+ raise ValueError (
464+ f"The estimator { self .estimator } does not take the parameter "
465+ "n_estimators. Use Random Forests or gradient boosting machines "
466+ "instead."
467+ )
468+ return self
469+
457470 def _get_support_mask (self ):
458471 check_is_fitted (self , 'support_' )
459472 return self .support_
0 commit comments