Switch to unified view

a b/model/lavis/common/utils.py
1
"""
2
 Copyright (c) 2022, salesforce.com, inc.
3
 All rights reserved.
4
 SPDX-License-Identifier: BSD-3-Clause
5
 For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
"""
7
8
import io
9
import json
10
import logging
11
import os
12
import pickle
13
import re
14
import shutil
15
import urllib
16
import urllib.error
17
import urllib.request
18
from typing import Optional
19
from urllib.parse import urlparse
20
21
import numpy as np
22
import pandas as pd
23
import yaml
24
from iopath.common.download import download
25
from iopath.common.file_io import file_lock, g_pathmgr
26
from model.lavis.common.registry import registry
27
from torch.utils.model_zoo import tqdm
28
from torchvision.datasets.utils import (
29
    check_integrity,
30
    download_file_from_google_drive,
31
    extract_archive,
32
)
33
34
35
def now():
36
    from datetime import datetime
37
38
    return datetime.now().strftime("%Y%m%d%H%M")[:-1]
39
40
41
def is_url(url_or_filename):
42
    parsed = urlparse(url_or_filename)
43
    return parsed.scheme in ("http", "https")
44
45
46
def get_cache_path(rel_path):
47
    return os.path.expanduser(os.path.join(registry.get_path("cache_root"), rel_path))
48
49
50
def get_abs_path(rel_path):
51
    return os.path.join(registry.get_path("library_root"), rel_path)
52
53
54
def load_json(filename):
55
    with open(filename, "r") as f:
56
        return json.load(f)
57
58
59
# The following are adapted from torchvision and vissl
60
# torchvision: https://github.com/pytorch/vision
61
# vissl: https://github.com/facebookresearch/vissl/blob/main/vissl/utils/download.py
62
63
64
def makedir(dir_path):
65
    """
66
    Create the directory if it does not exist.
67
    """
68
    is_success = False
69
    try:
70
        if not g_pathmgr.exists(dir_path):
71
            g_pathmgr.mkdirs(dir_path)
72
        is_success = True
73
    except BaseException:
74
        print(f"Error creating directory: {dir_path}")
75
    return is_success
76
77
78
def get_redirected_url(url: str):
79
    """
80
    Given a URL, returns the URL it redirects to or the
81
    original URL in case of no indirection
82
    """
83
    import requests
84
85
    with requests.Session() as session:
86
        with session.get(url, stream=True, allow_redirects=True) as response:
87
            if response.history:
88
                return response.url
89
            else:
90
                return url
91
92
93
def to_google_drive_download_url(view_url: str) -> str:
94
    """
95
    Utility function to transform a view URL of google drive
96
    to a download URL for google drive
97
    Example input:
98
        https://drive.google.com/file/d/137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp/view
99
    Example output:
100
        https://drive.google.com/uc?export=download&id=137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp
101
    """
102
    splits = view_url.split("/")
103
    assert splits[-1] == "view"
104
    file_id = splits[-2]
105
    return f"https://drive.google.com/uc?export=download&id={file_id}"
106
107
108
def download_google_drive_url(url: str, output_path: str, output_file_name: str):
109
    """
110
    Download a file from google drive
111
    Downloading an URL from google drive requires confirmation when
112
    the file of the size is too big (google drive notifies that
113
    anti-viral checks cannot be performed on such files)
114
    """
115
    import requests
116
117
    with requests.Session() as session:
118
119
        # First get the confirmation token and append it to the URL
120
        with session.get(url, stream=True, allow_redirects=True) as response:
121
            for k, v in response.cookies.items():
122
                if k.startswith("download_warning"):
123
                    url = url + "&confirm=" + v
124
125
        # Then download the content of the file
126
        with session.get(url, stream=True, verify=True) as response:
127
            makedir(output_path)
128
            path = os.path.join(output_path, output_file_name)
129
            total_size = int(response.headers.get("Content-length", 0))
130
            with open(path, "wb") as file:
131
                from tqdm import tqdm
132
133
                with tqdm(total=total_size) as progress_bar:
134
                    for block in response.iter_content(
135
                        chunk_size=io.DEFAULT_BUFFER_SIZE
136
                    ):
137
                        file.write(block)
138
                        progress_bar.update(len(block))
139
140
141
def _get_google_drive_file_id(url: str) -> Optional[str]:
142
    parts = urlparse(url)
143
144
    if re.match(r"(drive|docs)[.]google[.]com", parts.netloc) is None:
145
        return None
146
147
    match = re.match(r"/file/d/(?P<id>[^/]*)", parts.path)
148
    if match is None:
149
        return None
150
151
    return match.group("id")
152
153
154
def _urlretrieve(url: str, filename: str, chunk_size: int = 1024) -> None:
155
    with open(filename, "wb") as fh:
156
        with urllib.request.urlopen(
157
            urllib.request.Request(url, headers={"User-Agent": "vissl"})
158
        ) as response:
159
            with tqdm(total=response.length) as pbar:
160
                for chunk in iter(lambda: response.read(chunk_size), ""):
161
                    if not chunk:
162
                        break
163
                    pbar.update(chunk_size)
164
                    fh.write(chunk)
165
166
167
def download_url(
168
    url: str,
169
    root: str,
170
    filename: Optional[str] = None,
171
    md5: Optional[str] = None,
172
) -> None:
173
    """Download a file from a url and place it in root.
174
    Args:
175
        url (str): URL to download file from
176
        root (str): Directory to place downloaded file in
177
        filename (str, optional): Name to save the file under.
178
                                  If None, use the basename of the URL.
179
        md5 (str, optional): MD5 checksum of the download. If None, do not check
180
    """
181
    root = os.path.expanduser(root)
182
    if not filename:
183
        filename = os.path.basename(url)
184
    fpath = os.path.join(root, filename)
185
186
    makedir(root)
187
188
    # check if file is already present locally
189
    if check_integrity(fpath, md5):
190
        print("Using downloaded and verified file: " + fpath)
191
        return
192
193
    # expand redirect chain if needed
194
    url = get_redirected_url(url)
195
196
    # check if file is located on Google Drive
197
    file_id = _get_google_drive_file_id(url)
198
    if file_id is not None:
199
        return download_file_from_google_drive(file_id, root, filename, md5)
200
201
    # download the file
202
    try:
203
        print("Downloading " + url + " to " + fpath)
204
        _urlretrieve(url, fpath)
205
    except (urllib.error.URLError, IOError) as e:  # type: ignore[attr-defined]
206
        if url[:5] == "https":
207
            url = url.replace("https:", "http:")
208
            print(
209
                "Failed download. Trying https -> http instead."
210
                " Downloading " + url + " to " + fpath
211
            )
212
            _urlretrieve(url, fpath)
213
        else:
214
            raise e
215
216
    # check integrity of downloaded file
217
    if not check_integrity(fpath, md5):
218
        raise RuntimeError("File not found or corrupted.")
219
220
221
def download_and_extract_archive(
222
    url: str,
223
    download_root: str,
224
    extract_root: Optional[str] = None,
225
    filename: Optional[str] = None,
226
    md5: Optional[str] = None,
227
    remove_finished: bool = False,
228
) -> None:
229
    download_root = os.path.expanduser(download_root)
230
    if extract_root is None:
231
        extract_root = download_root
232
    if not filename:
233
        filename = os.path.basename(url)
234
235
    download_url(url, download_root, filename, md5)
236
237
    archive = os.path.join(download_root, filename)
238
    print("Extracting {} to {}".format(archive, extract_root))
239
    extract_archive(archive, extract_root, remove_finished)
240
241
242
def cache_url(url: str, cache_dir: str) -> str:
243
    """
244
    This implementation downloads the remote resource and caches it locally.
245
    The resource will only be downloaded if not previously requested.
246
    """
247
    parsed_url = urlparse(url)
248
    dirname = os.path.join(cache_dir, os.path.dirname(parsed_url.path.lstrip("/")))
249
    makedir(dirname)
250
    filename = url.split("/")[-1]
251
    cached = os.path.join(dirname, filename)
252
    with file_lock(cached):
253
        if not os.path.isfile(cached):
254
            logging.info(f"Downloading {url} to {cached} ...")
255
            cached = download(url, dirname, filename=filename)
256
    logging.info(f"URL {url} cached in {cached}")
257
    return cached
258
259
260
# TODO (prigoyal): convert this into RAII-style API
261
def create_file_symlink(file1, file2):
262
    """
263
    Simply create the symlinks for a given file1 to file2.
264
    Useful during model checkpointing to symlinks to the
265
    latest successful checkpoint.
266
    """
267
    try:
268
        if g_pathmgr.exists(file2):
269
            g_pathmgr.rm(file2)
270
        g_pathmgr.symlink(file1, file2)
271
    except Exception as e:
272
        logging.info(f"Could NOT create symlink. Error: {e}")
273
274
275
def save_file(data, filename, append_to_json=True, verbose=True):
276
    """
277
    Common i/o utility to handle saving data to various file formats.
278
    Supported:
279
        .pkl, .pickle, .npy, .json
280
    Specifically for .json, users have the option to either append (default)
281
    or rewrite by passing in Boolean value to append_to_json.
282
    """
283
    if verbose:
284
        logging.info(f"Saving data to file: {filename}")
285
    file_ext = os.path.splitext(filename)[1]
286
    if file_ext in [".pkl", ".pickle"]:
287
        with g_pathmgr.open(filename, "wb") as fopen:
288
            pickle.dump(data, fopen, pickle.HIGHEST_PROTOCOL)
289
    elif file_ext == ".npy":
290
        with g_pathmgr.open(filename, "wb") as fopen:
291
            np.save(fopen, data)
292
    elif file_ext == ".json":
293
        if append_to_json:
294
            with g_pathmgr.open(filename, "a") as fopen:
295
                fopen.write(json.dumps(data, sort_keys=True) + "\n")
296
                fopen.flush()
297
        else:
298
            with g_pathmgr.open(filename, "w") as fopen:
299
                fopen.write(json.dumps(data, sort_keys=True) + "\n")
300
                fopen.flush()
301
    elif file_ext == ".yaml":
302
        with g_pathmgr.open(filename, "w") as fopen:
303
            dump = yaml.dump(data)
304
            fopen.write(dump)
305
            fopen.flush()
306
    else:
307
        raise Exception(f"Saving {file_ext} is not supported yet")
308
309
    if verbose:
310
        logging.info(f"Saved data to file: {filename}")
311
312
313
def load_file(filename, mmap_mode=None, verbose=True, allow_pickle=False):
314
    """
315
    Common i/o utility to handle loading data from various file formats.
316
    Supported:
317
        .pkl, .pickle, .npy, .json
318
    For the npy files, we support reading the files in mmap_mode.
319
    If the mmap_mode of reading is not successful, we load data without the
320
    mmap_mode.
321
    """
322
    if verbose:
323
        logging.info(f"Loading data from file: {filename}")
324
325
    file_ext = os.path.splitext(filename)[1]
326
    if file_ext == ".txt":
327
        with g_pathmgr.open(filename, "r") as fopen:
328
            data = fopen.readlines()
329
    elif file_ext in [".pkl", ".pickle"]:
330
        with g_pathmgr.open(filename, "rb") as fopen:
331
            data = pickle.load(fopen, encoding="latin1")
332
    elif file_ext == ".npy":
333
        if mmap_mode:
334
            try:
335
                with g_pathmgr.open(filename, "rb") as fopen:
336
                    data = np.load(
337
                        fopen,
338
                        allow_pickle=allow_pickle,
339
                        encoding="latin1",
340
                        mmap_mode=mmap_mode,
341
                    )
342
            except ValueError as e:
343
                logging.info(
344
                    f"Could not mmap {filename}: {e}. Trying without g_pathmgr"
345
                )
346
                data = np.load(
347
                    filename,
348
                    allow_pickle=allow_pickle,
349
                    encoding="latin1",
350
                    mmap_mode=mmap_mode,
351
                )
352
                logging.info("Successfully loaded without g_pathmgr")
353
            except Exception:
354
                logging.info("Could not mmap without g_pathmgr. Trying without mmap")
355
                with g_pathmgr.open(filename, "rb") as fopen:
356
                    data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1")
357
        else:
358
            with g_pathmgr.open(filename, "rb") as fopen:
359
                data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1")
360
    elif file_ext == ".json":
361
        with g_pathmgr.open(filename, "r") as fopen:
362
            data = json.load(fopen)
363
    elif file_ext == ".yaml":
364
        with g_pathmgr.open(filename, "r") as fopen:
365
            data = yaml.load(fopen, Loader=yaml.FullLoader)
366
    elif file_ext == ".csv":
367
        with g_pathmgr.open(filename, "r") as fopen:
368
            data = pd.read_csv(fopen)
369
    else:
370
        raise Exception(f"Reading from {file_ext} is not supported yet")
371
    return data
372
373
374
def abspath(resource_path: str):
375
    """
376
    Make a path absolute, but take into account prefixes like
377
    "http://" or "manifold://"
378
    """
379
    regex = re.compile(r"^\w+://")
380
    if regex.match(resource_path) is None:
381
        return os.path.abspath(resource_path)
382
    else:
383
        return resource_path
384
385
386
def makedir(dir_path):
387
    """
388
    Create the directory if it does not exist.
389
    """
390
    is_success = False
391
    try:
392
        if not g_pathmgr.exists(dir_path):
393
            g_pathmgr.mkdirs(dir_path)
394
        is_success = True
395
    except BaseException:
396
        logging.info(f"Error creating directory: {dir_path}")
397
    return is_success
398
399
400
def is_url(input_url):
401
    """
402
    Check if an input string is a url. look for http(s):// and ignoring the case
403
    """
404
    is_url = re.match(r"^(?:http)s?://", input_url, re.IGNORECASE) is not None
405
    return is_url
406
407
408
def cleanup_dir(dir):
409
    """
410
    Utility for deleting a directory. Useful for cleaning the storage space
411
    that contains various training artifacts like checkpoints, data etc.
412
    """
413
    if os.path.exists(dir):
414
        logging.info(f"Deleting directory: {dir}")
415
        shutil.rmtree(dir)
416
    logging.info(f"Deleted contents of directory: {dir}")
417
418
419
def get_file_size(filename):
420
    """
421
    Given a file, get the size of file in MB
422
    """
423
    size_in_mb = os.path.getsize(filename) / float(1024**2)
424
    return size_in_mb