|
a |
|
b/pathflowai/visualize.py |
|
|
1 |
""" |
|
|
2 |
visualize.py |
|
|
3 |
======================= |
|
|
4 |
Plots SHAP outputs, UMAP embeddings, and overlays predictions on top of WSI. |
|
|
5 |
""" |
|
|
6 |
|
|
|
7 |
import plotly.graph_objs as go |
|
|
8 |
import plotly.offline as py |
|
|
9 |
import pandas as pd, numpy as np |
|
|
10 |
import networkx as nx |
|
|
11 |
import dask.array as da |
|
|
12 |
from PIL import Image |
|
|
13 |
import matplotlib |
|
|
14 |
matplotlib.use('Agg') |
|
|
15 |
import matplotlib.pyplot as plt |
|
|
16 |
import seaborn as sns |
|
|
17 |
import sqlite3 |
|
|
18 |
import seaborn as sns |
|
|
19 |
from os.path import join |
|
|
20 |
from pathflowai.utils import npy2da |
|
|
21 |
sns.set() |
|
|
22 |
|
|
|
23 |
class PlotlyPlot: |
|
|
24 |
"""Creates plotly html plots.""" |
|
|
25 |
def __init__(self): |
|
|
26 |
self.plots=[] |
|
|
27 |
|
|
|
28 |
def add_plot(self, t_data_df, G=None, color_col='color', name_col='name', xyz_cols=['x','y','z'], size=2, opacity=1.0, custom_colors=[]): |
|
|
29 |
"""Adds plotting data to be plotted. |
|
|
30 |
|
|
|
31 |
Parameters |
|
|
32 |
---------- |
|
|
33 |
t_data_df:dataframe |
|
|
34 |
3-D transformed dataframe. |
|
|
35 |
G:nx.Graph |
|
|
36 |
Networkx graph. |
|
|
37 |
color_col:str |
|
|
38 |
Column to use to color points. |
|
|
39 |
name_col:str |
|
|
40 |
Column to use to name points. |
|
|
41 |
xyz_cols:list |
|
|
42 |
3 columns that denote x,y,z coords. |
|
|
43 |
size:int |
|
|
44 |
Marker size. |
|
|
45 |
opacity:float |
|
|
46 |
Marker opacity. |
|
|
47 |
custom_colors:list |
|
|
48 |
Custom colors to supply. |
|
|
49 |
""" |
|
|
50 |
plots = [] |
|
|
51 |
x,y,z=tuple(xyz_cols) |
|
|
52 |
if t_data_df[color_col].dtype == np.float64: |
|
|
53 |
plots.append( |
|
|
54 |
go.Scatter3d(x=t_data_df[x], y=t_data_df[y], |
|
|
55 |
z=t_data_df[z], |
|
|
56 |
name='', mode='markers', |
|
|
57 |
marker=dict(color=t_data_df[color_col], size=size, opacity=opacity, colorscale='Viridis', |
|
|
58 |
colorbar=dict(title='Colorbar')), text=t_data_df[color_col] if name_col not in list(t_data_df) else t_data_df[name_col])) |
|
|
59 |
else: |
|
|
60 |
colors = t_data_df[color_col].unique() |
|
|
61 |
c = sns.color_palette('hls', len(colors)) |
|
|
62 |
c = np.array(['rgb({})'.format(','.join(((np.array(c_i)*255).astype(int).astype(str).tolist()))) for c_i in c])#c = ['hsl(' + str(h) + ',50%' + ',50%)' for h in np.linspace(0, 360, len(colors) + 2)] |
|
|
63 |
if custom_colors: |
|
|
64 |
c = custom_colors |
|
|
65 |
color_dict = {name: c[i] for i,name in enumerate(sorted(colors))} |
|
|
66 |
|
|
|
67 |
for name,col in color_dict.items(): |
|
|
68 |
plots.append( |
|
|
69 |
go.Scatter3d(x=t_data_df[x][t_data_df[color_col]==name], y=t_data_df[y][t_data_df[color_col]==name], |
|
|
70 |
z=t_data_df[z][t_data_df[color_col]==name], |
|
|
71 |
name=str(name), mode='markers', |
|
|
72 |
marker=dict(color=col, size=size, opacity=opacity), text=t_data_df.index[t_data_df[color_col]==name] if 'name' not in list(t_data_df) else t_data_df[name_col][t_data_df[color_col]==name])) |
|
|
73 |
if G is not None: |
|
|
74 |
#pos = nx.spring_layout(G,dim=3,iterations=0,pos={i: tuple(t_data.loc[i,['x','y','z']]) for i in range(len(t_data))}) |
|
|
75 |
Xed, Yed, Zed = [], [], [] |
|
|
76 |
for edge in G.edges(): |
|
|
77 |
if edge[0] in t_data_df.index.values and edge[1] in t_data_df.index.values: |
|
|
78 |
Xed += [t_data_df.loc[edge[0],x], t_data_df.loc[edge[1],x], None] |
|
|
79 |
Yed += [t_data_df.loc[edge[0],y], t_data_df.loc[edge[1],y], None] |
|
|
80 |
Zed += [t_data_df.loc[edge[0],z], t_data_df.loc[edge[1],z], None] |
|
|
81 |
plots.append(go.Scatter3d(x=Xed, |
|
|
82 |
y=Yed, |
|
|
83 |
z=Zed, |
|
|
84 |
mode='lines', |
|
|
85 |
line=go.scatter3d.Line(color='rgb(210,210,210)', width=2), |
|
|
86 |
hoverinfo='none' |
|
|
87 |
)) |
|
|
88 |
self.plots.extend(plots) |
|
|
89 |
|
|
|
90 |
def plot(self, output_fname, axes_off=False): |
|
|
91 |
"""Plot embedding of patches to html file. |
|
|
92 |
|
|
|
93 |
Parameters |
|
|
94 |
---------- |
|
|
95 |
output_fname:str |
|
|
96 |
Output html file. |
|
|
97 |
axes_off:bool |
|
|
98 |
Remove axes. |
|
|
99 |
|
|
|
100 |
""" |
|
|
101 |
if axes_off: |
|
|
102 |
fig = go.Figure(data=self.plots,layout=go.Layout(scene=dict(xaxis=dict(title='',autorange=True,showgrid=False,zeroline=False,showline=False,ticks='',showticklabels=False), |
|
|
103 |
yaxis=dict(title='',autorange=True,showgrid=False,zeroline=False,showline=False,ticks='',showticklabels=False), |
|
|
104 |
zaxis=dict(title='',autorange=True,showgrid=False,zeroline=False,showline=False,ticks='',showticklabels=False)))) |
|
|
105 |
else: |
|
|
106 |
fig = go.Figure(data=self.plots) |
|
|
107 |
py.plot(fig, filename=output_fname, auto_open=False) |
|
|
108 |
|
|
|
109 |
def to_pil(arr): |
|
|
110 |
"""Numpy array to pil. |
|
|
111 |
|
|
|
112 |
Parameters |
|
|
113 |
---------- |
|
|
114 |
arr:array |
|
|
115 |
Numpy array. |
|
|
116 |
|
|
|
117 |
Returns |
|
|
118 |
------- |
|
|
119 |
Image |
|
|
120 |
PIL Image. |
|
|
121 |
|
|
|
122 |
""" |
|
|
123 |
return Image.fromarray(arr.astype('uint8'), 'RGB') |
|
|
124 |
|
|
|
125 |
def blend(arr1, arr2, alpha=0.5): |
|
|
126 |
"""Blend 2 arrays together, mixing with alpha. |
|
|
127 |
|
|
|
128 |
Parameters |
|
|
129 |
---------- |
|
|
130 |
arr1:array |
|
|
131 |
Image 1. |
|
|
132 |
arr2:array |
|
|
133 |
Image 2. |
|
|
134 |
alpha:float |
|
|
135 |
Higher alpha makes image more like image 1. |
|
|
136 |
|
|
|
137 |
Returns |
|
|
138 |
------- |
|
|
139 |
array |
|
|
140 |
Resulting image. |
|
|
141 |
|
|
|
142 |
""" |
|
|
143 |
return alpha*arr1 + (1.-alpha)*arr2 |
|
|
144 |
|
|
|
145 |
def prob2rbg(prob, palette, arr): |
|
|
146 |
"""Convert probability score to rgb image. |
|
|
147 |
|
|
|
148 |
Parameters |
|
|
149 |
---------- |
|
|
150 |
prob:float |
|
|
151 |
Between 0 and 1 score. |
|
|
152 |
palette:palette |
|
|
153 |
Pallet converts between prob and color. |
|
|
154 |
arr:array |
|
|
155 |
Original array. |
|
|
156 |
|
|
|
157 |
Returns |
|
|
158 |
------- |
|
|
159 |
array |
|
|
160 |
New image colored by prediction score. |
|
|
161 |
|
|
|
162 |
""" |
|
|
163 |
col = palette(prob) |
|
|
164 |
for i in range(3): |
|
|
165 |
arr[...,i] = int(col[i]*255) |
|
|
166 |
return arr |
|
|
167 |
|
|
|
168 |
def seg2rgb(seg, palette, n_segmentation_classes): |
|
|
169 |
"""Color each pixel by segmentation class. |
|
|
170 |
|
|
|
171 |
Parameters |
|
|
172 |
---------- |
|
|
173 |
seg:array |
|
|
174 |
Segmentation mask. |
|
|
175 |
palette:palette |
|
|
176 |
Color to RGB map. |
|
|
177 |
n_segmentation_classes:int |
|
|
178 |
Total number segmentation classes. |
|
|
179 |
|
|
|
180 |
Returns |
|
|
181 |
------- |
|
|
182 |
array |
|
|
183 |
Returned segmentation image. |
|
|
184 |
""" |
|
|
185 |
#print(seg.shape) |
|
|
186 |
#print((seg/n_segmentation_classes)) |
|
|
187 |
img=(palette(seg/n_segmentation_classes)[...,:3]*255).astype(int) |
|
|
188 |
#print(img.shape) |
|
|
189 |
return img |
|
|
190 |
|
|
|
191 |
def annotation2rgb(i,palette,arr): |
|
|
192 |
"""Go from annotation of patch to color. |
|
|
193 |
|
|
|
194 |
Parameters |
|
|
195 |
---------- |
|
|
196 |
i:int |
|
|
197 |
Annotation index. |
|
|
198 |
palette:palette |
|
|
199 |
Index to color mapping. |
|
|
200 |
arr:array |
|
|
201 |
Image array. |
|
|
202 |
|
|
|
203 |
Returns |
|
|
204 |
------- |
|
|
205 |
array |
|
|
206 |
Resulting image. |
|
|
207 |
|
|
|
208 |
""" |
|
|
209 |
col = palette[i] |
|
|
210 |
for i in range(3): |
|
|
211 |
arr[...,i] = int(col[i]*255) |
|
|
212 |
return arr |
|
|
213 |
|
|
|
214 |
def plot_image_(image_file, compression_factor=2., test_image_name='test.png'): |
|
|
215 |
"""Plots entire SVS/other image. |
|
|
216 |
|
|
|
217 |
Parameters |
|
|
218 |
---------- |
|
|
219 |
image_file:str |
|
|
220 |
Image file. |
|
|
221 |
compression_factor:float |
|
|
222 |
Amount to shrink each dimension of image. |
|
|
223 |
test_image_name:str |
|
|
224 |
Output image file. |
|
|
225 |
|
|
|
226 |
""" |
|
|
227 |
from pathflowai.utils import svs2dask_array, npy2da |
|
|
228 |
import cv2 |
|
|
229 |
if image_file.endswith('.zarr'): |
|
|
230 |
arr=da.from_zarr(image_file) |
|
|
231 |
else: |
|
|
232 |
arr=svs2dask_array(image_file, tile_size=1000, overlap=0, remove_last=True, allow_unknown_chunksizes=False) if (not image_file.endswith('.npy')) else npy2da(image_file) |
|
|
233 |
arr2=to_pil(cv2.resize(arr.compute(), dsize=tuple((np.array(arr.shape[:2])/compression_factor).astype(int).tolist()), interpolation=cv2.INTER_CUBIC)) |
|
|
234 |
arr2.save(test_image_name) |
|
|
235 |
|
|
|
236 |
# for now binary output |
|
|
237 |
class PredictionPlotter: |
|
|
238 |
"""Plots predictions over entire image. |
|
|
239 |
|
|
|
240 |
Parameters |
|
|
241 |
---------- |
|
|
242 |
dask_arr_dict:dict |
|
|
243 |
Stores all dask arrays corresponding to all of the images. |
|
|
244 |
patch_info_db:str |
|
|
245 |
Patch level information, eg. prediction. |
|
|
246 |
compression_factor:float |
|
|
247 |
How much to compress image by. |
|
|
248 |
alpha:float |
|
|
249 |
Low value assigns higher weight to prediction over original image. |
|
|
250 |
patch_size:int |
|
|
251 |
Patch size. |
|
|
252 |
no_db:bool |
|
|
253 |
Don't use patch information. |
|
|
254 |
plot_annotation:bool |
|
|
255 |
Plot annotations from patch information. |
|
|
256 |
segmentation:bool |
|
|
257 |
Plot segmentation mask. |
|
|
258 |
n_segmentation_classes:int |
|
|
259 |
Number segmentation classes. |
|
|
260 |
input_dir:str |
|
|
261 |
Input directory. |
|
|
262 |
annotation_col:str |
|
|
263 |
Annotation column to plot. |
|
|
264 |
scaling_factor:float |
|
|
265 |
Multiplies the prediction scores to make them appear darker on the images when predicting. |
|
|
266 |
""" |
|
|
267 |
# some patches have been filtered out, not one to one!!! figure out |
|
|
268 |
def __init__(self, dask_arr_dict, patch_info_db, compression_factor=3, alpha=0.5, patch_size=224, no_db=False, plot_annotation=False, segmentation=False, n_segmentation_classes=4, input_dir='', annotation_col='annotation', scaling_factor=1.): |
|
|
269 |
|
|
|
270 |
self.segmentation = segmentation |
|
|
271 |
self.scaling_factor=scaling_factor |
|
|
272 |
self.segmentation_maps = None |
|
|
273 |
self.n_segmentation_classes=float(n_segmentation_classes) |
|
|
274 |
self.pred_palette = sns.cubehelix_palette(start=0,as_cmap=True) |
|
|
275 |
if not no_db: |
|
|
276 |
self.compression_factor=compression_factor |
|
|
277 |
self.alpha = alpha |
|
|
278 |
self.patch_size = patch_size |
|
|
279 |
conn = sqlite3.connect(patch_info_db) |
|
|
280 |
patch_info=pd.read_sql('select * from "{}";'.format(patch_size),con=conn) |
|
|
281 |
conn.close() |
|
|
282 |
self.annotations = {str(a):i for i,a in enumerate(patch_info['annotation'].unique().tolist())} |
|
|
283 |
self.plot_annotation=plot_annotation |
|
|
284 |
self.palette=sns.color_palette(n_colors=len(list(self.annotations.keys()))) |
|
|
285 |
#print(self.palette) |
|
|
286 |
if 'y_pred' not in patch_info.columns: |
|
|
287 |
patch_info['y_pred'] = 0. |
|
|
288 |
self.patch_info=patch_info[['ID','x','y','patch_size','annotation',annotation_col]] # y_pred |
|
|
289 |
if 0: |
|
|
290 |
for ID in predictions: |
|
|
291 |
patch_info.loc[patch_info["ID"]==ID,'y_pred'] = predictions[ID] |
|
|
292 |
self.patch_info = self.patch_info[np.isin(self.patch_info['ID'],np.array(list(dask_arr_dict.keys())))] |
|
|
293 |
if self.segmentation: |
|
|
294 |
self.segmentation_maps = {slide:npy2da(join(input_dir,'{}_mask.npy'.format(slide))) for slide in dask_arr_dict.keys()} |
|
|
295 |
#self.patch_info[['x','y','patch_size']]/=self.compression_factor |
|
|
296 |
self.dask_arr_dict = {k:v[...,:3] for k,v in dask_arr_dict.items()} |
|
|
297 |
|
|
|
298 |
def add_custom_segmentation(self, basename, npy): |
|
|
299 |
"""Replace segmentation mask with new custom segmentation. |
|
|
300 |
|
|
|
301 |
Parameters |
|
|
302 |
---------- |
|
|
303 |
basename:str |
|
|
304 |
Patient ID |
|
|
305 |
npy:str |
|
|
306 |
Numpy mask. |
|
|
307 |
""" |
|
|
308 |
self.segmentation_maps[basename] = da.from_array(np.load(npy,mmap_mode='r+')) |
|
|
309 |
|
|
|
310 |
def generate_image(self, ID): |
|
|
311 |
"""Generate the image array for the whole slide image with predictions overlaid. |
|
|
312 |
|
|
|
313 |
Parameters |
|
|
314 |
---------- |
|
|
315 |
ID:str |
|
|
316 |
patient ID. |
|
|
317 |
|
|
|
318 |
Returns |
|
|
319 |
------- |
|
|
320 |
array |
|
|
321 |
Resulting overlaid whole slide image. |
|
|
322 |
|
|
|
323 |
""" |
|
|
324 |
patch_info = self.patch_info[self.patch_info['ID']==ID] |
|
|
325 |
dask_arr = self.dask_arr_dict[ID] |
|
|
326 |
arr_shape = np.array(dask_arr.shape).astype(float) |
|
|
327 |
|
|
|
328 |
#image=da.zeros_like(dask_arr) |
|
|
329 |
|
|
|
330 |
arr_shape[:2]/=self.compression_factor |
|
|
331 |
|
|
|
332 |
arr_shape=arr_shape.astype(int).tolist() |
|
|
333 |
|
|
|
334 |
img = Image.new('RGB',arr_shape[:2],'white') |
|
|
335 |
|
|
|
336 |
for i in range(patch_info.shape[0]): |
|
|
337 |
ID,x,y,patch_size,annotation,pred = patch_info.iloc[i].tolist() |
|
|
338 |
#print(x,y,annotation) |
|
|
339 |
x_new,y_new = int(x/self.compression_factor),int(y/self.compression_factor) |
|
|
340 |
image = np.zeros((patch_size,patch_size,3)) |
|
|
341 |
if self.segmentation: |
|
|
342 |
image=seg2rgb(self.segmentation_maps[ID][x:x+patch_size,y:y+patch_size].compute(),self.pred_palette, self.n_segmentation_classes) |
|
|
343 |
else: |
|
|
344 |
image=prob2rbg(pred*self.scaling_factor, self.pred_palette, image) if not self.plot_annotation else annotation2rgb(self.annotations[str(pred)],self.palette,image) # annotation |
|
|
345 |
arr=dask_arr[x:x+patch_size,y:y+patch_size].compute() |
|
|
346 |
#print(image.shape) |
|
|
347 |
blended_patch=blend(arr,image, self.alpha).transpose((1,0,2)) |
|
|
348 |
blended_patch_pil = to_pil(blended_patch) |
|
|
349 |
patch_size/=self.compression_factor |
|
|
350 |
patch_size=int(patch_size) |
|
|
351 |
blended_patch_pil=blended_patch_pil.resize((patch_size,patch_size)) |
|
|
352 |
img.paste(blended_patch_pil, box=(x_new,y_new), mask=None) |
|
|
353 |
return img |
|
|
354 |
|
|
|
355 |
def return_patch(self, ID, x, y, patch_size): |
|
|
356 |
"""Return one single patch instead of entire image. |
|
|
357 |
|
|
|
358 |
Parameters |
|
|
359 |
---------- |
|
|
360 |
ID:str |
|
|
361 |
Patient ID |
|
|
362 |
x:int |
|
|
363 |
X coordinate. |
|
|
364 |
y:int |
|
|
365 |
Y coordinate. |
|
|
366 |
patch_size:int |
|
|
367 |
Patch size. |
|
|
368 |
|
|
|
369 |
Returns |
|
|
370 |
------- |
|
|
371 |
array |
|
|
372 |
Image. |
|
|
373 |
""" |
|
|
374 |
img=(self.dask_arr_dict[ID][x:x+patch_size,y:y+patch_size].compute() if not self.segmentation else seg2rgb(self.segmentation_maps[ID][x:x+patch_size,y:y+patch_size].compute(),self.pred_palette, self.n_segmentation_classes)) |
|
|
375 |
return to_pil(img) |
|
|
376 |
|
|
|
377 |
def output_image(self, img, filename, tif=False): |
|
|
378 |
"""Output calculated image to file. |
|
|
379 |
|
|
|
380 |
Parameters |
|
|
381 |
---------- |
|
|
382 |
img:array |
|
|
383 |
Image. |
|
|
384 |
filename:str |
|
|
385 |
Output file name. |
|
|
386 |
tif:bool |
|
|
387 |
Store in TIF format? |
|
|
388 |
""" |
|
|
389 |
if tif: |
|
|
390 |
from tifffile import imwrite |
|
|
391 |
imwrite(filename, np.array(img), photometric='rgb') |
|
|
392 |
else: |
|
|
393 |
img.save(filename) |
|
|
394 |
|
|
|
395 |
def plot_shap(model, dataset_opts, transform_opts, batch_size, outputfilename, n_outputs=1, method='deep', local_smoothing=0.0, n_samples=20, pred_out=False): |
|
|
396 |
"""Plot shapley attributions overlaid on images for classification tasks. |
|
|
397 |
|
|
|
398 |
Parameters |
|
|
399 |
---------- |
|
|
400 |
model:nn.Module |
|
|
401 |
Pytorch model. |
|
|
402 |
dataset_opts:dict |
|
|
403 |
Options used to configure dataset |
|
|
404 |
transform_opts:dict |
|
|
405 |
Options used to configure transformers. |
|
|
406 |
batch_size:int |
|
|
407 |
Batch size for training. |
|
|
408 |
outputfilename:str |
|
|
409 |
Output filename. |
|
|
410 |
n_outputs:int |
|
|
411 |
Number of top outputs. |
|
|
412 |
method:str |
|
|
413 |
Gradient or deep explainer. |
|
|
414 |
local_smoothing:float |
|
|
415 |
How much to smooth shapley map. |
|
|
416 |
n_samples:int |
|
|
417 |
Number shapley samples to draw. |
|
|
418 |
pred_out:bool |
|
|
419 |
Label images with binary prediction score? |
|
|
420 |
|
|
|
421 |
""" |
|
|
422 |
import torch |
|
|
423 |
from torch.nn import functional as F |
|
|
424 |
import numpy as np |
|
|
425 |
from torch.utils.data import DataLoader |
|
|
426 |
import shap |
|
|
427 |
from pathflowai.datasets import DynamicImageDataset |
|
|
428 |
import matplotlib |
|
|
429 |
from matplotlib import pyplot as plt |
|
|
430 |
from pathflowai.sampler import ImbalancedDatasetSampler |
|
|
431 |
|
|
|
432 |
out_transform=dict(sigmoid=F.sigmoid,softmax=F.softmax,none=lambda x: x) |
|
|
433 |
binary_threshold=dataset_opts.pop('binary_threshold') |
|
|
434 |
num_targets=dataset_opts.pop('num_targets') |
|
|
435 |
|
|
|
436 |
dataset = DynamicImageDataset(**dataset_opts) |
|
|
437 |
|
|
|
438 |
if dataset_opts['classify_annotations']: |
|
|
439 |
binarizer=dataset.binarize_annotations(num_targets=num_targets,binary_threshold=binary_threshold) |
|
|
440 |
num_targets=len(dataset.targets) |
|
|
441 |
|
|
|
442 |
dataloader_val = DataLoader(dataset,batch_size=batch_size, num_workers=10, shuffle=True if num_targets>1 else False, sampler=ImbalancedDatasetSampler(dataset) if num_targets==1 else None) |
|
|
443 |
#dataloader_test = DataLoader(dataset,batch_size=batch_size,num_workers=10, shuffle=False) |
|
|
444 |
|
|
|
445 |
background,y_background=next(iter(dataloader_val)) |
|
|
446 |
if method=='gradient': |
|
|
447 |
background=torch.cat([background,next(iter(dataloader_val))[0]],0) |
|
|
448 |
X_test,y_test=next(iter(dataloader_val)) |
|
|
449 |
|
|
|
450 |
if torch.cuda.is_available(): |
|
|
451 |
background=background.cuda() |
|
|
452 |
X_test=X_test.cuda() |
|
|
453 |
|
|
|
454 |
if pred_out!='none': |
|
|
455 |
if torch.cuda.is_available(): |
|
|
456 |
model2=model.cuda() |
|
|
457 |
y_test=out_transform[pred_out](model2(X_test)).detach().cpu() |
|
|
458 |
|
|
|
459 |
y_test=y_test.numpy() |
|
|
460 |
|
|
|
461 |
if method=='deep': |
|
|
462 |
e = shap.DeepExplainer(model, background) |
|
|
463 |
s=e.shap_values(X_test, ranked_outputs=n_outputs) |
|
|
464 |
elif method=='gradient': |
|
|
465 |
e = shap.GradientExplainer(model, background, batch_size=batch_size, local_smoothing=local_smoothing) |
|
|
466 |
s=e.shap_values(X_test, ranked_outputs=n_outputs, nsamples=n_samples) |
|
|
467 |
|
|
|
468 |
if y_test.shape[1]>1: |
|
|
469 |
y_test=y_test.argmax(axis=1) |
|
|
470 |
|
|
|
471 |
if n_outputs>1: |
|
|
472 |
shap_values, idx = s |
|
|
473 |
else: |
|
|
474 |
shap_values, idx = s, y_test |
|
|
475 |
|
|
|
476 |
#print(shap_values) # .detach().cpu() |
|
|
477 |
|
|
|
478 |
if num_targets == 1: |
|
|
479 |
shap_numpy = [np.swapaxes(np.swapaxes(shap_values, 1, -1), 1, 2)] |
|
|
480 |
else: |
|
|
481 |
shap_numpy = [np.swapaxes(np.swapaxes(s, 1, -1), 1, 2) for s in shap_values] |
|
|
482 |
#print(shap_numpy.shape) |
|
|
483 |
X_test_numpy=X_test.detach().cpu().numpy() |
|
|
484 |
X_test_numpy=X_test_numpy.transpose((0,2,3,1)) |
|
|
485 |
for i in range(X_test_numpy.shape[0]): |
|
|
486 |
X_test_numpy[i,...]*=np.array(transform_opts['std']) |
|
|
487 |
X_test_numpy[i,...]+=np.array(transform_opts['mean']) |
|
|
488 |
X_test_numpy=X_test_numpy.transpose((0,3,1,2)) |
|
|
489 |
test_numpy = np.swapaxes(np.swapaxes(X_test_numpy, 1, -1), 1, 2) |
|
|
490 |
if pred_out!='none': |
|
|
491 |
labels=y_test.astype(str) |
|
|
492 |
else: |
|
|
493 |
labels = np.array([[(dataloader_val.dataset.targets[i[j]] if num_targets>1 else str(i)) for j in range(n_outputs)] for i in idx])#[:,np.newaxis] # y_test |
|
|
494 |
if 0 and (len(labels.shape)<2 or labels.shape[1]==1): |
|
|
495 |
labels=labels.flatten()#[:np.newaxis] |
|
|
496 |
|
|
|
497 |
#print(labels.shape,shap_numpy.shape[0]) |
|
|
498 |
plt.figure() |
|
|
499 |
shap.image_plot(shap_numpy, test_numpy, labels)# if num_targets!=1 else shap_values -test_numpy , labels=dataloader_test.dataset.targets) |
|
|
500 |
plt.savefig(outputfilename, dpi=300) |
|
|
501 |
|
|
|
502 |
def plot_umap_images(dask_arr_dict, embeddings_file, ID=None, cval=1., image_res=300., outputfname='output_embedding.png', mpl_scatter=True, remove_background_annotation='', max_background_area=0.01, zoom=0.05, n_neighbors=10, sort_col='', sort_mode='asc'): |
|
|
503 |
"""Make UMAP embedding plot, overlaid with images. |
|
|
504 |
|
|
|
505 |
Parameters |
|
|
506 |
---------- |
|
|
507 |
dask_arr_dict:dict |
|
|
508 |
Stored dask arrays for each WSI. |
|
|
509 |
embeddings_file:str |
|
|
510 |
Embeddings pickle file stored from running using after trainign the model. |
|
|
511 |
ID:str |
|
|
512 |
Patient ID. |
|
|
513 |
cval:float |
|
|
514 |
Deprecated |
|
|
515 |
image_res:float |
|
|
516 |
Image resolution. |
|
|
517 |
outputfname:str |
|
|
518 |
Output image file. |
|
|
519 |
mpl_scatter:bool |
|
|
520 |
Recommended: Use matplotlib for scatter plot. |
|
|
521 |
remove_background_annotation:str |
|
|
522 |
Remove the background annotations. Enter for annotation to remove. |
|
|
523 |
max_background_area:float |
|
|
524 |
Maximum backgrund area in each tile for inclusion. |
|
|
525 |
zoom:float |
|
|
526 |
How much to zoom in on each patch, less than 1 is zoom out. |
|
|
527 |
n_neighbors:int |
|
|
528 |
Number of neighbors for UMAP embedding. |
|
|
529 |
sort_col:str |
|
|
530 |
Patch info column to sort on. |
|
|
531 |
sort_mode:str |
|
|
532 |
Sort ascending or descending. |
|
|
533 |
|
|
|
534 |
Returns |
|
|
535 |
------- |
|
|
536 |
type |
|
|
537 |
Description of returned object. |
|
|
538 |
|
|
|
539 |
Inspired by: https://gist.github.com/lukemetz/be6123c7ee3b366e333a |
|
|
540 |
WIP!! Needs testing.""" |
|
|
541 |
import torch |
|
|
542 |
import dask |
|
|
543 |
from dask.distributed import Client |
|
|
544 |
from umap import UMAP |
|
|
545 |
from pathflowai.visualize import PlotlyPlot |
|
|
546 |
import pandas as pd, numpy as np |
|
|
547 |
import skimage.io |
|
|
548 |
from skimage.transform import resize |
|
|
549 |
import matplotlib |
|
|
550 |
matplotlib.use('Agg') |
|
|
551 |
from matplotlib import pyplot as plt |
|
|
552 |
sns.set(style='white') |
|
|
553 |
|
|
|
554 |
def min_resize(img, size): |
|
|
555 |
""" |
|
|
556 |
Resize an image so that it is size along the minimum spatial dimension. |
|
|
557 |
""" |
|
|
558 |
w, h = map(float, img.shape[:2]) |
|
|
559 |
if min([w, h]) != size: |
|
|
560 |
if w <= h: |
|
|
561 |
img = resize(img, (int(round((h/w)*size)), int(size))) |
|
|
562 |
else: |
|
|
563 |
img = resize(img, (int(size), int(round((w/h)*size)))) |
|
|
564 |
return img |
|
|
565 |
|
|
|
566 |
#dask_arr = dask_arr_dict[ID] |
|
|
567 |
|
|
|
568 |
embeddings_dict=torch.load(embeddings_file) |
|
|
569 |
embeddings=embeddings_dict['embeddings'] |
|
|
570 |
patch_info=embeddings_dict['patch_info'] |
|
|
571 |
if sort_col: |
|
|
572 |
idx=np.argsort(patch_info[sort_col].values) |
|
|
573 |
if sort_mode == 'desc': |
|
|
574 |
idx=idx[::-1] |
|
|
575 |
patch_info = patch_info.iloc[idx] |
|
|
576 |
embeddings=embeddings.iloc[idx] |
|
|
577 |
if ID: |
|
|
578 |
removal_bool=(patch_info['ID']==ID).values |
|
|
579 |
patch_info = patch_info.loc[removal_bool] |
|
|
580 |
embeddings=embeddings.loc[removal_bool] |
|
|
581 |
if remove_background_annotation: |
|
|
582 |
removal_bool=(patch_info[remove_background_annotation]<=(1.-max_background_area)).values |
|
|
583 |
patch_info=patch_info.loc[removal_bool] |
|
|
584 |
embeddings=embeddings.loc[removal_bool] |
|
|
585 |
|
|
|
586 |
umap=UMAP(n_components=2,n_neighbors=n_neighbors) |
|
|
587 |
t_data=pd.DataFrame(umap.fit_transform(embeddings.iloc[:,:-1].values),columns=['x','y'],index=embeddings.index) |
|
|
588 |
|
|
|
589 |
images=[] |
|
|
590 |
|
|
|
591 |
for i in range(patch_info.shape[0]): |
|
|
592 |
ID=patch_info.iloc[i]['ID'] |
|
|
593 |
x,y,patch_size=patch_info.iloc[i][['x','y','patch_size']].values.tolist() |
|
|
594 |
arr=dask_arr_dict[ID][x:x+patch_size,y:y+patch_size]#.transpose((2,0,1)) |
|
|
595 |
images.append(arr) |
|
|
596 |
|
|
|
597 |
c=Client() |
|
|
598 |
images=dask.compute(images) |
|
|
599 |
c.close() |
|
|
600 |
|
|
|
601 |
if mpl_scatter: |
|
|
602 |
from matplotlib.offsetbox import OffsetImage, AnnotationBbox |
|
|
603 |
def imscatter(x, y, ax, imageData, zoom): |
|
|
604 |
images = [] |
|
|
605 |
for i in range(len(x)): |
|
|
606 |
x0, y0 = x[i], y[i] |
|
|
607 |
img = imageData[i] |
|
|
608 |
#print(img.shape) |
|
|
609 |
image = OffsetImage(img, zoom=zoom) |
|
|
610 |
ab = AnnotationBbox(image, (x0, y0), xycoords='data', frameon=False) |
|
|
611 |
images.append(ax.add_artist(ab)) |
|
|
612 |
|
|
|
613 |
ax.update_datalim(np.column_stack([x, y])) |
|
|
614 |
ax.autoscale() |
|
|
615 |
|
|
|
616 |
fig, ax = plt.subplots() |
|
|
617 |
imscatter(t_data['x'].values, t_data['y'].values, imageData=images[0], ax=ax, zoom=zoom) |
|
|
618 |
sns.despine() |
|
|
619 |
plt.savefig(outputfname,dpi=300) |
|
|
620 |
|
|
|
621 |
|
|
|
622 |
else: |
|
|
623 |
xx=t_data.iloc[:,0] |
|
|
624 |
yy=t_data.iloc[:,1] |
|
|
625 |
|
|
|
626 |
images = [min_resize(image, img_res) for image in images] |
|
|
627 |
max_width = max([image.shape[0] for image in images]) |
|
|
628 |
max_height = max([image.shape[1] for image in images]) |
|
|
629 |
|
|
|
630 |
x_min, x_max = xx.min(), xx.max() |
|
|
631 |
y_min, y_max = yy.min(), yy.max() |
|
|
632 |
# Fix the ratios |
|
|
633 |
sx = (x_max-x_min) |
|
|
634 |
sy = (y_max-y_min) |
|
|
635 |
if sx > sy: |
|
|
636 |
res_x = sx/float(sy)*res |
|
|
637 |
res_y = res |
|
|
638 |
else: |
|
|
639 |
res_x = res |
|
|
640 |
res_y = sy/float(sx)*res |
|
|
641 |
|
|
|
642 |
canvas = np.ones((res_x+max_width, res_y+max_height, 3))*cval |
|
|
643 |
x_coords = np.linspace(x_min, x_max, res_x) |
|
|
644 |
y_coords = np.linspace(y_min, y_max, res_y) |
|
|
645 |
for x, y, image in zip(xx, yy, images): |
|
|
646 |
w, h = image.shape[:2] |
|
|
647 |
x_idx = np.argmin((x - x_coords)**2) |
|
|
648 |
y_idx = np.argmin((y - y_coords)**2) |
|
|
649 |
canvas[x_idx:x_idx+w, y_idx:y_idx+h] = image |
|
|
650 |
|
|
|
651 |
skimage.io.imsave(outputfname, canvas) |