a b/exseek/scripts/report.py
1
#! /usr/bin/env python
2
from __future__ import print_function
3
import argparse, sys, os, errno
4
import logging
5
logging.basicConfig(level=logging.INFO, format='[%(asctime)s] [%(levelname)s] %(name)s: %(message)s')
6
7
from scipy import interp
8
from sklearn.metrics import roc_curve, roc_auc_score
9
from sklearn.preprocessing import RobustScaler
10
import numpy as np
11
import pandas as pd
12
13
command_handlers = {}
14
def command_handler(f):
15
    command_handlers[f.__name__] = f
16
    return f
17
18
def _compare_feature_selection_params(input_dir):
19
    from tqdm import tqdm
20
    import pandas as pd
21
    import matplotlib.pyplot as plt
22
23
    records = []
24
    pbar = tqdm(unit='directory')
25
    for compare_group in os.listdir(input_dir):
26
        for path in os.listdir(os.path.join(input_dir, compare_group)):
27
            classifier, n_features, selector, resample_method  = path.split('.')
28
            record = {
29
                'compare_group': compare_group,
30
                'classifier': classifier,
31
                'n_features': n_features,
32
                'selector': selector,
33
                'resample_method': resample_method
34
            }
35
            metrics = pd.read_table(os.path.join(input_dir, compare_group, path, 'metrics.{}.txt'.format(resample_method)))
36
            record['test_roc_auc_mean'] = metrics['test_roc_auc'].mean()
37
            if resample_method == 'leave_one_out':
38
                record['test_roc_auc_std'] = 0
39
            elif resample_method == 'stratified_shuffle_split':
40
                record['test_roc_auc_std'] = metrics['test_roc_auc'].std()
41
            pbar.update(1)
42
            records.append(record)
43
    pbar.close()
44
    records = pd.DataFrame.from_records(records)
45
    records.loc[:, 'n_features'] = records.loc[:, 'n_features'].astype('int')
46
    compare_groups = records.loc[:, 'compare_group'].unique()
47
48
    figsize = 3.5
49
    # Compare resample methods
50
    fig, axes = plt.subplots(1, len(compare_groups), 
51
                             figsize=(figsize*len(compare_groups), figsize),
52
                             sharey=True, sharex=False)
53
    for i, compare_group in enumerate(compare_groups):
54
        if len(compare_groups) > 1:
55
            ax = axes[i]
56
        else:
57
            ax = axes
58
        sub_df = records.query('compare_group == "{}"'.format(compare_group))
59
        pivot = sub_df.pivot_table(index=['classifier', 'n_features', 'selector'], 
60
                  columns=['resample_method'], 
61
                  values='test_roc_auc_mean')
62
        ax.scatter(pivot.loc[:, 'leave_one_out'], pivot.loc[:, 'stratified_shuffle_split'], s=12)
63
        ax.set_xlabel('AUROC (leave_one_out)')
64
        ax.set_ylabel('AUROC (stratified_shuffle_split)')
65
        ax.set_xlim(0.5, 1)
66
        ax.set_ylim(0.5, 1)
67
        ax.plot([0.5, 1], [0.5, 1], linestyle='dashed', color='gray', linewidth=0.8)
68
        ax.set_title(compare_group)
69
70
    # Compare classifiers
71
    fig, axes = plt.subplots(1, len(compare_groups), 
72
                             figsize=(figsize*len(compare_groups), figsize),
73
                             sharey=True, sharex=False)
74
    for i, compare_group in enumerate(compare_groups):
75
        if len(compare_groups) > 1:
76
            ax = axes[i]
77
        else:
78
            ax = axes
79
        sub_df = records.query('compare_group == "{}"'.format(compare_group))
80
        pivot = sub_df.pivot_table(index=['resample_method', 'n_features', 'selector'], 
81
                  columns=['classifier'], 
82
                  values='test_roc_auc_mean')
83
        ax.scatter(pivot.loc[:, 'logistic_regression'], pivot.loc[:, 'random_forest'], s=12)
84
        ax.set_xlabel('AUROC (logistic_regression)')
85
        ax.set_ylabel('AUROC (random_forest)')
86
        ax.set_xlim(0.5, 1)
87
        ax.set_ylim(0.5, 1)
88
        ax.plot([0.5, 1], [0.5, 1], linestyle='dashed', color='gray', linewidth=0.8)
89
        ax.set_title(compare_group)
90
91
    # Compare number of features
92
    fig, axes = plt.subplots(1, len(compare_groups), 
93
                             figsize=(figsize*len(compare_groups), figsize),
94
                             sharey=False, sharex=False)
95
    for i, compare_group in enumerate(compare_groups):
96
        if len(compare_groups) > 1:
97
            ax = axes[i]
98
        else:
99
            ax = axes
100
        sub_df = records.query('compare_group == "{}"'.format(compare_group))
101
        pivot = sub_df.pivot_table(index=['classifier', 'selector', 'resample_method'], 
102
                  columns=['n_features'], 
103
                  values='test_roc_auc_mean')
104
        ax.plot(np.repeat(pivot.columns.values.reshape((-1, 1)), pivot.shape[0], axis=1),
105
                pivot.values.T)
106
        ax.set_ylim(0.5, 1)
107
        ax.set_xlabel('Number of features')
108
        ax.set_ylabel('AUROC')
109
        ax.set_title(compare_group)
110
111
    # Compare feature selection methods
112
    fig, axes = plt.subplots(1, len(compare_groups), 
113
                             figsize=(figsize*len(compare_groups), figsize),
114
                             sharey=True, sharex=False)
115
    for i, compare_group in enumerate(compare_groups):
116
        if len(compare_groups) > 1:
117
            ax = axes[i]
118
        else:
119
            ax = axes
120
        sub_df = records.query('compare_group == "{}"'.format(compare_group))
121
        pivot = sub_df.pivot_table(index=['classifier', 'n_features', 'resample_method'], 
122
                  columns=['selector'], 
123
                  values='test_roc_auc_mean')
124
        ax.plot(np.repeat(pivot.columns.values.reshape((-1, 1)), pivot.shape[0], axis=1),
125
                pivot.values.T)
126
        ax.set_ylim(0.5, 1)
127
        ax.set_xlabel('Feature selection method')
128
        ax.set_ylabel('AUROC')
129
        ax.set_title(compare_group)
130
    return records
131
132
@command_handler
133
def compare_feature_selection_params(args):
134
    _compare_feature_selection_params(args.input_dir)
135
    logger.info('save plot: ' + args.output_file)
136
    plt.savefig(args.output_file)
137
138
def _compare_features(input_dir, datasets):
139
    import pandas as pd
140
    from tqdm import tqdm
141
    import seaborn as sns
142
143
    pbar = tqdm(unit='directory')
144
    records = []
145
    feature_matrices = {}
146
    #feature_support_matrices = {}
147
    feature_indicator_matrices = {}
148
    for dataset in datasets:
149
        cpm = pd.read_table('output/cpm_matrix/{}.txt'.format(dataset), index_col=0)
150
        for compare_group in os.listdir(os.path.join(input_dir, dataset)):
151
            feature_lists = {}
152
            #feature_supports = {}
153
            for path in os.listdir(os.path.join(input_dir, dataset, compare_group)):
154
                classifier, n_features, selector, resample_method  = path.split('.')
155
                if int(n_features) > 10:
156
                    continue
157
                if (classifier != 'random_forest') or (selector != 'robust'):
158
                    continue
159
                if resample_method != 'stratified_shuffle_split':
160
                    continue
161
                record = {
162
                    'compare_group': compare_group,
163
                    'classifier': classifier,
164
                    'n_features': n_features,
165
                    'selector': selector,
166
                    'resample_method': resample_method,
167
                    'dataset': dataset
168
                }
169
                # feature importance
170
                feature_lists[n_features] = pd.read_table(os.path.join(input_dir, dataset, compare_group,
171
                    path, 'feature_importances.txt'), header=None, index_col=0).iloc[:, 0]
172
                feature_lists[n_features].index = feature_lists[n_features].index.astype('str')
173
                # feature support
174
                #with h5py.File(os.path.join(input_dir, dataset, compare_group,
175
                #    path, 'evaluation.{}.h5'.format(resample_method)), 'r') as f:
176
                #    feature_support = np.mean(f['feature_selection'][:], axis=0)
177
                #    feature_support = pd.Series(feature_support, index=cpm.index.values)
178
                #    feature_support = feature_support[feature_lists[n_features].index.values]
179
                #    feature_supports[n_features] = feature_support
180
                # read metrics
181
                metrics = pd.read_table(os.path.join(input_dir, dataset, compare_group, 
182
                    path, 'metrics.{}.txt'.format(resample_method)))
183
                record['test_roc_auc_mean'] = metrics['test_roc_auc'].mean()
184
                if resample_method == 'leave_one_out':
185
                    record['test_roc_auc_std'] = 0
186
                elif resample_method == 'stratified_shuffle_split':
187
                    record['test_roc_auc_std'] = metrics['test_roc_auc'].std()
188
                pbar.update(1)
189
                records.append(record)
190
            # feature union set
191
            feature_set = reduce(np.union1d, [a.index.values for a in feature_lists.values()])
192
            # build feature importance matrix
193
            feature_matrix = pd.DataFrame(np.zeros((len(feature_set), len(feature_lists))),
194
                                          index=feature_set, columns=list(feature_lists.keys()))
195
            for n_features, feature_importance in feature_lists.items():
196
                feature_matrix.loc[feature_importance.index.values, n_features] = feature_importance.values
197
            feature_matrix.columns = feature_matrix.columns.astype('int')
198
            feature_matrix.index = feature_matrix.index.astype('str')
199
            feature_matrix = feature_matrix.loc[:, feature_matrix.columns.sort_values().values]
200
                    
201
            feature_matrices[(dataset, compare_group)] = feature_matrix
202
            # build feature indicator matrix
203
            feature_indicator_matrix = pd.DataFrame(np.zeros((len(feature_set), len(feature_lists))),
204
                                          index=feature_set, columns=list(feature_lists.keys()))
205
            for n_features, feature_importance in feature_lists.items():
206
                feature_indicator_matrix.loc[feature_importance.index.values, n_features] = 1
207
            feature_indicator_matrix.columns = feature_indicator_matrix.columns.astype('int')
208
            feature_indicator_matrix = feature_indicator_matrix.loc[:, feature_indicator_matrix.columns.sort_values().values]
209
            feature_indicator_matrices[(dataset, compare_group)] = feature_indicator_matrix
210
            
211
            if dataset in feature_fields:
212
                feature_meta = feature_matrix.index.to_series().str.split('|', expand=True)
213
                feature_meta.columns = feature_fields[dataset]
214
                if 'transcript_id' in feature_fields[dataset]:
215
                    feature_matrix.insert(
216
                        0, 'gene_type', 
217
                        transcript_table_by_transcript_id.loc[feature_meta['transcript_id'].values, 'gene_type'].values)
218
                    feature_matrix.insert(
219
                        0, 'gene_name', 
220
                        transcript_table_by_transcript_id.loc[feature_meta['transcript_id'].values, 'gene_name'].values)
221
                    
222
                elif 'gene_id' in feature_fields[dataset]:
223
                    feature_matrix.insert(
224
                        0, 'gene_type', 
225
                        transcript_table_by_gene_id.loc[feature_meta['gene_id'].values, 'gene_type'].values)
226
                    feature_matrix.insert(
227
                        0, 'gene_name', 
228
                        transcript_table_by_gene_id.loc[feature_meta['gene_id'].values, 'gene_name'].values)
229
                elif 'transcript_name' in feature_fields[dataset]:
230
                    feature_matrix.insert(
231
                        0, 'gene_type', 
232
                        transcript_table_by_transcript_name.loc[feature_meta['transcript_name'].values, 'gene_type'].values)
233
                    feature_matrix.insert(
234
                        0, 'gene_name', 
235
                        transcript_table_by_transcript_name.loc[feature_meta['transcript_name'].values, 'gene_name'].values)
236
                    
237
                feature_indicator_matrix.index = feature_matrix.loc[:, 'gene_name'].values + '|' + feature_matrix.loc[:, 'gene_type'].values
238
            # build feature support matrix
239
            #feature_support_matrix = pd.DataFrame(np.zeros((len(feature_set), len(feature_lists))),
240
            #                              index=feature_set, columns=list(feature_lists.keys()))
241
            #for n_features, feature_support in feature_supports.items():
242
            #    feature_support_matrix.loc[feature_support.index.values, n_features] = feature_support.values
243
            #feature_support_matrix.columns = feature_support_matrix.columns.astype('int')
244
            #feature_support_matrix = feature_matrix.loc[:, feature_support_matrix.columns.sort_values().values]
245
            #feature_support_matrices[(dataset, compare_group)] = feature_support_matrix
246
            fig, ax = plt.subplots(figsize=(6, 8))
247
            sns.heatmap(feature_indicator_matrix,
248
                        cmap=sns.light_palette('green', as_cmap=True), cbar=False, ax=ax, linewidth=1)
249
            ax.set_xlabel('Number of features')
250
            ax.set_ylabel('Fetures')
251
            ax.set_title('{}, {}'.format(dataset, compare_group))
252
253
            display(feature_matrix.style\
254
                .background_gradient(cmap=sns.light_palette('green', as_cmap=True))\
255
                .set_precision(2)\
256
                .set_caption('{}, {}'.format(dataset, compare_group)))
257
258
    pbar.close()
259
    metrics = pd.DataFrame.from_records(records)
260
    return metrics, feature_matrices, feature_indicator_matrices
261
262
def plot_roc_curve_ci(y, is_train, predicted_scores, ax, title=None):
263
    # ROC curve
264
    n_splits = is_train.shape[0]
265
    all_fprs = np.linspace(0, 1, 100)
266
    roc_curves = np.zeros((n_splits, len(all_fprs), 2))
267
    roc_aucs = np.zeros(n_splits)
268
    for i in range(n_splits):
269
        fpr, tpr, thresholds = roc_curve(y[~is_train[i]], predicted_scores[i, ~is_train[i]])
270
        roc_aucs[i] = roc_auc_score(y[~is_train[i]], predicted_scores[i, ~is_train[i]])
271
        roc_curves[i, :, 0] = all_fprs
272
        roc_curves[i, :, 1] = interp(all_fprs, fpr, tpr)
273
    roc_curves = pd.DataFrame(roc_curves.reshape((-1, 2)), columns=['fpr', 'tpr'])
274
    sns.lineplot(x='fpr', y='tpr', data=roc_curves, ci='sd', ax=ax,
275
                 label='Average ROAUC = {:.4f}'.format(roc_aucs.mean()))
276
    #ax.plot(fpr, tpr, label='ROAUC = {:.4f}'.format(roc_auc_score(y_test, y_score[:, 1])))
277
    #ax.plot([0, 1], [0, 1], linestyle='dashed')
278
    ax.set_xlabel('False positive rate')
279
    ax.set_ylabel('True positive rate')
280
    ax.plot([0, 1], [0, 1], linestyle='dashed', color='gray')
281
    if title:
282
        ax.set_title(title)
283
    ax.legend()
284
285
def _plot_10_features(input_dir, datasets, use_log=False, scale=False, title=None):
286
    pbar = tqdm_notebook(unit='directory')
287
    for dataset in datasets:
288
        sample_classes = pd.read_table('metadata/sample_classes.{}.txt'.format(groups[dataset]),
289
                                       header=None, index_col=0).iloc[:, 0]
290
        cpm = pd.read_table('output/cpm_matrix/{}.txt'.format(dataset), index_col=0)
291
        if use_log:
292
            cpm = np.log2(cpm + 0.001)
293
        if scale:
294
            X = RobustScaler().fit_transform(cpm.T.values).T
295
        X = cpm.values
296
        X = pd.DataFrame(X, index=cpm.index.values, columns=cpm.columns.values)
297
        for compare_group in os.listdir(os.path.join(input_dir, dataset)):
298
            for path in os.listdir(os.path.join(input_dir, dataset, compare_group)):
299
                classifier, n_features, selector, resample_method  = path.split('.')
300
                if int(n_features) != 10:
301
                    continue
302
                if (classifier != 'random_forest') or (selector != 'robust'):
303
                    continue
304
                if resample_method != 'stratified_shuffle_split':
305
                    continue
306
                record = {
307
                    'compare_group': compare_group,
308
                    'classifier': classifier,
309
                    'n_features': n_features,
310
                    'selector': selector,
311
                    'resample_method': resample_method,
312
                    'dataset': dataset
313
                }
314
                result_dir = os.path.join(input_dir, dataset, compare_group, path)
315
                with h5py.File(os.path.join(result_dir, 'evaluation.{}.h5'.format(resample_method))) as f:
316
                    train_index = f['train_index'][:]
317
                    predicted_scores = f['predictions'][:]
318
                    labels = f['labels'][:]
319
                fig, ax = plt.subplots(figsize=(8, 8))
320
                plot_roc_curve_ci(labels, train_index, predicted_scores, ax, 
321
                                  title='{}, {}'.format(dataset, compare_group))
322
                
323
                features = pd.read_table(os.path.join(result_dir, 'features.txt'), header=None).iloc[:, 0].values
324
                pbar.update(1)
325
326
    pbar.close()
327
328
329
def _evaluate_preprocess_methods(input_dirs, preprocess_methods, title=None):
330
    records = []
331
    pbar = tqdm_notebook(unit='directory')
332
    for preprocess_method, input_dir in zip(preprocess_methods, input_dirs):
333
        for compare_group in os.listdir(input_dir):
334
            for path in os.listdir(os.path.join(input_dir, compare_group)):
335
                classifier, n_features, selector, resample_method  = path.split('.')
336
                if int(n_features) > 50:
337
                    continue
338
                if (classifier != 'random_forest') or (selector != 'robust'):
339
                    continue
340
                if resample_method != 'stratified_shuffle_split':
341
                    continue
342
                record = {
343
                    'compare_group': compare_group,
344
                    'classifier': classifier,
345
                    'n_features': n_features,
346
                    'selector': selector,
347
                    'resample_method': resample_method,
348
                    'preprocess_method': preprocess_method
349
                }
350
                metrics = pd.read_table(os.path.join(input_dir, compare_group, path, 'metrics.{}.txt'.format(resample_method)))
351
                record['test_roc_auc_mean'] = metrics['test_roc_auc'].mean()
352
                if resample_method == 'leave_one_out':
353
                    record['test_roc_auc_std'] = 0
354
                elif resample_method == 'stratified_shuffle_split':
355
                    record['test_roc_auc_std'] = metrics['test_roc_auc'].std()
356
                pbar.update(1)
357
                records.append(record)
358
    pbar.close()
359
    records = pd.DataFrame.from_records(records)
360
    records['n_features'] = records.loc[:, 'n_features'].astype(np.int32)
361
    for compare_group, sub_df in records.groupby('compare_group'):
362
        pivot = sub_df.pivot_table(
363
            index='preprocess_method', columns='n_features', values='test_roc_auc_mean')
364
        #print(pivot.iloc[:, 0])
365
        #print(np.argsort(np.argsort(pivot.values, axis=0), axis=0)[:, 0])
366
        mean_ranks = np.mean(pivot.shape[0] - np.argsort(np.argsort(pivot.values, axis=0), axis=0), axis=1)
367
        mean_ranks = pd.Series(mean_ranks, index=pivot.index.values)
368
        mean_ranks = mean_ranks.sort_values()
369
        rename_index = ['{} (rank = {:.1f})'.format(name, value) for name, value in zip(mean_ranks.index, mean_ranks.values)]
370
        rename_index = pd.Series(rename_index, index=mean_ranks.index.values)
371
        sub_df = sub_df.copy()
372
        sub_df['preprocess_method'] = rename_index[sub_df['preprocess_method'].values].values
373
        sub_df['n_features'] = sub_df['n_features'].astype('int')
374
        sub_df = sub_df.sort_values(['preprocess_method', 'n_features'], ascending=True)
375
        sub_df['n_features'] = sub_df['n_features'].astype('str')
376
        fig, ax = plt.subplots(figsize=(8, 8))                      
377
        #sns.lineplot('n_features', 'test_roc_auc_mean', hue='preprocess_method', data=sub_df, 
378
        #          ci=None, ax=ax, markers='o', hue_order=rename_index.values, sort=False)
379
        for preprocess_method in rename_index.values:
380
            tmp_df = sub_df[sub_df['preprocess_method'] == preprocess_method]
381
            ax.plot(np.arange(tmp_df.shape[0]) + 1, tmp_df['test_roc_auc_mean'], label=preprocess_method)
382
            ax.set_xticks(np.arange(tmp_df.shape[0]) + 1)
383
            ax.set_xticklabels(tmp_df['n_features'])
384
        ax.set_xlabel('Number of features')
385
        ax.set_ylabel('Average AUROC')
386
        if len(preprocess_methods) > 1:
387
            ax.legend(title='Preprocess method', bbox_to_anchor=(1.04,0.5), 
388
                      loc="center left", borderaxespad=0)
389
        ax.set_ylim(0.5, 1)
390
        if title:
391
            ax.set_title(title + ', ' + compare_group)
392
393
@command_handler
394
def evaluate_preprocessing_methods(args):
395
    _evaluate_preprocess_methods(args.input_dirs, args.precessing_methods)
396
397
def bigwig_fetch(filename, chrom, start, end, dtype='float'):
398
    import subprocess
399
    p = subprocess.Popen(['bigWigToBedGraph', filename, 'stdout',
400
                      '-chrom={}'.format(chrom), '-start={}'.format(start), '-end={}'.format(end)],
401
                    stdout=subprocess.PIPE)
402
    data = np.zeros(end - start, dtype=dtype)
403
    for line in p.stdout:
404
        line = str(line, encoding='ascii')
405
        c = line.strip().split('\t')
406
        data[(int(c[1]) - start):(int(c[2]) - start)] = float(c[3])
407
    return data
408
    
409
410
def extract_feature_sequence(feature, genome_dir):
411
    from pyfaidx import Fasta
412
    from Bio.Seq import Seq
413
414
    feature = line.split('\t')[0]
415
    gene_id, gene_type, gene_name, domain_id, transcript_id, start, end = feature.split('|')
416
    start = int(start)
417
    end = int(end)
418
    if gene_type == 'genomic':
419
        gene_type = 'genome'
420
    fasta = Fasta(os.path.join(args.genome_dir, 'fasta', gene_type + '.fa'))
421
    if gene_type == 'genome':
422
        chrom, gstart, gend, strand = gene_id.split('_')
423
        gstart = int(gstart)
424
        gend = int(gend)
425
        seq = fasta[chrom][gstart:gend].seq
426
        if strand == '-':
427
            seq = str(Seq(seq).reverse_complement())
428
    else:
429
        seq = fasta[transcript_id][start:end].seq
430
    seq = seq.upper()
431
432
433
@command_handler
434
def visualize_domains(args):
435
    import numpy as np
436
    import matplotlib
437
    matplotlib.use('Agg')
438
    import matplotlib.pyplot as plt
439
    from matplotlib.backends.backend_pdf import PdfPages
440
    from matplotlib.gridspec import GridSpec
441
    import seaborn as sns
442
    sns.set_style('white')
443
    import pandas as pd
444
    plt.rcParams['figure.dpi'] = 96
445
    from tqdm import tqdm
446
    from pykent import BigWigFile
447
    from scipy.cluster.hierarchy import linkage, dendrogram
448
    from pyfaidx import Fasta
449
    from Bio.Seq import Seq
450
    from call_peak import call_peaks
451
452
    # read sample ids
453
    #logger.info('read sample ids: ' + args.sample_ids)
454
    #sample_ids = open(args.sample_ids_file, 'r').read().split()
455
    logger.info('reads sample classes: ' + args.sample_classes)
456
    sample_classes =  pd.read_table(args.sample_classes, sep='\t', index_col=0).iloc[:, 0]
457
    sample_classes = sample_classes.sort_values()
458
    sample_ids = sample_classes.index.values
459
460
    # read features
461
    features = pd.read_table(args.features, header=None).iloc[:, 0]
462
    feature_info = features.str.split('|', expand=True)
463
    feature_info.columns = ['gene_id', 'gene_type', 'gene_name', 'domain_id', 'transcript_id', 'start', 'end']
464
    feature_info.index = features.values
465
    # read count matrix to get read depth
466
    #counts = pd.read_table(args.count_matrix, index_col=0)
467
    #read_depth = counts.sum(axis=0)
468
    #del counts
469
    # read chrom sizes
470
    #chrom_sizes = pd.read_table(args.chrom_sizes, sep='\t', index_col=0, header=None).iloc[:, 0]
471
472
    with PdfPages(args.output_file) as pdf:
473
        for feature_name, feature in tqdm(feature_info.iterrows(), unit='feature'):
474
            #logger.info('plot feature: {}'.format(feature_name))
475
            if feature['gene_type'] == 'genomic':
476
                chrom, start, end, strand = feature['gene_id'].split('_')
477
                start = int(start)
478
                end = int(end)
479
                bigwig_file = os.path.join(args.output_dir, 'bigwig', '{{0}}.genome.{0}.bigWig'.format(strand))
480
            elif feature['gene_type'] in ('piRNA', 'miRNA'):
481
                continue
482
            else:
483
                start = int(feature['start'])
484
                end = int(feature['end'])
485
                chrom = feature.transcript_id
486
                bigwig_file = os.path.join(args.output_dir, 'tbigwig_normalized', '{0}.transcriptome.bigWig')
487
            # read coverage from BigWig files
488
            coverage = None
489
            for i, sample_id in enumerate(sample_ids):
490
                bwf = BigWigFile(bigwig_file.format(sample_id))
491
                if coverage is None:
492
                    # interval to display coverage
493
                    chrom_size = bwf.get_chrom_size(feature['transcript_id'])
494
                    if chrom_size == 0:
495
                        raise ValueError('cannot find transcript id {} in bigwig'.format(feature['transcript_id']))
496
                    view_start = max(start - args.flanking, 0)
497
                    view_end = min(end + args.flanking, chrom_size)
498
                    coverage = np.zeros((len(sample_ids), view_end - view_start), dtype='float')
499
                    logger.info('create_coverage_matrix: ({}, {})'.format(*coverage.shape))
500
                #logger.info('bigWigQuery: {}:{}-{}'.format(chrom, view_start, view_end))
501
                values = bwf.query(chrom, view_start, view_end, fillna=0)
502
                del bwf
503
                if values is not None:
504
                    coverage[i] = values
505
                #coverage[i] = bigwig_fetch(bigwig_file.format(sample_id), chrom, view_start, view_end, dtype='int')
506
                # normalize coverage by read depth
507
                #coverage[i] *= 1e6/read_depth[sample_id]
508
                # log2 transformation
509
                coverage[i] = np.log2(coverage[i] + 1)
510
            
511
            # get sequence
512
            gene_type = feature['gene_type']
513
            if gene_type == 'genomic':
514
                gene_type = 'genome'
515
            fasta = Fasta(os.path.join(args.genome_dir, 'fasta', gene_type + '.fa'))
516
            seq = fasta[feature['transcript_id']][view_start:view_end].seq
517
            if (gene_type == 'genome') and (strand == '-'):
518
                seq = str(Seq(seq).reverse_complement())
519
            seq = seq.upper()
520
521
            # draw heatmap
522
            '''
523
            plot_data = pd.DataFrame(coverage)
524
            cmap = sns.light_palette('blue', as_cmap=True, n_colors=6)
525
            g = sns.clustermap(plot_data, figsize=(20, 8), col_cluster=False, row_colors=None, cmap='Blues')
526
            g.ax_heatmap.set_yticklabels([])
527
            g.ax_heatmap.set_yticks([])
528
            xticks = np.arange(0, coverage.shape[1], 10)
529
            g.ax_heatmap.set_xticks(xticks)
530
            g.ax_heatmap.set_xticklabels(xticks, rotation=0)
531
            g.ax_heatmap.vlines(x=domain_start, ymin=0, ymax=g.ax_heatmap.get_ylim()[0], linestyle='dashed', linewidth=1.0)
532
            g.ax_heatmap.vlines(x=domain_end, ymin=0, ymax=g.ax_heatmap.get_ylim()[0], linestyle='dashed', linewidth=1.0)
533
            g.ax_heatmap.set_title(feature_name)
534
            '''
535
            # hierarchical clustering
536
            order = np.arange(coverage.shape[0], dtype='int')
537
            for label in np.unique(sample_classes):
538
                mask = (sample_classes == label)
539
                Z = linkage(coverage[mask], 'single')
540
                R = dendrogram(Z, no_plot=True, labels=order[mask])
541
                order[mask] = R['ivl']
542
            sample_classes = sample_classes.iloc[order]
543
            coverage = coverage[order]
544
545
            plt.rcParams['xtick.minor.visible'] = True
546
            plt.rcParams['xtick.minor.size'] = 4
547
            plt.rcParams['xtick.bottom'] = True
548
            plt.rcParams['xtick.labelsize'] = 8
549
            fig = plt.figure(figsize=(20, 6))
550
            gs = GridSpec(4, 3, figure=fig, width_ratios=[0.95, 0.03, 0.02], height_ratios=[0.6, 0.15, 0.1, 0.1], hspace=0.2, wspace=0.15)
551
            #fig, axes = plt.subplots(1, 2, figsize=(20, 3), sharey=True, 
552
            #    gridspec_kw={'width_ratios': [0.98, 0.02], 'hspace': 0})
553
            ax_heatmap = plt.subplot(gs[0, 0])
554
            ax_colorbar = plt.subplot(gs[0, 1])
555
            ax_label = plt.subplot(gs[0, 2])
556
            ax_line = plt.subplot(gs[1, 0])
557
            ax_domain = plt.subplot(gs[2, 0])
558
            ax_refined_domain = plt.subplot(gs[3, 0])
559
560
            p = ax_heatmap.pcolormesh(coverage, cmap='Blues')
561
            ax_heatmap.set_xticks(np.arange(coverage.shape[1]) + 0.5, minor=True)
562
            ax_heatmap.set_xlim(0, coverage.shape[1])
563
            xticks = ax_heatmap.get_xticks()
564
            ax_heatmap.set_xticks(xticks + 0.5)
565
            ax_heatmap.set_xticklabels(xticks.astype('int'))
566
            ax_heatmap.set_title(feature_name)
567
568
            fig.colorbar(p, cax=ax_colorbar, use_gridspec=False, orientation='vertical')
569
570
            for label in sample_classes.unique():
571
                ax_label.barh(y=np.arange(coverage.shape[0]), width=(sample_classes == label).astype('int'), height=1,
572
                    edgecolor='none', label=label)
573
            ax_label.set_xlim(0, 1)
574
            ax_label.set_ylim(0, coverage.shape[0])
575
            ax_label.tick_params(labelbottom=False, bottom=False)
576
            ax_label.set_xticks([])
577
            ax_label.set_yticks([])
578
            ax_label.legend(title='Class', bbox_to_anchor=(1.1, 0.5), loc="center left", borderaxespad=0)
579
580
            ax_line.fill_between(np.arange(coverage.shape[1]), coverage.mean(axis=0), step='pre', alpha=0.9)
581
            ax_line.set_xlim(0, coverage.shape[1])
582
            #ax_line.set_xticks(np.arange(coverage.shape[1]) + 0.5, minor=True)
583
            #xticks = ax_line.get_xticks()
584
            #ax_line.set_xticks(xticks + 0.5)
585
            #ax_line.set_xticklabels(xticks.astype('int'))
586
            ax_line.set_xticks(np.arange(coverage.shape[1]) + 0.5)
587
            ax_line.set_xticks([], minor=True)
588
            ax_line.set_xticklabels(list(seq))
589
            ax_line.set_ylim(0, ax_line.get_ylim()[1])
590
            #ax_line.vlines(x=start - view_start + 0.5, ymin=0, ymax=ax_line.get_ylim()[1], linestyle='dashed', linewidth=1.0)
591
            #ax_line.vlines(x=end - view_start + 0.5, ymin=0, ymax=ax_line.get_ylim()[1], linestyle='dashed', linewidth=1.0)
592
593
            ax_domain.hlines(y=0.5, xmin=start - view_start, xmax=end - view_start, linewidth=5, color='C0')
594
            ax_domain.set_ylim(0, 1)
595
            ax_domain.set_ylabel('Domain')
596
            ax_domain.set_yticks([])
597
            ax_domain.set_xticks([])
598
            ax_domain.set_xticks(np.arange(coverage.shape[1]) + 0.5, minor=True)
599
            ax_domain.set_xlim(0, coverage.shape[1])
600
            ax_domain.spines['top'].set_visible(False)
601
            ax_domain.spines['right'].set_visible(False)
602
603
            coverage_mean = coverage.mean(axis=0)
604
            for _, peak_start, peak_end in call_peaks([coverage_mean], min_length=10):
605
                ax_refined_domain.hlines(y=0.5, xmin=peak_start, xmax=peak_end, linewidth=5, color='C0')
606
            ax_refined_domain.set_ylim(0, 1)
607
            ax_refined_domain.set_ylabel('Refined')
608
            ax_refined_domain.set_yticks([])
609
            ax_refined_domain.set_xticks([])
610
            ax_refined_domain.set_xticks(np.arange(coverage.shape[1]) + 0.5, minor=True)
611
            ax_refined_domain.set_xlim(0, coverage.shape[1])
612
            ax_refined_domain.spines['top'].set_visible(False)
613
            ax_refined_domain.spines['right'].set_visible(False)
614
615
            fig.tight_layout()
616
            # save plot
617
            pdf.savefig(fig)
618
            plt.close()
619
620
621
if __name__ == '__main__':
622
    main_parser = argparse.ArgumentParser(description='Preprocessing module')
623
    subparsers = main_parser.add_subparsers(dest='command')
624
625
    parser = subparsers.add_parser('visualize_domains', help='plot read coverage of domains as heatmaps')
626
    parser.add_argument('--sample-classes', type=str, required=True, help='e.g. {data_dir}/sample_classes.txt')
627
    parser.add_argument('--output-dir', type=str, required=True, help='e.g. output/scirep')
628
    parser.add_argument('--features', type=str, required=True, help='list of selected features')
629
    #parser.add_argument('--count-matrix', type=str, required=True, help='count matrix')
630
    parser.add_argument('--output-file', '-o', type=str, required=True, help='output PDF file')
631
    parser.add_argument('--flanking', type=int, default=20, help='flanking length for genomic domains')
632
    parser.add_argument('--genome-dir', type=str, required=True, help='e.g. genome/hg38')
633
    
634
    args = main_parser.parse_args()
635
    if not args.command:
636
        main_parser.print_help()
637
        sys.exit(1)
638
    logger = logging.getLogger('report.' + args.command)
639
640
    command_handlers.get(args.command)(args)