--- a
+++ b/pathflowai/visualize.py
@@ -0,0 +1,651 @@
+"""
+visualize.py
+=======================
+Plots SHAP outputs, UMAP embeddings, and overlays predictions on top of WSI.
+"""
+
+import plotly.graph_objs as go
+import plotly.offline as py
+import pandas as pd, numpy as np
+import networkx as nx
+import dask.array as da
+from PIL import Image
+import matplotlib
+matplotlib.use('Agg')
+import matplotlib.pyplot as plt
+import seaborn as sns
+import sqlite3
+import seaborn as sns
+from os.path import join
+from pathflowai.utils import npy2da
+sns.set()
+
+class PlotlyPlot:
+	"""Creates plotly html plots."""
+	def __init__(self):
+		self.plots=[]
+
+	def add_plot(self, t_data_df, G=None, color_col='color', name_col='name', xyz_cols=['x','y','z'], size=2, opacity=1.0, custom_colors=[]):
+		"""Adds plotting data to be plotted.
+
+		Parameters
+		----------
+		t_data_df:dataframe
+			3-D transformed dataframe.
+		G:nx.Graph
+			Networkx graph.
+		color_col:str
+			Column to use to color points.
+		name_col:str
+			Column to use to name points.
+		xyz_cols:list
+			3 columns that denote x,y,z coords.
+		size:int
+			Marker size.
+		opacity:float
+			Marker opacity.
+		custom_colors:list
+			Custom colors to supply.
+		"""
+		plots = []
+		x,y,z=tuple(xyz_cols)
+		if t_data_df[color_col].dtype == np.float64:
+			plots.append(
+				go.Scatter3d(x=t_data_df[x], y=t_data_df[y],
+							 z=t_data_df[z],
+							 name='', mode='markers',
+							 marker=dict(color=t_data_df[color_col], size=size, opacity=opacity, colorscale='Viridis',
+							 colorbar=dict(title='Colorbar')), text=t_data_df[color_col] if name_col not in list(t_data_df) else t_data_df[name_col]))
+		else:
+			colors = t_data_df[color_col].unique()
+			c = sns.color_palette('hls', len(colors))
+			c = np.array(['rgb({})'.format(','.join(((np.array(c_i)*255).astype(int).astype(str).tolist()))) for c_i in c])#c = ['hsl(' + str(h) + ',50%' + ',50%)' for h in np.linspace(0, 360, len(colors) + 2)]
+			if custom_colors:
+				c = custom_colors
+			color_dict = {name: c[i] for i,name in enumerate(sorted(colors))}
+
+			for name,col in color_dict.items():
+				plots.append(
+					go.Scatter3d(x=t_data_df[x][t_data_df[color_col]==name], y=t_data_df[y][t_data_df[color_col]==name],
+								 z=t_data_df[z][t_data_df[color_col]==name],
+								 name=str(name), mode='markers',
+								 marker=dict(color=col, size=size, opacity=opacity), text=t_data_df.index[t_data_df[color_col]==name] if 'name' not in list(t_data_df) else t_data_df[name_col][t_data_df[color_col]==name]))
+		if G is not None:
+			#pos = nx.spring_layout(G,dim=3,iterations=0,pos={i: tuple(t_data.loc[i,['x','y','z']]) for i in range(len(t_data))})
+			Xed, Yed, Zed = [], [], []
+			for edge in G.edges():
+				if edge[0] in t_data_df.index.values and edge[1] in t_data_df.index.values:
+					Xed += [t_data_df.loc[edge[0],x], t_data_df.loc[edge[1],x], None]
+					Yed += [t_data_df.loc[edge[0],y], t_data_df.loc[edge[1],y], None]
+					Zed += [t_data_df.loc[edge[0],z], t_data_df.loc[edge[1],z], None]
+			plots.append(go.Scatter3d(x=Xed,
+					  y=Yed,
+					  z=Zed,
+					  mode='lines',
+					  line=go.scatter3d.Line(color='rgb(210,210,210)', width=2),
+					  hoverinfo='none'
+					  ))
+		self.plots.extend(plots)
+
+	def plot(self, output_fname, axes_off=False):
+		"""Plot embedding of patches to html file.
+
+		Parameters
+		----------
+		output_fname:str
+			Output html file.
+		axes_off:bool
+			Remove axes.
+
+		"""
+		if axes_off:
+			fig = go.Figure(data=self.plots,layout=go.Layout(scene=dict(xaxis=dict(title='',autorange=True,showgrid=False,zeroline=False,showline=False,ticks='',showticklabels=False),
+				yaxis=dict(title='',autorange=True,showgrid=False,zeroline=False,showline=False,ticks='',showticklabels=False),
+				zaxis=dict(title='',autorange=True,showgrid=False,zeroline=False,showline=False,ticks='',showticklabels=False))))
+		else:
+			fig = go.Figure(data=self.plots)
+		py.plot(fig, filename=output_fname, auto_open=False)
+
+def to_pil(arr):
+	"""Numpy array to pil.
+
+	Parameters
+	----------
+	arr:array
+		Numpy array.
+
+	Returns
+	-------
+	Image
+		PIL Image.
+
+	"""
+	return Image.fromarray(arr.astype('uint8'), 'RGB')
+
+def blend(arr1, arr2, alpha=0.5):
+	"""Blend 2 arrays together, mixing with alpha.
+
+	Parameters
+	----------
+	arr1:array
+		Image 1.
+	arr2:array
+		Image 2.
+	alpha:float
+		Higher alpha makes image more like image 1.
+
+	Returns
+	-------
+	array
+		Resulting image.
+
+	"""
+	return alpha*arr1 + (1.-alpha)*arr2
+
+def prob2rbg(prob, palette, arr):
+	"""Convert probability score to rgb image.
+
+	Parameters
+	----------
+	prob:float
+		Between 0 and 1 score.
+	palette:palette
+		Pallet converts between prob and color.
+	arr:array
+		Original array.
+
+	Returns
+	-------
+	array
+		New image colored by prediction score.
+
+	"""
+	col = palette(prob)
+	for i in range(3):
+		arr[...,i] = int(col[i]*255)
+	return arr
+
+def seg2rgb(seg, palette, n_segmentation_classes):
+	"""Color each pixel by segmentation class.
+
+	Parameters
+	----------
+	seg:array
+		Segmentation mask.
+	palette:palette
+		Color to RGB map.
+	n_segmentation_classes:int
+		Total number segmentation classes.
+
+	Returns
+	-------
+	array
+		Returned segmentation image.
+	"""
+	#print(seg.shape)
+	#print((seg/n_segmentation_classes))
+	img=(palette(seg/n_segmentation_classes)[...,:3]*255).astype(int)
+	#print(img.shape)
+	return img
+
+def annotation2rgb(i,palette,arr):
+	"""Go from annotation of patch to color.
+
+	Parameters
+	----------
+	i:int
+		Annotation index.
+	palette:palette
+		Index to color mapping.
+	arr:array
+		Image array.
+
+	Returns
+	-------
+	array
+		Resulting image.
+
+	"""
+	col = palette[i]
+	for i in range(3):
+		arr[...,i] = int(col[i]*255)
+	return arr
+
+def plot_image_(image_file, compression_factor=2., test_image_name='test.png'):
+	"""Plots entire SVS/other image.
+
+	Parameters
+	----------
+	image_file:str
+		Image file.
+	compression_factor:float
+		Amount to shrink each dimension of image.
+	test_image_name:str
+		Output image file.
+
+	"""
+	from pathflowai.utils import svs2dask_array, npy2da
+	import cv2
+	if image_file.endswith('.zarr'):
+		arr=da.from_zarr(image_file)
+	else:
+		arr=svs2dask_array(image_file, tile_size=1000, overlap=0, remove_last=True, allow_unknown_chunksizes=False) if (not image_file.endswith('.npy')) else npy2da(image_file)
+	arr2=to_pil(cv2.resize(arr.compute(), dsize=tuple((np.array(arr.shape[:2])/compression_factor).astype(int).tolist()), interpolation=cv2.INTER_CUBIC))
+	arr2.save(test_image_name)
+
+# for now binary output
+class PredictionPlotter:
+	"""Plots predictions over entire image.
+
+	Parameters
+	----------
+	dask_arr_dict:dict
+		Stores all dask arrays corresponding to all of the images.
+	patch_info_db:str
+		Patch level information, eg. prediction.
+	compression_factor:float
+		How much to compress image by.
+	alpha:float
+		Low value assigns higher weight to prediction over original image.
+	patch_size:int
+		Patch size.
+	no_db:bool
+		Don't use patch information.
+	plot_annotation:bool
+		Plot annotations from patch information.
+	segmentation:bool
+		Plot segmentation mask.
+	n_segmentation_classes:int
+		Number segmentation classes.
+	input_dir:str
+		Input directory.
+	annotation_col:str
+		Annotation column to plot.
+	scaling_factor:float
+		Multiplies the prediction scores to make them appear darker on the images when predicting.
+	"""
+	# some patches have been filtered out, not one to one!!! figure out
+	def __init__(self, dask_arr_dict, patch_info_db, compression_factor=3, alpha=0.5, patch_size=224, no_db=False, plot_annotation=False, segmentation=False, n_segmentation_classes=4, input_dir='', annotation_col='annotation', scaling_factor=1.):
+
+		self.segmentation = segmentation
+		self.scaling_factor=scaling_factor
+		self.segmentation_maps = None
+		self.n_segmentation_classes=float(n_segmentation_classes)
+		self.pred_palette = sns.cubehelix_palette(start=0,as_cmap=True)
+		if not no_db:
+			self.compression_factor=compression_factor
+			self.alpha = alpha
+			self.patch_size = patch_size
+			conn = sqlite3.connect(patch_info_db)
+			patch_info=pd.read_sql('select * from "{}";'.format(patch_size),con=conn)
+			conn.close()
+			self.annotations = {str(a):i for i,a in enumerate(patch_info['annotation'].unique().tolist())}
+			self.plot_annotation=plot_annotation
+			self.palette=sns.color_palette(n_colors=len(list(self.annotations.keys())))
+			#print(self.palette)
+			if 'y_pred' not in patch_info.columns:
+				patch_info['y_pred'] = 0.
+			self.patch_info=patch_info[['ID','x','y','patch_size','annotation',annotation_col]] # y_pred
+			if 0:
+				for ID in predictions:
+					patch_info.loc[patch_info["ID"]==ID,'y_pred'] = predictions[ID]
+			self.patch_info = self.patch_info[np.isin(self.patch_info['ID'],np.array(list(dask_arr_dict.keys())))]
+		if self.segmentation:
+			self.segmentation_maps = {slide:npy2da(join(input_dir,'{}_mask.npy'.format(slide))) for slide in dask_arr_dict.keys()}
+		#self.patch_info[['x','y','patch_size']]/=self.compression_factor
+		self.dask_arr_dict = {k:v[...,:3] for k,v in dask_arr_dict.items()}
+
+	def add_custom_segmentation(self, basename, npy):
+		"""Replace segmentation mask with new custom segmentation.
+
+		Parameters
+		----------
+		basename:str
+			Patient ID
+		npy:str
+			Numpy mask.
+		"""
+		self.segmentation_maps[basename] = da.from_array(np.load(npy,mmap_mode='r+'))
+
+	def generate_image(self, ID):
+		"""Generate the image array for the whole slide image with predictions overlaid.
+
+		Parameters
+		----------
+		ID:str
+			patient ID.
+
+		Returns
+		-------
+		array
+			Resulting overlaid whole slide image.
+
+		"""
+		patch_info = self.patch_info[self.patch_info['ID']==ID]
+		dask_arr = self.dask_arr_dict[ID]
+		arr_shape = np.array(dask_arr.shape).astype(float)
+
+		#image=da.zeros_like(dask_arr)
+
+		arr_shape[:2]/=self.compression_factor
+
+		arr_shape=arr_shape.astype(int).tolist()
+
+		img = Image.new('RGB',arr_shape[:2],'white')
+
+		for i in range(patch_info.shape[0]):
+			ID,x,y,patch_size,annotation,pred = patch_info.iloc[i].tolist()
+			#print(x,y,annotation)
+			x_new,y_new = int(x/self.compression_factor),int(y/self.compression_factor)
+			image = np.zeros((patch_size,patch_size,3))
+			if self.segmentation:
+				image=seg2rgb(self.segmentation_maps[ID][x:x+patch_size,y:y+patch_size].compute(),self.pred_palette, self.n_segmentation_classes)
+			else:
+				image=prob2rbg(pred*self.scaling_factor, self.pred_palette, image) if not self.plot_annotation else annotation2rgb(self.annotations[str(pred)],self.palette,image) # annotation
+			arr=dask_arr[x:x+patch_size,y:y+patch_size].compute()
+			#print(image.shape)
+			blended_patch=blend(arr,image, self.alpha).transpose((1,0,2))
+			blended_patch_pil = to_pil(blended_patch)
+			patch_size/=self.compression_factor
+			patch_size=int(patch_size)
+			blended_patch_pil=blended_patch_pil.resize((patch_size,patch_size))
+			img.paste(blended_patch_pil, box=(x_new,y_new), mask=None)
+		return img
+
+	def return_patch(self, ID, x, y, patch_size):
+		"""Return one single patch instead of entire image.
+
+		Parameters
+		----------
+		ID:str
+			Patient ID
+		x:int
+			X coordinate.
+		y:int
+			Y coordinate.
+		patch_size:int
+			Patch size.
+
+		Returns
+		-------
+		array
+			Image.
+		"""
+		img=(self.dask_arr_dict[ID][x:x+patch_size,y:y+patch_size].compute() if not self.segmentation else seg2rgb(self.segmentation_maps[ID][x:x+patch_size,y:y+patch_size].compute(),self.pred_palette, self.n_segmentation_classes))
+		return to_pil(img)
+
+	def output_image(self, img, filename, tif=False):
+		"""Output calculated image to file.
+
+		Parameters
+		----------
+		img:array
+			Image.
+		filename:str
+			Output file name.
+		tif:bool
+			Store in TIF format?
+		"""
+		if tif:
+			from tifffile import imwrite
+			imwrite(filename, np.array(img), photometric='rgb')
+		else:
+			img.save(filename)
+
+def plot_shap(model, dataset_opts, transform_opts, batch_size, outputfilename, n_outputs=1, method='deep', local_smoothing=0.0, n_samples=20, pred_out=False):
+	"""Plot shapley attributions overlaid on images for classification tasks.
+
+	Parameters
+	----------
+	model:nn.Module
+		Pytorch model.
+	dataset_opts:dict
+		Options used to configure dataset
+	transform_opts:dict
+		Options used to configure transformers.
+	batch_size:int
+		Batch size for training.
+	outputfilename:str
+		Output filename.
+	n_outputs:int
+		Number of top outputs.
+	method:str
+		Gradient or deep explainer.
+	local_smoothing:float
+		How much to smooth shapley map.
+	n_samples:int
+		Number shapley samples to draw.
+	pred_out:bool
+		Label images with binary prediction score?
+
+	"""
+	import torch
+	from torch.nn import functional as F
+	import numpy as np
+	from torch.utils.data import DataLoader
+	import shap
+	from pathflowai.datasets import DynamicImageDataset
+	import matplotlib
+	from matplotlib import pyplot as plt
+	from pathflowai.sampler import ImbalancedDatasetSampler
+
+	out_transform=dict(sigmoid=F.sigmoid,softmax=F.softmax,none=lambda x: x)
+	binary_threshold=dataset_opts.pop('binary_threshold')
+	num_targets=dataset_opts.pop('num_targets')
+
+	dataset = DynamicImageDataset(**dataset_opts)
+
+	if dataset_opts['classify_annotations']:
+		binarizer=dataset.binarize_annotations(num_targets=num_targets,binary_threshold=binary_threshold)
+		num_targets=len(dataset.targets)
+
+	dataloader_val = DataLoader(dataset,batch_size=batch_size, num_workers=10, shuffle=True if num_targets>1 else False, sampler=ImbalancedDatasetSampler(dataset) if num_targets==1 else None)
+	#dataloader_test = DataLoader(dataset,batch_size=batch_size,num_workers=10, shuffle=False)
+
+	background,y_background=next(iter(dataloader_val))
+	if method=='gradient':
+		background=torch.cat([background,next(iter(dataloader_val))[0]],0)
+	X_test,y_test=next(iter(dataloader_val))
+
+	if torch.cuda.is_available():
+		background=background.cuda()
+		X_test=X_test.cuda()
+
+	if pred_out!='none':
+		if torch.cuda.is_available():
+			model2=model.cuda()
+		y_test=out_transform[pred_out](model2(X_test)).detach().cpu()
+
+	y_test=y_test.numpy()
+
+	if method=='deep':
+		e = shap.DeepExplainer(model, background)
+		s=e.shap_values(X_test, ranked_outputs=n_outputs)
+	elif method=='gradient':
+		e = shap.GradientExplainer(model, background, batch_size=batch_size, local_smoothing=local_smoothing)
+		s=e.shap_values(X_test, ranked_outputs=n_outputs, nsamples=n_samples)
+
+	if y_test.shape[1]>1:
+		y_test=y_test.argmax(axis=1)
+
+	if n_outputs>1:
+		shap_values, idx = s
+	else:
+		shap_values, idx = s, y_test
+
+	#print(shap_values) # .detach().cpu()
+
+	if num_targets == 1:
+		shap_numpy = [np.swapaxes(np.swapaxes(shap_values, 1, -1), 1, 2)]
+	else:
+		shap_numpy = [np.swapaxes(np.swapaxes(s, 1, -1), 1, 2) for s in shap_values]
+		#print(shap_numpy.shape)
+	X_test_numpy=X_test.detach().cpu().numpy()
+	X_test_numpy=X_test_numpy.transpose((0,2,3,1))
+	for i in range(X_test_numpy.shape[0]):
+		X_test_numpy[i,...]*=np.array(transform_opts['std'])
+		X_test_numpy[i,...]+=np.array(transform_opts['mean'])
+	X_test_numpy=X_test_numpy.transpose((0,3,1,2))
+	test_numpy = np.swapaxes(np.swapaxes(X_test_numpy, 1, -1), 1, 2)
+	if pred_out!='none':
+		labels=y_test.astype(str)
+	else:
+		labels = np.array([[(dataloader_val.dataset.targets[i[j]] if num_targets>1 else str(i)) for j in range(n_outputs)] for i in idx])#[:,np.newaxis] # y_test
+	if 0 and (len(labels.shape)<2 or labels.shape[1]==1):
+		labels=labels.flatten()#[:np.newaxis]
+
+	#print(labels.shape,shap_numpy.shape[0])
+	plt.figure()
+	shap.image_plot(shap_numpy, test_numpy, labels)# if num_targets!=1 else shap_values -test_numpy , labels=dataloader_test.dataset.targets)
+	plt.savefig(outputfilename, dpi=300)
+
+def plot_umap_images(dask_arr_dict, embeddings_file, ID=None, cval=1., image_res=300., outputfname='output_embedding.png', mpl_scatter=True, remove_background_annotation='', max_background_area=0.01, zoom=0.05, n_neighbors=10, sort_col='', sort_mode='asc'):
+	"""Make UMAP embedding plot, overlaid with images.
+
+	Parameters
+	----------
+	dask_arr_dict:dict
+		Stored dask arrays for each WSI.
+	embeddings_file:str
+		Embeddings pickle file stored from running using after trainign the model.
+	ID:str
+		Patient ID.
+	cval:float
+		Deprecated
+	image_res:float
+		Image resolution.
+	outputfname:str
+		Output image file.
+	mpl_scatter:bool
+		Recommended: Use matplotlib for scatter plot.
+	remove_background_annotation:str
+		Remove the background annotations. Enter for annotation to remove.
+	max_background_area:float
+		Maximum backgrund area in each tile for inclusion.
+	zoom:float
+		How much to zoom in on each patch, less than 1 is zoom out.
+	n_neighbors:int
+		Number of neighbors for UMAP embedding.
+	sort_col:str
+		Patch info column to sort on.
+	sort_mode:str
+		Sort ascending or descending.
+
+	Returns
+	-------
+	type
+		Description of returned object.
+
+	Inspired by: https://gist.github.com/lukemetz/be6123c7ee3b366e333a
+	WIP!! Needs testing."""
+	import torch
+	import dask
+	from dask.distributed import Client
+	from umap import UMAP
+	from pathflowai.visualize import PlotlyPlot
+	import pandas as pd, numpy as np
+	import skimage.io
+	from skimage.transform import resize
+	import matplotlib
+	matplotlib.use('Agg')
+	from matplotlib import pyplot as plt
+	sns.set(style='white')
+
+	def min_resize(img, size):
+		"""
+		Resize an image so that it is size along the minimum spatial dimension.
+		"""
+		w, h = map(float, img.shape[:2])
+		if min([w, h]) != size:
+			if w <= h:
+				img = resize(img, (int(round((h/w)*size)), int(size)))
+			else:
+				img = resize(img, (int(size), int(round((w/h)*size))))
+		return img
+
+	#dask_arr = dask_arr_dict[ID]
+
+	embeddings_dict=torch.load(embeddings_file)
+	embeddings=embeddings_dict['embeddings']
+	patch_info=embeddings_dict['patch_info']
+	if sort_col:
+		idx=np.argsort(patch_info[sort_col].values)
+		if sort_mode == 'desc':
+			idx=idx[::-1]
+		patch_info = patch_info.iloc[idx]
+		embeddings=embeddings.iloc[idx]
+	if ID:
+		removal_bool=(patch_info['ID']==ID).values
+		patch_info = patch_info.loc[removal_bool]
+		embeddings=embeddings.loc[removal_bool]
+	if remove_background_annotation:
+		removal_bool=(patch_info[remove_background_annotation]<=(1.-max_background_area)).values
+		patch_info=patch_info.loc[removal_bool]
+		embeddings=embeddings.loc[removal_bool]
+
+	umap=UMAP(n_components=2,n_neighbors=n_neighbors)
+	t_data=pd.DataFrame(umap.fit_transform(embeddings.iloc[:,:-1].values),columns=['x','y'],index=embeddings.index)
+
+	images=[]
+
+	for i in range(patch_info.shape[0]):
+		ID=patch_info.iloc[i]['ID']
+		x,y,patch_size=patch_info.iloc[i][['x','y','patch_size']].values.tolist()
+		arr=dask_arr_dict[ID][x:x+patch_size,y:y+patch_size]#.transpose((2,0,1))
+		images.append(arr)
+
+	c=Client()
+	images=dask.compute(images)
+	c.close()
+
+	if mpl_scatter:
+		from matplotlib.offsetbox import OffsetImage, AnnotationBbox
+		def imscatter(x, y, ax, imageData, zoom):
+			images = []
+			for i in range(len(x)):
+				x0, y0 = x[i], y[i]
+				img = imageData[i]
+				#print(img.shape)
+				image = OffsetImage(img, zoom=zoom)
+				ab = AnnotationBbox(image, (x0, y0), xycoords='data', frameon=False)
+				images.append(ax.add_artist(ab))
+
+			ax.update_datalim(np.column_stack([x, y]))
+			ax.autoscale()
+
+		fig, ax = plt.subplots()
+		imscatter(t_data['x'].values, t_data['y'].values, imageData=images[0], ax=ax, zoom=zoom)
+		sns.despine()
+		plt.savefig(outputfname,dpi=300)
+
+
+	else:
+		xx=t_data.iloc[:,0]
+		yy=t_data.iloc[:,1]
+
+		images = [min_resize(image, img_res) for image in images]
+		max_width = max([image.shape[0] for image in images])
+		max_height = max([image.shape[1] for image in images])
+
+		x_min, x_max = xx.min(), xx.max()
+		y_min, y_max = yy.min(), yy.max()
+		# Fix the ratios
+		sx = (x_max-x_min)
+		sy = (y_max-y_min)
+		if sx > sy:
+			res_x = sx/float(sy)*res
+			res_y = res
+		else:
+			res_x = res
+			res_y = sy/float(sx)*res
+
+		canvas = np.ones((res_x+max_width, res_y+max_height, 3))*cval
+		x_coords = np.linspace(x_min, x_max, res_x)
+		y_coords = np.linspace(y_min, y_max, res_y)
+		for x, y, image in zip(xx, yy, images):
+			w, h = image.shape[:2]
+			x_idx = np.argmin((x - x_coords)**2)
+			y_idx = np.argmin((y - y_coords)**2)
+			canvas[x_idx:x_idx+w, y_idx:y_idx+h] = image
+
+		skimage.io.imsave(outputfname, canvas)