Diff of /monai/multilabel_train.py [000000] .. [4e96d3]

Switch to unified view

a b/monai/multilabel_train.py
1
import argparse
2
import gc
3
import importlib
4
import os
5
import sys
6
import shutil
7
8
import numpy as np
9
import pandas as pd
10
import torch
11
from torch import nn
12
from monai.handlers.utils import from_engine
13
from monai.inferers import sliding_window_inference
14
from monai.data import decollate_batch
15
from torch.cuda.amp import GradScaler, autocast
16
from tqdm import tqdm
17
18
from utils import *
19
20
from monai.transforms import (
21
    Compose,
22
    Activations,
23
    AsDiscrete,
24
    Activationsd,
25
    AsDiscreted,
26
    KeepLargestConnectedComponentd,
27
    Invertd,
28
    LoadImage,
29
    Transposed,
30
)
31
import json
32
from metric import HausdorffScore
33
from monai.utils import set_determinism
34
from monai.losses import DiceLoss, DiceCELoss
35
from monai.networks.nets import UNet, SegResNet, DynUnet
36
from monai.optimizers import Novograd
37
from monai.metrics import DiceMetric
38
torch.backends.cudnn.enabled = True
39
torch.backends.cudnn.benchmark = True
40
41
42
def main(cfg):
43
44
    # data sequence
45
    if cfg.fold != -1:
46
        cfg.data_json_dir = cfg.data_dir + f"dataset_3d_fold_{cfg.fold}.json"
47
    else:
48
        cfg.data_json_dir = cfg.data_dir + f"dataset_3d_all.json"
49
50
    with open(cfg.data_json_dir, "r") as f:
51
        cfg.data_json = json.load(f)
52
53
    if cfg.fold != -1:
54
        fold_dir = f"fold{cfg.fold}"
55
    else:
56
        fold_dir = "all"
57
    os.makedirs(str(cfg.output_dir + f"/{fold_dir}/"), exist_ok=True)
58
59
    # # set random seed
60
    # set_determinism(cfg.seed)
61
62
    train_dataset = get_train_dataset(cfg)
63
    train_dataloader = get_train_dataloader(train_dataset, cfg)
64
65
    val_dataset = get_val_dataset(cfg)
66
    val_dataloader = get_val_dataloader(val_dataset, cfg)
67
68
    print(f"run fold {cfg.fold}, train len: {len(train_dataset)}")
69
70
    if cfg.model_type.startswith("segres"):
71
        model = SegResNet(
72
            spatial_dims = 3,
73
            in_channels = 1,
74
            out_channels = 3,
75
            init_filters = int(cfg.model_type.replace("segres", "")),
76
            norm = "BATCH",
77
            act = "PRELU"
78
        ).to(cfg.device)
79
80
81
    print(cfg.weights)
82
    if cfg.weights is not None:
83
        stt = torch.load(cfg.weights, map_location = "cpu")
84
        if "model" in stt:
85
            stt = stt["model"]
86
        if "state_dict" in stt:
87
            stt = stt["state_dict"]
88
            del stt["out.conv.conv.weight"], stt["out.conv.conv.bias"]
89
        model.load_state_dict(stt, strict = False)
90
        print(f"weights from: {cfg.weights} are loaded.")
91
92
    # set optimizer, lr scheduler
93
    total_steps = len(train_dataset)
94
    optimizer = get_optimizer(model, cfg)
95
    # optimizer = Novograd(model.parameters(), cfg.lr)
96
    scheduler = get_scheduler(cfg, optimizer, total_steps)
97
98
    seg_loss_func = DiceBceMultilabelLoss(w_dice=cfg.w_dice, w_bce=1-cfg.w_dice)
99
    # seg_loss_func = DiceLoss(sigmoid=True, smooth_nr=0.01, smooth_dr=0.01, include_background=True, batch=True)
100
    dice_metric = DiceMetric(reduction="mean")
101
    hausdorff_metric = HausdorffScore(reduction="mean")
102
    metric_function = [dice_metric, hausdorff_metric]
103
104
    post_pred = Compose([
105
        Activations(sigmoid=True),
106
        AsDiscrete(threshold=0.5),
107
    ])
108
109
    # train and val loop
110
    step = 0
111
    i = 0
112
    if cfg.eval is True:
113
        best_val_metric = run_eval(
114
            model=model,
115
            val_dataloader=val_dataloader,
116
            post_pred=post_pred,
117
            metric_function=metric_function,
118
            seg_loss_func=seg_loss_func,
119
            cfg=cfg,
120
            epoch=0,
121
        )
122
    else:
123
        best_val_metric = 0.0
124
    best_weights_name = "best_weights"
125
    for epoch in range(cfg.epochs):
126
        print("EPOCH:", epoch)
127
        gc.collect()
128
        if cfg.train is True:
129
            run_train(
130
                model=model,
131
                train_dataloader=train_dataloader,
132
                optimizer=optimizer,
133
                scheduler=scheduler,
134
                seg_loss_func=seg_loss_func,
135
                cfg=cfg,
136
                # writer=writer,
137
                epoch=epoch,
138
                step=step,
139
                iteration=i,
140
            )
141
142
        if (epoch + 1) % cfg.eval_epochs == 0 and cfg.eval is True and epoch > cfg.start_eval_epoch:
143
            val_metric = run_eval(
144
                model=model,
145
                val_dataloader=val_dataloader,
146
                post_pred=post_pred,
147
                metric_function=metric_function,
148
                seg_loss_func=seg_loss_func,
149
                cfg=cfg,
150
                epoch=epoch,
151
            )
152
153
            if val_metric > best_val_metric:
154
                print(f"Find better metric: val_metric {best_val_metric:.5} -> {val_metric:.5}")
155
                best_val_metric = val_metric
156
                checkpoint = create_checkpoint(
157
                    model,
158
                    optimizer,
159
                    epoch,
160
                    scheduler=scheduler,
161
                )
162
                torch.save(
163
                    checkpoint,
164
                    f"{cfg.output_dir}/{fold_dir}/{best_weights_name}.pth",
165
                )
166
            else:
167
                if cfg.load_best_weights is True:
168
                    try:
169
                        model.load_state_dict(torch.load(f"{cfg.output_dir}/{fold_dir}/{best_weights_name}.pth")["model"])
170
                        print(f"metric no improve, load the saved best weights with score: {best_val_metric}.")
171
                    except:
172
                        pass
173
174
        if (epoch + 1) == cfg.epochs:
175
            # save final best weights, with its distinct name in order to avoid mistakes.
176
            if os.path.exists(f"{cfg.output_dir}/{fold_dir}/{best_weights_name}.pth"):
177
                shutil.copyfile(
178
                    f"{cfg.output_dir}/{fold_dir}/{best_weights_name}.pth",
179
                    f"{cfg.output_dir}/{fold_dir}/{best_weights_name}_{best_val_metric:.4f}.pth",
180
                )
181
182
        torch.save(
183
            model.state_dict(),
184
            f"{cfg.output_dir}/{fold_dir}/last.pth",
185
        )
186
            
187
188
def run_train(
189
    model,
190
    train_dataloader,
191
    optimizer,
192
    scheduler,
193
    seg_loss_func,
194
    cfg,
195
    # writer,
196
    epoch,
197
    step,
198
    iteration,
199
):
200
    model.train()
201
    scaler = GradScaler()
202
    progress_bar = tqdm(range(len(train_dataloader)))
203
    tr_it = iter(train_dataloader)
204
    dataset_size = 0
205
    running_loss = 0.0
206
207
    for itr in progress_bar:
208
        iteration += 1
209
        batch = next(tr_it)
210
        inputs, masks = (
211
            batch["image"].to(cfg.device),
212
            batch["mask"].to(cfg.device),
213
        )
214
215
        step += cfg.batch_size
216
217
        if cfg.amp is True:
218
            with autocast():
219
                outputs = model(inputs)
220
                loss = seg_loss_func(outputs, masks)
221
        else:
222
            outputs = model(inputs)
223
            loss = seg_loss_func(outputs, masks)
224
        if cfg.amp is True:
225
            scaler.scale(loss).backward()
226
            torch.nn.utils.clip_grad_norm_(model.parameters(), 12)
227
            scaler.step(optimizer)
228
            scaler.update()
229
        else:
230
            loss.backward()
231
            optimizer.step()
232
233
        optimizer.zero_grad()
234
        scheduler.step()
235
        
236
        running_loss += (loss.item() * cfg.batch_size)
237
        dataset_size += cfg.batch_size
238
        losses = running_loss / dataset_size
239
        progress_bar.set_description(f"loss: {losses:.4f} lr: {optimizer.param_groups[0]['lr']:.6f}")
240
        del batch, inputs, masks, outputs, loss
241
    print(f"Train loss: {losses:.4f}")
242
    torch.cuda.empty_cache()
243
244
def run_eval(model, val_dataloader, post_pred, metric_function, seg_loss_func, cfg, epoch):
245
246
    model.eval()
247
248
    dice_metric, hausdorff_metric = metric_function
249
250
    progress_bar = tqdm(range(len(val_dataloader)))
251
    val_it = iter(val_dataloader)
252
    with torch.no_grad():
253
        for itr in progress_bar:
254
            batch = next(val_it)
255
            val_inputs, val_masks = (
256
                batch["image"].to(cfg.device),
257
                batch["mask"].to(cfg.device),
258
            )
259
            if cfg.val_amp is True:
260
                with autocast():
261
                    val_outputs = sliding_window_inference(val_inputs, cfg.roi_size, cfg.sw_batch_size, model)
262
            else:
263
                val_outputs = sliding_window_inference(val_inputs, cfg.roi_size, cfg.sw_batch_size, model)
264
            # cal metric
265
            if cfg.run_tta_val is True:
266
                tta_ct = 1
267
                for dims in [[2],[3],[2,3]]:
268
                    flip_val_outputs = sliding_window_inference(torch.flip(val_inputs, dims=dims), cfg.roi_size, cfg.sw_batch_size, model)
269
                    val_outputs += torch.flip(flip_val_outputs, dims=dims)
270
                    tta_ct += 1
271
                
272
                val_outputs /= tta_ct
273
274
            val_outputs = [post_pred(i) for i in val_outputs]
275
            val_outputs = torch.stack(val_outputs)
276
            # metric is slice level put (n, c, h, w, d) to (n, d, c, h, w) to (n*d, c, h, w)
277
            val_outputs = val_outputs.permute([0, 4, 1, 2, 3]).flatten(0, 1)
278
            val_masks = val_masks.permute([0, 4, 1, 2, 3]).flatten(0, 1)
279
280
            hausdorff_metric(y_pred=val_outputs, y=val_masks)
281
            dice_metric(y_pred=val_outputs, y=val_masks)
282
283
            del val_outputs, val_inputs, val_masks, batch
284
285
    dice_score = dice_metric.aggregate().item()
286
    hausdorff_score = hausdorff_metric.aggregate().item()
287
    dice_metric.reset()
288
    hausdorff_metric.reset()
289
290
    all_score = dice_score * 0.4 + hausdorff_score * 0.6
291
    print(f"dice_score: {dice_score} hausdorff_score: {hausdorff_score} all_score: {all_score}")
292
    torch.cuda.empty_cache()
293
294
    return all_score
295
296
297
if __name__ == "__main__":
298
299
    sys.path.append("configs")
300
301
    parser = argparse.ArgumentParser(description="")
302
303
    parser.add_argument("-c", "--config", default="cfg_unet_multilabel", help="config filename")
304
    parser.add_argument("-f", "--fold", type=int, default=0, help="fold")
305
    parser.add_argument("-s", "--seed", type=int, default=20220421, help="seed")
306
    parser.add_argument("-w", "--weights", default=None, help="the path of weights")
307
308
    parser_args, _ = parser.parse_known_args(sys.argv)
309
310
    cfg = importlib.import_module(parser_args.config).cfg
311
    cfg.fold = parser_args.fold
312
    cfg.seed = parser_args.seed
313
    cfg.weights = parser_args.weights
314
315
    main(cfg)