|
a |
|
b/lavis/runners/runner_base.py |
|
|
1 |
""" |
|
|
2 |
Copyright (c) 2022, salesforce.com, inc. |
|
|
3 |
All rights reserved. |
|
|
4 |
SPDX-License-Identifier: BSD-3-Clause |
|
|
5 |
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause |
|
|
6 |
""" |
|
|
7 |
|
|
|
8 |
import datetime |
|
|
9 |
import json |
|
|
10 |
import logging |
|
|
11 |
import os |
|
|
12 |
import time |
|
|
13 |
from pathlib import Path |
|
|
14 |
|
|
|
15 |
import torch |
|
|
16 |
import torch.distributed as dist |
|
|
17 |
import webdataset as wds |
|
|
18 |
from lavis.common.dist_utils import ( |
|
|
19 |
download_cached_file, |
|
|
20 |
get_rank, |
|
|
21 |
get_world_size, |
|
|
22 |
is_main_process, |
|
|
23 |
main_process, |
|
|
24 |
) |
|
|
25 |
from lavis.common.registry import registry |
|
|
26 |
from lavis.common.utils import is_url |
|
|
27 |
from lavis.datasets.data_utils import concat_datasets, reorg_datasets_by_split |
|
|
28 |
from lavis.datasets.datasets.dataloader_utils import ( |
|
|
29 |
IterLoader, |
|
|
30 |
MultiIterLoader, |
|
|
31 |
PrefetchLoader, |
|
|
32 |
) |
|
|
33 |
from torch.nn.parallel import DistributedDataParallel as DDP |
|
|
34 |
from torch.utils.data import DataLoader, DistributedSampler |
|
|
35 |
from torch.utils.data.dataset import ChainDataset |
|
|
36 |
|
|
|
37 |
|
|
|
38 |
@registry.register_runner("runner_base") |
|
|
39 |
class RunnerBase: |
|
|
40 |
""" |
|
|
41 |
A runner class to train and evaluate a model given a task and datasets. |
|
|
42 |
|
|
|
43 |
The runner uses pytorch distributed data parallel by default. Future release |
|
|
44 |
will support other distributed frameworks. |
|
|
45 |
""" |
|
|
46 |
|
|
|
47 |
def __init__(self, cfg, task, model, datasets, job_id): |
|
|
48 |
self.config = cfg |
|
|
49 |
self.job_id = job_id |
|
|
50 |
|
|
|
51 |
self.task = task |
|
|
52 |
self.datasets = datasets |
|
|
53 |
|
|
|
54 |
self._model = model |
|
|
55 |
|
|
|
56 |
self._wrapped_model = None |
|
|
57 |
self._device = None |
|
|
58 |
self._optimizer = None |
|
|
59 |
self._scaler = None |
|
|
60 |
self._dataloaders = None |
|
|
61 |
self._lr_sched = None |
|
|
62 |
|
|
|
63 |
self.start_epoch = 0 |
|
|
64 |
|
|
|
65 |
# self.setup_seeds() |
|
|
66 |
self.setup_output_dir() |
|
|
67 |
|
|
|
68 |
@property |
|
|
69 |
def device(self): |
|
|
70 |
if self._device is None: |
|
|
71 |
self._device = torch.device(self.config.run_cfg.device) |
|
|
72 |
|
|
|
73 |
return self._device |
|
|
74 |
|
|
|
75 |
@property |
|
|
76 |
def use_distributed(self): |
|
|
77 |
return self.config.run_cfg.distributed |
|
|
78 |
|
|
|
79 |
@property |
|
|
80 |
def model(self): |
|
|
81 |
""" |
|
|
82 |
A property to get the DDP-wrapped model on the device. |
|
|
83 |
""" |
|
|
84 |
# move model to device |
|
|
85 |
if self._model.device != self.device: |
|
|
86 |
self._model = self._model.to(self.device) |
|
|
87 |
|
|
|
88 |
# distributed training wrapper |
|
|
89 |
if self.use_distributed: |
|
|
90 |
if self._wrapped_model is None: |
|
|
91 |
self._wrapped_model = DDP( |
|
|
92 |
self._model, device_ids=[self.config.run_cfg.gpu] |
|
|
93 |
) |
|
|
94 |
else: |
|
|
95 |
self._wrapped_model = self._model |
|
|
96 |
|
|
|
97 |
return self._wrapped_model |
|
|
98 |
|
|
|
99 |
@property |
|
|
100 |
def optimizer(self): |
|
|
101 |
# TODO make optimizer class and configurations |
|
|
102 |
if self._optimizer is None: |
|
|
103 |
lr_scale = self.config.run_cfg.get("lr_layer_decay", 1) |
|
|
104 |
weight_decay = self.config.run_cfg.get("weight_decay", 0.05) |
|
|
105 |
optim_params = self._model.get_optimizer_params(weight_decay,lr_scale) |
|
|
106 |
|
|
|
107 |
num_parameters = 0 |
|
|
108 |
for p_group in optim_params: |
|
|
109 |
for p in p_group["params"]: |
|
|
110 |
num_parameters += p.data.nelement() |
|
|
111 |
logging.info("number of trainable parameters: {}".format(num_parameters)) |
|
|
112 |
|
|
|
113 |
beta2 = self.config.run_cfg.get("beta2", 0.999) |
|
|
114 |
|
|
|
115 |
self._optimizer = torch.optim.AdamW( |
|
|
116 |
optim_params, |
|
|
117 |
lr=float(self.config.run_cfg.init_lr), |
|
|
118 |
betas=(0.9, beta2), |
|
|
119 |
) |
|
|
120 |
return self._optimizer |
|
|
121 |
|
|
|
122 |
@property |
|
|
123 |
def scaler(self): |
|
|
124 |
amp = self.config.run_cfg.get("amp", False) |
|
|
125 |
|
|
|
126 |
if amp: |
|
|
127 |
if self._scaler is None: |
|
|
128 |
self._scaler = torch.cuda.amp.GradScaler() |
|
|
129 |
|
|
|
130 |
return self._scaler |
|
|
131 |
|
|
|
132 |
@property |
|
|
133 |
def lr_scheduler(self): |
|
|
134 |
""" |
|
|
135 |
A property to get and create learning rate scheduler by split just in need. |
|
|
136 |
""" |
|
|
137 |
if self._lr_sched is None: |
|
|
138 |
lr_sched_cls = registry.get_lr_scheduler_class(self.config.run_cfg.lr_sched) |
|
|
139 |
|
|
|
140 |
# max_epoch = self.config.run_cfg.max_epoch |
|
|
141 |
max_epoch = self.max_epoch |
|
|
142 |
# min_lr = self.config.run_cfg.min_lr |
|
|
143 |
min_lr = self.min_lr |
|
|
144 |
# init_lr = self.config.run_cfg.init_lr |
|
|
145 |
init_lr = self.init_lr |
|
|
146 |
|
|
|
147 |
# optional parameters |
|
|
148 |
decay_rate = self.config.run_cfg.get("lr_decay_rate", None) |
|
|
149 |
warmup_start_lr = self.config.run_cfg.get("warmup_lr", -1) |
|
|
150 |
warmup_steps = self.config.run_cfg.get("warmup_steps", 0) |
|
|
151 |
|
|
|
152 |
self._lr_sched = lr_sched_cls( |
|
|
153 |
optimizer=self.optimizer, |
|
|
154 |
max_epoch=max_epoch, |
|
|
155 |
min_lr=min_lr, |
|
|
156 |
init_lr=init_lr, |
|
|
157 |
decay_rate=decay_rate, |
|
|
158 |
warmup_start_lr=warmup_start_lr, |
|
|
159 |
warmup_steps=warmup_steps, |
|
|
160 |
) |
|
|
161 |
|
|
|
162 |
return self._lr_sched |
|
|
163 |
|
|
|
164 |
@property |
|
|
165 |
def dataloaders(self) -> dict: |
|
|
166 |
""" |
|
|
167 |
A property to get and create dataloaders by split just in need. |
|
|
168 |
|
|
|
169 |
If no train_dataset_ratio is provided, concatenate map-style datasets and |
|
|
170 |
chain wds.DataPipe datasets separately. Training set becomes a tuple |
|
|
171 |
(ConcatDataset, ChainDataset), both are optional but at least one of them is |
|
|
172 |
required. The resultant ConcatDataset and ChainDataset will be sampled evenly. |
|
|
173 |
|
|
|
174 |
If train_dataset_ratio is provided, create a MultiIterLoader to sample |
|
|
175 |
each dataset by ratios during training. |
|
|
176 |
|
|
|
177 |
Currently do not support multiple datasets for validation and test. |
|
|
178 |
|
|
|
179 |
Returns: |
|
|
180 |
dict: {split_name: (tuples of) dataloader} |
|
|
181 |
""" |
|
|
182 |
if self._dataloaders is None: |
|
|
183 |
# reoganize datasets by split and concatenate/chain if necessary |
|
|
184 |
dataset_ratios = self.config.run_cfg.get("train_dataset_ratios", None) |
|
|
185 |
|
|
|
186 |
# concatenate map-style datasets and chain wds.DataPipe datasets separately |
|
|
187 |
# training set becomes a tuple (ConcatDataset, ChainDataset), both are |
|
|
188 |
# optional but at least one of them is required. The resultant ConcatDataset |
|
|
189 |
# and ChainDataset will be sampled evenly. |
|
|
190 |
logging.info( |
|
|
191 |
"dataset_ratios not specified, datasets will be concatenated (map-style datasets) or chained (webdataset.DataPipeline)." |
|
|
192 |
) |
|
|
193 |
|
|
|
194 |
datasets = reorg_datasets_by_split(self.datasets) |
|
|
195 |
self.datasets = concat_datasets(datasets) |
|
|
196 |
|
|
|
197 |
# print dataset statistics after concatenation/chaining |
|
|
198 |
for split_name in self.datasets: |
|
|
199 |
if isinstance(self.datasets[split_name], tuple) or isinstance( |
|
|
200 |
self.datasets[split_name], list |
|
|
201 |
): |
|
|
202 |
# mixed wds.DataPipeline and torch.utils.data.Dataset |
|
|
203 |
num_records = sum( |
|
|
204 |
[ |
|
|
205 |
len(d) |
|
|
206 |
if not type(d) in [wds.DataPipeline, ChainDataset] |
|
|
207 |
else 0 |
|
|
208 |
for d in self.datasets[split_name] |
|
|
209 |
] |
|
|
210 |
) |
|
|
211 |
|
|
|
212 |
else: |
|
|
213 |
if hasattr(self.datasets[split_name], "__len__"): |
|
|
214 |
# a single map-style dataset |
|
|
215 |
num_records = len(self.datasets[split_name]) |
|
|
216 |
else: |
|
|
217 |
# a single wds.DataPipeline |
|
|
218 |
num_records = -1 |
|
|
219 |
logging.info( |
|
|
220 |
"Only a single wds.DataPipeline dataset, no __len__ attribute." |
|
|
221 |
) |
|
|
222 |
|
|
|
223 |
if num_records >= 0: |
|
|
224 |
logging.info( |
|
|
225 |
"Loaded {} records for {} split from the dataset.".format( |
|
|
226 |
num_records, split_name |
|
|
227 |
) |
|
|
228 |
) |
|
|
229 |
|
|
|
230 |
# create dataloaders |
|
|
231 |
split_names = sorted(self.datasets.keys()) |
|
|
232 |
|
|
|
233 |
datasets = [self.datasets[split] for split in split_names] |
|
|
234 |
is_trains = [split in self.train_splits for split in split_names] |
|
|
235 |
|
|
|
236 |
batch_sizes = [ |
|
|
237 |
self.config.run_cfg.batch_size_train |
|
|
238 |
if split == "train" |
|
|
239 |
else self.config.run_cfg.batch_size_eval |
|
|
240 |
for split in split_names |
|
|
241 |
] |
|
|
242 |
|
|
|
243 |
collate_fns = [] |
|
|
244 |
for dataset in datasets: |
|
|
245 |
if isinstance(dataset, tuple) or isinstance(dataset, list): |
|
|
246 |
collate_fns.append([getattr(d, "collater", None) for d in dataset]) |
|
|
247 |
else: |
|
|
248 |
collate_fns.append(getattr(dataset, "collater", None)) |
|
|
249 |
|
|
|
250 |
dataloaders = self.create_loaders( |
|
|
251 |
datasets=datasets, |
|
|
252 |
num_workers=self.config.run_cfg.num_workers, |
|
|
253 |
batch_sizes=batch_sizes, |
|
|
254 |
is_trains=is_trains, |
|
|
255 |
collate_fns=collate_fns, |
|
|
256 |
dataset_ratios=dataset_ratios, |
|
|
257 |
) |
|
|
258 |
|
|
|
259 |
self._dataloaders = {k: v for k, v in zip(split_names, dataloaders)} |
|
|
260 |
|
|
|
261 |
return self._dataloaders |
|
|
262 |
|
|
|
263 |
@property |
|
|
264 |
def cuda_enabled(self): |
|
|
265 |
return self.device.type == "cuda" |
|
|
266 |
|
|
|
267 |
@property |
|
|
268 |
def max_epoch(self): |
|
|
269 |
return int(self.config.run_cfg.max_epoch) |
|
|
270 |
|
|
|
271 |
@property |
|
|
272 |
def log_freq(self): |
|
|
273 |
log_freq = self.config.run_cfg.get("log_freq", 50) |
|
|
274 |
return int(log_freq) |
|
|
275 |
|
|
|
276 |
@property |
|
|
277 |
def init_lr(self): |
|
|
278 |
return float(self.config.run_cfg.init_lr) |
|
|
279 |
|
|
|
280 |
@property |
|
|
281 |
def min_lr(self): |
|
|
282 |
return float(self.config.run_cfg.min_lr) |
|
|
283 |
|
|
|
284 |
@property |
|
|
285 |
def accum_grad_iters(self): |
|
|
286 |
return int(self.config.run_cfg.get("accum_grad_iters", 1)) |
|
|
287 |
|
|
|
288 |
@property |
|
|
289 |
def valid_splits(self): |
|
|
290 |
valid_splits = self.config.run_cfg.get("valid_splits", []) |
|
|
291 |
|
|
|
292 |
if len(valid_splits) == 0: |
|
|
293 |
logging.info("No validation splits found.") |
|
|
294 |
|
|
|
295 |
return valid_splits |
|
|
296 |
|
|
|
297 |
@property |
|
|
298 |
def test_splits(self): |
|
|
299 |
test_splits = self.config.run_cfg.get("test_splits", []) |
|
|
300 |
|
|
|
301 |
return test_splits |
|
|
302 |
|
|
|
303 |
@property |
|
|
304 |
def train_splits(self): |
|
|
305 |
train_splits = self.config.run_cfg.get("train_splits", []) |
|
|
306 |
|
|
|
307 |
if len(train_splits) == 0: |
|
|
308 |
logging.info("Empty train splits.") |
|
|
309 |
|
|
|
310 |
return train_splits |
|
|
311 |
|
|
|
312 |
@property |
|
|
313 |
def evaluate_only(self): |
|
|
314 |
""" |
|
|
315 |
Set to True to skip training. |
|
|
316 |
""" |
|
|
317 |
return self.config.run_cfg.evaluate |
|
|
318 |
|
|
|
319 |
@property |
|
|
320 |
def use_dist_eval_sampler(self): |
|
|
321 |
return self.config.run_cfg.get("use_dist_eval_sampler", True) |
|
|
322 |
|
|
|
323 |
@property |
|
|
324 |
def resume_ckpt_path(self): |
|
|
325 |
return self.config.run_cfg.get("resume_ckpt_path", None) |
|
|
326 |
|
|
|
327 |
@property |
|
|
328 |
def train_loader(self): |
|
|
329 |
train_dataloader = self.dataloaders["train"] |
|
|
330 |
|
|
|
331 |
return train_dataloader |
|
|
332 |
|
|
|
333 |
def setup_output_dir(self): |
|
|
334 |
lib_root = Path(registry.get_path("library_root")) |
|
|
335 |
|
|
|
336 |
output_dir = lib_root / self.config.run_cfg.output_dir / self.job_id |
|
|
337 |
result_dir = output_dir / "result" |
|
|
338 |
|
|
|
339 |
output_dir.mkdir(parents=True, exist_ok=True) |
|
|
340 |
result_dir.mkdir(parents=True, exist_ok=True) |
|
|
341 |
|
|
|
342 |
registry.register_path("result_dir", str(result_dir)) |
|
|
343 |
registry.register_path("output_dir", str(output_dir)) |
|
|
344 |
|
|
|
345 |
self.result_dir = result_dir |
|
|
346 |
self.output_dir = output_dir |
|
|
347 |
|
|
|
348 |
def train(self): |
|
|
349 |
start_time = time.time() |
|
|
350 |
best_agg_metric = 0 |
|
|
351 |
best_epoch = 0 |
|
|
352 |
|
|
|
353 |
self.log_config() |
|
|
354 |
|
|
|
355 |
# resume from checkpoint if specified |
|
|
356 |
if not self.evaluate_only and self.resume_ckpt_path is not None: |
|
|
357 |
self._load_checkpoint(self.resume_ckpt_path) |
|
|
358 |
|
|
|
359 |
for cur_epoch in range(self.start_epoch, self.max_epoch): |
|
|
360 |
# training phase |
|
|
361 |
if not self.evaluate_only: |
|
|
362 |
logging.info("Start training") |
|
|
363 |
# See https://github.com/salesforce/LAVIS/issues/449 |
|
|
364 |
# if cur_epoch == self.start_epoch: |
|
|
365 |
# self.task.before_training( |
|
|
366 |
# model=self.unwrap_dist_model(self.model), |
|
|
367 |
# dataset=self.datasets["train"], |
|
|
368 |
# ) |
|
|
369 |
train_stats = self.train_epoch(cur_epoch) |
|
|
370 |
self.log_stats(split_name="train", stats=train_stats) |
|
|
371 |
|
|
|
372 |
# evaluation phase |
|
|
373 |
if len(self.valid_splits) > 0: |
|
|
374 |
for split_name in self.valid_splits: |
|
|
375 |
logging.info("Evaluating on {}.".format(split_name)) |
|
|
376 |
|
|
|
377 |
val_log = self.eval_epoch( |
|
|
378 |
split_name=split_name, cur_epoch=cur_epoch |
|
|
379 |
) |
|
|
380 |
if val_log is not None: |
|
|
381 |
if is_main_process(): |
|
|
382 |
assert ( |
|
|
383 |
"agg_metrics" in val_log |
|
|
384 |
), "No agg_metrics found in validation log." |
|
|
385 |
|
|
|
386 |
agg_metrics = val_log["agg_metrics"] |
|
|
387 |
if agg_metrics > best_agg_metric and split_name == "val": |
|
|
388 |
best_epoch, best_agg_metric = cur_epoch, agg_metrics |
|
|
389 |
|
|
|
390 |
self._save_checkpoint(cur_epoch, is_best=True) |
|
|
391 |
|
|
|
392 |
val_log.update({"best_epoch": best_epoch}) |
|
|
393 |
self.log_stats(val_log, split_name) |
|
|
394 |
|
|
|
395 |
else: |
|
|
396 |
# if no validation split is provided, we just save the checkpoint at the end of each epoch. |
|
|
397 |
if not self.evaluate_only: |
|
|
398 |
self._save_checkpoint(cur_epoch, is_best=False) |
|
|
399 |
|
|
|
400 |
if self.evaluate_only: |
|
|
401 |
break |
|
|
402 |
|
|
|
403 |
dist.barrier() |
|
|
404 |
|
|
|
405 |
# testing phase |
|
|
406 |
test_epoch = "best" if len(self.valid_splits) > 0 else cur_epoch |
|
|
407 |
self.evaluate(cur_epoch=test_epoch, skip_reload=self.evaluate_only) |
|
|
408 |
|
|
|
409 |
total_time = time.time() - start_time |
|
|
410 |
total_time_str = str(datetime.timedelta(seconds=int(total_time))) |
|
|
411 |
logging.info("Training time {}".format(total_time_str)) |
|
|
412 |
|
|
|
413 |
def evaluate(self, cur_epoch="best", skip_reload=False): |
|
|
414 |
test_logs = dict() |
|
|
415 |
|
|
|
416 |
if len(self.test_splits) > 0: |
|
|
417 |
for split_name in self.test_splits: |
|
|
418 |
test_logs[split_name] = self.eval_epoch( |
|
|
419 |
split_name=split_name, cur_epoch=cur_epoch, skip_reload=skip_reload |
|
|
420 |
) |
|
|
421 |
|
|
|
422 |
return test_logs |
|
|
423 |
|
|
|
424 |
def train_epoch(self, epoch): |
|
|
425 |
# train |
|
|
426 |
self.model.train() |
|
|
427 |
|
|
|
428 |
return self.task.train_epoch( |
|
|
429 |
epoch=epoch, |
|
|
430 |
model=self.model, |
|
|
431 |
data_loader=self.train_loader, |
|
|
432 |
optimizer=self.optimizer, |
|
|
433 |
scaler=self.scaler, |
|
|
434 |
lr_scheduler=self.lr_scheduler, |
|
|
435 |
cuda_enabled=self.cuda_enabled, |
|
|
436 |
log_freq=self.log_freq, |
|
|
437 |
accum_grad_iters=self.accum_grad_iters, |
|
|
438 |
) |
|
|
439 |
|
|
|
440 |
@torch.no_grad() |
|
|
441 |
def eval_epoch(self, split_name, cur_epoch, skip_reload=False): |
|
|
442 |
""" |
|
|
443 |
Evaluate the model on a given split. |
|
|
444 |
|
|
|
445 |
Args: |
|
|
446 |
split_name (str): name of the split to evaluate on. |
|
|
447 |
cur_epoch (int): current epoch. |
|
|
448 |
skip_reload_best (bool): whether to skip reloading the best checkpoint. |
|
|
449 |
During training, we will reload the best checkpoint for validation. |
|
|
450 |
During testing, we will use provided weights and skip reloading the best checkpoint . |
|
|
451 |
""" |
|
|
452 |
data_loader = self.dataloaders.get(split_name, None) |
|
|
453 |
assert data_loader, "data_loader for split {} is None.".format(split_name) |
|
|
454 |
|
|
|
455 |
# TODO In validation, you need to compute loss as well as metrics |
|
|
456 |
# TODO consider moving to model.before_evaluation() |
|
|
457 |
model = self.unwrap_dist_model(self.model) |
|
|
458 |
if not skip_reload and cur_epoch == "best": |
|
|
459 |
model = self._reload_best_model(model) |
|
|
460 |
model.eval() |
|
|
461 |
|
|
|
462 |
self.task.before_evaluation( |
|
|
463 |
model=model, |
|
|
464 |
dataset=self.datasets[split_name], |
|
|
465 |
) |
|
|
466 |
results = self.task.evaluation(model, data_loader) |
|
|
467 |
|
|
|
468 |
if results is not None: |
|
|
469 |
return self.task.after_evaluation( |
|
|
470 |
val_result=results, |
|
|
471 |
split_name=split_name, |
|
|
472 |
epoch=cur_epoch, |
|
|
473 |
) |
|
|
474 |
|
|
|
475 |
def unwrap_dist_model(self, model): |
|
|
476 |
if self.use_distributed: |
|
|
477 |
return model.module |
|
|
478 |
else: |
|
|
479 |
return model |
|
|
480 |
|
|
|
481 |
def create_loaders( |
|
|
482 |
self, |
|
|
483 |
datasets, |
|
|
484 |
num_workers, |
|
|
485 |
batch_sizes, |
|
|
486 |
is_trains, |
|
|
487 |
collate_fns, |
|
|
488 |
dataset_ratios=None, |
|
|
489 |
): |
|
|
490 |
""" |
|
|
491 |
Create dataloaders for training and validation. |
|
|
492 |
""" |
|
|
493 |
|
|
|
494 |
def _create_loader(dataset, num_workers, bsz, is_train, collate_fn): |
|
|
495 |
# create a single dataloader for each split |
|
|
496 |
if isinstance(dataset, ChainDataset) or isinstance( |
|
|
497 |
dataset, wds.DataPipeline |
|
|
498 |
): |
|
|
499 |
# wds.WebdDataset instance are chained together |
|
|
500 |
# webdataset.DataPipeline has its own sampler and collate_fn |
|
|
501 |
loader = iter( |
|
|
502 |
DataLoader( |
|
|
503 |
dataset, |
|
|
504 |
batch_size=bsz, |
|
|
505 |
num_workers=num_workers, |
|
|
506 |
pin_memory=True, |
|
|
507 |
) |
|
|
508 |
) |
|
|
509 |
else: |
|
|
510 |
# map-style dataset are concatenated together |
|
|
511 |
# setup distributed sampler |
|
|
512 |
if self.use_distributed: |
|
|
513 |
sampler = DistributedSampler( |
|
|
514 |
dataset, |
|
|
515 |
shuffle=is_train, |
|
|
516 |
num_replicas=get_world_size(), |
|
|
517 |
rank=get_rank(), |
|
|
518 |
) |
|
|
519 |
if not self.use_dist_eval_sampler: |
|
|
520 |
# e.g. retrieval evaluation |
|
|
521 |
sampler = sampler if is_train else None |
|
|
522 |
else: |
|
|
523 |
sampler = None |
|
|
524 |
|
|
|
525 |
loader = DataLoader( |
|
|
526 |
dataset, |
|
|
527 |
batch_size=bsz, |
|
|
528 |
num_workers=num_workers, |
|
|
529 |
pin_memory=True, |
|
|
530 |
sampler=sampler, |
|
|
531 |
shuffle=sampler is None and is_train, |
|
|
532 |
collate_fn=collate_fn, |
|
|
533 |
drop_last=True if is_train else False, |
|
|
534 |
) |
|
|
535 |
loader = PrefetchLoader(loader) |
|
|
536 |
|
|
|
537 |
if is_train: |
|
|
538 |
loader = IterLoader(loader, use_distributed=self.use_distributed) |
|
|
539 |
|
|
|
540 |
return loader |
|
|
541 |
|
|
|
542 |
loaders = [] |
|
|
543 |
|
|
|
544 |
for dataset, bsz, is_train, collate_fn in zip( |
|
|
545 |
datasets, batch_sizes, is_trains, collate_fns |
|
|
546 |
): |
|
|
547 |
if isinstance(dataset, list) or isinstance(dataset, tuple): |
|
|
548 |
loader = MultiIterLoader( |
|
|
549 |
loaders=[ |
|
|
550 |
_create_loader(d, num_workers, bsz, is_train, collate_fn[i]) |
|
|
551 |
for i, d in enumerate(dataset) |
|
|
552 |
], |
|
|
553 |
ratios=dataset_ratios, |
|
|
554 |
) |
|
|
555 |
else: |
|
|
556 |
loader = _create_loader(dataset, num_workers, bsz, is_train, collate_fn) |
|
|
557 |
|
|
|
558 |
loaders.append(loader) |
|
|
559 |
|
|
|
560 |
return loaders |
|
|
561 |
|
|
|
562 |
@main_process |
|
|
563 |
def _save_checkpoint(self, cur_epoch, is_best=False): |
|
|
564 |
""" |
|
|
565 |
Save the checkpoint at the current epoch. |
|
|
566 |
""" |
|
|
567 |
model_no_ddp = self.unwrap_dist_model(self.model) |
|
|
568 |
param_grad_dic = { |
|
|
569 |
k: v.requires_grad for (k, v) in model_no_ddp.named_parameters() |
|
|
570 |
} |
|
|
571 |
state_dict = model_no_ddp.state_dict() |
|
|
572 |
for k in list(state_dict.keys()): |
|
|
573 |
if k in param_grad_dic.keys() and not param_grad_dic[k]: |
|
|
574 |
# delete parameters that do not require gradient |
|
|
575 |
del state_dict[k] |
|
|
576 |
|
|
|
577 |
save_obj = { |
|
|
578 |
"model": state_dict, |
|
|
579 |
"optimizer": self.optimizer.state_dict(), |
|
|
580 |
"config": self.config.to_dict(), |
|
|
581 |
"scaler": self.scaler.state_dict() if self.scaler else None, |
|
|
582 |
"epoch": cur_epoch, |
|
|
583 |
} |
|
|
584 |
save_to = os.path.join( |
|
|
585 |
self.output_dir, |
|
|
586 |
"checkpoint_{}.pth".format("best" if is_best else cur_epoch), |
|
|
587 |
) |
|
|
588 |
logging.info("Saving checkpoint at epoch {} to {}.".format(cur_epoch, save_to)) |
|
|
589 |
torch.save(save_obj, save_to) |
|
|
590 |
|
|
|
591 |
def _reload_best_model(self, model): |
|
|
592 |
""" |
|
|
593 |
Load the best checkpoint for evaluation. |
|
|
594 |
""" |
|
|
595 |
checkpoint_path = os.path.join(self.output_dir, "checkpoint_best.pth") |
|
|
596 |
|
|
|
597 |
logging.info("Loading checkpoint from {}.".format(checkpoint_path)) |
|
|
598 |
checkpoint = torch.load(checkpoint_path, map_location="cpu") |
|
|
599 |
try: |
|
|
600 |
model.load_state_dict(checkpoint["model"]) |
|
|
601 |
except RuntimeError as e: |
|
|
602 |
logging.warning( |
|
|
603 |
""" |
|
|
604 |
Key mismatch when loading checkpoint. This is expected if only part of the model is saved. |
|
|
605 |
Trying to load the model with strict=False. |
|
|
606 |
""" |
|
|
607 |
) |
|
|
608 |
model.load_state_dict(checkpoint["model"], strict=False) |
|
|
609 |
return model |
|
|
610 |
|
|
|
611 |
def _load_checkpoint(self, url_or_filename): |
|
|
612 |
""" |
|
|
613 |
Resume from a checkpoint. |
|
|
614 |
""" |
|
|
615 |
if is_url(url_or_filename): |
|
|
616 |
cached_file = download_cached_file( |
|
|
617 |
url_or_filename, check_hash=False, progress=True |
|
|
618 |
) |
|
|
619 |
checkpoint = torch.load(cached_file, map_location=self.device) |
|
|
620 |
elif os.path.isfile(url_or_filename): |
|
|
621 |
checkpoint = torch.load(url_or_filename, map_location=self.device) |
|
|
622 |
else: |
|
|
623 |
raise RuntimeError("checkpoint url or path is invalid") |
|
|
624 |
|
|
|
625 |
state_dict = checkpoint["model"] |
|
|
626 |
self.unwrap_dist_model(self.model).load_state_dict(state_dict) |
|
|
627 |
|
|
|
628 |
self.optimizer.load_state_dict(checkpoint["optimizer"]) |
|
|
629 |
if self.scaler and "scaler" in checkpoint: |
|
|
630 |
self.scaler.load_state_dict(checkpoint["scaler"]) |
|
|
631 |
|
|
|
632 |
self.start_epoch = checkpoint["epoch"] + 1 |
|
|
633 |
logging.info("Resume checkpoint from {}".format(url_or_filename)) |
|
|
634 |
|
|
|
635 |
@main_process |
|
|
636 |
def log_stats(self, stats, split_name): |
|
|
637 |
if isinstance(stats, dict): |
|
|
638 |
log_stats = {**{f"{split_name}_{k}": v for k, v in stats.items()}} |
|
|
639 |
with open(os.path.join(self.output_dir, "log.txt"), "a") as f: |
|
|
640 |
f.write(json.dumps(log_stats) + "\n") |
|
|
641 |
elif isinstance(stats, list): |
|
|
642 |
pass |
|
|
643 |
|
|
|
644 |
@main_process |
|
|
645 |
def log_config(self): |
|
|
646 |
with open(os.path.join(self.output_dir, "log.txt"), "a") as f: |
|
|
647 |
f.write(json.dumps(self.config.to_dict(), indent=4) + "\n") |