|
a |
|
b/cell_magic_wand.py |
|
|
1 |
########################################################################### |
|
|
2 |
# |
|
|
3 |
# Python implementation of the ImageJ Cell Magic Wand plugin |
|
|
4 |
# (http://www.maxplanckflorida.org/fitzpatricklab/software/cellMagicWand/) |
|
|
5 |
# with modifications to reduce variability due to seed point selection |
|
|
6 |
# and to support edge detection using all z slices of a 3D image |
|
|
7 |
# |
|
|
8 |
# Author: Noah Apthorpe (apthorpe@cs.princeton.edu) |
|
|
9 |
# |
|
|
10 |
# Description: Draws a border within a specified radius |
|
|
11 |
# around a specified center "seed" point |
|
|
12 |
# using a polar transform and a dynamic |
|
|
13 |
# programming edge-following algorithm |
|
|
14 |
# |
|
|
15 |
# Usage: Import and call the cell_magic_wand() function |
|
|
16 |
# or cell_magic_wand_3d () function with |
|
|
17 |
# a source image, radius window, and location of center |
|
|
18 |
# point. Other parameters set as optional arguments. |
|
|
19 |
# Returns a binary mask with 1s inside the detected edge and |
|
|
20 |
# a list of points along the detected edge. |
|
|
21 |
# |
|
|
22 |
########################################################################### |
|
|
23 |
|
|
|
24 |
import numpy as np |
|
|
25 |
from scipy.ndimage.interpolation import zoom |
|
|
26 |
from scipy.ndimage.morphology import binary_fill_holes |
|
|
27 |
|
|
|
28 |
|
|
|
29 |
def coord_polar_to_cart(r, theta, center): |
|
|
30 |
'''Converts polar coordinates around center to Cartesian''' |
|
|
31 |
x = r * np.cos(theta) + center[0] |
|
|
32 |
y = r * np.sin(theta) + center[1] |
|
|
33 |
return x, y |
|
|
34 |
|
|
|
35 |
|
|
|
36 |
def coord_cart_to_polar(x, y, center): |
|
|
37 |
'''Converts Cartesian coordinates to polar''' |
|
|
38 |
r = np.sqrt((x-center[0])**2 + (y-center[1])**2) |
|
|
39 |
theta = np.arctan2((y-center[1]), (x-center[0])) |
|
|
40 |
return r, theta |
|
|
41 |
|
|
|
42 |
|
|
|
43 |
def image_cart_to_polar(image, center, min_radius, max_radius, phase_width, zoom_factor=1): |
|
|
44 |
'''Converts an image from cartesian to polar coordinates around center''' |
|
|
45 |
|
|
|
46 |
# Upsample image |
|
|
47 |
if zoom_factor != 1: |
|
|
48 |
image = zoom(image, (zoom_factor, zoom_factor), order=4) |
|
|
49 |
center = (center[0]*zoom_factor + zoom_factor/2, center[1]*zoom_factor + zoom_factor/2) |
|
|
50 |
min_radius = min_radius * zoom_factor |
|
|
51 |
max_radius = max_radius * zoom_factor |
|
|
52 |
|
|
|
53 |
# pad if necessary |
|
|
54 |
max_x, max_y = image.shape[0], image.shape[1] |
|
|
55 |
pad_dist_x = np.max([(center[0] + max_radius) - max_x, -(center[0] - max_radius)]) |
|
|
56 |
pad_dist_y = np.max([(center[1] + max_radius) - max_y, -(center[1] - max_radius)]) |
|
|
57 |
pad_dist = int(np.max([0, pad_dist_x, pad_dist_y])) |
|
|
58 |
if pad_dist != 0: |
|
|
59 |
image = np.pad(image, pad_dist, 'constant') |
|
|
60 |
|
|
|
61 |
# coordinate conversion |
|
|
62 |
theta, r = np.meshgrid(np.linspace(0, 2*np.pi, phase_width), |
|
|
63 |
np.arange(min_radius, max_radius)) |
|
|
64 |
x, y = coord_polar_to_cart(r, theta, center) |
|
|
65 |
x, y = np.round(x), np.round(y) |
|
|
66 |
x, y = x.astype(int), y.astype(int) |
|
|
67 |
|
|
|
68 |
polar = image[x, y] |
|
|
69 |
polar.reshape((max_radius - min_radius, phase_width)) |
|
|
70 |
|
|
|
71 |
return polar |
|
|
72 |
|
|
|
73 |
|
|
|
74 |
def mask_polar_to_cart(mask, center, min_radius, max_radius, output_shape, zoom_factor=1): |
|
|
75 |
'''Converts a polar binary mask to Cartesian and places in an image of zeros''' |
|
|
76 |
|
|
|
77 |
# Account for upsampling |
|
|
78 |
if zoom_factor != 1: |
|
|
79 |
center = (center[0]*zoom_factor + zoom_factor/2, center[1]*zoom_factor + zoom_factor/2) |
|
|
80 |
min_radius = min_radius * zoom_factor |
|
|
81 |
max_radius = max_radius * zoom_factor |
|
|
82 |
output_shape = map(lambda a: a * zoom_factor, output_shape) |
|
|
83 |
|
|
|
84 |
# new image |
|
|
85 |
image = np.zeros(output_shape) |
|
|
86 |
|
|
|
87 |
# coordinate conversion |
|
|
88 |
theta, r = np.meshgrid(np.linspace(0, 2*np.pi, mask.shape[1]), |
|
|
89 |
np.arange(0, max_radius)) |
|
|
90 |
x, y = coord_polar_to_cart(r, theta, center) |
|
|
91 |
x, y = np.round(x), np.round(y) |
|
|
92 |
x, y = x.astype(int), y.astype(int) |
|
|
93 |
|
|
|
94 |
x = np.clip(x, 0, image.shape[0]-1) |
|
|
95 |
y = np.clip(y, 0, image.shape[1]-1) |
|
|
96 |
ix,iy = np.meshgrid(np.arange(0,mask.shape[1]), np.arange(0,mask.shape[0])) |
|
|
97 |
image[x,y] = mask |
|
|
98 |
|
|
|
99 |
# downsample image |
|
|
100 |
if zoom_factor != 1: |
|
|
101 |
zf = 1/float(zoom_factor) |
|
|
102 |
image = zoom(image, (zf, zf), order=4) |
|
|
103 |
|
|
|
104 |
# ensure image remains a filled binary mask |
|
|
105 |
image = (image > 0.5).astype(int) |
|
|
106 |
image = binary_fill_holes(image) |
|
|
107 |
return image |
|
|
108 |
|
|
|
109 |
|
|
|
110 |
def find_edge_2d(polar, min_radius): |
|
|
111 |
'''Dynamic programming algorithm to find edge given polar image''' |
|
|
112 |
if len(polar.shape) != 2: |
|
|
113 |
raise ValueError("argument to find_edge_2d must be 2D") |
|
|
114 |
|
|
|
115 |
# Dynamic programming phase |
|
|
116 |
values_right_shift = np.pad(polar, ((0, 0), (0, 1)), 'constant', constant_values=0)[:, 1:] |
|
|
117 |
values_closeright_shift = np.pad(polar, ((1, 0),(0, 1)), 'constant', constant_values=0)[:-1, 1:] |
|
|
118 |
values_awayright_shift = np.pad(polar, ((0, 1), (0, 1)), 'constant', constant_values=0)[1:, 1:] |
|
|
119 |
|
|
|
120 |
values_move = np.zeros((polar.shape[0], polar.shape[1], 3)) |
|
|
121 |
values_move[:, :, 2] = np.add(polar, values_closeright_shift) # closeright |
|
|
122 |
values_move[:, :, 1] = np.add(polar, values_right_shift) # right |
|
|
123 |
values_move[:, :, 0] = np.add(polar, values_awayright_shift) # awayright |
|
|
124 |
values = np.amax(values_move, axis=2) |
|
|
125 |
|
|
|
126 |
directions = np.argmax(values_move, axis=2) |
|
|
127 |
directions = np.subtract(directions, 1) |
|
|
128 |
directions = np.negative(directions) |
|
|
129 |
|
|
|
130 |
# Edge following phase |
|
|
131 |
edge = [] |
|
|
132 |
mask = np.zeros(polar.shape) |
|
|
133 |
r_max, r = 0, 0 |
|
|
134 |
for i,v in enumerate(values[:,0]): |
|
|
135 |
if v >= r_max: |
|
|
136 |
r, r_max = i, v |
|
|
137 |
edge.append((r+min_radius, 0)) |
|
|
138 |
mask[0:r+1, 0] = 1 |
|
|
139 |
for t in range(1,polar.shape[1]): |
|
|
140 |
r += directions[r, t-1] |
|
|
141 |
if r >= directions.shape[0]: r = directions.shape[0]-1 |
|
|
142 |
if r < 0: r = 0 |
|
|
143 |
edge.append((r+min_radius, t)) |
|
|
144 |
mask[0:r+1, t] = 1 |
|
|
145 |
|
|
|
146 |
# add to inside of mask accounting for min_radius |
|
|
147 |
new_mask = np.ones((min_radius+mask.shape[0], mask.shape[1])) |
|
|
148 |
new_mask[min_radius:, :] = mask |
|
|
149 |
|
|
|
150 |
return np.array(edge), new_mask |
|
|
151 |
|
|
|
152 |
|
|
|
153 |
def edge_polar_to_cart(edge, center): |
|
|
154 |
'''Converts a list of polar edge points to a list of cartesian edge points''' |
|
|
155 |
cart_edge = [] |
|
|
156 |
for (r,t) in edge: |
|
|
157 |
x, y = coord_polar_to_cart(r, t, center) |
|
|
158 |
cart_edge.append((round(x), round(y))) |
|
|
159 |
return cart_edge |
|
|
160 |
|
|
|
161 |
|
|
|
162 |
def cell_magic_wand_single_point(image, center, min_radius, max_radius, |
|
|
163 |
roughness=2, zoom_factor=1): |
|
|
164 |
'''Draws a border within a specified radius around a specified center "seed" point |
|
|
165 |
using a polar transform and a dynamic programming edge-following algorithm. |
|
|
166 |
|
|
|
167 |
Returns a binary mask with 1s inside the detected edge and |
|
|
168 |
a list of points along the detected edge.''' |
|
|
169 |
if roughness < 1: |
|
|
170 |
roughness = 1 |
|
|
171 |
print("roughness must be >= 1, setting roughness to 1") |
|
|
172 |
if min_radius < 0: |
|
|
173 |
min_radius = 0 |
|
|
174 |
print("min_radius must be >=0, setting min_radius to 0") |
|
|
175 |
if max_radius <= min_radius: |
|
|
176 |
max_radius = min_radius + 1 |
|
|
177 |
print("max_radius must be larger than min_radius, setting max_radius to " + str(max_radius)) |
|
|
178 |
if zoom_factor <= 0: |
|
|
179 |
zoom_factor = 1 |
|
|
180 |
print("negative zoom_factor not allowed, setting zoom_factor to 1") |
|
|
181 |
phase_width = int(2 * np.pi * max_radius * roughness) |
|
|
182 |
polar_image = image_cart_to_polar(image, center, min_radius, max_radius, |
|
|
183 |
phase_width=phase_width, zoom_factor=zoom_factor) |
|
|
184 |
polar_edge, polar_mask = find_edge_2d(polar_image, min_radius) |
|
|
185 |
cart_edge = edge_polar_to_cart(polar_edge, center) |
|
|
186 |
cart_mask = mask_polar_to_cart(polar_mask, center, min_radius, max_radius, |
|
|
187 |
image.shape, zoom_factor=zoom_factor) |
|
|
188 |
return cart_mask, cart_edge |
|
|
189 |
|
|
|
190 |
|
|
|
191 |
def cell_magic_wand(image, center, min_radius, max_radius, |
|
|
192 |
roughness=2, zoom_factor=1, center_range=2): |
|
|
193 |
'''Runs the cell magic wand tool on multiple points near the supplied center and |
|
|
194 |
combines the results for a more robust edge detection then provided by the vanilla wand tool. |
|
|
195 |
|
|
|
196 |
Returns a binary mask with 1s inside detected edge''' |
|
|
197 |
|
|
|
198 |
centers = [] |
|
|
199 |
for i in [-center_range, 0, center_range]: |
|
|
200 |
for j in [-center_range, 0, center_range]: |
|
|
201 |
centers.append((center[0]+i, center[1]+j)) |
|
|
202 |
masks = np.zeros((image.shape[0], image.shape[1], len(centers))) |
|
|
203 |
for i, c in enumerate(centers): |
|
|
204 |
mask, edge = cell_magic_wand_single_point(image, c, min_radius, max_radius, |
|
|
205 |
roughness=roughness, zoom_factor=zoom_factor) |
|
|
206 |
masks[:,:,i] = mask |
|
|
207 |
mean_mask = np.mean(masks, axis=2) |
|
|
208 |
final_mask = (mean_mask > 0.5).astype(int) |
|
|
209 |
return final_mask |
|
|
210 |
|
|
|
211 |
|
|
|
212 |
def cell_magic_wand_3d(image_3d, center, min_radius, max_radius, |
|
|
213 |
roughness=2, zoom_factor=1, center_range=2, z_step=1): |
|
|
214 |
'''Robust cell magic wand tool for 3D images with dimensions (z, x, y) - default for tifffile.load. |
|
|
215 |
This functions runs the robust wand tool on each z slice in the image and returns the mean mask |
|
|
216 |
thresholded to 0.5''' |
|
|
217 |
masks = np.zeros((int(image_3d.shape[0]/z_step), image_3d.shape[1], image_3d.shape[2])) |
|
|
218 |
for s in range(int(image_3d.shape[0]/z_step)): |
|
|
219 |
mask = cell_magic_wand(image_3d[s*z_step,:,:], center, min_radius, max_radius, |
|
|
220 |
roughness=roughness, zoom_factor=zoom_factor, |
|
|
221 |
center_range=center_range) |
|
|
222 |
masks[s,:,:] = mask |
|
|
223 |
mean_mask = np.mean(masks, axis=0) |
|
|
224 |
final_mask = (mean_mask > 0.5).astype(int) |
|
|
225 |
return final_mask |