|
a |
|
b/openomics/io/files.py |
|
|
1 |
import gzip |
|
|
2 |
import os |
|
|
3 |
import zipfile |
|
|
4 |
from os.path import exists, getsize |
|
|
5 |
from typing import Tuple, Union, TextIO, Optional, Dict, List |
|
|
6 |
from urllib.error import URLError |
|
|
7 |
from logzero import logger |
|
|
8 |
|
|
|
9 |
import dask.dataframe as dd |
|
|
10 |
import filetype |
|
|
11 |
import rarfile |
|
|
12 |
import requests |
|
|
13 |
import sqlalchemy as sa |
|
|
14 |
import validators |
|
|
15 |
from astropy.utils import data |
|
|
16 |
from requests.adapters import HTTPAdapter, Retry |
|
|
17 |
|
|
|
18 |
import openomics |
|
|
19 |
|
|
|
20 |
|
|
|
21 |
# @astropy.config.set_temp_cache(openomics.config["cache_dir"]) |
|
|
22 |
def get_pkg_data_filename(baseurl: str, filepath: str): |
|
|
23 |
"""Downloads a remote file given the url, then caches it to the user's home |
|
|
24 |
folder. |
|
|
25 |
|
|
|
26 |
Args: |
|
|
27 |
baseurl: Url to the download path, excluding the file name |
|
|
28 |
filepath: The file path to download |
|
|
29 |
|
|
|
30 |
Returns: |
|
|
31 |
filename (str): A file path on the local file system corresponding to |
|
|
32 |
the data requested in data_name. |
|
|
33 |
""" |
|
|
34 |
# Split data url and file name if the user provided a whole path in `file_resources` |
|
|
35 |
if validators.url(filepath): |
|
|
36 |
base, filepath = os.path.split(filepath) |
|
|
37 |
base = base + "/" |
|
|
38 |
else: |
|
|
39 |
base, filepath = baseurl, filepath |
|
|
40 |
|
|
|
41 |
try: |
|
|
42 |
# TODO doesn't yet save files to 'cache_dir' but to astropy's default cache dir |
|
|
43 |
# logger.debug(f"Fetching file from: {base}{filepath}, saving to {openomics.config['cache_dir']}") |
|
|
44 |
|
|
|
45 |
with data.conf.set_temp("dataurl", base), data.conf.set_temp("remote_timeout", 30): |
|
|
46 |
return data.get_pkg_data_filename(filepath, package="openomics.database", show_progress=True) |
|
|
47 |
|
|
|
48 |
except (URLError, ValueError) as e: |
|
|
49 |
raise Exception(f"Unable to download file at {os.path.join(base, filepath)}. " |
|
|
50 |
f"Please try manually downloading the files and add path to `file_resources` arg. \n{e}") |
|
|
51 |
|
|
|
52 |
|
|
|
53 |
def decompress_file(filepath: str, filename: str, file_ext: filetype.Type, write_uncompressed=False) \ |
|
|
54 |
-> Tuple[Union[gzip.GzipFile, TextIO], str]: |
|
|
55 |
""" |
|
|
56 |
Decompress the `filepath` corresponding to its `file_ext` compression type, then return the uncompressed data (or its path) and |
|
|
57 |
the `filename` without the `file_ext` suffix. |
|
|
58 |
|
|
|
59 |
Args: |
|
|
60 |
filepath (str): The file path to the data file |
|
|
61 |
filename (str): The filename of the data file |
|
|
62 |
file_ext (filetype.Type): The file extension of the data file |
|
|
63 |
write_uncompressed (bool): Whether to write the uncompressed file to disk |
|
|
64 |
|
|
|
65 |
Returns: |
|
|
66 |
uncompressed_file (): The uncompressed file path |
|
|
67 |
updated_filename (str): The filename without the `file_ext` suffix |
|
|
68 |
""" |
|
|
69 |
data = filepath |
|
|
70 |
|
|
|
71 |
if file_ext is None: |
|
|
72 |
return data, filename |
|
|
73 |
|
|
|
74 |
elif file_ext.extension == "gz": |
|
|
75 |
data = gzip.open(filepath, "rt") |
|
|
76 |
|
|
|
77 |
elif file_ext.extension == "zip": |
|
|
78 |
with zipfile.ZipFile(filepath, "r") as zf: |
|
|
79 |
for subfile in zf.infolist(): |
|
|
80 |
# Select first file with matching file extension |
|
|
81 |
subfile_name = os.path.splitext(subfile.filename)[-1] |
|
|
82 |
if subfile_name == os.path.splitext(filename.replace(".zip", ""))[-1]: |
|
|
83 |
data = zf.open(subfile.filename, mode="r") |
|
|
84 |
|
|
|
85 |
elif file_ext.extension == "rar": |
|
|
86 |
with rarfile.RarFile(filepath, "r") as rf: |
|
|
87 |
|
|
|
88 |
for subfile in rf.infolist(): |
|
|
89 |
# If the file extension matches |
|
|
90 |
subfile_name = os.path.splitext(subfile.filename)[-1] |
|
|
91 |
if subfile_name == os.path.splitext(filename.replace(".rar", ""))[-1]: |
|
|
92 |
data = rf.open(subfile.filename, mode="r") |
|
|
93 |
|
|
|
94 |
else: |
|
|
95 |
logger.warn(f"WARNING: filepath_ext.extension {file_ext.extension} not supported.") |
|
|
96 |
return data, filename |
|
|
97 |
|
|
|
98 |
filename = get_uncompressed_filepath(filename) |
|
|
99 |
uncompressed_path = get_uncompressed_filepath(filepath) |
|
|
100 |
|
|
|
101 |
if write_uncompressed and not exists(uncompressed_path): |
|
|
102 |
with open(uncompressed_path, 'w', encoding='utf8') as f_out: |
|
|
103 |
logger.info(f"Writing uncompressed {filename} file to {uncompressed_path}") |
|
|
104 |
f_out.write(data.read()) |
|
|
105 |
|
|
|
106 |
if exists(uncompressed_path) and getsize(uncompressed_path) > 0: |
|
|
107 |
data = uncompressed_path |
|
|
108 |
|
|
|
109 |
return data, filename |
|
|
110 |
|
|
|
111 |
|
|
|
112 |
def get_uncompressed_filepath(filepath: str) -> str: |
|
|
113 |
"""Return the uncompressed filepath by removing the file extension suffix. |
|
|
114 |
|
|
|
115 |
Args: |
|
|
116 |
filepath (str): File path to the compressed file |
|
|
117 |
|
|
|
118 |
Returns: |
|
|
119 |
uncompressed_path (str): File path to the uncompressed file |
|
|
120 |
""" |
|
|
121 |
uncompressed_path = '' |
|
|
122 |
if filepath.endswith(".gz"): |
|
|
123 |
uncompressed_path = filepath.removesuffix(".gz") |
|
|
124 |
elif filepath.endswith(".zip"): |
|
|
125 |
uncompressed_path = filepath.removesuffix(".zip") |
|
|
126 |
elif filepath.endswith(".rar"): |
|
|
127 |
uncompressed_path = filepath.removesuffix(".rar") |
|
|
128 |
else: |
|
|
129 |
uncompressed_path = filepath + ".uncompressed" |
|
|
130 |
|
|
|
131 |
if uncompressed_path and filepath != uncompressed_path: |
|
|
132 |
return uncompressed_path |
|
|
133 |
else: |
|
|
134 |
return '' |
|
|
135 |
|
|
|
136 |
|
|
|
137 |
def select_files_with_ext(file_resources: Dict[str, str], ext: str, contains: Optional[str] = None) -> Dict[str, str]: |
|
|
138 |
"""Return a list of file paths with the specified file extension. Only string values are considered as file paths. |
|
|
139 |
|
|
|
140 |
Args: |
|
|
141 |
file_resources (dict): A dictionary of file names and their corresponding file paths |
|
|
142 |
ext (str): The file extension to filter the file names by |
|
|
143 |
contains (str): If not None, only return file paths that contain the specified string |
|
|
144 |
|
|
|
145 |
Returns: |
|
|
146 |
file_paths (dict): A dict of file names and corresponding paths with the specified file extension |
|
|
147 |
""" |
|
|
148 |
subset_file_resources = {} |
|
|
149 |
for filename, filepath in file_resources.items(): |
|
|
150 |
if not isinstance(filepath, str): continue |
|
|
151 |
if filename.endswith(ext) and (contains is None or contains in filename): |
|
|
152 |
subset_file_resources[filename] = filepath |
|
|
153 |
|
|
|
154 |
return subset_file_resources |
|
|
155 |
|
|
|
156 |
|
|
|
157 |
def read_db(path, table, index_col): |
|
|
158 |
""" |
|
|
159 |
Args: |
|
|
160 |
path: |
|
|
161 |
table: |
|
|
162 |
index_col: |
|
|
163 |
""" |
|
|
164 |
engine = sa.create_engine(path) |
|
|
165 |
# conn = engine.connect() |
|
|
166 |
m = sa.MetaData() |
|
|
167 |
table = sa.Table(table, m, autoload=True, autoload_with=engine) |
|
|
168 |
|
|
|
169 |
# conn.execute("create table testtable (uid integer Primary Key, datetime NUM)") |
|
|
170 |
# conn.execute("insert into testtable values (1, '2017-08-03 01:11:31')") |
|
|
171 |
# print(conn.execute('PRAGMA table_info(testtable)').fetchall()) |
|
|
172 |
# conn.close() |
|
|
173 |
|
|
|
174 |
uid, dt = list(table.columns) |
|
|
175 |
q = sa.select([dt.cast(sa.types.String)]).select_from(table) |
|
|
176 |
|
|
|
177 |
daskDF = dd.read_sql_table(table, path, index_col=index_col, parse_dates={'datetime': '%Y-%m-%d %H:%M:%S'}) |
|
|
178 |
return daskDF |
|
|
179 |
|
|
|
180 |
|
|
|
181 |
def retry(num=5): |
|
|
182 |
"""retry connection. |
|
|
183 |
|
|
|
184 |
define max tries num if the backoff_factor is 0.1, then sleep() will |
|
|
185 |
sleep for [0.1s, 0.2s, 0.4s, ...] between retries. It will also force a |
|
|
186 |
retry if the status code returned is 500, 502, 503 or 504. |
|
|
187 |
|
|
|
188 |
Args: |
|
|
189 |
num: |
|
|
190 |
""" |
|
|
191 |
s = requests.Session() |
|
|
192 |
retries = Retry(total=num, backoff_factor=0.1, |
|
|
193 |
status_forcelist=[500, 502, 503, 504]) |
|
|
194 |
s.mount('http://', HTTPAdapter(max_retries=retries)) |
|
|
195 |
|
|
|
196 |
return s |