--- a +++ b/pathflowai/utils.py @@ -0,0 +1,1141 @@ +""" +utils.py +======================= +General utilities that still need to be broken up into preprocessing, machine learning input preparation, and output submodules. +""" + +import numpy as np +from bs4 import BeautifulSoup +from shapely.geometry import Point +from shapely.geometry.polygon import Polygon +import glob +from os.path import join +import plotly.graph_objs as go +import plotly.offline as py +import pandas as pd, numpy as np +import scipy.sparse as sps +from PIL import Image, ImageDraw +Image.MAX_IMAGE_PIXELS=1e10 +import numpy as np +import scipy.sparse as sps +from os.path import join +import os, subprocess, pandas as pd +import sqlite3 +import torch +from torch.utils.data import Dataset#, DataLoader +from sklearn.model_selection import train_test_split +import pysnooper +from shapely.ops import unary_union, polygonize +from shapely.geometry import MultiPolygon, LineString +import numpy as np +import dask.array as da +import dask +import openslide +from openslide import deepzoom +#import xarray as xr, sparse +import pickle +import copy +import h5py +import nonechucks as nc +from nonechucks import SafeDataLoader as DataLoader + +import cv2 +import numpy as np +from skimage.morphology import watershed +from skimage.feature import peak_local_max +from scipy.ndimage import label as scilabel, distance_transform_edt +import scipy.ndimage as ndimage +from skimage import morphology as morph +from scipy.ndimage.morphology import binary_fill_holes as fill_holes +from skimage.filters import threshold_otsu, rank +from skimage.morphology import convex_hull_image, remove_small_holes +from skimage import measure +import xmltodict as xd +from collections import defaultdict + + +def load_sql_df(sql_file, patch_size): + """Load pandas dataframe from SQL, accessing particular patch size within SQL. + + Parameters + ---------- + sql_file:str + SQL db. + patch_size:int + Patch size. + + Returns + ------- + dataframe + Patch level information. + + """ + conn = sqlite3.connect(sql_file) + df=pd.read_sql('select * from "{}";'.format(patch_size),con=conn) + conn.close() + return df + +def df2sql(df, sql_file, patch_size, mode='replace'): + """Write dataframe containing patch level information to SQL db. + + Parameters + ---------- + df:dataframe + Dataframe containing patch information. + sql_file:str + SQL database. + patch_size:int + Size of patches. + mode:str + Replace or append. + + """ + conn = sqlite3.connect(sql_file) + df.set_index('index').to_sql(str(patch_size), con=conn, if_exists=mode) + conn.close() + + +######### + +# https://github.com/qupath/qupath/wiki/Supported-image-formats +def svs2dask_array(svs_file, tile_size=1000, overlap=0, remove_last=True, allow_unknown_chunksizes=False, transpose=False): + """Convert SVS, TIF or TIFF to dask array. + Parameters + ---------- + svs_file : str + Image file. + tile_size : int + Size of chunk to be read in. + overlap : int + Do not modify, overlap between neighboring tiles. + remove_last : bool + Remove last tile because it has a custom size. + allow_unknown_chunksizes : bool + Allow different chunk sizes, more flexible, but slowdown. + Returns + ------- + arr : dask.array.Array + A Dask Array representing the contents of the image file. + >>> arr = svs2dask_array(svs_file, tile_size=1000, overlap=0, remove_last=True, allow_unknown_chunksizes=False) + >>> arr2 = arr.compute() + >>> arr3 = to_pil(cv2.resize(arr2, dsize=(1440, 700), interpolation=cv2.INTER_CUBIC)) + >>> arr3.save(test_image_name) + """ + # https://github.com/jlevy44/PathFlowAI/blob/master/pathflowai/utils.py + img = openslide.open_slide(svs_file) + if type(img) is openslide.OpenSlide: + gen = deepzoom.DeepZoomGenerator( + img, tile_size=tile_size, overlap=overlap, limit_bounds=True) + max_level = len(gen.level_dimensions) - 1 + n_tiles_x, n_tiles_y = gen.level_tiles[max_level] + + @dask.delayed(pure=True) + def get_tile(level, column, row): + tile = gen.get_tile(level, (column, row)) + return np.array(tile).transpose((1, 0, 2)) + + sample_tile_shape = get_tile(max_level, 0, 0).shape.compute() + rows = range(n_tiles_y - (0 if not remove_last else 1)) + cols = range(n_tiles_x - (0 if not remove_last else 1)) + arr = da.concatenate([da.concatenate([da.from_delayed(get_tile(max_level, col, row), sample_tile_shape, np.uint8) for row in rows], + allow_unknown_chunksizes=allow_unknown_chunksizes, axis=1) for col in cols], allow_unknown_chunksizes=allow_unknown_chunksizes) + if transpose: + arr=arr.transpose([1, 0, 2]) + return arr + else: # img is instance of openslide.ImageSlide + return dask_image.imread.imread(svs_file) + +def img2npy_(input_dir,basename, svs_file): + """Convert SVS, TIF, TIFF to NPY. + + Parameters + ---------- + input_dir:str + Output file dir. + basename:str + Basename of output file + svs_file:str + SVS, TIF, TIFF file input. + + Returns + ------- + str + NPY output file. + """ + npy_out_file = join(input_dir,'{}.npy'.format(basename)) + arr = svs2dask_array(svs_file) + np.save(npy_out_file,arr.compute()) + return npy_out_file + +def load_image(svs_file): + """Load SVS, TIF, TIFF + + Parameters + ---------- + svs_file:type + Description of parameter `svs_file`. + + Returns + ------- + type + Description of returned object. + """ + im = Image.open(svs_file) + return np.transpose(np.array(im),(1,0)), im.size + +def create_purple_mask(arr, img_size=None, sparse=True): + """Create a gray scale intensity mask. This will be changed soon to support other thresholding QC methods. + + Parameters + ---------- + arr:dask.array + Dask array containing image information. + img_size:int + Deprecated. + sparse:bool + Deprecated + + Returns + ------- + dask.array + Intensity, grayscale array over image. + + """ + r,b,g=arr[:,:,0],arr[:,:,1],arr[:,:,2] + gray = 0.2989 * r + 0.5870 * g + 0.1140 * b + #rb_avg = (r+b)/2 + mask= ((255.-gray))# >= threshold)#(r > g - 10) & (b > g - 10) & (rb_avg > g + 20)#np.vectorize(is_purple)(arr).astype(int) + if 0 and sparse: + mask = mask.nonzero() + mask = np.array([mask[0].compute(), mask[1].compute()]).T + #mask = (np.ones(len(mask[0])),mask) + #mask = sparse.COO.from_scipy_sparse(sps.coo_matrix(mask, img_size, dtype=np.uint8).tocsr()) + return mask + +def add_purple_mask(arr): + """Optional add intensity mask to the dask array. + + Parameters + ---------- + arr:dask.array + Image data. + + Returns + ------- + array + Image data with intensity added as forth channel. + + """ + return np.concatenate((arr,create_purple_mask(arr)),axis=0) + +def create_sparse_annotation_arrays(xml_file, img_size, annotations=[], transpose_annotations=False): + """Convert annotation xml to shapely objects and store in dictionary. + + Parameters + ---------- + xml_file:str + XML file containing annotations. + img_size:int + Deprecated. + annotations:list + Annotations to look for in xml export. + + Returns + ------- + dict + Dictionary with annotation-shapely object pairs. + + """ + interior_points_dict = {annotation:parse_coord_return_boxes(xml_file, annotation_name = annotation, return_coords = False, transpose_annotations=transpose_annotations) for annotation in annotations}#grab_interior_points(xml_file, img_size, annotations=annotations) if annotations else {} + return {annotation:interior_points_dict[annotation] for annotation in annotations}#sparse.COO.from_scipy_sparse((sps.coo_matrix(interior_points_dict[annotation],img_size, dtype=np.uint8) if interior_points_dict[annotation] not None else sps.coo_matrix(img_size, dtype=np.uint8)).tocsr()) for annotation in annotations} # [sps.coo_matrix(img_size, dtype=np.uint8)]+ + +def load_image(svs_file): + return (npy2da(svs_file) if (svs_file.endswith('.npy') or svs_file.endswith('.h5')) else svs2dask_array(svs_file, tile_size=1000, overlap=0)) + +def load_preprocessed_img(img_file): + if img_file.endswith('.zarr') and not os.path.exists(f"{img_file}/.zarray"): + img_file=img_file.replace(".zarr",".npy") + return npy2da(img_file) if (img_file.endswith('.npy') or img_file.endswith('.h5')) else da.from_zarr(img_file) + +def load_process_image(svs_file, xml_file=None, npy_mask=None, annotations=[], transpose_annotations=False): + """Load SVS-like image (including NPY), segmentation/classification annotations, generate dask array and dictionary of annotations. + + Parameters + ---------- + svs_file:str + Image file + xml_file:str + Annotation file. + npy_mask:array + Numpy segmentation mask. + annotations:list + List of annotations in xml. + + Returns + ------- + array + Dask array of image. + dict + Annotation masks. + + """ + arr = load_image(svs_file)#npy2da(svs_file) if (svs_file.endswith('.npy') or svs_file.endswith('.h5')) else svs2dask_array(svs_file, tile_size=1000, overlap=0)#load_image(svs_file) + img_size = arr.shape[:2] + masks = {}#{'purple': create_purple_mask(arr,img_size,sparse=False)} + if xml_file is not None: + masks.update(create_sparse_annotation_arrays(xml_file, img_size, annotations=annotations, transpose_annotations=transpose_annotations)) + if npy_mask is not None: + masks.update({'annotations':npy_mask}) + #data = dict(image=(['x','y','rgb'],arr),**masks) + #data_arr = {'image':xr.Variable(['x','y','color'], arr)} + #purple_arr = {'mask':xr.Variable(['x','y'], masks['purple'])} + #mask_arr = {m:xr.Variable(['row','col'],masks[m]) for m in masks if m != 'purple'} if 'annotations' not in annotations else {'annotations':xr.Variable(['x','y'],masks['annotations'])} + #masks['purple'] = masks['purple'].reshape(*masks['purple'].shape,1) + #arr = da.concatenate([arr,masks.pop('purple')],axis=2) + return arr, masks#xr.Dataset.from_dict({k:v for k,v in list(data_arr.items())+list(purple_arr.items())+list(mask_arr.items())})#list(dict(image=data_arr,purple=purple_arr,annotations=mask_arr).items()))#arr, masks + +def save_dataset(arr, masks, out_zarr, out_pkl, no_zarr): + """Saves dask array image, dictionary of annotations to zarr and pickle respectively. + + Parameters + ---------- + arr:array + Image. + masks:dict + Dictionary of annotation shapes. + out_zarr:str + Zarr output file for image. + out_pkl:str + Pickle output file. + """ + if not no_zarr: + arr.astype('uint8').to_zarr(out_zarr, overwrite=True) + pickle.dump(masks,open(out_pkl,'wb')) + + #dataset.to_netcdf(out_netcdf, compute=False) + #pickle.dump(dataset, open(out_pkl,'wb'), protocol=-1) + +def run_preprocessing_pipeline(svs_file, xml_file=None, npy_mask=None, annotations=[], out_zarr='output_zarr.zarr', out_pkl='output.pkl',no_zarr=False,transpose_annotations=False): + """Run preprocessing pipeline. Store image into zarr format, segmentations maintain as npy, and xml annotations as pickle. + + Parameters + ---------- + svs_file:str + Input image file. + xml_file:str + Input annotation file. + npy_mask:str + NPY segmentation mask. + annotations:list + List of annotations. + out_zarr:str + Output zarr for image. + out_pkl:str + Output pickle for annotations. + """ + #save_dataset(load_process_image(svs_file, xml_file, npy_mask, annotations), out_netcdf) + arr, masks = load_process_image(svs_file, xml_file, npy_mask, annotations, transpose_annotations) + save_dataset(arr, masks,out_zarr, out_pkl, no_zarr) + +################### + +def adjust_mask(mask_file, dask_img_array_file, out_npy, n_neighbors): + """Fixes segmentation masks to reduce coarse annotations over empty regions. + + Parameters + ---------- + mask_file:str + NPY segmentation mask. + dask_img_array_file:str + Dask image file. + out_npy:str + Output numpy file. + n_neighbors:int + Number nearest neighbors for dilation and erosion of mask from background to not background. + + Returns + ------- + str + Output numpy file. + + """ + from dask_image.ndmorph import binary_opening + from dask.distributed import Client + #c=Client() + dask_img_array=da.from_zarr(dask_img_array_file) + mask=npy2da(mask_file) + is_tissue_mask = mask>0. + is_tissue_mask_img=((dask_img_array[...,0]>200.) & (dask_img_array[...,1]>200.)& (dask_img_array[...,2]>200.)) == 0 + opening=binary_opening(is_tissue_mask_img,structure=da.ones((n_neighbors,n_neighbors)))#,mask=is_tissue_mask) + mask[(opening==0)&(is_tissue_mask==1)]=0 + np.save(out_npy,mask.compute()) + #c.close() + return out_npy + +def filter_grays(rgb, tolerance=15, output_type="bool"): + """ https://github.com/deroneriksson/python-wsi-preprocessing/blob/master/deephistopath/wsi/filter.py + Create a mask to filter out pixels where the red, green, and blue channel values are similar. + Args: + np_img: RGB image as a NumPy array. + tolerance: Tolerance value to determine how similar the values must be in order to be filtered out + output_type: Type of array to return (bool, float, or uint8). + Returns: + NumPy array representing a mask where pixels with similar red, green, and blue values have been masked out. + """ + (h, w, c) = rgb.shape + rgb = rgb.astype(np.int) + rg_diff = np.abs(rgb[:, :, 0] - rgb[:, :, 1]) <= tolerance + rb_diff = np.abs(rgb[:, :, 0] - rgb[:, :, 2]) <= tolerance + gb_diff = np.abs(rgb[:, :, 1] - rgb[:, :, 2]) <= tolerance + result = ~(rg_diff & rb_diff & gb_diff) + if output_type == "bool": + pass + elif output_type == "float": + result = result.astype(float) + else: + result = result.astype("uint8") * 255 + return result + +def label_objects(img, + otsu=True, + min_object_size=100000, + threshold=240, + connectivity=8, + kernel=61, + keep_holes=False, + max_hole_size=0, + gray_before_close=False, + blur_size=0): + I=cv2.cvtColor(img,cv2.COLOR_RGB2GRAY) + gray_mask=filter_grays(img, output_type="bool") + if otsu: threshold = threshold_otsu(I) + BW = (I<threshold).astype(bool) + if gray_before_close: BW=BW&gray_mask + if kernel>0: BW = morph.binary_closing(BW, morph.disk(kernel))#square + if not gray_before_close: BW=BW&gray_mask + if blur_size: BW=(cv2.blur(BW.astype(np.uint8), (blur_size,blur_size))==1) + labels = scilabel(BW)[0] + labels=morph.remove_small_objects(labels, min_size=min_object_size, connectivity = connectivity, in_place=True) + if not keep_holes and max_hole_size: + BW=morph.remove_small_objects(labels==0, min_size=max_hole_size, connectivity = connectivity, in_place=True)==False#remove_small_holes(labels,area_threshold=max_hole_size, connectivity = connectivity, in_place=True)>0 + elif keep_holes: + BW=labels>0 + else: + BW=fill_holes(labels) + labels = scilabel(BW)[0] + return(BW!=0),labels + +def generate_tissue_mask(arr, + compression=8, + otsu=False, + threshold=220, + connectivity=8, + kernel=61, + min_object_size=100000, + return_convex_hull=False, + keep_holes=False, + max_hole_size=0, + gray_before_close=False, + blur_size=0): + img=cv2.resize(arr,None,fx=1/compression,fy=1/compression,interpolation=cv2.INTER_CUBIC) + WB, lbl=label_objects(img, otsu=otsu, min_object_size=min_object_size, threshold=threshold, connectivity=connectivity, kernel=kernel,keep_holes=keep_holes,max_hole_size=max_hole_size, gray_before_close=gray_before_close,blur_size=blur_size) + if return_convex_hull: + for i in range(1,lbl.max()+1): + WB=WB+convex_hull_image(lbl==i) + WB=WB>0 + WB=cv2.resize(WB.astype(np.uint8),arr.shape[:2][::-1],interpolation=cv2.INTER_CUBIC)>0 + return WB + +################### + +def process_svs(svs_file, xml_file, annotations=[], output_dir='./'): + """Store images into npy format and store annotations into pickle dictionary. + + Parameters + ---------- + svs_file:str + Image file. + xml_file:str + Annotations file. + annotations:list + List of annotations in image. + output_dir:str + Output directory. + """ + os.makedirs(output_dir,exist_ok=True) + basename = svs_file.split('/')[-1].split('.')[0] + arr, masks = load_process_image(svs_file, xml_file) + np.save(join(output_dir,'{}.npy'.format(basename)),arr) + pickle.dump(masks, open(join(output_dir,'{}.pkl'.format(basename)),'wb'), protocol=-1) + +#################### + +def load_dataset(in_zarr, in_pkl): + """Load ZARR image and annotations pickle. + + Parameters + ---------- + in_zarr:str + Input image. + in_pkl:str + Input annotations. + + Returns + ------- + dask.array + Image array. + dict + Annotations dictionary. + + """ + if not os.path.exists(in_pkl): + annotations={'annotations':''} + else: + annotations=pickle.load(open(in_pkl,'rb')) + return (da.from_zarr(in_zarr) if in_zarr.endswith('.zarr') else load_image(in_zarr)), annotations#xr.open_dataset(in_netcdf) + +def is_valid_patch(xs,ys,patch_size,purple_mask,intensity_threshold,threshold=0.5): + """Deprecated, computes whether patch is valid.""" + print(xs,ys) + return (purple_mask[xs:xs+patch_size,ys:ys+patch_size]>=intensity_threshold).mean() > threshold + +def fix_polygon(poly): + if not poly.is_valid: + #print(poly.exterior.coords.xy) + + poly=LineString(np.vstack(poly.exterior.coords.xy).T) + poly=unary_union(LineString(poly.coords[:] + poly.coords[0:1])) + #arr.geometry = arr.buffer(0) + poly = [p for p in polygonize(poly)] + else: + poly = [poly] + return poly + +def replace(txt,d=dict()): + for k in d: + txt=txt.replace(k,d[k]) + return txt + +def xml2dict_ASAP(xml="",replace_d=dict()): + print(xml) + with open(xml,"rb") as f: + d=xd.parse(f) + d_h=None + d_h=d['ASAP_Annotations']['AnnotationGroups'] + + d_final=defaultdict(list) + try: + for i,annotation in enumerate(d['ASAP_Annotations']["Annotations"]["Annotation"]): + try: + k="{}".format(replace(annotation["@PartOfGroup"],replace_d)) + d_final[k].append(np.array([(float(coord["@X"]),float(coord["@Y"])) for coord in annotation["Coordinates"]["Coordinate"]])) + except: + print(i) + except: + print(d['ASAP_Annotations']["Annotations"]) + d_final=dict(d_final) + return d_final,d_h + +#@pysnooper.snoop("extract_patch.log") +def extract_patch_information(basename, + input_dir='./', + annotations=[], + threshold=0.5, + patch_size=224, + generate_finetune_segmentation=False, + target_class=0, + intensity_threshold=100., + target_threshold=0., + adj_mask='', + basic_preprocess=False, + tries=0, + entire_image=False, + svs_file='', + transpose_annotations=False, + get_tissue_mask=False, + otsu=False, + compression=8., + return_convex_hull=False, + keep_holes=False, + max_hole_size=0, + gray_before_close=False, + kernel=61, + min_object_size=100000, + blur_size=0): + """Final step of preprocessing pipeline. Break up image into patches, include if not background and of a certain intensity, find area of each annotation type in patch, spatial information, image ID and dump data to SQL table. + + Parameters + ---------- + basename:str + Patient ID. + input_dir:str + Input directory. + annotations:list + List of annotations to record, these can be different tissue types, must correspond with XML labels. + threshold:float + Value between 0 and 1 that indicates the minimum amount of patch that musn't be background for inclusion. + patch_size:int + Patch size of patches; this will become one of the tables. + generate_finetune_segmentation:bool + Deprecated. + target_class:int + Number of segmentation classes desired, from 0th class to target_class-1 will be annotated in SQL. + intensity_threshold:float + Value between 0 and 255 that represents minimum intensity to not include as background. Will be modified with new transforms. + target_threshold:float + Deprecated. + adj_mask:str + Adjusted mask if performed binary opening operations in previous preprocessing step. + basic_preprocess:bool + Do not store patch level information. + tries:int + Number of tries in case there is a Dask timeout, run again. + + Returns + ------- + dataframe + Patch information. + + """ + #from collections import OrderedDict + #annotations=OrderedDict(annotations) + #from dask.multiprocessing import get + import dask + import time + from dask import dataframe as dd + import dask.array as da + import multiprocessing + from shapely.ops import unary_union + from shapely.geometry import MultiPolygon + from itertools import product + from functools import reduce + #from distributed import Client,LocalCluster + # max_tries=4 + # kargs=dict(basename=basename, input_dir=input_dir, annotations=annotations, threshold=threshold, patch_size=patch_size, generate_finetune_segmentation=generate_finetune_segmentation, target_class=target_class, intensity_threshold=intensity_threshold, target_threshold=target_threshold, adj_mask=adj_mask, basic_preprocess=basic_preprocess, tries=tries, svs_file=svs_file, transpose_annotations=transpose_annotations) + # try: + #, + # 'distributed.scheduler.allowed-failures':20, + # 'num-workers':20}): + #cluster=LocalCluster() + #cluster.adapt(minimum=10, maximum=100) + #cluster = LocalCluster(threads_per_worker=1, n_workers=20, memory_limit="80G") + #client=Client()#Client(cluster)#processes=True)#cluster, + in_zarr=join(input_dir,'{}.zarr'.format(basename)) + in_zarr=(in_zarr if os.path.exists(in_zarr) else svs_file) + arr, masks = load_dataset(in_zarr,join(input_dir,'{}_mask.pkl'.format(basename))) + if 'annotations' in masks: + segmentation = True + #if generate_finetune_segmentation: + mask=join(input_dir,'{}_mask.npy'.format(basename)) + mask = (mask if os.path.exists(mask) else mask.replace('.npy','.npz')) + segmentation_mask = (npy2da(mask) if not adj_mask else adj_mask) + if transpose_annotations: + segmentation_mask=segmentation_mask.transpose([1,0,2]) + else: + segmentation = False + annotations=list(annotations) + print(annotations) + #masks=np.load(masks['annotations']) + #npy_file = join(input_dir,'{}.npy'.format(basename)) + purple_mask = create_purple_mask(arr) if not get_tissue_mask else da.from_array(generate_tissue_mask(arr.compute(),compression=compression, + otsu=otsu, + threshold=255-intensity_threshold, + connectivity=8, + kernel=kernel, + min_object_size=min_object_size, + return_convex_hull=return_convex_hull, + keep_holes=keep_holes, + max_hole_size=max_hole_size, + gray_before_close=gray_before_close, + blur_size=blur_size)) + if get_tissue_mask: + intensity_threshold=0.5 + + x_max = float(arr.shape[0]) + y_max = float(arr.shape[1]) + x_steps = int((x_max-patch_size) / patch_size ) + y_steps = int((y_max-patch_size) / patch_size ) + for annotation in annotations: + if masks[annotation]: + masks[annotation]=list(reduce(lambda x,y: x+y, [fix_polygon(poly) for poly in masks[annotation]])) + try: + masks[annotation]=[unary_union(masks[annotation])] if masks[annotation] else [] + except: + masks[annotation]=[MultiPolygon(masks[annotation])] if masks[annotation] else [] + patch_info=pd.DataFrame([([basename,i*patch_size,j*patch_size,patch_size,'NA']+[0.]*(target_class if segmentation else len(annotations))) for i,j in product(range(x_steps+1),range(y_steps+1))],columns=(['ID','x','y','patch_size','annotation']+(annotations if not segmentation else list([str(i) for i in range(target_class)]))))#[dask.delayed(return_line_info)(i,j) for (i,j) in product(range(x_steps+1),range(y_steps+1))] + if entire_image: + patch_info.iloc[:,1:4]=np.nan + patch_info=pd.DataFrame(patch_info.iloc[0,:]) + else: + if basic_preprocess: + patch_info=patch_info.iloc[:,:4] + valid_patches=[] + for xs,ys in patch_info[['x','y']].values.tolist(): + valid_patches.append(((purple_mask[xs:xs+patch_size,ys:ys+patch_size]>=intensity_threshold).mean() > threshold) if intensity_threshold > 0 else True) # dask.delayed(is_valid_patch)(xs,ys,patch_size,purple_mask,intensity_threshold,threshold) + valid_patches=np.array(da.compute(*valid_patches)) + print('Valid Patches Complete') + #print(valid_patches) + patch_info=patch_info.loc[valid_patches] + if not basic_preprocess: + area_info=[] + if segmentation: + patch_info.loc[:,'annotation']='segment' + for xs,ys in patch_info[['x','y']].values.tolist(): + xf=xs+patch_size + yf=ys+patch_size + #print(xs,ys) + area_info.append(da.histogram(segmentation_mask[xs:xf,ys:yf],range=[0,target_class-1],bins=target_class)[0]) + #area_info.append(dask.delayed(seg_line)(xs,ys,patch_size,segmentation_mask,target_class)) + else: + for xs,ys in patch_info[['x','y']].values.tolist(): + area_info.append([dask.delayed(is_coords_in_box)([xs,ys],patch_size,masks[annotation]) for annotation in annotations]) + #area_info=da.concatenate(area_info,axis=0).compute() + area_info=np.array(dask.compute(*area_info)).astype(float)#da.concatenate(area_info,axis=0).compute(dtype=np.float16,scheduler='threaded')).astype(np.float16) + print('Area Info Complete') + area_info = area_info/(patch_size**2) + patch_info.iloc[:,5:]=area_info + #print(patch_info.dtypes) + annot=list(patch_info.iloc[:,5:]) + patch_info.loc[:,'annotation']=np.vectorize(lambda i: annot[patch_info.iloc[i,5:].values.argmax()])(np.arange(patch_info.shape[0]))#patch_info[np.arange(target_class).astype(str).tolist()].values.argmax(1).astype(str) + #client.close() + # except Exception as e: + # print(e) + # kargs['tries']+=1 + # if kargs['tries']==max_tries: + # raise Exception('Exceeded past maximum number of tries.') + # else: + # print('Restarting preprocessing again.') + # extract_patch_information(**kargs) + # print(patch_info) + return patch_info + +def generate_patch_pipeline(basename, + input_dir='./', + annotations=[], + threshold=0.5, + patch_size=224, + out_db='patch_info.db', + generate_finetune_segmentation=False, + target_class=0, + intensity_threshold=100., + target_threshold=0., + adj_mask='', + basic_preprocess=False, + entire_image=False, + svs_file='', + transpose_annotations=False, + get_tissue_mask=False, + otsu=False, + compression=8., + return_convex_hull=False, + keep_holes=False, + max_hole_size=0, + gray_before_close=False, + kernel=61, + min_object_size=100000, + blur_size=0): + """Find area coverage of each annotation in each patch and store patch information into SQL db. + + Parameters + ---------- + basename:str + Patient ID. + input_dir:str + Input directory. + annotations:list + List of annotations to record, these can be different tissue types, must correspond with XML labels. + threshold:float + Value between 0 and 1 that indicates the minimum amount of patch that musn't be background for inclusion. + patch_size:int + Patch size of patches; this will become one of the tables. + out_db:str + Output SQL database. + generate_finetune_segmentation:bool + Deprecated. + target_class:int + Number of segmentation classes desired, from 0th class to target_class-1 will be annotated in SQL. + intensity_threshold:float + Value between 0 and 255 that represents minimum intensity to not include as background. Will be modified with new transforms. + target_threshold:float + Deprecated. + adj_mask:str + Adjusted mask if performed binary opening operations in previous preprocessing step. + basic_preprocess:bool + Do not store patch level information. + """ + patch_info = extract_patch_information(basename, + input_dir, + annotations, + threshold, + patch_size, + generate_finetune_segmentation=generate_finetune_segmentation, + target_class=target_class, + intensity_threshold=intensity_threshold, + target_threshold=target_threshold, + adj_mask=adj_mask, + basic_preprocess=basic_preprocess, + entire_image=entire_image, + svs_file=svs_file, + transpose_annotations=transpose_annotations, + get_tissue_mask=get_tissue_mask, + otsu=otsu, + compression=compression, + return_convex_hull=return_convex_hull, + keep_holes=keep_holes, + max_hole_size=max_hole_size, + gray_before_close=gray_before_close, + kernel=kernel, + min_object_size=min_object_size, + blur_size=blur_size) + conn = sqlite3.connect(out_db) + patch_info.to_sql(str(patch_size), con=conn, if_exists='append') + conn.close() + + +# now output csv +def save_all_patch_info(basenames, input_dir='./', annotations=[], threshold=0.5, patch_size=224, output_pkl='patch_info.pkl'): + """Deprecated.""" + df=pd.concat([extract_patch_information(basename, input_dir, annotations, threshold, patch_size) for basename in basenames]).reset_index(drop=True) + df.to_pickle(output_pkl) + +######### + +def create_zero_mask(npy_mask,in_zarr,in_pkl): + from scipy.sparse import csr_matrix, save_npz + arr,annotations_dict=load_dataset(in_zarr, in_pkl) + annotations_dict.update({'annotations':npy_mask}) + #np.save(npy_mask, np.zeros(arr.shape[:-1])) + save_npz(file=npy_mask,matrix=csr_matrix(arr.shape[:-1])) + pickle.dump(annotations_dict,open(in_pkl,'wb')) + +######### + + +def create_train_val_test(train_val_test_pkl, input_info_db, patch_size): + """Create dataframe that splits slides into training validation and test. + + Parameters + ---------- + train_val_test_pkl:str + Pickle for training validation and test slides. + input_info_db:str + Patch information SQL database. + patch_size:int + Patch size looking to access. + + Returns + ------- + dataframe + Train test validation splits. + + """ + if os.path.exists(train_val_test_pkl): + IDs = pd.read_pickle(train_val_test_pkl) + else: + conn = sqlite3.connect(input_info_db) + df=pd.read_sql('select * from "{}";'.format(patch_size),con=conn) + conn.close() + IDs=df['ID'].unique() + IDs=pd.DataFrame(IDs,columns=['ID']) + IDs_train, IDs_test = train_test_split(IDs) + IDs_train, IDs_val = train_test_split(IDs_train) + IDs_train['set']='train' + IDs_val['set']='val' + IDs_test['set']='test' + IDs=pd.concat([IDs_train,IDs_val,IDs_test]) + IDs.to_pickle(train_val_test_pkl) + return IDs + +def modify_patch_info(input_info_db='patch_info.db', slide_labels=pd.DataFrame(), pos_annotation_class='', patch_size=224, segmentation=False, other_annotations=[], target_segmentation_class=-1, target_threshold=0., classify_annotations=False, modify_patches=False): + """Modify the patch information to get ready for deep learning, incorporate whole slide labels if needed. + + Parameters + ---------- + input_info_db:str + SQL DB file. + slide_labels:dataframe + Dataframe with whole slide labels. + pos_annotation_class:str + Tissue/annotation label to label with whole slide image label, if not supplied, any slide's patches receive the whole slide label. + patch_size:int + Patch size. + segmentation:bool + Segmentation? + other_annotations:list + Other annotations to access from patch information. + target_segmentation_class:int + Segmentation class to threshold. + target_threshold:float + Include patch if patch has target area greater than this. + classify_annotations:bool + Classifying annotations for pretraining, or final model? + + Returns + ------- + dataframe + Modified patch information. + + """ + conn = sqlite3.connect(input_info_db) + df=pd.read_sql('select * from "{}";'.format(patch_size),con=conn) + conn.close() + #print(df) + df=df.drop_duplicates() + df=df.loc[np.isin(df['ID'],slide_labels.index)] + #print(classify_annotations) + if not segmentation: + if classify_annotations: + targets=df['annotation'].unique().tolist() + if len(targets)==1: + targets=list(df.iloc[:,5:]) + else: + targets = list(slide_labels) + if type(pos_annotation_class)==type(''): + included_annotations = [pos_annotation_class] + else: + included_annotations = copy.deepcopy(pos_annotation_class) + included_annotations.extend(other_annotations) + print(df.shape,included_annotations) + if modify_patches: + df=df[np.isin(df['annotation'],included_annotations)] + for target in targets: + df[target]=0. + for slide in slide_labels.index: + slide_bool=((df['ID']==slide) & df[pos_annotation_class]>0.) if pos_annotation_class else (df['ID']==slide) # (df['annotation']==pos_annotation_class) + if slide_bool.sum(): + for target in targets: + df.loc[slide_bool,target] = slide_labels.loc[slide,target]#.values#1. + df['area']=np.vectorize(lambda i: df.iloc[i][df.iloc[i]['annotation']])(np.arange(df.shape[0])) if modify_patches else 1. + if 'area' in list(df) and target_threshold>0.: + df=df.loc[df['area']>=target_threshold] + else: + df['target']=0. + if target_segmentation_class >=0: + df=df.loc[df[str(target_segmentation_class)]>=target_threshold] + print(df.shape) + return df + +def npy2da(npy_file): + """Numpy to dask array. + + Parameters + ---------- + npy_file:str + Input npy file. + + Returns + ------- + dask.array + Converted numpy array to dask. + + """ + if npy_file.endswith('.npy'): + if os.path.exists(npy_file): + arr=da.from_array(np.load(npy_file, mmap_mode = 'r+')) + else: + npy_file=npy_file.replace('.npy','.npz') + elif npy_file.endswith('.npz'): + from scipy.sparse import load_npz + arr=da.from_array(load_npz(npy_file).toarray()) + elif npy_file.endswith('.h5'): + arr=da.from_array(h5py.File(npy_file, 'r')['dataset']) + return arr + +def grab_interior_points(xml_file, img_size, annotations=[]): + """Deprecated.""" + interior_point_dict = {} + for annotation in annotations: + try: + interior_point_dict[annotation] = parse_coord_return_boxes(xml_file, annotation, return_coords = False) # boxes2interior(img_size, + except: + interior_point_dict[annotation] = []#np.array([[],[]]) + return interior_point_dict + +def boxes2interior(img_size, polygons): + """Deprecated.""" + img = Image.new('L', img_size, 0) + for polygon in polygons: + ImageDraw.Draw(img).polygon(polygon, outline=1, fill=1) + mask = np.array(img).nonzero() + #mask = (np.ones(len(mask[0])),mask) + return mask + +def parse_coord_return_boxes(xml_file, annotation_name = '', return_coords = False, transpose_annotations=False): + """Get list of shapely objects for each annotation in the XML object. + + Parameters + ---------- + xml_file:str + Annotation file. + annotation_name:str + Name of xml annotation. + return_coords:bool + Just return list of coords over shapes. + + Returns + ------- + list + List of shapely objects. + + """ + boxes = [] + if xml_file.endswith(".xml"): + xml_data = BeautifulSoup(open(xml_file),'html') + #print(xml_data.findAll('annotation')) + #print(xml_data.findAll('Annotation')) + for annotation in xml_data.findAll('annotation'): + if annotation['partofgroup'] == annotation_name: + for coordinates in annotation.findAll('coordinates'): + # FIXME may need to change x and y coordinates + coords = np.array([(coordinate['x'],coordinate['y']) for coordinate in coordinates.findAll('coordinate')]) + if transpose_annotations: + coords=coords[:,::-1] + coords=coords.tolist() + if return_coords: + boxes.append(coords) + else: + boxes.append(Polygon(np.array(coords).astype(np.float))) + else: + annotations=pickle.load(open(xml_file,'rb')).get(annotation_name,[])#[annotation_name] + for annotation in annotations: + if transpose_annotations: + annotation=annotation[:,::-1] + boxes.append(annotation.tolist() if return_coords else Polygon(annotation)) + return boxes + +def is_coords_in_box(coords,patch_size,boxes): + """Get area of annotation in patch. + + Parameters + ---------- + coords:array + X,Y coordinates of patch. + patch_size:int + Patch size. + boxes:list + Shapely objects for annotations. + + Returns + ------- + float + Area of annotation type. + + """ + if len(boxes): + points=Polygon(np.array([[0,0],[1,0],[1,1],[0,1]])*patch_size+coords) + area=points.intersection(boxes[0]).area#any(list(map(lambda x: x.intersects(points),boxes)))#return_image_coord(nx=nx,ny=ny,xi=xi,yi=yi, output_point=output_point) + else: + area=0. + return area + +def is_image_in_boxes(image_coord_dict, boxes): + """Find if image intersects with annotations. + + Parameters + ---------- + image_coord_dict:dict + Dictionary of patches. + boxes:list + Shapely annotation shapes. + + Returns + ------- + dict + Dictionary of whether image intersects with any of the annotations. + + """ + return {image: any(list(map(lambda x: x.intersects(image_coord_dict[image]),boxes))) for image in image_coord_dict} + +def images2coord_dict(images, output_point=False): + """Deprecated""" + return {image: image2coords(image, output_point) for image in images} + +def dir2images(image_dir): + """Deprecated""" + return glob.glob(join(image_dir,'*.jpg')) + +def return_image_in_boxes_dict(image_dir, xml_file, annotation=''): + """Deprecated""" + boxes = parse_coord_return_boxes(xml_file, annotation) + images = dir2images(image_dir) + coord_dict = images2coord_dict(images) + return is_image_in_boxes(image_coord_dict=coord_dict,boxes=boxes) + +def image2coords(image_file, output_point=False): + """Deprecated.""" + nx,ny,yi,xi = np.array(image_file.split('/')[-1].split('.')[0].split('_')[1:]).astype(int).tolist() + return return_image_coord(nx=nx,ny=ny,xi=xi,yi=yi, output_point=output_point) + +def retain_images(image_dir,xml_file, annotation=''): + """Deprecated""" + image_in_boxes_dict=return_image_in_boxes_dict(image_dir,xml_file, annotation) + return [img for img in image_in_boxes_dict if image_in_boxes_dict[img]] + +def return_image_coord(nx=0,ny=0,xl=3333,yl=3333,xi=0,yi=0,xc=3,yc=3,dimx=224,dimy=224, output_point=False): + """Deprecated""" + if output_point: + return np.array([xc,yc])*np.array([nx*xl+xi+dimx/2,ny*yl+yi+dimy/2]) + else: + static_point = np.array([nx*xl+xi,ny*yl+yi]) + points = np.array([(np.array([xc,yc])*(static_point+np.array(new_point))).tolist() for new_point in [[0,0],[dimx,0],[dimx,dimy],[0,dimy]]]) + return Polygon(points)#Point(*((np.array([xc,yc])*np.array([nx*xl+xi+dimx/2,ny*yl+yi+dimy/2])).tolist())) # [::-1] + +def fix_name(basename): + """Fixes illegitimate basename, deprecated.""" + if len(basename) < 3: + return '{}0{}'.format(*basename) + return basename + +def fix_names(file_dir): + """Fixes basenames, deprecated.""" + for filename in glob.glob(join(file_dir,'*')): + basename = filename.split('/')[-1] + basename, suffix = basename[:basename.rfind('.')], basename[basename.rfind('.'):] + if len(basename) < 3: + new_filename=join(file_dir,'{}0{}{}'.format(*basename,suffix)) + print(filename,new_filename) + subprocess.call('mv {} {}'.format(filename,new_filename),shell=True) + +####### + +#@pysnooper.snoop('seg2npy.log') +def segmentation_predictions2npy(y_pred, patch_info, segmentation_map, npy_output, original_patch_size=500, resized_patch_size=256, output_probs=False): + """Convert segmentation predictions from model to numpy masks. + + Parameters + ---------- + y_pred:list + List of patch segmentation masks + patch_info:dataframe + Patch information from DB. + segmentation_map:array + Existing segmentation mask. + npy_output:str + Output npy file. + """ + import cv2 + import copy + print(output_probs) + seg_map_shape=segmentation_map.shape[-2:] + original_seg_shape=copy.deepcopy(seg_map_shape) + if resized_patch_size!=original_patch_size: + seg_map_shape = [int(dim*resized_patch_size/original_patch_size) for dim in seg_map_shape] + segmentation_map = np.zeros(tuple(seg_map_shape)).astype(float) + for i in range(patch_info.shape[0]): + patch_info_i = patch_info.iloc[i] + ID = patch_info_i['ID'] + xs = patch_info_i['x'] + ys = patch_info_i['y'] + patch_size = patch_info_i['patch_size'] + if resized_patch_size!=original_patch_size: + xs=int(xs*resized_patch_size/original_patch_size) + ys=int(ys*resized_patch_size/original_patch_size) + patch_size=resized_patch_size + prediction=y_pred[i,...] + segmentation_map[xs:xs+patch_size,ys:ys+patch_size] = prediction + if resized_patch_size!=original_patch_size: + segmentation_map=cv2.resize(segmentation_map.astype(float), dsize=original_seg_shape, interpolation=cv2.INTER_NEAREST) + os.makedirs(npy_output[:npy_output.rfind('/')],exist_ok=True) + if not output_probs: + segmentation_map=segmentation_map.astype(np.uint8) + np.save(npy_output,segmentation_map)