--- a +++ b/drunet/data.py @@ -0,0 +1,216 @@ +import os +import pathlib + +import tqdm +import cv2 as cv +import numpy as np +import tensorflow as tf +from tensorflow.keras import * +import matplotlib.pyplot as plt +import tensorflow.keras as keras +import utils + + +def return_inputs(inputs): + """Returns the output value according to the input type, used for image path input""" + all_image_paths = None + if type(inputs) is str: + if os.path.isfile(inputs): + all_image_paths = [inputs] + elif os.path.isdir(inputs): + all_image_paths = utils.list_file(inputs) + elif type(inputs) is list: + all_image_paths = inputs + return all_image_paths + + +# 1. make dataset +def get_path_name(data_dir, get_id=False, nums=-1): + name_list = [] + path_list = [] + for path in pathlib.Path(data_dir).iterdir(): + path_list.append(str(path)) + if get_id: + name_list.append(path.stem[-5:]) + else: + name_list.append(path.stem) + if nums != -1: + name_list = name_list[:nums] + path_list = path_list[:nums] + name_list = sorted(name_list, key=lambda path_: int(pathlib.Path(path_).stem)) + path_list = sorted(path_list, key=lambda path_: int(pathlib.Path(path_).stem)) + return name_list, path_list + + +class TFData: + def __init__(self, image_shape, image_dir=None, mask_dir=None, + out_name=None, out_dir='', zip_file=True, mask_gray=True): + self.image_shape = image_shape + self.zip_file = zip_file + self.image_dir = image_dir + self.mask_dir = mask_dir + self.out_name = out_name + self.out_dir = os.path.join(out_dir, out_name) + self.mask_gray = mask_gray + + if len(image_shape) == 3 and image_shape[-1] != 1: + self.image_gray = False + else: + self.image_gray = True + if self.zip_file: + self.options = tf.io.TFRecordOptions(compression_type='GZIP') + + if image_dir is not None and mask_dir is not None: + self.image_name, self.image_list = get_path_name(self.image_dir, False) + self.mask_name, self.mask_list = get_path_name(self.mask_dir, False) + self.data_zip = zip(self.image_list, self.mask_list) + + def image_to_byte(self, path, gray_scale): + image = cv.imread(path) + if not gray_scale: + image = cv.cvtColor(image, cv.COLOR_BGR2RGB) + elif len(image.shape) == 3: + image = cv.cvtColor(image, cv.COLOR_BGR2GRAY) + else: + pass + image = cv.resize(image, tuple(self.image_shape[:2])) + + return image.tobytes() + + def write_tfrecord(self): + if not os.path.exists(self.out_dir): + if self.zip_file: + writer = tf.io.TFRecordWriter(self.out_dir, self.options) + else: + writer = tf.io.TFRecordWriter(self.out_dir) + + print(len(self.image_list)) + for image_path, mask_path in tqdm.tqdm(self.data_zip, total=len(self.image_list)): + image = self.image_to_byte(image_path, self.image_gray) + mask = self.image_to_byte(mask_path, self.mask_gray) + + example = tf.train.Example(features=tf.train.Features( + feature={ + 'mask': tf.train.Feature(bytes_list=tf.train.BytesList(value=[mask])), + 'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])) + } + )) + writer.write(example.SerializeToString()) + writer.close() + print('Dataset finished!') + + def _parse_function(self, example_proto): + features = tf.io.parse_single_example( + example_proto, + features={ + 'mask': tf.io.FixedLenFeature([], tf.string), + 'image': tf.io.FixedLenFeature([], tf.string) + } + ) + + image = features['image'] + image = tf.io.decode_raw(image, tf.uint8) + if self.image_gray: + image = tf.reshape(image, self.image_shape[:2]) + image = tf.expand_dims(image, -1) + else: + image = tf.reshape(image, self.image_shape) + + label = features['mask'] + label = tf.io.decode_raw(label, tf.uint8) + if self.mask_gray: + label = tf.reshape(label, self.image_shape[:2]) + label = tf.expand_dims(label, -1) + else: + label = tf.reshape(label, self.image_shape) + + return image, label + + def data_iterator(self, batch_size, data_name='', repeat=1, shuffle=True): + if len(data_name) == 0: + data_name = self.out_dir + else: + data_name = data_name + + if self.zip_file: + dataset = tf.data.TFRecordDataset(data_name, compression_type='GZIP') + else: + dataset = tf.data.TFRecordDataset(data_name) + dataset = dataset.map(self._parse_function) + + if shuffle: + dataset = dataset.shuffle(buffer_size=100).repeat(repeat).batch(batch_size, drop_remainder=True) + else: + dataset = dataset.repeat(repeat).batch(batch_size, drop_remainder=True) + return dataset + + +def data_preprocess(image, mask): + """Normalize the image and mask data sets between 0-1""" + image = tf.cast(image, np.float32) + image = image / 127.5 - 1 + mask = tf.cast(mask, np.float32) + mask = mask / 255.0 + return image, mask + + +def make_data(image_shape, image_dir, mask_dir, out_name=None, out_dir=''): + tf_data = TFData(image_shape=image_shape, out_dir=out_dir, out_name=out_name, + image_dir=image_dir, mask_dir=mask_dir) + tf_data.write_tfrecord() + return + + +def get_tfrecord_data(tf_record_path, tf_record_name, data_shape, batch_size=32, repeat=1, shuffle=True): + tf_data = TFData(image_shape=data_shape, out_dir=tf_record_path, out_name=tf_record_name) + seg_data = tf_data.data_iterator(batch_size=batch_size, repeat=repeat, shuffle=shuffle) + seg_data = seg_data.map(data_preprocess) + return seg_data + + +def get_test_data(test_data_path, image_shape, image_nums=16): + """ + :param test_data_path: test image path + :param image_shape: Need to resize the shape of the test image, a tuple of length 3, [height, width, channel] + :param image_nums: How many images need to be tested, the default is 16 + :return: normalized image collection + """ + or_resize_shape = (1440, 1440) + normalize_test_data = [] + original_test_data = [] + test_image_name = [] + test_data_paths = return_inputs(test_data_path) + + for path in test_data_paths: + try: + test_image_name.append(pathlib.Path(path).name) + original_test_image = cv.imread(str(path)) + original_test_image = cv.resize(original_test_image, or_resize_shape) + original_shape = original_test_image.shape + if len(original_shape) == 0: + print('Unable to read the {} file, please keep the path without Chinese! --First'.format(str(path))) + else: + original_test_data.append(original_test_image) + if image_shape[-1] == 1: + original_test_image = cv.cvtColor(original_test_image, cv.COLOR_BGR2GRAY) + image = cv.resize(original_test_image, tuple(image_shape[:2])) + image = image.astype(np.float32) + image = image / 127.5 - 1 + normalize_test_data.append(image) + if image_nums == -1: + pass + else: + if len(normalize_test_data) == image_nums: + break + except Exception as e: + print('Unable to read the {} file, please keep the path without Chinese! --Second'.format(str(path))) + print(e) + + normalize_test_array = np.array(normalize_test_data) + if image_shape[-1] == 1: + normalize_test_array = np.expand_dims(normalize_test_array, -1) + original_test_array = np.array(original_test_data) + if original_test_array.shape == 3: + original_test_array = np.expand_dims(original_test_array, 0) + normalize_test_array = np.expand_dims(normalize_test_array, 0) + return test_image_name, original_test_array, normalize_test_array