[dec218]: / dataset_builder / generate_answer.py

Download this file

161 lines (125 with data), 5.4 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import os
import re
import json
import pandas as pd
import logging
from tqdm import tqdm
def load_database(db_file_path):
import sqlite3
conn = sqlite3.connect(db_file_path)
cur = conn.cursor()
return cur
def post_process_sql(
query,
current_time="2105-12-31 23:59:00",
precomputed_dict={
"temperature": (35.5, 38.1),
"sao2": (95.0, 100.0),
"heart rate": (60.0, 100.0),
"respiration": (12.0, 18.0),
"systolic bp": (90.0, 120.0),
"diastolic bp": (60.0, 90.0),
"mean bp": (60.0, 110.0),
},
):
# handle current_time
if "current_time" in query:
query = query.replace("current_time", f"'{current_time}'")
# handle vital signs
if re.search("[ \n]+([a-zA-Z0-9_]+_lower)", query) and re.search("[ \n]+([a-zA-Z0-9_]+_upper)", query):
vital_lower_expr = re.findall("[ \n]+([a-zA-Z0-9_]+_lower)", query)[0]
vital_upper_expr = re.findall("[ \n]+([a-zA-Z0-9_]+_upper)", query)[0]
vital_name_list = list(set(re.findall("([a-zA-Z0-9_]+)_lower", vital_lower_expr) + re.findall("([a-zA-Z0-9_]+)_upper", vital_upper_expr)))
if len(vital_name_list) == 1:
processed_vital_name = vital_name_list[0].replace("_", " ")
if processed_vital_name in precomputed_dict:
vital_range = precomputed_dict[processed_vital_name]
query = query.replace(vital_lower_expr, f"{vital_range[0]}").replace(vital_upper_expr, f"{vital_range[1]}")
# handle etc.
query = query.replace("''", "'").replace("< =", "<=")
query = query.replace("%y", "%Y").replace("%j", "%J")
query = query.replace("‘", "'").replace("’", "'")
query = query.replace("\u201c", '"').replace("\u201d", '"')
return query
def post_process_answer(answer, round_digit=6, sorted_answer=False):
assert isinstance(answer, list)
if len(answer) == 0: # empty answer
assert answer == []
else:
# tuple data preprocessing
if isinstance(answer[0], tuple):
assert len(answer[0]) == 1
answer = [ans[0] for ans in answer]
if isinstance(answer[0], float): # float-type answer
answer = [round(ans, round_digit) for ans in answer]
elif isinstance(answer[0], str): # string-type answer
if sorted_answer:
answer = sorted(answer)
else:
pass
return answer
def main(args):
# Initialize error count
answer_error_cnt = 0
# Read json file
dataset = json.load(open(args.json_file_path))
# Load database
cur = load_database(db_file_path=args.db_file_path)
new_dataset = []
for data in tqdm(dataset):
try:
db_id = data["db_id"]
assert db_id == "mimic_iv_cxr", "db_id should be mimic_iv_cxr."
gold_program = data["_gold_program"]
gold_program = post_process_sql(gold_program)
res = cur.execute(gold_program)
answer = res.fetchall()
answer = post_process_answer(answer)
data.pop("_gold_program")
new_data = {
**data,
"answer": answer,
}
# Debugging block
if args.debug:
try:
_answer = data["_answer"]
_answer = post_process_answer(_answer)
assert answer == _answer, f"answer: {answer}, _answer: {_answer}"
except:
answer_error_cnt += 1
print(f"answer_error_cnt: {answer_error_cnt}")
print(f"Answer mismatch at {db_id}: retrieved answer {answer}, gold answer {_answer}")
new_data.pop("_answer")
assert not any([k.startswith("_") for k in new_data.keys()])
new_dataset.append(new_data)
except Exception as e:
print(f"Error processing data {db_id}: {e}")
breakpoint()
# continue
print(f"Total answer errors: {answer_error_cnt}")
# Store new dataset
with open(args.output_path, "w") as f:
json.dump(new_dataset, f, indent=4, default=str)
print(f"Saved new dataset to {args.output_path}")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--debug", action="store_true")
parser.add_argument("--mimic_iv_dir", type=str, required=True)
parser.add_argument("--mimic_cxr_jpg_dir", type=str, required=True)
parser.add_argument("--chest_imagenome_dir", type=str, required=True)
parser.add_argument("--json_file_path", type=str, default="../dataset/mimic_iv_cxr/_test.json")
parser.add_argument("--db_file_path", type=str, default="../database/mimic_iv_cxr/test/mimic_iv_cxr.db")
parser.add_argument("--output_path", type=str, default="../dataset/mimic_iv_cxr/test.json")
args = parser.parse_args()
# split
args.split = os.path.basename(args.output_path).split(".")[0]
# postprocess json_file_path
if args.debug and "_debug" not in args.json_file_path:
args.json_file_path = args.json_file_path.replace(".json", "_debug.json")
# postprocess db_file_path
if args.split == "valid":
args.db_file_path = args.db_file_path.replace("/valid/", "/train/")
print(args)
main(args)