Switch to side-by-side view

--- a
+++ b/evaluate_transformation.py
@@ -0,0 +1,114 @@
+import sys
+import argparse
+import os
+import json
+import csv
+import numpy as np
+
+from utils.filemanager import get_points_paths
+from utils.logger import logger, pprint
+from utils.landmarks import get_landmarks_from_txt, write_landmarks_to_list
+from utils.metrics import compute_TRE
+
+if __name__ == "__main__":
+    # optional arguments from the command line 
+    parser = argparse.ArgumentParser()
+
+    parser.add_argument('--experiment_name', type=str, default='elastix_01', help='experiment name')
+    parser.add_argument('--reg_params_key', type=str, default='Parameter.affine+Parameter.bsplines', help='registration parameters key generated by create_script.py')
+    parser.add_argument('--output_path', type=str, default='output', help='root dir for output scripts')
+    parser.add_argument("--generate_report", action='store_true', help='if True, an evaluation report .txt file will be generated. If not, only the transformed keypoints txt file will be generated for each test sample.')
+    parser.add_argument('--dataset_path', type=str, default='dataset/train', help='root dir for nifti data to get the gt exhale landmarks')
+
+    # parse the arguments
+    args = parser.parse_args()
+
+    # create experiment search path
+    # points is the folder where the transformed points are saved using transformix
+    args.exp_points_output = os.path.join(args.output_path, args.experiment_name, args.reg_params_key, 'points')
+    
+    # get a list of all the transformed keypoints files
+    transformed_points = get_points_paths(args.exp_points_output, "outputpoints", num_occurrences=2)
+
+    if len(transformed_points) == 0:
+        logger.error(f"No transformed points found in {args.exp_points_output} directory.")
+        sys.exit(1)
+
+    # check if generate_report is True
+    if args.generate_report:
+        gt_points = get_points_paths(args.dataset_path, "_300_eBH_xyz_r1", num_occurrences=1)
+
+        if len(gt_points) == 0:
+            logger.error(f"No gt points found in {args.dataset_path} directory.")
+            sys.exit(1)
+
+        # Create a list to store the TRE results
+        tre_results = []
+
+    else:
+        gt_points = [0 for _ in range(len(transformed_points))] # the list has to have values for the zip(*) to return the values inside
+
+
+    logger.info(f"Found {len(transformed_points)} transformed points files for subjects ({[subject.split('/')[-2] for subject in transformed_points]})")
+
+    # extract the transformed points from the transformed_points transformix files and save them in a separate file
+    for transformed_points_file, gt_point in zip(transformed_points, gt_points):
+        print(f"Processing {transformed_points_file}...")
+
+        # get the transformed points
+        transformed_landmarks = get_landmarks_from_txt(transformed_points_file, search_key='OutputIndexFixed')
+
+        # the transformed points has to be 300
+        assert len(transformed_landmarks) == 300, f"Transformed points file {transformed_points_file} has {len(transformed_landmarks)} points instead of 300."
+        
+        # write the transformed points to a file 
+        # the points are written inside the same directory as the transformed_points_file
+        output_landmarks_path = os.path.join(transformed_points_file.replace('outputpoints.txt', ''), 'outputpoints_transformed.txt')
+        write_landmarks_to_list(transformed_landmarks, output_landmarks_path)
+
+        # generate the evaluation report if args.generate_report is True, this is when we have the ground truth exhale files
+        if args.generate_report:
+            sample_name = gt_point.split('/')[-1].split('_')[0] #copd1, copd2, ...
+
+            # load the dataset dictionary, we remove the last path element because we want to get the description.json file
+            with open(os.path.join(args.dataset_path.replace("train", "", 1).replace("test", "", 1),'description.json'), 'r') as json_file:
+                dictionary = json.loads(json_file.read())
+            file_information = dictionary[args.dataset_path.replace('\\', '/').split("/")[-1]][sample_name]
+            print(file_information)
+
+            TRE_mean, TRE_std = compute_TRE(output_landmarks_path, gt_point, tuple(file_information['voxel_dim']))
+            print("TRE (After Registration):- ", f"(Mean TRE: {TRE_mean})", f"(STD TRE: {TRE_std}). \n")
+
+            # Append TRE results to the list
+            tre_results.append({'sample_name': sample_name, 'TRE_mean': TRE_mean, 'TRE_std': TRE_std})
+
+    # generate the evaluation report if args.generate_report is True, this is when we have the ground truth exhale files
+    if args.generate_report:
+        # write the TRE results to a csv file for each sample
+        output_csv_path = os.path.join(args.exp_points_output, 'TRE_sample_results.csv')
+        with open(output_csv_path, 'w', newline='') as csv_file:
+            fieldnames = ['sample_name', 'TRE_mean', 'TRE_std']
+            writer = csv.DictWriter(csv_file, fieldnames=fieldnames)
+
+            # Write the header
+            writer.writeheader()
+
+            # Write the data
+            for result in tre_results:
+                writer.writerow(result)
+
+        # write the overall mean results
+        TRE_mean_list = [result['TRE_mean'] for result in tre_results]
+        TRE_std_list = [result['TRE_std'] for result in tre_results]
+
+        output_csv_path = os.path.join(args.exp_points_output, 'TRE_overall_results.csv')
+
+        with open(output_csv_path, 'w', newline='') as csv_file:
+            fieldnames = ['Overall mean (TRE_mean)', 'Overall mean (TRE_std)']
+            writer = csv.DictWriter(csv_file, fieldnames=fieldnames)
+
+            # Write the header
+            writer.writeheader()
+
+            # Write the data
+            writer.writerow({'Overall mean (TRE_mean)': np.mean(TRE_mean_list), 'Overall mean (TRE_std)': np.mean(TRE_std_list)})