Switch to side-by-side view

--- a
+++ b/AICare-baselines/losses/__init__.py
@@ -0,0 +1,20 @@
+import torch
+import torch.nn.functional as F
+
+from .multitask_loss import get_multitask_loss
+from .time_aware_loss import get_time_aware_loss
+
+
+def get_loss(y_pred, y_true, task, time_aware=False):
+    if task == "outcome":
+        loss = F.binary_cross_entropy(y_pred, y_true[:, 0])
+    elif task == "los":
+        loss = F.mse_loss(y_pred, y_true[:, 1])
+    elif task == "multitask":
+        loss = get_multitask_loss(y_pred[:,0], y_pred[:,1], y_true[:,0], y_true[:,1])
+
+    # If use time aware loss:
+    if task == "outcome" and time_aware:
+        loss = get_time_aware_loss(y_pred, y_true[:, 0], y_true[:, 1])
+
+    return loss