--- a
+++ b/segmentation/trainddp.py
@@ -0,0 +1,270 @@
+'''
+Copyright (c) Microsoft Corporation. All rights reserved.
+Licensed under the MIT License.
+'''
+
+from monai.transforms import (
+    AsDiscrete,
+    Compose,
+)
+import argparse
+from monai.inferers import sliding_window_inference
+from monai.data import CacheDataset, DataLoader, decollate_batch
+import torch
+import matplotlib.pyplot as plt
+import os
+import pandas as pd
+import time
+from torch.utils.data.distributed import DistributedSampler
+from torch.nn.parallel import DistributedDataParallel as DDP
+import torch.distributed as dist 
+import os
+from initialize_train import (
+    create_data_split_files,
+    get_train_valid_data_in_dict_format, 
+    get_train_transforms, 
+    get_valid_transforms, 
+    get_model, 
+    get_loss_function,
+    get_optimizer, 
+    get_scheduler,
+    get_metric,
+    get_validation_sliding_window_size
+)
+
+import sys
+config_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")
+sys.path.append(config_dir)
+from config import RESULTS_FOLDER
+torch.backends.cudnn.benchmark = True
+#%%
+def ddp_setup():
+    dist.init_process_group(backend='nccl', init_method="env://")
+
+def convert_to_4digits(str_num):
+    if len(str_num) == 1:
+        new_num = '000' + str_num
+    elif len(str_num) == 2:
+        new_num = '00' + str_num
+    elif len(str_num) == 3:
+        new_num = '0' + str_num
+    else:
+        new_num = str_num
+    return new_num
+
+#%%
+def load_train_objects(args):
+    train_data, valid_data = get_train_valid_data_in_dict_format(args.fold) 
+    train_transforms = get_train_transforms(args.input_patch_size)
+    valid_transforms = get_valid_transforms()
+    model = get_model(args.network_name, args.input_patch_size) 
+    optimizer = get_optimizer(model, learning_rate=args.lr, weight_decay=args.wd)
+    loss_function = get_loss_function()
+    scheduler = get_scheduler(optimizer, args.epochs)
+    metric = get_metric()
+
+    return (
+        train_data,
+        valid_data,
+        train_transforms,
+        valid_transforms,
+        model,
+        loss_function,
+        optimizer,
+        scheduler,
+        metric
+    )
+
+
+def prepare_dataset(data, transforms, args):
+    dataset = CacheDataset(data=data, transform=transforms, cache_rate=args.cache_rate, num_workers=args.num_workers)
+    return dataset
+
+
+def main_worker(save_models_dir, save_logs_dir, args):
+    # init_process_group
+    ddp_setup() 
+    # get local rank on the GPU
+    local_rank = int(dist.get_rank())
+    if local_rank == 0:
+        print(f"Training {args.network_name} on fold {args.fold}")
+        print(f"The models will be saved in {save_models_dir}")
+        print(f"The training/validation logs will be saved in {save_logs_dir}")
+
+    # get all training and validation objects
+    train_data, valid_data, train_transforms, valid_transforms, model, loss_function, optimizer, scheduler, metric = load_train_objects(args)
+
+    # get dataset of object-type CacheDataset 
+    train_dataset = prepare_dataset(train_data, train_transforms, args)
+    valid_dataset = prepare_dataset(valid_data, valid_transforms, args)
+
+    # get DistributedSampler instances for both training and validation dataloader
+    # this will be used to split data into different GPUs
+    train_sampler = DistributedSampler(dataset=train_dataset, shuffle=True)
+    valid_sampler = DistributedSampler(dataset=valid_dataset, shuffle=False)
+    
+    # initializing train and valid dataloaders
+    train_dataloader = DataLoader(
+        train_dataset,
+        batch_size=args.train_bs,
+        pin_memory=True,
+        shuffle=False,
+        sampler=train_sampler,
+        num_workers=args.num_workers
+    )
+    valid_dataloader = DataLoader(
+        valid_dataset,
+        batch_size=1,
+        pin_memory=True,
+        shuffle=False,
+        sampler=valid_sampler,
+        num_workers=args.num_workers
+    )
+
+    post_pred = Compose([AsDiscrete(argmax=True, to_onehot=2)])
+    post_label = Compose([AsDiscrete(to_onehot=2)])
+
+    # filepaths for storing training and validation logs from different GPUs
+    trainlog_fpath = os.path.join(save_logs_dir, f'trainlog_gpu{local_rank}.csv')
+    validlog_fpath = os.path.join(save_logs_dir, f'validlog_gpu{local_rank}.csv')
+
+    # initialize the GPU device    
+    device = torch.device(f"cuda:{local_rank}")
+    torch.cuda.set_device(device)
+
+    # number of epochs and epoch interval for running validation
+    max_epochs = args.epochs
+    val_interval = args.val_interval
+
+    # push models to device
+    model = model.to(device)
+
+    epoch_loss_values = []
+    metric_values = []
+
+    # wrap the model with DDP
+    model = DDP(model, device_ids=[device])
+        
+    experiment_start_time = time.time()
+    
+    for epoch in range(max_epochs):
+        epoch_start_time = time.time()
+        print(f"[GPU{local_rank}]: Running training: epoch = {epoch + 1}")
+        model.train()
+        epoch_loss = 0
+        step = 0
+        train_sampler.set_epoch(epoch)
+        for batch_data in train_dataloader:
+            step += 1
+            inputs, labels = (
+                batch_data['CTPT'].to(device),
+                batch_data['GT'].to(device),
+            )
+            optimizer.zero_grad()
+            outputs = model(inputs)
+            loss = loss_function(outputs, labels)
+            loss.backward()
+            optimizer.step()
+            epoch_loss += loss.item()
+        epoch_loss /= step
+        print(f"[GPU:{local_rank}]: epoch {epoch + 1}/{max_epochs}: average loss: {epoch_loss:.4f}")
+        epoch_loss_values.append(epoch_loss)
+
+        # steps forward the CosineAnnealingLR scheduler
+        scheduler.step()
+
+        # update the training log file
+        epoch_loss_values_df = pd.DataFrame(data=epoch_loss_values, columns=['Loss'])
+        epoch_loss_values_df.to_csv(trainlog_fpath, index=False)
+
+
+        if (epoch + 1) % val_interval == 0:
+            print(f"[GPU{local_rank}]: Running validation")
+            model.eval()
+            with torch.no_grad():
+                for val_data in valid_dataloader:
+                    val_inputs, val_labels = (
+                        val_data['CTPT'].to(device),
+                        val_data['GT'].to(device),
+                    )
+                    roi_size = get_validation_sliding_window_size(args.input_patch_size) 
+                    sw_batch_size = args.sw_bs
+                    val_outputs = sliding_window_inference(
+                        val_inputs, roi_size, sw_batch_size, model)
+                    val_outputs = [post_pred(i) for i in decollate_batch(val_outputs)]
+                    val_labels = [post_label(i) for i in decollate_batch(val_labels)]
+                    # compute metric for current iteration
+                    metric(y_pred=val_outputs, y=val_labels)
+
+                # aggregate the final mean dice result
+                metric_val = metric.aggregate().item()
+                metric.reset()
+                metric_values.append(metric_val)
+                metric_values_df = pd.DataFrame(data=metric_values, columns=['Metric'])
+                metric_values_df.to_csv(validlog_fpath, index=False)
+               
+                print(f"[GPU:{local_rank}] SAVING MODEL at epoch: {epoch + 1}; Mean DSC: {metric_val:.4f}")
+                savepath = os.path.join(save_models_dir, "model_ep="+convert_to_4digits(str(int(epoch + 1)))+".pth")
+                torch.save(model.module.state_dict(), savepath)
+
+        epoch_end_time = (time.time() - epoch_start_time)/60
+        print(f"[GPU:{local_rank}]: Epoch {epoch + 1} time: {round(epoch_end_time,2)} min")
+       
+    experiment_end_time = (time.time() - experiment_start_time)/(60*60)
+    print(f"[GPU:{local_rank}]: Total time: {round(experiment_end_time,2)} hr")
+
+    dist.destroy_process_group()
+
+def main(args):
+    os.environ['OMP_NUM_THREADS'] = '6'
+    fold = args.fold
+    network = args.network_name
+    inputsize = f'randcrop{args.input_patch_size}'
+
+    experiment_code = f"{network}_fold{fold}_{inputsize}"
+
+    #save models folder
+    save_models_dir = os.path.join(RESULTS_FOLDER,'models')
+    save_models_dir = os.path.join(save_models_dir, 'fold'+str(fold), network, experiment_code)
+    os.makedirs(save_models_dir, exist_ok=True)
+    
+    # save train and valid logs folder
+    save_logs_dir = os.path.join(RESULTS_FOLDER,'logs')
+    save_logs_dir = os.path.join(save_logs_dir, 'fold'+str(fold), network, experiment_code)
+    os.makedirs(save_logs_dir, exist_ok=True)
+    
+    main_worker(save_models_dir, save_logs_dir, args)
+    
+
+
+if __name__ == "__main__": 
+    # create datasplit files for train and test images
+    # follow all the instructions for dataset directory creation and images/labels file names as given in: LINK
+    create_data_split_files() 
+    parser = argparse.ArgumentParser(description='Lymphoma PET/CT lesion segmentation using MONAI-PyTorch')
+    parser.add_argument('--fold', type=int, default=0, metavar='fold',
+                        help='validation fold (default: 0), remaining folds will be used for training')
+    parser.add_argument('--network-name', type=str, default='unet', metavar='netname',
+                        help='network name for training (default: unet)')
+    parser.add_argument('--epochs', type=int, default=500, metavar='epochs',
+                        help='number of epochs to train (default: 10)')
+    parser.add_argument('--input-patch-size', type=int, default=192, metavar='inputsize',
+                        help='size of cropped input patch for training (default: 192)')
+    parser.add_argument('--train-bs', type=int, default=1, metavar='train-bs',
+                        help='mini-batchsize for training (default: 1)')
+    parser.add_argument('--num_workers', type=int, default=2, metavar='nw',
+                        help='num_workers for train and validation dataloaders (default: 2)')
+    parser.add_argument('--cache-rate', type=float, default=0.1, metavar='cr',
+                        help='cache_rate for CacheDataset from MONAI (default=0.1)')
+    parser.add_argument('--lr', type=float, default=2e-4, metavar='lr',
+                        help='initial learning rate for AdamW optimizer (default=2e-4); Cosine scheduler will decrease this to 0 in args.epochs epochs')
+    parser.add_argument('--wd', type=float, default=1e-5, metavar='wd',
+                        help='weight-decay for AdamW optimizer (default=1e-5)')
+    parser.add_argument('--val-interval', type=int, default=2, metavar='val-interval',
+                        help='epochs interval for which validation will be performed (default=2)')
+    parser.add_argument('--sw-bs', type=int, default=2, metavar='sw-bs',
+                        help='batchsize for sliding window inference (default=2)')
+    args = parser.parse_args()
+    
+    main(args)
+