Diff of /drunet/segment.py [000000] .. [2824d6]

Switch to side-by-side view

--- a
+++ b/drunet/segment.py
@@ -0,0 +1,320 @@
+import os
+import time
+import argparse
+import pathlib
+
+import tqdm
+import cv2 as cv
+import numpy as np
+import pandas as pd
+import tensorflow as tf
+from tensorflow import keras
+import matplotlib.pyplot as plt
+from tensorflow.compat.v1 import ConfigProto
+from tensorflow.compat.v1 import InteractiveSession
+# Custom package
+import data
+import loss
+import utils
+import module
+import performance
+from model import dr_unet
+
+config = ConfigProto()
+config.gpu_options.allow_growth = True
+
+# 1. Parameter settings
+parser = argparse.ArgumentParser(description="Segment Use Args")
+parser.add_argument('--model-name', default='DR_UNet', type=str)
+parser.add_argument('--dims', default=32, type=int)
+parser.add_argument('--epochs', default=50, type=int)
+parser.add_argument('--batch-size', default=16, type=int)
+parser.add_argument('--lr', default=2e-4, type=float)
+
+# Training data, testing, verification parameter settings
+parser.add_argument('--height', default=256, type=int)
+parser.add_argument('--width', default=256, type=int)
+parser.add_argument('--channel', default=1, type=int)
+parser.add_argument('--pred-height', default=4 * 256, type=int)
+parser.add_argument('--pred-width', default=4 * 256, type=int)
+parser.add_argument('--total-samples', default=5000, type=int)
+parser.add_argument('--invalid-samples', default=1000, type=int)
+parser.add_argument('--regularize', default=False, type=bool)
+parser.add_argument('--record-dir', default=r'', type=str, help='the save dir of tfrecord')
+parser.add_argument('--train-record-name', type=str, default=r'train_data', help='the train record save name')
+parser.add_argument('--test-image-dir', default=r'', type=str, help='the path of test images dir')
+parser.add_argument('--invalid-record-name', type=str, default=r'test_data', help='the invalid record save name')
+parser.add_argument('--gt-mask-dir', default=r'', type=str, help='the ground truth dir of validation set')
+parser.add_argument('--invalid-volume-dir', default=r'', type=str, help='estimation bleeding volume')
+args = parser.parse_args()
+
+
+class Segmentation:
+    def __init__(self, params):
+        self.params = params
+        self.input_shape = [params.height, params.width, params.channel]
+        self.mask_shape = [params.height, params.width, 1]
+        self.model_name = params.model_name
+        self.crop_height = params.pred_height
+        self.crop_width = params.pred_width
+        self.regularize = params.regularize
+
+        # Obtain a segmentation model
+        self.seg_model = dr_unet.dr_unet(input_shape=self.input_shape, dims=params.dims)
+        self.seg_model.summary()
+
+        # Optimization function
+        self.optimizer = tf.keras.optimizers.Adam(lr=params.lr)
+
+        # Every epoch, predict invalid-images to test the segmentation performance of the model
+        self.save_dir = str(params.model_name).upper()
+        self.weight_save_dir = os.path.join(self.save_dir, 'checkpoint')
+        self.pred_invalid_save_dir = os.path.join(self.save_dir, 'invalid_pred')
+        self.invalid_crop_save_dir = os.path.join(self.save_dir, 'invalid_pred_crop')
+        self.pred_test_save_dir = os.path.join(self.save_dir, 'test_pred')
+        utils.check_file([
+            self.save_dir, self.weight_save_dir, self.pred_invalid_save_dir,
+            self.pred_test_save_dir, self.invalid_crop_save_dir]
+        )
+
+        # Save model parameters
+        train_steps = tf.Variable(0, tf.int32)
+        self.save_ckpt = tf.train.Checkpoint(
+            train_steps=train_steps, seg_model=self.seg_model, model_optimizer=self.optimizer)
+        self.save_manger = tf.train.CheckpointManager(
+            self.save_ckpt, directory=self.weight_save_dir, max_to_keep=1)
+
+        # Set the loss function
+        self.loss_fun = loss.bce_dice_loss
+
+    def load_model(self):
+        if self.save_manger.latest_checkpoint:
+            self.save_ckpt.restore(self.save_manger.latest_checkpoint)
+            print('Loading model: {}'.format(self.save_manger.latest_checkpoint))
+        else:
+            print('Retrain the model!')
+        return
+
+    @tf.function
+    def train_step(self, inputs, target):
+        tf.keras.backend.set_learning_phase(True)
+
+        with tf.GradientTape() as tape:
+            pred_mask = self.seg_model(inputs)
+            loss = self.loss_fun(target, pred_mask)
+            if self.regularize:
+                loss = tf.reduce_sum(loss) + tf.reduce_sum(self.seg_model.losses)
+        gradient = tape.gradient(loss, self.seg_model.trainable_variables)
+        self.optimizer.apply_gradients(zip(gradient, self.seg_model.trainable_variables))
+        return tf.reduce_mean(loss)
+
+    @tf.function
+    def inference(self, inputs):
+        tf.keras.backend.set_learning_phase(True)
+        pred = self.seg_model(inputs)
+        return pred
+
+    @staticmethod
+    def calculate_volume_by_mask(mask_dir, save_dir, model_name, dpi=96, thickness=0.45):
+        all_mask_file_paths = utils.list_file(mask_dir)
+
+        pd_record = pd.DataFrame(columns=['file_name', 'Volume'])
+        for file_dir in tqdm.tqdm(all_mask_file_paths):
+            file_name = pathlib.Path(file_dir).stem
+
+            each_blood_volume = module.calculate_volume(file_dir, thickness=thickness, dpi=dpi)
+            pd_record = pd_record.append({'file_name': file_name, 'Volume': each_blood_volume}, ignore_index=True)
+            pd_record.to_csv(
+                os.path.join(save_dir, '{}_{}.csv'.format(model_name, file_name)), index=True, header=True)
+        return
+
+    def predict_blood_volume(self, input_dir, save_dir, calc_nums=-1, dpi=96, thickness=0.45):
+        """
+         :param input_dir: The directory for testing bleeding volume images,
+                           there are multiple folders under the directory, each folder represents a CT image of a patient
+         :param save_dir: The predicted segmented image save directory
+         :param calc_nums: predict how many images in the folder
+         :param dpi: image parameters
+         :param thickness: slice thickness
+        """
+        # Loading weights of model
+        self.load_model()
+        save_pred_images_dir = os.path.join(save_dir, 'pred_images')
+        save_pred_csv_dir = os.path.join(save_dir, 'pred_csv')
+        utils.check_file([save_pred_images_dir, save_pred_csv_dir])
+        all_file_dirs = utils.list_file(input_dir)
+
+        cost_time_list = []
+        total_images = 0
+        for file_dir in tqdm.tqdm(all_file_dirs):
+            file_name = pathlib.Path(file_dir).stem
+
+            image_names, ori_images, normed_images = data.get_test_data(
+                test_data_path=file_dir, image_shape=self.input_shape, image_nums=calc_nums)
+            total_images += len(image_names)
+
+            start_time = time.time()
+            pred_mask = self.inference(normed_images)
+            end_time = time.time()
+            print('FPS: {}'.format(pred_mask.shape[0] / (end_time - start_time)),
+                  pred_mask.shape, end_time - start_time)
+
+            denorm_pred_mask = module.reverse_pred_image(pred_mask.numpy())  # (image_nums, 256, 256, 1)
+            if denorm_pred_mask.ndim == 2 and self.input_shape[-1] == 1:
+                denorm_pred_mask = np.expand_dims(denorm_pred_mask, 0)
+
+            drawed_images = []
+            blood_areas = []
+            pd_record = pd.DataFrame(columns=['image_name', 'Square Centimeter', 'Volume'])
+            for index in range(denorm_pred_mask.shape[0]):
+                drawed_image, blood_area = module.draw_contours(ori_images[index], denorm_pred_mask[index], dpi=dpi)
+                drawed_images.append(drawed_image)
+                blood_areas.append(blood_area)
+                pd_record = pd_record.append({'image_name': image_names[index], 'Square Centimeter': blood_area},
+                                             ignore_index=True)
+
+            one_pred_save_dir = os.path.join(save_pred_images_dir, file_name)
+            module.save_invalid_data(ori_images, drawed_images, denorm_pred_mask,
+                                     image_names, reshape=True, save_dir=one_pred_save_dir)
+
+            # Calculate the amount of bleeding based on the area of each layer of hematoma
+            blood_volume = module.count_volume(blood_areas, thickness=thickness)
+            pd_record = pd_record.append({'Volume': blood_volume}, ignore_index=True)
+            pd_record.to_csv(os.path.join(save_pred_csv_dir, '{}_{}.csv'.format(self.model_name, file_name)),
+                             index=True, header=True)
+            cost_time_list.append(end_time - start_time)
+            print('FileName: {} time: {}'.format(file_name, end_time - start_time))
+        print('total_time: {:.2f}, mean_time: {:.2f}, total_images: {}'.format(
+            np.sum(cost_time_list), np.mean(cost_time_list), total_images))
+        return
+
+    def predict_and_save(self, input_dir, save_dir, calc_nums=-1, batch_size=16):
+        """ predict bleeding image and save
+         :param input_dir: There are several images waiting to be tested under the input_dir folder
+         :param save_dir: The file directory where the segmented image predicted by the model is saved
+         :param calc_nums: How many images are taken from the directory to participate in the calculation
+         :param batch_size: how many images to test each time
+         :return:
+        """
+        mask_save_dir = os.path.join(save_dir, 'pred_mask')
+        drawed_save_dir = os.path.join(save_dir, 'drawed_image')
+        utils.check_file([mask_save_dir, drawed_save_dir])
+        self.load_model()
+
+        test_image_list = utils.list_file(input_dir)
+        for index in range(len(test_image_list) // 128 + 1):
+            input_test_list = test_image_list[index * 128:(index + 1) * 128]
+
+            image_names, ori_images, normed_images = data.get_test_data(
+                test_data_path=input_test_list, image_shape=self.input_shape, image_nums=-1,
+            )
+            if calc_nums != -1:
+                ori_images = ori_images[:calc_nums]
+                normed_images = normed_images[:calc_nums]
+                image_names = image_names[:calc_nums]
+
+            inference_times = normed_images.shape[0] // batch_size + 1
+            for inference_time in range(inference_times):
+                this_normed_images = normed_images[
+                                     inference_time * batch_size:(inference_time + 1) * batch_size, ...]
+                this_ori_images = ori_images[
+                                  inference_time * batch_size:(inference_time + 1) * batch_size, ...]
+                this_image_names = image_names[
+                                   inference_time * batch_size:(inference_time + 1) * batch_size]
+
+                this_pred_mask = self.inference(this_normed_images)
+                this_denorm_pred_mask = module.reverse_pred_image(this_pred_mask.numpy())
+                if ori_images.shape[0] == 1:
+                    this_denorm_pred_mask = np.expand_dims(this_denorm_pred_mask, 0)
+
+                for i in range(this_denorm_pred_mask.shape[0]):
+                    bin_denorm_pred_mask = this_denorm_pred_mask[i]
+                    this_drawed_image, this_blood_area = module.draw_contours(
+                        this_ori_images[i], bin_denorm_pred_mask, dpi=96
+                    )
+                    cv.imwrite(os.path.join(
+                        mask_save_dir, '{}'.format(this_image_names[i])), bin_denorm_pred_mask
+                    )
+                    cv.imwrite(os.path.join(
+                        drawed_save_dir, '{}'.format(this_image_names[i])), this_drawed_image
+                    )
+        return
+
+    def train(self, start_epoch=1):
+        # get training dataset
+        train_data = data.get_tfrecord_data(
+            self.params.record_dir, self.params.train_record_name,
+            self.input_shape, batch_size=self.params.batch_size)
+        self.load_model()
+
+        pd_record = pd.DataFrame(columns=['Epoch', 'Iteration', 'Loss', 'Time'])
+        data_name, original_test_image, norm_test_image = data.get_test_data(
+            test_data_path=self.params.test_image_dir, image_shape=self.input_shape, image_nums=-1
+        )
+
+        start_time = time.time()
+        best_dice = 0.0
+        for epoch in range(start_epoch, self.params.epochs):
+            for train_image, gt_mask in tqdm.tqdm(
+                    train_data, total=self.params.total_samples // self.params.batch_size):
+                self.save_ckpt.train_steps.assign_add(1)
+                iteration = self.save_ckpt.train_steps.numpy()
+
+                # training step
+                train_loss = self.train_step(train_image, gt_mask)
+                if iteration % 100 == 0:
+                    print('Epoch: {}, Iteration: {}, Loss: {:.2f}, Time: {:.2f} s'.format(
+                        epoch, iteration, train_loss, time.time() - start_time))
+
+                    # test step
+                    test_pred = self.inference(norm_test_image)
+                    module.save_images(
+                        image_shape=self.mask_shape, pred=test_pred,
+                        save_path=self.pred_test_save_dir, index=iteration, split=False
+                    )
+                    pd_record = pd_record.append({
+                        'Epoch': epoch, 'Iteration': iteration, 'Loss': train_loss.numpy(),
+                        'Time': time.time() - start_time}, ignore_index=True
+                    )
+                    pd_record.to_csv(os.path.join(
+                        self.save_dir, '{}_record.csv'.format(self.params.model_name)), index=True, header=True
+                    )
+
+            m_dice = self.invalid(epoch)
+            if m_dice > best_dice:
+                best_dice = m_dice
+                print('Best Dice:{}'.format(best_dice))
+                self.save_manger.save(checkpoint_number=epoch)
+        return
+
+    def invalid(self, epoch):
+        invalid_data = data.get_tfrecord_data(
+            self.params.record_dir, self.params.invalid_record_name,
+            self.input_shape, batch_size=self.params.batch_size, shuffle=False)
+
+        epoch_pred_save_dir = None
+        for index, (invalid_image, invalid_mask) in enumerate(
+                tqdm.tqdm(invalid_data, total=self.params.invalid_samples // self.params.batch_size + 1)):
+            invalid_pred = self.inference(invalid_image)
+            epoch_pred_save_dir = os.path.join(self.pred_invalid_save_dir, f'epoch_{epoch}')
+            module.save_images(
+                image_shape=self.mask_shape, pred=invalid_pred,
+                save_path=epoch_pred_save_dir, index=f'{index}', split=False
+            )
+
+        # Test model performance
+        epoch_cropped_save_dir = os.path.join(
+            self.invalid_crop_save_dir, f'epoch_{epoch}'
+        )
+        utils.crop_image(epoch_pred_save_dir, epoch_cropped_save_dir,
+                         self.crop_width, self.crop_height,
+                         self.input_shape[0], self.input_shape[1]
+                         )
+        m_dice, m_iou, m_precision, m_recall = performance.save_performace_to_csv(
+            pred_dir=epoch_cropped_save_dir, gt_dir=self.params.gt_mask_dir,
+            img_resize=(self.params.height, self.params.width),
+            csv_save_name=f'{self.model_name}_epoch_{epoch}',
+            csv_save_path=epoch_cropped_save_dir
+        )
+        return m_dice
+