Switch to side-by-side view

--- a
+++ b/BioSeqNet/resnest/utils.py
@@ -0,0 +1,131 @@
+##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
+## Created by: Hang Zhang
+## Email: zhanghang0704@gmail.com
+## Copyright (c) 2020
+##
+## LICENSE file in the root directory of this source tree 
+##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
+import os
+from pathlib import Path
+import requests
+import errno
+import shutil
+import hashlib
+import zipfile
+import logging
+from tqdm import tqdm
+
+logger = logging.getLogger(__name__)
+
+__all__ = ['unzip', 'download', 'mkdir', 'check_sha1', 'raise_num_file']
+
+def unzip(zip_file_path, root=os.path.expanduser('./')):
+    """Unzips files located at `zip_file_path` into parent directory specified by `root`.
+    """
+    folders = []
+    with zipfile.ZipFile(zip_file_path) as zf:
+        zf.extractall(root)
+        for name in zf.namelist():
+            folder = Path(name).parts[0]
+            if folder not in folders:
+                folders.append(folder)
+    folders = folders[0] if len(folders) == 1 else tuple(folders)
+    return folders
+
+def download(url, path=None, overwrite=False, sha1_hash=None):
+    """Download files from a given URL.
+
+    Parameters
+    ----------
+    url : str
+        URL where file is located
+    path : str, optional
+        Destination path to store downloaded file. By default stores to the
+        current directory with same name as in url.
+    overwrite : bool, optional
+        Whether to overwrite destination file if one already exists at this location.
+    sha1_hash : str, optional
+        Expected sha1 hash in hexadecimal digits (will ignore existing file when hash is specified
+        but doesn't match).
+
+    Returns
+    -------
+    str
+        The file path of the downloaded file.
+    """
+    if path is None:
+        fname = url.split('/')[-1]
+    else:
+        path = os.path.expanduser(path)
+        if os.path.isdir(path):
+            fname = os.path.join(path, url.split('/')[-1])
+        else:
+            fname = path
+
+    if overwrite or not os.path.exists(fname) or (sha1_hash and not check_sha1(fname, sha1_hash)):
+        dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname)))
+        if not os.path.exists(dirname):
+            os.makedirs(dirname)
+
+        logger.info('Downloading %s from %s...'%(fname, url))
+        r = requests.get(url, stream=True)
+        if r.status_code != 200:
+            raise RuntimeError("Failed downloading url %s"%url)
+        total_length = r.headers.get('content-length')
+        with open(fname, 'wb') as f:
+            if total_length is None: # no content length header
+                for chunk in r.iter_content(chunk_size=1024):
+                    if chunk: # filter out keep-alive new chunks
+                        f.write(chunk)
+            else:
+                total_length = int(total_length)
+                for chunk in tqdm(r.iter_content(chunk_size=1024),
+                                  total=int(total_length / 1024. + 0.5),
+                                  unit='KB', unit_scale=False, dynamic_ncols=True):
+                    f.write(chunk)
+
+        if sha1_hash and not check_sha1(fname, sha1_hash):
+            raise UserWarning('File {} is downloaded but the content hash does not match. ' \
+                              'The repo may be outdated or download may be incomplete. ' \
+                              'If the "repo_url" is overridden, consider switching to ' \
+                              'the default repo.'.format(fname))
+
+    return fname
+
+
+def check_sha1(filename, sha1_hash):
+    """Check whether the sha1 hash of the file content matches the expected hash.
+
+    Parameters
+    ----------
+    filename : str
+        Path to the file.
+    sha1_hash : str
+        Expected sha1 hash in hexadecimal digits.
+
+    Returns
+    -------
+    bool
+        Whether the file content matches the expected hash.
+    """
+    sha1 = hashlib.sha1()
+    with open(filename, 'rb') as f:
+        while True:
+            data = f.read(1048576)
+            if not data:
+                break
+            sha1.update(data)
+
+    return sha1.hexdigest() == sha1_hash
+
+
+def mkdir(path):
+    """Make directory at the specified local path with special error handling.
+    """
+    try:
+        os.makedirs(path)
+    except OSError as exc:  # Python >2.5
+        if exc.errno == errno.EEXIST and os.path.isdir(path):
+            pass
+        else:
+            raise