Switch to side-by-side view

--- a
+++ b/exseek/scripts/count_reads.py
@@ -0,0 +1,244 @@
+#! /usr/bin/env python
+import argparse, sys, os, errno
+import logging
+logging.basicConfig(level=logging.INFO, format='[%(asctime)s] [%(levelname)s] %(name)s: %(message)s')
+
+command_handlers = {}
+def command_handler(f):
+    command_handlers[f.__name__] = f
+    return f
+
+@command_handler
+def count_transcript(args):
+    import pysam
+    import numpy as np
+    from ioutils import open_file_or_stdout
+    from collections import OrderedDict
+
+    logger.info('read input BAM/SAM file: ' + args.input_file)
+    sam = pysam.AlignmentFile(args.input_file, "rb")
+    counts = OrderedDict()
+    min_mapping_quality = args.min_mapping_quality
+    strandness = {'no': 0, 'forward': 1, 'reverse': 2}.get(args.strandness, 0)
+    for read in sam:
+        if read.is_unmapped:
+            continue
+        if read.mapping_quality < min_mapping_quality:
+            continue
+        if (strandness == 1) and read.is_reverse:
+            continue
+        if (strandness == 2) and (not read.is_reverse):
+            continue
+        if read.reference_name not in counts:
+            counts[read.reference_name] = 0
+        counts[read.reference_name] += 1
+    
+    with open_file_or_stdout(args.output_file) as f:
+        if sam.header is not None:
+            for sq in sam.header['SQ']:
+                name = sq['SN']
+                f.write('{}\t{}\n'.format(name, counts.get(name, 0)))
+        else:
+            for name, count in counts.items():
+                f.write('{}\t{}\n'.format(name, count))
+
+@command_handler
+def count_circrna(args):
+    import HTSeq
+    import numpy as np
+    import pandas as pd
+    from collections import OrderedDict, defaultdict
+    from ioutils import open_file_or_stdout
+
+    logger.info('read input BAM/SAM file: ' + args.input_file)
+    if args.input_file.endswith('.sam'):
+        sam = HTSeq.SAM_Reader(args.input_file)
+    elif args.input_file.endswith('.bam'):
+        sam = HTSeq.BAM_Reader(args.input_file)
+    else:
+        raise ValueError('unsupported file extension')
+    
+    # extract junction positions from SAM header
+    logger.info('extract junction positions')
+    junction_positions = OrderedDict()
+    for sq in sam.get_header_dict()['SQ']:
+        junction_positions[sq['SN']] = sq['LN']//2
+    # initialize counts
+    gene_ids = list(junction_positions.keys())
+    counts = pd.Series(np.zeros(len(gene_ids), dtype='int'), index=gene_ids)
+    # count reads
+    min_mapping_quality = args.min_mapping_quality
+    strandness = args.strandness
+    if args.paired_end:
+        logger.info('count paired-end fragments')
+        stats = defaultdict(int)
+        for bundle in HTSeq.pair_SAM_alignments(sam, bundle=True):
+            stats['total_pairs'] += 1
+            # ignore multi-mapped pairs
+            if len(bundle) != 1:
+                stats['multi_mapping'] += 1
+                continue
+            read1, read2 = bundle[0]
+            # ignore singletons
+            if (read1 is None) or (read2 is None):
+                stats['singleton'] += 1
+                continue
+            # ignore unmapped reads
+            if not (read1.aligned and read2.aligned):
+                stats['unmapped'] += 1
+                continue
+            # ignore pairs with mapping quality below threshold
+            if (read1.aQual < min_mapping_quality) or (read2.aQual < min_mapping_quality):
+                stats['low_mapping_quality'] += 1
+                continue
+            if (strandness == 'forward') and (not ((read1.iv.strand == '+') and (read2.iv.strand == '-'))):
+                stats['improper_strand'] += 1
+                continue
+            if (strandness == 'reverse') and (not ((read1.iv.strand == '-') and (read2.iv.strand == '+'))):
+                stats['improper_strand'] += 1
+                continue
+            # ignore pairs on different chromosomes
+            if read1.iv.chrom != read2.iv.chrom:
+                stats['diff_chrom'] += 1
+                continue
+            pos = junction_positions[read1.iv.chrom]
+            if read1.iv.start < pos <= read2.iv.end:
+                counts[read1.iv.chrom] += 1
+        for key, val in stats.items():
+            logger.info('{}: {}'.format(key, val))
+    else:
+        logger.info('count single-end reads')
+        for read in sam:
+            # ignore unmapped read
+            if not read.aligned:
+                continue
+            # ignore reads with mapping quality below threshold
+            if read.aQual < min_mapping_quality:
+                continue
+            if (strandness == 'forward') and (read.iv.strand == '-'):
+                continue
+            if (strandness == 'reverse') and (not ((read.iv.strand == '+'))):
+                continue
+            pos = junction_positions[read.iv.chrom]
+            if read.iv.start < pos <= read.iv.end:
+                counts[read.iv.chrom] += 1
+    # output counts
+    logger.info('count fragments: {}'.format(counts.sum()))
+    logger.info('write counts to file: ' + args.output_file)
+    with open_file_or_stdout(args.output_file) as fout:
+        counts.to_csv(fout, sep='\t', header=None, index=True, na_rep='NA')
+
+@command_handler
+def count_mature_mirna(args):
+    from collections import OrderedDict, defaultdict
+    from ioutils import open_file_or_stdin, open_file_or_stdout
+    import pysam
+    from utils import read_gff
+
+    logger.info('read input GFF file: ' + args.annotation)
+    fin = open(args.annotation, 'r')
+    # key: precursor_id, value: precursor record
+    precursors = OrderedDict()
+    # key: precursor_id, value: list of mature records
+    matures = defaultdict(list)
+    mature_names = []
+    # read features from GFF file
+    for record in read_gff(fin):
+        if record.feature == 'miRNA_primary_transcript':
+            precursors[record.attr['ID']] = record
+        elif record.feature == 'miRNA':
+            matures[record.attr['Derives_from']].append(record)
+            mature_names.append(record.attr['Name'])
+    fin.close()
+    # get locations of mature miRNAs
+    # key: precursor_name, key: dict of (mature_name, (start, end))
+    mature_locations = defaultdict(dict)
+    for precursor_id, precursor in precursors.items():
+        precursor_name = precursor.attr['Name']
+        for mature in matures[precursor_id]:
+            if mature.strand == '+':
+                mature_locations[precursor_name][mature.attr['Name']] = (
+                    mature.start - precursor.start,
+                    mature.end - precursor.start + 1)
+            else:
+                mature_locations[precursor_name][mature.attr['Name']]  = (
+                    precursor.end - mature.end,
+                    precursor.end - mature.start + 1)
+
+    logger.info('read input BAM/SAM file: ' + args.input_file)
+    sam = pysam.AlignmentFile(args.input_file, "rb")
+    counts = defaultdict(int)
+    min_mapping_quality = args.min_mapping_quality
+    for read in sam:
+        if read.is_unmapped:
+            continue
+        if read.mapping_quality < min_mapping_quality:
+            continue
+        if read.is_reverse:
+            continue
+        # find mature miRNA with maximum overlap with the read
+        max_overlap = 0
+        matched_mature_name = None
+        for mature_name, mature_location in mature_locations[read.reference_name].items():
+            # get overlap
+            overlap = (mature_location[1] - mature_location[0]) \
+                + (read.reference_end - read.reference_start) \
+                - (max(read.reference_end, mature_location[1]) - min(read.reference_start, mature_location[0]))
+            if overlap > max_overlap:
+                max_overlap = overlap
+                matched_mature_name = mature_name
+        if max_overlap <= 0:
+            continue
+        # count the read
+        counts[matched_mature_name] += 1
+    
+    logger.info('open output file: ' + args.output_file)
+    with open_file_or_stdout(args.output_file) as f:
+        for mature_name in mature_names:
+            f.write('{}\t{}\n'.format(mature_name, counts[mature_name]))
+         
+if __name__ == '__main__':
+    main_parser = argparse.ArgumentParser(description='Count reads in BAM files')
+    subparsers = main_parser.add_subparsers(dest='command')
+
+    parser = subparsers.add_parser('count_transcript', 
+        help='count reads in BAM in transcript coordinates')
+    parser.add_argument('--input-file', '-i', type=str, required=True, help='input BAM/SAM file')
+    parser.add_argument('--min-mapping-quality', '-q', type=int, default=0,
+        help='only count reads with mapping quality greater than this number')
+    parser.add_argument('--strandness', '-s', type=str, default='no',
+        choices=('forward', 'reverse', 'no'),
+        help='forward/reverse: only count reads in reverse strand. no: count reads in both strands')
+    parser.add_argument('--output-file', '-o', type=str, default='-',
+        help='output file')
+
+    parser = subparsers.add_parser('count_circrna', 
+        help='count reads/fragments mapped to circRNA junctions')
+    parser.add_argument('--input-file', '-i', type=str, required=True, help='input BAM/SAM file')
+    parser.add_argument('--paired-end', '-p', action='store_true', help='count reads as paired-end')
+    parser.add_argument('--min-mapping-quality', '-q', type=int, default=0,
+        help='only count reads with mapping quality greater than this number')
+    parser.add_argument('--strandness', '-s', type=str, default='no',
+        choices=('forward', 'reverse', 'no'),
+        help='forward/reverse: only count reads in reverse strand. no: count reads in both strands')
+    parser.add_argument('--output-file', '-o', type=str, default='-', 
+        help='output tab-deliminated file. Two columns: gene_id, count')
+    
+    parser = subparsers.add_parser('count_mature_mirna',
+        help='count reads mapped to mature miRNA')
+    parser.add_argument('--input-file', '-i', type=str, required=True, 
+        help='input BAM/SAM file mapped to miRBase hairpin sequences')
+    parser.add_argument('--annotation', '-a', type=str, required=True,
+        help='GFF3 file containing mature miRNA locations in precursor miRNA')
+    parser.add_argument('--min-mapping-quality', '-q', type=int, default=0,
+        help='only count reads with mapping quality greater than this number')
+    parser.add_argument('--output-file', '-o', type=str, default='-',
+        help='output file')
+    
+    args = main_parser.parse_args()
+    if args.command is None:
+        main_parser.print_help()
+        sys.exit(1)
+    logger = logging.getLogger('count_reads.' + args.command)
+
+    command_handlers.get(args.command)(args)
\ No newline at end of file