14
14
import logging
15
15
import functools
16
16
17
+ from inspect import getcallargs
18
+
17
19
import numpy as np
18
20
19
21
from sklearn .metrics .classification import (_check_targets , _prf_divide ,
22
24
from sklearn .utils .fixes import bincount
23
25
from sklearn .utils .multiclass import unique_labels
24
26
27
+ try :
28
+ from inspect import signature
29
+ except ImportError :
30
+ from sklearn .externals .funcsigs import signature
31
+
32
+
25
33
LOGGER = logging .getLogger (__name__ )
26
34
27
35
@@ -563,10 +571,10 @@ def geometric_mean_score(y_true,
563
571
564
572
565
573
def make_index_balanced_accuracy (alpha = 0.1 , squared = True ):
566
- """Balance any scoring function using the indexed balanced accuracy
574
+ """Balance any scoring function using the index balanced accuracy
567
575
568
576
This factory function wraps scoring function to express it as the
569
- indexed balanced accuracy (IBA). You need to use this function to
577
+ index balanced accuracy (IBA). You need to use this function to
570
578
decorate any scoring function.
571
579
572
580
Parameters
@@ -582,7 +590,7 @@ def make_index_balanced_accuracy(alpha=0.1, squared=True):
582
590
-------
583
591
iba_scoring_func : callable,
584
592
Returns the scoring metric decorated which will automatically compute
585
- the indexed balanced accuracy.
593
+ the index balanced accuracy.
586
594
587
595
Examples
588
596
--------
@@ -603,21 +611,16 @@ def compute_score(*args, **kwargs):
603
611
# Square if desired
604
612
if squared :
605
613
_score = np .power (_score , 2 )
606
- # args will contain the y_pred and y_true
607
- # kwargs will contain the other parameters
608
- labels = kwargs .get ('labels' , None )
609
- pos_label = kwargs .get ('pos_label' , 1 )
610
- average = kwargs .get ('average' , 'binary' )
611
- sample_weight = kwargs .get ('sample_weight' , None )
612
- # Compute the sensitivity and specificity
613
- dict_sen_spe = {
614
- 'labels' : labels ,
615
- 'pos_label' : pos_label ,
616
- 'average' : average ,
617
- 'sample_weight' : sample_weight
618
- }
619
- sen , spe , _ = sensitivity_specificity_support (* args ,
620
- ** dict_sen_spe )
614
+ # Create the list of tags
615
+ tags_scoring_func = getcallargs (scoring_func , * args , ** kwargs )
616
+ # Get the signature of the sens/spec function
617
+ sens_spec_sig = signature (sensitivity_specificity_support )
618
+ # Filter the inputs required by the sens/spec function
619
+ tags_sens_spec = sens_spec_sig .bind (** tags_scoring_func )
620
+ # Call the sens/spec function
621
+ sen , spe , _ = sensitivity_specificity_support (
622
+ * tags_sens_spec .args ,
623
+ ** tags_sens_spec .kwargs )
621
624
# Compute the dominance
622
625
dom = sen - spe
623
626
return (1. + alpha * dom ) * _score
@@ -640,7 +643,7 @@ def classification_report_imbalanced(y_true,
640
643
Specific metrics have been proposed to evaluate the classification
641
644
performed on imbalanced dataset. This report compiles the
642
645
state-of-the-art metrics: precision/recall/specificity, geometric
643
- mean, and indexed balanced accuracy of the
646
+ mean, and index balanced accuracy of the
644
647
geometric mean.
645
648
646
649
Parameters
@@ -674,7 +677,7 @@ def classification_report_imbalanced(y_true,
674
677
-------
675
678
report : string
676
679
Text summary of the precision, recall, specificity, geometric mean,
677
- and indexed balanced accuracy.
680
+ and index balanced accuracy.
678
681
679
682
Examples
680
683
--------
@@ -746,7 +749,7 @@ class 2 1.00 0.67 1.00 0.80 0.82 0.69\
746
749
labels = labels ,
747
750
average = None ,
748
751
sample_weight = sample_weight )
749
- # Indexed balanced accuracy
752
+ # Index balanced accuracy
750
753
iba_gmean = make_index_balanced_accuracy (
751
754
alpha = alpha , squared = True )(geometric_mean_score )
752
755
iba = iba_gmean (
0 commit comments