Switch to unified view

a b/singlecellmultiomics/molecule/molecule.py
1
from singlecellmultiomics.utils.sequtils import hamming_distance
2
import pysamiterators.iterators
3
import singlecellmultiomics.bamProcessing
4
from singlecellmultiomics.fragment import Fragment
5
from array import array
6
import itertools
7
import numpy as np
8
from singlecellmultiomics.utils import style_str, prob_to_phred, phredscores_to_base_call, base_probabilities_to_likelihood, likelihood_to_prob
9
import textwrap
10
import singlecellmultiomics.alleleTools
11
import functools
12
import typing
13
import pysam
14
import pysamiterators
15
from singlecellmultiomics.utils import find_ranges, create_MD_tag
16
import pandas as pd
17
from uuid import uuid4
18
from cached_property import cached_property
19
from collections import Counter, defaultdict
20
21
22
###############
23
24
# Variant validation function
25
def detect_alleles(molecules,
26
                   contig,
27
                   position,
28
                   min_cell_obs=3,
29
                   base_confidence_threshold=None,
30
                   classifier=None):  # [ alleles]
31
    """
32
    Detect the alleles (variable bases) present at the selected location
33
34
    Args:
35
36
        molecules : generator to extract molecules from
37
38
        variant_location(tuple) : (contig, position) zero based location of the location to test
39
40
        min_cell_obs (int) : minimum amount of cells containing the allele to be emitted
41
42
        confidence_threshold(float) : minimum confidence of concensus base-call to be taken into account
43
44
        classifier (obj) : classifier used for consensus call, when no classifier is supplied a mayority vote is used
45
46
    """
47
    observed_alleles = defaultdict(set)  # cell -> { base_call , .. }
48
    for molecule in molecules:
49
        base_call = molecule.get_consensus_base(contig, position, classifier=classifier)
50
51
        # confidence = molecule.get_mean_base_quality(*variant_location, base_call)
52
        if base_call is not None:
53
            observed_alleles[base_call].add(molecule.sample)
54
55
    return [allele for allele, cells in observed_alleles.items() if len(cells) >= min_cell_obs]
56
57
58
def get_variant_phase(molecules, contig, position, variant_base, allele_resolver,
59
                      phasing_ratio_threshold=None):  # (location,base) -> [( location, base, idenfifier)]
60
    alleles = [variant_base]
61
    phases = defaultdict(Counter)  # Allele_id -> variant->obs
62
    for molecule in molecules:
63
        # allele_obs = molecule.get_allele(return_allele_informative_base_dict=True,allele_resolver=allele_resolver)
64
        allele = list(molecule.get_allele(allele_resolver))
65
        if allele is None or len(allele) > 1 or len(allele) == 0:
66
            continue
67
        allele = allele[0]
68
69
        base = molecule.get_consensus_base(contig, position)
70
        if base in alleles:
71
            phases[base][allele] += 1
72
        else:
73
            pass
74
75
    if len(phases[variant_base]) == 0:
76
        raise ValueError("Phasing not established, no gSNVS available")
77
78
    phased_allele_id = phases[variant_base].most_common(1)[0][0]
79
80
    # Check if the phasing noise is below the threshold:
81
    if phasing_ratio_threshold is not None:
82
        correct = phases[variant_base].most_common(1)[0][1]
83
        total = sum(phases[variant_base].values())
84
        phasing_ratio = correct / total
85
        if correct / total < phasing_ratio_threshold:
86
            raise ValueError(f'Phasing ratio not met. ({phasing_ratio}) < {phasing_ratio_threshold}')
87
    # Check if the other allele i
88
89
    return phased_allele_id
90
91
92
###############
93
94
@functools.lru_cache(maxsize=1000)
95
def might_be_variant(chrom, pos, variants, ref_base=None):
96
    """Returns True if a variant exists at the given coordinate"""
97
    if ref_base == 'N':
98
        return False
99
    try:
100
        for record in variants.fetch(chrom, pos, pos + 1):
101
            return True
102
    except ValueError as e:
103
        return False # Happens when the contig does not exists, return False
104
    return False
105
106
def consensii_default_vector():
107
    """a numpy vector with 5 elements initialsed as zeros"""
108
    return np.zeros(5)
109
110
# 1: read1 2: read2
111
def molecule_to_random_primer_dict(
112
        molecule,
113
        primer_length=6,
114
        primer_read=2,
115
        max_N_distance=0):
116
    rp = defaultdict(list)
117
118
    # First add all reactions without a N in the sequence:
119
    for fragment in molecule:
120
121
        h_contig, hstart, hseq = fragment.get_random_primer_hash()
122
        if hseq is None:
123
            # This should really not happen with freshly demultiplexed data, it means we cannot extract the random primer sequence
124
            # which should be present as a tag (rS) in the record
125
            rp[None, None, None].append(fragment)
126
        elif 'N' not in hseq:
127
            rp[h_contig, hstart, hseq].append(fragment)
128
129
    # Try to match reactions with N with reactions without a N
130
    for fragment in molecule:
131
132
        h_contig, hstart, hseq = fragment.get_random_primer_hash()
133
        if hseq is not None and 'N' in hseq:
134
            # find nearest
135
            for other_contig, other_start, other_seq in rp:
136
                if other_contig != h_contig or other_start != hstart:
137
                    continue
138
139
                if hseq.count('N') > max_N_distance:
140
                    continue
141
142
                if 'N' in other_seq:
143
                    continue
144
145
                if hamming_distance(hseq, other_seq) == 0:
146
                    rp[other_contig, other_start, other_seq].append(fragment)
147
    return rp
148
149
150
class Molecule():
151
    """Molecule class, contains one or more associated fragments
152
153
    Attributes:
154
        fragments (list): associated fragments
155
156
        spanStart (int): starting coordinate of molecule, None if not available
157
158
        spanEnd (int): ending coordinate of molecule, None if not available
159
160
        chromosome (str): mapping chromosome of molecule, None if not available
161
162
        cache_size (int): radius of molecule assignment cache
163
164
        reference (pysam.FastaFile) : reference file, used to obtain base contexts
165
            and correct aligned_pairs iteration when the MD tag is not correct
166
167
        strand (bool): mapping strand.
168
            True when strand is REVERSE
169
            False when strand is FORWARD
170
            None when strand is not determined
171
    """
172
    def get_empty_clone(self, fragments=None):
173
        return type(self)(fragments,
174
                          cache_size = self.cache_size,
175
                          reference=self.reference,
176
                          min_max_mapping_quality=self.min_max_mapping_quality,
177
                          allele_assingment_method=self.allele_assingment_method,
178
                          allele_resolver=self.allele_resolver,
179
                          mapability_reader=self.mapability_reader,
180
                          max_associated_fragments=self.max_associated_fragments,
181
                          **self.kwargs)
182
183
    def __init__(self,
184
                 fragments: typing.Optional[typing.Iterable] = None,
185
                 cache_size: int = 10_000,
186
                 reference: typing.Union[pysam.FastaFile, pysamiterators.CachedFasta] = None,
187
                 # When all fragments have a mapping quality below this value
188
                 # the is_valid method will return False,
189
                 min_max_mapping_quality: typing.Optional[int] = None,
190
                 mapability_reader: typing.Optional[singlecellmultiomics.bamProcessing.MapabilityReader] = None,
191
                 allele_resolver: typing.Optional[singlecellmultiomics.alleleTools.AlleleResolver] = None,
192
                 max_associated_fragments=None,
193
                 allele_assingment_method=1, # 0: all variants from the same allele, 1: likelihood
194
                 **kwargs
195
196
                 ):
197
        """Initialise Molecule
198
199
        Parameters
200
        ----------
201
        fragments :  list(singlecellmultiomics.fragment.Fragment)
202
            Fragment to assign to Molecule. More fragments can be added later
203
204
        min_max_mapping_quality :  When all fragments have a mapping quality below this value the is_valid method will return False
205
206
        allele_resolver :  alleleTools.AlleleResolver or None. Supply an allele resolver in order to assign an allele to the molecule
207
208
        mapability_reader : singlecellmultiomics.bamProcessing.MapabilityReader, supply a mapability_reader to set mapping_quality of 0 to molecules mapping to locations which are not mapping uniquely during in-silico library generation.
209
210
        cache_size (int): radius of molecule assignment cache
211
212
        max_associated_fragments(int) : Maximum amount of fragments associated to molecule. If more fragments are added using add_fragment() they are not added anymore to the molecule
213
214
        """
215
        self.kwargs = kwargs
216
        self.reference = reference
217
        self.fragments = []
218
        self.spanStart = None
219
        self.spanEnd = None
220
        self.chromosome = None
221
        self.cache_size = cache_size
222
        self.strand = None
223
        self.umi = None
224
        self.overflow_fragments = 0
225
        self.umi_hamming_distance = None
226
        # when set, when comparing to a fragment the fragment to be added has
227
        # to match this hash
228
        self.fragment_match = None
229
        self.min_max_mapping_quality = min_max_mapping_quality
230
        self.umi_counter = Counter()  # Observations of umis
231
        self.max_associated_fragments = max_associated_fragments
232
        if fragments is not None:
233
            if isinstance(fragments, list):
234
                for frag in fragments:
235
                    self.add_fragment(frag)
236
            else:
237
                self.add_fragment(fragments)
238
239
        self.allele_resolver = allele_resolver
240
        self.mapability_reader = mapability_reader
241
        self.allele_assingment_method = allele_assingment_method
242
        self.methylation_call_dict = None
243
        self.finalised = False
244
        self.obtained_allele_likelihoods = None
245
246
    @cached_property
247
    def can_be_split_into_allele_molecules(self):
248
        l = self.allele_likelihoods
249
        if l is None or len(l)<=1:
250
            return False
251
        return True
252
253
    def split_into_allele_molecules(self):
254
        """
255
        Split this molecule into multiple molecules, associated to multiple alleles
256
        Returns:
257
            list_of_molecules: list
258
        """
259
        # Perform allele based clustering
260
        allele_clustered_frags = {}
261
        for fragment in self:
262
            n = self.get_empty_clone(fragment)
263
            if n.allele not in allele_clustered_frags:
264
                allele_clustered_frags[n.allele] = []
265
            allele_clustered_frags[n.allele].append(n)
266
267
        allele_clustered = {}
268
        for allele, assigned_molecules in allele_clustered_frags.items():
269
            for i, m in enumerate(assigned_molecules):
270
                if i == 0:
271
                    allele_clustered[allele] = m
272
                else:
273
                    allele_clustered[allele].add_molecule(m)
274
275
        if len(allele_clustered)>1:
276
            for m in allele_clustered.values():
277
                m.set_meta('cr', 'SplitUponAlleleClustering')
278
279
        return list(allele_clustered.values())
280
281
    @cached_property
282
    def allele(self):
283
        if self.allele_resolver is None:
284
            return None
285
        if self.allele_assingment_method == 0:
286
            # Obtain allele if available
287
            if self.allele_resolver is not None:
288
                try:
289
                    hits = self.get_allele(self.allele_resolver)
290
                    # Only store when we have a unique single hit:
291
                    if len(hits) == 1:
292
                        self.allele = list(hits)[0]
293
                except ValueError as e:
294
                    # This happens when a consensus can not be obtained
295
                    pass
296
        elif self.allele_assingment_method == 1:
297
            al = Counter(self.allele_likelihoods)
298
            if al is None or len(al)<1:
299
                return None
300
            return al.most_common(1)[0][0]
301
302
        raise NotImplementedError(f'allele_assingment_method {self.allele_assingment_method} is not defined')
303
304
    def __finalise__(self):
305
        """This function is called when all associated fragments have been gathered"""
306
307
        # Perfom allele assignment based on likelihood:
308
        # this is now only generated upon demand, see .allele method
309
310
        if self.mapability_reader is not None:
311
            self.update_mapability()
312
313
        self.finalised = True
314
315
    def update_mapability(self, set_mq_zero=False):
316
        """ Update mapability of this molecule.
317
        mapping qualities are set to 0 if the mapability_reader returns False
318
        for site_is_mapable
319
320
        The mapability_reader can be set when initiating the molecule, or added later.
321
322
        Args:
323
            set_mq_zero(bool) : set mapping quality of associated reads to 0 when the
324
            mappability reader returns a bad verdict
325
326
        Tip:
327
            Use `createMapabilityIndex.py` to create an index to feed to the mapability_reader
328
        """
329
330
        mapable = None
331
        try:
332
            mapable = self.mapability_reader.site_is_mapable(
333
                *self.get_cut_site())
334
        except TypeError:
335
            pass
336
        except Exception as e:
337
            raise
338
339
        if mapable is False:
340
            self.set_meta('mp', 'bad')
341
            if set_mq_zero:
342
                for read in self.iter_reads():
343
                    read.mapping_quality = 0
344
        elif mapable is True:
345
            self.set_meta('mp', 'unique')
346
        else:
347
            self.set_meta('mp', 'unknown')
348
349
    def calculate_consensus(self, consensus_model, molecular_identifier, out, **model_kwargs):
350
        """
351
        Create consensus read for molecule
352
353
        Args:
354
355
            consensus_model
356
357
            molecular_identifier (str) : identier for this molecule, will be suffixed to the reference_id
358
359
            out(pysam.AlingmentFile) : target bam file
360
361
            **model_kwargs : arguments passed to the consensus model
362
363
        """
364
        try:
365
            consensus_reads = self.deduplicate_to_single_CIGAR_spaced(
366
                out,
367
                f'c_{self.get_a_reference_id()}_{molecular_identifier}',
368
                consensus_model,
369
                NUC_RADIUS=model_kwargs['consensus_k_rad']
370
            )
371
            for consensus_read in consensus_reads:
372
                consensus_read.set_tag('RG', self[0].get_read_group())
373
                consensus_read.set_tag('mi', molecular_identifier)
374
                out.write(consensus_read)
375
376
        except Exception as e:
377
378
            self.set_rejection_reason('CONSENSUS_FAILED', set_qcfail=True)
379
            self.write_pysam(out)
380
381
    def get_a_reference_id(self):
382
        """
383
        Obtain a reference id for a random associated mapped read
384
        """
385
        for read in self.iter_reads():
386
            if not read.is_unmapped:
387
                return read.reference_id
388
        return -1
389
390
    def get_consensus_read(self, target_file,
391
                           read_name, consensus=None,
392
                           phred_scores=None,
393
                           cigarstring=None,
394
                           mdstring=None,
395
                           start=None,
396
                           supplementary=False
397
398
                           ):
399
        """get pysam.AlignedSegment containing aggregated molecule information
400
401
        Args:
402
            target_file(pysam.AlignmentFile) : File to create the read for
403
404
            read_name(str) : name of the read to write
405
        Returns:
406
            read(pysam.AlignedSegment)
407
        """
408
        if start is None:
409
            start = self.spanStart
410
        if consensus is None:
411
            try:  # Obtain the consensus sequence
412
                consensus = self.get_consensus()
413
            except Exception as e:
414
                raise
415
        if isinstance(consensus, str):
416
            sequence = consensus
417
        else:
418
            sequence = ''.join(
419
                (consensus.get(
420
                    (self.chromosome, ref_pos), 'N') for ref_pos in range(
421
                    self.spanStart, self.spanEnd + 1)))
422
423
        if isinstance(phred_scores, dict):
424
            phred_score_array = list(
425
                phred_scores.get(
426
                    (self.chromosome, ref_pos), 0) for ref_pos in range(
427
                    self.spanStart, self.spanEnd + 1))
428
        else:
429
            phred_score_array = phred_scores
430
431
        # Construct consensus - read
432
        cread = pysam.AlignedSegment(header=target_file.header)
433
        cread.reference_name = self.chromosome
434
        cread.reference_start = start
435
        cread.query_name = read_name
436
        cread.query_sequence = sequence
437
        cread.query_qualities = phred_score_array
438
        cread.is_supplementary = supplementary
439
        if cigarstring is not None:
440
            cread.cigarstring = cigarstring
441
        else:
442
            cread.cigarstring = f'{len(sequence)}M'
443
        cread.mapping_quality = self.get_max_mapping_qual()
444
445
        cread.is_reverse = self.strand
446
        if mdstring is not None:
447
            cread.set_tag('MD', mdstring)
448
449
        self.write_tags_to_psuedoreads((cread,))
450
451
        return cread
452
453
    """ method = 1
454
        sequence = []
455
        cigar = []
456
        if method==0:
457
            prev_end = None
458
            for block_start,block_end in molecule.get_aligned_blocks():
459
                if molecule.strand:
460
                    print(block_end>block_start,block_start, block_end)
461
                if prev_end is not None:
462
                    cigar.append(f'{block_start - prev_end}D')
463
464
                block_len = block_end-block_start+1
465
                cigar.append(f'{block_len}M')
466
                for ref_pos in range(block_start,block_end+1):
467
                    call = consensus.get((molecule.chromosome, ref_pos),'N')
468
                    sequence.append(call)
469
                prev_end = block_end+1
470
471
            cigarstring = ''.join(cigar)
472
        """
473
474
    def get_feature_vector(self, window_size=90):
475
        """ Obtain a feature vector representation of the molecule
476
477
        Returns:
478
            feature_vector(np.array)
479
        """
480
481
        return np.array([
482
            self.get_strand(),
483
            self.has_valid_span(),
484
            self.get_umi_error_rate(),
485
            self.get_consensus_gc_ratio(),
486
            len(self.get_raw_barcode_sequences()),
487
            self.get_safely_aligned_length(),
488
            self.get_max_mapping_qual(),
489
            (self.alleles is None),
490
            self.contains_valid_fragment(),
491
            self.is_multimapped(),
492
            self.get_feature_window(window_size=window_size)
493
        ])
494
495
    def get_tag_counter(self):
496
        """
497
        Obtain a dictionary with tag -> value -> frequency
498
499
        Returns:
500
            tag_obs (defaultdict(Counter)):
501
                { tag(str) : { value(int/str): frequency:(int) }
502
503
504
        """
505
        tags_obs = defaultdict(Counter)
506
        for tag, value in itertools.chain(
507
                *[r.tags for r in self.iter_reads()]):
508
            try:
509
                tags_obs[tag][value] += 1
510
            except TypeError:
511
                # Dont count arrays for example
512
                pass
513
        return tags_obs
514
515
    def write_tags(self):
516
        """ Write BAM tags to all reads associated to this molecule
517
518
        This function sets the following tags:
519
            - mI : most common umi
520
            - DA : allele
521
            - af : amount of associated fragments
522
            - rt : rt_reaction_index
523
            - rd : rt_duplicate_index
524
            - TR : Total RT reactions
525
            - ap : phasing information (if allele_resolver is set)
526
            - TF : total fragments
527
            - ms : size of the molecule (largest fragment)
528
        """
529
        self.is_valid(set_rejection_reasons=True)
530
        if self.umi is not None:
531
            self.set_meta('mI', self.umi)
532
        if self.allele is not None:
533
            self.set_meta('DA', str(self.allele))
534
535
        # Set total amount of associated fragments
536
        self.set_meta('TF', len(self.fragments) + self.overflow_fragments)
537
        try:
538
            self.set_meta('ms',self.estimated_max_length)
539
        except Exception as e:
540
            # There is no properly defined aligned length
541
            pass
542
        # associatedFragmentCount :
543
        self.set_meta('af', len(self))
544
        for rc, frag in enumerate(self):
545
            frag.set_meta('RC', rc)
546
            if rc > 0:
547
                # Set duplicate bit
548
                for read in frag:
549
                    if read is not None:
550
                        read.is_duplicate = True
551
552
        # Write RT reaction tags (rt: rt reaction index, rd rt duplicate index)
553
        # This is only required for fragments which have defined random primers
554
        rt_reaction_index = None
555
        for rt_reaction_index, ( (contig, random_primer_start, random_primer_sequence), frags) in enumerate(
556
                self.get_rt_reactions().items()):
557
558
            for rt_duplicate_index, frag in enumerate(frags):
559
                frag.set_meta('rt', rt_reaction_index)
560
                frag.set_meta('rd', rt_duplicate_index)
561
                frag.set_meta('rp', random_primer_start)
562
        self.set_meta('TR', 0 if (rt_reaction_index is None) else rt_reaction_index + 1)
563
564
        if self.allele_resolver is not None:
565
            self.write_allele_phasing_information_tag()
566
567
    def write_tags_to_psuedoreads(self, reads):
568
        """
569
        Write molecule information to the supplied reads as BAM tags
570
        """
571
        # write methylation tags to new reads if applicable:
572
        if self.methylation_call_dict is not None:
573
            self.set_methylation_call_tags(
574
                self.methylation_call_dict, reads=reads)
575
576
        for read in reads:
577
            read.set_tag('SM', self.sample)
578
            if hasattr(self, 'get_cut_site'):
579
                read.set_tag('DS', self.get_cut_site()[1])
580
581
            if self.umi is not None:
582
                read.set_tag('RX', self.umi)
583
                bc = list(self.get_barcode_sequences())[0]
584
                read.set_tag('BC', bc)
585
                read.set_tag('MI', bc + self.umi)
586
587
            # Store total amount of RT reactions:
588
            read.set_tag('TR', len(self.get_rt_reactions()))
589
            read.set_tag('TF', len(self.fragments) + self.overflow_fragments)
590
591
            if self.allele is not None:
592
                read.set_tag('DA', self.allele)
593
594
        if self.allele_resolver is not None:
595
            self.write_allele_phasing_information_tag(
596
                self.allele_resolver, reads=reads)
597
598
    def deduplicate_to_single(
599
            self,
600
            target_bam,
601
            read_name,
602
            classifier,
603
            reference=None):
604
        """
605
        Deduplicate all reads associated to this molecule to a single pseudoread
606
607
        Args:
608
            target_bam (pysam.AlignmentFile) : file to associate the read with
609
            read_name (str) : name of the pseudoread
610
            classifier (sklearn classifier) : classifier for consensus prediction
611
612
        Returns:
613
            read (pysam.AlignedSegment) : Pseudo-read containing aggregated information
614
        """
615
        # Set all associated reads to duplicate
616
        for read in self.iter_reads():
617
            read.is_duplicate = True
618
619
        features = self.get_base_calling_feature_matrix(reference=reference)
620
621
        # We only use the proba:
622
        base_calling_probs = classifier.predict_proba(features)
623
        predicted_sequence = ['ACGT'[i] for i in np.argmax(base_calling_probs, 1)]
624
        phred_scores = np.rint(-10 * np.log10(np.clip(1 - base_calling_probs.max(1), 0.000000001, 0.999999))).astype(
625
            'B')
626
627
        read = self.get_consensus_read(
628
            read_name=read_name,
629
            target_file=target_bam,
630
            consensus=''.join(predicted_sequence),
631
            phred_scores=phred_scores)
632
        read.is_read1 = True
633
        return read
634
635
    def deduplicate_to_single_CIGAR_spaced(
636
            self,
637
            target_bam,
638
            read_name,
639
            classifier=None,
640
            max_N_span=300,
641
            reference=None,
642
            **feature_matrix_args
643
    ):
644
        """
645
        Deduplicate all associated reads to a single pseudoread, when the span is larger than max_N_span
646
        the read is split up in multi-segments. Uncovered locations are spaced using N's in the CIGAR.
647
648
        Args:
649
            target_bam (pysam.AlignmentFile) : file to associate the read with
650
            read_name (str) : name of the pseudoread
651
            classifier (sklearn classifier) : classifier for consensus prediction
652
        Returns:
653
            reads( list [ pysam.AlignedSegment ] )
654
655
        """
656
        # Set all associated reads to duplicate
657
        for read in self.iter_reads():
658
            read.is_duplicate = True
659
660
        if classifier is not None:
661
            features, reference_bases, CIGAR, alignment_start, alignment_end = self.get_base_calling_feature_matrix_spaced(
662
                True, reference=reference, **feature_matrix_args)
663
664
            base_calling_probs = classifier.predict_proba(features)
665
            predicted_sequence = ['ACGT'[i] for i in np.argmax(base_calling_probs, 1)]
666
667
            reference_sequence = ''.join(
668
                [base for chrom, pos, base in reference_bases])
669
            # predicted_sequence[ features[:, [ x*8 for x in range(4) ] ].sum(1)==0 ] ='N'
670
            predicted_sequence = ''.join(predicted_sequence)
671
672
            phred_scores = np.rint(
673
                -10 * np.log10(np.clip(1 - base_calling_probs.max(1),
674
                                       0.000000001,
675
                                       0.999999)
676
                               )).astype('B')
677
678
        reads = []
679
680
        query_index_start = 0
681
        query_index_end = 0
682
        reference_position = alignment_start  # pointer to current position
683
        reference_start = alignment_start  # pointer to alignment start of current read
684
        supplementary = False
685
        partial_CIGAR = []
686
        partial_MD = []
687
688
        for operation, amount in CIGAR:
689
            if operation == 'M':  # Consume query and reference
690
                query_index_end += amount
691
                reference_position += amount
692
                partial_CIGAR.append(f'{amount}{operation}')
693
694
            if operation == 'N':
695
                # Consume reference:
696
                reference_position += amount
697
                if amount > max_N_span:  # Split up in supplementary alignment
698
                    # Eject previous
699
                    # reference_seq =
700
701
                    consensus_read = self.get_consensus_read(
702
                        read_name=read_name,
703
                        target_file=target_bam,
704
                        consensus=predicted_sequence[query_index_start:query_index_end],
705
                        phred_scores=phred_scores[query_index_start:query_index_end],
706
                        cigarstring=''.join(partial_CIGAR),
707
                        mdstring=create_MD_tag(
708
                            reference_sequence[query_index_start:query_index_end],
709
                            predicted_sequence[query_index_start:query_index_end]
710
                        ),
711
                        start=reference_start,
712
                        supplementary=supplementary
713
                    )
714
                    reads.append(consensus_read)
715
                    if not supplementary:
716
                        consensus_read.is_read1 = True
717
718
                    supplementary = True
719
                    # Start new:
720
                    query_index_start = query_index_end
721
                    reference_start = reference_position
722
                    partial_CIGAR = []
723
                else:
724
                    partial_CIGAR.append(f'{amount}{operation}')
725
726
        reads.append(self.get_consensus_read(
727
            read_name=read_name,
728
            target_file=target_bam,
729
            consensus=''.join(predicted_sequence[query_index_start:query_index_end]),
730
            phred_scores=phred_scores[query_index_start:query_index_end],
731
            cigarstring=''.join(partial_CIGAR),
732
            mdstring=create_MD_tag(
733
                reference_sequence[query_index_start:query_index_end],
734
                predicted_sequence[query_index_start:query_index_end]
735
736
            ),
737
            start=reference_start,
738
            supplementary=supplementary
739
        ))
740
741
        # Write last index tag to last read ..
742
        if supplementary:
743
            reads[-1].is_read2 = True
744
745
        # Write NH tag (the amount of records with the same query read):
746
        for read in reads:
747
            read.set_tag('NH', len(reads))
748
749
        return reads
750
751
    def extract_stretch_from_dict(self, base_call_dict, alignment_start, alignment_end):
752
        base_calling_probs = np.array(
753
            [base_call_dict.get((self.chromosome, pos), ('N', 0))[1] for pos in range(alignment_start, alignment_end)])
754
        predicted_sequence = [base_call_dict.get((self.chromosome, pos), ('N', 0))[0] for pos in
755
                              range(alignment_start, alignment_end)]
756
        predicted_sequence = ''.join(predicted_sequence)
757
        phred_scores = np.rint(
758
            -10 * np.log10(np.clip(1 - base_calling_probs,
759
                                   0.000000001,
760
                                   0.999999999)
761
                           )).astype('B')
762
        return predicted_sequence, phred_scores
763
764
    def get_base_confidence_dict(self):
765
        """
766
        Get dictionary containing base calls per position and the corresponding confidences
767
768
        Returns:
769
            obs (dict) :  (contig (str), position  (int) ) : base (str) : prob correct (list)
770
        """
771
        # Convert (contig, position) -> (base_call) into:
772
        # (contig, position) -> (base_call, confidence)
773
        obs = defaultdict(lambda: defaultdict(list))
774
        for read in self.iter_reads():
775
            for qpos, rpos in read.get_aligned_pairs(matches_only=True):
776
                qbase = read.seq[qpos]
777
                qqual = read.query_qualities[qpos]
778
                # @ todo reads which span multiple chromosomes
779
                obs[(self.chromosome, rpos)][qbase].append(1 - np.power(10, -qqual / 10))
780
        return obs
781
782
783
    def deduplicate_majority(self, target_bam, read_name, max_N_span=None):
784
785
        obs = self.get_base_confidence_dict()
786
787
        reads = list(self.get_dedup_reads(read_name,
788
                                     target_bam,
789
                                     obs={reference_position: phredscores_to_base_call(probs)
790
                                          for reference_position, probs in obs.items()},
791
                                     max_N_span=max_N_span))
792
        self.write_tags_to_psuedoreads([read for read in reads if read is not None])
793
        return reads
794
795
    def generate_partial_reads(self, obs, max_N_span=None):
796
        CIGAR, alignment_start, alignment_end = self.get_CIGAR()
797
        query_index_start = 0
798
        query_index_end = 0
799
        reference_position = alignment_start  # pointer to current position
800
        reference_start = alignment_start  # pointer to alignment start of current read
801
        reference_end = None
802
        partial_CIGAR = []
803
        partial_MD = []
804
        partial_sequence = []
805
        partial_phred = []
806
807
        for operation, amount in CIGAR:
808
            if operation == 'N':
809
                if max_N_span is not None and amount > max_N_span:
810
                    yield reference_start, reference_end, partial_sequence, partial_phred, partial_CIGAR, partial_MD
811
                    # Clear all
812
                    partial_CIGAR = []
813
                    partial_MD = []
814
                    partial_sequence = []
815
                    partial_phred = []
816
                else:
817
                    # Increment
818
                    partial_CIGAR.append(f'{amount}{operation}')
819
                    query_index_start += sum((len(s) for s in partial_sequence))
820
821
                reference_position += amount
822
            elif operation == 'M':  # Consume query and reference
823
824
                query_index_end += amount
825
                if len(partial_CIGAR) == 0:
826
                    reference_start = reference_position
827
                start_fetch = reference_position
828
829
                reference_position += amount
830
                reference_end = reference_position
831
                partial_CIGAR.append(f'{amount}{operation}')
832
833
                predicted_sequence, phred_scores = self.extract_stretch_from_dict(obs, start_fetch, reference_end)  # [start .. end)
834
835
                partial_sequence.append(predicted_sequence)
836
                partial_phred.append(phred_scores)
837
838
        yield reference_start, reference_end, partial_sequence, partial_phred, partial_CIGAR, partial_MD
839
840
    def get_dedup_reads(self, read_name, target_bam, obs, max_N_span=None):
841
        if self.chromosome is None:
842
            return None # We cannot perform this action
843
        for reference_start, reference_end, partial_sequence, partial_phred, partial_CIGAR, partial_MD in self.generate_partial_reads(
844
                obs, max_N_span=max_N_span):
845
            consensus_read = self.get_consensus_read(
846
                read_name=read_name,
847
                target_file=target_bam,
848
                consensus=''.join(partial_sequence),
849
                phred_scores= array('B', np.concatenate(partial_phred)), # Needs to be casted to array
850
                cigarstring=''.join(partial_CIGAR),
851
                mdstring=create_MD_tag(
852
                    self.reference.fetch(self.chromosome, reference_start, reference_end),
853
                    ''.join(partial_sequence)
854
                ),
855
                start=reference_start,
856
                supplementary=False
857
            )
858
859
            consensus_read.is_reverse = self.strand
860
            yield consensus_read
861
862
    def deduplicate_to_single_CIGAR_spaced_from_dict(
863
            self,
864
            target_bam,
865
            read_name,
866
            base_call_dict,  # (contig, position) -> (base_call, confidence)
867
            max_N_span=300,
868
    ):
869
        """
870
        Deduplicate all associated reads to a single pseudoread, when the span is larger than max_N_span
871
        the read is split up in multi-segments. Uncovered locations are spaced using N's in the CIGAR.
872
873
        Args:
874
            target_bam (pysam.AlignmentFile) : file to associate the read with
875
            read_name (str) : name of the pseudoread
876
            classifier (sklearn classifier) : classifier for consensus prediction
877
        Returns:
878
            reads( list [ pysam.AlignedSegment ] )
879
880
        """
881
        # Set all associated reads to duplicate
882
        for read in self.iter_reads():
883
            read.is_duplicate = True
884
885
        CIGAR, alignment_start, alignment_end = self.get_CIGAR()
886
887
        reads = []
888
889
        query_index_start = 0
890
        query_index_end = 0
891
        reference_position = alignment_start  # pointer to current position
892
        reference_start = alignment_start  # pointer to alignment start of current read
893
        reference_end = None
894
        supplementary = False
895
        partial_CIGAR = []
896
        partial_MD = []
897
898
        partial_sequence = []
899
        partial_phred = []
900
901
        for operation, amount in CIGAR:
902
903
            if operation == 'N':
904
                # Pop the previous read..
905
                if len(partial_sequence):
906
                    assert reference_end is not None
907
                    consensus_read = self.get_consensus_read(
908
                        read_name=read_name,
909
                        target_file=target_bam,
910
                        consensus=''.join(partial_sequence),
911
                        phred_scores=array('B',np.concatenate(partial_phred)),
912
                        cigarstring=''.join(partial_CIGAR),
913
                        mdstring=create_MD_tag(
914
                            self.reference.fetch(self.chromosome, reference_start, reference_end),
915
                            ''.join(partial_sequence)
916
                        ),
917
                        start=reference_start,
918
                        supplementary=supplementary
919
                    )
920
                    reads.append(consensus_read)
921
                    if not supplementary:
922
                        consensus_read.is_read1 = True
923
924
                    supplementary = True
925
                    reference_start = reference_position
926
                    partial_CIGAR = []
927
                    partial_phred = []
928
                    partial_sequence = []
929
930
                # Consume reference:
931
                reference_position += amount
932
                partial_CIGAR.append(f'{amount}{operation}')
933
934
            if operation == 'M':  # Consume query and reference
935
                query_index_end += amount
936
                # This should only be reset upon a new read:
937
                if len(partial_CIGAR) == 0:
938
                    reference_start = reference_position
939
                reference_position += amount
940
                reference_end = reference_position
941
942
                partial_CIGAR.append(f'{amount}{operation}')
943
944
                predicted_sequence, phred_scores = self.extract_stretch_from_dict(base_call_dict, reference_start,
945
                                                                                  reference_end)  # [start .. end)
946
947
                partial_sequence.append(predicted_sequence)
948
                partial_phred.append(phred_scores)
949
950
        consensus_read = self.get_consensus_read(
951
            read_name=read_name,
952
            target_file=target_bam,
953
            consensus=''.join(partial_sequence),
954
            phred_scores= array('B',np.concatenate(partial_phred)),
955
            cigarstring=''.join(partial_CIGAR),
956
            mdstring=create_MD_tag(
957
                self.reference.fetch(self.chromosome, reference_start, reference_end),
958
                ''.join(partial_sequence)
959
            ),
960
            start=reference_start,
961
            supplementary=supplementary
962
        )
963
        reads.append(consensus_read)
964
        if not supplementary:
965
            consensus_read.is_read1 = True
966
967
        supplementary = True
968
        reference_start = reference_position
969
        partial_CIGAR = []
970
971
        # Write last index tag to last read ..
972
        if supplementary:
973
            reads[-1].is_read2 = True
974
            reads[0].is_read1 = True
975
976
        # Write NH tag (the amount of records with the same query read):
977
        for read in reads:
978
            read.set_tag('NH', len(reads))
979
980
        return reads
981
982
    def get_base_calling_feature_matrix(
983
            self,
984
            return_ref_info=False,
985
            start=None,
986
            end=None,
987
            reference=None,
988
            NUC_RADIUS=1,
989
            USE_RT=True,
990
            select_read_groups=None):
991
        """
992
        Obtain feature matrix for base calling
993
994
        Args:
995
            return_ref_info (bool) : return both X and array with feature information
996
            start (int) : start of range, genomic position
997
            end (int) : end of range (inclusive), genomic position
998
            reference(pysam.FastaFile) : reference to fetch reference bases from, if not supplied the MD tag is used
999
            NUC_RADIUS(int) : generate kmer features target nucleotide
1000
            USE_RT(bool) : use RT reaction features
1001
            select_read_groups(set) : only use reads from these read groups to generate features
1002
        """
1003
        if start is None:
1004
            start = self.spanStart
1005
        if end is None:
1006
            end = self.spanEnd
1007
1008
        with np.errstate(divide='ignore', invalid='ignore'):
1009
            BASE_COUNT = 5
1010
            RT_INDEX = 7 if USE_RT else None
1011
            STRAND_INDEX = 0
1012
            PHRED_INDEX = 1
1013
            RC_INDEX = 2
1014
            MATE_INDEX = 3
1015
            CYCLE_INDEX = 4
1016
            MQ_INDEX = 5
1017
            FS_INDEX = 6
1018
1019
            COLUMN_OFFSET = 0
1020
            features_per_block = 8 - (not USE_RT)
1021
1022
            origin_start = start
1023
            origin_end = end
1024
1025
            end += NUC_RADIUS
1026
            start -= NUC_RADIUS
1027
1028
            features = np.zeros(
1029
                (end - start + 1, (features_per_block * BASE_COUNT) + COLUMN_OFFSET))
1030
1031
            if return_ref_info:
1032
                ref_bases = {}
1033
1034
            for rt_id, fragments in self.get_rt_reactions().items():
1035
                # we need to keep track what positions where covered by this RT
1036
                # reaction
1037
                RT_reaction_coverage = set()  # (pos, base_call)
1038
                for fragment in fragments:
1039
                    for read in fragment:
1040
                        if select_read_groups is not None:
1041
                            if not read.has_tag('RG'):
1042
                                raise ValueError(
1043
                                    "Not all reads in the BAM file have a read group defined.")
1044
                            if not read.get_tag('RG') in select_read_groups:
1045
                                continue
1046
                        # Skip reads outside range
1047
                        if read is None or read.reference_start > (
1048
                                end + 1) or read.reference_end < start:
1049
                            continue
1050
                        for cycle, q_pos, ref_pos, ref_base in pysamiterators.ReadCycleIterator(
1051
                                read, matches_only=True, with_seq=True, reference=reference):
1052
1053
                            row_index = ref_pos - start
1054
                            if row_index < 0 or row_index >= features.shape[0]:
1055
                                continue
1056
1057
                            query_base = read.seq[q_pos]
1058
                            # Base index block:
1059
                            block_index = 'ACGTN'.index(query_base)
1060
1061
                            # Update rt_reactions
1062
                            if USE_RT:
1063
                                if not (
1064
                                               ref_pos, query_base) in RT_reaction_coverage:
1065
                                    features[row_index][RT_INDEX +
1066
                                                        COLUMN_OFFSET +
1067
                                                        features_per_block *
1068
                                                        block_index] += 1
1069
                                RT_reaction_coverage.add((ref_pos, query_base))
1070
1071
                            # Update total phred score
1072
                            features[row_index][PHRED_INDEX +
1073
                                                COLUMN_OFFSET +
1074
                                                features_per_block *
1075
                                                block_index] += read.query_qualities[q_pos]
1076
1077
                            # Update total reads
1078
1079
                            features[row_index][RC_INDEX + COLUMN_OFFSET +
1080
                                                features_per_block * block_index] += 1
1081
1082
                            # Update mate index
1083
                            features[row_index][MATE_INDEX +
1084
                                                COLUMN_OFFSET +
1085
                                                features_per_block *
1086
                                                block_index] += read.is_read2
1087
1088
                            # Update fragment sizes:
1089
                            features[row_index][FS_INDEX +
1090
                                                COLUMN_OFFSET +
1091
                                                features_per_block *
1092
                                                block_index] += abs(fragment.span[1] -
1093
                                                                    fragment.span[2])
1094
1095
                            # Update cycle
1096
                            features[row_index][CYCLE_INDEX +
1097
                                                COLUMN_OFFSET +
1098
                                                features_per_block *
1099
                                                block_index] += cycle
1100
1101
                            # Update MQ:
1102
                            features[row_index][MQ_INDEX +
1103
                                                COLUMN_OFFSET +
1104
                                                features_per_block *
1105
                                                block_index] += read.mapping_quality
1106
1107
                            # update strand:
1108
                            features[row_index][STRAND_INDEX +
1109
                                                COLUMN_OFFSET +
1110
                                                features_per_block *
1111
                                                block_index] += read.is_reverse
1112
1113
                            if return_ref_info:
1114
                                row_index_in_output = ref_pos - origin_start
1115
                                if row_index_in_output < 0 or row_index_in_output >= origin_end - origin_start + 1:
1116
                                    continue
1117
1118
                                ref_bases[ref_pos] = ref_base.upper()
1119
1120
            # Normalize all and return
1121
1122
            for block_index in range(BASE_COUNT):  # ACGTN
1123
                for index in (
1124
                        PHRED_INDEX,
1125
                        MATE_INDEX,
1126
                        CYCLE_INDEX,
1127
                        MQ_INDEX,
1128
                        FS_INDEX,
1129
                        STRAND_INDEX):
1130
                    features[:, index +
1131
                                COLUMN_OFFSET +
1132
                                features_per_block *
1133
                                block_index] /= features[:, RC_INDEX +
1134
                                                            COLUMN_OFFSET +
1135
                                                            features_per_block *
1136
                                                            block_index]
1137
            # np.nan_to_num( features, nan=-1, copy=False )
1138
            features[np.isnan(features)] = -1
1139
1140
            if NUC_RADIUS > 0:
1141
                # duplicate columns in shifted manner
1142
                x = features
1143
                features = np.zeros(
1144
                    (x.shape[0] - NUC_RADIUS * 2, x.shape[1] * (1 + NUC_RADIUS * 2)))
1145
                for offset in range(0, NUC_RADIUS * 2 + 1):
1146
                    slice_start = offset
1147
                    slice_end = -(NUC_RADIUS * 2) + offset
1148
                    if slice_end == 0:
1149
                        features[:, features_per_block *
1150
                                    BASE_COUNT *
1151
                                    offset:features_per_block *
1152
                                           BASE_COUNT *
1153
                                           (offset +
1154
                                            1)] = x[slice_start:, :]
1155
                    else:
1156
                        features[:, features_per_block *
1157
                                    BASE_COUNT *
1158
                                    offset:features_per_block *
1159
                                           BASE_COUNT *
1160
                                           (offset +
1161
                                            1)] = x[slice_start:slice_end, :]
1162
1163
            if return_ref_info:
1164
                ref_info = [
1165
                    (self.chromosome, ref_pos, ref_bases.get(ref_pos, 'N'))
1166
                    for ref_pos in range(origin_start, origin_end + 1)]
1167
                return features, ref_info
1168
            return features
1169
1170
    def get_CIGAR(self, reference=None):
1171
        """ Get alignment of all associated reads
1172
1173
        Returns:
1174
            y : reference bases
1175
            CIGAR : alignment of feature matrix to reference tuples (operation, count)
1176
            reference(pysam.FastaFile) : reference to fetch reference bases from, if not supplied the MD tag is used
1177
        """
1178
1179
        X = None
1180
1181
        CIGAR = []
1182
        prev_end = None
1183
        alignment_start = None
1184
        alignment_end = None
1185
        for start, end in self.get_aligned_blocks():
1186
1187
            if prev_end is not None:
1188
                CIGAR.append(('N', start - prev_end - 1))
1189
            CIGAR.append(('M', (end - start + 1)))
1190
            prev_end = end
1191
1192
            if alignment_start is None:
1193
                alignment_start = start
1194
                alignment_end = end
1195
            else:
1196
                alignment_start = min(alignment_start, start)
1197
                alignment_end = max(alignment_end, end)
1198
1199
        return CIGAR, alignment_start, alignment_end
1200
1201
    @functools.lru_cache(maxsize=4)
1202
    def get_base_calling_feature_matrix_spaced(
1203
            self,
1204
            return_ref_info=False,
1205
            reference=None,
1206
            **feature_matrix_args):
1207
        """
1208
        Obtain a base-calling feature matrix for all reference aligned bases.
1209
1210
        Returns:
1211
            X : feature matrix
1212
            y : reference bases
1213
            CIGAR : alignment of feature matrix to reference tuples (operation, count)
1214
            reference(pysam.FastaFile) : reference to fetch reference bases from, if not supplied the MD tag is used
1215
        """
1216
1217
        X = None
1218
        if return_ref_info:
1219
            y = []
1220
        CIGAR = []
1221
        prev_end = None
1222
        alignment_start = None
1223
        alignment_end = None
1224
        for start, end in self.get_aligned_blocks():
1225
            if return_ref_info:
1226
                x, y_ = self.get_base_calling_feature_matrix(
1227
                    return_ref_info=return_ref_info, start=start, end=end,
1228
                    reference=reference, **feature_matrix_args
1229
                )
1230
                y += y_
1231
            else:
1232
                x = self.get_base_calling_feature_matrix(
1233
                    return_ref_info=return_ref_info,
1234
                    start=start,
1235
                    end=end,
1236
                    reference=reference,
1237
                    **feature_matrix_args)
1238
            if X is None:
1239
                X = x
1240
            else:
1241
                X = np.append(X, x, axis=0)
1242
1243
            if prev_end is not None:
1244
                CIGAR.append(('N', start - prev_end - 1))
1245
            CIGAR.append(('M', (end - start + 1)))
1246
            prev_end = end
1247
1248
            if alignment_start is None:
1249
                alignment_start = start
1250
                alignment_end = end
1251
            else:
1252
                alignment_start = min(alignment_start, start)
1253
                alignment_end = max(alignment_end, end)
1254
1255
        if return_ref_info:
1256
            return X, y, CIGAR, alignment_start, alignment_end
1257
        else:
1258
            return X, CIGAR, alignment_start, alignment_end
1259
1260
    def get_base_calling_training_data(
1261
            self,
1262
            mask_variants=None,
1263
            might_be_variant_function=None,
1264
            reference=None,
1265
            **feature_matrix_args):
1266
        if mask_variants is not None and might_be_variant_function is None:
1267
            might_be_variant_function = might_be_variant
1268
1269
        features, feature_info, _CIGAR, _alignment_start, _alignment_end = self.get_base_calling_feature_matrix_spaced(
1270
            True, reference=reference, **feature_matrix_args)
1271
1272
        # Edgecase: it can be that not a single base can be used for base calling
1273
        # in that case features will be None
1274
        # when there is no features return None
1275
        if features is None or len(features) == 0:
1276
            return None
1277
1278
        # check which bases should not be used
1279
        use_indices = [
1280
            mask_variants is None or
1281
            not might_be_variant_function(chrom, pos, mask_variants, base)
1282
            for chrom, pos, base in feature_info]
1283
1284
        X_molecule = features[use_indices]
1285
        y_molecule = [
1286
            base for use, (chrom, pos, base) in
1287
            zip(use_indices, feature_info) if use
1288
        ]
1289
        return X_molecule, y_molecule
1290
1291
    def has_valid_span(self):
1292
        """Check if the span of the molecule is determined
1293
1294
        Returns:
1295
            has_valid_span (bool)
1296
        """
1297
        if self.spanStart is not None and self.spanEnd is not None:
1298
            return True
1299
        return False
1300
1301
    def get_strand_repr(self, unknown='?'):
1302
        """Get string representation of mapping strand
1303
1304
        Args:
1305
            unknown (str) :  set what character/string to return
1306
                             when the strand is not available
1307
1308
        Returns:
1309
            strand_repr (str) : + forward, - reverse, ? unknown
1310
        """
1311
        s = self.get_strand()
1312
        if s is None:
1313
            return unknown
1314
        if s:
1315
            return '-'
1316
        else:
1317
            return '+'
1318
1319
    def set_rejection_reason(self, reason, set_qcfail=False):
1320
        """ Add rejection reason to all fragments associated to this molecule
1321
1322
        Args:
1323
            reason (str) : rejection reason to set
1324
1325
            set_qcfail(bool) : set qcfail bit to True for all associated reads
1326
        """
1327
        for fragment in self:
1328
            fragment.set_rejection_reason(reason, set_qcfail=set_qcfail)
1329
1330
    def is_valid(self, set_rejection_reasons=False):
1331
        """Check if the molecule is valid
1332
        All of the following requirements should be met:
1333
        - no multimapping
1334
        - no low mapping mapping_quality (Change molecule.min_max_mapping_quality to set the threshold)
1335
        - molecule is associated with at least one valid fragment
1336
1337
        Args:
1338
            set_rejection_reasons (bool) : When set to True, all reads get a
1339
            rejection reason (RR tag) written to them if the molecule is rejected.
1340
1341
        Returns:
1342
            is_valid (bool) : True when all requirements are met, False otherwise
1343
1344
        """
1345
        if self.is_multimapped():
1346
            if set_rejection_reasons:
1347
                self.set_rejection_reason('multimapping')
1348
            return False
1349
1350
        if self.min_max_mapping_quality is not None and \
1351
                self.get_max_mapping_qual() < self.min_max_mapping_quality:
1352
            if set_rejection_reasons:
1353
                self.set_rejection_reason('MQ')
1354
            return False
1355
1356
        if not self.contains_valid_fragment():
1357
            if set_rejection_reasons:
1358
                self.set_rejection_reason('invalid_fragments')
1359
            return False
1360
1361
        return True
1362
1363
    def get_aligned_blocks(self):
1364
        """ get all consecutive blocks of aligned reference positions
1365
1366
        Returns:
1367
            sorted list of aligned blocks (list) : [ (start, end), (start, end) ]
1368
        """
1369
        return find_ranges(
1370
            sorted(list(set(
1371
                (ref_pos
1372
                 for read in self.iter_reads()
1373
                 for q_pos, ref_pos in read.get_aligned_pairs(matches_only=True, with_seq=False)))))
1374
        )
1375
1376
    def __len__(self):
1377
        """Obtain the amount of fragments associated to the molecule"""
1378
        return len(self.fragments)
1379
1380
    def get_consensus_base_frequencies(self, allow_N=False):
1381
        """Obtain the frequency of bases in the molecule consensus sequence
1382
1383
        Returns:
1384
            base_frequencies (Counter) : Counter containing base frequecies, for example: { 'A':10,'T':3, C:4 }
1385
        """
1386
        return Counter(
1387
            self.get_consensus(
1388
                allow_N=allow_N).values())
1389
1390
    def get_feature_vector(self):
1391
        """ Obtain a feature vector representation of the molecule
1392
1393
        Returns:
1394
            feature_vector(np.array)
1395
        """
1396
1397
        return np.array([
1398
            len(self),
1399
            self.get_strand(),
1400
            self.has_valid_span(),
1401
            self.get_umi_error_rate(),
1402
            self.get_consensus_gc_ratio(),
1403
            len(self.get_raw_barcode_sequences()),
1404
            self.get_safely_aligned_length(),
1405
            self.get_max_mapping_qual(),
1406
            (self.allele is None),
1407
            self.contains_valid_fragment(),
1408
            self.is_multimapped(),
1409
            self.get_undigested_site_count(),
1410
            self.is_valid()
1411
        ])
1412
1413
    def get_alignment_tensor(self,
1414
                             max_reads,
1415
                             window_radius=20,
1416
                             centroid=None,
1417
                             mask_centroid=False,
1418
                             refence_backed=False,
1419
                             skip_missing_reads=False
1420
                             ):
1421
        """ Obtain a tensor representation of the molecule alignment around the given centroid
1422
1423
        Args:
1424
            max_reads (int) : maximum amount of reads returned in the tensor, this will be the amount of rows/4 of the returned feature matrix
1425
1426
            window_radius (int) : radius of bp around centroid
1427
1428
            centroid(int) : center of extracted window, when not specified the cut location of the molecule is used
1429
1430
            mask_centroid(bool) : when True, mask reference base at centroid with N
1431
1432
            refence_backed(bool) : when True the molecules reference is used to emit reference bases instead of the MD tag
1433
1434
        Returns:
1435
            tensor_repr(np.array) : (4*window_radius*2*max_reads) dimensional feature matrix
1436
        """
1437
        reference = None
1438
        if refence_backed:
1439
            reference = self.reference
1440
            if self.reference is None:
1441
                raise ValueError(
1442
                    "refence_backed set to True, but the molecule has no reference assigned. Assing one using pysam.FastaFile()")
1443
1444
        height = max_reads
1445
        chromosome = self.chromosome
1446
        if centroid is None:
1447
            _, centroid, strand = self.get_cut_site()
1448
        span_start = centroid - window_radius
1449
        span_end = centroid + window_radius
1450
        span_len = abs(span_start - span_end)
1451
        base_content_table = np.zeros((height, span_len))
1452
        base_mismatches_table = np.zeros((height, span_len))
1453
        base_indel_table = np.zeros((height, span_len))
1454
        base_qual_table = np.zeros((height, span_len))
1455
        base_clip_table = np.zeros((height, span_len))
1456
        pointer = 0
1457
1458
        mask = None
1459
        if mask_centroid:
1460
            mask = set((chromosome, centroid))
1461
1462
        for _, frags in self.get_rt_reactions().items():
1463
            for frag in frags:
1464
                pointer = frag.write_tensor(
1465
                    chromosome,
1466
                    span_start,
1467
                    span_end,
1468
                    index_start=pointer,
1469
                    base_content_table=base_content_table,
1470
                    base_mismatches_table=base_mismatches_table,
1471
                    base_indel_table=base_indel_table,
1472
                    base_qual_table=base_qual_table,
1473
                    base_clip_table=base_clip_table,
1474
                    height=height,
1475
                    mask_reference_bases=mask,
1476
                    reference=reference,
1477
                    skip_missing_reads=skip_missing_reads)
1478
        x = np.vstack(
1479
            [
1480
                base_content_table,
1481
                base_mismatches_table,
1482
                base_indel_table,
1483
                base_qual_table,
1484
                base_clip_table
1485
            ])
1486
1487
        return x
1488
1489
    def get_consensus_gc_ratio(self):
1490
        """Obtain the GC ratio of the molecule consensus sequence
1491
1492
        Returns:
1493
            gc_ratio(float) : GC ratio
1494
        """
1495
        bf = self.get_consensus_base_frequencies()
1496
        return (bf['G'] + bf['C']) / sum(bf.values())
1497
1498
    def get_umi_error_rate(self):
1499
        """Obtain fraction of fragments that are associated
1500
        to the molecule with a exact matching UMI vs total amount of associated fragments
1501
        Returns:
1502
            exact_matching_fraction (float)
1503
        """
1504
        mc = 0
1505
        other = 0
1506
        for i, (umi, obs) in enumerate(self.umi_counter.most_common()):
1507
            if i == 0:
1508
                mc = obs
1509
            else:
1510
                other += obs
1511
1512
        return mc / (other + mc)
1513
1514
    def get_barcode_sequences(self):
1515
        """Obtain (Cell) barcode sequences associated to molecule
1516
1517
        Returns:
1518
            barcode sequences (set) : barcode sequence(s)
1519
        """
1520
        return set(read.get_tag('BC') for read in self.iter_reads())
1521
1522
    def get_raw_barcode_sequences(self):
1523
        """Obtain (Cell) barcode sequences associated to molecule, not hamming corrected
1524
1525
        Returns:
1526
            barcode sequences (set) : barcode sequence(s)
1527
        """
1528
        return set(read.get_tag('bc') for read in self.iter_reads())
1529
1530
    def get_strand(self):
1531
        """Obtain mapping strand of molecule
1532
1533
        Returns:
1534
            strand : True,False,None
1535
                True when strand is REVERSE
1536
                False when strand is FORWARD
1537
                None when strand is not determined
1538
        """
1539
        return self.strand
1540
1541
    def __repr__(self):
1542
1543
        max_show = 6  # maximum amount of fragments to show
1544
        frag_repr = '\n\t'.join([textwrap.indent(str(fragment), ' ' * 4)
1545
                                 for fragment in self.fragments[:max_show]])
1546
1547
        return f"""{self.__class__.__name__}
1548
        with {len(self.fragments)} assinged fragments
1549
        {"Allele :" + (self.allele if self.allele is not None else "No allele assigned")}
1550
        """ + frag_repr + (
1551
            '' if len(self.fragments) < max_show else f'... {len(self.fragments) - max_show} fragments not shown')
1552
1553
    def update_umi(self):
1554
        """Set UMI
1555
        sets self.umi (str) sets the most common umi associated to the molecule
1556
        """
1557
        self.umi = self.umi_counter.most_common(1)[0][0]
1558
1559
    def get_umi(self):
1560
        """Obtain umi of molecule
1561
1562
        Returns:
1563
            umi (str):
1564
                return main umi associated with this molecule
1565
        """
1566
1567
        return self.umi
1568
1569
    def get_sample(self):
1570
        """Obtain sample
1571
1572
        Returns:
1573
            sample (str):
1574
                Sample associated with the molecule. Usually extracted from SM tag.
1575
                Calls fragment.get_sample() to obtain the sample
1576
        """
1577
        for fragment in self.fragments:
1578
            return fragment.get_sample()
1579
1580
    def get_cut_site(self):
1581
        """For restriction based protocol data, obtain genomic location of cut site
1582
1583
        Returns:
1584
            None if site is not available
1585
1586
            chromosome (str)
1587
            position (int)
1588
            strand (bool)
1589
        """
1590
1591
        for fragment in self.fragments:
1592
            try:
1593
                site = fragment.get_site_location()
1594
            except AttributeError:
1595
                return None
1596
            if site is not None:
1597
                return tuple((*site, fragment.get_strand()))
1598
        return None
1599
1600
    def get_mean_mapping_qual(self):
1601
        """Get mean mapping quality of the molecule
1602
1603
        Returns:
1604
            mean_mapping_qual (float)
1605
        """
1606
        return np.mean([fragment.mapping_quality for fragment in self])
1607
1608
    def get_max_mapping_qual(self):
1609
        """Get max mapping quality of the molecule
1610
        Returns:
1611
            max_mapping_qual (float)
1612
        """
1613
        return max([fragment.mapping_quality for fragment in self])
1614
1615
    def get_min_mapping_qual(self) -> float:
1616
        """Get min mapping quality of the molecule
1617
        Returns:
1618
            min_mapping_qual (float)
1619
        """
1620
        return min([fragment.mapping_quality for fragment in self])
1621
1622
    def contains_valid_fragment(self):
1623
        """Check if an associated fragment exists which returns True for is_valid()
1624
1625
        Returns:
1626
            contains_valid_fragment (bool) : True when any associated fragment is_valid()
1627
        """
1628
        return any(
1629
            (hasattr(fragment, 'is_valid') and fragment.is_valid()
1630
             for fragment in self.fragments))
1631
1632
    def is_multimapped(self):
1633
        """Check if the molecule is multimapping
1634
1635
        Returns:
1636
            is_multimapped (bool) : True when multimapping
1637
        """
1638
        for fragment in self.fragments:
1639
            if not fragment.is_multimapped:
1640
                return False
1641
        return True
1642
1643
    def add_molecule(self, other):
1644
        """
1645
        Merge other molecule into this molecule.
1646
        Merges by assigning all fragments in other to this molecule.
1647
        """
1648
        for fragment in other:
1649
            self._add_fragment(fragment)
1650
1651
1652
    def get_span_sequence(self, reference=None):
1653
        """Obtain the sequence between the start and end of the molecule
1654
        Args:
1655
            reference(pysam.FastaFile) : reference  to use.
1656
                If not specified `self.reference` is used
1657
        Returns:
1658
            sequence (str)
1659
        """
1660
        if self.chromosome is None:
1661
            return ''
1662
1663
        if reference is None:
1664
            if self.reference is None:
1665
                raise ValueError('Please supply a reference (PySAM.FastaFile)')
1666
            reference = self.reference
1667
        return reference.fetch(
1668
            self.chromosome,
1669
            self.spanStart,
1670
            self.spanEnd).upper()
1671
1672
    def get_fragment_span_sequence(self, reference=None):
1673
        return self.get_span_sequence(reference)
1674
1675
    def _add_fragment(self, fragment):
1676
1677
        # Do not process the fragment when the max_associated_fragments threshold is exceeded
1678
        if self.max_associated_fragments is not None and len(self.fragments) >= (self.max_associated_fragments):
1679
            self.overflow_fragments += 1
1680
            raise OverflowError()
1681
1682
        self.match_hash = fragment.match_hash
1683
1684
        # if we already had a fragment, this fragment is a duplicate:
1685
        if len(self.fragments) > 1:
1686
            fragment.set_duplicate(True)
1687
1688
        self.fragments.append(fragment)
1689
1690
        # Update span:
1691
        add_span = fragment.get_span()
1692
1693
        # It is possible that the span is not defined, then set the respective keys to None
1694
        # This indicates the molecule is qcfail
1695
1696
        #if not self.has_valid_span():
1697
        #    self.spanStart, self.spanEnd, self.chromosome = None,None, None
1698
        #else:
1699
        self.spanStart = add_span[1] if self.spanStart is None else min(
1700
            add_span[1], self.spanStart)
1701
        self.spanEnd = add_span[2] if self.spanEnd is None else max(
1702
            add_span[2], self.spanEnd)
1703
        self.chromosome = add_span[0]
1704
1705
        self.span = (self.chromosome, self.spanStart, self.spanEnd)
1706
        if fragment.strand is not None:
1707
            self.strand = fragment.strand
1708
        self.umi_counter[fragment.umi] += 1
1709
        self.umi_hamming_distance = fragment.umi_hamming_distance
1710
        self.saved_base_obs = None
1711
        self.update_umi()
1712
        return True
1713
1714
    @property
1715
    def aligned_length(self) -> int:
1716
        if self.has_valid_span():
1717
            return self.spanEnd - self.spanStart
1718
        else:
1719
            return None
1720
1721
    @property
1722
    def is_completely_matching(self) -> bool:
1723
        """
1724
        Checks if all associated reads are completely mapped:
1725
        checks if all cigar operations are M,
1726
        Returns True when all cigar operations are M, False otherwise
1727
        """
1728
1729
        return all(
1730
                (
1731
                     all(
1732
                     [ (operation==0)
1733
                        for operation, amount in read.cigartuples] )
1734
                for read in self.iter_reads()))
1735
1736
1737
    @property
1738
    def estimated_max_length(self) -> int:
1739
        """
1740
        Obtain the estimated size of the fragment,
1741
        returns None when estimation is not possible
1742
        Takes into account removed bases (R2)
1743
        Assumes inwards sequencing orientation
1744
        """
1745
        max_size = None
1746
        for frag in self:
1747
            r = frag.estimated_length
1748
            if r is None :
1749
                continue
1750
            if max_size is None:
1751
                max_size = r
1752
            elif r>max_size:
1753
                max_size = r
1754
        return max_size
1755
1756
    def get_safely_aligned_length(self):
1757
        """Get the amount of safely aligned bases (excludes primers)
1758
        Returns:
1759
            aligned_bases (int) : Amount of safely aligned bases
1760
             or None when this cannot be determined because both mates are not mapped
1761
        """
1762
        if self.spanStart is None or self.spanEnd is None:
1763
            return None
1764
1765
        start = None
1766
        end = None
1767
        contig = None
1768
        for fragment in self:
1769
            if not fragment.safe_span:
1770
                continue
1771
1772
            if contig is None:
1773
                contig = fragment.span[0]
1774
            if contig == fragment.span[0]:
1775
                f_start, f_end = fragment.get_safe_span()
1776
                if start is None:
1777
                    start = f_start
1778
                    end = f_end
1779
                else:
1780
                    start = min(f_start, start)
1781
                    end = min(f_end, end)
1782
1783
        if end is None:
1784
            raise ValueError('Not safe')
1785
        return abs(end - start)
1786
1787
    def add_fragment(self, fragment, use_hash=True):
1788
        """Associate a fragment with this Molecule
1789
1790
        Args:
1791
            fragment (singlecellmultiomics.fragment.Fragment) : Fragment to associate
1792
        Returns:
1793
            has_been_added (bool) : Returns False when the fragments which have already been associated to the molecule refuse the fragment
1794
1795
        Raises:
1796
            OverflowError : when too many fragments have been associated already
1797
                            control this with .max_associated_fragments attribute
1798
        """
1799
1800
        if len(self.fragments) == 0:
1801
            self._add_fragment(fragment)
1802
            self.sample = fragment.sample
1803
            return True
1804
1805
        if use_hash:
1806
            if self == fragment:
1807
                self._add_fragment(fragment)
1808
                return True
1809
1810
        else:
1811
            for f in self.fragments:
1812
                if f == fragment:
1813
                    # it matches this molecule:
1814
                    self._add_fragment(fragment)
1815
                    return True
1816
1817
        return False
1818
1819
    def can_be_yielded(self, chromosome, position):
1820
        """Check if the molecule is far enough away from the supplied location to be ejected from a buffer.
1821
1822
        Args:
1823
            chromosome (str) : chromosome / contig of location to test
1824
            position (int) : genomic location of location to test
1825
1826
        Returns:
1827
            can_be_yielded (bool) : True when the molecule is far enough away from the supplied location to be ejected from a buffer.
1828
        """
1829
1830
        if chromosome is None:
1831
            return False
1832
        if chromosome != self.chromosome:
1833
            return True
1834
        return position < (
1835
                self.spanStart -
1836
                self.cache_size *
1837
                0.5) or position > (
1838
                       self.spanEnd +
1839
                       self.cache_size *
1840
                       0.5)
1841
1842
    def get_rt_reactions(self) -> dict:
1843
        """Obtain RT reaction dictionary
1844
1845
        returns:
1846
            rt_dict (dict):  {(primer,pos) : [fragment, fragment..] }
1847
        """
1848
        return molecule_to_random_primer_dict(self)
1849
1850
    def get_rt_reaction_fragment_sizes(self, max_N_distance=1):
1851
        """Obtain all RT reaction fragment sizes
1852
        Returns:
1853
            rt_sizes (list of ints)
1854
        """
1855
1856
        rt_reactions = molecule_to_random_primer_dict(
1857
            self, max_N_distance=max_N_distance)
1858
        amount_of_rt_reactions = len(rt_reactions)
1859
1860
        # this obtains the maximum fragment size:
1861
        frag_chrom, frag_start, frag_end = pysamiterators.iterators.getListSpanningCoordinates(
1862
            [v for v in itertools.chain.from_iterable(self) if v is not None])
1863
1864
        # Obtain the fragment sizes of all RT reactions:
1865
        rt_sizes = []
1866
        for (rt_contig, rt_end, hexamer), fragments in rt_reactions.items():
1867
1868
            if rt_end is None:
1869
                continue
1870
1871
            rt_chrom, rt_start, rt_end = pysamiterators.iterators.getListSpanningCoordinates(
1872
                itertools.chain.from_iterable(
1873
                    [fragment for fragment in fragments if
1874
                     fragment is not None and fragment.get_random_primer_hash()[0] is not None]))
1875
1876
            rt_sizes.append([rt_end - rt_start])
1877
        return rt_sizes
1878
1879
    def get_mean_rt_fragment_size(self):
1880
        """Obtain the mean RT reaction fragment size
1881
1882
        Returns:
1883
            mean_rt_size (float)
1884
        """
1885
        return np.nanmean(
1886
            self.get_rt_reaction_fragment_sizes()
1887
        )
1888
1889
    def write_pysam(self, target_file, consensus=False, no_source_reads=False, consensus_name=None, consensus_read_callback=None, consensus_read_callback_kwargs=None):
1890
        """Write all associated reads to the target file
1891
1892
        Args:
1893
            target_file (pysam.AlignmentFile) : Target file
1894
            consensus (bool) : write consensus
1895
            no_source_reads (bool) : while in consensus mode, don't write original reads
1896
            consensus_read_callback (function) : this function is called with every consensus read as an arguments
1897
            consensus_read_callback_kwargs (dict) : arguments to pass to the callback function
1898
        """
1899
        if consensus:
1900
            reads = self.deduplicate_majority(target_file,
1901
                                              f'molecule_{uuid4()}' if consensus_name is None else consensus_name)
1902
            if consensus_read_callback is not None:
1903
                if consensus_read_callback_kwargs is not None:
1904
                    consensus_read_callback(reads, **consensus_read_callback_kwargs)
1905
                else:
1906
                    consensus_read_callback(reads)
1907
1908
            for read in reads:
1909
                target_file.write(read)
1910
1911
            if not no_source_reads:
1912
                for read in self.iter_reads():
1913
                    read.is_duplicate=True
1914
                for fragment in self:
1915
                    fragment.write_pysam(target_file)
1916
1917
        else:
1918
            for fragment in self:
1919
                fragment.write_pysam(target_file)
1920
1921
    def set_methylation_call_tags(self,
1922
                                  call_dict, bismark_call_tag='XM',
1923
                                  total_methylated_tag='MC',
1924
                                  total_unmethylated_tag='uC',
1925
                                  total_methylated_CPG_tag='sZ',
1926
                                  total_unmethylated_CPG_tag='sz',
1927
                                  total_methylated_CHH_tag='sH',
1928
                                  total_unmethylated_CHH_tag='sh',
1929
                                  total_methylated_CHG_tag='sX',
1930
                                  total_unmethylated_CHG_tag='sx',
1931
                                  reads=None
1932
                                  ):
1933
        """Set methylation call tags given a methylation dictionary
1934
1935
        This method sets multiple tags in every read associated to the molecule.
1936
        The tags being set are the bismark_call_tag, every aligned base is annotated
1937
        with a zZxXhH or ".", and a tag for both the total methylated C's and unmethylated C's
1938
1939
        Args:
1940
            call_dict (dict) : Dictionary containing bismark calls (chrom,pos) :
1941
                        {'context':letter,'reference_base': letter   , 'consensus': letter, optiona: 'qual': pred_score (int) }
1942
1943
            bismark_call_tag (str) : tag to write bismark call string
1944
1945
            total_methylated_tag (str) : tag to write total methylated bases
1946
1947
            total_unmethylated_tag (str) : tag to write total unmethylated bases
1948
1949
            reads (iterable) : reads to write the tags to, when not supplied, the tags are written to all associated reads
1950
        Returns:
1951
            can_be_yielded (bool)
1952
        """
1953
        self.methylation_call_dict = call_dict
1954
1955
        # molecule_XM dictionary containing count of contexts
1956
        molecule_XM = Counter(
1957
            list(
1958
                d.get(
1959
                    'context',
1960
                    '.') for d in self.methylation_call_dict.values()))
1961
        # Contruct XM strings
1962
        if reads is None:
1963
            reads = self.iter_reads()
1964
        for read in reads:
1965
1966
            bis_met_call_string = ''.join([
1967
                call_dict.get(
1968
                    (read.reference_name, rpos), {}).get('context', '.')
1969
                # Obtain all aligned positions from the call dict
1970
                # iterate all positions in the alignment
1971
                for qpos, rpos in read.get_aligned_pairs(matches_only=True)
1972
                if qpos is not None and rpos is not None])
1973
            # make sure to ignore non matching positions ? is this neccesary?
1974
1975
            read.set_tag(
1976
                # Write the methylation tag to the read
1977
                bismark_call_tag,
1978
                bis_met_call_string
1979
            )
1980
1981
            # Set total methylated bases
1982
            read.set_tag(
1983
                total_methylated_tag,
1984
                molecule_XM['Z'] + molecule_XM['X'] + molecule_XM['H'])
1985
1986
            # Set total unmethylated bases
1987
            read.set_tag(
1988
                total_unmethylated_tag,
1989
                molecule_XM['z'] + molecule_XM['x'] + molecule_XM['h'])
1990
1991
            # Set total CPG methylated and unmethylated:
1992
            read.set_tag(
1993
                total_methylated_CPG_tag,
1994
                molecule_XM['Z'])
1995
1996
            read.set_tag(
1997
                total_unmethylated_CPG_tag,
1998
                molecule_XM['z'])
1999
2000
            # Set total CHG methylated and unmethylated:
2001
            read.set_tag(
2002
                total_methylated_CHG_tag,
2003
                molecule_XM['X'])
2004
2005
            read.set_tag(
2006
                total_unmethylated_CHG_tag,
2007
                molecule_XM['x'])
2008
2009
            # Set total CHH methylated and unmethylated:
2010
            read.set_tag(
2011
                total_methylated_CHH_tag,
2012
                molecule_XM['H'])
2013
2014
            read.set_tag(
2015
                total_unmethylated_CHH_tag,
2016
                molecule_XM['h'])
2017
2018
2019
            # Set XR (Read conversion string)
2020
            # @todo: this is TAPS specific, inneficient, ugly
2021
            try:
2022
                fwd = 0
2023
                rev = 0
2024
                for (qpos, rpos, ref_base), call in zip(
2025
                    read.get_aligned_pairs(matches_only=True,with_seq=True),
2026
                    bis_met_call_string):
2027
                    qbase = read.query_sequence[qpos]
2028
                    if call.isupper():
2029
                        if qbase=='A':
2030
                            rev+=1
2031
                        elif qbase=='T':
2032
                            fwd+=1
2033
2034
                # Set XG (genome conversion string)
2035
                if rev>fwd:
2036
                    read.set_tag('XR','CT')
2037
                    read.set_tag('XG','GA')
2038
                else:
2039
                    read.set_tag('XR','CT')
2040
                    read.set_tag('XG','CT')
2041
            except ValueError:
2042
                # Dont set the tag
2043
                pass
2044
2045
    def set_meta(self, tag, value):
2046
        """Set meta information to all fragments
2047
2048
        Args:
2049
            tag (str):
2050
                2 letter tag
2051
            value: value to set
2052
2053
        """
2054
        for f in self:
2055
            f.set_meta(tag, value)
2056
2057
    def __getitem__(self, index):
2058
        """Obtain a fragment belonging to this molecule.
2059
2060
        Args:
2061
            index (int):
2062
                index of the fragment [0 ,1 , 2 ..]
2063
2064
        Returns:
2065
            fragment (singlecellmultiomics.fragment.Fragment)
2066
        """
2067
        return self.fragments[index]
2068
2069
    def get_alignment_stats(self):
2070
        """Get dictionary containing mean amount clip/insert/deletion/matches per base
2071
2072
        Returns:
2073
            cigar_stats(dict): dictionary {
2074
                clips_per_bp(int),
2075
                deletions_per_bp(int),
2076
                matches_per_bp(int),
2077
                inserts_per_bp(int),
2078
                alternative_hits_per_read(int),
2079
2080
                }
2081
        """
2082
        clips = 0
2083
        matches = 0
2084
        inserts = 0
2085
        deletions = 0
2086
        totalbases = 0
2087
        total_reads = 0
2088
        total_alts = 0
2089
        for read in self.iter_reads():
2090
            totalbases += read.query_length
2091
            total_reads += 1
2092
            for operation, amount in read.cigartuples:
2093
                if operation == 4:
2094
                    clips += amount
2095
                elif operation == 2:
2096
                    deletions += amount
2097
                elif operation == 0:
2098
                    matches += amount
2099
                elif operation == 1:
2100
                    inserts += amount
2101
            if read.has_tag('XA'):
2102
                total_alts += len(read.get_tag('XA').split(';'))
2103
2104
        clips_per_bp = clips / totalbases
2105
        inserts_per_bp = inserts / totalbases
2106
        deletions_per_bp = deletions / totalbases
2107
        matches_per_bp = matches / totalbases
2108
2109
        alt_per_read = total_alts / total_reads
2110
2111
        return {
2112
            'clips_per_bp': clips_per_bp,
2113
            'inserts_per_bp': inserts_per_bp,
2114
            'deletions_per_bp': deletions_per_bp,
2115
            'matches_per_bp': matches_per_bp,
2116
            'alt_per_read': alt_per_read,
2117
            'total_bases':totalbases,
2118
            'total_reads':total_reads,
2119
        }
2120
2121
    def get_mean_cycle(
2122
            self,
2123
            chromosome,
2124
            position,
2125
            base=None,
2126
            not_base=None):
2127
        """Get the mean cycle at the supplied coordinate and base-call
2128
2129
        Args:
2130
            chromosome (str)
2131
            position (int)
2132
            base (str) : select only reads with this base
2133
            not_base(str) : select only reads without this base
2134
2135
        Returns:
2136
            mean_cycles (tuple): mean cycle for R1 and R2
2137
        """
2138
        assert (base is not None or not_base is not None), "Supply base or not_base"
2139
2140
        cycles_R1 = []
2141
        cycles_R2 = []
2142
        for read in self.iter_reads():
2143
2144
            if read is None or read.reference_name != chromosome:
2145
                continue
2146
2147
2148
            for cycle, query_pos, ref_pos in pysamiterators.iterators.ReadCycleIterator(
2149
                    read, with_seq=False):
2150
2151
                if query_pos is None or ref_pos != position:
2152
                    continue
2153
2154
                if not_base is not None and read.seq[query_pos] == not_base:
2155
                    continue
2156
                if base is not None and read.seq[query_pos] != base:
2157
                    continue
2158
2159
                if read.is_read2:
2160
                    cycles_R2.append(cycle)
2161
                else:
2162
                    cycles_R1.append(cycle)
2163
        if len(cycles_R2) == 0 and len(cycles_R1)==0:
2164
            raise IndexError(
2165
                "There are no observations if the supplied base/location combination")
2166
        return (np.mean(cycles_R1) if len(cycles_R1) else np.nan),  (np.mean(cycles_R2) if len(cycles_R2) else np.nan)
2167
2168
    def get_mean_base_quality(
2169
                self,
2170
                chromosome,
2171
                position,
2172
                base=None,
2173
                not_base=None):
2174
            """Get the mean phred score at the supplied coordinate and base-call
2175
2176
            Args:
2177
                chromosome (str)
2178
                position (int)
2179
                base (str) : select only reads with this base
2180
                not_base(str) : select only reads without this base
2181
2182
            Returns:
2183
                mean_phred_score (float)
2184
            """
2185
            assert (base is not None or not_base is not None), "Supply base or not_base"
2186
2187
            qualities = []
2188
            for read in self.iter_reads():
2189
2190
                if read is None or read.reference_name != chromosome:
2191
                    continue
2192
2193
                for query_pos, ref_pos in read.get_aligned_pairs(
2194
                        with_seq=False, matches_only=True):
2195
2196
                    if query_pos is None or ref_pos != position:
2197
                        continue
2198
2199
                    if not_base is not None and read.seq[query_pos] == not_base:
2200
                        continue
2201
                    if base is not None and read.seq[query_pos] != base:
2202
                        continue
2203
2204
                    qualities.append(ord(read.qual[query_pos]))
2205
            if len(qualities) == 0:
2206
                raise IndexError(
2207
                    "There are no observations if the supplied base/location combination")
2208
            return np.mean(qualities)
2209
2210
    @cached_property
2211
    def allele_likelihoods(self):
2212
        """
2213
        Per allele likelihood
2214
2215
        Returns:
2216
            likelihoods (dict) : {allele_name : likelihood}
2217
2218
        """
2219
        return self.get_allele_likelihoods()[0]
2220
2221
    @property
2222
    def library(self):
2223
        """
2224
        Associated library
2225
2226
        Returns:
2227
           library (str) : Library associated with the first read of this molecule
2228
2229
        """
2230
        for read in self.iter_reads():
2231
            if read.has_tag('LY'):
2232
                return read.get_tag('LY')
2233
2234
    @cached_property
2235
    def allele_probabilities(self):
2236
        """
2237
        Per allele probability
2238
2239
        Returns:
2240
            likelihoods (dict) : {allele_name : prob}
2241
2242
        """
2243
        return likelihood_to_prob( self.get_allele_likelihoods()[0] )
2244
2245
2246
    @cached_property
2247
    def allele_confidence(self) -> int:
2248
        """
2249
        Returns
2250
            confidence(int) : a phred scalled confidence value for the allele
2251
            assignment, returns zero when no allele is associated to the molecule
2252
        """
2253
        l = self.allele_probabilities
2254
        if l is None or len(l) == 0 :
2255
            return 0
2256
        return int(prob_to_phred( Counter(l).most_common(1)[0][1] ))
2257
2258
    @cached_property
2259
    def base_confidences(self):
2260
        return self.get_base_confidence_dict()
2261
2262
    @cached_property
2263
    def base_likelihoods(self):
2264
        return {(chrom, pos):base_probabilities_to_likelihood(probs) for (chrom, pos),probs in self.base_confidences.items()}
2265
2266
    @cached_property
2267
    def base_probabilities(self):
2268
        # Optimization which is equal to {location:likelihood_to_prob(liks) for location,liks in self.base_likelihoods.items()}
2269
        obs = {}
2270
        for read in self.iter_reads():
2271
            for qpos, rpos in read.get_aligned_pairs(matches_only=True):
2272
                qbase = read.seq[qpos]
2273
                qqual = read.query_qualities[qpos]
2274
                if qbase=='N':
2275
                    continue
2276
                # @ todo reads which span multiple chromosomes
2277
                k = (self.chromosome, rpos)
2278
                p = 1 - np.power(10, -qqual / 10)
2279
2280
                if not k in obs:
2281
                    obs[k] = {}
2282
                if not qbase in obs[k]:
2283
                    obs[k][qbase] = [p,1] # likelihood, n
2284
                    obs[k]['N'] = [1-p,1] # likelihood, n
2285
                else:
2286
                    obs[k][qbase][0] *= p
2287
                    obs[k][qbase][1] += 1
2288
2289
                    obs[k]['N'][0] *= 1-p # likelihood, n
2290
                    obs[k]['N'][1] += 1 # likelihood, n
2291
        # Perform likelihood conversion and convert to probs
2292
        return { location: likelihood_to_prob({
2293
            base:likelihood/np.power(0.25,n-1)
2294
                    for base,(likelihood,n) in base_likelihoods.items() })
2295
                    for location,base_likelihoods in obs.items()}
2296
2297
    ## This is a duplicate of the above but only calculates for allele informative positions
2298
    @cached_property
2299
    def allele_informative_base_probabilities(self):
2300
        # Optimization which is equal to {location:likelihood_to_prob(liks) for location,liks in self.base_likelihoods.items()}
2301
        obs = {}
2302
        for read in self.iter_reads():
2303
            for qpos, rpos in read.get_aligned_pairs(matches_only=True):
2304
                if not self.allele_resolver.has_location( read.reference_name, rpos ):
2305
                    continue
2306
                qbase = read.seq[qpos]
2307
                qqual = read.query_qualities[qpos]
2308
                if qbase=='N':
2309
                    continue
2310
                # @ todo reads which span multiple chromosomes
2311
                k = (self.chromosome, rpos)
2312
                p = 1 - np.power(10, -qqual / 10)
2313
2314
                if not k in obs:
2315
                    obs[k] = {}
2316
                if not qbase in obs[k]:
2317
                    obs[k][qbase] = [p,1] # likelihood, n
2318
                    obs[k]['N'] = [1-p,1] # likelihood, n
2319
                else:
2320
                    obs[k][qbase][0] *= p
2321
                    obs[k][qbase][1] += 1
2322
2323
                    obs[k]['N'][0] *= 1-p # likelihood, n
2324
                    obs[k]['N'][1] += 1 # likelihood, n
2325
        # Perform likelihood conversion and convert to probs
2326
        return { location: likelihood_to_prob({
2327
            base:likelihood/np.power(0.25,n-1)
2328
                    for base,(likelihood,n) in base_likelihoods.items() })
2329
                    for location,base_likelihoods in obs.items()}
2330
2331
2332
2333
    def calculate_allele_likelihoods(self):
2334
        self.aibd = defaultdict(list)
2335
        self.obtained_allele_likelihoods = Counter()  # Allele -> [prob, prob, prob]
2336
2337
        for (chrom, pos), base_probs in self.allele_informative_base_probabilities.items():
2338
2339
            for base, p in base_probs.items():
2340
                if base == 'N':
2341
                    continue
2342
2343
                assoc_alleles = self.allele_resolver.getAllelesAt(chrom, pos, base)
2344
                if assoc_alleles is not None and len(assoc_alleles) == 1:
2345
                    allele = list(assoc_alleles)[0]
2346
                    self.obtained_allele_likelihoods[allele] += p
2347
2348
                    self.aibd[allele].append((chrom, pos, base, p))
2349
2350
2351
2352
    def get_allele_likelihoods(self,):
2353
        """Obtain the allele(s) this molecule maps to
2354
2355
        Args:
2356
            allele_resolver(singlecellmultiomics.alleleTools.AlleleResolver)  : resolver used
2357
            return_allele_informative_base_dict(bool) : return dictionary containing the bases used for allele determination
2358
            defaultdict(list,
2359
            {'allele1': [
2360
              ('chr18', 410937, 'T'),
2361
              ('chr18', 410943, 'G'),
2362
              ('chr18', 410996, 'G'),
2363
              ('chr18', 411068, 'A')]})
2364
2365
        Returns:
2366
            { 'allele_a': likelihood, 'allele_b':likelihood }
2367
        """
2368
        if self.obtained_allele_likelihoods is None:
2369
            self.calculate_allele_likelihoods()
2370
2371
        return self.obtained_allele_likelihoods, self.aibd
2372
2373
2374
2375
    def get_allele(
2376
            self,
2377
            allele_resolver=None,
2378
            return_allele_informative_base_dict=False):
2379
        """Obtain the allele(s) this molecule maps to
2380
2381
        Args:
2382
            allele_resolver(singlecellmultiomics.alleleTools.AlleleResolver)  : resolver used
2383
            return_allele_informative_base_dict(bool) : return dictionary containing the bases used for allele determination
2384
            defaultdict(list,
2385
            {'allele1': [('chr18', 410937, 'T'),
2386
              ('chr18', 410943, 'G'),
2387
              ('chr18', 410996, 'G'),
2388
              ('chr18', 411068, 'A')]})
2389
2390
        Returns:
2391
            alleles(set( str )) : Set of strings containing associated alleles
2392
        """
2393
2394
        if allele_resolver is None:
2395
            if self.allele_resolver is not None:
2396
                allele_resolver = self.allele_resolver
2397
            else:
2398
                raise ValueError(
2399
                    "Supply allele resolver or set it to molecule.allele_resolver")
2400
2401
        alleles = set()
2402
        if return_allele_informative_base_dict:
2403
            aibd = defaultdict(list)
2404
        try:
2405
            for (chrom, pos), base in self.get_consensus(
2406
                    base_obs=self.get_base_observation_dict_NOREF()).items():
2407
                c = allele_resolver.getAllelesAt(chrom, pos, base)
2408
                if c is not None and len(c) == 1:
2409
                    alleles.update(c)
2410
                    if return_allele_informative_base_dict:
2411
                        aibd[list(c)[0]].append((chrom, pos, base))
2412
2413
        except Exception as e:
2414
            if return_allele_informative_base_dict:
2415
                return dict()
2416
            else:
2417
                return {}
2418
2419
        if return_allele_informative_base_dict:
2420
            return aibd
2421
        return alleles
2422
2423
    def write_allele_phasing_information_tag(
2424
            self, allele_resolver=None, tag='ap', reads=None):
2425
        """
2426
        Write allele phasing information to ap tag
2427
2428
        For every associated read a tag wil be written containing:
2429
        chromosome,postion,base,allele_name|chromosome,postion,base,allele_name|...
2430
        for all variants found by the AlleleResolver
2431
        """
2432
        if reads is None:
2433
            reads = self.iter_reads()
2434
2435
        use_likelihood = (self.allele_assingment_method==1)
2436
2437
        if not use_likelihood:
2438
            haplotype = self.get_allele(
2439
                return_allele_informative_base_dict=True,
2440
                allele_resolver=allele_resolver)
2441
2442
            phased_locations = [
2443
                (allele, chromosome, position, base)
2444
                for allele, bps in haplotype.items()
2445
                for chromosome, position, base in bps]
2446
2447
            phase_str = '|'.join(
2448
                [
2449
                    f'{chromosome},{position},{base},{allele}' for allele,
2450
                                                                   chromosome,
2451
                                                                   position,
2452
                                                                   base in phased_locations])
2453
        else:
2454
2455
            allele_likelihoods, aibd = self.get_allele_likelihoods()
2456
            allele_likelihoods = likelihood_to_prob(allele_likelihoods)
2457
2458
            phased_locations = [
2459
                (allele, chromosome, position, base, confidence)
2460
                for allele, bps in aibd.items()
2461
                for chromosome, position, base, confidence in bps]
2462
2463
            phase_str = '|'.join(
2464
                [
2465
                    f'{chromosome},{position},{base},{allele},{ prob_to_phred(confidence) }' for allele,
2466
                                                                   chromosome,
2467
                                                                   position,
2468
                                                                   base,
2469
                                                                   confidence in phased_locations])
2470
2471
2472
2473
        if len(phase_str) > 0:
2474
            for read in reads:
2475
                read.set_tag(tag, phase_str)
2476
                if use_likelihood:
2477
                    read.set_tag('al', self.allele_confidence)
2478
2479
    def get_base_observation_dict_NOREF(self, allow_N=False):
2480
        '''
2481
        identical to get_base_observation_dict but does not obtain reference bases,
2482
        has to be used when no MD tag is present
2483
        Args:
2484
            return_refbases ( bool ):
2485
                return both observed bases and reference bases
2486
2487
        Returns:
2488
            { genome_location (tuple) : base (string) : obs (int) }
2489
            and
2490
            { genome_location (tuple) : base (string) if return_refbases is True }
2491
        '''
2492
2493
        base_obs = defaultdict(Counter)
2494
2495
        used = 0  # some alignments yielded valid calls
2496
        ignored = 0
2497
        for fragment in self:
2498
            _, start, end = fragment.span
2499
            for read in fragment:
2500
                if read is None:
2501
                    continue
2502
2503
                for cycle, query_pos, ref_pos in pysamiterators.iterators.ReadCycleIterator(
2504
                        read, with_seq=False):
2505
2506
                    if query_pos is None or ref_pos is None or ref_pos < start or ref_pos > end:
2507
                        continue
2508
                    query_base = read.seq[query_pos]
2509
                    if query_base == 'N' and not allow_N:
2510
                        continue
2511
                    base_obs[(read.reference_name, ref_pos)][query_base] += 1
2512
2513
        if used == 0 and ignored > 0:
2514
            raise ValueError('Could not extract any safe data from molecule')
2515
2516
        return base_obs
2517
2518
    def get_base_observation_dict(self, return_refbases=False, allow_N=False,
2519
        allow_unsafe=True, one_call_per_frag=False, min_cycle_r1=None,
2520
         max_cycle_r1=None, min_cycle_r2=None, max_cycle_r2=None, use_cache=True, min_bq=None):
2521
        '''
2522
        Obtain observed bases at reference aligned locations
2523
2524
        Args:
2525
            return_refbases ( bool ):
2526
                return both observed bases and reference bases
2527
            allow_N (bool): Keep N base calls in observations
2528
2529
            min_cycle_r1(int) : Exclude read 1 base calls with a cycle smaller than this value (excludes bases which are trimmed before mapping)
2530
2531
            max_cycle_r1(int) : Exclude read 1 base calls with a cycle larger than this value (excludes bases which are trimmed before mapping)
2532
2533
            min_cycle_r2(int) : Exclude read 2 base calls with a cycle smaller than this value (excludes bases which are trimmed before mapping)
2534
2535
            max_cycle_r2(int) : Exclude read 2 base calls with a cycle larger than this value (excludes bases which are trimmed before mapping)
2536
2537
2538
        Returns:
2539
            { genome_location (tuple) : base (string) : obs (int) }
2540
            and
2541
            { genome_location (tuple) : base (string) if return_refbases is True }
2542
        '''
2543
2544
        # Check if cached is available
2545
        if use_cache:
2546
            if self.saved_base_obs is not None :
2547
                if not return_refbases:
2548
                    return self.saved_base_obs[0]
2549
                else:
2550
                    if self.saved_base_obs[1] is not None:
2551
                        return self.saved_base_obs
2552
2553
        base_obs = defaultdict(Counter)
2554
2555
        ref_bases = {}
2556
        used = 0  # some alignments yielded valid calls
2557
        ignored = 0
2558
        error = None
2559
        for fragment in self:
2560
            _, start, end = fragment.span
2561
2562
            used += 1
2563
2564
            if one_call_per_frag:
2565
                frag_location_obs = set()
2566
2567
            for read in fragment:
2568
                if read is None:
2569
                    continue
2570
2571
                if allow_unsafe:
2572
                    for query_pos, ref_pos, ref_base in read.get_aligned_pairs(matches_only=True, with_seq=True):
2573
                        if query_pos is None or ref_pos is None:  # or ref_pos < start or ref_pos > end:
2574
                            continue
2575
2576
                        query_base = read.seq[query_pos]
2577
                        # query_qual = read.qual[query_pos]
2578
                        if min_bq is not None and read.query_qualities[query_pos]<min_bq:
2579
                            continue
2580
2581
                        if query_base == 'N':
2582
                            continue
2583
2584
                        k = (read.reference_name, ref_pos)
2585
2586
                        if one_call_per_frag:
2587
                            if k in frag_location_obs:
2588
                                continue
2589
                            frag_location_obs.add(k)
2590
2591
                        base_obs[k][query_base] += 1
2592
2593
                        if return_refbases:
2594
                            ref_bases[(read.reference_name, ref_pos)
2595
                            ] = ref_base.upper()
2596
2597
2598
                else:
2599
                    for cycle, query_pos, ref_pos, ref_base in pysamiterators.iterators.ReadCycleIterator(
2600
                            read, with_seq=True, reference=self.reference):
2601
2602
                        if query_pos is None or ref_pos is None:  # or ref_pos < start or ref_pos > end:
2603
                            continue
2604
2605
                        # Verify cycle filters:
2606
                        if (not read.is_paired or read.is_read1) and (
2607
                                ( min_cycle_r1 is not None and cycle <  min_cycle_r1 ) or
2608
                                ( max_cycle_r1 is not None and  cycle >  max_cycle_r1 )):
2609
                            continue
2610
2611
                        if (read.is_paired and read.is_read2) and (
2612
                                ( min_cycle_r2 is not None and cycle <  min_cycle_r2 ) or
2613
                                ( max_cycle_r2 is not None and cycle >  max_cycle_r2 )):
2614
                            continue
2615
2616
                        query_base = read.seq[query_pos]
2617
                        # Skip bases with low bq:
2618
                        if min_bq is not None and read.query_qualities[query_pos]<min_bq:
2619
                            continue
2620
2621
                        if query_base == 'N':
2622
                            continue
2623
2624
                        k = (read.reference_name, ref_pos)
2625
                        if one_call_per_frag:
2626
                            if k in frag_location_obs:
2627
                                continue
2628
                            frag_location_obs.add(k)
2629
2630
                        base_obs[(read.reference_name, ref_pos)][query_base] += 1
2631
2632
                        if return_refbases:
2633
                            ref_bases[(read.reference_name, ref_pos)
2634
                            ] = ref_base.upper()
2635
2636
        if used == 0 and ignored > 0:
2637
            raise ValueError(f'Could not extract any safe data from molecule {error}')
2638
2639
        self.saved_base_obs = (base_obs, ref_bases)
2640
2641
        if return_refbases:
2642
            return base_obs, ref_bases
2643
2644
        return base_obs
2645
2646
    def get_match_mismatch_frequency(self, ignore_locations=None):
2647
        """Get amount of base-calls matching and mismatching the reference sequence,
2648
           mismatches in every read are counted
2649
2650
        Args:
2651
            ignore_locations (iterable(tuple([chrom(str),pos(int)])) ) :
2652
                Locations not to take into account for the match and mismatch frequency
2653
2654
        Returns:
2655
            matches(int), mismatches(int)
2656
        """
2657
        matches = 0
2658
        mismatches = 0
2659
2660
        base_obs, ref_bases = self.get_base_observation_dict(
2661
            return_refbases=True)
2662
        for location, obs in base_obs.items():
2663
            if ignore_locations is not None and location in ignore_locations:
2664
                continue
2665
2666
            if location in ref_bases:
2667
                ref = ref_bases[location]
2668
                if ref not in 'ACTG':  # don't count weird bases in the reference @warn
2669
                    continue
2670
                matches += obs[ref]
2671
                mismatches += sum((base_obs for base,
2672
                                                base_obs in obs.most_common() if base != ref))
2673
2674
        return matches, mismatches
2675
2676
    def get_consensus(self,
2677
                    dove_safe: bool = False,
2678
                    only_include_refbase: str = None,
2679
                    allow_N=False,
2680
                    with_probs_and_obs=False, **get_consensus_dictionaries_kwargs):
2681
        """
2682
        Obtain consensus base-calls for the molecule
2683
        """
2684
2685
        if allow_N:
2686
            raise NotImplementedError()
2687
2688
        consensii = defaultdict(consensii_default_vector)  # location -> obs (A,C,G,T,N)
2689
        if with_probs_and_obs:
2690
            phred_scores = defaultdict(lambda:defaultdict(list))
2691
        for fragment in self:
2692
            if dove_safe and not fragment.has_R2() or not fragment.has_R1():
2693
                continue
2694
2695
            try:
2696
                for position, (q_base, phred_score) in fragment.get_consensus(
2697
                        dove_safe=dove_safe,only_include_refbase=only_include_refbase,
2698
                        **get_consensus_dictionaries_kwargs).items():
2699
2700
                    if q_base == 'N':
2701
                        continue
2702
                    #    consensii[position][4] += phred_score
2703
                    # else:
2704
                    if with_probs_and_obs:
2705
                        phred_scores[position][q_base].append(phred_score)
2706
2707
                    consensii[position]['ACGTN'.index(q_base)] += 1
2708
            except ValueError as e:
2709
                # For example: ValueError('This method only works for inwards facing reads')
2710
                pass
2711
        if len(consensii)==0:
2712
            if with_probs_and_obs:
2713
                return dict(),None,None
2714
            else:
2715
                return dict()
2716
2717
        locations = np.empty(len(consensii), dtype=object)
2718
        locations[:] = sorted(list(consensii.keys()))
2719
2720
        v = np.vstack([ consensii[location] for location in locations])
2721
        majority_base_indices = np.argmax(v, axis=1)
2722
2723
        # Check if there is ties, this result in multiple hits for argmax (majority_base_indices),
2724
        # such a situtation is of course terrible and should be dropped
2725
        proper = (v == v[np.arange(v.shape[0]), majority_base_indices][:, np.newaxis]).sum(1) == 1
2726
2727
        if with_probs_and_obs:
2728
            return (
2729
                dict(zip(locations[proper], ['ACGTN'[idx] for idx in majority_base_indices[proper]])),
2730
                phred_scores,
2731
                consensii
2732
            )
2733
        else:
2734
            return  dict(zip(locations[proper], ['ACGTN'[idx] for idx in majority_base_indices[proper]]))
2735
2736
2737
    def get_consensus_old(
2738
            self,
2739
            base_obs=None,
2740
            classifier=None,
2741
            store_consensus=True,
2742
            reuse_cached_consensus=True,
2743
            allow_unsafe=False,
2744
            allow_N=False):
2745
        """Get dictionary containing consensus calls in respect to reference.
2746
        By default mayority voting is used to determine the consensus base. If a classifier is supplied the classifier is used to determine the consensus base.
2747
2748
        Args:
2749
            base_obs (defaultdict(Counter)) :
2750
                { genome_location (tuple) : base (string) : obs (int) }
2751
2752
            classifier : fitted classifier to use for consensus calling. When no classifier is provided the consensus is determined by majority voting
2753
            store_consensus (bool) : Store the generated consensus for re-use
2754
2755
        Returns:
2756
            consensus (dict)  :  {location : base}
2757
        """
2758
        consensus = {}  # postion -> base , key is not set when not decided
2759
2760
        if classifier is not None:
2761
2762
            if reuse_cached_consensus and hasattr(
2763
                    self, 'classifier_consensus') and self.classifier_consensus is not None:
2764
                return self.classifier_consensus
2765
2766
            features, reference_bases, CIGAR, alignment_start, alignment_end = self.get_base_calling_feature_matrix_spaced(
2767
                True)
2768
2769
            if features is None:
2770
                # We cannot determine the consensus as there are no features...
2771
                return dict()
2772
2773
            base_calling_probs = classifier.predict_proba(features)
2774
            predicted_sequence = ['ACGT'[i] for i in np.argmax(base_calling_probs, 1)]
2775
2776
            reference_sequence = ''.join(
2777
                [base for chrom, pos, base in reference_bases])
2778
2779
            phred_scores = np.rint(
2780
                -10 * np.log10(np.clip(1 - base_calling_probs.max(1),
2781
                                       0.000000001,
2782
                                       0.999999999)
2783
                               )).astype('B')
2784
2785
            consensus = {(chrom, pos): consensus_base for (
2786
                                                              chrom, pos, ref_base), consensus_base in
2787
                         zip(reference_bases, predicted_sequence)}
2788
2789
            if store_consensus:
2790
                self.classifier_consensus = consensus
2791
                self.classifier_phred_scores = phred_scores
2792
            return consensus
2793
2794
        if base_obs is None:
2795
            try:
2796
                base_obs, ref_bases = self.get_base_observation_dict(
2797
                    return_refbases=True, allow_N=allow_N, allow_unsafe=allow_unsafe)
2798
            except ValueError as e:
2799
                # We cannot determine safe regions
2800
                raise
2801
2802
        for location, obs in base_obs.items():
2803
            votes = obs.most_common()
2804
            if len(votes) == 1 or votes[1][1] < votes[0][1]:
2805
                consensus[location] = votes[0][0]
2806
2807
        if store_consensus:
2808
            self.majority_consensus = consensus
2809
2810
        return consensus
2811
2812
    def get_consensus_base(self, contig, position, classifier=None):
2813
        """Obtain base call at single position of the molecule
2814
2815
        Args:
2816
            contig (str) : contig to extract base call from
2817
2818
            position (int) : position to extract base call from (zero based)
2819
2820
            classifier (obj) : base calling classifier
2821
2822
        Returns:
2823
2824
            base_call (str) : base call, or None when no base call could be made
2825
        """
2826
2827
        try:
2828
            c = self.get_consensus(classifier)
2829
        except ValueError:
2830
            return None
2831
        return c.get((contig, position), None)
2832
2833
    # when enabled other calls (non ref non alt will be set None)
2834
    def check_variants(self, variants, exclude_other_calls=True):
2835
        """Verify variants in molecule
2836
2837
        Args:
2838
            variants (pysam.VariantFile) : Variant file handle to extract variants from
2839
2840
        Returns:
2841
            dict (defaultdict( Counter )) : { (chrom,pos) : ( call (str) ): observations  (int) }
2842
        """
2843
        variant_dict = {}
2844
        for variant in variants.fetch(
2845
                self.chromosome,
2846
                self.spanStart,
2847
                self.spanEnd):
2848
            variant_dict[(variant.chrom, variant.pos - 1)
2849
            ] = (variant.ref, variant.alts)
2850
2851
        variant_calls = defaultdict(Counter)
2852
        for fragment in self:
2853
2854
            _, start, end = fragment.span
2855
            for read in fragment:
2856
                if read is None:
2857
                    continue
2858
2859
                for cycle, query_pos, ref_pos in pysamiterators.iterators.ReadCycleIterator(
2860
                        read):
2861
2862
                    if query_pos is None or ref_pos is None or ref_pos < start or ref_pos > end:
2863
                        continue
2864
                    query_base = read.seq[query_pos]
2865
2866
                    k = (read.reference_name, ref_pos)
2867
                    if k in variant_dict:
2868
                        call = None
2869
                        ref, alts = variant_dict[k]
2870
                        if query_base == ref:
2871
                            call = ('ref', query_base)
2872
                        elif query_base in alts:
2873
                            call = ('alt', query_base)
2874
2875
                        if not exclude_other_calls or call is not None:
2876
                            variant_calls[k][call] += 1
2877
2878
        return variant_calls
2879
2880
    def get_aligned_reference_bases_dict(self):
2881
        """Get dictionary containing all reference bases to which this molecule aligns
2882
        Returns:
2883
            aligned_reference_positions (dict) :  { (chrom,pos) : 'A', (chrom,pos):'T', .. }
2884
        """
2885
        aligned_reference_positions = {}
2886
        for read in self.iter_reads():
2887
            for read_pos, ref_pos, ref_base in read.get_aligned_pairs(
2888
                    with_seq=True, matches_only=True):
2889
                aligned_reference_positions[(
2890
                    read.reference_name, ref_pos)] = ref_base.upper()
2891
        return aligned_reference_positions
2892
2893
    def iter_reads(self):
2894
        """Iterate over all associated reads
2895
        Returns:
2896
            generator (pysam.AlignedSegment)
2897
        """
2898
2899
        for fragment in self.fragments:
2900
            for read in fragment:
2901
                if read is not None:
2902
                    yield read
2903
2904
    def __iter__(self):
2905
        """Iterate over all associated fragments
2906
2907
        Yields:
2908
            singlecellmultiomics.fragment.Fragment
2909
        """
2910
        for fragment in self.fragments:
2911
            yield fragment
2912
2913
    @property
2914
    def span_len(self):
2915
        return abs(self.spanEnd - self.spanStart)
2916
2917
    def get_methylated_count(self, context=3):
2918
        """Get the total amount of methylated bases
2919
2920
        Args:
2921
            context (int) : 3 or 4 base context
2922
2923
        Returns:
2924
            r (Counter) : sum of methylated bases in contexts
2925
        """
2926
2927
        r = Counter()
2928
2929
    def get_html(
2930
            self,
2931
            reference=None,
2932
            consensus=None,
2933
            show_reference_sequence=True,
2934
            show_consensus_sequence=True,
2935
            reference_bases=None):
2936
        """Get html representation of the molecule
2937
        Returns:
2938
            html_rep(str) : Html representation of the molecule
2939
        """
2940
2941
        html = f"""<h3>{self.chromosome}:{self.spanStart}-{self.spanEnd}
2942
            sample:{self.get_sample()}  {'valid molecule' if self[0].is_valid() else 'Non valid molecule'}</h3>
2943
            <h5>UMI:{self.get_umi()} Mapping qual:{round(self.get_mean_mapping_qual(), 1)} Cut loc: {"%s:%s" % self[0].get_site_location()} </h5>
2944
            <div style="white-space:nowrap; font-family:monospace; color:#888">"""
2945
        # undigested:{self.get_undigested_site_count()}
2946
        consensus = self.get_consensus()
2947
2948
        # Obtain reference bases dictionary:
2949
        if reference_bases is None:
2950
            if reference is None:
2951
                reference_bases = self.get_aligned_reference_bases_dict()
2952
2953
            else:
2954
                # obtain reference_bases from reference file
2955
                raise NotImplementedError()
2956
2957
        for fragment in itertools.chain(*self.get_rt_reactions().values()):
2958
            html += f'<h5>{fragment.get_R1().query_name}</h5>'
2959
            for readid, read in [
2960
                (1, fragment.get_R1()),
2961
                (2, fragment.get_R2())]:  # go over R1 and R2:
2962
                # This is just the sequence:
2963
                if read is None:
2964
                    continue
2965
                html += fragment.get_html(
2966
                    self.chromosome,
2967
                    self.spanStart,
2968
                    self.spanEnd,
2969
                    show_read1=(readid == 1),
2970
                    show_read2=(readid == 2)
2971
                ) + '<br />'
2972
2973
        # Obtain reference sequence and consensus sequence:
2974
        if consensus is None:
2975
            consensus = self.get_consensus()
2976
2977
        span_len = self.spanEnd - self.spanStart
2978
        visualized = ['.'] * span_len
2979
        reference_vis = ['?'] * span_len
2980
        for location, query_base in consensus.items():
2981
            try:
2982
                if reference_bases is None or reference_bases.get(
2983
                        location, '?') == query_base:
2984
                    visualized[location[1] - self.spanStart] = query_base
2985
                    if reference_bases is not None:
2986
                        # or reference_bases.get(location,'?')
2987
                        reference_vis[location[1] -
2988
                                      self.spanStart] = query_base
2989
                else:
2990
                    visualized[location[1] -
2991
                               self.spanStart] = style_str(query_base, color='red', weight=800)
2992
                    if reference_bases is not None:
2993
                        reference_vis[location[1] - self.spanStart] = style_str(
2994
                            reference_bases.get(location, '?'), color='black', weight=800)
2995
            except IndexError as e:
2996
                pass  # Tried to visualize a base outside view
2997
2998
        if show_consensus_sequence:
2999
            html += ''.join(visualized) + '<br />'
3000
3001
        if show_reference_sequence:
3002
            html += ''.join(reference_vis) + '<br />'
3003
3004
        html += "</div>"
3005
        return html
3006
3007
    def get_methylation_dict(self):
3008
        """Obtain methylation dictionary
3009
3010
        Returns:
3011
            methylated_positions (Counter):
3012
                (read.reference_name, rpos) : times seen methylated
3013
3014
            methylated_state (dict):
3015
                {(read.reference_name, rpos) : 1/0/-1 }
3016
                1 for methylated
3017
                0 for unmethylated
3018
                -1 for unknown
3019
3020
        """
3021
        methylated_positions = Counter()  # chrom-pos->count
3022
        methylated_state = dict()  # chrom-pos->1, 0, -1
3023
        for fragment in self:
3024
            for read in fragment:
3025
                if read is None or not read.has_tag('XM'):
3026
                    continue
3027
                methylation_status_string = read.get_tag('XM')
3028
                i = 0
3029
                for qpos, rpos, ref_base in read.get_aligned_pairs(
3030
                        with_seq=True):
3031
                    if qpos is None:
3032
                        continue
3033
                    if ref_base is None:
3034
                        continue
3035
                    if rpos is None:
3036
                        continue
3037
                    methylation_status = methylation_status_string[i]
3038
                    if methylation_status.isupper():
3039
                        methylated_positions[(read.reference_name, rpos)] += 1
3040
                        if methylated_state.get(
3041
                                (read.reference_name, rpos), 1) == 1:
3042
                            methylated_state[(read.reference_name, rpos)] = 1
3043
                        else:
3044
                            methylated_state[(read.reference_name, rpos)] = -1
3045
                    else:
3046
                        if methylation_status == '.':
3047
                            pass
3048
                        else:
3049
                            if methylated_state.get(
3050
                                    (read.reference_name, rpos), 0) == 0:
3051
                                methylated_state[(
3052
                                    read.reference_name, rpos)] = 0
3053
                            else:
3054
                                methylated_state[(
3055
                                    read.reference_name, rpos)] = -1
3056
                    i += 1
3057
        return methylated_positions, methylated_state
3058
3059
3060
    def _get_allele_from_reads(self) -> str:
3061
        """
3062
        Obtain associated allele based on the associated reads of the molecule
3063
3064
        """
3065
        allele = None
3066
        for frag in self:
3067
            for read in frag:
3068
                if read is None or not read.has_tag('DA'):
3069
                    continue
3070
                allele = read.get_tag('DA')
3071
                return allele
3072
        return None