|
a |
|
b/evaluate_transformation.py |
|
|
1 |
import sys |
|
|
2 |
import argparse |
|
|
3 |
import os |
|
|
4 |
import json |
|
|
5 |
import csv |
|
|
6 |
import numpy as np |
|
|
7 |
|
|
|
8 |
from utils.filemanager import get_points_paths |
|
|
9 |
from utils.logger import logger, pprint |
|
|
10 |
from utils.landmarks import get_landmarks_from_txt, write_landmarks_to_list |
|
|
11 |
from utils.metrics import compute_TRE |
|
|
12 |
|
|
|
13 |
if __name__ == "__main__": |
|
|
14 |
# optional arguments from the command line |
|
|
15 |
parser = argparse.ArgumentParser() |
|
|
16 |
|
|
|
17 |
parser.add_argument('--experiment_name', type=str, default='elastix_01', help='experiment name') |
|
|
18 |
parser.add_argument('--reg_params_key', type=str, default='Parameter.affine+Parameter.bsplines', help='registration parameters key generated by create_script.py') |
|
|
19 |
parser.add_argument('--output_path', type=str, default='output', help='root dir for output scripts') |
|
|
20 |
parser.add_argument("--generate_report", action='store_true', help='if True, an evaluation report .txt file will be generated. If not, only the transformed keypoints txt file will be generated for each test sample.') |
|
|
21 |
parser.add_argument('--dataset_path', type=str, default='dataset/train', help='root dir for nifti data to get the gt exhale landmarks') |
|
|
22 |
|
|
|
23 |
# parse the arguments |
|
|
24 |
args = parser.parse_args() |
|
|
25 |
|
|
|
26 |
# create experiment search path |
|
|
27 |
# points is the folder where the transformed points are saved using transformix |
|
|
28 |
args.exp_points_output = os.path.join(args.output_path, args.experiment_name, args.reg_params_key, 'points') |
|
|
29 |
|
|
|
30 |
# get a list of all the transformed keypoints files |
|
|
31 |
transformed_points = get_points_paths(args.exp_points_output, "outputpoints", num_occurrences=2) |
|
|
32 |
|
|
|
33 |
if len(transformed_points) == 0: |
|
|
34 |
logger.error(f"No transformed points found in {args.exp_points_output} directory.") |
|
|
35 |
sys.exit(1) |
|
|
36 |
|
|
|
37 |
# check if generate_report is True |
|
|
38 |
if args.generate_report: |
|
|
39 |
gt_points = get_points_paths(args.dataset_path, "_300_eBH_xyz_r1", num_occurrences=1) |
|
|
40 |
|
|
|
41 |
if len(gt_points) == 0: |
|
|
42 |
logger.error(f"No gt points found in {args.dataset_path} directory.") |
|
|
43 |
sys.exit(1) |
|
|
44 |
|
|
|
45 |
# Create a list to store the TRE results |
|
|
46 |
tre_results = [] |
|
|
47 |
|
|
|
48 |
else: |
|
|
49 |
gt_points = [0 for _ in range(len(transformed_points))] # the list has to have values for the zip(*) to return the values inside |
|
|
50 |
|
|
|
51 |
|
|
|
52 |
logger.info(f"Found {len(transformed_points)} transformed points files for subjects ({[subject.split('/')[-2] for subject in transformed_points]})") |
|
|
53 |
|
|
|
54 |
# extract the transformed points from the transformed_points transformix files and save them in a separate file |
|
|
55 |
for transformed_points_file, gt_point in zip(transformed_points, gt_points): |
|
|
56 |
print(f"Processing {transformed_points_file}...") |
|
|
57 |
|
|
|
58 |
# get the transformed points |
|
|
59 |
transformed_landmarks = get_landmarks_from_txt(transformed_points_file, search_key='OutputIndexFixed') |
|
|
60 |
|
|
|
61 |
# the transformed points has to be 300 |
|
|
62 |
assert len(transformed_landmarks) == 300, f"Transformed points file {transformed_points_file} has {len(transformed_landmarks)} points instead of 300." |
|
|
63 |
|
|
|
64 |
# write the transformed points to a file |
|
|
65 |
# the points are written inside the same directory as the transformed_points_file |
|
|
66 |
output_landmarks_path = os.path.join(transformed_points_file.replace('outputpoints.txt', ''), 'outputpoints_transformed.txt') |
|
|
67 |
write_landmarks_to_list(transformed_landmarks, output_landmarks_path) |
|
|
68 |
|
|
|
69 |
# generate the evaluation report if args.generate_report is True, this is when we have the ground truth exhale files |
|
|
70 |
if args.generate_report: |
|
|
71 |
sample_name = gt_point.split('/')[-1].split('_')[0] #copd1, copd2, ... |
|
|
72 |
|
|
|
73 |
# load the dataset dictionary, we remove the last path element because we want to get the description.json file |
|
|
74 |
with open(os.path.join(args.dataset_path.replace("train", "", 1).replace("test", "", 1),'description.json'), 'r') as json_file: |
|
|
75 |
dictionary = json.loads(json_file.read()) |
|
|
76 |
file_information = dictionary[args.dataset_path.replace('\\', '/').split("/")[-1]][sample_name] |
|
|
77 |
print(file_information) |
|
|
78 |
|
|
|
79 |
TRE_mean, TRE_std = compute_TRE(output_landmarks_path, gt_point, tuple(file_information['voxel_dim'])) |
|
|
80 |
print("TRE (After Registration):- ", f"(Mean TRE: {TRE_mean})", f"(STD TRE: {TRE_std}). \n") |
|
|
81 |
|
|
|
82 |
# Append TRE results to the list |
|
|
83 |
tre_results.append({'sample_name': sample_name, 'TRE_mean': TRE_mean, 'TRE_std': TRE_std}) |
|
|
84 |
|
|
|
85 |
# generate the evaluation report if args.generate_report is True, this is when we have the ground truth exhale files |
|
|
86 |
if args.generate_report: |
|
|
87 |
# write the TRE results to a csv file for each sample |
|
|
88 |
output_csv_path = os.path.join(args.exp_points_output, 'TRE_sample_results.csv') |
|
|
89 |
with open(output_csv_path, 'w', newline='') as csv_file: |
|
|
90 |
fieldnames = ['sample_name', 'TRE_mean', 'TRE_std'] |
|
|
91 |
writer = csv.DictWriter(csv_file, fieldnames=fieldnames) |
|
|
92 |
|
|
|
93 |
# Write the header |
|
|
94 |
writer.writeheader() |
|
|
95 |
|
|
|
96 |
# Write the data |
|
|
97 |
for result in tre_results: |
|
|
98 |
writer.writerow(result) |
|
|
99 |
|
|
|
100 |
# write the overall mean results |
|
|
101 |
TRE_mean_list = [result['TRE_mean'] for result in tre_results] |
|
|
102 |
TRE_std_list = [result['TRE_std'] for result in tre_results] |
|
|
103 |
|
|
|
104 |
output_csv_path = os.path.join(args.exp_points_output, 'TRE_overall_results.csv') |
|
|
105 |
|
|
|
106 |
with open(output_csv_path, 'w', newline='') as csv_file: |
|
|
107 |
fieldnames = ['Overall mean (TRE_mean)', 'Overall mean (TRE_std)'] |
|
|
108 |
writer = csv.DictWriter(csv_file, fieldnames=fieldnames) |
|
|
109 |
|
|
|
110 |
# Write the header |
|
|
111 |
writer.writeheader() |
|
|
112 |
|
|
|
113 |
# Write the data |
|
|
114 |
writer.writerow({'Overall mean (TRE_mean)': np.mean(TRE_mean_list), 'Overall mean (TRE_std)': np.mean(TRE_std_list)}) |