--- a +++ b/evaluate_transformation.py @@ -0,0 +1,114 @@ +import sys +import argparse +import os +import json +import csv +import numpy as np + +from utils.filemanager import get_points_paths +from utils.logger import logger, pprint +from utils.landmarks import get_landmarks_from_txt, write_landmarks_to_list +from utils.metrics import compute_TRE + +if __name__ == "__main__": + # optional arguments from the command line + parser = argparse.ArgumentParser() + + parser.add_argument('--experiment_name', type=str, default='elastix_01', help='experiment name') + parser.add_argument('--reg_params_key', type=str, default='Parameter.affine+Parameter.bsplines', help='registration parameters key generated by create_script.py') + parser.add_argument('--output_path', type=str, default='output', help='root dir for output scripts') + parser.add_argument("--generate_report", action='store_true', help='if True, an evaluation report .txt file will be generated. If not, only the transformed keypoints txt file will be generated for each test sample.') + parser.add_argument('--dataset_path', type=str, default='dataset/train', help='root dir for nifti data to get the gt exhale landmarks') + + # parse the arguments + args = parser.parse_args() + + # create experiment search path + # points is the folder where the transformed points are saved using transformix + args.exp_points_output = os.path.join(args.output_path, args.experiment_name, args.reg_params_key, 'points') + + # get a list of all the transformed keypoints files + transformed_points = get_points_paths(args.exp_points_output, "outputpoints", num_occurrences=2) + + if len(transformed_points) == 0: + logger.error(f"No transformed points found in {args.exp_points_output} directory.") + sys.exit(1) + + # check if generate_report is True + if args.generate_report: + gt_points = get_points_paths(args.dataset_path, "_300_eBH_xyz_r1", num_occurrences=1) + + if len(gt_points) == 0: + logger.error(f"No gt points found in {args.dataset_path} directory.") + sys.exit(1) + + # Create a list to store the TRE results + tre_results = [] + + else: + gt_points = [0 for _ in range(len(transformed_points))] # the list has to have values for the zip(*) to return the values inside + + + logger.info(f"Found {len(transformed_points)} transformed points files for subjects ({[subject.split('/')[-2] for subject in transformed_points]})") + + # extract the transformed points from the transformed_points transformix files and save them in a separate file + for transformed_points_file, gt_point in zip(transformed_points, gt_points): + print(f"Processing {transformed_points_file}...") + + # get the transformed points + transformed_landmarks = get_landmarks_from_txt(transformed_points_file, search_key='OutputIndexFixed') + + # the transformed points has to be 300 + assert len(transformed_landmarks) == 300, f"Transformed points file {transformed_points_file} has {len(transformed_landmarks)} points instead of 300." + + # write the transformed points to a file + # the points are written inside the same directory as the transformed_points_file + output_landmarks_path = os.path.join(transformed_points_file.replace('outputpoints.txt', ''), 'outputpoints_transformed.txt') + write_landmarks_to_list(transformed_landmarks, output_landmarks_path) + + # generate the evaluation report if args.generate_report is True, this is when we have the ground truth exhale files + if args.generate_report: + sample_name = gt_point.split('/')[-1].split('_')[0] #copd1, copd2, ... + + # load the dataset dictionary, we remove the last path element because we want to get the description.json file + with open(os.path.join(args.dataset_path.replace("train", "", 1).replace("test", "", 1),'description.json'), 'r') as json_file: + dictionary = json.loads(json_file.read()) + file_information = dictionary[args.dataset_path.replace('\\', '/').split("/")[-1]][sample_name] + print(file_information) + + TRE_mean, TRE_std = compute_TRE(output_landmarks_path, gt_point, tuple(file_information['voxel_dim'])) + print("TRE (After Registration):- ", f"(Mean TRE: {TRE_mean})", f"(STD TRE: {TRE_std}). \n") + + # Append TRE results to the list + tre_results.append({'sample_name': sample_name, 'TRE_mean': TRE_mean, 'TRE_std': TRE_std}) + + # generate the evaluation report if args.generate_report is True, this is when we have the ground truth exhale files + if args.generate_report: + # write the TRE results to a csv file for each sample + output_csv_path = os.path.join(args.exp_points_output, 'TRE_sample_results.csv') + with open(output_csv_path, 'w', newline='') as csv_file: + fieldnames = ['sample_name', 'TRE_mean', 'TRE_std'] + writer = csv.DictWriter(csv_file, fieldnames=fieldnames) + + # Write the header + writer.writeheader() + + # Write the data + for result in tre_results: + writer.writerow(result) + + # write the overall mean results + TRE_mean_list = [result['TRE_mean'] for result in tre_results] + TRE_std_list = [result['TRE_std'] for result in tre_results] + + output_csv_path = os.path.join(args.exp_points_output, 'TRE_overall_results.csv') + + with open(output_csv_path, 'w', newline='') as csv_file: + fieldnames = ['Overall mean (TRE_mean)', 'Overall mean (TRE_std)'] + writer = csv.DictWriter(csv_file, fieldnames=fieldnames) + + # Write the header + writer.writeheader() + + # Write the data + writer.writerow({'Overall mean (TRE_mean)': np.mean(TRE_mean_list), 'Overall mean (TRE_std)': np.mean(TRE_std_list)})