a b/UNET.py
1
"""
2
data:
3
    CT         :     used
4
    mask       :     used
5
    labels(txt): not used
6
    labelsJson : not used
7
"""
8
9
10
# U-Net Structure:
11
# 1->64->64...............................................................->(*/)128->64->64=>2      # 1: input image, 2: output segmentation map
12
#    (-+)64->128->128...........................................->(*/)256->128->128
13
#             (-+)128->256->256.......................->(*/)512->256->256
14
#                       (-+)256->512->512..->(*/)1024->512->512     # 1024: 512+512
15
#                                 (-+)512->1024->1024
16
# ->  : conv 3x3, RELU
17
# ..->: copy & crop
18
# (-+): max pool 2x2
19
# (*/): up-conv 2x2
20
# =>  : conv 1x1
21
22
'''
23
Issues:
24
1. 要不直接把preprocessing_tmp1 经过"data_generation"分割90% 出来用作训练集, 并存为dataset放在外面目录下
25
 ----> 看下之后test部分的图片怎么预处理, 如果处理方式一样那就dataset拿出来放到外面去
26
'''
27
28
from keras._tf_keras import keras  # CPU - keras > 3.*
29
from keras._tf_keras.keras.layers import *  # CPU - keras > 3.*
30
from keras._tf_keras.keras.preprocessing.image import (
31
    ImageDataGenerator,
32
)  # CPU - keras > 3.*
33
34
# from keras.layers import *  # GPU - keras > 2.*
35
# from keras.callbacks import ModelCheckpoint  # GPU - keras > 2.*
36
# from keras.preprocessing.image import ImageDataGenerator  # GPU - keras > 2.*
37
38
from keras import Model
39
from keras import backend as K
40
41
import os
42
import numpy as np
43
# from data_preparation import draw_image
44
import matplotlib.pyplot as plt
45
import cv2
46
47
48
# img & mask
49
# 3. data augmentation: (import tensorflow.keras.preprocessing.Image)
50
#         [1] define an image_generator -> ImageDataGenerator()
51
#         [2] image data augmentation -> flow_from_directory()
52
#         [3] image normalization
53
#     问题:
54
#         [1] 先进行.nii -> png/json/txt, 后进一步keras数据增强
55
#             有个问题: images/mask增强后随之的json/txt是否也要发生改变 ---> ???
56
#         [2] tensorflow和torch一起用 ---> 可以
57
#             model -> Yolo 使用的是pytorch
58
#             data augmentation 使用的是 tensorflow->keras
59
#         [3] gene后一定要跟fit()均值化,否则会提示:
60
#               F:\AI_Outils\Anaconda\1\envs\opencv_CPU\Lib\site-packages\keras\src\legacy\preprocessing\image.py:1263: UserWarning: This ImageDataGenerator specifies `featurewise_center`, but it hasn't been fit on any training data. Fit it first by calling `.fit(numpy_data)`.
61
62
63
# data augmentation for train
64
def train_generator(dataset_path, type):
65
    data_path = os.path.join(dataset_path, type)
66
    data_pre_path = os.path.join(dataset_path, f"{type}_generator")
67
    img_png_path = os.path.join(data_pre_path, "images")
68
    mask_png_path = os.path.join(data_pre_path, "masks")
69
70
    PATH = {
71
        data_pre_path,
72
        img_png_path,
73
        mask_png_path,
74
    }
75
    for path in PATH:
76
        os.makedirs(path, exist_ok=True)
77
78
    # 3.1 define an image_generator: to perform various transformations on object
79
    generator_args = dict(
80
        rotation_range=0.1,
81
        width_shift_range=0.05,
82
        height_shift_range=0.05,
83
        shear_range=0.05,
84
        zoom_range=0.05,
85
        horizontal_flip=False,
86
        vertical_flip=False,
87
    )
88
    generator_image = ImageDataGenerator(generator_args)
89
    generator_mask = ImageDataGenerator(generator_args)
90
91
    # 3.2 implement further data augmentation for image & mask
92
    generation_image = generator_image.flow_from_directory(
93
        directory=data_path,
94
        classes=["images"],
95
        class_mode=None,
96
        color_mode="grayscale",
97
        target_size=(512, 512),
98
        batch_size=2,
99
        save_to_dir=os.path.join(data_pre_path, "images"),
100
        # save_prefix='ct_',
101
        seed=123,
102
    )
103
    generation_mask = generator_mask.flow_from_directory(
104
        directory=data_path,
105
        classes=["masks"],
106
        class_mode=None,
107
        color_mode="grayscale",
108
        target_size=(512, 512),
109
        batch_size=2,
110
        save_to_dir=os.path.join(data_pre_path, "masks"),
111
        # save_prefix='mask_',
112
        seed=123,
113
    )
114
    generation = zip(generation_image, generation_mask)
115
116
    print("2--------------------------")
117
    i = 0
118
    # 3.3 image normalization (image -> not normalized yet, mask -> binary)
119
    for image, mask in generation:
120
        '''
121
        i = i + 1
122
123
        # output image data to TXT
124
        arr = np.array(image[0][:, :, 0])
125
        np.savetxt("array_0.txt", arr)
126
        print(f"image: min: {np.nanmin(arr)}, max: {np.nanmax(arr)}.")
127
128
        # output image data to TXT
129
        arr = np.array(normalization(image)[0][:, :, 0])
130
        np.savetxt("array_1.txt", arr)
131
        print(
132
            f"normalization_image: min: {np.nanmin(arr)}, max: {np.nanmax(arr)}."
133
        )
134
135
        print(image.shape, mask.shape)  # (2, 256, 256, 1) (2, 256, 256, 1) --> batch_size=2
136
137
        # visualization
138
        data_img_slices = [
139
            image[0][:, :, 0],
140
            normalization(image)[0][:, :, 0],
141
            mask[0][:, :, 0],
142
            image[1][:, :, 0],
143
            normalization(image)[1][:, :, 0],
144
            mask[1][:, :, 0],
145
        ]
146
        draw_image(data_img_slices, 1, 6, None)
147
        
148
        if i == 1:
149
            break
150
        '''
151
152
        yield (normalization(image), mask)
153
        # image[0][:, :, 0] = normalization(image[0][:, :, 0])
154
        # image[1][:, :, 0] = normalization(image[1][:, :, 0])
155
    print("Further data augmentation was completed successfully.")
156
157
158
def binarization(data):
159
    """
160
    Binarization: Converts data to only two values, e.g. 0 & 1
161
    To do: To highlight certain features in the image
162
    Processing: x'[x/255.0 > 0.5] = 1.0
163
                x'[x/255.0 <= 0.5] = 0.0
164
    """
165
    data_binary = data / 255.0
166
    data_binary[data_binary > 0.5] = 1.0
167
    data_binary[data_binary <= 0.5] = 0.0
168
    return data_binary
169
170
171
def standardization(data):
172
    """
173
    Standardization: Converts the data into a new distribution with a mean of 0 and a standard deviation of 1
174
    To do: to have comparability between different features (if data feature value range/unit is quite different --> perform standardization).
175
        --> Standardization does not change the distribution of feature data
176
    Processing: x' = (x - mean) / std
177
    """
178
    mean = np.mean(data)
179
    std = np.std(data)
180
    data_std = (data - mean) / std
181
    return data_std
182
183
184
def normalization(data):
185
    """
186
    Normalization: Scale the data to a specific range, e.g. 0-1
187
    To do: To make the influence of each feature on the target variable consistent.
188
        --> Data normalization changes the distribution of feature data
189
    Processing: x' = (x - min)/(max - min)
190
    Note: In the medical field, normalization is generally performed --> to accelerate the convergence of the network and make the model more stable
191
    """
192
    min = np.nanmin(data)
193
    max = np.nanmax(data)
194
    data_nor = (data - min) / (max - min)
195
    return data_nor
196
197
198
def u_net(input_size = (512, 512, 1), path=None):
199
    # layer 1-1
200
    inputs_L1_1 = Input(input_size)
201
    conv1_L1_1 = Conv2D(64, 3, activation="relu", padding="same", kernel_initializer="he_normal")(inputs_L1_1)  # filters: 64, kernel_size:3x3, kernel_initializer: use normal distribution to initializer Weights of kernel
202
    conv2_L1_1 = Conv2D(64, 3, activation="relu", padding="same", kernel_initializer="he_normal")(conv1_L1_1)
203
    pool1_L1_1 = MaxPool2D(pool_size=(2,2))(conv2_L1_1)
204
205
    # layer 2-1
206
    conv3_L2_1 = Conv2D(128, 3, activation="relu", padding="same", kernel_initializer="he_normal")(pool1_L1_1)
207
    conv4_L2_1 = Conv2D(128, 3, activation="relu", padding="same", kernel_initializer="he_normal")(conv3_L2_1)
208
    pool2_L2_1 = MaxPool2D(pool_size=(2, 2))(conv4_L2_1)
209
210
    # layer 3-1
211
    conv5_L3_1 = Conv2D(256, 3, activation="relu", padding="same", kernel_initializer="he_normal")(pool2_L2_1)
212
    conv6_L3_1 = Conv2D(256, 3, activation="relu", padding="same", kernel_initializer="he_normal")(conv5_L3_1)
213
    pool3_L3_1 = MaxPool2D(pool_size=(2, 2))(conv6_L3_1)
214
215
    # layer 4-1
216
    conv7_L4_1 = Conv2D(512, 3, activation="relu", padding="same", kernel_initializer="he_normal")(pool3_L3_1)
217
    conv8_L4_1 = Conv2D(512, 3, activation="relu", padding="same", kernel_initializer="he_normal")(conv7_L4_1)
218
    pool4_L4_1 = MaxPool2D(pool_size=(2, 2))(conv8_L4_1)
219
220
    # layer 5
221
    conv9_L5 = Conv2D(1024, 3, activation="relu", padding="same", kernel_initializer="he_normal")(pool4_L4_1)
222
    conv10_L5 = Conv2D(1024, 3, activation="relu", padding="same", kernel_initializer="he_normal")(conv9_L5)
223
    up1_L5 = UpSampling2D(size=(2, 2))(conv10_L5)  # deconvolution
224
225
    # layer 4-2
226
    conv11_L4_2 = Conv2D(512, 3, activation="relu", padding="same", kernel_initializer="he_normal")(concatenate([up1_L5, conv8_L4_1], axis=3))  # concatenation
227
    conv12_L4_2 = Conv2D(512, 3, activation="relu", padding="same", kernel_initializer="he_normal")(conv11_L4_2)
228
    up2_L4_2 = UpSampling2D(size=(2, 2))(conv12_L4_2)
229
230
    # layer 3-2
231
    conv13_L3_2 = Conv2D(256, 3, activation="relu", padding="same", kernel_initializer="he_normal")(concatenate([up2_L4_2, conv6_L3_1], axis=3))
232
    conv14_L3_2 = Conv2D(256, 3, activation="relu", padding="same", kernel_initializer="he_normal")(conv13_L3_2)
233
    up3_L3_2 = UpSampling2D(size=(2, 2))(conv14_L3_2)
234
235
    # layer 2-2
236
    conv15_L2_2 = Conv2D(128, 3, activation="relu", padding="same", kernel_initializer="he_normal")(concatenate([up3_L3_2, conv4_L2_1], axis=3))
237
    conv16_L2_2 = Conv2D(128, 3, activation="relu", padding="same", kernel_initializer="he_normal")(conv15_L2_2)
238
    up4_L2_2 = UpSampling2D(size=(2, 2))(conv16_L2_2)
239
240
    # layer 1-2
241
    conv17_L1_2 = Conv2D(64, 3, activation="relu", padding="same", kernel_initializer="he_normal")(concatenate([up4_L2_2, conv2_L1_1], axis=3))
242
    conv18_L1_2 = Conv2D(64, 3, activation="relu", padding="same", kernel_initializer="he_normal")(conv17_L1_2)
243
    outputs_L1_2 = Conv2D(1, 1, activation="sigmoid")(conv18_L1_2)
244
245
    # build model
246
    model = Model(inputs = inputs_L1_1, outputs = outputs_L1_2)
247
248
    # compile model
249
    '''
250
    Loss     : 0-1 binary cross-entropy (binary_crossentropy)
251
    Optimizer: Adaptive Descent (Adam)
252
    Callback : After each epoch is trained, autosave a best pre-trained model(optimal weights). (keras.callbacks.ModelCheckpoint)
253
    '''
254
    model.compile(optimizer="adam", loss="binary_crossentropy", metrics=["accuracy"])
255
256
    return model
257
258
259
class ShowMask(keras.callbacks.Callback):
260
    def __init__(self):
261
        super().__init__()
262
263
    def on_epoch_end(self, epoch, logs=None):
264
        print()
265
        idx = 0
266
        for img, mask in gene:
267
            compare_list = [img[0], mask[0], model.predict(img[0].reshape(1, 512, 512, 1))[0]]
268
            for i in range(0, len(compare_list)):
269
                plt.subplot(1, 3, i+1)
270
                plt.imshow(compare_list[i], cmap="gray")
271
                plt.axis(False)
272
            # plt.show()
273
            plt.savefig(f"compare_{idx}.png")
274
            idx = idx + 1
275
            break
276
        # return super().on_epoch_end(epoch, logs)
277
278
279
# Test
280
281
# dataset: data augmentation for train data
282
UNETDataset_path = "./UNETDataset"
283
gene = train_generator(UNETDataset_path, "train")  # could be used as input to the model and directly as training
284
285
# model params
286
steps_per_epoch = 50
287
epochs = 100
288
model_name = f"u_net-512-512-1-pneumonia_{epochs}_{steps_per_epoch}.keras"
289
models_path = "./models/"
290
model_path = os.path.join(models_path, model_name)
291
model_ckpt = keras.callbacks.ModelCheckpoint(model_path, save_best_only=False, verbose=1)
292
293
# train
294
K.clear_session() # keras
295
296
model = u_net(path=model_path) # structure
297
# print(model.summary())
298
299
# model.fit(gene, steps_per_epoch=steps_per_epoch, epochs=epochs, callbacks=[model_ckpt, ShowMask()]) # train
300
model.fit(
301
    gene,
302
    steps_per_epoch=steps_per_epoch,
303
    epochs=epochs,
304
    callbacks=[model_ckpt],
305
)  # train
306
307
308
309
# 思路:
310
# 1. Dataset_mini -> CPU OK, GPU KO
311
# 2. cudnn v8 -> GPU KO
312
# 3. 缩减 input_size -> 512->256
313
# 4. 缩减 unet structure
314
315
# evalution
316
data_val_generator_path = os.path.join(UNETDataset_path, "val_generator")
317
compare_path = os.path.join(data_val_generator_path, "compare")
318
PATH = {data_val_generator_path, compare_path}
319
for path in PATH:
320
    os.makedirs(path, exist_ok=True)
321
322
# model_test= keras.models.load_model(model_path)
323
'''
324
gene = train_generator(UNETDataset_path, "val")
325
idx = 0
326
for img, mask in gene:
327
    predict_mask = model_test.predict(img)[0]
328
    # predict_mask_np = (predict_mask * 255).numpy()
329
330
    _, real_mask = cv2.threshold(mask[0], 127, 255, 0)
331
    real_mask = (real_mask).astype('uint8')
332
    real_contours, _ = cv2.findContours(real_mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
333
    real_overlap_img = cv2.drawContours(img[0].copy(), real_contours, -1, (0, 255, 0), 2)
334
335
    _, pred_mask = cv2.threshold((predict_mask * 255).astype("uint8"), 127, 255, 0)
336
    pred_mask = (pred_mask).astype('uint8')
337
    pred_contours, _ = cv2.findContours(pred_mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
338
    pred_overlap_img = cv2.drawContours(img[0].copy(), pred_contours, -1, (255, 0, 0), 2)
339
340
    compare_list = [img[0], pred_mask, real_overlap_img, pred_overlap_img]
341
342
    for i in range(0, len(compare_list)):
343
        plt.subplot(1, 4, i+1)
344
        plt.imshow(compare_list[i], cmap="gray")
345
        plt.axis(False)
346
    # plt.show()
347
    save_path = os.path.join(compare_path, f"compare_{idx}.png")
348
    plt.savefig(save_path)
349
350
    idx = idx + 1
351
'''
352
353
'''
354
# test save compare_png
355
356
images_path = os.path.join(data_val_generator_path, "images")
357
masks_path = os.path.join(data_val_generator_path, "masks")
358
359
idx = 0
360
for png_name in os.listdir(images_path):
361
    # predict_mask = model_test.predict(img)[0]
362
    
363
    img = cv2.imread(os.path.join(images_path, png_name))
364
    mask = cv2.imread(os.path.join(masks_path, png_name), cv2.IMREAD_GRAYSCALE)
365
366
    _, real_mask = cv2.threshold(mask, 127, 255, 0)
367
    real_mask = (real_mask).astype("uint8")
368
    real_contours, _ = cv2.findContours(
369
        real_mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
370
    )
371
    real_overlap_img = cv2.drawContours(
372
        img.copy(), real_contours, -1, (0, 255, 0), 2
373
    )
374
375
    compare_list = [img, real_overlap_img]
376
377
    for i in range(0, len(compare_list)):
378
        plt.subplot(1, 2, i + 1)
379
        plt.imshow(compare_list[i], cmap="gray")
380
        plt.axis(False)
381
    # plt.show()
382
    save_path = os.path.join(compare_path, f"compare_{idx}.png")
383
    plt.savefig(save_path)
384
385
    idx = idx + 1
386
'''