[d5c425]: / losses.py

Download this file

53 lines (44 with data), 1.9 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import torch
from torch import nn
from torch.nn.modules.loss import _WeightedLoss
import numpy as np
class CoxLoss(_WeightedLoss):
# This calculation credit to Travers Ching https://github.com/traversc/cox-nnet
# Cox-nnet: An artificial neural network method for prognosis prediction of high-throughput omics data
def forward(self, hazard_pred: torch.Tensor, survtime: torch.Tensor, censor: torch.Tensor,):
current_batch_len = len(survtime)
# modified for speed
R_mat = survtime.reshape((1, current_batch_len)) >= survtime.reshape(
(current_batch_len, 1)
)
theta = hazard_pred.reshape(-1)
exp_theta = torch.exp(theta)
loss_cox = -torch.mean(
(theta - torch.log(torch.sum(exp_theta * R_mat, dim=1))) * censor
)
return loss_cox
class MultiTaskLoss(nn.Module):
def __init__(
self,
gamma=0.5,
criterion_class=nn.BCEWithLogitsLoss(),
criterion_cox=CoxLoss()
) -> None:
super().__init__()
self.gamma = gamma
self.criterion_class = criterion_class
self.criterion_cox = criterion_cox
def forward(self, task, pred_grade, pred_hazard, grade, time, event=None):
if task == "multitask":
grade_loss = self.criterion_class(pred_grade, grade)
cox_loss = self.criterion_cox(pred_hazard, time, event)
return self.gamma * grade_loss + (1 - self.gamma) * cox_loss
elif task == "classification":
grade_loss = self.criterion_class(pred_grade, grade)
return grade_loss
elif task == "survival":
cox_loss = self.criterion_cox(pred_hazard, time, event)
return cox_loss
else:
raise NotImplementedError(
f'task method {task} is not implemented')