|
a |
|
b/ehragent/question_difficulty.py |
|
|
1 |
import json |
|
|
2 |
import re |
|
|
3 |
import matplotlib.pyplot as plt |
|
|
4 |
import collections |
|
|
5 |
import matplotlib |
|
|
6 |
import seaborn as sns |
|
|
7 |
from collections import defaultdict |
|
|
8 |
import os |
|
|
9 |
|
|
|
10 |
|
|
|
11 |
matplotlib.use("Agg") |
|
|
12 |
matplotlib.rcParams.update({'font.family': 'Times New Roman'}) |
|
|
13 |
matplotlib.rcParams['pdf.fonttype'] = 42 |
|
|
14 |
matplotlib.rcParams['ps.fonttype'] = 42 |
|
|
15 |
|
|
|
16 |
sns.set_theme(style="ticks", font="Times New Roman", font_scale=2.1, rc={'grid.linestyle': ':', 'axes.grid': True}) |
|
|
17 |
|
|
|
18 |
for dataset in ["mimic_iii", "eicu"]: |
|
|
19 |
with open(f"<YOUR_DATASET_PATH>", 'r') as f_in, \ |
|
|
20 |
open(f"<YOUR_SAVED_DATASET_PATH>", "w") as f_out: |
|
|
21 |
list_num_q_tag_var = [] |
|
|
22 |
list_num_tables = [] |
|
|
23 |
list_num_columns = [] |
|
|
24 |
num_q_tag_var_dict = defaultdict(list) |
|
|
25 |
num_tables_dict = defaultdict(list) |
|
|
26 |
num_columns_dict = defaultdict(list) |
|
|
27 |
for lines in f_in: |
|
|
28 |
x = json.loads(lines) |
|
|
29 |
|
|
|
30 |
# number of variables in q_tag |
|
|
31 |
num_q_tag_var = x["q_tag"].count("{") + x["q_tag"].count("[") |
|
|
32 |
|
|
|
33 |
# number of different tables used in the sql |
|
|
34 |
# (we only count the tables in the original datasets, but not the new ones created in the sql) |
|
|
35 |
tables = re.findall(r'\bfrom\s+(\w+)\b', x["query"]) |
|
|
36 |
num_tables = len(set(tables)) |
|
|
37 |
|
|
|
38 |
# number of different columns used in the sql |
|
|
39 |
# (we only count the columns in the original datasets, but not the new ones created in the sql) |
|
|
40 |
columns = re.findall(r'\b\w*\.\w*\b', x["query"]) |
|
|
41 |
columns = [item for item in columns if not re.match(r'^t\d', item)] |
|
|
42 |
num_columns = len(set(columns)) |
|
|
43 |
|
|
|
44 |
x["num_q_tag_var"] = num_q_tag_var |
|
|
45 |
x["num_tables"] = num_tables |
|
|
46 |
x["num_columns"] = num_columns |
|
|
47 |
|
|
|
48 |
f_out.write(json.dumps(x) + "\n") |
|
|
49 |
|
|
|
50 |
list_num_q_tag_var.append(num_q_tag_var) |
|
|
51 |
list_num_tables.append(num_tables) |
|
|
52 |
list_num_columns.append(num_columns) |
|
|
53 |
|
|
|
54 |
num_q_tag_var_dict[num_q_tag_var].append(x["id"]) |
|
|
55 |
num_tables_dict[num_tables].append(x["id"]) |
|
|
56 |
num_columns_dict[num_columns].append(x["id"]) |
|
|
57 |
|
|
|
58 |
os.makedirs(f"{dataset}/num_q_tag_var/", exist_ok=True) |
|
|
59 |
for k, id_list in num_q_tag_var_dict.items(): |
|
|
60 |
with open(f"{dataset}/num_q_tag_var/{k}.jsonl", "w") as f: |
|
|
61 |
json.dump(id_list, f) |
|
|
62 |
os.makedirs(f"{dataset}/num_tables/", exist_ok=True) |
|
|
63 |
for k, id_list in num_tables_dict.items(): |
|
|
64 |
with open(f"{dataset}/num_tables/{k}.jsonl", "w") as f: |
|
|
65 |
json.dump(id_list, f) |
|
|
66 |
os.makedirs(f"{dataset}/num_columns/", exist_ok=True) |
|
|
67 |
for k, id_list in num_columns_dict.items(): |
|
|
68 |
with open(f"{dataset}/num_columns/{k}.jsonl", "w") as f: |
|
|
69 |
json.dump(id_list, f) |
|
|
70 |
|
|
|
71 |
|
|
|
72 |
# plot the distribution |
|
|
73 |
all_list = [list_num_q_tag_var, list_num_tables, list_num_columns] |
|
|
74 |
xlabel = ["# q_tag variables", "# tables", "# columns"] |
|
|
75 |
file_name = ["num_q_tag_ver_distri.pdf", "num_tables_distri.pdf", "num_columns_distri.pdf"] |
|
|
76 |
|
|
|
77 |
os.makedirs(f"{dataset}/figures/", exist_ok=True) |
|
|
78 |
for i in range(len(all_list)): |
|
|
79 |
c = collections.Counter(all_list[i]) |
|
|
80 |
c = sorted(c.items()) |
|
|
81 |
plt.figure(figsize=(6, 5.5), dpi=120) |
|
|
82 |
ax = sns.barplot(x=[i[0] for i in c], y=[i[1] for i in c], color=sns.color_palette()[0]) |
|
|
83 |
plt.ylabel("Frequency", size=30) |
|
|
84 |
plt.xlabel(xlabel[i], size=30) |
|
|
85 |
plt.tight_layout(rect=[-0.05, -0.05, 1.05, 1.05]) |
|
|
86 |
plt.savefig(f"{dataset}/figures/{file_name[i]}") |