--- a +++ b/inference_URPC.py @@ -0,0 +1,49 @@ +import os +import argparse +import torch +from networks.unet_urpc import unet_3D_dv_semi +from utils.test_patch import test_all_case + +parser = argparse.ArgumentParser() +parser.add_argument('--dataset_name', type=str, default='LA', help='dataset_name') +parser.add_argument('--root_path', type=str, default='/***/data_set/LASet/data/', help='Name of Experiment') +parser.add_argument('--exp', type=str, default='vnet', help='exp_name') +parser.add_argument('--model', type=str, default='URPC', help='model_name') +parser.add_argument('--gpu', type=str, default='0', help='GPU to use') +parser.add_argument('--detail', type=int, default=1, help='print metrics for every samples?') +parser.add_argument('--labelnum', type=int, default=25, help='labeled data') +parser.add_argument('--nms', type=int, default=0, help='apply NMS post-procssing?') + +FLAGS = parser.parse_args() +os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu +test_save_path = 'predictions/URPC/' + +num_classes = 2 +patch_size = (112, 112, 80) +FLAGS.root_path = FLAGS.root_path +with open(FLAGS.root_path + '/../test.list', 'r') as f: + image_list = f.readlines() +image_list = [FLAGS.root_path + item.replace('\n', '') + "/mri_norm2.h5" for item in image_list] + +if not os.path.exists(test_save_path): + os.makedirs(test_save_path) +print(test_save_path) + +def test_calculate_metric(): + net = unet_3D_dv_semi(n_classes=num_classes, in_channels=1).cuda() + save_mode_path = 'model/LA_vnet_25_labeled/URPC/URPC_best_model.pth' + net.load_state_dict(torch.load(save_mode_path), strict=False) # False + print("init weight from {}".format(save_mode_path)) + net.eval() + + avg_metric = test_all_case(FLAGS.model, 1, net, image_list, num_classes=num_classes, + patch_size=(112, 112, 80), stride_xy=18, stride_z=4, + save_result=False, test_save_path=test_save_path, + metric_detail=FLAGS.detail, nms=FLAGS.nms) + + return avg_metric + + +if __name__ == '__main__': + metric = test_calculate_metric() + print(metric)