a b/AICare-baselines/losses/multitask_loss.py
1
import torch
2
from torch import nn
3
4
5
class MultitaskLoss(nn.Module):
6
    def __init__(self, task_num=2):
7
        super(MultitaskLoss, self).__init__()
8
        self.task_num = task_num
9
        self.alpha = nn.Parameter(torch.ones((task_num)))
10
        self.mse = nn.MSELoss()
11
        self.bce = nn.BCELoss()
12
13
    def forward(self, outcome_pred, los_pred, outcome, los):
14
        loss0 = self.bce(outcome_pred, outcome)
15
        loss1 = self.mse(los_pred, los)
16
        return loss0 * self.alpha[0] + loss1 * self.alpha[1]
17
18
def get_multitask_loss(outcome_pred, los_pred, outcome, los):
19
    mtl = MultitaskLoss(task_num=2)
20
    return mtl(outcome_pred, los_pred, outcome, los)