Diff of /src/experiments/bert.py [000000] .. [735bb5]

Switch to unified view

a b/src/experiments/bert.py
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
"""
5
Experiments on the BERT model and the different datasets (i.e. n2c2, DDI)
6
"""
7
8
# Base Dependencies
9
# -----------------
10
from copy import deepcopy
11
from pathlib import Path
12
from os.path import join as pjoin
13
14
# Package Dependencies
15
# --------------------
16
from .common import final_repetition
17
18
# Local Dependencies
19
# ------------------
20
from training.config import PLExperimentConfig, BaalExperimentConfig
21
from training.bert import BertTrainer
22
from utils import set_seed
23
24
# 3rd-Party Dependencies
25
# ----------------------
26
from datasets import load_from_disk
27
28
# Constants
29
# ----------
30
from constants import (
31
    DDI_HF_TEST_PATH,
32
    DDI_HF_TRAIN_PATH,
33
    N2C2_HF_TRAIN_PATH,
34
    N2C2_HF_TEST_PATH,
35
    N2C2_REL_TYPES,
36
    EXP_RANDOM_SEEDS,
37
    BaalQueryStrategy
38
)
39
40
MODEL_NAME = "bert"
41
42
43
def bert_passive_learning_n2c2(init_repetition: int = 0, n_repetitions: int = 5, pairs: bool = False, logging: bool = True):
44
45
    config = PLExperimentConfig(
46
        max_epoch=25, batch_size=32, val_size=0.2, es_patience=3
47
    )
48
49
    for repetition in range(init_repetition, final_repetition(init_repetition, n_repetitions)):
50
        # set random seed
51
        random_seed: int = EXP_RANDOM_SEEDS[repetition]
52
        set_seed(random_seed)
53
        config.seed = random_seed
54
55
        for rel_type in N2C2_REL_TYPES:
56
            # load datasets
57
            train_dataset = load_from_disk(
58
                str(Path(pjoin(N2C2_HF_TRAIN_PATH, MODEL_NAME, rel_type)))
59
            )
60
            test_dataset = load_from_disk(
61
                str(Path(pjoin(N2C2_HF_TEST_PATH, MODEL_NAME, rel_type)))
62
            )
63
64
            # create trainer
65
            trainer = BertTrainer(
66
                dataset="n2c2",
67
                train_dataset=train_dataset,
68
                test_dataset=test_dataset,
69
                pairs=pairs,
70
                relation_type=rel_type,
71
            )
72
73
            # train passive learning
74
            trainer.train_passive_learning(config=config, logging=logging)
75
76
77
def bert_active_learning_n2c2(init_repetition: int = 0, n_repetitions: int = 5, pairs: bool = False, logging: bool = True):
78
79
    config = BaalExperimentConfig(max_epoch=10, batch_size=32)
80
81
    for repetition in range(init_repetition, final_repetition(init_repetition, n_repetitions)):
82
        # set random seed
83
        random_seed: int = EXP_RANDOM_SEEDS[repetition]
84
        set_seed(random_seed)
85
        config.seed = random_seed
86
87
        for rel_type in N2C2_REL_TYPES:
88
89
            # load datasets
90
            train_dataset = load_from_disk(
91
                str(Path(pjoin(N2C2_HF_TRAIN_PATH, MODEL_NAME, rel_type)))
92
            )
93
            test_dataset = load_from_disk(
94
                str(Path(pjoin(N2C2_HF_TEST_PATH, MODEL_NAME, rel_type)))
95
            )
96
97
            # create trainer
98
            trainer = BertTrainer(
99
                dataset="n2c2",
100
                train_dataset=train_dataset,
101
                test_dataset=test_dataset,
102
                pairs=pairs,
103
                relation_type=rel_type,
104
            )
105
106
            for query_strategy in BaalQueryStrategy:
107
                exp_config = deepcopy(config)
108
                trainer.train_active_learning(query_strategy, exp_config, logging=logging)
109
110
111
def bert_passive_learning_ddi(init_repetition: int = 0, n_repetitions: int = 5, pairs: bool = False, logging: bool = True):
112
113
    config = PLExperimentConfig(
114
        max_epoch=25, batch_size=32, val_size=0.2, es_patience=3
115
    )
116
117
    for repetition in range(init_repetition, final_repetition(init_repetition, n_repetitions)):
118
        # set random seed
119
        random_seed: int = EXP_RANDOM_SEEDS[repetition]
120
        set_seed(random_seed)
121
        config.seed = random_seed
122
123
        # load datasets
124
        train_dataset = load_from_disk(str(Path(pjoin(DDI_HF_TRAIN_PATH, MODEL_NAME))))
125
        test_dataset = load_from_disk(str(Path(pjoin(DDI_HF_TEST_PATH, MODEL_NAME))))
126
127
        # create trainer
128
        trainer = BertTrainer(
129
            dataset="ddi",
130
            train_dataset=train_dataset,
131
            test_dataset=test_dataset,
132
            pairs=pairs,
133
        )
134
135
        # train passive learning
136
        trainer.train_passive_learning(config=config, logging=logging)
137
138
139
def bert_active_learning_ddi(init_repetition: int = 0, n_repetitions: int = 5, pairs: bool = False, logging: bool = True):
140
141
    config = BaalExperimentConfig(max_epoch=15, batch_size=32,)
142
143
    for repetition in range(init_repetition, final_repetition(init_repetition, n_repetitions)):
144
        # set random seed
145
        random_seed: int = EXP_RANDOM_SEEDS[repetition]
146
        set_seed(random_seed)
147
        config.seed = random_seed
148
149
        # load datasets
150
        train_dataset = load_from_disk(str(Path(pjoin(DDI_HF_TRAIN_PATH, MODEL_NAME))))
151
        test_dataset = load_from_disk(str(Path(pjoin(DDI_HF_TEST_PATH, MODEL_NAME))))
152
153
        # create trainer
154
        trainer = BertTrainer(
155
            dataset="ddi",
156
            train_dataset=train_dataset,
157
            test_dataset=test_dataset,
158
            pairs=pairs,
159
        )
160
161
        for query_strategy in BaalQueryStrategy:
162
            exp_config = deepcopy(config)
163
            trainer.train_active_learning(query_strategy, exp_config, logging=logging)