Diff of /losses.py [000000] .. [d5c425]

Switch to side-by-side view

--- a
+++ b/losses.py
@@ -0,0 +1,52 @@
+import torch
+from torch import nn
+from torch.nn.modules.loss import _WeightedLoss
+import numpy as np
+
+
+class CoxLoss(_WeightedLoss):
+    # This calculation credit to Travers Ching https://github.com/traversc/cox-nnet
+    # Cox-nnet: An artificial neural network method for prognosis prediction of high-throughput omics data
+    def forward(self, hazard_pred: torch.Tensor, survtime: torch.Tensor, censor: torch.Tensor,):
+
+        current_batch_len = len(survtime)
+        # modified for speed
+        R_mat = survtime.reshape((1, current_batch_len)) >= survtime.reshape(
+            (current_batch_len, 1)
+        )
+        theta = hazard_pred.reshape(-1)
+        exp_theta = torch.exp(theta)
+        loss_cox = -torch.mean(
+            (theta - torch.log(torch.sum(exp_theta * R_mat, dim=1))) * censor
+        )
+        return loss_cox
+    
+
+
+
+class MultiTaskLoss(nn.Module):
+    def __init__(
+        self,
+        gamma=0.5,
+        criterion_class=nn.BCEWithLogitsLoss(),
+        criterion_cox=CoxLoss()
+    ) -> None:
+        super().__init__()
+        self.gamma = gamma
+        self.criterion_class = criterion_class
+        self.criterion_cox = criterion_cox
+
+    def forward(self, task, pred_grade, pred_hazard, grade, time, event=None):
+        if task == "multitask":
+            grade_loss = self.criterion_class(pred_grade, grade)
+            cox_loss = self.criterion_cox(pred_hazard, time, event)
+            return self.gamma * grade_loss + (1 - self.gamma) * cox_loss
+        elif task == "classification":
+            grade_loss = self.criterion_class(pred_grade, grade)
+            return grade_loss
+        elif task == "survival":
+            cox_loss = self.criterion_cox(pred_hazard, time, event)
+            return cox_loss
+        else:
+            raise NotImplementedError(
+                f'task method {task} is not implemented')