--- a
+++ b/segmentation/inference.py
@@ -0,0 +1,125 @@
+#%%
+'''
+Copyright (c) Microsoft Corporation. All rights reserved.
+Licensed under the MIT License.
+'''
+import numpy as np 
+import glob
+import os 
+import pandas as pd 
+import SimpleITK as sitk
+import sys
+import argparse
+from monai.inferers import sliding_window_inference
+from monai.data import DataLoader, Dataset, decollate_batch
+import torch
+import os
+import glob
+import pandas as pd
+import numpy as np
+import torch.nn as nn
+import time
+from initialize_train import (
+    get_validation_sliding_window_size,
+    get_model,
+    get_test_data_in_dict_format,
+    get_valid_transforms,
+    get_post_transforms
+)
+config_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")
+sys.path.append(config_dir)
+from config import RESULTS_FOLDER
+#%%
+def convert_to_4digits(str_num):
+    if len(str_num) == 1:
+        new_num = '000' + str_num
+    elif len(str_num) == 2:
+        new_num = '00' + str_num
+    elif len(str_num) == 3:
+        new_num = '0' + str_num
+    else:
+        new_num = str_num
+    return new_num
+
+def create_dictionary_ctptgt(ctpaths, ptpaths, gtpaths):
+    data = []
+    for i in range(len(gtpaths)):
+        ctpath = ctpaths[i]
+        ptpath = ptpaths[i]
+        gtpath = gtpaths[i]
+        data.append({'CT':ctpath, 'PT':ptpath, 'GT':gtpath})
+    return data
+
+def read_image_array(path):
+    img =  sitk.ReadImage(path)
+    array = np.transpose(sitk.GetArrayFromImage(img), (2,1,0))
+    return array
+
+#%%
+def main(args):
+    # initialize inference
+    fold = args.fold
+    network = args.network_name
+    inputsize = args.input_patch_size
+    experiment_code = f"{network}_fold{fold}_randcrop{inputsize}"
+    sw_roi_size = get_validation_sliding_window_size(inputsize) # get sliding_window inference size for given input patch size
+    
+    # find the best model for this experiment from the training/validation logs
+    # best model is the model with the best validation `Metric` (DSC)
+    save_logs_dir = os.path.join(RESULTS_FOLDER, 'logs')
+    validlog_fname = os.path.join(save_logs_dir, 'fold'+str(fold), network, experiment_code, 'validlog_gpu0.csv')
+    validlog = pd.read_csv(validlog_fname)
+    best_epoch = 2*(np.argmax(validlog['Metric']) + 1)
+    best_metric = np.max(validlog['Metric'])
+    print(f"Using the {network} model at epoch={best_epoch} with mean valid DSC = {round(best_metric, 4)}")
+
+    # get the best model and push it to device=cuda:0
+    save_models_dir = os.path.join(RESULTS_FOLDER,'models')
+    save_models_dir = os.path.join(save_models_dir, 'fold'+str(fold), network, experiment_code)
+    best_model_fname = 'model_ep=' + convert_to_4digits(str(best_epoch)) +'.pth'
+    model_path = os.path.join(save_models_dir, best_model_fname)
+    device = torch.device(f"cuda:0")
+    model = get_model(network, input_patch_size=inputsize)
+    model.load_state_dict(torch.load(model_path, map_location=device))
+    model.to(device)
+        
+    # initialize the location to save predicted masks
+    save_preds_dir = os.path.join(RESULTS_FOLDER, f'predictions')
+    save_preds_dir = os.path.join(save_preds_dir, 'fold'+str(fold), network, experiment_code)
+    os.makedirs(save_preds_dir, exist_ok=True)
+
+    # get test data (in dictionary format for MONAI dataloader), test_transforms and post_transforms
+    test_data = get_test_data_in_dict_format()
+    test_transforms = get_valid_transforms()
+    post_transforms = get_post_transforms(test_transforms, save_preds_dir)
+    
+    # initalize PyTorch dataset and Dataloader
+    dataset_test = Dataset(data=test_data, transform=test_transforms)
+    dataloader_test = DataLoader(dataset_test, batch_size=1, shuffle=False, num_workers=args.num_workers)
+
+    model.eval()
+    with torch.no_grad():
+        for data in dataloader_test:
+            inputs = data['CTPT'].to(device)
+            sw_batch_size = args.sw_bs
+            print(sw_batch_size)
+            data['Pred'] = sliding_window_inference(inputs, sw_roi_size, sw_batch_size, model)
+            data = [post_transforms(i) for i in decollate_batch(data)]
+
+
+if __name__ == "__main__": 
+    parser = argparse.ArgumentParser(description='Lymphoma PET/CT lesion segmentation using MONAI-PyTorch')
+    parser.add_argument('--fold', type=int, default=0, metavar='fold',
+                        help='validation fold (default: 0), remaining folds will be used for training')
+    parser.add_argument('--network-name', type=str, default='unet', metavar='netname',
+                        help='network name for training (default: unet)')
+    parser.add_argument('--input-patch-size', type=int, default=192, metavar='inputsize',
+                        help='size of cropped input patch for training (default: 192)')
+    parser.add_argument('--num_workers', type=int, default=2, metavar='nw',
+                        help='num_workers for train and validation dataloaders (default: 2)')
+    parser.add_argument('--sw-bs', type=int, default=2, metavar='sw-bs',
+                        help='batchsize for sliding window inference (default=2)')
+    args = parser.parse_args()
+    
+    main(args)
+