Switch to side-by-side view

--- a
+++ b/singlecellmultiomics/bamProcessing/bamMethylationCutDistance.py
@@ -0,0 +1,247 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+import matplotlib
+matplotlib.rcParams['figure.dpi'] = 160
+matplotlib.use('Agg')
+import matplotlib.pyplot as plt
+import pandas as pd
+import seaborn as sns
+import multiprocessing
+from singlecellmultiomics.bamProcessing.bamBinCounts import generate_commands, count_methylation_binned
+import argparse
+from colorama import Fore, Style
+from singlecellmultiomics.utils import dataframe_to_wig
+from singlecellmultiomics.methylation import MethylationCountMatrix
+from singlecellmultiomics.bamProcessing.bamFunctions import get_reference_from_pysam_alignmentFile
+from colorama import Fore,Style
+from collections import defaultdict, Counter
+from multiprocessing import Pool
+from datetime import datetime
+import pysam
+from singlecellmultiomics.bamProcessing import get_contig_sizes, get_contig_size
+from singlecellmultiomics.bamProcessing.bamBinCounts import generate_commands, read_counts
+
+
+def sample_dict():
+    return defaultdict(Counter)
+
+
+
+def methylation_to_cut_histogram(args):
+    (alignments_path, bin_size, max_fragment_size, \
+     contig, start, end, \
+     min_mq, alt_spans, key_tags, dedup, kwargs) = args
+
+
+    distance_methylation = defaultdict(sample_dict) # sample - > distance -> context(ZzHhXx) : obs
+    max_dist = 1000
+
+    # Define which reads we want to count:
+    known =  set()
+    if 'known' in kwargs and kwargs['known'] is not None:
+        # Only ban the very specific TAPS conversions:
+        try:
+            with pysam.VariantFile(kwargs['known']) as variants:
+                for record in variants.fetch(contig, start, end):
+                    if record.ref=='C' and 'T' in record.alts:
+                        known.add( record.pos)
+                    if record.ref=='G' and 'A' in record.alts:
+                        known.add(record.pos)
+        except ValueError:
+            # This happends on contigs not present in the vcf
+            pass
+
+    p = 0
+
+    start_time = datetime.now()
+    with pysam.AlignmentFile(alignments_path, threads=4) as alignments:
+        # Obtain size of selected contig:
+        contig_size = get_contig_size(alignments, contig)
+        if contig_size is None:
+            raise ValueError('Unknown contig')
+
+        # Determine where we start looking for fragments:
+        f_start = max(0, start - max_fragment_size)
+        f_end = min(end + max_fragment_size, contig_size)
+
+        for p, read in enumerate(alignments.fetch(contig=contig, start=f_start,
+                                                  stop=f_end)):
+
+
+
+
+            if p%50==0 and 'maxtime' in kwargs and kwargs['maxtime'] is not None:
+                if (datetime.now() - start_time).total_seconds() > kwargs['maxtime']:
+                    print(f'Gave up on {contig}:{start}-{end}')
+
+                    break
+
+            if not read_counts(read, min_mq=min_mq, dedup=dedup):
+                continue
+
+
+            tags = dict(read.tags)
+            for i, (qpos, methylation_pos) in enumerate(read.get_aligned_pairs(matches_only=True)):
+
+                # Don't count sites outside the selected bounds
+                if methylation_pos < start or methylation_pos >= end:
+                    continue
+
+                call = tags['XM'][i]
+                if call=='.':
+                    continue
+
+                sample = read.get_tag('SM')
+
+
+                distance = abs(read.get_tag('DS') - methylation_pos)
+                if distance>max_dist:
+                    continue
+
+                distance_methylation[sample][(read.is_read1, read.is_reverse, distance)][call] +=1
+
+    return distance_methylation
+
+
+threads = None
+
+def get_distance_methylation(bam_path,
+                                 bp_per_job: int,
+                                 min_mapping_qual: int = None,
+                                 skip_contigs: set = None,
+                                 known_variants: str = None,
+                                 maxtime: int = None,
+                                 head: int=None,
+                                 threads: int = None,
+                                **kwargs
+                                 ):
+
+
+    all_kwargs = {'known': known_variants,
+            'maxtime': maxtime,
+            'threads':threads
+            }
+    all_kwargs.update(kwargs)
+    commands = generate_commands(
+        alignments_path=bam_path,
+        key_tags=None,
+        max_fragment_size=0,
+        dedup=True,
+        head=head,
+        bin_size=bp_per_job,
+        bins_per_job= 1, min_mq=min_mapping_qual,
+        kwargs=all_kwargs,
+        skip_contigs=skip_contigs
+    )
+
+
+    distance_methylation = defaultdict(sample_dict) # sample - > distance -> context(ZzHhXx) : obs
+
+    with Pool(threads) as workers:
+
+        for result in workers.imap_unordered(methylation_to_cut_histogram, commands):
+            for sample, data_for_sample in result.items():
+                for distance, context_obs in data_for_sample.items():
+                    distance_methylation[sample][distance] += context_obs
+    return distance_methylation
+
+
+if __name__ == '__main__':
+    argparser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+        description="""Extract methylation levels relative to cut site (DS tag) from bam file""")
+
+    argparser.add_argument('bamfile', metavar='bamfile', type=str)
+    argparser.add_argument('-bp_per_job', default=5_000_000, type=int, help='Amount of basepairs to be processed per thread per chunk')
+    argparser.add_argument('-threads', default=None, type=int, help='Amount of threads to use for counting, None to use the amount of available threads')
+
+    fi = argparser.add_argument_group("Filters")
+    fi.add_argument('-min_mapping_qual', default=40, type=int)
+    fi.add_argument('-head', default=None, type=int,help='Process the first n bins')
+    fi.add_argument('-skip_contigs', type=str, help='Comma separated contigs to skip', default='MT,chrM')
+    fi.add_argument('-known_variants',
+                           help='VCF file with known variants, will be not taken into account as methylated/unmethylated',
+                           type=str)
+
+    og = argparser.add_argument_group("Output")
+    og.add_argument('-prefix', default='distance_calls', type=str, help='Prefix for output files')
+
+    args = argparser.parse_args()
+
+
+
+    print('Obtaining counts ', end="")
+    r = get_distance_methylation(bam_path = args.bamfile,
+                                 bp_per_job = args.bp_per_job,
+                                 known_variants = args.known_variants,
+                                 skip_contigs = args.skip_contigs.split(','),
+                                 min_mapping_qual=args.min_mapping_qual,
+                                 head = args.head,
+                                 threads=args.threads,
+    )
+    print(f" [ {Fore.GREEN}OK{Style.RESET_ALL} ] ")
+
+
+    for ctx in 'zhx':
+
+        beta = {}
+        met = {}
+        un = {}
+        for sample, sample_data in r.items():
+            beta[sample] = {}
+            met[sample] = {}
+            un[sample] = {}
+            for distance, contexts in sample_data.items():
+                if ctx in contexts or ctx.upper() in contexts:
+                    beta[sample][distance] = contexts[ctx.upper()]/(contexts[ctx.upper()]+contexts[ctx])
+                    met[sample][distance] = contexts[ctx.upper()]
+                    un[sample][distance] = contexts[ctx]
+
+        pd.DataFrame(beta).sort_index().T.sort_index().to_csv(f'{args.prefix}_beta_{ctx}.csv')
+        pd.DataFrame(beta).sort_index().T.sort_index().to_csv(f'{args.prefix}_beta_{ctx}.pickle.gz')
+        pd.DataFrame(met).sort_index().T.sort_index().to_csv(f'{args.prefix}_counts_{ctx.upper()}.csv')
+        pd.DataFrame(met).sort_index().T.sort_index().to_csv(f'{args.prefix}_counts_{ctx.upper()}.pickle.gz')
+        pd.DataFrame(un).sort_index().T.sort_index().to_csv(f'{args.prefix}_counts_{ctx}.csv')
+        pd.DataFrame(un).sort_index().T.sort_index().to_csv(f'{args.prefix}_counts_{ctx}.pickle.gz')
+
+        # Make plots
+        beta = {}
+        met = {}
+        un = {}
+        for sample, sample_data in r.items():
+            beta[sample] = {}
+            met[sample] = {}
+            un[sample] = {}
+            for distance, contexts in sample_data.items():
+                if distance[-1] > 500 or distance[-1] < 4: # Clip in sane region
+                    continue
+                if ctx in contexts or ctx.upper() in contexts:
+                    beta[sample][distance] = contexts[ctx.upper()] / (contexts[ctx.upper()] + contexts[ctx])
+                    met[sample][distance] = contexts[ctx.upper()]
+                    un[sample][distance] = contexts[ctx]
+
+        beta = pd.DataFrame(beta).sort_index().T.sort_index()
+        met = pd.DataFrame(met).sort_index().T.sort_index()
+        un = pd.DataFrame(un).sort_index().T.sort_index()
+
+        for mate in [True, False]:
+            for strand in [True, False]:
+                fig, (ax1, ax2) = plt.subplots(2, 1, sharex=True)
+
+                un[mate, strand].sum().rename('Unmethylated').plot(ax=ax1)
+                met[mate, strand].sum().rename('Methylated').plot(ax=ax1)
+
+                ax1.set_xlabel('distance to cut')
+                ax1.set_ylabel('# molecules')
+                ax1.legend()
+
+                (met[mate, strand].sum() / (un[mate, strand].sum() + met[mate, strand].sum())).rename('Beta').plot(
+                    ax=ax2)
+                # ax2.set_ylim(0,0)
+
+                sns.despine()
+                ax1.set_title(f'Mate {"R1" if mate else "R2"}, strand:{"reverse" if strand else "forward"}')
+                ax2.set_ylabel('Beta')
+                plt.savefig(f'{args.prefix}_{ctx}_{"R1" if mate else "R2"}_{"reverse" if strand else "forward"}.png')
+                plt.close('all')