Diff of /cell_magic_wand.py [000000] .. [467f44]

Switch to unified view

a b/cell_magic_wand.py
1
###########################################################################
2
# 
3
# Python implementation of the ImageJ Cell Magic Wand plugin
4
# (http://www.maxplanckflorida.org/fitzpatricklab/software/cellMagicWand/)
5
# with modifications to reduce variability due to seed point selection
6
# and to support edge detection using all z slices of a 3D image
7
#
8
# Author: Noah Apthorpe (apthorpe@cs.princeton.edu)
9
#
10
# Description: Draws a border within a specified radius
11
#              around a specified center "seed" point
12
#              using a polar transform and a dynamic
13
#              programming edge-following algorithm
14
#
15
# Usage: Import and call the cell_magic_wand() function
16
#        or cell_magic_wand_3d () function with
17
#        a source image, radius window, and location of center
18
#        point.  Other parameters set as optional arguments.
19
#        Returns a binary mask with 1s inside the detected edge and
20
#        a list of points along the detected edge.
21
#
22
###########################################################################
23
24
import numpy as np
25
from scipy.ndimage.interpolation import zoom
26
from scipy.ndimage.morphology import binary_fill_holes
27
28
29
def coord_polar_to_cart(r, theta, center):
30
    '''Converts polar coordinates around center to Cartesian'''
31
    x = r * np.cos(theta) + center[0]
32
    y = r * np.sin(theta) + center[1]
33
    return x, y
34
35
36
def coord_cart_to_polar(x, y, center):
37
    '''Converts Cartesian coordinates to polar'''
38
    r = np.sqrt((x-center[0])**2 + (y-center[1])**2)
39
    theta = np.arctan2((y-center[1]), (x-center[0]))
40
    return r, theta
41
42
43
def image_cart_to_polar(image, center, min_radius, max_radius, phase_width, zoom_factor=1):
44
    '''Converts an image from cartesian to polar coordinates around center'''
45
46
    # Upsample image
47
    if zoom_factor != 1:
48
        image = zoom(image, (zoom_factor, zoom_factor), order=4)
49
        center = (center[0]*zoom_factor + zoom_factor/2, center[1]*zoom_factor + zoom_factor/2)
50
        min_radius = min_radius * zoom_factor
51
        max_radius = max_radius * zoom_factor
52
    
53
    # pad if necessary
54
    max_x, max_y = image.shape[0], image.shape[1]
55
    pad_dist_x = np.max([(center[0] + max_radius) - max_x, -(center[0] - max_radius)])
56
    pad_dist_y = np.max([(center[1] + max_radius) - max_y, -(center[1] - max_radius)])
57
    pad_dist = int(np.max([0, pad_dist_x, pad_dist_y]))
58
    if pad_dist != 0:
59
        image = np.pad(image, pad_dist, 'constant')
60
61
    # coordinate conversion
62
    theta, r = np.meshgrid(np.linspace(0, 2*np.pi, phase_width),
63
                           np.arange(min_radius, max_radius))
64
    x, y = coord_polar_to_cart(r, theta, center)
65
    x, y = np.round(x), np.round(y)
66
    x, y = x.astype(int), y.astype(int)
67
    
68
    polar = image[x, y]
69
    polar.reshape((max_radius - min_radius, phase_width))
70
71
    return polar
72
73
74
def mask_polar_to_cart(mask, center, min_radius, max_radius, output_shape, zoom_factor=1):
75
    '''Converts a polar binary mask to Cartesian and places in an image of zeros'''
76
77
    # Account for upsampling 
78
    if zoom_factor != 1:
79
        center = (center[0]*zoom_factor + zoom_factor/2, center[1]*zoom_factor + zoom_factor/2)
80
        min_radius = min_radius * zoom_factor
81
        max_radius = max_radius * zoom_factor
82
        output_shape = map(lambda a: a * zoom_factor, output_shape)
83
84
    # new image
85
    image = np.zeros(output_shape)
86
87
    # coordinate conversion
88
    theta, r = np.meshgrid(np.linspace(0, 2*np.pi, mask.shape[1]),
89
                           np.arange(0, max_radius))
90
    x, y = coord_polar_to_cart(r, theta, center)
91
    x, y = np.round(x), np.round(y)
92
    x, y = x.astype(int), y.astype(int)
93
94
    x = np.clip(x, 0, image.shape[0]-1)
95
    y = np.clip(y, 0, image.shape[1]-1)
96
    ix,iy = np.meshgrid(np.arange(0,mask.shape[1]), np.arange(0,mask.shape[0]))    
97
    image[x,y] = mask
98
99
    # downsample image
100
    if zoom_factor != 1:
101
        zf = 1/float(zoom_factor)
102
        image = zoom(image, (zf, zf), order=4)
103
104
    # ensure image remains a filled binary mask
105
    image = (image > 0.5).astype(int)
106
    image = binary_fill_holes(image)
107
    return image
108
    
109
110
def find_edge_2d(polar, min_radius):
111
    '''Dynamic programming algorithm to find edge given polar image'''
112
    if len(polar.shape) != 2:
113
        raise ValueError("argument to find_edge_2d must be 2D")
114
115
    # Dynamic programming phase
116
    values_right_shift = np.pad(polar, ((0, 0), (0, 1)), 'constant', constant_values=0)[:, 1:]
117
    values_closeright_shift = np.pad(polar, ((1, 0),(0, 1)), 'constant', constant_values=0)[:-1, 1:]
118
    values_awayright_shift = np.pad(polar, ((0, 1), (0, 1)), 'constant', constant_values=0)[1:, 1:]
119
120
    values_move = np.zeros((polar.shape[0], polar.shape[1], 3))
121
    values_move[:, :, 2] = np.add(polar, values_closeright_shift)  # closeright
122
    values_move[:, :, 1] = np.add(polar, values_right_shift)  # right
123
    values_move[:, :, 0] = np.add(polar, values_awayright_shift)  # awayright
124
    values = np.amax(values_move, axis=2)
125
126
    directions = np.argmax(values_move, axis=2)
127
    directions = np.subtract(directions, 1)
128
    directions = np.negative(directions)
129
        
130
    # Edge following phase
131
    edge = []
132
    mask = np.zeros(polar.shape)
133
    r_max, r = 0, 0
134
    for i,v in enumerate(values[:,0]):
135
        if v >= r_max:
136
            r, r_max = i, v
137
    edge.append((r+min_radius, 0))
138
    mask[0:r+1, 0] = 1
139
    for t in range(1,polar.shape[1]):
140
        r += directions[r, t-1]
141
        if r >= directions.shape[0]: r = directions.shape[0]-1
142
        if r < 0: r = 0
143
        edge.append((r+min_radius, t))
144
        mask[0:r+1, t] = 1
145
146
    # add to inside of mask accounting for min_radius
147
    new_mask = np.ones((min_radius+mask.shape[0], mask.shape[1]))
148
    new_mask[min_radius:, :] = mask
149
    
150
    return np.array(edge), new_mask
151
152
153
def edge_polar_to_cart(edge, center):
154
    '''Converts a list of polar edge points to a list of cartesian edge points'''
155
    cart_edge = [] 
156
    for (r,t) in edge:
157
        x, y = coord_polar_to_cart(r, t, center)
158
        cart_edge.append((round(x), round(y)))
159
    return cart_edge
160
161
162
def cell_magic_wand_single_point(image, center, min_radius, max_radius,
163
                                 roughness=2, zoom_factor=1):
164
    '''Draws a border within a specified radius around a specified center "seed" point
165
    using a polar transform and a dynamic programming edge-following algorithm.
166
167
    Returns a binary mask with 1s inside the detected edge and
168
    a list of points along the detected edge.'''
169
    if roughness < 1:
170
        roughness = 1
171
        print("roughness must be >= 1, setting roughness to 1")
172
    if min_radius < 0:
173
        min_radius = 0
174
        print("min_radius must be >=0, setting min_radius to 0")
175
    if max_radius <= min_radius:
176
        max_radius = min_radius + 1
177
        print("max_radius must be larger than min_radius, setting max_radius to " + str(max_radius))
178
    if zoom_factor <= 0:
179
        zoom_factor = 1
180
        print("negative zoom_factor not allowed, setting zoom_factor to 1")
181
    phase_width = int(2 * np.pi * max_radius * roughness)
182
    polar_image = image_cart_to_polar(image, center, min_radius, max_radius,
183
                                      phase_width=phase_width, zoom_factor=zoom_factor)
184
    polar_edge, polar_mask = find_edge_2d(polar_image, min_radius)
185
    cart_edge = edge_polar_to_cart(polar_edge, center)
186
    cart_mask = mask_polar_to_cart(polar_mask, center, min_radius, max_radius,
187
                                   image.shape, zoom_factor=zoom_factor)
188
    return cart_mask, cart_edge
189
190
191
def cell_magic_wand(image, center, min_radius, max_radius,
192
                    roughness=2, zoom_factor=1, center_range=2):
193
    '''Runs the cell magic wand tool on multiple points near the supplied center and 
194
    combines the results for a more robust edge detection then provided by the vanilla wand tool.
195
196
    Returns a binary mask with 1s inside detected edge'''
197
    
198
    centers = []
199
    for i in [-center_range, 0, center_range]:
200
        for j in [-center_range, 0, center_range]:
201
            centers.append((center[0]+i, center[1]+j))
202
    masks = np.zeros((image.shape[0], image.shape[1], len(centers)))
203
    for i, c in enumerate(centers):
204
        mask, edge = cell_magic_wand_single_point(image, c, min_radius, max_radius,
205
                                                  roughness=roughness, zoom_factor=zoom_factor)
206
        masks[:,:,i] = mask
207
    mean_mask = np.mean(masks, axis=2)
208
    final_mask = (mean_mask > 0.5).astype(int)
209
    return final_mask
210
211
212
def cell_magic_wand_3d(image_3d, center, min_radius, max_radius,
213
                       roughness=2, zoom_factor=1, center_range=2, z_step=1):
214
    '''Robust cell magic wand tool for 3D images with dimensions (z, x, y) - default for tifffile.load.
215
    This functions runs the robust wand tool on each z slice in the image and returns the mean mask
216
    thresholded to 0.5'''
217
    masks = np.zeros((int(image_3d.shape[0]/z_step), image_3d.shape[1], image_3d.shape[2]))
218
    for s in range(int(image_3d.shape[0]/z_step)):
219
        mask = cell_magic_wand(image_3d[s*z_step,:,:], center, min_radius, max_radius,
220
                               roughness=roughness, zoom_factor=zoom_factor,
221
                               center_range=center_range)
222
        masks[s,:,:] = mask
223
    mean_mask = np.mean(masks, axis=0)
224
    final_mask = (mean_mask > 0.5).astype(int)
225
    return final_mask