Switch to unified view

a b/model/lavis/tasks/base_task.py
1
"""
2
 Copyright (c) 2022, salesforce.com, inc.
3
 All rights reserved.
4
 SPDX-License-Identifier: BSD-3-Clause
5
 For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
"""
7
8
import logging
9
import os
10
11
import torch
12
import torch.distributed as dist
13
from model.lavis.common.dist_utils import get_rank, get_world_size, is_main_process, is_dist_avail_and_initialized
14
from model.lavis.common.logger import MetricLogger, SmoothedValue
15
from model.lavis.common.registry import registry
16
from model.lavis.datasets.data_utils import prepare_sample
17
18
19
class BaseTask:
20
    def __init__(self, **kwargs):
21
        super().__init__()
22
23
        self.inst_id_key = "instance_id"
24
25
    @classmethod
26
    def setup_task(cls, **kwargs):
27
        return cls()
28
29
    def build_model(self, cfg):
30
        model_config = cfg.model_cfg
31
32
        model_cls = registry.get_model_class(model_config.arch)
33
        return model_cls.from_config(model_config)
34
35
    def build_datasets(self, cfg):
36
        """
37
        Build a dictionary of datasets, keyed by split 'train', 'valid', 'test'.
38
        Download dataset and annotations automatically if not exist.
39
40
        Args:
41
            cfg (common.config.Config): _description_
42
43
        Returns:
44
            dict: Dictionary of torch.utils.data.Dataset objects by split.
45
        """
46
47
        datasets = dict()
48
49
        datasets_config = cfg.datasets_cfg
50
51
        assert len(datasets_config) > 0, "At least one dataset has to be specified."
52
53
        for name in datasets_config:
54
            dataset_config = datasets_config[name]
55
56
            builder = registry.get_builder_class(name)(dataset_config)
57
            dataset = builder.build_datasets()
58
59
            datasets[name] = dataset
60
61
        return datasets
62
63
    def train_step(self, model, samples):
64
        output = model(samples)
65
        loss_dict = {}
66
        for k,v in output.items():
67
            if "loss" in k:
68
                loss_dict[k] = v
69
        return output["loss"], loss_dict
70
71
    def valid_step(self, model, samples):
72
        raise NotImplementedError
73
74
    def before_evaluation(self, model, dataset, **kwargs):
75
        model.before_evaluation(dataset=dataset, task_type=type(self))
76
77
    def after_evaluation(self, **kwargs):
78
        pass
79
80
    def inference_step(self):
81
        raise NotImplementedError
82
83
    def evaluation(self, model, data_loader, cuda_enabled=True):
84
        metric_logger = MetricLogger(delimiter="  ")
85
        header = "Evaluation"
86
        # TODO make it configurable
87
        print_freq = 10
88
89
        results = []
90
91
        for samples in metric_logger.log_every(data_loader, print_freq, header):
92
            samples = prepare_sample(samples, cuda_enabled=cuda_enabled)
93
94
            eval_output = self.valid_step(model=model, samples=samples)
95
            results.extend(eval_output)
96
97
        if is_dist_avail_and_initialized():
98
            dist.barrier()
99
100
        return results
101
102
    def train_epoch(
103
        self,
104
        epoch,
105
        model,
106
        data_loader,
107
        optimizer,
108
        lr_scheduler,
109
        scaler=None,
110
        cuda_enabled=False,
111
        log_freq=50,
112
        accum_grad_iters=1,
113
    ):
114
        return self._train_inner_loop(
115
            epoch=epoch,
116
            iters_per_epoch=len(data_loader),
117
            model=model,
118
            data_loader=data_loader,
119
            optimizer=optimizer,
120
            scaler=scaler,
121
            lr_scheduler=lr_scheduler,
122
            log_freq=log_freq,
123
            cuda_enabled=cuda_enabled,
124
            accum_grad_iters=accum_grad_iters,
125
        )
126
127
    def train_iters(
128
        self,
129
        epoch,
130
        start_iters,
131
        iters_per_inner_epoch,
132
        model,
133
        data_loader,
134
        optimizer,
135
        lr_scheduler,
136
        scaler=None,
137
        cuda_enabled=False,
138
        log_freq=50,
139
        accum_grad_iters=1,
140
    ):
141
        return self._train_inner_loop(
142
            epoch=epoch,
143
            start_iters=start_iters,
144
            iters_per_epoch=iters_per_inner_epoch,
145
            model=model,
146
            data_loader=data_loader,
147
            optimizer=optimizer,
148
            scaler=scaler,
149
            lr_scheduler=lr_scheduler,
150
            log_freq=log_freq,
151
            cuda_enabled=cuda_enabled,
152
            accum_grad_iters=accum_grad_iters,
153
        )
154
155
    def _train_inner_loop(
156
        self,
157
        epoch,
158
        iters_per_epoch,
159
        model,
160
        data_loader,
161
        optimizer,
162
        lr_scheduler,
163
        scaler=None,
164
        start_iters=None,
165
        log_freq=50,
166
        cuda_enabled=False,
167
        accum_grad_iters=1,
168
    ):
169
        """
170
        An inner training loop compatible with both epoch-based and iter-based training.
171
172
        When using epoch-based, training stops after one epoch; when using iter-based,
173
        training stops after #iters_per_epoch iterations.
174
        """
175
        use_amp = scaler is not None
176
177
        if not hasattr(data_loader, "__next__"):
178
            # convert to iterator if not already
179
            data_loader = iter(data_loader)
180
181
        metric_logger = MetricLogger(delimiter="  ")
182
        metric_logger.add_meter("lr", SmoothedValue(window_size=1, fmt="{value:.6f}"))
183
        metric_logger.add_meter("loss", SmoothedValue(window_size=1, fmt="{value:.4f}"))
184
185
        # if iter-based runner, schedule lr based on inner epoch.
186
        logging.info(
187
            "Start training epoch {}, {} iters per inner epoch.".format(
188
                epoch, iters_per_epoch
189
            )
190
        )
191
        header = "Train: data epoch: [{}]".format(epoch)
192
        if start_iters is None:
193
            # epoch-based runner
194
            inner_epoch = epoch
195
        else:
196
            # In iter-based runner, we schedule the learning rate based on iterations.
197
            inner_epoch = start_iters // iters_per_epoch
198
            header = header + "; inner epoch [{}]".format(inner_epoch)
199
200
        for i in metric_logger.log_every(range(iters_per_epoch), log_freq, header):
201
            # if using iter-based runner, we stop after iters_per_epoch iterations.
202
            if i >= iters_per_epoch:
203
                break
204
205
            samples = next(data_loader)
206
207
            samples = prepare_sample(samples, cuda_enabled=cuda_enabled)
208
            samples.update(
209
                {
210
                    "epoch": inner_epoch,
211
                    "num_iters_per_epoch": iters_per_epoch,
212
                    "iters": i,
213
                }
214
            )
215
216
            lr_scheduler.step(cur_epoch=inner_epoch, cur_step=i)
217
218
            with torch.cuda.amp.autocast(enabled=use_amp):
219
                loss, loss_dict = self.train_step(model=model, samples=samples)
220
                loss /= accum_grad_iters #TODO: not affect loss_dict values for logging
221
222
            # after_train_step()
223
            if use_amp:
224
                scaler.scale(loss).backward()
225
            else:
226
                loss.backward()
227
228
            # update gradients every accum_grad_iters iterations
229
            if (i + 1) % accum_grad_iters == 0:
230
                if use_amp:
231
                    scaler.step(optimizer)
232
                    scaler.update()                     
233
                else:    
234
                    optimizer.step()
235
                optimizer.zero_grad()
236
237
            metric_logger.update(**loss_dict)
238
            metric_logger.update(lr=optimizer.param_groups[0]["lr"])
239
240
        # after train_epoch()
241
        # gather the stats from all processes
242
        metric_logger.synchronize_between_processes()
243
        logging.info("Averaged stats: " + str(metric_logger.global_avg()))
244
        return {
245
            k: "{:.3f}".format(meter.global_avg)
246
            for k, meter in metric_logger.meters.items()
247
        }
248
249
    @staticmethod
250
    def save_result(result, result_dir, filename, remove_duplicate=""):
251
        import json
252
253
        result_file = os.path.join(
254
            result_dir, "%s_rank%d.json" % (filename, get_rank())
255
        )
256
        final_result_file = os.path.join(result_dir, "%s.json" % filename)
257
258
        json.dump(result, open(result_file, "w"))
259
260
        if is_dist_avail_and_initialized():
261
            dist.barrier()
262
263
        if is_main_process():
264
            logging.warning("rank %d starts merging results." % get_rank())
265
            # combine results from all processes
266
            result = []
267
268
            for rank in range(get_world_size()):
269
                result_file = os.path.join(
270
                    result_dir, "%s_rank%d.json" % (filename, rank)
271
                )
272
                res = json.load(open(result_file, "r"))
273
                result += res
274
275
            if remove_duplicate:
276
                result_new = []
277
                id_list = []
278
                for res in result:
279
                    if res[remove_duplicate] not in id_list:
280
                        id_list.append(res[remove_duplicate])
281
                        result_new.append(res)
282
                result = result_new
283
284
            json.dump(result, open(final_result_file, "w"))
285
            print("result file saved to %s" % final_result_file)
286
287
        return final_result_file