Switch to side-by-side view

--- a
+++ b/medseg/poly_seg_3d_module.py
@@ -0,0 +1,110 @@
+'''
+This should handle transforming targets, currently only implemented for lung variation
+'''
+import pytorch_lightning as pl
+from torch.optim.lr_scheduler import StepLR
+from medseg.utils import get_optimizer, DICELoss, DICEMetric, itk_snap_spawner
+from medseg.unet_v2 import UNet
+
+
+class PolySeg3DModule(pl.LightningModule):
+    def __init__(self, hparams):
+        super().__init__()
+        self.save_hyperparameters(hparams)
+
+        dropout = getattr(self.hparams, "dropout", None)
+        self.weight_decay = getattr(self.hparams, "weight_decay", None)
+        self.scheduling_factor = getattr(self.hparams, "scheduling_factor", None)
+        if dropout == True:
+            print("WARNING: Replacing old hparams true dropout by full dropout")
+            dropout = "full"
+        
+        self.model = UNet(self.hparams.nin, self.hparams.seg_nout, "instance", "3d", self.hparams.init_channel)
+
+        self.lossfn = DICELoss(volumetric=False, per_channel=True, check_bounds=False)
+        self.dicer = DICEMetric(per_channel_metric=True, check_bounds=False)
+
+    def forward(self, x, get_bg=False):
+        '''
+        Testing with 2 channel output for both lungs
+        '''
+        logits = self.model(x)  # 3 canais, bg, ll, rl
+        if get_bg:
+            y_hat = logits.softmax(dim=1)
+        else:
+            y_hat = logits.softmax(dim=1)[:, 1:]
+        y_hat_lung = y_hat.sum(dim=1, keepdim=True)
+
+        return y_hat, y_hat_lung
+        
+    def polymorphic_loss_metrics(self, x, y, y_hat, y_hat_lung, meta, mode=None):
+        target_format = meta["target_format"][0]
+        assert target_format in ["simple", "has_left_right", "has_ggo_con", "full_poly"]
+        loss, metrics = None, None
+
+        if target_format in ["simple", "has_ggo_con"]:
+            # Format returns only lung binary mask, extract lung target
+            y = y[:, 0:1]
+            if mode == "val":
+                metrics = self.dicer(y_hat_lung, y)[0]
+            elif mode == "train":
+                loss = self.lossfn(y_hat_lung, y)
+        elif target_format in ["has_left_right", "full_poly"]:
+            # Extract left right lung
+            y = y[:, 0:2]
+            if mode == "val":
+                metrics = self.dicer(y_hat, y)
+            elif mode == "train":
+                loss = self.lossfn(y_hat, y)
+
+        return y, loss, metrics
+
+    def compute_loss(self, x, y, y_hat, y_hat_lung, meta, prestr):
+        _, loss, _ = self.polymorphic_loss_metrics(x, y, y_hat, y_hat_lung, meta, mode="train")
+        
+        self.log(f"{prestr}loss", loss, on_step=True, on_epoch=True)
+
+        return loss
+
+    def compute_metrics(self, x, y, y_hat, y_hat_lung, meta):
+        _, _, metrics = self.polymorphic_loss_metrics(x, y, y_hat, y_hat_lung, meta, mode="val")
+
+        if isinstance(metrics, list):
+            left_lung_dice, right_lung_dice = metrics
+            self.log("left_lung_dice", left_lung_dice, on_epoch=True, on_step=True, prog_bar=True)
+            self.log("right_lung_dice", right_lung_dice, on_epoch=True, on_step=True, prog_bar=True)
+        else:
+            self.log("lung_dice", metrics, on_epoch=True, on_step=True, prog_bar=True)
+
+        
+    def training_step(self, train_batch, batch_idx):
+        x, y, meta = train_batch
+
+        y_hat, y_hat_lung = self.forward(x)
+
+        loss = self.compute_loss(x, y, y_hat, y_hat_lung, meta, prestr='')
+
+        return loss
+
+    def validation_step(self, val_batch, batch_idx):
+        x, y, meta = val_batch
+        
+        y_hat, y_hat_lung = self.forward(x)
+        
+        self.compute_loss(x, y, y_hat, y_hat_lung, meta, prestr="val_")
+        self.compute_metrics(x, y, y_hat, y_hat_lung, meta)
+
+    def configure_optimizers(self):
+        '''
+        Select optimizer and scheduling strategy according to hparams.
+        '''
+        opt = getattr(self.hparams, "opt", "Adam")
+        optimizer = get_optimizer(opt, self.model.parameters(), self.hparams.lr, wd=self.weight_decay)
+        print(f"Weight decay: {self.weight_decay}")
+
+        if self.scheduling_factor is not None:
+            print(f"Using step LR {self.scheduling_factor}!")
+            scheduler = StepLR(optimizer, 1, self.scheduling_factor)
+            return [optimizer], [scheduler]
+        else:
+            return optimizer