--- a +++ b/utils/loss_func.py @@ -0,0 +1,113 @@ +import torch +import torch.nn as nn +import numpy as np +import torch +import torch.nn.functional as F + + +class NLLSurvLoss(nn.Module): + """ + The negative log-likelihood loss function for the discrete time to event model (Zadeh and Schmid, 2020). + Code borrowed from https://github.com/mahmoodlab/Patch-GCN/blob/master/utils/utils.py + Parameters + ---------- + alpha: float + TODO: document + eps: float + Numerical constant; lower bound to avoid taking logs of tiny numbers. + reduction: str + Do we sum or average the loss function over the batches. Must be one of ['mean', 'sum'] + """ + def __init__(self, alpha=0.0, eps=1e-7, reduction='mean'): + super().__init__() + self.alpha = alpha + self.eps = eps + self.reduction = reduction + + def __call__(self, h, y, t, c): + """ + Parameters + ---------- + h: (n_batches, n_classes) + The neural network output discrete survival predictions such that hazards = sigmoid(h). + y_c: (n_batches, 2) or (n_batches, 3) + The true time bin label (first column) and censorship indicator (second column). + """ + + return nll_loss(h=h, y=y.unsqueeze(dim=1), c=c.unsqueeze(dim=1), + alpha=self.alpha, eps=self.eps, + reduction=self.reduction) + + +# TODO: document better and clean up +def nll_loss(h, y, c, alpha=0.0, eps=1e-7, reduction='mean'): + """ + The negative log-likelihood loss function for the discrete time to event model (Zadeh and Schmid, 2020). + Code borrowed from https://github.com/mahmoodlab/Patch-GCN/blob/master/utils/utils.py + Parameters + ---------- + h: (n_batches, n_classes) + The neural network output discrete survival predictions such that hazards = sigmoid(h). + y: (n_batches, 1) + The true time bin index label. + c: (n_batches, 1) + The censoring status indicator. + alpha: float + TODO: document + eps: float + Numerical constant; lower bound to avoid taking logs of tiny numbers. + reduction: str + Do we sum or average the loss function over the batches. Must be one of ['mean', 'sum'] + References + ---------- + Zadeh, S.G. and Schmid, M., 2020. Bias in cross-entropy-based training of deep survival networks. IEEE transactions on pattern analysis and machine intelligence. + """ + # print("h shape", h.shape) + + # make sure these are ints + y = y.type(torch.int64) + c = c.type(torch.int64) + + hazards = torch.sigmoid(h) + # print("hazards shape", hazards.shape) + + S = torch.cumprod(1 - hazards, dim=1) + # print("S.shape", S.shape, S) + + S_padded = torch.cat([torch.ones_like(c), S], 1) + # S(-1) = 0, all patients are alive from (-inf, 0) by definition + # after padding, S(0) = S[1], S(1) = S[2], etc, h(0) = h[0] + # hazards[y] = hazards(1) + # S[1] = S(1) + # TODO: document and check + + # print("S_padded.shape", S_padded.shape, S_padded) + + + # TODO: document/better naming + s_prev = torch.gather(S_padded, dim=1, index=y).clamp(min=eps) + h_this = torch.gather(hazards, dim=1, index=y).clamp(min=eps) + s_this = torch.gather(S_padded, dim=1, index=y+1).clamp(min=eps) + # print('s_prev.s_prev', s_prev.shape, s_prev) + # print('h_this.shape', h_this.shape, h_this) + # print('s_this.shape', s_this.shape, s_this) + + uncensored_loss = -(1 - c) * (torch.log(s_prev) + torch.log(h_this)) + censored_loss = - c * torch.log(s_this) + + + # print('uncensored_loss.shape', uncensored_loss.shape) + # print('censored_loss.shape', censored_loss.shape) + + neg_l = censored_loss + uncensored_loss + if alpha is not None: + loss = (1 - alpha) * neg_l + alpha * uncensored_loss + + if reduction == 'mean': + loss = loss.mean() + elif reduction == 'sum': + loss = loss.sum() + else: + raise ValueError("Bad input for reduction: {}".format(reduction)) + + return loss \ No newline at end of file