Switch to unified view

a b/lavis/datasets/data_utils.py
1
"""
2
 Copyright (c) 2022, salesforce.com, inc.
3
 All rights reserved.
4
 SPDX-License-Identifier: BSD-3-Clause
5
 For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
"""
7
8
import gzip
9
import logging
10
import os
11
import random as rnd
12
import tarfile
13
import zipfile
14
15
import decord
16
import webdataset as wds
17
import numpy as np
18
import torch
19
from torch.utils.data.dataset import IterableDataset, ChainDataset
20
from decord import VideoReader
21
from lavis.common.registry import registry
22
from lavis.datasets.datasets.base_dataset import ConcatDataset
23
from tqdm import tqdm
24
25
decord.bridge.set_bridge("torch")
26
MAX_INT = registry.get("MAX_INT")
27
28
29
def load_video(video_path, n_frms=MAX_INT, height=-1, width=-1, sampling="uniform"):
30
    vr = VideoReader(uri=video_path, height=height, width=width)
31
32
    vlen = len(vr)
33
    start, end = 0, vlen
34
35
    n_frms = min(n_frms, vlen)
36
37
    if sampling == "uniform":
38
        indices = np.arange(start, end, vlen / n_frms).astype(int)
39
    elif sampling == "headtail":
40
        indices_h = sorted(rnd.sample(range(vlen // 2), n_frms // 2))
41
        indices_t = sorted(rnd.sample(range(vlen // 2, vlen), n_frms // 2))
42
        indices = indices_h + indices_t
43
    else:
44
        raise NotImplementedError
45
46
    # get_batch -> T, H, W, C
47
    frms = vr.get_batch(indices).permute(3, 0, 1, 2).float()  # (C, T, H, W)
48
49
    return frms
50
51
52
def apply_to_sample(f, sample):
53
    if len(sample) == 0:
54
        return {}
55
56
    def _apply(x):
57
        if torch.is_tensor(x):
58
            return f(x)
59
        elif isinstance(x, dict):
60
            return {key: _apply(value) for key, value in x.items()}
61
        elif isinstance(x, list):
62
            return [_apply(x) for x in x]
63
        else:
64
            return x
65
66
    return _apply(sample)
67
68
69
def move_to_cuda(sample):
70
    def _move_to_cuda(tensor):
71
        return tensor.cuda()
72
73
    return apply_to_sample(_move_to_cuda, sample)
74
75
76
def prepare_sample(samples, cuda_enabled=True):
77
    if cuda_enabled:
78
        samples = move_to_cuda(samples)
79
80
    # TODO fp16 support
81
82
    return samples
83
84
85
def reorg_datasets_by_split(datasets):
86
    """
87
    Organizes datasets by split.
88
89
    Args:
90
        datasets: dict of torch.utils.data.Dataset objects by name.
91
92
    Returns:
93
        Dict of datasets by split {split_name: List[Datasets]}.
94
    """
95
    # if len(datasets) == 1:
96
    #     return datasets[list(datasets.keys())[0]]
97
    # else:
98
    reorg_datasets = dict()
99
100
    # reorganize by split
101
    for _, dataset in datasets.items():
102
        for split_name, dataset_split in dataset.items():
103
            if split_name not in reorg_datasets:
104
                reorg_datasets[split_name] = [dataset_split]
105
            else:
106
                reorg_datasets[split_name].append(dataset_split)
107
108
    return reorg_datasets
109
110
111
def concat_datasets(datasets):
112
    """
113
    Concatenates multiple datasets into a single dataset.
114
115
    It supports may-style datasets and DataPipeline from WebDataset. Currently, does not support
116
    generic IterableDataset because it requires creating separate samplers.
117
118
    Now only supports conctenating training datasets and assuming validation and testing
119
    have only a single dataset. This is because metrics should not be computed on the concatenated
120
    datasets.
121
122
    Args:
123
        datasets: dict of torch.utils.data.Dataset objects by split.
124
125
    Returns:
126
        Dict of concatenated datasets by split, "train" is the concatenation of multiple datasets,
127
        "val" and "test" remain the same.
128
129
        If the input training datasets contain both map-style and DataPipeline datasets, returns
130
        a tuple, where the first element is a concatenated map-style dataset and the second
131
        element is a chained DataPipeline dataset.
132
133
    """
134
    # concatenate datasets in the same split
135
    for split_name in datasets:
136
        if split_name != "train":
137
            assert (
138
                len(datasets[split_name]) == 1
139
            ), "Do not support multiple {} datasets.".format(split_name)
140
            datasets[split_name] = datasets[split_name][0]
141
        else:
142
            iterable_datasets, map_datasets = [], []
143
            for dataset in datasets[split_name]:
144
                if isinstance(dataset, wds.DataPipeline):
145
                    logging.info(
146
                        "Dataset {} is IterableDataset, can't be concatenated.".format(
147
                            dataset
148
                        )
149
                    )
150
                    iterable_datasets.append(dataset)
151
                elif isinstance(dataset, IterableDataset):
152
                    raise NotImplementedError(
153
                        "Do not support concatenation of generic IterableDataset."
154
                    )
155
                else:
156
                    map_datasets.append(dataset)
157
158
            # if len(iterable_datasets) > 0:
159
            # concatenate map-style datasets and iterable-style datasets separately
160
            chained_datasets = (
161
                ChainDataset(iterable_datasets) if len(iterable_datasets) > 0 else None
162
            )
163
            concat_datasets = (
164
                ConcatDataset(map_datasets) if len(map_datasets) > 0 else None
165
            )
166
167
            train_datasets = concat_datasets, chained_datasets
168
            train_datasets = tuple([x for x in train_datasets if x is not None])
169
            train_datasets = (
170
                train_datasets[0] if len(train_datasets) == 1 else train_datasets
171
            )
172
173
            datasets[split_name] = train_datasets
174
175
    return datasets
176
177
178
def extract_archive(from_path, to_path=None, overwrite=False):
179
    """Extract archive.
180
181
    Args:
182
        from_path: the path of the archive.
183
        to_path: the root path of the extracted files (directory of from_path)
184
        overwrite: overwrite existing files (False)
185
186
    Returns:
187
        List of paths to extracted files even if not overwritten.
188
189
    Examples:
190
        >>> url = 'http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/validation.tar.gz'
191
        >>> from_path = './validation.tar.gz'
192
        >>> to_path = './'
193
        >>> torchtext.utils.download_from_url(url, from_path)
194
        >>> torchtext.utils.extract_archive(from_path, to_path)
195
        >>> ['.data/val.de', '.data/val.en']
196
        >>> torchtext.utils.download_from_url(url, from_path)
197
        >>> torchtext.utils.extract_archive(from_path, to_path)
198
        >>> ['.data/val.de', '.data/val.en']
199
200
    """
201
202
    if to_path is None:
203
        to_path = os.path.dirname(from_path)
204
205
    if from_path.endswith((".tar.gz", ".tgz")):
206
        logging.info("Opening tar file {} to {}.".format(from_path, to_path))
207
        with tarfile.open(from_path, "r") as tar:
208
            files = []
209
            for file_ in tqdm(tar):
210
                file_path = os.path.join(to_path, file_.name)
211
                if file_.isfile():
212
                    files.append(file_path)
213
                    if os.path.exists(file_path):
214
                        logging.info("{} already extracted.".format(file_path))
215
                        if not overwrite:
216
                            continue
217
                tar.extract(file_, to_path)
218
            logging.info("Finished extracting tar file {}.".format(from_path))
219
            return files
220
221
    elif from_path.endswith(".zip"):
222
        assert zipfile.is_zipfile(from_path), from_path
223
        logging.info("Opening zip file {} to {}.".format(from_path, to_path))
224
        with zipfile.ZipFile(from_path, "r") as zfile:
225
            files = []
226
            for file_ in tqdm(zfile.namelist()):
227
                file_path = os.path.join(to_path, file_)
228
                files.append(file_path)
229
                if os.path.exists(file_path):
230
                    logging.info("{} already extracted.".format(file_path))
231
                    if not overwrite:
232
                        continue
233
                zfile.extract(file_, to_path)
234
        files = [f for f in files if os.path.isfile(f)]
235
        logging.info("Finished extracting zip file {}.".format(from_path))
236
        return files
237
238
    elif from_path.endswith(".gz"):
239
        logging.info("Opening gz file {} to {}.".format(from_path, to_path))
240
        default_block_size = 65536
241
        filename = from_path[:-3]
242
        files = [filename]
243
        with gzip.open(from_path, "rb") as gzfile, open(filename, "wb") as d_file:
244
            while True:
245
                block = gzfile.read(default_block_size)
246
                if not block:
247
                    break
248
                else:
249
                    d_file.write(block)
250
            d_file.write(block)
251
        logging.info("Finished extracting gz file {}.".format(from_path))
252
        return files
253
254
    else:
255
        raise NotImplementedError(
256
            "We currently only support tar.gz, .tgz, .gz and zip achives."
257
        )
258
259
260
def save_frames_grid(img_array, out_path):
261
    import torch
262
    from PIL import Image
263
    from torchvision.utils import make_grid
264
265
    if len(img_array.shape) == 3:
266
        img_array = img_array.unsqueeze(0)
267
    elif len(img_array.shape) == 5:
268
        b, t, c, h, w = img_array.shape
269
        img_array = img_array.view(-1, c, h, w)
270
    elif len(img_array.shape) == 4:
271
        pass
272
    else:
273
        raise NotImplementedError(
274
            "Supports only (b,t,c,h,w)-shaped inputs. First two dimensions can be ignored."
275
        )
276
277
    assert img_array.shape[1] == 3, "Exepcting input shape of (H, W, 3), i.e. RGB-only."
278
279
    grid = make_grid(img_array)
280
    ndarr = grid.permute(1, 2, 0).to("cpu", torch.uint8).numpy()
281
282
    img = Image.fromarray(ndarr)
283
284
    img.save(out_path)