from PIL import Image
import numpy as np
import tensorflow as tf
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def write_image_annotation_pairs_to_tfrecord(filename_pairs, tfrecords_filename):
"""Writes given image/annotation pairs to the tfrecords file.
The function reads each image/annotation pair given filenames
of image and respective annotation and writes it to the tfrecord
file.
Parameters
----------
filename_pairs : array of tuples (img_filepath, annotation_filepath)
Array of tuples of image/annotation filenames
tfrecords_filename : string
Tfrecords filename to write the image/annotation pairs
"""
writer = tf.python_io.TFRecordWriter(tfrecords_filename)
for img_path, annotation_path in filename_pairs:
img = np.array(Image.open(img_path))
annotation = np.array(Image.open(annotation_path))
# Unomment this one when working with surgical data
# annotation = annotation[:, :, 0]
# The reason to store image sizes was demonstrated
# in the previous example -- we have to know sizes
# of images to later read raw serialized string,
# convert to 1d array and convert to respective
# shape that image used to have.
height = img.shape[0]
width = img.shape[1]
img_raw = img.tostring()
annotation_raw = annotation.tostring()
example = tf.train.Example(features=tf.train.Features(feature={
'height': _int64_feature(height),
'width': _int64_feature(width),
'image_raw': _bytes_feature(img_raw),
'mask_raw': _bytes_feature(annotation_raw)}))
writer.write(example.SerializeToString())
writer.close()
def read_image_annotation_pairs_from_tfrecord(tfrecords_filename):
"""Return image/annotation pairs from the tfrecords file.
The function reads the tfrecords file and returns image
and respective annotation matrices pairs.
Parameters
----------
tfrecords_filename : string
filename of .tfrecords file to read from
Returns
-------
image_annotation_pairs : array of tuples (img, annotation)
The image and annotation that were read from the file
"""
image_annotation_pairs = []
record_iterator = tf.python_io.tf_record_iterator(path=tfrecords_filename)
for string_record in record_iterator:
example = tf.train.Example()
example.ParseFromString(string_record)
height = int(example.features.feature['height']
.int64_list
.value[0])
width = int(example.features.feature['width']
.int64_list
.value[0])
img_string = (example.features.feature['image_raw']
.bytes_list
.value[0])
annotation_string = (example.features.feature['mask_raw']
.bytes_list
.value[0])
img_1d = np.fromstring(img_string, dtype=np.uint8)
img = img_1d.reshape((height, width, -1))
annotation_1d = np.fromstring(annotation_string, dtype=np.uint8)
# Annotations don't have depth (3rd dimension)
# TODO: check if it works for other datasets
annotation = annotation_1d.reshape((height, width))
image_annotation_pairs.append((img, annotation))
return image_annotation_pairs
def read_tfrecord_and_decode_into_image_annotation_pair_tensors(tfrecord_filenames_queue):
"""Return image/annotation tensors that are created by reading tfrecord file.
The function accepts tfrecord filenames queue as an input which is usually
can be created using tf.train.string_input_producer() where filename
is specified with desired number of epochs. This function takes queue
produced by aforemention tf.train.string_input_producer() and defines
tensors converted from raw binary representations into
reshaped image/annotation tensors.
Parameters
----------
tfrecord_filenames_queue : tfrecord filename queue
String queue object from tf.train.string_input_producer()
Returns
-------
image, annotation : tuple of tf.int32 (image, annotation)
Tuple of image/annotation tensors
"""
reader = tf.TFRecordReader()
_, serialized_example = reader.read(tfrecord_filenames_queue)
features = tf.parse_single_example(
serialized_example,
features={
'height': tf.FixedLenFeature([], tf.int64),
'width': tf.FixedLenFeature([], tf.int64),
'image_raw': tf.FixedLenFeature([], tf.string),
'mask_raw': tf.FixedLenFeature([], tf.string)
})
image = tf.decode_raw(features['image_raw'], tf.uint8)
annotation = tf.decode_raw(features['mask_raw'], tf.uint8)
height = tf.cast(features['height'], tf.int32)
width = tf.cast(features['width'], tf.int32)
image_shape = tf.stack([height, width, 1])
annotation_shape = tf.stack([height, width, 1])
image = tf.reshape(image, image_shape)
image = tf.concat([image, image, image], axis=2)
annotation = tf.reshape(annotation, annotation_shape)
return image, annotation