Switch to unified view

a b/openomics/database/sequence.py
1
import os
2
import re
3
import traceback
4
from abc import abstractmethod
5
from collections import defaultdict, OrderedDict
6
from typing import Union, List, Callable, Dict, Tuple, Optional, Iterable
7
8
import numpy as np
9
import pandas as pd
10
import tqdm
11
from Bio import SeqIO
12
from Bio.SeqFeature import ExactPosition
13
from dask import dataframe as dd
14
from logzero import logger
15
from pyfaidx import Fasta
16
from six.moves import intern
17
18
import openomics
19
from openomics.io.read_gtf import read_gtf
20
from .base import Database
21
from ..io.files import select_files_with_ext
22
from ..transforms.agg import get_agg_func
23
from ..transforms.df import drop_duplicate_columns
24
25
__all__ = ['GENCODE', 'UniProt', 'MirBase', 'RNAcentral']
26
27
SEQUENCE_COL = 'sequence'
28
29
30
class SequenceDatabase(Database):
31
    """Provides a series of methods to extract sequence data from
32
    SequenceDataset.
33
    """
34
35
    def __init__(self, **kwargs):
36
        """
37
        Args:
38
            **kwargs:
39
        """
40
        super().__init__(**kwargs)
41
        self.close()
42
43
    @abstractmethod
44
    def load_sequences(self, fasta_file: str, index=None, keys: Union[pd.Index, List[str]] = None, blocksize=None) \
45
        -> pd.DataFrame:
46
        """Returns a pandas DataFrame containing the fasta sequence entries.
47
        With a column named 'sequence'.
48
49
        Args:
50
            index ():
51
            fasta_file (str): path to the fasta file, usually as
52
                self.file_resources[<file_name>]
53
            keys (pd.Index): list of keys to
54
            blocksize:
55
        """
56
        raise NotImplementedError
57
58
    @abstractmethod
59
    def get_sequences(self, index: str, omic: str, agg: str, **kwargs) -> Union[pd.Series, Dict]:
60
        """Returns a dictionary where keys are 'index' and values are
61
        sequence(s).
62
63
        Args:
64
            index (str): {"gene_id", "gene_name", "transcript_id",
65
                "transcript_name"}
66
            omic (str): {"lncRNA", "microRNA", "messengerRNA"}
67
            agg (str): {"all", "shortest", "longest"}
68
            **kwargs: any additional argument to pass to
69
                SequenceDataset.get_sequences()
70
        """
71
        raise NotImplementedError
72
73
    @staticmethod
74
    def aggregator_fn(agg: Union[str, Callable] = None) -> Callable:
75
        """Returns a function used aggregate a list of sequences from a groupby
76
        on a given key.
77
78
        Args:
79
            agg: One of ("all", "shortest", "longest", "first"), default "all". If "all",
80
                then return a list of sequences.
81
        """
82
        if agg == "all":
83
            agg_func = lambda x: list(x) if not isinstance(x, str) else x
84
        elif agg == "shortest":
85
            agg_func = lambda x: min(x, key=len) if isinstance(x, list) else x
86
        elif agg == "longest":
87
            agg_func = lambda x: max(x, key=len) if isinstance(x, list) else x
88
        elif agg == 'first':
89
            agg_func = lambda x: x[0] if isinstance(x, list) else x
90
        elif callable(agg):
91
            return agg
92
        else:
93
            raise Exception(
94
                "agg_sequences argument must be one of {'all', 'shortest', 'longest'}"
95
            )
96
        return agg_func
97
98
class GENCODE(SequenceDatabase):
99
    """Loads the GENCODE database from https://www.gencodegenes.org/ .
100
101
    Default path: ftp://ftp.ebi.ac.uk/pub/databases/gencode/Gencode_human/release_32/ .
102
    Default file_resources: {
103
        "basic.annotation.gtf": "gencode.v32.basic.annotation.gtf.gz",
104
        "long_noncoding_RNAs.gtf": "gencode.v32.long_noncoding_RNAs.gtf.gz",
105
        "lncRNA_transcripts.fa": "gencode.v32.lncRNA_transcripts.fa.gz",
106
        "transcripts.fa": "gencode.v32.transcripts.fa.gz",
107
    }
108
    """
109
    def __init__(
110
        self,
111
        path="ftp://ftp.ebi.ac.uk/pub/databases/gencode/Gencode_human/release_32/",
112
        file_resources=None,
113
        col_rename=None,
114
        blocksize=0,
115
        remove_version_num=False,
116
        **kwargs
117
    ):
118
        """
119
        Args:
120
            path:
121
            file_resources:
122
            col_rename:
123
            blocksize:
124
            remove_version_num (bool): Whether to drop the version number on the
125
                ensembl ID.
126
        """
127
        if file_resources is None:
128
            file_resources = {
129
                "basic.annotation.gtf.gz": "gencode.v32.basic.annotation.gtf.gz",
130
                "long_noncoding_RNAs.gtf.gz": "gencode.v32.long_noncoding_RNAs.gtf.gz",
131
                "lncRNA_transcripts.fa.gz": "gencode.v32.lncRNA_transcripts.fa.gz",
132
                "transcripts.fa.gz": "gencode.v32.transcripts.fa.gz",
133
            }
134
135
        self.remove_version_num = remove_version_num
136
137
        super().__init__(path=path, file_resources=file_resources, col_rename=col_rename, blocksize=blocksize, **kwargs)
138
139
    def load_dataframe(self, file_resources, blocksize=None):
140
        """
141
        Loads the GENCODE annotation file into a pandas or dask DataFrame.
142
        Args:
143
            file_resources:
144
            blocksize:
145
146
        Returns:
147
148
        """
149
        gtf_files = select_files_with_ext(file_resources, ".gtf")
150
        if not gtf_files:
151
            gtf_files = select_files_with_ext(file_resources, ".gtf.gz")
152
153
        ddfs = []
154
        for filename, filepath in gtf_files.items():
155
            if blocksize and not isinstance(filepath, str): continue
156
            ddf = read_gtf(filepath, blocksize=blocksize,
157
                           compression="gzip" if filename.endswith(".gz") else None)
158
159
            ddfs.append(ddf)
160
161
        annotation_df = dd.concat(ddfs, interleave_partitions=True,
162
                                  ignore_unknown_divisions=True) if blocksize else pd.concat(ddfs)
163
164
        if self.remove_version_num:
165
            annotation_df = annotation_df.assign(
166
                gene_id=annotation_df["gene_id"].str.replace("[.]\d*", "", regex=True),
167
                transcript_id=annotation_df["transcript_id"].str.replace("[.]\d*", "", regex=True))
168
169
        return annotation_df
170
171
    def load_sequences(self, fasta_file: str, index=None, keys: pd.Index = None, blocksize=None):
172
        """
173
        Args:
174
            index ():
175
            keys ():
176
            fasta_file:
177
            blocksize:
178
        """
179
        if hasattr(self, '_seq_df_dict') and fasta_file in self._seq_df_dict:
180
            return self._seq_df_dict[fasta_file]
181
182
        def get_transcript_id(x):
183
            key = x.split('|')[0]  # transcript_id
184
            if self.remove_version_num:
185
                return re.sub("[.]\d*", "", key)
186
            else:
187
                return key
188
189
        fa = Fasta(fasta_file, key_function=get_transcript_id, as_raw=True)
190
191
        entries = []
192
        for key, record in tqdm.tqdm(fa.items(), desc=str(fasta_file)):
193
            if keys is not None and key not in keys: continue
194
195
            attrs = record.long_name.split("|")
196
            record_dict = {
197
                "transcript_id": attrs[0],
198
                "gene_id": attrs[1],
199
                "gene_name": attrs[5],
200
                "transcript_name": attrs[4],
201
                "transcript_length": attrs[6],
202
                "transcript_biotype": intern(attrs[7]),
203
                SEQUENCE_COL: str(record),
204
            }
205
206
            entries.append(record_dict)
207
208
        seq_df = pd.DataFrame(entries)
209
        if blocksize:
210
            seq_df = dd.from_pandas(seq_df, chunksize=blocksize)
211
212
        if self.remove_version_num:
213
            seq_df["gene_id"] = seq_df["gene_id"].str.replace("[.]\d*", "", regex=True)
214
            seq_df["transcript_id"] = seq_df["transcript_id"].str.replace("[.]\d*", "", regex=True)
215
216
        # Cache the seq_df
217
        if not hasattr(self, '_seq_df_dict'):
218
            self._seq_df_dict = {}
219
        if keys is not None:
220
            self._seq_df_dict[fasta_file] = seq_df
221
222
        return seq_df
223
224
    def get_sequences(self, index: Union[str, Tuple[str]], omic: str, agg: str = 'all', biotypes: List[str] = None):
225
        """
226
        Args:
227
            index (str):
228
            omic (str):
229
            agg (str):
230
            biotypes (List[str]):
231
        """
232
        agg_func = self.aggregator_fn(agg)
233
234
        # Parse lncRNA & mRNA fasta
235
        if omic == openomics.MessengerRNA.name():
236
            fasta_file = self.file_resources["transcripts.fa"]
237
        elif omic == openomics.LncRNA.name():
238
            fasta_file = self.file_resources["lncRNA_transcripts.fa"]
239
        else:
240
            raise Exception("omic argument must be one of {'MessengerRNA', 'LncRNA'}")
241
242
        assert isinstance(fasta_file, str), \
243
            f"Fasta file provided in `file_resources` must be an uncompressed .fa file. Given {fasta_file}."
244
245
        seq_df = self.load_sequences(fasta_file)
246
247
        if "gene" in index:
248
            if biotypes:
249
                seq_df = seq_df[seq_df["transcript_biotype"].isin(biotypes)]
250
            else:
251
                print("INFO: You can pass in a list of transcript biotypes to filter using the argument 'biotypes'.")
252
253
            return seq_df.groupby(index)[SEQUENCE_COL].agg(agg_func)
254
255
        else:
256
            return seq_df.groupby(index)[SEQUENCE_COL].first()
257
258
    def get_rename_dict(self, from_index="gene_id", to_index="gene_name"):
259
        """
260
        Args:
261
            from_index:
262
            to_index:
263
        """
264
        ensembl_id_to_gene_name = pd.Series(
265
            self.data[to_index].values, index=self.data[from_index]).to_dict()
266
        return ensembl_id_to_gene_name
267
268
269
class UniProt(SequenceDatabase):
270
    COLUMNS_RENAME_DICT = {
271
        # idmapping_selected.tab
272
        "UniProtKB-AC": 'protein_id',
273
        "UniProtKB-ID": 'protein_name',
274
        "Ensembl": "gene_id",
275
        "Ensembl_TRS": "transcript_id",
276
        "Ensembl_PRO": "protein_embl_id",
277
        "NCBI-taxon": "species_id",
278
        "GeneID (EntrezGene)": "entrezgene_id",
279
        "GO": "go_id",
280
        # FASTA headers
281
        "OS": 'species', "OX": 'species_id', 'GN': 'gene_name', 'PE': 'ProteinExistence', 'SV': "version",
282
        # UniProt XML headers
283
        "accession": "UniProtKB-AC", "name": "protein_name", "gene": "gene_name", "keyword": "keywords",
284
        "geneLocation": "subcellular_location",
285
286
    }
287
288
    SPECIES_ID_NAME = {
289
        '10090': 'MOUSE', '10116': 'RAT', '226900': 'BACCR', '243273': 'MYCGE', '284812': 'SCHPO', '287': 'PSEAI',
290
        '3702': 'ARATH', '99287': 'SALTY', '44689': 'DICDI', '4577': 'MAIZE', '559292': 'YEAST', '6239': 'CAEEL',
291
        '7227': 'DROME', '7955': 'DANRE', '83333': 'ECOLI', '9606': 'HUMAN', '9823': 'PIG', }
292
293
    SPECIES_ID_TAXONOMIC = {
294
        'HUMAN': 'human', 'MOUSE': 'rodents', 'RAT': 'rodents', 'BACCR': 'bacteria', 'MYCGE': 'bacteria',
295
        'SCHPO': 'fungi', 'PSEAI': 'bacteria', 'ARATH': 'plants', 'SALTY': 'bacteria', 'DICDI': 'bacteria',
296
        'MAIZE': 'plants', 'YEAST': 'fungi', 'CAEEL': 'vertebrates', 'DROME': 'invertebrates', 'DANRE': 'vertebrates',
297
        'ECOLI': 'bacteria', 'PIG': 'mammals',
298
    }
299
300
    def __init__(self, path="https://ftp.uniprot.org/pub/databases/uniprot/current_release/",
301
                 file_resources: Dict[str, str] = None,
302
                 species_id: str = "9606", remove_version_num=True,
303
                 index_col='UniProtKB-AC', keys=None,
304
                 col_rename=COLUMNS_RENAME_DICT,
305
                 **kwargs):
306
        """
307
        Loads the UniProt database from https://uniprot.org/ .
308
309
        Default path: 'https://ftp.uniprot.org/pub/databases/uniprot/current_release/'
310
        Default file_resources: {
311
            file_resources['uniprot_sprot.xml.gz'] = "knowledgebase/complete/uniprot_sprot.xml.gz
312
            file_resources['uniprot_trembl.xml.gz'] = "knowledgebase/complete/uniprot_trembl.xml.gz
313
            file_resources["idmapping_selected.tab.gz"] = "knowledgebase/idmapping/idmapping_selected.tab.gz'
314
            file_resources["proteomes.tsv"] = "https://rest.uniprot.org/proteomes/stream?compressed=true&
315
                fields=upid%2Corganism%2Corganism_id&format=tsv&query=%28%2A%29%20AND%20%28proteome_type%3A1%29"
316
            file_resources['speclist.txt'] = 'https://ftp.uniprot.org/pub/databases/uniprot/current_release/
317
                knowledgebase/complete/docs/speclist'
318
        }
319
320
        Args:
321
            path:
322
            file_resources:
323
            col_rename:
324
            verbose:
325
            blocksize:
326
        """
327
        self.species_id = species_id
328
        self.species = UniProt.SPECIES_ID_NAME[species_id] if species_id in UniProt.SPECIES_ID_NAME else None
329
        self.taxonomic_id = UniProt.SPECIES_ID_TAXONOMIC[
330
            self.species] if self.species in UniProt.SPECIES_ID_TAXONOMIC else None
331
        self.remove_version_num = remove_version_num
332
333
        if file_resources is None:
334
            file_resources = {}
335
336
            file_resources['uniprot_sprot.xml.gz'] = os.path.join(path, "knowledgebase/complete/uniprot_sprot.xml.gz")
337
            file_resources['uniprot_trembl.xml.gz'] = os.path.join(path, "knowledgebase/complete/uniprot_trembl.xml.gz")
338
            file_resources["idmapping_selected.tab.gz"] = os.path.join(
339
                path, "knowledgebase/idmapping/idmapping_selected.tab.gz")
340
341
            if self.species:
342
                file_resources['uniprot_sprot.xml.gz'] = os.path.join(
343
                    path, "knowledgebase/taxonomic_divisions/", f'uniprot_sprot_{self.taxonomic_id}.xml.gz')
344
                file_resources['uniprot_trembl.xml.gz'] = os.path.join(
345
                    path, "knowledgebase/taxonomic_divisions/", f'uniprot_trembl_{self.taxonomic_id}.xml.gz')
346
                file_resources["idmapping_selected.tab.gz"] = os.path.join(
347
                    path, "knowledgebase/idmapping/by_organism/",
348
                    f'{self.species}_{self.species_id}_idmapping_selected.tab.gz')
349
350
            file_resources["proteomes.tsv"] = \
351
                "https://rest.uniprot.org/proteomes/stream?compressed=true&fields=upid%2Corganism%2Corganism_id&format=tsv&query=%28%2A%29%20AND%20%28proteome_type%3A1%29"
352
353
        super().__init__(path=path, file_resources=file_resources, index_col=index_col, keys=keys,
354
                         col_rename=col_rename,
355
                         **kwargs)
356
357
    def load_dataframe(self, file_resources, blocksize=None):
358
        """
359
        Args:
360
            file_resources:
361
            blocksize:
362
        """
363
        # Load idmapping_selected.tab
364
        args = dict(
365
            names=['UniProtKB-AC', 'UniProtKB-ID', 'GeneID (EntrezGene)', 'RefSeq', 'GI', 'PDB', 'GO', 'UniRef100',
366
                   'UniRef90', 'UniRef50', 'UniParc', 'PIR', 'NCBI-taxon', 'MIM', 'UniGene', 'PubMed', 'EMBL',
367
                   'EMBL-CDS', 'Ensembl', 'Ensembl_TRS', 'Ensembl_PRO', 'Additional PubMed'],
368
            usecols=['UniProtKB-AC', 'UniProtKB-ID', 'GeneID (EntrezGene)', 'RefSeq', 'GI', 'PDB', 'GO',
369
                     'NCBI-taxon', 'Ensembl', 'Ensembl_TRS', 'Ensembl_PRO'],
370
            dtype='str')
371
372
        if blocksize:
373
            if "idmapping_selected.parquet" in file_resources and \
374
                isinstance(file_resources["idmapping_selected.parquet"], str):
375
                idmapping = dd.read_parquet(file_resources["idmapping_selected.parquet"])
376
377
            elif "idmapping_selected.tab" in file_resources and \
378
                isinstance(file_resources["idmapping_selected.tab"], str):
379
                idmapping = dd.read_table(file_resources["idmapping_selected.tab"], blocksize=blocksize, **args)
380
            else:
381
                idmapping = dd.read_table(file_resources["idmapping_selected.tab.gz"], compression="gzip", **args, )
382
383
            idmapping: dd.DataFrame
384
        else:
385
            if "idmapping_selected.parquet" in file_resources and \
386
                isinstance(file_resources["idmapping_selected.parquet"], str):
387
                idmapping = pd.read_parquet(file_resources["idmapping_selected.parquet"])
388
            else:
389
                idmapping = pd.read_table(file_resources["idmapping_selected.tab"], index_col=self.index_col, **args)
390
391
        # Filter UniProt accession keys
392
        if self.keys is not None and idmapping.index.name == self.index_col:
393
            idmapping = idmapping.loc[idmapping.index.isin(self.keys)]
394
        elif self.keys is not None and idmapping.index.name != self.index_col:
395
            idmapping = idmapping.loc[idmapping[self.index_col].isin(self.keys)]
396
397
        if idmapping.index.name != self.index_col:
398
            idmapping = idmapping.set_index(self.index_col, sorted=False)
399
        if not idmapping.known_divisions:
400
            idmapping.divisions = idmapping.compute_current_divisions()
401
402
        # Transform list columns
403
        if isinstance(idmapping, dd.DataFrame):
404
            idmapping = idmapping.assign(**self.assign_transforms(idmapping))
405
        else:
406
            idmapping = idmapping.assign(**self.assign_transforms(idmapping))
407
408
        # Join metadata from uniprot_sprot.parquet
409
        if any(fn.startswith('uniprot') and fn.endswith('.parquet') for fn in file_resources):
410
            uniprot_anns = self.load_uniprot_parquet(file_resources, blocksize=blocksize)
411
            uniprot_anns = uniprot_anns[uniprot_anns.columns.difference(idmapping.columns)]
412
            uniprot_anns = drop_duplicate_columns(uniprot_anns)
413
            assert idmapping.index.name == uniprot_anns.index.name, f"{idmapping.index.name} != {uniprot_anns.index.name}"
414
            idmapping = idmapping.join(uniprot_anns, on=idmapping.index.name, how='left')
415
416
        # Load proteome.tsv
417
        if "proteomes.tsv" in file_resources:
418
            proteomes = pd.read_table(file_resources["proteomes.tsv"],
419
                                      usecols=['Organism Id', 'Proteome Id'],
420
                                      dtype={'Organism Id': 'str', 'Proteome Id': 'str'}) \
421
                .rename(columns={'Organism Id': 'NCBI-taxon', 'Proteome Id': 'proteome_id'}) \
422
                .dropna().set_index('NCBI-taxon')
423
            idmapping = idmapping.join(proteomes, on='NCBI-taxon')
424
425
        # Load species info from speclist.txt
426
        if 'speclist.txt' in file_resources:
427
            speclist = self.get_species_list(file_resources['speclist.txt'])
428
            idmapping = idmapping.join(speclist, on='NCBI-taxon')
429
430
        return idmapping
431
432
    def assign_transforms(self, idmapping: pd.DataFrame) -> Dict[str, Union[dd.Series, pd.Series]]:
433
        # Convert string of list elements to a np.array
434
        list2array = lambda x: np.array(x) if isinstance(x, Iterable) else x
435
        assign_fn = {}
436
        for col in {'PDB', 'GI', 'GO', 'RefSeq'}.intersection(idmapping.columns):
437
            try:
438
                # Split string to list
439
                assign_fn[col] = idmapping[col].str.split("; ").map(list2array)
440
            except:
441
                continue
442
443
        for col in {'Ensembl', 'Ensembl_TRS', 'Ensembl_PRO'}.intersection(idmapping.columns):
444
            # Removing .# ENGS gene version number at the end
445
            try:
446
                if self.remove_version_num:
447
                    series = idmapping[col].str.replace("[.]\d*", "", regex=True)
448
                else:
449
                    series = idmapping[col]
450
451
                assign_fn[col] = series.str.split("; ").map(list2array)
452
453
                if col == 'Ensembl_PRO':
454
                    # Prepend species_id to ensembl protein ids to match with STRING PPI
455
                    concat = dd.concat([idmapping["NCBI-taxon"], assign_fn[col]]) \
456
                        if isinstance(idmapping, dd.DataFrame) else \
457
                        pd.concat([idmapping["NCBI-taxon"], assign_fn[col]])
458
459
                    assign_fn['protein_external_id'] = concat.apply(
460
                        lambda row: np.char.add(row['NCBI-taxon'] + ".", row['Ensembl_PRO']) \
461
                            if isinstance(row['Ensembl_PRO'], Iterable) else None,
462
                        axis=1)
463
            except:
464
                continue
465
466
        return assign_fn
467
468
    def load_uniprot_parquet(self, file_resources: Dict[str, str], blocksize=None) -> Union[dd.DataFrame, pd.DataFrame]:
469
        dfs = []
470
        for filename, file_path in file_resources.items():
471
            if not ('uniprot' in filename and filename.endswith('.parquet')): continue
472
            if blocksize:
473
                df: dd.DataFrame = dd.read_parquet(file_path) \
474
                    .rename(columns=UniProt.COLUMNS_RENAME_DICT)
475
                if df.index.name in UniProt.COLUMNS_RENAME_DICT:
476
                    df.index = df.index.rename(UniProt.COLUMNS_RENAME_DICT[df.index.name])
477
478
                if self.keys is not None:
479
                    if self.index_col in df.columns:
480
                        df = df.loc[df[self.index_col].isin(self.keys)]
481
                    elif df.index.name == self.index_col:
482
                        df = df.loc[df.index.isin(self.keys)]
483
484
                if df.index.size.compute() == 0: continue
485
486
                if df.index.name != self.index_col:
487
                    try:
488
                        df = df.set_index(self.index_col, sorted=True)
489
                    except Exception as e:
490
                        print(file_path, e)
491
                        df = df.set_index(self.index_col, sorted=False)
492
493
                if not df.known_divisions:
494
                    df.divisions = df.compute_current_divisions()
495
496
            else:
497
                df = pd.read_parquet(file_path).rename(columns=UniProt.COLUMNS_RENAME_DICT).set_index(self.index_col)
498
499
                if self.keys is not None:
500
                    df_keys = df.index if df.index.name == self.index_col else df[self.index_col]
501
                    df = df.loc[df_keys.isin(self.keys)]
502
                if df.index.size == 0: continue
503
504
            dfs.append(df)
505
506
        if dfs:
507
            dfs = dd.concat(dfs, interleave_partitions=True) if blocksize else pd.concat(dfs)
508
509
            return dfs
510
        else:
511
            return None
512
513
    def load_uniprot_xml(self, file_path: str, keys=None, blocksize=None) -> pd.DataFrame:
514
        records = []
515
        seqfeats = []
516
        if isinstance(keys, str):
517
            index = keys
518
            keys_set = self.data.index if keys == self.data.index.name else self.data[keys]
519
        elif isinstance(keys, (dd.Index, dd.Series)):
520
            index = keys.name
521
            keys_set = keys.compute()
522
        else:
523
            index = keys_set = None
524
525
        for record in tqdm.tqdm(SeqIO.parse(file_path, "uniprot-xml"), desc=str(file_path)):
526
            # Sequence features
527
            annotations = defaultdict(None, record.annotations)
528
            record_dict = {
529
                'protein_id': record.id,
530
                "protein_name": record.name,
531
                'gene_name': annotations['gene_name_primary'],
532
                'description': record.description,
533
                'molecule_type': annotations['molecule_type'],
534
                'created': annotations['created'],
535
                'ec_id': annotations['type'],
536
                'subcellular_location': annotations['comment_subcellularlocation_location'],
537
                'taxonomy': annotations['taxonomy'],
538
                'keywords': annotations['keywords'],
539
                'sequence_mass': annotations['sequence_mass'],
540
                SEQUENCE_COL: str(record.seq),
541
            }
542
            if index is not None:
543
                if record_dict[keys] not in keys_set: continue
544
545
            records.append(record_dict)
546
547
            # Sequence interval features
548
            _parse_interval = lambda sf: pd.Interval(left=sf.location.start, right=sf.location.end, )
549
            feature_type_intervals = defaultdict(lambda: [])
550
            for sf in record.features:
551
                if isinstance(sf.location.start, ExactPosition) and isinstance(sf.location.end, ExactPosition):
552
                    feature_type_intervals[sf.type].append(_parse_interval(sf))
553
554
            features_dict = {type: pd.IntervalIndex(intervals, name=type) \
555
                             for type, intervals in feature_type_intervals.items()}
556
            seqfeats.append({"protein_id": record.id, **features_dict})
557
558
        records_df = pd.DataFrame(records) if not blocksize else dd.from_pandas(records, chunksize=blocksize)
559
        records_df = records_df.set_index(['protein_id'])
560
561
        seqfeats_df = pd.DataFrame(seqfeats) if not blocksize else dd.from_pandas(seqfeats, chunksize=blocksize)
562
        seqfeats_df = seqfeats_df.set_index(['protein_id'])
563
        seqfeats_df.columns = [f"seq/{col}" for col in seqfeats_df.columns]
564
565
        # Join new metadata to self.data
566
        if SEQUENCE_COL not in self.data.columns:
567
            exclude_cols = records_df.columns.intersection(self.data.columns)
568
            self.data = self.data.join(records_df.drop(columns=exclude_cols, errors="ignore"),
569
                                       on='protein_id', how="left")
570
        else:
571
            self.data.update(records_df, overwrite=False)
572
573
        # Add new seq features
574
        if len(seqfeats_df.columns.difference(self.data.columns)):
575
            self.data = self.data.join(seqfeats_df.drop(columns=seqfeats_df.columns.intersection(self.data.columns)),
576
                                       on='protein_id', how="left")
577
        # Fillna seq features
578
        if len(seqfeats_df.columns.intersection(self.data.columns)):
579
            self.data.update(seqfeats_df.filter(seqfeats_df.columns.intersection(self.data.columns), axis='columns'),
580
                             overwrite=False)
581
582
        return records_df
583
584
    @classmethod
585
    def get_species_list(cls, file_path):
586
        speclist = pd.read_fwf(file_path,
587
                               names=['species_code', 'Taxon', 'species_id', 'attr'],
588
                               comment="==", skipinitialspace=True, skiprows=59, skipfooter=4)
589
        speclist = speclist.drop(index=speclist.index[~speclist['attr'].str.contains("=")])
590
        speclist['species_id'] = speclist['species_id'].str.rstrip(":")
591
        speclist = speclist.fillna(method='ffill')
592
        speclist = speclist.groupby(speclist.columns[:3].tolist())['attr'] \
593
            .apply('|'.join) \
594
            .apply(lambda s: dict(map(str.strip, sub.split('=', 1)) for sub in s.split("|") if '=' in sub)) \
595
            .apply(pd.Series)
596
        speclist = speclist.rename(columns={'N': 'Official (scientific) name', 'C': 'Common name', 'S': 'Synonym'}) \
597
            .reset_index() \
598
            .set_index('species_id')
599
        speclist['Taxon'] = speclist['Taxon'].replace(
600
            {'A': 'archaea', 'B': 'bacteria', 'E': 'eukaryota', 'V': 'viruses', 'O': 'others'})
601
        speclist.index.name = 'NCBI-taxon'
602
        return speclist
603
604
    def load_sequences(self, fasta_file: str, index=None, keys: Union[pd.Index, List[str]] = None, blocksize=None) \
605
        -> OrderedDict:
606
        def get_id(s: str):
607
            if index == 'protein_id':
608
                return s.split('|')[1]
609
            elif index == 'protein_name':
610
                return s.split('|')[2]
611
            else:
612
                return s.split('|')[1]
613
614
        fa = Fasta(fasta_file, key_function=get_id, as_raw=True, )
615
616
        return fa.records
617
618
    def get_sequences(self, index: str, omic: str = None, agg: str = "first", **kwargs):
619
        assert index, '`index` must be either "protein_id" or "protein_name"'
620
621
        # Parse lncRNA & mRNA fasta
622
        seq_df = self.load_sequences(self.file_resources["uniprot_sprot.fasta"], index=index, blocksize=self.blocksize)
623
        if "uniprot_trembl.fasta" in self.file_resources:
624
            trembl_seq_df = self.load_sequences(self.file_resources["uniprot_trembl.fasta"], index=index,
625
                                                blocksize=self.blocksize)
626
            seq_df.update(trembl_seq_df)
627
628
        return seq_df
629
630
class MirBase(SequenceDatabase):
631
    """Loads the MirBase database from https://mirbase.org .
632
633
    Default path: "ftp://mirbase.org/pub/mirbase/CURRENT/" .
634
    Default file_resources: {
635
        "aliases.txt": "aliases.txt.gz",
636
        "mature.fa": "mature.fa.gz",
637
        "hairpin.fa": "hairpin.fa.gz",
638
        "rnacentral.mirbase.tsv": "ftp://ftp.ebi.ac.uk/pub/databases/RNAcentral/current_release/id_mapping/database_mappings/mirbase.tsv",
639
    }
640
    """
641
642
    def __init__(
643
        self,
644
        path="http://mirbase.org/ftp/CURRENT/",
645
        file_resources=None,
646
        species_id: Optional[str] = '9606',
647
        index_col: str = "mirbase_id",
648
        col_rename=None,
649
        **kwargs,
650
    ):
651
        """
652
        Args:
653
            path:
654
            file_resources:
655
            sequence (str):
656
            species_id (str): Species code, e.g., 9606 for human
657
            col_rename:
658
            blocksize:
659
        """
660
        if file_resources is None:
661
            file_resources = {}
662
            file_resources["aliases.txt.gz"] = "aliases.txt.gz"
663
            file_resources["mature.fa.gz"] = "mature.fa.gz"
664
            file_resources["hairpin.fa.gz"] = "hairpin.fa.gz"
665
666
        if 'rnacentral.mirbase.tsv' not in file_resources:
667
            file_resources["rnacentral.mirbase.tsv"] = "ftp://ftp.ebi.ac.uk/pub/databases/RNAcentral/current_release/" \
668
                                                       "id_mapping/database_mappings/mirbase.tsv"
669
670
        self.species_id = species_id
671
        super().__init__(path=path, file_resources=file_resources, index_col=index_col, col_rename=col_rename, **kwargs)
672
673
    def load_dataframe(self, file_resources, blocksize=None):
674
        """
675
        Args:
676
            file_resources: dict of file name and path
677
            blocksize:
678
        """
679
        rnacentral_mirbase = pd.read_table(
680
            file_resources["rnacentral.mirbase.tsv"], low_memory=True, header=None,
681
            names=["RNAcentral id", "database", "mirbase_id", "species_id", "RNA type", "NA"],
682
            usecols=["RNAcentral id", "database", "mirbase_id", "species_id", "RNA type"],
683
            index_col="RNAcentral id",
684
            dtype={'mirbase_id': 'str', "species_id": "category", 'database': 'category', 'RNA type': 'category'})
685
686
        if isinstance(self.species_id, str):
687
            rnacentral_mirbase = rnacentral_mirbase[rnacentral_mirbase["species_id"] == self.species_id]
688
        elif isinstance(self.species_id, Iterable):
689
            rnacentral_mirbase = rnacentral_mirbase[rnacentral_mirbase["species_id"].isin(set(self.species_id))]
690
691
        mirbase_df = pd.read_table(file_resources["aliases.txt"], low_memory=True, header=None,
692
                                   names=["mirbase_id", "mirbase_name"], index_col=self.index_col,
693
                                   dtype='str', )
694
        if mirbase_df.index.name == 'mirbase id':
695
            mirbase_df = mirbase_df.join(rnacentral_mirbase, on=self.index_col, how="left", rsuffix='_rnacentral')
696
        else:
697
            mirbase_df = mirbase_df.merge(rnacentral_mirbase, on=self.index_col, how="left")
698
699
        # Expanding miRNA names in each MirBase Ascension ID
700
        mirbase_df['mirbase_name'] = mirbase_df['mirbase_name'].str.rstrip(";").str.split(";")
701
702
        seq_dfs = []
703
        for filename in file_resources:
704
            if filename.endswith('.fa') or filename.endswith('.fasta'):
705
                assert isinstance(file_resources[filename], str), f"Must provide a path to an uncompressed .fa file. " \
706
                                                                  f"Given {file_resources[filename]}"
707
                df = self.load_sequences(file_resources[filename], index=self.index_col, keys=self.keys)
708
                seq_dfs.append(df)
709
710
        if len(seq_dfs):
711
            seq_dfs = pd.concat(seq_dfs, axis=0)
712
            mirbase_df = mirbase_df.join(seq_dfs, how='left', on=self.index_col)
713
        else:
714
            logger.info('Missing sequence data because no "hairpin.fa" or "mature.fa" file were given.')
715
716
        # mirbase_df = mirbase_df.explode(column='gene_name')
717
        # mirbase_name["miRNA name"] = mirbase_name["miRNA name"].str.lower()
718
        # mirbase_name["miRNA name"] = mirbase_name["miRNA name"].str.replace("-3p.*|-5p.*", "")
719
720
        return mirbase_df
721
722
    def load_sequences(self, fasta_file, index=None, keys=None, blocksize=None):
723
        """
724
        Args:
725
            fasta_file:
726
            index ():
727
            keys ():
728
            blocksize:
729
        """
730
        if hasattr(self, '_seq_df_dict') and fasta_file in self._seq_df_dict:
731
            return self._seq_df_dict[fasta_file]
732
733
        fa = Fasta(fasta_file, read_long_names=True, as_raw=True)
734
        mirna_types = {'stem-loop', 'stem', 'type', 'loop'}
735
736
        entries = []
737
        for key, record in tqdm.tqdm(fa.items(), desc=str(fasta_file)):
738
            attrs: List[str] = record.long_name.split(" ")
739
740
            if attrs[-1] in mirna_types:
741
                if attrs[-2] in mirna_types:
742
                    mirna_type = intern(' '.join(attrs[-2:]))
743
                    gene_name_idx = -3
744
                else:
745
                    mirna_type = intern(attrs[-1])
746
                    gene_name_idx = -2
747
            else:
748
                mirna_type = None
749
                gene_name_idx = -1
750
751
            record_dict = {
752
                "gene_id": attrs[0],
753
                "mirbase_id": attrs[1],
754
                "species": intern(" ".join(attrs[2:gene_name_idx])),
755
                "gene_name": attrs[gene_name_idx],
756
                "mirna_type": mirna_type,
757
                SEQUENCE_COL: str(record),
758
            }
759
            if keys is not None and index:
760
                if record_dict[index] not in keys:
761
                    del record_dict
762
                    continue
763
764
            entries.append(record_dict)
765
766
        df = pd.DataFrame(entries)
767
        if index:
768
            df = df.set_index(index)
769
        # if blocksize:
770
        #     df = dd.from_pandas(df, chunksize=blocksize)
771
772
        if not hasattr(self, '_seq_df_dict'):
773
            self._seq_df_dict = {}
774
        self._seq_df_dict[fasta_file] = df
775
776
        return df
777
778
    def get_sequences(self,
779
                      index="gene_name",
780
                      omic=None,
781
                      agg="all",
782
                      **kwargs):
783
        """
784
        Args:
785
            index:
786
            omic:
787
            agg:
788
            **kwargs:
789
        """
790
        dfs = []
791
        for filename in self.file_resources:
792
            if filename.endswith('.fa'):
793
                seq_df = self.load_sequences(self.file_resources[filename])
794
                dfs.append(seq_df)
795
        seq_df = pd.concat(dfs, axis=0)
796
797
        seq_df = seq_df.groupby(index)[SEQUENCE_COL].agg(self.aggregator_fn(agg))
798
799
        return seq_df
800
801
802
class RNAcentral(SequenceDatabase):
803
    """
804
    Loads the RNAcentral database from https://rnacentral.org/ and provides a series of methods to extract sequence data from it.
805
806
    Default path: https://ftp.ebi.ac.uk/pub/databases/RNAcentral/current_release/ .
807
    Default file_resources: {
808
        "rnacentral_rfam_annotations.tsv": "go_annotations/rnacentral_rfam_annotations.tsv.gz",
809
        "database_mappings/gencode.tsv": "id_mapping/database_mappings/gencode.tsv",
810
        "gencode.fasta": "sequences/by-database/gencode.fasta",
811
        ...
812
    }
813
    """
814
    COLUMNS_RENAME_DICT = {
815
        'ensembl_gene_id': 'gene_id',
816
        'external id': 'transcript_id',
817
        'GO terms': 'go_id',
818
        'gene symbol': 'gene_id',
819
    }
820
821
    def __init__(self, path="https://ftp.ebi.ac.uk/pub/databases/RNAcentral/current_release/", file_resources=None,
822
                 col_rename=COLUMNS_RENAME_DICT, species_id: Union[List[str], str, None] = None,
823
                 index_col="RNAcentral id", keys=None,
824
                 remove_version_num=True, remove_species_suffix=True, **kwargs):
825
        """
826
        Provide
827
828
        Args:
829
            path ():
830
            file_resources ():
831
            col_rename ():
832
            species_id ():
833
            index_col ():
834
            keys ():
835
            remove_version_num ():
836
            remove_species_suffix ():
837
            **kwargs ():
838
        """
839
        self.species_id = species_id
840
        self.remove_version_num = remove_version_num
841
        self.remove_species_suffix = remove_species_suffix
842
843
        if file_resources is None:
844
            file_resources = {}
845
            file_resources["rnacentral_rfam_annotations.tsv.gz"] = "go_annotations/rnacentral_rfam_annotations.tsv.gz"
846
            file_resources["database_mappings/ensembl_gencode.tsv"] = "id_mapping/database_mappings/ensembl_gencode.tsv"
847
            file_resources["database_mappings/mirbase.tsv"] = "id_mapping/database_mappings/mirbase.tsv"
848
849
        super().__init__(path=path, file_resources=file_resources, col_rename=col_rename, index_col=index_col,
850
                         keys=keys, **kwargs)
851
852
    def load_dataframe(self, file_resources, blocksize=None):
853
        """
854
        Args:
855
            file_resources:
856
            blocksize:
857
        """
858
        transcripts_df = []
859
        # Concatenate transcripts ids by combining `database_mappings/` files from multiple RNAcentral databases
860
        for filename in (fname for fname in file_resources if "database_mappings" in fname):
861
            args = dict(low_memory=True, header=None,
862
                        names=["RNAcentral id", "database", "external id", "species_id", "RNA type", "gene symbol"],
863
                        dtype={'gene symbol': 'str',
864
                               'database': 'category', 'species_id': 'category', 'RNA type': 'category', })
865
866
            if blocksize:
867
                if filename.endswith('.tsv'):
868
                    id_mapping: dd.DataFrame = dd.read_table(
869
                        file_resources[filename], blocksize=None if isinstance(blocksize, bool) else blocksize, **args)
870
                elif filename.endswith('.parquet'):
871
                    id_mapping: dd.DataFrame = dd.read_parquet(
872
                        file_resources[filename], blocksize=None if isinstance(blocksize, bool) else blocksize, )
873
                else:
874
                    id_mapping = None
875
            else:
876
                if filename.endswith('.tsv'):
877
                    id_mapping = pd.read_table(file_resources[filename], **args)
878
                elif filename.endswith('.parquet'):
879
                    id_mapping = pd.read_parquet(file_resources[filename])
880
                else:
881
                    id_mapping = None
882
883
            if id_mapping is None:
884
                raise Exception("Must provide a file with 'database_mappings/(*).tsv' in file_resources")
885
886
            # Filter by species
887
            if isinstance(self.species_id, str):
888
                id_mapping = id_mapping.where(id_mapping["species_id"] == self.species_id)
889
            elif isinstance(self.species_id, Iterable):
890
                id_mapping = id_mapping.where(id_mapping["species_id"].isin(self.species_id))
891
892
            # Filter by index
893
            if self.keys and id_mapping.index.name == self.index_col:
894
                id_mapping = id_mapping.loc[id_mapping.index.isin(self.keys)]
895
            elif self.keys and id_mapping.index.name != self.index_col:
896
                id_mapping = id_mapping.loc[id_mapping[self.index_col].isin(self.keys)]
897
898
            # Add species_id prefix to index values to match the sequence ids
899
            if "RNAcentral id" in id_mapping.columns:
900
                id_mapping["RNAcentral id"] = id_mapping["RNAcentral id"] + "_" + id_mapping["species_id"].astype(str)
901
            elif "RNAcentral id" == id_mapping.index.name:
902
                id_mapping.index = id_mapping.index + "_" + id_mapping["species_id"].astype(str)
903
904
            # Add sequence column if a FASTA file provided for the database
905
            fasta_filename = f"{filename.split('/')[-1].split('.')[0]}.fasta"
906
            if fasta_filename in file_resources:
907
                seq_df = self.load_sequences(file_resources[fasta_filename])
908
                id_mapping = id_mapping.merge(seq_df, how='left',
909
                                              left_on="RNAcentral id",
910
                                              left_index=True if id_mapping.index.name == "RNAcentral id" else False,
911
                                              right_index=True)
912
            else:
913
                logger.info(f"{fasta_filename} not provided for `{filename}` so missing sequencing data")
914
915
            if self.remove_version_num and 'gene symbol' in id_mapping.columns:
916
                id_mapping["gene symbol"] = id_mapping["gene symbol"].str.replace("[.].\d*", "", regex=True)
917
            if self.remove_species_suffix:
918
                id_mapping["RNAcentral id"] = id_mapping["RNAcentral id"].str.replace("_(\d*)", '', regex=True)
919
920
            # Set index
921
            args = dict(sorted=True) if blocksize else {}
922
            id_mapping = id_mapping.set_index(self.index_col, **args)
923
            if isinstance(id_mapping, dd.DataFrame) and not id_mapping.known_divisions:
924
                id_mapping.divisions = id_mapping.compute_current_divisions()
925
926
            transcripts_df.append(id_mapping)
927
928
        # Concatenate multiple `database_mappings` files from different databases
929
        if blocksize:
930
            transcripts_df = dd.concat(transcripts_df, axis=0, interleave_partitions=True, join='outer')
931
        else:
932
            transcripts_df = pd.concat(transcripts_df, axis=0, join='outer')
933
934
        # Join go_id and Rfams annotations to each "RNAcentral id" from 'rnacentral_rfam_annotations.tsv'
935
        try:
936
            transcripts_df = self.add_rfam_annotation(transcripts_df, file_resources, blocksize)
937
        except Exception as e:
938
            logger.warning(f"Failed to add Rfam annotations to transcripts_df:")
939
            traceback.print_exc()
940
941
        return transcripts_df
942
943
    def add_rfam_annotation(self, transcripts_df: Union[pd.DataFrame, dd.DataFrame],
944
                            file_resources, blocksize=None) -> Union[pd.DataFrame, dd.DataFrame]:
945
        args = dict(low_memory=True, names=["RNAcentral id", "GO terms", "Rfams"])
946
947
        if blocksize:
948
            if 'rnacentral_rfam_annotations.tsv' in file_resources and isinstance(
949
                file_resources['rnacentral_rfam_annotations.tsv'], str):
950
                anns = dd.read_table(file_resources["rnacentral_rfam_annotations.tsv"], **args)
951
            else:
952
                anns = dd.read_table(file_resources["rnacentral_rfam_annotations.tsv.gz"], compression="gzip", **args)
953
            anns = anns.set_index("RNAcentral id", sorted=True)
954
955
            # Filter annotations by "RNAcentral id" in `transcripts_df`
956
            anns = anns.loc[anns.index.isin(transcripts_df.index.compute())]
957
958
            if not anns.known_divisions:
959
                anns.divisions = anns.compute_current_divisions()
960
961
            # Groupby on index
962
            anns_groupby: dd.DataFrame = anns \
963
                .groupby(by=lambda idx: idx) \
964
                .agg({col: get_agg_func('unique', use_dask=True) for col in ["GO terms", 'Rfams']})
965
966
        else:
967
            anns = pd.read_table(file_resources["rnacentral_rfam_annotations.tsv"], index_col='RNAcentral id', **args)
968
            idx = transcripts_df.index.compute() if isinstance(transcripts_df, dd.DataFrame) else transcripts_df.index
969
            anns = anns.loc[anns.index.isin(set(idx))]
970
            anns_groupby = anns.groupby("RNAcentral id").agg({col: 'unique' for col in ["GO terms", 'Rfams']})
971
972
        transcripts_df = transcripts_df.merge(anns_groupby, how='left', left_index=True, right_index=True)
973
        return transcripts_df
974
975
    def load_sequences(self, fasta_file: str, index=None, keys=None, blocksize=None):
976
        """
977
        Args:
978
            index ():
979
            fasta_file:
980
            keys ():
981
            blocksize:
982
        """
983
        fa = Fasta(fasta_file, as_raw=True)
984
985
        entries = []
986
        for key, record in tqdm.tqdm(fa.items(), desc=str(fasta_file)):
987
            id = re.sub("_(\d*)", '', key) if self.remove_species_suffix else key
988
            if keys is not None and self.index_col == 'RNAcentral id' and id not in keys:
989
                continue
990
            desc = record.long_name.split(" ", maxsplit=1)[-1]
991
992
            record_dict = {
993
                'RNAcentral id': key,
994
                'description': desc,
995
                SEQUENCE_COL: str(record),
996
            }
997
998
            entries.append(record_dict)
999
1000
        df = pd.DataFrame(entries).set_index("RNAcentral id")
1001
1002
        return df
1003
1004
    def get_sequences(self,
1005
                      index="RNAcentral id",
1006
                      omic=None,
1007
                      agg="all",
1008
                      **kwargs):
1009
        """
1010
        Args:
1011
            index:
1012
            omic:
1013
            agg:
1014
            **kwargs:
1015
        """
1016
        dfs = []
1017
        for filename in self.file_resources:
1018
            if filename.endswith('.fa') or filename.endswith('.fasta'):
1019
                seq_df = self.load_sequences(self.file_resources[filename])
1020
                dfs.append(seq_df)
1021
        seq_df = pd.concat(dfs, axis=0)
1022
1023
        seq_df = seq_df.groupby(index)[SEQUENCE_COL].agg(self.aggregator_fn(agg))
1024
1025
        return seq_df