|
a |
|
b/Segmentation/train/reshape.py |
|
|
1 |
from Segmentation.plotting.voxels import plot_volume, plot_to_image |
|
|
2 |
import tensorflow as tf |
|
|
3 |
import numpy as np |
|
|
4 |
|
|
|
5 |
# TODO: Replace all tf.slice operations with more Pythonic expressions |
|
|
6 |
# TODO: colour maps should be consistent with the ones for 2D |
|
|
7 |
|
|
|
8 |
colour_maps = { |
|
|
9 |
1: [tf.constant([1, 1, 1], dtype=tf.float32), tf.constant([[[[255, 255, 0]]]], dtype=tf.float32)], # background / black |
|
|
10 |
2: [tf.constant([2, 2, 2], dtype=tf.float32), tf.constant([[[[0, 255, 255]]]], dtype=tf.float32)], |
|
|
11 |
3: [tf.constant([3, 3, 3], dtype=tf.float32), tf.constant([[[[255, 0, 255]]]], dtype=tf.float32)], |
|
|
12 |
4: [tf.constant([4, 4, 4], dtype=tf.float32), tf.constant([[[[255, 255, 255]]]], dtype=tf.float32)], |
|
|
13 |
5: [tf.constant([5, 5, 5], dtype=tf.float32), tf.constant([[[[120, 120, 120]]]], dtype=tf.float32)], |
|
|
14 |
6: [tf.constant([6, 6, 6], dtype=tf.float32), tf.constant([[[[255, 165, 0]]]], dtype=tf.float32)], |
|
|
15 |
} |
|
|
16 |
|
|
|
17 |
def replace_vector(img, search, replace): |
|
|
18 |
condition = tf.equal(img, search) |
|
|
19 |
condition = tf.reduce_all(condition, axis=-1) |
|
|
20 |
condition = tf.stack((condition,) * img.shape[-1], axis=-1) |
|
|
21 |
replace_tiled = tf.tile(replace, img.shape[:-1]) |
|
|
22 |
replace_tiled = tf.reshape(replace_tiled, img.shape) |
|
|
23 |
return tf.where(condition, replace_tiled, img) |
|
|
24 |
|
|
|
25 |
def get_mid_slice(x, y, pred, multi_class): |
|
|
26 |
mid = tf.cast(tf.divide(tf.shape(y)[1], 2), tf.int32) |
|
|
27 |
start = [0, mid, 0, 0, 0] |
|
|
28 |
size = [1, 1, -1, -1, -1] |
|
|
29 |
x_slice = tf.slice(x, start, size) |
|
|
30 |
y_slice = tf.slice(y, start, size) |
|
|
31 |
pred_slice = tf.slice(pred, start, size) |
|
|
32 |
if multi_class: |
|
|
33 |
x_slice = tf.squeeze(x_slice, axis=-1) |
|
|
34 |
x_slice = tf.stack((x_slice,) * 3, axis=-1) |
|
|
35 |
y_slice = tf.argmax(y_slice, axis=-1) |
|
|
36 |
y_slice = tf.stack((y_slice,) * 3, axis=-1) |
|
|
37 |
y_slice = tf.cast(y_slice, tf.float32) |
|
|
38 |
pred_slice = tf.argmax(pred_slice, axis=-1) |
|
|
39 |
pred_slice = tf.stack((pred_slice,) * 3, axis=-1) |
|
|
40 |
pred_slice = tf.cast(pred_slice, tf.float32) |
|
|
41 |
for c in colour_maps: |
|
|
42 |
y_slice = replace_vector(y_slice, colour_maps[c][0], colour_maps[c][1]) |
|
|
43 |
pred_slice = replace_vector(pred_slice, colour_maps[c][0], colour_maps[c][1]) |
|
|
44 |
else: |
|
|
45 |
pred_slice = tf.math.round(pred_slice) |
|
|
46 |
|
|
|
47 |
img_pad = tf.ones((pred_slice.shape[0], pred_slice.shape[1], pred_slice.shape[2], 3, pred_slice.shape[4])) |
|
|
48 |
img = tf.concat((x_slice, img_pad, y_slice, img_pad, pred_slice), axis=-2) |
|
|
49 |
|
|
|
50 |
return tf.reshape(img, (img.shape[1:])) |
|
|
51 |
|
|
|
52 |
def get_mid_vol(y, pred, multi_class, rad=12, check_empty=False): |
|
|
53 |
y_shape = tf.shape(y) |
|
|
54 |
slice_start = [0, (y_shape[1] // 2) - rad, (y_shape[2] // 2) - rad, (y_shape[3] // 2) - rad, 0] |
|
|
55 |
slice_size = [1, rad * 2, rad * 2, rad * 2, -1] |
|
|
56 |
y_subvol = tf.slice(y, slice_start, slice_size) |
|
|
57 |
|
|
|
58 |
if (check_empty) and (tf.math.reduce_sum(y_subvol) < 25): |
|
|
59 |
return None |
|
|
60 |
|
|
|
61 |
if multi_class: |
|
|
62 |
y_subvol = tf.argmax(y_subvol, axis=-1) |
|
|
63 |
y_subvol = tf.cast(y_subvol, tf.float32) |
|
|
64 |
else: |
|
|
65 |
y_subvol = tf.reshape(y_subvol, (y_subvol.shape[1:4])) |
|
|
66 |
y_subvol = tf.stack((y_subvol,) * 3, axis=-1) |
|
|
67 |
pred_subvol = tf.slice(pred, slice_start, slice_size) |
|
|
68 |
if multi_class: |
|
|
69 |
pred_subvol = tf.argmax(pred_subvol, axis=-1) |
|
|
70 |
pred_subvol = tf.cast(pred_subvol, tf.float32) |
|
|
71 |
else: |
|
|
72 |
pred_subvol = tf.math.round(pred_subvol) # new |
|
|
73 |
pred_subvol = tf.reshape(pred_subvol, (pred_subvol.shape[1:4])) |
|
|
74 |
pred_subvol = tf.stack((pred_subvol,) * 3, axis=-1) |
|
|
75 |
if multi_class: |
|
|
76 |
for c in colour_maps: |
|
|
77 |
y_subvol = replace_vector(y_subvol, colour_maps[c][0], tf.divide(colour_maps[c][1], 255)) |
|
|
78 |
pred_subvol = replace_vector(pred_subvol, colour_maps[c][0], tf.divide(colour_maps[c][1], 255)) |
|
|
79 |
y_subvol = tf.squeeze(y_subvol, axis=0) |
|
|
80 |
pred_subvol = tf.squeeze(pred_subvol, axis=0) |
|
|
81 |
fig = plot_volume(y_subvol, show=False) |
|
|
82 |
y_img = plot_to_image(fig) |
|
|
83 |
del y_subvol |
|
|
84 |
fig = plot_volume(pred_subvol, show=False) |
|
|
85 |
del pred_subvol |
|
|
86 |
pred_img = plot_to_image(fig) |
|
|
87 |
|
|
|
88 |
img = tf.concat((y_img, pred_img), axis=-2) |
|
|
89 |
return img |
|
|
90 |
|
|
|
91 |
def plot_through_slices(batch_idx, x_crop, y_crop, mean_pred, writer, multi_class=False): |
|
|
92 |
imgs = [] |
|
|
93 |
slice_size = [1, 1, -1, -1, -1] |
|
|
94 |
for i in range(160): |
|
|
95 |
slice_start = [batch_idx, i, 0, 0, 0] |
|
|
96 |
x_slice = tf.slice(x_crop, slice_start, slice_size) |
|
|
97 |
y_slice = tf.slice(y_crop, slice_start, slice_size) |
|
|
98 |
m_slice = tf.slice(mean_pred, slice_start, slice_size) |
|
|
99 |
if multi_class: |
|
|
100 |
y_slice = tf.argmax(y_slice, axis=-1) |
|
|
101 |
y_slice = tf.cast(y_slice, tf.float32) |
|
|
102 |
y_slice = tf.stack((y_slice,) * 3, axis=-1) |
|
|
103 |
m_slice = tf.argmax(m_slice, axis=-1) |
|
|
104 |
m_slice = tf.cast(m_slice, tf.float32) |
|
|
105 |
m_slice = tf.stack((m_slice,) * 3, axis=-1) |
|
|
106 |
if multi_class: |
|
|
107 |
for c in colour_maps: |
|
|
108 |
y_slice = replace_vector(y_slice, colour_maps[c][0], tf.divide(colour_maps[c][1], 255)) |
|
|
109 |
m_slice = replace_vector(m_slice, colour_maps[c][0], tf.divide(colour_maps[c][1], 255)) |
|
|
110 |
else: |
|
|
111 |
m_slice = tf.math.round(m_slice) |
|
|
112 |
if multi_class: |
|
|
113 |
x_slice = tf.squeeze(x_slice, axis=-1) |
|
|
114 |
x_slice = tf.stack((x_slice,) * 3, axis=-1) |
|
|
115 |
if not multi_class: |
|
|
116 |
x_min = tf.reduce_min(x_slice) |
|
|
117 |
x_max = tf.reduce_max(x_slice) |
|
|
118 |
x_slice = (x_slice - x_min) / (x_max - x_min) |
|
|
119 |
img = tf.concat((x_slice, y_slice, m_slice), axis=-2) |
|
|
120 |
img = tf.reshape(img, (img.shape[1:])) |
|
|
121 |
with writer.as_default(): |
|
|
122 |
tf.summary.image(f"Whole Validation - All Slices", img, step=i) |
|
|
123 |
img = tf.reshape(img, (img.shape[1:])) |
|
|
124 |
if not multi_class: |
|
|
125 |
img = tf.reshape(img, (img.shape[:-1])) |
|
|
126 |
img = img * 255 |
|
|
127 |
imgs.append(img.numpy().astype(np.uint8)) |
|
|
128 |
return imgs |