--- a +++ b/singlecellmultiomics/bamProcessing/estimateTapsConversionEfficiency.py @@ -0,0 +1,250 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import matplotlib.pyplot as plt +from singlecellmultiomics.molecule import MoleculeIterator, CHICNLAMolecule, TAPSNlaIIIMolecule,TAPSCHICMolecule,TAPS +from singlecellmultiomics.fragment import CHICFragment, NlaIIIFragment +import pysam +from pysamiterators import CachedFasta +from singlecellmultiomics.variants.substitutions import conversion_dict_stranded +from collections import defaultdict +from singlecellmultiomics.utils import reverse_complement, complement +from glob import glob +from multiprocessing import Pool +from singlecellmultiomics.bamProcessing.bamFunctions import get_reference_path_from_bam +from collections import Counter +import pandas as pd +import matplotlib as mpl +from singlecellmultiomics.utils.sequtils import phred_to_prob, prob_to_phred +import seaborn as sns +import argparse + +def update_mutation_dict(molecule,reference, conversions_per_library, context_obs): + + consensus = molecule.get_consensus(dove_safe=True, + min_phred_score=22, + skip_first_n_cycles_R1=10, + skip_last_n_cycles_R1=20, + skip_first_n_cycles_R2=10, + skip_last_n_cycles_R2=20, + dove_R2_distance=15, + dove_R1_distance=15 + + + ) + + nm = 0 + + + contexts_to_add = [] + + + for (chrom,pos), base in consensus.items(): + context = reference.fetch(chrom, pos-1, pos+2).upper() + + if len(context)!=3: + continue + + # Check if the base matches or the refence contains N's + if 'N' in context or len(context)!=3: + continue + + # Ignore germline variants: + #if might_be_variant(chrom, pos, known): + # continue + + if not molecule.strand: # reverse template + context = reverse_complement(context) + base = complement(base) + + if context[1]!='C' and context[1]!=base: + nm+=1 + + contexts_to_add.append((context,base)) + + + if nm>5: + nm=5 + + k = tuple((*molecule.sample.rsplit('_',2), nm)) + for (context, base) in contexts_to_add: + + context_obs[ k ][context] += 1 + try: + conversions_per_library[k][(context, base)] += 1 + except: + pass + + +def get_conversion_counts(args): + + + taps = TAPS() + + conversions_per_library = defaultdict( conversion_dict_stranded ) + context_obs = defaultdict( Counter ) + + bam,refpath,method,every_fragment_as_molecule, spikein_name = args + + + if method=='nla': + fragment_class=NlaIIIFragment + molecule_class=TAPSNlaIIIMolecule + else: + fragment_class=CHICFragment + molecule_class=TAPSCHICMolecule + + with pysam.FastaFile(refpath) as ref: + reference = CachedFasta(ref) + + + print(f'Processing {bam}') + + with pysam.AlignmentFile(bam, threads=8) as al: + + for molecule in MoleculeIterator( + al, + fragment_class=fragment_class, + molecule_class=molecule_class, + molecule_class_args={ + 'reference':reference, + 'taps':taps, + 'taps_strand':'R' + }, + every_fragment_as_molecule=every_fragment_as_molecule, + fragment_class_args={}, + contig = spikein_name + ): + update_mutation_dict(molecule, reference ,conversions_per_library, context_obs) + + return conversions_per_library, context_obs + + +def generate_taps_conversion_stats(bams, reference_path, prefix, method, every_fragment_as_molecule, spikein_name, n_threads=None): + if reference_path is None: + reference_path = get_reference_path_from_bam(bams[0]) + + print(f'Reference at {reference_path}') + if reference_path is None: + raise ValueError('Please supply a reference fasta file') + + conversions_per_library = defaultdict( conversion_dict_stranded ) + context_obs = defaultdict( Counter ) + + with Pool(n_threads) as workers: + + for cl, co in workers.imap(get_conversion_counts, [(bam, reference_path, method, every_fragment_as_molecule, spikein_name) for bam in bams] ): + + for lib, obs in cl.items(): + for k,v in obs.items(): + conversions_per_library[lib][k] +=v + + for lib, obs in co.items(): + for k,v in obs.items(): + context_obs[lib][k] += v + + + qf = pd.DataFrame(context_obs) + qf.to_csv(f'{prefix}_conversions_counts_raw_lambda.csv') + ### + indices = [] + for lib, qqf in qf.groupby(level=0,axis=1): + + ser = qqf.sum(level=(0,1),axis=1).sum().sort_values(ascending=False) + ser = ser[ser>5000][::-1] + indices += list(ser.index) + + + ### + + normed_conversions_per_library = defaultdict( conversion_dict_stranded ) + + for INDEX in context_obs: + for (context, base),obs in conversions_per_library[INDEX].items(): + try: + normed_conversions_per_library[INDEX][(context,base)] = obs/ context_obs[INDEX][context] + except Exception: + pass + + df = pd.DataFrame(normed_conversions_per_library) + df = df[ [INDEX for INDEX in df if (INDEX[0], INDEX[1]) in indices] ] + + df = df.loc[ [(context, base)for context, base in df.index if context[1]=='C' and base=='T' and context.endswith('CG')] ] + df = df.T + + df.to_csv(f'{prefix}_conversions_lambda.csv') + + mpl.rcParams['figure.dpi'] = 300 + + samples = [] + + for (lib, cell, nm), row in df.iterrows(): + + if nm!=0: + continue + + for context, base in [('ACG', 'T'), + ('CCG', 'T'), + ('GCG', 'T'), + ('TCG', 'T')]: + + r = { + 'lib':lib, + 'cell':cell, + 'nm':nm, + #'plate': int(lib.split('-')[-1].replace('pl','')), + 'group': f'{nm},{context},{cell}', + 'context': f'{context}>{base}', + 'conversion rate':row[context,base] + } + + samples.append(r) + + plot_table = pd.DataFrame(samples) + print(plot_table) + + ph = 22 + fig, ax = plt.subplots(figsize=(12,5)) + sns.boxplot(data=plot_table.sort_values('lib'),x='context', y='conversion rate',hue='lib',whis=6, ax=ax) + + #ax = sns.swarmplot(data=plot_table,x='nm', y='conversion rate',hue='lib',) + + plt.legend() + plt.ylabel('Lambda Conversion rate') + sns.despine() + plt.legend(loc='center left', bbox_to_anchor=(1, 0.5 )) + + plt.suptitle('Estimated TAPS conversion rate', y=1.05, fontsize=12) + plt.title(f'Lambda spike-in, >{(1.0-(phred_to_prob(22)))*100 : .2f}% accuracy base calls', fontsize=10) + plt.tight_layout() + plt.savefig(f'{prefix}_conversion_rate_phred_{ph}.png', bbox_inches='tight') + plt.close() + + +if __name__=='__main__': + argparser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + description='Estimate the conversion efficiency of a TAPS converted file. ') + argparser.add_argument('bams', type=str, nargs='+',help='Input bam files') + argparser.add_argument('-o', type=str, help="output alias (Will be the prefix of the output files)", required=True) + argparser.add_argument('-method', type=str, default='nla', help='Molecule class (nla or chic). Use chic when you are not sure or when another other protocol is used.') + argparser.add_argument('--dedup', action='store_true',help='perform UMI deduplication and consensus calling. Do not use when the UMI\'s are (near) saturated') + argparser.add_argument('-t', type=int, help='Amount of threads') + argparser.add_argument('-spikein_name', type=str, help='Name of spikein contig',default='J02459.1') + + + argparser.add_argument( + '-ref', + type=str, + default=None, + help="Path to reference fast (autodected if not supplied)") + args = argparser.parse_args() + + + generate_taps_conversion_stats(args.bams, + args.ref, + prefix=args.o, + method=args.method, + every_fragment_as_molecule=not args.dedup, + spikein_name=args.spikein_name, + n_threads=args.t)