|
a |
|
b/lavis/datasets/builders/classification_builder.py |
|
|
1 |
""" |
|
|
2 |
Copyright (c) 2022, salesforce.com, inc. |
|
|
3 |
All rights reserved. |
|
|
4 |
SPDX-License-Identifier: BSD-3-Clause |
|
|
5 |
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause |
|
|
6 |
""" |
|
|
7 |
|
|
|
8 |
from lavis.common.registry import registry |
|
|
9 |
from lavis.datasets.builders.base_dataset_builder import BaseDatasetBuilder |
|
|
10 |
from lavis.datasets.datasets.nlvr_datasets import NLVRDataset, NLVREvalDataset |
|
|
11 |
from lavis.datasets.datasets.snli_ve_datasets import SNLIVisualEntialmentDataset |
|
|
12 |
|
|
|
13 |
|
|
|
14 |
@registry.register_builder("nlvr") |
|
|
15 |
class NLVRBuilder(BaseDatasetBuilder): |
|
|
16 |
train_dataset_cls = NLVRDataset |
|
|
17 |
eval_dataset_cls = NLVREvalDataset |
|
|
18 |
|
|
|
19 |
DATASET_CONFIG_DICT = {"default": "configs/datasets/nlvr/defaults.yaml"} |
|
|
20 |
|
|
|
21 |
|
|
|
22 |
@registry.register_builder("snli_ve") |
|
|
23 |
class SNLIVisualEntailmentBuilder(BaseDatasetBuilder): |
|
|
24 |
train_dataset_cls = SNLIVisualEntialmentDataset |
|
|
25 |
eval_dataset_cls = SNLIVisualEntialmentDataset |
|
|
26 |
|
|
|
27 |
DATASET_CONFIG_DICT = {"default": "configs/datasets/snli_ve/defaults.yaml"} |