Skip to content

Commit a88b22e

Browse files
authored
Merge pull request #106 from ImageMarkup/customizable-validation-metric
2 parents be759b3 + 688d38c commit a88b22e

File tree

5 files changed

+76
-16
lines changed

5 files changed

+76
-16
lines changed

isic_challenge_scoring/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
from isic_challenge_scoring.classification import ClassificationScore
1+
from isic_challenge_scoring.classification import ClassificationScore, ValidationMetric
22
from isic_challenge_scoring.segmentation import SegmentationScore
33
from isic_challenge_scoring.types import ScoreException
44

5-
__all__ = ['ClassificationScore', 'SegmentationScore', 'ScoreException']
5+
__all__ = ['ClassificationScore', 'SegmentationScore', 'ScoreException', 'ValidationMetric']

isic_challenge_scoring/classification.py

Lines changed: 51 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from __future__ import annotations
22

33
from dataclasses import dataclass
4+
import enum
45
import pathlib
5-
from typing import Dict, TextIO, cast
6+
from typing import Dict, Optional, TextIO, cast
67

78
import pandas as pd
89

@@ -12,6 +13,12 @@
1213
from isic_challenge_scoring.types import DataFrameDict, RocDict, Score, ScoreDict, SeriesDict
1314

1415

16+
class ValidationMetric(enum.Enum):
17+
BALANCED_ACCURACY = 'balanced_accuracy'
18+
AUC = 'auc'
19+
AVERAGE_PRECISION = 'ap'
20+
21+
1522
@dataclass(init=False)
1623
class ClassificationScore(Score):
1724
per_category: pd.DataFrame
@@ -24,6 +31,7 @@ def __init__(
2431
truth_probabilities: pd.DataFrame,
2532
prediction_probabilities: pd.DataFrame,
2633
truth_weights: pd.DataFrame,
34+
validation_metric: Optional[ValidationMetric] = None,
2735
) -> None:
2836
categories = truth_probabilities.columns
2937

@@ -61,9 +69,36 @@ def __init__(
6169
)
6270

6371
self.overall = self.aggregate.at['balanced_accuracy']
64-
self.validation = metrics.balanced_multiclass_accuracy(
65-
truth_probabilities, prediction_probabilities, truth_weights.validation_weight
66-
)
72+
73+
if validation_metric:
74+
if validation_metric == ValidationMetric.BALANCED_ACCURACY:
75+
self.validation = metrics.balanced_multiclass_accuracy(
76+
truth_probabilities, prediction_probabilities, truth_weights.validation_weight
77+
)
78+
elif validation_metric == ValidationMetric.AVERAGE_PRECISION:
79+
per_category_ap = pd.Series(
80+
[
81+
metrics.average_precision(
82+
truth_probabilities[category],
83+
prediction_probabilities[category],
84+
truth_weights.validation_weight,
85+
)
86+
for category in categories
87+
]
88+
)
89+
self.validation = per_category_ap.mean()
90+
elif validation_metric == ValidationMetric.AUC:
91+
per_category_auc = pd.Series(
92+
[
93+
metrics.auc(
94+
truth_probabilities[category],
95+
prediction_probabilities[category],
96+
truth_weights.validation_weight,
97+
)
98+
for category in categories
99+
]
100+
)
101+
self.validation = per_category_auc.mean()
67102

68103
@staticmethod
69104
def _category_score(
@@ -153,7 +188,10 @@ def to_dict(self, rocs: bool = True) -> ScoreDict:
153188

154189
@classmethod
155190
def from_stream(
156-
cls, truth_file_stream: TextIO, prediction_file_stream: TextIO
191+
cls,
192+
truth_file_stream: TextIO,
193+
prediction_file_stream: TextIO,
194+
validation_metric: Optional[ValidationMetric] = None,
157195
) -> ClassificationScore:
158196
truth_probabilities, truth_weights = parse_truth_csv(truth_file_stream)
159197
categories = truth_probabilities.columns
@@ -164,16 +202,21 @@ def from_stream(
164202
sort_rows(truth_probabilities)
165203
sort_rows(prediction_probabilities)
166204

167-
score = cls(truth_probabilities, prediction_probabilities, truth_weights)
205+
score = cls(truth_probabilities, prediction_probabilities, truth_weights, validation_metric)
168206
return score
169207

170208
@classmethod
171209
def from_file(
172-
cls, truth_file: pathlib.Path, prediction_file: pathlib.Path
210+
cls,
211+
truth_file: pathlib.Path,
212+
prediction_file: pathlib.Path,
213+
validation_metric: Optional[ValidationMetric] = None,
173214
) -> ClassificationScore:
174215
with truth_file.open('r') as truth_file_stream, prediction_file.open(
175216
'r'
176217
) as prediction_file_stream:
177218
return cls.from_stream(
178-
cast(TextIO, truth_file_stream), cast(TextIO, prediction_file_stream)
219+
cast(TextIO, truth_file_stream),
220+
cast(TextIO, prediction_file_stream),
221+
validation_metric,
179222
)

isic_challenge_scoring/types.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from dataclasses import dataclass
2-
from typing import Dict, List, Union
2+
from typing import Dict, List, Optional, Union
33

44

55
class ScoreException(Exception):
@@ -9,13 +9,13 @@ class ScoreException(Exception):
99
SeriesDict = Dict[str, float]
1010
DataFrameDict = Dict[str, SeriesDict]
1111
RocDict = Dict[str, List[float]]
12-
ScoreDict = Dict[str, Union[float, SeriesDict, DataFrameDict, Dict[str, RocDict]]]
12+
ScoreDict = Dict[str, Union[float, Optional[float], SeriesDict, DataFrameDict, Dict[str, RocDict]]]
1313

1414

1515
@dataclass
1616
class Score:
1717
overall: float
18-
validation: float
18+
validation: Optional[float]
1919

2020
def to_string(self) -> str:
2121
output = f'Overall: {self.overall}\n'
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
version https://git-lfs.github.com/spec/v1
2-
oid sha256:6cc39a6c579a5bd90639705c6a1ad4a326cb3df0626682be1cc693ec30a20ff6
3-
size 420666
2+
oid sha256:ece19dc406aa5293cd0185b6872b00149bea68a2106a396343b4db326642dd28
3+
size 490801

tests/test_classification.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,24 @@
1-
from isic_challenge_scoring.classification import ClassificationScore
1+
import pytest
2+
3+
from isic_challenge_scoring.classification import ClassificationScore, ValidationMetric
24

35

46
def test_score(classification_truth_file_path, classification_prediction_file_path):
57
assert ClassificationScore.from_file(
68
classification_truth_file_path, classification_prediction_file_path
79
)
10+
11+
12+
@pytest.mark.parametrize(
13+
'validation_metric',
14+
[ValidationMetric.AUC, ValidationMetric.BALANCED_ACCURACY, ValidationMetric.AVERAGE_PRECISION],
15+
)
16+
def test_score_validation_metric(
17+
classification_truth_file_path, classification_prediction_file_path, validation_metric
18+
):
19+
score = ClassificationScore.from_file(
20+
classification_truth_file_path,
21+
classification_prediction_file_path,
22+
validation_metric=validation_metric,
23+
)
24+
assert isinstance(score.validation, float)

0 commit comments

Comments
 (0)