--- a +++ b/medseg/poly_seg_3d_module.py @@ -0,0 +1,110 @@ +''' +This should handle transforming targets, currently only implemented for lung variation +''' +import pytorch_lightning as pl +from torch.optim.lr_scheduler import StepLR +from medseg.utils import get_optimizer, DICELoss, DICEMetric, itk_snap_spawner +from medseg.unet_v2 import UNet + + +class PolySeg3DModule(pl.LightningModule): + def __init__(self, hparams): + super().__init__() + self.save_hyperparameters(hparams) + + dropout = getattr(self.hparams, "dropout", None) + self.weight_decay = getattr(self.hparams, "weight_decay", None) + self.scheduling_factor = getattr(self.hparams, "scheduling_factor", None) + if dropout == True: + print("WARNING: Replacing old hparams true dropout by full dropout") + dropout = "full" + + self.model = UNet(self.hparams.nin, self.hparams.seg_nout, "instance", "3d", self.hparams.init_channel) + + self.lossfn = DICELoss(volumetric=False, per_channel=True, check_bounds=False) + self.dicer = DICEMetric(per_channel_metric=True, check_bounds=False) + + def forward(self, x, get_bg=False): + ''' + Testing with 2 channel output for both lungs + ''' + logits = self.model(x) # 3 canais, bg, ll, rl + if get_bg: + y_hat = logits.softmax(dim=1) + else: + y_hat = logits.softmax(dim=1)[:, 1:] + y_hat_lung = y_hat.sum(dim=1, keepdim=True) + + return y_hat, y_hat_lung + + def polymorphic_loss_metrics(self, x, y, y_hat, y_hat_lung, meta, mode=None): + target_format = meta["target_format"][0] + assert target_format in ["simple", "has_left_right", "has_ggo_con", "full_poly"] + loss, metrics = None, None + + if target_format in ["simple", "has_ggo_con"]: + # Format returns only lung binary mask, extract lung target + y = y[:, 0:1] + if mode == "val": + metrics = self.dicer(y_hat_lung, y)[0] + elif mode == "train": + loss = self.lossfn(y_hat_lung, y) + elif target_format in ["has_left_right", "full_poly"]: + # Extract left right lung + y = y[:, 0:2] + if mode == "val": + metrics = self.dicer(y_hat, y) + elif mode == "train": + loss = self.lossfn(y_hat, y) + + return y, loss, metrics + + def compute_loss(self, x, y, y_hat, y_hat_lung, meta, prestr): + _, loss, _ = self.polymorphic_loss_metrics(x, y, y_hat, y_hat_lung, meta, mode="train") + + self.log(f"{prestr}loss", loss, on_step=True, on_epoch=True) + + return loss + + def compute_metrics(self, x, y, y_hat, y_hat_lung, meta): + _, _, metrics = self.polymorphic_loss_metrics(x, y, y_hat, y_hat_lung, meta, mode="val") + + if isinstance(metrics, list): + left_lung_dice, right_lung_dice = metrics + self.log("left_lung_dice", left_lung_dice, on_epoch=True, on_step=True, prog_bar=True) + self.log("right_lung_dice", right_lung_dice, on_epoch=True, on_step=True, prog_bar=True) + else: + self.log("lung_dice", metrics, on_epoch=True, on_step=True, prog_bar=True) + + + def training_step(self, train_batch, batch_idx): + x, y, meta = train_batch + + y_hat, y_hat_lung = self.forward(x) + + loss = self.compute_loss(x, y, y_hat, y_hat_lung, meta, prestr='') + + return loss + + def validation_step(self, val_batch, batch_idx): + x, y, meta = val_batch + + y_hat, y_hat_lung = self.forward(x) + + self.compute_loss(x, y, y_hat, y_hat_lung, meta, prestr="val_") + self.compute_metrics(x, y, y_hat, y_hat_lung, meta) + + def configure_optimizers(self): + ''' + Select optimizer and scheduling strategy according to hparams. + ''' + opt = getattr(self.hparams, "opt", "Adam") + optimizer = get_optimizer(opt, self.model.parameters(), self.hparams.lr, wd=self.weight_decay) + print(f"Weight decay: {self.weight_decay}") + + if self.scheduling_factor is not None: + print(f"Using step LR {self.scheduling_factor}!") + scheduler = StepLR(optimizer, 1, self.scheduling_factor) + return [optimizer], [scheduler] + else: + return optimizer