--- a +++ b/BioSeqNet/resnest/utils.py @@ -0,0 +1,131 @@ +##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +## Created by: Hang Zhang +## Email: zhanghang0704@gmail.com +## Copyright (c) 2020 +## +## LICENSE file in the root directory of this source tree +##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +import os +from pathlib import Path +import requests +import errno +import shutil +import hashlib +import zipfile +import logging +from tqdm import tqdm + +logger = logging.getLogger(__name__) + +__all__ = ['unzip', 'download', 'mkdir', 'check_sha1', 'raise_num_file'] + +def unzip(zip_file_path, root=os.path.expanduser('./')): + """Unzips files located at `zip_file_path` into parent directory specified by `root`. + """ + folders = [] + with zipfile.ZipFile(zip_file_path) as zf: + zf.extractall(root) + for name in zf.namelist(): + folder = Path(name).parts[0] + if folder not in folders: + folders.append(folder) + folders = folders[0] if len(folders) == 1 else tuple(folders) + return folders + +def download(url, path=None, overwrite=False, sha1_hash=None): + """Download files from a given URL. + + Parameters + ---------- + url : str + URL where file is located + path : str, optional + Destination path to store downloaded file. By default stores to the + current directory with same name as in url. + overwrite : bool, optional + Whether to overwrite destination file if one already exists at this location. + sha1_hash : str, optional + Expected sha1 hash in hexadecimal digits (will ignore existing file when hash is specified + but doesn't match). + + Returns + ------- + str + The file path of the downloaded file. + """ + if path is None: + fname = url.split('/')[-1] + else: + path = os.path.expanduser(path) + if os.path.isdir(path): + fname = os.path.join(path, url.split('/')[-1]) + else: + fname = path + + if overwrite or not os.path.exists(fname) or (sha1_hash and not check_sha1(fname, sha1_hash)): + dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname))) + if not os.path.exists(dirname): + os.makedirs(dirname) + + logger.info('Downloading %s from %s...'%(fname, url)) + r = requests.get(url, stream=True) + if r.status_code != 200: + raise RuntimeError("Failed downloading url %s"%url) + total_length = r.headers.get('content-length') + with open(fname, 'wb') as f: + if total_length is None: # no content length header + for chunk in r.iter_content(chunk_size=1024): + if chunk: # filter out keep-alive new chunks + f.write(chunk) + else: + total_length = int(total_length) + for chunk in tqdm(r.iter_content(chunk_size=1024), + total=int(total_length / 1024. + 0.5), + unit='KB', unit_scale=False, dynamic_ncols=True): + f.write(chunk) + + if sha1_hash and not check_sha1(fname, sha1_hash): + raise UserWarning('File {} is downloaded but the content hash does not match. ' \ + 'The repo may be outdated or download may be incomplete. ' \ + 'If the "repo_url" is overridden, consider switching to ' \ + 'the default repo.'.format(fname)) + + return fname + + +def check_sha1(filename, sha1_hash): + """Check whether the sha1 hash of the file content matches the expected hash. + + Parameters + ---------- + filename : str + Path to the file. + sha1_hash : str + Expected sha1 hash in hexadecimal digits. + + Returns + ------- + bool + Whether the file content matches the expected hash. + """ + sha1 = hashlib.sha1() + with open(filename, 'rb') as f: + while True: + data = f.read(1048576) + if not data: + break + sha1.update(data) + + return sha1.hexdigest() == sha1_hash + + +def mkdir(path): + """Make directory at the specified local path with special error handling. + """ + try: + os.makedirs(path) + except OSError as exc: # Python >2.5 + if exc.errno == errno.EEXIST and os.path.isdir(path): + pass + else: + raise