Diff of /utils/metrics.py [000000] .. [6fe801]

Switch to unified view

a b/utils/metrics.py
1
import numpy as np
2
import os
3
import pandas as pd
4
import matplotlib.pyplot as plt
5
from glob import glob
6
7
def plot_boxplot(experiment_name, output_dir, exclude=[], title="Boxplot"):
8
    '''
9
    Plot boxplot for the given data.
10
11
    Args:
12
        experiment_name (str): Name of the experiment.
13
        output_dir (str): Path to the output directory.
14
        exclude (list): List of columns to exclude from the boxplot.
15
        title (str): Title of the plot.
16
17
    Note:
18
        The dataframe holds the data in columns. Each column represents an experiment (single box plot) that we want to plot.
19
        We add the data to specific column of the experiment in the datafraame.
20
21
        Each row in the dataframe represents the result obtained from each subject in the experiment.
22
23
        >> df.head()
24
        >>          Par0003.affine  Par0003.bs-R1-fg  Par0003.bs-R6-ug  experiment_name
25
        >> copd1    10.62             26.25              1.34            ..
26
        >> copd2    10.07             21.45              2.68            ..
27
        >> copd3    03.57             12.04              1.27            ..
28
        >> copd4    07.48             29.45              1.53            ..
29
30
        If we describe the dataframe, we get the following:
31
32
        >> stats = df.describe()
33
        >> stats
34
        >>       Par0003.affine  Par0003.bs-R1-fg  Par0003.bs-R6-ug
35
        >> count  4.000000        4.000000          4.000000
36
        >> mean   7.935000        22.795000         1.705000
37
        >> std    3.417692        7.221071          0.700713
38
        >> min    3.570000        12.040000         1.270000
39
        >> 25%    6.345000        19.522500         1.330000
40
        >> 50%    8.775000        23.850000         1.435000
41
        >> 75%    10.365000       27.122500         2.060000
42
        >> max    10.620000       29.450000         2.680000
43
44
    Returns:
45
        None. The function generates and displays the box plot.
46
    '''
47
48
    # Get the data
49
    columns = os.listdir(f'../output/{experiment_name}/')
50
    TRE_sample_results = [path.replace('\\', '/') for path in glob(os.path.join(output_dir, experiment_name, "***", "points", "TRE_sample_results.csv"), recursive=True)]
51
52
    # Remove the excluded columns
53
    for column in exclude:
54
        if column in columns:
55
            columns.remove(column)
56
57
        TRE_sample_results = [item for item in TRE_sample_results if column not in item]
58
59
    # debugging
60
    # columns = columns[:5]
61
    # TRE_sample_results = TRE_sample_results[:5]
62
    # print(columns)
63
    # print(TRE_sample_results)
64
65
    # assert len(columns) == len(TRE_sample_results)
66
    assert  len(columns) == len(TRE_sample_results), f"Number of columns ({len(columns)}) does not match number of results ({len(TRE_sample_results)})"
67
68
    # Create a dataframe
69
    df = pd.DataFrame(columns=columns)
70
71
    for i, path in enumerate(TRE_sample_results):
72
        # Read the csv file
73
        data = pd.read_csv(path, index_col=0)
74
75
        # Add the data to the dataframe
76
        columns[i] = columns[i] + f" ({data['TRE_mean'].mean():.3f})"
77
        df[columns[i]] = data['TRE_mean']
78
        
79
    # Plot the boxplot
80
    boxplot = df.boxplot(column=columns, rot=90)
81
82
    # Get the lowest values for each column
83
    lowest_values = df.mean()
84
85
    # Get the column with the overall lowest minimum value
86
    lowest_column = lowest_values.idxmin()
87
88
    # Highlight the entire boxplot for the column with the lowest minimum value in red
89
    position = columns.index(lowest_column) + 1
90
91
    # 7 is the number of data that represents a single boxplot (divide len(boxplot.get_lines())//len(columns) to get the number of data per boxplot)
92
    boxplot.get_lines()[position * 7 - 7].set(color='red', linewidth=3) 
93
94
    # Set plot title
95
    plt.title(title)
96
97
    # Show the plot
98
    plt.show()
99
100
def compute_TRE(pts_exhale_file, pts_inhale_file, voxel_size):
101
    """
102
    Computes the Target Registration Error (TRE) to quantify the accuracy of the registration process. The TRE is calculated using 3D Euclidean 
103
    distance between the keypoints in the reference image (File 1) and the transformed keypoints in the registered image (File 2).
104
105
    Args:
106
        pts_exhale_file (str): path to the file containing the coordinates of the moving points
107
        pts_inhale_file (str): path to the file containing the coordinates of the fixed points
108
        voxel_size (tuple): voxel size in mm
109
110
    Returns:
111
        mean_TRE (float): mean TRE in mm
112
        std_TRE (float): standard deviation of the TRE in mm
113
    """
114
    # load the files if paths are provided
115
    pts_inhale = np.loadtxt(pts_inhale_file)
116
    pts_exhale = np.loadtxt(pts_exhale_file)
117
118
    # Check if the number of points in both files is the same
119
    if len(pts_inhale) != len(pts_exhale):
120
        raise ValueError("The number of points in the fixed and moving files must be the same.")
121
122
    # Compute the TRE - square root of the sum of the squared elements
123
    TRE = np.linalg.norm((pts_inhale - pts_exhale) * voxel_size, axis=1)
124
125
    return np.round(np.mean(TRE),2), np.round(np.std(TRE),2)