|
a |
|
b/AICare-baselines/losses/multitask_loss.py |
|
|
1 |
import torch |
|
|
2 |
from torch import nn |
|
|
3 |
|
|
|
4 |
|
|
|
5 |
class MultitaskLoss(nn.Module): |
|
|
6 |
def __init__(self, task_num=2): |
|
|
7 |
super(MultitaskLoss, self).__init__() |
|
|
8 |
self.task_num = task_num |
|
|
9 |
self.alpha = nn.Parameter(torch.ones((task_num))) |
|
|
10 |
self.mse = nn.MSELoss() |
|
|
11 |
self.bce = nn.BCELoss() |
|
|
12 |
|
|
|
13 |
def forward(self, outcome_pred, los_pred, outcome, los): |
|
|
14 |
loss0 = self.bce(outcome_pred, outcome) |
|
|
15 |
loss1 = self.mse(los_pred, los) |
|
|
16 |
return loss0 * self.alpha[0] + loss1 * self.alpha[1] |
|
|
17 |
|
|
|
18 |
def get_multitask_loss(outcome_pred, los_pred, outcome, los): |
|
|
19 |
mtl = MultitaskLoss(task_num=2) |
|
|
20 |
return mtl(outcome_pred, los_pred, outcome, los) |