Switch to side-by-side view

--- a
+++ b/ehragent/question_difficulty.py
@@ -0,0 +1,86 @@
+import json
+import re
+import matplotlib.pyplot as plt
+import collections
+import matplotlib
+import seaborn as sns
+from collections import defaultdict
+import os
+
+
+matplotlib.use("Agg")
+matplotlib.rcParams.update({'font.family': 'Times New Roman'})
+matplotlib.rcParams['pdf.fonttype'] = 42
+matplotlib.rcParams['ps.fonttype'] = 42
+
+sns.set_theme(style="ticks", font="Times New Roman", font_scale=2.1, rc={'grid.linestyle': ':', 'axes.grid': True})
+
+for dataset in ["mimic_iii", "eicu"]:
+    with open(f"<YOUR_DATASET_PATH>", 'r') as f_in, \
+            open(f"<YOUR_SAVED_DATASET_PATH>", "w") as f_out:
+        list_num_q_tag_var = []
+        list_num_tables = []
+        list_num_columns = []
+        num_q_tag_var_dict = defaultdict(list)
+        num_tables_dict = defaultdict(list)
+        num_columns_dict = defaultdict(list)
+        for lines in f_in:
+            x = json.loads(lines)
+
+            # number of variables in q_tag
+            num_q_tag_var = x["q_tag"].count("{") + x["q_tag"].count("[")
+
+            # number of different tables used in the sql
+            # (we only count the tables in the original datasets, but not the new ones created in the sql)
+            tables = re.findall(r'\bfrom\s+(\w+)\b', x["query"])
+            num_tables = len(set(tables))
+
+            # number of different columns used in the sql
+            # (we only count the columns in the original datasets, but not the new ones created in the sql)
+            columns = re.findall(r'\b\w*\.\w*\b', x["query"])
+            columns = [item for item in columns if not re.match(r'^t\d', item)]
+            num_columns = len(set(columns))
+
+            x["num_q_tag_var"] = num_q_tag_var
+            x["num_tables"] = num_tables
+            x["num_columns"] = num_columns
+
+            f_out.write(json.dumps(x) + "\n")
+
+            list_num_q_tag_var.append(num_q_tag_var)
+            list_num_tables.append(num_tables)
+            list_num_columns.append(num_columns)
+
+            num_q_tag_var_dict[num_q_tag_var].append(x["id"])
+            num_tables_dict[num_tables].append(x["id"])
+            num_columns_dict[num_columns].append(x["id"])
+
+    os.makedirs(f"{dataset}/num_q_tag_var/", exist_ok=True)
+    for k, id_list in num_q_tag_var_dict.items():
+        with open(f"{dataset}/num_q_tag_var/{k}.jsonl", "w") as f:
+            json.dump(id_list, f)
+    os.makedirs(f"{dataset}/num_tables/", exist_ok=True)
+    for k, id_list in num_tables_dict.items():
+        with open(f"{dataset}/num_tables/{k}.jsonl", "w") as f:
+            json.dump(id_list, f)
+    os.makedirs(f"{dataset}/num_columns/", exist_ok=True)
+    for k, id_list in num_columns_dict.items():
+        with open(f"{dataset}/num_columns/{k}.jsonl", "w") as f:
+            json.dump(id_list, f)
+
+
+    # plot the distribution
+    all_list = [list_num_q_tag_var, list_num_tables, list_num_columns]
+    xlabel = ["# q_tag variables", "# tables", "# columns"]
+    file_name = ["num_q_tag_ver_distri.pdf", "num_tables_distri.pdf", "num_columns_distri.pdf"]
+
+    os.makedirs(f"{dataset}/figures/", exist_ok=True)
+    for i in range(len(all_list)):
+        c = collections.Counter(all_list[i])
+        c = sorted(c.items())
+        plt.figure(figsize=(6, 5.5), dpi=120)
+        ax = sns.barplot(x=[i[0] for i in c], y=[i[1] for i in c], color=sns.color_palette()[0])
+        plt.ylabel("Frequency", size=30)
+        plt.xlabel(xlabel[i], size=30)
+        plt.tight_layout(rect=[-0.05, -0.05, 1.05, 1.05])
+        plt.savefig(f"{dataset}/figures/{file_name[i]}")