--- a
+++ b/lavis/datasets/data_utils.py
@@ -0,0 +1,284 @@
+"""
+ Copyright (c) 2022, salesforce.com, inc.
+ All rights reserved.
+ SPDX-License-Identifier: BSD-3-Clause
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
+"""
+
+import gzip
+import logging
+import os
+import random as rnd
+import tarfile
+import zipfile
+
+import decord
+import webdataset as wds
+import numpy as np
+import torch
+from torch.utils.data.dataset import IterableDataset, ChainDataset
+from decord import VideoReader
+from lavis.common.registry import registry
+from lavis.datasets.datasets.base_dataset import ConcatDataset
+from tqdm import tqdm
+
+decord.bridge.set_bridge("torch")
+MAX_INT = registry.get("MAX_INT")
+
+
+def load_video(video_path, n_frms=MAX_INT, height=-1, width=-1, sampling="uniform"):
+    vr = VideoReader(uri=video_path, height=height, width=width)
+
+    vlen = len(vr)
+    start, end = 0, vlen
+
+    n_frms = min(n_frms, vlen)
+
+    if sampling == "uniform":
+        indices = np.arange(start, end, vlen / n_frms).astype(int)
+    elif sampling == "headtail":
+        indices_h = sorted(rnd.sample(range(vlen // 2), n_frms // 2))
+        indices_t = sorted(rnd.sample(range(vlen // 2, vlen), n_frms // 2))
+        indices = indices_h + indices_t
+    else:
+        raise NotImplementedError
+
+    # get_batch -> T, H, W, C
+    frms = vr.get_batch(indices).permute(3, 0, 1, 2).float()  # (C, T, H, W)
+
+    return frms
+
+
+def apply_to_sample(f, sample):
+    if len(sample) == 0:
+        return {}
+
+    def _apply(x):
+        if torch.is_tensor(x):
+            return f(x)
+        elif isinstance(x, dict):
+            return {key: _apply(value) for key, value in x.items()}
+        elif isinstance(x, list):
+            return [_apply(x) for x in x]
+        else:
+            return x
+
+    return _apply(sample)
+
+
+def move_to_cuda(sample):
+    def _move_to_cuda(tensor):
+        return tensor.cuda()
+
+    return apply_to_sample(_move_to_cuda, sample)
+
+
+def prepare_sample(samples, cuda_enabled=True):
+    if cuda_enabled:
+        samples = move_to_cuda(samples)
+
+    # TODO fp16 support
+
+    return samples
+
+
+def reorg_datasets_by_split(datasets):
+    """
+    Organizes datasets by split.
+
+    Args:
+        datasets: dict of torch.utils.data.Dataset objects by name.
+
+    Returns:
+        Dict of datasets by split {split_name: List[Datasets]}.
+    """
+    # if len(datasets) == 1:
+    #     return datasets[list(datasets.keys())[0]]
+    # else:
+    reorg_datasets = dict()
+
+    # reorganize by split
+    for _, dataset in datasets.items():
+        for split_name, dataset_split in dataset.items():
+            if split_name not in reorg_datasets:
+                reorg_datasets[split_name] = [dataset_split]
+            else:
+                reorg_datasets[split_name].append(dataset_split)
+
+    return reorg_datasets
+
+
+def concat_datasets(datasets):
+    """
+    Concatenates multiple datasets into a single dataset.
+
+    It supports may-style datasets and DataPipeline from WebDataset. Currently, does not support
+    generic IterableDataset because it requires creating separate samplers.
+
+    Now only supports conctenating training datasets and assuming validation and testing
+    have only a single dataset. This is because metrics should not be computed on the concatenated
+    datasets.
+
+    Args:
+        datasets: dict of torch.utils.data.Dataset objects by split.
+
+    Returns:
+        Dict of concatenated datasets by split, "train" is the concatenation of multiple datasets,
+        "val" and "test" remain the same.
+
+        If the input training datasets contain both map-style and DataPipeline datasets, returns
+        a tuple, where the first element is a concatenated map-style dataset and the second
+        element is a chained DataPipeline dataset.
+
+    """
+    # concatenate datasets in the same split
+    for split_name in datasets:
+        if split_name != "train":
+            assert (
+                len(datasets[split_name]) == 1
+            ), "Do not support multiple {} datasets.".format(split_name)
+            datasets[split_name] = datasets[split_name][0]
+        else:
+            iterable_datasets, map_datasets = [], []
+            for dataset in datasets[split_name]:
+                if isinstance(dataset, wds.DataPipeline):
+                    logging.info(
+                        "Dataset {} is IterableDataset, can't be concatenated.".format(
+                            dataset
+                        )
+                    )
+                    iterable_datasets.append(dataset)
+                elif isinstance(dataset, IterableDataset):
+                    raise NotImplementedError(
+                        "Do not support concatenation of generic IterableDataset."
+                    )
+                else:
+                    map_datasets.append(dataset)
+
+            # if len(iterable_datasets) > 0:
+            # concatenate map-style datasets and iterable-style datasets separately
+            chained_datasets = (
+                ChainDataset(iterable_datasets) if len(iterable_datasets) > 0 else None
+            )
+            concat_datasets = (
+                ConcatDataset(map_datasets) if len(map_datasets) > 0 else None
+            )
+
+            train_datasets = concat_datasets, chained_datasets
+            train_datasets = tuple([x for x in train_datasets if x is not None])
+            train_datasets = (
+                train_datasets[0] if len(train_datasets) == 1 else train_datasets
+            )
+
+            datasets[split_name] = train_datasets
+
+    return datasets
+
+
+def extract_archive(from_path, to_path=None, overwrite=False):
+    """Extract archive.
+
+    Args:
+        from_path: the path of the archive.
+        to_path: the root path of the extracted files (directory of from_path)
+        overwrite: overwrite existing files (False)
+
+    Returns:
+        List of paths to extracted files even if not overwritten.
+
+    Examples:
+        >>> url = 'http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/validation.tar.gz'
+        >>> from_path = './validation.tar.gz'
+        >>> to_path = './'
+        >>> torchtext.utils.download_from_url(url, from_path)
+        >>> torchtext.utils.extract_archive(from_path, to_path)
+        >>> ['.data/val.de', '.data/val.en']
+        >>> torchtext.utils.download_from_url(url, from_path)
+        >>> torchtext.utils.extract_archive(from_path, to_path)
+        >>> ['.data/val.de', '.data/val.en']
+
+    """
+
+    if to_path is None:
+        to_path = os.path.dirname(from_path)
+
+    if from_path.endswith((".tar.gz", ".tgz")):
+        logging.info("Opening tar file {} to {}.".format(from_path, to_path))
+        with tarfile.open(from_path, "r") as tar:
+            files = []
+            for file_ in tqdm(tar):
+                file_path = os.path.join(to_path, file_.name)
+                if file_.isfile():
+                    files.append(file_path)
+                    if os.path.exists(file_path):
+                        logging.info("{} already extracted.".format(file_path))
+                        if not overwrite:
+                            continue
+                tar.extract(file_, to_path)
+            logging.info("Finished extracting tar file {}.".format(from_path))
+            return files
+
+    elif from_path.endswith(".zip"):
+        assert zipfile.is_zipfile(from_path), from_path
+        logging.info("Opening zip file {} to {}.".format(from_path, to_path))
+        with zipfile.ZipFile(from_path, "r") as zfile:
+            files = []
+            for file_ in tqdm(zfile.namelist()):
+                file_path = os.path.join(to_path, file_)
+                files.append(file_path)
+                if os.path.exists(file_path):
+                    logging.info("{} already extracted.".format(file_path))
+                    if not overwrite:
+                        continue
+                zfile.extract(file_, to_path)
+        files = [f for f in files if os.path.isfile(f)]
+        logging.info("Finished extracting zip file {}.".format(from_path))
+        return files
+
+    elif from_path.endswith(".gz"):
+        logging.info("Opening gz file {} to {}.".format(from_path, to_path))
+        default_block_size = 65536
+        filename = from_path[:-3]
+        files = [filename]
+        with gzip.open(from_path, "rb") as gzfile, open(filename, "wb") as d_file:
+            while True:
+                block = gzfile.read(default_block_size)
+                if not block:
+                    break
+                else:
+                    d_file.write(block)
+            d_file.write(block)
+        logging.info("Finished extracting gz file {}.".format(from_path))
+        return files
+
+    else:
+        raise NotImplementedError(
+            "We currently only support tar.gz, .tgz, .gz and zip achives."
+        )
+
+
+def save_frames_grid(img_array, out_path):
+    import torch
+    from PIL import Image
+    from torchvision.utils import make_grid
+
+    if len(img_array.shape) == 3:
+        img_array = img_array.unsqueeze(0)
+    elif len(img_array.shape) == 5:
+        b, t, c, h, w = img_array.shape
+        img_array = img_array.view(-1, c, h, w)
+    elif len(img_array.shape) == 4:
+        pass
+    else:
+        raise NotImplementedError(
+            "Supports only (b,t,c,h,w)-shaped inputs. First two dimensions can be ignored."
+        )
+
+    assert img_array.shape[1] == 3, "Exepcting input shape of (H, W, 3), i.e. RGB-only."
+
+    grid = make_grid(img_array)
+    ndarr = grid.permute(1, 2, 0).to("cpu", torch.uint8).numpy()
+
+    img = Image.fromarray(ndarr)
+
+    img.save(out_path)