--- a +++ b/segmentation/inference.py @@ -0,0 +1,125 @@ +#%% +''' +Copyright (c) Microsoft Corporation. All rights reserved. +Licensed under the MIT License. +''' +import numpy as np +import glob +import os +import pandas as pd +import SimpleITK as sitk +import sys +import argparse +from monai.inferers import sliding_window_inference +from monai.data import DataLoader, Dataset, decollate_batch +import torch +import os +import glob +import pandas as pd +import numpy as np +import torch.nn as nn +import time +from initialize_train import ( + get_validation_sliding_window_size, + get_model, + get_test_data_in_dict_format, + get_valid_transforms, + get_post_transforms +) +config_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..") +sys.path.append(config_dir) +from config import RESULTS_FOLDER +#%% +def convert_to_4digits(str_num): + if len(str_num) == 1: + new_num = '000' + str_num + elif len(str_num) == 2: + new_num = '00' + str_num + elif len(str_num) == 3: + new_num = '0' + str_num + else: + new_num = str_num + return new_num + +def create_dictionary_ctptgt(ctpaths, ptpaths, gtpaths): + data = [] + for i in range(len(gtpaths)): + ctpath = ctpaths[i] + ptpath = ptpaths[i] + gtpath = gtpaths[i] + data.append({'CT':ctpath, 'PT':ptpath, 'GT':gtpath}) + return data + +def read_image_array(path): + img = sitk.ReadImage(path) + array = np.transpose(sitk.GetArrayFromImage(img), (2,1,0)) + return array + +#%% +def main(args): + # initialize inference + fold = args.fold + network = args.network_name + inputsize = args.input_patch_size + experiment_code = f"{network}_fold{fold}_randcrop{inputsize}" + sw_roi_size = get_validation_sliding_window_size(inputsize) # get sliding_window inference size for given input patch size + + # find the best model for this experiment from the training/validation logs + # best model is the model with the best validation `Metric` (DSC) + save_logs_dir = os.path.join(RESULTS_FOLDER, 'logs') + validlog_fname = os.path.join(save_logs_dir, 'fold'+str(fold), network, experiment_code, 'validlog_gpu0.csv') + validlog = pd.read_csv(validlog_fname) + best_epoch = 2*(np.argmax(validlog['Metric']) + 1) + best_metric = np.max(validlog['Metric']) + print(f"Using the {network} model at epoch={best_epoch} with mean valid DSC = {round(best_metric, 4)}") + + # get the best model and push it to device=cuda:0 + save_models_dir = os.path.join(RESULTS_FOLDER,'models') + save_models_dir = os.path.join(save_models_dir, 'fold'+str(fold), network, experiment_code) + best_model_fname = 'model_ep=' + convert_to_4digits(str(best_epoch)) +'.pth' + model_path = os.path.join(save_models_dir, best_model_fname) + device = torch.device(f"cuda:0") + model = get_model(network, input_patch_size=inputsize) + model.load_state_dict(torch.load(model_path, map_location=device)) + model.to(device) + + # initialize the location to save predicted masks + save_preds_dir = os.path.join(RESULTS_FOLDER, f'predictions') + save_preds_dir = os.path.join(save_preds_dir, 'fold'+str(fold), network, experiment_code) + os.makedirs(save_preds_dir, exist_ok=True) + + # get test data (in dictionary format for MONAI dataloader), test_transforms and post_transforms + test_data = get_test_data_in_dict_format() + test_transforms = get_valid_transforms() + post_transforms = get_post_transforms(test_transforms, save_preds_dir) + + # initalize PyTorch dataset and Dataloader + dataset_test = Dataset(data=test_data, transform=test_transforms) + dataloader_test = DataLoader(dataset_test, batch_size=1, shuffle=False, num_workers=args.num_workers) + + model.eval() + with torch.no_grad(): + for data in dataloader_test: + inputs = data['CTPT'].to(device) + sw_batch_size = args.sw_bs + print(sw_batch_size) + data['Pred'] = sliding_window_inference(inputs, sw_roi_size, sw_batch_size, model) + data = [post_transforms(i) for i in decollate_batch(data)] + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Lymphoma PET/CT lesion segmentation using MONAI-PyTorch') + parser.add_argument('--fold', type=int, default=0, metavar='fold', + help='validation fold (default: 0), remaining folds will be used for training') + parser.add_argument('--network-name', type=str, default='unet', metavar='netname', + help='network name for training (default: unet)') + parser.add_argument('--input-patch-size', type=int, default=192, metavar='inputsize', + help='size of cropped input patch for training (default: 192)') + parser.add_argument('--num_workers', type=int, default=2, metavar='nw', + help='num_workers for train and validation dataloaders (default: 2)') + parser.add_argument('--sw-bs', type=int, default=2, metavar='sw-bs', + help='batchsize for sliding window inference (default=2)') + args = parser.parse_args() + + main(args) +