|
a |
|
b/test.py |
|
|
1 |
import sys |
|
|
2 |
from time import time |
|
|
3 |
import pickle |
|
|
4 |
from os import makedirs |
|
|
5 |
from os.path import join |
|
|
6 |
import argparse |
|
|
7 |
|
|
|
8 |
import torch |
|
|
9 |
from torch import nn, sigmoid |
|
|
10 |
from torch.utils.data import DataLoader |
|
|
11 |
|
|
|
12 |
from model import UNet3D |
|
|
13 |
from dataset import MRIDataset |
|
|
14 |
from utils import Report, transfer_weights |
|
|
15 |
""" |
|
|
16 |
Run inference on test set. |
|
|
17 |
This script also saves voxel-wise inference results and labels on numpy arrays |
|
|
18 |
for future reuse (without running on gpu). |
|
|
19 |
""" |
|
|
20 |
|
|
|
21 |
argv = sys.argv[1:] |
|
|
22 |
parser = argparse.ArgumentParser( |
|
|
23 |
formatter_class=argparse.RawDescriptionHelpFormatter, |
|
|
24 |
prog='PROG',) |
|
|
25 |
parser.add_argument('--data_dir', |
|
|
26 |
type=str, |
|
|
27 |
required=True, |
|
|
28 |
help="path to images (should have subdirs: test)") |
|
|
29 |
parser.add_argument('--model_path', '-m', type=str, |
|
|
30 |
required=True, |
|
|
31 |
help='path to a model snapshot') |
|
|
32 |
parser.add_argument('--size', type=int, nargs='+', default=[128]) |
|
|
33 |
parser.add_argument('--dump_name', |
|
|
34 |
type=str, |
|
|
35 |
default='', |
|
|
36 |
help='name dump file to avoid overwrting files.') |
|
|
37 |
args = parser.parse_args(argv) |
|
|
38 |
|
|
|
39 |
net = UNet3D(1, 1, use_bias=True, inplanes=16) |
|
|
40 |
transfer_weights(net, args.model_path) |
|
|
41 |
net.cuda() |
|
|
42 |
net.train(False) |
|
|
43 |
torch.no_grad() |
|
|
44 |
|
|
|
45 |
test_dir = join(args.data_dir, 'test') |
|
|
46 |
batch_size = 1 |
|
|
47 |
|
|
|
48 |
|
|
|
49 |
def inference(target_dir): |
|
|
50 |
volume_size = args.size*3 if len(args.size) == 1 else args.size |
|
|
51 |
dataloader = DataLoader(MRIDataset(target_dir, |
|
|
52 |
volume_size, |
|
|
53 |
sampling_mode='center', |
|
|
54 |
deterministic=True), |
|
|
55 |
batch_size=1, |
|
|
56 |
num_workers=4) |
|
|
57 |
input_paths = dataloader.dataset.inputs |
|
|
58 |
label_paths = dataloader.dataset.labels |
|
|
59 |
|
|
|
60 |
reporter = Report() |
|
|
61 |
preds_list = list() |
|
|
62 |
labels_list = list() |
|
|
63 |
sum_hd = 0 |
|
|
64 |
sum_sd = 0 |
|
|
65 |
num_voxels_hd = list() |
|
|
66 |
for i, (inputs, labels) in enumerate(dataloader): |
|
|
67 |
inputs = inputs.cuda() |
|
|
68 |
preds = sigmoid(net(inputs.detach()).detach()) |
|
|
69 |
|
|
|
70 |
full_labels = dataloader.dataset._load_full_label(label_paths[i]) |
|
|
71 |
full_preds = dataloader.dataset._project_full_label(input_paths[i], |
|
|
72 |
preds.cpu()) |
|
|
73 |
preds = full_preds |
|
|
74 |
labels = full_labels |
|
|
75 |
|
|
|
76 |
reporter.feed(preds, labels) |
|
|
77 |
temp_reporter = Report() |
|
|
78 |
temp_reporter.feed(preds, labels) |
|
|
79 |
sum_hd += temp_reporter.hard_dice() |
|
|
80 |
sum_sd += temp_reporter.soft_dice() |
|
|
81 |
num_voxels_hd.append( |
|
|
82 |
(labels.view(-1).sum().item(), temp_reporter.hard_dice())) |
|
|
83 |
del temp_reporter |
|
|
84 |
|
|
|
85 |
preds_list.append(preds.cpu()) |
|
|
86 |
labels_list.append(labels.cpu()) |
|
|
87 |
del inputs, labels, preds |
|
|
88 |
print("Micro Averaged Dice {}, {}".format(sum_hd / len(dataloader), |
|
|
89 |
sum_sd / len(dataloader))) |
|
|
90 |
if len(args.dump_name): |
|
|
91 |
# dump preds for visualization |
|
|
92 |
pickle.dump([input_paths, preds_list, labels_list], |
|
|
93 |
open('preds_dump_{}.pickle'.format(args.dump_name), 'wb')) |
|
|
94 |
print(reporter) |
|
|
95 |
print(reporter.stats()) |
|
|
96 |
preds = torch.stack(preds_list).view(-1).numpy() |
|
|
97 |
labels = torch.stack(labels_list).view(-1).numpy().astype(int) |
|
|
98 |
print(num_voxels_hd) |
|
|
99 |
return preds, labels |
|
|
100 |
|
|
|
101 |
|
|
|
102 |
preds, labels = inference(test_dir) |