Diff of /train.py [000000] .. [83198a]

Switch to unified view

a b/train.py
1
# Copyright 2020 MONAI Consortium
2
# Licensed under the Apache License, Version 2.0 (the "License");
3
# you may not use this file except in compliance with the License.
4
# You may obtain a copy of the License at
5
#     http://www.apache.org/licenses/LICENSE-2.0
6
# Unless required by applicable law or agreed to in writing, software
7
# distributed under the License is distributed on an "AS IS" BASIS,
8
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
# See the License for the specific language governing permissions and
10
# limitations under the License.
11
12
13
from init import Options
14
from networks import build_net, update_learning_rate, build_UNETR
15
# from networks import build_net
16
import logging
17
import os
18
import sys
19
import tempfile
20
from glob import glob
21
22
import nibabel as nib
23
import numpy as np
24
import torch
25
from torch.utils.data import DataLoader
26
from torch.utils.tensorboard import SummaryWriter
27
28
import monai
29
from monai.data import create_test_image_3d, list_data_collate, decollate_batch
30
from monai.inferers import sliding_window_inference
31
from monai.metrics import DiceMetric
32
from monai.transforms import (EnsureType, Compose, LoadImaged, AddChanneld, Transpose,Activations,AsDiscrete, RandGaussianSmoothd, CropForegroundd, SpatialPadd,
33
                              ScaleIntensityd, ToTensord, RandSpatialCropd, Rand3DElasticd, RandAffined, RandZoomd,
34
    Spacingd, Orientationd, Resized, ThresholdIntensityd, RandShiftIntensityd, BorderPadd, RandGaussianNoised, RandAdjustContrastd,NormalizeIntensityd,RandFlipd)
35
36
from monai.visualize import plot_2d_or_3d_image
37
38
def main():
39
    opt = Options().parse()
40
    # monai.config.print_config()
41
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)
42
43
    # check gpus
44
    if opt.gpu_ids != '-1':
45
        num_gpus = len(opt.gpu_ids.split(','))
46
    else:
47
        num_gpus = 0
48
    print('number of GPU:', num_gpus)
49
50
    # Data loader creation
51
    # train images
52
    train_images = sorted(glob(os.path.join(opt.images_folder, 'train', 'image*.nii')))
53
    train_segs = sorted(glob(os.path.join(opt.labels_folder, 'train', 'label*.nii')))
54
55
    train_images_for_dice = sorted(glob(os.path.join(opt.images_folder, 'train', 'image*.nii')))
56
    train_segs_for_dice = sorted(glob(os.path.join(opt.labels_folder, 'train', 'label*.nii')))
57
58
    # validation images
59
    val_images = sorted(glob(os.path.join(opt.images_folder, 'val', 'image*.nii')))
60
    val_segs = sorted(glob(os.path.join(opt.labels_folder, 'val', 'label*.nii')))
61
62
    # test images
63
    test_images = sorted(glob(os.path.join(opt.images_folder, 'test', 'image*.nii')))
64
    test_segs = sorted(glob(os.path.join(opt.labels_folder, 'test', 'label*.nii')))
65
66
    # augment the data list for training
67
    for i in range(int(opt.increase_factor_data)):
68
    
69
        train_images.extend(train_images)
70
        train_segs.extend(train_segs)
71
72
    print('Number of training patches per epoch:', len(train_images))
73
    print('Number of training images per epoch:', len(train_images_for_dice))
74
    print('Number of validation images per epoch:', len(val_images))
75
    print('Number of test images per epoch:', len(test_images))
76
77
    # Creation of data directories for data_loader
78
79
    train_dicts = [{'image': image_name, 'label': label_name}
80
                  for image_name, label_name in zip(train_images, train_segs)]
81
82
    train_dice_dicts = [{'image': image_name, 'label': label_name}
83
                   for image_name, label_name in zip(train_images_for_dice, train_segs_for_dice)]
84
85
    val_dicts = [{'image': image_name, 'label': label_name}
86
                   for image_name, label_name in zip(val_images, val_segs)]
87
88
    test_dicts = [{'image': image_name, 'label': label_name}
89
                 for image_name, label_name in zip(test_images, test_segs)]
90
91
    # Transforms list
92
93
    if opt.resolution is not None:
94
        train_transforms = [
95
            LoadImaged(keys=['image', 'label']),
96
            AddChanneld(keys=['image', 'label']),
97
            # ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135),  # CT HU filter
98
            # ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215),
99
            CropForegroundd(keys=['image', 'label'], source_key='image'),               # crop CropForeground
100
101
            NormalizeIntensityd(keys=['image']),                                          # augmentation
102
            ScaleIntensityd(keys=['image']),                                              # intensity
103
            Spacingd(keys=['image', 'label'], pixdim=opt.resolution, mode=('bilinear', 'nearest')),  # resolution
104
105
            RandFlipd(keys=['image', 'label'], prob=0.15, spatial_axis=1),
106
            RandFlipd(keys=['image', 'label'], prob=0.15, spatial_axis=0),
107
            RandFlipd(keys=['image', 'label'], prob=0.15, spatial_axis=2),
108
            RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1,
109
                        rotate_range=(np.pi / 36, np.pi / 36, np.pi * 2), padding_mode="zeros"),
110
            RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1,
111
                        rotate_range=(np.pi / 36, np.pi / 2, np.pi / 36), padding_mode="zeros"),
112
            RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1,
113
                        rotate_range=(np.pi / 2, np.pi / 36, np.pi / 36), padding_mode="zeros"),
114
            Rand3DElasticd(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1,
115
                           sigma_range=(5, 8), magnitude_range=(100, 200), scale_range=(0.15, 0.15, 0.15),
116
                           padding_mode="zeros"),
117
            RandGaussianSmoothd(keys=["image"], sigma_x=(0.5, 1.15), sigma_y=(0.5, 1.15), sigma_z=(0.5, 1.15), prob=0.1,),
118
            RandAdjustContrastd(keys=['image'], gamma=(0.5, 2.5), prob=0.1),
119
            RandGaussianNoised(keys=['image'], prob=0.1, mean=np.random.uniform(0, 0.5), std=np.random.uniform(0, 15)),
120
            RandShiftIntensityd(keys=['image'], offsets=np.random.uniform(0,0.3), prob=0.1),
121
122
            SpatialPadd(keys=['image', 'label'], spatial_size=opt.patch_size, method= 'end'),  # pad if the image is smaller than patch
123
            RandSpatialCropd(keys=['image', 'label'], roi_size=opt.patch_size, random_size=False),
124
            ToTensord(keys=['image', 'label'])
125
        ]
126
127
        val_transforms = [
128
            LoadImaged(keys=['image', 'label']),
129
            AddChanneld(keys=['image', 'label']),
130
            # ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135),
131
            # ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215),
132
            CropForegroundd(keys=['image', 'label'], source_key='image'),                   # crop CropForeground
133
134
            NormalizeIntensityd(keys=['image']),                                      # intensity
135
            ScaleIntensityd(keys=['image']),
136
            Spacingd(keys=['image', 'label'], pixdim=opt.resolution, mode=('bilinear', 'nearest')),  # resolution
137
138
            SpatialPadd(keys=['image', 'label'], spatial_size=opt.patch_size, method= 'end'),  # pad if the image is smaller than patch
139
            ToTensord(keys=['image', 'label'])
140
        ]
141
    else:
142
        train_transforms = [
143
            LoadImaged(keys=['image', 'label']),
144
            AddChanneld(keys=['image', 'label']),
145
            # ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135),
146
            # ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215),
147
            CropForegroundd(keys=['image', 'label'], source_key='image'),               # crop CropForeground
148
149
            NormalizeIntensityd(keys=['image']),                                          # augmentation
150
            ScaleIntensityd(keys=['image']),                                              # intensity
151
152
            RandFlipd(keys=['image', 'label'], prob=0.15, spatial_axis=1),
153
            RandFlipd(keys=['image', 'label'], prob=0.15, spatial_axis=0),
154
            RandFlipd(keys=['image', 'label'], prob=0.15, spatial_axis=2),
155
            RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1,
156
                        rotate_range=(np.pi / 36, np.pi / 36, np.pi * 2), padding_mode="zeros"),
157
            RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1,
158
                        rotate_range=(np.pi / 36, np.pi / 2, np.pi / 36), padding_mode="zeros"),
159
            RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1,
160
                        rotate_range=(np.pi / 2, np.pi / 36, np.pi / 36), padding_mode="zeros"),
161
            Rand3DElasticd(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1,
162
                           sigma_range=(5, 8), magnitude_range=(100, 200), scale_range=(0.15, 0.15, 0.15),
163
                           padding_mode="zeros"),
164
            RandGaussianSmoothd(keys=["image"], sigma_x=(0.5, 1.15), sigma_y=(0.5, 1.15), sigma_z=(0.5, 1.15), prob=0.1,),
165
            RandAdjustContrastd(keys=['image'], gamma=(0.5, 2.5), prob=0.1),
166
            RandGaussianNoised(keys=['image'], prob=0.1, mean=np.random.uniform(0, 0.5), std=np.random.uniform(0, 1)),
167
            RandShiftIntensityd(keys=['image'], offsets=np.random.uniform(0,0.3), prob=0.1),
168
169
            SpatialPadd(keys=['image', 'label'], spatial_size=opt.patch_size, method= 'end'),  # pad if the image is smaller than patch
170
            RandSpatialCropd(keys=['image', 'label'], roi_size=opt.patch_size, random_size=False),
171
            ToTensord(keys=['image', 'label'])
172
        ]
173
174
        val_transforms = [
175
            LoadImaged(keys=['image', 'label']),
176
            AddChanneld(keys=['image', 'label']),
177
            # ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135),
178
            # ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215),
179
            CropForegroundd(keys=['image', 'label'], source_key='image'),                   # crop CropForeground
180
181
            NormalizeIntensityd(keys=['image']),                                      # intensity
182
            ScaleIntensityd(keys=['image']),
183
184
            SpatialPadd(keys=['image', 'label'], spatial_size=opt.patch_size, method= 'end'),  # pad if the image is smaller than patch
185
            ToTensord(keys=['image', 'label'])
186
        ]
187
188
    train_transforms = Compose(train_transforms)
189
    val_transforms = Compose(val_transforms)
190
191
    # create a training data loader
192
    check_train = monai.data.Dataset(data=train_dicts, transform=train_transforms)
193
    train_loader = DataLoader(check_train, batch_size=opt.batch_size, shuffle=True, collate_fn=list_data_collate, num_workers=opt.workers, pin_memory=False)
194
195
    # create a training_dice data loader
196
    check_val = monai.data.Dataset(data=train_dice_dicts, transform=val_transforms)
197
    train_dice_loader = DataLoader(check_val, batch_size=1, num_workers=opt.workers, collate_fn=list_data_collate, pin_memory=False)
198
199
    # create a validation data loader
200
    check_val = monai.data.Dataset(data=val_dicts, transform=val_transforms)
201
    val_loader = DataLoader(check_val, batch_size=1, num_workers=opt.workers, collate_fn=list_data_collate, pin_memory=False)
202
203
    # create a validation data loader
204
    check_val = monai.data.Dataset(data=test_dicts, transform=val_transforms)
205
    test_loader = DataLoader(check_val, batch_size=1, num_workers=opt.workers, collate_fn=list_data_collate, pin_memory=False)
206
207
    # build the network
208
    if opt.network is 'nnunet':
209
        net = build_net()  # nn build_net
210
    elif opt.network is 'unetr':
211
        net = build_UNETR() # UneTR
212
    net.cuda()
213
214
    if num_gpus > 1:
215
        net = torch.nn.DataParallel(net)
216
217
    if opt.preload is not None:
218
        net.load_state_dict(torch.load(opt.preload))
219
220
    dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)
221
    post_trans = Compose([EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold_values=True)])
222
223
    loss_function = monai.losses.DiceCELoss(sigmoid=True)
224
    torch.backends.cudnn.benchmark = opt.benchmark
225
226
227
    if opt.network is 'nnunet':
228
229
        optim = torch.optim.SGD(net.parameters(), lr=opt.lr, momentum=0.99, weight_decay=3e-5, nesterov=True,)
230
        net_scheduler = torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda=lambda epoch: (1 - epoch / opt.epochs) ** 0.9)
231
232
    elif opt.network is 'unetr':
233
234
        optim = torch.optim.AdamW(net.parameters(), lr=1e-4, weight_decay=1e-5)
235
236
    # start a typical PyTorch training
237
    val_interval = 1
238
    best_metric = -1
239
    best_metric_epoch = -1
240
    epoch_loss_values = list()
241
    writer = SummaryWriter()
242
    for epoch in range(opt.epochs):
243
        print("-" * 10)
244
        print(f"epoch {epoch + 1}/{opt.epochs}")
245
        net.train()
246
        epoch_loss = 0
247
        step = 0
248
        for batch_data in train_loader:
249
            step += 1
250
            inputs, labels = batch_data["image"].cuda(), batch_data["label"].cuda()
251
            optim.zero_grad()
252
            outputs = net(inputs)
253
            loss = loss_function(outputs, labels)
254
            loss.backward()
255
            optim.step()
256
            epoch_loss += loss.item()
257
            epoch_len = len(check_train) // train_loader.batch_size
258
            print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
259
            writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step)
260
        epoch_loss /= step
261
        epoch_loss_values.append(epoch_loss)
262
        print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")
263
        if opt.network is 'nnunet':
264
            update_learning_rate(net_scheduler, optim)
265
266
        if (epoch + 1) % val_interval == 0:
267
            net.eval()
268
            with torch.no_grad():
269
270
                def plot_dice(images_loader):
271
272
                    val_images = None
273
                    val_labels = None
274
                    val_outputs = None
275
                    for data in images_loader:
276
                        val_images, val_labels = data["image"].cuda(), data["label"].cuda()
277
                        roi_size = opt.patch_size
278
                        sw_batch_size = 4
279
                        val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, net)
280
                        val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)]
281
                        dice_metric(y_pred=val_outputs, y=val_labels)
282
283
                    # aggregate the final mean dice result
284
                    metric = dice_metric.aggregate().item()
285
                    # reset the status for next validation round
286
                    dice_metric.reset()
287
288
                    return metric, val_images, val_labels, val_outputs
289
290
                metric, val_images, val_labels, val_outputs = plot_dice(val_loader)
291
292
                # Save best model
293
                if metric > best_metric:
294
                    best_metric = metric
295
                    best_metric_epoch = epoch + 1
296
                    torch.save(net.state_dict(), "best_metric_model.pth")
297
                    print("saved new best metric model")
298
299
                metric_train, train_images, train_labels, train_outputs = plot_dice(train_dice_loader)
300
                metric_test, test_images, test_labels, test_outputs = plot_dice(test_loader)
301
302
                # Logger bar
303
                print(
304
                    "current epoch: {} Training dice: {:.4f} Validation dice: {:.4f} Testing dice: {:.4f} Best Validation dice: {:.4f} at epoch {}".format(
305
                        epoch + 1, metric_train, metric, metric_test, best_metric, best_metric_epoch
306
                    )
307
                )
308
309
                writer.add_scalar("Mean_epoch_loss", epoch_loss, epoch + 1)
310
                writer.add_scalar("Testing_dice", metric_test, epoch + 1)
311
                writer.add_scalar("Training_dice", metric_train, epoch + 1)
312
                writer.add_scalar("Validation_dice", metric, epoch + 1)
313
                # plot the last model output as GIF image in TensorBoard with the corresponding image and label
314
                # val_outputs = (val_outputs.sigmoid() >= 0.5).float()
315
                plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="validation image")
316
                plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="validation label")
317
                plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="validation inference")
318
                plot_2d_or_3d_image(test_images, epoch + 1, writer, index=0, tag="test image")
319
                plot_2d_or_3d_image(test_labels, epoch + 1, writer, index=0, tag="test label")
320
                plot_2d_or_3d_image(test_outputs, epoch + 1, writer, index=0, tag="test inference")
321
322
    print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}")
323
    writer.close()
324
325
326
if __name__ == "__main__":
327
    main()