--- a
+++ b/src/dataset/utils/visualization.py
@@ -0,0 +1,185 @@
+import os
+from matplotlib import pyplot as plt
+import matplotlib
+import time
+import numpy as np
+from matplotlib.colors import LinearSegmentedColormap
+from nilearn.plotting import plot_anat
+from matplotlib import cm
+from skimage.transform import resize
+import io
+
+
+def plot_batch(batch, seg: bool = False, slice: int = 32, batch_size: int=4):
+
+    def unnorm(data, epsilon=1e-8):
+        non_zero = data[data > 0.0]
+        mean = non_zero.mean()
+        std = non_zero.std() + epsilon
+        out = data * std + mean
+        out[data == 0] = 0
+        return out
+
+    plt.figure(figsize=(10, 3.5))
+
+    for i, volume in enumerate(batch):
+        plt.subplot(1, batch_size + 1, i + 1)
+
+        img = volume[:, slice, :].T if seg else volume[0, :, slice, :].T
+
+        npimg = img.cpu().detach().numpy()
+        img = npimg if seg else unnorm(npimg)
+        plt.imshow(img, cmap="gray")
+        plt.axis("off")
+
+
+    buf = io.BytesIO()
+    plt.savefig(buf, format='png')
+    buf.seek(0)
+    plt.close()
+    return buf
+
+
+
+def plot_3_view(modal: str, vol: np.ndarray, s: int=100, discrete: bool=False,
+                color_map: str="gray", save: bool=True):
+
+    views = [vol[s, :, :], vol[:, s, :], vol[:, :, s]]
+    fig = plt.figure(figsize=(10, 3.5))
+
+    for position in range(1, len(views) + 1):
+        plt.subplot(1, len(views), position)
+        plt.imshow(views[position - 1].T, cmap=color_map)
+        plt.axis("off")
+        if discrete:
+            plt.clim(0, 4)
+
+    if discrete:
+        plt.colorbar(ticks=range(5))
+    else:
+        plt.colorbar()
+
+    if save:
+        fig.savefig(f'plot_{modal}_{time.time()}.png')
+    else:
+        plt.show()
+
+
+def plot_3_view_uncertainty(modal: str, vol: np.ndarray, s: int=100, color_map: str="gray", save: bool=True):
+
+    views = [vol[s, :, :], vol[:, s, :], vol[:, :, s]]
+    fig = plt.figure(figsize=(10, 3.5))
+
+    for position in range(1, len(views) + 1):
+        plt.subplot(1, len(views), position)
+        plt.imshow(views[position - 1].T, cmap=color_map)
+        plt.axis("off")
+        plt.clim(0, 100)
+
+    plt.colorbar()
+
+    if save:
+        fig.savefig(f'plot_unc_{modal}_{time.time()}.png')
+    else:
+        plt.show()
+
+
+def plot_axis_overlayed(modalities: dict, segmentation_mask: str, subject: int, axis: str = 'x', save: bool=False):
+    """Save or show figure of provided axis"""
+    fig, axes = plt.subplots(len(modalities), 1)
+
+    for i, (modality_name, modality_path) in enumerate(modalities.items()):
+        display = plot_anat(modality_path, draw_cross=False, display_mode=axis, axes=axes[i], figure=fig, title=modality_name)
+        display.add_overlay(segmentation_mask)
+
+    if save:
+        fig.savefig(f'results/patient_{subject}.png')
+    else:
+        matplotlib.use('TkAgg')
+        plt.show()
+
+
+def plot_brain_batch_per_patient(patient_ids, data, save=True):
+    for patient in patient_ids:
+        patient = data[patient.item()]
+        patient_modalities = list(map(lambda x: os.path.join(patient.data_path, patient.patch_name, x), [patient.flair, patient.t2, patient.t1, patient.t1ce]))
+        patient_modalities = {"flair": patient_modalities[0],"t2": patient_modalities[1],"t1": patient_modalities[2],"t1ce": patient_modalities[3] }
+        patient_seg = os.path.join(patient.data_path, patient.patch_name, patient.seg)
+        plot_axis_overlayed(patient_modalities, patient_seg, patient.patch_name, axis='x', save=save)
+
+
+def plot_batch_slice(images, gt, slice = 10, save=True):
+    """Plot, for a given batch, different types of visualizations.
+    If paths: plot overlayed axis plot
+    If paths=None: plot slice of volume
+    """
+    for element_index in range(0, len(images)):
+        for i, mod_id in enumerate(images):
+            patient_mod = images[i][element_index]
+
+            plot_3_view(f"batch_element_{i}", patient_mod, slice, save=save)
+
+        patient_seg = gt[element_index]
+        plot_3_view('seg', patient_seg, slice, save=save)
+
+
+def plot_batch_cubes(patient_ids, batch_volumes, batch_gt, patches=1, img_size=30):
+    for batch_pos, patient in enumerate(patient_ids[:patches]):
+        patient = patient.item()
+        modality = batch_volumes[batch_pos][0]
+        gt =  batch_gt[batch_pos][0]
+        resized_modality = resize(modality,(img_size, img_size, img_size), mode='constant')
+        resized_gt = resize(gt, (img_size, img_size, img_size), mode='constant')
+
+        fig = plot_cube(resized_modality, img_size)
+        fig.savefig(f'results/3Dplot_{patient}_{batch_pos}.png')
+
+        fig_seg = plot_cube(resized_gt, img_size)
+        fig_seg.savefig(f'results/3Dplot_{patient}_gt_{batch_pos}.png')
+
+
+def plot_cube(cube, dim, gt=False, angle=320):
+    def normalize(arr):
+        arr_min = np.min(arr)
+        return (arr - arr_min) / (np.max(arr) - arr_min)
+
+    def explode(data):
+        shape_arr = np.array(data.shape)
+        size = shape_arr[:3] * 2 - 1
+        exploded = np.zeros(np.concatenate([size, shape_arr[3:]]), dtype=data.dtype)
+        exploded[::2, ::2, ::2] = data
+        return exploded
+
+    def expand_coordinates(indices):
+        x, y, z = indices
+        x[1::2, :, :] += 1
+        y[:, 1::2, :] += 1
+        z[:, :, 1::2] += 1
+        return x, y, z
+
+    if gt:
+        colors = [(1, 0, 0), (0, 1, 0), (0, 0, 1), (1, 1, 1)]
+        my_cmap = LinearSegmentedColormap.from_list('my_cmap', colors, 4)
+        plt.register_cmap(cmap=my_cmap)
+        cmap = plt.get_cmap('my_cmap')
+        cube = np.around(cube)
+    else:
+        cmap = cm.viridis
+
+    cube = normalize(cube)
+    facecolors = cmap(cube)
+    facecolors[:, :, :, -1] = cube
+    facecolors = explode(facecolors)
+
+    filled = facecolors[:, :, :, -1] != 0
+    x, y, z = expand_coordinates(np.indices(np.array(filled.shape) + 1))
+
+    fig = plt.figure(figsize=(30 / 2.54, 30 / 2.54))
+    ax = fig.gca(projection='3d')
+    ax.view_init(30, angle)
+    ax.set_xlim(right=dim * 2)
+    ax.set_ylim(top=dim * 2)
+    ax.set_zlim(top=dim * 2)
+
+    ax.voxels(x, y, z, filled, facecolors=facecolors, shade=False)
+    return fig
\ No newline at end of file