|
a |
|
b/src/callbacks.py |
|
|
1 |
from catalyst.dl.core import Callback, RunnerState, CallbackOrder |
|
|
2 |
from catalyst.dl.callbacks import CriterionCallback |
|
|
3 |
from catalyst.contrib.criterion import FocalLossBinary |
|
|
4 |
from catalyst.dl.utils.criterion import accuracy |
|
|
5 |
from catalyst.utils import get_activation_fn |
|
|
6 |
import torch |
|
|
7 |
import torch.nn as nn |
|
|
8 |
import numpy as np |
|
|
9 |
from typing import List |
|
|
10 |
|
|
|
11 |
import torch |
|
|
12 |
from catalyst.utils import get_activation_fn |
|
|
13 |
|
|
|
14 |
|
|
|
15 |
class MultiTaskCriterionCallback(Callback): |
|
|
16 |
def __init__( |
|
|
17 |
self, |
|
|
18 |
input_seg_key: str = "targets", |
|
|
19 |
input_cls_key: str = "labels", |
|
|
20 |
output_seg_key: str = "logits", |
|
|
21 |
output_cls_key: str = "cls_logits", |
|
|
22 |
prefix: str = "loss", |
|
|
23 |
criterion_key: str = None, |
|
|
24 |
loss_key: str = None, |
|
|
25 |
multiplier: float = 1.0, |
|
|
26 |
loss_weights: List[float] = None, |
|
|
27 |
): |
|
|
28 |
super(MultiTaskCriterionCallback, self).__init__(CallbackOrder.Criterion) |
|
|
29 |
self.input_seg_key = input_seg_key |
|
|
30 |
self.input_cls_key = input_cls_key |
|
|
31 |
self.output_seg_key = output_seg_key |
|
|
32 |
self.output_cls_key = output_cls_key |
|
|
33 |
self.prefix = prefix |
|
|
34 |
self.criterion_key = criterion_key |
|
|
35 |
self.loss_key = loss_key |
|
|
36 |
self.multiplier = multiplier |
|
|
37 |
self.loss_weights = loss_weights |
|
|
38 |
|
|
|
39 |
self.criterion_cls = nn.BCEWithLogitsLoss() |
|
|
40 |
|
|
|
41 |
def _add_loss_to_state(self, state: RunnerState, loss): |
|
|
42 |
if self.loss_key is None: |
|
|
43 |
if state.loss is not None: |
|
|
44 |
if isinstance(state.loss, list): |
|
|
45 |
state.loss.append(loss) |
|
|
46 |
else: |
|
|
47 |
state.loss = [state.loss, loss] |
|
|
48 |
else: |
|
|
49 |
state.loss = loss |
|
|
50 |
else: |
|
|
51 |
if state.loss is not None: |
|
|
52 |
assert isinstance(state.loss, dict) |
|
|
53 |
state.loss[self.loss_key] = loss |
|
|
54 |
else: |
|
|
55 |
state.loss = {self.loss_key: loss} |
|
|
56 |
|
|
|
57 |
def _compute_loss(self, state: RunnerState, criterion): |
|
|
58 |
output_seg = state.output[self.output_seg_key] |
|
|
59 |
output_cls = state.output[self.output_cls_key] |
|
|
60 |
input_seg = state.input[self.input_seg_key] |
|
|
61 |
input_cls = state.input[self.input_cls_key] |
|
|
62 |
|
|
|
63 |
# assert len(self.loss_weights) == len(outputs) |
|
|
64 |
loss = 0 |
|
|
65 |
|
|
|
66 |
# Segmentation loss |
|
|
67 |
loss += criterion(output_seg, input_seg) * self.loss_weights[0] |
|
|
68 |
# Classification loss |
|
|
69 |
loss += self.criterion_cls(output_cls, input_cls) * self.loss_weights[1] |
|
|
70 |
|
|
|
71 |
return loss |
|
|
72 |
|
|
|
73 |
def on_stage_start(self, state: RunnerState): |
|
|
74 |
assert state.criterion is not None |
|
|
75 |
|
|
|
76 |
def on_batch_end(self, state: RunnerState): |
|
|
77 |
criterion = state.get_key( |
|
|
78 |
key="criterion", inner_key=self.criterion_key |
|
|
79 |
) |
|
|
80 |
|
|
|
81 |
loss = self._compute_loss(state, criterion) * self.multiplier |
|
|
82 |
|
|
|
83 |
state.metrics.add_batch_value( |
|
|
84 |
metrics_dict={ |
|
|
85 |
self.prefix: loss.item(), |
|
|
86 |
} |
|
|
87 |
) |
|
|
88 |
|
|
|
89 |
self._add_loss_to_state(state, loss) |