import json
import re
import matplotlib.pyplot as plt
import collections
import matplotlib
import seaborn as sns
from collections import defaultdict
import os
matplotlib.use("Agg")
matplotlib.rcParams.update({'font.family': 'Times New Roman'})
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
sns.set_theme(style="ticks", font="Times New Roman", font_scale=2.1, rc={'grid.linestyle': ':', 'axes.grid': True})
for dataset in ["mimic_iii", "eicu"]:
with open(f"<YOUR_DATASET_PATH>", 'r') as f_in, \
open(f"<YOUR_SAVED_DATASET_PATH>", "w") as f_out:
list_num_q_tag_var = []
list_num_tables = []
list_num_columns = []
num_q_tag_var_dict = defaultdict(list)
num_tables_dict = defaultdict(list)
num_columns_dict = defaultdict(list)
for lines in f_in:
x = json.loads(lines)
# number of variables in q_tag
num_q_tag_var = x["q_tag"].count("{") + x["q_tag"].count("[")
# number of different tables used in the sql
# (we only count the tables in the original datasets, but not the new ones created in the sql)
tables = re.findall(r'\bfrom\s+(\w+)\b', x["query"])
num_tables = len(set(tables))
# number of different columns used in the sql
# (we only count the columns in the original datasets, but not the new ones created in the sql)
columns = re.findall(r'\b\w*\.\w*\b', x["query"])
columns = [item for item in columns if not re.match(r'^t\d', item)]
num_columns = len(set(columns))
x["num_q_tag_var"] = num_q_tag_var
x["num_tables"] = num_tables
x["num_columns"] = num_columns
f_out.write(json.dumps(x) + "\n")
list_num_q_tag_var.append(num_q_tag_var)
list_num_tables.append(num_tables)
list_num_columns.append(num_columns)
num_q_tag_var_dict[num_q_tag_var].append(x["id"])
num_tables_dict[num_tables].append(x["id"])
num_columns_dict[num_columns].append(x["id"])
os.makedirs(f"{dataset}/num_q_tag_var/", exist_ok=True)
for k, id_list in num_q_tag_var_dict.items():
with open(f"{dataset}/num_q_tag_var/{k}.jsonl", "w") as f:
json.dump(id_list, f)
os.makedirs(f"{dataset}/num_tables/", exist_ok=True)
for k, id_list in num_tables_dict.items():
with open(f"{dataset}/num_tables/{k}.jsonl", "w") as f:
json.dump(id_list, f)
os.makedirs(f"{dataset}/num_columns/", exist_ok=True)
for k, id_list in num_columns_dict.items():
with open(f"{dataset}/num_columns/{k}.jsonl", "w") as f:
json.dump(id_list, f)
# plot the distribution
all_list = [list_num_q_tag_var, list_num_tables, list_num_columns]
xlabel = ["# q_tag variables", "# tables", "# columns"]
file_name = ["num_q_tag_ver_distri.pdf", "num_tables_distri.pdf", "num_columns_distri.pdf"]
os.makedirs(f"{dataset}/figures/", exist_ok=True)
for i in range(len(all_list)):
c = collections.Counter(all_list[i])
c = sorted(c.items())
plt.figure(figsize=(6, 5.5), dpi=120)
ax = sns.barplot(x=[i[0] for i in c], y=[i[1] for i in c], color=sns.color_palette()[0])
plt.ylabel("Frequency", size=30)
plt.xlabel(xlabel[i], size=30)
plt.tight_layout(rect=[-0.05, -0.05, 1.05, 1.05])
plt.savefig(f"{dataset}/figures/{file_name[i]}")