--- a +++ b/AICare-baselines/losses/__init__.py @@ -0,0 +1,20 @@ +import torch +import torch.nn.functional as F + +from .multitask_loss import get_multitask_loss +from .time_aware_loss import get_time_aware_loss + + +def get_loss(y_pred, y_true, task, time_aware=False): + if task == "outcome": + loss = F.binary_cross_entropy(y_pred, y_true[:, 0]) + elif task == "los": + loss = F.mse_loss(y_pred, y_true[:, 1]) + elif task == "multitask": + loss = get_multitask_loss(y_pred[:,0], y_pred[:,1], y_true[:,0], y_true[:,1]) + + # If use time aware loss: + if task == "outcome" and time_aware: + loss = get_time_aware_loss(y_pred, y_true[:, 0], y_true[:, 1]) + + return loss