Switch to unified view

a b/singlecellmultiomics/bamProcessing/bamToRNACounts.py
1
#!/usr/bin/env python3
2
# -*- coding: utf-8 -*-
3
import scanpy as sc
4
import matplotlib.pyplot as plt
5
import os
6
import sys
7
import pysam
8
import collections
9
import argparse
10
import gzip
11
import pickle
12
import matplotlib
13
import numpy as np
14
import singlecellmultiomics
15
import singlecellmultiomics.molecule
16
import singlecellmultiomics.fragment
17
import singlecellmultiomics.features
18
import pysamiterators.iterators
19
import pysam
20
import pandas as pd
21
import scipy.sparse
22
import gzip
23
from singlecellmultiomics.molecule import MoleculeIterator
24
from singlecellmultiomics.alleleTools import alleleTools
25
import multiprocessing
26
from singlecellmultiomics.bamProcessing.bamFunctions import sort_and_index
27
28
matplotlib.use('Agg')
29
matplotlib.rcParams['figure.dpi'] = 160
30
31
32
def get_gene_id_to_gene_name_conversion_table(annotation_path_exons,
33
                                              featureTypes=['gene_name']):
34
    """Create a dictionary converting a gene id to other gene features,
35
        such as gene_name/gene_biotype etc.
36
37
    Arguments:
38
        annotation_path_exons(str) : path to GTF file (can be gzipped)
39
        featureTypes(list) : list of features to convert to, for example ['gene_name','gene_biotype']
40
41
    Returns:
42
        conversion_dict(dict) : { gene_id : 'firstFeature_secondFeature'}
43
        """
44
    conversion_table = {}
45
    with (gzip.open(annotation_path_exons, 'rt') if annotation_path_exons.endswith('.gz') else open(annotation_path_exons, 'r')) as t:
46
        for i, line in enumerate(t):
47
            parts = line.rstrip().split(None, 8)
48
            keyValues = {}
49
            for part in parts[-1].split(';'):
50
                kv = part.strip().split()
51
                if len(kv) == 2:
52
                    key = kv[0]
53
                    value = kv[1].replace('"', '')
54
                    keyValues[key] = value
55
            # determine the conversion name:
56
            if 'gene_id' in keyValues and any(
57
                    [feat in keyValues for feat in featureTypes]):
58
                conversion_table[keyValues['gene_id']] = '_'.join([
59
                    keyValues.get(feature, 'None')
60
                    for feature in featureTypes])
61
62
    return conversion_table
63
64
65
def count_transcripts(cargs):
66
    args, contig = cargs
67
    if args.alleles is not None:
68
        allele_resolver = alleleTools.AlleleResolver(
69
            args.alleles, lazyLoad=(not args.loadAllelesToMem))
70
    else:
71
        allele_resolver = None
72
73
    contig_mapping = None
74
75
    if args.contigmapping == 'danio':
76
        contig_mapping = {
77
            '1': 'CM002885.2',
78
            '2': 'CM002886.2',
79
            '3': 'CM002887.2',
80
            '4': 'CM002888.2',
81
            '5': 'CM002889.2',
82
83
            '6': 'CM002890.2',
84
            '7': 'CM002891.2',
85
            '8': 'CM002892.2',
86
            '9': 'CM002893.2',
87
            '10': 'CM002894.2',
88
            '11': 'CM002895.2',
89
            '12': 'CM002896.2',
90
            '13': 'CM002897.2',
91
            '14': 'CM002898.2',
92
            '15': 'CM002899.2',
93
94
            '16': 'CM002900.2',
95
            '17': 'CM002901.2',
96
            '18': 'CM002902.2',
97
            '19': 'CM002903.2',
98
            '20': 'CM002904.2',
99
            '21': 'CM002905.2',
100
            '22': 'CM002906.2',
101
            '23': 'CM002907.2',
102
            '24': 'CM002908.2',
103
            '25': 'CM002909.2',
104
        }
105
106
    # Load features
107
    contig_mapping = None
108
    #conversion_table = get_gene_id_to_gene_name_conversion_table(args.gtfexon)
109
    features = singlecellmultiomics.features.FeatureContainer()
110
    if contig_mapping is not None:
111
        features.remapKeys = contig_mapping
112
    features.loadGTF(
113
        args.gtfexon,
114
        select_feature_type=['exon'],
115
        identifierFields=(
116
            'exon_id',
117
            'transcript_id'),
118
        store_all=True,
119
        head=args.hf,
120
        contig=contig)
121
    features.loadGTF(
122
        args.gtfintron,
123
        select_feature_type=['intron'],
124
        identifierFields=['transcript_id'],
125
        store_all=True,
126
        head=args.hf,
127
        contig=contig)
128
129
    # What is used for assignment of molecules?
130
    if args.method == 'nla':
131
        molecule_class = singlecellmultiomics.molecule.AnnotatedNLAIIIMolecule
132
        fragment_class = singlecellmultiomics.fragment.NlaIIIFragment
133
        pooling_method = 1  # all data from the same cell can be dealt with separately
134
        stranded = None  # data is not stranded
135
    elif args.method == 'vasa' or args.method == 'cs':
136
        molecule_class = singlecellmultiomics.molecule.VASA
137
        fragment_class = singlecellmultiomics.fragment.SingleEndTranscriptFragment
138
        pooling_method = 1
139
        stranded = 1  # data is stranded, mapping to other strand
140
    else:
141
        raise ValueError("Supply a valid method")
142
143
    # COUNT:
144
    exon_counts_per_cell = collections.defaultdict(
145
        collections.Counter)  # cell->gene->umiCount
146
    intron_counts_per_cell = collections.defaultdict(
147
        collections.Counter)  # cell->gene->umiCount
148
    junction_counts_per_cell = collections.defaultdict(
149
        collections.Counter)  # cell->gene->umiCount
150
    gene_counts_per_cell = collections.defaultdict(
151
        collections.Counter)  # cell->gene->umiCount
152
153
    gene_set = set()
154
    sample_set = set()
155
    annotated_molecules = 0
156
    read_molecules = 0
157
    if args.producebam:
158
        bam_path_produced = f'{args.o}/output_bam_{contig}.unsorted.bam'
159
        with pysam.AlignmentFile(args.alignmentfiles[0]) as alignments:
160
            output_bam = pysam.AlignmentFile(
161
                bam_path_produced, "wb", header=alignments.header)
162
163
    ref = None
164
    if args.ref is not None:
165
        ref = pysamiterators.iterators.CachedFasta(pysam.FastaFile(args.ref))
166
167
    for alignmentfile_path in args.alignmentfiles:
168
169
        i = 0
170
        with pysam.AlignmentFile(alignmentfile_path) as alignments:
171
            molecule_iterator = MoleculeIterator(
172
                alignments=alignments,
173
                check_eject_every=5000,
174
                molecule_class=molecule_class,
175
                molecule_class_args={
176
                    'features': features,
177
                    'stranded': stranded,
178
                    'min_max_mapping_quality': args.minmq,
179
                    'reference': ref,
180
                    'allele_resolver': allele_resolver
181
                },
182
183
                fragment_class=fragment_class,
184
                fragment_class_args={
185
                    'umi_hamming_distance': args.umi_hamming_distance,
186
                    'features':features
187
                    },
188
                perform_qflag=True,
189
                # when the reads have not been tagged yet, this flag is very
190
                # much required
191
                pooling_method=pooling_method,
192
                contig=contig
193
            )
194
195
            for i, molecule in enumerate(molecule_iterator):
196
                if not molecule.is_valid():
197
                    if args.producebam:
198
                        molecule.write_tags()
199
                        molecule.write_pysam(output_bam)
200
                    continue
201
202
                molecule.annotate(args.annotmethod)
203
                molecule.set_intron_exon_features()
204
205
                if args.producebam:
206
                    molecule.write_tags()
207
                    molecule.write_pysam(output_bam)
208
209
                allele = None
210
                if allele_resolver is not None:
211
                    allele = molecule.allele
212
                    if allele is None:
213
                        allele = 'noAllele'
214
215
                # Obtain total count introns/exons reduce it so the sum of the
216
                # count will be 1:
217
                # len(molecule.introns.union( molecule.exons).difference(molecule.junctions))+len(molecule.junctions)
218
                total_count_for_molecule = len(molecule.genes)
219
                if total_count_for_molecule == 0:
220
                    continue  # we didn't find  any gene counts
221
222
                # Distibute count over amount of gene hits:
223
                count_to_add = 1 / total_count_for_molecule
224
                for gene in molecule.genes:
225
                    if allele is not None:
226
                        gene = f'{allele}_{gene}'
227
                    gene_counts_per_cell[molecule.sample][gene] += count_to_add
228
                    gene_set.add(gene)
229
                    sample_set.add(molecule.get_sample())
230
231
                # Obtain introns/exons/splice junction information:
232
                for intron in molecule.introns:
233
                    gene = intron
234
                    if allele is not None:
235
                        gene = f'{allele}_{intron}'
236
                    intron_counts_per_cell[molecule.sample][gene] += count_to_add
237
                    gene_set.add(gene)
238
239
                for exon in molecule.exons:
240
                    gene = exon
241
                    if allele is not None:
242
                        gene = f'{allele}_{exon}'
243
                    exon_counts_per_cell[molecule.sample][gene] += count_to_add
244
                    gene_set.add(gene)
245
246
                for junction in molecule.junctions:
247
                    gene = junction
248
                    if allele is not None:
249
                        gene = f'{allele}_{junction}'
250
                    junction_counts_per_cell[molecule.sample][gene] += count_to_add
251
                    gene_set.add(gene)
252
253
                annotated_molecules += 1
254
                if args.head and (i + 1) > args.head:
255
                    print(
256
                        f"-head was supplied, {i} molecules discovered, stopping")
257
                    break
258
259
        read_molecules += i
260
261
    if args.producebam:
262
        output_bam.close()
263
        final_bam_path = bam_path_produced.replace('.unsorted', '')
264
        sort_and_index(bam_path_produced, final_bam_path, remove_unsorted=True)
265
266
    return (
267
        gene_set,
268
        sample_set,
269
        gene_counts_per_cell,
270
        junction_counts_per_cell,
271
        exon_counts_per_cell,
272
        intron_counts_per_cell,
273
        annotated_molecules,
274
        read_molecules,
275
        contig
276
277
    )
278
279
280
if __name__ == '__main__':
281
282
    argparser = argparse.ArgumentParser(
283
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
284
        description='Create count tables from BAM file.')
285
    argparser.add_argument(
286
        '-o',
287
        type=str,
288
        help="output data folder",
289
        default='./rna_counts/')
290
    argparser.add_argument('alignmentfiles', type=str, nargs='+')
291
    argparser.add_argument(
292
        '-gtfexon',
293
        type=str,
294
        required=True,
295
        help="exon GTF file containing the features to plot")
296
    argparser.add_argument(
297
        '-gtfintron',
298
        type=str,
299
        required=True,
300
        help="intron GTF file containing the features to plot")
301
    argparser.add_argument('-umi_hamming_distance', type=int, default=1)
302
    argparser.add_argument(
303
        '-contigmapping',
304
        type=str,
305
        help="Use this when the GTF chromosome names do not match the ones in you bam file")
306
    argparser.add_argument(
307
        '-minmq',
308
        type=int,
309
        help="Minimum molcule mapping quality",
310
        default=20)
311
    argparser.add_argument(
312
        '-method',
313
        type=str,
314
        help="Data type: vasa,nla,cs",
315
        required=True)
316
    argparser.add_argument(
317
        '-head',
318
        type=int,
319
        help="Process this amount of molecules and export tables, also set -hf to be really fast")
320
    argparser.add_argument(
321
        '-hf',
322
        type=int,
323
        help="headfeatures Process this amount features and then continue, for a quick test set this to 1000 or so.")
324
    argparser.add_argument('-ref', type=str, help="Reference file (FASTA)")
325
    argparser.add_argument('-alleles', type=str, help="Allele file (VCF)")
326
    argparser.add_argument(
327
        '--loadAllelesToMem',
328
        action='store_true',
329
        help='Load allele data completely into memory')
330
    argparser.add_argument(
331
        '--producebam',
332
        action='store_true',
333
        help='Produce bam file with counts tagged')
334
    argparser.add_argument(
335
        '--ignoreMT',
336
        action='store_true',
337
        help='Ignore mitochondria')
338
    argparser.add_argument(
339
        '-t',
340
        type=int,
341
        default=8,
342
        help="Amount of chromosomes processed in parallel")
343
344
    argparser.add_argument(
345
        '-annotmethod',
346
        type=int,
347
        default=1,
348
        help="Annotation resolving method. 0: molecule consensus aligned blocks. 1: per read per aligned base")
349
350
    #argparser.add_argument('-tagged_bam_out',  type=str, help="Output bam file" )
351
352
    args = argparser.parse_args()
353
    workers = multiprocessing.Pool(args.t)
354
355
    if not os.path.exists(args.o):
356
        os.makedirs(args.o)
357
358
    jobs = []
359
    contigs_todo = []
360
    with pysam.AlignmentFile(args.alignmentfiles[0]) as g:
361
        # sort by size and do big ones first.. this will be faster
362
        for _, chrom in sorted(
363
                list(zip(g.lengths, g.references)), reverse=True):
364
365
            if chrom.startswith('ERCC') or chrom.startswith('chrUn') or chrom.endswith(
366
                    '_random') or chrom.startswith('GL') or chrom.startswith('JH'):
367
                continue
368
            if chrom.startswith('KN') or chrom.startswith('KZ') or chrom.startswith(
369
                    'chrUn') or chrom.endswith('_random') or 'ERCC' in chrom:
370
                continue
371
            if args.ignoreMT and chrom in ('mt', 'çhrMT', 'MT'):
372
                print("Ignoring mitochondria")
373
                continue
374
            jobs.append((args, chrom))
375
            contigs_todo.append(chrom)
376
377
    gene_counts_per_cell = collections.defaultdict(
378
        collections.Counter)  # cell->gene->umiCount
379
    exon_counts_per_cell = collections.defaultdict(
380
        collections.Counter)  # cell->gene->umiCount
381
    intron_counts_per_cell = collections.defaultdict(
382
        collections.Counter)  # cell->gene->umiCount
383
    junction_counts_per_cell = collections.defaultdict(
384
        collections.Counter)  # cell->gene->umiCount
385
    gene_set = set()
386
    sample_set = set()
387
    read_molecules = 0
388
    annotated_molecules = 0
389
    for (
390
        result_gene_set,
391
        result_sample_set,
392
        result_gene_counts_per_cell,
393
        result_junction_counts_per_cell,
394
        result_exon_counts_per_cell,
395
        result_intron_counts_per_cell,
396
        result_annotated_molecules,
397
        result_read_molecules,
398
        result_contig
399
    ) in workers.imap_unordered(count_transcripts, jobs):
400
        # Update all:
401
        gene_set.update(result_gene_set)
402
        sample_set.update(result_sample_set)
403
404
        for cell, counts in result_gene_counts_per_cell.items():
405
            gene_counts_per_cell[cell].update(counts)
406
        for cell, counts in result_junction_counts_per_cell.items():
407
            junction_counts_per_cell[cell].update(counts)
408
        for cell, counts in result_exon_counts_per_cell.items():
409
            exon_counts_per_cell[cell].update(counts)
410
        for cell, counts in result_intron_counts_per_cell.items():
411
            intron_counts_per_cell[cell].update(counts)
412
        read_molecules += result_read_molecules
413
        annotated_molecules += result_annotated_molecules
414
        # Now we finished counting
415
        contigs_todo = [x for x in contigs_todo if x != result_contig]
416
        print(
417
            f'Finished {result_contig}, so far found {read_molecules} molecules, annotated {annotated_molecules}, {len(sample_set)} samples')
418
        print(f"Remaining contigs:{','.join(contigs_todo)}")
419
420
        print('writing current matrices')
421
422
        # freeze order of samples and genes:
423
        sample_order = sorted(list(sample_set))
424
        gene_order = sorted(list(gene_set))
425
426
        # Construct the sparse matrices:
427
        sparse_gene_matrix = scipy.sparse.dok_matrix(
428
            (len(sample_set), len(gene_set)), dtype=np.int64)
429
        # Construct the sparse matrices:
430
        sparse_intron_matrix = scipy.sparse.dok_matrix(
431
            (len(sample_set), len(gene_set)), dtype=np.int64)
432
        # sparse_intron_matrix.setdefault(0)
433
        sparse_exon_matrix = scipy.sparse.dok_matrix(
434
            (len(sample_set), len(gene_set)), dtype=np.int64)
435
        # sparse_exon_matrix.setdefault(0)
436
        sparse_junction_matrix = scipy.sparse.dok_matrix(
437
            (len(sample_set), len(gene_set)), dtype=np.int64)
438
439
        for sample_idx, sample in enumerate(sample_order):
440
            if sample in gene_counts_per_cell:
441
                for gene, counts in gene_counts_per_cell[sample].items():
442
                    gene_idx = gene_order.index(gene)
443
                    sparse_gene_matrix[sample_idx, gene_idx] = counts
444
            if sample in exon_counts_per_cell:
445
                for gene, counts in exon_counts_per_cell[sample].items():
446
                    gene_idx = gene_order.index(gene)
447
                    sparse_exon_matrix[sample_idx, gene_idx] = counts
448
            if sample in intron_counts_per_cell:
449
                for gene, counts in intron_counts_per_cell[sample].items():
450
                    gene_idx = gene_order.index(gene)
451
                    sparse_intron_matrix[sample_idx, gene_idx] = counts
452
            if sample in junction_counts_per_cell:
453
                for gene, counts in junction_counts_per_cell[sample].items():
454
                    gene_idx = gene_order.index(gene)
455
                    sparse_junction_matrix[sample_idx, gene_idx] = counts
456
457
        # Write matrices to disk
458
        sparse_gene_matrix = sparse_gene_matrix.tocsc()
459
        sparse_intron_matrix = sparse_intron_matrix.tocsc()
460
        sparse_exon_matrix = sparse_exon_matrix.tocsc()
461
        sparse_junction_matrix = sparse_junction_matrix.tocsc()
462
463
        scipy.sparse.save_npz(
464
            f'{args.o}/sparse_gene_matrix.npz',
465
            sparse_gene_matrix)
466
        scipy.sparse.save_npz(
467
            f'{args.o}/sparse_intron_matrix.npz',
468
            sparse_intron_matrix)
469
        scipy.sparse.save_npz(
470
            f'{args.o}/sparse_exon_matrix.npz',
471
            sparse_exon_matrix)
472
        scipy.sparse.save_npz(
473
            f'{args.o}/sparse_junction_matrix.npz',
474
            sparse_junction_matrix)
475
476
        try:
477
            # Write scanpy file vanilla
478
            adata = sc.AnnData(
479
                sparse_gene_matrix.todense()
480
            )
481
            adata.var_names = gene_order
482
            adata.obs_names = sample_order
483
            adata.write(f'{args.o}/scanpy_vanilla.h5ad')
484
485
            # Write scanpy file, with introns
486
            adata = sc.AnnData(
487
                sparse_gene_matrix,
488
                layers={
489
                    'spliced': sparse_junction_matrix,
490
                    'unspliced': sparse_intron_matrix,
491
                    'exon': sparse_exon_matrix
492
                }
493
            )
494
            adata.var_names = gene_order
495
            adata.obs_names = sample_order
496
            adata.write(f'{args.o}/scanpy_complete.h5ad')
497
        except Exception as e:
498
            print("Could not (yet?) write the scanpy files, error below")
499
            print(e)
500
501
    print("Writing final tables to dense csv files")
502
    pd.DataFrame(sparse_gene_matrix.todense(), columns=gene_order,
503
                 index=sample_order).to_csv(f'{args.o}/genes.csv.gz')
504
    pd.DataFrame(sparse_intron_matrix.todense(), columns=gene_order,
505
                 index=sample_order).to_csv(f'{args.o}/introns.csv.gz')
506
    pd.DataFrame(sparse_exon_matrix.todense(), columns=gene_order,
507
                 index=sample_order).to_csv(f'{args.o}/exons.csv.gz')
508
    pd.DataFrame(
509
        sparse_junction_matrix.todense(),
510
        columns=gene_order,
511
        index=sample_order).to_csv(f'{args.o}/junctions.csv.gz')
512
513
    # Write as plaintext:
514
    adata.to_df().to_csv(f'{args.o}/counts.csv')