|
a |
|
b/src/production.py |
|
|
1 |
from utils import get_model |
|
|
2 |
from data_functions import get_transforms |
|
|
3 |
from torch.utils.data import Dataset, DataLoader |
|
|
4 |
import cv2 |
|
|
5 |
import torch |
|
|
6 |
import numpy as np |
|
|
7 |
import nibabel as nib |
|
|
8 |
import random |
|
|
9 |
import string |
|
|
10 |
import os |
|
|
11 |
from config import BinaryModelConfig, MultiModelConfig, LungsModelConfig |
|
|
12 |
from PIL import Image, ImageFont, ImageDraw |
|
|
13 |
|
|
|
14 |
|
|
|
15 |
def get_setup(): |
|
|
16 |
# preparing |
|
|
17 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
18 |
models = [] |
|
|
19 |
transforms = [] |
|
|
20 |
|
|
|
21 |
# setup for every model |
|
|
22 |
for cfg in [BinaryModelConfig, MultiModelConfig, LungsModelConfig]: |
|
|
23 |
# getting model |
|
|
24 |
model = get_model(cfg)(cfg) |
|
|
25 |
model.load_state_dict(torch.load(cfg.best_dict, map_location=device)) |
|
|
26 |
model.eval() |
|
|
27 |
models.append(model) |
|
|
28 |
|
|
|
29 |
# getting transforms |
|
|
30 |
_, test_transforms = get_transforms(cfg) |
|
|
31 |
transforms.append(test_transforms) |
|
|
32 |
return models, transforms |
|
|
33 |
|
|
|
34 |
|
|
|
35 |
def generate_folder_name(): |
|
|
36 |
return ''.join(random.choice(string.ascii_lowercase) for _ in range(7)) + '/' |
|
|
37 |
|
|
|
38 |
|
|
|
39 |
def make_legend(image, annotation): |
|
|
40 |
# rgb_image = cv2.cvtColor(bgr_image, cv2.COLOR_BGR2RGB) |
|
|
41 |
rgb_image = np.round(image).astype(np.uint8) |
|
|
42 |
image = Image.fromarray(rgb_image) |
|
|
43 |
old_size = image.size |
|
|
44 |
if len(annotation.split('\n')) == 3: |
|
|
45 |
new_size = (old_size[0], old_size[1] + 130) |
|
|
46 |
new_image = Image.new('RGB', new_size) |
|
|
47 |
new_image.paste(image) |
|
|
48 |
font = ImageFont.truetype("arial.ttf", 30) |
|
|
49 |
draw = ImageDraw.Draw(new_image) |
|
|
50 |
draw.ellipse((20 + 2, new_size[1] - 30 + 2, 40 - 2, new_size[1] - 10 - 2), fill=(0, 255, 0)) |
|
|
51 |
draw.text((50, new_size[1] - 40), |
|
|
52 |
annotation.split('\n')[1], (255, 255, 255), font=font) |
|
|
53 |
draw.ellipse((20 + 2, new_size[1] - 70 + 2, 40 - 2, new_size[1] - 50 - 2), fill=(0, 0, 255)) |
|
|
54 |
draw.text((50, new_size[1] - 80), |
|
|
55 |
annotation.split('\n')[2], (255, 255, 255), font=font) |
|
|
56 |
draw.text((50, new_size[1] - 120), |
|
|
57 |
annotation.split('\n')[0], (255, 255, 255), font=font) |
|
|
58 |
else: |
|
|
59 |
new_size = (old_size[0], old_size[1] + 90) |
|
|
60 |
new_image = Image.new('RGB', new_size) |
|
|
61 |
new_image.paste(image) |
|
|
62 |
font = ImageFont.truetype("arial.ttf", 30) |
|
|
63 |
draw = ImageDraw.Draw(new_image) |
|
|
64 |
draw.ellipse((20 + 2, new_size[1] - 30 + 2, 40 - 2, new_size[1] - 10 - 2), fill=(0, 255, 255)) |
|
|
65 |
draw.text((50, new_size[1] - 40), |
|
|
66 |
annotation.split('\n')[1], (255, 255, 255), font=font) |
|
|
67 |
draw.text((50, new_size[1] - 80), |
|
|
68 |
annotation.split('\n')[0], (255, 255, 255), font=font) |
|
|
69 |
return np.asarray(new_image) |
|
|
70 |
|
|
|
71 |
|
|
|
72 |
def data_to_paths(data, save_folder): |
|
|
73 |
all_paths = [] |
|
|
74 |
create_folder(save_folder) |
|
|
75 |
if not os.path.isdir(data): # single file |
|
|
76 |
data = [data] |
|
|
77 |
else: # folder of files |
|
|
78 |
data = [os.path.join(data, x) for x in os.listdir(data)] |
|
|
79 |
|
|
|
80 |
for path in data: |
|
|
81 |
if not os.path.exists(path): # path not exists |
|
|
82 |
print(f'Path \"{path}\" not exists') |
|
|
83 |
continue |
|
|
84 |
# reformatting by type |
|
|
85 |
if path.endswith('.png') or path.endswith('.jpg') or path.endswith('.jpeg'): |
|
|
86 |
all_paths.append(path) |
|
|
87 |
elif path.endswith('.nii') or path.endswith('.nii.gz'): |
|
|
88 |
# NIftI format will be png format in folder "slices" |
|
|
89 |
if not os.path.exists(os.path.join(save_folder, 'slices')): |
|
|
90 |
os.mkdir(os.path.join(save_folder, 'slices')) |
|
|
91 |
|
|
|
92 |
paths = [] |
|
|
93 |
|
|
|
94 |
# NIftI to numpy arrays |
|
|
95 |
nii_name = path.split('\\')[-1].split('.')[0] |
|
|
96 |
images = nib.load(path) |
|
|
97 |
images = np.array(images.dataobj) |
|
|
98 |
images = np.moveaxis(images, -1, 0) |
|
|
99 |
|
|
|
100 |
for i, image in enumerate(images): |
|
|
101 |
image = window_image(image) # windowing |
|
|
102 |
image += abs(np.min(image)) |
|
|
103 |
image = image / np.max(image) |
|
|
104 |
# saving like png image |
|
|
105 |
image_path = os.path.join(save_folder, 'slices', nii_name + '_' + str(i) + '.png') |
|
|
106 |
cv2.imwrite(image_path, image * 255) |
|
|
107 |
|
|
|
108 |
paths.append(image_path) |
|
|
109 |
all_paths.extend(paths) |
|
|
110 |
else: |
|
|
111 |
print(f'Path \"{path}\" is not supported format') |
|
|
112 |
return all_paths |
|
|
113 |
|
|
|
114 |
|
|
|
115 |
def window_image(image, window_center=-600, window_width=1500): |
|
|
116 |
img_min = window_center - window_width // 2 |
|
|
117 |
img_max = window_center + window_width // 2 |
|
|
118 |
window_image = image.copy() |
|
|
119 |
window_image[window_image < img_min] = img_min |
|
|
120 |
window_image[window_image > img_max] = img_max |
|
|
121 |
return window_image |
|
|
122 |
|
|
|
123 |
|
|
|
124 |
def read_files(files): |
|
|
125 |
# creating folder for user |
|
|
126 |
folder_name = generate_folder_name() |
|
|
127 |
path = 'images/' + folder_name |
|
|
128 |
if not os.path.exists(path): |
|
|
129 |
os.mkdir(path) |
|
|
130 |
|
|
|
131 |
paths = [] |
|
|
132 |
for file in files: |
|
|
133 |
paths.append([]) |
|
|
134 |
# if NIfTI we should get slices |
|
|
135 |
if file.name.endswith('.nii') or file.name.endswith('.nii.gz'): |
|
|
136 |
# saving file from user |
|
|
137 |
nii_path = path + file.name |
|
|
138 |
open(nii_path, 'wb').write(file.getvalue()) |
|
|
139 |
|
|
|
140 |
# loading |
|
|
141 |
images = nib.load(nii_path) |
|
|
142 |
images = np.array(images.dataobj) |
|
|
143 |
images = np.moveaxis(images, -1, 0) |
|
|
144 |
|
|
|
145 |
os.remove(nii_path) # clearing |
|
|
146 |
|
|
|
147 |
for i, image in enumerate(images): # saving every slice in NIftI |
|
|
148 |
# windowing |
|
|
149 |
image = window_image(image) |
|
|
150 |
image += abs(np.min(image)) |
|
|
151 |
image = image / np.max(image) |
|
|
152 |
|
|
|
153 |
# saving |
|
|
154 |
image_path = path + file.name.split('.')[0] + f'_{i}.png' |
|
|
155 |
cv2.imwrite(image_path, image * 255) |
|
|
156 |
paths[-1].append(image_path) |
|
|
157 |
|
|
|
158 |
else: |
|
|
159 |
with open(path + file.name, 'wb') as f: |
|
|
160 |
f.write(file.getvalue()) |
|
|
161 |
|
|
|
162 |
paths[-1].append(path + file.name) |
|
|
163 |
return paths, folder_name |
|
|
164 |
|
|
|
165 |
|
|
|
166 |
def create_folder(path): |
|
|
167 |
if not os.path.exists(path): |
|
|
168 |
os.mkdir(path) |
|
|
169 |
|
|
|
170 |
|
|
|
171 |
def get_predictions(paths, models, transforms, multi_class=True): |
|
|
172 |
# preparing |
|
|
173 |
binary_model, multi_model, lung_model = models |
|
|
174 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
175 |
dataloader = DataLoader(ProductionCovid19Dataset(paths, transform=transforms[0]), batch_size=1, drop_last=False) |
|
|
176 |
|
|
|
177 |
# prediction |
|
|
178 |
for X, _ in dataloader: |
|
|
179 |
X = X.to(device) |
|
|
180 |
X = X / torch.max(X) |
|
|
181 |
|
|
|
182 |
with torch.no_grad(): |
|
|
183 |
pred = binary_model(X) |
|
|
184 |
lung = lung_model(X) |
|
|
185 |
|
|
|
186 |
img = X.squeeze().cpu() |
|
|
187 |
pred = pred.squeeze().cpu() |
|
|
188 |
pred = torch.argmax(pred, 0).float() |
|
|
189 |
lung = lung.squeeze().cpu() |
|
|
190 |
lung = torch.argmax(lung, 0).float() |
|
|
191 |
|
|
|
192 |
# if multi class we should use both models to predict |
|
|
193 |
if multi_class: |
|
|
194 |
multi_output = multi_model(X) |
|
|
195 |
multi_pred = multi_output.squeeze().cpu() |
|
|
196 |
multi_pred = torch.argmax(multi_pred, 0).float() |
|
|
197 |
multi_pred = (multi_pred % 3) # model on trained on 3 classes but using only 2 |
|
|
198 |
pred = pred + (multi_pred == 2) # ground-glass from binary model and consolidation from second |
|
|
199 |
pred = pred # to [0;1] range |
|
|
200 |
yield img.numpy(), pred.numpy(), lung.numpy() |
|
|
201 |
|
|
|
202 |
|
|
|
203 |
def combo_with_lungs(disease, lungs): |
|
|
204 |
return disease * (lungs == 1), disease * (lungs == 2) |
|
|
205 |
|
|
|
206 |
|
|
|
207 |
def make_masks(paths, models, transforms, multi_class=True): |
|
|
208 |
for path, (img, pred, lung) in zip(paths, get_predictions(paths, models, transforms, multi_class)): |
|
|
209 |
lung_left = (lung == 1) |
|
|
210 |
lung_right = (lung == 2) |
|
|
211 |
not_disease = (pred == 0) |
|
|
212 |
if multi_class: |
|
|
213 |
consolidation = (pred == 2) # red channel |
|
|
214 |
ground_glass = (pred == 1) # green channel |
|
|
215 |
|
|
|
216 |
img = np.array([np.zeros_like(img), ground_glass, consolidation]) + img * not_disease |
|
|
217 |
|
|
|
218 |
annotation = f' left | right\n' \ |
|
|
219 |
f' Ground-glass - {np.sum(ground_glass * lung_left) / np.sum(lung_left) * 100:.1f}% | {np.sum(ground_glass * lung_right) / np.sum(lung_right) * 100:.1f}%\n' \ |
|
|
220 |
f'Consolidation - {np.sum(consolidation * lung_left) / np.sum(lung_left) * 100:.1f}% | {np.sum(consolidation * lung_right) / np.sum(lung_right) * 100:.1f}%' |
|
|
221 |
else: |
|
|
222 |
# disease percents |
|
|
223 |
disease = (pred == 1) |
|
|
224 |
|
|
|
225 |
annotation = f' left | right\n' \ |
|
|
226 |
f'Disease - {np.sum(disease * lung_left) / np.sum(lung_left) * 100:.1f}% | {np.sum(disease * lung_right) / np.sum(lung_right) * 100:.1f}%' |
|
|
227 |
|
|
|
228 |
img = np.array([np.zeros_like(img), disease, disease]) + img * not_disease |
|
|
229 |
|
|
|
230 |
img = img.swapaxes(0, -1) |
|
|
231 |
img = np.round(img * 255) |
|
|
232 |
img = cv2.rotate(img, cv2.ROTATE_90_COUNTERCLOCKWISE) |
|
|
233 |
img = cv2.flip(img, 0) |
|
|
234 |
yield img, annotation, path |
|
|
235 |
|
|
|
236 |
|
|
|
237 |
class ProductionCovid19Dataset(Dataset): |
|
|
238 |
def __init__(self, paths, transform=None): |
|
|
239 |
self.paths = paths |
|
|
240 |
self.transform = transform |
|
|
241 |
self._len = len(paths) |
|
|
242 |
|
|
|
243 |
def __len__(self): |
|
|
244 |
return self._len |
|
|
245 |
|
|
|
246 |
def __getitem__(self, index): |
|
|
247 |
path = self.paths[index] |
|
|
248 |
image = cv2.imread(path) |
|
|
249 |
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) |
|
|
250 |
if self.transform: |
|
|
251 |
transformed = self.transform(image=image) |
|
|
252 |
image = transformed['image'] |
|
|
253 |
image = torch.from_numpy(np.array([image], dtype=np.float)) |
|
|
254 |
image = image.type(torch.FloatTensor) |
|
|
255 |
return image, 'None' |