Diff of /src/training/base.py [000000] .. [735bb5]

Switch to unified view

a b/src/training/base.py
1
# Base Dependencies
2
# -----------------
3
import numpy as np
4
from abc import ABC, abstractmethod
5
from os.path import join as pjoin
6
from pathlib import Path
7
from typing import Dict, Optional, Union, List
8
9
# Package Dependencies
10
# --------------------
11
from .config import PLExperimentConfig, ALExperimentConfig
12
from .utils import compute_metrics
13
14
# Local Dependencies
15
# ------------------
16
from models.relation_collection import RelationCollection
17
from utils import ddi_binary_relation
18
19
# 3rd-Party Dependencies
20
# ----------------------
21
import torch
22
from sklearn.metrics import accuracy_score
23
from sklearn.utils.class_weight import compute_class_weight
24
from torch.utils.data import Dataset
25
26
# Constants
27
# ---------
28
from constants import CHECKPOINTS_CACHE_DIR, DATASETS
29
30
31
# BaseTrainer
32
# -----------
33
class BaseTrainer(ABC):
34
    def __init__(
35
        self,
36
        dataset: str,
37
        train_dataset: Union[RelationCollection, Dataset],
38
        test_dataset: Union[RelationCollection, Dataset],
39
        relation_type: Optional[str] = None,
40
    ):
41
        """
42
        Args:
43
            dataset (str): name of the dataset, e.g., "n2c2".
44
            train_dataset (Dataset): train split of the dataset.
45
            test_dataset (Dataset): test split of the dataset.
46
            relation_type (str, optional): relation type. Defaults to None.
47
48
        Raises:
49
            ValueError: if the name dataset provided is not supported
50
        """
51
        if dataset not in DATASETS:
52
            raise ValueError("unsupported dataset '{}'".format(dataset))
53
54
        self.dataset = dataset
55
        self.relation_type = relation_type
56
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
57
58
        # datasets
59
        self.train_dataset = train_dataset
60
        self.test_dataset = test_dataset
61
62
        # get total number of instances, tokens and characters
63
        if isinstance(self.train_dataset, RelationCollection):
64
            self.n_instances = self.train_dataset.n_instances
65
            self.n_tokens = self.train_dataset.n_tokens
66
            self.n_characters = self.train_dataset.n_characters
67
        else:
68
            self.n_instances = len(self.train_dataset)
69
            self.n_tokens = self.train_dataset["seq_length"].sum().item()
70
            self.n_characters = self.train_dataset["char_length"].sum().item()
71
72
    @property
73
    @abstractmethod
74
    def method_name(self) -> str:
75
        pass
76
77
    @property
78
    @abstractmethod
79
    def method_name_pretty(self) -> str:
80
        pass
81
82
    @property
83
    def use_cuda(self) -> bool:
84
        return self.device.type == "cuda"
85
86
    @property
87
    def metrics_average(self) -> str:
88
        if self.dataset == "n2c2":
89
            avg = "binary"
90
        else:
91
            avg = "micro"
92
93
        return avg
94
95
    @property
96
    def num_classes(self) -> int:
97
        if self.dataset == "n2c2":
98
            n = 2
99
        else:
100
            n = 5
101
        return n
102
103
    @property
104
    def pl_checkpoint_path(self):
105
        """Pasive Learning checkpoints directory path"""
106
        directory = Path(
107
            pjoin(CHECKPOINTS_CACHE_DIR, "pl", self.method_name, self.dataset)
108
        )
109
        if self.relation_type:
110
            directory = Path(pjoin(directory, self.relation_type))
111
112
        if not directory.is_dir():
113
            directory.mkdir(parents=True, exist_ok=False)
114
115
        return directory
116
117
    @property
118
    def al_checkpoint_path(self):
119
        """Active Learning checkpoints directory path"""
120
        directory = Path(
121
            pjoin(CHECKPOINTS_CACHE_DIR, "al", self.method_name, self.dataset)
122
        )
123
        if self.relation_type:
124
            directory = Path(pjoin(directory, self.relation_type))
125
126
        if not directory.is_dir():
127
            directory.mkdir(parents=True, exist_ok=False)
128
129
        return directory
130
131
    # Instance Methods
132
    # ----------------
133
    @abstractmethod
134
    def train_passive_learning(
135
        self, config: PLExperimentConfig, verbose: bool = True, logging: bool = True
136
    ):
137
        """Trains the model using Passive Learning"""
138
        raise NotImplementedError()
139
140
    @abstractmethod
141
    def train_active_learning(
142
        self,
143
        query_strategy: str,
144
        config: ALExperimentConfig,
145
        verbose: bool = True,
146
        logging: bool = True,
147
    ):
148
        """Trains the model using Active Learning"""
149
        raise NotImplementedError()
150
    
151
    def print_info_passive_learning(self) -> None:
152
        """Prints information about the Passive Learning training process"""
153
        print(f"\n\n**** {self.method_name_pretty} - Train Passive Learning ****")
154
        print(f" - Dataset: {self.dataset}")
155
        if self.relation_type:
156
            print(f" - Relation type: {self.relation_type}")
157
158
    def print_info_active_learning(
159
        self, q_strategy: str, pool_size: int, init_q_size: int, q_size: int
160
    ) -> None:
161
        """Prints information about the Active Learning training process"""
162
        print(f"\n\n**** {self.method_name_pretty} - Train Active Learning ****")
163
        print(f"  - Dataset: {self.dataset}")
164
        if self.relation_type:
165
            print(f"  - Relation type: {self.relation_type}")
166
        print(f"  - Strategy = {q_strategy}")
167
        print(f"  - Pool size = {pool_size}")
168
        print(f"  - Initial query size = {init_q_size}")
169
        print(f"  - Query size = {q_size}")
170
171
    def compute_class_weights(self, labels: list) -> Optional[torch.Tensor]:
172
        """Computes the class weights for the given labels"""
173
        if len(np.unique(labels)) == self.num_classes:
174
            class_weights = compute_class_weight(
175
                class_weight="balanced",
176
                classes=np.array(range(self.num_classes)),
177
                y=labels,
178
            )
179
            class_weights = torch.from_numpy(class_weights).float().to(self.device)
180
        else:
181
            class_weights = None
182
183
        return class_weights
184
185
    def compute_init_q_size(self, config: ALExperimentConfig) -> int:
186
        """Computes the initial pool size for the given configuration"""
187
        return min(
188
            config.max_query_size,
189
            int(round(config.initial_pool_perc * self.n_instances)),
190
        )
191
192
    def compute_q_size(self, config: ALExperimentConfig) -> int:
193
        """Computes the query size for the given configuration"""
194
        return min(
195
            config.max_query_size,
196
            int(round(config.query_size_perc * self.n_instances)),
197
        )
198
199
    def compute_al_steps(self, config: ALExperimentConfig) -> int:
200
        """Computes the number of active learning steps for the given configuration"""
201
        query_size = self.compute_q_size(config)
202
        return int(round(self.n_instances * config.max_annotation / query_size)) - 1
203
204
    def compute_step_accuracy(self, y_true: list, y_pred: list) -> float:
205
        """Computes the accuracy for the given step"""
206
        return accuracy_score(y_true, y_pred, normalize=True)
207
    
208
    def compute_metrics(
209
        self,
210
        y_true: list,
211
        y_pred: list,
212
        labels: Optional[List[str]] = None,
213
        pos_label: int = 1,
214
    ) -> Dict:
215
        """Computes metrics
216
217
        Args:
218
            y_true (list): list of ground truths
219
            y_pred (list): list of predicted values
220
            labels (Optional[List[str]], optional): list of labels. Defaults to None.
221
            pos_label (int, optional): positive label. Defaults to 1.
222
223
        Returns:
224
            Dict: precision, recall and F1-score
225
        """
226
        if self.dataset == "n2c2":
227
            metrics = self.compute_metrics_n2c2(y_true, y_pred, labels, pos_label)
228
        else: # ddi
229
            metrics = self.compute_metrics_ddi(y_true, y_pred, labels, pos_label)
230
231
        return metrics             
232
233
    def compute_metrics_n2c2(
234
        self,
235
        y_true: list,
236
        y_pred: list,
237
        labels: Optional[List[str]] = None,
238
        pos_label: int = 1,
239
    ):
240
        metrics = {}
241
242
        # accuracy
243
        metrics["acc"] = accuracy_score(y_true=y_true, y_pred=y_pred, normalize=True)
244
245
        avg_metrics = compute_metrics(
246
            y_true=y_true, y_pred=y_pred, average=self.metrics_average, pos_label=1
247
        )
248
        for key, value in avg_metrics.items():
249
            metrics[key] = value
250
251
        return metrics
252
253
    def compute_metrics_ddi(
254
        self,
255
        y_true: list,
256
        y_pred: list,
257
        labels: Optional[List[str]] = None,
258
        pos_label: int = 1,
259
    ):  
260
        metrics = {}
261
262
        # accuracy
263
        metrics["acc"] = accuracy_score(y_true=y_true, y_pred=y_pred, normalize=True)
264
265
        # macro
266
        relevant_classes = [1, 2, 3, 4]
267
        relevant_indices = np.isin(y_true, relevant_classes)
268
        micro_metrics = compute_metrics(
269
            y_true=y_true[relevant_indices],
270
            y_pred=y_pred[relevant_indices],
271
            average="micro",
272
        )
273
        for key, value in micro_metrics.items():
274
            metrics[key] = value
275
276
        # ddi: "Detection"
277
        y_true_binary = list(map(lambda x: ddi_binary_relation(x), y_true))
278
        y_pred_binary = list(map(lambda x: ddi_binary_relation(x), y_pred))
279
280
        detection_metrics = compute_metrics(
281
            y_true=y_true_binary, y_pred=y_pred_binary, average="binary"
282
        )
283
        for key, value in detection_metrics.items():
284
            metrics["detect_" + key] = value
285
286
        # ddi: per class
287
        per_class_metrics = compute_metrics(
288
            y_true=y_true, y_pred=y_pred, average=None, labels=[0, 1, 2, 3, 4]
289
        )
290
        for key, values in per_class_metrics.items():
291
            for i, value in enumerate(values):
292
                if labels:
293
                    class_name = labels[i]
294
                else:
295
                    class_name = str(i)
296
                metrics["class_" + key + "_" + class_name] = value
297
298
        return metrics   
299
300
    # Class methods
301
    # --------------
302
    @classmethod
303
    def print_al_iteration_metrics(cls, step: int, metrics: Dict[str, float]):
304
        print("\n** Iteration {} - Metrics **".format(step), flush=True)
305
        for key, value in metrics.items():
306
            print("  - {} = {}".format(key, value), flush=True)
307
        print("")
308
309
    @classmethod
310
    def print_val_metrics(cls, epoch: int, metrics: Dict[str, float]):
311
        print("\n** Epoch {} - Validation set - Metrics **".format(epoch), flush=True)
312
        for key, value in metrics.items():
313
            print("  - {} = {}".format(key, value), flush=True)
314
        print("")
315
316
    @classmethod
317
    def print_train_metrics(cls, metrics: Dict[str, float]):
318
        print("\n** Training set - Metrics **", flush=True)
319
        for key, value in metrics.items():
320
            print("  - {} = {}".format(key, value), flush=True)
321
        print("")
322
323
    @classmethod
324
    def print_test_metrics(cls, metrics: Dict[str, float]):
325
        print("\n** Test set - Metrics **", flush=True)
326
        for key, value in metrics.items():
327
            print("  - {} = {}".format(key, value), flush=True)
328
        print("")