Diff of /drunet/module.py [000000] .. [2824d6]

Switch to unified view

a b/drunet/module.py
1
from collections import Counter
2
import pathlib
3
import math
4
import os
5
6
import matplotlib.pyplot as plt
7
import tensorflow as tf
8
import numpy as np
9
import cv2 as cv
10
import utils
11
import tqdm
12
13
14
def binary_image_from_dri(input_dir, threshold=128, save_dir=None):
15
    if os.path.isdir(input_dir):
16
        paths = utils.list_file(input_dir)
17
        utils.check_file([save_dir])
18
    else:
19
        paths = [input_dir]
20
21
    for path in paths:
22
        path_stem = pathlib.Path(path).stem
23
        image = cv.imread(path, 0)
24
        bin_image = binary_image(image, threshold)
25
        if save_dir is not None:
26
            cv.imwrite(os.path.join(save_dir, '{}.jpg'.format(path_stem)), bin_image)
27
    return
28
29
30
def binary_image(image, threshold):
31
    shape = image.shape
32
    if len(shape) == 3:
33
        image = cv.cvtColor(image, cv.COLOR_BGR2GRAY)
34
    th, bin_image = cv.threshold(image, threshold, 255, cv.THRESH_BINARY)
35
    return bin_image
36
37
38
def reverse_pred_image(normalize_pred_image):
39
    reverse_image = normalize_pred_image.squeeze() * 255.0
40
    reverse_image = np.array(reverse_image, dtype=np.uint8)
41
    return reverse_image
42
43
44
def save_images(pred, index, save_path, image_shape, split=False):
45
    image_numbers = int(np.sqrt(pred.shape[0]))
46
    if not split:
47
        h = image_shape[0]
48
        w = image_shape[1]
49
        H = int(image_numbers * image_shape[0])
50
        W = int(image_numbers * image_shape[1])
51
        big_image = np.zeros(shape=(H, W, image_shape[-1]), dtype=np.uint8).squeeze()
52
53
        for i in range(pow(image_numbers, 2)):
54
            image = (pred[i, :, :] * 255.0)
55
            image = np.array(image, dtype=np.uint8)
56
            image = image.squeeze()
57
            j = i % image_numbers
58
            k = i // image_numbers
59
            if image_shape[-1] == 1 and len(image_shape) == 3:
60
                big_image[k * h:(k + 1) * h, j * w:(j + 1) * w] = image
61
            else:
62
                big_image[k * h:(k + 1) * h, j * w:(j + 1) * w, :] = image
63
        path = os.path.join(save_path, 'Segment_train_pred_{}.png'.format(index))
64
        plt.imsave(path, big_image, cmap='gray')
65
    else:
66
        for i in range(image_numbers ** 2):
67
            image = (pred[i, :, :] * 255.0)
68
            image = np.array(image, dtype=np.uint8)
69
            image = image.squeeze()
70
            path = os.path.join(save_path, 'Segment_pred_{}_{}.png'.format(index, i))
71
            plt.imsave(path, image, cmap='gray')
72
    return
73
74
75
def get_area(image):
76
    """Count the area of bleeding area"""
77
    image = cv.cvtColor(image, cv.COLOR_BGR2GRAY)
78
    _, bin_image = cv.threshold(image, 0, 255, cv.THRESH_BINARY)
79
    count_result = Counter(list(bin_image.reshape(-1, )))
80
    area = count_result.get(255)
81
    return area
82
83
84
def pixel_to_ml(pixel_area, dpi=96):
85
    if pixel_area is None:
86
        pixel_area = 0.0
87
    return pixel_area / pow(dpi, 2) * pow(25.4, 2) / 100
88
89
90
def draw_contours(image, mask, max_count=8, dpi=96):
91
    image = np.array(image)
92
    mask = np.array(mask)
93
    height, width = image.shape[:2]
94
    copy_image = image.copy()
95
96
    if len(mask.shape) == 3 and mask.shape[-1] != 1:
97
        mask = cv.cvtColor(mask, cv.COLOR_BGR2GRAY)
98
    else:
99
        mask = mask
100
    mask = cv.resize(mask, (height, width))
101
    th, bin_mask = cv.threshold(mask, 0, 255, cv.THRESH_BINARY)
102
103
    con_list = []
104
    blood_area = []
105
    contours, _ = cv.findContours(bin_mask, cv.RETR_TREE, cv.CHAIN_APPROX_SIMPLE)
106
    for index, contour in enumerate(contours):
107
        area = cv.contourArea(contour)
108
        if 200 < area < height * width * 0.94:
109
            con_list.append(index)
110
            blood_area.append(pixel_to_ml(area, dpi))
111
112
    if len(con_list) > max_count:
113
        blood_area = [0]
114
    else:
115
        for index in con_list:
116
            copy_image = cv.drawContours(copy_image, contours, index, (0, 0, 255), 5)
117
    return copy_image, sum(blood_area)
118
119
120
def save_invalid_data(origin_images, drawed_images, pred_mask_images, image_names, save_dir, reshape=True):
121
    """
122
     :param image_names: the image file names of the original images, in the form of a list
123
     :param origin_images: original bleeding images
124
     :param drawed_images: draw a contour map of the bleeding area on the original image according to the mask
125
     :param pred_mask_images: predicted mask image
126
     :param save_dir: save path of all images
127
     :param reshape: restore all images to the original image size
128
     """
129
    origin_save_dir = os.path.join(save_dir, 'origin')
130
    drawed_save_dir = os.path.join(save_dir, 'drawed')
131
    mask_save_dir = os.path.join(save_dir, 'pred_mask')
132
    utils.check_file([origin_save_dir, drawed_save_dir, mask_save_dir])
133
134
    for index in range(len(origin_images)):
135
        origin_image = origin_images[index]
136
        drawed_image = drawed_images[index]
137
        mask_image = pred_mask_images[index]
138
        _, bin_mask_image = cv.threshold(mask_image, 0, 255, cv.THRESH_BINARY)
139
140
        if reshape:
141
            origin_image = cv.resize(origin_image, (256, 256))
142
            drawed_image = cv.resize(drawed_image, (256, 256))
143
            save_name = '{}'.format(image_names[index])
144
            save_mask_path = os.path.join(mask_save_dir, save_name)
145
            save_origin_path = os.path.join(origin_save_dir, save_name)
146
            save_drawed_path = os.path.join(drawed_save_dir, save_name)
147
            cv.imwrite(save_mask_path, bin_mask_image)
148
            cv.imwrite(save_origin_path, origin_image)
149
            cv.imwrite(save_drawed_path, drawed_image)
150
    return
151
152
153
def count_volume(areas, thickness=0.45):
154
    for area in areas:
155
        if area == 0:
156
            areas.remove(area)
157
    areas_count = len(areas)
158
    volume = [areas[index] * thickness for index in range(areas_count)]
159
    return sum(volume)
160
161
162
def calculate_volume(mask_dir, real_shape=(1440, 1440), thickness=0.4, dpi=96):
163
    all_areas = []
164
    for path in tqdm.tqdm(pathlib.Path(mask_dir).iterdir()):
165
        mask_image = cv.imread(str(path))
166
        mask_image = cv.resize(mask_image, real_shape)
167
        area = pixel_to_ml(get_area(mask_image), dpi=dpi)
168
        all_areas.append(area)
169
    volume = count_volume(all_areas, thickness)
170
    return volume