Diff of /utils_loss.py [000000] .. [352cae]

Switch to unified view

a b/utils_loss.py
1
# Used from https://github.com/mahmoodlab/MCAT
2
3
import torch
4
import numpy as np
5
6
# divide continuous time scale into k discrete bins in total,  T_cont \in {[0, a_1), [a_1, a_2), ...., [a_(k-1), inf)}
7
# Y = T_discrete is the discrete event time:
8
# Y = 0 if T_cont \in (-inf, 0), Y = 1 if T_cont \in [0, a_1),  Y = 2 if T_cont in [a_1, a_2), ..., Y = k if T_cont in [a_(k-1), inf)
9
# discrete hazards: discrete probability of h(t) = P(Y=t | Y>=t, X),  t = 0,1,2,...,k
10
# S: survival function: P(Y > t | X)
11
# all patients are alive from (-inf, 0) by definition, so P(Y=0) = 0
12
# h(0) = 0 ---> do not need to model
13
# S(0) = P(Y > 0 | X) = 1 ----> do not need to model
14
'''
15
Summary: neural network is hazard probability function, h(t) for t = 1,2,...,k
16
corresponding Y = 1, ..., k. h(t) represents the probability that patient dies in [0, a_1), [a_1, a_2), ..., [a_(k-1), inf]
17
'''
18
# def neg_likelihood_loss(hazards, Y, c):
19
#   batch_size = len(Y)
20
#   Y = Y.view(batch_size, 1) # ground truth bin, 1,2,...,k
21
#   c = c.view(batch_size, 1).float() #censorship status, 0 or 1
22
#   S = torch.cumprod(1 - hazards, dim=1) # surival is cumulative product of 1 - hazards
23
#   # without padding, S(1) = S[0], h(1) = h[0]
24
#   S_padded = torch.cat([torch.ones_like(c), S], 1) #S(0) = 1, all patients are alive from (-inf, 0) by definition
25
#   # after padding, S(0) = S[0], S(1) = S[1], etc, h(1) = h[0]
26
#   #h[y] = h(1)
27
#   #S[1] = S(1)
28
#   neg_l = - c * torch.log(torch.gather(S_padded, 1, Y)) - (1 - c) * (torch.log(torch.gather(S_padded, 1, Y-1)) + torch.log(hazards[:, Y-1]))
29
#   neg_l = neg_l.mean()
30
#   return neg_l
31
32
33
# divide continuous time scale into k discrete bins in total,  T_cont \in {[0, a_1), [a_1, a_2), ...., [a_(k-1), inf)}
34
# Y = T_discrete is the discrete event time:
35
# Y = -1 if T_cont \in (-inf, 0), Y = 0 if T_cont \in [0, a_1),  Y = 1 if T_cont in [a_1, a_2), ..., Y = k-1 if T_cont in [a_(k-1), inf)
36
# discrete hazards: discrete probability of h(t) = P(Y=t | Y>=t, X),  t = -1,0,1,2,...,k
37
# S: survival function: P(Y > t | X)
38
# all patients are alive from (-inf, 0) by definition, so P(Y=-1) = 0
39
# h(-1) = 0 ---> do not need to model
40
# S(-1) = P(Y > -1 | X) = 1 ----> do not need to model
41
'''
42
Summary: neural network is hazard probability function, h(t) for t = 0,1,2,...,k-1
43
corresponding Y = 0,1, ..., k-1. h(t) represents the probability that patient dies in [0, a_1), [a_1, a_2), ..., [a_(k-1), inf]
44
'''
45
def nll_loss(hazards, S, Y, c, alpha=0.4, eps=1e-7):
46
    batch_size = len(Y)
47
    Y = Y.view(batch_size, 1) # ground truth bin, 1,2,...,k
48
    c = c.view(batch_size, 1).float() #censorship status, 0 or 1
49
    if S is None:
50
        S = torch.cumprod(1 - hazards, dim=1) # surival is cumulative product of 1 - hazards
51
    # without padding, S(0) = S[0], h(0) = h[0]
52
    S_padded = torch.cat([torch.ones_like(c), S], 1) #S(-1) = 0, all patients are alive from (-inf, 0) by definition
53
    # after padding, S(0) = S[1], S(1) = S[2], etc, h(0) = h[0]
54
    #h[y] = h(1)
55
    #S[1] = S(1)
56
    uncensored_loss = -(1 - c) * (torch.log(torch.gather(S_padded, 1, Y).clamp(min=eps)) + torch.log(torch.gather(hazards, 1, Y).clamp(min=eps)))
57
    censored_loss = - c * torch.log(torch.gather(S_padded, 1, Y+1).clamp(min=eps))
58
    neg_l = censored_loss + uncensored_loss
59
    loss = (1-alpha) * neg_l + alpha * uncensored_loss
60
    loss = loss.mean()
61
    return loss
62
63
# loss_fn(hazards=hazards, S=S, Y=Y_hat, c=c, alpha=0)
64
class NLLSurvLoss(object):
65
    def __init__(self, alpha=0.15):
66
        self.alpha = alpha
67
68
    def __call__(self, hazards, S, Y, c, alpha=None):
69
        if alpha is None:
70
            return nll_loss(hazards, S, Y, c, alpha=self.alpha)
71
        else:
72
            return nll_loss(hazards, S, Y, c, alpha=alpha)