[0f1df3]: / AICare-baselines / losses / multitask_loss.py

Download this file

20 lines (16 with data), 667 Bytes

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
import torch
from torch import nn
class MultitaskLoss(nn.Module):
def __init__(self, task_num=2):
super(MultitaskLoss, self).__init__()
self.task_num = task_num
self.alpha = nn.Parameter(torch.ones((task_num)))
self.mse = nn.MSELoss()
self.bce = nn.BCELoss()
def forward(self, outcome_pred, los_pred, outcome, los):
loss0 = self.bce(outcome_pred, outcome)
loss1 = self.mse(los_pred, los)
return loss0 * self.alpha[0] + loss1 * self.alpha[1]
def get_multitask_loss(outcome_pred, los_pred, outcome, los):
mtl = MultitaskLoss(task_num=2)
return mtl(outcome_pred, los_pred, outcome, los)