|
a |
|
b/UNET.py |
|
|
1 |
""" |
|
|
2 |
data: |
|
|
3 |
CT : used |
|
|
4 |
mask : used |
|
|
5 |
labels(txt): not used |
|
|
6 |
labelsJson : not used |
|
|
7 |
""" |
|
|
8 |
|
|
|
9 |
|
|
|
10 |
# U-Net Structure: |
|
|
11 |
# 1->64->64...............................................................->(*/)128->64->64=>2 # 1: input image, 2: output segmentation map |
|
|
12 |
# (-+)64->128->128...........................................->(*/)256->128->128 |
|
|
13 |
# (-+)128->256->256.......................->(*/)512->256->256 |
|
|
14 |
# (-+)256->512->512..->(*/)1024->512->512 # 1024: 512+512 |
|
|
15 |
# (-+)512->1024->1024 |
|
|
16 |
# -> : conv 3x3, RELU |
|
|
17 |
# ..->: copy & crop |
|
|
18 |
# (-+): max pool 2x2 |
|
|
19 |
# (*/): up-conv 2x2 |
|
|
20 |
# => : conv 1x1 |
|
|
21 |
|
|
|
22 |
''' |
|
|
23 |
Issues: |
|
|
24 |
1. 要不直接把preprocessing_tmp1 经过"data_generation"分割90% 出来用作训练集, 并存为dataset放在外面目录下 |
|
|
25 |
----> 看下之后test部分的图片怎么预处理, 如果处理方式一样那就dataset拿出来放到外面去 |
|
|
26 |
''' |
|
|
27 |
|
|
|
28 |
from keras._tf_keras import keras # CPU - keras > 3.* |
|
|
29 |
from keras._tf_keras.keras.layers import * # CPU - keras > 3.* |
|
|
30 |
from keras._tf_keras.keras.preprocessing.image import ( |
|
|
31 |
ImageDataGenerator, |
|
|
32 |
) # CPU - keras > 3.* |
|
|
33 |
|
|
|
34 |
# from keras.layers import * # GPU - keras > 2.* |
|
|
35 |
# from keras.callbacks import ModelCheckpoint # GPU - keras > 2.* |
|
|
36 |
# from keras.preprocessing.image import ImageDataGenerator # GPU - keras > 2.* |
|
|
37 |
|
|
|
38 |
from keras import Model |
|
|
39 |
from keras import backend as K |
|
|
40 |
|
|
|
41 |
import os |
|
|
42 |
import numpy as np |
|
|
43 |
# from data_preparation import draw_image |
|
|
44 |
import matplotlib.pyplot as plt |
|
|
45 |
import cv2 |
|
|
46 |
|
|
|
47 |
|
|
|
48 |
# img & mask |
|
|
49 |
# 3. data augmentation: (import tensorflow.keras.preprocessing.Image) |
|
|
50 |
# [1] define an image_generator -> ImageDataGenerator() |
|
|
51 |
# [2] image data augmentation -> flow_from_directory() |
|
|
52 |
# [3] image normalization |
|
|
53 |
# 问题: |
|
|
54 |
# [1] 先进行.nii -> png/json/txt, 后进一步keras数据增强 |
|
|
55 |
# 有个问题: images/mask增强后随之的json/txt是否也要发生改变 ---> ??? |
|
|
56 |
# [2] tensorflow和torch一起用 ---> 可以 |
|
|
57 |
# model -> Yolo 使用的是pytorch |
|
|
58 |
# data augmentation 使用的是 tensorflow->keras |
|
|
59 |
# [3] gene后一定要跟fit()均值化,否则会提示: |
|
|
60 |
# F:\AI_Outils\Anaconda\1\envs\opencv_CPU\Lib\site-packages\keras\src\legacy\preprocessing\image.py:1263: UserWarning: This ImageDataGenerator specifies `featurewise_center`, but it hasn't been fit on any training data. Fit it first by calling `.fit(numpy_data)`. |
|
|
61 |
|
|
|
62 |
|
|
|
63 |
# data augmentation for train |
|
|
64 |
def train_generator(dataset_path, type): |
|
|
65 |
data_path = os.path.join(dataset_path, type) |
|
|
66 |
data_pre_path = os.path.join(dataset_path, f"{type}_generator") |
|
|
67 |
img_png_path = os.path.join(data_pre_path, "images") |
|
|
68 |
mask_png_path = os.path.join(data_pre_path, "masks") |
|
|
69 |
|
|
|
70 |
PATH = { |
|
|
71 |
data_pre_path, |
|
|
72 |
img_png_path, |
|
|
73 |
mask_png_path, |
|
|
74 |
} |
|
|
75 |
for path in PATH: |
|
|
76 |
os.makedirs(path, exist_ok=True) |
|
|
77 |
|
|
|
78 |
# 3.1 define an image_generator: to perform various transformations on object |
|
|
79 |
generator_args = dict( |
|
|
80 |
rotation_range=0.1, |
|
|
81 |
width_shift_range=0.05, |
|
|
82 |
height_shift_range=0.05, |
|
|
83 |
shear_range=0.05, |
|
|
84 |
zoom_range=0.05, |
|
|
85 |
horizontal_flip=False, |
|
|
86 |
vertical_flip=False, |
|
|
87 |
) |
|
|
88 |
generator_image = ImageDataGenerator(generator_args) |
|
|
89 |
generator_mask = ImageDataGenerator(generator_args) |
|
|
90 |
|
|
|
91 |
# 3.2 implement further data augmentation for image & mask |
|
|
92 |
generation_image = generator_image.flow_from_directory( |
|
|
93 |
directory=data_path, |
|
|
94 |
classes=["images"], |
|
|
95 |
class_mode=None, |
|
|
96 |
color_mode="grayscale", |
|
|
97 |
target_size=(512, 512), |
|
|
98 |
batch_size=2, |
|
|
99 |
save_to_dir=os.path.join(data_pre_path, "images"), |
|
|
100 |
# save_prefix='ct_', |
|
|
101 |
seed=123, |
|
|
102 |
) |
|
|
103 |
generation_mask = generator_mask.flow_from_directory( |
|
|
104 |
directory=data_path, |
|
|
105 |
classes=["masks"], |
|
|
106 |
class_mode=None, |
|
|
107 |
color_mode="grayscale", |
|
|
108 |
target_size=(512, 512), |
|
|
109 |
batch_size=2, |
|
|
110 |
save_to_dir=os.path.join(data_pre_path, "masks"), |
|
|
111 |
# save_prefix='mask_', |
|
|
112 |
seed=123, |
|
|
113 |
) |
|
|
114 |
generation = zip(generation_image, generation_mask) |
|
|
115 |
|
|
|
116 |
print("2--------------------------") |
|
|
117 |
i = 0 |
|
|
118 |
# 3.3 image normalization (image -> not normalized yet, mask -> binary) |
|
|
119 |
for image, mask in generation: |
|
|
120 |
''' |
|
|
121 |
i = i + 1 |
|
|
122 |
|
|
|
123 |
# output image data to TXT |
|
|
124 |
arr = np.array(image[0][:, :, 0]) |
|
|
125 |
np.savetxt("array_0.txt", arr) |
|
|
126 |
print(f"image: min: {np.nanmin(arr)}, max: {np.nanmax(arr)}.") |
|
|
127 |
|
|
|
128 |
# output image data to TXT |
|
|
129 |
arr = np.array(normalization(image)[0][:, :, 0]) |
|
|
130 |
np.savetxt("array_1.txt", arr) |
|
|
131 |
print( |
|
|
132 |
f"normalization_image: min: {np.nanmin(arr)}, max: {np.nanmax(arr)}." |
|
|
133 |
) |
|
|
134 |
|
|
|
135 |
print(image.shape, mask.shape) # (2, 256, 256, 1) (2, 256, 256, 1) --> batch_size=2 |
|
|
136 |
|
|
|
137 |
# visualization |
|
|
138 |
data_img_slices = [ |
|
|
139 |
image[0][:, :, 0], |
|
|
140 |
normalization(image)[0][:, :, 0], |
|
|
141 |
mask[0][:, :, 0], |
|
|
142 |
image[1][:, :, 0], |
|
|
143 |
normalization(image)[1][:, :, 0], |
|
|
144 |
mask[1][:, :, 0], |
|
|
145 |
] |
|
|
146 |
draw_image(data_img_slices, 1, 6, None) |
|
|
147 |
|
|
|
148 |
if i == 1: |
|
|
149 |
break |
|
|
150 |
''' |
|
|
151 |
|
|
|
152 |
yield (normalization(image), mask) |
|
|
153 |
# image[0][:, :, 0] = normalization(image[0][:, :, 0]) |
|
|
154 |
# image[1][:, :, 0] = normalization(image[1][:, :, 0]) |
|
|
155 |
print("Further data augmentation was completed successfully.") |
|
|
156 |
|
|
|
157 |
|
|
|
158 |
def binarization(data): |
|
|
159 |
""" |
|
|
160 |
Binarization: Converts data to only two values, e.g. 0 & 1 |
|
|
161 |
To do: To highlight certain features in the image |
|
|
162 |
Processing: x'[x/255.0 > 0.5] = 1.0 |
|
|
163 |
x'[x/255.0 <= 0.5] = 0.0 |
|
|
164 |
""" |
|
|
165 |
data_binary = data / 255.0 |
|
|
166 |
data_binary[data_binary > 0.5] = 1.0 |
|
|
167 |
data_binary[data_binary <= 0.5] = 0.0 |
|
|
168 |
return data_binary |
|
|
169 |
|
|
|
170 |
|
|
|
171 |
def standardization(data): |
|
|
172 |
""" |
|
|
173 |
Standardization: Converts the data into a new distribution with a mean of 0 and a standard deviation of 1 |
|
|
174 |
To do: to have comparability between different features (if data feature value range/unit is quite different --> perform standardization). |
|
|
175 |
--> Standardization does not change the distribution of feature data |
|
|
176 |
Processing: x' = (x - mean) / std |
|
|
177 |
""" |
|
|
178 |
mean = np.mean(data) |
|
|
179 |
std = np.std(data) |
|
|
180 |
data_std = (data - mean) / std |
|
|
181 |
return data_std |
|
|
182 |
|
|
|
183 |
|
|
|
184 |
def normalization(data): |
|
|
185 |
""" |
|
|
186 |
Normalization: Scale the data to a specific range, e.g. 0-1 |
|
|
187 |
To do: To make the influence of each feature on the target variable consistent. |
|
|
188 |
--> Data normalization changes the distribution of feature data |
|
|
189 |
Processing: x' = (x - min)/(max - min) |
|
|
190 |
Note: In the medical field, normalization is generally performed --> to accelerate the convergence of the network and make the model more stable |
|
|
191 |
""" |
|
|
192 |
min = np.nanmin(data) |
|
|
193 |
max = np.nanmax(data) |
|
|
194 |
data_nor = (data - min) / (max - min) |
|
|
195 |
return data_nor |
|
|
196 |
|
|
|
197 |
|
|
|
198 |
def u_net(input_size = (512, 512, 1), path=None): |
|
|
199 |
# layer 1-1 |
|
|
200 |
inputs_L1_1 = Input(input_size) |
|
|
201 |
conv1_L1_1 = Conv2D(64, 3, activation="relu", padding="same", kernel_initializer="he_normal")(inputs_L1_1) # filters: 64, kernel_size:3x3, kernel_initializer: use normal distribution to initializer Weights of kernel |
|
|
202 |
conv2_L1_1 = Conv2D(64, 3, activation="relu", padding="same", kernel_initializer="he_normal")(conv1_L1_1) |
|
|
203 |
pool1_L1_1 = MaxPool2D(pool_size=(2,2))(conv2_L1_1) |
|
|
204 |
|
|
|
205 |
# layer 2-1 |
|
|
206 |
conv3_L2_1 = Conv2D(128, 3, activation="relu", padding="same", kernel_initializer="he_normal")(pool1_L1_1) |
|
|
207 |
conv4_L2_1 = Conv2D(128, 3, activation="relu", padding="same", kernel_initializer="he_normal")(conv3_L2_1) |
|
|
208 |
pool2_L2_1 = MaxPool2D(pool_size=(2, 2))(conv4_L2_1) |
|
|
209 |
|
|
|
210 |
# layer 3-1 |
|
|
211 |
conv5_L3_1 = Conv2D(256, 3, activation="relu", padding="same", kernel_initializer="he_normal")(pool2_L2_1) |
|
|
212 |
conv6_L3_1 = Conv2D(256, 3, activation="relu", padding="same", kernel_initializer="he_normal")(conv5_L3_1) |
|
|
213 |
pool3_L3_1 = MaxPool2D(pool_size=(2, 2))(conv6_L3_1) |
|
|
214 |
|
|
|
215 |
# layer 4-1 |
|
|
216 |
conv7_L4_1 = Conv2D(512, 3, activation="relu", padding="same", kernel_initializer="he_normal")(pool3_L3_1) |
|
|
217 |
conv8_L4_1 = Conv2D(512, 3, activation="relu", padding="same", kernel_initializer="he_normal")(conv7_L4_1) |
|
|
218 |
pool4_L4_1 = MaxPool2D(pool_size=(2, 2))(conv8_L4_1) |
|
|
219 |
|
|
|
220 |
# layer 5 |
|
|
221 |
conv9_L5 = Conv2D(1024, 3, activation="relu", padding="same", kernel_initializer="he_normal")(pool4_L4_1) |
|
|
222 |
conv10_L5 = Conv2D(1024, 3, activation="relu", padding="same", kernel_initializer="he_normal")(conv9_L5) |
|
|
223 |
up1_L5 = UpSampling2D(size=(2, 2))(conv10_L5) # deconvolution |
|
|
224 |
|
|
|
225 |
# layer 4-2 |
|
|
226 |
conv11_L4_2 = Conv2D(512, 3, activation="relu", padding="same", kernel_initializer="he_normal")(concatenate([up1_L5, conv8_L4_1], axis=3)) # concatenation |
|
|
227 |
conv12_L4_2 = Conv2D(512, 3, activation="relu", padding="same", kernel_initializer="he_normal")(conv11_L4_2) |
|
|
228 |
up2_L4_2 = UpSampling2D(size=(2, 2))(conv12_L4_2) |
|
|
229 |
|
|
|
230 |
# layer 3-2 |
|
|
231 |
conv13_L3_2 = Conv2D(256, 3, activation="relu", padding="same", kernel_initializer="he_normal")(concatenate([up2_L4_2, conv6_L3_1], axis=3)) |
|
|
232 |
conv14_L3_2 = Conv2D(256, 3, activation="relu", padding="same", kernel_initializer="he_normal")(conv13_L3_2) |
|
|
233 |
up3_L3_2 = UpSampling2D(size=(2, 2))(conv14_L3_2) |
|
|
234 |
|
|
|
235 |
# layer 2-2 |
|
|
236 |
conv15_L2_2 = Conv2D(128, 3, activation="relu", padding="same", kernel_initializer="he_normal")(concatenate([up3_L3_2, conv4_L2_1], axis=3)) |
|
|
237 |
conv16_L2_2 = Conv2D(128, 3, activation="relu", padding="same", kernel_initializer="he_normal")(conv15_L2_2) |
|
|
238 |
up4_L2_2 = UpSampling2D(size=(2, 2))(conv16_L2_2) |
|
|
239 |
|
|
|
240 |
# layer 1-2 |
|
|
241 |
conv17_L1_2 = Conv2D(64, 3, activation="relu", padding="same", kernel_initializer="he_normal")(concatenate([up4_L2_2, conv2_L1_1], axis=3)) |
|
|
242 |
conv18_L1_2 = Conv2D(64, 3, activation="relu", padding="same", kernel_initializer="he_normal")(conv17_L1_2) |
|
|
243 |
outputs_L1_2 = Conv2D(1, 1, activation="sigmoid")(conv18_L1_2) |
|
|
244 |
|
|
|
245 |
# build model |
|
|
246 |
model = Model(inputs = inputs_L1_1, outputs = outputs_L1_2) |
|
|
247 |
|
|
|
248 |
# compile model |
|
|
249 |
''' |
|
|
250 |
Loss : 0-1 binary cross-entropy (binary_crossentropy) |
|
|
251 |
Optimizer: Adaptive Descent (Adam) |
|
|
252 |
Callback : After each epoch is trained, autosave a best pre-trained model(optimal weights). (keras.callbacks.ModelCheckpoint) |
|
|
253 |
''' |
|
|
254 |
model.compile(optimizer="adam", loss="binary_crossentropy", metrics=["accuracy"]) |
|
|
255 |
|
|
|
256 |
return model |
|
|
257 |
|
|
|
258 |
|
|
|
259 |
class ShowMask(keras.callbacks.Callback): |
|
|
260 |
def __init__(self): |
|
|
261 |
super().__init__() |
|
|
262 |
|
|
|
263 |
def on_epoch_end(self, epoch, logs=None): |
|
|
264 |
print() |
|
|
265 |
idx = 0 |
|
|
266 |
for img, mask in gene: |
|
|
267 |
compare_list = [img[0], mask[0], model.predict(img[0].reshape(1, 512, 512, 1))[0]] |
|
|
268 |
for i in range(0, len(compare_list)): |
|
|
269 |
plt.subplot(1, 3, i+1) |
|
|
270 |
plt.imshow(compare_list[i], cmap="gray") |
|
|
271 |
plt.axis(False) |
|
|
272 |
# plt.show() |
|
|
273 |
plt.savefig(f"compare_{idx}.png") |
|
|
274 |
idx = idx + 1 |
|
|
275 |
break |
|
|
276 |
# return super().on_epoch_end(epoch, logs) |
|
|
277 |
|
|
|
278 |
|
|
|
279 |
# Test |
|
|
280 |
|
|
|
281 |
# dataset: data augmentation for train data |
|
|
282 |
UNETDataset_path = "./UNETDataset" |
|
|
283 |
gene = train_generator(UNETDataset_path, "train") # could be used as input to the model and directly as training |
|
|
284 |
|
|
|
285 |
# model params |
|
|
286 |
steps_per_epoch = 50 |
|
|
287 |
epochs = 100 |
|
|
288 |
model_name = f"u_net-512-512-1-pneumonia_{epochs}_{steps_per_epoch}.keras" |
|
|
289 |
models_path = "./models/" |
|
|
290 |
model_path = os.path.join(models_path, model_name) |
|
|
291 |
model_ckpt = keras.callbacks.ModelCheckpoint(model_path, save_best_only=False, verbose=1) |
|
|
292 |
|
|
|
293 |
# train |
|
|
294 |
K.clear_session() # keras |
|
|
295 |
|
|
|
296 |
model = u_net(path=model_path) # structure |
|
|
297 |
# print(model.summary()) |
|
|
298 |
|
|
|
299 |
# model.fit(gene, steps_per_epoch=steps_per_epoch, epochs=epochs, callbacks=[model_ckpt, ShowMask()]) # train |
|
|
300 |
model.fit( |
|
|
301 |
gene, |
|
|
302 |
steps_per_epoch=steps_per_epoch, |
|
|
303 |
epochs=epochs, |
|
|
304 |
callbacks=[model_ckpt], |
|
|
305 |
) # train |
|
|
306 |
|
|
|
307 |
|
|
|
308 |
|
|
|
309 |
# 思路: |
|
|
310 |
# 1. Dataset_mini -> CPU OK, GPU KO |
|
|
311 |
# 2. cudnn v8 -> GPU KO |
|
|
312 |
# 3. 缩减 input_size -> 512->256 |
|
|
313 |
# 4. 缩减 unet structure |
|
|
314 |
|
|
|
315 |
# evalution |
|
|
316 |
data_val_generator_path = os.path.join(UNETDataset_path, "val_generator") |
|
|
317 |
compare_path = os.path.join(data_val_generator_path, "compare") |
|
|
318 |
PATH = {data_val_generator_path, compare_path} |
|
|
319 |
for path in PATH: |
|
|
320 |
os.makedirs(path, exist_ok=True) |
|
|
321 |
|
|
|
322 |
# model_test= keras.models.load_model(model_path) |
|
|
323 |
''' |
|
|
324 |
gene = train_generator(UNETDataset_path, "val") |
|
|
325 |
idx = 0 |
|
|
326 |
for img, mask in gene: |
|
|
327 |
predict_mask = model_test.predict(img)[0] |
|
|
328 |
# predict_mask_np = (predict_mask * 255).numpy() |
|
|
329 |
|
|
|
330 |
_, real_mask = cv2.threshold(mask[0], 127, 255, 0) |
|
|
331 |
real_mask = (real_mask).astype('uint8') |
|
|
332 |
real_contours, _ = cv2.findContours(real_mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) |
|
|
333 |
real_overlap_img = cv2.drawContours(img[0].copy(), real_contours, -1, (0, 255, 0), 2) |
|
|
334 |
|
|
|
335 |
_, pred_mask = cv2.threshold((predict_mask * 255).astype("uint8"), 127, 255, 0) |
|
|
336 |
pred_mask = (pred_mask).astype('uint8') |
|
|
337 |
pred_contours, _ = cv2.findContours(pred_mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) |
|
|
338 |
pred_overlap_img = cv2.drawContours(img[0].copy(), pred_contours, -1, (255, 0, 0), 2) |
|
|
339 |
|
|
|
340 |
compare_list = [img[0], pred_mask, real_overlap_img, pred_overlap_img] |
|
|
341 |
|
|
|
342 |
for i in range(0, len(compare_list)): |
|
|
343 |
plt.subplot(1, 4, i+1) |
|
|
344 |
plt.imshow(compare_list[i], cmap="gray") |
|
|
345 |
plt.axis(False) |
|
|
346 |
# plt.show() |
|
|
347 |
save_path = os.path.join(compare_path, f"compare_{idx}.png") |
|
|
348 |
plt.savefig(save_path) |
|
|
349 |
|
|
|
350 |
idx = idx + 1 |
|
|
351 |
''' |
|
|
352 |
|
|
|
353 |
''' |
|
|
354 |
# test save compare_png |
|
|
355 |
|
|
|
356 |
images_path = os.path.join(data_val_generator_path, "images") |
|
|
357 |
masks_path = os.path.join(data_val_generator_path, "masks") |
|
|
358 |
|
|
|
359 |
idx = 0 |
|
|
360 |
for png_name in os.listdir(images_path): |
|
|
361 |
# predict_mask = model_test.predict(img)[0] |
|
|
362 |
|
|
|
363 |
img = cv2.imread(os.path.join(images_path, png_name)) |
|
|
364 |
mask = cv2.imread(os.path.join(masks_path, png_name), cv2.IMREAD_GRAYSCALE) |
|
|
365 |
|
|
|
366 |
_, real_mask = cv2.threshold(mask, 127, 255, 0) |
|
|
367 |
real_mask = (real_mask).astype("uint8") |
|
|
368 |
real_contours, _ = cv2.findContours( |
|
|
369 |
real_mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE |
|
|
370 |
) |
|
|
371 |
real_overlap_img = cv2.drawContours( |
|
|
372 |
img.copy(), real_contours, -1, (0, 255, 0), 2 |
|
|
373 |
) |
|
|
374 |
|
|
|
375 |
compare_list = [img, real_overlap_img] |
|
|
376 |
|
|
|
377 |
for i in range(0, len(compare_list)): |
|
|
378 |
plt.subplot(1, 2, i + 1) |
|
|
379 |
plt.imshow(compare_list[i], cmap="gray") |
|
|
380 |
plt.axis(False) |
|
|
381 |
# plt.show() |
|
|
382 |
save_path = os.path.join(compare_path, f"compare_{idx}.png") |
|
|
383 |
plt.savefig(save_path) |
|
|
384 |
|
|
|
385 |
idx = idx + 1 |
|
|
386 |
''' |