a b/medseg/poly_seg_3d_module.py
1
'''
2
This should handle transforming targets, currently only implemented for lung variation
3
'''
4
import pytorch_lightning as pl
5
from torch.optim.lr_scheduler import StepLR
6
from medseg.utils import get_optimizer, DICELoss, DICEMetric, itk_snap_spawner
7
from medseg.unet_v2 import UNet
8
9
10
class PolySeg3DModule(pl.LightningModule):
11
    def __init__(self, hparams):
12
        super().__init__()
13
        self.save_hyperparameters(hparams)
14
15
        dropout = getattr(self.hparams, "dropout", None)
16
        self.weight_decay = getattr(self.hparams, "weight_decay", None)
17
        self.scheduling_factor = getattr(self.hparams, "scheduling_factor", None)
18
        if dropout == True:
19
            print("WARNING: Replacing old hparams true dropout by full dropout")
20
            dropout = "full"
21
        
22
        self.model = UNet(self.hparams.nin, self.hparams.seg_nout, "instance", "3d", self.hparams.init_channel)
23
24
        self.lossfn = DICELoss(volumetric=False, per_channel=True, check_bounds=False)
25
        self.dicer = DICEMetric(per_channel_metric=True, check_bounds=False)
26
27
    def forward(self, x, get_bg=False):
28
        '''
29
        Testing with 2 channel output for both lungs
30
        '''
31
        logits = self.model(x)  # 3 canais, bg, ll, rl
32
        if get_bg:
33
            y_hat = logits.softmax(dim=1)
34
        else:
35
            y_hat = logits.softmax(dim=1)[:, 1:]
36
        y_hat_lung = y_hat.sum(dim=1, keepdim=True)
37
38
        return y_hat, y_hat_lung
39
        
40
    def polymorphic_loss_metrics(self, x, y, y_hat, y_hat_lung, meta, mode=None):
41
        target_format = meta["target_format"][0]
42
        assert target_format in ["simple", "has_left_right", "has_ggo_con", "full_poly"]
43
        loss, metrics = None, None
44
45
        if target_format in ["simple", "has_ggo_con"]:
46
            # Format returns only lung binary mask, extract lung target
47
            y = y[:, 0:1]
48
            if mode == "val":
49
                metrics = self.dicer(y_hat_lung, y)[0]
50
            elif mode == "train":
51
                loss = self.lossfn(y_hat_lung, y)
52
        elif target_format in ["has_left_right", "full_poly"]:
53
            # Extract left right lung
54
            y = y[:, 0:2]
55
            if mode == "val":
56
                metrics = self.dicer(y_hat, y)
57
            elif mode == "train":
58
                loss = self.lossfn(y_hat, y)
59
60
        return y, loss, metrics
61
62
    def compute_loss(self, x, y, y_hat, y_hat_lung, meta, prestr):
63
        _, loss, _ = self.polymorphic_loss_metrics(x, y, y_hat, y_hat_lung, meta, mode="train")
64
        
65
        self.log(f"{prestr}loss", loss, on_step=True, on_epoch=True)
66
67
        return loss
68
69
    def compute_metrics(self, x, y, y_hat, y_hat_lung, meta):
70
        _, _, metrics = self.polymorphic_loss_metrics(x, y, y_hat, y_hat_lung, meta, mode="val")
71
72
        if isinstance(metrics, list):
73
            left_lung_dice, right_lung_dice = metrics
74
            self.log("left_lung_dice", left_lung_dice, on_epoch=True, on_step=True, prog_bar=True)
75
            self.log("right_lung_dice", right_lung_dice, on_epoch=True, on_step=True, prog_bar=True)
76
        else:
77
            self.log("lung_dice", metrics, on_epoch=True, on_step=True, prog_bar=True)
78
79
        
80
    def training_step(self, train_batch, batch_idx):
81
        x, y, meta = train_batch
82
83
        y_hat, y_hat_lung = self.forward(x)
84
85
        loss = self.compute_loss(x, y, y_hat, y_hat_lung, meta, prestr='')
86
87
        return loss
88
89
    def validation_step(self, val_batch, batch_idx):
90
        x, y, meta = val_batch
91
        
92
        y_hat, y_hat_lung = self.forward(x)
93
        
94
        self.compute_loss(x, y, y_hat, y_hat_lung, meta, prestr="val_")
95
        self.compute_metrics(x, y, y_hat, y_hat_lung, meta)
96
97
    def configure_optimizers(self):
98
        '''
99
        Select optimizer and scheduling strategy according to hparams.
100
        '''
101
        opt = getattr(self.hparams, "opt", "Adam")
102
        optimizer = get_optimizer(opt, self.model.parameters(), self.hparams.lr, wd=self.weight_decay)
103
        print(f"Weight decay: {self.weight_decay}")
104
105
        if self.scheduling_factor is not None:
106
            print(f"Using step LR {self.scheduling_factor}!")
107
            scheduler = StepLR(optimizer, 1, self.scheduling_factor)
108
            return [optimizer], [scheduler]
109
        else:
110
            return optimizer