|
a |
|
b/utils_loss.py |
|
|
1 |
# Used from https://github.com/mahmoodlab/MCAT |
|
|
2 |
|
|
|
3 |
import torch |
|
|
4 |
import numpy as np |
|
|
5 |
|
|
|
6 |
# divide continuous time scale into k discrete bins in total, T_cont \in {[0, a_1), [a_1, a_2), ...., [a_(k-1), inf)} |
|
|
7 |
# Y = T_discrete is the discrete event time: |
|
|
8 |
# Y = 0 if T_cont \in (-inf, 0), Y = 1 if T_cont \in [0, a_1), Y = 2 if T_cont in [a_1, a_2), ..., Y = k if T_cont in [a_(k-1), inf) |
|
|
9 |
# discrete hazards: discrete probability of h(t) = P(Y=t | Y>=t, X), t = 0,1,2,...,k |
|
|
10 |
# S: survival function: P(Y > t | X) |
|
|
11 |
# all patients are alive from (-inf, 0) by definition, so P(Y=0) = 0 |
|
|
12 |
# h(0) = 0 ---> do not need to model |
|
|
13 |
# S(0) = P(Y > 0 | X) = 1 ----> do not need to model |
|
|
14 |
''' |
|
|
15 |
Summary: neural network is hazard probability function, h(t) for t = 1,2,...,k |
|
|
16 |
corresponding Y = 1, ..., k. h(t) represents the probability that patient dies in [0, a_1), [a_1, a_2), ..., [a_(k-1), inf] |
|
|
17 |
''' |
|
|
18 |
# def neg_likelihood_loss(hazards, Y, c): |
|
|
19 |
# batch_size = len(Y) |
|
|
20 |
# Y = Y.view(batch_size, 1) # ground truth bin, 1,2,...,k |
|
|
21 |
# c = c.view(batch_size, 1).float() #censorship status, 0 or 1 |
|
|
22 |
# S = torch.cumprod(1 - hazards, dim=1) # surival is cumulative product of 1 - hazards |
|
|
23 |
# # without padding, S(1) = S[0], h(1) = h[0] |
|
|
24 |
# S_padded = torch.cat([torch.ones_like(c), S], 1) #S(0) = 1, all patients are alive from (-inf, 0) by definition |
|
|
25 |
# # after padding, S(0) = S[0], S(1) = S[1], etc, h(1) = h[0] |
|
|
26 |
# #h[y] = h(1) |
|
|
27 |
# #S[1] = S(1) |
|
|
28 |
# neg_l = - c * torch.log(torch.gather(S_padded, 1, Y)) - (1 - c) * (torch.log(torch.gather(S_padded, 1, Y-1)) + torch.log(hazards[:, Y-1])) |
|
|
29 |
# neg_l = neg_l.mean() |
|
|
30 |
# return neg_l |
|
|
31 |
|
|
|
32 |
|
|
|
33 |
# divide continuous time scale into k discrete bins in total, T_cont \in {[0, a_1), [a_1, a_2), ...., [a_(k-1), inf)} |
|
|
34 |
# Y = T_discrete is the discrete event time: |
|
|
35 |
# Y = -1 if T_cont \in (-inf, 0), Y = 0 if T_cont \in [0, a_1), Y = 1 if T_cont in [a_1, a_2), ..., Y = k-1 if T_cont in [a_(k-1), inf) |
|
|
36 |
# discrete hazards: discrete probability of h(t) = P(Y=t | Y>=t, X), t = -1,0,1,2,...,k |
|
|
37 |
# S: survival function: P(Y > t | X) |
|
|
38 |
# all patients are alive from (-inf, 0) by definition, so P(Y=-1) = 0 |
|
|
39 |
# h(-1) = 0 ---> do not need to model |
|
|
40 |
# S(-1) = P(Y > -1 | X) = 1 ----> do not need to model |
|
|
41 |
''' |
|
|
42 |
Summary: neural network is hazard probability function, h(t) for t = 0,1,2,...,k-1 |
|
|
43 |
corresponding Y = 0,1, ..., k-1. h(t) represents the probability that patient dies in [0, a_1), [a_1, a_2), ..., [a_(k-1), inf] |
|
|
44 |
''' |
|
|
45 |
def nll_loss(hazards, S, Y, c, alpha=0.4, eps=1e-7): |
|
|
46 |
batch_size = len(Y) |
|
|
47 |
Y = Y.view(batch_size, 1) # ground truth bin, 1,2,...,k |
|
|
48 |
c = c.view(batch_size, 1).float() #censorship status, 0 or 1 |
|
|
49 |
if S is None: |
|
|
50 |
S = torch.cumprod(1 - hazards, dim=1) # surival is cumulative product of 1 - hazards |
|
|
51 |
# without padding, S(0) = S[0], h(0) = h[0] |
|
|
52 |
S_padded = torch.cat([torch.ones_like(c), S], 1) #S(-1) = 0, all patients are alive from (-inf, 0) by definition |
|
|
53 |
# after padding, S(0) = S[1], S(1) = S[2], etc, h(0) = h[0] |
|
|
54 |
#h[y] = h(1) |
|
|
55 |
#S[1] = S(1) |
|
|
56 |
uncensored_loss = -(1 - c) * (torch.log(torch.gather(S_padded, 1, Y).clamp(min=eps)) + torch.log(torch.gather(hazards, 1, Y).clamp(min=eps))) |
|
|
57 |
censored_loss = - c * torch.log(torch.gather(S_padded, 1, Y+1).clamp(min=eps)) |
|
|
58 |
neg_l = censored_loss + uncensored_loss |
|
|
59 |
loss = (1-alpha) * neg_l + alpha * uncensored_loss |
|
|
60 |
loss = loss.mean() |
|
|
61 |
return loss |
|
|
62 |
|
|
|
63 |
# loss_fn(hazards=hazards, S=S, Y=Y_hat, c=c, alpha=0) |
|
|
64 |
class NLLSurvLoss(object): |
|
|
65 |
def __init__(self, alpha=0.15): |
|
|
66 |
self.alpha = alpha |
|
|
67 |
|
|
|
68 |
def __call__(self, hazards, S, Y, c, alpha=None): |
|
|
69 |
if alpha is None: |
|
|
70 |
return nll_loss(hazards, S, Y, c, alpha=self.alpha) |
|
|
71 |
else: |
|
|
72 |
return nll_loss(hazards, S, Y, c, alpha=alpha) |