import sys
import argparse
import os
import json
import csv
import numpy as np
from utils.filemanager import get_points_paths
from utils.logger import logger, pprint
from utils.landmarks import get_landmarks_from_txt, write_landmarks_to_list
from utils.metrics import compute_TRE
if __name__ == "__main__":
# optional arguments from the command line
parser = argparse.ArgumentParser()
parser.add_argument('--experiment_name', type=str, default='elastix_01', help='experiment name')
parser.add_argument('--reg_params_key', type=str, default='Parameter.affine+Parameter.bsplines', help='registration parameters key generated by create_script.py')
parser.add_argument('--output_path', type=str, default='output', help='root dir for output scripts')
parser.add_argument("--generate_report", action='store_true', help='if True, an evaluation report .txt file will be generated. If not, only the transformed keypoints txt file will be generated for each test sample.')
parser.add_argument('--dataset_path', type=str, default='dataset/train', help='root dir for nifti data to get the gt exhale landmarks')
# parse the arguments
args = parser.parse_args()
# create experiment search path
# points is the folder where the transformed points are saved using transformix
args.exp_points_output = os.path.join(args.output_path, args.experiment_name, args.reg_params_key, 'points')
# get a list of all the transformed keypoints files
transformed_points = get_points_paths(args.exp_points_output, "outputpoints", num_occurrences=2)
if len(transformed_points) == 0:
logger.error(f"No transformed points found in {args.exp_points_output} directory.")
sys.exit(1)
# check if generate_report is True
if args.generate_report:
gt_points = get_points_paths(args.dataset_path, "_300_eBH_xyz_r1", num_occurrences=1)
if len(gt_points) == 0:
logger.error(f"No gt points found in {args.dataset_path} directory.")
sys.exit(1)
# Create a list to store the TRE results
tre_results = []
else:
gt_points = [0 for _ in range(len(transformed_points))] # the list has to have values for the zip(*) to return the values inside
logger.info(f"Found {len(transformed_points)} transformed points files for subjects ({[subject.split('/')[-2] for subject in transformed_points]})")
# extract the transformed points from the transformed_points transformix files and save them in a separate file
for transformed_points_file, gt_point in zip(transformed_points, gt_points):
print(f"Processing {transformed_points_file}...")
# get the transformed points
transformed_landmarks = get_landmarks_from_txt(transformed_points_file, search_key='OutputIndexFixed')
# the transformed points has to be 300
assert len(transformed_landmarks) == 300, f"Transformed points file {transformed_points_file} has {len(transformed_landmarks)} points instead of 300."
# write the transformed points to a file
# the points are written inside the same directory as the transformed_points_file
output_landmarks_path = os.path.join(transformed_points_file.replace('outputpoints.txt', ''), 'outputpoints_transformed.txt')
write_landmarks_to_list(transformed_landmarks, output_landmarks_path)
# generate the evaluation report if args.generate_report is True, this is when we have the ground truth exhale files
if args.generate_report:
sample_name = gt_point.split('/')[-1].split('_')[0] #copd1, copd2, ...
# load the dataset dictionary, we remove the last path element because we want to get the description.json file
with open(os.path.join(args.dataset_path.replace("train", "", 1).replace("test", "", 1),'description.json'), 'r') as json_file:
dictionary = json.loads(json_file.read())
file_information = dictionary[args.dataset_path.replace('\\', '/').split("/")[-1]][sample_name]
print(file_information)
TRE_mean, TRE_std = compute_TRE(output_landmarks_path, gt_point, tuple(file_information['voxel_dim']))
print("TRE (After Registration):- ", f"(Mean TRE: {TRE_mean})", f"(STD TRE: {TRE_std}). \n")
# Append TRE results to the list
tre_results.append({'sample_name': sample_name, 'TRE_mean': TRE_mean, 'TRE_std': TRE_std})
# generate the evaluation report if args.generate_report is True, this is when we have the ground truth exhale files
if args.generate_report:
# write the TRE results to a csv file for each sample
output_csv_path = os.path.join(args.exp_points_output, 'TRE_sample_results.csv')
with open(output_csv_path, 'w', newline='') as csv_file:
fieldnames = ['sample_name', 'TRE_mean', 'TRE_std']
writer = csv.DictWriter(csv_file, fieldnames=fieldnames)
# Write the header
writer.writeheader()
# Write the data
for result in tre_results:
writer.writerow(result)
# write the overall mean results
TRE_mean_list = [result['TRE_mean'] for result in tre_results]
TRE_std_list = [result['TRE_std'] for result in tre_results]
output_csv_path = os.path.join(args.exp_points_output, 'TRE_overall_results.csv')
with open(output_csv_path, 'w', newline='') as csv_file:
fieldnames = ['Overall mean (TRE_mean)', 'Overall mean (TRE_std)']
writer = csv.DictWriter(csv_file, fieldnames=fieldnames)
# Write the header
writer.writeheader()
# Write the data
writer.writerow({'Overall mean (TRE_mean)': np.mean(TRE_mean_list), 'Overall mean (TRE_std)': np.mean(TRE_std_list)})