|
a |
|
b/model/lavis/datasets/data_utils.py |
|
|
1 |
""" |
|
|
2 |
Copyright (c) 2022, salesforce.com, inc. |
|
|
3 |
All rights reserved. |
|
|
4 |
SPDX-License-Identifier: BSD-3-Clause |
|
|
5 |
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause |
|
|
6 |
""" |
|
|
7 |
|
|
|
8 |
import gzip |
|
|
9 |
import logging |
|
|
10 |
import os |
|
|
11 |
import random as rnd |
|
|
12 |
import tarfile |
|
|
13 |
import zipfile |
|
|
14 |
|
|
|
15 |
|
|
|
16 |
import numpy as np |
|
|
17 |
import torch |
|
|
18 |
from torch.utils.data.dataset import IterableDataset, ChainDataset |
|
|
19 |
from model.lavis.common.registry import registry |
|
|
20 |
from model.lavis.datasets.datasets.base_dataset import ConcatDataset |
|
|
21 |
from tqdm import tqdm |
|
|
22 |
|
|
|
23 |
MAX_INT = registry.get("MAX_INT") |
|
|
24 |
|
|
|
25 |
|
|
|
26 |
|
|
|
27 |
|
|
|
28 |
def apply_to_sample(f, sample): |
|
|
29 |
if len(sample) == 0: |
|
|
30 |
return {} |
|
|
31 |
|
|
|
32 |
def _apply(x): |
|
|
33 |
if torch.is_tensor(x): |
|
|
34 |
return f(x) |
|
|
35 |
elif isinstance(x, dict): |
|
|
36 |
return {key: _apply(value) for key, value in x.items()} |
|
|
37 |
elif isinstance(x, list): |
|
|
38 |
return [_apply(x) for x in x] |
|
|
39 |
else: |
|
|
40 |
return x |
|
|
41 |
|
|
|
42 |
return _apply(sample) |
|
|
43 |
|
|
|
44 |
|
|
|
45 |
def move_to_cuda(sample): |
|
|
46 |
def _move_to_cuda(tensor): |
|
|
47 |
return tensor.cuda() |
|
|
48 |
|
|
|
49 |
return apply_to_sample(_move_to_cuda, sample) |
|
|
50 |
|
|
|
51 |
|
|
|
52 |
def prepare_sample(samples, cuda_enabled=True): |
|
|
53 |
if cuda_enabled: |
|
|
54 |
samples = move_to_cuda(samples) |
|
|
55 |
|
|
|
56 |
# TODO fp16 support |
|
|
57 |
|
|
|
58 |
return samples |
|
|
59 |
|
|
|
60 |
|
|
|
61 |
def reorg_datasets_by_split(datasets): |
|
|
62 |
""" |
|
|
63 |
Organizes datasets by split. |
|
|
64 |
|
|
|
65 |
Args: |
|
|
66 |
datasets: dict of torch.utils.data.Dataset objects by name. |
|
|
67 |
|
|
|
68 |
Returns: |
|
|
69 |
Dict of datasets by split {split_name: List[Datasets]}. |
|
|
70 |
""" |
|
|
71 |
# if len(datasets) == 1: |
|
|
72 |
# return datasets[list(datasets.keys())[0]] |
|
|
73 |
# else: |
|
|
74 |
reorg_datasets = dict() |
|
|
75 |
|
|
|
76 |
# reorganize by split |
|
|
77 |
for _, dataset in datasets.items(): |
|
|
78 |
for split_name, dataset_split in dataset.items(): |
|
|
79 |
if split_name not in reorg_datasets: |
|
|
80 |
reorg_datasets[split_name] = [dataset_split] |
|
|
81 |
else: |
|
|
82 |
reorg_datasets[split_name].append(dataset_split) |
|
|
83 |
|
|
|
84 |
return reorg_datasets |
|
|
85 |
|
|
|
86 |
|
|
|
87 |
def concat_datasets(datasets): |
|
|
88 |
""" |
|
|
89 |
Concatenates multiple datasets into a single dataset. |
|
|
90 |
|
|
|
91 |
It supports may-style datasets and DataPipeline from WebDataset. Currently, does not support |
|
|
92 |
generic IterableDataset because it requires creating separate samplers. |
|
|
93 |
|
|
|
94 |
Now only supports conctenating training datasets and assuming validation and testing |
|
|
95 |
have only a single dataset. This is because metrics should not be computed on the concatenated |
|
|
96 |
datasets. |
|
|
97 |
|
|
|
98 |
Args: |
|
|
99 |
datasets: dict of torch.utils.data.Dataset objects by split. |
|
|
100 |
|
|
|
101 |
Returns: |
|
|
102 |
Dict of concatenated datasets by split, "train" is the concatenation of multiple datasets, |
|
|
103 |
"val" and "test" remain the same. |
|
|
104 |
|
|
|
105 |
If the input training datasets contain both map-style and DataPipeline datasets, returns |
|
|
106 |
a tuple, where the first element is a concatenated map-style dataset and the second |
|
|
107 |
element is a chained DataPipeline dataset. |
|
|
108 |
|
|
|
109 |
""" |
|
|
110 |
# concatenate datasets in the same split |
|
|
111 |
for split_name in datasets: |
|
|
112 |
if split_name != "train": |
|
|
113 |
assert ( |
|
|
114 |
len(datasets[split_name]) == 1 |
|
|
115 |
), "Do not support multiple {} datasets.".format(split_name) |
|
|
116 |
datasets[split_name] = datasets[split_name][0] |
|
|
117 |
else: |
|
|
118 |
iterable_datasets, map_datasets = [], [] |
|
|
119 |
for dataset in datasets[split_name]: |
|
|
120 |
if isinstance(dataset, IterableDataset): |
|
|
121 |
raise NotImplementedError( |
|
|
122 |
"Do not support concatenation of generic IterableDataset." |
|
|
123 |
) |
|
|
124 |
else: |
|
|
125 |
map_datasets.append(dataset) |
|
|
126 |
|
|
|
127 |
# if len(iterable_datasets) > 0: |
|
|
128 |
# concatenate map-style datasets and iterable-style datasets separately |
|
|
129 |
chained_datasets = ( |
|
|
130 |
ChainDataset(iterable_datasets) if len(iterable_datasets) > 0 else None |
|
|
131 |
) |
|
|
132 |
concat_datasets = ( |
|
|
133 |
ConcatDataset(map_datasets) if len(map_datasets) > 0 else None |
|
|
134 |
) |
|
|
135 |
|
|
|
136 |
train_datasets = concat_datasets, chained_datasets |
|
|
137 |
train_datasets = tuple([x for x in train_datasets if x is not None]) |
|
|
138 |
train_datasets = ( |
|
|
139 |
train_datasets[0] if len(train_datasets) == 1 else train_datasets |
|
|
140 |
) |
|
|
141 |
|
|
|
142 |
datasets[split_name] = train_datasets |
|
|
143 |
|
|
|
144 |
return datasets |
|
|
145 |
|
|
|
146 |
|
|
|
147 |
def extract_archive(from_path, to_path=None, overwrite=False): |
|
|
148 |
"""Extract archive. |
|
|
149 |
|
|
|
150 |
Args: |
|
|
151 |
from_path: the path of the archive. |
|
|
152 |
to_path: the root path of the extracted files (directory of from_path) |
|
|
153 |
overwrite: overwrite existing files (False) |
|
|
154 |
|
|
|
155 |
Returns: |
|
|
156 |
List of paths to extracted files even if not overwritten. |
|
|
157 |
|
|
|
158 |
Examples: |
|
|
159 |
>>> url = 'http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/validation.tar.gz' |
|
|
160 |
>>> from_path = './validation.tar.gz' |
|
|
161 |
>>> to_path = './' |
|
|
162 |
>>> torchtext.utils.download_from_url(url, from_path) |
|
|
163 |
>>> torchtext.utils.extract_archive(from_path, to_path) |
|
|
164 |
>>> ['.data/val.de', '.data/val.en'] |
|
|
165 |
>>> torchtext.utils.download_from_url(url, from_path) |
|
|
166 |
>>> torchtext.utils.extract_archive(from_path, to_path) |
|
|
167 |
>>> ['.data/val.de', '.data/val.en'] |
|
|
168 |
|
|
|
169 |
""" |
|
|
170 |
|
|
|
171 |
if to_path is None: |
|
|
172 |
to_path = os.path.dirname(from_path) |
|
|
173 |
|
|
|
174 |
if from_path.endswith((".tar.gz", ".tgz")): |
|
|
175 |
logging.info("Opening tar file {} to {}.".format(from_path, to_path)) |
|
|
176 |
with tarfile.open(from_path, "r") as tar: |
|
|
177 |
files = [] |
|
|
178 |
for file_ in tqdm(tar): |
|
|
179 |
file_path = os.path.join(to_path, file_.name) |
|
|
180 |
if file_.isfile(): |
|
|
181 |
files.append(file_path) |
|
|
182 |
if os.path.exists(file_path): |
|
|
183 |
logging.info("{} already extracted.".format(file_path)) |
|
|
184 |
if not overwrite: |
|
|
185 |
continue |
|
|
186 |
tar.extract(file_, to_path) |
|
|
187 |
logging.info("Finished extracting tar file {}.".format(from_path)) |
|
|
188 |
return files |
|
|
189 |
|
|
|
190 |
elif from_path.endswith(".zip"): |
|
|
191 |
assert zipfile.is_zipfile(from_path), from_path |
|
|
192 |
logging.info("Opening zip file {} to {}.".format(from_path, to_path)) |
|
|
193 |
with zipfile.ZipFile(from_path, "r") as zfile: |
|
|
194 |
files = [] |
|
|
195 |
for file_ in tqdm(zfile.namelist()): |
|
|
196 |
file_path = os.path.join(to_path, file_) |
|
|
197 |
files.append(file_path) |
|
|
198 |
if os.path.exists(file_path): |
|
|
199 |
logging.info("{} already extracted.".format(file_path)) |
|
|
200 |
if not overwrite: |
|
|
201 |
continue |
|
|
202 |
zfile.extract(file_, to_path) |
|
|
203 |
files = [f for f in files if os.path.isfile(f)] |
|
|
204 |
logging.info("Finished extracting zip file {}.".format(from_path)) |
|
|
205 |
return files |
|
|
206 |
|
|
|
207 |
elif from_path.endswith(".gz"): |
|
|
208 |
logging.info("Opening gz file {} to {}.".format(from_path, to_path)) |
|
|
209 |
default_block_size = 65536 |
|
|
210 |
filename = from_path[:-3] |
|
|
211 |
files = [filename] |
|
|
212 |
with gzip.open(from_path, "rb") as gzfile, open(filename, "wb") as d_file: |
|
|
213 |
while True: |
|
|
214 |
block = gzfile.read(default_block_size) |
|
|
215 |
if not block: |
|
|
216 |
break |
|
|
217 |
else: |
|
|
218 |
d_file.write(block) |
|
|
219 |
d_file.write(block) |
|
|
220 |
logging.info("Finished extracting gz file {}.".format(from_path)) |
|
|
221 |
return files |
|
|
222 |
|
|
|
223 |
else: |
|
|
224 |
raise NotImplementedError( |
|
|
225 |
"We currently only support tar.gz, .tgz, .gz and zip achives." |
|
|
226 |
) |
|
|
227 |
|
|
|
228 |
|
|
|
229 |
def save_frames_grid(img_array, out_path): |
|
|
230 |
import torch |
|
|
231 |
from PIL import Image |
|
|
232 |
from torchvision.utils import make_grid |
|
|
233 |
|
|
|
234 |
if len(img_array.shape) == 3: |
|
|
235 |
img_array = img_array.unsqueeze(0) |
|
|
236 |
elif len(img_array.shape) == 5: |
|
|
237 |
b, t, c, h, w = img_array.shape |
|
|
238 |
img_array = img_array.view(-1, c, h, w) |
|
|
239 |
elif len(img_array.shape) == 4: |
|
|
240 |
pass |
|
|
241 |
else: |
|
|
242 |
raise NotImplementedError( |
|
|
243 |
"Supports only (b,t,c,h,w)-shaped inputs. First two dimensions can be ignored." |
|
|
244 |
) |
|
|
245 |
|
|
|
246 |
assert img_array.shape[1] == 3, "Exepcting input shape of (H, W, 3), i.e. RGB-only." |
|
|
247 |
|
|
|
248 |
grid = make_grid(img_array) |
|
|
249 |
ndarr = grid.permute(1, 2, 0).to("cpu", torch.uint8).numpy() |
|
|
250 |
|
|
|
251 |
img = Image.fromarray(ndarr) |
|
|
252 |
|
|
|
253 |
img.save(out_path) |