Diff of /openomics/io/files.py [000000] .. [548210]

Switch to side-by-side view

--- a
+++ b/openomics/io/files.py
@@ -0,0 +1,196 @@
+import gzip
+import os
+import zipfile
+from os.path import exists, getsize
+from typing import Tuple, Union, TextIO, Optional, Dict, List
+from urllib.error import URLError
+from logzero import logger
+
+import dask.dataframe as dd
+import filetype
+import rarfile
+import requests
+import sqlalchemy as sa
+import validators
+from astropy.utils import data
+from requests.adapters import HTTPAdapter, Retry
+
+import openomics
+
+
+# @astropy.config.set_temp_cache(openomics.config["cache_dir"])
+def get_pkg_data_filename(baseurl: str, filepath: str):
+    """Downloads a remote file given the url, then caches it to the user's home
+    folder.
+
+    Args:
+        baseurl: Url to the download path, excluding the file name
+        filepath: The file path to download
+
+    Returns:
+        filename (str): A file path on the local file system corresponding to
+        the data requested in data_name.
+    """
+    # Split data url and file name if the user provided a whole path in `file_resources`
+    if validators.url(filepath):
+        base, filepath = os.path.split(filepath)
+        base = base + "/"
+    else:
+        base, filepath = baseurl, filepath
+
+    try:
+        # TODO doesn't yet save files to 'cache_dir' but to astropy's default cache dir
+        # logger.debug(f"Fetching file from: {base}{filepath}, saving to {openomics.config['cache_dir']}")
+
+        with data.conf.set_temp("dataurl", base), data.conf.set_temp("remote_timeout", 30):
+            return data.get_pkg_data_filename(filepath, package="openomics.database", show_progress=True)
+
+    except (URLError, ValueError) as e:
+        raise Exception(f"Unable to download file at {os.path.join(base, filepath)}. "
+                        f"Please try manually downloading the files and add path to `file_resources` arg. \n{e}")
+
+
+def decompress_file(filepath: str, filename: str, file_ext: filetype.Type, write_uncompressed=False) \
+    -> Tuple[Union[gzip.GzipFile, TextIO], str]:
+    """
+    Decompress the `filepath` corresponding to its `file_ext` compression type, then return the uncompressed data (or its path) and
+    the `filename` without the `file_ext` suffix.
+
+    Args:
+        filepath (str): The file path to the data file
+        filename (str): The filename of the data file
+        file_ext (filetype.Type): The file extension of the data file
+        write_uncompressed (bool): Whether to write the uncompressed file to disk
+
+    Returns:
+        uncompressed_file (): The uncompressed file path
+        updated_filename (str): The filename without the `file_ext` suffix
+    """
+    data = filepath
+
+    if file_ext is None:
+        return data, filename
+
+    elif file_ext.extension == "gz":
+        data = gzip.open(filepath, "rt")
+
+    elif file_ext.extension == "zip":
+        with zipfile.ZipFile(filepath, "r") as zf:
+            for subfile in zf.infolist():
+                # Select first file with matching file extension
+                subfile_name = os.path.splitext(subfile.filename)[-1]
+                if subfile_name == os.path.splitext(filename.replace(".zip", ""))[-1]:
+                    data = zf.open(subfile.filename, mode="r")
+
+    elif file_ext.extension == "rar":
+        with rarfile.RarFile(filepath, "r") as rf:
+
+            for subfile in rf.infolist():
+                # If the file extension matches
+                subfile_name = os.path.splitext(subfile.filename)[-1]
+                if subfile_name == os.path.splitext(filename.replace(".rar", ""))[-1]:
+                    data = rf.open(subfile.filename, mode="r")
+
+    else:
+        logger.warn(f"WARNING: filepath_ext.extension {file_ext.extension} not supported.")
+        return data, filename
+
+    filename = get_uncompressed_filepath(filename)
+    uncompressed_path = get_uncompressed_filepath(filepath)
+
+    if write_uncompressed and not exists(uncompressed_path):
+        with open(uncompressed_path, 'w', encoding='utf8') as f_out:
+            logger.info(f"Writing uncompressed {filename} file to {uncompressed_path}")
+            f_out.write(data.read())
+
+    if exists(uncompressed_path) and getsize(uncompressed_path) > 0:
+        data = uncompressed_path
+
+    return data, filename
+
+
+def get_uncompressed_filepath(filepath: str) -> str:
+    """Return the uncompressed filepath by removing the file extension suffix.
+
+    Args:
+        filepath (str): File path to the compressed file
+
+    Returns:
+        uncompressed_path (str): File path to the uncompressed file
+    """
+    uncompressed_path = ''
+    if filepath.endswith(".gz"):
+        uncompressed_path = filepath.removesuffix(".gz")
+    elif filepath.endswith(".zip"):
+        uncompressed_path = filepath.removesuffix(".zip")
+    elif filepath.endswith(".rar"):
+        uncompressed_path = filepath.removesuffix(".rar")
+    else:
+        uncompressed_path = filepath + ".uncompressed"
+
+    if uncompressed_path and filepath != uncompressed_path:
+        return uncompressed_path
+    else:
+        return ''
+
+
+def select_files_with_ext(file_resources: Dict[str, str], ext: str, contains: Optional[str] = None) -> Dict[str, str]:
+    """Return a list of file paths with the specified file extension. Only string values are considered as file paths.
+
+    Args:
+        file_resources (dict): A dictionary of file names and their corresponding file paths
+        ext (str): The file extension to filter the file names by
+        contains (str): If not None, only return file paths that contain the specified string
+
+    Returns:
+        file_paths (dict): A dict of file names and corresponding paths with the specified file extension
+    """
+    subset_file_resources = {}
+    for filename, filepath in file_resources.items():
+        if not isinstance(filepath, str): continue
+        if filename.endswith(ext) and (contains is None or contains in filename):
+            subset_file_resources[filename] = filepath
+
+    return subset_file_resources
+
+
+def read_db(path, table, index_col):
+    """
+    Args:
+        path:
+        table:
+        index_col:
+    """
+    engine = sa.create_engine(path)
+    # conn = engine.connect()
+    m = sa.MetaData()
+    table = sa.Table(table, m, autoload=True, autoload_with=engine)
+
+    # conn.execute("create table testtable (uid integer Primary Key, datetime NUM)")
+    # conn.execute("insert into testtable values (1, '2017-08-03 01:11:31')")
+    # print(conn.execute('PRAGMA table_info(testtable)').fetchall())
+    # conn.close()
+
+    uid, dt = list(table.columns)
+    q = sa.select([dt.cast(sa.types.String)]).select_from(table)
+
+    daskDF = dd.read_sql_table(table, path, index_col=index_col, parse_dates={'datetime': '%Y-%m-%d %H:%M:%S'})
+    return daskDF
+
+
+def retry(num=5):
+    """retry connection.
+
+    define max tries num if the backoff_factor is 0.1, then sleep() will
+    sleep for [0.1s, 0.2s, 0.4s, ...] between retries. It will also force a
+    retry if the status code returned is 500, 502, 503 or 504.
+
+    Args:
+        num:
+    """
+    s = requests.Session()
+    retries = Retry(total=num, backoff_factor=0.1,
+                    status_forcelist=[500, 502, 503, 504])
+    s.mount('http://', HTTPAdapter(max_retries=retries))
+
+    return s