a b/src/training/bilstm.py
1
# Base Dependencies
2
# -----------------
3
import numpy as np
4
import time
5
from copy import deepcopy
6
from functools import partial
7
from tqdm import tqdm
8
from typing import Dict, Optional
9
from pathlib import Path
10
from os.path import join
11
12
# Package Dependencies
13
# --------------------
14
from .base import BaseTrainer
15
from .config import PLExperimentConfig, BaalExperimentConfig
16
from .early_stopping import EarlyStopping
17
from .utils import get_baal_query_strategy
18
19
# Local Dependencies
20
# -------------------
21
from extensions.baal import (
22
    MyModelWrapperBilstm,
23
    MyActiveLearningDatasetBilstm,
24
    MyActiveLearningLoop,
25
)
26
from extensions.torchmetrics import (
27
    DetectionF1Score,
28
    DetectionPrecision,
29
    DetectionRecall,
30
)
31
from ml_models.bilstm import (
32
    HasanModel,
33
    EmbeddingConfig,
34
    LSTMConfig,
35
    RDEmbeddingConfig,
36
)
37
from re_datasets.bilstm_utils import pad_and_sort_batch, custom_collate
38
from vocabulary import Vocabulary, read_list_from_file
39
40
# 3rd-Party Dependencies
41
# ----------------------
42
import neptune
43
import torch
44
45
from baal.bayesian.dropout import patch_module
46
from datasets import Dataset
47
from torch.optim import Adam
48
from torch.nn import CrossEntropyLoss, Module
49
from torch.utils.data import DataLoader
50
from torch.utils.data.sampler import BatchSampler, RandomSampler
51
from torchmetrics import Accuracy
52
from torchmetrics.classification import F1Score, Precision, Recall
53
54
# Constants
55
# ---------
56
from constants import (
57
    N2C2_VOCAB_PATH,
58
    DDI_VOCAB_PATH,
59
    N2C2_IOB_TAGS,
60
    DDI_IOB_TAGS,
61
    N2C2_RD_MAX,
62
    DDI_RD_MAX,
63
    RD_EMB_DIM,
64
    IOB_EMB_DIM,
65
    BIOWV_EMB_DIM,
66
    POS_EMB_DIM,
67
    DEP_EMB_DIM,
68
    BIOWORD2VEC_PATH,
69
    U_POS_TAGS,
70
    DEP_TAGS,
71
    BaalQueryStrategy,
72
)
73
from config import NEPTUNE_API_TOKEN, NEPTUNE_PROJECT
74
75
76
class BilstmTrainer(BaseTrainer):
77
    """Trainer for BiLSTM method."""
78
79
    def __init__(
80
        self,
81
        dataset: str,
82
        train_dataset: Dataset,
83
        test_dataset: Dataset,
84
        relation_type: Optional[str] = None,
85
    ):
86
        """
87
        Args:
88
            dataset (str): name of the dataset, e.g., "n2c2".
89
            train_dataset (Dataset): train split of the dataset.
90
            test_dataset (Dataset): test split of the dataset.
91
            relation_type (str, optional): relation type.
92
93
        Raises:
94
            ValueError: if the name dataset provided is not supported
95
        """
96
        super().__init__(dataset, train_dataset, test_dataset, relation_type)
97
98
        # vocabulary
99
        self.vocab = self._init_vocab()
100
101
        # transform datasets
102
        self.transform = partial(
103
            pad_and_sort_batch, padding_idx=self.vocab.pad_index, rd_max=self.RD_MAX
104
        )
105
106
    @property
107
    def method_name(self) -> str:
108
        return "bilstm"
109
110
    @property
111
    def method_name_pretty(self) -> str:
112
        return "BiLSTM"
113
114
    @property
115
    def task(self) -> str:
116
        if self.dataset == "n2c2":
117
            task = "binary"
118
        else:
119
            task = "multiclass"
120
        return task
121
122
    @property
123
    def model_class(self) -> str:
124
        return HasanModel
125
126
    @property
127
    def RD_MAX(self) -> str:
128
        if self.dataset == "n2c2":
129
            rd_max = N2C2_RD_MAX
130
        else:
131
            rd_max = DDI_RD_MAX
132
        return rd_max
133
134
    @property
135
    def IOB_TAGS(self) -> str:
136
        if self.dataset == "n2c2":
137
            iob_tags = N2C2_IOB_TAGS
138
        else:
139
            iob_tags = DDI_IOB_TAGS
140
        return iob_tags
141
142
    def _init_optimizer(self, model: Module):
143
        return Adam(model.parameters(), lr=0.0001)
144
145
    def _init_vocab(self):
146
        """Loads the vocabulary of the dataset"""
147
        if self.dataset == "n2c2":
148
            vocab_path = N2C2_VOCAB_PATH
149
        else:
150
            vocab_path = DDI_VOCAB_PATH
151
152
        return Vocabulary(read_list_from_file(vocab_path))
153
    
154
    def _init_model(self, patch: bool = False) -> HasanModel:
155
        """Builds the BiLSTM model setting the right configuration for the chosen dataset"""
156
        # word embedding configuration
157
        biowv_config = EmbeddingConfig(
158
            embedding_dim=BIOWV_EMB_DIM,
159
            vocab_size=len(self.vocab),
160
            emb_path=BIOWORD2VEC_PATH,
161
            freeze=True,
162
            padding_idx=self.vocab.pad_index,
163
        )
164
165
        # relative-distance embedding configuration
166
        rd_config = RDEmbeddingConfig(
167
            input_dim=self.RD_MAX, embedding_dim=RD_EMB_DIM, freeze=False
168
        )
169
170
        # IOB embedding configuration
171
        iob_config = EmbeddingConfig(
172
            embedding_dim=IOB_EMB_DIM, vocab_size=(len(self.IOB_TAGS) + 1), freeze=False
173
        )
174
175
        # Part-of-Speach tag embedding configuration
176
        pos_config = EmbeddingConfig(
177
            embedding_dim=POS_EMB_DIM, vocab_size=(len(U_POS_TAGS) + 1), freeze=False
178
        )
179
180
        dep_config = EmbeddingConfig(
181
            embedding_dim=DEP_EMB_DIM, vocab_size=(len(DEP_TAGS) + 1), freeze=False
182
        )
183
184
        # BiLSTM configuration
185
        lstm_config = LSTMConfig(
186
            emb_size=(
187
                BIOWV_EMB_DIM + 2 * RD_EMB_DIM + POS_EMB_DIM + DEP_EMB_DIM + IOB_EMB_DIM
188
            )
189
        )
190
191
        model = self.model_class(
192
            vocab=self.vocab,
193
            lstm_config=lstm_config,
194
            bioword2vec_config=biowv_config,
195
            rd_config=rd_config,
196
            pos_config=pos_config,
197
            dep_config=dep_config,
198
            iob_config=iob_config,
199
            num_classes=self.num_classes,
200
        )
201
202
        if patch:
203
            model = patch_module(model)
204
205
        return model
206
207
    def _reset_trainer(self):
208
        self.train_dataset.reset_format()
209
        self.test_dataset.reset_format()
210
211
    def create_dataloader(self, dataset: Dataset, batch_size: int = 6) -> DataLoader:
212
        """Creates a dataloader from a dataset with the adequate configuration
213
214
        Args:
215
            dataset (Dataset): dataset to load
216
217
        Returns:
218
            DataLoader: dataloader for the given dataset
219
        """
220
        dataset.set_transform(self.transform)
221
222
        # create dataloader
223
        sampler = BatchSampler(
224
            RandomSampler(dataset), batch_size=batch_size, drop_last=False
225
        )
226
        dataloader = DataLoader(dataset, sampler=sampler, collate_fn=custom_collate)
227
228
        return dataloader
229
230
    def eval_model(
231
        self,
232
        model: Module,
233
        dataloader: DataLoader,
234
        criterion: Module,
235
    ) -> Dict[str, float]:
236
        """Evaluates the current model on the dev or test set
237
238
        Args:
239
            model (Module): model to use for evaluation.
240
            dataloader (DataLoader): dataloader of evaluation dataset
241
        Returns:
242
            Dict: metrics including loss (`loss`), precision (`p`), recall (`r`) and F1-score (`f1`)
243
        """
244
245
        y_true = np.array([], dtype=np.int8)
246
        y_pred = np.array([], dtype=np.int8)
247
248
        val_loss = 0.0
249
250
        with torch.no_grad():
251
            for inputs, labels in dataloader:
252
                # send (inputs, labels) to device
253
                labels = labels.to(self.device)
254
                for key, value in inputs.items():
255
                    inputs[key] = value.to(self.device)
256
257
                # calculate outputs
258
                outputs = model(inputs)
259
                loss = criterion(outputs, labels)
260
                val_loss += len(inputs) * loss.item()
261
262
                # calculate predictions
263
                _, predicted = torch.max(outputs.data, 1)
264
265
                # store labels and predictions
266
                y_true = np.append(y_true, labels.cpu().detach().numpy())
267
                y_pred = np.append(y_pred, predicted.cpu().detach().numpy())
268
269
        metrics = self.compute_metrics(y_true, y_pred)
270
        metrics["loss"] = val_loss / len(dataloader)
271
272
        return metrics
273
274
    def train_passive_learning(
275
        self, config: PLExperimentConfig, verbose: bool = True, logging: bool = True
276
    ):
277
        """Trains the BiLSTM model using passive learning and early stopping
278
279
        Args:
280
            config (PLExperimentConfig): cofiguration
281
            verbose (bool): determines if information is printed during training. Daults to True.
282
            logging (bool): log the test metrics on Neptune. Defaults to True.
283
        """
284
        self._reset_trainer()
285
286
        # setup
287
        train_val_split = self.train_dataset.train_test_split(
288
            test_size=config.val_size, stratify_by_column="label"
289
        )
290
        labels = np.array(train_val_split["train"]["label"])
291
292
        train_dataloader = self.create_dataloader(
293
            train_val_split["train"], batch_size=config.batch_size
294
        )
295
296
        val_dataloader = self.create_dataloader(
297
            train_val_split["test"], batch_size=config.batch_size
298
        )
299
        test_dataloader = self.create_dataloader(
300
            self.test_dataset, batch_size=config.batch_size
301
        )
302
303
        if logging:
304
            run = neptune.init_run(project=NEPTUNE_PROJECT, api_token=NEPTUNE_API_TOKEN)
305
306
        model = self._init_model()
307
        model = model.to(self.device)
308
        criterion = CrossEntropyLoss(weight=self.compute_class_weights(labels))
309
        optimizer = self._init_optimizer(model)
310
311
        # print info
312
        if verbose:
313
            self.print_info_passive_learning()
314
315
        # early stopper
316
        ES = EarlyStopping(
317
            patience=config.es_patience,
318
            verbose=True,
319
            path=Path(join(self.pl_checkpoint_path, "best_model.pt")),
320
        )
321
322
        # training loop
323
        for epoch in range(config.max_epoch):
324
            running_loss = 0.0
325
            for i, (inputs, labels) in tqdm(enumerate(train_dataloader, 0)):
326
                # get the inputs; data is a list of [inputs, labels]
327
                labels = labels.to(self.device)
328
                for key, value in inputs.items():
329
                    inputs[key] = value.to(self.device)
330
331
                # zero the parameter gradients
332
                optimizer.zero_grad()
333
334
                # forward + backward + optimize
335
                outputs = model(inputs)
336
                loss = criterion(outputs, labels)
337
                loss.backward()
338
                optimizer.step()
339
340
                # print statistics
341
                running_loss += loss.item()
342
343
            # evaluate model on validation set
344
            val_metrics = self.eval_model(model, val_dataloader, criterion)
345
            train_loss = running_loss / len(train_dataloader)
346
            val_loss = val_metrics["loss"]
347
            running_loss = 0.0
348
            if logging:
349
                run["loss/train"].append(train_loss)
350
                run["loss/val"].append(val_loss)
351
352
                for key, value in val_metrics.items():
353
                    if key != "loss":
354
                        run[f"val/{key}"].append(value)
355
356
            if verbose:
357
                self.print_val_metrics(epoch + 1, val_metrics)
358
359
            # check early stopping
360
            ES(val_loss, model)
361
            if ES.early_stop:
362
                break
363
364
        # load best model
365
        model.load_state_dict(
366
            torch.load(Path(join(self.pl_checkpoint_path, "best_model.pt")))
367
        )
368
369
        # evaluate model on test dataset
370
        test_metrics = self.eval_model(model, test_dataloader, criterion)
371
        if verbose:
372
            self.print_test_metrics(test_metrics)
373
        if logging:
374
            run["method"] = self.method_name
375
            run["dataset"] = self.dataset
376
            run["relation"] = self.relation_type
377
            run["strategy"] = "passive learning"
378
            run["config"] = config.__dict__
379
            run["epochs"] = epoch
380
381
            for key, value in test_metrics.items():
382
                run["test/" + key] = value
383
            run.stop()
384
385
        return model
386
387
    def set_al_metrics(self, baal_model: MyModelWrapperBilstm):
388
        """
389
        Configures the metrics that are to be computed during the active learning experiment
390
391
        Args:
392
            baal_model (MyModelWrapperBilstm): model wrapper
393
394
        """
395
        # accuracy
396
        baal_model.add_metric(
397
            name="acc",
398
            initializer=lambda: Accuracy(task=self.task, average="micro").to(
399
                self.device
400
            ),
401
        )
402
403
        if self.dataset == "n2c2":
404
            f1 = F1Score(num_classes=self.num_classes, ignore_index=0).to(self.device)
405
            p = Precision(num_classes=self.num_classes, ignore_index=0).to(self.device)
406
            r = Recall(num_classes=self.num_classes, ignore_index=0).to(self.device)
407
            baal_model.add_metric(name="f1", initializer=lambda: f1)
408
            baal_model.add_metric(name="p", initializer=lambda: p)
409
            baal_model.add_metric(name="r", initializer=lambda: r)
410
411
        else:  # self.dataset == "ddi":
412
            # detection + classification metrics
413
            cla_f1_micro = F1Score(
414
                num_classes=self.num_classes, average="micro", ignore_index=0
415
            ).to(self.device)
416
417
            cla_p_micro = Precision(
418
                num_classes=self.num_classes, average="micro", ignore_index=0
419
            ).to(self.device)
420
421
            cla_r_micro = Recall(
422
                num_classes=self.num_classes, average="micro", ignore_index=0
423
            ).to(self.device)
424
425
            cla_f1_macro = F1Score(
426
                num_classes=self.num_classes, average="macro", ignore_index=0
427
            ).to(self.device)
428
429
            cla_p_macro = Precision(
430
                num_classes=self.num_classes, average="macro", ignore_index=0
431
            ).to(self.device)
432
433
            cla_r_macro = Recall(
434
                num_classes=self.num_classes, average="macro", ignore_index=0
435
            ).to(self.device)
436
437
            baal_model.add_metric(name="micro_f1", initializer=lambda: cla_f1_micro)
438
            baal_model.add_metric(name="micro_p", initializer=lambda: cla_p_micro)
439
            baal_model.add_metric(name="micro_r", initializer=lambda: cla_r_micro)
440
            baal_model.add_metric(name="macro_f1", initializer=lambda: cla_f1_macro)
441
            baal_model.add_metric(name="macro_p", initializer=lambda: cla_p_macro)
442
            baal_model.add_metric(name="macro_r", initializer=lambda: cla_r_macro)
443
444
            # detection metrics
445
            detect_f1 = DetectionF1Score().to(self.device)
446
            detect_p = DetectionPrecision().to(self.device)
447
            detect_r = DetectionRecall().to(self.device)
448
449
            baal_model.add_metric(name="detect_f1", initializer=lambda: detect_f1)
450
            baal_model.add_metric(name="detect_p", initializer=lambda: detect_p)
451
            baal_model.add_metric(name="detect_r", initializer=lambda: detect_r)
452
453
            # per class metrics
454
            per_class_f1 = F1Score(num_classes=self.num_classes, average="none").to(
455
                self.device
456
            )
457
458
            per_class_p = Precision(num_classes=self.num_classes, average="none").to(
459
                self.device
460
            )
461
462
            per_class_r = Recall(num_classes=self.num_classes, average="none").to(
463
                self.device
464
            )
465
466
            baal_model.add_metric(name="class_f1", initializer=lambda: per_class_f1)
467
            baal_model.add_metric(name="class_p", initializer=lambda: per_class_p)
468
            baal_model.add_metric(name="class_r", initializer=lambda: per_class_r)
469
470
        return baal_model
471
472
    def train_active_learning(
473
        self,
474
        query_strategy: BaalQueryStrategy,
475
        config: BaalExperimentConfig,
476
        verbose: bool = True,
477
        logging: bool = True,
478
    ):
479
        """Trains the BiLSTM model using active learning
480
481
        Args:
482
            query_strategy (str): name of the query strategy to be used in the experiment.
483
            config (BaalExperimentConfig): experiment configuration.
484
            verbose (bool): determines if information is printed during trainig or not. Defaults to True.s
485
            logging (bool): log the test metrics on Neptune. Defaults to True.
486
        """
487
        self._reset_trainer()
488
489
        if logging:
490
            run = neptune.init_run(project=NEPTUNE_PROJECT, api_token=NEPTUNE_API_TOKEN)
491
492
        # setup querying
493
        INIT_QUERY_SIZE = self.compute_init_q_size(config)
494
        QUERY_SIZE = self.compute_q_size(config)
495
        AL_STEPS = 2 # self.compute_al_steps(config)
496
        
497
        f_query_strategy = get_baal_query_strategy(
498
            name=query_strategy.value,
499
            shuffle_prop=config.shuffle_prop,
500
            query_size=QUERY_SIZE,
501
        )   
502
503
504
        if verbose:
505
            self.print_info_active_learning(
506
                q_strategy=query_strategy.value,
507
                pool_size=self.n_instances,
508
                init_q_size=INIT_QUERY_SIZE,
509
                q_size=QUERY_SIZE,
510
            )
511
512
        # setup active set
513
        self.train_dataset.set_transform(self.transform)
514
        self.test_dataset.set_transform(self.transform)
515
        active_set = MyActiveLearningDatasetBilstm(self.train_dataset)
516
        active_set.can_label = False
517
        active_set.label_randomly(INIT_QUERY_SIZE)
518
519
        # setup model
520
        PATCH =  config.all_bayesian or (query_strategy == BaalQueryStrategy.BATCH_BALD)
521
        if not PATCH: 
522
            config.iterations = 1
523
        model = self._init_model(PATCH)
524
        model = model.to(self.device)
525
        criterion = CrossEntropyLoss(self.compute_class_weights(active_set.labels))
526
        optimizer = self._init_optimizer(model)
527
528
        baal_model = MyModelWrapperBilstm(
529
            model,
530
            criterion,
531
            replicate_in_memory=False,
532
            min_train_passes=config.min_train_passes,
533
        )
534
        baal_model = self.set_al_metrics(baal_model)
535
536
        # active loop
537
        active_loop = MyActiveLearningLoop(
538
            dataset=active_set,
539
            get_probabilities=baal_model.predict_on_dataset,
540
            heuristic=f_query_strategy,
541
            query_size=QUERY_SIZE,
542
            batch_size=config.batch_size,
543
            iterations=config.iterations,
544
            use_cuda=self.use_cuda,
545
            verbose=False,
546
            workers=2,
547
            collate_fn=custom_collate,
548
        )
549
550
        # We will reset the weights at each active learning step so we make a copy.
551
        init_weights = deepcopy(baal_model.state_dict())
552
553
        if logging:
554
            run["model"] = self.method_name
555
            run["dataset"] = self.dataset
556
            run["relation"] = self.relation_type
557
            run["bayesian"] = config.all_bayesian or (
558
                query_strategy == BaalQueryStrategy.BATCH_BALD
559
            )
560
            run["strategy"] = query_strategy.value
561
            run["config"] = config.__dict__
562
            run["annotation/intance_ann"].append(active_set.n_labelled / self.n_instances)
563
            run["annotation/token_ann"].append(
564
                active_set.n_labelled_tokens / self.n_tokens
565
            )
566
            run["annotation/char_ann"].append(
567
                active_set.n_labelled_chars / self.n_characters
568
            )
569
570
        step_acc = []
571
572
        # Active learning loop
573
        for step in tqdm(range(AL_STEPS)):
574
            init_step_time = time.time()
575
576
            # Load the initial weights.
577
            baal_model.load_state_dict(init_weights)
578
579
            # Train the model on the currently labelled dataset.
580
            init_train_time = time.time()
581
            _ = baal_model.train_on_dataset(
582
                dataset=active_set,
583
                optimizer=optimizer,
584
                batch_size=config.batch_size,
585
                use_cuda=self.use_cuda,
586
                epoch=config.max_epoch,
587
                collate_fn=custom_collate,
588
            )
589
            train_time = time.time() - init_train_time
590
591
            # test the model on the test set.
592
            baal_model.test_on_dataset(
593
                dataset=self.test_dataset,
594
                batch_size=config.batch_size,
595
                use_cuda=self.use_cuda,
596
                average_predictions=config.iterations,
597
                collate_fn=custom_collate,
598
            )
599
600
            if verbose:
601
                self.print_al_iteration_metrics(step + 1, baal_model.get_metrics())
602
603
            # query new instances to be labelled
604
            init_query_time = time.time()
605
            should_continue = active_loop.step()
606
            query_time = time.time() - init_query_time
607
            step_time = time.time() - init_step_time
608
609
            if logging:
610
                run["times/step_time"].append(step_time)
611
                run["times/train_time"].append(train_time)
612
                run["times/query_time"].append(query_time)
613
                run["annotation/intance_ann"].append(
614
                    active_set.n_labelled / self.n_instances
615
                )
616
                run["annotation/token_ann"].append(
617
                    active_set.n_labelled_tokens / self.n_tokens
618
                )
619
                run["annotation/char_ann"].append(
620
                    active_set.n_labelled_chars / self.n_characters
621
                )
622
623
            if not should_continue:
624
                break
625
626
            # adjust class weights
627
            baal_model.criterion = CrossEntropyLoss(
628
                self.compute_class_weights(active_set.labels)
629
            )
630
        # end of active learning loop
631
632
        if logging:
633
            for metrics in baal_model.active_learning_metrics.values():
634
                for key, value in metrics.items():
635
                    f_key = key.replace("test_", "test/").replace("train_", "train/")
636
637
                    if "class" in key:
638
                        for i, class_value in enumerate(value):
639
                            run[f_key + "_" + str(i)].append(class_value)
640
                    else:
641
                        run[f_key].append(value)
642
643
            run["train/step_acc"].extend(active_loop.step_acc)
644
            run["train/step_score"].extend(active_loop.step_score)
645
646
            run.stop()