Switch to unified view

a b/lavis/runners/runner_base.py
1
"""
2
 Copyright (c) 2022, salesforce.com, inc.
3
 All rights reserved.
4
 SPDX-License-Identifier: BSD-3-Clause
5
 For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
"""
7
8
import datetime
9
import json
10
import logging
11
import os
12
import time
13
from pathlib import Path
14
15
import torch
16
import torch.distributed as dist
17
import webdataset as wds
18
from lavis.common.dist_utils import (
19
    download_cached_file,
20
    get_rank,
21
    get_world_size,
22
    is_main_process,
23
    main_process,
24
)
25
from lavis.common.registry import registry
26
from lavis.common.utils import is_url
27
from lavis.datasets.data_utils import concat_datasets, reorg_datasets_by_split
28
from lavis.datasets.datasets.dataloader_utils import (
29
    IterLoader,
30
    MultiIterLoader,
31
    PrefetchLoader,
32
)
33
from torch.nn.parallel import DistributedDataParallel as DDP
34
from torch.utils.data import DataLoader, DistributedSampler
35
from torch.utils.data.dataset import ChainDataset
36
37
38
@registry.register_runner("runner_base")
39
class RunnerBase:
40
    """
41
    A runner class to train and evaluate a model given a task and datasets.
42
43
    The runner uses pytorch distributed data parallel by default. Future release
44
    will support other distributed frameworks.
45
    """
46
47
    def __init__(self, cfg, task, model, datasets, job_id):
48
        self.config = cfg
49
        self.job_id = job_id
50
51
        self.task = task
52
        self.datasets = datasets
53
54
        self._model = model
55
56
        self._wrapped_model = None
57
        self._device = None
58
        self._optimizer = None
59
        self._scaler = None
60
        self._dataloaders = None
61
        self._lr_sched = None
62
63
        self.start_epoch = 0
64
65
        # self.setup_seeds()
66
        self.setup_output_dir()
67
68
    @property
69
    def device(self):
70
        if self._device is None:
71
            self._device = torch.device(self.config.run_cfg.device)
72
73
        return self._device
74
75
    @property
76
    def use_distributed(self):
77
        return self.config.run_cfg.distributed
78
79
    @property
80
    def model(self):
81
        """
82
        A property to get the DDP-wrapped model on the device.
83
        """
84
        # move model to device
85
        if self._model.device != self.device:
86
            self._model = self._model.to(self.device)
87
88
            # distributed training wrapper
89
            if self.use_distributed:
90
                if self._wrapped_model is None:
91
                    self._wrapped_model = DDP(
92
                        self._model, device_ids=[self.config.run_cfg.gpu]
93
                    )
94
            else:
95
                self._wrapped_model = self._model
96
97
        return self._wrapped_model
98
99
    @property
100
    def optimizer(self):
101
        # TODO make optimizer class and configurations
102
        if self._optimizer is None:
103
            lr_scale = self.config.run_cfg.get("lr_layer_decay", 1)
104
            weight_decay = self.config.run_cfg.get("weight_decay", 0.05)
105
            optim_params = self._model.get_optimizer_params(weight_decay,lr_scale)
106
107
            num_parameters = 0
108
            for p_group in optim_params:
109
                for p in p_group["params"]:
110
                    num_parameters += p.data.nelement()    
111
            logging.info("number of trainable parameters: {}".format(num_parameters))      
112
                  
113
            beta2 = self.config.run_cfg.get("beta2", 0.999)
114
115
            self._optimizer = torch.optim.AdamW(
116
                optim_params,
117
                lr=float(self.config.run_cfg.init_lr),
118
                betas=(0.9, beta2),
119
            )    
120
        return self._optimizer
121
122
    @property
123
    def scaler(self):
124
        amp = self.config.run_cfg.get("amp", False)
125
126
        if amp:
127
            if self._scaler is None:
128
                self._scaler = torch.cuda.amp.GradScaler()
129
130
        return self._scaler
131
132
    @property
133
    def lr_scheduler(self):
134
        """
135
        A property to get and create learning rate scheduler by split just in need.
136
        """
137
        if self._lr_sched is None:
138
            lr_sched_cls = registry.get_lr_scheduler_class(self.config.run_cfg.lr_sched)
139
140
            # max_epoch = self.config.run_cfg.max_epoch
141
            max_epoch = self.max_epoch
142
            # min_lr = self.config.run_cfg.min_lr
143
            min_lr = self.min_lr
144
            # init_lr = self.config.run_cfg.init_lr
145
            init_lr = self.init_lr
146
147
            # optional parameters
148
            decay_rate = self.config.run_cfg.get("lr_decay_rate", None)
149
            warmup_start_lr = self.config.run_cfg.get("warmup_lr", -1)
150
            warmup_steps = self.config.run_cfg.get("warmup_steps", 0)
151
152
            self._lr_sched = lr_sched_cls(
153
                optimizer=self.optimizer,
154
                max_epoch=max_epoch,
155
                min_lr=min_lr,
156
                init_lr=init_lr,
157
                decay_rate=decay_rate,
158
                warmup_start_lr=warmup_start_lr,
159
                warmup_steps=warmup_steps,
160
            )
161
162
        return self._lr_sched
163
164
    @property
165
    def dataloaders(self) -> dict:
166
        """
167
        A property to get and create dataloaders by split just in need.
168
169
        If no train_dataset_ratio is provided, concatenate map-style datasets and
170
        chain wds.DataPipe datasets separately. Training set becomes a tuple
171
        (ConcatDataset, ChainDataset), both are optional but at least one of them is
172
        required. The resultant ConcatDataset and ChainDataset will be sampled evenly.
173
174
        If train_dataset_ratio is provided, create a MultiIterLoader to sample
175
        each dataset by ratios during training.
176
177
        Currently do not support multiple datasets for validation and test.
178
179
        Returns:
180
            dict: {split_name: (tuples of) dataloader}
181
        """
182
        if self._dataloaders is None:
183
            # reoganize datasets by split and concatenate/chain if necessary
184
            dataset_ratios = self.config.run_cfg.get("train_dataset_ratios", None)
185
186
            # concatenate map-style datasets and chain wds.DataPipe datasets separately
187
            # training set becomes a tuple (ConcatDataset, ChainDataset), both are
188
            # optional but at least one of them is required. The resultant ConcatDataset
189
            # and ChainDataset will be sampled evenly.
190
            logging.info(
191
                "dataset_ratios not specified, datasets will be concatenated (map-style datasets) or chained (webdataset.DataPipeline)."
192
            )
193
194
            datasets = reorg_datasets_by_split(self.datasets)
195
            self.datasets = concat_datasets(datasets)
196
197
            # print dataset statistics after concatenation/chaining
198
            for split_name in self.datasets:
199
                if isinstance(self.datasets[split_name], tuple) or isinstance(
200
                    self.datasets[split_name], list
201
                ):
202
                    # mixed wds.DataPipeline and torch.utils.data.Dataset
203
                    num_records = sum(
204
                        [
205
                            len(d)
206
                            if not type(d) in [wds.DataPipeline, ChainDataset]
207
                            else 0
208
                            for d in self.datasets[split_name]
209
                        ]
210
                    )
211
212
                else:
213
                    if hasattr(self.datasets[split_name], "__len__"):
214
                        # a single map-style dataset
215
                        num_records = len(self.datasets[split_name])
216
                    else:
217
                        # a single wds.DataPipeline
218
                        num_records = -1
219
                        logging.info(
220
                            "Only a single wds.DataPipeline dataset, no __len__ attribute."
221
                        )
222
223
                if num_records >= 0:
224
                    logging.info(
225
                        "Loaded {} records for {} split from the dataset.".format(
226
                            num_records, split_name
227
                        )
228
                    )
229
230
            # create dataloaders
231
            split_names = sorted(self.datasets.keys())
232
233
            datasets = [self.datasets[split] for split in split_names]
234
            is_trains = [split in self.train_splits for split in split_names]
235
236
            batch_sizes = [
237
                self.config.run_cfg.batch_size_train
238
                if split == "train"
239
                else self.config.run_cfg.batch_size_eval
240
                for split in split_names
241
            ]
242
243
            collate_fns = []
244
            for dataset in datasets:
245
                if isinstance(dataset, tuple) or isinstance(dataset, list):
246
                    collate_fns.append([getattr(d, "collater", None) for d in dataset])
247
                else:
248
                    collate_fns.append(getattr(dataset, "collater", None))
249
250
            dataloaders = self.create_loaders(
251
                datasets=datasets,
252
                num_workers=self.config.run_cfg.num_workers,
253
                batch_sizes=batch_sizes,
254
                is_trains=is_trains,
255
                collate_fns=collate_fns,
256
                dataset_ratios=dataset_ratios,
257
            )
258
259
            self._dataloaders = {k: v for k, v in zip(split_names, dataloaders)}
260
261
        return self._dataloaders
262
263
    @property
264
    def cuda_enabled(self):
265
        return self.device.type == "cuda"
266
267
    @property
268
    def max_epoch(self):
269
        return int(self.config.run_cfg.max_epoch)
270
271
    @property
272
    def log_freq(self):
273
        log_freq = self.config.run_cfg.get("log_freq", 50)
274
        return int(log_freq)
275
276
    @property
277
    def init_lr(self):
278
        return float(self.config.run_cfg.init_lr)
279
280
    @property
281
    def min_lr(self):
282
        return float(self.config.run_cfg.min_lr)
283
284
    @property
285
    def accum_grad_iters(self):
286
        return int(self.config.run_cfg.get("accum_grad_iters", 1))
287
288
    @property
289
    def valid_splits(self):
290
        valid_splits = self.config.run_cfg.get("valid_splits", [])
291
292
        if len(valid_splits) == 0:
293
            logging.info("No validation splits found.")
294
295
        return valid_splits
296
297
    @property
298
    def test_splits(self):
299
        test_splits = self.config.run_cfg.get("test_splits", [])
300
301
        return test_splits
302
303
    @property
304
    def train_splits(self):
305
        train_splits = self.config.run_cfg.get("train_splits", [])
306
307
        if len(train_splits) == 0:
308
            logging.info("Empty train splits.")
309
310
        return train_splits
311
312
    @property
313
    def evaluate_only(self):
314
        """
315
        Set to True to skip training.
316
        """
317
        return self.config.run_cfg.evaluate
318
319
    @property
320
    def use_dist_eval_sampler(self):
321
        return self.config.run_cfg.get("use_dist_eval_sampler", True)
322
323
    @property
324
    def resume_ckpt_path(self):
325
        return self.config.run_cfg.get("resume_ckpt_path", None)
326
327
    @property
328
    def train_loader(self):
329
        train_dataloader = self.dataloaders["train"]
330
331
        return train_dataloader
332
333
    def setup_output_dir(self):
334
        lib_root = Path(registry.get_path("library_root"))
335
336
        output_dir = lib_root / self.config.run_cfg.output_dir / self.job_id
337
        result_dir = output_dir / "result"
338
339
        output_dir.mkdir(parents=True, exist_ok=True)
340
        result_dir.mkdir(parents=True, exist_ok=True)
341
342
        registry.register_path("result_dir", str(result_dir))
343
        registry.register_path("output_dir", str(output_dir))
344
345
        self.result_dir = result_dir
346
        self.output_dir = output_dir
347
348
    def train(self):
349
        start_time = time.time()
350
        best_agg_metric = 0
351
        best_epoch = 0
352
353
        self.log_config()
354
355
        # resume from checkpoint if specified
356
        if not self.evaluate_only and self.resume_ckpt_path is not None:
357
            self._load_checkpoint(self.resume_ckpt_path)
358
359
        for cur_epoch in range(self.start_epoch, self.max_epoch):
360
            # training phase
361
            if not self.evaluate_only:
362
                logging.info("Start training")
363
                # See https://github.com/salesforce/LAVIS/issues/449
364
                # if cur_epoch == self.start_epoch:
365
                #     self.task.before_training(
366
                #         model=self.unwrap_dist_model(self.model),
367
                #         dataset=self.datasets["train"],
368
                #     )
369
                train_stats = self.train_epoch(cur_epoch)
370
                self.log_stats(split_name="train", stats=train_stats)
371
372
            # evaluation phase
373
            if len(self.valid_splits) > 0:
374
                for split_name in self.valid_splits:
375
                    logging.info("Evaluating on {}.".format(split_name))
376
377
                    val_log = self.eval_epoch(
378
                        split_name=split_name, cur_epoch=cur_epoch
379
                    )
380
                    if val_log is not None:
381
                        if is_main_process():
382
                            assert (
383
                                "agg_metrics" in val_log
384
                            ), "No agg_metrics found in validation log."
385
386
                            agg_metrics = val_log["agg_metrics"]
387
                            if agg_metrics > best_agg_metric and split_name == "val":
388
                                best_epoch, best_agg_metric = cur_epoch, agg_metrics
389
390
                                self._save_checkpoint(cur_epoch, is_best=True)
391
392
                            val_log.update({"best_epoch": best_epoch})
393
                            self.log_stats(val_log, split_name)
394
395
            else:
396
                # if no validation split is provided, we just save the checkpoint at the end of each epoch.
397
                if not self.evaluate_only:
398
                    self._save_checkpoint(cur_epoch, is_best=False)
399
400
            if self.evaluate_only:
401
                break
402
403
            dist.barrier()
404
405
        # testing phase
406
        test_epoch = "best" if len(self.valid_splits) > 0 else cur_epoch
407
        self.evaluate(cur_epoch=test_epoch, skip_reload=self.evaluate_only)
408
409
        total_time = time.time() - start_time
410
        total_time_str = str(datetime.timedelta(seconds=int(total_time)))
411
        logging.info("Training time {}".format(total_time_str))
412
413
    def evaluate(self, cur_epoch="best", skip_reload=False):
414
        test_logs = dict()
415
416
        if len(self.test_splits) > 0:
417
            for split_name in self.test_splits:
418
                test_logs[split_name] = self.eval_epoch(
419
                    split_name=split_name, cur_epoch=cur_epoch, skip_reload=skip_reload
420
                )
421
422
            return test_logs
423
424
    def train_epoch(self, epoch):
425
        # train
426
        self.model.train()
427
428
        return self.task.train_epoch(
429
            epoch=epoch,
430
            model=self.model,
431
            data_loader=self.train_loader,
432
            optimizer=self.optimizer,
433
            scaler=self.scaler,
434
            lr_scheduler=self.lr_scheduler,
435
            cuda_enabled=self.cuda_enabled,
436
            log_freq=self.log_freq,
437
            accum_grad_iters=self.accum_grad_iters,
438
        )
439
440
    @torch.no_grad()
441
    def eval_epoch(self, split_name, cur_epoch, skip_reload=False):
442
        """
443
        Evaluate the model on a given split.
444
445
        Args:
446
            split_name (str): name of the split to evaluate on.
447
            cur_epoch (int): current epoch.
448
            skip_reload_best (bool): whether to skip reloading the best checkpoint.
449
                During training, we will reload the best checkpoint for validation.
450
                During testing, we will use provided weights and skip reloading the best checkpoint .
451
        """
452
        data_loader = self.dataloaders.get(split_name, None)
453
        assert data_loader, "data_loader for split {} is None.".format(split_name)
454
455
        # TODO In validation, you need to compute loss as well as metrics
456
        # TODO consider moving to model.before_evaluation()
457
        model = self.unwrap_dist_model(self.model)
458
        if not skip_reload and cur_epoch == "best":
459
            model = self._reload_best_model(model)
460
        model.eval()
461
462
        self.task.before_evaluation(
463
            model=model,
464
            dataset=self.datasets[split_name],
465
        )
466
        results = self.task.evaluation(model, data_loader)
467
468
        if results is not None:
469
            return self.task.after_evaluation(
470
                val_result=results,
471
                split_name=split_name,
472
                epoch=cur_epoch,
473
            )
474
475
    def unwrap_dist_model(self, model):
476
        if self.use_distributed:
477
            return model.module
478
        else:
479
            return model
480
481
    def create_loaders(
482
        self,
483
        datasets,
484
        num_workers,
485
        batch_sizes,
486
        is_trains,
487
        collate_fns,
488
        dataset_ratios=None,
489
    ):
490
        """
491
        Create dataloaders for training and validation.
492
        """
493
494
        def _create_loader(dataset, num_workers, bsz, is_train, collate_fn):
495
            # create a single dataloader for each split
496
            if isinstance(dataset, ChainDataset) or isinstance(
497
                dataset, wds.DataPipeline
498
            ):
499
                # wds.WebdDataset instance are chained together
500
                # webdataset.DataPipeline has its own sampler and collate_fn
501
                loader = iter(
502
                    DataLoader(
503
                        dataset,
504
                        batch_size=bsz,
505
                        num_workers=num_workers,
506
                        pin_memory=True,
507
                    )
508
                )
509
            else:
510
                # map-style dataset are concatenated together
511
                # setup distributed sampler
512
                if self.use_distributed:
513
                    sampler = DistributedSampler(
514
                        dataset,
515
                        shuffle=is_train,
516
                        num_replicas=get_world_size(),
517
                        rank=get_rank(),
518
                    )
519
                    if not self.use_dist_eval_sampler:
520
                        # e.g. retrieval evaluation
521
                        sampler = sampler if is_train else None
522
                else:
523
                    sampler = None
524
525
                loader = DataLoader(
526
                    dataset,
527
                    batch_size=bsz,
528
                    num_workers=num_workers,
529
                    pin_memory=True,
530
                    sampler=sampler,
531
                    shuffle=sampler is None and is_train,
532
                    collate_fn=collate_fn,
533
                    drop_last=True if is_train else False,
534
                )
535
                loader = PrefetchLoader(loader)
536
537
                if is_train:
538
                    loader = IterLoader(loader, use_distributed=self.use_distributed)
539
540
            return loader
541
542
        loaders = []
543
544
        for dataset, bsz, is_train, collate_fn in zip(
545
            datasets, batch_sizes, is_trains, collate_fns
546
        ):
547
            if isinstance(dataset, list) or isinstance(dataset, tuple):
548
                loader = MultiIterLoader(
549
                    loaders=[
550
                        _create_loader(d, num_workers, bsz, is_train, collate_fn[i])
551
                        for i, d in enumerate(dataset)
552
                    ],
553
                    ratios=dataset_ratios,
554
                )
555
            else:
556
                loader = _create_loader(dataset, num_workers, bsz, is_train, collate_fn)
557
558
            loaders.append(loader)
559
560
        return loaders
561
562
    @main_process
563
    def _save_checkpoint(self, cur_epoch, is_best=False):
564
        """
565
        Save the checkpoint at the current epoch.
566
        """
567
        model_no_ddp = self.unwrap_dist_model(self.model)
568
        param_grad_dic = {
569
            k: v.requires_grad for (k, v) in model_no_ddp.named_parameters()
570
        }
571
        state_dict = model_no_ddp.state_dict()
572
        for k in list(state_dict.keys()):
573
            if k in param_grad_dic.keys() and not param_grad_dic[k]:
574
                # delete parameters that do not require gradient
575
                del state_dict[k]
576
577
        save_obj = {
578
            "model": state_dict,
579
            "optimizer": self.optimizer.state_dict(),
580
            "config": self.config.to_dict(),
581
            "scaler": self.scaler.state_dict() if self.scaler else None,
582
            "epoch": cur_epoch,
583
        }
584
        save_to = os.path.join(
585
            self.output_dir,
586
            "checkpoint_{}.pth".format("best" if is_best else cur_epoch),
587
        )
588
        logging.info("Saving checkpoint at epoch {} to {}.".format(cur_epoch, save_to))
589
        torch.save(save_obj, save_to)
590
591
    def _reload_best_model(self, model):
592
        """
593
        Load the best checkpoint for evaluation.
594
        """
595
        checkpoint_path = os.path.join(self.output_dir, "checkpoint_best.pth")
596
597
        logging.info("Loading checkpoint from {}.".format(checkpoint_path))
598
        checkpoint = torch.load(checkpoint_path, map_location="cpu")
599
        try:
600
            model.load_state_dict(checkpoint["model"])
601
        except RuntimeError as e:
602
            logging.warning(
603
                """
604
                Key mismatch when loading checkpoint. This is expected if only part of the model is saved.
605
                Trying to load the model with strict=False.
606
                """
607
            )
608
            model.load_state_dict(checkpoint["model"], strict=False)
609
        return model
610
611
    def _load_checkpoint(self, url_or_filename):
612
        """
613
        Resume from a checkpoint.
614
        """
615
        if is_url(url_or_filename):
616
            cached_file = download_cached_file(
617
                url_or_filename, check_hash=False, progress=True
618
            )
619
            checkpoint = torch.load(cached_file, map_location=self.device)
620
        elif os.path.isfile(url_or_filename):
621
            checkpoint = torch.load(url_or_filename, map_location=self.device)
622
        else:
623
            raise RuntimeError("checkpoint url or path is invalid")
624
625
        state_dict = checkpoint["model"]
626
        self.unwrap_dist_model(self.model).load_state_dict(state_dict)
627
628
        self.optimizer.load_state_dict(checkpoint["optimizer"])
629
        if self.scaler and "scaler" in checkpoint:
630
            self.scaler.load_state_dict(checkpoint["scaler"])
631
632
        self.start_epoch = checkpoint["epoch"] + 1
633
        logging.info("Resume checkpoint from {}".format(url_or_filename))
634
635
    @main_process
636
    def log_stats(self, stats, split_name):
637
        if isinstance(stats, dict):
638
            log_stats = {**{f"{split_name}_{k}": v for k, v in stats.items()}}
639
            with open(os.path.join(self.output_dir, "log.txt"), "a") as f:
640
                f.write(json.dumps(log_stats) + "\n")
641
        elif isinstance(stats, list):
642
            pass
643
644
    @main_process
645
    def log_config(self):
646
        with open(os.path.join(self.output_dir, "log.txt"), "a") as f:
647
            f.write(json.dumps(self.config.to_dict(), indent=4) + "\n")