Diff of /test.py [000000] .. [9ff54e]

Switch to unified view

a b/test.py
1
import sys
2
from time import time
3
import pickle
4
from os import makedirs
5
from os.path import join
6
import argparse
7
8
import torch
9
from torch import nn, sigmoid
10
from torch.utils.data import DataLoader
11
12
from model import UNet3D
13
from dataset import MRIDataset
14
from utils import Report, transfer_weights
15
"""
16
Run inference on test set.
17
This script also saves voxel-wise inference results and labels on numpy arrays
18
for future reuse (without running on gpu).
19
"""
20
21
argv = sys.argv[1:]
22
parser = argparse.ArgumentParser(
23
    formatter_class=argparse.RawDescriptionHelpFormatter,
24
    prog='PROG',)
25
parser.add_argument('--data_dir',
26
                    type=str,
27
                    required=True,
28
                    help="path to images (should have subdirs: test)")
29
parser.add_argument('--model_path', '-m', type=str,
30
                    required=True,
31
                    help='path to a model snapshot')
32
parser.add_argument('--size', type=int, nargs='+', default=[128])
33
parser.add_argument('--dump_name',
34
                    type=str,
35
                    default='',
36
                    help='name dump file to avoid overwrting files.')
37
args = parser.parse_args(argv)
38
39
net = UNet3D(1, 1, use_bias=True, inplanes=16)
40
transfer_weights(net, args.model_path)
41
net.cuda()
42
net.train(False)
43
torch.no_grad()
44
45
test_dir = join(args.data_dir, 'test')
46
batch_size = 1
47
48
49
def inference(target_dir):
50
    volume_size = args.size*3 if len(args.size) == 1 else args.size
51
    dataloader = DataLoader(MRIDataset(target_dir,
52
                                       volume_size,
53
                                       sampling_mode='center',
54
                                       deterministic=True),
55
                            batch_size=1,
56
                            num_workers=4)
57
    input_paths = dataloader.dataset.inputs
58
    label_paths = dataloader.dataset.labels
59
60
    reporter = Report()
61
    preds_list = list()
62
    labels_list = list()
63
    sum_hd = 0
64
    sum_sd = 0
65
    num_voxels_hd = list()
66
    for i, (inputs, labels) in enumerate(dataloader):
67
        inputs = inputs.cuda()
68
        preds = sigmoid(net(inputs.detach()).detach())
69
70
        full_labels = dataloader.dataset._load_full_label(label_paths[i])
71
        full_preds = dataloader.dataset._project_full_label(input_paths[i],
72
                                                            preds.cpu())
73
        preds = full_preds
74
        labels = full_labels
75
76
        reporter.feed(preds, labels)
77
        temp_reporter = Report()
78
        temp_reporter.feed(preds, labels)
79
        sum_hd += temp_reporter.hard_dice()
80
        sum_sd += temp_reporter.soft_dice()
81
        num_voxels_hd.append(
82
            (labels.view(-1).sum().item(), temp_reporter.hard_dice()))
83
        del temp_reporter
84
85
        preds_list.append(preds.cpu())
86
        labels_list.append(labels.cpu())
87
        del inputs, labels, preds
88
    print("Micro Averaged Dice {}, {}".format(sum_hd / len(dataloader),
89
                                              sum_sd / len(dataloader)))
90
    if len(args.dump_name):
91
        # dump preds for visualization
92
        pickle.dump([input_paths, preds_list, labels_list],
93
                    open('preds_dump_{}.pickle'.format(args.dump_name), 'wb'))
94
    print(reporter)
95
    print(reporter.stats())
96
    preds = torch.stack(preds_list).view(-1).numpy()
97
    labels = torch.stack(labels_list).view(-1).numpy().astype(int)
98
    print(num_voxels_hd)
99
    return preds, labels
100
101
102
preds, labels = inference(test_dir)