Switch to unified view

a b/evaluate_transformation.py
1
import sys
2
import argparse
3
import os
4
import json
5
import csv
6
import numpy as np
7
8
from utils.filemanager import get_points_paths
9
from utils.logger import logger, pprint
10
from utils.landmarks import get_landmarks_from_txt, write_landmarks_to_list
11
from utils.metrics import compute_TRE
12
13
if __name__ == "__main__":
14
    # optional arguments from the command line 
15
    parser = argparse.ArgumentParser()
16
17
    parser.add_argument('--experiment_name', type=str, default='elastix_01', help='experiment name')
18
    parser.add_argument('--reg_params_key', type=str, default='Parameter.affine+Parameter.bsplines', help='registration parameters key generated by create_script.py')
19
    parser.add_argument('--output_path', type=str, default='output', help='root dir for output scripts')
20
    parser.add_argument("--generate_report", action='store_true', help='if True, an evaluation report .txt file will be generated. If not, only the transformed keypoints txt file will be generated for each test sample.')
21
    parser.add_argument('--dataset_path', type=str, default='dataset/train', help='root dir for nifti data to get the gt exhale landmarks')
22
23
    # parse the arguments
24
    args = parser.parse_args()
25
26
    # create experiment search path
27
    # points is the folder where the transformed points are saved using transformix
28
    args.exp_points_output = os.path.join(args.output_path, args.experiment_name, args.reg_params_key, 'points')
29
    
30
    # get a list of all the transformed keypoints files
31
    transformed_points = get_points_paths(args.exp_points_output, "outputpoints", num_occurrences=2)
32
33
    if len(transformed_points) == 0:
34
        logger.error(f"No transformed points found in {args.exp_points_output} directory.")
35
        sys.exit(1)
36
37
    # check if generate_report is True
38
    if args.generate_report:
39
        gt_points = get_points_paths(args.dataset_path, "_300_eBH_xyz_r1", num_occurrences=1)
40
41
        if len(gt_points) == 0:
42
            logger.error(f"No gt points found in {args.dataset_path} directory.")
43
            sys.exit(1)
44
45
        # Create a list to store the TRE results
46
        tre_results = []
47
48
    else:
49
        gt_points = [0 for _ in range(len(transformed_points))] # the list has to have values for the zip(*) to return the values inside
50
51
52
    logger.info(f"Found {len(transformed_points)} transformed points files for subjects ({[subject.split('/')[-2] for subject in transformed_points]})")
53
54
    # extract the transformed points from the transformed_points transformix files and save them in a separate file
55
    for transformed_points_file, gt_point in zip(transformed_points, gt_points):
56
        print(f"Processing {transformed_points_file}...")
57
58
        # get the transformed points
59
        transformed_landmarks = get_landmarks_from_txt(transformed_points_file, search_key='OutputIndexFixed')
60
61
        # the transformed points has to be 300
62
        assert len(transformed_landmarks) == 300, f"Transformed points file {transformed_points_file} has {len(transformed_landmarks)} points instead of 300."
63
        
64
        # write the transformed points to a file 
65
        # the points are written inside the same directory as the transformed_points_file
66
        output_landmarks_path = os.path.join(transformed_points_file.replace('outputpoints.txt', ''), 'outputpoints_transformed.txt')
67
        write_landmarks_to_list(transformed_landmarks, output_landmarks_path)
68
69
        # generate the evaluation report if args.generate_report is True, this is when we have the ground truth exhale files
70
        if args.generate_report:
71
            sample_name = gt_point.split('/')[-1].split('_')[0] #copd1, copd2, ...
72
73
            # load the dataset dictionary, we remove the last path element because we want to get the description.json file
74
            with open(os.path.join(args.dataset_path.replace("train", "", 1).replace("test", "", 1),'description.json'), 'r') as json_file:
75
                dictionary = json.loads(json_file.read())
76
            file_information = dictionary[args.dataset_path.replace('\\', '/').split("/")[-1]][sample_name]
77
            print(file_information)
78
79
            TRE_mean, TRE_std = compute_TRE(output_landmarks_path, gt_point, tuple(file_information['voxel_dim']))
80
            print("TRE (After Registration):- ", f"(Mean TRE: {TRE_mean})", f"(STD TRE: {TRE_std}). \n")
81
82
            # Append TRE results to the list
83
            tre_results.append({'sample_name': sample_name, 'TRE_mean': TRE_mean, 'TRE_std': TRE_std})
84
85
    # generate the evaluation report if args.generate_report is True, this is when we have the ground truth exhale files
86
    if args.generate_report:
87
        # write the TRE results to a csv file for each sample
88
        output_csv_path = os.path.join(args.exp_points_output, 'TRE_sample_results.csv')
89
        with open(output_csv_path, 'w', newline='') as csv_file:
90
            fieldnames = ['sample_name', 'TRE_mean', 'TRE_std']
91
            writer = csv.DictWriter(csv_file, fieldnames=fieldnames)
92
93
            # Write the header
94
            writer.writeheader()
95
96
            # Write the data
97
            for result in tre_results:
98
                writer.writerow(result)
99
100
        # write the overall mean results
101
        TRE_mean_list = [result['TRE_mean'] for result in tre_results]
102
        TRE_std_list = [result['TRE_std'] for result in tre_results]
103
104
        output_csv_path = os.path.join(args.exp_points_output, 'TRE_overall_results.csv')
105
106
        with open(output_csv_path, 'w', newline='') as csv_file:
107
            fieldnames = ['Overall mean (TRE_mean)', 'Overall mean (TRE_std)']
108
            writer = csv.DictWriter(csv_file, fieldnames=fieldnames)
109
110
            # Write the header
111
            writer.writeheader()
112
113
            # Write the data
114
            writer.writerow({'Overall mean (TRE_mean)': np.mean(TRE_mean_list), 'Overall mean (TRE_std)': np.mean(TRE_std_list)})