a b/pathflowai/stain_norm.py
1
import cv2
2
import sys
3
import fire
4
import histomicstk
5
import histomicstk as htk
6
import openslide
7
import dask
8
import tqdm
9
import numpy as np
10
from dask.diagnostics import ProgressBar
11
from pathflowai.utils import generate_tissue_mask
12
from histomicstk.preprocessing.color_normalization.\
13
    deconvolution_based_normalization import deconvolution_based_normalization
14
15
W_target = np.array([
16
            [0.6185391,  0.1576997,  -0.01119131],
17
            [0.7012888,  0.8638838,  0.45586256],
18
            [0.3493163,  0.4657428, -0.85597752]
19
        ])
20
21
def return_norm_image(img,mask,W_source=None,W_target=None):
22
    img=deconvolution_based_normalization(
23
        img, W_source=W_source, W_target=W_target, im_target=None,
24
        stains=['hematoxylin', 'eosin'], mask_out=~mask,
25
        stain_unmixing_routine_params={"I_0":215})
26
    return img
27
28
def check_ext(image_file):
29
    return any([image_file.endswith(ext) for ext in ['.svs','.png','.jpg','.jpeg','.tiff','.tif']])
30
31
def stain_norm(image_file,compression=10,patch_size=1024):
32
    if check_ext(image_file):
33
        img = openslide.open_slide(image_file)
34
        image = np.array(img.read_region((0,0), 0, img.level_dimensions[0]))[...,:3]
35
    elif image_file.endswith(".npy"):
36
        image=np.load(image_file)
37
    else: raise NotImplementedError
38
    mask=generate_tissue_mask(image,compression=compression,keep_holes=False)
39
    img_small=cv2.resize(image,None,fx=1/compression,fy=1/compression)
40
    mask_small=cv2.resize(mask.astype(int),None,fx=1/compression,fy=1/compression,interpolation=cv2.INTER_NEAREST).astype(bool)
41
    W_source = htk.preprocessing.color_deconvolution.rgb_separate_stains_macenko_pca(img_small, 215)
42
    W_source = htk.preprocessing.color_deconvolution._reorder_stains(W_source)
43
    res=[]
44
    coords=[]
45
    for i in np.arange(0,image.shape[0]-patch_size,patch_size):
46
        for j in np.arange(0,image.shape[1]-patch_size,patch_size):
47
            if mask[i:i+patch_size,j:j+patch_size].mean():
48
                coords.append((i,j))
49
                res.append(dask.delayed(return_norm_image)(image[i:i+patch_size,j:j+patch_size],mask[i:i+patch_size,j:j+patch_size],W_source,W_target))
50
    with ProgressBar():
51
        res_returned=dask.compute(*res,scheduler="processes")
52
    img_new=np.ones(image.shape).astype(np.uint8)*255
53
    for k in tqdm.trange(len(coords)):
54
        i,j=coords[k]
55
        img_new[i:i+patch_size,j:j+patch_size]=res_returned[k]
56
    return img_new
57
58
def stain_norm_pipeline(image_file="stain_in.svs",
59
                        npy_out='stain_out.npy',
60
                        compression=10,
61
                        patch_size=1024):
62
    np.save(npy_out,stain_norm(image_file,compression,patch_size))
63
64
if __name__=="__main__":
65
    fire.Fire(stain_norm_pipeline)