a b/fetal_net/utils/patches.py
1
from multiprocessing import cpu_count
2
3
import numpy as np
4
from tqdm import tqdm
5
from multiprocessing.dummy import Pool
6
7
8
def compute_patch_indices(image_shape, patch_size, overlap, start=None):
9
    if isinstance(overlap, int):
10
        overlap = np.asarray([overlap] * len(image_shape))
11
    if start is None:
12
        n_patches = np.ceil(image_shape / (patch_size - overlap))
13
        overflow = (patch_size - overlap) * n_patches - image_shape + overlap
14
        start = -np.ceil(overflow / 2)
15
    elif isinstance(start, int):
16
        start = np.asarray([start] * len(image_shape))
17
    stop = image_shape + start
18
    step = patch_size - overlap
19
    return get_set_of_patch_indices(start, stop, step)
20
21
22
# def compute_patch_indices_full(image_shape, patch_size, overlap, start=None):
23
#     if isinstance(overlap, int):
24
#         overlap = np.asarray([overlap] * len(image_shape))
25
#     step = patch_size - overlap
26
#     max_index = image_shape - patch_size
27
#     if start is None:
28
#         n_patches = np.ceil(max_index / step)
29
#         overflow = step * n_patches - image_shape
30
#         start = -np.ceil(overflow / 2)
31
#     elif isinstance(start, int):
32
#         start = np.asarray([start] * len(image_shape))
33
#     stop = max_index - start
34
#     return get_set_of_patch_indices(start, stop, step)
35
36
37
def get_set_of_patch_indices(start, stop, step):
38
    return np.asarray(np.mgrid[start[0]:stop[0]:step[0], start[1]:stop[1]:step[1],
39
                      start[2]:stop[2]:step[2]].reshape(3, -1).T, dtype=np.int)
40
41
42
def get_random_patch_index(image_shape, patch_shape):
43
    """
44
    Returns a random corner index for a patch. If this is used during training, the middle pixels will be seen by
45
    the model way more often than the edge pixels (which is probably a bad thing).
46
    :param image_shape: Shape of the image
47
    :param patch_shape: Shape of the patch
48
    :return: a tuple containing the corner index which can be used to get a patch from an image
49
    """
50
    return get_random_nd_index(np.subtract(image_shape, patch_shape))
51
52
53
def get_random_nd_index(index_max):
54
    return tuple([np.random.choice(index_max[index] + 1) for index in range(len(index_max))])
55
56
57
def get_patch_from_3d_data(data, patch_shape, patch_index):
58
    """
59
    Returns a patch from a numpy array.
60
    :param data: numpy array from which to get the patch.
61
    :param patch_shape: shape/size of the patch.
62
    :param patch_index: corner index of the patch.
63
    :return: numpy array take from the data with the patch shape specified.
64
    """
65
    patch_index = np.asarray(patch_index, dtype=np.int16)
66
    patch_shape = np.asarray(patch_shape)
67
    image_shape = data.shape[-3:]
68
    if np.any(patch_index < 0) or np.any((patch_index + patch_shape) > image_shape):
69
        data, patch_index = fix_out_of_bound_patch_attempt(data, patch_shape, patch_index)
70
    return data[..., patch_index[0]:patch_index[0] + patch_shape[0],
71
                     patch_index[1]:patch_index[1] + patch_shape[1],
72
                     patch_index[2]:patch_index[2] + patch_shape[2]]
73
74
75
def fix_out_of_bound_patch_attempt(data, patch_shape, patch_index, ndim=3):
76
    """
77
    Pads the data and alters the patch index so that a patch will be correct.
78
    :param data:
79
    :param patch_shape:
80
    :param patch_index:
81
    :return: padded data, fixed patch index
82
    """
83
    image_shape = data.shape[-ndim:]
84
    pad_before = np.abs((patch_index < 0) * patch_index)
85
    pad_after = np.abs(((patch_index + patch_shape) > image_shape) * ((patch_index + patch_shape) - image_shape))
86
    pad_args = np.stack([pad_before, pad_after], axis=1)
87
    if pad_args.shape[0] < len(data.shape):
88
        pad_args = [[0, 0]] * (len(data.shape) - pad_args.shape[0]) + pad_args.tolist()
89
    data = np.pad(data, pad_args, mode="edge")
90
    patch_index += pad_before
91
    return data, patch_index
92
93
94
def reconstruct_from_patches(patches, patch_indices, data_shape, default_value=0):
95
    """
96
    Reconstructs an array of the original shape from the lists of patches and corresponding patch indices. Overlapping
97
    patches are averaged.
98
    :param patches: List of numpy array patches.
99
    :param patch_indices: List of indices that corresponds to the list of patches.
100
    :param data_shape: Shape of the array from which the patches were extracted.
101
    :param default_value: The default value of the resulting data. if the patch coverage is complete, this value will
102
    be overwritten.
103
    :return: numpy array containing the data reconstructed by the patches.
104
    """
105
106
    def compute_full_picture(patches, patch_indices, data_shape):
107
        data = np.zeros(data_shape)
108
        image_shape = data_shape[-4:-1]
109
        count = np.zeros(data_shape, dtype=np.int)
110
        for patch, index in zip(patches, patch_indices):
111
            image_patch_shape = patch.shape[-4:-1]
112
            if np.any(index < 0):
113
                fix_patch = np.asarray((index < 0) * np.abs(index), dtype=np.int)
114
                patch = patch[fix_patch[0]:, fix_patch[1]:, fix_patch[2]:, ...]
115
                index[index < 0] = 0
116
            if np.any((index + image_patch_shape) >= image_shape):
117
                fix_patch = np.asarray(image_patch_shape - (((index + image_patch_shape) >= image_shape)
118
                                                            * ((index + image_patch_shape) - image_shape)),
119
                                       dtype=np.int)
120
                patch = patch[:fix_patch[0], :fix_patch[1], :fix_patch[2], ...]
121
            patch_index = np.zeros(data_shape, dtype=np.bool)
122
            patch_index[index[0]:index[0] + patch.shape[-4],
123
                        index[1]:index[1] + patch.shape[-3],
124
                        index[2]:index[2] + patch.shape[-2], ...] = True
125
            # patch_data = np.zeros(data_shape)
126
            # patch_data[patch_index] = patch.flatten()
127
            data[patch_index] += patch.flatten()
128
129
            # new_data_index = np.logical_and(patch_index, np.logical_not(count > 0))
130
            # data[new_data_index] = patch_data[new_data_index]
131
            #
132
            # averaged_data_index = np.logical_and(patch_index, count > 0)
133
            # #if np.any(averaged_data_index):
134
            #     data[averaged_data_index] = (data[averaged_data_index] * count[averaged_data_index] + patch_data[
135
            #         averaged_data_index]) / (count[averaged_data_index] + 1)
136
            count[patch_index] += 1
137
        return data, count
138
139
    workers = cpu_count()
140
    pool = Pool(workers)
141
    results = []
142
    for i in range(workers):
143
        patches_i, patch_indices_i = patches[i::workers], \
144
                                     patch_indices[i::workers]
145
        results.append(pool.apply_async(compute_full_picture,
146
                                        args=(patches_i, patch_indices_i, data_shape)))
147
148
    data = np.zeros(data_shape)
149
    count = np.zeros(data_shape, dtype=np.int)
150
    for i in range(len(results)):
151
        data_i, count_i = results[i].get()
152
        data += data_i
153
        count += count_i
154
    assert np.all(count > 0)
155
    return data / count