Diff of /src/experiments/bilstm.py [000000] .. [735bb5]

Switch to unified view

a b/src/experiments/bilstm.py
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
"""
5
Experiments on the BiLSTM model and the different datasets (i.e. n2c2, DDI)
6
"""
7
8
# Base Dependencies
9
# -----------------
10
from copy import deepcopy
11
from pathlib import Path
12
from os.path import join as pjoin
13
14
# Package Dependencies
15
# --------------------
16
from .common import final_repetition
17
18
# Local Dependencies
19
# ------------------
20
from training.config import PLExperimentConfig, BaalExperimentConfig
21
from training.bilstm import BilstmTrainer
22
from utils import set_seed
23
24
# 3rd-Party Dependencies
25
# ----------------------
26
from datasets import load_from_disk
27
28
# Constants
29
# ----------
30
from constants import (
31
    DDI_HF_TEST_PATH,
32
    DDI_HF_TRAIN_PATH,
33
    N2C2_HF_TRAIN_PATH,
34
    N2C2_HF_TEST_PATH,
35
    N2C2_REL_TYPES,
36
    BaalQueryStrategy,
37
    EXP_RANDOM_SEEDS
38
)
39
40
MODEL_NAME = "bilstm"
41
42
43
def bilstm_passive_learning_n2c2(init_repetition: int = 0, n_repetitions: int = 5, logging: bool = True):
44
45
    config = PLExperimentConfig(
46
        max_epoch=25,
47
        batch_size=32
48
    )
49
  
50
    for rel_type in N2C2_REL_TYPES:
51
        # load datasets
52
        train_dataset = load_from_disk(
53
            str(Path(pjoin(N2C2_HF_TRAIN_PATH, MODEL_NAME, rel_type)))
54
        )
55
        test_dataset = load_from_disk(
56
            str(Path(pjoin(N2C2_HF_TEST_PATH, MODEL_NAME, rel_type)))
57
        )
58
59
        # create trainer
60
        trainer = BilstmTrainer("n2c2", train_dataset, test_dataset, rel_type)
61
            
62
        for repetition in range(init_repetition, final_repetition(init_repetition, n_repetitions)):
63
            # set random seed
64
            random_seed: int = EXP_RANDOM_SEEDS[repetition]
65
            set_seed(random_seed)
66
            config.seed = random_seed
67
            # train passive learning
68
            trainer.train_passive_learning(config=config, logging=logging)
69
70
71
def bilstm_passive_learning_ddi(init_repetition:int = 0, n_repetitions: int = 5, logging: bool = True):
72
    config = PLExperimentConfig(
73
        max_epoch=25,
74
        batch_size=32
75
    )
76
77
    # load datasets
78
    train_dataset = load_from_disk(str(Path(pjoin(DDI_HF_TRAIN_PATH, MODEL_NAME))))
79
    test_dataset = load_from_disk(str(Path(pjoin(DDI_HF_TEST_PATH, MODEL_NAME))))
80
81
    # create trainer
82
    trainer = BilstmTrainer("ddi", train_dataset, test_dataset)
83
84
    # train passive learing
85
    for repetition in range(init_repetition, final_repetition(init_repetition, n_repetitions)):
86
        # set random seed
87
        random_seed: int = EXP_RANDOM_SEEDS[repetition]
88
        set_seed(random_seed)
89
        config.seed = random_seed
90
        
91
        trainer.train_passive_learning(config=config, logging=logging)
92
93
94
def bilstm_active_learning_n2c2(init_repetition:int = 0, n_repetitions: int = 5, logging: bool = True):
95
96
    config = BaalExperimentConfig(
97
        max_epoch=15,
98
        batch_size=32,
99
        all_bayesian=False
100
    )
101
    
102
    
103
    for repetition in range(init_repetition, final_repetition(init_repetition, n_repetitions)):
104
        # set random seed
105
        random_seed: int = EXP_RANDOM_SEEDS[repetition]
106
        set_seed(random_seed)
107
        config.seed = random_seed
108
109
        for rel_type in N2C2_REL_TYPES:
110
111
            # load datasets
112
            train_dataset = load_from_disk(
113
                str(Path(pjoin(N2C2_HF_TRAIN_PATH, MODEL_NAME, rel_type)))
114
            )
115
            test_dataset = load_from_disk(
116
                str(Path(pjoin(N2C2_HF_TEST_PATH, MODEL_NAME, rel_type)))
117
            )
118
119
            # create trainer
120
            trainer = BilstmTrainer(
121
                dataset="n2c2",
122
                train_dataset=train_dataset,
123
                test_dataset=test_dataset,
124
                relation_type=rel_type,
125
            )
126
127
            for query_strategy in BaalQueryStrategy:
128
                exp_config = deepcopy(config)
129
                trainer.train_active_learning(query_strategy, config=exp_config, logging=logging)
130
131
132
def bilstm_active_learning_ddi(init_repetition:int = 0, n_repetitions: int = 5, logging: bool = True):
133
    config = BaalExperimentConfig(max_epoch=15,  batch_size=32, all_bayesian=False)
134
135
    # load datasets
136
    train_dataset = load_from_disk(str(Path(pjoin(DDI_HF_TRAIN_PATH, MODEL_NAME))))
137
    test_dataset = load_from_disk(str(Path(pjoin(DDI_HF_TEST_PATH, MODEL_NAME))))
138
139
    # create trainer
140
    trainer = BilstmTrainer("ddi", train_dataset, test_dataset)
141
142
    for repetition in range(init_repetition, final_repetition(init_repetition, n_repetitions)):
143
        # set random seed
144
        random_seed: int = EXP_RANDOM_SEEDS[repetition]
145
        set_seed(random_seed)
146
        config.seed = random_seed
147
148
        for query_strategy in BaalQueryStrategy:
149
            exp_config = deepcopy(config)
150
            trainer.train_active_learning(query_strategy, config=exp_config, logging=logging)