import Augmentor
import random
import numpy as np
from Augmentor.Operations import Operation
import os
from itertools import cycle
from os.path import join
from random import randint
import random
import cv2
from skimage import exposure
from skimage.exposure import match_histograms
from mediaug.image_utils import is_greyscale, rotate, soften_mask, image_on_image_alpha, get_blank_mask
from mediaug.dataset import Dataset
from mediaug.image_utils import pil_to_np
# import some operations
from Augmentor.Operations import Crop, Rotate, Flip, Distort, Zoom, RandomBrightness, HistogramEqualisation, Scale
def perform_operation(dp, op):
""" Do an operation to a DataPoint
Args:
dp (DataPoint): An image data point
op (Operation): An augmentation operation
Returns:
augmented
"""
return (pil_to_np(x) for x in op.perform_operation([dp.pil_img, dp.pil_mask]))
class Pipeline(Augmentor.DataPipeline):
def __init__(self, ds):
images, labels = [], []
for i, _class in enumerate(ds.classes):
class_label = i
for x in ds[_class]:
images.append([x.img, x.mask])
labels.append(class_label)
Augmentor.DataPipeline.__init__(self, images, labels)
def generator(self, batch_size=1):
gen = Augmentor.DataPipeline.generator(self, batch_size=batch_size)
while True:
images, classes = next(gen)
image_list, mask_list = [], []
for x in images:
img, mask = x[0], x[1]
image_list.append(img)
mask_list.append(mask)
if batch_size == 1:
yield image_list[0], mask_list[0], classes[0]
else:
yield image_list, mask_list, classes
def get_data_generator(image_path, mask_path, batch_size=1):
pipeline = Augmentor.Pipeline(image_path)
if mask_path is not None:
pipeline.ground_truth(mask_path)
pipeline.rotate(probability=0.5, max_left_rotation=25, max_right_rotation=25)
pipeline.flip_left_right(probability=0.5)
pipeline.zoom_random(probability=0.5, percentage_area=0.6)
pipeline.flip_top_bottom(probability=0.5)
pipeline.random_distortion(probability=.3, grid_width=8, grid_height=8, magnitude=5)
pipeline.crop_random(.05, .85)
gen = pipeline.keras_generator(batch_size=batch_size)
return gen
def randomly_insert_cells(img: np.array, mask: np.array,
ds: Dataset, cell_names_to_add: list,
num_cell_range: tuple) -> np.array:
""" Randomly inserts cells into image
Args:
img (np.array)
mask (np.array)
ds (Datset)
num_cell_range (tuple): ex (1, 5)
Returns:
new_img (np.array)
new_mask (np.array)
"""
h, w = img.shape[:2]
cell_list = []
for cell_name in cell_names_to_add:
cell_list += ds[cell_name]
num_cells_to_insert = randint(*num_cell_range)
for i in range(num_cells_to_insert):
cell = random.choice(cell_list)
b = 5
pos = (randint(b, h-b), randint(b, w-b))
angle = randint(0, 360)
scale = random.normalvariate(1, .2)
img, mask = add_cell(img, mask, cell.img, cell.mask, pos, angle, scale)
print(f'Adding slide with {num_cells_to_insert}')
return img, mask
def add_cell(bg, bg_mask, fg, orig_fg_mask, pos, angle=0, scale=1,
blend_method=None, blend_edge_amount=1):
""" adds a cell to base image
Args:
bg (np.array): background img
bg_mask (np.array): background img mask
fg (np.array): foreground img to add
fg_mask (np.array): foreground img mask
pos (x,y): where to put the fg
angle (float): angle of fg in degrees
b (int): the abount to blend from the mask
Returns:
img (np.array)
mask (np.array)
"""
fg = cv2.resize(fg, (0,0), fx=scale, fy=scale)
orig_fg_mask = cv2.resize(orig_fg_mask, (0,0), fx=scale, fy=scale)
fg_mask = cv2.cvtColor(orig_fg_mask, cv2.COLOR_BGR2GRAY)
# blend mode
if blend_method == 'hist':
fg = match_histograms(fg, bg, multichannel=True)
# prep the mask : TODO: improve this
fg_mask[fg_mask == 51] = 160
fg_mask[fg_mask == 170] = 250
fg = rotate(fg, angle)
fg_mask = rotate(fg_mask, angle)
fg_mask = soften_mask(fg_mask, amount=blend_edge_amount) # FIXME: this also blurs nucleus
new_img = image_on_image_alpha(bg, fg, fg_mask, pos)
fg_mask[fg_mask != 0] = 255
new_mask = image_on_image_alpha(bg_mask, orig_fg_mask, fg_mask, pos)
return new_img, new_mask