Diff of /drunet/data.py [000000] .. [2824d6]

Switch to side-by-side view

--- a
+++ b/drunet/data.py
@@ -0,0 +1,216 @@
+import os
+import pathlib
+
+import tqdm
+import cv2 as cv
+import numpy as np
+import tensorflow as tf
+from tensorflow.keras import *
+import matplotlib.pyplot as plt
+import tensorflow.keras as keras
+import utils
+
+
+def return_inputs(inputs):
+    """Returns the output value according to the input type, used for image path input"""
+    all_image_paths = None
+    if type(inputs) is str:
+        if os.path.isfile(inputs):
+            all_image_paths = [inputs]
+        elif os.path.isdir(inputs):
+            all_image_paths = utils.list_file(inputs)
+    elif type(inputs) is list:
+        all_image_paths = inputs
+    return all_image_paths
+
+
+# 1. make dataset
+def get_path_name(data_dir, get_id=False, nums=-1):
+    name_list = []
+    path_list = []
+    for path in pathlib.Path(data_dir).iterdir():
+        path_list.append(str(path))
+        if get_id:
+            name_list.append(path.stem[-5:])
+        else:
+            name_list.append(path.stem)
+    if nums != -1:
+        name_list = name_list[:nums]
+        path_list = path_list[:nums]
+    name_list = sorted(name_list, key=lambda path_: int(pathlib.Path(path_).stem))
+    path_list = sorted(path_list, key=lambda path_: int(pathlib.Path(path_).stem))
+    return name_list, path_list
+
+
+class TFData:
+    def __init__(self, image_shape, image_dir=None, mask_dir=None,
+                 out_name=None, out_dir='', zip_file=True, mask_gray=True):
+        self.image_shape = image_shape
+        self.zip_file = zip_file
+        self.image_dir = image_dir
+        self.mask_dir = mask_dir
+        self.out_name = out_name
+        self.out_dir = os.path.join(out_dir, out_name)
+        self.mask_gray = mask_gray
+
+        if len(image_shape) == 3 and image_shape[-1] != 1:
+            self.image_gray = False
+        else:
+            self.image_gray = True
+        if self.zip_file:
+            self.options = tf.io.TFRecordOptions(compression_type='GZIP')
+
+        if image_dir is not None and mask_dir is not None:
+            self.image_name, self.image_list = get_path_name(self.image_dir, False)
+            self.mask_name, self.mask_list = get_path_name(self.mask_dir, False)
+            self.data_zip = zip(self.image_list, self.mask_list)
+
+    def image_to_byte(self, path, gray_scale):
+        image = cv.imread(path)
+        if not gray_scale:
+            image = cv.cvtColor(image, cv.COLOR_BGR2RGB)
+        elif len(image.shape) == 3:
+            image = cv.cvtColor(image, cv.COLOR_BGR2GRAY)
+        else:
+            pass
+        image = cv.resize(image, tuple(self.image_shape[:2]))
+
+        return image.tobytes()
+
+    def write_tfrecord(self):
+        if not os.path.exists(self.out_dir):
+            if self.zip_file:
+                writer = tf.io.TFRecordWriter(self.out_dir, self.options)
+            else:
+                writer = tf.io.TFRecordWriter(self.out_dir)
+
+            print(len(self.image_list))
+            for image_path, mask_path in tqdm.tqdm(self.data_zip, total=len(self.image_list)):
+                image = self.image_to_byte(image_path, self.image_gray)
+                mask = self.image_to_byte(mask_path, self.mask_gray)
+
+                example = tf.train.Example(features=tf.train.Features(
+                    feature={
+                        'mask': tf.train.Feature(bytes_list=tf.train.BytesList(value=[mask])),
+                        'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image]))
+                    }
+                ))
+                writer.write(example.SerializeToString())
+            writer.close()
+        print('Dataset finished!')
+
+    def _parse_function(self, example_proto):
+        features = tf.io.parse_single_example(
+            example_proto,
+            features={
+                'mask': tf.io.FixedLenFeature([], tf.string),
+                'image': tf.io.FixedLenFeature([], tf.string)
+            }
+        )
+
+        image = features['image']
+        image = tf.io.decode_raw(image, tf.uint8)
+        if self.image_gray:
+            image = tf.reshape(image, self.image_shape[:2])
+            image = tf.expand_dims(image, -1)
+        else:
+            image = tf.reshape(image, self.image_shape)
+
+        label = features['mask']
+        label = tf.io.decode_raw(label, tf.uint8)
+        if self.mask_gray:
+            label = tf.reshape(label, self.image_shape[:2])
+            label = tf.expand_dims(label, -1)
+        else:
+            label = tf.reshape(label, self.image_shape)
+
+        return image, label
+
+    def data_iterator(self, batch_size, data_name='', repeat=1, shuffle=True):
+        if len(data_name) == 0:
+            data_name = self.out_dir
+        else:
+            data_name = data_name
+
+        if self.zip_file:
+            dataset = tf.data.TFRecordDataset(data_name, compression_type='GZIP')
+        else:
+            dataset = tf.data.TFRecordDataset(data_name)
+        dataset = dataset.map(self._parse_function)
+
+        if shuffle:
+            dataset = dataset.shuffle(buffer_size=100).repeat(repeat).batch(batch_size, drop_remainder=True)
+        else:
+            dataset = dataset.repeat(repeat).batch(batch_size, drop_remainder=True)
+        return dataset
+
+
+def data_preprocess(image, mask):
+    """Normalize the image and mask data sets between 0-1"""
+    image = tf.cast(image, np.float32)
+    image = image / 127.5 - 1
+    mask = tf.cast(mask, np.float32)
+    mask = mask / 255.0
+    return image, mask
+
+
+def make_data(image_shape, image_dir, mask_dir, out_name=None, out_dir=''):
+    tf_data = TFData(image_shape=image_shape, out_dir=out_dir, out_name=out_name,
+                     image_dir=image_dir, mask_dir=mask_dir)
+    tf_data.write_tfrecord()
+    return
+
+
+def get_tfrecord_data(tf_record_path, tf_record_name, data_shape, batch_size=32, repeat=1, shuffle=True):
+    tf_data = TFData(image_shape=data_shape, out_dir=tf_record_path, out_name=tf_record_name)
+    seg_data = tf_data.data_iterator(batch_size=batch_size, repeat=repeat, shuffle=shuffle)
+    seg_data = seg_data.map(data_preprocess)
+    return seg_data
+
+
+def get_test_data(test_data_path, image_shape, image_nums=16):
+    """
+    :param test_data_path: test image path
+    :param image_shape: Need to resize the shape of the test image, a tuple of length 3, [height, width, channel]
+    :param image_nums: How many images need to be tested, the default is 16
+    :return: normalized image collection
+    """
+    or_resize_shape = (1440, 1440)
+    normalize_test_data = []
+    original_test_data = []
+    test_image_name = []
+    test_data_paths = return_inputs(test_data_path)
+
+    for path in test_data_paths:
+        try:
+            test_image_name.append(pathlib.Path(path).name)
+            original_test_image = cv.imread(str(path))
+            original_test_image = cv.resize(original_test_image, or_resize_shape)
+            original_shape = original_test_image.shape
+            if len(original_shape) == 0:
+                print('Unable to read the {} file, please keep the path without Chinese! --First'.format(str(path)))
+            else:
+                original_test_data.append(original_test_image)
+            if image_shape[-1] == 1:
+                original_test_image = cv.cvtColor(original_test_image, cv.COLOR_BGR2GRAY)
+            image = cv.resize(original_test_image, tuple(image_shape[:2]))
+            image = image.astype(np.float32)
+            image = image / 127.5 - 1
+            normalize_test_data.append(image)
+            if image_nums == -1:
+                pass
+            else:
+                if len(normalize_test_data) == image_nums:
+                    break
+        except Exception as e:
+            print('Unable to read the {} file, please keep the path without Chinese! --Second'.format(str(path)))
+            print(e)
+
+    normalize_test_array = np.array(normalize_test_data)
+    if image_shape[-1] == 1:
+        normalize_test_array = np.expand_dims(normalize_test_array, -1)
+    original_test_array = np.array(original_test_data)
+    if original_test_array.shape == 3:
+        original_test_array = np.expand_dims(original_test_array, 0)
+        normalize_test_array = np.expand_dims(normalize_test_array, 0)
+    return test_image_name, original_test_array, normalize_test_array