--- a +++ b/src/experiments/bert.py @@ -0,0 +1,163 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +""" +Experiments on the BERT model and the different datasets (i.e. n2c2, DDI) +""" + +# Base Dependencies +# ----------------- +from copy import deepcopy +from pathlib import Path +from os.path import join as pjoin + +# Package Dependencies +# -------------------- +from .common import final_repetition + +# Local Dependencies +# ------------------ +from training.config import PLExperimentConfig, BaalExperimentConfig +from training.bert import BertTrainer +from utils import set_seed + +# 3rd-Party Dependencies +# ---------------------- +from datasets import load_from_disk + +# Constants +# ---------- +from constants import ( + DDI_HF_TEST_PATH, + DDI_HF_TRAIN_PATH, + N2C2_HF_TRAIN_PATH, + N2C2_HF_TEST_PATH, + N2C2_REL_TYPES, + EXP_RANDOM_SEEDS, + BaalQueryStrategy +) + +MODEL_NAME = "bert" + + +def bert_passive_learning_n2c2(init_repetition: int = 0, n_repetitions: int = 5, pairs: bool = False, logging: bool = True): + + config = PLExperimentConfig( + max_epoch=25, batch_size=32, val_size=0.2, es_patience=3 + ) + + for repetition in range(init_repetition, final_repetition(init_repetition, n_repetitions)): + # set random seed + random_seed: int = EXP_RANDOM_SEEDS[repetition] + set_seed(random_seed) + config.seed = random_seed + + for rel_type in N2C2_REL_TYPES: + # load datasets + train_dataset = load_from_disk( + str(Path(pjoin(N2C2_HF_TRAIN_PATH, MODEL_NAME, rel_type))) + ) + test_dataset = load_from_disk( + str(Path(pjoin(N2C2_HF_TEST_PATH, MODEL_NAME, rel_type))) + ) + + # create trainer + trainer = BertTrainer( + dataset="n2c2", + train_dataset=train_dataset, + test_dataset=test_dataset, + pairs=pairs, + relation_type=rel_type, + ) + + # train passive learning + trainer.train_passive_learning(config=config, logging=logging) + + +def bert_active_learning_n2c2(init_repetition: int = 0, n_repetitions: int = 5, pairs: bool = False, logging: bool = True): + + config = BaalExperimentConfig(max_epoch=10, batch_size=32) + + for repetition in range(init_repetition, final_repetition(init_repetition, n_repetitions)): + # set random seed + random_seed: int = EXP_RANDOM_SEEDS[repetition] + set_seed(random_seed) + config.seed = random_seed + + for rel_type in N2C2_REL_TYPES: + + # load datasets + train_dataset = load_from_disk( + str(Path(pjoin(N2C2_HF_TRAIN_PATH, MODEL_NAME, rel_type))) + ) + test_dataset = load_from_disk( + str(Path(pjoin(N2C2_HF_TEST_PATH, MODEL_NAME, rel_type))) + ) + + # create trainer + trainer = BertTrainer( + dataset="n2c2", + train_dataset=train_dataset, + test_dataset=test_dataset, + pairs=pairs, + relation_type=rel_type, + ) + + for query_strategy in BaalQueryStrategy: + exp_config = deepcopy(config) + trainer.train_active_learning(query_strategy, exp_config, logging=logging) + + +def bert_passive_learning_ddi(init_repetition: int = 0, n_repetitions: int = 5, pairs: bool = False, logging: bool = True): + + config = PLExperimentConfig( + max_epoch=25, batch_size=32, val_size=0.2, es_patience=3 + ) + + for repetition in range(init_repetition, final_repetition(init_repetition, n_repetitions)): + # set random seed + random_seed: int = EXP_RANDOM_SEEDS[repetition] + set_seed(random_seed) + config.seed = random_seed + + # load datasets + train_dataset = load_from_disk(str(Path(pjoin(DDI_HF_TRAIN_PATH, MODEL_NAME)))) + test_dataset = load_from_disk(str(Path(pjoin(DDI_HF_TEST_PATH, MODEL_NAME)))) + + # create trainer + trainer = BertTrainer( + dataset="ddi", + train_dataset=train_dataset, + test_dataset=test_dataset, + pairs=pairs, + ) + + # train passive learning + trainer.train_passive_learning(config=config, logging=logging) + + +def bert_active_learning_ddi(init_repetition: int = 0, n_repetitions: int = 5, pairs: bool = False, logging: bool = True): + + config = BaalExperimentConfig(max_epoch=15, batch_size=32,) + + for repetition in range(init_repetition, final_repetition(init_repetition, n_repetitions)): + # set random seed + random_seed: int = EXP_RANDOM_SEEDS[repetition] + set_seed(random_seed) + config.seed = random_seed + + # load datasets + train_dataset = load_from_disk(str(Path(pjoin(DDI_HF_TRAIN_PATH, MODEL_NAME)))) + test_dataset = load_from_disk(str(Path(pjoin(DDI_HF_TEST_PATH, MODEL_NAME)))) + + # create trainer + trainer = BertTrainer( + dataset="ddi", + train_dataset=train_dataset, + test_dataset=test_dataset, + pairs=pairs, + ) + + for query_strategy in BaalQueryStrategy: + exp_config = deepcopy(config) + trainer.train_active_learning(query_strategy, exp_config, logging=logging)