|
a |
|
b/segmentation/trainddp.py |
|
|
1 |
''' |
|
|
2 |
Copyright (c) Microsoft Corporation. All rights reserved. |
|
|
3 |
Licensed under the MIT License. |
|
|
4 |
''' |
|
|
5 |
|
|
|
6 |
from monai.transforms import ( |
|
|
7 |
AsDiscrete, |
|
|
8 |
Compose, |
|
|
9 |
) |
|
|
10 |
import argparse |
|
|
11 |
from monai.inferers import sliding_window_inference |
|
|
12 |
from monai.data import CacheDataset, DataLoader, decollate_batch |
|
|
13 |
import torch |
|
|
14 |
import matplotlib.pyplot as plt |
|
|
15 |
import os |
|
|
16 |
import pandas as pd |
|
|
17 |
import time |
|
|
18 |
from torch.utils.data.distributed import DistributedSampler |
|
|
19 |
from torch.nn.parallel import DistributedDataParallel as DDP |
|
|
20 |
import torch.distributed as dist |
|
|
21 |
import os |
|
|
22 |
from initialize_train import ( |
|
|
23 |
create_data_split_files, |
|
|
24 |
get_train_valid_data_in_dict_format, |
|
|
25 |
get_train_transforms, |
|
|
26 |
get_valid_transforms, |
|
|
27 |
get_model, |
|
|
28 |
get_loss_function, |
|
|
29 |
get_optimizer, |
|
|
30 |
get_scheduler, |
|
|
31 |
get_metric, |
|
|
32 |
get_validation_sliding_window_size |
|
|
33 |
) |
|
|
34 |
|
|
|
35 |
import sys |
|
|
36 |
config_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..") |
|
|
37 |
sys.path.append(config_dir) |
|
|
38 |
from config import RESULTS_FOLDER |
|
|
39 |
torch.backends.cudnn.benchmark = True |
|
|
40 |
#%% |
|
|
41 |
def ddp_setup(): |
|
|
42 |
dist.init_process_group(backend='nccl', init_method="env://") |
|
|
43 |
|
|
|
44 |
def convert_to_4digits(str_num): |
|
|
45 |
if len(str_num) == 1: |
|
|
46 |
new_num = '000' + str_num |
|
|
47 |
elif len(str_num) == 2: |
|
|
48 |
new_num = '00' + str_num |
|
|
49 |
elif len(str_num) == 3: |
|
|
50 |
new_num = '0' + str_num |
|
|
51 |
else: |
|
|
52 |
new_num = str_num |
|
|
53 |
return new_num |
|
|
54 |
|
|
|
55 |
#%% |
|
|
56 |
def load_train_objects(args): |
|
|
57 |
train_data, valid_data = get_train_valid_data_in_dict_format(args.fold) |
|
|
58 |
train_transforms = get_train_transforms(args.input_patch_size) |
|
|
59 |
valid_transforms = get_valid_transforms() |
|
|
60 |
model = get_model(args.network_name, args.input_patch_size) |
|
|
61 |
optimizer = get_optimizer(model, learning_rate=args.lr, weight_decay=args.wd) |
|
|
62 |
loss_function = get_loss_function() |
|
|
63 |
scheduler = get_scheduler(optimizer, args.epochs) |
|
|
64 |
metric = get_metric() |
|
|
65 |
|
|
|
66 |
return ( |
|
|
67 |
train_data, |
|
|
68 |
valid_data, |
|
|
69 |
train_transforms, |
|
|
70 |
valid_transforms, |
|
|
71 |
model, |
|
|
72 |
loss_function, |
|
|
73 |
optimizer, |
|
|
74 |
scheduler, |
|
|
75 |
metric |
|
|
76 |
) |
|
|
77 |
|
|
|
78 |
|
|
|
79 |
def prepare_dataset(data, transforms, args): |
|
|
80 |
dataset = CacheDataset(data=data, transform=transforms, cache_rate=args.cache_rate, num_workers=args.num_workers) |
|
|
81 |
return dataset |
|
|
82 |
|
|
|
83 |
|
|
|
84 |
def main_worker(save_models_dir, save_logs_dir, args): |
|
|
85 |
# init_process_group |
|
|
86 |
ddp_setup() |
|
|
87 |
# get local rank on the GPU |
|
|
88 |
local_rank = int(dist.get_rank()) |
|
|
89 |
if local_rank == 0: |
|
|
90 |
print(f"Training {args.network_name} on fold {args.fold}") |
|
|
91 |
print(f"The models will be saved in {save_models_dir}") |
|
|
92 |
print(f"The training/validation logs will be saved in {save_logs_dir}") |
|
|
93 |
|
|
|
94 |
# get all training and validation objects |
|
|
95 |
train_data, valid_data, train_transforms, valid_transforms, model, loss_function, optimizer, scheduler, metric = load_train_objects(args) |
|
|
96 |
|
|
|
97 |
# get dataset of object-type CacheDataset |
|
|
98 |
train_dataset = prepare_dataset(train_data, train_transforms, args) |
|
|
99 |
valid_dataset = prepare_dataset(valid_data, valid_transforms, args) |
|
|
100 |
|
|
|
101 |
# get DistributedSampler instances for both training and validation dataloader |
|
|
102 |
# this will be used to split data into different GPUs |
|
|
103 |
train_sampler = DistributedSampler(dataset=train_dataset, shuffle=True) |
|
|
104 |
valid_sampler = DistributedSampler(dataset=valid_dataset, shuffle=False) |
|
|
105 |
|
|
|
106 |
# initializing train and valid dataloaders |
|
|
107 |
train_dataloader = DataLoader( |
|
|
108 |
train_dataset, |
|
|
109 |
batch_size=args.train_bs, |
|
|
110 |
pin_memory=True, |
|
|
111 |
shuffle=False, |
|
|
112 |
sampler=train_sampler, |
|
|
113 |
num_workers=args.num_workers |
|
|
114 |
) |
|
|
115 |
valid_dataloader = DataLoader( |
|
|
116 |
valid_dataset, |
|
|
117 |
batch_size=1, |
|
|
118 |
pin_memory=True, |
|
|
119 |
shuffle=False, |
|
|
120 |
sampler=valid_sampler, |
|
|
121 |
num_workers=args.num_workers |
|
|
122 |
) |
|
|
123 |
|
|
|
124 |
post_pred = Compose([AsDiscrete(argmax=True, to_onehot=2)]) |
|
|
125 |
post_label = Compose([AsDiscrete(to_onehot=2)]) |
|
|
126 |
|
|
|
127 |
# filepaths for storing training and validation logs from different GPUs |
|
|
128 |
trainlog_fpath = os.path.join(save_logs_dir, f'trainlog_gpu{local_rank}.csv') |
|
|
129 |
validlog_fpath = os.path.join(save_logs_dir, f'validlog_gpu{local_rank}.csv') |
|
|
130 |
|
|
|
131 |
# initialize the GPU device |
|
|
132 |
device = torch.device(f"cuda:{local_rank}") |
|
|
133 |
torch.cuda.set_device(device) |
|
|
134 |
|
|
|
135 |
# number of epochs and epoch interval for running validation |
|
|
136 |
max_epochs = args.epochs |
|
|
137 |
val_interval = args.val_interval |
|
|
138 |
|
|
|
139 |
# push models to device |
|
|
140 |
model = model.to(device) |
|
|
141 |
|
|
|
142 |
epoch_loss_values = [] |
|
|
143 |
metric_values = [] |
|
|
144 |
|
|
|
145 |
# wrap the model with DDP |
|
|
146 |
model = DDP(model, device_ids=[device]) |
|
|
147 |
|
|
|
148 |
experiment_start_time = time.time() |
|
|
149 |
|
|
|
150 |
for epoch in range(max_epochs): |
|
|
151 |
epoch_start_time = time.time() |
|
|
152 |
print(f"[GPU{local_rank}]: Running training: epoch = {epoch + 1}") |
|
|
153 |
model.train() |
|
|
154 |
epoch_loss = 0 |
|
|
155 |
step = 0 |
|
|
156 |
train_sampler.set_epoch(epoch) |
|
|
157 |
for batch_data in train_dataloader: |
|
|
158 |
step += 1 |
|
|
159 |
inputs, labels = ( |
|
|
160 |
batch_data['CTPT'].to(device), |
|
|
161 |
batch_data['GT'].to(device), |
|
|
162 |
) |
|
|
163 |
optimizer.zero_grad() |
|
|
164 |
outputs = model(inputs) |
|
|
165 |
loss = loss_function(outputs, labels) |
|
|
166 |
loss.backward() |
|
|
167 |
optimizer.step() |
|
|
168 |
epoch_loss += loss.item() |
|
|
169 |
epoch_loss /= step |
|
|
170 |
print(f"[GPU:{local_rank}]: epoch {epoch + 1}/{max_epochs}: average loss: {epoch_loss:.4f}") |
|
|
171 |
epoch_loss_values.append(epoch_loss) |
|
|
172 |
|
|
|
173 |
# steps forward the CosineAnnealingLR scheduler |
|
|
174 |
scheduler.step() |
|
|
175 |
|
|
|
176 |
# update the training log file |
|
|
177 |
epoch_loss_values_df = pd.DataFrame(data=epoch_loss_values, columns=['Loss']) |
|
|
178 |
epoch_loss_values_df.to_csv(trainlog_fpath, index=False) |
|
|
179 |
|
|
|
180 |
|
|
|
181 |
if (epoch + 1) % val_interval == 0: |
|
|
182 |
print(f"[GPU{local_rank}]: Running validation") |
|
|
183 |
model.eval() |
|
|
184 |
with torch.no_grad(): |
|
|
185 |
for val_data in valid_dataloader: |
|
|
186 |
val_inputs, val_labels = ( |
|
|
187 |
val_data['CTPT'].to(device), |
|
|
188 |
val_data['GT'].to(device), |
|
|
189 |
) |
|
|
190 |
roi_size = get_validation_sliding_window_size(args.input_patch_size) |
|
|
191 |
sw_batch_size = args.sw_bs |
|
|
192 |
val_outputs = sliding_window_inference( |
|
|
193 |
val_inputs, roi_size, sw_batch_size, model) |
|
|
194 |
val_outputs = [post_pred(i) for i in decollate_batch(val_outputs)] |
|
|
195 |
val_labels = [post_label(i) for i in decollate_batch(val_labels)] |
|
|
196 |
# compute metric for current iteration |
|
|
197 |
metric(y_pred=val_outputs, y=val_labels) |
|
|
198 |
|
|
|
199 |
# aggregate the final mean dice result |
|
|
200 |
metric_val = metric.aggregate().item() |
|
|
201 |
metric.reset() |
|
|
202 |
metric_values.append(metric_val) |
|
|
203 |
metric_values_df = pd.DataFrame(data=metric_values, columns=['Metric']) |
|
|
204 |
metric_values_df.to_csv(validlog_fpath, index=False) |
|
|
205 |
|
|
|
206 |
print(f"[GPU:{local_rank}] SAVING MODEL at epoch: {epoch + 1}; Mean DSC: {metric_val:.4f}") |
|
|
207 |
savepath = os.path.join(save_models_dir, "model_ep="+convert_to_4digits(str(int(epoch + 1)))+".pth") |
|
|
208 |
torch.save(model.module.state_dict(), savepath) |
|
|
209 |
|
|
|
210 |
epoch_end_time = (time.time() - epoch_start_time)/60 |
|
|
211 |
print(f"[GPU:{local_rank}]: Epoch {epoch + 1} time: {round(epoch_end_time,2)} min") |
|
|
212 |
|
|
|
213 |
experiment_end_time = (time.time() - experiment_start_time)/(60*60) |
|
|
214 |
print(f"[GPU:{local_rank}]: Total time: {round(experiment_end_time,2)} hr") |
|
|
215 |
|
|
|
216 |
dist.destroy_process_group() |
|
|
217 |
|
|
|
218 |
def main(args): |
|
|
219 |
os.environ['OMP_NUM_THREADS'] = '6' |
|
|
220 |
fold = args.fold |
|
|
221 |
network = args.network_name |
|
|
222 |
inputsize = f'randcrop{args.input_patch_size}' |
|
|
223 |
|
|
|
224 |
experiment_code = f"{network}_fold{fold}_{inputsize}" |
|
|
225 |
|
|
|
226 |
#save models folder |
|
|
227 |
save_models_dir = os.path.join(RESULTS_FOLDER,'models') |
|
|
228 |
save_models_dir = os.path.join(save_models_dir, 'fold'+str(fold), network, experiment_code) |
|
|
229 |
os.makedirs(save_models_dir, exist_ok=True) |
|
|
230 |
|
|
|
231 |
# save train and valid logs folder |
|
|
232 |
save_logs_dir = os.path.join(RESULTS_FOLDER,'logs') |
|
|
233 |
save_logs_dir = os.path.join(save_logs_dir, 'fold'+str(fold), network, experiment_code) |
|
|
234 |
os.makedirs(save_logs_dir, exist_ok=True) |
|
|
235 |
|
|
|
236 |
main_worker(save_models_dir, save_logs_dir, args) |
|
|
237 |
|
|
|
238 |
|
|
|
239 |
|
|
|
240 |
if __name__ == "__main__": |
|
|
241 |
# create datasplit files for train and test images |
|
|
242 |
# follow all the instructions for dataset directory creation and images/labels file names as given in: LINK |
|
|
243 |
create_data_split_files() |
|
|
244 |
parser = argparse.ArgumentParser(description='Lymphoma PET/CT lesion segmentation using MONAI-PyTorch') |
|
|
245 |
parser.add_argument('--fold', type=int, default=0, metavar='fold', |
|
|
246 |
help='validation fold (default: 0), remaining folds will be used for training') |
|
|
247 |
parser.add_argument('--network-name', type=str, default='unet', metavar='netname', |
|
|
248 |
help='network name for training (default: unet)') |
|
|
249 |
parser.add_argument('--epochs', type=int, default=500, metavar='epochs', |
|
|
250 |
help='number of epochs to train (default: 10)') |
|
|
251 |
parser.add_argument('--input-patch-size', type=int, default=192, metavar='inputsize', |
|
|
252 |
help='size of cropped input patch for training (default: 192)') |
|
|
253 |
parser.add_argument('--train-bs', type=int, default=1, metavar='train-bs', |
|
|
254 |
help='mini-batchsize for training (default: 1)') |
|
|
255 |
parser.add_argument('--num_workers', type=int, default=2, metavar='nw', |
|
|
256 |
help='num_workers for train and validation dataloaders (default: 2)') |
|
|
257 |
parser.add_argument('--cache-rate', type=float, default=0.1, metavar='cr', |
|
|
258 |
help='cache_rate for CacheDataset from MONAI (default=0.1)') |
|
|
259 |
parser.add_argument('--lr', type=float, default=2e-4, metavar='lr', |
|
|
260 |
help='initial learning rate for AdamW optimizer (default=2e-4); Cosine scheduler will decrease this to 0 in args.epochs epochs') |
|
|
261 |
parser.add_argument('--wd', type=float, default=1e-5, metavar='wd', |
|
|
262 |
help='weight-decay for AdamW optimizer (default=1e-5)') |
|
|
263 |
parser.add_argument('--val-interval', type=int, default=2, metavar='val-interval', |
|
|
264 |
help='epochs interval for which validation will be performed (default=2)') |
|
|
265 |
parser.add_argument('--sw-bs', type=int, default=2, metavar='sw-bs', |
|
|
266 |
help='batchsize for sliding window inference (default=2)') |
|
|
267 |
args = parser.parse_args() |
|
|
268 |
|
|
|
269 |
main(args) |
|
|
270 |
|