Diff of /fetal/evaluate.py [000000] .. [ccb1dd]

Switch to unified view

a b/fetal/evaluate.py
1
import numpy as np
2
import nibabel as nib
3
import os
4
import glob
5
import pandas as pd
6
import matplotlib
7
8
matplotlib.use('agg')
9
import matplotlib.pyplot as plt
10
11
12
def get_fetal_envelope_mask(data):
13
    return data > 0
14
15
16
def dice_coefficient(truth, prediction):
17
    return 2 * np.sum(truth * prediction) / (np.sum(truth) + np.sum(prediction))
18
19
20
def main():
21
    header = ("FetalEnvelope",)
22
    masking_functions = (get_fetal_envelope_mask,)
23
    rows = list()
24
    subject_ids = list()
25
    for case_folder in glob.glob("prediction/*"):
26
        if not os.path.isdir(case_folder):
27
            continue
28
        subject_ids.append(os.path.basename(case_folder))
29
        truth_file = os.path.join(case_folder, "truth.nii.gz")
30
        truth_image = nib.load(truth_file)
31
        truth = truth_image.get_data()
32
        prediction_file = os.path.join(case_folder, "prediction.nii.gz")
33
        prediction_image = nib.load(prediction_file)
34
        prediction = prediction_image.get_data()
35
        rows.append([dice_coefficient(func(truth), func(prediction)) for func in masking_functions])
36
37
    df = pd.DataFrame.from_records(rows, columns=header, index=subject_ids)
38
    df.to_csv("./prediction/brats_scores.csv")
39
40
    scores = dict()
41
    for index, score in enumerate(df.columns):
42
        values = df.values.T[index]
43
        scores[score] = values[np.isnan(values) == False]
44
45
    plt.boxplot(list(scores.values()), labels=list(scores.keys()))
46
    plt.ylabel("Dice Coefficient")
47
    plt.savefig("validation_scores_boxplot.png")
48
    plt.close()
49
50
    if os.path.exists("./training.log"):
51
        training_df = pd.read_csv("./training.log").set_index('epoch')
52
53
        plt.plot(training_df['loss'].values, label='training loss')
54
        plt.plot(training_df['val_loss'].values, label='validation loss')
55
        plt.ylabel('Loss')
56
        plt.xlabel('Epoch')
57
        plt.xlim((0, len(training_df.index)))
58
        plt.legend(loc='upper right')
59
        plt.savefig('loss_graph.png')
60
61
62
if __name__ == "__main__":
63
    main()