Switch to side-by-side view

--- a
+++ b/pathflowai/cli_visualizations.py
@@ -0,0 +1,200 @@
+import click
+from pathflowai.visualize import PredictionPlotter, plot_image_
+import glob, os
+from utils import load_preprocessed_img
+import dask.array as da
+
+
+CONTEXT_SETTINGS = dict(help_option_names=['-h','--help'], max_content_width=90)
+
+@click.group(context_settings= CONTEXT_SETTINGS)
+@click.version_option(version='0.1')
+def visualize():
+	pass
+
+@visualize.command()
+@click.option('-i', '--input_dir', default='./inputs/', help='Input directory for patches.', type=click.Path(exists=False), show_default=True)
+@click.option('-b', '--basename', default='A01', help='Basename of patches.', type=click.Path(exists=False), show_default=True)
+@click.option('-p', '--patch_info_file', default='patch_info.db', help='Datbase containing all patches', type=click.Path(exists=False), show_default=True)
+@click.option('-ps', '--patch_size', default=224, help='Patch size.',  show_default=True)
+@click.option('-x', '--x', default=0, help='X Coordinate of patch.',  show_default=True)
+@click.option('-y', '--y', default=0, help='Y coordinate of patch.',  show_default=True)
+@click.option('-o', '--outputfname', default='./output_image.png', help='Output extracted image.', type=click.Path(exists=False), show_default=True)
+@click.option('-s', '--segmentation', is_flag=True, help='Plot segmentations.', show_default=True)
+@click.option('-sc', '--n_segmentation_classes', default=4, help='Number segmentation classes',  show_default=True)
+@click.option('-c', '--custom_segmentation', default='', help='Add custom segmentation map from prediction, in npy',  show_default=True)
+def extract_patch(input_dir, basename, patch_info_file, patch_size, x, y, outputfname, segmentation, n_segmentation_classes, custom_segmentation):
+	"""Extract image of patch of any size/location and output to image file"""
+	if glob.glob(os.path.join(input_dir,'*.zarr')):
+		dask_arr_dict = {os.path.basename(f).split('.zarr')[0]:da.from_zarr(f) for f in glob.glob(os.path.join(input_dir,'*.zarr')) if os.path.basename(f).split('.zarr')[0] == basename}
+	else:
+		dask_arr_dict = {basename:load_preprocessed_img(os.path.join(input_dir,'{}.npy'.format(basename)))}
+	pred_plotter = PredictionPlotter(dask_arr_dict, patch_info_file, compression_factor=3, alpha=0.5, patch_size=patch_size, no_db=True, segmentation=segmentation,n_segmentation_classes=n_segmentation_classes, input_dir=input_dir)
+	if custom_segmentation:
+		pred_plotter.add_custom_segmentation(basename,custom_segmentation)
+	img = pred_plotter.return_patch(basename, x, y, patch_size)
+	pred_plotter.output_image(img,outputfname)
+
+@visualize.command()
+@click.option('-i', '--image_file', default='./inputs/a.svs', help='Input image file.', type=click.Path(exists=False), show_default=True)
+@click.option('-cf', '--compression_factor', default=3., help='How much compress image.',  show_default=True)
+@click.option('-o', '--outputfname', default='./output_image.png', help='Output extracted image.', type=click.Path(exists=False), show_default=True)
+def plot_image(image_file, compression_factor, outputfname):
+	"""Plots the whole slide image supplied."""
+	plot_image_(image_file, compression_factor=compression_factor, test_image_name=outputfname)
+
+@visualize.command()
+@click.option('-i', '--mask_file', default='./inputs/a_mask.npy', help='Input mask file.', type=click.Path(exists=False), show_default=True)
+@click.option('-o', '--outputfname', default='./output_image.png', help='Output extracted image.', type=click.Path(exists=False), show_default=True)
+def plot_mask_mpl(mask_file, outputfname):
+	"""Plots the whole slide mask supplied."""
+	import matplotlib
+	matplotlib.use('Agg')
+	import matplotlib.pyplot as plt
+	import numpy as np
+	#plt.figure()
+	plt.imshow(np.load(mask_file))
+	plt.axis('off')
+	plt.savefig(outputfname,dpi=500)
+
+@visualize.command()
+@click.option('-i', '--input_dir', default='./inputs/', help='Input directory for patches.', type=click.Path(exists=False), show_default=True)
+@click.option('-b', '--basename', default='A01', help='Basename of patches.', type=click.Path(exists=False), show_default=True)
+@click.option('-p', '--patch_info_file', default='patch_info.db', help='Datbase containing all patches', type=click.Path(exists=False), show_default=True)
+@click.option('-ps', '--patch_size', default=224, help='Patch size.',  show_default=True)
+@click.option('-o', '--outputfname', default='./output_image.png', help='Output extracted image.', type=click.Path(exists=False), show_default=True)
+@click.option('-an', '--annotations', is_flag=True, help='Plot annotations instead of predictions.', show_default=True)
+@click.option('-cf', '--compression_factor', default=3., help='How much compress image.',  show_default=True)
+@click.option('-al', '--alpha', default=0.8, help='How much to give annotations/predictions versus original image.',  show_default=True)
+@click.option('-s', '--segmentation', is_flag=True, help='Plot segmentations.', show_default=True)
+@click.option('-sc', '--n_segmentation_classes', default=4, help='Number segmentation classes',  show_default=True)
+@click.option('-c', '--custom_segmentation', default='', help='Add custom segmentation map from prediction, npy format.',  show_default=True)
+@click.option('-ac', '--annotation_col', default='annotation', help='Column of annotations', type=click.Path(exists=False), show_default=True)
+@click.option('-sf', '--scaling_factor', default=1., help='Multiply all prediction scores by this amount.',  show_default=True)
+@click.option('-tif', '--tif_file', is_flag=True, help='Write to tiff file.',  show_default=True)
+def plot_predictions(input_dir,basename,patch_info_file,patch_size,outputfname,annotations, compression_factor, alpha, segmentation, n_segmentation_classes, custom_segmentation, annotation_col, scaling_factor, tif_file):
+	"""Overlays classification, regression and segmentation patch level predictions on top of whole slide image."""
+	if glob.glob(os.path.join(input_dir,'*.zarr')):
+		dask_arr_dict = {os.path.basename(f).split('.zarr')[0]:da.from_zarr(f) for f in glob.glob(os.path.join(input_dir,'*.zarr')) if os.path.basename(f).split('.zarr')[0] == basename}
+	else:
+		dask_arr_dict = {basename:load_preprocessed_img(os.path.join(input_dir,'{}.npy'.format(basename)))}
+	pred_plotter = PredictionPlotter(dask_arr_dict, patch_info_file, compression_factor=compression_factor, alpha=alpha, patch_size=patch_size, no_db=False, plot_annotation=annotations, segmentation=segmentation, n_segmentation_classes=n_segmentation_classes, input_dir=input_dir, annotation_col=annotation_col, scaling_factor=scaling_factor)
+	if custom_segmentation:
+		pred_plotter.add_custom_segmentation(basename,custom_segmentation)
+	img = pred_plotter.generate_image(basename)
+	pred_plotter.output_image(img, outputfname, tif_file)
+
+@visualize.command()
+@click.option('-i', '--img_file', default='image.txt', help='Input image.', type=click.Path(exists=False), show_default=True)
+@click.option('-a', '--annotation_txt', default='annotation.txt', help='Column of annotations', type=click.Path(exists=False), show_default=True)
+@click.option('-ocf', '--original_compression_factor', default=1., help='How much compress image.',  show_default=True)
+@click.option('-cf', '--compression_factor', default=3., help='How much compress image.',  show_default=True)
+@click.option('-o', '--outputfilename', default='./output_image.png', help='Output extracted image.', type=click.Path(exists=False), show_default=True)
+def overlay_new_annotations(img_file,annotation_txt, original_compression_factor,compression_factor, outputfilename):
+	"""Custom annotations, in format [Point: x, y, Point: x, y ... ] one line like this per polygon, overlap these polygons on top of WSI."""
+	#from shapely.ops import unary_union, polygonize
+	#from shapely.geometry import MultiPolygon, LineString, MultiPoint, box, Point
+	#from shapely.geometry.polygon import Polygon
+	print("Experimental, in development")
+	import matplotlib
+	matplotlib.use('Agg')
+	import matplotlib.pyplot as plt
+	import re, numpy as np
+	from PIL import Image
+	import cv2
+	from pathflowai.visualize import to_pil
+	from scipy.misc import imresize
+	im=plt.imread(img_file) if not img_file.endswith('.npy') else np.load(img_file,mmap_mode='r+')
+	print(im.shape)
+	if compression_factor>1 and original_compression_factor == 1.:
+		im=cv2.resize(im,dsize=(int(im.shape[1]/compression_factor),int(im.shape[0]/compression_factor)),interpolation=cv2.INTER_CUBIC)#im.resize((int(im.shape[0]/compression_factor),int(im.shape[1]/compression_factor)))
+	print(im.shape)
+	im=np.array(im)
+	im=im.transpose((1,0,2))##[::-1,...]#
+	plt.imshow(im)
+	with open(annotation_txt) as f:
+		polygons=[np.array([list(map(float,filter(None,coords.strip(' ').split(',')))) for coords in re.sub('\]|\[|\ ','',line).rstrip().split('Point:') if coords])/compression_factor for line in f]
+	for polygon in polygons:
+		plt.plot(polygon[:,0],polygon[:,1],color='blue')
+	plt.axis('off')
+	plt.savefig(outputfilename,dpi=500)
+
+@visualize.command()
+@click.option('-i', '--embeddings_file', default='predictions/embeddings.pkl', help='Embeddings.', type=click.Path(exists=False), show_default=True)
+@click.option('-o', '--plotly_output_file', default='predictions/embeddings.html', help='Plotly output file.', type=click.Path(exists=False), show_default=True)
+@click.option('-a', '--annotations', default=[], multiple=True, help='Multiple annotations to color image.', show_default=True)
+@click.option('-rb', '--remove_background_annotation', default='', help='If selected, removes 100\% background patches based on this annotation.', type=click.Path(exists=False), show_default=True)
+@click.option('-ma', '--max_background_area', default=0.05, help='Max background area before exclusion.',  show_default=True)
+@click.option('-b', '--basename', default='', help='Basename of patches.', type=click.Path(exists=False), show_default=True)
+@click.option('-nn', '--n_neighbors', default=8, help='Number nearest neighbors.',  show_default=True)
+def plot_embeddings(embeddings_file,plotly_output_file, annotations, remove_background_annotation , max_background_area, basename, n_neighbors):
+	"""Perform UMAP embeddings of patches and plot using plotly."""
+	import torch
+	from umap import UMAP
+	from pathflowai.visualize import PlotlyPlot
+	import pandas as pd, numpy as np
+	embeddings_dict=torch.load(embeddings_file)
+	embeddings=embeddings_dict['embeddings']
+	patch_info=embeddings_dict['patch_info']
+	if remove_background_annotation:
+		removal_bool=(patch_info[remove_background_annotation]<=(1.-max_background_area)).values
+		patch_info=patch_info.loc[removal_bool]
+		embeddings=embeddings.loc[removal_bool]
+	if basename:
+		removal_bool=(patch_info['ID']==basename).values
+		patch_info=patch_info.loc[removal_bool]
+		embeddings=embeddings.loc[removal_bool]
+	if annotations:
+		annotations=np.array(annotations)
+		if len(annotations)>1:
+			embeddings.loc[:,'ID']=np.vectorize(lambda i: annotations[np.argmax(patch_info.iloc[i][annotations].values)])(np.arange(embeddings.shape[0]))
+		else:
+			embeddings.loc[:,'ID']=patch_info[annotations].values
+	umap=UMAP(n_components=3,n_neighbors=n_neighbors)
+	t_data=pd.DataFrame(umap.fit_transform(embeddings.iloc[:,:-1].values),columns=['x','y','z'],index=embeddings.index)
+	t_data['color']=embeddings['ID'].values
+	t_data['name']=embeddings.index.values
+	pp=PlotlyPlot()
+	pp.add_plot(t_data,size=8)
+	pp.plot(plotly_output_file,axes_off=True)
+
+@visualize.command()
+@click.option('-m', '--model_pkl', default='', help='Plotly output file.', type=click.Path(exists=False), show_default=True)
+@click.option('-bs', '--batch_size', default=32, help='Batch size.',  show_default=True)
+@click.option('-o', '--outputfilename', default='predictions/shap_plots.png', help='SHAPley visualization.', type=click.Path(exists=False), show_default=True)
+@click.option('-mth', '--method', default='deep', help='Method of explaining.', type=click.Choice(['deep','gradient']), show_default=True)
+@click.option('-l', '--local_smoothing', default=0.0, help='Local smoothing of SHAP scores.',  show_default=True)
+@click.option('-ns', '--n_samples', default=32, help='Number shapley samples for shapley regression (gradient explainer).',  show_default=True)
+@click.option('-p', '--pred_out', default='none', help='If not none, output prediction as shap label.', type=click.Choice(['none','sigmoid','softmax']), show_default=True)
+def shapley_plot(model_pkl, batch_size, outputfilename, method='deep', local_smoothing=0.0, n_samples=20, pred_out='none'):
+	"""Run SHAPley attribution method on patches after classification task to see where model made prediction based on."""
+	from pathflowai.visualize import plot_shap
+	import torch
+	from pathflowai.datasets import get_data_transforms
+	model_dict=torch.load(model_pkl)
+	model_dict['dataset_opts']['transformers']=get_data_transforms(**model_dict['transform_opts'])
+	plot_shap(model_dict['model'], model_dict['dataset_opts'], model_dict['transform_opts'], batch_size, outputfilename, method=method, local_smoothing=local_smoothing, n_samples=n_samples, pred_out=pred_out)
+
+@visualize.command()
+@click.option('-i', '--input_dir', default='./inputs/', help='Input directory for patches.', type=click.Path(exists=False), show_default=True)
+@click.option('-e', '--embeddings_file', default='predictions/embeddings.pkl', help='Embeddings.', type=click.Path(exists=False), show_default=True)
+@click.option('-b', '--basename', default='', help='Basename of patches.', type=click.Path(exists=False), show_default=True)
+@click.option('-o', '--outputfilename', default='predictions/shap_plots.png', help='Embedding visualization.', type=click.Path(exists=False), show_default=True)
+@click.option('-mpl', '--mpl_scatter', is_flag=True, help='Plot segmentations.', show_default=True)
+@click.option('-rb', '--remove_background_annotation', default='', help='If selected, removes 100\% background patches based on this annotation.', type=click.Path(exists=False), show_default=True)
+@click.option('-ma', '--max_background_area', default=0.05, help='Max background area before exclusion.',  show_default=True)
+@click.option('-z', '--zoom', default=0.05, help='Size of images.',  show_default=True)
+@click.option('-nn', '--n_neighbors', default=8, help='Number nearest neighbors.',  show_default=True)
+@click.option('-sc', '--sort_col', default='', help='Sort samples on this column.', type=click.Path(exists=False), show_default=True)
+@click.option('-sm', '--sort_mode', default='asc', help='Sort ascending or descending.', type=click.Choice(['asc','desc']), show_default=True)
+def plot_image_umap_embeddings(input_dir,embeddings_file,basename,outputfilename,mpl_scatter, remove_background_annotation, max_background_area, zoom, n_neighbors, sort_col='', sort_mode='asc'):
+	"""Plots a UMAP embedding with each point as its corresponding patch image."""
+	from pathflowai.visualize import plot_umap_images
+	if glob.glob(os.path.join(input_dir,'*.zarr')):
+		dask_arr_dict = {os.path.basename(f).split('.zarr')[0]:da.from_zarr(f) for f in glob.glob(os.path.join(input_dir,'*.zarr')) if (not basename) or os.path.basename(f).split('.zarr')[0] == basename}
+	else:
+		dask_arr_dict = {basename:load_preprocessed_img(os.path.join(input_dir,'{}.npy'.format(basename))) for basename in ([basename] if basename else set(list(map(lambda x: os.path.basename(os.path.splitext(x)[0]),glob.glob(os.path.join(input_dir,"*.*"))))))}
+	plot_umap_images(dask_arr_dict, embeddings_file, ID=basename, cval=1., image_res=300., outputfname=outputfilename, mpl_scatter=mpl_scatter, remove_background_annotation=remove_background_annotation, max_background_area=max_background_area, zoom=zoom, n_neighbors=n_neighbors, sort_col=sort_col, sort_mode=sort_mode)
+
+if __name__ == '__main__':
+	visualize()