a b/pathflowai/utils.py
1
"""
2
utils.py
3
=======================
4
General utilities that still need to be broken up into preprocessing, machine learning input preparation, and output submodules.
5
"""
6
7
import numpy as np
8
from bs4 import BeautifulSoup
9
from shapely.geometry import Point
10
from shapely.geometry.polygon import Polygon
11
import glob
12
from os.path import join
13
import plotly.graph_objs as go
14
import plotly.offline as py
15
import pandas as pd, numpy as np
16
import scipy.sparse as sps
17
from PIL import Image, ImageDraw
18
Image.MAX_IMAGE_PIXELS=1e10
19
import numpy as np
20
import scipy.sparse as sps
21
from os.path import join
22
import os, subprocess, pandas as pd
23
import sqlite3
24
import torch
25
from torch.utils.data import Dataset#, DataLoader
26
from sklearn.model_selection import train_test_split
27
import pysnooper
28
from shapely.ops import unary_union, polygonize
29
from shapely.geometry import MultiPolygon, LineString
30
import numpy as np
31
import dask.array as da
32
import dask
33
import openslide
34
from openslide import deepzoom
35
#import xarray as xr, sparse
36
import pickle
37
import copy
38
import h5py
39
import nonechucks as nc
40
from nonechucks import SafeDataLoader as DataLoader
41
42
import cv2
43
import numpy as np
44
from skimage.morphology import watershed
45
from skimage.feature import peak_local_max
46
from scipy.ndimage import label as scilabel, distance_transform_edt
47
import scipy.ndimage as ndimage
48
from skimage import morphology as morph
49
from scipy.ndimage.morphology import binary_fill_holes as fill_holes
50
from skimage.filters import threshold_otsu, rank
51
from skimage.morphology import convex_hull_image, remove_small_holes
52
from skimage import measure
53
import xmltodict as xd
54
from collections import defaultdict
55
56
57
def load_sql_df(sql_file, patch_size):
58
    """Load pandas dataframe from SQL, accessing particular patch size within SQL.
59
60
    Parameters
61
    ----------
62
    sql_file:str
63
        SQL db.
64
    patch_size:int
65
        Patch size.
66
67
    Returns
68
    -------
69
    dataframe
70
        Patch level information.
71
72
    """
73
    conn = sqlite3.connect(sql_file)
74
    df=pd.read_sql('select * from "{}";'.format(patch_size),con=conn)
75
    conn.close()
76
    return df
77
78
def df2sql(df, sql_file, patch_size, mode='replace'):
79
    """Write dataframe containing patch level information to SQL db.
80
81
    Parameters
82
    ----------
83
    df:dataframe
84
        Dataframe containing patch information.
85
    sql_file:str
86
        SQL database.
87
    patch_size:int
88
        Size of patches.
89
    mode:str
90
        Replace or append.
91
92
    """
93
    conn = sqlite3.connect(sql_file)
94
    df.set_index('index').to_sql(str(patch_size), con=conn, if_exists=mode)
95
    conn.close()
96
97
98
#########
99
100
# https://github.com/qupath/qupath/wiki/Supported-image-formats
101
def svs2dask_array(svs_file, tile_size=1000, overlap=0, remove_last=True, allow_unknown_chunksizes=False, transpose=False):
102
    """Convert SVS, TIF or TIFF to dask array.
103
    Parameters
104
    ----------
105
    svs_file : str
106
            Image file.
107
    tile_size : int
108
            Size of chunk to be read in.
109
    overlap : int
110
            Do not modify, overlap between neighboring tiles.
111
    remove_last : bool
112
            Remove last tile because it has a custom size.
113
    allow_unknown_chunksizes : bool
114
            Allow different chunk sizes, more flexible, but slowdown.
115
    Returns
116
    -------
117
    arr : dask.array.Array
118
            A Dask Array representing the contents of the image file.
119
    >>> arr = svs2dask_array(svs_file, tile_size=1000, overlap=0, remove_last=True, allow_unknown_chunksizes=False)
120
    >>> arr2 = arr.compute()
121
    >>> arr3 = to_pil(cv2.resize(arr2, dsize=(1440, 700), interpolation=cv2.INTER_CUBIC))
122
    >>> arr3.save(test_image_name)
123
    """
124
    # https://github.com/jlevy44/PathFlowAI/blob/master/pathflowai/utils.py
125
    img = openslide.open_slide(svs_file)
126
    if type(img) is openslide.OpenSlide:
127
        gen = deepzoom.DeepZoomGenerator(
128
            img, tile_size=tile_size, overlap=overlap, limit_bounds=True)
129
        max_level = len(gen.level_dimensions) - 1
130
        n_tiles_x, n_tiles_y = gen.level_tiles[max_level]
131
132
        @dask.delayed(pure=True)
133
        def get_tile(level, column, row):
134
            tile = gen.get_tile(level, (column, row))
135
            return np.array(tile).transpose((1, 0, 2))
136
137
        sample_tile_shape = get_tile(max_level, 0, 0).shape.compute()
138
        rows = range(n_tiles_y - (0 if not remove_last else 1))
139
        cols = range(n_tiles_x - (0 if not remove_last else 1))
140
        arr = da.concatenate([da.concatenate([da.from_delayed(get_tile(max_level, col, row), sample_tile_shape, np.uint8) for row in rows],
141
                                             allow_unknown_chunksizes=allow_unknown_chunksizes, axis=1) for col in cols], allow_unknown_chunksizes=allow_unknown_chunksizes)
142
        if transpose:
143
            arr=arr.transpose([1, 0, 2])
144
        return arr
145
    else:  # img is instance of openslide.ImageSlide
146
        return dask_image.imread.imread(svs_file)
147
148
def img2npy_(input_dir,basename, svs_file):
149
    """Convert SVS, TIF, TIFF to NPY.
150
151
    Parameters
152
    ----------
153
    input_dir:str
154
        Output file dir.
155
    basename:str
156
        Basename of output file
157
    svs_file:str
158
        SVS, TIF, TIFF file input.
159
160
    Returns
161
    -------
162
    str
163
        NPY output file.
164
    """
165
    npy_out_file = join(input_dir,'{}.npy'.format(basename))
166
    arr = svs2dask_array(svs_file)
167
    np.save(npy_out_file,arr.compute())
168
    return npy_out_file
169
170
def load_image(svs_file):
171
    """Load SVS, TIF, TIFF
172
173
    Parameters
174
    ----------
175
    svs_file:type
176
        Description of parameter `svs_file`.
177
178
    Returns
179
    -------
180
    type
181
        Description of returned object.
182
    """
183
    im = Image.open(svs_file)
184
    return np.transpose(np.array(im),(1,0)), im.size
185
186
def create_purple_mask(arr, img_size=None, sparse=True):
187
    """Create a gray scale intensity mask. This will be changed soon to support other thresholding QC methods.
188
189
    Parameters
190
    ----------
191
    arr:dask.array
192
        Dask array containing image information.
193
    img_size:int
194
        Deprecated.
195
    sparse:bool
196
        Deprecated
197
198
    Returns
199
    -------
200
    dask.array
201
        Intensity, grayscale array over image.
202
203
    """
204
    r,b,g=arr[:,:,0],arr[:,:,1],arr[:,:,2]
205
    gray = 0.2989 * r + 0.5870 * g + 0.1140 * b
206
    #rb_avg = (r+b)/2
207
    mask= ((255.-gray))# >= threshold)#(r > g - 10) & (b > g - 10) & (rb_avg > g + 20)#np.vectorize(is_purple)(arr).astype(int)
208
    if 0 and sparse:
209
        mask = mask.nonzero()
210
        mask = np.array([mask[0].compute(), mask[1].compute()]).T
211
        #mask = (np.ones(len(mask[0])),mask)
212
        #mask = sparse.COO.from_scipy_sparse(sps.coo_matrix(mask, img_size, dtype=np.uint8).tocsr())
213
    return mask
214
215
def add_purple_mask(arr):
216
    """Optional add intensity mask to the dask array.
217
218
    Parameters
219
    ----------
220
    arr:dask.array
221
        Image data.
222
223
    Returns
224
    -------
225
    array
226
        Image data with intensity added as forth channel.
227
228
    """
229
    return np.concatenate((arr,create_purple_mask(arr)),axis=0)
230
231
def create_sparse_annotation_arrays(xml_file, img_size, annotations=[], transpose_annotations=False):
232
    """Convert annotation xml to shapely objects and store in dictionary.
233
234
    Parameters
235
    ----------
236
    xml_file:str
237
        XML file containing annotations.
238
    img_size:int
239
        Deprecated.
240
    annotations:list
241
        Annotations to look for in xml export.
242
243
    Returns
244
    -------
245
    dict
246
        Dictionary with annotation-shapely object pairs.
247
248
    """
249
    interior_points_dict = {annotation:parse_coord_return_boxes(xml_file, annotation_name = annotation, return_coords = False, transpose_annotations=transpose_annotations) for annotation in annotations}#grab_interior_points(xml_file, img_size, annotations=annotations) if annotations else {}
250
    return {annotation:interior_points_dict[annotation] for annotation in annotations}#sparse.COO.from_scipy_sparse((sps.coo_matrix(interior_points_dict[annotation],img_size, dtype=np.uint8) if interior_points_dict[annotation] not None else sps.coo_matrix(img_size, dtype=np.uint8)).tocsr()) for annotation in annotations} # [sps.coo_matrix(img_size, dtype=np.uint8)]+
251
252
def load_image(svs_file):
253
    return (npy2da(svs_file) if (svs_file.endswith('.npy') or svs_file.endswith('.h5')) else svs2dask_array(svs_file, tile_size=1000, overlap=0))
254
255
def load_preprocessed_img(img_file):
256
    if img_file.endswith('.zarr') and not os.path.exists(f"{img_file}/.zarray"):
257
        img_file=img_file.replace(".zarr",".npy")
258
    return npy2da(img_file) if (img_file.endswith('.npy') or img_file.endswith('.h5')) else da.from_zarr(img_file)
259
260
def load_process_image(svs_file, xml_file=None, npy_mask=None, annotations=[], transpose_annotations=False):
261
    """Load SVS-like image (including NPY), segmentation/classification annotations, generate dask array and dictionary of annotations.
262
263
    Parameters
264
    ----------
265
    svs_file:str
266
        Image file
267
    xml_file:str
268
        Annotation file.
269
    npy_mask:array
270
        Numpy segmentation mask.
271
    annotations:list
272
        List of annotations in xml.
273
274
    Returns
275
    -------
276
    array
277
        Dask array of image.
278
    dict
279
        Annotation masks.
280
281
    """
282
    arr = load_image(svs_file)#npy2da(svs_file) if (svs_file.endswith('.npy') or svs_file.endswith('.h5')) else svs2dask_array(svs_file, tile_size=1000, overlap=0)#load_image(svs_file)
283
    img_size = arr.shape[:2]
284
    masks = {}#{'purple': create_purple_mask(arr,img_size,sparse=False)}
285
    if xml_file is not None:
286
        masks.update(create_sparse_annotation_arrays(xml_file, img_size, annotations=annotations, transpose_annotations=transpose_annotations))
287
    if npy_mask is not None:
288
        masks.update({'annotations':npy_mask})
289
    #data = dict(image=(['x','y','rgb'],arr),**masks)
290
    #data_arr = {'image':xr.Variable(['x','y','color'], arr)}
291
    #purple_arr = {'mask':xr.Variable(['x','y'], masks['purple'])}
292
    #mask_arr =  {m:xr.Variable(['row','col'],masks[m]) for m in masks if m != 'purple'} if 'annotations' not in annotations else {'annotations':xr.Variable(['x','y'],masks['annotations'])}
293
    #masks['purple'] = masks['purple'].reshape(*masks['purple'].shape,1)
294
    #arr = da.concatenate([arr,masks.pop('purple')],axis=2)
295
    return arr, masks#xr.Dataset.from_dict({k:v for k,v in list(data_arr.items())+list(purple_arr.items())+list(mask_arr.items())})#list(dict(image=data_arr,purple=purple_arr,annotations=mask_arr).items()))#arr, masks
296
297
def save_dataset(arr, masks, out_zarr, out_pkl, no_zarr):
298
    """Saves dask array image, dictionary of annotations to zarr and pickle respectively.
299
300
    Parameters
301
    ----------
302
    arr:array
303
        Image.
304
    masks:dict
305
        Dictionary of annotation shapes.
306
    out_zarr:str
307
        Zarr output file for image.
308
    out_pkl:str
309
        Pickle output file.
310
    """
311
    if not no_zarr:
312
        arr.astype('uint8').to_zarr(out_zarr, overwrite=True)
313
    pickle.dump(masks,open(out_pkl,'wb'))
314
315
    #dataset.to_netcdf(out_netcdf, compute=False)
316
    #pickle.dump(dataset, open(out_pkl,'wb'), protocol=-1)
317
318
def run_preprocessing_pipeline(svs_file, xml_file=None, npy_mask=None, annotations=[], out_zarr='output_zarr.zarr', out_pkl='output.pkl',no_zarr=False,transpose_annotations=False):
319
    """Run preprocessing pipeline. Store image into zarr format, segmentations maintain as npy, and xml annotations as pickle.
320
321
    Parameters
322
    ----------
323
    svs_file:str
324
        Input image file.
325
    xml_file:str
326
        Input annotation file.
327
    npy_mask:str
328
        NPY segmentation mask.
329
    annotations:list
330
        List of annotations.
331
    out_zarr:str
332
        Output zarr for image.
333
    out_pkl:str
334
        Output pickle for annotations.
335
    """
336
    #save_dataset(load_process_image(svs_file, xml_file, npy_mask, annotations), out_netcdf)
337
    arr, masks = load_process_image(svs_file, xml_file, npy_mask, annotations, transpose_annotations)
338
    save_dataset(arr, masks,out_zarr, out_pkl, no_zarr)
339
340
###################
341
342
def adjust_mask(mask_file, dask_img_array_file, out_npy, n_neighbors):
343
    """Fixes segmentation masks to reduce coarse annotations over empty regions.
344
345
    Parameters
346
    ----------
347
    mask_file:str
348
        NPY segmentation mask.
349
    dask_img_array_file:str
350
        Dask image file.
351
    out_npy:str
352
        Output numpy file.
353
    n_neighbors:int
354
        Number nearest neighbors for dilation and erosion of mask from background to not background.
355
356
    Returns
357
    -------
358
    str
359
        Output numpy file.
360
361
    """
362
    from dask_image.ndmorph import binary_opening
363
    from dask.distributed import Client
364
    #c=Client()
365
    dask_img_array=da.from_zarr(dask_img_array_file)
366
    mask=npy2da(mask_file)
367
    is_tissue_mask = mask>0.
368
    is_tissue_mask_img=((dask_img_array[...,0]>200.) & (dask_img_array[...,1]>200.)& (dask_img_array[...,2]>200.)) == 0
369
    opening=binary_opening(is_tissue_mask_img,structure=da.ones((n_neighbors,n_neighbors)))#,mask=is_tissue_mask)
370
    mask[(opening==0)&(is_tissue_mask==1)]=0
371
    np.save(out_npy,mask.compute())
372
    #c.close()
373
    return out_npy
374
375
def filter_grays(rgb, tolerance=15, output_type="bool"):
376
  """ https://github.com/deroneriksson/python-wsi-preprocessing/blob/master/deephistopath/wsi/filter.py
377
  Create a mask to filter out pixels where the red, green, and blue channel values are similar.
378
  Args:
379
    np_img: RGB image as a NumPy array.
380
    tolerance: Tolerance value to determine how similar the values must be in order to be filtered out
381
    output_type: Type of array to return (bool, float, or uint8).
382
  Returns:
383
    NumPy array representing a mask where pixels with similar red, green, and blue values have been masked out.
384
  """
385
  (h, w, c) = rgb.shape
386
  rgb = rgb.astype(np.int)
387
  rg_diff = np.abs(rgb[:, :, 0] - rgb[:, :, 1]) <= tolerance
388
  rb_diff = np.abs(rgb[:, :, 0] - rgb[:, :, 2]) <= tolerance
389
  gb_diff = np.abs(rgb[:, :, 1] - rgb[:, :, 2]) <= tolerance
390
  result = ~(rg_diff & rb_diff & gb_diff)
391
  if output_type == "bool":
392
      pass
393
  elif output_type == "float":
394
      result = result.astype(float)
395
  else:
396
      result = result.astype("uint8") * 255
397
  return result
398
399
def label_objects(img,
400
                    otsu=True,
401
                    min_object_size=100000,
402
                    threshold=240,
403
                    connectivity=8,
404
                    kernel=61,
405
                    keep_holes=False,
406
                    max_hole_size=0,
407
                    gray_before_close=False,
408
                    blur_size=0):
409
    I=cv2.cvtColor(img,cv2.COLOR_RGB2GRAY)
410
    gray_mask=filter_grays(img, output_type="bool")
411
    if otsu: threshold = threshold_otsu(I)
412
    BW = (I<threshold).astype(bool)
413
    if gray_before_close: BW=BW&gray_mask
414
    if kernel>0: BW = morph.binary_closing(BW, morph.disk(kernel))#square
415
    if not gray_before_close: BW=BW&gray_mask
416
    if blur_size: BW=(cv2.blur(BW.astype(np.uint8), (blur_size,blur_size))==1)
417
    labels = scilabel(BW)[0]
418
    labels=morph.remove_small_objects(labels, min_size=min_object_size, connectivity = connectivity, in_place=True)
419
    if not keep_holes and max_hole_size:
420
        BW=morph.remove_small_objects(labels==0, min_size=max_hole_size, connectivity = connectivity, in_place=True)==False#remove_small_holes(labels,area_threshold=max_hole_size, connectivity = connectivity, in_place=True)>0
421
    elif keep_holes:
422
        BW=labels>0
423
    else:
424
        BW=fill_holes(labels)
425
    labels = scilabel(BW)[0]
426
    return(BW!=0),labels
427
428
def generate_tissue_mask(arr,
429
                         compression=8,
430
                         otsu=False,
431
                         threshold=220,
432
                         connectivity=8,
433
                         kernel=61,
434
                         min_object_size=100000,
435
                         return_convex_hull=False,
436
                         keep_holes=False,
437
                         max_hole_size=0,
438
                         gray_before_close=False,
439
                         blur_size=0):
440
    img=cv2.resize(arr,None,fx=1/compression,fy=1/compression,interpolation=cv2.INTER_CUBIC)
441
    WB, lbl=label_objects(img, otsu=otsu, min_object_size=min_object_size, threshold=threshold, connectivity=connectivity, kernel=kernel,keep_holes=keep_holes,max_hole_size=max_hole_size, gray_before_close=gray_before_close,blur_size=blur_size)
442
    if return_convex_hull:
443
        for i in range(1,lbl.max()+1):
444
            WB=WB+convex_hull_image(lbl==i)
445
        WB=WB>0
446
    WB=cv2.resize(WB.astype(np.uint8),arr.shape[:2][::-1],interpolation=cv2.INTER_CUBIC)>0
447
    return WB
448
449
###################
450
451
def process_svs(svs_file, xml_file, annotations=[], output_dir='./'):
452
    """Store images into npy format and store annotations into pickle dictionary.
453
454
    Parameters
455
    ----------
456
    svs_file:str
457
        Image file.
458
    xml_file:str
459
        Annotations file.
460
    annotations:list
461
        List of annotations in image.
462
    output_dir:str
463
        Output directory.
464
    """
465
    os.makedirs(output_dir,exist_ok=True)
466
    basename = svs_file.split('/')[-1].split('.')[0]
467
    arr, masks = load_process_image(svs_file, xml_file)
468
    np.save(join(output_dir,'{}.npy'.format(basename)),arr)
469
    pickle.dump(masks, open(join(output_dir,'{}.pkl'.format(basename)),'wb'), protocol=-1)
470
471
####################
472
473
def load_dataset(in_zarr, in_pkl):
474
    """Load ZARR image and annotations pickle.
475
476
    Parameters
477
    ----------
478
    in_zarr:str
479
        Input image.
480
    in_pkl:str
481
        Input annotations.
482
483
    Returns
484
    -------
485
    dask.array
486
        Image array.
487
    dict
488
        Annotations dictionary.
489
490
    """
491
    if not os.path.exists(in_pkl):
492
        annotations={'annotations':''}
493
    else:
494
        annotations=pickle.load(open(in_pkl,'rb'))
495
    return (da.from_zarr(in_zarr) if in_zarr.endswith('.zarr') else load_image(in_zarr)), annotations#xr.open_dataset(in_netcdf)
496
497
def is_valid_patch(xs,ys,patch_size,purple_mask,intensity_threshold,threshold=0.5):
498
    """Deprecated, computes whether patch is valid."""
499
    print(xs,ys)
500
    return (purple_mask[xs:xs+patch_size,ys:ys+patch_size]>=intensity_threshold).mean() > threshold
501
502
def fix_polygon(poly):
503
    if not poly.is_valid:
504
        #print(poly.exterior.coords.xy)
505
506
        poly=LineString(np.vstack(poly.exterior.coords.xy).T)
507
        poly=unary_union(LineString(poly.coords[:] + poly.coords[0:1]))
508
        #arr.geometry = arr.buffer(0)
509
        poly = [p for p in polygonize(poly)]
510
    else:
511
        poly = [poly]
512
    return poly
513
514
def replace(txt,d=dict()):
515
    for k in d:
516
        txt=txt.replace(k,d[k])
517
    return txt
518
519
def xml2dict_ASAP(xml="",replace_d=dict()):
520
    print(xml)
521
    with open(xml,"rb") as f:
522
        d=xd.parse(f)
523
    d_h=None
524
    d_h=d['ASAP_Annotations']['AnnotationGroups']
525
526
    d_final=defaultdict(list)
527
    try:
528
        for i,annotation in enumerate(d['ASAP_Annotations']["Annotations"]["Annotation"]):
529
            try:
530
                k="{}".format(replace(annotation["@PartOfGroup"],replace_d))
531
                d_final[k].append(np.array([(float(coord["@X"]),float(coord["@Y"])) for coord in annotation["Coordinates"]["Coordinate"]]))
532
            except:
533
                print(i)
534
    except:
535
        print(d['ASAP_Annotations']["Annotations"])
536
    d_final=dict(d_final)
537
    return d_final,d_h
538
539
#@pysnooper.snoop("extract_patch.log")
540
def extract_patch_information(basename,
541
                                input_dir='./',
542
                                annotations=[],
543
                                threshold=0.5,
544
                                patch_size=224,
545
                                generate_finetune_segmentation=False,
546
                                target_class=0,
547
                                intensity_threshold=100.,
548
                                target_threshold=0.,
549
                                adj_mask='',
550
                                basic_preprocess=False,
551
                                tries=0,
552
                                entire_image=False,
553
                                svs_file='',
554
                                transpose_annotations=False,
555
                                get_tissue_mask=False,
556
                                otsu=False,
557
                                compression=8.,
558
                                return_convex_hull=False,
559
                                keep_holes=False,
560
                                max_hole_size=0,
561
                                gray_before_close=False,
562
                                kernel=61,
563
                                min_object_size=100000,
564
                                blur_size=0):
565
    """Final step of preprocessing pipeline. Break up image into patches, include if not background and of a certain intensity, find area of each annotation type in patch, spatial information, image ID and dump data to SQL table.
566
567
    Parameters
568
    ----------
569
    basename:str
570
        Patient ID.
571
    input_dir:str
572
        Input directory.
573
    annotations:list
574
        List of annotations to record, these can be different tissue types, must correspond with XML labels.
575
    threshold:float
576
        Value between 0 and 1 that indicates the minimum amount of patch that musn't be background for inclusion.
577
    patch_size:int
578
        Patch size of patches; this will become one of the tables.
579
    generate_finetune_segmentation:bool
580
        Deprecated.
581
    target_class:int
582
        Number of segmentation classes desired, from 0th class to target_class-1 will be annotated in SQL.
583
    intensity_threshold:float
584
        Value between 0 and 255 that represents minimum intensity to not include as background. Will be modified with new transforms.
585
    target_threshold:float
586
        Deprecated.
587
    adj_mask:str
588
        Adjusted mask if performed binary opening operations in previous preprocessing step.
589
    basic_preprocess:bool
590
        Do not store patch level information.
591
    tries:int
592
        Number of tries in case there is a Dask timeout, run again.
593
594
    Returns
595
    -------
596
    dataframe
597
        Patch information.
598
599
    """
600
    #from collections import OrderedDict
601
    #annotations=OrderedDict(annotations)
602
    #from dask.multiprocessing import get
603
    import dask
604
    import time
605
    from dask import dataframe as dd
606
    import dask.array as da
607
    import multiprocessing
608
    from shapely.ops import unary_union
609
    from shapely.geometry import MultiPolygon
610
    from itertools import product
611
    from functools import reduce
612
    #from distributed import Client,LocalCluster
613
    # max_tries=4
614
    # kargs=dict(basename=basename, input_dir=input_dir, annotations=annotations, threshold=threshold, patch_size=patch_size, generate_finetune_segmentation=generate_finetune_segmentation, target_class=target_class, intensity_threshold=intensity_threshold, target_threshold=target_threshold, adj_mask=adj_mask, basic_preprocess=basic_preprocess, tries=tries, svs_file=svs_file, transpose_annotations=transpose_annotations)
615
    # try:
616
        #,
617
        #                       'distributed.scheduler.allowed-failures':20,
618
        #                       'num-workers':20}):
619
        #cluster=LocalCluster()
620
        #cluster.adapt(minimum=10, maximum=100)
621
        #cluster = LocalCluster(threads_per_worker=1, n_workers=20, memory_limit="80G")
622
        #client=Client()#Client(cluster)#processes=True)#cluster,
623
    in_zarr=join(input_dir,'{}.zarr'.format(basename))
624
    in_zarr=(in_zarr if os.path.exists(in_zarr) else svs_file)
625
    arr, masks = load_dataset(in_zarr,join(input_dir,'{}_mask.pkl'.format(basename)))
626
    if 'annotations' in masks:
627
        segmentation = True
628
        #if generate_finetune_segmentation:
629
        mask=join(input_dir,'{}_mask.npy'.format(basename))
630
        mask = (mask if os.path.exists(mask) else mask.replace('.npy','.npz'))
631
        segmentation_mask = (npy2da(mask) if not adj_mask else adj_mask)
632
        if transpose_annotations:
633
            segmentation_mask=segmentation_mask.transpose([1,0,2])
634
    else:
635
        segmentation = False
636
        annotations=list(annotations)
637
        print(annotations)
638
        #masks=np.load(masks['annotations'])
639
    #npy_file = join(input_dir,'{}.npy'.format(basename))
640
    purple_mask = create_purple_mask(arr) if not get_tissue_mask else da.from_array(generate_tissue_mask(arr.compute(),compression=compression,
641
                                                                                                                    otsu=otsu,
642
                                                                                                                    threshold=255-intensity_threshold,
643
                                                                                                                    connectivity=8,
644
                                                                                                                    kernel=kernel,
645
                                                                                                                    min_object_size=min_object_size,
646
                                                                                                                    return_convex_hull=return_convex_hull,
647
                                                                                                                    keep_holes=keep_holes,
648
                                                                                                                    max_hole_size=max_hole_size,
649
                                                                                                                    gray_before_close=gray_before_close,
650
                                                                                                                    blur_size=blur_size))
651
    if get_tissue_mask:
652
        intensity_threshold=0.5
653
654
    x_max = float(arr.shape[0])
655
    y_max = float(arr.shape[1])
656
    x_steps = int((x_max-patch_size) / patch_size )
657
    y_steps = int((y_max-patch_size) / patch_size )
658
    for annotation in annotations:
659
        if masks[annotation]:
660
            masks[annotation]=list(reduce(lambda x,y: x+y, [fix_polygon(poly) for poly in masks[annotation]]))
661
        try:
662
            masks[annotation]=[unary_union(masks[annotation])] if masks[annotation] else []
663
        except:
664
            masks[annotation]=[MultiPolygon(masks[annotation])] if masks[annotation] else []
665
    patch_info=pd.DataFrame([([basename,i*patch_size,j*patch_size,patch_size,'NA']+[0.]*(target_class if segmentation else len(annotations))) for i,j in product(range(x_steps+1),range(y_steps+1))],columns=(['ID','x','y','patch_size','annotation']+(annotations if not segmentation else list([str(i) for i in range(target_class)]))))#[dask.delayed(return_line_info)(i,j) for (i,j) in product(range(x_steps+1),range(y_steps+1))]
666
    if entire_image:
667
        patch_info.iloc[:,1:4]=np.nan
668
        patch_info=pd.DataFrame(patch_info.iloc[0,:])
669
    else:
670
        if basic_preprocess:
671
            patch_info=patch_info.iloc[:,:4]
672
        valid_patches=[]
673
        for xs,ys in patch_info[['x','y']].values.tolist():
674
            valid_patches.append(((purple_mask[xs:xs+patch_size,ys:ys+patch_size]>=intensity_threshold).mean() > threshold) if intensity_threshold > 0 else True) # dask.delayed(is_valid_patch)(xs,ys,patch_size,purple_mask,intensity_threshold,threshold)
675
        valid_patches=np.array(da.compute(*valid_patches))
676
        print('Valid Patches Complete')
677
        #print(valid_patches)
678
        patch_info=patch_info.loc[valid_patches]
679
        if not basic_preprocess:
680
            area_info=[]
681
            if segmentation:
682
                patch_info.loc[:,'annotation']='segment'
683
                for xs,ys in patch_info[['x','y']].values.tolist():
684
                    xf=xs+patch_size
685
                    yf=ys+patch_size
686
                    #print(xs,ys)
687
                    area_info.append(da.histogram(segmentation_mask[xs:xf,ys:yf],range=[0,target_class-1],bins=target_class)[0])
688
                    #area_info.append(dask.delayed(seg_line)(xs,ys,patch_size,segmentation_mask,target_class))
689
            else:
690
                for xs,ys in patch_info[['x','y']].values.tolist():
691
                    area_info.append([dask.delayed(is_coords_in_box)([xs,ys],patch_size,masks[annotation]) for annotation in annotations])
692
            #area_info=da.concatenate(area_info,axis=0).compute()
693
            area_info=np.array(dask.compute(*area_info)).astype(float)#da.concatenate(area_info,axis=0).compute(dtype=np.float16,scheduler='threaded')).astype(np.float16)
694
            print('Area Info Complete')
695
            area_info = area_info/(patch_size**2)
696
            patch_info.iloc[:,5:]=area_info
697
            #print(patch_info.dtypes)
698
            annot=list(patch_info.iloc[:,5:])
699
            patch_info.loc[:,'annotation']=np.vectorize(lambda i: annot[patch_info.iloc[i,5:].values.argmax()])(np.arange(patch_info.shape[0]))#patch_info[np.arange(target_class).astype(str).tolist()].values.argmax(1).astype(str)
700
                #client.close()
701
    # except Exception as e:
702
    #   print(e)
703
    #   kargs['tries']+=1
704
    #   if kargs['tries']==max_tries:
705
    #       raise Exception('Exceeded past maximum number of tries.')
706
    #   else:
707
    #       print('Restarting preprocessing again.')
708
    #       extract_patch_information(**kargs)
709
    # print(patch_info)
710
    return patch_info
711
712
def generate_patch_pipeline(basename,
713
                            input_dir='./',
714
                            annotations=[],
715
                            threshold=0.5,
716
                            patch_size=224,
717
                            out_db='patch_info.db',
718
                            generate_finetune_segmentation=False,
719
                            target_class=0,
720
                            intensity_threshold=100.,
721
                            target_threshold=0.,
722
                            adj_mask='',
723
                            basic_preprocess=False,
724
                            entire_image=False,
725
                            svs_file='',
726
                            transpose_annotations=False,
727
                            get_tissue_mask=False,
728
                            otsu=False,
729
                            compression=8.,
730
                            return_convex_hull=False,
731
                            keep_holes=False,
732
                            max_hole_size=0,
733
                            gray_before_close=False,
734
                            kernel=61,
735
                            min_object_size=100000,
736
                            blur_size=0):
737
    """Find area coverage of each annotation in each patch and store patch information into SQL db.
738
739
    Parameters
740
    ----------
741
    basename:str
742
        Patient ID.
743
    input_dir:str
744
        Input directory.
745
    annotations:list
746
        List of annotations to record, these can be different tissue types, must correspond with XML labels.
747
    threshold:float
748
        Value between 0 and 1 that indicates the minimum amount of patch that musn't be background for inclusion.
749
    patch_size:int
750
        Patch size of patches; this will become one of the tables.
751
    out_db:str
752
        Output SQL database.
753
    generate_finetune_segmentation:bool
754
        Deprecated.
755
    target_class:int
756
        Number of segmentation classes desired, from 0th class to target_class-1 will be annotated in SQL.
757
    intensity_threshold:float
758
        Value between 0 and 255 that represents minimum intensity to not include as background. Will be modified with new transforms.
759
    target_threshold:float
760
        Deprecated.
761
    adj_mask:str
762
        Adjusted mask if performed binary opening operations in previous preprocessing step.
763
    basic_preprocess:bool
764
        Do not store patch level information.
765
    """
766
    patch_info = extract_patch_information(basename,
767
                                            input_dir,
768
                                            annotations,
769
                                            threshold,
770
                                            patch_size,
771
                                            generate_finetune_segmentation=generate_finetune_segmentation,
772
                                            target_class=target_class,
773
                                            intensity_threshold=intensity_threshold,
774
                                            target_threshold=target_threshold,
775
                                            adj_mask=adj_mask,
776
                                            basic_preprocess=basic_preprocess,
777
                                            entire_image=entire_image,
778
                                            svs_file=svs_file,
779
                                            transpose_annotations=transpose_annotations,
780
                                            get_tissue_mask=get_tissue_mask,
781
                                            otsu=otsu,
782
                                            compression=compression,
783
                                            return_convex_hull=return_convex_hull,
784
                                            keep_holes=keep_holes,
785
                                            max_hole_size=max_hole_size,
786
                                            gray_before_close=gray_before_close,
787
                                            kernel=kernel,
788
                                            min_object_size=min_object_size,
789
                                            blur_size=blur_size)
790
    conn = sqlite3.connect(out_db)
791
    patch_info.to_sql(str(patch_size), con=conn, if_exists='append')
792
    conn.close()
793
794
795
# now output csv
796
def save_all_patch_info(basenames, input_dir='./', annotations=[], threshold=0.5, patch_size=224, output_pkl='patch_info.pkl'):
797
    """Deprecated."""
798
    df=pd.concat([extract_patch_information(basename, input_dir, annotations, threshold, patch_size) for basename in basenames]).reset_index(drop=True)
799
    df.to_pickle(output_pkl)
800
801
#########
802
803
def create_zero_mask(npy_mask,in_zarr,in_pkl):
804
    from scipy.sparse import csr_matrix, save_npz
805
    arr,annotations_dict=load_dataset(in_zarr, in_pkl)
806
    annotations_dict.update({'annotations':npy_mask})
807
    #np.save(npy_mask, np.zeros(arr.shape[:-1]))
808
    save_npz(file=npy_mask,matrix=csr_matrix(arr.shape[:-1]))
809
    pickle.dump(annotations_dict,open(in_pkl,'wb'))
810
811
#########
812
813
814
def create_train_val_test(train_val_test_pkl, input_info_db, patch_size):
815
    """Create dataframe that splits slides into training validation and test.
816
817
    Parameters
818
    ----------
819
    train_val_test_pkl:str
820
        Pickle for training validation and test slides.
821
    input_info_db:str
822
        Patch information SQL database.
823
    patch_size:int
824
        Patch size looking to access.
825
826
    Returns
827
    -------
828
    dataframe
829
        Train test validation splits.
830
831
    """
832
    if os.path.exists(train_val_test_pkl):
833
        IDs = pd.read_pickle(train_val_test_pkl)
834
    else:
835
        conn = sqlite3.connect(input_info_db)
836
        df=pd.read_sql('select * from "{}";'.format(patch_size),con=conn)
837
        conn.close()
838
        IDs=df['ID'].unique()
839
        IDs=pd.DataFrame(IDs,columns=['ID'])
840
        IDs_train, IDs_test = train_test_split(IDs)
841
        IDs_train, IDs_val = train_test_split(IDs_train)
842
        IDs_train['set']='train'
843
        IDs_val['set']='val'
844
        IDs_test['set']='test'
845
        IDs=pd.concat([IDs_train,IDs_val,IDs_test])
846
        IDs.to_pickle(train_val_test_pkl)
847
    return IDs
848
849
def modify_patch_info(input_info_db='patch_info.db', slide_labels=pd.DataFrame(), pos_annotation_class='', patch_size=224, segmentation=False, other_annotations=[], target_segmentation_class=-1, target_threshold=0., classify_annotations=False, modify_patches=False):
850
    """Modify the patch information to get ready for deep learning, incorporate whole slide labels if needed.
851
852
    Parameters
853
    ----------
854
    input_info_db:str
855
        SQL DB file.
856
    slide_labels:dataframe
857
        Dataframe with whole slide labels.
858
    pos_annotation_class:str
859
        Tissue/annotation label to label with whole slide image label, if not supplied, any slide's patches receive the whole slide label.
860
    patch_size:int
861
        Patch size.
862
    segmentation:bool
863
        Segmentation?
864
    other_annotations:list
865
        Other annotations to access from patch information.
866
    target_segmentation_class:int
867
        Segmentation class to threshold.
868
    target_threshold:float
869
        Include patch if patch has target area greater than this.
870
    classify_annotations:bool
871
        Classifying annotations for pretraining, or final model?
872
873
    Returns
874
    -------
875
    dataframe
876
        Modified patch information.
877
878
    """
879
    conn = sqlite3.connect(input_info_db)
880
    df=pd.read_sql('select * from "{}";'.format(patch_size),con=conn)
881
    conn.close()
882
    #print(df)
883
    df=df.drop_duplicates()
884
    df=df.loc[np.isin(df['ID'],slide_labels.index)]
885
    #print(classify_annotations)
886
    if not segmentation:
887
        if classify_annotations:
888
            targets=df['annotation'].unique().tolist()
889
            if len(targets)==1:
890
                targets=list(df.iloc[:,5:])
891
        else:
892
            targets = list(slide_labels)
893
            if type(pos_annotation_class)==type(''):
894
                included_annotations = [pos_annotation_class]
895
            else:
896
                included_annotations = copy.deepcopy(pos_annotation_class)
897
            included_annotations.extend(other_annotations)
898
            print(df.shape,included_annotations)
899
            if modify_patches:
900
                df=df[np.isin(df['annotation'],included_annotations)]
901
            for target in targets:
902
                df[target]=0.
903
            for slide in slide_labels.index:
904
                slide_bool=((df['ID']==slide) & df[pos_annotation_class]>0.) if pos_annotation_class else (df['ID']==slide) # (df['annotation']==pos_annotation_class)
905
                if slide_bool.sum():
906
                    for target in targets:
907
                        df.loc[slide_bool,target] = slide_labels.loc[slide,target]#.values#1.
908
        df['area']=np.vectorize(lambda i: df.iloc[i][df.iloc[i]['annotation']])(np.arange(df.shape[0])) if modify_patches else 1.
909
        if 'area' in list(df) and target_threshold>0.:
910
            df=df.loc[df['area']>=target_threshold]
911
    else:
912
        df['target']=0.
913
        if target_segmentation_class >=0:
914
            df=df.loc[df[str(target_segmentation_class)]>=target_threshold]
915
    print(df.shape)
916
    return df
917
918
def npy2da(npy_file):
919
    """Numpy to dask array.
920
921
    Parameters
922
    ----------
923
    npy_file:str
924
        Input npy file.
925
926
    Returns
927
    -------
928
    dask.array
929
        Converted numpy array to dask.
930
931
    """
932
    if npy_file.endswith('.npy'):
933
        if os.path.exists(npy_file):
934
            arr=da.from_array(np.load(npy_file, mmap_mode = 'r+'))
935
        else:
936
            npy_file=npy_file.replace('.npy','.npz')
937
    elif npy_file.endswith('.npz'):
938
        from scipy.sparse import load_npz
939
        arr=da.from_array(load_npz(npy_file).toarray())
940
    elif npy_file.endswith('.h5'):
941
        arr=da.from_array(h5py.File(npy_file, 'r')['dataset'])
942
    return arr
943
944
def grab_interior_points(xml_file, img_size, annotations=[]):
945
    """Deprecated."""
946
    interior_point_dict = {}
947
    for annotation in annotations:
948
        try:
949
            interior_point_dict[annotation] = parse_coord_return_boxes(xml_file, annotation, return_coords = False) # boxes2interior(img_size,
950
        except:
951
            interior_point_dict[annotation] = []#np.array([[],[]])
952
    return interior_point_dict
953
954
def boxes2interior(img_size, polygons):
955
    """Deprecated."""
956
    img = Image.new('L', img_size, 0)
957
    for polygon in polygons:
958
        ImageDraw.Draw(img).polygon(polygon, outline=1, fill=1)
959
    mask = np.array(img).nonzero()
960
    #mask = (np.ones(len(mask[0])),mask)
961
    return mask
962
963
def parse_coord_return_boxes(xml_file, annotation_name = '', return_coords = False, transpose_annotations=False):
964
    """Get list of shapely objects for each annotation in the XML object.
965
966
    Parameters
967
    ----------
968
    xml_file:str
969
        Annotation file.
970
    annotation_name:str
971
        Name of xml annotation.
972
    return_coords:bool
973
        Just return list of coords over shapes.
974
975
    Returns
976
    -------
977
    list
978
        List of shapely objects.
979
980
    """
981
    boxes = []
982
    if xml_file.endswith(".xml"):
983
        xml_data = BeautifulSoup(open(xml_file),'html')
984
        #print(xml_data.findAll('annotation'))
985
        #print(xml_data.findAll('Annotation'))
986
        for annotation in xml_data.findAll('annotation'):
987
            if annotation['partofgroup'] == annotation_name:
988
                for coordinates in annotation.findAll('coordinates'):
989
                    # FIXME may need to change x and y coordinates
990
                    coords = np.array([(coordinate['x'],coordinate['y']) for coordinate in coordinates.findAll('coordinate')])
991
                    if transpose_annotations:
992
                        coords=coords[:,::-1]
993
                    coords=coords.tolist()
994
                    if return_coords:
995
                        boxes.append(coords)
996
                    else:
997
                        boxes.append(Polygon(np.array(coords).astype(np.float)))
998
    else:
999
        annotations=pickle.load(open(xml_file,'rb')).get(annotation_name,[])#[annotation_name]
1000
        for annotation in annotations:
1001
            if transpose_annotations:
1002
                annotation=annotation[:,::-1]
1003
            boxes.append(annotation.tolist() if return_coords else Polygon(annotation))
1004
    return boxes
1005
1006
def is_coords_in_box(coords,patch_size,boxes):
1007
    """Get area of annotation in patch.
1008
1009
    Parameters
1010
    ----------
1011
    coords:array
1012
        X,Y coordinates of patch.
1013
    patch_size:int
1014
        Patch size.
1015
    boxes:list
1016
        Shapely objects for annotations.
1017
1018
    Returns
1019
    -------
1020
    float
1021
        Area of annotation type.
1022
1023
    """
1024
    if len(boxes):
1025
        points=Polygon(np.array([[0,0],[1,0],[1,1],[0,1]])*patch_size+coords)
1026
        area=points.intersection(boxes[0]).area#any(list(map(lambda x: x.intersects(points),boxes)))#return_image_coord(nx=nx,ny=ny,xi=xi,yi=yi, output_point=output_point)
1027
    else:
1028
        area=0.
1029
    return area
1030
1031
def is_image_in_boxes(image_coord_dict, boxes):
1032
    """Find if image intersects with annotations.
1033
1034
    Parameters
1035
    ----------
1036
    image_coord_dict:dict
1037
        Dictionary of patches.
1038
    boxes:list
1039
        Shapely annotation shapes.
1040
1041
    Returns
1042
    -------
1043
    dict
1044
        Dictionary of whether image intersects with any of the annotations.
1045
1046
    """
1047
    return {image: any(list(map(lambda x: x.intersects(image_coord_dict[image]),boxes))) for image in image_coord_dict}
1048
1049
def images2coord_dict(images, output_point=False):
1050
    """Deprecated"""
1051
    return {image: image2coords(image, output_point) for image in images}
1052
1053
def dir2images(image_dir):
1054
    """Deprecated"""
1055
    return glob.glob(join(image_dir,'*.jpg'))
1056
1057
def return_image_in_boxes_dict(image_dir, xml_file, annotation=''):
1058
    """Deprecated"""
1059
    boxes = parse_coord_return_boxes(xml_file, annotation)
1060
    images = dir2images(image_dir)
1061
    coord_dict = images2coord_dict(images)
1062
    return is_image_in_boxes(image_coord_dict=coord_dict,boxes=boxes)
1063
1064
def image2coords(image_file, output_point=False):
1065
    """Deprecated."""
1066
    nx,ny,yi,xi = np.array(image_file.split('/')[-1].split('.')[0].split('_')[1:]).astype(int).tolist()
1067
    return return_image_coord(nx=nx,ny=ny,xi=xi,yi=yi, output_point=output_point)
1068
1069
def retain_images(image_dir,xml_file, annotation=''):
1070
    """Deprecated"""
1071
    image_in_boxes_dict=return_image_in_boxes_dict(image_dir,xml_file, annotation)
1072
    return [img for img in image_in_boxes_dict if image_in_boxes_dict[img]]
1073
1074
def return_image_coord(nx=0,ny=0,xl=3333,yl=3333,xi=0,yi=0,xc=3,yc=3,dimx=224,dimy=224, output_point=False):
1075
    """Deprecated"""
1076
    if output_point:
1077
        return np.array([xc,yc])*np.array([nx*xl+xi+dimx/2,ny*yl+yi+dimy/2])
1078
    else:
1079
        static_point = np.array([nx*xl+xi,ny*yl+yi])
1080
        points = np.array([(np.array([xc,yc])*(static_point+np.array(new_point))).tolist() for new_point in [[0,0],[dimx,0],[dimx,dimy],[0,dimy]]])
1081
        return Polygon(points)#Point(*((np.array([xc,yc])*np.array([nx*xl+xi+dimx/2,ny*yl+yi+dimy/2])).tolist())) # [::-1]
1082
1083
def fix_name(basename):
1084
    """Fixes illegitimate basename, deprecated."""
1085
    if len(basename) < 3:
1086
        return '{}0{}'.format(*basename)
1087
    return basename
1088
1089
def fix_names(file_dir):
1090
    """Fixes basenames, deprecated."""
1091
    for filename in glob.glob(join(file_dir,'*')):
1092
        basename = filename.split('/')[-1]
1093
        basename, suffix = basename[:basename.rfind('.')], basename[basename.rfind('.'):]
1094
        if len(basename) < 3:
1095
            new_filename=join(file_dir,'{}0{}{}'.format(*basename,suffix))
1096
            print(filename,new_filename)
1097
            subprocess.call('mv {} {}'.format(filename,new_filename),shell=True)
1098
1099
#######
1100
1101
#@pysnooper.snoop('seg2npy.log')
1102
def segmentation_predictions2npy(y_pred, patch_info, segmentation_map, npy_output, original_patch_size=500, resized_patch_size=256, output_probs=False):
1103
    """Convert segmentation predictions from model to numpy masks.
1104
1105
    Parameters
1106
    ----------
1107
    y_pred:list
1108
        List of patch segmentation masks
1109
    patch_info:dataframe
1110
        Patch information from DB.
1111
    segmentation_map:array
1112
        Existing segmentation mask.
1113
    npy_output:str
1114
        Output npy file.
1115
    """
1116
    import cv2
1117
    import copy
1118
    print(output_probs)
1119
    seg_map_shape=segmentation_map.shape[-2:]
1120
    original_seg_shape=copy.deepcopy(seg_map_shape)
1121
    if resized_patch_size!=original_patch_size:
1122
        seg_map_shape = [int(dim*resized_patch_size/original_patch_size) for dim in seg_map_shape]
1123
    segmentation_map = np.zeros(tuple(seg_map_shape)).astype(float)
1124
    for i in range(patch_info.shape[0]):
1125
        patch_info_i = patch_info.iloc[i]
1126
        ID = patch_info_i['ID']
1127
        xs = patch_info_i['x']
1128
        ys = patch_info_i['y']
1129
        patch_size = patch_info_i['patch_size']
1130
        if resized_patch_size!=original_patch_size:
1131
            xs=int(xs*resized_patch_size/original_patch_size)
1132
            ys=int(ys*resized_patch_size/original_patch_size)
1133
            patch_size=resized_patch_size
1134
        prediction=y_pred[i,...]
1135
        segmentation_map[xs:xs+patch_size,ys:ys+patch_size] = prediction
1136
    if resized_patch_size!=original_patch_size:
1137
        segmentation_map=cv2.resize(segmentation_map.astype(float), dsize=original_seg_shape, interpolation=cv2.INTER_NEAREST)
1138
    os.makedirs(npy_output[:npy_output.rfind('/')],exist_ok=True)
1139
    if not output_probs:
1140
        segmentation_map=segmentation_map.astype(np.uint8)
1141
    np.save(npy_output,segmentation_map)