[735bb5]: / src / extensions / baal / model_wrapper.py

Download this file

191 lines (169 with data), 6.8 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
# Base Dependencies
# -----------------
import sys
import structlog
from math import floor
from typing import Callable, Optional
# PyTorch Dependencies
# --------------------
import torch
from torch.utils.data import DataLoader
from torch.utils.data.sampler import BatchSampler, RandomSampler
from torch.utils.data.dataloader import default_collate
from tqdm import tqdm
# Baal Dependencies
# ------------------
from baal.active.dataset.base import Dataset
from baal.modelwrapper import ModelWrapper
from baal.utils.iterutils import map_on_tensor
log = structlog.get_logger("ModelWrapper")
# Model Wrappers
# --------------
class MyModelWrapperBilstm(ModelWrapper):
"""
MyModelWrapper
Modification of ModelWrapper to allow a transform on a batch from a
HF Dataset with several inputs (i.e. dictionary of tensors)
"""
def __init__(self, model, criterion, replicate_in_memory=True, min_train_passes: int = 10):
super().__init__(model, criterion, replicate_in_memory)
self.min_train_passes = min_train_passes
self.batch_sizes = []
def _compute_batch_size(self, n_labelled: int, max_batch_size: int):
bs = min(int(floor(n_labelled / self.min_train_passes)), max_batch_size)
bs = max(2, bs)
return bs
def train_on_dataset(
self,
dataset: Dataset,
optimizer: torch.optim,
batch_size: int,
epoch: int,
use_cuda: bool,
workers: int = 2,
collate_fn: Optional[Callable] = None,
regularizer: Optional[Callable] = None,
):
"""
Train for `epoch` epochs on a Dataset `dataset.
Args:
dataset (Dataset): Pytorch Dataset to be trained on.
optimizer (optim.Optimizer): Optimizer to use.
batch_size (int): The batch size used in the DataLoader.
epoch (int): Number of epoch to train for.
use_cuda (bool): Use cuda or not.
workers (int): Number of workers for the multiprocessing.
collate_fn (Optional[Callable]): The collate function to use.
regularizer (Optional[Callable]): The loss regularization for training.
Returns:
The training history.
"""
dataset_size = len(dataset)
actual_batch_size = batch_size #self._compute_batch_size(dataset_size, batch_size)
self.batch_sizes.append(actual_batch_size)
self.train()
self.set_dataset_size(dataset_size)
history = []
log.info("Starting training", epoch=epoch, dataset=dataset_size)
collate_fn = collate_fn or default_collate
sampler = BatchSampler(
RandomSampler(dataset), batch_size=actual_batch_size, drop_last=False
)
dataloader = DataLoader(
dataset, sampler=sampler, num_workers=workers, collate_fn=collate_fn
)
for _ in range(epoch):
self._reset_metrics("train")
for data, target, *_ in dataloader:
_ = self.train_on_batch(data, target, optimizer, use_cuda, regularizer)
history.append(self.get_metrics("train")["train_loss"])
optimizer.zero_grad() # Assert that the gradient is flushed.
log.info(
"Training complete", train_loss=self.get_metrics("train")["train_loss"]
)
self.active_step(dataset_size, self.get_metrics("train"))
return history
def test_on_dataset(
self,
dataset: Dataset,
batch_size: int,
use_cuda: bool,
workers: int = 2,
collate_fn: Optional[Callable] = None,
average_predictions: int = 1,
):
"""
Test the model on a Dataset `dataset`.
Args:
dataset (Dataset): Dataset to evaluate on.
batch_size (int): Batch size used for evaluation.
use_cuda (bool): Use Cuda or not.
workers (int): Number of workers to use.
collate_fn (Optional[Callable]): The collate function to use.
average_predictions (int): The number of predictions to average to
compute the test loss.
Returns:
Average loss value over the dataset.
"""
self.eval()
log.info("Starting evaluating", dataset=len(dataset))
self._reset_metrics("test")
sampler = BatchSampler(
RandomSampler(dataset), batch_size=batch_size, drop_last=False
)
dataloader = DataLoader(
dataset, sampler=sampler, num_workers=workers, collate_fn=collate_fn
)
for data, target, *_ in dataloader:
_ = self.test_on_batch(
data, target, cuda=use_cuda, average_predictions=average_predictions
)
log.info("Evaluation complete", test_loss=self.get_metrics("test")["test_loss"])
self.active_step(None, self.get_metrics("test"))
return self.get_metrics("test")["test_loss"]
def predict_on_dataset_generator(
self,
dataset: Dataset,
batch_size: int,
iterations: int,
use_cuda: bool,
workers: int = 2,
collate_fn: Optional[Callable] = None,
half=False,
verbose=True,
):
"""
Use the model to predict on a dataset `iterations` time.
Args:
dataset (Dataset): Dataset to predict on.
batch_size (int): Batch size to use during prediction.
iterations (int): Number of iterations per sample.
use_cuda (bool): Use CUDA or not.
workers (int): Number of workers to use.
collate_fn (Optional[Callable]): The collate function to use.
half (bool): If True use half precision.
verbose (bool): If True use tqdm to display progress
Notes:
The "batch" is made of `batch_size` * `iterations` samples.
Returns:
Generators [batch_size, n_classes, ..., n_iterations].
"""
self.eval()
if len(dataset) == 0:
return None
log.info("Start Predict", dataset=len(dataset))
collate_fn = collate_fn or default_collate
sampler = BatchSampler(
RandomSampler(dataset), batch_size=batch_size, drop_last=False
)
loader = DataLoader(
dataset, sampler=sampler, num_workers=workers, collate_fn=collate_fn
)
if verbose:
loader = tqdm(loader, total=len(loader), file=sys.stdout)
for idx, (data, *_) in enumerate(loader):
pred = self.predict_on_batch(data, iterations, use_cuda)
pred = map_on_tensor(lambda x: x.detach(), pred)
if half:
pred = map_on_tensor(lambda x: x.half(), pred)
yield map_on_tensor(lambda x: x.cpu().numpy(), pred)