|
a |
|
b/medseg/poly_seg_3d_module.py |
|
|
1 |
''' |
|
|
2 |
This should handle transforming targets, currently only implemented for lung variation |
|
|
3 |
''' |
|
|
4 |
import pytorch_lightning as pl |
|
|
5 |
from torch.optim.lr_scheduler import StepLR |
|
|
6 |
from medseg.utils import get_optimizer, DICELoss, DICEMetric, itk_snap_spawner |
|
|
7 |
from medseg.unet_v2 import UNet |
|
|
8 |
|
|
|
9 |
|
|
|
10 |
class PolySeg3DModule(pl.LightningModule): |
|
|
11 |
def __init__(self, hparams): |
|
|
12 |
super().__init__() |
|
|
13 |
self.save_hyperparameters(hparams) |
|
|
14 |
|
|
|
15 |
dropout = getattr(self.hparams, "dropout", None) |
|
|
16 |
self.weight_decay = getattr(self.hparams, "weight_decay", None) |
|
|
17 |
self.scheduling_factor = getattr(self.hparams, "scheduling_factor", None) |
|
|
18 |
if dropout == True: |
|
|
19 |
print("WARNING: Replacing old hparams true dropout by full dropout") |
|
|
20 |
dropout = "full" |
|
|
21 |
|
|
|
22 |
self.model = UNet(self.hparams.nin, self.hparams.seg_nout, "instance", "3d", self.hparams.init_channel) |
|
|
23 |
|
|
|
24 |
self.lossfn = DICELoss(volumetric=False, per_channel=True, check_bounds=False) |
|
|
25 |
self.dicer = DICEMetric(per_channel_metric=True, check_bounds=False) |
|
|
26 |
|
|
|
27 |
def forward(self, x, get_bg=False): |
|
|
28 |
''' |
|
|
29 |
Testing with 2 channel output for both lungs |
|
|
30 |
''' |
|
|
31 |
logits = self.model(x) # 3 canais, bg, ll, rl |
|
|
32 |
if get_bg: |
|
|
33 |
y_hat = logits.softmax(dim=1) |
|
|
34 |
else: |
|
|
35 |
y_hat = logits.softmax(dim=1)[:, 1:] |
|
|
36 |
y_hat_lung = y_hat.sum(dim=1, keepdim=True) |
|
|
37 |
|
|
|
38 |
return y_hat, y_hat_lung |
|
|
39 |
|
|
|
40 |
def polymorphic_loss_metrics(self, x, y, y_hat, y_hat_lung, meta, mode=None): |
|
|
41 |
target_format = meta["target_format"][0] |
|
|
42 |
assert target_format in ["simple", "has_left_right", "has_ggo_con", "full_poly"] |
|
|
43 |
loss, metrics = None, None |
|
|
44 |
|
|
|
45 |
if target_format in ["simple", "has_ggo_con"]: |
|
|
46 |
# Format returns only lung binary mask, extract lung target |
|
|
47 |
y = y[:, 0:1] |
|
|
48 |
if mode == "val": |
|
|
49 |
metrics = self.dicer(y_hat_lung, y)[0] |
|
|
50 |
elif mode == "train": |
|
|
51 |
loss = self.lossfn(y_hat_lung, y) |
|
|
52 |
elif target_format in ["has_left_right", "full_poly"]: |
|
|
53 |
# Extract left right lung |
|
|
54 |
y = y[:, 0:2] |
|
|
55 |
if mode == "val": |
|
|
56 |
metrics = self.dicer(y_hat, y) |
|
|
57 |
elif mode == "train": |
|
|
58 |
loss = self.lossfn(y_hat, y) |
|
|
59 |
|
|
|
60 |
return y, loss, metrics |
|
|
61 |
|
|
|
62 |
def compute_loss(self, x, y, y_hat, y_hat_lung, meta, prestr): |
|
|
63 |
_, loss, _ = self.polymorphic_loss_metrics(x, y, y_hat, y_hat_lung, meta, mode="train") |
|
|
64 |
|
|
|
65 |
self.log(f"{prestr}loss", loss, on_step=True, on_epoch=True) |
|
|
66 |
|
|
|
67 |
return loss |
|
|
68 |
|
|
|
69 |
def compute_metrics(self, x, y, y_hat, y_hat_lung, meta): |
|
|
70 |
_, _, metrics = self.polymorphic_loss_metrics(x, y, y_hat, y_hat_lung, meta, mode="val") |
|
|
71 |
|
|
|
72 |
if isinstance(metrics, list): |
|
|
73 |
left_lung_dice, right_lung_dice = metrics |
|
|
74 |
self.log("left_lung_dice", left_lung_dice, on_epoch=True, on_step=True, prog_bar=True) |
|
|
75 |
self.log("right_lung_dice", right_lung_dice, on_epoch=True, on_step=True, prog_bar=True) |
|
|
76 |
else: |
|
|
77 |
self.log("lung_dice", metrics, on_epoch=True, on_step=True, prog_bar=True) |
|
|
78 |
|
|
|
79 |
|
|
|
80 |
def training_step(self, train_batch, batch_idx): |
|
|
81 |
x, y, meta = train_batch |
|
|
82 |
|
|
|
83 |
y_hat, y_hat_lung = self.forward(x) |
|
|
84 |
|
|
|
85 |
loss = self.compute_loss(x, y, y_hat, y_hat_lung, meta, prestr='') |
|
|
86 |
|
|
|
87 |
return loss |
|
|
88 |
|
|
|
89 |
def validation_step(self, val_batch, batch_idx): |
|
|
90 |
x, y, meta = val_batch |
|
|
91 |
|
|
|
92 |
y_hat, y_hat_lung = self.forward(x) |
|
|
93 |
|
|
|
94 |
self.compute_loss(x, y, y_hat, y_hat_lung, meta, prestr="val_") |
|
|
95 |
self.compute_metrics(x, y, y_hat, y_hat_lung, meta) |
|
|
96 |
|
|
|
97 |
def configure_optimizers(self): |
|
|
98 |
''' |
|
|
99 |
Select optimizer and scheduling strategy according to hparams. |
|
|
100 |
''' |
|
|
101 |
opt = getattr(self.hparams, "opt", "Adam") |
|
|
102 |
optimizer = get_optimizer(opt, self.model.parameters(), self.hparams.lr, wd=self.weight_decay) |
|
|
103 |
print(f"Weight decay: {self.weight_decay}") |
|
|
104 |
|
|
|
105 |
if self.scheduling_factor is not None: |
|
|
106 |
print(f"Using step LR {self.scheduling_factor}!") |
|
|
107 |
scheduler = StepLR(optimizer, 1, self.scheduling_factor) |
|
|
108 |
return [optimizer], [scheduler] |
|
|
109 |
else: |
|
|
110 |
return optimizer |