|
a |
|
b/BioSeqNet/resnest/utils.py |
|
|
1 |
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ |
|
|
2 |
## Created by: Hang Zhang |
|
|
3 |
## Email: zhanghang0704@gmail.com |
|
|
4 |
## Copyright (c) 2020 |
|
|
5 |
## |
|
|
6 |
## LICENSE file in the root directory of this source tree |
|
|
7 |
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ |
|
|
8 |
import os |
|
|
9 |
from pathlib import Path |
|
|
10 |
import requests |
|
|
11 |
import errno |
|
|
12 |
import shutil |
|
|
13 |
import hashlib |
|
|
14 |
import zipfile |
|
|
15 |
import logging |
|
|
16 |
from tqdm import tqdm |
|
|
17 |
|
|
|
18 |
logger = logging.getLogger(__name__) |
|
|
19 |
|
|
|
20 |
__all__ = ['unzip', 'download', 'mkdir', 'check_sha1', 'raise_num_file'] |
|
|
21 |
|
|
|
22 |
def unzip(zip_file_path, root=os.path.expanduser('./')): |
|
|
23 |
"""Unzips files located at `zip_file_path` into parent directory specified by `root`. |
|
|
24 |
""" |
|
|
25 |
folders = [] |
|
|
26 |
with zipfile.ZipFile(zip_file_path) as zf: |
|
|
27 |
zf.extractall(root) |
|
|
28 |
for name in zf.namelist(): |
|
|
29 |
folder = Path(name).parts[0] |
|
|
30 |
if folder not in folders: |
|
|
31 |
folders.append(folder) |
|
|
32 |
folders = folders[0] if len(folders) == 1 else tuple(folders) |
|
|
33 |
return folders |
|
|
34 |
|
|
|
35 |
def download(url, path=None, overwrite=False, sha1_hash=None): |
|
|
36 |
"""Download files from a given URL. |
|
|
37 |
|
|
|
38 |
Parameters |
|
|
39 |
---------- |
|
|
40 |
url : str |
|
|
41 |
URL where file is located |
|
|
42 |
path : str, optional |
|
|
43 |
Destination path to store downloaded file. By default stores to the |
|
|
44 |
current directory with same name as in url. |
|
|
45 |
overwrite : bool, optional |
|
|
46 |
Whether to overwrite destination file if one already exists at this location. |
|
|
47 |
sha1_hash : str, optional |
|
|
48 |
Expected sha1 hash in hexadecimal digits (will ignore existing file when hash is specified |
|
|
49 |
but doesn't match). |
|
|
50 |
|
|
|
51 |
Returns |
|
|
52 |
------- |
|
|
53 |
str |
|
|
54 |
The file path of the downloaded file. |
|
|
55 |
""" |
|
|
56 |
if path is None: |
|
|
57 |
fname = url.split('/')[-1] |
|
|
58 |
else: |
|
|
59 |
path = os.path.expanduser(path) |
|
|
60 |
if os.path.isdir(path): |
|
|
61 |
fname = os.path.join(path, url.split('/')[-1]) |
|
|
62 |
else: |
|
|
63 |
fname = path |
|
|
64 |
|
|
|
65 |
if overwrite or not os.path.exists(fname) or (sha1_hash and not check_sha1(fname, sha1_hash)): |
|
|
66 |
dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname))) |
|
|
67 |
if not os.path.exists(dirname): |
|
|
68 |
os.makedirs(dirname) |
|
|
69 |
|
|
|
70 |
logger.info('Downloading %s from %s...'%(fname, url)) |
|
|
71 |
r = requests.get(url, stream=True) |
|
|
72 |
if r.status_code != 200: |
|
|
73 |
raise RuntimeError("Failed downloading url %s"%url) |
|
|
74 |
total_length = r.headers.get('content-length') |
|
|
75 |
with open(fname, 'wb') as f: |
|
|
76 |
if total_length is None: # no content length header |
|
|
77 |
for chunk in r.iter_content(chunk_size=1024): |
|
|
78 |
if chunk: # filter out keep-alive new chunks |
|
|
79 |
f.write(chunk) |
|
|
80 |
else: |
|
|
81 |
total_length = int(total_length) |
|
|
82 |
for chunk in tqdm(r.iter_content(chunk_size=1024), |
|
|
83 |
total=int(total_length / 1024. + 0.5), |
|
|
84 |
unit='KB', unit_scale=False, dynamic_ncols=True): |
|
|
85 |
f.write(chunk) |
|
|
86 |
|
|
|
87 |
if sha1_hash and not check_sha1(fname, sha1_hash): |
|
|
88 |
raise UserWarning('File {} is downloaded but the content hash does not match. ' \ |
|
|
89 |
'The repo may be outdated or download may be incomplete. ' \ |
|
|
90 |
'If the "repo_url" is overridden, consider switching to ' \ |
|
|
91 |
'the default repo.'.format(fname)) |
|
|
92 |
|
|
|
93 |
return fname |
|
|
94 |
|
|
|
95 |
|
|
|
96 |
def check_sha1(filename, sha1_hash): |
|
|
97 |
"""Check whether the sha1 hash of the file content matches the expected hash. |
|
|
98 |
|
|
|
99 |
Parameters |
|
|
100 |
---------- |
|
|
101 |
filename : str |
|
|
102 |
Path to the file. |
|
|
103 |
sha1_hash : str |
|
|
104 |
Expected sha1 hash in hexadecimal digits. |
|
|
105 |
|
|
|
106 |
Returns |
|
|
107 |
------- |
|
|
108 |
bool |
|
|
109 |
Whether the file content matches the expected hash. |
|
|
110 |
""" |
|
|
111 |
sha1 = hashlib.sha1() |
|
|
112 |
with open(filename, 'rb') as f: |
|
|
113 |
while True: |
|
|
114 |
data = f.read(1048576) |
|
|
115 |
if not data: |
|
|
116 |
break |
|
|
117 |
sha1.update(data) |
|
|
118 |
|
|
|
119 |
return sha1.hexdigest() == sha1_hash |
|
|
120 |
|
|
|
121 |
|
|
|
122 |
def mkdir(path): |
|
|
123 |
"""Make directory at the specified local path with special error handling. |
|
|
124 |
""" |
|
|
125 |
try: |
|
|
126 |
os.makedirs(path) |
|
|
127 |
except OSError as exc: # Python >2.5 |
|
|
128 |
if exc.errno == errno.EEXIST and os.path.isdir(path): |
|
|
129 |
pass |
|
|
130 |
else: |
|
|
131 |
raise |