Diff of /utils/loss_func.py [000000] .. [405115]

Switch to unified view

a b/utils/loss_func.py
1
import torch
2
import torch.nn as nn
3
import numpy as np
4
import torch
5
import torch.nn.functional as F
6
7
8
class NLLSurvLoss(nn.Module):
9
    """
10
    The negative log-likelihood loss function for the discrete time to event model (Zadeh and Schmid, 2020).
11
    Code borrowed from https://github.com/mahmoodlab/Patch-GCN/blob/master/utils/utils.py
12
    Parameters
13
    ----------
14
    alpha: float
15
        TODO: document
16
    eps: float
17
        Numerical constant; lower bound to avoid taking logs of tiny numbers.
18
    reduction: str
19
        Do we sum or average the loss function over the batches. Must be one of ['mean', 'sum']
20
    """
21
    def __init__(self, alpha=0.0, eps=1e-7, reduction='mean'):
22
        super().__init__()
23
        self.alpha = alpha
24
        self.eps = eps
25
        self.reduction = reduction
26
27
    def __call__(self, h, y, t, c):
28
        """
29
        Parameters
30
        ----------
31
        h: (n_batches, n_classes)
32
            The neural network output discrete survival predictions such that hazards = sigmoid(h).
33
        y_c: (n_batches, 2) or (n_batches, 3)
34
            The true time bin label (first column) and censorship indicator (second column).
35
        """
36
37
        return nll_loss(h=h, y=y.unsqueeze(dim=1), c=c.unsqueeze(dim=1),
38
                        alpha=self.alpha, eps=self.eps,
39
                        reduction=self.reduction)
40
41
42
# TODO: document better and clean up
43
def nll_loss(h, y, c, alpha=0.0, eps=1e-7, reduction='mean'):
44
    """
45
    The negative log-likelihood loss function for the discrete time to event model (Zadeh and Schmid, 2020).
46
    Code borrowed from https://github.com/mahmoodlab/Patch-GCN/blob/master/utils/utils.py
47
    Parameters
48
    ----------
49
    h: (n_batches, n_classes)
50
        The neural network output discrete survival predictions such that hazards = sigmoid(h).
51
    y: (n_batches, 1)
52
        The true time bin index label.
53
    c: (n_batches, 1)
54
        The censoring status indicator.
55
    alpha: float
56
        TODO: document
57
    eps: float
58
        Numerical constant; lower bound to avoid taking logs of tiny numbers.
59
    reduction: str
60
        Do we sum or average the loss function over the batches. Must be one of ['mean', 'sum']
61
    References
62
    ----------
63
    Zadeh, S.G. and Schmid, M., 2020. Bias in cross-entropy-based training of deep survival networks. IEEE transactions on pattern analysis and machine intelligence.
64
    """
65
    # print("h shape", h.shape)
66
67
    # make sure these are ints
68
    y = y.type(torch.int64)
69
    c = c.type(torch.int64)
70
71
    hazards = torch.sigmoid(h)
72
    # print("hazards shape", hazards.shape)
73
74
    S = torch.cumprod(1 - hazards, dim=1)
75
    # print("S.shape", S.shape, S)
76
77
    S_padded = torch.cat([torch.ones_like(c), S], 1)
78
    # S(-1) = 0, all patients are alive from (-inf, 0) by definition
79
    # after padding, S(0) = S[1], S(1) = S[2], etc, h(0) = h[0]
80
    # hazards[y] = hazards(1)
81
    # S[1] = S(1)
82
    # TODO: document and check
83
84
    # print("S_padded.shape", S_padded.shape, S_padded)
85
86
87
    # TODO: document/better naming
88
    s_prev = torch.gather(S_padded, dim=1, index=y).clamp(min=eps)
89
    h_this = torch.gather(hazards, dim=1, index=y).clamp(min=eps)
90
    s_this = torch.gather(S_padded, dim=1, index=y+1).clamp(min=eps)
91
    # print('s_prev.s_prev', s_prev.shape, s_prev)
92
    # print('h_this.shape', h_this.shape, h_this)
93
    # print('s_this.shape', s_this.shape, s_this)
94
95
    uncensored_loss = -(1 - c) * (torch.log(s_prev) + torch.log(h_this))
96
    censored_loss = - c * torch.log(s_this)
97
    
98
99
    # print('uncensored_loss.shape', uncensored_loss.shape)
100
    # print('censored_loss.shape', censored_loss.shape)
101
102
    neg_l = censored_loss + uncensored_loss
103
    if alpha is not None:
104
        loss = (1 - alpha) * neg_l + alpha * uncensored_loss
105
106
    if reduction == 'mean':
107
        loss = loss.mean()
108
    elif reduction == 'sum':
109
        loss = loss.sum()
110
    else:
111
        raise ValueError("Bad input for reduction: {}".format(reduction))
112
113
    return loss