[0f1df3]: / AICare-baselines / losses / __init__.py

Download this file

21 lines (15 with data), 632 Bytes

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
import torch
import torch.nn.functional as F
from .multitask_loss import get_multitask_loss
from .time_aware_loss import get_time_aware_loss
def get_loss(y_pred, y_true, task, time_aware=False):
if task == "outcome":
loss = F.binary_cross_entropy(y_pred, y_true[:, 0])
elif task == "los":
loss = F.mse_loss(y_pred, y_true[:, 1])
elif task == "multitask":
loss = get_multitask_loss(y_pred[:,0], y_pred[:,1], y_true[:,0], y_true[:,1])
# If use time aware loss:
if task == "outcome" and time_aware:
loss = get_time_aware_loss(y_pred, y_true[:, 0], y_true[:, 1])
return loss