[637b40]: / adpkd_segmentation / utils / criterions.py

Download this file

32 lines (26 with data), 1.1 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
"""Dictionary of criterions"""
class LossesMetrics:
""" Produces function to generate dict of keys: losses/metrics for batch"""
def __init__(self, criterions_dict, requires_extra_info=None):
"""
Args:
criterions_dict {dict} -- key (str) : criterion_losses
requires_extra_info {list or None}, keys in `criterions_dict`
noting losses which require extra information
"""
self.criterions_dict = criterions_dict
if requires_extra_info is None:
self.requires_extra_info = set()
else:
self.requires_extra_info = set(requires_extra_info)
def __call__(self):
def losses_dict(y_hat, y, extra_dict=None):
res = {}
for c_name, criterion in self.criterions_dict.items():
# optional info for some criterions
if c_name in self.requires_extra_info:
res[c_name] = criterion(y_hat, y, extra_dict)
else:
res[c_name] = criterion(y_hat, y)
return res
return losses_dict