[d129b2]: / medicalbert / evaluator / validation_metric_factory.py

Download this file

64 lines (47 with data), 2.0 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
class ROCValidation:
def __init__(self):
self.best_score = 0
self.best_checkpoint = None
self.classifier = None
def update(self, metrics, classifier, checkpoint): # metrics=roc, precision, accuracy, loss
# As we only select on ROC we only compare use that metric
roc_score = metrics.iloc[0]['ROC']
print("This score is {} - current best is {}".format(roc_score, self.best_score))
if roc_score >= self.best_score:
self.best_score = roc_score
self.best_checkpoint = checkpoint
self.classifier = classifier
def get_checkpoint(self):
return self.best_checkpoint
def get_score(self):
return self.best_score
def get_classifier(self):
return self.classifier
class LossValidation:
def __init__(self):
self.best_score = 1.0 # rounding errors means the actual reported loss can be greater than 1 :/
self.best_checkpoint = None
self.classifier = None
def update(self, metrics, classifier, checkpoint): # metrics=roc, precision, accuracy, loss
# As we only select on ROC we only compare use that metric
loss_score = metrics.iloc[0]['loss']
if loss_score <= self.best_score:
self.best_score = loss_score
self.best_checkpoint = checkpoint
self.classifier = classifier
def get_checkpoint(self):
return self.best_checkpoint
def get_score(self):
return self.best_score
def get_classifier(self):
return self.classifier
class ValidationMetricFactory:
def __init__(self):
self._validators = {"roc": ROCValidation, "loss": LossValidation}
def register_validator(self, name, validator):
self._validators[name] = validator
def make_validator(self, validator):
vd = self._validators.get(validator)
if not vd:
raise ValueError(format)
return vd()