from multiprocessing import cpu_count
import numpy as np
from tqdm import tqdm
from multiprocessing.dummy import Pool
def compute_patch_indices(image_shape, patch_size, overlap, start=None):
if isinstance(overlap, int):
overlap = np.asarray([overlap] * len(image_shape))
if start is None:
n_patches = np.ceil(image_shape / (patch_size - overlap))
overflow = (patch_size - overlap) * n_patches - image_shape + overlap
start = -np.ceil(overflow / 2)
elif isinstance(start, int):
start = np.asarray([start] * len(image_shape))
stop = image_shape + start
step = patch_size - overlap
return get_set_of_patch_indices(start, stop, step)
# def compute_patch_indices_full(image_shape, patch_size, overlap, start=None):
# if isinstance(overlap, int):
# overlap = np.asarray([overlap] * len(image_shape))
# step = patch_size - overlap
# max_index = image_shape - patch_size
# if start is None:
# n_patches = np.ceil(max_index / step)
# overflow = step * n_patches - image_shape
# start = -np.ceil(overflow / 2)
# elif isinstance(start, int):
# start = np.asarray([start] * len(image_shape))
# stop = max_index - start
# return get_set_of_patch_indices(start, stop, step)
def get_set_of_patch_indices(start, stop, step):
return np.asarray(np.mgrid[start[0]:stop[0]:step[0], start[1]:stop[1]:step[1],
start[2]:stop[2]:step[2]].reshape(3, -1).T, dtype=np.int)
def get_random_patch_index(image_shape, patch_shape):
"""
Returns a random corner index for a patch. If this is used during training, the middle pixels will be seen by
the model way more often than the edge pixels (which is probably a bad thing).
:param image_shape: Shape of the image
:param patch_shape: Shape of the patch
:return: a tuple containing the corner index which can be used to get a patch from an image
"""
return get_random_nd_index(np.subtract(image_shape, patch_shape))
def get_random_nd_index(index_max):
return tuple([np.random.choice(index_max[index] + 1) for index in range(len(index_max))])
def get_patch_from_3d_data(data, patch_shape, patch_index):
"""
Returns a patch from a numpy array.
:param data: numpy array from which to get the patch.
:param patch_shape: shape/size of the patch.
:param patch_index: corner index of the patch.
:return: numpy array take from the data with the patch shape specified.
"""
patch_index = np.asarray(patch_index, dtype=np.int16)
patch_shape = np.asarray(patch_shape)
image_shape = data.shape[-3:]
if np.any(patch_index < 0) or np.any((patch_index + patch_shape) > image_shape):
data, patch_index = fix_out_of_bound_patch_attempt(data, patch_shape, patch_index)
return data[..., patch_index[0]:patch_index[0] + patch_shape[0],
patch_index[1]:patch_index[1] + patch_shape[1],
patch_index[2]:patch_index[2] + patch_shape[2]]
def fix_out_of_bound_patch_attempt(data, patch_shape, patch_index, ndim=3):
"""
Pads the data and alters the patch index so that a patch will be correct.
:param data:
:param patch_shape:
:param patch_index:
:return: padded data, fixed patch index
"""
image_shape = data.shape[-ndim:]
pad_before = np.abs((patch_index < 0) * patch_index)
pad_after = np.abs(((patch_index + patch_shape) > image_shape) * ((patch_index + patch_shape) - image_shape))
pad_args = np.stack([pad_before, pad_after], axis=1)
if pad_args.shape[0] < len(data.shape):
pad_args = [[0, 0]] * (len(data.shape) - pad_args.shape[0]) + pad_args.tolist()
data = np.pad(data, pad_args, mode="edge")
patch_index += pad_before
return data, patch_index
def reconstruct_from_patches(patches, patch_indices, data_shape, default_value=0):
"""
Reconstructs an array of the original shape from the lists of patches and corresponding patch indices. Overlapping
patches are averaged.
:param patches: List of numpy array patches.
:param patch_indices: List of indices that corresponds to the list of patches.
:param data_shape: Shape of the array from which the patches were extracted.
:param default_value: The default value of the resulting data. if the patch coverage is complete, this value will
be overwritten.
:return: numpy array containing the data reconstructed by the patches.
"""
def compute_full_picture(patches, patch_indices, data_shape):
data = np.zeros(data_shape)
image_shape = data_shape[-4:-1]
count = np.zeros(data_shape, dtype=np.int)
for patch, index in zip(patches, patch_indices):
image_patch_shape = patch.shape[-4:-1]
if np.any(index < 0):
fix_patch = np.asarray((index < 0) * np.abs(index), dtype=np.int)
patch = patch[fix_patch[0]:, fix_patch[1]:, fix_patch[2]:, ...]
index[index < 0] = 0
if np.any((index + image_patch_shape) >= image_shape):
fix_patch = np.asarray(image_patch_shape - (((index + image_patch_shape) >= image_shape)
* ((index + image_patch_shape) - image_shape)),
dtype=np.int)
patch = patch[:fix_patch[0], :fix_patch[1], :fix_patch[2], ...]
patch_index = np.zeros(data_shape, dtype=np.bool)
patch_index[index[0]:index[0] + patch.shape[-4],
index[1]:index[1] + patch.shape[-3],
index[2]:index[2] + patch.shape[-2], ...] = True
# patch_data = np.zeros(data_shape)
# patch_data[patch_index] = patch.flatten()
data[patch_index] += patch.flatten()
# new_data_index = np.logical_and(patch_index, np.logical_not(count > 0))
# data[new_data_index] = patch_data[new_data_index]
#
# averaged_data_index = np.logical_and(patch_index, count > 0)
# #if np.any(averaged_data_index):
# data[averaged_data_index] = (data[averaged_data_index] * count[averaged_data_index] + patch_data[
# averaged_data_index]) / (count[averaged_data_index] + 1)
count[patch_index] += 1
return data, count
workers = cpu_count()
pool = Pool(workers)
results = []
for i in range(workers):
patches_i, patch_indices_i = patches[i::workers], \
patch_indices[i::workers]
results.append(pool.apply_async(compute_full_picture,
args=(patches_i, patch_indices_i, data_shape)))
data = np.zeros(data_shape)
count = np.zeros(data_shape, dtype=np.int)
for i in range(len(results)):
data_i, count_i = results[i].get()
data += data_i
count += count_i
assert np.all(count > 0)
return data / count