Switch to side-by-side view

--- a
+++ b/singlecellmultiomics/bamProcessing/plotRegion.py
@@ -0,0 +1,333 @@
+#!/usr/bin/env python
+
+
+from multiprocessing import Pool
+from singlecellmultiomics.bamProcessing.bamFunctions import mate_iter
+import argparse
+import pysam
+from glob import glob
+import pandas as pd
+from singlecellmultiomics.bamProcessing import get_contig_sizes
+from collections import Counter, defaultdict
+from singlecellmultiomics.features import FeatureContainer
+import os
+from matplotlib.patches import Rectangle
+import matplotlib as mpl
+from scipy.ndimage import gaussian_filter
+import seaborn as sns
+import numpy as np
+import matplotlib.pyplot as plt
+from itertools import product
+from singlecellmultiomics.bamProcessing import get_contigs_with_reads
+
+
+
+
+
+def _generate_count_dict(args):
+
+    bam_path, bin_size, contig, start, stop = args #reference_path  = args
+
+    #reference_handle = pysam.FastaFile(reference_path)
+    #reference = CachedFasta(reference_handle)
+
+
+    cut_counts = defaultdict(Counter )
+    i = 0
+    with pysam.AlignmentFile(bam_path) as alignments:
+
+        for R1,R2 in mate_iter(alignments, contig=contig, start=start, stop=stop):
+
+            if R1 is None or R1.is_duplicate or not R1.has_tag('DS') or R1.is_qcfail:
+                continue
+
+            cut_pos = R1.get_tag('DS')
+            sample = R1.get_tag('SM')
+
+            bin_idx=int(cut_pos/bin_size)*bin_size
+            cut_counts[(contig,bin_idx)][sample] += 1
+
+    return cut_counts, contig, bam_path
+
+
+def get_binned_counts(bams, bin_size, regions=None):
+
+    fs = 1000
+    if regions is None:
+        regions = [(c,None,None) for c in get_contig_sizes(bams[0]).keys()]
+
+    else:
+        for i,r in enumerate(regions):
+            if type(r)==str:
+                regions[i] = (r,None,None)
+            else:
+                contig, start, end =r
+                if type(start)==int:
+                    start = max(0,start-fs)
+
+                regions[i] = (contig,start,end)
+
+    jobs = [(bam_path, bin_size, *region) for region, bam_path in product(regions, bams)]
+
+
+    cut_counts = defaultdict(Counter)
+    with Pool() as workers:
+
+        for i, (cc, contig, bam_path) in enumerate(workers.imap(_generate_count_dict,jobs)):
+
+            for k,v in cc.items():
+                cut_counts[k] += v
+
+            print(i,'/', len(jobs), end='\r')
+
+    return pd.DataFrame(cut_counts).T
+
+
+
+def plot_region(counts, features, contig, start, end, sigma=2, target=None, caxlabel='Molecules per spike-in'):
+
+    if target is None:
+        target = f'{contig}_{start}_{end}.png'
+
+    def create_gene_models(start,end,ax):
+
+        exon_height = 0.010
+        gene_height = 0.0002
+        spacer = 0.035
+
+        overlap_dist = 200_000
+
+
+        gene_y = {}
+        ymax = 0
+        for fs,fe,name,strand, feature_meta in features.features[contig]:
+
+
+            if not (((fs>=start or fe>=start) and (fs<=end or fe<=end))):
+                continue
+            feature_meta = dict(feature_meta)
+
+            if feature_meta.get('type') == 'gene':
+
+                if not 'gene_name' in  feature_meta or feature_meta.get('gene_name').startswith('AC'):
+                    continue
+
+                # Determine g-y coordinate:
+
+                gy_not_avail = set()
+                for gene,(s,e,loc) in gene_y.items():
+                    if (s+overlap_dist>=fs and s-overlap_dist<=fe) or (e+overlap_dist>=fs and e-overlap_dist<=fe):
+                        # Overlap:
+                        gy_not_avail.add(loc)
+
+                gy = 0
+                while gy in gy_not_avail:
+                    gy+=1
+
+
+                gene_y[name] = (fs,fe,gy)
+
+                y_offset = gy * spacer
+
+                ymax = max(y_offset+gene_height,ymax)
+
+                r = Rectangle((fs,-gene_height*0.5 + y_offset), fe-fs, gene_height, angle=0.0, color='k')
+
+                ax.add_patch( r )
+                ax.text((fe+fs)*0.5,-1.6*exon_height + y_offset,feature_meta.get('gene_name'),horizontalalignment='center',
+              verticalalignment='center',fontsize=3)
+                #print(feature_meta)
+
+
+
+        if False:
+
+            for xx in range(3):
+                for fs,fe,name,strand, feature_meta in features.features[contig]:
+
+                    if not (((fs>=start or fe>=start) and (fs<=end or fe<=end))):
+                        continue
+
+                    feature_meta = dict(feature_meta)
+                    if not name in gene_y:
+                        continue
+
+                    if feature_meta.get('type') == 'exon':
+                        y_offset = gene_y[name][2]*spacer
+                        ymax = max(y_offset+exon_height,ymax)
+                        r = Rectangle((fs,-exon_height*0.5 + y_offset), fe-fs, exon_height, angle=0.0,color='k', lw=0)
+                        ax.add_patch( r )
+
+
+
+        ax.set_xlim(start,end)
+        ax.set_ylim(-0.1,ymax)
+        #ax.axis('off')
+        ax.set_yticks([])
+        ax.set_xlabel(f'chr{contig} location bp', fontsize=6)
+
+        #print([t.get_text() for t in ax.get_xticklabels()])
+        #ax.set_xticklabels([t.get_text() for t in ax.get_xticklabels()],fontsize=4)
+        ax.set_xticklabels(ax.get_xticks(), fontsize=4)
+
+
+        ax.tick_params(length=0.5)
+
+
+    for sigma in range(2,3):
+
+        mpl.rcParams['figure.dpi'] = 300
+
+        font = {'family' : 'helvetica',
+                'weight' : 'normal',
+                'size'   : 8}
+
+        mpl.rc('font', **font)
+
+        if end - start < 3_000_000:
+            mode ='k'
+            stepper = 100_000
+            res = 100
+        else:
+            mode='M'
+            stepper=1_000_000
+            res = 1
+
+
+
+        qf = counts.loc[:, [(c,p) for c,p in counts if c==contig and p>=start and p<=end] ].sort_index()
+        qf = qf.sort_index(1).sort_index(0)
+        qf = pd.DataFrame(gaussian_filter(qf, sigma=(0.00001,sigma)), index=qf.index, columns=qf.columns)
+        qf = qf.sort_index(1).sort_index(0)
+
+        cm = sns.clustermap(qf,
+               #z_score=0,
+                row_cluster=False,
+                col_cluster=False,
+                vmax=np.percentile(qf,99.5),#0.0005,
+                #vmax=10,
+                dendrogram_ratio=0.1,
+                #row_colors=row_colors.loc[qf.index].drop('LOWESS_STAGE',1),
+                figsize=(8,4), cmap='Greys', cbar_kws={"shrink": .1},
+                cbar_pos=(0.0, 0.5, 0.01, 0.16),)
+
+        ax = cm.ax_col_dendrogram
+        qf.mean().plot.bar(ax=ax,color='k',width=1)
+
+
+        ax.set_yticks([])
+
+        cm.ax_heatmap.set_xticks([]) #np.arange(start,end, 1_000_000))
+        cm.ax_heatmap.set_yticks([])
+
+
+        cm.ax_heatmap.set_ylabel(f'{qf.shape[0]} single cells', fontsize=8)
+        cm.ax_heatmap.tick_params(length=0.5)
+        cm.ax_heatmap.set_xlabel(None)
+
+        ax.grid()
+        cm.cax.set_ylabel(caxlabel,fontsize=6)
+        cm.cax.tick_params(labelsize=4)
+        #plt.suptitle(mark, x=0.05)
+
+
+        fig = plt.gcf()
+        heatmap_start_x,heatmap_start_y, heatmap_end_x, heatmap_end_y = cm.ax_heatmap.get_position().bounds
+
+        width = heatmap_end_x #-heatmap_start_x
+        height = 0.2 if features is not None else 0.05
+        ax = fig.add_axes(  (heatmap_start_x, heatmap_start_y-height-0.02, width, height)  )
+        ax.ticklabel_format(axis='x',style='sci')
+
+        sns.despine(fig=fig, ax=ax)
+        if features is not None:
+            create_gene_models(start,end,ax=ax)
+        else:
+            ax.set_xlim(start,end)
+            #ax.axis('off')
+            ax.set_yticks([])
+            ax.set_xlabel(f'chr{contig} location bp', fontsize=6)
+
+            #ax.set_xticklabels(ax.get_xticks(), fontsize=4)
+            plt.xticks(fontsize=4)
+            ax.tick_params(length=0.5)
+
+        plt.savefig(target)
+        plt.close()
+
+
+if __name__=='__main__':
+    argparser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+        description='Plot a genomic region')
+
+    argparser.add_argument('bams', type=str, nargs='+', help='(X) Training bam files')
+
+    argparser.add_argument('-regions', type=str, help='Regions to plot, with a bin size behind it, for example: 1:1000-100000:1000 , will be a single region plotted with a 1000bp bin size split regions by commas without a space')
+    argparser.add_argument('-features', type=str, help='Gene models to plot (.gtf file or .gtf.gz)', required=False)
+    argparser.add_argument('-norm', type=str, help='Normalize to, select from : total-molecules,spike', default='total-molecules')
+    argparser.add_argument('-prefix', type=str, help='Prefix for output file',default='')
+    argparser.add_argument('-format', type=str, help='png or svg',default='png')
+
+    args = argparser.parse_args()
+
+    regions = []
+    contigs = set()
+    for region in args.regions.split(','):
+        contig = region.split(':')[0]
+        if not '-' in region:
+            start, end = None, None
+        else:
+            start, end = region.split(':')[1].split('-')
+            start = int(start)
+            end = int(end)
+        bin_size = int(region.split(':')[-1])
+        if start is not None:
+            print(f'Region: {contig} from {start} to {end} with bin size : {bin_size}')
+        else:
+            print(f'Region: {contig} with bin size : {bin_size}')
+        contigs.add(contig)
+        regions.append( ((contig,start,end), bin_size))
+
+    contigs=list(contigs)
+    bams = args.bams
+
+    if args.features is not None:
+        print('Reading features')
+        features = FeatureContainer()
+        if len(contigs)==1:
+            print(f'Reading only features from {contigs[0]}')
+            features.loadGTF(args.features,store_all=True,contig=contigs[0])
+        else:
+            features.loadGTF(args.features,store_all=True)
+    else:
+        features = None
+    print('Counting')
+
+    # Obtain counts per cell
+    norm = 'spike'
+    if norm == 'spike':
+        normalize_to_counts = get_binned_counts(bams, bin_size=10_000_000, regions=['J02459.1'])
+    elif norm=='total-molecules':
+        normalize_to_counts = get_binned_counts(bams, bin_size=10_000_000)
+
+    for region, region_bin_size in regions:
+        print(f'Plotting {region}')
+        contig, start, end = region
+        region_counts = get_binned_counts(bams, region_bin_size, regions=[ region ] )
+        counts = (region_counts/normalize_to_counts.sum()).fillna(0).T.sort_index(1).sort_index(0)
+
+        # Fill non intialized bins with zeros:
+        add = []
+        for i in np.arange(counts.columns[0][1], counts.columns[-1][1], region_bin_size):
+            if not (contig,i) in counts.columns:
+                add.append((contig,i))
+
+        for a in add:
+            counts[a] = 0
+
+        counts = counts.sort_index(1)
+
+
+        target = args.prefix+f'{contig}_{start}-{end}_{region_bin_size}.{args.format}'
+        plot_region(counts, features, contig, start, end, sigma=2, target=target, caxlabel='Molecules per spike-in' if norm =='spike' else 'Molecules / total molecules')