|
a |
|
b/AICare-baselines/losses/__init__.py |
|
|
1 |
import torch |
|
|
2 |
import torch.nn.functional as F |
|
|
3 |
|
|
|
4 |
from .multitask_loss import get_multitask_loss |
|
|
5 |
from .time_aware_loss import get_time_aware_loss |
|
|
6 |
|
|
|
7 |
|
|
|
8 |
def get_loss(y_pred, y_true, task, time_aware=False): |
|
|
9 |
if task == "outcome": |
|
|
10 |
loss = F.binary_cross_entropy(y_pred, y_true[:, 0]) |
|
|
11 |
elif task == "los": |
|
|
12 |
loss = F.mse_loss(y_pred, y_true[:, 1]) |
|
|
13 |
elif task == "multitask": |
|
|
14 |
loss = get_multitask_loss(y_pred[:,0], y_pred[:,1], y_true[:,0], y_true[:,1]) |
|
|
15 |
|
|
|
16 |
# If use time aware loss: |
|
|
17 |
if task == "outcome" and time_aware: |
|
|
18 |
loss = get_time_aware_loss(y_pred, y_true[:, 0], y_true[:, 1]) |
|
|
19 |
|
|
|
20 |
return loss |