a b/exseek/scripts/call_peak.py
1
#! /usr/bin/env python
2
import argparse, sys, os, errno
3
import logging
4
logging.basicConfig(level=logging.INFO, format='[%(asctime)s] [%(levelname)s] %(name)s: %(message)s')
5
6
import numpy as np
7
'''
8
import matplotlib
9
matplotlib.use('Agg')
10
import matplotlib.pyplot as plt
11
from matplotlib.backends.backend_pdf import PdfPages
12
import seaborn as sns
13
sns.set()
14
'''
15
import pandas as pd
16
from pandas import DataFrame, Series
17
from scipy.fftpack import fft
18
from scipy.signal import convolve
19
import numba
20
21
command_handlers = {}
22
def command_handler(f):
23
    command_handlers[f.__name__] = f
24
    return f
25
26
def read_coverage(filename):
27
    coverage = []
28
    gene_ids = []
29
    with open(filename, 'r') as f:
30
        for line in f:
31
            c = line.strip().split('\t')
32
            gene_id = c[0]
33
            values = np.array(c[1:]).astype(np.float64)
34
            gene_ids.append(gene_id)
35
            coverage.append(values)
36
    return gene_ids, coverage
37
38
@numba.jit('int64(int32[:], int32[:], float64, float64, float64)')
39
def icm_update(x, y, h=0.0, beta=1.0, eta=2.1):
40
    n_changes = 0
41
    N = x.shape[0]
42
    for i in range(N):
43
        dx = -2*x[i]
44
        dE = 0
45
        if i > 0:
46
            dE += h*dx - beta*dx*x[i - 1] - eta*dx*y[i]
47
        if i < (N - 1):
48
            dE += h*dx - beta*dx*x[i + 1] - eta*dx*y[i]
49
        if dE < 0:
50
            x[i] = -x[i]
51
            n_changes += 1
52
    return n_changes
53
        
54
def icm_smooth(x, h=0.0, beta=1.0, eta=2.1):
55
    '''Smooth signals using iterated conditional modes
56
    Args:
57
        x: 1D signal
58
    Returns:
59
        Smoothed signal of the same length of x
60
    '''
61
    x = x*2 - 1
62
    y = x.copy()
63
    #E = h*np.sum(x) - beta*x[:-1]*x[1:] - eta*x*y
64
    n_updates = icm_update(x, y, h=h, beta=beta, eta=eta)
65
    while n_updates > 0:
66
        n_updates = icm_update(x, y, h=h, beta=beta, eta=eta)
67
    x = (x > 0).astype(np.int32)
68
    return x
69
70
def call_peak_gene(sig, local_bg_width=3, local_bg_weight=0.5, bg_global=None, smooth=False):
71
    if bg_global is None:
72
        bg_global = np.mean(sig)
73
74
    filter = np.full(local_bg_width, 1.0/local_bg_width)
75
    bg_local = convolve(sig, filter, mode='same')
76
    bg = local_bg_weight*bg_local + (1.0 - local_bg_weight)*bg_global
77
    bg[np.isclose(bg, 0)] = 1
78
    snr = sig/bg
79
    peaks = (snr > 1.0).astype(np.int32)
80
    if smooth:
81
        peaks = icm_smooth(peaks, h=-2.0, beta=4.0, eta=2.0)
82
    x = np.zeros(len(peaks) + 2, dtype=np.int32)
83
    x[1:-1] = peaks
84
    starts = np.nonzero(x[1:] > x[:-1])[0]
85
    ends = np.nonzero(x[:-1] > x[1:])[0]
86
    peaks = np.column_stack([starts, ends])
87
    return peaks
88
89
def estimate_bg_global(signals):
90
    '''signals
91
    '''
92
    signals_mean = np.asarray([np.mean(s) for s in signals])
93
    bg = np.median(signals_mean)
94
    return bg
95
96
def call_peaks(signals, min_length=2):
97
    #bg_global = estimate_bg_global(signals)
98
    bg_global = None
99
    peaks = []
100
    for i, signal in enumerate(signals):
101
        peak_locations = call_peak_gene(signal, bg_global=bg_global, smooth=True)
102
        #print(signal)
103
        for start, end in peak_locations:
104
            #print(peak_locations)
105
            if (min_length is None) or ((end - start) >= min_length):
106
                peaks.append((i, start, end, signal[start:end].mean()))
107
    return peaks
108
109
@command_handler
110
def call_peak(args):
111
    from tqdm import trange
112
113
    logger.info('read input file: ' + args.input_file)
114
    gene_ids, signals = read_coverage(args.input_file)
115
    if args.use_log:
116
        signals = [np.log10(np.maximum(1e-3, a)) + 3 for a in signals]
117
118
    signals_mean = np.asarray([np.mean(a) for a in signals])
119
    bg_global = np.median(signals_mean)
120
121
    logger.info('create output plot file: ' + args.output_file)
122
    with open(args.output_file, 'w') as fout:
123
        for i in trange(len(signals), unit='gene'):
124
            peaks_locations = call_peak_gene(signals[i], bg_global=bg_global, local_bg_weight=args.local_bg_weight,
125
                local_bg_width=args.local_bg_width, smooth=args.smooth)
126
            for j in range(peaks_locations.shape[0]):
127
                fout.write('{}\t{}\t{}\n'.format(gene_ids[i], peaks_locations[j, 0], peaks_locations[j, 1]))
128
129
@command_handler
130
def refine_peaks(args):
131
    import pandas as pd
132
    import numpy as np
133
    from tqdm import tqdm
134
    import re
135
    import h5py
136
    #from bx.bbi.bigwig_file import BigWigFile
137
    import pyBigWig
138
    from ioutils import open_file_or_stdout
139
140
    logger.info('read input matrix file: ' + args.peaks)
141
    #matrix = pd.read_table(args.matrix, sep='\t', index_col=0)
142
    #feature_info = matrix.index.to_series().str.split('|', expand=True)
143
    #feature_info.columns = ['gene_id', 'gene_type', 'gene_name', 'domain_id', 'transcript_id', 'start', 'end']
144
    #feature_info['start'] = feature_info['start'].astype('int')
145
    #feature_info['end'] = feature_info['end'].astype('int')
146
    peaks = pd.read_table(args.peaks, sep='\t', header=None, dtype=str)
147
    if peaks.shape[1] < 6:
148
        raise ValueError('less than 6 columns in peak file')
149
    peaks.columns = ['chrom', 'start', 'end', 'name', 'score', 'strand'] + ['c%d'%i for i in range(6, peaks.shape[1])]
150
    peaks['start'] = peaks['start'].astype('int')
151
    peaks['end'] = peaks['end'].astype('int')
152
153
    logger.info('read chrom sizes: ' + args.chrom_sizes)
154
    chrom_sizes = pd.read_table(args.chrom_sizes, sep='\t', header=None, names=['chrom', 'size'])
155
    chrom_sizes = chrom_sizes.drop_duplicates('chrom')
156
    chrom_sizes = chrom_sizes.set_index('chrom').iloc[:, 0]
157
158
    logger.info('read input genomic bigwig file: ' + args.tbigwig)
159
    tbigwig = pyBigWig.open(args.tbigwig)
160
    #chrom_sizes.update(dict(tbigwig.get_chrom_sizes()))
161
162
    gbigwig = {}
163
    logger.info('read input genomic bigwig (+) file: ' + args.gbigwig_plus)
164
    gbigwig['+'] = pyBigWig.open(args.gbigwig_plus)
165
    logger.info('read input genomic bigwig (-) file: ' + args.gbigwig_minus)
166
    gbigwig['-'] = pyBigWig.open(args.gbigwig_minus)
167
    #chrom_sizes.update(dict(gbigwig['+'].get_chrom_sizes()))
168
169
    flanking = args.flanking
170
    signals = []
171
    signals_mean = []
172
    windows = []
173
    #pat_gene_id = re.compile('^(.*)_([0-9]+)_([0-9]+)_([+-])$')
174
    for _, peak in peaks.iterrows():
175
        if peak['chrom'].startswith('chr'):
176
        #if feature['gene_type'] == 'genomic':
177
            #chrom, start, end, strand = pat_gene_id.match(feature['gene_id']).groups()
178
            #start = int(start)
179
            #end = int(end)
180
            window_start = max(0, peak['start'] - flanking)
181
            window_end = min(peak['end'] + flanking, chrom_sizes[peak['chrom']])
182
            data = np.nan_to_num(gbigwig[peak['strand']].values(peak['chrom'], window_start, window_end))
183
        else:
184
            strand = '+'
185
            window_start = max(0, peak['start'] - flanking)
186
            window_end = min(peak['end'] + flanking, chrom_sizes[peak['chrom']])
187
            data = np.nan_to_num(tbigwig.values(peak['chrom'], window_start, window_end))
188
        if data is None:
189
            data = np.zeros((window_end - window_start))
190
            #logger.info('no coverage data found for peak: {}'.format(feature['domain_id']))
191
        if args.use_log:
192
            data = np.log2(np.maximum(data, 0.25)) + 2
193
        signals.append(data)
194
        signals_mean.append(np.mean(data))
195
        windows.append((peak['chrom'], window_start, window_end, peak['start'], peak['end'], peak['strand']))
196
    tbigwig.close()
197
    gbigwig['+'].close()
198
    gbigwig['-'].close()
199
    windows = pd.DataFrame.from_records(windows)
200
    windows.columns = ['chrom', 'window_start', 'window_end', 'start', 'end', 'strand']
201
202
    logger.info('call peaks')
203
    refined_peaks = call_peaks(signals, min_length=args.min_length)
204
    with open_file_or_stdout(args.output_file) as fout:
205
        for i, start, end, mean_signal in refined_peaks:
206
            # map peak coordinates from window to original
207
            peak = [windows['chrom'][i], 
208
                start + windows['start'][i],
209
                end + windows['start'][i],
210
                'peak_%d'%(i + 1),
211
                '%.4f'%mean_signal,
212
                strand
213
            ]
214
            # remove peaks not overlapping with the window
215
            if (peak[1] > windows['end'][i]) or (peak[2] < windows['start'][i]):
216
                continue
217
            fout.write('\t'.join(map(str, peak)) + '\n')
218
        #print('%s\t%d\t%d => %s\t%d\t%d'%(
219
        #    windows['chrom'][i], windows['start'][i], windows['end'][i],
220
        #    windows['chrom'][i], start, end))
221
222
def _call_peaks_localmax(x, min_peak_length=10, bin_width=10, min_cov=5, decay=0.5):
223
    '''Call peaks by extending from local maxima
224
225
    Parameters:
226
    ----------
227
228
    x: array-like, (length,)
229
        Input signal values
230
    
231
    min_peak_length: int
232
        Minimum length required for each peak
233
    
234
    bin_width: int
235
        Bin width for searching bins with mean coverage higher than min_cov
236
    
237
    min_cov: float
238
        Minimum coverage to define a peak
239
    
240
    decay: float
241
        Stops extending a peak after signal values fall below decay*peak_summit
242
243
    Returns:
244
    -----------
245
246
    peaks: list of list
247
        Peaks found
248
        Each element of the list is a list: [start, end, local_max]
249
    '''
250
    # average signal over bins with 50% overlap
251
    half_bin_width = bin_width//2
252
    length = x.shape[0]
253
    n_bins = max(1, length//half_bin_width)
254
    bin_cov = np.zeros(n_bins)
255
    for i in range(n_bins):
256
        bin_cov[i] = np.mean(x[(i*half_bin_width):min((i + 2)*half_bin_width, length)])
257
    cand_bins = np.nonzero(bin_cov > min_cov)[0]
258
    n_cand_bins = cand_bins.shape[0]
259
    cand_bin_index = 0
260
    left_bound = 0
261
    peaks = []
262
    while cand_bin_index < n_cand_bins:
263
        i = cand_bins[cand_bin_index]*half_bin_width
264
        start = i
265
        end = i
266
        # find local max
267
        while (start > left_bound) and (x[start - 1] >= x[start]) and (x[start - 1] >= min_cov):
268
            start -= 1
269
        while (end < (length - 1)) and (x[end + 1] >= x[end]) and (x[end + 1] >= min_cov):
270
            end += 1
271
        max_index = 0
272
        if x[start] >= x[end]:
273
            local_max = x[start]
274
            max_index = start
275
        else:
276
            local_max = x[end]
277
            max_index = end
278
        if local_max > min_cov:
279
            # find bounds when input signal drops below 0.5*local_max
280
            start = max_index
281
            while (start > left_bound) and (x[start - 1] >= decay*local_max):
282
                start -= 1
283
                local_max = max(local_max, x[start])
284
            end = max_index
285
            while (end < (len(x) - 1)) and (x[end + 1] >= decay*local_max):
286
                end += 1
287
                local_max = max(local_max, x[end])
288
            # add current peak to results
289
            if (end - start) >= min_peak_length:
290
                #print((start, end, local_max))
291
                peaks.append([start, end, local_max])
292
            # find next candidate bin
293
            left_bound = end
294
            next_cand_bin = end//half_bin_width
295
            while (cand_bin_index < n_cand_bins) and (cand_bins[cand_bin_index] < next_cand_bin):
296
                cand_bin_index += 1
297
        cand_bin_index += 1
298
    return peaks
299
300
@command_handler
301
def call_peaks_localmax(args):
302
    import pyBigWig
303
    import numpy as np
304
305
    logger.info('read input file: ' + args.input_file)
306
    bigwig = pyBigWig.open(args.input_file)
307
    logger.info('write output file: ' + args.output_file)
308
    bed = open(args.output_file, 'w')
309
    chroms = bigwig.chroms()
310
    n_peaks = 0
311
    for chrom, size in chroms.items():
312
        if chrom.startswith('chr'):
313
            continue
314
        x = np.nan_to_num(bigwig.values(chrom, 0, size))
315
        peaks_chrom = _call_peaks_localmax(x, 
316
            min_peak_length=args.min_peak_length, bin_width=args.bin_width,
317
            min_cov=args.min_cov, decay=args.decay)
318
        for peak in peaks_chrom:
319
            n_peaks += 1
320
            #peaks.append([chrom, peak[0], peak[1], 'peak_%d'%n_peaks, peak[2], '+'])
321
            bed.write('%s\t%d\t%d\tpeak_%d\t%d\t+\n'%(chrom, peak[0], peak[1], n_peaks, peak[2]))
322
    bigwig.close()
323
    bed.close()
324
325
326
if __name__ == '__main__':
327
    main_parser = argparse.ArgumentParser(description='Call peaks from exRNA signals')
328
    subparsers = main_parser.add_subparsers(dest='command')
329
    
330
    parser = subparsers.add_parser('call_peak')
331
    parser.add_argument('--input-file', '-i', type=str, required=True,
332
        help='input file of exRNA signals for each transcript')
333
    parser.add_argument('--use-log', action='store_true', 
334
        help='use log10 instead raw signals')
335
    parser.add_argument('--smooth', action='store_true',
336
        help='merge adjacent peaks')
337
    parser.add_argument('--local-bg-width', type=int, default=3,
338
        help='number of nearby bins for estimation of local background')
339
    parser.add_argument('--local-bg-weight', type=float, default=0.5, 
340
        help='weight for local background (0.0-1.0)')
341
    parser.add_argument('--output-file', '-o', type=str, required=True,
342
        help='output plot file BED format')
343
344
    parser = subparsers.add_parser('call_peaks_localmax')
345
    parser.add_argument('--input-file', '-i', type=str, required=True,
346
        help='input BigWig file of raw reads coverage')
347
    parser.add_argument('--min-peak-length', type=int, default=10,
348
        help='minimum length required for a peak')
349
    parser.add_argument('--decay', type=float, default=0.5,
350
        help='decay factor of peak summit to define peak boundary')
351
    parser.add_argument('--min-cov', type=float, default=5,
352
        help='minimum coverage required to define a peak')
353
    parser.add_argument('--bin-width', type=int, default=10,
354
        help='bin width to search enriched bins')
355
    parser.add_argument('--output-file', '-o', type=str, required=True,
356
        help='output peaks in BED format')
357
358
    
359
    parser = subparsers.add_parser('refine_peaks')
360
    parser.add_argument('--peaks', type=str, required=True,
361
        help='input count matrix with feature names as the first column')
362
    parser.add_argument('--tbigwig', type=str, required=True,
363
        help='transcript BigWig file')
364
    parser.add_argument('--gbigwig-plus', type=str, required=True,
365
        help='genomic BigWig (+) file')
366
    parser.add_argument('--gbigwig-minus', type=str, required=True,
367
        help='genomic BigWig (-) file')
368
    parser.add_argument('--chrom-sizes', type=str, required=True,
369
        help='chrom sizes')
370
    parser.add_argument('--output-file', '-o', type=str, default='-',
371
        help='output refined peaks')
372
    parser.add_argument('--use-log', action='store_true', 
373
        help='use log10 instead raw signals')
374
    parser.add_argument('--smooth', action='store_true',
375
        help='merge adjacent peaks')
376
    parser.add_argument('--local-bg-width', type=int, default=3,
377
        help='number of nearby bins for estimation of local background')
378
    parser.add_argument('--local-bg-weight', type=float, default=0.5, 
379
        help='weight for local background (0.0-1.0)')
380
    parser.add_argument('--flanking', type=int, default=20)
381
    parser.add_argument('--min-length', type=int, default=10)
382
383
    args = main_parser.parse_args()
384
    if args.command is None:
385
        raise ValueError('empty command')
386
    logger = logging.getLogger('call_peak.' + args.command)
387
388
    command_handlers.get(args.command)(args)
389