Diff of /pathflowai/visualize.py [000000] .. [e9500f]

Switch to unified view

a b/pathflowai/visualize.py
1
"""
2
visualize.py
3
=======================
4
Plots SHAP outputs, UMAP embeddings, and overlays predictions on top of WSI.
5
"""
6
7
import plotly.graph_objs as go
8
import plotly.offline as py
9
import pandas as pd, numpy as np
10
import networkx as nx
11
import dask.array as da
12
from PIL import Image
13
import matplotlib
14
matplotlib.use('Agg')
15
import matplotlib.pyplot as plt
16
import seaborn as sns
17
import sqlite3
18
import seaborn as sns
19
from os.path import join
20
from pathflowai.utils import npy2da
21
sns.set()
22
23
class PlotlyPlot:
24
    """Creates plotly html plots."""
25
    def __init__(self):
26
        self.plots=[]
27
28
    def add_plot(self, t_data_df, G=None, color_col='color', name_col='name', xyz_cols=['x','y','z'], size=2, opacity=1.0, custom_colors=[]):
29
        """Adds plotting data to be plotted.
30
31
        Parameters
32
        ----------
33
        t_data_df:dataframe
34
            3-D transformed dataframe.
35
        G:nx.Graph
36
            Networkx graph.
37
        color_col:str
38
            Column to use to color points.
39
        name_col:str
40
            Column to use to name points.
41
        xyz_cols:list
42
            3 columns that denote x,y,z coords.
43
        size:int
44
            Marker size.
45
        opacity:float
46
            Marker opacity.
47
        custom_colors:list
48
            Custom colors to supply.
49
        """
50
        plots = []
51
        x,y,z=tuple(xyz_cols)
52
        if t_data_df[color_col].dtype == np.float64:
53
            plots.append(
54
                go.Scatter3d(x=t_data_df[x], y=t_data_df[y],
55
                             z=t_data_df[z],
56
                             name='', mode='markers',
57
                             marker=dict(color=t_data_df[color_col], size=size, opacity=opacity, colorscale='Viridis',
58
                             colorbar=dict(title='Colorbar')), text=t_data_df[color_col] if name_col not in list(t_data_df) else t_data_df[name_col]))
59
        else:
60
            colors = t_data_df[color_col].unique()
61
            c = sns.color_palette('hls', len(colors))
62
            c = np.array(['rgb({})'.format(','.join(((np.array(c_i)*255).astype(int).astype(str).tolist()))) for c_i in c])#c = ['hsl(' + str(h) + ',50%' + ',50%)' for h in np.linspace(0, 360, len(colors) + 2)]
63
            if custom_colors:
64
                c = custom_colors
65
            color_dict = {name: c[i] for i,name in enumerate(sorted(colors))}
66
67
            for name,col in color_dict.items():
68
                plots.append(
69
                    go.Scatter3d(x=t_data_df[x][t_data_df[color_col]==name], y=t_data_df[y][t_data_df[color_col]==name],
70
                                 z=t_data_df[z][t_data_df[color_col]==name],
71
                                 name=str(name), mode='markers',
72
                                 marker=dict(color=col, size=size, opacity=opacity), text=t_data_df.index[t_data_df[color_col]==name] if 'name' not in list(t_data_df) else t_data_df[name_col][t_data_df[color_col]==name]))
73
        if G is not None:
74
            #pos = nx.spring_layout(G,dim=3,iterations=0,pos={i: tuple(t_data.loc[i,['x','y','z']]) for i in range(len(t_data))})
75
            Xed, Yed, Zed = [], [], []
76
            for edge in G.edges():
77
                if edge[0] in t_data_df.index.values and edge[1] in t_data_df.index.values:
78
                    Xed += [t_data_df.loc[edge[0],x], t_data_df.loc[edge[1],x], None]
79
                    Yed += [t_data_df.loc[edge[0],y], t_data_df.loc[edge[1],y], None]
80
                    Zed += [t_data_df.loc[edge[0],z], t_data_df.loc[edge[1],z], None]
81
            plots.append(go.Scatter3d(x=Xed,
82
                      y=Yed,
83
                      z=Zed,
84
                      mode='lines',
85
                      line=go.scatter3d.Line(color='rgb(210,210,210)', width=2),
86
                      hoverinfo='none'
87
                      ))
88
        self.plots.extend(plots)
89
90
    def plot(self, output_fname, axes_off=False):
91
        """Plot embedding of patches to html file.
92
93
        Parameters
94
        ----------
95
        output_fname:str
96
            Output html file.
97
        axes_off:bool
98
            Remove axes.
99
100
        """
101
        if axes_off:
102
            fig = go.Figure(data=self.plots,layout=go.Layout(scene=dict(xaxis=dict(title='',autorange=True,showgrid=False,zeroline=False,showline=False,ticks='',showticklabels=False),
103
                yaxis=dict(title='',autorange=True,showgrid=False,zeroline=False,showline=False,ticks='',showticklabels=False),
104
                zaxis=dict(title='',autorange=True,showgrid=False,zeroline=False,showline=False,ticks='',showticklabels=False))))
105
        else:
106
            fig = go.Figure(data=self.plots)
107
        py.plot(fig, filename=output_fname, auto_open=False)
108
109
def to_pil(arr):
110
    """Numpy array to pil.
111
112
    Parameters
113
    ----------
114
    arr:array
115
        Numpy array.
116
117
    Returns
118
    -------
119
    Image
120
        PIL Image.
121
122
    """
123
    return Image.fromarray(arr.astype('uint8'), 'RGB')
124
125
def blend(arr1, arr2, alpha=0.5):
126
    """Blend 2 arrays together, mixing with alpha.
127
128
    Parameters
129
    ----------
130
    arr1:array
131
        Image 1.
132
    arr2:array
133
        Image 2.
134
    alpha:float
135
        Higher alpha makes image more like image 1.
136
137
    Returns
138
    -------
139
    array
140
        Resulting image.
141
142
    """
143
    return alpha*arr1 + (1.-alpha)*arr2
144
145
def prob2rbg(prob, palette, arr):
146
    """Convert probability score to rgb image.
147
148
    Parameters
149
    ----------
150
    prob:float
151
        Between 0 and 1 score.
152
    palette:palette
153
        Pallet converts between prob and color.
154
    arr:array
155
        Original array.
156
157
    Returns
158
    -------
159
    array
160
        New image colored by prediction score.
161
162
    """
163
    col = palette(prob)
164
    for i in range(3):
165
        arr[...,i] = int(col[i]*255)
166
    return arr
167
168
def seg2rgb(seg, palette, n_segmentation_classes):
169
    """Color each pixel by segmentation class.
170
171
    Parameters
172
    ----------
173
    seg:array
174
        Segmentation mask.
175
    palette:palette
176
        Color to RGB map.
177
    n_segmentation_classes:int
178
        Total number segmentation classes.
179
180
    Returns
181
    -------
182
    array
183
        Returned segmentation image.
184
    """
185
    #print(seg.shape)
186
    #print((seg/n_segmentation_classes))
187
    img=(palette(seg/n_segmentation_classes)[...,:3]*255).astype(int)
188
    #print(img.shape)
189
    return img
190
191
def annotation2rgb(i,palette,arr):
192
    """Go from annotation of patch to color.
193
194
    Parameters
195
    ----------
196
    i:int
197
        Annotation index.
198
    palette:palette
199
        Index to color mapping.
200
    arr:array
201
        Image array.
202
203
    Returns
204
    -------
205
    array
206
        Resulting image.
207
208
    """
209
    col = palette[i]
210
    for i in range(3):
211
        arr[...,i] = int(col[i]*255)
212
    return arr
213
214
def plot_image_(image_file, compression_factor=2., test_image_name='test.png'):
215
    """Plots entire SVS/other image.
216
217
    Parameters
218
    ----------
219
    image_file:str
220
        Image file.
221
    compression_factor:float
222
        Amount to shrink each dimension of image.
223
    test_image_name:str
224
        Output image file.
225
226
    """
227
    from pathflowai.utils import svs2dask_array, npy2da
228
    import cv2
229
    if image_file.endswith('.zarr'):
230
        arr=da.from_zarr(image_file)
231
    else:
232
        arr=svs2dask_array(image_file, tile_size=1000, overlap=0, remove_last=True, allow_unknown_chunksizes=False) if (not image_file.endswith('.npy')) else npy2da(image_file)
233
    arr2=to_pil(cv2.resize(arr.compute(), dsize=tuple((np.array(arr.shape[:2])/compression_factor).astype(int).tolist()), interpolation=cv2.INTER_CUBIC))
234
    arr2.save(test_image_name)
235
236
# for now binary output
237
class PredictionPlotter:
238
    """Plots predictions over entire image.
239
240
    Parameters
241
    ----------
242
    dask_arr_dict:dict
243
        Stores all dask arrays corresponding to all of the images.
244
    patch_info_db:str
245
        Patch level information, eg. prediction.
246
    compression_factor:float
247
        How much to compress image by.
248
    alpha:float
249
        Low value assigns higher weight to prediction over original image.
250
    patch_size:int
251
        Patch size.
252
    no_db:bool
253
        Don't use patch information.
254
    plot_annotation:bool
255
        Plot annotations from patch information.
256
    segmentation:bool
257
        Plot segmentation mask.
258
    n_segmentation_classes:int
259
        Number segmentation classes.
260
    input_dir:str
261
        Input directory.
262
    annotation_col:str
263
        Annotation column to plot.
264
    scaling_factor:float
265
        Multiplies the prediction scores to make them appear darker on the images when predicting.
266
    """
267
    # some patches have been filtered out, not one to one!!! figure out
268
    def __init__(self, dask_arr_dict, patch_info_db, compression_factor=3, alpha=0.5, patch_size=224, no_db=False, plot_annotation=False, segmentation=False, n_segmentation_classes=4, input_dir='', annotation_col='annotation', scaling_factor=1.):
269
270
        self.segmentation = segmentation
271
        self.scaling_factor=scaling_factor
272
        self.segmentation_maps = None
273
        self.n_segmentation_classes=float(n_segmentation_classes)
274
        self.pred_palette = sns.cubehelix_palette(start=0,as_cmap=True)
275
        if not no_db:
276
            self.compression_factor=compression_factor
277
            self.alpha = alpha
278
            self.patch_size = patch_size
279
            conn = sqlite3.connect(patch_info_db)
280
            patch_info=pd.read_sql('select * from "{}";'.format(patch_size),con=conn)
281
            conn.close()
282
            self.annotations = {str(a):i for i,a in enumerate(patch_info['annotation'].unique().tolist())}
283
            self.plot_annotation=plot_annotation
284
            self.palette=sns.color_palette(n_colors=len(list(self.annotations.keys())))
285
            #print(self.palette)
286
            if 'y_pred' not in patch_info.columns:
287
                patch_info['y_pred'] = 0.
288
            self.patch_info=patch_info[['ID','x','y','patch_size','annotation',annotation_col]] # y_pred
289
            if 0:
290
                for ID in predictions:
291
                    patch_info.loc[patch_info["ID"]==ID,'y_pred'] = predictions[ID]
292
            self.patch_info = self.patch_info[np.isin(self.patch_info['ID'],np.array(list(dask_arr_dict.keys())))]
293
        if self.segmentation:
294
            self.segmentation_maps = {slide:npy2da(join(input_dir,'{}_mask.npy'.format(slide))) for slide in dask_arr_dict.keys()}
295
        #self.patch_info[['x','y','patch_size']]/=self.compression_factor
296
        self.dask_arr_dict = {k:v[...,:3] for k,v in dask_arr_dict.items()}
297
298
    def add_custom_segmentation(self, basename, npy):
299
        """Replace segmentation mask with new custom segmentation.
300
301
        Parameters
302
        ----------
303
        basename:str
304
            Patient ID
305
        npy:str
306
            Numpy mask.
307
        """
308
        self.segmentation_maps[basename] = da.from_array(np.load(npy,mmap_mode='r+'))
309
310
    def generate_image(self, ID):
311
        """Generate the image array for the whole slide image with predictions overlaid.
312
313
        Parameters
314
        ----------
315
        ID:str
316
            patient ID.
317
318
        Returns
319
        -------
320
        array
321
            Resulting overlaid whole slide image.
322
323
        """
324
        patch_info = self.patch_info[self.patch_info['ID']==ID]
325
        dask_arr = self.dask_arr_dict[ID]
326
        arr_shape = np.array(dask_arr.shape).astype(float)
327
328
        #image=da.zeros_like(dask_arr)
329
330
        arr_shape[:2]/=self.compression_factor
331
332
        arr_shape=arr_shape.astype(int).tolist()
333
334
        img = Image.new('RGB',arr_shape[:2],'white')
335
336
        for i in range(patch_info.shape[0]):
337
            ID,x,y,patch_size,annotation,pred = patch_info.iloc[i].tolist()
338
            #print(x,y,annotation)
339
            x_new,y_new = int(x/self.compression_factor),int(y/self.compression_factor)
340
            image = np.zeros((patch_size,patch_size,3))
341
            if self.segmentation:
342
                image=seg2rgb(self.segmentation_maps[ID][x:x+patch_size,y:y+patch_size].compute(),self.pred_palette, self.n_segmentation_classes)
343
            else:
344
                image=prob2rbg(pred*self.scaling_factor, self.pred_palette, image) if not self.plot_annotation else annotation2rgb(self.annotations[str(pred)],self.palette,image) # annotation
345
            arr=dask_arr[x:x+patch_size,y:y+patch_size].compute()
346
            #print(image.shape)
347
            blended_patch=blend(arr,image, self.alpha).transpose((1,0,2))
348
            blended_patch_pil = to_pil(blended_patch)
349
            patch_size/=self.compression_factor
350
            patch_size=int(patch_size)
351
            blended_patch_pil=blended_patch_pil.resize((patch_size,patch_size))
352
            img.paste(blended_patch_pil, box=(x_new,y_new), mask=None)
353
        return img
354
355
    def return_patch(self, ID, x, y, patch_size):
356
        """Return one single patch instead of entire image.
357
358
        Parameters
359
        ----------
360
        ID:str
361
            Patient ID
362
        x:int
363
            X coordinate.
364
        y:int
365
            Y coordinate.
366
        patch_size:int
367
            Patch size.
368
369
        Returns
370
        -------
371
        array
372
            Image.
373
        """
374
        img=(self.dask_arr_dict[ID][x:x+patch_size,y:y+patch_size].compute() if not self.segmentation else seg2rgb(self.segmentation_maps[ID][x:x+patch_size,y:y+patch_size].compute(),self.pred_palette, self.n_segmentation_classes))
375
        return to_pil(img)
376
377
    def output_image(self, img, filename, tif=False):
378
        """Output calculated image to file.
379
380
        Parameters
381
        ----------
382
        img:array
383
            Image.
384
        filename:str
385
            Output file name.
386
        tif:bool
387
            Store in TIF format?
388
        """
389
        if tif:
390
            from tifffile import imwrite
391
            imwrite(filename, np.array(img), photometric='rgb')
392
        else:
393
            img.save(filename)
394
395
def plot_shap(model, dataset_opts, transform_opts, batch_size, outputfilename, n_outputs=1, method='deep', local_smoothing=0.0, n_samples=20, pred_out=False):
396
    """Plot shapley attributions overlaid on images for classification tasks.
397
398
    Parameters
399
    ----------
400
    model:nn.Module
401
        Pytorch model.
402
    dataset_opts:dict
403
        Options used to configure dataset
404
    transform_opts:dict
405
        Options used to configure transformers.
406
    batch_size:int
407
        Batch size for training.
408
    outputfilename:str
409
        Output filename.
410
    n_outputs:int
411
        Number of top outputs.
412
    method:str
413
        Gradient or deep explainer.
414
    local_smoothing:float
415
        How much to smooth shapley map.
416
    n_samples:int
417
        Number shapley samples to draw.
418
    pred_out:bool
419
        Label images with binary prediction score?
420
421
    """
422
    import torch
423
    from torch.nn import functional as F
424
    import numpy as np
425
    from torch.utils.data import DataLoader
426
    import shap
427
    from pathflowai.datasets import DynamicImageDataset
428
    import matplotlib
429
    from matplotlib import pyplot as plt
430
    from pathflowai.sampler import ImbalancedDatasetSampler
431
432
    out_transform=dict(sigmoid=F.sigmoid,softmax=F.softmax,none=lambda x: x)
433
    binary_threshold=dataset_opts.pop('binary_threshold')
434
    num_targets=dataset_opts.pop('num_targets')
435
436
    dataset = DynamicImageDataset(**dataset_opts)
437
438
    if dataset_opts['classify_annotations']:
439
        binarizer=dataset.binarize_annotations(num_targets=num_targets,binary_threshold=binary_threshold)
440
        num_targets=len(dataset.targets)
441
442
    dataloader_val = DataLoader(dataset,batch_size=batch_size, num_workers=10, shuffle=True if num_targets>1 else False, sampler=ImbalancedDatasetSampler(dataset) if num_targets==1 else None)
443
    #dataloader_test = DataLoader(dataset,batch_size=batch_size,num_workers=10, shuffle=False)
444
445
    background,y_background=next(iter(dataloader_val))
446
    if method=='gradient':
447
        background=torch.cat([background,next(iter(dataloader_val))[0]],0)
448
    X_test,y_test=next(iter(dataloader_val))
449
450
    if torch.cuda.is_available():
451
        background=background.cuda()
452
        X_test=X_test.cuda()
453
454
    if pred_out!='none':
455
        if torch.cuda.is_available():
456
            model2=model.cuda()
457
        y_test=out_transform[pred_out](model2(X_test)).detach().cpu()
458
459
    y_test=y_test.numpy()
460
461
    if method=='deep':
462
        e = shap.DeepExplainer(model, background)
463
        s=e.shap_values(X_test, ranked_outputs=n_outputs)
464
    elif method=='gradient':
465
        e = shap.GradientExplainer(model, background, batch_size=batch_size, local_smoothing=local_smoothing)
466
        s=e.shap_values(X_test, ranked_outputs=n_outputs, nsamples=n_samples)
467
468
    if y_test.shape[1]>1:
469
        y_test=y_test.argmax(axis=1)
470
471
    if n_outputs>1:
472
        shap_values, idx = s
473
    else:
474
        shap_values, idx = s, y_test
475
476
    #print(shap_values) # .detach().cpu()
477
478
    if num_targets == 1:
479
        shap_numpy = [np.swapaxes(np.swapaxes(shap_values, 1, -1), 1, 2)]
480
    else:
481
        shap_numpy = [np.swapaxes(np.swapaxes(s, 1, -1), 1, 2) for s in shap_values]
482
        #print(shap_numpy.shape)
483
    X_test_numpy=X_test.detach().cpu().numpy()
484
    X_test_numpy=X_test_numpy.transpose((0,2,3,1))
485
    for i in range(X_test_numpy.shape[0]):
486
        X_test_numpy[i,...]*=np.array(transform_opts['std'])
487
        X_test_numpy[i,...]+=np.array(transform_opts['mean'])
488
    X_test_numpy=X_test_numpy.transpose((0,3,1,2))
489
    test_numpy = np.swapaxes(np.swapaxes(X_test_numpy, 1, -1), 1, 2)
490
    if pred_out!='none':
491
        labels=y_test.astype(str)
492
    else:
493
        labels = np.array([[(dataloader_val.dataset.targets[i[j]] if num_targets>1 else str(i)) for j in range(n_outputs)] for i in idx])#[:,np.newaxis] # y_test
494
    if 0 and (len(labels.shape)<2 or labels.shape[1]==1):
495
        labels=labels.flatten()#[:np.newaxis]
496
497
    #print(labels.shape,shap_numpy.shape[0])
498
    plt.figure()
499
    shap.image_plot(shap_numpy, test_numpy, labels)# if num_targets!=1 else shap_values -test_numpy , labels=dataloader_test.dataset.targets)
500
    plt.savefig(outputfilename, dpi=300)
501
502
def plot_umap_images(dask_arr_dict, embeddings_file, ID=None, cval=1., image_res=300., outputfname='output_embedding.png', mpl_scatter=True, remove_background_annotation='', max_background_area=0.01, zoom=0.05, n_neighbors=10, sort_col='', sort_mode='asc'):
503
    """Make UMAP embedding plot, overlaid with images.
504
505
    Parameters
506
    ----------
507
    dask_arr_dict:dict
508
        Stored dask arrays for each WSI.
509
    embeddings_file:str
510
        Embeddings pickle file stored from running using after trainign the model.
511
    ID:str
512
        Patient ID.
513
    cval:float
514
        Deprecated
515
    image_res:float
516
        Image resolution.
517
    outputfname:str
518
        Output image file.
519
    mpl_scatter:bool
520
        Recommended: Use matplotlib for scatter plot.
521
    remove_background_annotation:str
522
        Remove the background annotations. Enter for annotation to remove.
523
    max_background_area:float
524
        Maximum backgrund area in each tile for inclusion.
525
    zoom:float
526
        How much to zoom in on each patch, less than 1 is zoom out.
527
    n_neighbors:int
528
        Number of neighbors for UMAP embedding.
529
    sort_col:str
530
        Patch info column to sort on.
531
    sort_mode:str
532
        Sort ascending or descending.
533
534
    Returns
535
    -------
536
    type
537
        Description of returned object.
538
539
    Inspired by: https://gist.github.com/lukemetz/be6123c7ee3b366e333a
540
    WIP!! Needs testing."""
541
    import torch
542
    import dask
543
    from dask.distributed import Client
544
    from umap import UMAP
545
    from pathflowai.visualize import PlotlyPlot
546
    import pandas as pd, numpy as np
547
    import skimage.io
548
    from skimage.transform import resize
549
    import matplotlib
550
    matplotlib.use('Agg')
551
    from matplotlib import pyplot as plt
552
    sns.set(style='white')
553
554
    def min_resize(img, size):
555
        """
556
        Resize an image so that it is size along the minimum spatial dimension.
557
        """
558
        w, h = map(float, img.shape[:2])
559
        if min([w, h]) != size:
560
            if w <= h:
561
                img = resize(img, (int(round((h/w)*size)), int(size)))
562
            else:
563
                img = resize(img, (int(size), int(round((w/h)*size))))
564
        return img
565
566
    #dask_arr = dask_arr_dict[ID]
567
568
    embeddings_dict=torch.load(embeddings_file)
569
    embeddings=embeddings_dict['embeddings']
570
    patch_info=embeddings_dict['patch_info']
571
    if sort_col:
572
        idx=np.argsort(patch_info[sort_col].values)
573
        if sort_mode == 'desc':
574
            idx=idx[::-1]
575
        patch_info = patch_info.iloc[idx]
576
        embeddings=embeddings.iloc[idx]
577
    if ID:
578
        removal_bool=(patch_info['ID']==ID).values
579
        patch_info = patch_info.loc[removal_bool]
580
        embeddings=embeddings.loc[removal_bool]
581
    if remove_background_annotation:
582
        removal_bool=(patch_info[remove_background_annotation]<=(1.-max_background_area)).values
583
        patch_info=patch_info.loc[removal_bool]
584
        embeddings=embeddings.loc[removal_bool]
585
586
    umap=UMAP(n_components=2,n_neighbors=n_neighbors)
587
    t_data=pd.DataFrame(umap.fit_transform(embeddings.iloc[:,:-1].values),columns=['x','y'],index=embeddings.index)
588
589
    images=[]
590
591
    for i in range(patch_info.shape[0]):
592
        ID=patch_info.iloc[i]['ID']
593
        x,y,patch_size=patch_info.iloc[i][['x','y','patch_size']].values.tolist()
594
        arr=dask_arr_dict[ID][x:x+patch_size,y:y+patch_size]#.transpose((2,0,1))
595
        images.append(arr)
596
597
    c=Client()
598
    images=dask.compute(images)
599
    c.close()
600
601
    if mpl_scatter:
602
        from matplotlib.offsetbox import OffsetImage, AnnotationBbox
603
        def imscatter(x, y, ax, imageData, zoom):
604
            images = []
605
            for i in range(len(x)):
606
                x0, y0 = x[i], y[i]
607
                img = imageData[i]
608
                #print(img.shape)
609
                image = OffsetImage(img, zoom=zoom)
610
                ab = AnnotationBbox(image, (x0, y0), xycoords='data', frameon=False)
611
                images.append(ax.add_artist(ab))
612
613
            ax.update_datalim(np.column_stack([x, y]))
614
            ax.autoscale()
615
616
        fig, ax = plt.subplots()
617
        imscatter(t_data['x'].values, t_data['y'].values, imageData=images[0], ax=ax, zoom=zoom)
618
        sns.despine()
619
        plt.savefig(outputfname,dpi=300)
620
621
622
    else:
623
        xx=t_data.iloc[:,0]
624
        yy=t_data.iloc[:,1]
625
626
        images = [min_resize(image, img_res) for image in images]
627
        max_width = max([image.shape[0] for image in images])
628
        max_height = max([image.shape[1] for image in images])
629
630
        x_min, x_max = xx.min(), xx.max()
631
        y_min, y_max = yy.min(), yy.max()
632
        # Fix the ratios
633
        sx = (x_max-x_min)
634
        sy = (y_max-y_min)
635
        if sx > sy:
636
            res_x = sx/float(sy)*res
637
            res_y = res
638
        else:
639
            res_x = res
640
            res_y = sy/float(sx)*res
641
642
        canvas = np.ones((res_x+max_width, res_y+max_height, 3))*cval
643
        x_coords = np.linspace(x_min, x_max, res_x)
644
        y_coords = np.linspace(y_min, y_max, res_y)
645
        for x, y, image in zip(xx, yy, images):
646
            w, h = image.shape[:2]
647
            x_idx = np.argmin((x - x_coords)**2)
648
            y_idx = np.argmin((y - y_coords)**2)
649
            canvas[x_idx:x_idx+w, y_idx:y_idx+h] = image
650
651
        skimage.io.imsave(outputfname, canvas)