Diff of /inference_URPC.py [000000] .. [903821]

Switch to unified view

a b/inference_URPC.py
1
import os
2
import argparse
3
import torch
4
from networks.unet_urpc import unet_3D_dv_semi
5
from utils.test_patch import test_all_case
6
7
parser = argparse.ArgumentParser()
8
parser.add_argument('--dataset_name', type=str,  default='LA', help='dataset_name')
9
parser.add_argument('--root_path', type=str, default='/***/data_set/LASet/data/', help='Name of Experiment')
10
parser.add_argument('--exp', type=str,  default='vnet', help='exp_name')
11
parser.add_argument('--model', type=str,  default='URPC', help='model_name')
12
parser.add_argument('--gpu', type=str,  default='0', help='GPU to use')
13
parser.add_argument('--detail', type=int,  default=1, help='print metrics for every samples?')
14
parser.add_argument('--labelnum', type=int, default=25, help='labeled data')
15
parser.add_argument('--nms', type=int, default=0, help='apply NMS post-procssing?')
16
17
FLAGS = parser.parse_args()
18
os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu
19
test_save_path = 'predictions/URPC/'
20
21
num_classes = 2
22
patch_size = (112, 112, 80)
23
FLAGS.root_path = FLAGS.root_path
24
with open(FLAGS.root_path + '/../test.list', 'r') as f:
25
    image_list = f.readlines()
26
image_list = [FLAGS.root_path + item.replace('\n', '') + "/mri_norm2.h5" for item in image_list]
27
28
if not os.path.exists(test_save_path):
29
    os.makedirs(test_save_path)
30
print(test_save_path)
31
32
def test_calculate_metric():
33
    net = unet_3D_dv_semi(n_classes=num_classes, in_channels=1).cuda()
34
    save_mode_path = 'model/LA_vnet_25_labeled/URPC/URPC_best_model.pth'
35
    net.load_state_dict(torch.load(save_mode_path), strict=False)  # False
36
    print("init weight from {}".format(save_mode_path))
37
    net.eval()
38
39
    avg_metric = test_all_case(FLAGS.model, 1, net, image_list, num_classes=num_classes,
40
                    patch_size=(112, 112, 80), stride_xy=18, stride_z=4,
41
                    save_result=False, test_save_path=test_save_path,
42
                    metric_detail=FLAGS.detail, nms=FLAGS.nms)
43
44
    return avg_metric
45
46
47
if __name__ == '__main__':
48
    metric = test_calculate_metric()
49
    print(metric)