Diff of /src/training/bert.py [000000] .. [735bb5]

Switch to unified view

a b/src/training/bert.py
1
# Base Dependencies
2
# -----------------
3
import numpy as np
4
import re
5
import time
6
from copy import deepcopy
7
from functools import partial
8
from os.path import join
9
from pathlib import Path
10
from typing import Optional, Dict
11
12
# Package Dependencies
13
# --------------------
14
from .base import BaseTrainer
15
from .config import PLExperimentConfig, BaalExperimentConfig
16
from .utils import get_baal_query_strategy, tokenize, tokenize_pairs
17
18
# Local Dependencies
19
# ------------------
20
from extensions.baal import my_active_huggingface_dataset, MyActiveLearningLoop
21
from extensions.transformers import WeightedLossTrainer
22
from ml_models.bert import ClinicalBERT, ClinicalBERTTokenizer, ClinicalBERTConfig
23
24
# 3rd-Party Dependencies
25
# ----------------------
26
import neptune
27
28
from baal.transformers_trainer_wrapper import BaalTransformersTrainer
29
from baal.bayesian.dropout import patch_module
30
from torch.utils.data import Dataset
31
from transformers import (
32
    EarlyStoppingCallback,
33
    EvalPrediction,
34
    IntervalStrategy,
35
    TrainingArguments,
36
)
37
38
# Constants
39
# ---------
40
from constants import BaalQueryStrategy
41
from config import NEPTUNE_API_TOKEN, NEPTUNE_PROJECT
42
43
44
class BertTrainer(BaseTrainer):
45
    """Trainer for the BERT method"""
46
47
    def __init__(
48
        self,
49
        dataset: str,
50
        train_dataset: Dataset,
51
        test_dataset: Dataset,
52
        pairs: bool = False,
53
        relation_type: Optional[str] = None,
54
    ):
55
        """
56
        dataset (str): name of the dataset, e.g., "n2c2".
57
        train_dataset (Dataset): train split of the dataset.
58
        test_dataset (Dataset): test split of the dataset.
59
        relation_type (str, optional): relation type.
60
61
        Raises:
62
            ValueError: if the name dataset provided is not supported
63
        """
64
        super().__init__(dataset, train_dataset, test_dataset, relation_type)
65
66
        self.pairs = pairs
67
        # tokenizer
68
        self.tokenizer = ClinicalBERTTokenizer()
69
70
        # tokenize datasets
71
        if not pairs:
72
            self.train_dataset = tokenize(self.tokenizer, self.train_dataset)
73
            self.test_dataset = tokenize(self.tokenizer, self.test_dataset)
74
        else:
75
            self.train_dataset = tokenize_pairs(self.tokenizer, self.train_dataset)
76
            self.test_dataset = tokenize_pairs(self.tokenizer, self.test_dataset)
77
78
    @property
79
    def method_name(self) -> str:
80
        if self.pairs:
81
            name = "bert-pairs"
82
        else:
83
            name = "bert"
84
        return name
85
86
    @property
87
    def method_name_pretty(self) -> str:
88
        if self.pairs:
89
            name = "Paired Clinical BERT"
90
        else:
91
            name = "Clinical BERT"
92
        return name
93
94
    def _init_model(self, patch: bool = False) -> ClinicalBERT:
95
        config = ClinicalBERTConfig
96
        config.num_labels = self.num_classes
97
        model = ClinicalBERT(config=ClinicalBERTConfig)
98
        if patch:
99
            model = patch_module(model)
100
        return model
101
102
    def compute_metrics_transformer(self, eval_preds: EvalPrediction) -> Dict:
103
        """Computes metrics from a Transformer's prediction.
104
105
        Args:
106
            eval_preds (EvalPrediction): transformer's prediction
107
108
        Returns:
109
            Dict: precision, recall and F1-score
110
        """
111
        logits, labels = eval_preds
112
        predictions = np.argmax(logits, axis=-1)
113
114
        return self.compute_metrics(y_true=labels, y_pred=predictions)
115
116
    def train_passive_learning(
117
        self, config: PLExperimentConfig, verbose: bool = True, logging: bool = True
118
    ):
119
        """Trains the BiLSTM model using passive learning and early stopping
120
121
        Args:
122
            config (PLExperimentConfig): cofiguration
123
            verbose (bool): determines if information is printed during training. Daults to True.
124
            logging (bool): log the test metrics on Neptune. Defaults to True.
125
        """
126
        if logging:
127
            run = neptune.init_run(project=NEPTUNE_PROJECT, api_token=NEPTUNE_API_TOKEN)
128
129
        # setup
130
        train_val_split = self.train_dataset.train_test_split(
131
            test_size=config.val_size, stratify_by_column="label"
132
        )
133
        train_set = train_val_split["train"]
134
        val_set = train_val_split["test"]
135
        test_set = self.test_dataset
136
137
        model = self._init_model()
138
139
        training_args = TrainingArguments(
140
            output_dir=self.pl_checkpoint_path,  # output directory
141
            optim="adamw_torch",  # optimizer
142
            weight_decay=0.01,  # strength of weight decay
143
            learning_rate=5e-5,  # learning rate
144
            evaluation_strategy=IntervalStrategy.EPOCH,
145
            save_strategy=IntervalStrategy.EPOCH,
146
            num_train_epochs=config.max_epoch,
147
            per_device_train_batch_size=config.batch_size,
148
            per_device_eval_batch_size=config.batch_size,  # batch size for evaluation
149
            log_level="warning",  # logging level
150
            logging_dir=".logs/n2c2/bert/",  # directory for storing logs
151
            report_to="none",
152
            metric_for_best_model="f1",
153
            load_best_model_at_end=True,
154
        )
155
156
        trainer = WeightedLossTrainer(
157
            model=model,
158
            args=training_args,
159
            seed=config.seed,
160
            train_dataset=train_set,
161
            eval_dataset=val_set,
162
            tokenizer=self.tokenizer,
163
            compute_metrics=self.compute_metrics_transformer,
164
            callbacks=[
165
                EarlyStoppingCallback(early_stopping_patience=config.es_patience)
166
            ],
167
        )
168
        labels = train_set["label"].numpy()
169
        trainer.class_weights = self.compute_class_weights(labels)
170
171
        # print info
172
        if verbose:
173
            self.print_info_passive_learning()
174
175
        # train model
176
        trainer.train()
177
        eval_loss_values = trainer.eval_loss
178
        train_loss_values = trainer.training_loss
179
180
        # evaluate model on test set
181
        test_metrics = trainer.evaluate(test_set)
182
183
        if verbose:
184
            self.print_test_metrics(test_metrics)
185
186
        # log to Neptune
187
        if logging:
188
            run["method"] = self.method_name
189
            run["dataset"] = self.dataset
190
            run["relation"] = self.relation_type
191
            run["strategy"] = "passive learning"
192
            run["config"] = config.__dict__
193
            run["epoch"] = len(eval_loss_values)
194
195
            for loss in train_loss_values:
196
                run["loss/train"].append(loss)
197
198
            for loss in eval_loss_values:
199
                run["loss/val"].append(loss)
200
201
            for key, value in test_metrics.items():
202
                key2 = re.sub(r"eval_", "", key)
203
                run["test/" + key2] = value
204
205
            run.stop()
206
207
        return model
208
209
    def train_active_learning(
210
        self,
211
        query_strategy: BaalQueryStrategy,
212
        config: BaalExperimentConfig,
213
        verbose: bool = True,
214
        save_models: bool = False,
215
        logging: bool = True,
216
    ):
217
        """Trains the BiLSTM model using active learning
218
219
        Args:
220
            query_strategy (str): name of the query strategy to be used in the experiment.
221
            config (BaalExperimentConfig): experiment configuration.
222
            verbose (bool): determines if information is printed during trainig or not. Defaults to True.s
223
            logging (bool): log the test metrics on Neptune. Defaults to True.
224
        """
225
226
        if logging:
227
            run = neptune.init_run(project=NEPTUNE_PROJECT, api_token=NEPTUNE_API_TOKEN)
228
            run["model"] = self.method_name
229
            run["dataset"] = self.dataset
230
            run["relation"] = self.relation_type
231
            run["strategy"] = query_strategy.value
232
            run["bayesian"] = config.all_bayesian or (
233
                query_strategy == BaalQueryStrategy.BATCH_BALD
234
            )
235
            run["params"] = config.__dict__
236
237
        # setup quering 
238
        INIT_QUERY_SIZE = self.compute_init_q_size(config)
239
        QUERY_SIZE = self.compute_q_size(config)
240
        AL_STEPS = self.compute_al_steps(config)
241
242
        f_query_strategy = get_baal_query_strategy(
243
            name=query_strategy.value,
244
            shuffle_prop=config.shuffle_prop,
245
            query_size=QUERY_SIZE,
246
        )
247
248
        # setup model
249
        PATCH = config.all_bayesian or (query_strategy == BaalQueryStrategy.BATCH_BALD)
250
        if not PATCH:
251
            config.iterations = 1     
252
253
        # setup active set
254
        active_set = my_active_huggingface_dataset(self.train_dataset)
255
        active_set.can_label = False
256
        active_set.label_randomly(INIT_QUERY_SIZE)
257
258
        # print info
259
        if verbose:
260
            self.print_info_active_learning(
261
                q_strategy=query_strategy.value,
262
                pool_size=self.n_instances,
263
                init_q_size=INIT_QUERY_SIZE,
264
                q_size=QUERY_SIZE,
265
            )
266
267
        training_args = TrainingArguments(
268
            output_dir=self.al_checkpoint_path,
269
            optim="adamw_torch",  # optimizer
270
            weight_decay=0.01,  # strength of weight decay
271
            learning_rate=5e-5,  # learning rate
272
            num_train_epochs=config.max_epoch,
273
            per_device_train_batch_size=config.batch_size,
274
            per_device_eval_batch_size=config.batch_size,  # batch size for evaluation
275
            log_level="warning",  # logging level
276
            logging_dir=".logs/n2c2/bert/",  # directory for storing logs
277
            report_to="none",
278
        )
279
280
        # create the trainer through Baal Wrapper
281
        baal_trainer = BaalTransformersTrainer(
282
            model_init=partial(self._init_model, PATCH),
283
            seed=config.seed,
284
            args=training_args,
285
            train_dataset=active_set,
286
            tokenizer=None,
287
            compute_metrics=self.compute_metrics_transformer,
288
        )
289
290
291
        # create Active Learning loop
292
        active_loop = MyActiveLearningLoop(
293
            dataset=active_set,
294
            get_probabilities=baal_trainer.predict_on_dataset,
295
            heuristic=f_query_strategy,
296
            query_size=QUERY_SIZE,
297
            iterations=config.iterations,
298
            max_sample=config.max_sample,
299
        )
300
301
        init_weights = deepcopy(baal_trainer.model.state_dict())
302
303
        # Active Learning loop
304
        for step in range(AL_STEPS):
305
            init_step_time = time.time()
306
307
            # reset the model to the initial state
308
            baal_trainer.model.load_state_dict(init_weights)
309
310
            # train model on current active set
311
            init_train_time = time.time()
312
            baal_trainer.train()
313
            train_time = time.time() - init_train_time
314
315
            if save_models:
316
                # save model
317
                path = Path(join(self.al_checkpoint_path, "model_{}.ck".format(step)))
318
                baal_trainer.model.save_pretrained(path)
319
                
320
            # evaluate model on test set
321
            metrics = baal_trainer.evaluate(self.test_dataset)
322
            metrics["dataset_size"] = active_set.n_labelled
323
324
            # print step metrics
325
            if verbose:
326
                self.print_al_iteration_metrics(step + 1, metrics)
327
328
            # query new instances
329
            init_query_time = time.time()
330
            should_continue = active_loop.step()
331
            query_time = time.time() - init_query_time
332
            step_time = time.time() - init_step_time
333
334
            if logging:
335
                run["times/step_time"].append(step_time)
336
                run["times/train_time"].append(train_time)
337
                run["times/query_time"].append(query_time)
338
                run["annotation/instance_ann"].append(
339
                    active_set.n_labelled / self.n_instances
340
                )
341
                run["annotation/token_ann"].append(
342
                    active_set.n_labelled_tokens / self.n_tokens
343
                )
344
                run["annotation/char_ann"].append(
345
                    active_set.n_labelled_chars / self.n_characters
346
                )
347
                for key, value in metrics.items():
348
                    f_key = key.replace("test_", "test/").replace("train_", "train/")
349
                    run[f_key].append(value)
350
351
            if not should_continue:
352
                break
353
354
            # We reset the model weights to relearn from the new train set.
355
            baal_trainer.load_state_dict(init_weights)
356
            baal_trainer.lr_scheduler = None
357
358
        # log to Neptune
359
        if logging:
360
            for step_acc in active_loop.step_acc:
361
                run["train/step_acc"].append(step_acc)
362
363
            for step_score in active_loop.step_score:
364
                run["train/step_score"].append(step_score)
365
366
            run.stop()