|
a |
|
b/src/dataset/utils/visualization.py |
|
|
1 |
import os |
|
|
2 |
from matplotlib import pyplot as plt |
|
|
3 |
import matplotlib |
|
|
4 |
import time |
|
|
5 |
import numpy as np |
|
|
6 |
from matplotlib.colors import LinearSegmentedColormap |
|
|
7 |
from nilearn.plotting import plot_anat |
|
|
8 |
from matplotlib import cm |
|
|
9 |
from skimage.transform import resize |
|
|
10 |
import io |
|
|
11 |
|
|
|
12 |
|
|
|
13 |
def plot_batch(batch, seg: bool = False, slice: int = 32, batch_size: int=4): |
|
|
14 |
|
|
|
15 |
def unnorm(data, epsilon=1e-8): |
|
|
16 |
non_zero = data[data > 0.0] |
|
|
17 |
mean = non_zero.mean() |
|
|
18 |
std = non_zero.std() + epsilon |
|
|
19 |
out = data * std + mean |
|
|
20 |
out[data == 0] = 0 |
|
|
21 |
return out |
|
|
22 |
|
|
|
23 |
plt.figure(figsize=(10, 3.5)) |
|
|
24 |
|
|
|
25 |
for i, volume in enumerate(batch): |
|
|
26 |
plt.subplot(1, batch_size + 1, i + 1) |
|
|
27 |
|
|
|
28 |
img = volume[:, slice, :].T if seg else volume[0, :, slice, :].T |
|
|
29 |
|
|
|
30 |
npimg = img.cpu().detach().numpy() |
|
|
31 |
img = npimg if seg else unnorm(npimg) |
|
|
32 |
plt.imshow(img, cmap="gray") |
|
|
33 |
plt.axis("off") |
|
|
34 |
|
|
|
35 |
|
|
|
36 |
buf = io.BytesIO() |
|
|
37 |
plt.savefig(buf, format='png') |
|
|
38 |
buf.seek(0) |
|
|
39 |
plt.close() |
|
|
40 |
return buf |
|
|
41 |
|
|
|
42 |
|
|
|
43 |
|
|
|
44 |
def plot_3_view(modal: str, vol: np.ndarray, s: int=100, discrete: bool=False, |
|
|
45 |
color_map: str="gray", save: bool=True): |
|
|
46 |
|
|
|
47 |
views = [vol[s, :, :], vol[:, s, :], vol[:, :, s]] |
|
|
48 |
fig = plt.figure(figsize=(10, 3.5)) |
|
|
49 |
|
|
|
50 |
for position in range(1, len(views) + 1): |
|
|
51 |
plt.subplot(1, len(views), position) |
|
|
52 |
plt.imshow(views[position - 1].T, cmap=color_map) |
|
|
53 |
plt.axis("off") |
|
|
54 |
if discrete: |
|
|
55 |
plt.clim(0, 4) |
|
|
56 |
|
|
|
57 |
if discrete: |
|
|
58 |
plt.colorbar(ticks=range(5)) |
|
|
59 |
else: |
|
|
60 |
plt.colorbar() |
|
|
61 |
|
|
|
62 |
if save: |
|
|
63 |
fig.savefig(f'plot_{modal}_{time.time()}.png') |
|
|
64 |
else: |
|
|
65 |
plt.show() |
|
|
66 |
|
|
|
67 |
|
|
|
68 |
def plot_3_view_uncertainty(modal: str, vol: np.ndarray, s: int=100, color_map: str="gray", save: bool=True): |
|
|
69 |
|
|
|
70 |
views = [vol[s, :, :], vol[:, s, :], vol[:, :, s]] |
|
|
71 |
fig = plt.figure(figsize=(10, 3.5)) |
|
|
72 |
|
|
|
73 |
for position in range(1, len(views) + 1): |
|
|
74 |
plt.subplot(1, len(views), position) |
|
|
75 |
plt.imshow(views[position - 1].T, cmap=color_map) |
|
|
76 |
plt.axis("off") |
|
|
77 |
plt.clim(0, 100) |
|
|
78 |
|
|
|
79 |
plt.colorbar() |
|
|
80 |
|
|
|
81 |
if save: |
|
|
82 |
fig.savefig(f'plot_unc_{modal}_{time.time()}.png') |
|
|
83 |
else: |
|
|
84 |
plt.show() |
|
|
85 |
|
|
|
86 |
|
|
|
87 |
def plot_axis_overlayed(modalities: dict, segmentation_mask: str, subject: int, axis: str = 'x', save: bool=False): |
|
|
88 |
"""Save or show figure of provided axis""" |
|
|
89 |
fig, axes = plt.subplots(len(modalities), 1) |
|
|
90 |
|
|
|
91 |
for i, (modality_name, modality_path) in enumerate(modalities.items()): |
|
|
92 |
display = plot_anat(modality_path, draw_cross=False, display_mode=axis, axes=axes[i], figure=fig, title=modality_name) |
|
|
93 |
display.add_overlay(segmentation_mask) |
|
|
94 |
|
|
|
95 |
if save: |
|
|
96 |
fig.savefig(f'results/patient_{subject}.png') |
|
|
97 |
else: |
|
|
98 |
matplotlib.use('TkAgg') |
|
|
99 |
plt.show() |
|
|
100 |
|
|
|
101 |
|
|
|
102 |
def plot_brain_batch_per_patient(patient_ids, data, save=True): |
|
|
103 |
for patient in patient_ids: |
|
|
104 |
patient = data[patient.item()] |
|
|
105 |
patient_modalities = list(map(lambda x: os.path.join(patient.data_path, patient.patch_name, x), [patient.flair, patient.t2, patient.t1, patient.t1ce])) |
|
|
106 |
patient_modalities = {"flair": patient_modalities[0],"t2": patient_modalities[1],"t1": patient_modalities[2],"t1ce": patient_modalities[3] } |
|
|
107 |
patient_seg = os.path.join(patient.data_path, patient.patch_name, patient.seg) |
|
|
108 |
plot_axis_overlayed(patient_modalities, patient_seg, patient.patch_name, axis='x', save=save) |
|
|
109 |
|
|
|
110 |
|
|
|
111 |
def plot_batch_slice(images, gt, slice = 10, save=True): |
|
|
112 |
"""Plot, for a given batch, different types of visualizations. |
|
|
113 |
If paths: plot overlayed axis plot |
|
|
114 |
If paths=None: plot slice of volume |
|
|
115 |
""" |
|
|
116 |
for element_index in range(0, len(images)): |
|
|
117 |
for i, mod_id in enumerate(images): |
|
|
118 |
patient_mod = images[i][element_index] |
|
|
119 |
|
|
|
120 |
plot_3_view(f"batch_element_{i}", patient_mod, slice, save=save) |
|
|
121 |
|
|
|
122 |
patient_seg = gt[element_index] |
|
|
123 |
plot_3_view('seg', patient_seg, slice, save=save) |
|
|
124 |
|
|
|
125 |
|
|
|
126 |
def plot_batch_cubes(patient_ids, batch_volumes, batch_gt, patches=1, img_size=30): |
|
|
127 |
for batch_pos, patient in enumerate(patient_ids[:patches]): |
|
|
128 |
patient = patient.item() |
|
|
129 |
modality = batch_volumes[batch_pos][0] |
|
|
130 |
gt = batch_gt[batch_pos][0] |
|
|
131 |
resized_modality = resize(modality,(img_size, img_size, img_size), mode='constant') |
|
|
132 |
resized_gt = resize(gt, (img_size, img_size, img_size), mode='constant') |
|
|
133 |
|
|
|
134 |
fig = plot_cube(resized_modality, img_size) |
|
|
135 |
fig.savefig(f'results/3Dplot_{patient}_{batch_pos}.png') |
|
|
136 |
|
|
|
137 |
fig_seg = plot_cube(resized_gt, img_size) |
|
|
138 |
fig_seg.savefig(f'results/3Dplot_{patient}_gt_{batch_pos}.png') |
|
|
139 |
|
|
|
140 |
|
|
|
141 |
def plot_cube(cube, dim, gt=False, angle=320): |
|
|
142 |
def normalize(arr): |
|
|
143 |
arr_min = np.min(arr) |
|
|
144 |
return (arr - arr_min) / (np.max(arr) - arr_min) |
|
|
145 |
|
|
|
146 |
def explode(data): |
|
|
147 |
shape_arr = np.array(data.shape) |
|
|
148 |
size = shape_arr[:3] * 2 - 1 |
|
|
149 |
exploded = np.zeros(np.concatenate([size, shape_arr[3:]]), dtype=data.dtype) |
|
|
150 |
exploded[::2, ::2, ::2] = data |
|
|
151 |
return exploded |
|
|
152 |
|
|
|
153 |
def expand_coordinates(indices): |
|
|
154 |
x, y, z = indices |
|
|
155 |
x[1::2, :, :] += 1 |
|
|
156 |
y[:, 1::2, :] += 1 |
|
|
157 |
z[:, :, 1::2] += 1 |
|
|
158 |
return x, y, z |
|
|
159 |
|
|
|
160 |
if gt: |
|
|
161 |
colors = [(1, 0, 0), (0, 1, 0), (0, 0, 1), (1, 1, 1)] |
|
|
162 |
my_cmap = LinearSegmentedColormap.from_list('my_cmap', colors, 4) |
|
|
163 |
plt.register_cmap(cmap=my_cmap) |
|
|
164 |
cmap = plt.get_cmap('my_cmap') |
|
|
165 |
cube = np.around(cube) |
|
|
166 |
else: |
|
|
167 |
cmap = cm.viridis |
|
|
168 |
|
|
|
169 |
cube = normalize(cube) |
|
|
170 |
facecolors = cmap(cube) |
|
|
171 |
facecolors[:, :, :, -1] = cube |
|
|
172 |
facecolors = explode(facecolors) |
|
|
173 |
|
|
|
174 |
filled = facecolors[:, :, :, -1] != 0 |
|
|
175 |
x, y, z = expand_coordinates(np.indices(np.array(filled.shape) + 1)) |
|
|
176 |
|
|
|
177 |
fig = plt.figure(figsize=(30 / 2.54, 30 / 2.54)) |
|
|
178 |
ax = fig.gca(projection='3d') |
|
|
179 |
ax.view_init(30, angle) |
|
|
180 |
ax.set_xlim(right=dim * 2) |
|
|
181 |
ax.set_ylim(top=dim * 2) |
|
|
182 |
ax.set_zlim(top=dim * 2) |
|
|
183 |
|
|
|
184 |
ax.voxels(x, y, z, filled, facecolors=facecolors, shade=False) |
|
|
185 |
return fig |