|
a |
|
b/U-Net/black_blood.py |
|
|
1 |
import argparse |
|
|
2 |
import logging |
|
|
3 |
import os |
|
|
4 |
|
|
|
5 |
import numpy as np |
|
|
6 |
import torch |
|
|
7 |
import torch.nn.functional as F |
|
|
8 |
from PIL import Image |
|
|
9 |
from torchvision import transforms |
|
|
10 |
|
|
|
11 |
from utils.data_loading import BasicDataset |
|
|
12 |
from unet.unet_model import UNet |
|
|
13 |
from utils.utils import plot_img_and_mask |
|
|
14 |
|
|
|
15 |
device = torch.device('cuda:4' if torch.cuda.is_available() else 'cpu') |
|
|
16 |
|
|
|
17 |
directory = r'/home/data/spleen_blood/data/test/imgs' |
|
|
18 |
|
|
|
19 |
def predict_img(net, |
|
|
20 |
full_img, |
|
|
21 |
device, |
|
|
22 |
scale_factor=1, |
|
|
23 |
out_threshold=0.5): |
|
|
24 |
net.eval() |
|
|
25 |
img = torch.from_numpy(BasicDataset.preprocess(None, full_img, scale_factor, is_mask=False)) |
|
|
26 |
img = img.unsqueeze(0) |
|
|
27 |
img = img.to(device=device, dtype=torch.float32) |
|
|
28 |
|
|
|
29 |
with torch.no_grad(): |
|
|
30 |
output = net(img).cpu() |
|
|
31 |
output = F.interpolate(output, (full_img.size[1], full_img.size[0]), mode='bilinear') |
|
|
32 |
if net.n_classes > 1: |
|
|
33 |
mask = output.argmax(dim=1) |
|
|
34 |
else: |
|
|
35 |
mask = torch.sigmoid(output) > out_threshold |
|
|
36 |
|
|
|
37 |
return mask[0].long().squeeze().numpy() |
|
|
38 |
|
|
|
39 |
def get_output_filenames(in_files): |
|
|
40 |
return f'{os.path.splitext(in_files)[0]}_OUT.png' |
|
|
41 |
|
|
|
42 |
|
|
|
43 |
def mask_to_image(mask: np.ndarray, mask_values): |
|
|
44 |
if isinstance(mask_values[0], list): |
|
|
45 |
out = np.zeros((mask.shape[-2], mask.shape[-1], len(mask_values[0])), dtype=np.uint8) |
|
|
46 |
elif mask_values == [0, 1]: |
|
|
47 |
out = np.zeros((mask.shape[-2], mask.shape[-1]), dtype=bool) |
|
|
48 |
else: |
|
|
49 |
out = np.zeros((mask.shape[-2], mask.shape[-1]), dtype=np.uint8) |
|
|
50 |
|
|
|
51 |
if mask.ndim == 3: |
|
|
52 |
mask = np.argmax(mask, axis=0) |
|
|
53 |
|
|
|
54 |
for i, v in enumerate(mask_values): |
|
|
55 |
out[mask == i] = v |
|
|
56 |
|
|
|
57 |
return Image.fromarray(out) |
|
|
58 |
|
|
|
59 |
|
|
|
60 |
if __name__ == '__main__': |
|
|
61 |
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s') |
|
|
62 |
|
|
|
63 |
model = './epoch_26_acc_0.90_best_val_acc.pth' |
|
|
64 |
|
|
|
65 |
for fn in os.listdir(directory): |
|
|
66 |
in_files = os.path.join(directory, fn) |
|
|
67 |
out_files = get_output_filenames(in_files) |
|
|
68 |
#print(fn, in_files, out_files) |
|
|
69 |
|
|
|
70 |
net = UNet(n_channels=1, n_classes=5, bilinear=True) |
|
|
71 |
|
|
|
72 |
logging.info(f'Loading model {model}') |
|
|
73 |
logging.info(f'Using device {device}') |
|
|
74 |
|
|
|
75 |
net.to(device=device) |
|
|
76 |
state_dict = torch.load(model, map_location=device) |
|
|
77 |
mask_values = state_dict.pop('mask_values', [0, 1]) |
|
|
78 |
net.load_state_dict(state_dict) |
|
|
79 |
|
|
|
80 |
logging.info('Model loaded!') |
|
|
81 |
|
|
|
82 |
filename = in_files |
|
|
83 |
logging.info(f'Predicting image {filename} ...') |
|
|
84 |
img = Image.open(filename) |
|
|
85 |
|
|
|
86 |
mask = predict_img(net=net, |
|
|
87 |
full_img=img, |
|
|
88 |
scale_factor=0.5, |
|
|
89 |
out_threshold=0.5, |
|
|
90 |
device=device) |
|
|
91 |
|
|
|
92 |
out_filename = out_files |
|
|
93 |
result = mask_to_image(mask, mask_values) |
|
|
94 |
result.save(out_filename) |
|
|
95 |
logging.info(f'Mask saved to {out_filename}') |
|
|
96 |
|
|
|
97 |
#logging.info(f'Visualizing results for image {filename}, close to continue...') |
|
|
98 |
#plot_img_and_mask(img, mask) |
|
|
99 |
|
|
|
100 |
|
|
|
101 |
|
|
|
102 |
|
|
|
103 |
|
|
|
104 |
|
|
|
105 |
|
|
|
106 |
|
|
|
107 |
|
|
|
108 |
|
|
|
109 |
|
|
|
110 |
|
|
|
111 |
|
|
|
112 |
|
|
|
113 |
|
|
|
114 |
|
|
|
115 |
|
|
|
116 |
|
|
|
117 |
|
|
|
118 |
|
|
|
119 |
|
|
|
120 |
|
|
|
121 |
|
|
|
122 |
|
|
|
123 |
|
|
|
124 |
|
|
|
125 |
|