Switch to unified view

a b/src/preprocessing/generate_statistics.py
1
# coding: utf-8
2
3
# Base Dependencies
4
# ------------------
5
import numpy as np
6
from collections import Counter
7
from typing import Dict
8
from pathlib import Path
9
from os.path import join as pjoin
10
from tqdm import tqdm
11
12
# Local Dependencies
13
# ------------------
14
from models.relation_collection import RelationCollection
15
16
# 3rd-Party Dependencies
17
# ----------------------
18
import pandas as pd
19
from tabulate import tabulate
20
21
# Constants
22
# ---------
23
from constants import N2C2_REL_TYPES, DDI_ALL_TYPES, N2C2_PATH, DDI_PATH
24
25
TABLE_FORMAT = "latex"
26
27
28
# Main Functions
29
# ---------------
30
def generate_statistics(dataset: str, collections: Dict[str, RelationCollection]):
31
    if dataset == "n2c2":
32
        return generate_statistics_n2c2(collections)
33
    elif dataset == "ddi":
34
        return generate_statistics_ddi(collections)
35
    else:
36
        raise ValueError("unsupported dataset '{}'".format(dataset))
37
38
39
def generate_statistics_n2c2(collections: Dict[str, RelationCollection]):
40
    """Generates the statistics for the n2c2 dataset"""
41
42
    df_counts = {
43
        "relation": [],
44
        "train_positive": [],
45
        "train_negative": [],
46
        "test_positive": [],
47
        "test_negative": [],
48
    }
49
    df_seq_lengths = {
50
        "relation": [],
51
        "train_min": [],
52
        "train_avg": [],
53
        "train_max": [],
54
        "test_min": [],
55
        "test_avg": [],
56
        "test_max": [],
57
    }
58
59
    # number of relations per type of relation
60
    for rel_type in tqdm(N2C2_REL_TYPES):
61
        df_counts["relation"].append(rel_type)
62
        df_seq_lengths["relation"].append(rel_type)
63
64
        for split, collection in collections.items():
65
            subcollection = collection.type_subcollection(rel_type)
66
67
            # add counts to data
68
            count_labels = Counter(subcollection.labels)
69
70
            df_counts[split + "_negative"].append(count_labels[0])
71
            df_counts[split + "_positive"].append(count_labels[1])
72
73
            # add sequence length to dataframe
74
            seq_lengths = list(
75
                map(lambda rel: len(rel.text.split()), subcollection.relations)
76
            )
77
78
            df_seq_lengths[split + "_min"].append(min(seq_lengths))
79
            df_seq_lengths[split + "_avg"].append(sum(seq_lengths) / len(subcollection))
80
            df_seq_lengths[split + "_max"].append(max(seq_lengths))
81
82
    df_counts = pd.DataFrame(df_counts)
83
    df_seq_lengths = pd.DataFrame(df_seq_lengths)
84
85
    # add totals to counts
86
    df_counts["train_total"] = df_counts["train_positive"] + df_counts["train_negative"]
87
    df_counts["test_total"] = df_counts["test_positive"] + df_counts["test_negative"]
88
    df_counts["total_positive"] = (
89
        df_counts["train_positive"] + df_counts["test_positive"]
90
    )
91
    df_counts["total_negative"] = (
92
        df_counts["train_negative"] + df_counts["test_negative"]
93
    )
94
    df_counts["total"] = df_counts["total_positive"] + df_counts["total_negative"]
95
    df_counts.loc[len(df_counts)] = ["Total"] + [
96
        df_counts[col].sum() for col in df_counts.columns[1:]
97
    ]
98
99
    all_train_seq_lengths = list(map(lambda rel: len(rel.text.split()), collections["train"].relations))
100
    all_test_seq_lengths = list(map(lambda rel: len(rel.text.split()), collections["test"].relations))
101
    df_seq_lengths = df_seq_lengths.append(
102
        {
103
            "relation": "Overall",
104
            "train_min": min(all_train_seq_lengths),
105
            "train_avg": sum(all_train_seq_lengths) / len(all_train_seq_lengths),
106
            "train_max": max(all_train_seq_lengths),
107
            "test_min": min(all_test_seq_lengths),
108
            "test_avg": sum(all_test_seq_lengths) / len(all_test_seq_lengths),
109
            "test_max": max(all_test_seq_lengths),
110
        },
111
        ignore_index=True,
112
    )
113
114
    # select and reorder columns
115
    df_counts = df_counts.loc[:, ["relation", "train_positive", "train_negative", "train_total", "test_positive", "test_negative", "test_total", "total"]]
116
117
    # save data to csv
118
    df_counts.to_csv(Path(pjoin(N2C2_PATH, "counts.csv")), index=False)
119
    df_seq_lengths.to_csv(Path(pjoin(N2C2_PATH, "seq_length.csv")), index=False)
120
121
    # print statistics
122
    print("\n **** Statistics of the N2C2 Dataset ****")
123
    print("Counts:")
124
    print(tabulate(df_counts, headers="keys", tablefmt=TABLE_FORMAT))
125
    print("Seq Length:")
126
    print(tabulate(df_seq_lengths, headers="keys", tablefmt=TABLE_FORMAT))
127
128
129
def generate_statistics_ddi(collections: Dict[str, RelationCollection]) -> None:
130
    """Generates the statistics of the DDI dataset"""
131
132
    df_counts = {"relation": [], "train": [], "test": []}
133
    df_seq_lengths = {
134
        "relation": [],
135
        "train_min": [],
136
        "train_avg": [],
137
        "train_max": [],
138
        "test_min": [],
139
        "test_avg": [],
140
        "test_max": [],
141
    }
142
143
    for rel_type in DDI_ALL_TYPES:
144
        df_counts["relation"].append(rel_type)
145
        df_seq_lengths["relation"].append(rel_type)
146
147
        for split, collection in collections.items():
148
            subcollection = collection.type_subcollection(rel_type)
149
150
            df_counts[split].append(len(subcollection))
151
152
            seq_lengths = list(
153
                map(lambda rel: len(rel.text.split()), subcollection.relations)
154
            )
155
            df_seq_lengths[split + "_min"].append(min(seq_lengths))
156
            df_seq_lengths[split + "_avg"].append(sum(seq_lengths) / len(subcollection))
157
            df_seq_lengths[split + "_max"].append(max(seq_lengths))
158
159
    # convert to dataframes
160
    df_counts = pd.DataFrame(df_counts)
161
    df_seq_lengths = pd.DataFrame(df_seq_lengths)
162
163
    # add totals
164
    train_negative = df_counts.loc[(df_counts["relation"] == "NO-REL"), "train"].values[
165
        0
166
    ]
167
    train_positive = df_counts.loc[(df_counts["relation"] != "NO-REL"), "train"].sum()
168
    test_negative = df_counts.loc[(df_counts["relation"] == "NO-REL"), "test"].values[0]
169
    test_positive = df_counts.loc[(df_counts["relation"] != "NO-REL"), "test"].sum()
170
    train_total = train_positive + train_negative
171
    test_total = test_positive + test_negative
172
    total = train_total + test_total
173
174
    # add positive row
175
    df_counts.loc[len(df_counts)] = ["Total Positive", train_positive, test_positive]
176
177
    # add totals
178
    df_counts["total"] = df_counts["train"] + df_counts["test"]
179
    df_counts.loc[len(df_counts)] = [" Total", train_total, test_total, total]
180
181
    all_train_seq_lengths = list(map(lambda rel: len(rel.text.split()), collections["train"].relations))
182
    all_test_seq_lengths = list(map(lambda rel: len(rel.text.split()), collections["test"].relations))
183
    df_seq_lengths = df_seq_lengths.append(
184
        {
185
            "relation": "Overall",
186
            "train_min": min(all_train_seq_lengths),
187
            "train_avg": sum(all_train_seq_lengths) / len(all_train_seq_lengths),
188
            "train_max": max(all_train_seq_lengths),
189
            "test_min": min(all_test_seq_lengths),
190
            "test_avg": sum(all_test_seq_lengths) / len(all_test_seq_lengths),
191
            "test_max": max(all_test_seq_lengths),
192
        },
193
        ignore_index=True,
194
    )
195
196
    # save data to csv
197
    df_counts.to_csv(Path(pjoin(DDI_PATH, "counts.csv")), index=False)
198
    df_seq_lengths.to_csv(Path(pjoin(DDI_PATH, "seq_length.csv")), index=False)
199
200
    # print statistics
201
    print("\n **** Statistics of the DDI Dataset ****")
202
    print("Counts:")
203
    print(tabulate(df_counts, headers="keys", tablefmt=TABLE_FORMAT))
204
    print("Seq Length:")
205
    print(tabulate(df_seq_lengths, headers="keys", tablefmt=TABLE_FORMAT))