Switch to unified view

a b/model/lavis/datasets/data_utils.py
1
"""
2
 Copyright (c) 2022, salesforce.com, inc.
3
 All rights reserved.
4
 SPDX-License-Identifier: BSD-3-Clause
5
 For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
"""
7
8
import gzip
9
import logging
10
import os
11
import random as rnd
12
import tarfile
13
import zipfile
14
15
16
import numpy as np
17
import torch
18
from torch.utils.data.dataset import IterableDataset, ChainDataset
19
from model.lavis.common.registry import registry
20
from model.lavis.datasets.datasets.base_dataset import ConcatDataset
21
from tqdm import tqdm
22
23
MAX_INT = registry.get("MAX_INT")
24
25
26
27
28
def apply_to_sample(f, sample):
29
    if len(sample) == 0:
30
        return {}
31
32
    def _apply(x):
33
        if torch.is_tensor(x):
34
            return f(x)
35
        elif isinstance(x, dict):
36
            return {key: _apply(value) for key, value in x.items()}
37
        elif isinstance(x, list):
38
            return [_apply(x) for x in x]
39
        else:
40
            return x
41
42
    return _apply(sample)
43
44
45
def move_to_cuda(sample):
46
    def _move_to_cuda(tensor):
47
        return tensor.cuda()
48
49
    return apply_to_sample(_move_to_cuda, sample)
50
51
52
def prepare_sample(samples, cuda_enabled=True):
53
    if cuda_enabled:
54
        samples = move_to_cuda(samples)
55
56
    # TODO fp16 support
57
58
    return samples
59
60
61
def reorg_datasets_by_split(datasets):
62
    """
63
    Organizes datasets by split.
64
65
    Args:
66
        datasets: dict of torch.utils.data.Dataset objects by name.
67
68
    Returns:
69
        Dict of datasets by split {split_name: List[Datasets]}.
70
    """
71
    # if len(datasets) == 1:
72
    #     return datasets[list(datasets.keys())[0]]
73
    # else:
74
    reorg_datasets = dict()
75
76
    # reorganize by split
77
    for _, dataset in datasets.items():
78
        for split_name, dataset_split in dataset.items():
79
            if split_name not in reorg_datasets:
80
                reorg_datasets[split_name] = [dataset_split]
81
            else:
82
                reorg_datasets[split_name].append(dataset_split)
83
84
    return reorg_datasets
85
86
87
def concat_datasets(datasets):
88
    """
89
    Concatenates multiple datasets into a single dataset.
90
91
    It supports may-style datasets and DataPipeline from WebDataset. Currently, does not support
92
    generic IterableDataset because it requires creating separate samplers.
93
94
    Now only supports conctenating training datasets and assuming validation and testing
95
    have only a single dataset. This is because metrics should not be computed on the concatenated
96
    datasets.
97
98
    Args:
99
        datasets: dict of torch.utils.data.Dataset objects by split.
100
101
    Returns:
102
        Dict of concatenated datasets by split, "train" is the concatenation of multiple datasets,
103
        "val" and "test" remain the same.
104
105
        If the input training datasets contain both map-style and DataPipeline datasets, returns
106
        a tuple, where the first element is a concatenated map-style dataset and the second
107
        element is a chained DataPipeline dataset.
108
109
    """
110
    # concatenate datasets in the same split
111
    for split_name in datasets:
112
        if split_name != "train":
113
            assert (
114
                len(datasets[split_name]) == 1
115
            ), "Do not support multiple {} datasets.".format(split_name)
116
            datasets[split_name] = datasets[split_name][0]
117
        else:
118
            iterable_datasets, map_datasets = [], []
119
            for dataset in datasets[split_name]:
120
                if isinstance(dataset, IterableDataset):
121
                    raise NotImplementedError(
122
                        "Do not support concatenation of generic IterableDataset."
123
                    )
124
                else:
125
                    map_datasets.append(dataset)
126
127
            # if len(iterable_datasets) > 0:
128
            # concatenate map-style datasets and iterable-style datasets separately
129
            chained_datasets = (
130
                ChainDataset(iterable_datasets) if len(iterable_datasets) > 0 else None
131
            )
132
            concat_datasets = (
133
                ConcatDataset(map_datasets) if len(map_datasets) > 0 else None
134
            )
135
136
            train_datasets = concat_datasets, chained_datasets
137
            train_datasets = tuple([x for x in train_datasets if x is not None])
138
            train_datasets = (
139
                train_datasets[0] if len(train_datasets) == 1 else train_datasets
140
            )
141
142
            datasets[split_name] = train_datasets
143
144
    return datasets
145
146
147
def extract_archive(from_path, to_path=None, overwrite=False):
148
    """Extract archive.
149
150
    Args:
151
        from_path: the path of the archive.
152
        to_path: the root path of the extracted files (directory of from_path)
153
        overwrite: overwrite existing files (False)
154
155
    Returns:
156
        List of paths to extracted files even if not overwritten.
157
158
    Examples:
159
        >>> url = 'http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/validation.tar.gz'
160
        >>> from_path = './validation.tar.gz'
161
        >>> to_path = './'
162
        >>> torchtext.utils.download_from_url(url, from_path)
163
        >>> torchtext.utils.extract_archive(from_path, to_path)
164
        >>> ['.data/val.de', '.data/val.en']
165
        >>> torchtext.utils.download_from_url(url, from_path)
166
        >>> torchtext.utils.extract_archive(from_path, to_path)
167
        >>> ['.data/val.de', '.data/val.en']
168
169
    """
170
171
    if to_path is None:
172
        to_path = os.path.dirname(from_path)
173
174
    if from_path.endswith((".tar.gz", ".tgz")):
175
        logging.info("Opening tar file {} to {}.".format(from_path, to_path))
176
        with tarfile.open(from_path, "r") as tar:
177
            files = []
178
            for file_ in tqdm(tar):
179
                file_path = os.path.join(to_path, file_.name)
180
                if file_.isfile():
181
                    files.append(file_path)
182
                    if os.path.exists(file_path):
183
                        logging.info("{} already extracted.".format(file_path))
184
                        if not overwrite:
185
                            continue
186
                tar.extract(file_, to_path)
187
            logging.info("Finished extracting tar file {}.".format(from_path))
188
            return files
189
190
    elif from_path.endswith(".zip"):
191
        assert zipfile.is_zipfile(from_path), from_path
192
        logging.info("Opening zip file {} to {}.".format(from_path, to_path))
193
        with zipfile.ZipFile(from_path, "r") as zfile:
194
            files = []
195
            for file_ in tqdm(zfile.namelist()):
196
                file_path = os.path.join(to_path, file_)
197
                files.append(file_path)
198
                if os.path.exists(file_path):
199
                    logging.info("{} already extracted.".format(file_path))
200
                    if not overwrite:
201
                        continue
202
                zfile.extract(file_, to_path)
203
        files = [f for f in files if os.path.isfile(f)]
204
        logging.info("Finished extracting zip file {}.".format(from_path))
205
        return files
206
207
    elif from_path.endswith(".gz"):
208
        logging.info("Opening gz file {} to {}.".format(from_path, to_path))
209
        default_block_size = 65536
210
        filename = from_path[:-3]
211
        files = [filename]
212
        with gzip.open(from_path, "rb") as gzfile, open(filename, "wb") as d_file:
213
            while True:
214
                block = gzfile.read(default_block_size)
215
                if not block:
216
                    break
217
                else:
218
                    d_file.write(block)
219
            d_file.write(block)
220
        logging.info("Finished extracting gz file {}.".format(from_path))
221
        return files
222
223
    else:
224
        raise NotImplementedError(
225
            "We currently only support tar.gz, .tgz, .gz and zip achives."
226
        )
227
228
229
def save_frames_grid(img_array, out_path):
230
    import torch
231
    from PIL import Image
232
    from torchvision.utils import make_grid
233
234
    if len(img_array.shape) == 3:
235
        img_array = img_array.unsqueeze(0)
236
    elif len(img_array.shape) == 5:
237
        b, t, c, h, w = img_array.shape
238
        img_array = img_array.view(-1, c, h, w)
239
    elif len(img_array.shape) == 4:
240
        pass
241
    else:
242
        raise NotImplementedError(
243
            "Supports only (b,t,c,h,w)-shaped inputs. First two dimensions can be ignored."
244
        )
245
246
    assert img_array.shape[1] == 3, "Exepcting input shape of (H, W, 3), i.e. RGB-only."
247
248
    grid = make_grid(img_array)
249
    ndarr = grid.permute(1, 2, 0).to("cpu", torch.uint8).numpy()
250
251
    img = Image.fromarray(ndarr)
252
253
    img.save(out_path)