Switch to unified view

a b/lavis/runners/runner_iter.py
1
"""
2
 Copyright (c) 2022, salesforce.com, inc.
3
 All rights reserved.
4
 SPDX-License-Identifier: BSD-3-Clause
5
 For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
"""
7
8
import datetime
9
import logging
10
import os
11
import time
12
13
import torch
14
import torch.distributed as dist
15
import webdataset as wds
16
from lavis.common.dist_utils import download_cached_file, is_main_process, main_process
17
from lavis.common.registry import registry
18
from lavis.common.utils import is_url
19
from lavis.datasets.data_utils import concat_datasets, reorg_datasets_by_split
20
from lavis.runners.runner_base import RunnerBase
21
from torch.utils.data.dataset import ChainDataset
22
23
24
@registry.register_runner("runner_iter")
25
class RunnerIter(RunnerBase):
26
    """
27
    Run training based on the number of iterations. This is common when
28
    the training dataset size is large. Underhood logic is similar to
29
    epoch-based training by considering every #iters_per_inner_epoch as an
30
    inner epoch.
31
32
    In iter-based runner, after every #iters_per_inner_epoch steps, we
33
34
        1) do a validation epoch;
35
        2) schedule the learning rate;
36
        3) save the checkpoint.
37
38
    We refer every #iters_per_inner_epoch steps as an inner epoch.
39
    """
40
41
    def __init__(self, cfg, task, model, datasets, job_id):
42
        super().__init__(cfg, task, model, datasets, job_id)
43
44
        self.start_iters = 0
45
46
        self.max_iters = int(self.config.run_cfg.get("max_iters", -1))
47
        assert self.max_iters > 0, "max_iters must be greater than 0."
48
49
        self.iters_per_inner_epoch = int(
50
            self.config.run_cfg.get("iters_per_inner_epoch", -1)
51
        )
52
        assert (
53
            self.iters_per_inner_epoch > 0
54
        ), "iters_per_inner_epoch must be greater than 0."
55
56
    @property
57
    def max_epoch(self):
58
        return int(self.max_iters / self.iters_per_inner_epoch)
59
60
    @property
61
    def cur_epoch(self):
62
        try:
63
            return self.train_loader.epoch
64
        except AttributeError:
65
            # pipeline data (e.g. LAION) is streaming, have no concept of epoch
66
            return 0
67
68
    def _progress(self, cur_iters):
69
        return "{}_iters={}".format(self.cur_epoch, cur_iters)
70
71
    def train(self):
72
        start_time = time.time()
73
        best_agg_metric = 0
74
        best_iters = 0
75
76
        self.log_config()
77
78
        # resume from checkpoint if specified
79
        if not self.evaluate_only and self.resume_ckpt_path is not None:
80
            self._load_checkpoint(self.resume_ckpt_path)
81
82
        for start_iters in range(
83
            self.start_iters, self.max_iters, self.iters_per_inner_epoch
84
        ):
85
            end_iters = start_iters + self.iters_per_inner_epoch
86
87
            # training phase
88
            if not self.evaluate_only:
89
                logging.info(
90
                    "Start training, max_iters={}, in total {} inner epochs.".format(
91
                        self.max_iters, int(self.max_iters / self.iters_per_inner_epoch)
92
                    )
93
                )
94
                if start_iters == self.start_iters:
95
                    self.task.before_training(
96
                        model=self.unwrap_dist_model(self.model),
97
                        dataset=self.datasets,
98
                    )
99
                train_stats = self.train_iters(self.cur_epoch, start_iters)
100
                self.log_stats(split_name="train", stats=train_stats)
101
102
            # evaluation phase
103
            if len(self.valid_splits) > 0:
104
                for split_name in self.valid_splits:
105
                    logging.info("Evaluating on {}.".format(split_name))
106
107
                    val_log = self.eval_epoch(
108
                        split_name=split_name, cur_epoch=self._progress(end_iters)
109
                    )
110
                    if val_log is not None:
111
                        if is_main_process():
112
                            assert (
113
                                "agg_metrics" in val_log
114
                            ), "No agg_metrics found in validation log."
115
116
                            agg_metrics = val_log["agg_metrics"]
117
                            if agg_metrics > best_agg_metric and split_name == "val":
118
                                best_iters, best_agg_metric = end_iters, agg_metrics
119
120
                                self._save_checkpoint(end_iters, is_best=True)
121
122
                            val_log.update({"best_iters": best_iters})
123
                            self.log_stats(val_log, split_name)
124
125
            else:
126
                # if no validation split is provided, we just save the checkpoint at the end of each inner epoch.
127
                if not self.evaluate_only:
128
                    self._save_checkpoint(end_iters, is_best=False)
129
130
            if self.evaluate_only:
131
                break
132
            dist.barrier()
133
134
        # testing phase
135
        self.evaluate(cur_epoch=self.cur_epoch)
136
137
        total_time = time.time() - start_time
138
        total_time_str = str(datetime.timedelta(seconds=int(total_time)))
139
        logging.info("Training time {}".format(total_time_str))
140
141
    def train_iters(self, epoch, start_iters):
142
        # train by iterations
143
        self.model.train()
144
145
        return self.task.train_iters(
146
            epoch=epoch,
147
            start_iters=start_iters,
148
            iters_per_inner_epoch=self.iters_per_inner_epoch,
149
            model=self.model,
150
            data_loader=self.train_loader,
151
            optimizer=self.optimizer,
152
            scaler=self.scaler,
153
            lr_scheduler=self.lr_scheduler,
154
            cuda_enabled=self.cuda_enabled,
155
            log_freq=self.log_freq,
156
            accum_grad_iters=self.accum_grad_iters,
157
        )
158
159
    @main_process
160
    def _save_checkpoint(self, cur_iters, is_best=False):
161
        model_no_ddp = self.unwrap_dist_model(self.model)
162
        param_grad_dic = {
163
            k: v.requires_grad for (k, v) in model_no_ddp.named_parameters()
164
        }
165
166
        state_dict = model_no_ddp.state_dict()
167
        for k in list(state_dict.keys()):
168
            if k in param_grad_dic.keys() and not param_grad_dic[k]:
169
                # delete parameters that do not require gradient
170
                del state_dict[k]
171
172
        save_obj = {
173
            "model": state_dict,
174
            "optimizer": self.optimizer.state_dict(),
175
            "config": self.config.to_dict(),
176
            "scaler": self.scaler.state_dict() if self.scaler else None,
177
            "iters": cur_iters,
178
        }
179
        save_to = os.path.join(
180
            self.output_dir,
181
            "checkpoint_{}.pth".format("best" if is_best else cur_iters),
182
        )
183
        logging.info("Saving checkpoint at iters {} to {}.".format(cur_iters, save_to))
184
        torch.save(save_obj, save_to)
185
186
    def _load_checkpoint(self, url_or_filename):
187
        """
188
        Resume from a checkpoint.
189
        """
190
        if is_url(url_or_filename):
191
            cached_file = download_cached_file(
192
                url_or_filename, check_hash=False, progress=True
193
            )
194
            checkpoint = torch.load(cached_file, map_location=self.device)
195
        elif os.path.isfile(url_or_filename):
196
            checkpoint = torch.load(url_or_filename, map_location=self.device)
197
        else:
198
            raise RuntimeError("checkpoint url or path is invalid")
199
200
        state_dict = checkpoint["model"]
201
        self.unwrap_dist_model(self.model).load_state_dict(state_dict)
202
203
        self.optimizer.load_state_dict(checkpoint["optimizer"])
204
        if self.scaler and "scaler" in checkpoint:
205
            self.scaler.load_state_dict(checkpoint["scaler"])
206
207
        self.start_iters = checkpoint["iters"] + 1
208
        logging.info("Resume checkpoint from {}".format(url_or_filename))
209
210
    @property
211
    def dataloaders(self) -> dict:
212
        """
213
        A property to get and create dataloaders by split just in need.
214
215
        If no train_dataset_ratio is provided, concatenate map-style datasets and
216
        chain wds.DataPipe datasets separately. Training set becomes a tuple
217
        (ConcatDataset, ChainDataset), both are optional but at least one of them is
218
        required. The resultant ConcatDataset and ChainDataset will be sampled evenly.
219
220
        If train_dataset_ratio is provided, create a MultiIterLoader to sample
221
        each dataset by ratios during training.
222
223
        Currently do not support multiple datasets for validation and test.
224
225
        Returns:
226
            dict: {split_name: (tuples of) dataloader}
227
        """
228
        if self._dataloaders is None:
229
            # reoganize datasets by split and concatenate/chain if necessary
230
            dataset_ratios = self.config.run_cfg.get("train_dataset_ratios", None)
231
232
            if dataset_ratios is None:
233
                # concatenate map-style datasets and chain wds.DataPipe datasets separately
234
                # training set becomes a tuple (ConcatDataset, ChainDataset), both are
235
                # optional but at least one of them is required. The resultant ConcatDataset
236
                # and ChainDataset will be sampled evenly.
237
                logging.info(
238
                    "dataset_ratios not specified, datasets will be concatenated (map-style datasets) or chained (webdataset.DataPipeline)."
239
                )
240
241
                datasets = reorg_datasets_by_split(self.datasets)
242
                self.datasets = concat_datasets(datasets)
243
            else:
244
                # create multi-loader with the provided ratios, without concatenating or chaining
245
                missing_keys = [k for k in dataset_ratios if k not in self.datasets]
246
                if len(missing_keys) > 0:
247
                    raise ValueError(
248
                        "Datasets with the following split names are not found: {}".format(
249
                            missing_keys
250
                        )
251
                    )
252
253
                unexpected_keys = [k for k in self.datasets if k not in dataset_ratios]
254
                if len(unexpected_keys) > 0:
255
                    raise ValueError(
256
                        "Datasets with the following split names are not expected: {}".format(
257
                            unexpected_keys
258
                        )
259
                    )
260
261
                dataset_ratios = [float(dataset_ratios[k]) for k in self.datasets]
262
                self.datasets = reorg_datasets_by_split(self.datasets)
263
                # to keep the same structure as return value of concat_datasets
264
                self.datasets = {
265
                    k: v[0] if len(v) == 1 else v for k, v in datasets.items()
266
                }
267
268
            # print dataset statistics after concatenation/chaining
269
            for split_name in self.datasets:
270
                if isinstance(self.datasets[split_name], tuple) or isinstance(
271
                    self.datasets[split_name], list
272
                ):
273
                    # mixed wds.DataPipeline and torch.utils.data.Dataset
274
                    num_records = sum(
275
                        [
276
                            len(d)
277
                            if not type(d) in [wds.DataPipeline, ChainDataset]
278
                            else 0
279
                            for d in self.datasets[split_name]
280
                        ]
281
                    )
282
283
                else:
284
                    try:
285
                        # a single map-style dataset
286
                        num_records = len(self.datasets[split_name])
287
                    except TypeError:
288
                        # a single wds.DataPipeline or ChainDataset
289
                        num_records = -1
290
                        logging.info(
291
                            "Only a single wds.DataPipeline dataset, no __len__ attribute."
292
                        )
293
294
                if num_records >= 0:
295
                    logging.info(
296
                        "Loaded {} records for {} split from the dataset.".format(
297
                            num_records, split_name
298
                        )
299
                    )
300
301
            # create dataloaders
302
            split_names = sorted(self.datasets.keys())
303
304
            datasets = [self.datasets[split] for split in split_names]
305
            is_trains = [split in self.train_splits for split in split_names]
306
307
            batch_sizes = [
308
                self.config.run_cfg.batch_size_train
309
                if split == "train"
310
                else self.config.run_cfg.batch_size_eval
311
                for split in split_names
312
            ]
313
314
            collate_fns = []
315
            for dataset in datasets:
316
                if isinstance(dataset, tuple) or isinstance(dataset, list):
317
                    collate_fns.append([getattr(d, "collater", None) for d in dataset])
318
                else:
319
                    collate_fns.append(getattr(dataset, "collater", None))
320
321
            dataloaders = self.create_loaders(
322
                datasets=datasets,
323
                num_workers=self.config.run_cfg.num_workers,
324
                batch_sizes=batch_sizes,
325
                is_trains=is_trains,
326
                collate_fns=collate_fns,
327
                dataset_ratios=dataset_ratios,
328
            )
329
330
            self._dataloaders = {k: v for k, v in zip(split_names, dataloaders)}
331
332
        return self._dataloaders