Switch to unified view

a b/singlecellmultiomics/bamProcessing/plotRegion.py
1
#!/usr/bin/env python
2
3
4
from multiprocessing import Pool
5
from singlecellmultiomics.bamProcessing.bamFunctions import mate_iter
6
import argparse
7
import pysam
8
from glob import glob
9
import pandas as pd
10
from singlecellmultiomics.bamProcessing import get_contig_sizes
11
from collections import Counter, defaultdict
12
from singlecellmultiomics.features import FeatureContainer
13
import os
14
from matplotlib.patches import Rectangle
15
import matplotlib as mpl
16
from scipy.ndimage import gaussian_filter
17
import seaborn as sns
18
import numpy as np
19
import matplotlib.pyplot as plt
20
from itertools import product
21
from singlecellmultiomics.bamProcessing import get_contigs_with_reads
22
23
24
25
26
27
def _generate_count_dict(args):
28
29
    bam_path, bin_size, contig, start, stop = args #reference_path  = args
30
31
    #reference_handle = pysam.FastaFile(reference_path)
32
    #reference = CachedFasta(reference_handle)
33
34
35
    cut_counts = defaultdict(Counter )
36
    i = 0
37
    with pysam.AlignmentFile(bam_path) as alignments:
38
39
        for R1,R2 in mate_iter(alignments, contig=contig, start=start, stop=stop):
40
41
            if R1 is None or R1.is_duplicate or not R1.has_tag('DS') or R1.is_qcfail:
42
                continue
43
44
            cut_pos = R1.get_tag('DS')
45
            sample = R1.get_tag('SM')
46
47
            bin_idx=int(cut_pos/bin_size)*bin_size
48
            cut_counts[(contig,bin_idx)][sample] += 1
49
50
    return cut_counts, contig, bam_path
51
52
53
def get_binned_counts(bams, bin_size, regions=None):
54
55
    fs = 1000
56
    if regions is None:
57
        regions = [(c,None,None) for c in get_contig_sizes(bams[0]).keys()]
58
59
    else:
60
        for i,r in enumerate(regions):
61
            if type(r)==str:
62
                regions[i] = (r,None,None)
63
            else:
64
                contig, start, end =r
65
                if type(start)==int:
66
                    start = max(0,start-fs)
67
68
                regions[i] = (contig,start,end)
69
70
    jobs = [(bam_path, bin_size, *region) for region, bam_path in product(regions, bams)]
71
72
73
    cut_counts = defaultdict(Counter)
74
    with Pool() as workers:
75
76
        for i, (cc, contig, bam_path) in enumerate(workers.imap(_generate_count_dict,jobs)):
77
78
            for k,v in cc.items():
79
                cut_counts[k] += v
80
81
            print(i,'/', len(jobs), end='\r')
82
83
    return pd.DataFrame(cut_counts).T
84
85
86
87
def plot_region(counts, features, contig, start, end, sigma=2, target=None, caxlabel='Molecules per spike-in'):
88
89
    if target is None:
90
        target = f'{contig}_{start}_{end}.png'
91
92
    def create_gene_models(start,end,ax):
93
94
        exon_height = 0.010
95
        gene_height = 0.0002
96
        spacer = 0.035
97
98
        overlap_dist = 200_000
99
100
101
        gene_y = {}
102
        ymax = 0
103
        for fs,fe,name,strand, feature_meta in features.features[contig]:
104
105
106
            if not (((fs>=start or fe>=start) and (fs<=end or fe<=end))):
107
                continue
108
            feature_meta = dict(feature_meta)
109
110
            if feature_meta.get('type') == 'gene':
111
112
                if not 'gene_name' in  feature_meta or feature_meta.get('gene_name').startswith('AC'):
113
                    continue
114
115
                # Determine g-y coordinate:
116
117
                gy_not_avail = set()
118
                for gene,(s,e,loc) in gene_y.items():
119
                    if (s+overlap_dist>=fs and s-overlap_dist<=fe) or (e+overlap_dist>=fs and e-overlap_dist<=fe):
120
                        # Overlap:
121
                        gy_not_avail.add(loc)
122
123
                gy = 0
124
                while gy in gy_not_avail:
125
                    gy+=1
126
127
128
                gene_y[name] = (fs,fe,gy)
129
130
                y_offset = gy * spacer
131
132
                ymax = max(y_offset+gene_height,ymax)
133
134
                r = Rectangle((fs,-gene_height*0.5 + y_offset), fe-fs, gene_height, angle=0.0, color='k')
135
136
                ax.add_patch( r )
137
                ax.text((fe+fs)*0.5,-1.6*exon_height + y_offset,feature_meta.get('gene_name'),horizontalalignment='center',
138
              verticalalignment='center',fontsize=3)
139
                #print(feature_meta)
140
141
142
143
        if False:
144
145
            for xx in range(3):
146
                for fs,fe,name,strand, feature_meta in features.features[contig]:
147
148
                    if not (((fs>=start or fe>=start) and (fs<=end or fe<=end))):
149
                        continue
150
151
                    feature_meta = dict(feature_meta)
152
                    if not name in gene_y:
153
                        continue
154
155
                    if feature_meta.get('type') == 'exon':
156
                        y_offset = gene_y[name][2]*spacer
157
                        ymax = max(y_offset+exon_height,ymax)
158
                        r = Rectangle((fs,-exon_height*0.5 + y_offset), fe-fs, exon_height, angle=0.0,color='k', lw=0)
159
                        ax.add_patch( r )
160
161
162
163
        ax.set_xlim(start,end)
164
        ax.set_ylim(-0.1,ymax)
165
        #ax.axis('off')
166
        ax.set_yticks([])
167
        ax.set_xlabel(f'chr{contig} location bp', fontsize=6)
168
169
        #print([t.get_text() for t in ax.get_xticklabels()])
170
        #ax.set_xticklabels([t.get_text() for t in ax.get_xticklabels()],fontsize=4)
171
        ax.set_xticklabels(ax.get_xticks(), fontsize=4)
172
173
174
        ax.tick_params(length=0.5)
175
176
177
    for sigma in range(2,3):
178
179
        mpl.rcParams['figure.dpi'] = 300
180
181
        font = {'family' : 'helvetica',
182
                'weight' : 'normal',
183
                'size'   : 8}
184
185
        mpl.rc('font', **font)
186
187
        if end - start < 3_000_000:
188
            mode ='k'
189
            stepper = 100_000
190
            res = 100
191
        else:
192
            mode='M'
193
            stepper=1_000_000
194
            res = 1
195
196
197
198
        qf = counts.loc[:, [(c,p) for c,p in counts if c==contig and p>=start and p<=end] ].sort_index()
199
        qf = qf.sort_index(1).sort_index(0)
200
        qf = pd.DataFrame(gaussian_filter(qf, sigma=(0.00001,sigma)), index=qf.index, columns=qf.columns)
201
        qf = qf.sort_index(1).sort_index(0)
202
203
        cm = sns.clustermap(qf,
204
               #z_score=0,
205
                row_cluster=False,
206
                col_cluster=False,
207
                vmax=np.percentile(qf,99.5),#0.0005,
208
                #vmax=10,
209
                dendrogram_ratio=0.1,
210
                #row_colors=row_colors.loc[qf.index].drop('LOWESS_STAGE',1),
211
                figsize=(8,4), cmap='Greys', cbar_kws={"shrink": .1},
212
                cbar_pos=(0.0, 0.5, 0.01, 0.16),)
213
214
        ax = cm.ax_col_dendrogram
215
        qf.mean().plot.bar(ax=ax,color='k',width=1)
216
217
218
        ax.set_yticks([])
219
220
        cm.ax_heatmap.set_xticks([]) #np.arange(start,end, 1_000_000))
221
        cm.ax_heatmap.set_yticks([])
222
223
224
        cm.ax_heatmap.set_ylabel(f'{qf.shape[0]} single cells', fontsize=8)
225
        cm.ax_heatmap.tick_params(length=0.5)
226
        cm.ax_heatmap.set_xlabel(None)
227
228
        ax.grid()
229
        cm.cax.set_ylabel(caxlabel,fontsize=6)
230
        cm.cax.tick_params(labelsize=4)
231
        #plt.suptitle(mark, x=0.05)
232
233
234
        fig = plt.gcf()
235
        heatmap_start_x,heatmap_start_y, heatmap_end_x, heatmap_end_y = cm.ax_heatmap.get_position().bounds
236
237
        width = heatmap_end_x #-heatmap_start_x
238
        height = 0.2 if features is not None else 0.05
239
        ax = fig.add_axes(  (heatmap_start_x, heatmap_start_y-height-0.02, width, height)  )
240
        ax.ticklabel_format(axis='x',style='sci')
241
242
        sns.despine(fig=fig, ax=ax)
243
        if features is not None:
244
            create_gene_models(start,end,ax=ax)
245
        else:
246
            ax.set_xlim(start,end)
247
            #ax.axis('off')
248
            ax.set_yticks([])
249
            ax.set_xlabel(f'chr{contig} location bp', fontsize=6)
250
251
            #ax.set_xticklabels(ax.get_xticks(), fontsize=4)
252
            plt.xticks(fontsize=4)
253
            ax.tick_params(length=0.5)
254
255
        plt.savefig(target)
256
        plt.close()
257
258
259
if __name__=='__main__':
260
    argparser = argparse.ArgumentParser(
261
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
262
        description='Plot a genomic region')
263
264
    argparser.add_argument('bams', type=str, nargs='+', help='(X) Training bam files')
265
266
    argparser.add_argument('-regions', type=str, help='Regions to plot, with a bin size behind it, for example: 1:1000-100000:1000 , will be a single region plotted with a 1000bp bin size split regions by commas without a space')
267
    argparser.add_argument('-features', type=str, help='Gene models to plot (.gtf file or .gtf.gz)', required=False)
268
    argparser.add_argument('-norm', type=str, help='Normalize to, select from : total-molecules,spike', default='total-molecules')
269
    argparser.add_argument('-prefix', type=str, help='Prefix for output file',default='')
270
    argparser.add_argument('-format', type=str, help='png or svg',default='png')
271
272
    args = argparser.parse_args()
273
274
    regions = []
275
    contigs = set()
276
    for region in args.regions.split(','):
277
        contig = region.split(':')[0]
278
        if not '-' in region:
279
            start, end = None, None
280
        else:
281
            start, end = region.split(':')[1].split('-')
282
            start = int(start)
283
            end = int(end)
284
        bin_size = int(region.split(':')[-1])
285
        if start is not None:
286
            print(f'Region: {contig} from {start} to {end} with bin size : {bin_size}')
287
        else:
288
            print(f'Region: {contig} with bin size : {bin_size}')
289
        contigs.add(contig)
290
        regions.append( ((contig,start,end), bin_size))
291
292
    contigs=list(contigs)
293
    bams = args.bams
294
295
    if args.features is not None:
296
        print('Reading features')
297
        features = FeatureContainer()
298
        if len(contigs)==1:
299
            print(f'Reading only features from {contigs[0]}')
300
            features.loadGTF(args.features,store_all=True,contig=contigs[0])
301
        else:
302
            features.loadGTF(args.features,store_all=True)
303
    else:
304
        features = None
305
    print('Counting')
306
307
    # Obtain counts per cell
308
    norm = 'spike'
309
    if norm == 'spike':
310
        normalize_to_counts = get_binned_counts(bams, bin_size=10_000_000, regions=['J02459.1'])
311
    elif norm=='total-molecules':
312
        normalize_to_counts = get_binned_counts(bams, bin_size=10_000_000)
313
314
    for region, region_bin_size in regions:
315
        print(f'Plotting {region}')
316
        contig, start, end = region
317
        region_counts = get_binned_counts(bams, region_bin_size, regions=[ region ] )
318
        counts = (region_counts/normalize_to_counts.sum()).fillna(0).T.sort_index(1).sort_index(0)
319
320
        # Fill non intialized bins with zeros:
321
        add = []
322
        for i in np.arange(counts.columns[0][1], counts.columns[-1][1], region_bin_size):
323
            if not (contig,i) in counts.columns:
324
                add.append((contig,i))
325
326
        for a in add:
327
            counts[a] = 0
328
329
        counts = counts.sort_index(1)
330
331
332
        target = args.prefix+f'{contig}_{start}-{end}_{region_bin_size}.{args.format}'
333
        plot_region(counts, features, contig, start, end, sigma=2, target=target, caxlabel='Molecules per spike-in' if norm =='spike' else 'Molecules / total molecules')