--- a
+++ b/src/callbacks.py
@@ -0,0 +1,89 @@
+from catalyst.dl.core import Callback, RunnerState, CallbackOrder
+from catalyst.dl.callbacks import CriterionCallback
+from catalyst.contrib.criterion import FocalLossBinary
+from catalyst.dl.utils.criterion import accuracy
+from catalyst.utils import get_activation_fn
+import torch
+import torch.nn as nn
+import numpy as np
+from typing import List
+
+import torch
+from catalyst.utils import get_activation_fn
+
+
+class MultiTaskCriterionCallback(Callback):
+    def __init__(
+        self,
+        input_seg_key: str = "targets",
+        input_cls_key: str = "labels",
+        output_seg_key: str = "logits",
+        output_cls_key: str = "cls_logits",
+        prefix: str = "loss",
+        criterion_key: str = None,
+        loss_key: str = None,
+        multiplier: float = 1.0,
+        loss_weights: List[float] = None,
+    ):
+        super(MultiTaskCriterionCallback, self).__init__(CallbackOrder.Criterion)
+        self.input_seg_key = input_seg_key
+        self.input_cls_key = input_cls_key
+        self.output_seg_key = output_seg_key
+        self.output_cls_key = output_cls_key
+        self.prefix = prefix
+        self.criterion_key = criterion_key
+        self.loss_key = loss_key
+        self.multiplier = multiplier
+        self.loss_weights = loss_weights
+
+        self.criterion_cls = nn.BCEWithLogitsLoss()
+
+    def _add_loss_to_state(self, state: RunnerState, loss):
+        if self.loss_key is None:
+            if state.loss is not None:
+                if isinstance(state.loss, list):
+                    state.loss.append(loss)
+                else:
+                    state.loss = [state.loss, loss]
+            else:
+                state.loss = loss
+        else:
+            if state.loss is not None:
+                assert isinstance(state.loss, dict)
+                state.loss[self.loss_key] = loss
+            else:
+                state.loss = {self.loss_key: loss}
+
+    def _compute_loss(self, state: RunnerState, criterion):
+        output_seg = state.output[self.output_seg_key]
+        output_cls = state.output[self.output_cls_key]
+        input_seg = state.input[self.input_seg_key]
+        input_cls = state.input[self.input_cls_key]
+
+        # assert len(self.loss_weights) == len(outputs)
+        loss = 0
+
+        # Segmentation loss
+        loss += criterion(output_seg, input_seg) * self.loss_weights[0]
+        # Classification loss
+        loss += self.criterion_cls(output_cls, input_cls) * self.loss_weights[1]
+
+        return loss
+
+    def on_stage_start(self, state: RunnerState):
+        assert state.criterion is not None
+
+    def on_batch_end(self, state: RunnerState):
+        criterion = state.get_key(
+            key="criterion", inner_key=self.criterion_key
+        )
+
+        loss = self._compute_loss(state, criterion) * self.multiplier
+
+        state.metrics.add_batch_value(
+            metrics_dict={
+                self.prefix: loss.item(),
+            }
+        )
+
+        self._add_loss_to_state(state, loss)