--- a +++ b/cell_magic_wand.py @@ -0,0 +1,225 @@ +########################################################################### +# +# Python implementation of the ImageJ Cell Magic Wand plugin +# (http://www.maxplanckflorida.org/fitzpatricklab/software/cellMagicWand/) +# with modifications to reduce variability due to seed point selection +# and to support edge detection using all z slices of a 3D image +# +# Author: Noah Apthorpe (apthorpe@cs.princeton.edu) +# +# Description: Draws a border within a specified radius +# around a specified center "seed" point +# using a polar transform and a dynamic +# programming edge-following algorithm +# +# Usage: Import and call the cell_magic_wand() function +# or cell_magic_wand_3d () function with +# a source image, radius window, and location of center +# point. Other parameters set as optional arguments. +# Returns a binary mask with 1s inside the detected edge and +# a list of points along the detected edge. +# +########################################################################### + +import numpy as np +from scipy.ndimage.interpolation import zoom +from scipy.ndimage.morphology import binary_fill_holes + + +def coord_polar_to_cart(r, theta, center): + '''Converts polar coordinates around center to Cartesian''' + x = r * np.cos(theta) + center[0] + y = r * np.sin(theta) + center[1] + return x, y + + +def coord_cart_to_polar(x, y, center): + '''Converts Cartesian coordinates to polar''' + r = np.sqrt((x-center[0])**2 + (y-center[1])**2) + theta = np.arctan2((y-center[1]), (x-center[0])) + return r, theta + + +def image_cart_to_polar(image, center, min_radius, max_radius, phase_width, zoom_factor=1): + '''Converts an image from cartesian to polar coordinates around center''' + + # Upsample image + if zoom_factor != 1: + image = zoom(image, (zoom_factor, zoom_factor), order=4) + center = (center[0]*zoom_factor + zoom_factor/2, center[1]*zoom_factor + zoom_factor/2) + min_radius = min_radius * zoom_factor + max_radius = max_radius * zoom_factor + + # pad if necessary + max_x, max_y = image.shape[0], image.shape[1] + pad_dist_x = np.max([(center[0] + max_radius) - max_x, -(center[0] - max_radius)]) + pad_dist_y = np.max([(center[1] + max_radius) - max_y, -(center[1] - max_radius)]) + pad_dist = int(np.max([0, pad_dist_x, pad_dist_y])) + if pad_dist != 0: + image = np.pad(image, pad_dist, 'constant') + + # coordinate conversion + theta, r = np.meshgrid(np.linspace(0, 2*np.pi, phase_width), + np.arange(min_radius, max_radius)) + x, y = coord_polar_to_cart(r, theta, center) + x, y = np.round(x), np.round(y) + x, y = x.astype(int), y.astype(int) + + polar = image[x, y] + polar.reshape((max_radius - min_radius, phase_width)) + + return polar + + +def mask_polar_to_cart(mask, center, min_radius, max_radius, output_shape, zoom_factor=1): + '''Converts a polar binary mask to Cartesian and places in an image of zeros''' + + # Account for upsampling + if zoom_factor != 1: + center = (center[0]*zoom_factor + zoom_factor/2, center[1]*zoom_factor + zoom_factor/2) + min_radius = min_radius * zoom_factor + max_radius = max_radius * zoom_factor + output_shape = map(lambda a: a * zoom_factor, output_shape) + + # new image + image = np.zeros(output_shape) + + # coordinate conversion + theta, r = np.meshgrid(np.linspace(0, 2*np.pi, mask.shape[1]), + np.arange(0, max_radius)) + x, y = coord_polar_to_cart(r, theta, center) + x, y = np.round(x), np.round(y) + x, y = x.astype(int), y.astype(int) + + x = np.clip(x, 0, image.shape[0]-1) + y = np.clip(y, 0, image.shape[1]-1) + ix,iy = np.meshgrid(np.arange(0,mask.shape[1]), np.arange(0,mask.shape[0])) + image[x,y] = mask + + # downsample image + if zoom_factor != 1: + zf = 1/float(zoom_factor) + image = zoom(image, (zf, zf), order=4) + + # ensure image remains a filled binary mask + image = (image > 0.5).astype(int) + image = binary_fill_holes(image) + return image + + +def find_edge_2d(polar, min_radius): + '''Dynamic programming algorithm to find edge given polar image''' + if len(polar.shape) != 2: + raise ValueError("argument to find_edge_2d must be 2D") + + # Dynamic programming phase + values_right_shift = np.pad(polar, ((0, 0), (0, 1)), 'constant', constant_values=0)[:, 1:] + values_closeright_shift = np.pad(polar, ((1, 0),(0, 1)), 'constant', constant_values=0)[:-1, 1:] + values_awayright_shift = np.pad(polar, ((0, 1), (0, 1)), 'constant', constant_values=0)[1:, 1:] + + values_move = np.zeros((polar.shape[0], polar.shape[1], 3)) + values_move[:, :, 2] = np.add(polar, values_closeright_shift) # closeright + values_move[:, :, 1] = np.add(polar, values_right_shift) # right + values_move[:, :, 0] = np.add(polar, values_awayright_shift) # awayright + values = np.amax(values_move, axis=2) + + directions = np.argmax(values_move, axis=2) + directions = np.subtract(directions, 1) + directions = np.negative(directions) + + # Edge following phase + edge = [] + mask = np.zeros(polar.shape) + r_max, r = 0, 0 + for i,v in enumerate(values[:,0]): + if v >= r_max: + r, r_max = i, v + edge.append((r+min_radius, 0)) + mask[0:r+1, 0] = 1 + for t in range(1,polar.shape[1]): + r += directions[r, t-1] + if r >= directions.shape[0]: r = directions.shape[0]-1 + if r < 0: r = 0 + edge.append((r+min_radius, t)) + mask[0:r+1, t] = 1 + + # add to inside of mask accounting for min_radius + new_mask = np.ones((min_radius+mask.shape[0], mask.shape[1])) + new_mask[min_radius:, :] = mask + + return np.array(edge), new_mask + + +def edge_polar_to_cart(edge, center): + '''Converts a list of polar edge points to a list of cartesian edge points''' + cart_edge = [] + for (r,t) in edge: + x, y = coord_polar_to_cart(r, t, center) + cart_edge.append((round(x), round(y))) + return cart_edge + + +def cell_magic_wand_single_point(image, center, min_radius, max_radius, + roughness=2, zoom_factor=1): + '''Draws a border within a specified radius around a specified center "seed" point + using a polar transform and a dynamic programming edge-following algorithm. + + Returns a binary mask with 1s inside the detected edge and + a list of points along the detected edge.''' + if roughness < 1: + roughness = 1 + print("roughness must be >= 1, setting roughness to 1") + if min_radius < 0: + min_radius = 0 + print("min_radius must be >=0, setting min_radius to 0") + if max_radius <= min_radius: + max_radius = min_radius + 1 + print("max_radius must be larger than min_radius, setting max_radius to " + str(max_radius)) + if zoom_factor <= 0: + zoom_factor = 1 + print("negative zoom_factor not allowed, setting zoom_factor to 1") + phase_width = int(2 * np.pi * max_radius * roughness) + polar_image = image_cart_to_polar(image, center, min_radius, max_radius, + phase_width=phase_width, zoom_factor=zoom_factor) + polar_edge, polar_mask = find_edge_2d(polar_image, min_radius) + cart_edge = edge_polar_to_cart(polar_edge, center) + cart_mask = mask_polar_to_cart(polar_mask, center, min_radius, max_radius, + image.shape, zoom_factor=zoom_factor) + return cart_mask, cart_edge + + +def cell_magic_wand(image, center, min_radius, max_radius, + roughness=2, zoom_factor=1, center_range=2): + '''Runs the cell magic wand tool on multiple points near the supplied center and + combines the results for a more robust edge detection then provided by the vanilla wand tool. + + Returns a binary mask with 1s inside detected edge''' + + centers = [] + for i in [-center_range, 0, center_range]: + for j in [-center_range, 0, center_range]: + centers.append((center[0]+i, center[1]+j)) + masks = np.zeros((image.shape[0], image.shape[1], len(centers))) + for i, c in enumerate(centers): + mask, edge = cell_magic_wand_single_point(image, c, min_radius, max_radius, + roughness=roughness, zoom_factor=zoom_factor) + masks[:,:,i] = mask + mean_mask = np.mean(masks, axis=2) + final_mask = (mean_mask > 0.5).astype(int) + return final_mask + + +def cell_magic_wand_3d(image_3d, center, min_radius, max_radius, + roughness=2, zoom_factor=1, center_range=2, z_step=1): + '''Robust cell magic wand tool for 3D images with dimensions (z, x, y) - default for tifffile.load. + This functions runs the robust wand tool on each z slice in the image and returns the mean mask + thresholded to 0.5''' + masks = np.zeros((int(image_3d.shape[0]/z_step), image_3d.shape[1], image_3d.shape[2])) + for s in range(int(image_3d.shape[0]/z_step)): + mask = cell_magic_wand(image_3d[s*z_step,:,:], center, min_radius, max_radius, + roughness=roughness, zoom_factor=zoom_factor, + center_range=center_range) + masks[s,:,:] = mask + mean_mask = np.mean(masks, axis=0) + final_mask = (mean_mask > 0.5).astype(int) + return final_mask