[735bb5]: / src / scripts / save_bert_al_models.py

Download this file

96 lines (78 with data), 2.3 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Experiments on the Random Forest model and the different datasets (i.e. n2c2, DDI)
"""
# Base Dependencies
# -----------------
from copy import deepcopy
from pathlib import Path
from os.path import join as pjoin
# Local Dependencies
# ------------------
from training.config import BaalExperimentConfig
from training.bert import BertTrainer
from utils import set_seed
# 3rd-Party Dependencies
# ----------------------
from datasets import load_from_disk
# Constants
# ----------
from constants import (
DDI_HF_TEST_PATH,
DDI_HF_TRAIN_PATH,
N2C2_HF_TRAIN_PATH,
N2C2_HF_TEST_PATH,
N2C2_REL_TYPES,
EXP_RANDOM_SEEDS,
BaalQueryStrategy
)
def main():
config = BaalExperimentConfig(max_epoch=15, batch_size=32,)
query_strategy = BaalQueryStrategy.LC
# DDI
# set random seed
set_seed(EXP_RANDOM_SEEDS[0])
# load datasets
train_dataset = load_from_disk(Path(pjoin(DDI_HF_TRAIN_PATH, "bert")))
test_dataset = load_from_disk(Path(pjoin(DDI_HF_TEST_PATH, "bert")))
# create trainer
trainer = BertTrainer(
dataset="ddi",
train_dataset=train_dataset,
test_dataset=test_dataset,
pairs=False,
)
exp_config = deepcopy(config)
trainer.train_active_learning(
query_strategy=query_strategy,
config=exp_config,
verbose=True,
save_models=True,
logging=False
)
# n2c2
set_seed(EXP_RANDOM_SEEDS[0])
for rel_type in ["Reason-Drug", "Duration-Drug", "ADE-Drug"]:
# load datasets
train_dataset = load_from_disk(
Path(pjoin(N2C2_HF_TRAIN_PATH, "bert", rel_type))
)
test_dataset = load_from_disk(
Path(pjoin(N2C2_HF_TEST_PATH, "bert", rel_type))
)
# create trainer
trainer = BertTrainer(
dataset="n2c2",
train_dataset=train_dataset,
test_dataset=test_dataset,
pairs=False,
relation_type=rel_type,
)
exp_config = deepcopy(config)
trainer.train_active_learning(
query_strategy,
exp_config,
save_models=True,
logging=False
)