|
a |
|
b/CaraNet/test_blood.py |
|
|
1 |
import torch |
|
|
2 |
import torch.nn.functional as F |
|
|
3 |
import numpy as np |
|
|
4 |
import os, argparse |
|
|
5 |
from scipy import misc |
|
|
6 |
#from lib.HarDMSEG import HarDMSEG |
|
|
7 |
from utils.dataloader import test_dataset |
|
|
8 |
#from CFP_Res2Net import cfpnet_res2net |
|
|
9 |
from collections import OrderedDict |
|
|
10 |
#from pranet import PraNet |
|
|
11 |
from CaraNet import caranet |
|
|
12 |
|
|
|
13 |
import cv2 |
|
|
14 |
|
|
|
15 |
parser = argparse.ArgumentParser() |
|
|
16 |
parser.add_argument('--testsize', type=int, default=352, help='testing size') |
|
|
17 |
parser.add_argument('--pth_path', type=str, default='/home/data/spleen_blood/CaraNet/snapshots/CaraNet-best.pth') |
|
|
18 |
|
|
|
19 |
for _data_name in ['test']: |
|
|
20 |
##### put your data_path here ##### |
|
|
21 |
data_path = '/home/data/spleen_blood/CaraNet/TestDataset/{}/'.format(_data_name) |
|
|
22 |
################################### |
|
|
23 |
|
|
|
24 |
save_path = '/home/data/spleen_blood/CaraNet/results/CaraNet/{}/'.format(_data_name) |
|
|
25 |
|
|
|
26 |
if not os.path.exists( save_path ): |
|
|
27 |
os.makedirs( save_path ) |
|
|
28 |
|
|
|
29 |
opt = parser.parse_args() |
|
|
30 |
model = caranet() |
|
|
31 |
weights = torch.load(opt.pth_path) |
|
|
32 |
new_state_dict = OrderedDict() |
|
|
33 |
|
|
|
34 |
for k, v in weights.items(): |
|
|
35 |
|
|
|
36 |
|
|
|
37 |
if 'total_ops' not in k and 'total_params' not in k: |
|
|
38 |
name = k |
|
|
39 |
new_state_dict[name] = v |
|
|
40 |
# print(new_state_dict[k]) |
|
|
41 |
|
|
|
42 |
# # print(k) |
|
|
43 |
# fp = open('./log3.txt','a') |
|
|
44 |
# fp.write(str(k)+'\n') |
|
|
45 |
# fp.close() |
|
|
46 |
# print(new_state_dict) |
|
|
47 |
|
|
|
48 |
model.load_state_dict(new_state_dict) |
|
|
49 |
model.cuda() |
|
|
50 |
model.eval() |
|
|
51 |
|
|
|
52 |
|
|
|
53 |
os.makedirs(save_path, exist_ok=True) |
|
|
54 |
image_root = '{}/images/'.format(data_path) |
|
|
55 |
gt_root = '{}/masks/'.format(data_path) |
|
|
56 |
test_loader = test_dataset(image_root, gt_root, opt.testsize) |
|
|
57 |
|
|
|
58 |
for i in range(test_loader.size): |
|
|
59 |
image, gt, name = test_loader.load_data() |
|
|
60 |
gt = np.asarray(gt, np.float32) |
|
|
61 |
gt /= (gt.max() + 1e-8) |
|
|
62 |
image = image.cuda() |
|
|
63 |
|
|
|
64 |
# res = model(image) |
|
|
65 |
res5,res4,res2,res1 = model(image) |
|
|
66 |
res = res5 |
|
|
67 |
res = F.upsample(res, size=gt.shape, mode='bilinear', align_corners=False) |
|
|
68 |
res = res.sigmoid().data.cpu().numpy().squeeze() |
|
|
69 |
res = (res - res.min()) / (res.max() - res.min() + 1e-8) |
|
|
70 |
|
|
|
71 |
#misc.imsave(save_path+name, res) |
|
|
72 |
cv2.imwrite(save_path+name, res) |
|
|
73 |
|
|
|
74 |
|
|
|
75 |
|
|
|
76 |
|
|
|
77 |
|
|
|
78 |
|
|
|
79 |
|
|
|
80 |
|
|
|
81 |
|
|
|
82 |
|
|
|
83 |
|
|
|
84 |
|
|
|
85 |
|
|
|
86 |
|
|
|
87 |
|
|
|
88 |
|
|
|
89 |
|
|
|
90 |
|
|
|
91 |
|
|
|
92 |
|
|
|
93 |
|