Switch to unified view

a b/src/dataset/utils/visualization.py
1
import os
2
from matplotlib import pyplot as plt
3
import matplotlib
4
import time
5
import numpy as np
6
from matplotlib.colors import LinearSegmentedColormap
7
from nilearn.plotting import plot_anat
8
from matplotlib import cm
9
from skimage.transform import resize
10
import io
11
12
13
def plot_batch(batch, seg: bool = False, slice: int = 32, batch_size: int=4):
14
15
    def unnorm(data, epsilon=1e-8):
16
        non_zero = data[data > 0.0]
17
        mean = non_zero.mean()
18
        std = non_zero.std() + epsilon
19
        out = data * std + mean
20
        out[data == 0] = 0
21
        return out
22
23
    plt.figure(figsize=(10, 3.5))
24
25
    for i, volume in enumerate(batch):
26
        plt.subplot(1, batch_size + 1, i + 1)
27
28
        img = volume[:, slice, :].T if seg else volume[0, :, slice, :].T
29
30
        npimg = img.cpu().detach().numpy()
31
        img = npimg if seg else unnorm(npimg)
32
        plt.imshow(img, cmap="gray")
33
        plt.axis("off")
34
35
36
    buf = io.BytesIO()
37
    plt.savefig(buf, format='png')
38
    buf.seek(0)
39
    plt.close()
40
    return buf
41
42
43
44
def plot_3_view(modal: str, vol: np.ndarray, s: int=100, discrete: bool=False,
45
                color_map: str="gray", save: bool=True):
46
47
    views = [vol[s, :, :], vol[:, s, :], vol[:, :, s]]
48
    fig = plt.figure(figsize=(10, 3.5))
49
50
    for position in range(1, len(views) + 1):
51
        plt.subplot(1, len(views), position)
52
        plt.imshow(views[position - 1].T, cmap=color_map)
53
        plt.axis("off")
54
        if discrete:
55
            plt.clim(0, 4)
56
57
    if discrete:
58
        plt.colorbar(ticks=range(5))
59
    else:
60
        plt.colorbar()
61
62
    if save:
63
        fig.savefig(f'plot_{modal}_{time.time()}.png')
64
    else:
65
        plt.show()
66
67
68
def plot_3_view_uncertainty(modal: str, vol: np.ndarray, s: int=100, color_map: str="gray", save: bool=True):
69
70
    views = [vol[s, :, :], vol[:, s, :], vol[:, :, s]]
71
    fig = plt.figure(figsize=(10, 3.5))
72
73
    for position in range(1, len(views) + 1):
74
        plt.subplot(1, len(views), position)
75
        plt.imshow(views[position - 1].T, cmap=color_map)
76
        plt.axis("off")
77
        plt.clim(0, 100)
78
79
    plt.colorbar()
80
81
    if save:
82
        fig.savefig(f'plot_unc_{modal}_{time.time()}.png')
83
    else:
84
        plt.show()
85
86
87
def plot_axis_overlayed(modalities: dict, segmentation_mask: str, subject: int, axis: str = 'x', save: bool=False):
88
    """Save or show figure of provided axis"""
89
    fig, axes = plt.subplots(len(modalities), 1)
90
91
    for i, (modality_name, modality_path) in enumerate(modalities.items()):
92
        display = plot_anat(modality_path, draw_cross=False, display_mode=axis, axes=axes[i], figure=fig, title=modality_name)
93
        display.add_overlay(segmentation_mask)
94
95
    if save:
96
        fig.savefig(f'results/patient_{subject}.png')
97
    else:
98
        matplotlib.use('TkAgg')
99
        plt.show()
100
101
102
def plot_brain_batch_per_patient(patient_ids, data, save=True):
103
    for patient in patient_ids:
104
        patient = data[patient.item()]
105
        patient_modalities = list(map(lambda x: os.path.join(patient.data_path, patient.patch_name, x), [patient.flair, patient.t2, patient.t1, patient.t1ce]))
106
        patient_modalities = {"flair": patient_modalities[0],"t2": patient_modalities[1],"t1": patient_modalities[2],"t1ce": patient_modalities[3] }
107
        patient_seg = os.path.join(patient.data_path, patient.patch_name, patient.seg)
108
        plot_axis_overlayed(patient_modalities, patient_seg, patient.patch_name, axis='x', save=save)
109
110
111
def plot_batch_slice(images, gt, slice = 10, save=True):
112
    """Plot, for a given batch, different types of visualizations.
113
    If paths: plot overlayed axis plot
114
    If paths=None: plot slice of volume
115
    """
116
    for element_index in range(0, len(images)):
117
        for i, mod_id in enumerate(images):
118
            patient_mod = images[i][element_index]
119
120
            plot_3_view(f"batch_element_{i}", patient_mod, slice, save=save)
121
122
        patient_seg = gt[element_index]
123
        plot_3_view('seg', patient_seg, slice, save=save)
124
125
126
def plot_batch_cubes(patient_ids, batch_volumes, batch_gt, patches=1, img_size=30):
127
    for batch_pos, patient in enumerate(patient_ids[:patches]):
128
        patient = patient.item()
129
        modality = batch_volumes[batch_pos][0]
130
        gt =  batch_gt[batch_pos][0]
131
        resized_modality = resize(modality,(img_size, img_size, img_size), mode='constant')
132
        resized_gt = resize(gt, (img_size, img_size, img_size), mode='constant')
133
134
        fig = plot_cube(resized_modality, img_size)
135
        fig.savefig(f'results/3Dplot_{patient}_{batch_pos}.png')
136
137
        fig_seg = plot_cube(resized_gt, img_size)
138
        fig_seg.savefig(f'results/3Dplot_{patient}_gt_{batch_pos}.png')
139
140
141
def plot_cube(cube, dim, gt=False, angle=320):
142
    def normalize(arr):
143
        arr_min = np.min(arr)
144
        return (arr - arr_min) / (np.max(arr) - arr_min)
145
146
    def explode(data):
147
        shape_arr = np.array(data.shape)
148
        size = shape_arr[:3] * 2 - 1
149
        exploded = np.zeros(np.concatenate([size, shape_arr[3:]]), dtype=data.dtype)
150
        exploded[::2, ::2, ::2] = data
151
        return exploded
152
153
    def expand_coordinates(indices):
154
        x, y, z = indices
155
        x[1::2, :, :] += 1
156
        y[:, 1::2, :] += 1
157
        z[:, :, 1::2] += 1
158
        return x, y, z
159
160
    if gt:
161
        colors = [(1, 0, 0), (0, 1, 0), (0, 0, 1), (1, 1, 1)]
162
        my_cmap = LinearSegmentedColormap.from_list('my_cmap', colors, 4)
163
        plt.register_cmap(cmap=my_cmap)
164
        cmap = plt.get_cmap('my_cmap')
165
        cube = np.around(cube)
166
    else:
167
        cmap = cm.viridis
168
169
    cube = normalize(cube)
170
    facecolors = cmap(cube)
171
    facecolors[:, :, :, -1] = cube
172
    facecolors = explode(facecolors)
173
174
    filled = facecolors[:, :, :, -1] != 0
175
    x, y, z = expand_coordinates(np.indices(np.array(filled.shape) + 1))
176
177
    fig = plt.figure(figsize=(30 / 2.54, 30 / 2.54))
178
    ax = fig.gca(projection='3d')
179
    ax.view_init(30, angle)
180
    ax.set_xlim(right=dim * 2)
181
    ax.set_ylim(top=dim * 2)
182
    ax.set_zlim(top=dim * 2)
183
184
    ax.voxels(x, y, z, filled, facecolors=facecolors, shade=False)
185
    return fig