|
a |
|
b/inference_ssas.py |
|
|
1 |
import os |
|
|
2 |
import argparse |
|
|
3 |
import torch |
|
|
4 |
from networks.vnet_sdf import VNet |
|
|
5 |
from utils.test_patch_sass import test_all_case |
|
|
6 |
|
|
|
7 |
parser = argparse.ArgumentParser() |
|
|
8 |
parser.add_argument('--dataset_name', type=str, default='LA', help='dataset_name') |
|
|
9 |
parser.add_argument('--root_path', type=str, default='/data/omnisky/postgraduate/Yb/data_set/LASet/data', help='Name of Experiment') |
|
|
10 |
parser.add_argument('--exp', type=str, default='vnet', help='exp_name') |
|
|
11 |
parser.add_argument('--model', type=str, default='vnet_DTC', help='model_name') |
|
|
12 |
parser.add_argument('--gpu', type=str, default='1', help='GPU to use') |
|
|
13 |
parser.add_argument('--labelnum', type=int, default=11, help='labeled data') |
|
|
14 |
parser.add_argument('--iter', type=int, default=6000, help='model iteration') |
|
|
15 |
parser.add_argument('--detail', type=int, default=1, help='print metrics for every samples?') |
|
|
16 |
parser.add_argument('--nms', type=int, default=0, help='apply NMS post-procssing?') |
|
|
17 |
|
|
|
18 |
|
|
|
19 |
FLAGS = parser.parse_args() |
|
|
20 |
|
|
|
21 |
os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu |
|
|
22 |
snapshot_path = "../model/{}".format(FLAGS.model) |
|
|
23 |
|
|
|
24 |
num_classes = 2 |
|
|
25 |
|
|
|
26 |
test_save_path = "model/{}_{}_{}_labeled/{}_predictions/".format(FLAGS.dataset_name, FLAGS.exp, FLAGS.labelnum, FLAGS.model) |
|
|
27 |
if not os.path.exists(test_save_path): |
|
|
28 |
os.makedirs(test_save_path) |
|
|
29 |
print(test_save_path) |
|
|
30 |
with open(FLAGS.root_path + '/../test.list', 'r') as f: |
|
|
31 |
image_list = f.readlines() |
|
|
32 |
image_list = [FLAGS.root_path + "/" + item.replace('\n', '') + "/mri_norm2.h5" for item in |
|
|
33 |
image_list] |
|
|
34 |
|
|
|
35 |
|
|
|
36 |
def test_calculate_metric(epoch_num): |
|
|
37 |
net = VNet(n_channels=1, n_classes=num_classes-1, normalization='batchnorm', has_dropout=False).cuda() |
|
|
38 |
save_mode_path = 'model/LA_vnet_12_labeled/sassnet_label12/iter_5200_dice_0.8954771273472677.pth' |
|
|
39 |
net.load_state_dict(torch.load(save_mode_path)) |
|
|
40 |
print("init weight from {}".format(save_mode_path)) |
|
|
41 |
net.eval() |
|
|
42 |
|
|
|
43 |
avg_metric = test_all_case(net, image_list, num_classes=num_classes, |
|
|
44 |
patch_size=(112, 112, 80), stride_xy=18, stride_z=4, |
|
|
45 |
save_result=False, test_save_path=test_save_path, |
|
|
46 |
metric_detail=FLAGS.detail, nms=FLAGS.nms) |
|
|
47 |
|
|
|
48 |
return avg_metric |
|
|
49 |
|
|
|
50 |
|
|
|
51 |
if __name__ == '__main__': |
|
|
52 |
metric = test_calculate_metric(FLAGS.iter) #6000 |
|
|
53 |
print(metric) |
|
|
54 |
|
|
|
55 |
# python test_LA.py --model 0214_re01 --gpu 0 |