|
a |
|
b/train.py |
|
|
1 |
# Copyright 2020 MONAI Consortium |
|
|
2 |
# Licensed under the Apache License, Version 2.0 (the "License"); |
|
|
3 |
# you may not use this file except in compliance with the License. |
|
|
4 |
# You may obtain a copy of the License at |
|
|
5 |
# http://www.apache.org/licenses/LICENSE-2.0 |
|
|
6 |
# Unless required by applicable law or agreed to in writing, software |
|
|
7 |
# distributed under the License is distributed on an "AS IS" BASIS, |
|
|
8 |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
9 |
# See the License for the specific language governing permissions and |
|
|
10 |
# limitations under the License. |
|
|
11 |
|
|
|
12 |
|
|
|
13 |
from init import Options |
|
|
14 |
from networks import build_net, update_learning_rate, build_UNETR |
|
|
15 |
# from networks import build_net |
|
|
16 |
import logging |
|
|
17 |
import os |
|
|
18 |
import sys |
|
|
19 |
import tempfile |
|
|
20 |
from glob import glob |
|
|
21 |
|
|
|
22 |
import nibabel as nib |
|
|
23 |
import numpy as np |
|
|
24 |
import torch |
|
|
25 |
from torch.utils.data import DataLoader |
|
|
26 |
from torch.utils.tensorboard import SummaryWriter |
|
|
27 |
|
|
|
28 |
import monai |
|
|
29 |
from monai.data import create_test_image_3d, list_data_collate, decollate_batch |
|
|
30 |
from monai.inferers import sliding_window_inference |
|
|
31 |
from monai.metrics import DiceMetric |
|
|
32 |
from monai.transforms import (EnsureType, Compose, LoadImaged, AddChanneld, Transpose,Activations,AsDiscrete, RandGaussianSmoothd, CropForegroundd, SpatialPadd, |
|
|
33 |
ScaleIntensityd, ToTensord, RandSpatialCropd, Rand3DElasticd, RandAffined, RandZoomd, |
|
|
34 |
Spacingd, Orientationd, Resized, ThresholdIntensityd, RandShiftIntensityd, BorderPadd, RandGaussianNoised, RandAdjustContrastd,NormalizeIntensityd,RandFlipd) |
|
|
35 |
|
|
|
36 |
from monai.visualize import plot_2d_or_3d_image |
|
|
37 |
|
|
|
38 |
def main(): |
|
|
39 |
opt = Options().parse() |
|
|
40 |
# monai.config.print_config() |
|
|
41 |
logging.basicConfig(stream=sys.stdout, level=logging.INFO) |
|
|
42 |
|
|
|
43 |
# check gpus |
|
|
44 |
if opt.gpu_ids != '-1': |
|
|
45 |
num_gpus = len(opt.gpu_ids.split(',')) |
|
|
46 |
else: |
|
|
47 |
num_gpus = 0 |
|
|
48 |
print('number of GPU:', num_gpus) |
|
|
49 |
|
|
|
50 |
# Data loader creation |
|
|
51 |
# train images |
|
|
52 |
train_images = sorted(glob(os.path.join(opt.images_folder, 'train', 'image*.nii'))) |
|
|
53 |
train_segs = sorted(glob(os.path.join(opt.labels_folder, 'train', 'label*.nii'))) |
|
|
54 |
|
|
|
55 |
train_images_for_dice = sorted(glob(os.path.join(opt.images_folder, 'train', 'image*.nii'))) |
|
|
56 |
train_segs_for_dice = sorted(glob(os.path.join(opt.labels_folder, 'train', 'label*.nii'))) |
|
|
57 |
|
|
|
58 |
# validation images |
|
|
59 |
val_images = sorted(glob(os.path.join(opt.images_folder, 'val', 'image*.nii'))) |
|
|
60 |
val_segs = sorted(glob(os.path.join(opt.labels_folder, 'val', 'label*.nii'))) |
|
|
61 |
|
|
|
62 |
# test images |
|
|
63 |
test_images = sorted(glob(os.path.join(opt.images_folder, 'test', 'image*.nii'))) |
|
|
64 |
test_segs = sorted(glob(os.path.join(opt.labels_folder, 'test', 'label*.nii'))) |
|
|
65 |
|
|
|
66 |
# augment the data list for training |
|
|
67 |
for i in range(int(opt.increase_factor_data)): |
|
|
68 |
|
|
|
69 |
train_images.extend(train_images) |
|
|
70 |
train_segs.extend(train_segs) |
|
|
71 |
|
|
|
72 |
print('Number of training patches per epoch:', len(train_images)) |
|
|
73 |
print('Number of training images per epoch:', len(train_images_for_dice)) |
|
|
74 |
print('Number of validation images per epoch:', len(val_images)) |
|
|
75 |
print('Number of test images per epoch:', len(test_images)) |
|
|
76 |
|
|
|
77 |
# Creation of data directories for data_loader |
|
|
78 |
|
|
|
79 |
train_dicts = [{'image': image_name, 'label': label_name} |
|
|
80 |
for image_name, label_name in zip(train_images, train_segs)] |
|
|
81 |
|
|
|
82 |
train_dice_dicts = [{'image': image_name, 'label': label_name} |
|
|
83 |
for image_name, label_name in zip(train_images_for_dice, train_segs_for_dice)] |
|
|
84 |
|
|
|
85 |
val_dicts = [{'image': image_name, 'label': label_name} |
|
|
86 |
for image_name, label_name in zip(val_images, val_segs)] |
|
|
87 |
|
|
|
88 |
test_dicts = [{'image': image_name, 'label': label_name} |
|
|
89 |
for image_name, label_name in zip(test_images, test_segs)] |
|
|
90 |
|
|
|
91 |
# Transforms list |
|
|
92 |
|
|
|
93 |
if opt.resolution is not None: |
|
|
94 |
train_transforms = [ |
|
|
95 |
LoadImaged(keys=['image', 'label']), |
|
|
96 |
AddChanneld(keys=['image', 'label']), |
|
|
97 |
# ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135), # CT HU filter |
|
|
98 |
# ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215), |
|
|
99 |
CropForegroundd(keys=['image', 'label'], source_key='image'), # crop CropForeground |
|
|
100 |
|
|
|
101 |
NormalizeIntensityd(keys=['image']), # augmentation |
|
|
102 |
ScaleIntensityd(keys=['image']), # intensity |
|
|
103 |
Spacingd(keys=['image', 'label'], pixdim=opt.resolution, mode=('bilinear', 'nearest')), # resolution |
|
|
104 |
|
|
|
105 |
RandFlipd(keys=['image', 'label'], prob=0.15, spatial_axis=1), |
|
|
106 |
RandFlipd(keys=['image', 'label'], prob=0.15, spatial_axis=0), |
|
|
107 |
RandFlipd(keys=['image', 'label'], prob=0.15, spatial_axis=2), |
|
|
108 |
RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, |
|
|
109 |
rotate_range=(np.pi / 36, np.pi / 36, np.pi * 2), padding_mode="zeros"), |
|
|
110 |
RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, |
|
|
111 |
rotate_range=(np.pi / 36, np.pi / 2, np.pi / 36), padding_mode="zeros"), |
|
|
112 |
RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, |
|
|
113 |
rotate_range=(np.pi / 2, np.pi / 36, np.pi / 36), padding_mode="zeros"), |
|
|
114 |
Rand3DElasticd(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, |
|
|
115 |
sigma_range=(5, 8), magnitude_range=(100, 200), scale_range=(0.15, 0.15, 0.15), |
|
|
116 |
padding_mode="zeros"), |
|
|
117 |
RandGaussianSmoothd(keys=["image"], sigma_x=(0.5, 1.15), sigma_y=(0.5, 1.15), sigma_z=(0.5, 1.15), prob=0.1,), |
|
|
118 |
RandAdjustContrastd(keys=['image'], gamma=(0.5, 2.5), prob=0.1), |
|
|
119 |
RandGaussianNoised(keys=['image'], prob=0.1, mean=np.random.uniform(0, 0.5), std=np.random.uniform(0, 15)), |
|
|
120 |
RandShiftIntensityd(keys=['image'], offsets=np.random.uniform(0,0.3), prob=0.1), |
|
|
121 |
|
|
|
122 |
SpatialPadd(keys=['image', 'label'], spatial_size=opt.patch_size, method= 'end'), # pad if the image is smaller than patch |
|
|
123 |
RandSpatialCropd(keys=['image', 'label'], roi_size=opt.patch_size, random_size=False), |
|
|
124 |
ToTensord(keys=['image', 'label']) |
|
|
125 |
] |
|
|
126 |
|
|
|
127 |
val_transforms = [ |
|
|
128 |
LoadImaged(keys=['image', 'label']), |
|
|
129 |
AddChanneld(keys=['image', 'label']), |
|
|
130 |
# ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135), |
|
|
131 |
# ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215), |
|
|
132 |
CropForegroundd(keys=['image', 'label'], source_key='image'), # crop CropForeground |
|
|
133 |
|
|
|
134 |
NormalizeIntensityd(keys=['image']), # intensity |
|
|
135 |
ScaleIntensityd(keys=['image']), |
|
|
136 |
Spacingd(keys=['image', 'label'], pixdim=opt.resolution, mode=('bilinear', 'nearest')), # resolution |
|
|
137 |
|
|
|
138 |
SpatialPadd(keys=['image', 'label'], spatial_size=opt.patch_size, method= 'end'), # pad if the image is smaller than patch |
|
|
139 |
ToTensord(keys=['image', 'label']) |
|
|
140 |
] |
|
|
141 |
else: |
|
|
142 |
train_transforms = [ |
|
|
143 |
LoadImaged(keys=['image', 'label']), |
|
|
144 |
AddChanneld(keys=['image', 'label']), |
|
|
145 |
# ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135), |
|
|
146 |
# ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215), |
|
|
147 |
CropForegroundd(keys=['image', 'label'], source_key='image'), # crop CropForeground |
|
|
148 |
|
|
|
149 |
NormalizeIntensityd(keys=['image']), # augmentation |
|
|
150 |
ScaleIntensityd(keys=['image']), # intensity |
|
|
151 |
|
|
|
152 |
RandFlipd(keys=['image', 'label'], prob=0.15, spatial_axis=1), |
|
|
153 |
RandFlipd(keys=['image', 'label'], prob=0.15, spatial_axis=0), |
|
|
154 |
RandFlipd(keys=['image', 'label'], prob=0.15, spatial_axis=2), |
|
|
155 |
RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, |
|
|
156 |
rotate_range=(np.pi / 36, np.pi / 36, np.pi * 2), padding_mode="zeros"), |
|
|
157 |
RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, |
|
|
158 |
rotate_range=(np.pi / 36, np.pi / 2, np.pi / 36), padding_mode="zeros"), |
|
|
159 |
RandAffined(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, |
|
|
160 |
rotate_range=(np.pi / 2, np.pi / 36, np.pi / 36), padding_mode="zeros"), |
|
|
161 |
Rand3DElasticd(keys=['image', 'label'], mode=('bilinear', 'nearest'), prob=0.1, |
|
|
162 |
sigma_range=(5, 8), magnitude_range=(100, 200), scale_range=(0.15, 0.15, 0.15), |
|
|
163 |
padding_mode="zeros"), |
|
|
164 |
RandGaussianSmoothd(keys=["image"], sigma_x=(0.5, 1.15), sigma_y=(0.5, 1.15), sigma_z=(0.5, 1.15), prob=0.1,), |
|
|
165 |
RandAdjustContrastd(keys=['image'], gamma=(0.5, 2.5), prob=0.1), |
|
|
166 |
RandGaussianNoised(keys=['image'], prob=0.1, mean=np.random.uniform(0, 0.5), std=np.random.uniform(0, 1)), |
|
|
167 |
RandShiftIntensityd(keys=['image'], offsets=np.random.uniform(0,0.3), prob=0.1), |
|
|
168 |
|
|
|
169 |
SpatialPadd(keys=['image', 'label'], spatial_size=opt.patch_size, method= 'end'), # pad if the image is smaller than patch |
|
|
170 |
RandSpatialCropd(keys=['image', 'label'], roi_size=opt.patch_size, random_size=False), |
|
|
171 |
ToTensord(keys=['image', 'label']) |
|
|
172 |
] |
|
|
173 |
|
|
|
174 |
val_transforms = [ |
|
|
175 |
LoadImaged(keys=['image', 'label']), |
|
|
176 |
AddChanneld(keys=['image', 'label']), |
|
|
177 |
# ThresholdIntensityd(keys=['image'], threshold=-135, above=True, cval=-135), |
|
|
178 |
# ThresholdIntensityd(keys=['image'], threshold=215, above=False, cval=215), |
|
|
179 |
CropForegroundd(keys=['image', 'label'], source_key='image'), # crop CropForeground |
|
|
180 |
|
|
|
181 |
NormalizeIntensityd(keys=['image']), # intensity |
|
|
182 |
ScaleIntensityd(keys=['image']), |
|
|
183 |
|
|
|
184 |
SpatialPadd(keys=['image', 'label'], spatial_size=opt.patch_size, method= 'end'), # pad if the image is smaller than patch |
|
|
185 |
ToTensord(keys=['image', 'label']) |
|
|
186 |
] |
|
|
187 |
|
|
|
188 |
train_transforms = Compose(train_transforms) |
|
|
189 |
val_transforms = Compose(val_transforms) |
|
|
190 |
|
|
|
191 |
# create a training data loader |
|
|
192 |
check_train = monai.data.Dataset(data=train_dicts, transform=train_transforms) |
|
|
193 |
train_loader = DataLoader(check_train, batch_size=opt.batch_size, shuffle=True, collate_fn=list_data_collate, num_workers=opt.workers, pin_memory=False) |
|
|
194 |
|
|
|
195 |
# create a training_dice data loader |
|
|
196 |
check_val = monai.data.Dataset(data=train_dice_dicts, transform=val_transforms) |
|
|
197 |
train_dice_loader = DataLoader(check_val, batch_size=1, num_workers=opt.workers, collate_fn=list_data_collate, pin_memory=False) |
|
|
198 |
|
|
|
199 |
# create a validation data loader |
|
|
200 |
check_val = monai.data.Dataset(data=val_dicts, transform=val_transforms) |
|
|
201 |
val_loader = DataLoader(check_val, batch_size=1, num_workers=opt.workers, collate_fn=list_data_collate, pin_memory=False) |
|
|
202 |
|
|
|
203 |
# create a validation data loader |
|
|
204 |
check_val = monai.data.Dataset(data=test_dicts, transform=val_transforms) |
|
|
205 |
test_loader = DataLoader(check_val, batch_size=1, num_workers=opt.workers, collate_fn=list_data_collate, pin_memory=False) |
|
|
206 |
|
|
|
207 |
# build the network |
|
|
208 |
if opt.network is 'nnunet': |
|
|
209 |
net = build_net() # nn build_net |
|
|
210 |
elif opt.network is 'unetr': |
|
|
211 |
net = build_UNETR() # UneTR |
|
|
212 |
net.cuda() |
|
|
213 |
|
|
|
214 |
if num_gpus > 1: |
|
|
215 |
net = torch.nn.DataParallel(net) |
|
|
216 |
|
|
|
217 |
if opt.preload is not None: |
|
|
218 |
net.load_state_dict(torch.load(opt.preload)) |
|
|
219 |
|
|
|
220 |
dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) |
|
|
221 |
post_trans = Compose([EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold_values=True)]) |
|
|
222 |
|
|
|
223 |
loss_function = monai.losses.DiceCELoss(sigmoid=True) |
|
|
224 |
torch.backends.cudnn.benchmark = opt.benchmark |
|
|
225 |
|
|
|
226 |
|
|
|
227 |
if opt.network is 'nnunet': |
|
|
228 |
|
|
|
229 |
optim = torch.optim.SGD(net.parameters(), lr=opt.lr, momentum=0.99, weight_decay=3e-5, nesterov=True,) |
|
|
230 |
net_scheduler = torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda=lambda epoch: (1 - epoch / opt.epochs) ** 0.9) |
|
|
231 |
|
|
|
232 |
elif opt.network is 'unetr': |
|
|
233 |
|
|
|
234 |
optim = torch.optim.AdamW(net.parameters(), lr=1e-4, weight_decay=1e-5) |
|
|
235 |
|
|
|
236 |
# start a typical PyTorch training |
|
|
237 |
val_interval = 1 |
|
|
238 |
best_metric = -1 |
|
|
239 |
best_metric_epoch = -1 |
|
|
240 |
epoch_loss_values = list() |
|
|
241 |
writer = SummaryWriter() |
|
|
242 |
for epoch in range(opt.epochs): |
|
|
243 |
print("-" * 10) |
|
|
244 |
print(f"epoch {epoch + 1}/{opt.epochs}") |
|
|
245 |
net.train() |
|
|
246 |
epoch_loss = 0 |
|
|
247 |
step = 0 |
|
|
248 |
for batch_data in train_loader: |
|
|
249 |
step += 1 |
|
|
250 |
inputs, labels = batch_data["image"].cuda(), batch_data["label"].cuda() |
|
|
251 |
optim.zero_grad() |
|
|
252 |
outputs = net(inputs) |
|
|
253 |
loss = loss_function(outputs, labels) |
|
|
254 |
loss.backward() |
|
|
255 |
optim.step() |
|
|
256 |
epoch_loss += loss.item() |
|
|
257 |
epoch_len = len(check_train) // train_loader.batch_size |
|
|
258 |
print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}") |
|
|
259 |
writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step) |
|
|
260 |
epoch_loss /= step |
|
|
261 |
epoch_loss_values.append(epoch_loss) |
|
|
262 |
print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}") |
|
|
263 |
if opt.network is 'nnunet': |
|
|
264 |
update_learning_rate(net_scheduler, optim) |
|
|
265 |
|
|
|
266 |
if (epoch + 1) % val_interval == 0: |
|
|
267 |
net.eval() |
|
|
268 |
with torch.no_grad(): |
|
|
269 |
|
|
|
270 |
def plot_dice(images_loader): |
|
|
271 |
|
|
|
272 |
val_images = None |
|
|
273 |
val_labels = None |
|
|
274 |
val_outputs = None |
|
|
275 |
for data in images_loader: |
|
|
276 |
val_images, val_labels = data["image"].cuda(), data["label"].cuda() |
|
|
277 |
roi_size = opt.patch_size |
|
|
278 |
sw_batch_size = 4 |
|
|
279 |
val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, net) |
|
|
280 |
val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)] |
|
|
281 |
dice_metric(y_pred=val_outputs, y=val_labels) |
|
|
282 |
|
|
|
283 |
# aggregate the final mean dice result |
|
|
284 |
metric = dice_metric.aggregate().item() |
|
|
285 |
# reset the status for next validation round |
|
|
286 |
dice_metric.reset() |
|
|
287 |
|
|
|
288 |
return metric, val_images, val_labels, val_outputs |
|
|
289 |
|
|
|
290 |
metric, val_images, val_labels, val_outputs = plot_dice(val_loader) |
|
|
291 |
|
|
|
292 |
# Save best model |
|
|
293 |
if metric > best_metric: |
|
|
294 |
best_metric = metric |
|
|
295 |
best_metric_epoch = epoch + 1 |
|
|
296 |
torch.save(net.state_dict(), "best_metric_model.pth") |
|
|
297 |
print("saved new best metric model") |
|
|
298 |
|
|
|
299 |
metric_train, train_images, train_labels, train_outputs = plot_dice(train_dice_loader) |
|
|
300 |
metric_test, test_images, test_labels, test_outputs = plot_dice(test_loader) |
|
|
301 |
|
|
|
302 |
# Logger bar |
|
|
303 |
print( |
|
|
304 |
"current epoch: {} Training dice: {:.4f} Validation dice: {:.4f} Testing dice: {:.4f} Best Validation dice: {:.4f} at epoch {}".format( |
|
|
305 |
epoch + 1, metric_train, metric, metric_test, best_metric, best_metric_epoch |
|
|
306 |
) |
|
|
307 |
) |
|
|
308 |
|
|
|
309 |
writer.add_scalar("Mean_epoch_loss", epoch_loss, epoch + 1) |
|
|
310 |
writer.add_scalar("Testing_dice", metric_test, epoch + 1) |
|
|
311 |
writer.add_scalar("Training_dice", metric_train, epoch + 1) |
|
|
312 |
writer.add_scalar("Validation_dice", metric, epoch + 1) |
|
|
313 |
# plot the last model output as GIF image in TensorBoard with the corresponding image and label |
|
|
314 |
# val_outputs = (val_outputs.sigmoid() >= 0.5).float() |
|
|
315 |
plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="validation image") |
|
|
316 |
plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="validation label") |
|
|
317 |
plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="validation inference") |
|
|
318 |
plot_2d_or_3d_image(test_images, epoch + 1, writer, index=0, tag="test image") |
|
|
319 |
plot_2d_or_3d_image(test_labels, epoch + 1, writer, index=0, tag="test label") |
|
|
320 |
plot_2d_or_3d_image(test_outputs, epoch + 1, writer, index=0, tag="test inference") |
|
|
321 |
|
|
|
322 |
print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}") |
|
|
323 |
writer.close() |
|
|
324 |
|
|
|
325 |
|
|
|
326 |
if __name__ == "__main__": |
|
|
327 |
main() |