--- a +++ b/src/training/bilstm.py @@ -0,0 +1,646 @@ +# Base Dependencies +# ----------------- +import numpy as np +import time +from copy import deepcopy +from functools import partial +from tqdm import tqdm +from typing import Dict, Optional +from pathlib import Path +from os.path import join + +# Package Dependencies +# -------------------- +from .base import BaseTrainer +from .config import PLExperimentConfig, BaalExperimentConfig +from .early_stopping import EarlyStopping +from .utils import get_baal_query_strategy + +# Local Dependencies +# ------------------- +from extensions.baal import ( + MyModelWrapperBilstm, + MyActiveLearningDatasetBilstm, + MyActiveLearningLoop, +) +from extensions.torchmetrics import ( + DetectionF1Score, + DetectionPrecision, + DetectionRecall, +) +from ml_models.bilstm import ( + HasanModel, + EmbeddingConfig, + LSTMConfig, + RDEmbeddingConfig, +) +from re_datasets.bilstm_utils import pad_and_sort_batch, custom_collate +from vocabulary import Vocabulary, read_list_from_file + +# 3rd-Party Dependencies +# ---------------------- +import neptune +import torch + +from baal.bayesian.dropout import patch_module +from datasets import Dataset +from torch.optim import Adam +from torch.nn import CrossEntropyLoss, Module +from torch.utils.data import DataLoader +from torch.utils.data.sampler import BatchSampler, RandomSampler +from torchmetrics import Accuracy +from torchmetrics.classification import F1Score, Precision, Recall + +# Constants +# --------- +from constants import ( + N2C2_VOCAB_PATH, + DDI_VOCAB_PATH, + N2C2_IOB_TAGS, + DDI_IOB_TAGS, + N2C2_RD_MAX, + DDI_RD_MAX, + RD_EMB_DIM, + IOB_EMB_DIM, + BIOWV_EMB_DIM, + POS_EMB_DIM, + DEP_EMB_DIM, + BIOWORD2VEC_PATH, + U_POS_TAGS, + DEP_TAGS, + BaalQueryStrategy, +) +from config import NEPTUNE_API_TOKEN, NEPTUNE_PROJECT + + +class BilstmTrainer(BaseTrainer): + """Trainer for BiLSTM method.""" + + def __init__( + self, + dataset: str, + train_dataset: Dataset, + test_dataset: Dataset, + relation_type: Optional[str] = None, + ): + """ + Args: + dataset (str): name of the dataset, e.g., "n2c2". + train_dataset (Dataset): train split of the dataset. + test_dataset (Dataset): test split of the dataset. + relation_type (str, optional): relation type. + + Raises: + ValueError: if the name dataset provided is not supported + """ + super().__init__(dataset, train_dataset, test_dataset, relation_type) + + # vocabulary + self.vocab = self._init_vocab() + + # transform datasets + self.transform = partial( + pad_and_sort_batch, padding_idx=self.vocab.pad_index, rd_max=self.RD_MAX + ) + + @property + def method_name(self) -> str: + return "bilstm" + + @property + def method_name_pretty(self) -> str: + return "BiLSTM" + + @property + def task(self) -> str: + if self.dataset == "n2c2": + task = "binary" + else: + task = "multiclass" + return task + + @property + def model_class(self) -> str: + return HasanModel + + @property + def RD_MAX(self) -> str: + if self.dataset == "n2c2": + rd_max = N2C2_RD_MAX + else: + rd_max = DDI_RD_MAX + return rd_max + + @property + def IOB_TAGS(self) -> str: + if self.dataset == "n2c2": + iob_tags = N2C2_IOB_TAGS + else: + iob_tags = DDI_IOB_TAGS + return iob_tags + + def _init_optimizer(self, model: Module): + return Adam(model.parameters(), lr=0.0001) + + def _init_vocab(self): + """Loads the vocabulary of the dataset""" + if self.dataset == "n2c2": + vocab_path = N2C2_VOCAB_PATH + else: + vocab_path = DDI_VOCAB_PATH + + return Vocabulary(read_list_from_file(vocab_path)) + + def _init_model(self, patch: bool = False) -> HasanModel: + """Builds the BiLSTM model setting the right configuration for the chosen dataset""" + # word embedding configuration + biowv_config = EmbeddingConfig( + embedding_dim=BIOWV_EMB_DIM, + vocab_size=len(self.vocab), + emb_path=BIOWORD2VEC_PATH, + freeze=True, + padding_idx=self.vocab.pad_index, + ) + + # relative-distance embedding configuration + rd_config = RDEmbeddingConfig( + input_dim=self.RD_MAX, embedding_dim=RD_EMB_DIM, freeze=False + ) + + # IOB embedding configuration + iob_config = EmbeddingConfig( + embedding_dim=IOB_EMB_DIM, vocab_size=(len(self.IOB_TAGS) + 1), freeze=False + ) + + # Part-of-Speach tag embedding configuration + pos_config = EmbeddingConfig( + embedding_dim=POS_EMB_DIM, vocab_size=(len(U_POS_TAGS) + 1), freeze=False + ) + + dep_config = EmbeddingConfig( + embedding_dim=DEP_EMB_DIM, vocab_size=(len(DEP_TAGS) + 1), freeze=False + ) + + # BiLSTM configuration + lstm_config = LSTMConfig( + emb_size=( + BIOWV_EMB_DIM + 2 * RD_EMB_DIM + POS_EMB_DIM + DEP_EMB_DIM + IOB_EMB_DIM + ) + ) + + model = self.model_class( + vocab=self.vocab, + lstm_config=lstm_config, + bioword2vec_config=biowv_config, + rd_config=rd_config, + pos_config=pos_config, + dep_config=dep_config, + iob_config=iob_config, + num_classes=self.num_classes, + ) + + if patch: + model = patch_module(model) + + return model + + def _reset_trainer(self): + self.train_dataset.reset_format() + self.test_dataset.reset_format() + + def create_dataloader(self, dataset: Dataset, batch_size: int = 6) -> DataLoader: + """Creates a dataloader from a dataset with the adequate configuration + + Args: + dataset (Dataset): dataset to load + + Returns: + DataLoader: dataloader for the given dataset + """ + dataset.set_transform(self.transform) + + # create dataloader + sampler = BatchSampler( + RandomSampler(dataset), batch_size=batch_size, drop_last=False + ) + dataloader = DataLoader(dataset, sampler=sampler, collate_fn=custom_collate) + + return dataloader + + def eval_model( + self, + model: Module, + dataloader: DataLoader, + criterion: Module, + ) -> Dict[str, float]: + """Evaluates the current model on the dev or test set + + Args: + model (Module): model to use for evaluation. + dataloader (DataLoader): dataloader of evaluation dataset + Returns: + Dict: metrics including loss (`loss`), precision (`p`), recall (`r`) and F1-score (`f1`) + """ + + y_true = np.array([], dtype=np.int8) + y_pred = np.array([], dtype=np.int8) + + val_loss = 0.0 + + with torch.no_grad(): + for inputs, labels in dataloader: + # send (inputs, labels) to device + labels = labels.to(self.device) + for key, value in inputs.items(): + inputs[key] = value.to(self.device) + + # calculate outputs + outputs = model(inputs) + loss = criterion(outputs, labels) + val_loss += len(inputs) * loss.item() + + # calculate predictions + _, predicted = torch.max(outputs.data, 1) + + # store labels and predictions + y_true = np.append(y_true, labels.cpu().detach().numpy()) + y_pred = np.append(y_pred, predicted.cpu().detach().numpy()) + + metrics = self.compute_metrics(y_true, y_pred) + metrics["loss"] = val_loss / len(dataloader) + + return metrics + + def train_passive_learning( + self, config: PLExperimentConfig, verbose: bool = True, logging: bool = True + ): + """Trains the BiLSTM model using passive learning and early stopping + + Args: + config (PLExperimentConfig): cofiguration + verbose (bool): determines if information is printed during training. Daults to True. + logging (bool): log the test metrics on Neptune. Defaults to True. + """ + self._reset_trainer() + + # setup + train_val_split = self.train_dataset.train_test_split( + test_size=config.val_size, stratify_by_column="label" + ) + labels = np.array(train_val_split["train"]["label"]) + + train_dataloader = self.create_dataloader( + train_val_split["train"], batch_size=config.batch_size + ) + + val_dataloader = self.create_dataloader( + train_val_split["test"], batch_size=config.batch_size + ) + test_dataloader = self.create_dataloader( + self.test_dataset, batch_size=config.batch_size + ) + + if logging: + run = neptune.init_run(project=NEPTUNE_PROJECT, api_token=NEPTUNE_API_TOKEN) + + model = self._init_model() + model = model.to(self.device) + criterion = CrossEntropyLoss(weight=self.compute_class_weights(labels)) + optimizer = self._init_optimizer(model) + + # print info + if verbose: + self.print_info_passive_learning() + + # early stopper + ES = EarlyStopping( + patience=config.es_patience, + verbose=True, + path=Path(join(self.pl_checkpoint_path, "best_model.pt")), + ) + + # training loop + for epoch in range(config.max_epoch): + running_loss = 0.0 + for i, (inputs, labels) in tqdm(enumerate(train_dataloader, 0)): + # get the inputs; data is a list of [inputs, labels] + labels = labels.to(self.device) + for key, value in inputs.items(): + inputs[key] = value.to(self.device) + + # zero the parameter gradients + optimizer.zero_grad() + + # forward + backward + optimize + outputs = model(inputs) + loss = criterion(outputs, labels) + loss.backward() + optimizer.step() + + # print statistics + running_loss += loss.item() + + # evaluate model on validation set + val_metrics = self.eval_model(model, val_dataloader, criterion) + train_loss = running_loss / len(train_dataloader) + val_loss = val_metrics["loss"] + running_loss = 0.0 + if logging: + run["loss/train"].append(train_loss) + run["loss/val"].append(val_loss) + + for key, value in val_metrics.items(): + if key != "loss": + run[f"val/{key}"].append(value) + + if verbose: + self.print_val_metrics(epoch + 1, val_metrics) + + # check early stopping + ES(val_loss, model) + if ES.early_stop: + break + + # load best model + model.load_state_dict( + torch.load(Path(join(self.pl_checkpoint_path, "best_model.pt"))) + ) + + # evaluate model on test dataset + test_metrics = self.eval_model(model, test_dataloader, criterion) + if verbose: + self.print_test_metrics(test_metrics) + if logging: + run["method"] = self.method_name + run["dataset"] = self.dataset + run["relation"] = self.relation_type + run["strategy"] = "passive learning" + run["config"] = config.__dict__ + run["epochs"] = epoch + + for key, value in test_metrics.items(): + run["test/" + key] = value + run.stop() + + return model + + def set_al_metrics(self, baal_model: MyModelWrapperBilstm): + """ + Configures the metrics that are to be computed during the active learning experiment + + Args: + baal_model (MyModelWrapperBilstm): model wrapper + + """ + # accuracy + baal_model.add_metric( + name="acc", + initializer=lambda: Accuracy(task=self.task, average="micro").to( + self.device + ), + ) + + if self.dataset == "n2c2": + f1 = F1Score(num_classes=self.num_classes, ignore_index=0).to(self.device) + p = Precision(num_classes=self.num_classes, ignore_index=0).to(self.device) + r = Recall(num_classes=self.num_classes, ignore_index=0).to(self.device) + baal_model.add_metric(name="f1", initializer=lambda: f1) + baal_model.add_metric(name="p", initializer=lambda: p) + baal_model.add_metric(name="r", initializer=lambda: r) + + else: # self.dataset == "ddi": + # detection + classification metrics + cla_f1_micro = F1Score( + num_classes=self.num_classes, average="micro", ignore_index=0 + ).to(self.device) + + cla_p_micro = Precision( + num_classes=self.num_classes, average="micro", ignore_index=0 + ).to(self.device) + + cla_r_micro = Recall( + num_classes=self.num_classes, average="micro", ignore_index=0 + ).to(self.device) + + cla_f1_macro = F1Score( + num_classes=self.num_classes, average="macro", ignore_index=0 + ).to(self.device) + + cla_p_macro = Precision( + num_classes=self.num_classes, average="macro", ignore_index=0 + ).to(self.device) + + cla_r_macro = Recall( + num_classes=self.num_classes, average="macro", ignore_index=0 + ).to(self.device) + + baal_model.add_metric(name="micro_f1", initializer=lambda: cla_f1_micro) + baal_model.add_metric(name="micro_p", initializer=lambda: cla_p_micro) + baal_model.add_metric(name="micro_r", initializer=lambda: cla_r_micro) + baal_model.add_metric(name="macro_f1", initializer=lambda: cla_f1_macro) + baal_model.add_metric(name="macro_p", initializer=lambda: cla_p_macro) + baal_model.add_metric(name="macro_r", initializer=lambda: cla_r_macro) + + # detection metrics + detect_f1 = DetectionF1Score().to(self.device) + detect_p = DetectionPrecision().to(self.device) + detect_r = DetectionRecall().to(self.device) + + baal_model.add_metric(name="detect_f1", initializer=lambda: detect_f1) + baal_model.add_metric(name="detect_p", initializer=lambda: detect_p) + baal_model.add_metric(name="detect_r", initializer=lambda: detect_r) + + # per class metrics + per_class_f1 = F1Score(num_classes=self.num_classes, average="none").to( + self.device + ) + + per_class_p = Precision(num_classes=self.num_classes, average="none").to( + self.device + ) + + per_class_r = Recall(num_classes=self.num_classes, average="none").to( + self.device + ) + + baal_model.add_metric(name="class_f1", initializer=lambda: per_class_f1) + baal_model.add_metric(name="class_p", initializer=lambda: per_class_p) + baal_model.add_metric(name="class_r", initializer=lambda: per_class_r) + + return baal_model + + def train_active_learning( + self, + query_strategy: BaalQueryStrategy, + config: BaalExperimentConfig, + verbose: bool = True, + logging: bool = True, + ): + """Trains the BiLSTM model using active learning + + Args: + query_strategy (str): name of the query strategy to be used in the experiment. + config (BaalExperimentConfig): experiment configuration. + verbose (bool): determines if information is printed during trainig or not. Defaults to True.s + logging (bool): log the test metrics on Neptune. Defaults to True. + """ + self._reset_trainer() + + if logging: + run = neptune.init_run(project=NEPTUNE_PROJECT, api_token=NEPTUNE_API_TOKEN) + + # setup querying + INIT_QUERY_SIZE = self.compute_init_q_size(config) + QUERY_SIZE = self.compute_q_size(config) + AL_STEPS = 2 # self.compute_al_steps(config) + + f_query_strategy = get_baal_query_strategy( + name=query_strategy.value, + shuffle_prop=config.shuffle_prop, + query_size=QUERY_SIZE, + ) + + + if verbose: + self.print_info_active_learning( + q_strategy=query_strategy.value, + pool_size=self.n_instances, + init_q_size=INIT_QUERY_SIZE, + q_size=QUERY_SIZE, + ) + + # setup active set + self.train_dataset.set_transform(self.transform) + self.test_dataset.set_transform(self.transform) + active_set = MyActiveLearningDatasetBilstm(self.train_dataset) + active_set.can_label = False + active_set.label_randomly(INIT_QUERY_SIZE) + + # setup model + PATCH = config.all_bayesian or (query_strategy == BaalQueryStrategy.BATCH_BALD) + if not PATCH: + config.iterations = 1 + model = self._init_model(PATCH) + model = model.to(self.device) + criterion = CrossEntropyLoss(self.compute_class_weights(active_set.labels)) + optimizer = self._init_optimizer(model) + + baal_model = MyModelWrapperBilstm( + model, + criterion, + replicate_in_memory=False, + min_train_passes=config.min_train_passes, + ) + baal_model = self.set_al_metrics(baal_model) + + # active loop + active_loop = MyActiveLearningLoop( + dataset=active_set, + get_probabilities=baal_model.predict_on_dataset, + heuristic=f_query_strategy, + query_size=QUERY_SIZE, + batch_size=config.batch_size, + iterations=config.iterations, + use_cuda=self.use_cuda, + verbose=False, + workers=2, + collate_fn=custom_collate, + ) + + # We will reset the weights at each active learning step so we make a copy. + init_weights = deepcopy(baal_model.state_dict()) + + if logging: + run["model"] = self.method_name + run["dataset"] = self.dataset + run["relation"] = self.relation_type + run["bayesian"] = config.all_bayesian or ( + query_strategy == BaalQueryStrategy.BATCH_BALD + ) + run["strategy"] = query_strategy.value + run["config"] = config.__dict__ + run["annotation/intance_ann"].append(active_set.n_labelled / self.n_instances) + run["annotation/token_ann"].append( + active_set.n_labelled_tokens / self.n_tokens + ) + run["annotation/char_ann"].append( + active_set.n_labelled_chars / self.n_characters + ) + + step_acc = [] + + # Active learning loop + for step in tqdm(range(AL_STEPS)): + init_step_time = time.time() + + # Load the initial weights. + baal_model.load_state_dict(init_weights) + + # Train the model on the currently labelled dataset. + init_train_time = time.time() + _ = baal_model.train_on_dataset( + dataset=active_set, + optimizer=optimizer, + batch_size=config.batch_size, + use_cuda=self.use_cuda, + epoch=config.max_epoch, + collate_fn=custom_collate, + ) + train_time = time.time() - init_train_time + + # test the model on the test set. + baal_model.test_on_dataset( + dataset=self.test_dataset, + batch_size=config.batch_size, + use_cuda=self.use_cuda, + average_predictions=config.iterations, + collate_fn=custom_collate, + ) + + if verbose: + self.print_al_iteration_metrics(step + 1, baal_model.get_metrics()) + + # query new instances to be labelled + init_query_time = time.time() + should_continue = active_loop.step() + query_time = time.time() - init_query_time + step_time = time.time() - init_step_time + + if logging: + run["times/step_time"].append(step_time) + run["times/train_time"].append(train_time) + run["times/query_time"].append(query_time) + run["annotation/intance_ann"].append( + active_set.n_labelled / self.n_instances + ) + run["annotation/token_ann"].append( + active_set.n_labelled_tokens / self.n_tokens + ) + run["annotation/char_ann"].append( + active_set.n_labelled_chars / self.n_characters + ) + + if not should_continue: + break + + # adjust class weights + baal_model.criterion = CrossEntropyLoss( + self.compute_class_weights(active_set.labels) + ) + # end of active learning loop + + if logging: + for metrics in baal_model.active_learning_metrics.values(): + for key, value in metrics.items(): + f_key = key.replace("test_", "test/").replace("train_", "train/") + + if "class" in key: + for i, class_value in enumerate(value): + run[f_key + "_" + str(i)].append(class_value) + else: + run[f_key].append(value) + + run["train/step_acc"].extend(active_loop.step_acc) + run["train/step_score"].extend(active_loop.step_score) + + run.stop()