Diff of /segmentation/trainddp.py [000000] .. [18498b]

Switch to unified view

a b/segmentation/trainddp.py
1
'''
2
Copyright (c) Microsoft Corporation. All rights reserved.
3
Licensed under the MIT License.
4
'''
5
6
from monai.transforms import (
7
    AsDiscrete,
8
    Compose,
9
)
10
import argparse
11
from monai.inferers import sliding_window_inference
12
from monai.data import CacheDataset, DataLoader, decollate_batch
13
import torch
14
import matplotlib.pyplot as plt
15
import os
16
import pandas as pd
17
import time
18
from torch.utils.data.distributed import DistributedSampler
19
from torch.nn.parallel import DistributedDataParallel as DDP
20
import torch.distributed as dist 
21
import os
22
from initialize_train import (
23
    create_data_split_files,
24
    get_train_valid_data_in_dict_format, 
25
    get_train_transforms, 
26
    get_valid_transforms, 
27
    get_model, 
28
    get_loss_function,
29
    get_optimizer, 
30
    get_scheduler,
31
    get_metric,
32
    get_validation_sliding_window_size
33
)
34
35
import sys
36
config_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")
37
sys.path.append(config_dir)
38
from config import RESULTS_FOLDER
39
torch.backends.cudnn.benchmark = True
40
#%%
41
def ddp_setup():
42
    dist.init_process_group(backend='nccl', init_method="env://")
43
44
def convert_to_4digits(str_num):
45
    if len(str_num) == 1:
46
        new_num = '000' + str_num
47
    elif len(str_num) == 2:
48
        new_num = '00' + str_num
49
    elif len(str_num) == 3:
50
        new_num = '0' + str_num
51
    else:
52
        new_num = str_num
53
    return new_num
54
55
#%%
56
def load_train_objects(args):
57
    train_data, valid_data = get_train_valid_data_in_dict_format(args.fold) 
58
    train_transforms = get_train_transforms(args.input_patch_size)
59
    valid_transforms = get_valid_transforms()
60
    model = get_model(args.network_name, args.input_patch_size) 
61
    optimizer = get_optimizer(model, learning_rate=args.lr, weight_decay=args.wd)
62
    loss_function = get_loss_function()
63
    scheduler = get_scheduler(optimizer, args.epochs)
64
    metric = get_metric()
65
66
    return (
67
        train_data,
68
        valid_data,
69
        train_transforms,
70
        valid_transforms,
71
        model,
72
        loss_function,
73
        optimizer,
74
        scheduler,
75
        metric
76
    )
77
78
79
def prepare_dataset(data, transforms, args):
80
    dataset = CacheDataset(data=data, transform=transforms, cache_rate=args.cache_rate, num_workers=args.num_workers)
81
    return dataset
82
83
84
def main_worker(save_models_dir, save_logs_dir, args):
85
    # init_process_group
86
    ddp_setup() 
87
    # get local rank on the GPU
88
    local_rank = int(dist.get_rank())
89
    if local_rank == 0:
90
        print(f"Training {args.network_name} on fold {args.fold}")
91
        print(f"The models will be saved in {save_models_dir}")
92
        print(f"The training/validation logs will be saved in {save_logs_dir}")
93
94
    # get all training and validation objects
95
    train_data, valid_data, train_transforms, valid_transforms, model, loss_function, optimizer, scheduler, metric = load_train_objects(args)
96
97
    # get dataset of object-type CacheDataset 
98
    train_dataset = prepare_dataset(train_data, train_transforms, args)
99
    valid_dataset = prepare_dataset(valid_data, valid_transforms, args)
100
101
    # get DistributedSampler instances for both training and validation dataloader
102
    # this will be used to split data into different GPUs
103
    train_sampler = DistributedSampler(dataset=train_dataset, shuffle=True)
104
    valid_sampler = DistributedSampler(dataset=valid_dataset, shuffle=False)
105
    
106
    # initializing train and valid dataloaders
107
    train_dataloader = DataLoader(
108
        train_dataset,
109
        batch_size=args.train_bs,
110
        pin_memory=True,
111
        shuffle=False,
112
        sampler=train_sampler,
113
        num_workers=args.num_workers
114
    )
115
    valid_dataloader = DataLoader(
116
        valid_dataset,
117
        batch_size=1,
118
        pin_memory=True,
119
        shuffle=False,
120
        sampler=valid_sampler,
121
        num_workers=args.num_workers
122
    )
123
124
    post_pred = Compose([AsDiscrete(argmax=True, to_onehot=2)])
125
    post_label = Compose([AsDiscrete(to_onehot=2)])
126
127
    # filepaths for storing training and validation logs from different GPUs
128
    trainlog_fpath = os.path.join(save_logs_dir, f'trainlog_gpu{local_rank}.csv')
129
    validlog_fpath = os.path.join(save_logs_dir, f'validlog_gpu{local_rank}.csv')
130
131
    # initialize the GPU device    
132
    device = torch.device(f"cuda:{local_rank}")
133
    torch.cuda.set_device(device)
134
135
    # number of epochs and epoch interval for running validation
136
    max_epochs = args.epochs
137
    val_interval = args.val_interval
138
139
    # push models to device
140
    model = model.to(device)
141
142
    epoch_loss_values = []
143
    metric_values = []
144
145
    # wrap the model with DDP
146
    model = DDP(model, device_ids=[device])
147
        
148
    experiment_start_time = time.time()
149
    
150
    for epoch in range(max_epochs):
151
        epoch_start_time = time.time()
152
        print(f"[GPU{local_rank}]: Running training: epoch = {epoch + 1}")
153
        model.train()
154
        epoch_loss = 0
155
        step = 0
156
        train_sampler.set_epoch(epoch)
157
        for batch_data in train_dataloader:
158
            step += 1
159
            inputs, labels = (
160
                batch_data['CTPT'].to(device),
161
                batch_data['GT'].to(device),
162
            )
163
            optimizer.zero_grad()
164
            outputs = model(inputs)
165
            loss = loss_function(outputs, labels)
166
            loss.backward()
167
            optimizer.step()
168
            epoch_loss += loss.item()
169
        epoch_loss /= step
170
        print(f"[GPU:{local_rank}]: epoch {epoch + 1}/{max_epochs}: average loss: {epoch_loss:.4f}")
171
        epoch_loss_values.append(epoch_loss)
172
173
        # steps forward the CosineAnnealingLR scheduler
174
        scheduler.step()
175
176
        # update the training log file
177
        epoch_loss_values_df = pd.DataFrame(data=epoch_loss_values, columns=['Loss'])
178
        epoch_loss_values_df.to_csv(trainlog_fpath, index=False)
179
180
181
        if (epoch + 1) % val_interval == 0:
182
            print(f"[GPU{local_rank}]: Running validation")
183
            model.eval()
184
            with torch.no_grad():
185
                for val_data in valid_dataloader:
186
                    val_inputs, val_labels = (
187
                        val_data['CTPT'].to(device),
188
                        val_data['GT'].to(device),
189
                    )
190
                    roi_size = get_validation_sliding_window_size(args.input_patch_size) 
191
                    sw_batch_size = args.sw_bs
192
                    val_outputs = sliding_window_inference(
193
                        val_inputs, roi_size, sw_batch_size, model)
194
                    val_outputs = [post_pred(i) for i in decollate_batch(val_outputs)]
195
                    val_labels = [post_label(i) for i in decollate_batch(val_labels)]
196
                    # compute metric for current iteration
197
                    metric(y_pred=val_outputs, y=val_labels)
198
199
                # aggregate the final mean dice result
200
                metric_val = metric.aggregate().item()
201
                metric.reset()
202
                metric_values.append(metric_val)
203
                metric_values_df = pd.DataFrame(data=metric_values, columns=['Metric'])
204
                metric_values_df.to_csv(validlog_fpath, index=False)
205
               
206
                print(f"[GPU:{local_rank}] SAVING MODEL at epoch: {epoch + 1}; Mean DSC: {metric_val:.4f}")
207
                savepath = os.path.join(save_models_dir, "model_ep="+convert_to_4digits(str(int(epoch + 1)))+".pth")
208
                torch.save(model.module.state_dict(), savepath)
209
210
        epoch_end_time = (time.time() - epoch_start_time)/60
211
        print(f"[GPU:{local_rank}]: Epoch {epoch + 1} time: {round(epoch_end_time,2)} min")
212
       
213
    experiment_end_time = (time.time() - experiment_start_time)/(60*60)
214
    print(f"[GPU:{local_rank}]: Total time: {round(experiment_end_time,2)} hr")
215
216
    dist.destroy_process_group()
217
218
def main(args):
219
    os.environ['OMP_NUM_THREADS'] = '6'
220
    fold = args.fold
221
    network = args.network_name
222
    inputsize = f'randcrop{args.input_patch_size}'
223
224
    experiment_code = f"{network}_fold{fold}_{inputsize}"
225
226
    #save models folder
227
    save_models_dir = os.path.join(RESULTS_FOLDER,'models')
228
    save_models_dir = os.path.join(save_models_dir, 'fold'+str(fold), network, experiment_code)
229
    os.makedirs(save_models_dir, exist_ok=True)
230
    
231
    # save train and valid logs folder
232
    save_logs_dir = os.path.join(RESULTS_FOLDER,'logs')
233
    save_logs_dir = os.path.join(save_logs_dir, 'fold'+str(fold), network, experiment_code)
234
    os.makedirs(save_logs_dir, exist_ok=True)
235
    
236
    main_worker(save_models_dir, save_logs_dir, args)
237
    
238
239
240
if __name__ == "__main__": 
241
    # create datasplit files for train and test images
242
    # follow all the instructions for dataset directory creation and images/labels file names as given in: LINK
243
    create_data_split_files() 
244
    parser = argparse.ArgumentParser(description='Lymphoma PET/CT lesion segmentation using MONAI-PyTorch')
245
    parser.add_argument('--fold', type=int, default=0, metavar='fold',
246
                        help='validation fold (default: 0), remaining folds will be used for training')
247
    parser.add_argument('--network-name', type=str, default='unet', metavar='netname',
248
                        help='network name for training (default: unet)')
249
    parser.add_argument('--epochs', type=int, default=500, metavar='epochs',
250
                        help='number of epochs to train (default: 10)')
251
    parser.add_argument('--input-patch-size', type=int, default=192, metavar='inputsize',
252
                        help='size of cropped input patch for training (default: 192)')
253
    parser.add_argument('--train-bs', type=int, default=1, metavar='train-bs',
254
                        help='mini-batchsize for training (default: 1)')
255
    parser.add_argument('--num_workers', type=int, default=2, metavar='nw',
256
                        help='num_workers for train and validation dataloaders (default: 2)')
257
    parser.add_argument('--cache-rate', type=float, default=0.1, metavar='cr',
258
                        help='cache_rate for CacheDataset from MONAI (default=0.1)')
259
    parser.add_argument('--lr', type=float, default=2e-4, metavar='lr',
260
                        help='initial learning rate for AdamW optimizer (default=2e-4); Cosine scheduler will decrease this to 0 in args.epochs epochs')
261
    parser.add_argument('--wd', type=float, default=1e-5, metavar='wd',
262
                        help='weight-decay for AdamW optimizer (default=1e-5)')
263
    parser.add_argument('--val-interval', type=int, default=2, metavar='val-interval',
264
                        help='epochs interval for which validation will be performed (default=2)')
265
    parser.add_argument('--sw-bs', type=int, default=2, metavar='sw-bs',
266
                        help='batchsize for sliding window inference (default=2)')
267
    args = parser.parse_args()
268
    
269
    main(args)
270