a b/singlecellmultiomics/bamProcessing/bamMethylationCutDistance.py
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
import matplotlib
5
matplotlib.rcParams['figure.dpi'] = 160
6
matplotlib.use('Agg')
7
import matplotlib.pyplot as plt
8
import pandas as pd
9
import seaborn as sns
10
import multiprocessing
11
from singlecellmultiomics.bamProcessing.bamBinCounts import generate_commands, count_methylation_binned
12
import argparse
13
from colorama import Fore, Style
14
from singlecellmultiomics.utils import dataframe_to_wig
15
from singlecellmultiomics.methylation import MethylationCountMatrix
16
from singlecellmultiomics.bamProcessing.bamFunctions import get_reference_from_pysam_alignmentFile
17
from colorama import Fore,Style
18
from collections import defaultdict, Counter
19
from multiprocessing import Pool
20
from datetime import datetime
21
import pysam
22
from singlecellmultiomics.bamProcessing import get_contig_sizes, get_contig_size
23
from singlecellmultiomics.bamProcessing.bamBinCounts import generate_commands, read_counts
24
25
26
def sample_dict():
27
    return defaultdict(Counter)
28
29
30
31
def methylation_to_cut_histogram(args):
32
    (alignments_path, bin_size, max_fragment_size, \
33
     contig, start, end, \
34
     min_mq, alt_spans, key_tags, dedup, kwargs) = args
35
36
37
    distance_methylation = defaultdict(sample_dict) # sample - > distance -> context(ZzHhXx) : obs
38
    max_dist = 1000
39
40
    # Define which reads we want to count:
41
    known =  set()
42
    if 'known' in kwargs and kwargs['known'] is not None:
43
        # Only ban the very specific TAPS conversions:
44
        try:
45
            with pysam.VariantFile(kwargs['known']) as variants:
46
                for record in variants.fetch(contig, start, end):
47
                    if record.ref=='C' and 'T' in record.alts:
48
                        known.add( record.pos)
49
                    if record.ref=='G' and 'A' in record.alts:
50
                        known.add(record.pos)
51
        except ValueError:
52
            # This happends on contigs not present in the vcf
53
            pass
54
55
    p = 0
56
57
    start_time = datetime.now()
58
    with pysam.AlignmentFile(alignments_path, threads=4) as alignments:
59
        # Obtain size of selected contig:
60
        contig_size = get_contig_size(alignments, contig)
61
        if contig_size is None:
62
            raise ValueError('Unknown contig')
63
64
        # Determine where we start looking for fragments:
65
        f_start = max(0, start - max_fragment_size)
66
        f_end = min(end + max_fragment_size, contig_size)
67
68
        for p, read in enumerate(alignments.fetch(contig=contig, start=f_start,
69
                                                  stop=f_end)):
70
71
72
73
74
            if p%50==0 and 'maxtime' in kwargs and kwargs['maxtime'] is not None:
75
                if (datetime.now() - start_time).total_seconds() > kwargs['maxtime']:
76
                    print(f'Gave up on {contig}:{start}-{end}')
77
78
                    break
79
80
            if not read_counts(read, min_mq=min_mq, dedup=dedup):
81
                continue
82
83
84
            tags = dict(read.tags)
85
            for i, (qpos, methylation_pos) in enumerate(read.get_aligned_pairs(matches_only=True)):
86
87
                # Don't count sites outside the selected bounds
88
                if methylation_pos < start or methylation_pos >= end:
89
                    continue
90
91
                call = tags['XM'][i]
92
                if call=='.':
93
                    continue
94
95
                sample = read.get_tag('SM')
96
97
98
                distance = abs(read.get_tag('DS') - methylation_pos)
99
                if distance>max_dist:
100
                    continue
101
102
                distance_methylation[sample][(read.is_read1, read.is_reverse, distance)][call] +=1
103
104
    return distance_methylation
105
106
107
threads = None
108
109
def get_distance_methylation(bam_path,
110
                                 bp_per_job: int,
111
                                 min_mapping_qual: int = None,
112
                                 skip_contigs: set = None,
113
                                 known_variants: str = None,
114
                                 maxtime: int = None,
115
                                 head: int=None,
116
                                 threads: int = None,
117
                                **kwargs
118
                                 ):
119
120
121
    all_kwargs = {'known': known_variants,
122
            'maxtime': maxtime,
123
            'threads':threads
124
            }
125
    all_kwargs.update(kwargs)
126
    commands = generate_commands(
127
        alignments_path=bam_path,
128
        key_tags=None,
129
        max_fragment_size=0,
130
        dedup=True,
131
        head=head,
132
        bin_size=bp_per_job,
133
        bins_per_job= 1, min_mq=min_mapping_qual,
134
        kwargs=all_kwargs,
135
        skip_contigs=skip_contigs
136
    )
137
138
139
    distance_methylation = defaultdict(sample_dict) # sample - > distance -> context(ZzHhXx) : obs
140
141
    with Pool(threads) as workers:
142
143
        for result in workers.imap_unordered(methylation_to_cut_histogram, commands):
144
            for sample, data_for_sample in result.items():
145
                for distance, context_obs in data_for_sample.items():
146
                    distance_methylation[sample][distance] += context_obs
147
    return distance_methylation
148
149
150
if __name__ == '__main__':
151
    argparser = argparse.ArgumentParser(
152
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
153
        description="""Extract methylation levels relative to cut site (DS tag) from bam file""")
154
155
    argparser.add_argument('bamfile', metavar='bamfile', type=str)
156
    argparser.add_argument('-bp_per_job', default=5_000_000, type=int, help='Amount of basepairs to be processed per thread per chunk')
157
    argparser.add_argument('-threads', default=None, type=int, help='Amount of threads to use for counting, None to use the amount of available threads')
158
159
    fi = argparser.add_argument_group("Filters")
160
    fi.add_argument('-min_mapping_qual', default=40, type=int)
161
    fi.add_argument('-head', default=None, type=int,help='Process the first n bins')
162
    fi.add_argument('-skip_contigs', type=str, help='Comma separated contigs to skip', default='MT,chrM')
163
    fi.add_argument('-known_variants',
164
                           help='VCF file with known variants, will be not taken into account as methylated/unmethylated',
165
                           type=str)
166
167
    og = argparser.add_argument_group("Output")
168
    og.add_argument('-prefix', default='distance_calls', type=str, help='Prefix for output files')
169
170
    args = argparser.parse_args()
171
172
173
174
    print('Obtaining counts ', end="")
175
    r = get_distance_methylation(bam_path = args.bamfile,
176
                                 bp_per_job = args.bp_per_job,
177
                                 known_variants = args.known_variants,
178
                                 skip_contigs = args.skip_contigs.split(','),
179
                                 min_mapping_qual=args.min_mapping_qual,
180
                                 head = args.head,
181
                                 threads=args.threads,
182
    )
183
    print(f" [ {Fore.GREEN}OK{Style.RESET_ALL} ] ")
184
185
186
    for ctx in 'zhx':
187
188
        beta = {}
189
        met = {}
190
        un = {}
191
        for sample, sample_data in r.items():
192
            beta[sample] = {}
193
            met[sample] = {}
194
            un[sample] = {}
195
            for distance, contexts in sample_data.items():
196
                if ctx in contexts or ctx.upper() in contexts:
197
                    beta[sample][distance] = contexts[ctx.upper()]/(contexts[ctx.upper()]+contexts[ctx])
198
                    met[sample][distance] = contexts[ctx.upper()]
199
                    un[sample][distance] = contexts[ctx]
200
201
        pd.DataFrame(beta).sort_index().T.sort_index().to_csv(f'{args.prefix}_beta_{ctx}.csv')
202
        pd.DataFrame(beta).sort_index().T.sort_index().to_csv(f'{args.prefix}_beta_{ctx}.pickle.gz')
203
        pd.DataFrame(met).sort_index().T.sort_index().to_csv(f'{args.prefix}_counts_{ctx.upper()}.csv')
204
        pd.DataFrame(met).sort_index().T.sort_index().to_csv(f'{args.prefix}_counts_{ctx.upper()}.pickle.gz')
205
        pd.DataFrame(un).sort_index().T.sort_index().to_csv(f'{args.prefix}_counts_{ctx}.csv')
206
        pd.DataFrame(un).sort_index().T.sort_index().to_csv(f'{args.prefix}_counts_{ctx}.pickle.gz')
207
208
        # Make plots
209
        beta = {}
210
        met = {}
211
        un = {}
212
        for sample, sample_data in r.items():
213
            beta[sample] = {}
214
            met[sample] = {}
215
            un[sample] = {}
216
            for distance, contexts in sample_data.items():
217
                if distance[-1] > 500 or distance[-1] < 4: # Clip in sane region
218
                    continue
219
                if ctx in contexts or ctx.upper() in contexts:
220
                    beta[sample][distance] = contexts[ctx.upper()] / (contexts[ctx.upper()] + contexts[ctx])
221
                    met[sample][distance] = contexts[ctx.upper()]
222
                    un[sample][distance] = contexts[ctx]
223
224
        beta = pd.DataFrame(beta).sort_index().T.sort_index()
225
        met = pd.DataFrame(met).sort_index().T.sort_index()
226
        un = pd.DataFrame(un).sort_index().T.sort_index()
227
228
        for mate in [True, False]:
229
            for strand in [True, False]:
230
                fig, (ax1, ax2) = plt.subplots(2, 1, sharex=True)
231
232
                un[mate, strand].sum().rename('Unmethylated').plot(ax=ax1)
233
                met[mate, strand].sum().rename('Methylated').plot(ax=ax1)
234
235
                ax1.set_xlabel('distance to cut')
236
                ax1.set_ylabel('# molecules')
237
                ax1.legend()
238
239
                (met[mate, strand].sum() / (un[mate, strand].sum() + met[mate, strand].sum())).rename('Beta').plot(
240
                    ax=ax2)
241
                # ax2.set_ylim(0,0)
242
243
                sns.despine()
244
                ax1.set_title(f'Mate {"R1" if mate else "R2"}, strand:{"reverse" if strand else "forward"}')
245
                ax2.set_ylabel('Beta')
246
                plt.savefig(f'{args.prefix}_{ctx}_{"R1" if mate else "R2"}_{"reverse" if strand else "forward"}.png')
247
                plt.close('all')