|
a |
|
b/minigpt4/tasks/__init__.py |
|
|
1 |
""" |
|
|
2 |
Copyright (c) 2022, salesforce.com, inc. |
|
|
3 |
All rights reserved. |
|
|
4 |
SPDX-License-Identifier: BSD-3-Clause |
|
|
5 |
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause |
|
|
6 |
""" |
|
|
7 |
|
|
|
8 |
from minigpt4.common.registry import registry |
|
|
9 |
from minigpt4.tasks.base_task import BaseTask |
|
|
10 |
from minigpt4.tasks.image_text_pretrain import ImageTextPretrainTask |
|
|
11 |
from minigpt4.tasks.mimic_generate_then_refine import MIMICGenerateThenRefine |
|
|
12 |
|
|
|
13 |
|
|
|
14 |
def setup_task(cfg): |
|
|
15 |
assert "task" in cfg.run_cfg, "Task name must be provided." |
|
|
16 |
|
|
|
17 |
task_name = cfg.run_cfg.task |
|
|
18 |
task = registry.get_task_class(task_name).setup_task(cfg=cfg) |
|
|
19 |
assert task is not None, "Task {} not properly registered.".format(task_name) |
|
|
20 |
|
|
|
21 |
return task |
|
|
22 |
|
|
|
23 |
|
|
|
24 |
__all__ = [ |
|
|
25 |
"BaseTask", |
|
|
26 |
"ImageTextPretrainTask", |
|
|
27 |
"MIMICGenerateThenRefine", |
|
|
28 |
] |