--- a +++ b/ehragent/question_difficulty.py @@ -0,0 +1,86 @@ +import json +import re +import matplotlib.pyplot as plt +import collections +import matplotlib +import seaborn as sns +from collections import defaultdict +import os + + +matplotlib.use("Agg") +matplotlib.rcParams.update({'font.family': 'Times New Roman'}) +matplotlib.rcParams['pdf.fonttype'] = 42 +matplotlib.rcParams['ps.fonttype'] = 42 + +sns.set_theme(style="ticks", font="Times New Roman", font_scale=2.1, rc={'grid.linestyle': ':', 'axes.grid': True}) + +for dataset in ["mimic_iii", "eicu"]: + with open(f"<YOUR_DATASET_PATH>", 'r') as f_in, \ + open(f"<YOUR_SAVED_DATASET_PATH>", "w") as f_out: + list_num_q_tag_var = [] + list_num_tables = [] + list_num_columns = [] + num_q_tag_var_dict = defaultdict(list) + num_tables_dict = defaultdict(list) + num_columns_dict = defaultdict(list) + for lines in f_in: + x = json.loads(lines) + + # number of variables in q_tag + num_q_tag_var = x["q_tag"].count("{") + x["q_tag"].count("[") + + # number of different tables used in the sql + # (we only count the tables in the original datasets, but not the new ones created in the sql) + tables = re.findall(r'\bfrom\s+(\w+)\b', x["query"]) + num_tables = len(set(tables)) + + # number of different columns used in the sql + # (we only count the columns in the original datasets, but not the new ones created in the sql) + columns = re.findall(r'\b\w*\.\w*\b', x["query"]) + columns = [item for item in columns if not re.match(r'^t\d', item)] + num_columns = len(set(columns)) + + x["num_q_tag_var"] = num_q_tag_var + x["num_tables"] = num_tables + x["num_columns"] = num_columns + + f_out.write(json.dumps(x) + "\n") + + list_num_q_tag_var.append(num_q_tag_var) + list_num_tables.append(num_tables) + list_num_columns.append(num_columns) + + num_q_tag_var_dict[num_q_tag_var].append(x["id"]) + num_tables_dict[num_tables].append(x["id"]) + num_columns_dict[num_columns].append(x["id"]) + + os.makedirs(f"{dataset}/num_q_tag_var/", exist_ok=True) + for k, id_list in num_q_tag_var_dict.items(): + with open(f"{dataset}/num_q_tag_var/{k}.jsonl", "w") as f: + json.dump(id_list, f) + os.makedirs(f"{dataset}/num_tables/", exist_ok=True) + for k, id_list in num_tables_dict.items(): + with open(f"{dataset}/num_tables/{k}.jsonl", "w") as f: + json.dump(id_list, f) + os.makedirs(f"{dataset}/num_columns/", exist_ok=True) + for k, id_list in num_columns_dict.items(): + with open(f"{dataset}/num_columns/{k}.jsonl", "w") as f: + json.dump(id_list, f) + + + # plot the distribution + all_list = [list_num_q_tag_var, list_num_tables, list_num_columns] + xlabel = ["# q_tag variables", "# tables", "# columns"] + file_name = ["num_q_tag_ver_distri.pdf", "num_tables_distri.pdf", "num_columns_distri.pdf"] + + os.makedirs(f"{dataset}/figures/", exist_ok=True) + for i in range(len(all_list)): + c = collections.Counter(all_list[i]) + c = sorted(c.items()) + plt.figure(figsize=(6, 5.5), dpi=120) + ax = sns.barplot(x=[i[0] for i in c], y=[i[1] for i in c], color=sns.color_palette()[0]) + plt.ylabel("Frequency", size=30) + plt.xlabel(xlabel[i], size=30) + plt.tight_layout(rect=[-0.05, -0.05, 1.05, 1.05]) + plt.savefig(f"{dataset}/figures/{file_name[i]}")