Diff of /src/callbacks.py [000000] .. [95f789]

Switch to unified view

a b/src/callbacks.py
1
from catalyst.dl.core import Callback, RunnerState, CallbackOrder
2
from catalyst.dl.callbacks import CriterionCallback
3
from catalyst.contrib.criterion import FocalLossBinary
4
from catalyst.dl.utils.criterion import accuracy
5
from catalyst.utils import get_activation_fn
6
import torch
7
import torch.nn as nn
8
import numpy as np
9
from typing import List
10
11
import torch
12
from catalyst.utils import get_activation_fn
13
14
15
class MultiTaskCriterionCallback(Callback):
16
    def __init__(
17
        self,
18
        input_seg_key: str = "targets",
19
        input_cls_key: str = "labels",
20
        output_seg_key: str = "logits",
21
        output_cls_key: str = "cls_logits",
22
        prefix: str = "loss",
23
        criterion_key: str = None,
24
        loss_key: str = None,
25
        multiplier: float = 1.0,
26
        loss_weights: List[float] = None,
27
    ):
28
        super(MultiTaskCriterionCallback, self).__init__(CallbackOrder.Criterion)
29
        self.input_seg_key = input_seg_key
30
        self.input_cls_key = input_cls_key
31
        self.output_seg_key = output_seg_key
32
        self.output_cls_key = output_cls_key
33
        self.prefix = prefix
34
        self.criterion_key = criterion_key
35
        self.loss_key = loss_key
36
        self.multiplier = multiplier
37
        self.loss_weights = loss_weights
38
39
        self.criterion_cls = nn.BCEWithLogitsLoss()
40
41
    def _add_loss_to_state(self, state: RunnerState, loss):
42
        if self.loss_key is None:
43
            if state.loss is not None:
44
                if isinstance(state.loss, list):
45
                    state.loss.append(loss)
46
                else:
47
                    state.loss = [state.loss, loss]
48
            else:
49
                state.loss = loss
50
        else:
51
            if state.loss is not None:
52
                assert isinstance(state.loss, dict)
53
                state.loss[self.loss_key] = loss
54
            else:
55
                state.loss = {self.loss_key: loss}
56
57
    def _compute_loss(self, state: RunnerState, criterion):
58
        output_seg = state.output[self.output_seg_key]
59
        output_cls = state.output[self.output_cls_key]
60
        input_seg = state.input[self.input_seg_key]
61
        input_cls = state.input[self.input_cls_key]
62
63
        # assert len(self.loss_weights) == len(outputs)
64
        loss = 0
65
66
        # Segmentation loss
67
        loss += criterion(output_seg, input_seg) * self.loss_weights[0]
68
        # Classification loss
69
        loss += self.criterion_cls(output_cls, input_cls) * self.loss_weights[1]
70
71
        return loss
72
73
    def on_stage_start(self, state: RunnerState):
74
        assert state.criterion is not None
75
76
    def on_batch_end(self, state: RunnerState):
77
        criterion = state.get_key(
78
            key="criterion", inner_key=self.criterion_key
79
        )
80
81
        loss = self._compute_loss(state, criterion) * self.multiplier
82
83
        state.metrics.add_batch_value(
84
            metrics_dict={
85
                self.prefix: loss.item(),
86
            }
87
        )
88
89
        self._add_loss_to_state(state, loss)