Diff of /segmentation/inference.py [000000] .. [18498b]

Switch to unified view

a b/segmentation/inference.py
1
#%%
2
'''
3
Copyright (c) Microsoft Corporation. All rights reserved.
4
Licensed under the MIT License.
5
'''
6
import numpy as np 
7
import glob
8
import os 
9
import pandas as pd 
10
import SimpleITK as sitk
11
import sys
12
import argparse
13
from monai.inferers import sliding_window_inference
14
from monai.data import DataLoader, Dataset, decollate_batch
15
import torch
16
import os
17
import glob
18
import pandas as pd
19
import numpy as np
20
import torch.nn as nn
21
import time
22
from initialize_train import (
23
    get_validation_sliding_window_size,
24
    get_model,
25
    get_test_data_in_dict_format,
26
    get_valid_transforms,
27
    get_post_transforms
28
)
29
config_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")
30
sys.path.append(config_dir)
31
from config import RESULTS_FOLDER
32
#%%
33
def convert_to_4digits(str_num):
34
    if len(str_num) == 1:
35
        new_num = '000' + str_num
36
    elif len(str_num) == 2:
37
        new_num = '00' + str_num
38
    elif len(str_num) == 3:
39
        new_num = '0' + str_num
40
    else:
41
        new_num = str_num
42
    return new_num
43
44
def create_dictionary_ctptgt(ctpaths, ptpaths, gtpaths):
45
    data = []
46
    for i in range(len(gtpaths)):
47
        ctpath = ctpaths[i]
48
        ptpath = ptpaths[i]
49
        gtpath = gtpaths[i]
50
        data.append({'CT':ctpath, 'PT':ptpath, 'GT':gtpath})
51
    return data
52
53
def read_image_array(path):
54
    img =  sitk.ReadImage(path)
55
    array = np.transpose(sitk.GetArrayFromImage(img), (2,1,0))
56
    return array
57
58
#%%
59
def main(args):
60
    # initialize inference
61
    fold = args.fold
62
    network = args.network_name
63
    inputsize = args.input_patch_size
64
    experiment_code = f"{network}_fold{fold}_randcrop{inputsize}"
65
    sw_roi_size = get_validation_sliding_window_size(inputsize) # get sliding_window inference size for given input patch size
66
    
67
    # find the best model for this experiment from the training/validation logs
68
    # best model is the model with the best validation `Metric` (DSC)
69
    save_logs_dir = os.path.join(RESULTS_FOLDER, 'logs')
70
    validlog_fname = os.path.join(save_logs_dir, 'fold'+str(fold), network, experiment_code, 'validlog_gpu0.csv')
71
    validlog = pd.read_csv(validlog_fname)
72
    best_epoch = 2*(np.argmax(validlog['Metric']) + 1)
73
    best_metric = np.max(validlog['Metric'])
74
    print(f"Using the {network} model at epoch={best_epoch} with mean valid DSC = {round(best_metric, 4)}")
75
76
    # get the best model and push it to device=cuda:0
77
    save_models_dir = os.path.join(RESULTS_FOLDER,'models')
78
    save_models_dir = os.path.join(save_models_dir, 'fold'+str(fold), network, experiment_code)
79
    best_model_fname = 'model_ep=' + convert_to_4digits(str(best_epoch)) +'.pth'
80
    model_path = os.path.join(save_models_dir, best_model_fname)
81
    device = torch.device(f"cuda:0")
82
    model = get_model(network, input_patch_size=inputsize)
83
    model.load_state_dict(torch.load(model_path, map_location=device))
84
    model.to(device)
85
        
86
    # initialize the location to save predicted masks
87
    save_preds_dir = os.path.join(RESULTS_FOLDER, f'predictions')
88
    save_preds_dir = os.path.join(save_preds_dir, 'fold'+str(fold), network, experiment_code)
89
    os.makedirs(save_preds_dir, exist_ok=True)
90
91
    # get test data (in dictionary format for MONAI dataloader), test_transforms and post_transforms
92
    test_data = get_test_data_in_dict_format()
93
    test_transforms = get_valid_transforms()
94
    post_transforms = get_post_transforms(test_transforms, save_preds_dir)
95
    
96
    # initalize PyTorch dataset and Dataloader
97
    dataset_test = Dataset(data=test_data, transform=test_transforms)
98
    dataloader_test = DataLoader(dataset_test, batch_size=1, shuffle=False, num_workers=args.num_workers)
99
100
    model.eval()
101
    with torch.no_grad():
102
        for data in dataloader_test:
103
            inputs = data['CTPT'].to(device)
104
            sw_batch_size = args.sw_bs
105
            print(sw_batch_size)
106
            data['Pred'] = sliding_window_inference(inputs, sw_roi_size, sw_batch_size, model)
107
            data = [post_transforms(i) for i in decollate_batch(data)]
108
109
110
if __name__ == "__main__": 
111
    parser = argparse.ArgumentParser(description='Lymphoma PET/CT lesion segmentation using MONAI-PyTorch')
112
    parser.add_argument('--fold', type=int, default=0, metavar='fold',
113
                        help='validation fold (default: 0), remaining folds will be used for training')
114
    parser.add_argument('--network-name', type=str, default='unet', metavar='netname',
115
                        help='network name for training (default: unet)')
116
    parser.add_argument('--input-patch-size', type=int, default=192, metavar='inputsize',
117
                        help='size of cropped input patch for training (default: 192)')
118
    parser.add_argument('--num_workers', type=int, default=2, metavar='nw',
119
                        help='num_workers for train and validation dataloaders (default: 2)')
120
    parser.add_argument('--sw-bs', type=int, default=2, metavar='sw-bs',
121
                        help='batchsize for sliding window inference (default=2)')
122
    args = parser.parse_args()
123
    
124
    main(args)
125