Switch to side-by-side view

--- a
+++ b/dataset_builder/generate_answer.py
@@ -0,0 +1,160 @@
+import os
+import re
+import json
+import pandas as pd
+import logging
+from tqdm import tqdm
+
+
+def load_database(db_file_path):
+    import sqlite3
+
+    conn = sqlite3.connect(db_file_path)
+    cur = conn.cursor()
+    return cur
+
+
+def post_process_sql(
+    query,
+    current_time="2105-12-31 23:59:00",
+    precomputed_dict={
+        "temperature": (35.5, 38.1),
+        "sao2": (95.0, 100.0),
+        "heart rate": (60.0, 100.0),
+        "respiration": (12.0, 18.0),
+        "systolic bp": (90.0, 120.0),
+        "diastolic bp": (60.0, 90.0),
+        "mean bp": (60.0, 110.0),
+    },
+):
+    # handle current_time
+    if "current_time" in query:
+        query = query.replace("current_time", f"'{current_time}'")
+
+    # handle vital signs
+    if re.search("[ \n]+([a-zA-Z0-9_]+_lower)", query) and re.search("[ \n]+([a-zA-Z0-9_]+_upper)", query):
+        vital_lower_expr = re.findall("[ \n]+([a-zA-Z0-9_]+_lower)", query)[0]
+        vital_upper_expr = re.findall("[ \n]+([a-zA-Z0-9_]+_upper)", query)[0]
+        vital_name_list = list(set(re.findall("([a-zA-Z0-9_]+)_lower", vital_lower_expr) + re.findall("([a-zA-Z0-9_]+)_upper", vital_upper_expr)))
+        if len(vital_name_list) == 1:
+            processed_vital_name = vital_name_list[0].replace("_", " ")
+            if processed_vital_name in precomputed_dict:
+                vital_range = precomputed_dict[processed_vital_name]
+                query = query.replace(vital_lower_expr, f"{vital_range[0]}").replace(vital_upper_expr, f"{vital_range[1]}")
+
+    # handle etc.
+    query = query.replace("''", "'").replace("< =", "<=")
+    query = query.replace("%y", "%Y").replace("%j", "%J")
+    query = query.replace("‘", "'").replace("’", "'")
+    query = query.replace("\u201c", '"').replace("\u201d", '"')
+
+    return query
+
+
+def post_process_answer(answer, round_digit=6, sorted_answer=False):
+    assert isinstance(answer, list)
+
+    if len(answer) == 0:  # empty answer
+        assert answer == []
+    else:
+        # tuple data preprocessing
+        if isinstance(answer[0], tuple):
+            assert len(answer[0]) == 1
+            answer = [ans[0] for ans in answer]
+
+        if isinstance(answer[0], float):  # float-type answer
+            answer = [round(ans, round_digit) for ans in answer]
+        elif isinstance(answer[0], str):  # string-type answer
+            if sorted_answer:
+                answer = sorted(answer)
+        else:
+            pass
+
+    return answer
+
+
+def main(args):
+    # Initialize error count
+    answer_error_cnt = 0
+
+    # Read json file
+    dataset = json.load(open(args.json_file_path))
+
+    # Load database
+    cur = load_database(db_file_path=args.db_file_path)
+
+    new_dataset = []
+    for data in tqdm(dataset):
+        try:
+            db_id = data["db_id"]
+            assert db_id == "mimic_iv_cxr", "db_id should be mimic_iv_cxr."
+
+            gold_program = data["_gold_program"]
+            gold_program = post_process_sql(gold_program)
+
+            res = cur.execute(gold_program)
+            answer = res.fetchall()
+            answer = post_process_answer(answer)
+
+            data.pop("_gold_program")
+            new_data = {
+                **data,
+                "answer": answer,
+            }
+
+            # Debugging block
+            if args.debug:
+                try:
+                    _answer = data["_answer"]
+                    _answer = post_process_answer(_answer)
+                    assert answer == _answer, f"answer: {answer}, _answer: {_answer}"
+                except:
+                    answer_error_cnt += 1
+                    print(f"answer_error_cnt: {answer_error_cnt}")
+                    print(f"Answer mismatch at {db_id}: retrieved answer {answer}, gold answer {_answer}")
+
+                new_data.pop("_answer")
+
+            assert not any([k.startswith("_") for k in new_data.keys()])
+            new_dataset.append(new_data)
+
+        except Exception as e:
+            print(f"Error processing data {db_id}: {e}")
+            breakpoint()
+            # continue
+
+    print(f"Total answer errors: {answer_error_cnt}")
+
+    # Store new dataset
+    with open(args.output_path, "w") as f:
+        json.dump(new_dataset, f, indent=4, default=str)
+
+    print(f"Saved new dataset to {args.output_path}")
+
+
+if __name__ == "__main__":
+    import argparse
+
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--debug", action="store_true")
+    parser.add_argument("--mimic_iv_dir", type=str, required=True)
+    parser.add_argument("--mimic_cxr_jpg_dir", type=str, required=True)
+    parser.add_argument("--chest_imagenome_dir", type=str, required=True)
+    parser.add_argument("--json_file_path", type=str, default="../dataset/mimic_iv_cxr/_test.json")
+    parser.add_argument("--db_file_path", type=str, default="../database/mimic_iv_cxr/test/mimic_iv_cxr.db")
+    parser.add_argument("--output_path", type=str, default="../dataset/mimic_iv_cxr/test.json")
+    args = parser.parse_args()
+
+    # split
+    args.split = os.path.basename(args.output_path).split(".")[0]
+
+    # postprocess json_file_path
+    if args.debug and "_debug" not in args.json_file_path:
+        args.json_file_path = args.json_file_path.replace(".json", "_debug.json")
+
+    # postprocess db_file_path
+    if args.split == "valid":
+        args.db_file_path = args.db_file_path.replace("/valid/", "/train/")
+
+    print(args)
+    main(args)