Diff of /drunet/performance.py [000000] .. [2824d6]

Switch to unified view

a b/drunet/performance.py
1
import os
2
import time
3
import math
4
import pathlib
5
from functools import reduce
6
from collections import Counter
7
8
import utils
9
import cv2 as cv
10
import numpy as np
11
import pandas as pd
12
from tqdm import tqdm
13
import matplotlib.pyplot as plt
14
15
16
def prepro_image(img_path, img_resize, threshold=128):
17
    image = cv.imread(img_path, 0)
18
    if len(image.shape) != 2:
19
        image = cv.cvtColor(image, cv.COLOR_RGB2GRAY)
20
    image = cv.resize(image, img_resize)
21
    _, bin_image = cv.threshold(image, threshold, 255, cv.THRESH_BINARY)
22
    bin_image = np.array(bin_image / 255, dtype=np.int)
23
    return bin_image
24
25
26
def true_positive(pred, gt):
27
    assert pred.shape == gt.shape
28
    tp_bool = np.logical_and(pred, gt)
29
    tp_int = np.array(tp_bool, dtype=np.int)
30
    tp = np.sum(tp_int)
31
    return tp
32
33
34
def true_negative(pred, gt):
35
    assert pred.shape == gt.shape
36
    no_pred = np.array(np.logical_not(pred), dtype=np.int)
37
    no_gt = np.array(np.logical_not(gt), dtype=np.int)
38
    tn = true_positive(no_pred, no_gt)
39
    return tn
40
41
42
def false_positive(pred, gt):
43
    assert pred.shape == gt.shape
44
    no_gt = np.array(np.logical_not(gt), dtype=np.int)
45
    fp = true_positive(pred, no_gt)
46
    return fp
47
48
49
def false_negative(pred, gt):
50
    assert pred.shape == gt.shape
51
    no_pred = np.array(np.logical_not(pred), dtype=np.int)
52
    fn = true_positive(no_pred, gt)
53
    return fn
54
55
56
def f1_score(precision, recall):
57
    return (2 * precision * recall) / (precision + recall)
58
59
60
def calc_performance(pred_path, gt_path, img_resize, threshold=128):
61
    """ Used to calculate the difference between the segmented image and Ground Truth for statistical prediction
62
        to evaluate the performance of the model, but currently only gray-scale images can be calculated
63
     :param pred_path: path of predicted image
64
     :param gt_path: real mask path
65
     :param img_resize: the size to resize the image
66
     :param threshold: Used to binaries images
67
     :return: pix-accuracy, precision, recall, VOE, RVD, Dice, IOU evaluation indicators
68
     """
69
    total_pix = reduce(lambda x, y: x * y, img_resize)
70
    pred_image = prepro_image(pred_path, img_resize, threshold)
71
    mask_image = prepro_image(gt_path, img_resize, threshold)
72
73
    tp = true_positive(pred_image, mask_image)
74
    tn = true_negative(pred_image, mask_image)
75
    fp = false_positive(pred_image, mask_image)
76
    fn = false_negative(pred_image, mask_image)
77
78
    accuracy = (tp + tn) / total_pix
79
    precision = tp / (tp + fp + 1e-10)
80
    recall = tp / (tp + fn + 1e-10)
81
    iou = tp / (tp + fp + fn + 1e-10)
82
    dice = 2 * tp / (fn + tp + tp + fp + 1e-10)
83
    voe = 1 - tp / (tp + fn + fp + 1e-10)
84
    rvd = (fp - fn) / (fn + tp + 1e-10)
85
    specificity = tn / (tn + fp + 1e-10)
86
    return tp, tn, fp, fn, accuracy, precision, recall, iou, dice, voe, rvd, specificity
87
88
89
def save_performance_to_csv(pred_dir, gt_dir, img_resize, csv_save_name, csv_save_path='', threshold=128):
90
    gt_paths, gt_names = utils.get_path(gt_dir)
91
    pred_paths, pred_names = utils.get_path(pred_dir)
92
93
    record_pd = pd.DataFrame(columns=[
94
        'pred_name', 'gt_name', 'TP', 'FP', 'FN', 'TN',
95
        'accuracy', 'precision', 'recall', 'IOU', 'DICE', 'VOE', 'RVD', 'specificity',
96
    ])
97
98
    total_file_nums = len(gt_paths)
99
    for file_index in tqdm(range(total_file_nums), total=total_file_nums):
100
        TP, TN, FP, FN, accuracy, precision, recall, IOU, DICE, VOE, RVD, specificity = calc_performance(
101
            pred_paths[file_index], gt_paths[file_index], img_resize, threshold)
102
103
        record_pd = record_pd.append({
104
            'pred_name': pred_names[file_index],
105
            'gt_name': gt_names[file_index],
106
            'accuracy': accuracy,
107
            'precision': precision,
108
            'recall': recall,
109
            'specificity': specificity,
110
            'IOU': IOU,
111
            'DICE': DICE,
112
            'VOE': VOE,
113
            'RVD': RVD,
114
            'TP': TP, 'FP': FP, 'FN': FN, 'TN': TN
115
        }, ignore_index=True)
116
117
    record_pd.to_csv(
118
        os.path.join(csv_save_path, '{}.csv'.format(csv_save_name)), index=True, header=True)
119
120
    m_accuracy, m_precision, m_recall, m_iou, m_dice, m_voe, m_rvd, m_spec = analysis_performance(
121
        os.path.join(csv_save_path, '{}.csv'.format(csv_save_name)))
122
    analysis_pd = pd.DataFrame(columns=[
123
        'm_accu', 'm_prec', 'm_recall', 'm_iou', 'm_dice', 'm_voe', 'm_rvd', 'm_spec'
124
    ])
125
    analysis_pd = analysis_pd.append({
126
        'm_accu': m_accuracy, 'm_prec': m_precision, 'm_recall': m_recall, 'm_iou': m_iou,
127
        'm_dice': m_dice, 'm_voe': m_voe, 'm_rvd': m_rvd, 'm_spec': m_spec,
128
    }, ignore_index=True)
129
    analysis_pd.to_csv(
130
        os.path.join(csv_save_path, 'analysis_{}.csv'.format(csv_save_name)), index=True, header=True)
131
    return m_dice, m_iou, m_precision, m_recall
132
133
134
def analysis_performance(csv_file_path):
135
    data_frame = pd.read_csv(csv_file_path, header=None)
136
137
    m_accuracy = np.mean(np.array(data_frame.loc[1:, 7], dtype=np.float32))
138
    m_precision = np.mean(np.array(data_frame.loc[1:, 8], dtype=np.float32))
139
    m_recall = np.mean(np.array(data_frame.loc[1:, 9], dtype=np.float32))
140
    m_iou = np.mean(np.array(data_frame.loc[1:, 10], dtype=np.float32))
141
    m_dice = np.mean(np.array(data_frame.loc[1:, 11], dtype=np.float32))
142
    m_voe = np.mean(np.array(data_frame.loc[1:, 12], dtype=np.float32))
143
    m_rvd = np.mean(np.array(data_frame.loc[1:, 13], dtype=np.float32))
144
    m_spec = np.mean(np.array(data_frame.loc[1:, 14], dtype=np.float32))
145
    print(
146
        ' accuracy: {},\n precision: {},\n recall: {},\n iou: {},\n dice: {},\n voe: {},\n rvd: {},\n spec: {}.\n'.format(
147
            m_accuracy, m_precision, m_recall, m_iou, m_dice, m_voe, m_rvd, m_spec))
148
    return m_accuracy, m_precision, m_recall, m_iou, m_dice, m_voe, m_rvd, m_spec