import h5py
import numpy as np
import os
import random
import matplotlib.pyplot as plt
import math
from functools import partial
import tensorflow as tf
from glob import glob
from Segmentation.utils.augmentation import crop_randomly_image_pair_2d, adjust_contrast_randomly_image_pair_2d
from Segmentation.utils.augmentation import adjust_brightness_randomly_image_pair_2d
from Segmentation.utils.augmentation import apply_centre_crop_3d, apply_valid_random_crop_3d
from Segmentation.utils.augmentation import apply_random_brightness_3d, apply_random_contrast_3d, apply_random_gamma_3d
from Segmentation.utils.augmentation import apply_flip_3d, apply_rotate_3d, normalise
def get_multiclass(label):
# label shape
# (batch_size, height, width, channels)
batch_size = label.shape[0]
height = label.shape[1]
width = label.shape[2]
channels = label.shape[3]
background = np.zeros((batch_size, height, width, 1))
label_sum = np.sum(label, axis=3)
background[label_sum == 0] = 1
label = np.concatenate((label, background), axis=3)
return label
def _bytes_feature(value):
"""Returns a bytes_list from a string / byte."""
if isinstance(value, type(tf.constant(0))):
value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _float_feature(value):
"""Returns a float_list from a float /p double."""
return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))
def _int64_feature(value):
"""Returns an int64_list from a bool / enum / int / uint."""
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def create_OAI_dataset(data_folder, tfrecord_directory, get_train=True, use_2d=True, crop_size=None):
if not os.path.exists(tfrecord_directory):
os.mkdir(tfrecord_directory)
train_val = 'train' if get_train else 'valid'
files = glob(os.path.join(data_folder, f'*.im'))
for idx, f in enumerate(files):
f_name = f.split("/")[-1]
f_name = f_name.split(".")[0]
fname_img = f'{f_name}.im'
fname_seg = f'{f_name}.seg'
img_filepath = os.path.join(data_folder, fname_img)
seg_filepath = os.path.join(data_folder, fname_seg)
assert os.path.exists(seg_filepath), f"Seg file does not exist: {seg_filepath}"
with h5py.File(img_filepath, 'r') as hf:
img = np.array(hf['data'])
with h5py.File(seg_filepath, 'r') as hf:
seg = np.array(hf['data'])
if crop_size is not None:
img_mid = (int(img.shape[0] / 2), int(img.shape[1] / 2))
seg_mid = (int(seg.shape[0] / 2), int(seg.shape[1] / 2))
assert img_mid == seg_mid, "We expect the mid shapes to be the same size"
seg_total = np.sum(seg)
img = img[img_mid[0] - crop_size:img_mid[0] + crop_size,
img_mid[1] - crop_size:img_mid[1] + crop_size, :]
seg = seg[seg_mid[0] - crop_size:seg_mid[0] + crop_size,
seg_mid[1] - crop_size:seg_mid[1] + crop_size, :, :]
# assert np.sum(seg) == seg_total, "We are losing information in the initial cropping."
assert img.shape == (crop_size * 2, crop_size * 2, 160)
assert seg.shape == (crop_size * 2, crop_size * 2, 160, 6)
img = np.rollaxis(img, 2, 0)
seg = np.rollaxis(seg, 2, 0)
seg_temp = np.zeros((*seg.shape[0:3], 1), dtype=np.int8)
assert seg.shape[0:3] == seg_temp.shape[0:3]
seg_sum = np.sum(seg, axis=-1)
seg_temp[seg_sum == 0] = 1
seg = np.concatenate([seg_temp, seg], axis=-1) # adds additional channel for no class
img = np.expand_dims(img, axis=-1)
assert img.shape[-1] == 1
assert seg.shape[-1] == 7
shard_dir = f'{idx:03d}-of-{len(files) - 1:03d}.tfrecords'
tfrecord_filename = os.path.join(tfrecord_directory, shard_dir)
target_shape, label_shape = None, None
with tf.io.TFRecordWriter(tfrecord_filename) as writer:
if use_2d:
for k in range(len(img)):
img_slice = img[k, :, :, :]
seg_slice = seg[k, :, :, :]
img_raw = img_slice.tostring()
seg_raw = seg_slice.tostring()
height = img_slice.shape[0]
width = img_slice.shape[1]
num_channels = seg_slice.shape[-1]
target_shape = img_slice.shape
label_shape = seg.shape
feature = {
'height': _int64_feature(height),
'width': _int64_feature(width),
'num_channels': _int64_feature(num_channels),
'image_raw': _bytes_feature(img_raw),
'label_raw': _bytes_feature(seg_raw)
}
example = tf.train.Example(features=tf.train.Features(feature=feature))
writer.write(example.SerializeToString())
else:
height = img.shape[0]
width = img.shape[1]
depth = img.shape[2]
num_channels = seg.shape[-1]
target_shape = img.shape
label_shape = seg.shape
img_raw = img.tostring()
seg_raw = seg.tostring()
feature = {
'height': _int64_feature(height),
'width': _int64_feature(width),
'depth': _int64_feature(depth),
'num_channels': _int64_feature(num_channels),
'image_raw': _bytes_feature(img_raw),
'label_raw': _bytes_feature(seg_raw)
}
example = tf.train.Example(features=tf.train.Features(feature=feature))
writer.write(example.SerializeToString())
print(f'{idx} out of {len(files) - 1} datasets have been processed. Target: {target_shape}, Label: {label_shape}')
def parse_fn_2d(example_proto, training, augmentation, multi_class=True, use_bfloat16=False, use_RGB=False):
if use_bfloat16:
dtype = tf.bfloat16
else:
dtype = tf.float32
features = {
'height': tf.io.FixedLenFeature([], tf.int64),
'width': tf.io.FixedLenFeature([], tf.int64),
'num_channels': tf.io.FixedLenFeature([], tf.int64),
'image_raw': tf.io.FixedLenFeature([], tf.string),
'label_raw': tf.io.FixedLenFeature([], tf.string)
}
# Parse the input tf.Example proto using the dictionary above.
image_features = tf.io.parse_single_example(example_proto, features)
image_raw = tf.io.decode_raw(image_features['image_raw'], tf.float32)
image = tf.cast(tf.reshape(image_raw, [384, 384, 1]), dtype)
if use_RGB:
image = tf.image.grayscale_to_rgb(image)
seg_raw = tf.io.decode_raw(image_features['label_raw'], tf.int16)
seg = tf.reshape(seg_raw, [384, 384, 7])
seg = tf.cast(seg, dtype)
if training:
if augmentation == 'random_crop':
image, seg = crop_randomly_image_pair_2d(image, seg)
elif augmentation == 'noise':
image, seg = adjust_brightness_randomly_image_pair_2d(image, seg)
image, seg = adjust_contrast_randomly_image_pair_2d(image, seg)
elif augmentation == 'crop_and_noise':
image, seg = crop_randomly_image_pair_2d(image, seg)
image, seg = adjust_brightness_randomly_image_pair_2d(image, seg)
image, seg = adjust_contrast_randomly_image_pair_2d(image, seg)
elif augmentation is None:
image = tf.image.resize_with_crop_or_pad(image, 288, 288)
seg = tf.image.resize_with_crop_or_pad(seg, 288, 288)
else:
"Augmentation strategy {} does not exist or is not supported!".format(augmentation)
else:
image = tf.image.resize_with_crop_or_pad(image, 288, 288)
seg = tf.image.resize_with_crop_or_pad(seg, 288, 288)
if not multi_class:
seg = tf.slice(seg, [0, 0, 1], [-1, -1, 6])
seg = tf.math.reduce_sum(seg, axis=-1)
seg = tf.expand_dims(seg, axis=-1)
seg = tf.clip_by_value(seg, 0, 1)
return (image, seg)
def parse_fn_3d(example_proto, training, multi_class=True, use_bfloat16=False, use_RGB=False):
if use_bfloat16:
dtype = tf.bfloat16
else:
dtype = tf.float32
features = {
'height': tf.io.FixedLenFeature([], tf.int64),
'width': tf.io.FixedLenFeature([], tf.int64),
'depth': tf.io.FixedLenFeature([], tf.int64),
'num_channels': tf.io.FixedLenFeature([], tf.int64),
'image_raw': tf.io.FixedLenFeature([], tf.string),
'label_raw': tf.io.FixedLenFeature([], tf.string)
}
# Parse the input tf.Example proto using the dictionary above.
image_features = tf.io.parse_single_example(example_proto, features)
image_raw = tf.io.decode_raw(image_features['image_raw'], tf.float32)
image = tf.reshape(image_raw, [160, 384, 384, 1])
image = tf.cast(image, dtype)
seg_raw = tf.io.decode_raw(image_features['label_raw'], tf.int16)
seg = tf.reshape(seg_raw, [160, 384, 384, 7])
seg = tf.cast(seg, dtype)
if not multi_class:
seg = tf.slice(seg, [0, 0, 0, 1], [-1, -1, -1, 6])
seg = tf.math.reduce_sum(seg, axis=-1)
seg = tf.expand_dims(seg, axis=-1)
seg = tf.clip_by_value(seg, 0, 1)
if training:
dx = tf.cast(tf.random.uniform(shape=[], minval=0, maxval=128), tf.int32)
dy = tf.cast(tf.random.uniform(shape=[], minval=0, maxval=96), tf.int32)
dz = tf.cast(tf.random.uniform(shape=[], minval=0, maxval=96), tf.int32)
image = image[dx:dx+32, dy:dy+288, dz:dz+288, :]
seg = seg[dx:dx+32, dy:dy+288, dz:dz+288, :]
else:
image = image[64:96, 48:336, 48:336, :]
seg = seg[64:96, 48:336, 48:336, :]
image = tf.reshape(image, [32, 288, 288, 1])
seg = tf.reshape(seg, [32, 288, 288, 7])
return (image, seg)
def read_tfrecord_2d(tfrecords_dir, batch_size, buffer_size, augmentation,
parse_fn=parse_fn_2d, multi_class=True,
is_training=False, use_bfloat16=False,
use_RGB=False):
file_list = tf.io.matching_files(os.path.join(tfrecords_dir, '*-*'))
shards = tf.data.Dataset.from_tensor_slices(file_list)
cycle_l = 1
if is_training:
shards = shards.shuffle(tf.cast(tf.shape(file_list)[0], tf.int64))
cycle_l = 8
if parse_fn == parse_fn_2d:
shards = shards.repeat()
dataset = shards.interleave(tf.data.TFRecordDataset,
cycle_length=cycle_l,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
if is_training:
dataset = dataset.shuffle(buffer_size=buffer_size)
parser = partial(parse_fn,
training=is_training,
augmentation=augmentation,
multi_class=multi_class,
use_bfloat16=use_bfloat16,
use_RGB=use_RGB)
dataset = dataset.map(map_func=parser, num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = dataset.batch(batch_size, drop_remainder=True).prefetch(tf.data.experimental.AUTOTUNE)
if parse_fn == parse_fn_3d:
dataset = dataset.repeat()
# optimise dataset performance
options = tf.data.Options()
options.experimental_optimization.parallel_batch = True
options.experimental_optimization.map_fusion = True
options.experimental_optimization.map_vectorization.enabled = True
options.experimental_optimization.map_parallelization = True
dataset = dataset.with_options(options)
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
return dataset
def read_tfrecord_3d(tfrecords_dir,
batch_size,
buffer_size,
is_training,
crop_size=None,
depth_crop_size=80,
aug=[],
predict_slice=False,
**kwargs):
dataset = read_tfrecord(tfrecords_dir=tfrecords_dir,
batch_size=batch_size,
buffer_size=buffer_size,
augmentation=None,
parse_fn=parse_fn_3d,
is_training=is_training,
**kwargs)
if crop_size is not None:
if is_training:
resize = "resize" in aug
random_shift = "shift" in aug
parse_crop = partial(apply_valid_random_crop_3d, crop_size=crop_size, depth_crop_size=depth_crop_size, resize=resize, random_shift=random_shift, output_slice=predict_slice)
dataset = dataset.map(map_func=parse_crop, num_parallel_calls=tf.data.experimental.AUTOTUNE)
if "bright" in aug:
dataset = dataset.map(apply_random_brightness_3d)
if "contrast" in aug:
dataset = dataset.map(apply_random_contrast_3d)
if "gamma" in aug:
dataset = dataset.map(apply_random_gamma_3d)
if "flip" in aug:
dataset = dataset.map(apply_flip_3d)
if "rotate" in aug:
dataset = dataset.map(apply_rotate_3d)
else:
parse_crop = partial(apply_centre_crop_3d, crop_size=crop_size, depth_crop_size=depth_crop_size, output_slice=predict_slice)
dataset = dataset.map(map_func=parse_crop, num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = dataset.map(normalise)
return dataset