|
a |
|
b/lavis/tasks/__init__.py |
|
|
1 |
""" |
|
|
2 |
Copyright (c) 2022, salesforce.com, inc. |
|
|
3 |
All rights reserved. |
|
|
4 |
SPDX-License-Identifier: BSD-3-Clause |
|
|
5 |
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause |
|
|
6 |
""" |
|
|
7 |
|
|
|
8 |
from lavis.common.registry import registry |
|
|
9 |
from lavis.tasks.base_task import BaseTask |
|
|
10 |
from lavis.tasks.captioning import CaptionTask |
|
|
11 |
from lavis.tasks.image_text_pretrain import ImageTextPretrainTask |
|
|
12 |
|
|
|
13 |
from lavis.tasks.vqa import VQATask |
|
|
14 |
|
|
|
15 |
|
|
|
16 |
def setup_task(cfg): |
|
|
17 |
assert "task" in cfg.run_cfg, "Task name must be provided." |
|
|
18 |
|
|
|
19 |
task_name = cfg.run_cfg.task |
|
|
20 |
task = registry.get_task_class(task_name).setup_task(cfg=cfg) |
|
|
21 |
assert task is not None, "Task {} not properly registered.".format(task_name) |
|
|
22 |
|
|
|
23 |
return task |
|
|
24 |
|
|
|
25 |
|
|
|
26 |
__all__ = [ |
|
|
27 |
"BaseTask", |
|
|
28 |
"CaptionTask", |
|
|
29 |
"VQATask", |
|
|
30 |
"ImageTextPretrainTask" |
|
|
31 |
] |