Switch to unified view

a b/openomics/database/ontology.py
1
import os
2
import warnings
3
from collections.abc import Iterable
4
from io import TextIOWrapper, StringIO
5
from typing import Tuple, List, Dict, Union, Callable, Optional
6
7
import dask.dataframe as dd
8
import networkx as nx
9
import numpy as np
10
import obonet
11
import pandas as pd
12
import scipy.sparse as ssp
13
from logzero import logger
14
from networkx import NetworkXError
15
from pandas import DataFrame
16
17
from openomics.io.read_gaf import read_gaf
18
from openomics.transforms.adj import slice_adj
19
from openomics.transforms.agg import get_agg_func
20
from .base import Database
21
22
__all__ = ['GeneOntology', 'UniProtGOA', 'InterPro', 'HumanPhenotypeOntology', ]
23
24
25
class Ontology(Database):
26
    annotations: pd.DataFrame
27
28
    def __init__(self,
29
                 path,
30
                 file_resources=None,
31
                 blocksize=None,
32
                 **kwargs):
33
        """
34
        Manages dataset input processing from tables and construct an ontology network from .obo file. There ontology
35
        network is G(V,E) where there exists e_ij for child i to parent j to present "node i is_a node j".
36
37
        Args:
38
            path:
39
            file_resources:
40
            blocksize:
41
            verbose:
42
        """
43
        super().__init__(path=path, file_resources=file_resources, blocksize=blocksize, **kwargs)
44
45
        self.network, self.node_list = self.load_network(self.file_resources)
46
        self.annotations = self.load_annotation(self.file_resources, self.blocksize)
47
48
        self.close()
49
50
    def load_network(self, file_resources) -> Tuple[nx.MultiDiGraph, List[str]]:
51
        raise NotImplementedError()
52
53
    def load_annotation(self, file_resources, blocksize) -> Union[pd.DataFrame, dd.DataFrame]:
54
        pass
55
56
    def filter_network(self, namespace) -> None:
57
        """
58
        Filter the subgraph node_list to only `namespace` terms.
59
        Args:
60
            namespace: one of {"biological_process", "cellular_component", "molecular_function"}
61
        """
62
        terms = self.data[self.data["namespace"] == namespace]["go_id"].unique()
63
        print("{} terms: {}".format(namespace,
64
                                    len(terms))) if self.verbose else None
65
        self.network = self.network.subgraph(nodes=list(terms))
66
        self.node_list = np.array(list(terms))
67
68
    def adj(self, node_list):
69
        adj_mtx = nx.adj_matrix(self.network, nodelist=node_list)
70
71
        if node_list is None or list(node_list) == list(self.node_list):
72
            return adj_mtx
73
        elif set(node_list) < set(self.node_list):
74
            return slice_adj(adj_mtx, list(self.node_list), node_list,
75
                             None)
76
        elif not (set(node_list) < set(self.node_list)):
77
            raise Exception("A node in node_list is not in self.node_list.")
78
79
        return adj_mtx
80
81
    def filter_annotation(self, annotation: pd.Series):
82
        go_terms = set(self.node_list)
83
        filtered_annotation = annotation.map(lambda x: list(set(x) & go_terms)
84
                                             if isinstance(x, list) else [])
85
86
        return filtered_annotation
87
88
    def get_child_nodes(self):
89
        adj = self.adj(self.node_list)
90
        leaf_terms = self.node_list[np.nonzero(adj.sum(axis=0) == 0)[1]]
91
        return leaf_terms
92
93
    def get_root_nodes(self):
94
        adj = self.adj(self.node_list)
95
        parent_terms = self.node_list[np.nonzero(adj.sum(axis=1) == 0)[0]]
96
        return parent_terms
97
98
    def get_dfs_paths(self, root_nodes: list, filter_duplicates=False):
99
        """
100
        Return all depth-first search paths from root node(s) to children node by traversing the ontology directed graph.
101
        Args:
102
            root_nodes (list): ["GO:0008150"] if biological processes, ["GO:0003674"] if molecular_function, or ["GO:0005575"] if cellular_component
103
            filter_duplicates (bool): whether to remove duplicated paths that end up at the same leaf nodes
104
105
        Returns: pd.DataFrame of all paths starting from the root nodes.
106
        """
107
        if not isinstance(root_nodes, list):
108
            root_nodes = list(root_nodes)
109
110
        paths = list(dfs_path(self.network, root_nodes))
111
        paths = list(flatten_list(paths))
112
        paths_df = pd.DataFrame(paths)
113
114
        if filter_duplicates:
115
            paths_df = paths_df[~paths_df.duplicated(keep="first")]
116
            paths_df = filter_dfs_paths(paths_df)
117
118
        return paths_df
119
120
    def remove_predecessor_terms(self, annotation: pd.Series, sep="\||;"):
121
        # leaf_terms = self.get_child_nodes()
122
        # if not annotation.map(lambda x: isinstance(x, (list, np.ndarray))).any() and sep:
123
        #     annotation = annotation.str.split(sep)
124
        #
125
        # parent_terms = annotation.map(lambda x: list(
126
        #     set(x) & set(leaf_terms)) if isinstance(x, (list, np.ndarray)) else None)
127
        # return parent_terms
128
        raise NotImplementedError
129
130
    def get_subgraph(self, edge_types: Union[str, List[str]]) -> Union[nx.MultiDiGraph, nx.DiGraph]:
131
        if not hasattr(self, "_subgraphs"):
132
            self._subgraphs = {}
133
        elif edge_types in self._subgraphs:
134
            return self._subgraphs[edge_types]
135
136
        if edge_types and isinstance(self.network, (nx.MultiGraph, nx.MultiDiGraph)):
137
            # Needed to create new nx.Graph because .edge_subgraph is too slow to iterate on (idk why)
138
            g = nx.from_edgelist([(u, v) for u, v, k in self.network.edges if k in edge_types],
139
                                 create_using=nx.DiGraph if self.network.is_directed() else nx.Graph)
140
        else:
141
            raise Exception("Must provide `edge_types` keys for a nx.MultiGraph type.")
142
143
        self._subgraphs[edge_types] = g
144
145
        return g
146
147
    def add_predecessor_terms(self, anns: pd.Series, edge_type: Union[str, List[str]] = 'is_a', sep="\||;"):
148
        anns_w_parents = anns.map(lambda x: [] if not isinstance(x, (list, np.ndarray)) else x) + \
149
                         get_predecessor_terms(anns, self.get_subgraph(edge_type))
150
151
        return anns_w_parents
152
153
    @staticmethod
154
    def get_node_color(file="~/Bioinformatics_ExternalData/GeneOntology/go_colors_biological.csv", ):
155
        go_colors = pd.read_csv(file)
156
157
        def selectgo(x):
158
            terms = [term for term in x if isinstance(term, str)]
159
            if len(terms) > 0:
160
                return terms[-1]
161
            else:
162
                return None
163
164
        go_colors["node"] = go_colors[[
165
            col for col in go_colors.columns if col.isdigit()
166
        ]].apply(selectgo, axis=1)
167
        go_id_colors = go_colors[go_colors["node"].notnull()].set_index("node")["HCL.color"]
168
        go_id_colors = go_id_colors[~go_id_colors.index.duplicated(keep="first")]
169
170
        print(go_id_colors.unique().shape, go_colors["HCL.color"].unique().shape)
171
        return go_id_colors
172
173
    def split_annotations(self, src_node_col="gene_name", dst_node_col="go_id", groupby: List[str] = ["Qualifier"],
174
                          train_date="2017-06-15", valid_date="2017-11-15", test_date="2021-12-31",
175
                          query: Optional[str] = "Evidence in ['EXP', 'IDA', 'IPI', 'IMP', 'IGI', 'IEP', 'TAS', 'IC']",
176
                          filter_src_nodes: pd.Index = None, filter_dst_nodes: pd.Index = None,
177
                          agg: Union[Callable, str] = "unique") -> Tuple[DataFrame, DataFrame, DataFrame]:
178
        """
179
180
        Args:
181
            src_node_col (str): Name of column containg the the src node types.
182
            dst_node_col (str): Name of column containg the the dst node types.
183
            train_date (str): A date before which the annotations belongs in the training set.
184
            valid_date (str): A date before which the annotations belongs in the validation set.
185
            test_date (str): A date before which the annotations belongs in the testing set.
186
            groupby (str): A list of strings to groupby annotations on, default [`src_node_col`, "Qualifier"].
187
            query (str, optional): A pandas query string to filter annotations. Default, only select ['EXP', 'IDA', 'IPI', 'IMP', 'IGI', 'IEP', 'TAS', 'IC'] annotations.
188
            filter_src_nodes (pd.Index): A subset annotations by these values on `src_node_col`.
189
            filter_dst_nodes (pd.Index): A subset annotations by these values on `dst_node_col`.
190
            agg (str): Either "unique" or "add_parent", or a callable function, or a dd.Aggregation() for aggregating on the `dst_node_col` column after groupby on `groupby`.
191
        """
192
        raise NotImplementedError
193
194
195
class GeneOntology(Ontology):
196
    """Loads the GeneOntology database from http://geneontology.org .
197
198
    Default path: "http://geneontology.org/gene-associations/".
199
200
    Default file_resources: {
201
        "go-basic.obo": "http://purl.obolibrary.org/obo/go/go-basic.obo",
202
        "goa_human.gaf": "goa_human.gaf.gz",
203
        "goa_human_rna.gaf": "goa_human_rna.gaf.gz",
204
        "goa_human_isoform.gaf": "goa_human_isoform.gaf.gz",
205
    }
206
    """
207
    COLUMNS_RENAME_DICT = {
208
        "DB_Object_Symbol": "gene_name",
209
        "DB_Object_ID": "gene_id",
210
        "GO_ID": "go_id",
211
        "Taxon_ID": 'species_id',
212
    }
213
214
    DROP_COLS = {'DB:Reference', 'With', 'Annotation_Extension', 'Gene_Product_Form_ID'}
215
216
    def __init__(
217
        self,
218
        path="http://geneontology.org/gene-associations/",
219
        species='human',
220
        file_resources=None,
221
        index_col='DB_Object_Symbol',
222
        keys=None,
223
        col_rename=COLUMNS_RENAME_DICT,
224
        blocksize=None,
225
        **kwargs
226
    ):
227
        """
228
        Loads the GeneOntology database from http://geneontology.org .
229
230
            Default path: "http://geneontology.org/gene-associations/" .
231
            Default file_resources: {
232
                "go-basic.obo": "http://purl.obolibrary.org/obo/go/go-basic.obo",
233
                "goa_human.gaf": "goa_human.gaf.gz",
234
                "goa_human_rna.gaf": "goa_human_rna.gaf.gz",
235
                "goa_human_isoform.gaf": "goa_human_isoform.gaf.gz",
236
            }
237
238
        Data for GO term annotations in .gpi files are already included in .obo file, so this module doesn't maker use of .gpi files.
239
240
        Handles downloading the latest Gene Ontology obo and annotation data, preprocesses them. It provide
241
        functionalities to create a directed acyclic graph of GO terms, filter terms, and filter annotations.
242
        """
243
        if species and not hasattr(self, 'species'):
244
            self.species = species.lower()
245
        elif species is None:
246
            self.species = 'uniprot'
247
248
        if file_resources is None:
249
            file_resources = {
250
                f"goa_{self.species}.gaf.gz": f"goa_{self.species}.gaf.gz",
251
            }
252
            if species != 'uniprot':
253
                file_resources[f"goa_{self.species}_rna.gaf.gz"] = f"goa_{self.species}_rna.gaf.gz"
254
                file_resources[f"goa_{self.species}_isoform.gaf.gz"] = f"goa_{self.species}_isoform.gaf.gz"
255
256
        if not any('.obo' in file for file in file_resources):
257
            warnings.warn(
258
                f'No .obo file provided in `file_resources`, so automatically adding "http://purl.obolibrary.org/obo/go/go-basic.obo"')
259
            file_resources["go-basic.obo"] = "http://purl.obolibrary.org/obo/go/go-basic.obo"
260
261
        super().__init__(path, file_resources, index_col=index_col, keys=keys, col_rename=col_rename,
262
                         blocksize=blocksize, **kwargs)
263
264
        # By default, the __init__ constructor run load_dataframe() before load_network(), but for GeneOntology,
265
        # we get node data from the nx.Graph, so we must run load_dataframe() again when self.network is not None.
266
        self.data = self.load_dataframe(self.file_resources, self.blocksize)
267
268
    def info(self):
269
        print("network {}".format(nx.info(self.network)))
270
271
    def load_dataframe(self, file_resources: Dict[str, TextIOWrapper], blocksize=None) -> DataFrame:
272
        if hasattr(self, 'network') and self.network is not None:
273
            # Annotations for each GO term from nodes in the NetworkX graph created by the .obo file
274
            go_terms = pd.DataFrame.from_dict(dict(self.network.nodes(data=True)), orient='index')
275
            go_terms["def"] = go_terms["def"].apply(
276
                lambda x: x.split('"')[1] if isinstance(x, str) else None)
277
            go_terms.index.name = "go_id"
278
279
            # Get depth of each term node in its ontology
280
            hierarchy = nx.subgraph_view(self.network, filter_edge=lambda u, v, e: e == 'is_a')
281
            for namespace in go_terms['namespace'].unique():
282
                root_term = go_terms.query(f'name == "{namespace}"').index.item()
283
                namespace_terms = go_terms.query(f'namespace == "{namespace}"').index
284
                shortest_paths = nx.shortest_path_length(hierarchy.subgraph(namespace_terms), root_term)
285
                go_terms.loc[namespace_terms, 'depth'] = namespace_terms.map(shortest_paths)
286
            go_terms['depth'] = go_terms['depth'].fillna(0).astype(int)
287
        else:
288
            go_terms = None
289
290
        return go_terms
291
292
    def load_network(self, file_resources) -> Tuple[nx.Graph, np.ndarray]:
293
        network, node_list = None, None
294
        fn = next((fn for fn in file_resources if fn.endswith(".obo")), None)
295
        if fn:
296
            network: nx.MultiDiGraph = obonet.read_obo(file_resources[fn])
297
            network = network.reverse(copy=True)
298
            node_list = np.array(network.nodes)
299
300
        return network, node_list
301
302
    def load_annotation(self, file_resources, blocksize=None) -> Union[pd.DataFrame, dd.DataFrame]:
303
        # Handle .gaf annotation files
304
        dfs = {}
305
        for filename, filepath_or_buffer in file_resources.items():
306
            gaf_name = filename.split(".")[0]
307
            # Ensure no duplicate GAF file (if having files uncompressed with same prefix)
308
            if gaf_name in dfs: continue
309
310
            if blocksize and isinstance(filepath_or_buffer, str):
311
                if filename.endswith(".processed.parquet"):
312
                    # Parsed and filtered gaf file
313
                    dfs[gaf_name] = dd.read_parquet(filepath_or_buffer, chunksize=blocksize)
314
                    if dfs[gaf_name].index.name != self.index_col and self.index_col in dfs[gaf_name].columns:
315
                        dfs[gaf_name] = dfs[gaf_name].set_index(self.index_col, sorted=True)
316
                    if not dfs[gaf_name].known_divisions:
317
                        dfs[gaf_name].divisions = dfs[gaf_name].compute_current_divisions()
318
319
                elif (filename.endswith(".parquet") or filename.endswith(".gaf")):
320
                    # .parquet from .gaf.gz file, unfiltered, with raw str values
321
                    dfs[gaf_name] = read_gaf(filepath_or_buffer, blocksize=blocksize, index_col=self.index_col,
322
                                             keys=self.keys, usecols=self.usecols)
323
324
                elif filename.endswith(".gaf.gz"):
325
                    # Compressed .gaf file downloaded
326
                    dfs[gaf_name] = read_gaf(filepath_or_buffer, blocksize=blocksize, index_col=self.index_col,
327
                                             keys=self.keys, usecols=self.usecols, compression='gzip')
328
329
            else:
330
                if filename.endswith(".processed.parquet"):
331
                    dfs[gaf_name] = pd.read_parquet(filepath_or_buffer)
332
                if filename.endswith(".gaf"):
333
                    dfs[gaf_name] = read_gaf(filepath_or_buffer, index_col=self.index_col, keys=self.keys,
334
                                             usecols=self.usecols)
335
336
        if len(dfs):
337
            annotations = dd.concat(list(dfs.values()), interleave_partitions=True) \
338
                if blocksize else pd.concat(dfs.values())
339
340
            annotations = annotations.rename(columns=UniProtGOA.COLUMNS_RENAME_DICT)
341
            if annotations.index.name in UniProtGOA.COLUMNS_RENAME_DICT:
342
                annotations.index = annotations.index.rename(
343
                    UniProtGOA.COLUMNS_RENAME_DICT[annotations.index.name])
344
        else:
345
            annotations = None
346
347
        return annotations
348
349
    def split_annotations(self, src_node_col="gene_name", dst_node_col="go_id", groupby: List[str] = ["Qualifier"],
350
                          train_date="2017-06-15", valid_date="2017-11-15", test_date="2021-12-31",
351
                          query: str = "Evidence in ['EXP', 'IDA', 'IPI', 'IMP', 'IGI', 'IEP', 'TAS', 'IC']",
352
                          filter_src_nodes: pd.Index = None, filter_dst_nodes: pd.Index = None,
353
                          agg: Union[str, Callable, dd.Aggregation] = "unique") \
354
        -> Tuple[DataFrame, DataFrame, DataFrame]:
355
        assert isinstance(groupby, list) and groupby, f"`groupby` must be a nonempty list of strings. Got {groupby}"
356
357
        # Set the source column (i.e. protein_id or gene_name), to be the first in groupby
358
        if src_node_col not in groupby:
359
            groupby = [src_node_col] + groupby
360
        if "Qualifier" not in groupby and "Qualifier" in self.annotations.columns:
361
            groupby.append("Qualifier")
362
363
        # Aggregator function
364
        if agg == "add_parent":
365
            subgraph = self.get_subgraph(edge_types="is_a")
366
            node_ancestors = {node: nx.ancestors(subgraph, node) for node in subgraph.nodes}
367
368
            if isinstance(self.annotations, dd.DataFrame):
369
                agg = dd.Aggregation(name='_unique_add_parent',
370
                                     chunk=lambda s: s.unique(),
371
                                     agg=lambda s0: s0.apply(get_predecessor_terms, node_ancestors, keep_terms=True),
372
                                     finalize=lambda s1: s1.apply(lambda li: np.hstack(li) if li else None))
373
            else:
374
                agg = lambda s: get_predecessor_terms(s, g=node_ancestors, join_groups=True, keep_terms=True)
375
376
        elif agg == 'unique' and isinstance(self.annotations, dd.DataFrame):
377
            agg = get_agg_func('unique', use_dask=True)
378
379
        elif isinstance(self.annotations, dd.DataFrame) and not isinstance(agg, dd.Aggregation):
380
            raise Exception("`agg` must be a dd.Aggregation for groupby.agg() on columns of a dask DataFrame")
381
382
        def _remove_dup_neg_go_id(s: pd.Series) -> pd.Series:
383
            if s.isna().any():
384
                return s
385
            elif isinstance(s[neg_dst_col], Iterable) and isinstance(s[dst_node_col], Iterable):
386
                rm_dups_go_id = [go_id for go_id in s[neg_dst_col] if go_id not in s[dst_node_col]]
387
                if len(rm_dups_go_id) == 0:
388
                    rm_dups_go_id = None
389
                s[neg_dst_col] = rm_dups_go_id
390
            return s
391
392
        neg_dst_col = f"neg_{dst_node_col}"
393
394
        # Filter annotations
395
        annotations = self.annotations
396
        if query:
397
            annotations = annotations.query(query)
398
        if filter_src_nodes is not None:
399
            annotations = annotations[annotations[src_node_col].isin(filter_src_nodes)]
400
        if filter_dst_nodes is not None:
401
            annotations = annotations[annotations[dst_node_col].isin(filter_dst_nodes)]
402
        if annotations.index.name in groupby:
403
            annotations = annotations.reset_index()
404
405
        # Split train/valid/test annotations
406
        train_anns = annotations.loc[annotations["Date"] <= pd.to_datetime(train_date)]
407
        valid_anns = annotations.loc[(annotations["Date"] <= pd.to_datetime(valid_date)) & \
408
                                     (annotations["Date"] > pd.to_datetime(train_date))]
409
        test_anns = annotations.loc[(annotations["Date"] <= pd.to_datetime(test_date)) & \
410
                                    (annotations["Date"] > pd.to_datetime(valid_date))]
411
412
        outputs = []
413
        for anns in [train_anns, valid_anns, test_anns]:
414
            # Keep track of which annotation has a "NOT" Qualifier
415
            is_neg_ann = anns.loc[:, "Qualifier"].map(lambda li: "NOT" in li)
416
417
            # Convert `Qualifiers` entries of list of strings to string
418
            args = dict(meta=pd.Series([""])) if isinstance(anns, dd.DataFrame) else {}
419
            anns.loc[:, 'Qualifier'] = anns.loc[:, 'Qualifier'].apply(
420
                lambda li: "".join([i for i in li if i != "NOT"]), **args)
421
422
            # Aggregate gene-GO annotations
423
            if isinstance(anns, pd.DataFrame) and len(anns.index):
424
                pos_anns = anns[~is_neg_ann].groupby(groupby).agg({dst_node_col: agg})
425
                neg_anns = anns[is_neg_ann].groupby(groupby).agg(**{neg_dst_col: (dst_node_col, agg)})
426
                pos_neg_anns = pd.concat([pos_anns, neg_anns], axis=1)
427
428
            elif isinstance(anns, dd.DataFrame) and len(anns.index) and dst_node_col in anns.columns:
429
                pos_anns = anns[~is_neg_ann].groupby(groupby).agg({dst_node_col: agg})
430
                if False and len(is_neg_ann.index):
431
                    neg_anns = anns[is_neg_ann].groupby(groupby).agg({dst_node_col: agg})
432
                    neg_anns.columns = [neg_dst_col]
433
                    pos_neg_anns = dd.concat([pos_anns, neg_anns], axis=1)
434
                else:
435
                    pos_neg_anns = pos_anns
436
                    pos_neg_anns[neg_dst_col] = None
437
438
            else:
439
                pos_neg_anns = pd.DataFrame(
440
                    columns=[dst_node_col, neg_dst_col],
441
                    index=pd.MultiIndex(levels=[[] for i in range(len(groupby))],
442
                                        codes=[[] for i in range(len(groupby))], names=groupby))
443
                outputs.append(pos_neg_anns)
444
                continue
445
446
            if isinstance(pos_neg_anns, pd.DataFrame):
447
                pos_neg_anns = pos_neg_anns.drop([""], axis='index', errors="ignore")
448
449
            # Remove "GO:0005515" (protein binding) annotations for a gene if it's the gene's only annotation
450
            _exclude_single_fn = lambda li: None \
451
                if isinstance(li, Iterable) and len(li) == 1 and "GO:0005515" in li else li
452
            args = dict(meta=pd.Series([list()])) if isinstance(anns, dd.DataFrame) else {}
453
            pos_neg_anns.loc[:, dst_node_col] = pos_neg_anns[dst_node_col].apply(_exclude_single_fn, **args)
454
455
            # Drop rows with all nan values
456
            if isinstance(pos_neg_anns, pd.DataFrame):
457
                pos_neg_anns = pos_neg_anns.drop(pos_neg_anns.index[pos_neg_anns.isna().all(1)], axis='index')
458
459
            # Ensure no negative terms duplicates positive annotations
460
            if len(is_neg_ann.index):
461
                args = dict(meta=pd.DataFrame({dst_node_col: [], neg_dst_col: []})) \
462
                    if isinstance(anns, dd.DataFrame) else {}
463
                pos_neg_anns = pos_neg_anns.apply(_remove_dup_neg_go_id, axis=1, **args)
464
465
            outputs.append(pos_neg_anns)
466
467
        return tuple(outputs)
468
469
470
def get_predecessor_terms(anns: Union[pd.Series, Iterable], g: Union[Dict[str, List[str]], nx.MultiDiGraph],
471
                          join_groups=False, keep_terms=True, exclude={'GO:0005575', 'GO:0008150', 'GO:0003674'}) \
472
    -> Union[pd.Series, List[str]]:
473
    """
474
475
    Args:
476
        anns ():
477
        g (nx.MultiDiGraph, Dict[str,Set[str]]): Either a NetworkX DAG or a precomputed lookup table of node to ancestors
478
        join_groups (): whether to concatenate multiple
479
        keep_terms ():
480
        exclude ():
481
482
    Returns:
483
484
    """
485
    if exclude is None:
486
        exclude = {}
487
488
    def _get_ancestors(terms: Iterable):
489
        try:
490
            if isinstance(terms, Iterable):
491
                if isinstance(g, dict):
492
                    parents = {parent \
493
                               for term in terms if term in g \
494
                               for parent in g[term] if parent not in exclude}
495
496
                elif isinstance(g, nx.Graph):
497
                    parents = {parent \
498
                               for term in terms if term in g.nodes \
499
                               for parent in nx.ancestors(g, term) if parent not in exclude}
500
                else:
501
                    raise Exception("Provided `g` arg must be either an nx.Graph or a Dict")
502
            else:
503
                parents = []
504
505
            if keep_terms and isinstance(terms, list):
506
                terms.extend(parents)
507
                out = terms
508
            else:
509
                out = list(parents)
510
511
        except (NetworkXError, KeyError) as nxe:
512
            if "foo" not in nxe.__str__():
513
                logger.error(f"{nxe.__class__.__name__} get_predecessor_terms._get_ancestors: {nxe}")
514
            out = terms if keep_terms else []
515
516
        return out
517
518
    if isinstance(anns, pd.Series):
519
        if (anns.map(type) == str).any():
520
            anns = anns.map(lambda s: [s])
521
522
        parent_terms = anns.map(_get_ancestors)
523
        if join_groups:
524
            parent_terms = sum(parent_terms, [])
525
526
    elif isinstance(anns, Iterable):
527
        parent_terms = _get_ancestors(anns)
528
529
    else:
530
        parent_terms = []
531
532
    return parent_terms
533
534
class UniProtGOA(GeneOntology):
535
    """Loads the GeneOntology database from https://www.ebi.ac.uk/GOA/ .
536
537
    Default path: "ftp://ftp.ebi.ac.uk/pub/databases/GO/goa/UNIPROT/" .
538
    Default file_resources: {
539
        "goa_uniprot_all.gaf": "goa_uniprot_all.gaf.gz",
540
    }
541
    """
542
543
    COLUMNS_RENAME_DICT = {
544
        "DB_Object_ID": "protein_id",
545
        "DB_Object_Symbol": "gene_name",
546
        "GO_ID": "go_id",
547
        "Taxon_ID": 'species_id',
548
    }
549
    def __init__(
550
        self,
551
        path="ftp://ftp.ebi.ac.uk/pub/databases/GO/goa/",
552
        species="HUMAN",
553
        file_resources=None,
554
        index_col='DB_Object_ID', keys=None,
555
        col_rename=COLUMNS_RENAME_DICT,
556
        blocksize=None,
557
        **kwargs,
558
    ):
559
        """
560
        Loads the UniProtGOA database from https://www.ebi.ac.uk/GOA/ .
561
562
            Default path: "ftp://ftp.ebi.ac.uk/pub/databases/GO/goa/UNIPROT/" .
563
            Default file_resources: {
564
                "goa_uniprot_all.gaf.gz": "goa_uniprot_all.gaf.gz",
565
                "go.obo": "http://current.geneontology.org/ontology/go.obo",
566
            }
567
568
        Handles downloading the latest Gene Ontology obo and annotation data, preprocesses them. It provides
569
        functionalities to create a directed acyclic graph of GO terms, filter terms, and filter annotations.
570
571
        Args:
572
            path ():
573
            species ():
574
            file_resources ():
575
            index_col ():
576
            keys ():
577
            col_rename ():
578
            blocksize ():
579
            **kwargs ():
580
        """
581
        if species is None:
582
            self.species = species = 'UNIPROT'
583
            substr = 'uniprot_all'
584
        else:
585
            self.species = species.upper()
586
            substr = species.lower()
587
588
        if file_resources is None:
589
            file_resources = {
590
                "go.obo": "http://current.geneontology.org/ontology/go.obo",
591
                f"goa_{self.species.lower()}.gaf.gz": os.path.join(species, f"goa_{substr}.gaf.gz"),
592
                # f"goa_{self.species.lower()}_isoform.gaf.gz": os.path.join(species, f"goa_{substr}_isoform.gaf.gz"),
593
                # f"goa_{self.species.lower()}_complex.gaf.gz": os.path.join(species, f"goa_{substr}_complex.gaf.gz"),
594
            }
595
596
        if not any('.obo' in file for file in file_resources):
597
            warnings.warn(f'No .obo file provided in `file_resources`, '
598
                          f'so automatically adding "http://purl.obolibrary.org/obo/go/go-basic.obo"')
599
            file_resources["go-basic.obo"] = "http://purl.obolibrary.org/obo/go/go-basic.obo"
600
601
        super().__init__(path=path, file_resources=file_resources, index_col=index_col, keys=keys,
602
                         col_rename=col_rename,
603
                         blocksize=blocksize, **kwargs)
604
605
606
class InterPro(Ontology):
607
    """
608
    Default parameters
609
    path="https://ftp.ebi.ac.uk/pub/databases/interpro/current_release/"
610
    file_resources = {
611
        "entry.list": "entry.list",
612
        "protein2ipr.dat.gz": "protein2ipr.dat.gz",
613
        "interpro2go": "interpro2go",
614
        "ParentChildTreeFile.txt": "ParentChildTreeFile.txt",
615
    }
616
    """
617
618
    def __init__(self, path="https://ftp.ebi.ac.uk/pub/databases/interpro/current_release/", index_col='UniProtKB-AC',
619
                 keys=None, file_resources=None, col_rename=None, **kwargs):
620
        """
621
622
        Args:
623
            path ():
624
            index_col (str): Default 'UniProtKB-AC'.
625
            keys ():
626
            file_resources ():
627
            col_rename ():
628
            **kwargs ():
629
        """
630
        assert index_col is not None
631
632
        if file_resources is None:
633
            file_resources = {}
634
            file_resources["entry.list"] = os.path.join(path, "entry.list")
635
            file_resources["protein2ipr.dat.gz"] = os.path.join(path, "protein2ipr.dat.gz")
636
            file_resources["interpro2go"] = os.path.join(path, "interpro2go")
637
            file_resources["ParentChildTreeFile.txt"] = os.path.join(path, "ParentChildTreeFile.txt")
638
639
        if any('protein2ipr' in fn for fn in file_resources):
640
            assert keys is not None, "Processing protein2ipr requires user to provide `keys` as a list of " \
641
                                     "`UniProtKB-AC` values"
642
643
        super().__init__(path=path, file_resources=file_resources, index_col=index_col, keys=keys,
644
                         col_rename=col_rename, **kwargs)
645
646
    def load_dataframe(self, file_resources: Dict[str, TextIOWrapper], blocksize=None):
647
        ipr_entries = pd.read_table(file_resources["entry.list"], index_col="ENTRY_AC")
648
649
        ipr2go_fn = next((fn for fn in file_resources if 'interpro2go' in fn), None)
650
        if ipr2go_fn:
651
            ipr2go = self.parse_interpro2go(file_resources[ipr2go_fn])
652
            if ipr2go is not None:
653
                ipr_entries = ipr_entries.join(ipr2go.groupby('ENTRY_AC')["go_id"].unique(), on="ENTRY_AC")
654
655
        ipr2go_fn = next((fn for fn in file_resources if 'interpro.xml' in fn), None)
656
        if ipr2go_fn:
657
            ipr_xml = pd.read_xml(file_resources[ipr2go_fn], xpath='//interpro',
658
                                  compression='gzip' if ipr2go_fn.endswith('.gz') else None) \
659
                .dropna(axis=1, how='all') \
660
                .rename(columns={'id': 'ENTRY_AC'}) \
661
                .set_index('ENTRY_AC')
662
            ipr_entries = ipr_entries.join(ipr_xml, on="ENTRY_AC")
663
664
        return ipr_entries
665
666
    def load_network(self, file_resources) -> Tuple[nx.Graph, np.ndarray]:
667
        network, node_list = None, None
668
        for filename in file_resources:
669
            if 'ParentChildTreeFile' in filename and isinstance(file_resources[filename], str):
670
                network: nx.MultiDiGraph = self.parse_ipr_treefile(file_resources[filename])
671
                node_list = np.array(network.nodes)
672
673
        return network, node_list
674
675
    def load_annotation(self, file_resources, blocksize=None):
676
        if not any('protein2ipr' in fn for fn in file_resources):
677
            return None
678
679
        ipr_entries = self.data
680
681
        # Use Dask
682
        args = dict(names=['UniProtKB-AC', 'ENTRY_AC', 'ENTRY_NAME', 'accession', 'start', 'stop'],
683
                    usecols=['UniProtKB-AC', 'ENTRY_AC', 'start', 'stop'],
684
                    dtype={'UniProtKB-AC': 'category', 'ENTRY_AC': 'category', 'start': 'int8', 'stop': 'int8'},
685
                    low_memory=True,
686
                    blocksize=None if isinstance(blocksize, bool) else blocksize)
687
        if 'protein2ipr.parquet' in file_resources:
688
            annotations = dd.read_parquet(file_resources["protein2ipr.parquet"])
689
        else:
690
            annotations = dd.read_table(file_resources["protein2ipr.dat"], **args)
691
        if self.keys is not None and self.index_col in annotations.columns:
692
            annotations = annotations.loc[annotations[self.index_col].isin(self.keys)]
693
        elif self.keys is not None and self.index_col == annotations.index.name:
694
            annotations = annotations.loc[annotations.index.isin(self.keys)]
695
        # if annotations.index.name != self.index_col:
696
        #     annotations = annotations.set_index(self.index_col, sorted=True)
697
        # if not annotations.known_divisions:
698
        #     annotations.divisions = annotations.compute_current_divisions()
699
        # Set ordering for rows and columns
700
        row_order = self.keys
701
        col_order = ipr_entries.index
702
        row2idx = {node: i for i, node in enumerate(row_order)}
703
        col2idx = {node: i for i, node in enumerate(col_order)}
704
705
        def edgelist2coo(edgelist_df: DataFrame, source='UniProtKB-AC', target='ENTRY_AC') -> Optional[ssp.coo_matrix]:
706
            if edgelist_df.shape[0] == 1 and edgelist_df.iloc[0, 0] == 'foo':
707
                return None
708
709
            if edgelist_df.index.name == source:
710
                source_nodes = edgelist_df.index
711
            else:
712
                source_nodes = edgelist_df[source]
713
714
            edgelist_df = edgelist_df.assign(row=source_nodes.map(row2idx).astype('int'),
715
                                             col=edgelist_df[target].map(col2idx).astype('int'))
716
717
            edgelist_df = edgelist_df.dropna(subset=['row', 'col'])
718
            if edgelist_df.shape[0] == 0:
719
                return None
720
721
            values = np.ones(edgelist_df.index.size)
722
            coo = ssp.coo_matrix((values, (edgelist_df['row'], edgelist_df['col'])),
723
                                 shape=(len(row2idx), ipr_entries.index.size))
724
            return coo
725
726
        # Create a sparse adjacency matrix each partition, then combine them
727
        adj = annotations.reduction(chunk=edgelist2coo,
728
                                    aggregate=lambda x: x.dropna().sum() if not x.isna().all() else None,
729
                                    meta=pd.Series([ssp.coo_matrix])).compute()
730
        assert len(adj) == 1, f"len(adj) = {len(adj)}"
731
        # Create a sparse matrix of UniProtKB-AC x ENTRY_AC
732
        annotations = pd.DataFrame.sparse.from_spmatrix(adj[0], index=row_order, columns=col_order)
733
734
        return annotations
735
736
    def parse_interpro2go(self, file: str) -> pd.DataFrame:
737
        def _process_line(line: str) -> Tuple[str, str, str]:
738
            pos = line.find('> GO')
739
            interpro_terms, go_term = line[:pos], line[pos:]
740
            interpro_id, interpro_name = interpro_terms.strip().split(' ', 1)
741
            go_name, go_id = go_term.split(';')
742
            go_desc = go_name.strip('> GO:')
743
744
            return (interpro_id.strip().split(':')[1], go_id.strip(), go_desc)
745
746
        if isinstance(file, str):
747
            with open(os.path.expanduser(file), 'r') as file:
748
                tuples = [_process_line(line.strip()) for line in file if line[0] != '!']
749
750
            ipr2go = pd.DataFrame(tuples, columns=['ENTRY_AC', "go_id", "go_desc"])
751
            return ipr2go
752
753
    def parse_ipr_treefile(self, lines: Union[List[str], StringIO]) -> nx.MultiDiGraph:
754
        """Parse the InterPro Tree from the given file.
755
        Args:
756
            lines: A readable file or file-like
757
        """
758
        if isinstance(lines, str):
759
            lines = open(os.path.expanduser(lines), 'r')
760
761
        graph = nx.MultiDiGraph()
762
        previous_depth, previous_name = 0, None
763
        stack = [previous_name]
764
765
        def count_front(s: str) -> int:
766
            """Count the number of leading dashes on a string."""
767
            for position, element in enumerate(s):
768
                if element != '-':
769
                    return position
770
771
        for line in lines:
772
            depth = count_front(line)
773
            interpro_id, name, *_ = line[depth:].split('::')
774
775
            if depth == 0:
776
                stack.clear()
777
                stack.append(interpro_id)
778
779
                graph.add_node(interpro_id, interpro_id=interpro_id, name=name)
780
781
            else:
782
                if depth > previous_depth:
783
                    stack.append(previous_name)
784
785
                elif depth < previous_depth:
786
                    del stack[-1]
787
788
                parent = stack[-1]
789
790
                graph.add_node(interpro_id, interpro_id=interpro_id, parent=parent, name=name)
791
                graph.add_edge(parent, interpro_id, key="is_a")
792
793
            previous_depth, previous_name = depth, interpro_id
794
795
        lines.close()
796
        return graph
797
798
799
class HumanPhenotypeOntology(Ontology):
800
    """Loads the Human Phenotype Ontology database from https://hpo.jax.org/app/ .
801
802
        Default path: "http://geneontology.org/gene-associations/" .
803
        Default file_resources: {
804
            "hp.obo": "http://purl.obolibrary.org/obo/hp.obo",
805
        }
806
        """
807
808
    COLUMNS_RENAME_DICT = {}
809
810
    def __init__(
811
        self,
812
        path="https://hpo.jax.org/",
813
        file_resources=None,
814
        col_rename=COLUMNS_RENAME_DICT,
815
        blocksize=0,
816
        verbose=False,
817
    ):
818
        """
819
        Handles downloading the latest Human Phenotype Ontology obo and annotation data, preprocesses them. It provide
820
        functionalities to create a directed acyclic graph of Ontology terms, filter terms, and filter annotations.
821
        """
822
        if file_resources is None:
823
            file_resources = {
824
                "hp.obo": "http://purl.obolibrary.org/obo/hp.obo",
825
            }
826
        super().__init__(
827
            path,
828
            file_resources,
829
            col_rename=col_rename,
830
            blocksize=blocksize,
831
            verbose=verbose,
832
        )
833
834
    def info(self):
835
        print("network {}".format(nx.info(self.network)))
836
    def load_network(self, file_resources):
837
        for file in file_resources:
838
            if ".obo" in file:
839
                network = obonet.read_obo(file_resources[file])
840
                network = network.reverse(copy=True)
841
                node_list = np.array(network.nodes)
842
        return network, node_list
843
844
845
846
847
def traverse_predecessors(network, seed_node, type=["is_a", "part_of"]):
848
    """
849
    Returns all successor terms from seed_node by traversing the ontology network with edges == `type`.
850
    Args:
851
        seed_node: seed node of the traversal
852
        type: the ontology type to include
853
    Returns:
854
        generator of list of lists for each dfs branches.
855
    """
856
    parents = dict(network.pred[seed_node])
857
    for parent, v in parents.items():
858
        if list(v.keys())[0] in type:
859
            yield [parent] + list(traverse_predecessors(network, parent, type))
860
861
862
def flatten(lst):
863
    return sum(([x] if not isinstance(x, list) else flatten(x) for x in lst),
864
               [])
865
866
867
def dfs_path(graph, path):
868
    node = path[-1]
869
    successors = list(graph.successors(node))
870
    if len(successors) > 0:
871
        for child in successors:
872
            yield list(dfs_path(graph, path + [child]))
873
    else:
874
        yield path
875
876
877
def flatten_list(list_in):
878
    if isinstance(list_in, list):
879
        for l in list_in:
880
            if isinstance(list_in[0], list):
881
                for y in flatten_list(l):
882
                    yield y
883
            elif isinstance(list_in[0], str):
884
                yield list_in
885
    else:
886
        yield list_in
887
888
889
def filter_dfs_paths(paths_df: pd.DataFrame):
890
    idx = {}
891
    for col in sorted(paths_df.columns[:-1], reverse=True):
892
        idx[col] = ~(paths_df[col].notnull()
893
                     & paths_df[col].duplicated(keep="first")
894
                     & paths_df[col + 1].isnull())
895
896
    idx = pd.DataFrame(idx)
897
898
    paths_df = paths_df[idx.all(axis=1)]
899
    return paths_df
900
901
902
def write_taxonomy(network, root_nodes, file_path):
903
    """
904
905
    Args:
906
        network: A network with edge(i, j) where i is a node and j is a child of i.
907
        root_nodes (list): a list of node names
908
        file_path (str):
909
    """
910
    file = open(file_path, "a")
911
    file.write("Root\t" + "\t".join(root_nodes) + "\n")
912
913
    for root_node in root_nodes:
914
        for node, children in nx.traversal.bfs_successors(network, root_node):
915
            if len(children) > 0:
916
                file.write(node + "\t" + "\t".join(children) + "\n")
917
    file.close()