[ccb1dd]: / fetal_net / utils / patches.py

Download this file

156 lines (134 with data), 7.0 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
from multiprocessing import cpu_count
import numpy as np
from tqdm import tqdm
from multiprocessing.dummy import Pool
def compute_patch_indices(image_shape, patch_size, overlap, start=None):
if isinstance(overlap, int):
overlap = np.asarray([overlap] * len(image_shape))
if start is None:
n_patches = np.ceil(image_shape / (patch_size - overlap))
overflow = (patch_size - overlap) * n_patches - image_shape + overlap
start = -np.ceil(overflow / 2)
elif isinstance(start, int):
start = np.asarray([start] * len(image_shape))
stop = image_shape + start
step = patch_size - overlap
return get_set_of_patch_indices(start, stop, step)
# def compute_patch_indices_full(image_shape, patch_size, overlap, start=None):
# if isinstance(overlap, int):
# overlap = np.asarray([overlap] * len(image_shape))
# step = patch_size - overlap
# max_index = image_shape - patch_size
# if start is None:
# n_patches = np.ceil(max_index / step)
# overflow = step * n_patches - image_shape
# start = -np.ceil(overflow / 2)
# elif isinstance(start, int):
# start = np.asarray([start] * len(image_shape))
# stop = max_index - start
# return get_set_of_patch_indices(start, stop, step)
def get_set_of_patch_indices(start, stop, step):
return np.asarray(np.mgrid[start[0]:stop[0]:step[0], start[1]:stop[1]:step[1],
start[2]:stop[2]:step[2]].reshape(3, -1).T, dtype=np.int)
def get_random_patch_index(image_shape, patch_shape):
"""
Returns a random corner index for a patch. If this is used during training, the middle pixels will be seen by
the model way more often than the edge pixels (which is probably a bad thing).
:param image_shape: Shape of the image
:param patch_shape: Shape of the patch
:return: a tuple containing the corner index which can be used to get a patch from an image
"""
return get_random_nd_index(np.subtract(image_shape, patch_shape))
def get_random_nd_index(index_max):
return tuple([np.random.choice(index_max[index] + 1) for index in range(len(index_max))])
def get_patch_from_3d_data(data, patch_shape, patch_index):
"""
Returns a patch from a numpy array.
:param data: numpy array from which to get the patch.
:param patch_shape: shape/size of the patch.
:param patch_index: corner index of the patch.
:return: numpy array take from the data with the patch shape specified.
"""
patch_index = np.asarray(patch_index, dtype=np.int16)
patch_shape = np.asarray(patch_shape)
image_shape = data.shape[-3:]
if np.any(patch_index < 0) or np.any((patch_index + patch_shape) > image_shape):
data, patch_index = fix_out_of_bound_patch_attempt(data, patch_shape, patch_index)
return data[..., patch_index[0]:patch_index[0] + patch_shape[0],
patch_index[1]:patch_index[1] + patch_shape[1],
patch_index[2]:patch_index[2] + patch_shape[2]]
def fix_out_of_bound_patch_attempt(data, patch_shape, patch_index, ndim=3):
"""
Pads the data and alters the patch index so that a patch will be correct.
:param data:
:param patch_shape:
:param patch_index:
:return: padded data, fixed patch index
"""
image_shape = data.shape[-ndim:]
pad_before = np.abs((patch_index < 0) * patch_index)
pad_after = np.abs(((patch_index + patch_shape) > image_shape) * ((patch_index + patch_shape) - image_shape))
pad_args = np.stack([pad_before, pad_after], axis=1)
if pad_args.shape[0] < len(data.shape):
pad_args = [[0, 0]] * (len(data.shape) - pad_args.shape[0]) + pad_args.tolist()
data = np.pad(data, pad_args, mode="edge")
patch_index += pad_before
return data, patch_index
def reconstruct_from_patches(patches, patch_indices, data_shape, default_value=0):
"""
Reconstructs an array of the original shape from the lists of patches and corresponding patch indices. Overlapping
patches are averaged.
:param patches: List of numpy array patches.
:param patch_indices: List of indices that corresponds to the list of patches.
:param data_shape: Shape of the array from which the patches were extracted.
:param default_value: The default value of the resulting data. if the patch coverage is complete, this value will
be overwritten.
:return: numpy array containing the data reconstructed by the patches.
"""
def compute_full_picture(patches, patch_indices, data_shape):
data = np.zeros(data_shape)
image_shape = data_shape[-4:-1]
count = np.zeros(data_shape, dtype=np.int)
for patch, index in zip(patches, patch_indices):
image_patch_shape = patch.shape[-4:-1]
if np.any(index < 0):
fix_patch = np.asarray((index < 0) * np.abs(index), dtype=np.int)
patch = patch[fix_patch[0]:, fix_patch[1]:, fix_patch[2]:, ...]
index[index < 0] = 0
if np.any((index + image_patch_shape) >= image_shape):
fix_patch = np.asarray(image_patch_shape - (((index + image_patch_shape) >= image_shape)
* ((index + image_patch_shape) - image_shape)),
dtype=np.int)
patch = patch[:fix_patch[0], :fix_patch[1], :fix_patch[2], ...]
patch_index = np.zeros(data_shape, dtype=np.bool)
patch_index[index[0]:index[0] + patch.shape[-4],
index[1]:index[1] + patch.shape[-3],
index[2]:index[2] + patch.shape[-2], ...] = True
# patch_data = np.zeros(data_shape)
# patch_data[patch_index] = patch.flatten()
data[patch_index] += patch.flatten()
# new_data_index = np.logical_and(patch_index, np.logical_not(count > 0))
# data[new_data_index] = patch_data[new_data_index]
#
# averaged_data_index = np.logical_and(patch_index, count > 0)
# #if np.any(averaged_data_index):
# data[averaged_data_index] = (data[averaged_data_index] * count[averaged_data_index] + patch_data[
# averaged_data_index]) / (count[averaged_data_index] + 1)
count[patch_index] += 1
return data, count
workers = cpu_count()
pool = Pool(workers)
results = []
for i in range(workers):
patches_i, patch_indices_i = patches[i::workers], \
patch_indices[i::workers]
results.append(pool.apply_async(compute_full_picture,
args=(patches_i, patch_indices_i, data_shape)))
data = np.zeros(data_shape)
count = np.zeros(data_shape, dtype=np.int)
for i in range(len(results)):
data_i, count_i = results[i].get()
data += data_i
count += count_i
assert np.all(count > 0)
return data / count