a b/AICare-baselines/losses/__init__.py
1
import torch
2
import torch.nn.functional as F
3
4
from .multitask_loss import get_multitask_loss
5
from .time_aware_loss import get_time_aware_loss
6
7
8
def get_loss(y_pred, y_true, task, time_aware=False):
9
    if task == "outcome":
10
        loss = F.binary_cross_entropy(y_pred, y_true[:, 0])
11
    elif task == "los":
12
        loss = F.mse_loss(y_pred, y_true[:, 1])
13
    elif task == "multitask":
14
        loss = get_multitask_loss(y_pred[:,0], y_pred[:,1], y_true[:,0], y_true[:,1])
15
16
    # If use time aware loss:
17
    if task == "outcome" and time_aware:
18
        loss = get_time_aware_loss(y_pred, y_true[:, 0], y_true[:, 1])
19
20
    return loss