Diff of /docs/tutorial.tasks.rst [000000] .. [dc40d0]

Switch to unified view

a b/docs/tutorial.tasks.rst
1
Adding Tasks
2
####################################
3
4
This is a tutorial on adding new machine learning tasks using ``lavis.tasks`` module.
5
6
The LAVIS library includes a standard task module that centralizes the model training and evaluation procedure of machine learning tasks. 
7
The ``lavis.tasks`` module is designed such that any new tasks can be added and integrated, catering to any customization in the training and testing procedures. 
8
In this tutorial, we will replicate the steps to add a new task into LAVIS for the `video-grounded dialogue tasks <https://arxiv.org/pdf/1901.09107.pdf>`_. 
9
10
Base Task ``lavis.tasks.base_task``
11
********************************************************************************
12
13
Note that any new model definition should inherit the base task class ``BaseTask``:
14
15
.. code-block:: python
16
17
    import logging
18
    import os
19
    
20
    import torch.distributed as dist
21
    from lavis.common.dist_utils import get_rank, get_world_size, is_main_process
22
    from lavis.common.logger import MetricLogger, SmoothedValue
23
    from lavis.common.registry import registry
24
    from lavis.datasets.data_utils import prepare_sample
25
    
26
    class BaseTask:
27
        def __init__(self, **kwargs):
28
            super().__init__()
29
    
30
            self.inst_id_key = "instance_id"
31
    
32
        @classmethod
33
        def setup_task(cls, **kwargs):
34
            return cls()
35
    
36
        def build_model(self, cfg):
37
            model_config = cfg.model_cfg
38
    
39
            model_cls = registry.get_model_class(model_config.arch)
40
            return model_cls.from_config(model_config)
41
    
42
        def build_datasets(self, cfg):
43
            """
44
            Build a dictionary of datasets, keyed by split 'train', 'valid', 'test'.
45
            Download dataset and annotations automatically if not exist.
46
    
47
            Args:
48
                cfg (common.config.Config): _description_
49
    
50
            Returns:
51
                dict: Dictionary of torch.utils.data.Dataset objects by split.
52
            """
53
    
54
            datasets = dict()
55
    
56
            datasets_config = cfg.datasets_cfg
57
    
58
            assert len(datasets_config) > 0, "At least one dataset has to be specified."
59
    
60
            for name in datasets_config:
61
                dataset_config = datasets_config[name]
62
    
63
                builder = registry.get_builder_class(name)(dataset_config)
64
                dataset = builder.build_datasets()
65
    
66
                datasets[name] = dataset
67
    
68
            return datasets
69
    
70
        def train_step(self, model, samples):
71
            loss = model(samples)["loss"]
72
            return loss
73
    
74
        ...
75
76
In this base task, we already declare and standardize many common methods such as ``train_step``, ``build_model``, and ``build_datasets``. 
77
Inheriting this base task class allows us to standardize operations of tasks across all task classes.
78
We recommend users not change the implementation of the base task class as this will have an impact on all existing task subclasses.
79
80
Dialogue Task ``lavis.tasks.dialogue``
81
********************************************************************************
82
83
In this step, we can define a new task class, e.g. under ``lavis.tasks.dialogue``, for video-grounded dialogues.
84
For instance, we define a new task class ``DialogueTask`` that inherits the super task class ``BaseTask``.
85
86
.. code-block:: python
87
88
    import json
89
    import os
90
    
91
    from lavis.common.dist_utils import main_process
92
    from lavis.common.logger import MetricLogger
93
    from lavis.common.registry import registry
94
    from lavis.tasks.base_task import BaseTask
95
    from lavis.datasets.data_utils import prepare_sample
96
    
97
    import numpy as np 
98
    
99
    @registry.register_task("dialogue")
100
    class DialogueTask(BaseTask):
101
        def __init__(self, num_beams, max_len, min_len, evaluate, report_metric=True):
102
            super().__init__()
103
    
104
            self.num_beams = num_beams
105
            self.max_len = max_len
106
            self.min_len = min_len
107
            self.evaluate = evaluate
108
    
109
            self.report_metric = report_metric
110
    
111
        @classmethod
112
        def setup_task(cls, cfg):
113
            run_cfg = cfg.run_cfg
114
    
115
            num_beams = run_cfg.num_beams
116
            max_len = run_cfg.max_len
117
            min_len = run_cfg.min_len
118
            evaluate = run_cfg.evaluate
119
    
120
            report_metric = run_cfg.get("report_metric", True)
121
    
122
            return cls(
123
                num_beams=num_beams,
124
                max_len=max_len,
125
                min_len=min_len,
126
                evaluate=evaluate,
127
                report_metric=report_metric,
128
            )
129
    
130
        def valid_step(self, model, samples):
131
            results = []        
132
            loss = model(samples)["loss"].item() 
133
            
134
            return [loss] 
135
        ...
136
137
Note that for any new task, we advise the users to review carefully the functions implemented within ``BaseTask`` and consider which methods should be modified. 
138
For instance, the base task class already contains a standard implementation of model training steps that are common among machine learning steps. 
139
Some major methods we want to emphasize and should be customized by each task are the ``valid_step`` and ``evaluation``. 
140
These operations were not fully implemented in the base task class due to the differences in evaluation procedures among many machine learning tasks. 
141
Another method that should be considered is the ``setup_task`` method. 
142
This method will receive configurations that set task-specific parameters to initialize any task instance.
143
144
Registering New Task ``lavis.tasks.__init__`` 
145
********************************************************************************
146
147
Any new task must be officially registered as part of the ``lavis.tasks`` module. For instance, to add a new task for video-grounded dialogues, we can modify the ``__init__.py`` as follows:
148
149
.. code-block:: python
150
151
    from lavis.tasks.dialogue import DialogueTask
152
    
153
    ...
154
    __all__ = [
155
        ...
156
        "DialogueTask"
157
    ]
158
159
Assigning Task 
160
***************
161
162
From the above example of task class, note that we define a ``setup_task`` method for each task class. 
163
This method will process a configuration file and pass specific parameters e.g. ``num_beams`` (for beam search generative tasks during the inference stage), to initialize the task classes properly. 
164
To assign and associate any task, we need to specify the correct registry of task classes in a configuration file. 
165
For instance, the following should be specified in a configuration file e.g. ``dialogue_avsd_ft.yaml``:
166
167
.. code-block:: yaml
168
169
    run:
170
      task: dialogue # name of the task 
171
      
172
      # optimizer
173
      ...
174
    
175
      max_len: 20
176
      min_len: 5
177
      num_beams: 3    
178
      ...
179
    
180
Subsequently, any processes (e.g. training) should load this configuration file to assign the correct task.
181
182
.. code-block:: sh
183
184
    python train.py --cfg-path dialogue_avsd_ft.yaml