#!/usr/bin/env python
from multiprocessing import Pool
from singlecellmultiomics.bamProcessing.bamFunctions import mate_iter
import argparse
import pysam
from glob import glob
import pandas as pd
from singlecellmultiomics.bamProcessing import get_contig_sizes
from collections import Counter, defaultdict
from singlecellmultiomics.features import FeatureContainer
import os
from matplotlib.patches import Rectangle
import matplotlib as mpl
from scipy.ndimage import gaussian_filter
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt
from itertools import product
from singlecellmultiomics.bamProcessing import get_contigs_with_reads
def _generate_count_dict(args):
bam_path, bin_size, contig, start, stop = args #reference_path = args
#reference_handle = pysam.FastaFile(reference_path)
#reference = CachedFasta(reference_handle)
cut_counts = defaultdict(Counter )
i = 0
with pysam.AlignmentFile(bam_path) as alignments:
for R1,R2 in mate_iter(alignments, contig=contig, start=start, stop=stop):
if R1 is None or R1.is_duplicate or not R1.has_tag('DS') or R1.is_qcfail:
continue
cut_pos = R1.get_tag('DS')
sample = R1.get_tag('SM')
bin_idx=int(cut_pos/bin_size)*bin_size
cut_counts[(contig,bin_idx)][sample] += 1
return cut_counts, contig, bam_path
def get_binned_counts(bams, bin_size, regions=None):
fs = 1000
if regions is None:
regions = [(c,None,None) for c in get_contig_sizes(bams[0]).keys()]
else:
for i,r in enumerate(regions):
if type(r)==str:
regions[i] = (r,None,None)
else:
contig, start, end =r
if type(start)==int:
start = max(0,start-fs)
regions[i] = (contig,start,end)
jobs = [(bam_path, bin_size, *region) for region, bam_path in product(regions, bams)]
cut_counts = defaultdict(Counter)
with Pool() as workers:
for i, (cc, contig, bam_path) in enumerate(workers.imap(_generate_count_dict,jobs)):
for k,v in cc.items():
cut_counts[k] += v
print(i,'/', len(jobs), end='\r')
return pd.DataFrame(cut_counts).T
def plot_region(counts, features, contig, start, end, sigma=2, target=None, caxlabel='Molecules per spike-in'):
if target is None:
target = f'{contig}_{start}_{end}.png'
def create_gene_models(start,end,ax):
exon_height = 0.010
gene_height = 0.0002
spacer = 0.035
overlap_dist = 200_000
gene_y = {}
ymax = 0
for fs,fe,name,strand, feature_meta in features.features[contig]:
if not (((fs>=start or fe>=start) and (fs<=end or fe<=end))):
continue
feature_meta = dict(feature_meta)
if feature_meta.get('type') == 'gene':
if not 'gene_name' in feature_meta or feature_meta.get('gene_name').startswith('AC'):
continue
# Determine g-y coordinate:
gy_not_avail = set()
for gene,(s,e,loc) in gene_y.items():
if (s+overlap_dist>=fs and s-overlap_dist<=fe) or (e+overlap_dist>=fs and e-overlap_dist<=fe):
# Overlap:
gy_not_avail.add(loc)
gy = 0
while gy in gy_not_avail:
gy+=1
gene_y[name] = (fs,fe,gy)
y_offset = gy * spacer
ymax = max(y_offset+gene_height,ymax)
r = Rectangle((fs,-gene_height*0.5 + y_offset), fe-fs, gene_height, angle=0.0, color='k')
ax.add_patch( r )
ax.text((fe+fs)*0.5,-1.6*exon_height + y_offset,feature_meta.get('gene_name'),horizontalalignment='center',
verticalalignment='center',fontsize=3)
#print(feature_meta)
if False:
for xx in range(3):
for fs,fe,name,strand, feature_meta in features.features[contig]:
if not (((fs>=start or fe>=start) and (fs<=end or fe<=end))):
continue
feature_meta = dict(feature_meta)
if not name in gene_y:
continue
if feature_meta.get('type') == 'exon':
y_offset = gene_y[name][2]*spacer
ymax = max(y_offset+exon_height,ymax)
r = Rectangle((fs,-exon_height*0.5 + y_offset), fe-fs, exon_height, angle=0.0,color='k', lw=0)
ax.add_patch( r )
ax.set_xlim(start,end)
ax.set_ylim(-0.1,ymax)
#ax.axis('off')
ax.set_yticks([])
ax.set_xlabel(f'chr{contig} location bp', fontsize=6)
#print([t.get_text() for t in ax.get_xticklabels()])
#ax.set_xticklabels([t.get_text() for t in ax.get_xticklabels()],fontsize=4)
ax.set_xticklabels(ax.get_xticks(), fontsize=4)
ax.tick_params(length=0.5)
for sigma in range(2,3):
mpl.rcParams['figure.dpi'] = 300
font = {'family' : 'helvetica',
'weight' : 'normal',
'size' : 8}
mpl.rc('font', **font)
if end - start < 3_000_000:
mode ='k'
stepper = 100_000
res = 100
else:
mode='M'
stepper=1_000_000
res = 1
qf = counts.loc[:, [(c,p) for c,p in counts if c==contig and p>=start and p<=end] ].sort_index()
qf = qf.sort_index(1).sort_index(0)
qf = pd.DataFrame(gaussian_filter(qf, sigma=(0.00001,sigma)), index=qf.index, columns=qf.columns)
qf = qf.sort_index(1).sort_index(0)
cm = sns.clustermap(qf,
#z_score=0,
row_cluster=False,
col_cluster=False,
vmax=np.percentile(qf,99.5),#0.0005,
#vmax=10,
dendrogram_ratio=0.1,
#row_colors=row_colors.loc[qf.index].drop('LOWESS_STAGE',1),
figsize=(8,4), cmap='Greys', cbar_kws={"shrink": .1},
cbar_pos=(0.0, 0.5, 0.01, 0.16),)
ax = cm.ax_col_dendrogram
qf.mean().plot.bar(ax=ax,color='k',width=1)
ax.set_yticks([])
cm.ax_heatmap.set_xticks([]) #np.arange(start,end, 1_000_000))
cm.ax_heatmap.set_yticks([])
cm.ax_heatmap.set_ylabel(f'{qf.shape[0]} single cells', fontsize=8)
cm.ax_heatmap.tick_params(length=0.5)
cm.ax_heatmap.set_xlabel(None)
ax.grid()
cm.cax.set_ylabel(caxlabel,fontsize=6)
cm.cax.tick_params(labelsize=4)
#plt.suptitle(mark, x=0.05)
fig = plt.gcf()
heatmap_start_x,heatmap_start_y, heatmap_end_x, heatmap_end_y = cm.ax_heatmap.get_position().bounds
width = heatmap_end_x #-heatmap_start_x
height = 0.2 if features is not None else 0.05
ax = fig.add_axes( (heatmap_start_x, heatmap_start_y-height-0.02, width, height) )
ax.ticklabel_format(axis='x',style='sci')
sns.despine(fig=fig, ax=ax)
if features is not None:
create_gene_models(start,end,ax=ax)
else:
ax.set_xlim(start,end)
#ax.axis('off')
ax.set_yticks([])
ax.set_xlabel(f'chr{contig} location bp', fontsize=6)
#ax.set_xticklabels(ax.get_xticks(), fontsize=4)
plt.xticks(fontsize=4)
ax.tick_params(length=0.5)
plt.savefig(target)
plt.close()
if __name__=='__main__':
argparser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
description='Plot a genomic region')
argparser.add_argument('bams', type=str, nargs='+', help='(X) Training bam files')
argparser.add_argument('-regions', type=str, help='Regions to plot, with a bin size behind it, for example: 1:1000-100000:1000 , will be a single region plotted with a 1000bp bin size split regions by commas without a space')
argparser.add_argument('-features', type=str, help='Gene models to plot (.gtf file or .gtf.gz)', required=False)
argparser.add_argument('-norm', type=str, help='Normalize to, select from : total-molecules,spike', default='total-molecules')
argparser.add_argument('-prefix', type=str, help='Prefix for output file',default='')
argparser.add_argument('-format', type=str, help='png or svg',default='png')
args = argparser.parse_args()
regions = []
contigs = set()
for region in args.regions.split(','):
contig = region.split(':')[0]
if not '-' in region:
start, end = None, None
else:
start, end = region.split(':')[1].split('-')
start = int(start)
end = int(end)
bin_size = int(region.split(':')[-1])
if start is not None:
print(f'Region: {contig} from {start} to {end} with bin size : {bin_size}')
else:
print(f'Region: {contig} with bin size : {bin_size}')
contigs.add(contig)
regions.append( ((contig,start,end), bin_size))
contigs=list(contigs)
bams = args.bams
if args.features is not None:
print('Reading features')
features = FeatureContainer()
if len(contigs)==1:
print(f'Reading only features from {contigs[0]}')
features.loadGTF(args.features,store_all=True,contig=contigs[0])
else:
features.loadGTF(args.features,store_all=True)
else:
features = None
print('Counting')
# Obtain counts per cell
norm = 'spike'
if norm == 'spike':
normalize_to_counts = get_binned_counts(bams, bin_size=10_000_000, regions=['J02459.1'])
elif norm=='total-molecules':
normalize_to_counts = get_binned_counts(bams, bin_size=10_000_000)
for region, region_bin_size in regions:
print(f'Plotting {region}')
contig, start, end = region
region_counts = get_binned_counts(bams, region_bin_size, regions=[ region ] )
counts = (region_counts/normalize_to_counts.sum()).fillna(0).T.sort_index(1).sort_index(0)
# Fill non intialized bins with zeros:
add = []
for i in np.arange(counts.columns[0][1], counts.columns[-1][1], region_bin_size):
if not (contig,i) in counts.columns:
add.append((contig,i))
for a in add:
counts[a] = 0
counts = counts.sort_index(1)
target = args.prefix+f'{contig}_{start}-{end}_{region_bin_size}.{args.format}'
plot_region(counts, features, contig, start, end, sigma=2, target=target, caxlabel='Molecules per spike-in' if norm =='spike' else 'Molecules / total molecules')