Diff of /monai 0.5.0/train.py [000000] .. [83198a]

Switch to unified view

a b/monai 0.5.0/train.py
1
# Copyright 2020 MONAI Consortium
2
# Licensed under the Apache License, Version 2.0 (the "License");
3
# you may not use this file except in compliance with the License.
4
# You may obtain a copy of the License at
5
#     http://www.apache.org/licenses/LICENSE-2.0
6
# Unless required by applicable law or agreed to in writing, software
7
# distributed under the License is distributed on an "AS IS" BASIS,
8
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
# See the License for the specific language governing permissions and
10
# limitations under the License.
11
12
13
from init import Options
14
from networks import *
15
# from networks import build_net
16
import logging
17
import os
18
import sys
19
import tempfile
20
from glob import glob
21
22
import nibabel as nib
23
import numpy as np
24
import torch
25
from torch.utils.data import DataLoader
26
from torch.utils.tensorboard import SummaryWriter
27
28
import monai
29
from monai.data import create_test_image_3d, list_data_collate
30
from monai.inferers import sliding_window_inference
31
from monai.metrics import DiceMetric
32
from monai.transforms import (Compose, LoadImaged, AddChanneld, Transpose,Activations,AsDiscrete, RandGaussianSmoothd, CropForegroundd, SpatialPadd,
33
                              ScaleIntensityd, ToTensord, RandSpatialCropd, Rand3DElasticd, RandAffined, RandZoomd,
34
    Spacingd, Orientationd, Resized, ThresholdIntensityd, RandShiftIntensityd, BorderPadd, RandGaussianNoised, RandAdjustContrastd,NormalizeIntensityd,RandFlipd)
35
36
from monai.visualize import plot_2d_or_3d_image
37
# from monai.engines import get_devices_spec
38
39
def main():
40
    opt = Options().parse()
41
    # monai.config.print_config()
42
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)
43
44
    # check gpus
45
    if opt.gpu_ids != '-1':
46
        num_gpus = len(opt.gpu_ids.split(','))
47
    else:
48
        num_gpus = 0
49
    print('number of GPU:', num_gpus)
50
51
    # Data loader creation
52
    # train images
53
    train_images = sorted(glob(os.path.join(opt.images_folder, 'train', 'image*.nii')))
54
    train_segs = sorted(glob(os.path.join(opt.labels_folder, 'train', 'label*.nii')))
55
56
    train_images_for_dice = sorted(glob(os.path.join(opt.images_folder, 'train', 'image*.nii')))
57
    train_segs_for_dice = sorted(glob(os.path.join(opt.labels_folder, 'train', 'label*.nii')))
58
59
    # validation images
60
    val_images = sorted(glob(os.path.join(opt.images_folder, 'val', 'image*.nii')))
61
    val_segs = sorted(glob(os.path.join(opt.labels_folder, 'val', 'label*.nii')))
62
63
    # test images
64
    test_images = sorted(glob(os.path.join(opt.images_folder, 'test', 'image*.nii')))
65
    test_segs = sorted(glob(os.path.join(opt.labels_folder, 'test', 'label*.nii')))
66
67
    # augment the data list for training
68
    for i in range(int(opt.increase_factor_data)):
69
    
70
        train_images.extend(train_images)
71
        train_segs.extend(train_segs)
72
73
    print('Number of training patches per epoch:', len(train_images))
74
    print('Number of training images per epoch:', len(train_images_for_dice))
75
    print('Number of validation images per epoch:', len(val_images))
76
    print('Number of test images per epoch:', len(test_images))
77
78
    # Creation of data directories for data_loader
79
80
    train_dicts = [{'image': image_name, 'label': label_name}
81
                  for image_name, label_name in zip(train_images, train_segs)]
82
83
    train_dice_dicts = [{'image': image_name, 'label': label_name}
84
                   for image_name, label_name in zip(train_images_for_dice, train_segs_for_dice)]
85
86
    val_dicts = [{'image': image_name, 'label': label_name}
87
                   for image_name, label_name in zip(val_images, val_segs)]
88
89
    test_dicts = [{'image': image_name, 'label': label_name}
90
                 for image_name, label_name in zip(test_images, test_segs)]
91
92
    # Transforms list
93
94
    if opt.resolution is not None:
95
        train_transforms = [
96
            LoadImaged(keys=['image', 'label']),
97
            AddChanneld(keys=['image', 'label']),
98
            # ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135),  # CT HU filter
99
            # ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215),
100
            CropForegroundd(keys=['image', 'label'], source_key='image'),               # crop CropForeground
101
102
            NormalizeIntensityd(keys=['image']),                                          # augmentation
103
            ScaleIntensityd(keys=['image']),                                              # intensity
104
            Spacingd(keys=['image', 'label'], pixdim=opt.resolution, mode=('bilinear', 'nearest')),  # resolution
105
106
            RandFlipd(keys=['image', 'label'], prob=0.15, spatial_axis=1),
107
            RandFlipd(keys=['image', 'label'], prob=0.15, spatial_axis=0),
108
            RandFlipd(keys=['image', 'label'], prob=0.15, spatial_axis=2),
109
            RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1,
110
                        rotate_range=(np.pi / 36, np.pi / 36, np.pi * 2), padding_mode="zeros"),
111
            RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1,
112
                        rotate_range=(np.pi / 36, np.pi / 2, np.pi / 36), padding_mode="zeros"),
113
            RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1,
114
                        rotate_range=(np.pi / 2, np.pi / 36, np.pi / 36), padding_mode="zeros"),
115
            Rand3DElasticd(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1,
116
                           sigma_range=(5, 8), magnitude_range=(100, 200), scale_range=(0.15, 0.15, 0.15),
117
                           padding_mode="zeros"),
118
            RandGaussianSmoothd(keys=["image"], sigma_x=(0.5, 1.15), sigma_y=(0.5, 1.15), sigma_z=(0.5, 1.15), prob=0.1,),
119
            RandAdjustContrastd(keys=['image'], gamma=(0.5, 2.5), prob=0.1),
120
            RandGaussianNoised(keys=['image'], prob=0.1, mean=np.random.uniform(0, 0.5), std=np.random.uniform(0, 15)),
121
            RandShiftIntensityd(keys=['image'], offsets=np.random.uniform(0,0.3), prob=0.1),
122
123
            SpatialPadd(keys=['image', 'label'], spatial_size=opt.patch_size, method= 'end'),  # pad if the image is smaller than patch
124
            RandSpatialCropd(keys=['image', 'label'], roi_size=opt.patch_size, random_size=False),
125
            ToTensord(keys=['image', 'label'])
126
        ]
127
128
        val_transforms = [
129
            LoadImaged(keys=['image', 'label']),
130
            AddChanneld(keys=['image', 'label']),
131
            # ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135),
132
            # ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215),
133
            CropForegroundd(keys=['image', 'label'], source_key='image'),                   # crop CropForeground
134
135
            NormalizeIntensityd(keys=['image']),                                      # intensity
136
            ScaleIntensityd(keys=['image']),
137
            Spacingd(keys=['image', 'label'], pixdim=opt.resolution, mode=('bilinear', 'nearest')),  # resolution
138
139
            SpatialPadd(keys=['image', 'label'], spatial_size=opt.patch_size, method= 'end'),  # pad if the image is smaller than patch
140
            ToTensord(keys=['image', 'label'])
141
        ]
142
    else:
143
        train_transforms = [
144
            LoadImaged(keys=['image', 'label']),
145
            AddChanneld(keys=['image', 'label']),
146
            # ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135),
147
            # ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215),
148
            CropForegroundd(keys=['image', 'label'], source_key='image'),               # crop CropForeground
149
150
            NormalizeIntensityd(keys=['image']),                                          # augmentation
151
            ScaleIntensityd(keys=['image']),                                              # intensity
152
153
            RandFlipd(keys=['image', 'label'], prob=0.15, spatial_axis=1),
154
            RandFlipd(keys=['image', 'label'], prob=0.15, spatial_axis=0),
155
            RandFlipd(keys=['image', 'label'], prob=0.15, spatial_axis=2),
156
            RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1,
157
                        rotate_range=(np.pi / 36, np.pi / 36, np.pi * 2), padding_mode="zeros"),
158
            RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1,
159
                        rotate_range=(np.pi / 36, np.pi / 2, np.pi / 36), padding_mode="zeros"),
160
            RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1,
161
                        rotate_range=(np.pi / 2, np.pi / 36, np.pi / 36), padding_mode="zeros"),
162
            Rand3DElasticd(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1,
163
                           sigma_range=(5, 8), magnitude_range=(100, 200), scale_range=(0.15, 0.15, 0.15),
164
                           padding_mode="zeros"),
165
            RandGaussianSmoothd(keys=["image"], sigma_x=(0.5, 1.15), sigma_y=(0.5, 1.15), sigma_z=(0.5, 1.15), prob=0.1,),
166
            RandAdjustContrastd(keys=['image'], gamma=(0.5, 2.5), prob=0.1),
167
            RandGaussianNoised(keys=['image'], prob=0.1, mean=np.random.uniform(0, 0.5), std=np.random.uniform(0, 1)),
168
            RandShiftIntensityd(keys=['image'], offsets=np.random.uniform(0,0.3), prob=0.1),
169
170
            SpatialPadd(keys=['image', 'label'], spatial_size=opt.patch_size, method= 'end'),  # pad if the image is smaller than patch
171
            RandSpatialCropd(keys=['image', 'label'], roi_size=opt.patch_size, random_size=False),
172
            ToTensord(keys=['image', 'label'])
173
        ]
174
175
        val_transforms = [
176
            LoadImaged(keys=['image', 'label']),
177
            AddChanneld(keys=['image', 'label']),
178
            # ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135),
179
            # ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215),
180
            CropForegroundd(keys=['image', 'label'], source_key='image'),                   # crop CropForeground
181
182
            NormalizeIntensityd(keys=['image']),                                      # intensity
183
            ScaleIntensityd(keys=['image']),
184
185
            SpatialPadd(keys=['image', 'label'], spatial_size=opt.patch_size, method= 'end'),  # pad if the image is smaller than patch
186
            ToTensord(keys=['image', 'label'])
187
        ]
188
189
    train_transforms = Compose(train_transforms)
190
    val_transforms = Compose(val_transforms)
191
192
    # create a training data loader
193
    check_train = monai.data.Dataset(data=train_dicts, transform=train_transforms)
194
    train_loader = DataLoader(check_train, batch_size=opt.batch_size, shuffle=True, num_workers=opt.workers, pin_memory=False)
195
196
    # create a training_dice data loader
197
    check_val = monai.data.Dataset(data=train_dice_dicts, transform=val_transforms)
198
    train_dice_loader = DataLoader(check_val, batch_size=1, num_workers=opt.workers, pin_memory=False)
199
200
    # create a validation data loader
201
    check_val = monai.data.Dataset(data=val_dicts, transform=val_transforms)
202
    val_loader = DataLoader(check_val, batch_size=1, num_workers=opt.workers, pin_memory=False)
203
204
    # create a validation data loader
205
    check_val = monai.data.Dataset(data=test_dicts, transform=val_transforms)
206
    test_loader = DataLoader(check_val, batch_size=1, num_workers=opt.workers, pin_memory=False)
207
208
    # # try to use all the available GPUs
209
    # devices = get_devices_spec(None)
210
211
    # build the network
212
    net = build_net()
213
    net.cuda()
214
215
    if num_gpus > 1:
216
        net = torch.nn.DataParallel(net)
217
218
    if opt.preload is not None:
219
        net.load_state_dict(torch.load(opt.preload))
220
221
    dice_metric = DiceMetric(include_background=True, reduction="mean")
222
    post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold_values=True)])
223
224
    loss_function = monai.losses.DiceCELoss(sigmoid=True)
225
226
    optim = torch.optim.SGD(net.parameters(), lr=opt.lr, momentum=0.99, weight_decay=3e-5, nesterov=True,)
227
228
    net_scheduler = torch.optim.lr_scheduler.LambdaLR(
229
        optim, lr_lambda=lambda epoch: (1 - epoch / opt.epochs) ** 0.9)
230
231
    # start a typical PyTorch training
232
    val_interval = 1
233
    best_metric = -1
234
    best_metric_epoch = -1
235
    epoch_loss_values = list()
236
    metric_values = list()
237
    writer = SummaryWriter()
238
    for epoch in range(opt.epochs):
239
        print("-" * 10)
240
        print(f"epoch {epoch + 1}/{opt.epochs}")
241
        net.train()
242
        epoch_loss = 0
243
        step = 0
244
        for batch_data in train_loader:
245
            step += 1
246
            inputs, labels = batch_data["image"].cuda(), batch_data["label"].cuda()
247
            optim.zero_grad()
248
            outputs = net(inputs)
249
            loss = loss_function(outputs, labels)
250
            loss.backward()
251
            optim.step()
252
            epoch_loss += loss.item()
253
            epoch_len = len(check_train) // train_loader.batch_size
254
            print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
255
            writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step)
256
        epoch_loss /= step
257
        epoch_loss_values.append(epoch_loss)
258
        print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")
259
        update_learning_rate(net_scheduler, optim)
260
261
        if (epoch + 1) % val_interval == 0:
262
            net.eval()
263
            with torch.no_grad():
264
265
                def plot_dice(images_loader):
266
267
                    metric_sum = 0.0
268
                    metric_count = 0
269
                    val_images = None
270
                    val_labels = None
271
                    val_outputs = None
272
                    for data in images_loader:
273
                        val_images, val_labels = data["image"].cuda(), data["label"].cuda()
274
                        roi_size = opt.patch_size
275
                        sw_batch_size = 4
276
                        val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, net)
277
                        val_outputs = post_trans(val_outputs)
278
                        value, _ = dice_metric(y_pred=val_outputs, y=val_labels)
279
                        metric_count += len(value)
280
                        metric_sum += value.item() * len(value)
281
                    metric = metric_sum / metric_count
282
                    metric_values.append(metric)
283
                    return metric, val_images, val_labels, val_outputs
284
285
                metric, val_images, val_labels, val_outputs = plot_dice(val_loader)
286
287
                # Save best model
288
                if metric > best_metric:
289
                    best_metric = metric
290
                    best_metric_epoch = epoch + 1
291
                    torch.save(net.state_dict(), "best_metric_model.pth")
292
                    print("saved new best metric model")
293
294
                metric_train, train_images, train_labels, train_outputs = plot_dice(train_dice_loader)
295
                metric_test, test_images, test_labels, test_outputs = plot_dice(test_loader)
296
297
                # Logger bar
298
                print(
299
                    "current epoch: {} Training dice: {:.4f} Validation dice: {:.4f} Testing dice: {:.4f} Best Validation dice: {:.4f} at epoch {}".format(
300
                        epoch + 1, metric_train, metric, metric_test, best_metric, best_metric_epoch
301
                    )
302
                )
303
304
                writer.add_scalar("Mean_epoch_loss", epoch_loss, epoch + 1)
305
                writer.add_scalar("Testing_dice", metric_test, epoch + 1)
306
                writer.add_scalar("Training_dice", metric_train, epoch + 1)
307
                writer.add_scalar("Validation_dice", metric, epoch + 1)
308
                # plot the last model output as GIF image in TensorBoard with the corresponding image and label
309
                # val_outputs = (val_outputs.sigmoid() >= 0.5).float()
310
                plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="validation image")
311
                plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="validation label")
312
                plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="validation inference")
313
                plot_2d_or_3d_image(test_images, epoch + 1, writer, index=0, tag="test image")
314
                plot_2d_or_3d_image(test_labels, epoch + 1, writer, index=0, tag="test label")
315
                plot_2d_or_3d_image(test_outputs, epoch + 1, writer, index=0, tag="test inference")
316
317
    print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}")
318
    writer.close()
319
320
321
if __name__ == "__main__":
322
    main()