Switch to unified view

a b/singlecellmultiomics/molecule/featureannotatedmolecule.py
1
from singlecellmultiomics.molecule.molecule import Molecule
2
import collections
3
import pandas as pd
4
5
class TranscriptMolecule(Molecule):
6
7
    def __init__(self, fragment,
8
             **kwargs):
9
        self.genes=set()
10
        Molecule.__init__(self, fragment, **kwargs)
11
12
13
    def _add_fragment(self, fragment):
14
15
        self.genes.update(fragment.genes)
16
        Molecule._add_fragment(self, fragment)
17
18
    def write_tags(self):
19
20
        for frag in self:
21
            frag.write_tags()
22
23
        Molecule.write_tags(self)
24
25
26
27
class FeatureAnnotatedMolecule(Molecule):
28
    """Molecule which is annotated with features (genes/exons/introns, .. )
29
    """
30
31
    def __init__(
32
            self,
33
            fragment,
34
            features,
35
            stranded=None,
36
            auto_set_intron_exon_features=False,
37
            capture_locations=False,
38
            **kwargs):
39
        """
40
            Args:
41
                fragments (singlecellmultiomics.fragment.Fragment): Fragments to associate to the molecule
42
                features (singlecellmultiomics.features.FeatureContainer) : container to use to obtain features from
43
                stranded : None; not stranded, False: same strand as R1, True: other strand
44
                capture_locations (bool) : Store information about the locations of the aligned features
45
                auto_set_intron_exon_features(bool) : obtain intron_exon_features upon initialising
46
                **kwargs: extra args
47
48
        """
49
        Molecule.__init__(self, fragment, **kwargs)
50
        self.features = features
51
        self.hits = collections.defaultdict(set)  # feature -> hit_bases
52
        self.stranded = stranded
53
        self.is_annotated = False
54
        self.capture_locations = capture_locations
55
        if capture_locations:
56
            self.feature_locations = {} #feature->locations (chrom,start,end, strand)
57
58
        self.junctions = set()
59
        self.genes = set()
60
        self.introns = set()
61
        self.exons = set()
62
        self.exon_hit_gene_names = set()  # readable names
63
        self.is_spliced = None
64
65
        if auto_set_intron_exon_features:
66
            self.set_intron_exon_features()
67
68
    def set_spliced(self, is_spliced):
69
        """ Set wether the transcript is spliced, False has priority over True """
70
        if self.is_spliced and not is_spliced:
71
            # has already been set
72
            self.is_spliced = False
73
        else:
74
            self.is_spliced = is_spliced
75
76
77
78
    def set_intron_exon_features(self):
79
        if not self.is_annotated:
80
            self.annotate()
81
82
        # Collect all hits:
83
        hits = self.hits.keys()
84
85
        # (gene, transcript) -> set( exon_id  .. )
86
        exon_hits = collections.defaultdict(set)
87
        intron_hits = collections.defaultdict(set)
88
89
        for hit, locations in self.hits.items():
90
            if not isinstance(hit, tuple):
91
                continue
92
93
            meta = dict(list(hit))
94
            if 'gene_id' not in meta:
95
                continue
96
            if meta.get('type') == 'exon':
97
                if 'transcript_id' not in meta:
98
                    continue
99
                self.genes.add(meta['gene_id'])
100
                self.exons.add(meta['exon_id'])
101
                if 'transcript_id' not in meta:
102
                    raise ValueError(
103
                        "Please use an Intron GTF file generated using -id 'transcript_id'")
104
                exon_hits[(meta['gene_id'], meta['transcript_id'])].add(
105
                    meta['exon_id'])
106
                if 'gene_name' in meta:
107
                    self.exon_hit_gene_names.add(meta['gene_name'])
108
            elif meta.get('type') == 'intron':
109
                self.genes.add(meta['gene_id'])
110
                self.introns.add(meta['gene_id'])
111
112
        # Find junctions and add all annotations to annotation sets
113
        debug = []
114
115
        for (gene, transcript), exons_overlapping in exon_hits.items():
116
            # If two exons are detected from the same gene we detected a
117
            # junction:
118
            if len(exons_overlapping) > 1:
119
                self.junctions.add(gene)
120
121
                # We found two exons and an intron:
122
                if gene in self.introns:
123
                    self.set_spliced(False)
124
                else:
125
                    self.set_spliced(True)
126
127
            debug.append(
128
                f'{gene}_{transcript}:' +
129
                ','.join(
130
                    list(exons_overlapping)))
131
132
        # Write exon dictionary:
133
        self.set_meta('DB', ';'.join(debug))
134
135
    def get_hit_df(self):
136
        """Obtain dataframe with hits
137
        Returns:
138
            pd.DataFrame
139
        """
140
        if not self.is_annotated:
141
            self.set_intron_exon_features()
142
143
        d = {}
144
        tabulated_hits = []
145
        for hit, locations in self.hits.items():
146
            if not isinstance(hit, tuple):
147
                continue
148
            meta = dict(list(hit))
149
            for location in locations:
150
                location_dict = {
151
                    'chromosome': location[0],
152
                    'start': location[1][0],
153
                    'end': location[1][1]}
154
                location_dict.update(meta)
155
                tabulated_hits.append(location_dict)
156
157
        return pd.DataFrame(tabulated_hits)
158
159
160
    def write_tags_to_psuedoreads(self, reads, call_super=True):
161
        # @ todo needs refactor; the psuedoread should just be a Fragment too, solves all issues
162
        if call_super:
163
            Molecule.write_tags_to_psuedoreads(self, reads)
164
165
        for read in reads:
166
            if len(self.exons) > 0:
167
                read.set_tag('EX', ','.join(sorted([str(x) for x in self.exons])))
168
            else:
169
                read.set_tag('EX', None)
170
171
            if len(self.introns) > 0:
172
                read.set_tag('IN', ','.join(
173
                    sorted([str(x) for x in self.introns])))
174
            else:
175
                read.set_tag('IN', None)
176
177
            if len(self.genes) > 0:
178
                read.set_tag('GN', ','.join(sorted([str(x) for x in self.genes])))
179
            else:
180
                read.set_tag('GN', None)
181
182
            if len(self.junctions) > 0:
183
                read.set_tag('JN', ','.join(
184
                    sorted([str(x) for x in self.junctions])))
185
                # Is transcriptome
186
                read.set_tag('IT', 1)
187
            elif len(self.genes) > 0:
188
                # Maps to gene but not junction
189
                read.set_tag('IT', 0.5)
190
                read.set_tag('JN', None)
191
            else:
192
                # Doesn't map to gene
193
                read.set_tag('IT', 0)
194
                read.set_tag('JN', None)
195
196
            if self.is_spliced is True:
197
                read.set_tag('SP', True)
198
            elif self.is_spliced is False:
199
                read.set_tag('SP', False)
200
            if len(self.exon_hit_gene_names):
201
                read.set_tag('gn', ';'.join(list(self.exon_hit_gene_names)))
202
            else:
203
                read.set_tag('gn', None)
204
205
    def write_tags(self):
206
        Molecule.write_tags(self)
207
208
        # Write cell ranger tags:
209
        if self.umi is not None:
210
            self.set_meta('UB', self.umi)
211
        bc = list(self.get_barcode_sequences())[0]
212
        self.set_meta('CB', bc)
213
214
        if len(self.exons) > 0:
215
            self.set_meta('EX', ','.join(sorted([str(x) for x in self.exons])))
216
        else:
217
            self.set_meta('EX',None)
218
219
        if len(self.introns) > 0:
220
            self.set_meta('IN', ','.join(
221
                sorted([str(x) for x in self.introns])))
222
        else:
223
            self.set_meta('IN',None)
224
225
        if len(self.genes) > 0:
226
            self.set_meta('GN', ','.join(sorted([str(x) for x in self.genes])))
227
        else:
228
            self.set_meta('GN',None)
229
230
        if len(self.junctions) > 0:
231
            self.set_meta('JN', ','.join(
232
                sorted([str(x) for x in self.junctions])))
233
            # Is transcriptome
234
            self.set_meta('IT', 1)
235
        elif len(self.genes) > 0:
236
            # Maps to gene but not junction
237
            self.set_meta('IT', 0.5)
238
            self.set_meta('JN',None)
239
        else:
240
            # Doesn't map to gene
241
            self.set_meta('IT', 0)
242
            self.set_meta('JN', None)
243
244
        if self.is_spliced is True:
245
            self.set_meta('SP', True)
246
        elif self.is_spliced is False:
247
            self.set_meta('SP', False)
248
        if len(self.exon_hit_gene_names):
249
            self.set_meta('gn', ';'.join(list(self.exon_hit_gene_names)))
250
        else:
251
            self.set_meta('gn',None)
252
253
    def annotate(self, method=0):
254
        """
255
            Args:
256
                method (int) : 0, obtain blocks and then obtain features. 1, try to obtain features for every aligned base
257
258
        """
259
        # When self.stranded is None, set to None strand. If self.stranded is
260
        # True reverse the strand, otherwise copy the strand
261
        strand = None if self.stranded is None else '+-'[
262
            (not self.strand if self.stranded else self.strand)]
263
        self.is_annotated = True
264
        if method == 0:
265
266
            # Obtain all blocks:
267
            try:
268
                for start, end in self.get_aligned_blocks():
269
                    for hit in self.features.findFeaturesBetween(
270
                            chromosome=self.chromosome, sampleStart=start, sampleEnd=end, strand=strand):
271
                        hit_start, hit_end, hit_id, hit_strand, hit_ids = hit
272
                        self.hits[hit_ids].add(
273
                            (self.chromosome, (hit_start, hit_end)))
274
275
                        if self.capture_locations:
276
                            if not hit_id in self.feature_locations:
277
                                self.feature_locations[hit_id] = []
278
                            self.feature_locations[hit_id].append( (hit_start, hit_end, hit_strand))
279
280
            except TypeError:
281
                # This happens when no reads map
282
                pass
283
        else:
284
285
            for read in self.iter_reads():
286
                for q_pos, ref_pos in read.get_aligned_pairs(
287
                        matches_only=True, with_seq=False):
288
                    for hit in self.features.findFeaturesAt(
289
                            chromosome=read.reference_name, lookupCoordinate=ref_pos, strand=strand):
290
                        hit_start, hit_end, hit_id, hit_strand, hit_ids = hit
291
                        self.hits[hit_ids].add((read.reference_name, ref_pos))
292
293
                        if self.capture_locations:
294
                            if not hit_id in self.feature_locations:
295
                                self.feature_locations[hit_id] = []
296
                            self.feature_locations[hit_id].append( (hit_start, hit_end, hit_strand))