--- a +++ b/singlecellmultiomics/molecule/featureannotatedmolecule.py @@ -0,0 +1,296 @@ +from singlecellmultiomics.molecule.molecule import Molecule +import collections +import pandas as pd + +class TranscriptMolecule(Molecule): + + def __init__(self, fragment, + **kwargs): + self.genes=set() + Molecule.__init__(self, fragment, **kwargs) + + + def _add_fragment(self, fragment): + + self.genes.update(fragment.genes) + Molecule._add_fragment(self, fragment) + + def write_tags(self): + + for frag in self: + frag.write_tags() + + Molecule.write_tags(self) + + + +class FeatureAnnotatedMolecule(Molecule): + """Molecule which is annotated with features (genes/exons/introns, .. ) + """ + + def __init__( + self, + fragment, + features, + stranded=None, + auto_set_intron_exon_features=False, + capture_locations=False, + **kwargs): + """ + Args: + fragments (singlecellmultiomics.fragment.Fragment): Fragments to associate to the molecule + features (singlecellmultiomics.features.FeatureContainer) : container to use to obtain features from + stranded : None; not stranded, False: same strand as R1, True: other strand + capture_locations (bool) : Store information about the locations of the aligned features + auto_set_intron_exon_features(bool) : obtain intron_exon_features upon initialising + **kwargs: extra args + + """ + Molecule.__init__(self, fragment, **kwargs) + self.features = features + self.hits = collections.defaultdict(set) # feature -> hit_bases + self.stranded = stranded + self.is_annotated = False + self.capture_locations = capture_locations + if capture_locations: + self.feature_locations = {} #feature->locations (chrom,start,end, strand) + + self.junctions = set() + self.genes = set() + self.introns = set() + self.exons = set() + self.exon_hit_gene_names = set() # readable names + self.is_spliced = None + + if auto_set_intron_exon_features: + self.set_intron_exon_features() + + def set_spliced(self, is_spliced): + """ Set wether the transcript is spliced, False has priority over True """ + if self.is_spliced and not is_spliced: + # has already been set + self.is_spliced = False + else: + self.is_spliced = is_spliced + + + + def set_intron_exon_features(self): + if not self.is_annotated: + self.annotate() + + # Collect all hits: + hits = self.hits.keys() + + # (gene, transcript) -> set( exon_id .. ) + exon_hits = collections.defaultdict(set) + intron_hits = collections.defaultdict(set) + + for hit, locations in self.hits.items(): + if not isinstance(hit, tuple): + continue + + meta = dict(list(hit)) + if 'gene_id' not in meta: + continue + if meta.get('type') == 'exon': + if 'transcript_id' not in meta: + continue + self.genes.add(meta['gene_id']) + self.exons.add(meta['exon_id']) + if 'transcript_id' not in meta: + raise ValueError( + "Please use an Intron GTF file generated using -id 'transcript_id'") + exon_hits[(meta['gene_id'], meta['transcript_id'])].add( + meta['exon_id']) + if 'gene_name' in meta: + self.exon_hit_gene_names.add(meta['gene_name']) + elif meta.get('type') == 'intron': + self.genes.add(meta['gene_id']) + self.introns.add(meta['gene_id']) + + # Find junctions and add all annotations to annotation sets + debug = [] + + for (gene, transcript), exons_overlapping in exon_hits.items(): + # If two exons are detected from the same gene we detected a + # junction: + if len(exons_overlapping) > 1: + self.junctions.add(gene) + + # We found two exons and an intron: + if gene in self.introns: + self.set_spliced(False) + else: + self.set_spliced(True) + + debug.append( + f'{gene}_{transcript}:' + + ','.join( + list(exons_overlapping))) + + # Write exon dictionary: + self.set_meta('DB', ';'.join(debug)) + + def get_hit_df(self): + """Obtain dataframe with hits + Returns: + pd.DataFrame + """ + if not self.is_annotated: + self.set_intron_exon_features() + + d = {} + tabulated_hits = [] + for hit, locations in self.hits.items(): + if not isinstance(hit, tuple): + continue + meta = dict(list(hit)) + for location in locations: + location_dict = { + 'chromosome': location[0], + 'start': location[1][0], + 'end': location[1][1]} + location_dict.update(meta) + tabulated_hits.append(location_dict) + + return pd.DataFrame(tabulated_hits) + + + def write_tags_to_psuedoreads(self, reads, call_super=True): + # @ todo needs refactor; the psuedoread should just be a Fragment too, solves all issues + if call_super: + Molecule.write_tags_to_psuedoreads(self, reads) + + for read in reads: + if len(self.exons) > 0: + read.set_tag('EX', ','.join(sorted([str(x) for x in self.exons]))) + else: + read.set_tag('EX', None) + + if len(self.introns) > 0: + read.set_tag('IN', ','.join( + sorted([str(x) for x in self.introns]))) + else: + read.set_tag('IN', None) + + if len(self.genes) > 0: + read.set_tag('GN', ','.join(sorted([str(x) for x in self.genes]))) + else: + read.set_tag('GN', None) + + if len(self.junctions) > 0: + read.set_tag('JN', ','.join( + sorted([str(x) for x in self.junctions]))) + # Is transcriptome + read.set_tag('IT', 1) + elif len(self.genes) > 0: + # Maps to gene but not junction + read.set_tag('IT', 0.5) + read.set_tag('JN', None) + else: + # Doesn't map to gene + read.set_tag('IT', 0) + read.set_tag('JN', None) + + if self.is_spliced is True: + read.set_tag('SP', True) + elif self.is_spliced is False: + read.set_tag('SP', False) + if len(self.exon_hit_gene_names): + read.set_tag('gn', ';'.join(list(self.exon_hit_gene_names))) + else: + read.set_tag('gn', None) + + def write_tags(self): + Molecule.write_tags(self) + + # Write cell ranger tags: + if self.umi is not None: + self.set_meta('UB', self.umi) + bc = list(self.get_barcode_sequences())[0] + self.set_meta('CB', bc) + + if len(self.exons) > 0: + self.set_meta('EX', ','.join(sorted([str(x) for x in self.exons]))) + else: + self.set_meta('EX',None) + + if len(self.introns) > 0: + self.set_meta('IN', ','.join( + sorted([str(x) for x in self.introns]))) + else: + self.set_meta('IN',None) + + if len(self.genes) > 0: + self.set_meta('GN', ','.join(sorted([str(x) for x in self.genes]))) + else: + self.set_meta('GN',None) + + if len(self.junctions) > 0: + self.set_meta('JN', ','.join( + sorted([str(x) for x in self.junctions]))) + # Is transcriptome + self.set_meta('IT', 1) + elif len(self.genes) > 0: + # Maps to gene but not junction + self.set_meta('IT', 0.5) + self.set_meta('JN',None) + else: + # Doesn't map to gene + self.set_meta('IT', 0) + self.set_meta('JN', None) + + if self.is_spliced is True: + self.set_meta('SP', True) + elif self.is_spliced is False: + self.set_meta('SP', False) + if len(self.exon_hit_gene_names): + self.set_meta('gn', ';'.join(list(self.exon_hit_gene_names))) + else: + self.set_meta('gn',None) + + def annotate(self, method=0): + """ + Args: + method (int) : 0, obtain blocks and then obtain features. 1, try to obtain features for every aligned base + + """ + # When self.stranded is None, set to None strand. If self.stranded is + # True reverse the strand, otherwise copy the strand + strand = None if self.stranded is None else '+-'[ + (not self.strand if self.stranded else self.strand)] + self.is_annotated = True + if method == 0: + + # Obtain all blocks: + try: + for start, end in self.get_aligned_blocks(): + for hit in self.features.findFeaturesBetween( + chromosome=self.chromosome, sampleStart=start, sampleEnd=end, strand=strand): + hit_start, hit_end, hit_id, hit_strand, hit_ids = hit + self.hits[hit_ids].add( + (self.chromosome, (hit_start, hit_end))) + + if self.capture_locations: + if not hit_id in self.feature_locations: + self.feature_locations[hit_id] = [] + self.feature_locations[hit_id].append( (hit_start, hit_end, hit_strand)) + + except TypeError: + # This happens when no reads map + pass + else: + + for read in self.iter_reads(): + for q_pos, ref_pos in read.get_aligned_pairs( + matches_only=True, with_seq=False): + for hit in self.features.findFeaturesAt( + chromosome=read.reference_name, lookupCoordinate=ref_pos, strand=strand): + hit_start, hit_end, hit_id, hit_strand, hit_ids = hit + self.hits[hit_ids].add((read.reference_name, ref_pos)) + + if self.capture_locations: + if not hit_id in self.feature_locations: + self.feature_locations[hit_id] = [] + self.feature_locations[hit_id].append( (hit_start, hit_end, hit_strand))