Diff of /cell_magic_wand.py [000000] .. [467f44]

Switch to side-by-side view

--- a
+++ b/cell_magic_wand.py
@@ -0,0 +1,225 @@
+###########################################################################
+# 
+# Python implementation of the ImageJ Cell Magic Wand plugin
+# (http://www.maxplanckflorida.org/fitzpatricklab/software/cellMagicWand/)
+# with modifications to reduce variability due to seed point selection
+# and to support edge detection using all z slices of a 3D image
+#
+# Author: Noah Apthorpe (apthorpe@cs.princeton.edu)
+#
+# Description: Draws a border within a specified radius
+#              around a specified center "seed" point
+#              using a polar transform and a dynamic
+#              programming edge-following algorithm
+#
+# Usage: Import and call the cell_magic_wand() function
+#        or cell_magic_wand_3d () function with
+#        a source image, radius window, and location of center
+#        point.  Other parameters set as optional arguments.
+#        Returns a binary mask with 1s inside the detected edge and
+#        a list of points along the detected edge.
+#
+###########################################################################
+
+import numpy as np
+from scipy.ndimage.interpolation import zoom
+from scipy.ndimage.morphology import binary_fill_holes
+
+
+def coord_polar_to_cart(r, theta, center):
+    '''Converts polar coordinates around center to Cartesian'''
+    x = r * np.cos(theta) + center[0]
+    y = r * np.sin(theta) + center[1]
+    return x, y
+
+
+def coord_cart_to_polar(x, y, center):
+    '''Converts Cartesian coordinates to polar'''
+    r = np.sqrt((x-center[0])**2 + (y-center[1])**2)
+    theta = np.arctan2((y-center[1]), (x-center[0]))
+    return r, theta
+
+
+def image_cart_to_polar(image, center, min_radius, max_radius, phase_width, zoom_factor=1):
+    '''Converts an image from cartesian to polar coordinates around center'''
+
+    # Upsample image
+    if zoom_factor != 1:
+        image = zoom(image, (zoom_factor, zoom_factor), order=4)
+        center = (center[0]*zoom_factor + zoom_factor/2, center[1]*zoom_factor + zoom_factor/2)
+        min_radius = min_radius * zoom_factor
+        max_radius = max_radius * zoom_factor
+    
+    # pad if necessary
+    max_x, max_y = image.shape[0], image.shape[1]
+    pad_dist_x = np.max([(center[0] + max_radius) - max_x, -(center[0] - max_radius)])
+    pad_dist_y = np.max([(center[1] + max_radius) - max_y, -(center[1] - max_radius)])
+    pad_dist = int(np.max([0, pad_dist_x, pad_dist_y]))
+    if pad_dist != 0:
+        image = np.pad(image, pad_dist, 'constant')
+
+    # coordinate conversion
+    theta, r = np.meshgrid(np.linspace(0, 2*np.pi, phase_width),
+                           np.arange(min_radius, max_radius))
+    x, y = coord_polar_to_cart(r, theta, center)
+    x, y = np.round(x), np.round(y)
+    x, y = x.astype(int), y.astype(int)
+    
+    polar = image[x, y]
+    polar.reshape((max_radius - min_radius, phase_width))
+
+    return polar
+
+
+def mask_polar_to_cart(mask, center, min_radius, max_radius, output_shape, zoom_factor=1):
+    '''Converts a polar binary mask to Cartesian and places in an image of zeros'''
+
+    # Account for upsampling 
+    if zoom_factor != 1:
+        center = (center[0]*zoom_factor + zoom_factor/2, center[1]*zoom_factor + zoom_factor/2)
+        min_radius = min_radius * zoom_factor
+        max_radius = max_radius * zoom_factor
+        output_shape = map(lambda a: a * zoom_factor, output_shape)
+
+    # new image
+    image = np.zeros(output_shape)
+
+    # coordinate conversion
+    theta, r = np.meshgrid(np.linspace(0, 2*np.pi, mask.shape[1]),
+                           np.arange(0, max_radius))
+    x, y = coord_polar_to_cart(r, theta, center)
+    x, y = np.round(x), np.round(y)
+    x, y = x.astype(int), y.astype(int)
+
+    x = np.clip(x, 0, image.shape[0]-1)
+    y = np.clip(y, 0, image.shape[1]-1)
+    ix,iy = np.meshgrid(np.arange(0,mask.shape[1]), np.arange(0,mask.shape[0]))    
+    image[x,y] = mask
+
+    # downsample image
+    if zoom_factor != 1:
+        zf = 1/float(zoom_factor)
+        image = zoom(image, (zf, zf), order=4)
+
+    # ensure image remains a filled binary mask
+    image = (image > 0.5).astype(int)
+    image = binary_fill_holes(image)
+    return image
+    
+
+def find_edge_2d(polar, min_radius):
+    '''Dynamic programming algorithm to find edge given polar image'''
+    if len(polar.shape) != 2:
+        raise ValueError("argument to find_edge_2d must be 2D")
+
+    # Dynamic programming phase
+    values_right_shift = np.pad(polar, ((0, 0), (0, 1)), 'constant', constant_values=0)[:, 1:]
+    values_closeright_shift = np.pad(polar, ((1, 0),(0, 1)), 'constant', constant_values=0)[:-1, 1:]
+    values_awayright_shift = np.pad(polar, ((0, 1), (0, 1)), 'constant', constant_values=0)[1:, 1:]
+
+    values_move = np.zeros((polar.shape[0], polar.shape[1], 3))
+    values_move[:, :, 2] = np.add(polar, values_closeright_shift)  # closeright
+    values_move[:, :, 1] = np.add(polar, values_right_shift)  # right
+    values_move[:, :, 0] = np.add(polar, values_awayright_shift)  # awayright
+    values = np.amax(values_move, axis=2)
+
+    directions = np.argmax(values_move, axis=2)
+    directions = np.subtract(directions, 1)
+    directions = np.negative(directions)
+        
+    # Edge following phase
+    edge = []
+    mask = np.zeros(polar.shape)
+    r_max, r = 0, 0
+    for i,v in enumerate(values[:,0]):
+        if v >= r_max:
+            r, r_max = i, v
+    edge.append((r+min_radius, 0))
+    mask[0:r+1, 0] = 1
+    for t in range(1,polar.shape[1]):
+        r += directions[r, t-1]
+        if r >= directions.shape[0]: r = directions.shape[0]-1
+        if r < 0: r = 0
+        edge.append((r+min_radius, t))
+        mask[0:r+1, t] = 1
+
+    # add to inside of mask accounting for min_radius
+    new_mask = np.ones((min_radius+mask.shape[0], mask.shape[1]))
+    new_mask[min_radius:, :] = mask
+    
+    return np.array(edge), new_mask
+
+
+def edge_polar_to_cart(edge, center):
+    '''Converts a list of polar edge points to a list of cartesian edge points'''
+    cart_edge = [] 
+    for (r,t) in edge:
+        x, y = coord_polar_to_cart(r, t, center)
+        cart_edge.append((round(x), round(y)))
+    return cart_edge
+
+
+def cell_magic_wand_single_point(image, center, min_radius, max_radius,
+                                 roughness=2, zoom_factor=1):
+    '''Draws a border within a specified radius around a specified center "seed" point
+    using a polar transform and a dynamic programming edge-following algorithm.
+
+    Returns a binary mask with 1s inside the detected edge and
+    a list of points along the detected edge.'''
+    if roughness < 1:
+        roughness = 1
+        print("roughness must be >= 1, setting roughness to 1")
+    if min_radius < 0:
+        min_radius = 0
+        print("min_radius must be >=0, setting min_radius to 0")
+    if max_radius <= min_radius:
+        max_radius = min_radius + 1
+        print("max_radius must be larger than min_radius, setting max_radius to " + str(max_radius))
+    if zoom_factor <= 0:
+        zoom_factor = 1
+        print("negative zoom_factor not allowed, setting zoom_factor to 1")
+    phase_width = int(2 * np.pi * max_radius * roughness)
+    polar_image = image_cart_to_polar(image, center, min_radius, max_radius,
+                                      phase_width=phase_width, zoom_factor=zoom_factor)
+    polar_edge, polar_mask = find_edge_2d(polar_image, min_radius)
+    cart_edge = edge_polar_to_cart(polar_edge, center)
+    cart_mask = mask_polar_to_cart(polar_mask, center, min_radius, max_radius,
+                                   image.shape, zoom_factor=zoom_factor)
+    return cart_mask, cart_edge
+
+
+def cell_magic_wand(image, center, min_radius, max_radius,
+                    roughness=2, zoom_factor=1, center_range=2):
+    '''Runs the cell magic wand tool on multiple points near the supplied center and 
+    combines the results for a more robust edge detection then provided by the vanilla wand tool.
+
+    Returns a binary mask with 1s inside detected edge'''
+    
+    centers = []
+    for i in [-center_range, 0, center_range]:
+        for j in [-center_range, 0, center_range]:
+            centers.append((center[0]+i, center[1]+j))
+    masks = np.zeros((image.shape[0], image.shape[1], len(centers)))
+    for i, c in enumerate(centers):
+        mask, edge = cell_magic_wand_single_point(image, c, min_radius, max_radius,
+                                                  roughness=roughness, zoom_factor=zoom_factor)
+        masks[:,:,i] = mask
+    mean_mask = np.mean(masks, axis=2)
+    final_mask = (mean_mask > 0.5).astype(int)
+    return final_mask
+
+
+def cell_magic_wand_3d(image_3d, center, min_radius, max_radius,
+                       roughness=2, zoom_factor=1, center_range=2, z_step=1):
+    '''Robust cell magic wand tool for 3D images with dimensions (z, x, y) - default for tifffile.load.
+    This functions runs the robust wand tool on each z slice in the image and returns the mean mask
+    thresholded to 0.5'''
+    masks = np.zeros((int(image_3d.shape[0]/z_step), image_3d.shape[1], image_3d.shape[2]))
+    for s in range(int(image_3d.shape[0]/z_step)):
+        mask = cell_magic_wand(image_3d[s*z_step,:,:], center, min_radius, max_radius,
+                               roughness=roughness, zoom_factor=zoom_factor,
+                               center_range=center_range)
+        masks[s,:,:] = mask
+    mean_mask = np.mean(masks, axis=0)
+    final_mask = (mean_mask > 0.5).astype(int)
+    return final_mask