Diff of /drunet/data.py [000000] .. [2824d6]

Switch to unified view

a b/drunet/data.py
1
import os
2
import pathlib
3
4
import tqdm
5
import cv2 as cv
6
import numpy as np
7
import tensorflow as tf
8
from tensorflow.keras import *
9
import matplotlib.pyplot as plt
10
import tensorflow.keras as keras
11
import utils
12
13
14
def return_inputs(inputs):
15
    """Returns the output value according to the input type, used for image path input"""
16
    all_image_paths = None
17
    if type(inputs) is str:
18
        if os.path.isfile(inputs):
19
            all_image_paths = [inputs]
20
        elif os.path.isdir(inputs):
21
            all_image_paths = utils.list_file(inputs)
22
    elif type(inputs) is list:
23
        all_image_paths = inputs
24
    return all_image_paths
25
26
27
# 1. make dataset
28
def get_path_name(data_dir, get_id=False, nums=-1):
29
    name_list = []
30
    path_list = []
31
    for path in pathlib.Path(data_dir).iterdir():
32
        path_list.append(str(path))
33
        if get_id:
34
            name_list.append(path.stem[-5:])
35
        else:
36
            name_list.append(path.stem)
37
    if nums != -1:
38
        name_list = name_list[:nums]
39
        path_list = path_list[:nums]
40
    name_list = sorted(name_list, key=lambda path_: int(pathlib.Path(path_).stem))
41
    path_list = sorted(path_list, key=lambda path_: int(pathlib.Path(path_).stem))
42
    return name_list, path_list
43
44
45
class TFData:
46
    def __init__(self, image_shape, image_dir=None, mask_dir=None,
47
                 out_name=None, out_dir='', zip_file=True, mask_gray=True):
48
        self.image_shape = image_shape
49
        self.zip_file = zip_file
50
        self.image_dir = image_dir
51
        self.mask_dir = mask_dir
52
        self.out_name = out_name
53
        self.out_dir = os.path.join(out_dir, out_name)
54
        self.mask_gray = mask_gray
55
56
        if len(image_shape) == 3 and image_shape[-1] != 1:
57
            self.image_gray = False
58
        else:
59
            self.image_gray = True
60
        if self.zip_file:
61
            self.options = tf.io.TFRecordOptions(compression_type='GZIP')
62
63
        if image_dir is not None and mask_dir is not None:
64
            self.image_name, self.image_list = get_path_name(self.image_dir, False)
65
            self.mask_name, self.mask_list = get_path_name(self.mask_dir, False)
66
            self.data_zip = zip(self.image_list, self.mask_list)
67
68
    def image_to_byte(self, path, gray_scale):
69
        image = cv.imread(path)
70
        if not gray_scale:
71
            image = cv.cvtColor(image, cv.COLOR_BGR2RGB)
72
        elif len(image.shape) == 3:
73
            image = cv.cvtColor(image, cv.COLOR_BGR2GRAY)
74
        else:
75
            pass
76
        image = cv.resize(image, tuple(self.image_shape[:2]))
77
78
        return image.tobytes()
79
80
    def write_tfrecord(self):
81
        if not os.path.exists(self.out_dir):
82
            if self.zip_file:
83
                writer = tf.io.TFRecordWriter(self.out_dir, self.options)
84
            else:
85
                writer = tf.io.TFRecordWriter(self.out_dir)
86
87
            print(len(self.image_list))
88
            for image_path, mask_path in tqdm.tqdm(self.data_zip, total=len(self.image_list)):
89
                image = self.image_to_byte(image_path, self.image_gray)
90
                mask = self.image_to_byte(mask_path, self.mask_gray)
91
92
                example = tf.train.Example(features=tf.train.Features(
93
                    feature={
94
                        'mask': tf.train.Feature(bytes_list=tf.train.BytesList(value=[mask])),
95
                        'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image]))
96
                    }
97
                ))
98
                writer.write(example.SerializeToString())
99
            writer.close()
100
        print('Dataset finished!')
101
102
    def _parse_function(self, example_proto):
103
        features = tf.io.parse_single_example(
104
            example_proto,
105
            features={
106
                'mask': tf.io.FixedLenFeature([], tf.string),
107
                'image': tf.io.FixedLenFeature([], tf.string)
108
            }
109
        )
110
111
        image = features['image']
112
        image = tf.io.decode_raw(image, tf.uint8)
113
        if self.image_gray:
114
            image = tf.reshape(image, self.image_shape[:2])
115
            image = tf.expand_dims(image, -1)
116
        else:
117
            image = tf.reshape(image, self.image_shape)
118
119
        label = features['mask']
120
        label = tf.io.decode_raw(label, tf.uint8)
121
        if self.mask_gray:
122
            label = tf.reshape(label, self.image_shape[:2])
123
            label = tf.expand_dims(label, -1)
124
        else:
125
            label = tf.reshape(label, self.image_shape)
126
127
        return image, label
128
129
    def data_iterator(self, batch_size, data_name='', repeat=1, shuffle=True):
130
        if len(data_name) == 0:
131
            data_name = self.out_dir
132
        else:
133
            data_name = data_name
134
135
        if self.zip_file:
136
            dataset = tf.data.TFRecordDataset(data_name, compression_type='GZIP')
137
        else:
138
            dataset = tf.data.TFRecordDataset(data_name)
139
        dataset = dataset.map(self._parse_function)
140
141
        if shuffle:
142
            dataset = dataset.shuffle(buffer_size=100).repeat(repeat).batch(batch_size, drop_remainder=True)
143
        else:
144
            dataset = dataset.repeat(repeat).batch(batch_size, drop_remainder=True)
145
        return dataset
146
147
148
def data_preprocess(image, mask):
149
    """Normalize the image and mask data sets between 0-1"""
150
    image = tf.cast(image, np.float32)
151
    image = image / 127.5 - 1
152
    mask = tf.cast(mask, np.float32)
153
    mask = mask / 255.0
154
    return image, mask
155
156
157
def make_data(image_shape, image_dir, mask_dir, out_name=None, out_dir=''):
158
    tf_data = TFData(image_shape=image_shape, out_dir=out_dir, out_name=out_name,
159
                     image_dir=image_dir, mask_dir=mask_dir)
160
    tf_data.write_tfrecord()
161
    return
162
163
164
def get_tfrecord_data(tf_record_path, tf_record_name, data_shape, batch_size=32, repeat=1, shuffle=True):
165
    tf_data = TFData(image_shape=data_shape, out_dir=tf_record_path, out_name=tf_record_name)
166
    seg_data = tf_data.data_iterator(batch_size=batch_size, repeat=repeat, shuffle=shuffle)
167
    seg_data = seg_data.map(data_preprocess)
168
    return seg_data
169
170
171
def get_test_data(test_data_path, image_shape, image_nums=16):
172
    """
173
    :param test_data_path: test image path
174
    :param image_shape: Need to resize the shape of the test image, a tuple of length 3, [height, width, channel]
175
    :param image_nums: How many images need to be tested, the default is 16
176
    :return: normalized image collection
177
    """
178
    or_resize_shape = (1440, 1440)
179
    normalize_test_data = []
180
    original_test_data = []
181
    test_image_name = []
182
    test_data_paths = return_inputs(test_data_path)
183
184
    for path in test_data_paths:
185
        try:
186
            test_image_name.append(pathlib.Path(path).name)
187
            original_test_image = cv.imread(str(path))
188
            original_test_image = cv.resize(original_test_image, or_resize_shape)
189
            original_shape = original_test_image.shape
190
            if len(original_shape) == 0:
191
                print('Unable to read the {} file, please keep the path without Chinese! --First'.format(str(path)))
192
            else:
193
                original_test_data.append(original_test_image)
194
            if image_shape[-1] == 1:
195
                original_test_image = cv.cvtColor(original_test_image, cv.COLOR_BGR2GRAY)
196
            image = cv.resize(original_test_image, tuple(image_shape[:2]))
197
            image = image.astype(np.float32)
198
            image = image / 127.5 - 1
199
            normalize_test_data.append(image)
200
            if image_nums == -1:
201
                pass
202
            else:
203
                if len(normalize_test_data) == image_nums:
204
                    break
205
        except Exception as e:
206
            print('Unable to read the {} file, please keep the path without Chinese! --Second'.format(str(path)))
207
            print(e)
208
209
    normalize_test_array = np.array(normalize_test_data)
210
    if image_shape[-1] == 1:
211
        normalize_test_array = np.expand_dims(normalize_test_array, -1)
212
    original_test_array = np.array(original_test_data)
213
    if original_test_array.shape == 3:
214
        original_test_array = np.expand_dims(original_test_array, 0)
215
        normalize_test_array = np.expand_dims(normalize_test_array, 0)
216
    return test_image_name, original_test_array, normalize_test_array