a b/pathflowai/cli_visualizations.py
1
import click
2
from pathflowai.visualize import PredictionPlotter, plot_image_
3
import glob, os
4
from utils import load_preprocessed_img
5
import dask.array as da
6
7
8
CONTEXT_SETTINGS = dict(help_option_names=['-h','--help'], max_content_width=90)
9
10
@click.group(context_settings= CONTEXT_SETTINGS)
11
@click.version_option(version='0.1')
12
def visualize():
13
    pass
14
15
@visualize.command()
16
@click.option('-i', '--input_dir', default='./inputs/', help='Input directory for patches.', type=click.Path(exists=False), show_default=True)
17
@click.option('-b', '--basename', default='A01', help='Basename of patches.', type=click.Path(exists=False), show_default=True)
18
@click.option('-p', '--patch_info_file', default='patch_info.db', help='Datbase containing all patches', type=click.Path(exists=False), show_default=True)
19
@click.option('-ps', '--patch_size', default=224, help='Patch size.',  show_default=True)
20
@click.option('-x', '--x', default=0, help='X Coordinate of patch.',  show_default=True)
21
@click.option('-y', '--y', default=0, help='Y coordinate of patch.',  show_default=True)
22
@click.option('-o', '--outputfname', default='./output_image.png', help='Output extracted image.', type=click.Path(exists=False), show_default=True)
23
@click.option('-s', '--segmentation', is_flag=True, help='Plot segmentations.', show_default=True)
24
@click.option('-sc', '--n_segmentation_classes', default=4, help='Number segmentation classes',  show_default=True)
25
@click.option('-c', '--custom_segmentation', default='', help='Add custom segmentation map from prediction, in npy',  show_default=True)
26
def extract_patch(input_dir, basename, patch_info_file, patch_size, x, y, outputfname, segmentation, n_segmentation_classes, custom_segmentation):
27
    """Extract image of patch of any size/location and output to image file"""
28
    if glob.glob(os.path.join(input_dir,'*.zarr')):
29
        dask_arr_dict = {os.path.basename(f).split('.zarr')[0]:da.from_zarr(f) for f in glob.glob(os.path.join(input_dir,'*.zarr')) if os.path.basename(f).split('.zarr')[0] == basename}
30
    else:
31
        dask_arr_dict = {basename:load_preprocessed_img(os.path.join(input_dir,'{}.npy'.format(basename)))}
32
    pred_plotter = PredictionPlotter(dask_arr_dict, patch_info_file, compression_factor=3, alpha=0.5, patch_size=patch_size, no_db=True, segmentation=segmentation,n_segmentation_classes=n_segmentation_classes, input_dir=input_dir)
33
    if custom_segmentation:
34
        pred_plotter.add_custom_segmentation(basename,custom_segmentation)
35
    img = pred_plotter.return_patch(basename, x, y, patch_size)
36
    pred_plotter.output_image(img,outputfname)
37
38
@visualize.command()
39
@click.option('-i', '--image_file', default='./inputs/a.svs', help='Input image file.', type=click.Path(exists=False), show_default=True)
40
@click.option('-cf', '--compression_factor', default=3., help='How much compress image.',  show_default=True)
41
@click.option('-o', '--outputfname', default='./output_image.png', help='Output extracted image.', type=click.Path(exists=False), show_default=True)
42
def plot_image(image_file, compression_factor, outputfname):
43
    """Plots the whole slide image supplied."""
44
    plot_image_(image_file, compression_factor=compression_factor, test_image_name=outputfname)
45
46
@visualize.command()
47
@click.option('-i', '--mask_file', default='./inputs/a_mask.npy', help='Input mask file.', type=click.Path(exists=False), show_default=True)
48
@click.option('-o', '--outputfname', default='./output_image.png', help='Output extracted image.', type=click.Path(exists=False), show_default=True)
49
def plot_mask_mpl(mask_file, outputfname):
50
    """Plots the whole slide mask supplied."""
51
    import matplotlib
52
    matplotlib.use('Agg')
53
    import matplotlib.pyplot as plt
54
    import numpy as np
55
    #plt.figure()
56
    plt.imshow(np.load(mask_file))
57
    plt.axis('off')
58
    plt.savefig(outputfname,dpi=500)
59
60
@visualize.command()
61
@click.option('-i', '--input_dir', default='./inputs/', help='Input directory for patches.', type=click.Path(exists=False), show_default=True)
62
@click.option('-b', '--basename', default='A01', help='Basename of patches.', type=click.Path(exists=False), show_default=True)
63
@click.option('-p', '--patch_info_file', default='patch_info.db', help='Datbase containing all patches', type=click.Path(exists=False), show_default=True)
64
@click.option('-ps', '--patch_size', default=224, help='Patch size.',  show_default=True)
65
@click.option('-o', '--outputfname', default='./output_image.png', help='Output extracted image.', type=click.Path(exists=False), show_default=True)
66
@click.option('-an', '--annotations', is_flag=True, help='Plot annotations instead of predictions.', show_default=True)
67
@click.option('-cf', '--compression_factor', default=3., help='How much compress image.',  show_default=True)
68
@click.option('-al', '--alpha', default=0.8, help='How much to give annotations/predictions versus original image.',  show_default=True)
69
@click.option('-s', '--segmentation', is_flag=True, help='Plot segmentations.', show_default=True)
70
@click.option('-sc', '--n_segmentation_classes', default=4, help='Number segmentation classes',  show_default=True)
71
@click.option('-c', '--custom_segmentation', default='', help='Add custom segmentation map from prediction, npy format.',  show_default=True)
72
@click.option('-ac', '--annotation_col', default='annotation', help='Column of annotations', type=click.Path(exists=False), show_default=True)
73
@click.option('-sf', '--scaling_factor', default=1., help='Multiply all prediction scores by this amount.',  show_default=True)
74
@click.option('-tif', '--tif_file', is_flag=True, help='Write to tiff file.',  show_default=True)
75
def plot_predictions(input_dir,basename,patch_info_file,patch_size,outputfname,annotations, compression_factor, alpha, segmentation, n_segmentation_classes, custom_segmentation, annotation_col, scaling_factor, tif_file):
76
    """Overlays classification, regression and segmentation patch level predictions on top of whole slide image."""
77
    if glob.glob(os.path.join(input_dir,'*.zarr')):
78
        dask_arr_dict = {os.path.basename(f).split('.zarr')[0]:da.from_zarr(f) for f in glob.glob(os.path.join(input_dir,'*.zarr')) if os.path.basename(f).split('.zarr')[0] == basename}
79
    else:
80
        dask_arr_dict = {basename:load_preprocessed_img(os.path.join(input_dir,'{}.npy'.format(basename)))}
81
    pred_plotter = PredictionPlotter(dask_arr_dict, patch_info_file, compression_factor=compression_factor, alpha=alpha, patch_size=patch_size, no_db=False, plot_annotation=annotations, segmentation=segmentation, n_segmentation_classes=n_segmentation_classes, input_dir=input_dir, annotation_col=annotation_col, scaling_factor=scaling_factor)
82
    if custom_segmentation:
83
        pred_plotter.add_custom_segmentation(basename,custom_segmentation)
84
    img = pred_plotter.generate_image(basename)
85
    pred_plotter.output_image(img, outputfname, tif_file)
86
87
@visualize.command()
88
@click.option('-i', '--img_file', default='image.txt', help='Input image.', type=click.Path(exists=False), show_default=True)
89
@click.option('-a', '--annotation_txt', default='annotation.txt', help='Column of annotations', type=click.Path(exists=False), show_default=True)
90
@click.option('-ocf', '--original_compression_factor', default=1., help='How much compress image.',  show_default=True)
91
@click.option('-cf', '--compression_factor', default=3., help='How much compress image.',  show_default=True)
92
@click.option('-o', '--outputfilename', default='./output_image.png', help='Output extracted image.', type=click.Path(exists=False), show_default=True)
93
def overlay_new_annotations(img_file,annotation_txt, original_compression_factor,compression_factor, outputfilename):
94
    """Custom annotations, in format [Point: x, y, Point: x, y ... ] one line like this per polygon, overlap these polygons on top of WSI."""
95
    #from shapely.ops import unary_union, polygonize
96
    #from shapely.geometry import MultiPolygon, LineString, MultiPoint, box, Point
97
    #from shapely.geometry.polygon import Polygon
98
    print("Experimental, in development")
99
    import matplotlib
100
    matplotlib.use('Agg')
101
    import matplotlib.pyplot as plt
102
    import re, numpy as np
103
    from PIL import Image
104
    import cv2
105
    from pathflowai.visualize import to_pil
106
    from scipy.misc import imresize
107
    im=plt.imread(img_file) if not img_file.endswith('.npy') else np.load(img_file,mmap_mode='r+')
108
    print(im.shape)
109
    if compression_factor>1 and original_compression_factor == 1.:
110
        im=cv2.resize(im,dsize=(int(im.shape[1]/compression_factor),int(im.shape[0]/compression_factor)),interpolation=cv2.INTER_CUBIC)#im.resize((int(im.shape[0]/compression_factor),int(im.shape[1]/compression_factor)))
111
    print(im.shape)
112
    im=np.array(im)
113
    im=im.transpose((1,0,2))##[::-1,...]#
114
    plt.imshow(im)
115
    with open(annotation_txt) as f:
116
        polygons=[np.array([list(map(float,filter(None,coords.strip(' ').split(',')))) for coords in re.sub('\]|\[|\ ','',line).rstrip().split('Point:') if coords])/compression_factor for line in f]
117
    for polygon in polygons:
118
        plt.plot(polygon[:,0],polygon[:,1],color='blue')
119
    plt.axis('off')
120
    plt.savefig(outputfilename,dpi=500)
121
122
@visualize.command()
123
@click.option('-i', '--embeddings_file', default='predictions/embeddings.pkl', help='Embeddings.', type=click.Path(exists=False), show_default=True)
124
@click.option('-o', '--plotly_output_file', default='predictions/embeddings.html', help='Plotly output file.', type=click.Path(exists=False), show_default=True)
125
@click.option('-a', '--annotations', default=[], multiple=True, help='Multiple annotations to color image.', show_default=True)
126
@click.option('-rb', '--remove_background_annotation', default='', help='If selected, removes 100\% background patches based on this annotation.', type=click.Path(exists=False), show_default=True)
127
@click.option('-ma', '--max_background_area', default=0.05, help='Max background area before exclusion.',  show_default=True)
128
@click.option('-b', '--basename', default='', help='Basename of patches.', type=click.Path(exists=False), show_default=True)
129
@click.option('-nn', '--n_neighbors', default=8, help='Number nearest neighbors.',  show_default=True)
130
def plot_embeddings(embeddings_file,plotly_output_file, annotations, remove_background_annotation , max_background_area, basename, n_neighbors):
131
    """Perform UMAP embeddings of patches and plot using plotly."""
132
    import torch
133
    from umap import UMAP
134
    from pathflowai.visualize import PlotlyPlot
135
    import pandas as pd, numpy as np
136
    embeddings_dict=torch.load(embeddings_file)
137
    embeddings=embeddings_dict['embeddings']
138
    patch_info=embeddings_dict['patch_info']
139
    if remove_background_annotation:
140
        removal_bool=(patch_info[remove_background_annotation]<=(1.-max_background_area)).values
141
        patch_info=patch_info.loc[removal_bool]
142
        embeddings=embeddings.loc[removal_bool]
143
    if basename:
144
        removal_bool=(patch_info['ID']==basename).values
145
        patch_info=patch_info.loc[removal_bool]
146
        embeddings=embeddings.loc[removal_bool]
147
    if annotations:
148
        annotations=np.array(annotations)
149
        if len(annotations)>1:
150
            embeddings.loc[:,'ID']=np.vectorize(lambda i: annotations[np.argmax(patch_info.iloc[i][annotations].values)])(np.arange(embeddings.shape[0]))
151
        else:
152
            embeddings.loc[:,'ID']=patch_info[annotations].values
153
    umap=UMAP(n_components=3,n_neighbors=n_neighbors)
154
    t_data=pd.DataFrame(umap.fit_transform(embeddings.iloc[:,:-1].values),columns=['x','y','z'],index=embeddings.index)
155
    t_data['color']=embeddings['ID'].values
156
    t_data['name']=embeddings.index.values
157
    pp=PlotlyPlot()
158
    pp.add_plot(t_data,size=8)
159
    pp.plot(plotly_output_file,axes_off=True)
160
161
@visualize.command()
162
@click.option('-m', '--model_pkl', default='', help='Plotly output file.', type=click.Path(exists=False), show_default=True)
163
@click.option('-bs', '--batch_size', default=32, help='Batch size.',  show_default=True)
164
@click.option('-o', '--outputfilename', default='predictions/shap_plots.png', help='SHAPley visualization.', type=click.Path(exists=False), show_default=True)
165
@click.option('-mth', '--method', default='deep', help='Method of explaining.', type=click.Choice(['deep','gradient']), show_default=True)
166
@click.option('-l', '--local_smoothing', default=0.0, help='Local smoothing of SHAP scores.',  show_default=True)
167
@click.option('-ns', '--n_samples', default=32, help='Number shapley samples for shapley regression (gradient explainer).',  show_default=True)
168
@click.option('-p', '--pred_out', default='none', help='If not none, output prediction as shap label.', type=click.Choice(['none','sigmoid','softmax']), show_default=True)
169
def shapley_plot(model_pkl, batch_size, outputfilename, method='deep', local_smoothing=0.0, n_samples=20, pred_out='none'):
170
    """Run SHAPley attribution method on patches after classification task to see where model made prediction based on."""
171
    from pathflowai.visualize import plot_shap
172
    import torch
173
    from pathflowai.datasets import get_data_transforms
174
    model_dict=torch.load(model_pkl)
175
    model_dict['dataset_opts']['transformers']=get_data_transforms(**model_dict['transform_opts'])
176
    plot_shap(model_dict['model'], model_dict['dataset_opts'], model_dict['transform_opts'], batch_size, outputfilename, method=method, local_smoothing=local_smoothing, n_samples=n_samples, pred_out=pred_out)
177
178
@visualize.command()
179
@click.option('-i', '--input_dir', default='./inputs/', help='Input directory for patches.', type=click.Path(exists=False), show_default=True)
180
@click.option('-e', '--embeddings_file', default='predictions/embeddings.pkl', help='Embeddings.', type=click.Path(exists=False), show_default=True)
181
@click.option('-b', '--basename', default='', help='Basename of patches.', type=click.Path(exists=False), show_default=True)
182
@click.option('-o', '--outputfilename', default='predictions/shap_plots.png', help='Embedding visualization.', type=click.Path(exists=False), show_default=True)
183
@click.option('-mpl', '--mpl_scatter', is_flag=True, help='Plot segmentations.', show_default=True)
184
@click.option('-rb', '--remove_background_annotation', default='', help='If selected, removes 100\% background patches based on this annotation.', type=click.Path(exists=False), show_default=True)
185
@click.option('-ma', '--max_background_area', default=0.05, help='Max background area before exclusion.',  show_default=True)
186
@click.option('-z', '--zoom', default=0.05, help='Size of images.',  show_default=True)
187
@click.option('-nn', '--n_neighbors', default=8, help='Number nearest neighbors.',  show_default=True)
188
@click.option('-sc', '--sort_col', default='', help='Sort samples on this column.', type=click.Path(exists=False), show_default=True)
189
@click.option('-sm', '--sort_mode', default='asc', help='Sort ascending or descending.', type=click.Choice(['asc','desc']), show_default=True)
190
def plot_image_umap_embeddings(input_dir,embeddings_file,basename,outputfilename,mpl_scatter, remove_background_annotation, max_background_area, zoom, n_neighbors, sort_col='', sort_mode='asc'):
191
    """Plots a UMAP embedding with each point as its corresponding patch image."""
192
    from pathflowai.visualize import plot_umap_images
193
    if glob.glob(os.path.join(input_dir,'*.zarr')):
194
        dask_arr_dict = {os.path.basename(f).split('.zarr')[0]:da.from_zarr(f) for f in glob.glob(os.path.join(input_dir,'*.zarr')) if (not basename) or os.path.basename(f).split('.zarr')[0] == basename}
195
    else:
196
        dask_arr_dict = {basename:load_preprocessed_img(os.path.join(input_dir,'{}.npy'.format(basename))) for basename in ([basename] if basename else set(list(map(lambda x: os.path.basename(os.path.splitext(x)[0]),glob.glob(os.path.join(input_dir,"*.*"))))))}
197
    plot_umap_images(dask_arr_dict, embeddings_file, ID=basename, cval=1., image_res=300., outputfname=outputfilename, mpl_scatter=mpl_scatter, remove_background_annotation=remove_background_annotation, max_background_area=max_background_area, zoom=zoom, n_neighbors=n_neighbors, sort_col=sort_col, sort_mode=sort_mode)
198
199
if __name__ == '__main__':
200
    visualize()