a b/inference_ssas.py
1
import os
2
import argparse
3
import torch
4
from networks.vnet_sdf import VNet
5
from utils.test_patch_sass import test_all_case
6
7
parser = argparse.ArgumentParser()
8
parser.add_argument('--dataset_name', type=str,  default='LA', help='dataset_name')
9
parser.add_argument('--root_path', type=str, default='/data/omnisky/postgraduate/Yb/data_set/LASet/data', help='Name of Experiment')
10
parser.add_argument('--exp', type=str,  default='vnet', help='exp_name')
11
parser.add_argument('--model', type=str,  default='vnet_DTC', help='model_name')
12
parser.add_argument('--gpu', type=str,  default='1', help='GPU to use')
13
parser.add_argument('--labelnum', type=int, default=11, help='labeled data')
14
parser.add_argument('--iter', type=int,  default=6000, help='model iteration')
15
parser.add_argument('--detail', type=int,  default=1, help='print metrics for every samples?')
16
parser.add_argument('--nms', type=int, default=0, help='apply NMS post-procssing?')
17
18
19
FLAGS = parser.parse_args()
20
21
os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu
22
snapshot_path = "../model/{}".format(FLAGS.model)
23
24
num_classes = 2
25
26
test_save_path = "model/{}_{}_{}_labeled/{}_predictions/".format(FLAGS.dataset_name, FLAGS.exp, FLAGS.labelnum, FLAGS.model)
27
if not os.path.exists(test_save_path):
28
    os.makedirs(test_save_path)
29
print(test_save_path)
30
with open(FLAGS.root_path + '/../test.list', 'r') as f:
31
    image_list = f.readlines()
32
image_list = [FLAGS.root_path + "/" + item.replace('\n', '') + "/mri_norm2.h5" for item in
33
              image_list]
34
35
36
def test_calculate_metric(epoch_num):
37
    net = VNet(n_channels=1, n_classes=num_classes-1, normalization='batchnorm', has_dropout=False).cuda()
38
    save_mode_path = 'model/LA_vnet_12_labeled/sassnet_label12/iter_5200_dice_0.8954771273472677.pth'
39
    net.load_state_dict(torch.load(save_mode_path))
40
    print("init weight from {}".format(save_mode_path))
41
    net.eval()
42
43
    avg_metric = test_all_case(net, image_list, num_classes=num_classes,
44
                           patch_size=(112, 112, 80), stride_xy=18, stride_z=4,
45
                           save_result=False, test_save_path=test_save_path,
46
                           metric_detail=FLAGS.detail, nms=FLAGS.nms)
47
48
    return avg_metric
49
50
51
if __name__ == '__main__':
52
    metric = test_calculate_metric(FLAGS.iter) #6000
53
    print(metric)
54
55
# python test_LA.py --model 0214_re01 --gpu 0