|
a |
|
b/AICare-baselines/losses/time_aware_loss.py |
|
|
1 |
import torch |
|
|
2 |
from torch import nn |
|
|
3 |
|
|
|
4 |
|
|
|
5 |
class TimeAwareLoss(nn.Module): |
|
|
6 |
def __init__(self, decay_rate=0.1, reward_factor=0.1): |
|
|
7 |
super(TimeAwareLoss, self).__init__() |
|
|
8 |
self.bce = nn.BCELoss(reduction='none') |
|
|
9 |
self.decay_rate = decay_rate |
|
|
10 |
self.reward_factor = reward_factor |
|
|
11 |
|
|
|
12 |
def forward(self, outcome_pred, outcome_true, los_true): |
|
|
13 |
los_weights = torch.exp(-self.decay_rate * los_true) # Exponential decay |
|
|
14 |
loss_unreduced = self.bce(outcome_pred, outcome_true) |
|
|
15 |
|
|
|
16 |
reward_term = (los_true * torch.abs(outcome_true - outcome_pred)).mean() # Reward term |
|
|
17 |
loss = (loss_unreduced * los_weights).mean()-self.reward_factor * reward_term # Weighted loss |
|
|
18 |
|
|
|
19 |
return torch.clamp(loss, min=0) |
|
|
20 |
|
|
|
21 |
def get_time_aware_loss(outcome_pred, outcome_true, los_true): |
|
|
22 |
time_aware_loss = TimeAwareLoss() |
|
|
23 |
return time_aware_loss(outcome_pred, outcome_true, los_true) |
|
|
24 |
|
|
|
25 |
if __name__ == "__main__": |
|
|
26 |
outcome_pred = torch.tensor([0.1]) |
|
|
27 |
outcome_true = torch.tensor([1.]) |
|
|
28 |
los_true = torch.tensor([-4.0]) |
|
|
29 |
print(get_time_aware_loss(outcome_pred, outcome_true, los_true)) |