11from __future__ import annotations
22
33from dataclasses import dataclass
4+ import enum
45import pathlib
5- from typing import Dict , TextIO , cast
6+ from typing import Dict , Optional , TextIO , cast
67
78import pandas as pd
89
1213from isic_challenge_scoring .types import DataFrameDict , RocDict , Score , ScoreDict , SeriesDict
1314
1415
16+ class ValidationMetric (enum .Enum ):
17+ BALANCED_ACCURACY = 'balanced_accuracy'
18+ AUC = 'auc'
19+ AVERAGE_PRECISION = 'ap'
20+
21+
1522@dataclass (init = False )
1623class ClassificationScore (Score ):
1724 per_category : pd .DataFrame
@@ -24,6 +31,7 @@ def __init__(
2431 truth_probabilities : pd .DataFrame ,
2532 prediction_probabilities : pd .DataFrame ,
2633 truth_weights : pd .DataFrame ,
34+ validation_metric : Optional [ValidationMetric ] = None ,
2735 ) -> None :
2836 categories = truth_probabilities .columns
2937
@@ -61,9 +69,36 @@ def __init__(
6169 )
6270
6371 self .overall = self .aggregate .at ['balanced_accuracy' ]
64- self .validation = metrics .balanced_multiclass_accuracy (
65- truth_probabilities , prediction_probabilities , truth_weights .validation_weight
66- )
72+
73+ if validation_metric :
74+ if validation_metric == ValidationMetric .BALANCED_ACCURACY :
75+ self .validation = metrics .balanced_multiclass_accuracy (
76+ truth_probabilities , prediction_probabilities , truth_weights .validation_weight
77+ )
78+ elif validation_metric == ValidationMetric .AVERAGE_PRECISION :
79+ per_category_ap = pd .Series (
80+ [
81+ metrics .average_precision (
82+ truth_probabilities [category ],
83+ prediction_probabilities [category ],
84+ truth_weights .validation_weight ,
85+ )
86+ for category in categories
87+ ]
88+ )
89+ self .validation = per_category_ap .mean ()
90+ elif validation_metric == ValidationMetric .AUC :
91+ per_category_auc = pd .Series (
92+ [
93+ metrics .auc (
94+ truth_probabilities [category ],
95+ prediction_probabilities [category ],
96+ truth_weights .validation_weight ,
97+ )
98+ for category in categories
99+ ]
100+ )
101+ self .validation = per_category_auc .mean ()
67102
68103 @staticmethod
69104 def _category_score (
@@ -153,7 +188,10 @@ def to_dict(self, rocs: bool = True) -> ScoreDict:
153188
154189 @classmethod
155190 def from_stream (
156- cls , truth_file_stream : TextIO , prediction_file_stream : TextIO
191+ cls ,
192+ truth_file_stream : TextIO ,
193+ prediction_file_stream : TextIO ,
194+ validation_metric : Optional [ValidationMetric ] = None ,
157195 ) -> ClassificationScore :
158196 truth_probabilities , truth_weights = parse_truth_csv (truth_file_stream )
159197 categories = truth_probabilities .columns
@@ -164,16 +202,21 @@ def from_stream(
164202 sort_rows (truth_probabilities )
165203 sort_rows (prediction_probabilities )
166204
167- score = cls (truth_probabilities , prediction_probabilities , truth_weights )
205+ score = cls (truth_probabilities , prediction_probabilities , truth_weights , validation_metric )
168206 return score
169207
170208 @classmethod
171209 def from_file (
172- cls , truth_file : pathlib .Path , prediction_file : pathlib .Path
210+ cls ,
211+ truth_file : pathlib .Path ,
212+ prediction_file : pathlib .Path ,
213+ validation_metric : Optional [ValidationMetric ] = None ,
173214 ) -> ClassificationScore :
174215 with truth_file .open ('r' ) as truth_file_stream , prediction_file .open (
175216 'r'
176217 ) as prediction_file_stream :
177218 return cls .from_stream (
178- cast (TextIO , truth_file_stream ), cast (TextIO , prediction_file_stream )
219+ cast (TextIO , truth_file_stream ),
220+ cast (TextIO , prediction_file_stream ),
221+ validation_metric ,
179222 )
0 commit comments