|
a |
|
b/utils.py |
|
|
1 |
import os |
|
|
2 |
import io |
|
|
3 |
import base64 |
|
|
4 |
|
|
|
5 |
import numpy as np |
|
|
6 |
import pandas as pd |
|
|
7 |
import cv2 |
|
|
8 |
|
|
|
9 |
import matplotlib.pyplot as plt |
|
|
10 |
|
|
|
11 |
import torch |
|
|
12 |
import torch.nn as nn |
|
|
13 |
from albumentations import Normalize |
|
|
14 |
|
|
|
15 |
import time |
|
|
16 |
from IPython.display import clear_output |
|
|
17 |
from IPython.display import HTML |
|
|
18 |
|
|
|
19 |
from loss_metric import dice_coef_metric_per_classes, jaccard_coef_metric_per_classes |
|
|
20 |
|
|
|
21 |
def get_one_slice_data(img_name: str, |
|
|
22 |
mask_name: str, |
|
|
23 |
root_imgs_path: str = "images/", |
|
|
24 |
root_masks_path: str = "masks/",) -> np.ndarray: |
|
|
25 |
|
|
|
26 |
img_path = os.path.join('images/', img_name) |
|
|
27 |
mask_path = os.path.join('masks/', mask_name) |
|
|
28 |
one_slice_img = cv2.imread(img_path)#[:,:,0] uncomment for grayscale |
|
|
29 |
one_slice_mask = cv2.imread(mask_path) |
|
|
30 |
one_slice_mask[one_slice_mask < 240] = 0 # remove artifacts |
|
|
31 |
one_slice_mask[one_slice_mask >= 240] = 255 |
|
|
32 |
|
|
|
33 |
return one_slice_img, one_slice_mask |
|
|
34 |
|
|
|
35 |
|
|
|
36 |
def get_id_predictions(net: nn.Module, |
|
|
37 |
ct_scan_id_df: pd.DataFrame, |
|
|
38 |
root_imgs_dir: str, |
|
|
39 |
treshold: float = 0.3) -> list: |
|
|
40 |
|
|
|
41 |
""" |
|
|
42 |
Factory for getting predictions and storing them and images in lists as uint8 images. |
|
|
43 |
Params: |
|
|
44 |
net: model for prediction. |
|
|
45 |
ct_scan_id_df: df with unique patient id. |
|
|
46 |
root_imgs_dir: root path for images. |
|
|
47 |
treshold: threshold for probabilities. |
|
|
48 |
""" |
|
|
49 |
sigmoid = lambda x: 1 / (1 + np.exp(-x)) |
|
|
50 |
images = [] |
|
|
51 |
predictions = [] |
|
|
52 |
net.eval() |
|
|
53 |
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
54 |
print("device:", device) |
|
|
55 |
with torch.no_grad(): |
|
|
56 |
for idx in range(len(ct_scan_id_df)): |
|
|
57 |
img_name = ct_scan_id_df.loc[idx, "ImageId"] |
|
|
58 |
path = os.path.join(root_imgs_dir, img_name) |
|
|
59 |
|
|
|
60 |
img_ = cv2.imread(path) |
|
|
61 |
|
|
|
62 |
img = Normalize().apply(img_) |
|
|
63 |
tensor = torch.FloatTensor(img).permute(2, 0, 1).unsqueeze(0) |
|
|
64 |
prediction = net.forward(tensor.to(device)) |
|
|
65 |
prediction = prediction.cpu().detach().numpy() |
|
|
66 |
prediction = prediction.squeeze(0).transpose(1, 2, 0) |
|
|
67 |
prediction = sigmoid(prediction) |
|
|
68 |
prediction = (prediction >= treshold).astype(np.float32) |
|
|
69 |
|
|
|
70 |
predictions.append((prediction * 255).astype("uint8")) |
|
|
71 |
images.append(img_) |
|
|
72 |
|
|
|
73 |
return images, predictions |
|
|
74 |
|
|
|
75 |
|
|
|
76 |
# Save image in original resolution |
|
|
77 |
# helpful link - https://stackoverflow.com/questions/34768717/matplotlib-unable-to-save-image-in-same-resolution-as-original-image |
|
|
78 |
|
|
|
79 |
def get_overlaid_masks_on_image( |
|
|
80 |
one_slice_image: np.ndarray, |
|
|
81 |
one_slice_mask: np.ndarray, |
|
|
82 |
w: float = 512, |
|
|
83 |
h: float = 512, |
|
|
84 |
dpi: float = 100, |
|
|
85 |
write: bool = False, |
|
|
86 |
path_to_save: str = '/content/', |
|
|
87 |
name_to_save: str = 'img_name'): |
|
|
88 |
"""overlap masks on image and save this as a new image.""" |
|
|
89 |
|
|
|
90 |
path_to_save_ = os.path.join(path_to_save, name_to_save) |
|
|
91 |
lung, heart, trachea = [one_slice_mask[:, :, i] for i in range(3)] |
|
|
92 |
figsize = (w / dpi), (h / dpi) |
|
|
93 |
fig = plt.figure(figsize=(figsize)) |
|
|
94 |
fig.add_axes([0, 0, 1, 1]) |
|
|
95 |
|
|
|
96 |
# image |
|
|
97 |
plt.imshow(one_slice_image, cmap="bone") |
|
|
98 |
|
|
|
99 |
# overlaying segmentation masks |
|
|
100 |
plt.imshow(np.ma.masked_where(lung == False, lung), |
|
|
101 |
cmap='cool', alpha=0.3) |
|
|
102 |
plt.imshow(np.ma.masked_where(heart == False, heart), |
|
|
103 |
cmap='autumn', alpha=0.3) |
|
|
104 |
plt.imshow(np.ma.masked_where(trachea == False, trachea), |
|
|
105 |
cmap='autumn_r', alpha=0.3) |
|
|
106 |
|
|
|
107 |
plt.axis('off') |
|
|
108 |
fig.savefig(f"{path_to_save_}.png",bbox_inches='tight', |
|
|
109 |
pad_inches=0.0, dpi=dpi, format="png") |
|
|
110 |
if write: |
|
|
111 |
plt.close() |
|
|
112 |
else: |
|
|
113 |
plt.show() |
|
|
114 |
|
|
|
115 |
|
|
|
116 |
def get_overlaid_masks_on_full_ctscan(ct_scan_id_df: pd.DataFrame, path_to_save: str): |
|
|
117 |
""" |
|
|
118 |
Creating images with overlaid masks on each slice of CT scan. |
|
|
119 |
Params: |
|
|
120 |
ct_scan_id_df: df with unique patient id. |
|
|
121 |
path_to_save: path to save images. |
|
|
122 |
""" |
|
|
123 |
num_slice = len(ct_scan_id_df) |
|
|
124 |
for slice_ in range(num_slice): |
|
|
125 |
img_name = ct_scan_id_df.loc[slice_, "ImageId"] |
|
|
126 |
mask_name = ct_scan_id_df.loc[slice_, "MaskId"] |
|
|
127 |
one_slice_img, one_slice_mask = get_one_slice_data(img_name, mask_name) |
|
|
128 |
get_overlaid_masks_on_image(one_slice_img, |
|
|
129 |
one_slice_mask, |
|
|
130 |
write=True, |
|
|
131 |
path_to_save=path_to_save, |
|
|
132 |
name_to_save=str(slice_) |
|
|
133 |
) |
|
|
134 |
|
|
|
135 |
def create_video(path_to_imgs: str, video_name: str, framerate: int): |
|
|
136 |
""" |
|
|
137 |
Create video from images. |
|
|
138 |
Params: |
|
|
139 |
path_to_imgs: path to dir with images. |
|
|
140 |
video_name: name for saving video. |
|
|
141 |
framerate: num frames per sec in video. |
|
|
142 |
""" |
|
|
143 |
img_names = sorted(os.listdir(path_to_imgs), key=lambda x: int(x[:-4])) # img_name must be numbers |
|
|
144 |
img_path = os.path.join(path_to_imgs, img_names[0]) |
|
|
145 |
frame_width, frame_height, _ = cv2.imread(img_path).shape |
|
|
146 |
fourc = cv2.VideoWriter_fourcc(*'MP4V') |
|
|
147 |
video = cv2.VideoWriter(video_name + ".mp4", |
|
|
148 |
fourc, |
|
|
149 |
framerate, |
|
|
150 |
(frame_width, frame_height)) |
|
|
151 |
|
|
|
152 |
for img_name in img_names: |
|
|
153 |
img_path = os.path.join(path_to_imgs, img_name) |
|
|
154 |
image = cv2.imread(img_path) |
|
|
155 |
video.write(image) |
|
|
156 |
|
|
|
157 |
cv2.destroyAllWindows() |
|
|
158 |
video.release() |
|
|
159 |
|
|
|
160 |
|
|
|
161 |
def compute_scores_per_classes(model, |
|
|
162 |
dataloader, |
|
|
163 |
classes): |
|
|
164 |
""" |
|
|
165 |
Compute Dice and Jaccard coefficients for each class. |
|
|
166 |
Params: |
|
|
167 |
model: neural net for make predictions. |
|
|
168 |
dataloader: dataset object to load data from. |
|
|
169 |
classes: list with classes. |
|
|
170 |
Returns: dictionaries with dice and jaccard coefficients for each class for each slice. |
|
|
171 |
""" |
|
|
172 |
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
173 |
dice_scores_per_classes = {key: list() for key in classes} |
|
|
174 |
iou_scores_per_classes = {key: list() for key in classes} |
|
|
175 |
|
|
|
176 |
with torch.no_grad(): |
|
|
177 |
for i, (imgs, targets) in enumerate(dataloader): |
|
|
178 |
imgs, targets = imgs.to(device), targets.to(device) |
|
|
179 |
logits = model(imgs) |
|
|
180 |
logits = logits.detach().cpu().numpy() |
|
|
181 |
targets = targets.detach().cpu().numpy() |
|
|
182 |
|
|
|
183 |
dice_scores = dice_coef_metric_per_classes(logits, targets) |
|
|
184 |
iou_scores = jaccard_coef_metric_per_classes(logits, targets) |
|
|
185 |
|
|
|
186 |
for key in dice_scores.keys(): |
|
|
187 |
dice_scores_per_classes[key].extend(dice_scores[key]) |
|
|
188 |
|
|
|
189 |
for key in iou_scores.keys(): |
|
|
190 |
iou_scores_per_classes[key].extend(iou_scores[key]) |
|
|
191 |
|
|
|
192 |
return dice_scores_per_classes, iou_scores_per_classes |