Switch to side-by-side view

--- a
+++ b/AICare-baselines/losses/time_aware_loss.py
@@ -0,0 +1,29 @@
+import torch
+from torch import nn
+
+
+class TimeAwareLoss(nn.Module):
+    def __init__(self, decay_rate=0.1, reward_factor=0.1):
+        super(TimeAwareLoss, self).__init__()
+        self.bce = nn.BCELoss(reduction='none')
+        self.decay_rate = decay_rate
+        self.reward_factor = reward_factor
+
+    def forward(self, outcome_pred, outcome_true, los_true):
+        los_weights = torch.exp(-self.decay_rate * los_true)  # Exponential decay
+        loss_unreduced = self.bce(outcome_pred, outcome_true)
+
+        reward_term = (los_true * torch.abs(outcome_true - outcome_pred)).mean()  # Reward term
+        loss = (loss_unreduced * los_weights).mean()-self.reward_factor * reward_term  # Weighted loss
+        
+        return torch.clamp(loss, min=0)
+
+def get_time_aware_loss(outcome_pred, outcome_true, los_true):
+    time_aware_loss = TimeAwareLoss()
+    return time_aware_loss(outcome_pred, outcome_true, los_true)
+
+if __name__ == "__main__":
+    outcome_pred = torch.tensor([0.1])
+    outcome_true = torch.tensor([1.])
+    los_true = torch.tensor([-4.0])
+    print(get_time_aware_loss(outcome_pred, outcome_true, los_true))