Switch to side-by-side view

--- a
+++ b/singlecellmultiomics/bamProcessing/estimateTapsConversionEfficiency.py
@@ -0,0 +1,250 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+
+import matplotlib.pyplot as plt
+from singlecellmultiomics.molecule import MoleculeIterator, CHICNLAMolecule, TAPSNlaIIIMolecule,TAPSCHICMolecule,TAPS
+from singlecellmultiomics.fragment import CHICFragment, NlaIIIFragment
+import pysam
+from pysamiterators import CachedFasta
+from singlecellmultiomics.variants.substitutions import conversion_dict_stranded
+from collections import defaultdict
+from singlecellmultiomics.utils import reverse_complement, complement
+from glob import glob
+from multiprocessing import Pool
+from singlecellmultiomics.bamProcessing.bamFunctions import get_reference_path_from_bam
+from collections import Counter
+import pandas as pd
+import matplotlib as mpl
+from singlecellmultiomics.utils.sequtils import phred_to_prob, prob_to_phred
+import seaborn as sns
+import argparse
+
+def update_mutation_dict(molecule,reference, conversions_per_library, context_obs):
+
+    consensus = molecule.get_consensus(dove_safe=True,
+                                       min_phred_score=22,
+                                       skip_first_n_cycles_R1=10,
+                                       skip_last_n_cycles_R1=20,
+                                       skip_first_n_cycles_R2=10,
+                                       skip_last_n_cycles_R2=20,
+                                       dove_R2_distance=15,
+                                       dove_R1_distance=15
+
+
+                                      )
+
+    nm = 0
+
+
+    contexts_to_add = []
+
+
+    for (chrom,pos), base in consensus.items():
+        context = reference.fetch(chrom, pos-1, pos+2).upper()
+
+        if len(context)!=3:
+            continue
+
+        # Check if the base matches or the refence contains N's
+        if 'N' in context or len(context)!=3:
+            continue
+
+        # Ignore germline variants:
+        #if might_be_variant(chrom, pos,  known):
+        #    continue
+
+        if not molecule.strand: # reverse template
+            context = reverse_complement(context)
+            base = complement(base)
+
+        if context[1]!='C' and context[1]!=base:
+            nm+=1
+
+        contexts_to_add.append((context,base))
+
+
+    if nm>5:
+        nm=5
+
+    k = tuple((*molecule.sample.rsplit('_',2), nm))
+    for (context, base) in contexts_to_add:
+
+        context_obs[ k ][context] += 1
+        try:
+            conversions_per_library[k][(context, base)] += 1
+        except:
+            pass
+
+
+def get_conversion_counts(args):
+
+
+    taps = TAPS()
+
+    conversions_per_library = defaultdict( conversion_dict_stranded )
+    context_obs = defaultdict( Counter )
+
+    bam,refpath,method,every_fragment_as_molecule, spikein_name  = args
+
+
+    if method=='nla':
+        fragment_class=NlaIIIFragment
+        molecule_class=TAPSNlaIIIMolecule
+    else:
+        fragment_class=CHICFragment
+        molecule_class=TAPSCHICMolecule
+
+    with pysam.FastaFile(refpath) as ref:
+        reference = CachedFasta(ref)
+
+
+        print(f'Processing {bam}')
+
+        with pysam.AlignmentFile(bam, threads=8) as al:
+
+            for molecule in MoleculeIterator(
+                al,
+                fragment_class=fragment_class,
+                molecule_class=molecule_class,
+                molecule_class_args={
+                     'reference':reference,
+                     'taps':taps,
+                    'taps_strand':'R'
+                },
+                every_fragment_as_molecule=every_fragment_as_molecule,
+                fragment_class_args={},
+                contig = spikein_name
+            ):
+                update_mutation_dict(molecule, reference ,conversions_per_library, context_obs)
+
+    return conversions_per_library, context_obs
+
+
+def generate_taps_conversion_stats(bams, reference_path, prefix, method, every_fragment_as_molecule, spikein_name, n_threads=None):
+    if reference_path is None:
+        reference_path = get_reference_path_from_bam(bams[0])
+
+    print(f'Reference at {reference_path}')
+    if reference_path is None:
+        raise ValueError('Please supply a reference fasta file')
+
+    conversions_per_library = defaultdict( conversion_dict_stranded )
+    context_obs = defaultdict( Counter )
+
+    with Pool(n_threads) as workers:
+
+        for cl, co in workers.imap(get_conversion_counts, [(bam, reference_path, method, every_fragment_as_molecule, spikein_name) for bam in bams] ):
+
+            for lib, obs in cl.items():
+                for k,v in obs.items():
+                    conversions_per_library[lib][k] +=v
+
+            for lib, obs in co.items():
+                for k,v in obs.items():
+                    context_obs[lib][k] += v
+
+
+    qf = pd.DataFrame(context_obs)
+    qf.to_csv(f'{prefix}_conversions_counts_raw_lambda.csv')
+    ###
+    indices = []
+    for lib, qqf in qf.groupby(level=0,axis=1):
+
+        ser = qqf.sum(level=(0,1),axis=1).sum().sort_values(ascending=False)
+        ser = ser[ser>5000][::-1]
+        indices += list(ser.index)
+
+
+    ###
+
+    normed_conversions_per_library = defaultdict( conversion_dict_stranded )
+
+    for INDEX in context_obs:
+        for (context, base),obs in conversions_per_library[INDEX].items():
+            try:
+                normed_conversions_per_library[INDEX][(context,base)] = obs/ context_obs[INDEX][context]
+            except Exception:
+                pass
+
+    df = pd.DataFrame(normed_conversions_per_library)
+    df = df[ [INDEX for INDEX in df if (INDEX[0],  INDEX[1]) in indices] ]
+
+    df = df.loc[ [(context, base)for context, base in df.index if context[1]=='C' and base=='T' and context.endswith('CG')]  ]
+    df = df.T
+
+    df.to_csv(f'{prefix}_conversions_lambda.csv')
+
+    mpl.rcParams['figure.dpi'] = 300
+
+    samples = []
+
+    for (lib, cell, nm), row in df.iterrows():
+
+        if nm!=0:
+            continue
+
+        for context, base in [('ACG', 'T'),
+            ('CCG', 'T'),
+            ('GCG', 'T'),
+            ('TCG', 'T')]:
+
+            r = {
+            'lib':lib,
+            'cell':cell,
+            'nm':nm,
+                #'plate': int(lib.split('-')[-1].replace('pl','')),
+            'group':  f'{nm},{context},{cell}',
+            'context': f'{context}>{base}',
+            'conversion rate':row[context,base]
+            }
+
+            samples.append(r)
+
+    plot_table = pd.DataFrame(samples)
+    print(plot_table)
+
+    ph = 22
+    fig, ax = plt.subplots(figsize=(12,5))
+    sns.boxplot(data=plot_table.sort_values('lib'),x='context', y='conversion rate',hue='lib',whis=6, ax=ax)
+
+    #ax = sns.swarmplot(data=plot_table,x='nm', y='conversion rate',hue='lib',)
+
+    plt.legend()
+    plt.ylabel('Lambda Conversion rate')
+    sns.despine()
+    plt.legend(loc='center left', bbox_to_anchor=(1, 0.5 ))
+
+    plt.suptitle('Estimated TAPS conversion rate', y=1.05, fontsize=12)
+    plt.title(f'Lambda spike-in, >{(1.0-(phred_to_prob(22)))*100 : .2f}% accuracy base calls', fontsize=10)
+    plt.tight_layout()
+    plt.savefig(f'{prefix}_conversion_rate_phred_{ph}.png', bbox_inches='tight')
+    plt.close()
+
+
+if __name__=='__main__':
+    argparser = argparse.ArgumentParser(
+      formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+      description='Estimate the conversion efficiency of a TAPS converted file. ')
+    argparser.add_argument('bams', type=str, nargs='+',help='Input bam files')
+    argparser.add_argument('-o', type=str, help="output alias (Will be the prefix of the output files)", required=True)
+    argparser.add_argument('-method', type=str, default='nla', help='Molecule class (nla or chic). Use chic when you are not sure or when another other protocol is used.')
+    argparser.add_argument('--dedup', action='store_true',help='perform UMI deduplication and consensus calling. Do not use when the UMI\'s are (near) saturated')
+    argparser.add_argument('-t', type=int, help='Amount of threads')
+    argparser.add_argument('-spikein_name', type=str, help='Name of spikein contig',default='J02459.1')
+
+
+    argparser.add_argument(
+        '-ref',
+        type=str,
+        default=None,
+        help="Path to reference fast (autodected if not supplied)")
+    args = argparser.parse_args()
+
+
+    generate_taps_conversion_stats(args.bams,
+                                   args.ref,
+                                   prefix=args.o,
+                                   method=args.method,
+                                   every_fragment_as_molecule=not args.dedup,
+                                   spikein_name=args.spikein_name,
+                                   n_threads=args.t)