a b/BioSeqNet/resnest/utils.py
1
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
2
## Created by: Hang Zhang
3
## Email: zhanghang0704@gmail.com
4
## Copyright (c) 2020
5
##
6
## LICENSE file in the root directory of this source tree 
7
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
8
import os
9
from pathlib import Path
10
import requests
11
import errno
12
import shutil
13
import hashlib
14
import zipfile
15
import logging
16
from tqdm import tqdm
17
18
logger = logging.getLogger(__name__)
19
20
__all__ = ['unzip', 'download', 'mkdir', 'check_sha1', 'raise_num_file']
21
22
def unzip(zip_file_path, root=os.path.expanduser('./')):
23
    """Unzips files located at `zip_file_path` into parent directory specified by `root`.
24
    """
25
    folders = []
26
    with zipfile.ZipFile(zip_file_path) as zf:
27
        zf.extractall(root)
28
        for name in zf.namelist():
29
            folder = Path(name).parts[0]
30
            if folder not in folders:
31
                folders.append(folder)
32
    folders = folders[0] if len(folders) == 1 else tuple(folders)
33
    return folders
34
35
def download(url, path=None, overwrite=False, sha1_hash=None):
36
    """Download files from a given URL.
37
38
    Parameters
39
    ----------
40
    url : str
41
        URL where file is located
42
    path : str, optional
43
        Destination path to store downloaded file. By default stores to the
44
        current directory with same name as in url.
45
    overwrite : bool, optional
46
        Whether to overwrite destination file if one already exists at this location.
47
    sha1_hash : str, optional
48
        Expected sha1 hash in hexadecimal digits (will ignore existing file when hash is specified
49
        but doesn't match).
50
51
    Returns
52
    -------
53
    str
54
        The file path of the downloaded file.
55
    """
56
    if path is None:
57
        fname = url.split('/')[-1]
58
    else:
59
        path = os.path.expanduser(path)
60
        if os.path.isdir(path):
61
            fname = os.path.join(path, url.split('/')[-1])
62
        else:
63
            fname = path
64
65
    if overwrite or not os.path.exists(fname) or (sha1_hash and not check_sha1(fname, sha1_hash)):
66
        dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname)))
67
        if not os.path.exists(dirname):
68
            os.makedirs(dirname)
69
70
        logger.info('Downloading %s from %s...'%(fname, url))
71
        r = requests.get(url, stream=True)
72
        if r.status_code != 200:
73
            raise RuntimeError("Failed downloading url %s"%url)
74
        total_length = r.headers.get('content-length')
75
        with open(fname, 'wb') as f:
76
            if total_length is None: # no content length header
77
                for chunk in r.iter_content(chunk_size=1024):
78
                    if chunk: # filter out keep-alive new chunks
79
                        f.write(chunk)
80
            else:
81
                total_length = int(total_length)
82
                for chunk in tqdm(r.iter_content(chunk_size=1024),
83
                                  total=int(total_length / 1024. + 0.5),
84
                                  unit='KB', unit_scale=False, dynamic_ncols=True):
85
                    f.write(chunk)
86
87
        if sha1_hash and not check_sha1(fname, sha1_hash):
88
            raise UserWarning('File {} is downloaded but the content hash does not match. ' \
89
                              'The repo may be outdated or download may be incomplete. ' \
90
                              'If the "repo_url" is overridden, consider switching to ' \
91
                              'the default repo.'.format(fname))
92
93
    return fname
94
95
96
def check_sha1(filename, sha1_hash):
97
    """Check whether the sha1 hash of the file content matches the expected hash.
98
99
    Parameters
100
    ----------
101
    filename : str
102
        Path to the file.
103
    sha1_hash : str
104
        Expected sha1 hash in hexadecimal digits.
105
106
    Returns
107
    -------
108
    bool
109
        Whether the file content matches the expected hash.
110
    """
111
    sha1 = hashlib.sha1()
112
    with open(filename, 'rb') as f:
113
        while True:
114
            data = f.read(1048576)
115
            if not data:
116
                break
117
            sha1.update(data)
118
119
    return sha1.hexdigest() == sha1_hash
120
121
122
def mkdir(path):
123
    """Make directory at the specified local path with special error handling.
124
    """
125
    try:
126
        os.makedirs(path)
127
    except OSError as exc:  # Python >2.5
128
        if exc.errno == errno.EEXIST and os.path.isdir(path):
129
            pass
130
        else:
131
            raise