Switch to unified view

a b/ehragent/question_difficulty.py
1
import json
2
import re
3
import matplotlib.pyplot as plt
4
import collections
5
import matplotlib
6
import seaborn as sns
7
from collections import defaultdict
8
import os
9
10
11
matplotlib.use("Agg")
12
matplotlib.rcParams.update({'font.family': 'Times New Roman'})
13
matplotlib.rcParams['pdf.fonttype'] = 42
14
matplotlib.rcParams['ps.fonttype'] = 42
15
16
sns.set_theme(style="ticks", font="Times New Roman", font_scale=2.1, rc={'grid.linestyle': ':', 'axes.grid': True})
17
18
for dataset in ["mimic_iii", "eicu"]:
19
    with open(f"<YOUR_DATASET_PATH>", 'r') as f_in, \
20
            open(f"<YOUR_SAVED_DATASET_PATH>", "w") as f_out:
21
        list_num_q_tag_var = []
22
        list_num_tables = []
23
        list_num_columns = []
24
        num_q_tag_var_dict = defaultdict(list)
25
        num_tables_dict = defaultdict(list)
26
        num_columns_dict = defaultdict(list)
27
        for lines in f_in:
28
            x = json.loads(lines)
29
30
            # number of variables in q_tag
31
            num_q_tag_var = x["q_tag"].count("{") + x["q_tag"].count("[")
32
33
            # number of different tables used in the sql
34
            # (we only count the tables in the original datasets, but not the new ones created in the sql)
35
            tables = re.findall(r'\bfrom\s+(\w+)\b', x["query"])
36
            num_tables = len(set(tables))
37
38
            # number of different columns used in the sql
39
            # (we only count the columns in the original datasets, but not the new ones created in the sql)
40
            columns = re.findall(r'\b\w*\.\w*\b', x["query"])
41
            columns = [item for item in columns if not re.match(r'^t\d', item)]
42
            num_columns = len(set(columns))
43
44
            x["num_q_tag_var"] = num_q_tag_var
45
            x["num_tables"] = num_tables
46
            x["num_columns"] = num_columns
47
48
            f_out.write(json.dumps(x) + "\n")
49
50
            list_num_q_tag_var.append(num_q_tag_var)
51
            list_num_tables.append(num_tables)
52
            list_num_columns.append(num_columns)
53
54
            num_q_tag_var_dict[num_q_tag_var].append(x["id"])
55
            num_tables_dict[num_tables].append(x["id"])
56
            num_columns_dict[num_columns].append(x["id"])
57
58
    os.makedirs(f"{dataset}/num_q_tag_var/", exist_ok=True)
59
    for k, id_list in num_q_tag_var_dict.items():
60
        with open(f"{dataset}/num_q_tag_var/{k}.jsonl", "w") as f:
61
            json.dump(id_list, f)
62
    os.makedirs(f"{dataset}/num_tables/", exist_ok=True)
63
    for k, id_list in num_tables_dict.items():
64
        with open(f"{dataset}/num_tables/{k}.jsonl", "w") as f:
65
            json.dump(id_list, f)
66
    os.makedirs(f"{dataset}/num_columns/", exist_ok=True)
67
    for k, id_list in num_columns_dict.items():
68
        with open(f"{dataset}/num_columns/{k}.jsonl", "w") as f:
69
            json.dump(id_list, f)
70
71
72
    # plot the distribution
73
    all_list = [list_num_q_tag_var, list_num_tables, list_num_columns]
74
    xlabel = ["# q_tag variables", "# tables", "# columns"]
75
    file_name = ["num_q_tag_ver_distri.pdf", "num_tables_distri.pdf", "num_columns_distri.pdf"]
76
77
    os.makedirs(f"{dataset}/figures/", exist_ok=True)
78
    for i in range(len(all_list)):
79
        c = collections.Counter(all_list[i])
80
        c = sorted(c.items())
81
        plt.figure(figsize=(6, 5.5), dpi=120)
82
        ax = sns.barplot(x=[i[0] for i in c], y=[i[1] for i in c], color=sns.color_palette()[0])
83
        plt.ylabel("Frequency", size=30)
84
        plt.xlabel(xlabel[i], size=30)
85
        plt.tight_layout(rect=[-0.05, -0.05, 1.05, 1.05])
86
        plt.savefig(f"{dataset}/figures/{file_name[i]}")