[4d48b1]: / utils / loss_func.py

Download this file

113 lines (94 with data), 4.0 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import torch
import torch.nn as nn
import numpy as np
import torch
import torch.nn.functional as F
class NLLSurvLoss(nn.Module):
"""
The negative log-likelihood loss function for the discrete time to event model (Zadeh and Schmid, 2020).
Code borrowed from https://github.com/mahmoodlab/Patch-GCN/blob/master/utils/utils.py
Parameters
----------
alpha: float
TODO: document
eps: float
Numerical constant; lower bound to avoid taking logs of tiny numbers.
reduction: str
Do we sum or average the loss function over the batches. Must be one of ['mean', 'sum']
"""
def __init__(self, alpha=0.0, eps=1e-7, reduction='mean'):
super().__init__()
self.alpha = alpha
self.eps = eps
self.reduction = reduction
def __call__(self, h, y, t, c):
"""
Parameters
----------
h: (n_batches, n_classes)
The neural network output discrete survival predictions such that hazards = sigmoid(h).
y_c: (n_batches, 2) or (n_batches, 3)
The true time bin label (first column) and censorship indicator (second column).
"""
return nll_loss(h=h, y=y.unsqueeze(dim=1), c=c.unsqueeze(dim=1),
alpha=self.alpha, eps=self.eps,
reduction=self.reduction)
# TODO: document better and clean up
def nll_loss(h, y, c, alpha=0.0, eps=1e-7, reduction='mean'):
"""
The negative log-likelihood loss function for the discrete time to event model (Zadeh and Schmid, 2020).
Code borrowed from https://github.com/mahmoodlab/Patch-GCN/blob/master/utils/utils.py
Parameters
----------
h: (n_batches, n_classes)
The neural network output discrete survival predictions such that hazards = sigmoid(h).
y: (n_batches, 1)
The true time bin index label.
c: (n_batches, 1)
The censoring status indicator.
alpha: float
TODO: document
eps: float
Numerical constant; lower bound to avoid taking logs of tiny numbers.
reduction: str
Do we sum or average the loss function over the batches. Must be one of ['mean', 'sum']
References
----------
Zadeh, S.G. and Schmid, M., 2020. Bias in cross-entropy-based training of deep survival networks. IEEE transactions on pattern analysis and machine intelligence.
"""
# print("h shape", h.shape)
# make sure these are ints
y = y.type(torch.int64)
c = c.type(torch.int64)
hazards = torch.sigmoid(h)
# print("hazards shape", hazards.shape)
S = torch.cumprod(1 - hazards, dim=1)
# print("S.shape", S.shape, S)
S_padded = torch.cat([torch.ones_like(c), S], 1)
# S(-1) = 0, all patients are alive from (-inf, 0) by definition
# after padding, S(0) = S[1], S(1) = S[2], etc, h(0) = h[0]
# hazards[y] = hazards(1)
# S[1] = S(1)
# TODO: document and check
# print("S_padded.shape", S_padded.shape, S_padded)
# TODO: document/better naming
s_prev = torch.gather(S_padded, dim=1, index=y).clamp(min=eps)
h_this = torch.gather(hazards, dim=1, index=y).clamp(min=eps)
s_this = torch.gather(S_padded, dim=1, index=y+1).clamp(min=eps)
# print('s_prev.s_prev', s_prev.shape, s_prev)
# print('h_this.shape', h_this.shape, h_this)
# print('s_this.shape', s_this.shape, s_this)
uncensored_loss = -(1 - c) * (torch.log(s_prev) + torch.log(h_this))
censored_loss = - c * torch.log(s_this)
# print('uncensored_loss.shape', uncensored_loss.shape)
# print('censored_loss.shape', censored_loss.shape)
neg_l = censored_loss + uncensored_loss
if alpha is not None:
loss = (1 - alpha) * neg_l + alpha * uncensored_loss
if reduction == 'mean':
loss = loss.mean()
elif reduction == 'sum':
loss = loss.sum()
else:
raise ValueError("Bad input for reduction: {}".format(reduction))
return loss