Diff of /drunet/segment.py [000000] .. [2824d6]

Switch to unified view

a b/drunet/segment.py
1
import os
2
import time
3
import argparse
4
import pathlib
5
6
import tqdm
7
import cv2 as cv
8
import numpy as np
9
import pandas as pd
10
import tensorflow as tf
11
from tensorflow import keras
12
import matplotlib.pyplot as plt
13
from tensorflow.compat.v1 import ConfigProto
14
from tensorflow.compat.v1 import InteractiveSession
15
# Custom package
16
import data
17
import loss
18
import utils
19
import module
20
import performance
21
from model import dr_unet
22
23
config = ConfigProto()
24
config.gpu_options.allow_growth = True
25
26
# 1. Parameter settings
27
parser = argparse.ArgumentParser(description="Segment Use Args")
28
parser.add_argument('--model-name', default='DR_UNet', type=str)
29
parser.add_argument('--dims', default=32, type=int)
30
parser.add_argument('--epochs', default=50, type=int)
31
parser.add_argument('--batch-size', default=16, type=int)
32
parser.add_argument('--lr', default=2e-4, type=float)
33
34
# Training data, testing, verification parameter settings
35
parser.add_argument('--height', default=256, type=int)
36
parser.add_argument('--width', default=256, type=int)
37
parser.add_argument('--channel', default=1, type=int)
38
parser.add_argument('--pred-height', default=4 * 256, type=int)
39
parser.add_argument('--pred-width', default=4 * 256, type=int)
40
parser.add_argument('--total-samples', default=5000, type=int)
41
parser.add_argument('--invalid-samples', default=1000, type=int)
42
parser.add_argument('--regularize', default=False, type=bool)
43
parser.add_argument('--record-dir', default=r'', type=str, help='the save dir of tfrecord')
44
parser.add_argument('--train-record-name', type=str, default=r'train_data', help='the train record save name')
45
parser.add_argument('--test-image-dir', default=r'', type=str, help='the path of test images dir')
46
parser.add_argument('--invalid-record-name', type=str, default=r'test_data', help='the invalid record save name')
47
parser.add_argument('--gt-mask-dir', default=r'', type=str, help='the ground truth dir of validation set')
48
parser.add_argument('--invalid-volume-dir', default=r'', type=str, help='estimation bleeding volume')
49
args = parser.parse_args()
50
51
52
class Segmentation:
53
    def __init__(self, params):
54
        self.params = params
55
        self.input_shape = [params.height, params.width, params.channel]
56
        self.mask_shape = [params.height, params.width, 1]
57
        self.model_name = params.model_name
58
        self.crop_height = params.pred_height
59
        self.crop_width = params.pred_width
60
        self.regularize = params.regularize
61
62
        # Obtain a segmentation model
63
        self.seg_model = dr_unet.dr_unet(input_shape=self.input_shape, dims=params.dims)
64
        self.seg_model.summary()
65
66
        # Optimization function
67
        self.optimizer = tf.keras.optimizers.Adam(lr=params.lr)
68
69
        # Every epoch, predict invalid-images to test the segmentation performance of the model
70
        self.save_dir = str(params.model_name).upper()
71
        self.weight_save_dir = os.path.join(self.save_dir, 'checkpoint')
72
        self.pred_invalid_save_dir = os.path.join(self.save_dir, 'invalid_pred')
73
        self.invalid_crop_save_dir = os.path.join(self.save_dir, 'invalid_pred_crop')
74
        self.pred_test_save_dir = os.path.join(self.save_dir, 'test_pred')
75
        utils.check_file([
76
            self.save_dir, self.weight_save_dir, self.pred_invalid_save_dir,
77
            self.pred_test_save_dir, self.invalid_crop_save_dir]
78
        )
79
80
        # Save model parameters
81
        train_steps = tf.Variable(0, tf.int32)
82
        self.save_ckpt = tf.train.Checkpoint(
83
            train_steps=train_steps, seg_model=self.seg_model, model_optimizer=self.optimizer)
84
        self.save_manger = tf.train.CheckpointManager(
85
            self.save_ckpt, directory=self.weight_save_dir, max_to_keep=1)
86
87
        # Set the loss function
88
        self.loss_fun = loss.bce_dice_loss
89
90
    def load_model(self):
91
        if self.save_manger.latest_checkpoint:
92
            self.save_ckpt.restore(self.save_manger.latest_checkpoint)
93
            print('Loading model: {}'.format(self.save_manger.latest_checkpoint))
94
        else:
95
            print('Retrain the model!')
96
        return
97
98
    @tf.function
99
    def train_step(self, inputs, target):
100
        tf.keras.backend.set_learning_phase(True)
101
102
        with tf.GradientTape() as tape:
103
            pred_mask = self.seg_model(inputs)
104
            loss = self.loss_fun(target, pred_mask)
105
            if self.regularize:
106
                loss = tf.reduce_sum(loss) + tf.reduce_sum(self.seg_model.losses)
107
        gradient = tape.gradient(loss, self.seg_model.trainable_variables)
108
        self.optimizer.apply_gradients(zip(gradient, self.seg_model.trainable_variables))
109
        return tf.reduce_mean(loss)
110
111
    @tf.function
112
    def inference(self, inputs):
113
        tf.keras.backend.set_learning_phase(True)
114
        pred = self.seg_model(inputs)
115
        return pred
116
117
    @staticmethod
118
    def calculate_volume_by_mask(mask_dir, save_dir, model_name, dpi=96, thickness=0.45):
119
        all_mask_file_paths = utils.list_file(mask_dir)
120
121
        pd_record = pd.DataFrame(columns=['file_name', 'Volume'])
122
        for file_dir in tqdm.tqdm(all_mask_file_paths):
123
            file_name = pathlib.Path(file_dir).stem
124
125
            each_blood_volume = module.calculate_volume(file_dir, thickness=thickness, dpi=dpi)
126
            pd_record = pd_record.append({'file_name': file_name, 'Volume': each_blood_volume}, ignore_index=True)
127
            pd_record.to_csv(
128
                os.path.join(save_dir, '{}_{}.csv'.format(model_name, file_name)), index=True, header=True)
129
        return
130
131
    def predict_blood_volume(self, input_dir, save_dir, calc_nums=-1, dpi=96, thickness=0.45):
132
        """
133
         :param input_dir: The directory for testing bleeding volume images,
134
                           there are multiple folders under the directory, each folder represents a CT image of a patient
135
         :param save_dir: The predicted segmented image save directory
136
         :param calc_nums: predict how many images in the folder
137
         :param dpi: image parameters
138
         :param thickness: slice thickness
139
        """
140
        # Loading weights of model
141
        self.load_model()
142
        save_pred_images_dir = os.path.join(save_dir, 'pred_images')
143
        save_pred_csv_dir = os.path.join(save_dir, 'pred_csv')
144
        utils.check_file([save_pred_images_dir, save_pred_csv_dir])
145
        all_file_dirs = utils.list_file(input_dir)
146
147
        cost_time_list = []
148
        total_images = 0
149
        for file_dir in tqdm.tqdm(all_file_dirs):
150
            file_name = pathlib.Path(file_dir).stem
151
152
            image_names, ori_images, normed_images = data.get_test_data(
153
                test_data_path=file_dir, image_shape=self.input_shape, image_nums=calc_nums)
154
            total_images += len(image_names)
155
156
            start_time = time.time()
157
            pred_mask = self.inference(normed_images)
158
            end_time = time.time()
159
            print('FPS: {}'.format(pred_mask.shape[0] / (end_time - start_time)),
160
                  pred_mask.shape, end_time - start_time)
161
162
            denorm_pred_mask = module.reverse_pred_image(pred_mask.numpy())  # (image_nums, 256, 256, 1)
163
            if denorm_pred_mask.ndim == 2 and self.input_shape[-1] == 1:
164
                denorm_pred_mask = np.expand_dims(denorm_pred_mask, 0)
165
166
            drawed_images = []
167
            blood_areas = []
168
            pd_record = pd.DataFrame(columns=['image_name', 'Square Centimeter', 'Volume'])
169
            for index in range(denorm_pred_mask.shape[0]):
170
                drawed_image, blood_area = module.draw_contours(ori_images[index], denorm_pred_mask[index], dpi=dpi)
171
                drawed_images.append(drawed_image)
172
                blood_areas.append(blood_area)
173
                pd_record = pd_record.append({'image_name': image_names[index], 'Square Centimeter': blood_area},
174
                                             ignore_index=True)
175
176
            one_pred_save_dir = os.path.join(save_pred_images_dir, file_name)
177
            module.save_invalid_data(ori_images, drawed_images, denorm_pred_mask,
178
                                     image_names, reshape=True, save_dir=one_pred_save_dir)
179
180
            # Calculate the amount of bleeding based on the area of each layer of hematoma
181
            blood_volume = module.count_volume(blood_areas, thickness=thickness)
182
            pd_record = pd_record.append({'Volume': blood_volume}, ignore_index=True)
183
            pd_record.to_csv(os.path.join(save_pred_csv_dir, '{}_{}.csv'.format(self.model_name, file_name)),
184
                             index=True, header=True)
185
            cost_time_list.append(end_time - start_time)
186
            print('FileName: {} time: {}'.format(file_name, end_time - start_time))
187
        print('total_time: {:.2f}, mean_time: {:.2f}, total_images: {}'.format(
188
            np.sum(cost_time_list), np.mean(cost_time_list), total_images))
189
        return
190
191
    def predict_and_save(self, input_dir, save_dir, calc_nums=-1, batch_size=16):
192
        """ predict bleeding image and save
193
         :param input_dir: There are several images waiting to be tested under the input_dir folder
194
         :param save_dir: The file directory where the segmented image predicted by the model is saved
195
         :param calc_nums: How many images are taken from the directory to participate in the calculation
196
         :param batch_size: how many images to test each time
197
         :return:
198
        """
199
        mask_save_dir = os.path.join(save_dir, 'pred_mask')
200
        drawed_save_dir = os.path.join(save_dir, 'drawed_image')
201
        utils.check_file([mask_save_dir, drawed_save_dir])
202
        self.load_model()
203
204
        test_image_list = utils.list_file(input_dir)
205
        for index in range(len(test_image_list) // 128 + 1):
206
            input_test_list = test_image_list[index * 128:(index + 1) * 128]
207
208
            image_names, ori_images, normed_images = data.get_test_data(
209
                test_data_path=input_test_list, image_shape=self.input_shape, image_nums=-1,
210
            )
211
            if calc_nums != -1:
212
                ori_images = ori_images[:calc_nums]
213
                normed_images = normed_images[:calc_nums]
214
                image_names = image_names[:calc_nums]
215
216
            inference_times = normed_images.shape[0] // batch_size + 1
217
            for inference_time in range(inference_times):
218
                this_normed_images = normed_images[
219
                                     inference_time * batch_size:(inference_time + 1) * batch_size, ...]
220
                this_ori_images = ori_images[
221
                                  inference_time * batch_size:(inference_time + 1) * batch_size, ...]
222
                this_image_names = image_names[
223
                                   inference_time * batch_size:(inference_time + 1) * batch_size]
224
225
                this_pred_mask = self.inference(this_normed_images)
226
                this_denorm_pred_mask = module.reverse_pred_image(this_pred_mask.numpy())
227
                if ori_images.shape[0] == 1:
228
                    this_denorm_pred_mask = np.expand_dims(this_denorm_pred_mask, 0)
229
230
                for i in range(this_denorm_pred_mask.shape[0]):
231
                    bin_denorm_pred_mask = this_denorm_pred_mask[i]
232
                    this_drawed_image, this_blood_area = module.draw_contours(
233
                        this_ori_images[i], bin_denorm_pred_mask, dpi=96
234
                    )
235
                    cv.imwrite(os.path.join(
236
                        mask_save_dir, '{}'.format(this_image_names[i])), bin_denorm_pred_mask
237
                    )
238
                    cv.imwrite(os.path.join(
239
                        drawed_save_dir, '{}'.format(this_image_names[i])), this_drawed_image
240
                    )
241
        return
242
243
    def train(self, start_epoch=1):
244
        # get training dataset
245
        train_data = data.get_tfrecord_data(
246
            self.params.record_dir, self.params.train_record_name,
247
            self.input_shape, batch_size=self.params.batch_size)
248
        self.load_model()
249
250
        pd_record = pd.DataFrame(columns=['Epoch', 'Iteration', 'Loss', 'Time'])
251
        data_name, original_test_image, norm_test_image = data.get_test_data(
252
            test_data_path=self.params.test_image_dir, image_shape=self.input_shape, image_nums=-1
253
        )
254
255
        start_time = time.time()
256
        best_dice = 0.0
257
        for epoch in range(start_epoch, self.params.epochs):
258
            for train_image, gt_mask in tqdm.tqdm(
259
                    train_data, total=self.params.total_samples // self.params.batch_size):
260
                self.save_ckpt.train_steps.assign_add(1)
261
                iteration = self.save_ckpt.train_steps.numpy()
262
263
                # training step
264
                train_loss = self.train_step(train_image, gt_mask)
265
                if iteration % 100 == 0:
266
                    print('Epoch: {}, Iteration: {}, Loss: {:.2f}, Time: {:.2f} s'.format(
267
                        epoch, iteration, train_loss, time.time() - start_time))
268
269
                    # test step
270
                    test_pred = self.inference(norm_test_image)
271
                    module.save_images(
272
                        image_shape=self.mask_shape, pred=test_pred,
273
                        save_path=self.pred_test_save_dir, index=iteration, split=False
274
                    )
275
                    pd_record = pd_record.append({
276
                        'Epoch': epoch, 'Iteration': iteration, 'Loss': train_loss.numpy(),
277
                        'Time': time.time() - start_time}, ignore_index=True
278
                    )
279
                    pd_record.to_csv(os.path.join(
280
                        self.save_dir, '{}_record.csv'.format(self.params.model_name)), index=True, header=True
281
                    )
282
283
            m_dice = self.invalid(epoch)
284
            if m_dice > best_dice:
285
                best_dice = m_dice
286
                print('Best Dice:{}'.format(best_dice))
287
                self.save_manger.save(checkpoint_number=epoch)
288
        return
289
290
    def invalid(self, epoch):
291
        invalid_data = data.get_tfrecord_data(
292
            self.params.record_dir, self.params.invalid_record_name,
293
            self.input_shape, batch_size=self.params.batch_size, shuffle=False)
294
295
        epoch_pred_save_dir = None
296
        for index, (invalid_image, invalid_mask) in enumerate(
297
                tqdm.tqdm(invalid_data, total=self.params.invalid_samples // self.params.batch_size + 1)):
298
            invalid_pred = self.inference(invalid_image)
299
            epoch_pred_save_dir = os.path.join(self.pred_invalid_save_dir, f'epoch_{epoch}')
300
            module.save_images(
301
                image_shape=self.mask_shape, pred=invalid_pred,
302
                save_path=epoch_pred_save_dir, index=f'{index}', split=False
303
            )
304
305
        # Test model performance
306
        epoch_cropped_save_dir = os.path.join(
307
            self.invalid_crop_save_dir, f'epoch_{epoch}'
308
        )
309
        utils.crop_image(epoch_pred_save_dir, epoch_cropped_save_dir,
310
                         self.crop_width, self.crop_height,
311
                         self.input_shape[0], self.input_shape[1]
312
                         )
313
        m_dice, m_iou, m_precision, m_recall = performance.save_performace_to_csv(
314
            pred_dir=epoch_cropped_save_dir, gt_dir=self.params.gt_mask_dir,
315
            img_resize=(self.params.height, self.params.width),
316
            csv_save_name=f'{self.model_name}_epoch_{epoch}',
317
            csv_save_path=epoch_cropped_save_dir
318
        )
319
        return m_dice
320