--- a +++ b/dataset_builder/generate_answer.py @@ -0,0 +1,160 @@ +import os +import re +import json +import pandas as pd +import logging +from tqdm import tqdm + + +def load_database(db_file_path): + import sqlite3 + + conn = sqlite3.connect(db_file_path) + cur = conn.cursor() + return cur + + +def post_process_sql( + query, + current_time="2105-12-31 23:59:00", + precomputed_dict={ + "temperature": (35.5, 38.1), + "sao2": (95.0, 100.0), + "heart rate": (60.0, 100.0), + "respiration": (12.0, 18.0), + "systolic bp": (90.0, 120.0), + "diastolic bp": (60.0, 90.0), + "mean bp": (60.0, 110.0), + }, +): + # handle current_time + if "current_time" in query: + query = query.replace("current_time", f"'{current_time}'") + + # handle vital signs + if re.search("[ \n]+([a-zA-Z0-9_]+_lower)", query) and re.search("[ \n]+([a-zA-Z0-9_]+_upper)", query): + vital_lower_expr = re.findall("[ \n]+([a-zA-Z0-9_]+_lower)", query)[0] + vital_upper_expr = re.findall("[ \n]+([a-zA-Z0-9_]+_upper)", query)[0] + vital_name_list = list(set(re.findall("([a-zA-Z0-9_]+)_lower", vital_lower_expr) + re.findall("([a-zA-Z0-9_]+)_upper", vital_upper_expr))) + if len(vital_name_list) == 1: + processed_vital_name = vital_name_list[0].replace("_", " ") + if processed_vital_name in precomputed_dict: + vital_range = precomputed_dict[processed_vital_name] + query = query.replace(vital_lower_expr, f"{vital_range[0]}").replace(vital_upper_expr, f"{vital_range[1]}") + + # handle etc. + query = query.replace("''", "'").replace("< =", "<=") + query = query.replace("%y", "%Y").replace("%j", "%J") + query = query.replace("‘", "'").replace("’", "'") + query = query.replace("\u201c", '"').replace("\u201d", '"') + + return query + + +def post_process_answer(answer, round_digit=6, sorted_answer=False): + assert isinstance(answer, list) + + if len(answer) == 0: # empty answer + assert answer == [] + else: + # tuple data preprocessing + if isinstance(answer[0], tuple): + assert len(answer[0]) == 1 + answer = [ans[0] for ans in answer] + + if isinstance(answer[0], float): # float-type answer + answer = [round(ans, round_digit) for ans in answer] + elif isinstance(answer[0], str): # string-type answer + if sorted_answer: + answer = sorted(answer) + else: + pass + + return answer + + +def main(args): + # Initialize error count + answer_error_cnt = 0 + + # Read json file + dataset = json.load(open(args.json_file_path)) + + # Load database + cur = load_database(db_file_path=args.db_file_path) + + new_dataset = [] + for data in tqdm(dataset): + try: + db_id = data["db_id"] + assert db_id == "mimic_iv_cxr", "db_id should be mimic_iv_cxr." + + gold_program = data["_gold_program"] + gold_program = post_process_sql(gold_program) + + res = cur.execute(gold_program) + answer = res.fetchall() + answer = post_process_answer(answer) + + data.pop("_gold_program") + new_data = { + **data, + "answer": answer, + } + + # Debugging block + if args.debug: + try: + _answer = data["_answer"] + _answer = post_process_answer(_answer) + assert answer == _answer, f"answer: {answer}, _answer: {_answer}" + except: + answer_error_cnt += 1 + print(f"answer_error_cnt: {answer_error_cnt}") + print(f"Answer mismatch at {db_id}: retrieved answer {answer}, gold answer {_answer}") + + new_data.pop("_answer") + + assert not any([k.startswith("_") for k in new_data.keys()]) + new_dataset.append(new_data) + + except Exception as e: + print(f"Error processing data {db_id}: {e}") + breakpoint() + # continue + + print(f"Total answer errors: {answer_error_cnt}") + + # Store new dataset + with open(args.output_path, "w") as f: + json.dump(new_dataset, f, indent=4, default=str) + + print(f"Saved new dataset to {args.output_path}") + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--debug", action="store_true") + parser.add_argument("--mimic_iv_dir", type=str, required=True) + parser.add_argument("--mimic_cxr_jpg_dir", type=str, required=True) + parser.add_argument("--chest_imagenome_dir", type=str, required=True) + parser.add_argument("--json_file_path", type=str, default="../dataset/mimic_iv_cxr/_test.json") + parser.add_argument("--db_file_path", type=str, default="../database/mimic_iv_cxr/test/mimic_iv_cxr.db") + parser.add_argument("--output_path", type=str, default="../dataset/mimic_iv_cxr/test.json") + args = parser.parse_args() + + # split + args.split = os.path.basename(args.output_path).split(".")[0] + + # postprocess json_file_path + if args.debug and "_debug" not in args.json_file_path: + args.json_file_path = args.json_file_path.replace(".json", "_debug.json") + + # postprocess db_file_path + if args.split == "valid": + args.db_file_path = args.db_file_path.replace("/valid/", "/train/") + + print(args) + main(args)