|
a |
|
b/fetal_net/utils/patches.py |
|
|
1 |
from multiprocessing import cpu_count |
|
|
2 |
|
|
|
3 |
import numpy as np |
|
|
4 |
from tqdm import tqdm |
|
|
5 |
from multiprocessing.dummy import Pool |
|
|
6 |
|
|
|
7 |
|
|
|
8 |
def compute_patch_indices(image_shape, patch_size, overlap, start=None): |
|
|
9 |
if isinstance(overlap, int): |
|
|
10 |
overlap = np.asarray([overlap] * len(image_shape)) |
|
|
11 |
if start is None: |
|
|
12 |
n_patches = np.ceil(image_shape / (patch_size - overlap)) |
|
|
13 |
overflow = (patch_size - overlap) * n_patches - image_shape + overlap |
|
|
14 |
start = -np.ceil(overflow / 2) |
|
|
15 |
elif isinstance(start, int): |
|
|
16 |
start = np.asarray([start] * len(image_shape)) |
|
|
17 |
stop = image_shape + start |
|
|
18 |
step = patch_size - overlap |
|
|
19 |
return get_set_of_patch_indices(start, stop, step) |
|
|
20 |
|
|
|
21 |
|
|
|
22 |
# def compute_patch_indices_full(image_shape, patch_size, overlap, start=None): |
|
|
23 |
# if isinstance(overlap, int): |
|
|
24 |
# overlap = np.asarray([overlap] * len(image_shape)) |
|
|
25 |
# step = patch_size - overlap |
|
|
26 |
# max_index = image_shape - patch_size |
|
|
27 |
# if start is None: |
|
|
28 |
# n_patches = np.ceil(max_index / step) |
|
|
29 |
# overflow = step * n_patches - image_shape |
|
|
30 |
# start = -np.ceil(overflow / 2) |
|
|
31 |
# elif isinstance(start, int): |
|
|
32 |
# start = np.asarray([start] * len(image_shape)) |
|
|
33 |
# stop = max_index - start |
|
|
34 |
# return get_set_of_patch_indices(start, stop, step) |
|
|
35 |
|
|
|
36 |
|
|
|
37 |
def get_set_of_patch_indices(start, stop, step): |
|
|
38 |
return np.asarray(np.mgrid[start[0]:stop[0]:step[0], start[1]:stop[1]:step[1], |
|
|
39 |
start[2]:stop[2]:step[2]].reshape(3, -1).T, dtype=np.int) |
|
|
40 |
|
|
|
41 |
|
|
|
42 |
def get_random_patch_index(image_shape, patch_shape): |
|
|
43 |
""" |
|
|
44 |
Returns a random corner index for a patch. If this is used during training, the middle pixels will be seen by |
|
|
45 |
the model way more often than the edge pixels (which is probably a bad thing). |
|
|
46 |
:param image_shape: Shape of the image |
|
|
47 |
:param patch_shape: Shape of the patch |
|
|
48 |
:return: a tuple containing the corner index which can be used to get a patch from an image |
|
|
49 |
""" |
|
|
50 |
return get_random_nd_index(np.subtract(image_shape, patch_shape)) |
|
|
51 |
|
|
|
52 |
|
|
|
53 |
def get_random_nd_index(index_max): |
|
|
54 |
return tuple([np.random.choice(index_max[index] + 1) for index in range(len(index_max))]) |
|
|
55 |
|
|
|
56 |
|
|
|
57 |
def get_patch_from_3d_data(data, patch_shape, patch_index): |
|
|
58 |
""" |
|
|
59 |
Returns a patch from a numpy array. |
|
|
60 |
:param data: numpy array from which to get the patch. |
|
|
61 |
:param patch_shape: shape/size of the patch. |
|
|
62 |
:param patch_index: corner index of the patch. |
|
|
63 |
:return: numpy array take from the data with the patch shape specified. |
|
|
64 |
""" |
|
|
65 |
patch_index = np.asarray(patch_index, dtype=np.int16) |
|
|
66 |
patch_shape = np.asarray(patch_shape) |
|
|
67 |
image_shape = data.shape[-3:] |
|
|
68 |
if np.any(patch_index < 0) or np.any((patch_index + patch_shape) > image_shape): |
|
|
69 |
data, patch_index = fix_out_of_bound_patch_attempt(data, patch_shape, patch_index) |
|
|
70 |
return data[..., patch_index[0]:patch_index[0] + patch_shape[0], |
|
|
71 |
patch_index[1]:patch_index[1] + patch_shape[1], |
|
|
72 |
patch_index[2]:patch_index[2] + patch_shape[2]] |
|
|
73 |
|
|
|
74 |
|
|
|
75 |
def fix_out_of_bound_patch_attempt(data, patch_shape, patch_index, ndim=3): |
|
|
76 |
""" |
|
|
77 |
Pads the data and alters the patch index so that a patch will be correct. |
|
|
78 |
:param data: |
|
|
79 |
:param patch_shape: |
|
|
80 |
:param patch_index: |
|
|
81 |
:return: padded data, fixed patch index |
|
|
82 |
""" |
|
|
83 |
image_shape = data.shape[-ndim:] |
|
|
84 |
pad_before = np.abs((patch_index < 0) * patch_index) |
|
|
85 |
pad_after = np.abs(((patch_index + patch_shape) > image_shape) * ((patch_index + patch_shape) - image_shape)) |
|
|
86 |
pad_args = np.stack([pad_before, pad_after], axis=1) |
|
|
87 |
if pad_args.shape[0] < len(data.shape): |
|
|
88 |
pad_args = [[0, 0]] * (len(data.shape) - pad_args.shape[0]) + pad_args.tolist() |
|
|
89 |
data = np.pad(data, pad_args, mode="edge") |
|
|
90 |
patch_index += pad_before |
|
|
91 |
return data, patch_index |
|
|
92 |
|
|
|
93 |
|
|
|
94 |
def reconstruct_from_patches(patches, patch_indices, data_shape, default_value=0): |
|
|
95 |
""" |
|
|
96 |
Reconstructs an array of the original shape from the lists of patches and corresponding patch indices. Overlapping |
|
|
97 |
patches are averaged. |
|
|
98 |
:param patches: List of numpy array patches. |
|
|
99 |
:param patch_indices: List of indices that corresponds to the list of patches. |
|
|
100 |
:param data_shape: Shape of the array from which the patches were extracted. |
|
|
101 |
:param default_value: The default value of the resulting data. if the patch coverage is complete, this value will |
|
|
102 |
be overwritten. |
|
|
103 |
:return: numpy array containing the data reconstructed by the patches. |
|
|
104 |
""" |
|
|
105 |
|
|
|
106 |
def compute_full_picture(patches, patch_indices, data_shape): |
|
|
107 |
data = np.zeros(data_shape) |
|
|
108 |
image_shape = data_shape[-4:-1] |
|
|
109 |
count = np.zeros(data_shape, dtype=np.int) |
|
|
110 |
for patch, index in zip(patches, patch_indices): |
|
|
111 |
image_patch_shape = patch.shape[-4:-1] |
|
|
112 |
if np.any(index < 0): |
|
|
113 |
fix_patch = np.asarray((index < 0) * np.abs(index), dtype=np.int) |
|
|
114 |
patch = patch[fix_patch[0]:, fix_patch[1]:, fix_patch[2]:, ...] |
|
|
115 |
index[index < 0] = 0 |
|
|
116 |
if np.any((index + image_patch_shape) >= image_shape): |
|
|
117 |
fix_patch = np.asarray(image_patch_shape - (((index + image_patch_shape) >= image_shape) |
|
|
118 |
* ((index + image_patch_shape) - image_shape)), |
|
|
119 |
dtype=np.int) |
|
|
120 |
patch = patch[:fix_patch[0], :fix_patch[1], :fix_patch[2], ...] |
|
|
121 |
patch_index = np.zeros(data_shape, dtype=np.bool) |
|
|
122 |
patch_index[index[0]:index[0] + patch.shape[-4], |
|
|
123 |
index[1]:index[1] + patch.shape[-3], |
|
|
124 |
index[2]:index[2] + patch.shape[-2], ...] = True |
|
|
125 |
# patch_data = np.zeros(data_shape) |
|
|
126 |
# patch_data[patch_index] = patch.flatten() |
|
|
127 |
data[patch_index] += patch.flatten() |
|
|
128 |
|
|
|
129 |
# new_data_index = np.logical_and(patch_index, np.logical_not(count > 0)) |
|
|
130 |
# data[new_data_index] = patch_data[new_data_index] |
|
|
131 |
# |
|
|
132 |
# averaged_data_index = np.logical_and(patch_index, count > 0) |
|
|
133 |
# #if np.any(averaged_data_index): |
|
|
134 |
# data[averaged_data_index] = (data[averaged_data_index] * count[averaged_data_index] + patch_data[ |
|
|
135 |
# averaged_data_index]) / (count[averaged_data_index] + 1) |
|
|
136 |
count[patch_index] += 1 |
|
|
137 |
return data, count |
|
|
138 |
|
|
|
139 |
workers = cpu_count() |
|
|
140 |
pool = Pool(workers) |
|
|
141 |
results = [] |
|
|
142 |
for i in range(workers): |
|
|
143 |
patches_i, patch_indices_i = patches[i::workers], \ |
|
|
144 |
patch_indices[i::workers] |
|
|
145 |
results.append(pool.apply_async(compute_full_picture, |
|
|
146 |
args=(patches_i, patch_indices_i, data_shape))) |
|
|
147 |
|
|
|
148 |
data = np.zeros(data_shape) |
|
|
149 |
count = np.zeros(data_shape, dtype=np.int) |
|
|
150 |
for i in range(len(results)): |
|
|
151 |
data_i, count_i = results[i].get() |
|
|
152 |
data += data_i |
|
|
153 |
count += count_i |
|
|
154 |
assert np.all(count > 0) |
|
|
155 |
return data / count |