a b/AICare-baselines/losses/time_aware_loss.py
1
import torch
2
from torch import nn
3
4
5
class TimeAwareLoss(nn.Module):
6
    def __init__(self, decay_rate=0.1, reward_factor=0.1):
7
        super(TimeAwareLoss, self).__init__()
8
        self.bce = nn.BCELoss(reduction='none')
9
        self.decay_rate = decay_rate
10
        self.reward_factor = reward_factor
11
12
    def forward(self, outcome_pred, outcome_true, los_true):
13
        los_weights = torch.exp(-self.decay_rate * los_true)  # Exponential decay
14
        loss_unreduced = self.bce(outcome_pred, outcome_true)
15
16
        reward_term = (los_true * torch.abs(outcome_true - outcome_pred)).mean()  # Reward term
17
        loss = (loss_unreduced * los_weights).mean()-self.reward_factor * reward_term  # Weighted loss
18
        
19
        return torch.clamp(loss, min=0)
20
21
def get_time_aware_loss(outcome_pred, outcome_true, los_true):
22
    time_aware_loss = TimeAwareLoss()
23
    return time_aware_loss(outcome_pred, outcome_true, los_true)
24
25
if __name__ == "__main__":
26
    outcome_pred = torch.tensor([0.1])
27
    outcome_true = torch.tensor([1.])
28
    los_true = torch.tensor([-4.0])
29
    print(get_time_aware_loss(outcome_pred, outcome_true, los_true))