Switch to unified view

a b/singlecellmultiomics/bamProcessing/estimateTapsConversionEfficiency.py
1
#!/usr/bin/env python3
2
# -*- coding: utf-8 -*-
3
4
import matplotlib.pyplot as plt
5
from singlecellmultiomics.molecule import MoleculeIterator, CHICNLAMolecule, TAPSNlaIIIMolecule,TAPSCHICMolecule,TAPS
6
from singlecellmultiomics.fragment import CHICFragment, NlaIIIFragment
7
import pysam
8
from pysamiterators import CachedFasta
9
from singlecellmultiomics.variants.substitutions import conversion_dict_stranded
10
from collections import defaultdict
11
from singlecellmultiomics.utils import reverse_complement, complement
12
from glob import glob
13
from multiprocessing import Pool
14
from singlecellmultiomics.bamProcessing.bamFunctions import get_reference_path_from_bam
15
from collections import Counter
16
import pandas as pd
17
import matplotlib as mpl
18
from singlecellmultiomics.utils.sequtils import phred_to_prob, prob_to_phred
19
import seaborn as sns
20
import argparse
21
22
def update_mutation_dict(molecule,reference, conversions_per_library, context_obs):
23
24
    consensus = molecule.get_consensus(dove_safe=True,
25
                                       min_phred_score=22,
26
                                       skip_first_n_cycles_R1=10,
27
                                       skip_last_n_cycles_R1=20,
28
                                       skip_first_n_cycles_R2=10,
29
                                       skip_last_n_cycles_R2=20,
30
                                       dove_R2_distance=15,
31
                                       dove_R1_distance=15
32
33
34
                                      )
35
36
    nm = 0
37
38
39
    contexts_to_add = []
40
41
42
    for (chrom,pos), base in consensus.items():
43
        context = reference.fetch(chrom, pos-1, pos+2).upper()
44
45
        if len(context)!=3:
46
            continue
47
48
        # Check if the base matches or the refence contains N's
49
        if 'N' in context or len(context)!=3:
50
            continue
51
52
        # Ignore germline variants:
53
        #if might_be_variant(chrom, pos,  known):
54
        #    continue
55
56
        if not molecule.strand: # reverse template
57
            context = reverse_complement(context)
58
            base = complement(base)
59
60
        if context[1]!='C' and context[1]!=base:
61
            nm+=1
62
63
        contexts_to_add.append((context,base))
64
65
66
    if nm>5:
67
        nm=5
68
69
    k = tuple((*molecule.sample.rsplit('_',2), nm))
70
    for (context, base) in contexts_to_add:
71
72
        context_obs[ k ][context] += 1
73
        try:
74
            conversions_per_library[k][(context, base)] += 1
75
        except:
76
            pass
77
78
79
def get_conversion_counts(args):
80
81
82
    taps = TAPS()
83
84
    conversions_per_library = defaultdict( conversion_dict_stranded )
85
    context_obs = defaultdict( Counter )
86
87
    bam,refpath,method,every_fragment_as_molecule, spikein_name  = args
88
89
90
    if method=='nla':
91
        fragment_class=NlaIIIFragment
92
        molecule_class=TAPSNlaIIIMolecule
93
    else:
94
        fragment_class=CHICFragment
95
        molecule_class=TAPSCHICMolecule
96
97
    with pysam.FastaFile(refpath) as ref:
98
        reference = CachedFasta(ref)
99
100
101
        print(f'Processing {bam}')
102
103
        with pysam.AlignmentFile(bam, threads=8) as al:
104
105
            for molecule in MoleculeIterator(
106
                al,
107
                fragment_class=fragment_class,
108
                molecule_class=molecule_class,
109
                molecule_class_args={
110
                     'reference':reference,
111
                     'taps':taps,
112
                    'taps_strand':'R'
113
                },
114
                every_fragment_as_molecule=every_fragment_as_molecule,
115
                fragment_class_args={},
116
                contig = spikein_name
117
            ):
118
                update_mutation_dict(molecule, reference ,conversions_per_library, context_obs)
119
120
    return conversions_per_library, context_obs
121
122
123
def generate_taps_conversion_stats(bams, reference_path, prefix, method, every_fragment_as_molecule, spikein_name, n_threads=None):
124
    if reference_path is None:
125
        reference_path = get_reference_path_from_bam(bams[0])
126
127
    print(f'Reference at {reference_path}')
128
    if reference_path is None:
129
        raise ValueError('Please supply a reference fasta file')
130
131
    conversions_per_library = defaultdict( conversion_dict_stranded )
132
    context_obs = defaultdict( Counter )
133
134
    with Pool(n_threads) as workers:
135
136
        for cl, co in workers.imap(get_conversion_counts, [(bam, reference_path, method, every_fragment_as_molecule, spikein_name) for bam in bams] ):
137
138
            for lib, obs in cl.items():
139
                for k,v in obs.items():
140
                    conversions_per_library[lib][k] +=v
141
142
            for lib, obs in co.items():
143
                for k,v in obs.items():
144
                    context_obs[lib][k] += v
145
146
147
    qf = pd.DataFrame(context_obs)
148
    qf.to_csv(f'{prefix}_conversions_counts_raw_lambda.csv')
149
    ###
150
    indices = []
151
    for lib, qqf in qf.groupby(level=0,axis=1):
152
153
        ser = qqf.sum(level=(0,1),axis=1).sum().sort_values(ascending=False)
154
        ser = ser[ser>5000][::-1]
155
        indices += list(ser.index)
156
157
158
    ###
159
160
    normed_conversions_per_library = defaultdict( conversion_dict_stranded )
161
162
    for INDEX in context_obs:
163
        for (context, base),obs in conversions_per_library[INDEX].items():
164
            try:
165
                normed_conversions_per_library[INDEX][(context,base)] = obs/ context_obs[INDEX][context]
166
            except Exception:
167
                pass
168
169
    df = pd.DataFrame(normed_conversions_per_library)
170
    df = df[ [INDEX for INDEX in df if (INDEX[0],  INDEX[1]) in indices] ]
171
172
    df = df.loc[ [(context, base)for context, base in df.index if context[1]=='C' and base=='T' and context.endswith('CG')]  ]
173
    df = df.T
174
175
    df.to_csv(f'{prefix}_conversions_lambda.csv')
176
177
    mpl.rcParams['figure.dpi'] = 300
178
179
    samples = []
180
181
    for (lib, cell, nm), row in df.iterrows():
182
183
        if nm!=0:
184
            continue
185
186
        for context, base in [('ACG', 'T'),
187
            ('CCG', 'T'),
188
            ('GCG', 'T'),
189
            ('TCG', 'T')]:
190
191
            r = {
192
            'lib':lib,
193
            'cell':cell,
194
            'nm':nm,
195
                #'plate': int(lib.split('-')[-1].replace('pl','')),
196
            'group':  f'{nm},{context},{cell}',
197
            'context': f'{context}>{base}',
198
            'conversion rate':row[context,base]
199
            }
200
201
            samples.append(r)
202
203
    plot_table = pd.DataFrame(samples)
204
    print(plot_table)
205
206
    ph = 22
207
    fig, ax = plt.subplots(figsize=(12,5))
208
    sns.boxplot(data=plot_table.sort_values('lib'),x='context', y='conversion rate',hue='lib',whis=6, ax=ax)
209
210
    #ax = sns.swarmplot(data=plot_table,x='nm', y='conversion rate',hue='lib',)
211
212
    plt.legend()
213
    plt.ylabel('Lambda Conversion rate')
214
    sns.despine()
215
    plt.legend(loc='center left', bbox_to_anchor=(1, 0.5 ))
216
217
    plt.suptitle('Estimated TAPS conversion rate', y=1.05, fontsize=12)
218
    plt.title(f'Lambda spike-in, >{(1.0-(phred_to_prob(22)))*100 : .2f}% accuracy base calls', fontsize=10)
219
    plt.tight_layout()
220
    plt.savefig(f'{prefix}_conversion_rate_phred_{ph}.png', bbox_inches='tight')
221
    plt.close()
222
223
224
if __name__=='__main__':
225
    argparser = argparse.ArgumentParser(
226
      formatter_class=argparse.ArgumentDefaultsHelpFormatter,
227
      description='Estimate the conversion efficiency of a TAPS converted file. ')
228
    argparser.add_argument('bams', type=str, nargs='+',help='Input bam files')
229
    argparser.add_argument('-o', type=str, help="output alias (Will be the prefix of the output files)", required=True)
230
    argparser.add_argument('-method', type=str, default='nla', help='Molecule class (nla or chic). Use chic when you are not sure or when another other protocol is used.')
231
    argparser.add_argument('--dedup', action='store_true',help='perform UMI deduplication and consensus calling. Do not use when the UMI\'s are (near) saturated')
232
    argparser.add_argument('-t', type=int, help='Amount of threads')
233
    argparser.add_argument('-spikein_name', type=str, help='Name of spikein contig',default='J02459.1')
234
235
236
    argparser.add_argument(
237
        '-ref',
238
        type=str,
239
        default=None,
240
        help="Path to reference fast (autodected if not supplied)")
241
    args = argparser.parse_args()
242
243
244
    generate_taps_conversion_stats(args.bams,
245
                                   args.ref,
246
                                   prefix=args.o,
247
                                   method=args.method,
248
                                   every_fragment_as_molecule=not args.dedup,
249
                                   spikein_name=args.spikein_name,
250
                                   n_threads=args.t)