--- a +++ b/src/training/bert.py @@ -0,0 +1,366 @@ +# Base Dependencies +# ----------------- +import numpy as np +import re +import time +from copy import deepcopy +from functools import partial +from os.path import join +from pathlib import Path +from typing import Optional, Dict + +# Package Dependencies +# -------------------- +from .base import BaseTrainer +from .config import PLExperimentConfig, BaalExperimentConfig +from .utils import get_baal_query_strategy, tokenize, tokenize_pairs + +# Local Dependencies +# ------------------ +from extensions.baal import my_active_huggingface_dataset, MyActiveLearningLoop +from extensions.transformers import WeightedLossTrainer +from ml_models.bert import ClinicalBERT, ClinicalBERTTokenizer, ClinicalBERTConfig + +# 3rd-Party Dependencies +# ---------------------- +import neptune + +from baal.transformers_trainer_wrapper import BaalTransformersTrainer +from baal.bayesian.dropout import patch_module +from torch.utils.data import Dataset +from transformers import ( + EarlyStoppingCallback, + EvalPrediction, + IntervalStrategy, + TrainingArguments, +) + +# Constants +# --------- +from constants import BaalQueryStrategy +from config import NEPTUNE_API_TOKEN, NEPTUNE_PROJECT + + +class BertTrainer(BaseTrainer): + """Trainer for the BERT method""" + + def __init__( + self, + dataset: str, + train_dataset: Dataset, + test_dataset: Dataset, + pairs: bool = False, + relation_type: Optional[str] = None, + ): + """ + dataset (str): name of the dataset, e.g., "n2c2". + train_dataset (Dataset): train split of the dataset. + test_dataset (Dataset): test split of the dataset. + relation_type (str, optional): relation type. + + Raises: + ValueError: if the name dataset provided is not supported + """ + super().__init__(dataset, train_dataset, test_dataset, relation_type) + + self.pairs = pairs + # tokenizer + self.tokenizer = ClinicalBERTTokenizer() + + # tokenize datasets + if not pairs: + self.train_dataset = tokenize(self.tokenizer, self.train_dataset) + self.test_dataset = tokenize(self.tokenizer, self.test_dataset) + else: + self.train_dataset = tokenize_pairs(self.tokenizer, self.train_dataset) + self.test_dataset = tokenize_pairs(self.tokenizer, self.test_dataset) + + @property + def method_name(self) -> str: + if self.pairs: + name = "bert-pairs" + else: + name = "bert" + return name + + @property + def method_name_pretty(self) -> str: + if self.pairs: + name = "Paired Clinical BERT" + else: + name = "Clinical BERT" + return name + + def _init_model(self, patch: bool = False) -> ClinicalBERT: + config = ClinicalBERTConfig + config.num_labels = self.num_classes + model = ClinicalBERT(config=ClinicalBERTConfig) + if patch: + model = patch_module(model) + return model + + def compute_metrics_transformer(self, eval_preds: EvalPrediction) -> Dict: + """Computes metrics from a Transformer's prediction. + + Args: + eval_preds (EvalPrediction): transformer's prediction + + Returns: + Dict: precision, recall and F1-score + """ + logits, labels = eval_preds + predictions = np.argmax(logits, axis=-1) + + return self.compute_metrics(y_true=labels, y_pred=predictions) + + def train_passive_learning( + self, config: PLExperimentConfig, verbose: bool = True, logging: bool = True + ): + """Trains the BiLSTM model using passive learning and early stopping + + Args: + config (PLExperimentConfig): cofiguration + verbose (bool): determines if information is printed during training. Daults to True. + logging (bool): log the test metrics on Neptune. Defaults to True. + """ + if logging: + run = neptune.init_run(project=NEPTUNE_PROJECT, api_token=NEPTUNE_API_TOKEN) + + # setup + train_val_split = self.train_dataset.train_test_split( + test_size=config.val_size, stratify_by_column="label" + ) + train_set = train_val_split["train"] + val_set = train_val_split["test"] + test_set = self.test_dataset + + model = self._init_model() + + training_args = TrainingArguments( + output_dir=self.pl_checkpoint_path, # output directory + optim="adamw_torch", # optimizer + weight_decay=0.01, # strength of weight decay + learning_rate=5e-5, # learning rate + evaluation_strategy=IntervalStrategy.EPOCH, + save_strategy=IntervalStrategy.EPOCH, + num_train_epochs=config.max_epoch, + per_device_train_batch_size=config.batch_size, + per_device_eval_batch_size=config.batch_size, # batch size for evaluation + log_level="warning", # logging level + logging_dir=".logs/n2c2/bert/", # directory for storing logs + report_to="none", + metric_for_best_model="f1", + load_best_model_at_end=True, + ) + + trainer = WeightedLossTrainer( + model=model, + args=training_args, + seed=config.seed, + train_dataset=train_set, + eval_dataset=val_set, + tokenizer=self.tokenizer, + compute_metrics=self.compute_metrics_transformer, + callbacks=[ + EarlyStoppingCallback(early_stopping_patience=config.es_patience) + ], + ) + labels = train_set["label"].numpy() + trainer.class_weights = self.compute_class_weights(labels) + + # print info + if verbose: + self.print_info_passive_learning() + + # train model + trainer.train() + eval_loss_values = trainer.eval_loss + train_loss_values = trainer.training_loss + + # evaluate model on test set + test_metrics = trainer.evaluate(test_set) + + if verbose: + self.print_test_metrics(test_metrics) + + # log to Neptune + if logging: + run["method"] = self.method_name + run["dataset"] = self.dataset + run["relation"] = self.relation_type + run["strategy"] = "passive learning" + run["config"] = config.__dict__ + run["epoch"] = len(eval_loss_values) + + for loss in train_loss_values: + run["loss/train"].append(loss) + + for loss in eval_loss_values: + run["loss/val"].append(loss) + + for key, value in test_metrics.items(): + key2 = re.sub(r"eval_", "", key) + run["test/" + key2] = value + + run.stop() + + return model + + def train_active_learning( + self, + query_strategy: BaalQueryStrategy, + config: BaalExperimentConfig, + verbose: bool = True, + save_models: bool = False, + logging: bool = True, + ): + """Trains the BiLSTM model using active learning + + Args: + query_strategy (str): name of the query strategy to be used in the experiment. + config (BaalExperimentConfig): experiment configuration. + verbose (bool): determines if information is printed during trainig or not. Defaults to True.s + logging (bool): log the test metrics on Neptune. Defaults to True. + """ + + if logging: + run = neptune.init_run(project=NEPTUNE_PROJECT, api_token=NEPTUNE_API_TOKEN) + run["model"] = self.method_name + run["dataset"] = self.dataset + run["relation"] = self.relation_type + run["strategy"] = query_strategy.value + run["bayesian"] = config.all_bayesian or ( + query_strategy == BaalQueryStrategy.BATCH_BALD + ) + run["params"] = config.__dict__ + + # setup quering + INIT_QUERY_SIZE = self.compute_init_q_size(config) + QUERY_SIZE = self.compute_q_size(config) + AL_STEPS = self.compute_al_steps(config) + + f_query_strategy = get_baal_query_strategy( + name=query_strategy.value, + shuffle_prop=config.shuffle_prop, + query_size=QUERY_SIZE, + ) + + # setup model + PATCH = config.all_bayesian or (query_strategy == BaalQueryStrategy.BATCH_BALD) + if not PATCH: + config.iterations = 1 + + # setup active set + active_set = my_active_huggingface_dataset(self.train_dataset) + active_set.can_label = False + active_set.label_randomly(INIT_QUERY_SIZE) + + # print info + if verbose: + self.print_info_active_learning( + q_strategy=query_strategy.value, + pool_size=self.n_instances, + init_q_size=INIT_QUERY_SIZE, + q_size=QUERY_SIZE, + ) + + training_args = TrainingArguments( + output_dir=self.al_checkpoint_path, + optim="adamw_torch", # optimizer + weight_decay=0.01, # strength of weight decay + learning_rate=5e-5, # learning rate + num_train_epochs=config.max_epoch, + per_device_train_batch_size=config.batch_size, + per_device_eval_batch_size=config.batch_size, # batch size for evaluation + log_level="warning", # logging level + logging_dir=".logs/n2c2/bert/", # directory for storing logs + report_to="none", + ) + + # create the trainer through Baal Wrapper + baal_trainer = BaalTransformersTrainer( + model_init=partial(self._init_model, PATCH), + seed=config.seed, + args=training_args, + train_dataset=active_set, + tokenizer=None, + compute_metrics=self.compute_metrics_transformer, + ) + + + # create Active Learning loop + active_loop = MyActiveLearningLoop( + dataset=active_set, + get_probabilities=baal_trainer.predict_on_dataset, + heuristic=f_query_strategy, + query_size=QUERY_SIZE, + iterations=config.iterations, + max_sample=config.max_sample, + ) + + init_weights = deepcopy(baal_trainer.model.state_dict()) + + # Active Learning loop + for step in range(AL_STEPS): + init_step_time = time.time() + + # reset the model to the initial state + baal_trainer.model.load_state_dict(init_weights) + + # train model on current active set + init_train_time = time.time() + baal_trainer.train() + train_time = time.time() - init_train_time + + if save_models: + # save model + path = Path(join(self.al_checkpoint_path, "model_{}.ck".format(step))) + baal_trainer.model.save_pretrained(path) + + # evaluate model on test set + metrics = baal_trainer.evaluate(self.test_dataset) + metrics["dataset_size"] = active_set.n_labelled + + # print step metrics + if verbose: + self.print_al_iteration_metrics(step + 1, metrics) + + # query new instances + init_query_time = time.time() + should_continue = active_loop.step() + query_time = time.time() - init_query_time + step_time = time.time() - init_step_time + + if logging: + run["times/step_time"].append(step_time) + run["times/train_time"].append(train_time) + run["times/query_time"].append(query_time) + run["annotation/instance_ann"].append( + active_set.n_labelled / self.n_instances + ) + run["annotation/token_ann"].append( + active_set.n_labelled_tokens / self.n_tokens + ) + run["annotation/char_ann"].append( + active_set.n_labelled_chars / self.n_characters + ) + for key, value in metrics.items(): + f_key = key.replace("test_", "test/").replace("train_", "train/") + run[f_key].append(value) + + if not should_continue: + break + + # We reset the model weights to relearn from the new train set. + baal_trainer.load_state_dict(init_weights) + baal_trainer.lr_scheduler = None + + # log to Neptune + if logging: + for step_acc in active_loop.step_acc: + run["train/step_acc"].append(step_acc) + + for step_score in active_loop.step_score: + run["train/step_score"].append(step_score) + + run.stop()