a b/src/production.py
1
from utils import get_model
2
from data_functions import get_transforms
3
from torch.utils.data import Dataset, DataLoader
4
import cv2
5
import torch
6
import numpy as np
7
import nibabel as nib
8
import random
9
import string
10
import os
11
from config import BinaryModelConfig, MultiModelConfig, LungsModelConfig
12
from PIL import Image, ImageFont, ImageDraw
13
14
15
def get_setup():
16
    # preparing
17
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
18
    models = []
19
    transforms = []
20
21
    # setup for every model
22
    for cfg in [BinaryModelConfig, MultiModelConfig, LungsModelConfig]:
23
        # getting model
24
        model = get_model(cfg)(cfg)
25
        model.load_state_dict(torch.load(cfg.best_dict, map_location=device))
26
        model.eval()
27
        models.append(model)
28
29
        # getting transforms
30
        _, test_transforms = get_transforms(cfg)
31
        transforms.append(test_transforms)
32
    return models, transforms
33
34
35
def generate_folder_name():
36
    return ''.join(random.choice(string.ascii_lowercase) for _ in range(7)) + '/'
37
38
39
def make_legend(image, annotation):
40
    # rgb_image = cv2.cvtColor(bgr_image, cv2.COLOR_BGR2RGB)
41
    rgb_image = np.round(image).astype(np.uint8)
42
    image = Image.fromarray(rgb_image)
43
    old_size = image.size
44
    if len(annotation.split('\n')) == 3:
45
        new_size = (old_size[0], old_size[1] + 130)
46
        new_image = Image.new('RGB', new_size)
47
        new_image.paste(image)
48
        font = ImageFont.truetype("arial.ttf", 30)
49
        draw = ImageDraw.Draw(new_image)
50
        draw.ellipse((20 + 2, new_size[1] - 30 + 2, 40 - 2, new_size[1] - 10 - 2), fill=(0, 255, 0))
51
        draw.text((50, new_size[1] - 40),
52
                  annotation.split('\n')[1], (255, 255, 255), font=font)
53
        draw.ellipse((20 + 2, new_size[1] - 70 + 2, 40 - 2, new_size[1] - 50 - 2), fill=(0, 0, 255))
54
        draw.text((50, new_size[1] - 80),
55
                  annotation.split('\n')[2], (255, 255, 255), font=font)
56
        draw.text((50, new_size[1] - 120),
57
                  annotation.split('\n')[0], (255, 255, 255), font=font)
58
    else:
59
        new_size = (old_size[0], old_size[1] + 90)
60
        new_image = Image.new('RGB', new_size)
61
        new_image.paste(image)
62
        font = ImageFont.truetype("arial.ttf", 30)
63
        draw = ImageDraw.Draw(new_image)
64
        draw.ellipse((20 + 2, new_size[1] - 30 + 2, 40 - 2, new_size[1] - 10 - 2), fill=(0, 255, 255))
65
        draw.text((50, new_size[1] - 40),
66
                  annotation.split('\n')[1], (255, 255, 255), font=font)
67
        draw.text((50, new_size[1] - 80),
68
                  annotation.split('\n')[0], (255, 255, 255), font=font)
69
    return np.asarray(new_image)
70
71
72
def data_to_paths(data, save_folder):
73
    all_paths = []
74
    create_folder(save_folder)
75
    if not os.path.isdir(data):  # single file
76
        data = [data]
77
    else:  # folder of files
78
        data = [os.path.join(data, x) for x in os.listdir(data)]
79
80
    for path in data:
81
        if not os.path.exists(path):  # path not exists
82
            print(f'Path \"{path}\" not exists')
83
            continue
84
        # reformatting by type
85
        if path.endswith('.png') or path.endswith('.jpg') or path.endswith('.jpeg'):
86
            all_paths.append(path)
87
        elif path.endswith('.nii') or path.endswith('.nii.gz'):
88
            # NIftI format will be png format in folder "slices"
89
            if not os.path.exists(os.path.join(save_folder, 'slices')):
90
                os.mkdir(os.path.join(save_folder, 'slices'))
91
92
            paths = []
93
94
            # NIftI to numpy arrays
95
            nii_name = path.split('\\')[-1].split('.')[0]
96
            images = nib.load(path)
97
            images = np.array(images.dataobj)
98
            images = np.moveaxis(images, -1, 0)
99
100
            for i, image in enumerate(images):
101
                image = window_image(image)  # windowing
102
                image += abs(np.min(image))
103
                image = image / np.max(image)
104
                # saving like png image
105
                image_path = os.path.join(save_folder, 'slices', nii_name + '_' + str(i) + '.png')
106
                cv2.imwrite(image_path, image * 255)
107
108
                paths.append(image_path)
109
            all_paths.extend(paths)
110
        else:
111
            print(f'Path \"{path}\" is not supported format')
112
    return all_paths
113
114
115
def window_image(image, window_center=-600, window_width=1500):
116
    img_min = window_center - window_width // 2
117
    img_max = window_center + window_width // 2
118
    window_image = image.copy()
119
    window_image[window_image < img_min] = img_min
120
    window_image[window_image > img_max] = img_max
121
    return window_image
122
123
124
def read_files(files):
125
    # creating folder for user
126
    folder_name = generate_folder_name()
127
    path = 'images/' + folder_name
128
    if not os.path.exists(path):
129
        os.mkdir(path)
130
131
    paths = []
132
    for file in files:
133
        paths.append([])
134
        # if NIfTI we should get slices
135
        if file.name.endswith('.nii') or file.name.endswith('.nii.gz'):
136
            # saving file from user
137
            nii_path = path + file.name
138
            open(nii_path, 'wb').write(file.getvalue())
139
140
            # loading
141
            images = nib.load(nii_path)
142
            images = np.array(images.dataobj)
143
            images = np.moveaxis(images, -1, 0)
144
145
            os.remove(nii_path)  # clearing
146
147
            for i, image in enumerate(images):  # saving every slice in NIftI
148
                # windowing
149
                image = window_image(image)
150
                image += abs(np.min(image))
151
                image = image / np.max(image)
152
153
                # saving
154
                image_path = path + file.name.split('.')[0] + f'_{i}.png'
155
                cv2.imwrite(image_path, image * 255)
156
                paths[-1].append(image_path)
157
158
        else:
159
            with open(path + file.name, 'wb') as f:
160
                f.write(file.getvalue())
161
162
            paths[-1].append(path + file.name)
163
    return paths, folder_name
164
165
166
def create_folder(path):
167
    if not os.path.exists(path):
168
        os.mkdir(path)
169
170
171
def get_predictions(paths, models, transforms, multi_class=True):
172
    # preparing
173
    binary_model, multi_model, lung_model = models
174
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
175
    dataloader = DataLoader(ProductionCovid19Dataset(paths, transform=transforms[0]), batch_size=1, drop_last=False)
176
177
    # prediction
178
    for X, _ in dataloader:
179
        X = X.to(device)
180
        X = X / torch.max(X)
181
182
        with torch.no_grad():
183
            pred = binary_model(X)
184
            lung = lung_model(X)
185
186
            img = X.squeeze().cpu()
187
            pred = pred.squeeze().cpu()
188
            pred = torch.argmax(pred, 0).float()
189
            lung = lung.squeeze().cpu()
190
            lung = torch.argmax(lung, 0).float()
191
192
            # if multi class we should use both models to predict
193
            if multi_class:
194
                multi_output = multi_model(X)
195
                multi_pred = multi_output.squeeze().cpu()
196
                multi_pred = torch.argmax(multi_pred, 0).float()
197
                multi_pred = (multi_pred % 3)  # model on trained on 3 classes but using only 2
198
                pred = pred + (multi_pred == 2)  # ground-glass from binary model and consolidation from second
199
            pred = pred  # to [0;1] range
200
            yield img.numpy(), pred.numpy(), lung.numpy()
201
202
203
def combo_with_lungs(disease, lungs):
204
    return disease * (lungs == 1), disease * (lungs == 2)
205
206
207
def make_masks(paths, models, transforms, multi_class=True):
208
    for path, (img, pred, lung) in zip(paths, get_predictions(paths, models, transforms, multi_class)):
209
        lung_left = (lung == 1)
210
        lung_right = (lung == 2)
211
        not_disease = (pred == 0)
212
        if multi_class:
213
            consolidation = (pred == 2)  # red channel
214
            ground_glass = (pred == 1)  # green channel
215
216
            img = np.array([np.zeros_like(img), ground_glass, consolidation]) + img * not_disease
217
218
            annotation = f'              left   |   right\n' \
219
                         f' Ground-glass - {np.sum(ground_glass * lung_left) / np.sum(lung_left) * 100:.1f}% | {np.sum(ground_glass * lung_right) / np.sum(lung_right) * 100:.1f}%\n' \
220
                         f'Consolidation - {np.sum(consolidation * lung_left) / np.sum(lung_left) * 100:.1f}% | {np.sum(consolidation * lung_right) / np.sum(lung_right) * 100:.1f}%'
221
        else:
222
            # disease percents
223
            disease = (pred == 1)
224
225
            annotation = f'              left   |   right\n' \
226
                         f'Disease - {np.sum(disease * lung_left) / np.sum(lung_left) * 100:.1f}%  |  {np.sum(disease * lung_right) / np.sum(lung_right) * 100:.1f}%'
227
228
            img = np.array([np.zeros_like(img), disease, disease]) + img * not_disease
229
230
        img = img.swapaxes(0, -1)
231
        img = np.round(img * 255)
232
        img = cv2.rotate(img, cv2.ROTATE_90_COUNTERCLOCKWISE)
233
        img = cv2.flip(img, 0)
234
        yield img, annotation, path
235
236
237
class ProductionCovid19Dataset(Dataset):
238
    def __init__(self, paths, transform=None):
239
        self.paths = paths
240
        self.transform = transform
241
        self._len = len(paths)
242
243
    def __len__(self):
244
        return self._len
245
246
    def __getitem__(self, index):
247
        path = self.paths[index]
248
        image = cv2.imread(path)
249
        image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
250
        if self.transform:
251
            transformed = self.transform(image=image)
252
            image = transformed['image']
253
        image = torch.from_numpy(np.array([image], dtype=np.float))
254
        image = image.type(torch.FloatTensor)
255
        return image, 'None'