--- a +++ b/adpkd_segmentation/utils/criterions.py @@ -0,0 +1,31 @@ +"""Dictionary of criterions""" + + +class LossesMetrics: + """ Produces function to generate dict of keys: losses/metrics for batch""" + + def __init__(self, criterions_dict, requires_extra_info=None): + """ + Args: + criterions_dict {dict} -- key (str) : criterion_losses + requires_extra_info {list or None}, keys in `criterions_dict` + noting losses which require extra information + """ + self.criterions_dict = criterions_dict + if requires_extra_info is None: + self.requires_extra_info = set() + else: + self.requires_extra_info = set(requires_extra_info) + + def __call__(self): + def losses_dict(y_hat, y, extra_dict=None): + res = {} + for c_name, criterion in self.criterions_dict.items(): + # optional info for some criterions + if c_name in self.requires_extra_info: + res[c_name] = criterion(y_hat, y, extra_dict) + else: + res[c_name] = criterion(y_hat, y) + return res + + return losses_dict