--- a +++ b/pathaia/semantic/functional_api.py @@ -0,0 +1,156 @@ +# coding: utf8 +""" +A module to produce semantic segmentations. + +Can be used visualize model perception of an image. +""" +import numpy +from ..util.images import unlabeled_regular_grid_list +from ..util.paths import imfiles_in_folder +from ..util import NDImage, NDIntMask2d, NDIntMask3d, NDIntMask4d +from ..deep.dense import Vocabulary +import os +import tifffile +from skimage.io import imread +from typing import Union, Tuple + + +def coarse(image: NDImage, model: Vocabulary) -> NDIntMask2d: + """Compute a coarse semantic segmentation of an image. + + Place a window on every pixel, pixel value in segmentation + is its prediction (by a model) made on the window. + + Args: + image: an image. + model: a ml model. + + Returns: + Segmentation mask of the given image. + + """ + psize = model.context + img = image.astype(float) + spaceshape = (image.shape[0], image.shape[1]) + segmentation = numpy.zeros(spaceshape, int) + positions = unlabeled_regular_grid_list(spaceshape, psize) + patches = [img[i : i + psize, j : j + psize].reshape(-1) for i, j in positions] + preds = model.predict(numpy.array(patches)) + for idx, pos in enumerate(positions): + pred = preds[idx] + i, j = pos + segmentation[i : i + psize, j : j + psize] = pred + return segmentation + + +def coarse_sep_channels_classif(image: NDImage, model: Vocabulary) -> NDIntMask3d: + """Compute a coarse semantic segmentation of an image. + + Place a window on every pixel, pixel value in segmentation + is its prediction (by a model) made on the window. + + Args: + image: an image. + model: a ml model. + + Returns: + Segmentation mask of the given image. + + """ + psize = model.context + img = image.astype(float) + spaceshape = (image.shape[0], image.shape[1]) + channels = image.shape[-1] + segmentations = numpy.zeros(img.shape, int) + positions = unlabeled_regular_grid_list(spaceshape, psize) + patchlists = [] + for c in range(channels): + patchlists.append( + numpy.array( + [ + img[:, :, c][i : i + psize, j : j + psize].reshape(-1) + for i, j in positions + ] + ) + ) + preds = model.predict(patchlists) + for idx, pos in enumerate(positions): + for c in range(channels): + i, j = pos + segmentations[i : i + psize, j : j + psize, c] = preds[c][idx] + return segmentations + + +def coarse_sep_channels_desc( + image: NDImage, model: Vocabulary, outrag: bool = False +) -> Union[NDIntMask4d, Tuple[NDIntMask4d, NDIntMask2d]]: + """Compute a coarse semantic segmentation of an image. + + Place a window on every pixel, pixel value in segmentation + is its prediction (by a model) made on the window. + + Args: + image: an image. + model: a ml model. + rag: whether to return rag + + Returns: + Segmentation mask of the given image, optionally with rag + + """ + psize = model.context + img = image.astype(float) + spaceshape = (image.shape[0], image.shape[1]) + channels = image.shape[-1] + segmentations = numpy.zeros(img.shape + (model.n_words,), int) + rag = numpy.zeros((img.shape[0], img.shape[1]), int) + positions = unlabeled_regular_grid_list(spaceshape, psize) + patchlists = [] + for c in range(channels): + patchlists.append( + numpy.array( + [ + img[:, :, c][i : i + psize, j : j + psize].reshape(-1) + for i, j in positions + ] + ) + ) + preds = model.predict(patchlists, fuzzy=True) + for idx, pos in enumerate(positions): + for c in range(channels): + i, j = pos + segmentations[i : i + psize, j : j + psize, c] = preds[c][idx] + rag[i : i + psize, j : j + psize] = idx + if outrag: + return segmentations, rag + return segmentations + + +def partition_slide_coarse( + slidefolder: str, level: int, outfolder: str, model: Vocabulary, sep_channels: bool +): + """Compute a coarse semantic segmentation of all patches of a slide. + + Do the coarse image segmentation on every patch extracted at a given level. + + Args: + slidefolder: absolute path to a pathaia slide folder. + level: level of patches in the pyramid. + outfolder: absolute path to an output folder for segmentations. + model: a ml predictor. + sep_channels: whether the model is multchannel. + + """ + imfolder = os.path.join(slidefolder, "level_{}".format(level)) + if sep_channels: + partition_ = coarse_sep_channels_classif + else: + partition_ = coarse + for imfile in imfiles_in_folder(imfolder, forbiden=["thumbnail", "semantic"]): + img = imread(imfile) + seg = partition_(img, model) + imgbase = os.path.basename(imfile) + imgbase, _ = os.path.splitext(imgbase) + outfile = "{}_semantic.tiff".format(imgbase) + outfile = os.path.join(outfolder, outfile) + tifffile.imsave(outfile, seg)