|
a |
|
b/drunet/segment.py |
|
|
1 |
import os |
|
|
2 |
import time |
|
|
3 |
import argparse |
|
|
4 |
import pathlib |
|
|
5 |
|
|
|
6 |
import tqdm |
|
|
7 |
import cv2 as cv |
|
|
8 |
import numpy as np |
|
|
9 |
import pandas as pd |
|
|
10 |
import tensorflow as tf |
|
|
11 |
from tensorflow import keras |
|
|
12 |
import matplotlib.pyplot as plt |
|
|
13 |
from tensorflow.compat.v1 import ConfigProto |
|
|
14 |
from tensorflow.compat.v1 import InteractiveSession |
|
|
15 |
# Custom package |
|
|
16 |
import data |
|
|
17 |
import loss |
|
|
18 |
import utils |
|
|
19 |
import module |
|
|
20 |
import performance |
|
|
21 |
from model import dr_unet |
|
|
22 |
|
|
|
23 |
config = ConfigProto() |
|
|
24 |
config.gpu_options.allow_growth = True |
|
|
25 |
|
|
|
26 |
# 1. Parameter settings |
|
|
27 |
parser = argparse.ArgumentParser(description="Segment Use Args") |
|
|
28 |
parser.add_argument('--model-name', default='DR_UNet', type=str) |
|
|
29 |
parser.add_argument('--dims', default=32, type=int) |
|
|
30 |
parser.add_argument('--epochs', default=50, type=int) |
|
|
31 |
parser.add_argument('--batch-size', default=16, type=int) |
|
|
32 |
parser.add_argument('--lr', default=2e-4, type=float) |
|
|
33 |
|
|
|
34 |
# Training data, testing, verification parameter settings |
|
|
35 |
parser.add_argument('--height', default=256, type=int) |
|
|
36 |
parser.add_argument('--width', default=256, type=int) |
|
|
37 |
parser.add_argument('--channel', default=1, type=int) |
|
|
38 |
parser.add_argument('--pred-height', default=4 * 256, type=int) |
|
|
39 |
parser.add_argument('--pred-width', default=4 * 256, type=int) |
|
|
40 |
parser.add_argument('--total-samples', default=5000, type=int) |
|
|
41 |
parser.add_argument('--invalid-samples', default=1000, type=int) |
|
|
42 |
parser.add_argument('--regularize', default=False, type=bool) |
|
|
43 |
parser.add_argument('--record-dir', default=r'', type=str, help='the save dir of tfrecord') |
|
|
44 |
parser.add_argument('--train-record-name', type=str, default=r'train_data', help='the train record save name') |
|
|
45 |
parser.add_argument('--test-image-dir', default=r'', type=str, help='the path of test images dir') |
|
|
46 |
parser.add_argument('--invalid-record-name', type=str, default=r'test_data', help='the invalid record save name') |
|
|
47 |
parser.add_argument('--gt-mask-dir', default=r'', type=str, help='the ground truth dir of validation set') |
|
|
48 |
parser.add_argument('--invalid-volume-dir', default=r'', type=str, help='estimation bleeding volume') |
|
|
49 |
args = parser.parse_args() |
|
|
50 |
|
|
|
51 |
|
|
|
52 |
class Segmentation: |
|
|
53 |
def __init__(self, params): |
|
|
54 |
self.params = params |
|
|
55 |
self.input_shape = [params.height, params.width, params.channel] |
|
|
56 |
self.mask_shape = [params.height, params.width, 1] |
|
|
57 |
self.model_name = params.model_name |
|
|
58 |
self.crop_height = params.pred_height |
|
|
59 |
self.crop_width = params.pred_width |
|
|
60 |
self.regularize = params.regularize |
|
|
61 |
|
|
|
62 |
# Obtain a segmentation model |
|
|
63 |
self.seg_model = dr_unet.dr_unet(input_shape=self.input_shape, dims=params.dims) |
|
|
64 |
self.seg_model.summary() |
|
|
65 |
|
|
|
66 |
# Optimization function |
|
|
67 |
self.optimizer = tf.keras.optimizers.Adam(lr=params.lr) |
|
|
68 |
|
|
|
69 |
# Every epoch, predict invalid-images to test the segmentation performance of the model |
|
|
70 |
self.save_dir = str(params.model_name).upper() |
|
|
71 |
self.weight_save_dir = os.path.join(self.save_dir, 'checkpoint') |
|
|
72 |
self.pred_invalid_save_dir = os.path.join(self.save_dir, 'invalid_pred') |
|
|
73 |
self.invalid_crop_save_dir = os.path.join(self.save_dir, 'invalid_pred_crop') |
|
|
74 |
self.pred_test_save_dir = os.path.join(self.save_dir, 'test_pred') |
|
|
75 |
utils.check_file([ |
|
|
76 |
self.save_dir, self.weight_save_dir, self.pred_invalid_save_dir, |
|
|
77 |
self.pred_test_save_dir, self.invalid_crop_save_dir] |
|
|
78 |
) |
|
|
79 |
|
|
|
80 |
# Save model parameters |
|
|
81 |
train_steps = tf.Variable(0, tf.int32) |
|
|
82 |
self.save_ckpt = tf.train.Checkpoint( |
|
|
83 |
train_steps=train_steps, seg_model=self.seg_model, model_optimizer=self.optimizer) |
|
|
84 |
self.save_manger = tf.train.CheckpointManager( |
|
|
85 |
self.save_ckpt, directory=self.weight_save_dir, max_to_keep=1) |
|
|
86 |
|
|
|
87 |
# Set the loss function |
|
|
88 |
self.loss_fun = loss.bce_dice_loss |
|
|
89 |
|
|
|
90 |
def load_model(self): |
|
|
91 |
if self.save_manger.latest_checkpoint: |
|
|
92 |
self.save_ckpt.restore(self.save_manger.latest_checkpoint) |
|
|
93 |
print('Loading model: {}'.format(self.save_manger.latest_checkpoint)) |
|
|
94 |
else: |
|
|
95 |
print('Retrain the model!') |
|
|
96 |
return |
|
|
97 |
|
|
|
98 |
@tf.function |
|
|
99 |
def train_step(self, inputs, target): |
|
|
100 |
tf.keras.backend.set_learning_phase(True) |
|
|
101 |
|
|
|
102 |
with tf.GradientTape() as tape: |
|
|
103 |
pred_mask = self.seg_model(inputs) |
|
|
104 |
loss = self.loss_fun(target, pred_mask) |
|
|
105 |
if self.regularize: |
|
|
106 |
loss = tf.reduce_sum(loss) + tf.reduce_sum(self.seg_model.losses) |
|
|
107 |
gradient = tape.gradient(loss, self.seg_model.trainable_variables) |
|
|
108 |
self.optimizer.apply_gradients(zip(gradient, self.seg_model.trainable_variables)) |
|
|
109 |
return tf.reduce_mean(loss) |
|
|
110 |
|
|
|
111 |
@tf.function |
|
|
112 |
def inference(self, inputs): |
|
|
113 |
tf.keras.backend.set_learning_phase(True) |
|
|
114 |
pred = self.seg_model(inputs) |
|
|
115 |
return pred |
|
|
116 |
|
|
|
117 |
@staticmethod |
|
|
118 |
def calculate_volume_by_mask(mask_dir, save_dir, model_name, dpi=96, thickness=0.45): |
|
|
119 |
all_mask_file_paths = utils.list_file(mask_dir) |
|
|
120 |
|
|
|
121 |
pd_record = pd.DataFrame(columns=['file_name', 'Volume']) |
|
|
122 |
for file_dir in tqdm.tqdm(all_mask_file_paths): |
|
|
123 |
file_name = pathlib.Path(file_dir).stem |
|
|
124 |
|
|
|
125 |
each_blood_volume = module.calculate_volume(file_dir, thickness=thickness, dpi=dpi) |
|
|
126 |
pd_record = pd_record.append({'file_name': file_name, 'Volume': each_blood_volume}, ignore_index=True) |
|
|
127 |
pd_record.to_csv( |
|
|
128 |
os.path.join(save_dir, '{}_{}.csv'.format(model_name, file_name)), index=True, header=True) |
|
|
129 |
return |
|
|
130 |
|
|
|
131 |
def predict_blood_volume(self, input_dir, save_dir, calc_nums=-1, dpi=96, thickness=0.45): |
|
|
132 |
""" |
|
|
133 |
:param input_dir: The directory for testing bleeding volume images, |
|
|
134 |
there are multiple folders under the directory, each folder represents a CT image of a patient |
|
|
135 |
:param save_dir: The predicted segmented image save directory |
|
|
136 |
:param calc_nums: predict how many images in the folder |
|
|
137 |
:param dpi: image parameters |
|
|
138 |
:param thickness: slice thickness |
|
|
139 |
""" |
|
|
140 |
# Loading weights of model |
|
|
141 |
self.load_model() |
|
|
142 |
save_pred_images_dir = os.path.join(save_dir, 'pred_images') |
|
|
143 |
save_pred_csv_dir = os.path.join(save_dir, 'pred_csv') |
|
|
144 |
utils.check_file([save_pred_images_dir, save_pred_csv_dir]) |
|
|
145 |
all_file_dirs = utils.list_file(input_dir) |
|
|
146 |
|
|
|
147 |
cost_time_list = [] |
|
|
148 |
total_images = 0 |
|
|
149 |
for file_dir in tqdm.tqdm(all_file_dirs): |
|
|
150 |
file_name = pathlib.Path(file_dir).stem |
|
|
151 |
|
|
|
152 |
image_names, ori_images, normed_images = data.get_test_data( |
|
|
153 |
test_data_path=file_dir, image_shape=self.input_shape, image_nums=calc_nums) |
|
|
154 |
total_images += len(image_names) |
|
|
155 |
|
|
|
156 |
start_time = time.time() |
|
|
157 |
pred_mask = self.inference(normed_images) |
|
|
158 |
end_time = time.time() |
|
|
159 |
print('FPS: {}'.format(pred_mask.shape[0] / (end_time - start_time)), |
|
|
160 |
pred_mask.shape, end_time - start_time) |
|
|
161 |
|
|
|
162 |
denorm_pred_mask = module.reverse_pred_image(pred_mask.numpy()) # (image_nums, 256, 256, 1) |
|
|
163 |
if denorm_pred_mask.ndim == 2 and self.input_shape[-1] == 1: |
|
|
164 |
denorm_pred_mask = np.expand_dims(denorm_pred_mask, 0) |
|
|
165 |
|
|
|
166 |
drawed_images = [] |
|
|
167 |
blood_areas = [] |
|
|
168 |
pd_record = pd.DataFrame(columns=['image_name', 'Square Centimeter', 'Volume']) |
|
|
169 |
for index in range(denorm_pred_mask.shape[0]): |
|
|
170 |
drawed_image, blood_area = module.draw_contours(ori_images[index], denorm_pred_mask[index], dpi=dpi) |
|
|
171 |
drawed_images.append(drawed_image) |
|
|
172 |
blood_areas.append(blood_area) |
|
|
173 |
pd_record = pd_record.append({'image_name': image_names[index], 'Square Centimeter': blood_area}, |
|
|
174 |
ignore_index=True) |
|
|
175 |
|
|
|
176 |
one_pred_save_dir = os.path.join(save_pred_images_dir, file_name) |
|
|
177 |
module.save_invalid_data(ori_images, drawed_images, denorm_pred_mask, |
|
|
178 |
image_names, reshape=True, save_dir=one_pred_save_dir) |
|
|
179 |
|
|
|
180 |
# Calculate the amount of bleeding based on the area of each layer of hematoma |
|
|
181 |
blood_volume = module.count_volume(blood_areas, thickness=thickness) |
|
|
182 |
pd_record = pd_record.append({'Volume': blood_volume}, ignore_index=True) |
|
|
183 |
pd_record.to_csv(os.path.join(save_pred_csv_dir, '{}_{}.csv'.format(self.model_name, file_name)), |
|
|
184 |
index=True, header=True) |
|
|
185 |
cost_time_list.append(end_time - start_time) |
|
|
186 |
print('FileName: {} time: {}'.format(file_name, end_time - start_time)) |
|
|
187 |
print('total_time: {:.2f}, mean_time: {:.2f}, total_images: {}'.format( |
|
|
188 |
np.sum(cost_time_list), np.mean(cost_time_list), total_images)) |
|
|
189 |
return |
|
|
190 |
|
|
|
191 |
def predict_and_save(self, input_dir, save_dir, calc_nums=-1, batch_size=16): |
|
|
192 |
""" predict bleeding image and save |
|
|
193 |
:param input_dir: There are several images waiting to be tested under the input_dir folder |
|
|
194 |
:param save_dir: The file directory where the segmented image predicted by the model is saved |
|
|
195 |
:param calc_nums: How many images are taken from the directory to participate in the calculation |
|
|
196 |
:param batch_size: how many images to test each time |
|
|
197 |
:return: |
|
|
198 |
""" |
|
|
199 |
mask_save_dir = os.path.join(save_dir, 'pred_mask') |
|
|
200 |
drawed_save_dir = os.path.join(save_dir, 'drawed_image') |
|
|
201 |
utils.check_file([mask_save_dir, drawed_save_dir]) |
|
|
202 |
self.load_model() |
|
|
203 |
|
|
|
204 |
test_image_list = utils.list_file(input_dir) |
|
|
205 |
for index in range(len(test_image_list) // 128 + 1): |
|
|
206 |
input_test_list = test_image_list[index * 128:(index + 1) * 128] |
|
|
207 |
|
|
|
208 |
image_names, ori_images, normed_images = data.get_test_data( |
|
|
209 |
test_data_path=input_test_list, image_shape=self.input_shape, image_nums=-1, |
|
|
210 |
) |
|
|
211 |
if calc_nums != -1: |
|
|
212 |
ori_images = ori_images[:calc_nums] |
|
|
213 |
normed_images = normed_images[:calc_nums] |
|
|
214 |
image_names = image_names[:calc_nums] |
|
|
215 |
|
|
|
216 |
inference_times = normed_images.shape[0] // batch_size + 1 |
|
|
217 |
for inference_time in range(inference_times): |
|
|
218 |
this_normed_images = normed_images[ |
|
|
219 |
inference_time * batch_size:(inference_time + 1) * batch_size, ...] |
|
|
220 |
this_ori_images = ori_images[ |
|
|
221 |
inference_time * batch_size:(inference_time + 1) * batch_size, ...] |
|
|
222 |
this_image_names = image_names[ |
|
|
223 |
inference_time * batch_size:(inference_time + 1) * batch_size] |
|
|
224 |
|
|
|
225 |
this_pred_mask = self.inference(this_normed_images) |
|
|
226 |
this_denorm_pred_mask = module.reverse_pred_image(this_pred_mask.numpy()) |
|
|
227 |
if ori_images.shape[0] == 1: |
|
|
228 |
this_denorm_pred_mask = np.expand_dims(this_denorm_pred_mask, 0) |
|
|
229 |
|
|
|
230 |
for i in range(this_denorm_pred_mask.shape[0]): |
|
|
231 |
bin_denorm_pred_mask = this_denorm_pred_mask[i] |
|
|
232 |
this_drawed_image, this_blood_area = module.draw_contours( |
|
|
233 |
this_ori_images[i], bin_denorm_pred_mask, dpi=96 |
|
|
234 |
) |
|
|
235 |
cv.imwrite(os.path.join( |
|
|
236 |
mask_save_dir, '{}'.format(this_image_names[i])), bin_denorm_pred_mask |
|
|
237 |
) |
|
|
238 |
cv.imwrite(os.path.join( |
|
|
239 |
drawed_save_dir, '{}'.format(this_image_names[i])), this_drawed_image |
|
|
240 |
) |
|
|
241 |
return |
|
|
242 |
|
|
|
243 |
def train(self, start_epoch=1): |
|
|
244 |
# get training dataset |
|
|
245 |
train_data = data.get_tfrecord_data( |
|
|
246 |
self.params.record_dir, self.params.train_record_name, |
|
|
247 |
self.input_shape, batch_size=self.params.batch_size) |
|
|
248 |
self.load_model() |
|
|
249 |
|
|
|
250 |
pd_record = pd.DataFrame(columns=['Epoch', 'Iteration', 'Loss', 'Time']) |
|
|
251 |
data_name, original_test_image, norm_test_image = data.get_test_data( |
|
|
252 |
test_data_path=self.params.test_image_dir, image_shape=self.input_shape, image_nums=-1 |
|
|
253 |
) |
|
|
254 |
|
|
|
255 |
start_time = time.time() |
|
|
256 |
best_dice = 0.0 |
|
|
257 |
for epoch in range(start_epoch, self.params.epochs): |
|
|
258 |
for train_image, gt_mask in tqdm.tqdm( |
|
|
259 |
train_data, total=self.params.total_samples // self.params.batch_size): |
|
|
260 |
self.save_ckpt.train_steps.assign_add(1) |
|
|
261 |
iteration = self.save_ckpt.train_steps.numpy() |
|
|
262 |
|
|
|
263 |
# training step |
|
|
264 |
train_loss = self.train_step(train_image, gt_mask) |
|
|
265 |
if iteration % 100 == 0: |
|
|
266 |
print('Epoch: {}, Iteration: {}, Loss: {:.2f}, Time: {:.2f} s'.format( |
|
|
267 |
epoch, iteration, train_loss, time.time() - start_time)) |
|
|
268 |
|
|
|
269 |
# test step |
|
|
270 |
test_pred = self.inference(norm_test_image) |
|
|
271 |
module.save_images( |
|
|
272 |
image_shape=self.mask_shape, pred=test_pred, |
|
|
273 |
save_path=self.pred_test_save_dir, index=iteration, split=False |
|
|
274 |
) |
|
|
275 |
pd_record = pd_record.append({ |
|
|
276 |
'Epoch': epoch, 'Iteration': iteration, 'Loss': train_loss.numpy(), |
|
|
277 |
'Time': time.time() - start_time}, ignore_index=True |
|
|
278 |
) |
|
|
279 |
pd_record.to_csv(os.path.join( |
|
|
280 |
self.save_dir, '{}_record.csv'.format(self.params.model_name)), index=True, header=True |
|
|
281 |
) |
|
|
282 |
|
|
|
283 |
m_dice = self.invalid(epoch) |
|
|
284 |
if m_dice > best_dice: |
|
|
285 |
best_dice = m_dice |
|
|
286 |
print('Best Dice:{}'.format(best_dice)) |
|
|
287 |
self.save_manger.save(checkpoint_number=epoch) |
|
|
288 |
return |
|
|
289 |
|
|
|
290 |
def invalid(self, epoch): |
|
|
291 |
invalid_data = data.get_tfrecord_data( |
|
|
292 |
self.params.record_dir, self.params.invalid_record_name, |
|
|
293 |
self.input_shape, batch_size=self.params.batch_size, shuffle=False) |
|
|
294 |
|
|
|
295 |
epoch_pred_save_dir = None |
|
|
296 |
for index, (invalid_image, invalid_mask) in enumerate( |
|
|
297 |
tqdm.tqdm(invalid_data, total=self.params.invalid_samples // self.params.batch_size + 1)): |
|
|
298 |
invalid_pred = self.inference(invalid_image) |
|
|
299 |
epoch_pred_save_dir = os.path.join(self.pred_invalid_save_dir, f'epoch_{epoch}') |
|
|
300 |
module.save_images( |
|
|
301 |
image_shape=self.mask_shape, pred=invalid_pred, |
|
|
302 |
save_path=epoch_pred_save_dir, index=f'{index}', split=False |
|
|
303 |
) |
|
|
304 |
|
|
|
305 |
# Test model performance |
|
|
306 |
epoch_cropped_save_dir = os.path.join( |
|
|
307 |
self.invalid_crop_save_dir, f'epoch_{epoch}' |
|
|
308 |
) |
|
|
309 |
utils.crop_image(epoch_pred_save_dir, epoch_cropped_save_dir, |
|
|
310 |
self.crop_width, self.crop_height, |
|
|
311 |
self.input_shape[0], self.input_shape[1] |
|
|
312 |
) |
|
|
313 |
m_dice, m_iou, m_precision, m_recall = performance.save_performace_to_csv( |
|
|
314 |
pred_dir=epoch_cropped_save_dir, gt_dir=self.params.gt_mask_dir, |
|
|
315 |
img_resize=(self.params.height, self.params.width), |
|
|
316 |
csv_save_name=f'{self.model_name}_epoch_{epoch}', |
|
|
317 |
csv_save_path=epoch_cropped_save_dir |
|
|
318 |
) |
|
|
319 |
return m_dice |
|
|
320 |
|