a b/Segmentation/train/reshape.py
1
from Segmentation.plotting.voxels import plot_volume, plot_to_image
2
import tensorflow as tf
3
import numpy as np
4
5
# TODO: Replace all tf.slice operations with more Pythonic expressions
6
# TODO: colour maps should be consistent with the ones for 2D
7
8
colour_maps = {
9
    1: [tf.constant([1, 1, 1], dtype=tf.float32), tf.constant([[[[255, 255, 0]]]], dtype=tf.float32)],  # background / black
10
    2: [tf.constant([2, 2, 2], dtype=tf.float32), tf.constant([[[[0, 255, 255]]]], dtype=tf.float32)],
11
    3: [tf.constant([3, 3, 3], dtype=tf.float32), tf.constant([[[[255, 0, 255]]]], dtype=tf.float32)],
12
    4: [tf.constant([4, 4, 4], dtype=tf.float32), tf.constant([[[[255, 255, 255]]]], dtype=tf.float32)],
13
    5: [tf.constant([5, 5, 5], dtype=tf.float32), tf.constant([[[[120, 120, 120]]]], dtype=tf.float32)],
14
    6: [tf.constant([6, 6, 6], dtype=tf.float32), tf.constant([[[[255, 165, 0]]]], dtype=tf.float32)],
15
}
16
17
def replace_vector(img, search, replace):
18
    condition = tf.equal(img, search)
19
    condition = tf.reduce_all(condition, axis=-1)
20
    condition = tf.stack((condition,) * img.shape[-1], axis=-1)
21
    replace_tiled = tf.tile(replace, img.shape[:-1])
22
    replace_tiled = tf.reshape(replace_tiled, img.shape)
23
    return tf.where(condition, replace_tiled, img)
24
25
def get_mid_slice(x, y, pred, multi_class):
26
    mid = tf.cast(tf.divide(tf.shape(y)[1], 2), tf.int32)
27
    start = [0, mid, 0, 0, 0]
28
    size = [1, 1, -1, -1, -1]
29
    x_slice = tf.slice(x, start, size)
30
    y_slice = tf.slice(y, start, size)
31
    pred_slice = tf.slice(pred, start, size)
32
    if multi_class:
33
        x_slice = tf.squeeze(x_slice, axis=-1)
34
        x_slice = tf.stack((x_slice,) * 3, axis=-1)
35
        y_slice = tf.argmax(y_slice, axis=-1)
36
        y_slice = tf.stack((y_slice,) * 3, axis=-1)
37
        y_slice = tf.cast(y_slice, tf.float32)
38
        pred_slice = tf.argmax(pred_slice, axis=-1)
39
        pred_slice = tf.stack((pred_slice,) * 3, axis=-1)
40
        pred_slice = tf.cast(pred_slice, tf.float32)
41
        for c in colour_maps:
42
            y_slice = replace_vector(y_slice, colour_maps[c][0], colour_maps[c][1])
43
            pred_slice = replace_vector(pred_slice, colour_maps[c][0], colour_maps[c][1])
44
    else:
45
        pred_slice = tf.math.round(pred_slice)
46
47
    img_pad = tf.ones((pred_slice.shape[0], pred_slice.shape[1], pred_slice.shape[2], 3, pred_slice.shape[4]))
48
    img = tf.concat((x_slice, img_pad, y_slice, img_pad, pred_slice), axis=-2)
49
50
    return tf.reshape(img, (img.shape[1:]))
51
52
def get_mid_vol(y, pred, multi_class, rad=12, check_empty=False):
53
    y_shape = tf.shape(y)
54
    slice_start = [0, (y_shape[1] // 2) - rad, (y_shape[2] // 2) - rad, (y_shape[3] // 2) - rad, 0]
55
    slice_size = [1, rad * 2, rad * 2, rad * 2, -1]
56
    y_subvol = tf.slice(y, slice_start, slice_size)
57
58
    if (check_empty) and (tf.math.reduce_sum(y_subvol) < 25):
59
        return None
60
61
    if multi_class:
62
        y_subvol = tf.argmax(y_subvol, axis=-1)
63
        y_subvol = tf.cast(y_subvol, tf.float32)
64
    else:
65
        y_subvol = tf.reshape(y_subvol, (y_subvol.shape[1:4]))
66
    y_subvol = tf.stack((y_subvol,) * 3, axis=-1)
67
    pred_subvol = tf.slice(pred, slice_start, slice_size)
68
    if multi_class:
69
        pred_subvol = tf.argmax(pred_subvol, axis=-1)
70
        pred_subvol = tf.cast(pred_subvol, tf.float32)
71
    else:
72
        pred_subvol = tf.math.round(pred_subvol)  # new
73
        pred_subvol = tf.reshape(pred_subvol, (pred_subvol.shape[1:4]))
74
    pred_subvol = tf.stack((pred_subvol,) * 3, axis=-1)
75
    if multi_class:
76
        for c in colour_maps:
77
            y_subvol = replace_vector(y_subvol, colour_maps[c][0], tf.divide(colour_maps[c][1], 255))
78
            pred_subvol = replace_vector(pred_subvol, colour_maps[c][0], tf.divide(colour_maps[c][1], 255))
79
        y_subvol = tf.squeeze(y_subvol, axis=0)
80
        pred_subvol = tf.squeeze(pred_subvol, axis=0)
81
    fig = plot_volume(y_subvol, show=False)
82
    y_img = plot_to_image(fig)
83
    del y_subvol
84
    fig = plot_volume(pred_subvol, show=False)
85
    del pred_subvol
86
    pred_img = plot_to_image(fig)
87
88
    img = tf.concat((y_img, pred_img), axis=-2)
89
    return img
90
91
def plot_through_slices(batch_idx, x_crop, y_crop, mean_pred, writer, multi_class=False):
92
    imgs = []
93
    slice_size = [1, 1, -1, -1, -1]
94
    for i in range(160):
95
        slice_start = [batch_idx, i, 0, 0, 0]
96
        x_slice = tf.slice(x_crop, slice_start, slice_size)
97
        y_slice = tf.slice(y_crop, slice_start, slice_size)
98
        m_slice = tf.slice(mean_pred, slice_start, slice_size)
99
        if multi_class:
100
            y_slice = tf.argmax(y_slice, axis=-1)
101
            y_slice = tf.cast(y_slice, tf.float32)
102
            y_slice = tf.stack((y_slice,) * 3, axis=-1)
103
            m_slice = tf.argmax(m_slice, axis=-1)
104
            m_slice = tf.cast(m_slice, tf.float32)
105
            m_slice = tf.stack((m_slice,) * 3, axis=-1)
106
            if multi_class:
107
                for c in colour_maps:
108
                    y_slice = replace_vector(y_slice, colour_maps[c][0], tf.divide(colour_maps[c][1], 255))
109
                    m_slice = replace_vector(m_slice, colour_maps[c][0], tf.divide(colour_maps[c][1], 255))
110
        else:
111
            m_slice = tf.math.round(m_slice)
112
        if multi_class:
113
            x_slice = tf.squeeze(x_slice, axis=-1)
114
            x_slice = tf.stack((x_slice,) * 3, axis=-1)        
115
        if not multi_class:
116
            x_min = tf.reduce_min(x_slice)
117
            x_max = tf.reduce_max(x_slice)
118
            x_slice = (x_slice - x_min) / (x_max - x_min)
119
        img = tf.concat((x_slice, y_slice, m_slice), axis=-2)
120
        img = tf.reshape(img, (img.shape[1:]))
121
        with writer.as_default():
122
            tf.summary.image(f"Whole Validation - All Slices", img, step=i)
123
        img = tf.reshape(img, (img.shape[1:]))
124
        if not multi_class:
125
            img = tf.reshape(img, (img.shape[:-1]))
126
            img = img * 255
127
        imgs.append(img.numpy().astype(np.uint8))
128
    return imgs