Diff of /exseek/scripts/report.py [000000] .. [4c33d4]

Switch to side-by-side view

--- a
+++ b/exseek/scripts/report.py
@@ -0,0 +1,640 @@
+#! /usr/bin/env python
+from __future__ import print_function
+import argparse, sys, os, errno
+import logging
+logging.basicConfig(level=logging.INFO, format='[%(asctime)s] [%(levelname)s] %(name)s: %(message)s')
+
+from scipy import interp
+from sklearn.metrics import roc_curve, roc_auc_score
+from sklearn.preprocessing import RobustScaler
+import numpy as np
+import pandas as pd
+
+command_handlers = {}
+def command_handler(f):
+    command_handlers[f.__name__] = f
+    return f
+
+def _compare_feature_selection_params(input_dir):
+    from tqdm import tqdm
+    import pandas as pd
+    import matplotlib.pyplot as plt
+
+    records = []
+    pbar = tqdm(unit='directory')
+    for compare_group in os.listdir(input_dir):
+        for path in os.listdir(os.path.join(input_dir, compare_group)):
+            classifier, n_features, selector, resample_method  = path.split('.')
+            record = {
+                'compare_group': compare_group,
+                'classifier': classifier,
+                'n_features': n_features,
+                'selector': selector,
+                'resample_method': resample_method
+            }
+            metrics = pd.read_table(os.path.join(input_dir, compare_group, path, 'metrics.{}.txt'.format(resample_method)))
+            record['test_roc_auc_mean'] = metrics['test_roc_auc'].mean()
+            if resample_method == 'leave_one_out':
+                record['test_roc_auc_std'] = 0
+            elif resample_method == 'stratified_shuffle_split':
+                record['test_roc_auc_std'] = metrics['test_roc_auc'].std()
+            pbar.update(1)
+            records.append(record)
+    pbar.close()
+    records = pd.DataFrame.from_records(records)
+    records.loc[:, 'n_features'] = records.loc[:, 'n_features'].astype('int')
+    compare_groups = records.loc[:, 'compare_group'].unique()
+
+    figsize = 3.5
+    # Compare resample methods
+    fig, axes = plt.subplots(1, len(compare_groups), 
+                             figsize=(figsize*len(compare_groups), figsize),
+                             sharey=True, sharex=False)
+    for i, compare_group in enumerate(compare_groups):
+        if len(compare_groups) > 1:
+            ax = axes[i]
+        else:
+            ax = axes
+        sub_df = records.query('compare_group == "{}"'.format(compare_group))
+        pivot = sub_df.pivot_table(index=['classifier', 'n_features', 'selector'], 
+                  columns=['resample_method'], 
+                  values='test_roc_auc_mean')
+        ax.scatter(pivot.loc[:, 'leave_one_out'], pivot.loc[:, 'stratified_shuffle_split'], s=12)
+        ax.set_xlabel('AUROC (leave_one_out)')
+        ax.set_ylabel('AUROC (stratified_shuffle_split)')
+        ax.set_xlim(0.5, 1)
+        ax.set_ylim(0.5, 1)
+        ax.plot([0.5, 1], [0.5, 1], linestyle='dashed', color='gray', linewidth=0.8)
+        ax.set_title(compare_group)
+
+    # Compare classifiers
+    fig, axes = plt.subplots(1, len(compare_groups), 
+                             figsize=(figsize*len(compare_groups), figsize),
+                             sharey=True, sharex=False)
+    for i, compare_group in enumerate(compare_groups):
+        if len(compare_groups) > 1:
+            ax = axes[i]
+        else:
+            ax = axes
+        sub_df = records.query('compare_group == "{}"'.format(compare_group))
+        pivot = sub_df.pivot_table(index=['resample_method', 'n_features', 'selector'], 
+                  columns=['classifier'], 
+                  values='test_roc_auc_mean')
+        ax.scatter(pivot.loc[:, 'logistic_regression'], pivot.loc[:, 'random_forest'], s=12)
+        ax.set_xlabel('AUROC (logistic_regression)')
+        ax.set_ylabel('AUROC (random_forest)')
+        ax.set_xlim(0.5, 1)
+        ax.set_ylim(0.5, 1)
+        ax.plot([0.5, 1], [0.5, 1], linestyle='dashed', color='gray', linewidth=0.8)
+        ax.set_title(compare_group)
+
+    # Compare number of features
+    fig, axes = plt.subplots(1, len(compare_groups), 
+                             figsize=(figsize*len(compare_groups), figsize),
+                             sharey=False, sharex=False)
+    for i, compare_group in enumerate(compare_groups):
+        if len(compare_groups) > 1:
+            ax = axes[i]
+        else:
+            ax = axes
+        sub_df = records.query('compare_group == "{}"'.format(compare_group))
+        pivot = sub_df.pivot_table(index=['classifier', 'selector', 'resample_method'], 
+                  columns=['n_features'], 
+                  values='test_roc_auc_mean')
+        ax.plot(np.repeat(pivot.columns.values.reshape((-1, 1)), pivot.shape[0], axis=1),
+                pivot.values.T)
+        ax.set_ylim(0.5, 1)
+        ax.set_xlabel('Number of features')
+        ax.set_ylabel('AUROC')
+        ax.set_title(compare_group)
+
+    # Compare feature selection methods
+    fig, axes = plt.subplots(1, len(compare_groups), 
+                             figsize=(figsize*len(compare_groups), figsize),
+                             sharey=True, sharex=False)
+    for i, compare_group in enumerate(compare_groups):
+        if len(compare_groups) > 1:
+            ax = axes[i]
+        else:
+            ax = axes
+        sub_df = records.query('compare_group == "{}"'.format(compare_group))
+        pivot = sub_df.pivot_table(index=['classifier', 'n_features', 'resample_method'], 
+                  columns=['selector'], 
+                  values='test_roc_auc_mean')
+        ax.plot(np.repeat(pivot.columns.values.reshape((-1, 1)), pivot.shape[0], axis=1),
+                pivot.values.T)
+        ax.set_ylim(0.5, 1)
+        ax.set_xlabel('Feature selection method')
+        ax.set_ylabel('AUROC')
+        ax.set_title(compare_group)
+    return records
+
+@command_handler
+def compare_feature_selection_params(args):
+    _compare_feature_selection_params(args.input_dir)
+    logger.info('save plot: ' + args.output_file)
+    plt.savefig(args.output_file)
+
+def _compare_features(input_dir, datasets):
+    import pandas as pd
+    from tqdm import tqdm
+    import seaborn as sns
+
+    pbar = tqdm(unit='directory')
+    records = []
+    feature_matrices = {}
+    #feature_support_matrices = {}
+    feature_indicator_matrices = {}
+    for dataset in datasets:
+        cpm = pd.read_table('output/cpm_matrix/{}.txt'.format(dataset), index_col=0)
+        for compare_group in os.listdir(os.path.join(input_dir, dataset)):
+            feature_lists = {}
+            #feature_supports = {}
+            for path in os.listdir(os.path.join(input_dir, dataset, compare_group)):
+                classifier, n_features, selector, resample_method  = path.split('.')
+                if int(n_features) > 10:
+                    continue
+                if (classifier != 'random_forest') or (selector != 'robust'):
+                    continue
+                if resample_method != 'stratified_shuffle_split':
+                    continue
+                record = {
+                    'compare_group': compare_group,
+                    'classifier': classifier,
+                    'n_features': n_features,
+                    'selector': selector,
+                    'resample_method': resample_method,
+                    'dataset': dataset
+                }
+                # feature importance
+                feature_lists[n_features] = pd.read_table(os.path.join(input_dir, dataset, compare_group,
+                    path, 'feature_importances.txt'), header=None, index_col=0).iloc[:, 0]
+                feature_lists[n_features].index = feature_lists[n_features].index.astype('str')
+                # feature support
+                #with h5py.File(os.path.join(input_dir, dataset, compare_group,
+                #    path, 'evaluation.{}.h5'.format(resample_method)), 'r') as f:
+                #    feature_support = np.mean(f['feature_selection'][:], axis=0)
+                #    feature_support = pd.Series(feature_support, index=cpm.index.values)
+                #    feature_support = feature_support[feature_lists[n_features].index.values]
+                #    feature_supports[n_features] = feature_support
+                # read metrics
+                metrics = pd.read_table(os.path.join(input_dir, dataset, compare_group, 
+                    path, 'metrics.{}.txt'.format(resample_method)))
+                record['test_roc_auc_mean'] = metrics['test_roc_auc'].mean()
+                if resample_method == 'leave_one_out':
+                    record['test_roc_auc_std'] = 0
+                elif resample_method == 'stratified_shuffle_split':
+                    record['test_roc_auc_std'] = metrics['test_roc_auc'].std()
+                pbar.update(1)
+                records.append(record)
+            # feature union set
+            feature_set = reduce(np.union1d, [a.index.values for a in feature_lists.values()])
+            # build feature importance matrix
+            feature_matrix = pd.DataFrame(np.zeros((len(feature_set), len(feature_lists))),
+                                          index=feature_set, columns=list(feature_lists.keys()))
+            for n_features, feature_importance in feature_lists.items():
+                feature_matrix.loc[feature_importance.index.values, n_features] = feature_importance.values
+            feature_matrix.columns = feature_matrix.columns.astype('int')
+            feature_matrix.index = feature_matrix.index.astype('str')
+            feature_matrix = feature_matrix.loc[:, feature_matrix.columns.sort_values().values]
+                    
+            feature_matrices[(dataset, compare_group)] = feature_matrix
+            # build feature indicator matrix
+            feature_indicator_matrix = pd.DataFrame(np.zeros((len(feature_set), len(feature_lists))),
+                                          index=feature_set, columns=list(feature_lists.keys()))
+            for n_features, feature_importance in feature_lists.items():
+                feature_indicator_matrix.loc[feature_importance.index.values, n_features] = 1
+            feature_indicator_matrix.columns = feature_indicator_matrix.columns.astype('int')
+            feature_indicator_matrix = feature_indicator_matrix.loc[:, feature_indicator_matrix.columns.sort_values().values]
+            feature_indicator_matrices[(dataset, compare_group)] = feature_indicator_matrix
+            
+            if dataset in feature_fields:
+                feature_meta = feature_matrix.index.to_series().str.split('|', expand=True)
+                feature_meta.columns = feature_fields[dataset]
+                if 'transcript_id' in feature_fields[dataset]:
+                    feature_matrix.insert(
+                        0, 'gene_type', 
+                        transcript_table_by_transcript_id.loc[feature_meta['transcript_id'].values, 'gene_type'].values)
+                    feature_matrix.insert(
+                        0, 'gene_name', 
+                        transcript_table_by_transcript_id.loc[feature_meta['transcript_id'].values, 'gene_name'].values)
+                    
+                elif 'gene_id' in feature_fields[dataset]:
+                    feature_matrix.insert(
+                        0, 'gene_type', 
+                        transcript_table_by_gene_id.loc[feature_meta['gene_id'].values, 'gene_type'].values)
+                    feature_matrix.insert(
+                        0, 'gene_name', 
+                        transcript_table_by_gene_id.loc[feature_meta['gene_id'].values, 'gene_name'].values)
+                elif 'transcript_name' in feature_fields[dataset]:
+                    feature_matrix.insert(
+                        0, 'gene_type', 
+                        transcript_table_by_transcript_name.loc[feature_meta['transcript_name'].values, 'gene_type'].values)
+                    feature_matrix.insert(
+                        0, 'gene_name', 
+                        transcript_table_by_transcript_name.loc[feature_meta['transcript_name'].values, 'gene_name'].values)
+                    
+                feature_indicator_matrix.index = feature_matrix.loc[:, 'gene_name'].values + '|' + feature_matrix.loc[:, 'gene_type'].values
+            # build feature support matrix
+            #feature_support_matrix = pd.DataFrame(np.zeros((len(feature_set), len(feature_lists))),
+            #                              index=feature_set, columns=list(feature_lists.keys()))
+            #for n_features, feature_support in feature_supports.items():
+            #    feature_support_matrix.loc[feature_support.index.values, n_features] = feature_support.values
+            #feature_support_matrix.columns = feature_support_matrix.columns.astype('int')
+            #feature_support_matrix = feature_matrix.loc[:, feature_support_matrix.columns.sort_values().values]
+            #feature_support_matrices[(dataset, compare_group)] = feature_support_matrix
+            fig, ax = plt.subplots(figsize=(6, 8))
+            sns.heatmap(feature_indicator_matrix,
+                        cmap=sns.light_palette('green', as_cmap=True), cbar=False, ax=ax, linewidth=1)
+            ax.set_xlabel('Number of features')
+            ax.set_ylabel('Fetures')
+            ax.set_title('{}, {}'.format(dataset, compare_group))
+
+            display(feature_matrix.style\
+                .background_gradient(cmap=sns.light_palette('green', as_cmap=True))\
+                .set_precision(2)\
+                .set_caption('{}, {}'.format(dataset, compare_group)))
+
+    pbar.close()
+    metrics = pd.DataFrame.from_records(records)
+    return metrics, feature_matrices, feature_indicator_matrices
+
+def plot_roc_curve_ci(y, is_train, predicted_scores, ax, title=None):
+    # ROC curve
+    n_splits = is_train.shape[0]
+    all_fprs = np.linspace(0, 1, 100)
+    roc_curves = np.zeros((n_splits, len(all_fprs), 2))
+    roc_aucs = np.zeros(n_splits)
+    for i in range(n_splits):
+        fpr, tpr, thresholds = roc_curve(y[~is_train[i]], predicted_scores[i, ~is_train[i]])
+        roc_aucs[i] = roc_auc_score(y[~is_train[i]], predicted_scores[i, ~is_train[i]])
+        roc_curves[i, :, 0] = all_fprs
+        roc_curves[i, :, 1] = interp(all_fprs, fpr, tpr)
+    roc_curves = pd.DataFrame(roc_curves.reshape((-1, 2)), columns=['fpr', 'tpr'])
+    sns.lineplot(x='fpr', y='tpr', data=roc_curves, ci='sd', ax=ax,
+                 label='Average ROAUC = {:.4f}'.format(roc_aucs.mean()))
+    #ax.plot(fpr, tpr, label='ROAUC = {:.4f}'.format(roc_auc_score(y_test, y_score[:, 1])))
+    #ax.plot([0, 1], [0, 1], linestyle='dashed')
+    ax.set_xlabel('False positive rate')
+    ax.set_ylabel('True positive rate')
+    ax.plot([0, 1], [0, 1], linestyle='dashed', color='gray')
+    if title:
+        ax.set_title(title)
+    ax.legend()
+
+def _plot_10_features(input_dir, datasets, use_log=False, scale=False, title=None):
+    pbar = tqdm_notebook(unit='directory')
+    for dataset in datasets:
+        sample_classes = pd.read_table('metadata/sample_classes.{}.txt'.format(groups[dataset]),
+                                       header=None, index_col=0).iloc[:, 0]
+        cpm = pd.read_table('output/cpm_matrix/{}.txt'.format(dataset), index_col=0)
+        if use_log:
+            cpm = np.log2(cpm + 0.001)
+        if scale:
+            X = RobustScaler().fit_transform(cpm.T.values).T
+        X = cpm.values
+        X = pd.DataFrame(X, index=cpm.index.values, columns=cpm.columns.values)
+        for compare_group in os.listdir(os.path.join(input_dir, dataset)):
+            for path in os.listdir(os.path.join(input_dir, dataset, compare_group)):
+                classifier, n_features, selector, resample_method  = path.split('.')
+                if int(n_features) != 10:
+                    continue
+                if (classifier != 'random_forest') or (selector != 'robust'):
+                    continue
+                if resample_method != 'stratified_shuffle_split':
+                    continue
+                record = {
+                    'compare_group': compare_group,
+                    'classifier': classifier,
+                    'n_features': n_features,
+                    'selector': selector,
+                    'resample_method': resample_method,
+                    'dataset': dataset
+                }
+                result_dir = os.path.join(input_dir, dataset, compare_group, path)
+                with h5py.File(os.path.join(result_dir, 'evaluation.{}.h5'.format(resample_method))) as f:
+                    train_index = f['train_index'][:]
+                    predicted_scores = f['predictions'][:]
+                    labels = f['labels'][:]
+                fig, ax = plt.subplots(figsize=(8, 8))
+                plot_roc_curve_ci(labels, train_index, predicted_scores, ax, 
+                                  title='{}, {}'.format(dataset, compare_group))
+                
+                features = pd.read_table(os.path.join(result_dir, 'features.txt'), header=None).iloc[:, 0].values
+                pbar.update(1)
+
+    pbar.close()
+
+
+def _evaluate_preprocess_methods(input_dirs, preprocess_methods, title=None):
+    records = []
+    pbar = tqdm_notebook(unit='directory')
+    for preprocess_method, input_dir in zip(preprocess_methods, input_dirs):
+        for compare_group in os.listdir(input_dir):
+            for path in os.listdir(os.path.join(input_dir, compare_group)):
+                classifier, n_features, selector, resample_method  = path.split('.')
+                if int(n_features) > 50:
+                    continue
+                if (classifier != 'random_forest') or (selector != 'robust'):
+                    continue
+                if resample_method != 'stratified_shuffle_split':
+                    continue
+                record = {
+                    'compare_group': compare_group,
+                    'classifier': classifier,
+                    'n_features': n_features,
+                    'selector': selector,
+                    'resample_method': resample_method,
+                    'preprocess_method': preprocess_method
+                }
+                metrics = pd.read_table(os.path.join(input_dir, compare_group, path, 'metrics.{}.txt'.format(resample_method)))
+                record['test_roc_auc_mean'] = metrics['test_roc_auc'].mean()
+                if resample_method == 'leave_one_out':
+                    record['test_roc_auc_std'] = 0
+                elif resample_method == 'stratified_shuffle_split':
+                    record['test_roc_auc_std'] = metrics['test_roc_auc'].std()
+                pbar.update(1)
+                records.append(record)
+    pbar.close()
+    records = pd.DataFrame.from_records(records)
+    records['n_features'] = records.loc[:, 'n_features'].astype(np.int32)
+    for compare_group, sub_df in records.groupby('compare_group'):
+        pivot = sub_df.pivot_table(
+            index='preprocess_method', columns='n_features', values='test_roc_auc_mean')
+        #print(pivot.iloc[:, 0])
+        #print(np.argsort(np.argsort(pivot.values, axis=0), axis=0)[:, 0])
+        mean_ranks = np.mean(pivot.shape[0] - np.argsort(np.argsort(pivot.values, axis=0), axis=0), axis=1)
+        mean_ranks = pd.Series(mean_ranks, index=pivot.index.values)
+        mean_ranks = mean_ranks.sort_values()
+        rename_index = ['{} (rank = {:.1f})'.format(name, value) for name, value in zip(mean_ranks.index, mean_ranks.values)]
+        rename_index = pd.Series(rename_index, index=mean_ranks.index.values)
+        sub_df = sub_df.copy()
+        sub_df['preprocess_method'] = rename_index[sub_df['preprocess_method'].values].values
+        sub_df['n_features'] = sub_df['n_features'].astype('int')
+        sub_df = sub_df.sort_values(['preprocess_method', 'n_features'], ascending=True)
+        sub_df['n_features'] = sub_df['n_features'].astype('str')
+        fig, ax = plt.subplots(figsize=(8, 8))                      
+        #sns.lineplot('n_features', 'test_roc_auc_mean', hue='preprocess_method', data=sub_df, 
+        #          ci=None, ax=ax, markers='o', hue_order=rename_index.values, sort=False)
+        for preprocess_method in rename_index.values:
+            tmp_df = sub_df[sub_df['preprocess_method'] == preprocess_method]
+            ax.plot(np.arange(tmp_df.shape[0]) + 1, tmp_df['test_roc_auc_mean'], label=preprocess_method)
+            ax.set_xticks(np.arange(tmp_df.shape[0]) + 1)
+            ax.set_xticklabels(tmp_df['n_features'])
+        ax.set_xlabel('Number of features')
+        ax.set_ylabel('Average AUROC')
+        if len(preprocess_methods) > 1:
+            ax.legend(title='Preprocess method', bbox_to_anchor=(1.04,0.5), 
+                      loc="center left", borderaxespad=0)
+        ax.set_ylim(0.5, 1)
+        if title:
+            ax.set_title(title + ', ' + compare_group)
+
+@command_handler
+def evaluate_preprocessing_methods(args):
+    _evaluate_preprocess_methods(args.input_dirs, args.precessing_methods)
+
+def bigwig_fetch(filename, chrom, start, end, dtype='float'):
+    import subprocess
+    p = subprocess.Popen(['bigWigToBedGraph', filename, 'stdout',
+                      '-chrom={}'.format(chrom), '-start={}'.format(start), '-end={}'.format(end)],
+                    stdout=subprocess.PIPE)
+    data = np.zeros(end - start, dtype=dtype)
+    for line in p.stdout:
+        line = str(line, encoding='ascii')
+        c = line.strip().split('\t')
+        data[(int(c[1]) - start):(int(c[2]) - start)] = float(c[3])
+    return data
+    
+
+def extract_feature_sequence(feature, genome_dir):
+    from pyfaidx import Fasta
+    from Bio.Seq import Seq
+
+    feature = line.split('\t')[0]
+    gene_id, gene_type, gene_name, domain_id, transcript_id, start, end = feature.split('|')
+    start = int(start)
+    end = int(end)
+    if gene_type == 'genomic':
+        gene_type = 'genome'
+    fasta = Fasta(os.path.join(args.genome_dir, 'fasta', gene_type + '.fa'))
+    if gene_type == 'genome':
+        chrom, gstart, gend, strand = gene_id.split('_')
+        gstart = int(gstart)
+        gend = int(gend)
+        seq = fasta[chrom][gstart:gend].seq
+        if strand == '-':
+            seq = str(Seq(seq).reverse_complement())
+    else:
+        seq = fasta[transcript_id][start:end].seq
+    seq = seq.upper()
+
+
+@command_handler
+def visualize_domains(args):
+    import numpy as np
+    import matplotlib
+    matplotlib.use('Agg')
+    import matplotlib.pyplot as plt
+    from matplotlib.backends.backend_pdf import PdfPages
+    from matplotlib.gridspec import GridSpec
+    import seaborn as sns
+    sns.set_style('white')
+    import pandas as pd
+    plt.rcParams['figure.dpi'] = 96
+    from tqdm import tqdm
+    from pykent import BigWigFile
+    from scipy.cluster.hierarchy import linkage, dendrogram
+    from pyfaidx import Fasta
+    from Bio.Seq import Seq
+    from call_peak import call_peaks
+
+    # read sample ids
+    #logger.info('read sample ids: ' + args.sample_ids)
+    #sample_ids = open(args.sample_ids_file, 'r').read().split()
+    logger.info('reads sample classes: ' + args.sample_classes)
+    sample_classes =  pd.read_table(args.sample_classes, sep='\t', index_col=0).iloc[:, 0]
+    sample_classes = sample_classes.sort_values()
+    sample_ids = sample_classes.index.values
+
+    # read features
+    features = pd.read_table(args.features, header=None).iloc[:, 0]
+    feature_info = features.str.split('|', expand=True)
+    feature_info.columns = ['gene_id', 'gene_type', 'gene_name', 'domain_id', 'transcript_id', 'start', 'end']
+    feature_info.index = features.values
+    # read count matrix to get read depth
+    #counts = pd.read_table(args.count_matrix, index_col=0)
+    #read_depth = counts.sum(axis=0)
+    #del counts
+    # read chrom sizes
+    #chrom_sizes = pd.read_table(args.chrom_sizes, sep='\t', index_col=0, header=None).iloc[:, 0]
+
+    with PdfPages(args.output_file) as pdf:
+        for feature_name, feature in tqdm(feature_info.iterrows(), unit='feature'):
+            #logger.info('plot feature: {}'.format(feature_name))
+            if feature['gene_type'] == 'genomic':
+                chrom, start, end, strand = feature['gene_id'].split('_')
+                start = int(start)
+                end = int(end)
+                bigwig_file = os.path.join(args.output_dir, 'bigwig', '{{0}}.genome.{0}.bigWig'.format(strand))
+            elif feature['gene_type'] in ('piRNA', 'miRNA'):
+                continue
+            else:
+                start = int(feature['start'])
+                end = int(feature['end'])
+                chrom = feature.transcript_id
+                bigwig_file = os.path.join(args.output_dir, 'tbigwig_normalized', '{0}.transcriptome.bigWig')
+            # read coverage from BigWig files
+            coverage = None
+            for i, sample_id in enumerate(sample_ids):
+                bwf = BigWigFile(bigwig_file.format(sample_id))
+                if coverage is None:
+                    # interval to display coverage
+                    chrom_size = bwf.get_chrom_size(feature['transcript_id'])
+                    if chrom_size == 0:
+                        raise ValueError('cannot find transcript id {} in bigwig'.format(feature['transcript_id']))
+                    view_start = max(start - args.flanking, 0)
+                    view_end = min(end + args.flanking, chrom_size)
+                    coverage = np.zeros((len(sample_ids), view_end - view_start), dtype='float')
+                    logger.info('create_coverage_matrix: ({}, {})'.format(*coverage.shape))
+                #logger.info('bigWigQuery: {}:{}-{}'.format(chrom, view_start, view_end))
+                values = bwf.query(chrom, view_start, view_end, fillna=0)
+                del bwf
+                if values is not None:
+                    coverage[i] = values
+                #coverage[i] = bigwig_fetch(bigwig_file.format(sample_id), chrom, view_start, view_end, dtype='int')
+                # normalize coverage by read depth
+                #coverage[i] *= 1e6/read_depth[sample_id]
+                # log2 transformation
+                coverage[i] = np.log2(coverage[i] + 1)
+            
+            # get sequence
+            gene_type = feature['gene_type']
+            if gene_type == 'genomic':
+                gene_type = 'genome'
+            fasta = Fasta(os.path.join(args.genome_dir, 'fasta', gene_type + '.fa'))
+            seq = fasta[feature['transcript_id']][view_start:view_end].seq
+            if (gene_type == 'genome') and (strand == '-'):
+                seq = str(Seq(seq).reverse_complement())
+            seq = seq.upper()
+
+            # draw heatmap
+            '''
+            plot_data = pd.DataFrame(coverage)
+            cmap = sns.light_palette('blue', as_cmap=True, n_colors=6)
+            g = sns.clustermap(plot_data, figsize=(20, 8), col_cluster=False, row_colors=None, cmap='Blues')
+            g.ax_heatmap.set_yticklabels([])
+            g.ax_heatmap.set_yticks([])
+            xticks = np.arange(0, coverage.shape[1], 10)
+            g.ax_heatmap.set_xticks(xticks)
+            g.ax_heatmap.set_xticklabels(xticks, rotation=0)
+            g.ax_heatmap.vlines(x=domain_start, ymin=0, ymax=g.ax_heatmap.get_ylim()[0], linestyle='dashed', linewidth=1.0)
+            g.ax_heatmap.vlines(x=domain_end, ymin=0, ymax=g.ax_heatmap.get_ylim()[0], linestyle='dashed', linewidth=1.0)
+            g.ax_heatmap.set_title(feature_name)
+            '''
+            # hierarchical clustering
+            order = np.arange(coverage.shape[0], dtype='int')
+            for label in np.unique(sample_classes):
+                mask = (sample_classes == label)
+                Z = linkage(coverage[mask], 'single')
+                R = dendrogram(Z, no_plot=True, labels=order[mask])
+                order[mask] = R['ivl']
+            sample_classes = sample_classes.iloc[order]
+            coverage = coverage[order]
+
+            plt.rcParams['xtick.minor.visible'] = True
+            plt.rcParams['xtick.minor.size'] = 4
+            plt.rcParams['xtick.bottom'] = True
+            plt.rcParams['xtick.labelsize'] = 8
+            fig = plt.figure(figsize=(20, 6))
+            gs = GridSpec(4, 3, figure=fig, width_ratios=[0.95, 0.03, 0.02], height_ratios=[0.6, 0.15, 0.1, 0.1], hspace=0.2, wspace=0.15)
+            #fig, axes = plt.subplots(1, 2, figsize=(20, 3), sharey=True, 
+            #    gridspec_kw={'width_ratios': [0.98, 0.02], 'hspace': 0})
+            ax_heatmap = plt.subplot(gs[0, 0])
+            ax_colorbar = plt.subplot(gs[0, 1])
+            ax_label = plt.subplot(gs[0, 2])
+            ax_line = plt.subplot(gs[1, 0])
+            ax_domain = plt.subplot(gs[2, 0])
+            ax_refined_domain = plt.subplot(gs[3, 0])
+
+            p = ax_heatmap.pcolormesh(coverage, cmap='Blues')
+            ax_heatmap.set_xticks(np.arange(coverage.shape[1]) + 0.5, minor=True)
+            ax_heatmap.set_xlim(0, coverage.shape[1])
+            xticks = ax_heatmap.get_xticks()
+            ax_heatmap.set_xticks(xticks + 0.5)
+            ax_heatmap.set_xticklabels(xticks.astype('int'))
+            ax_heatmap.set_title(feature_name)
+
+            fig.colorbar(p, cax=ax_colorbar, use_gridspec=False, orientation='vertical')
+
+            for label in sample_classes.unique():
+                ax_label.barh(y=np.arange(coverage.shape[0]), width=(sample_classes == label).astype('int'), height=1,
+                    edgecolor='none', label=label)
+            ax_label.set_xlim(0, 1)
+            ax_label.set_ylim(0, coverage.shape[0])
+            ax_label.tick_params(labelbottom=False, bottom=False)
+            ax_label.set_xticks([])
+            ax_label.set_yticks([])
+            ax_label.legend(title='Class', bbox_to_anchor=(1.1, 0.5), loc="center left", borderaxespad=0)
+
+            ax_line.fill_between(np.arange(coverage.shape[1]), coverage.mean(axis=0), step='pre', alpha=0.9)
+            ax_line.set_xlim(0, coverage.shape[1])
+            #ax_line.set_xticks(np.arange(coverage.shape[1]) + 0.5, minor=True)
+            #xticks = ax_line.get_xticks()
+            #ax_line.set_xticks(xticks + 0.5)
+            #ax_line.set_xticklabels(xticks.astype('int'))
+            ax_line.set_xticks(np.arange(coverage.shape[1]) + 0.5)
+            ax_line.set_xticks([], minor=True)
+            ax_line.set_xticklabels(list(seq))
+            ax_line.set_ylim(0, ax_line.get_ylim()[1])
+            #ax_line.vlines(x=start - view_start + 0.5, ymin=0, ymax=ax_line.get_ylim()[1], linestyle='dashed', linewidth=1.0)
+            #ax_line.vlines(x=end - view_start + 0.5, ymin=0, ymax=ax_line.get_ylim()[1], linestyle='dashed', linewidth=1.0)
+
+            ax_domain.hlines(y=0.5, xmin=start - view_start, xmax=end - view_start, linewidth=5, color='C0')
+            ax_domain.set_ylim(0, 1)
+            ax_domain.set_ylabel('Domain')
+            ax_domain.set_yticks([])
+            ax_domain.set_xticks([])
+            ax_domain.set_xticks(np.arange(coverage.shape[1]) + 0.5, minor=True)
+            ax_domain.set_xlim(0, coverage.shape[1])
+            ax_domain.spines['top'].set_visible(False)
+            ax_domain.spines['right'].set_visible(False)
+
+            coverage_mean = coverage.mean(axis=0)
+            for _, peak_start, peak_end in call_peaks([coverage_mean], min_length=10):
+                ax_refined_domain.hlines(y=0.5, xmin=peak_start, xmax=peak_end, linewidth=5, color='C0')
+            ax_refined_domain.set_ylim(0, 1)
+            ax_refined_domain.set_ylabel('Refined')
+            ax_refined_domain.set_yticks([])
+            ax_refined_domain.set_xticks([])
+            ax_refined_domain.set_xticks(np.arange(coverage.shape[1]) + 0.5, minor=True)
+            ax_refined_domain.set_xlim(0, coverage.shape[1])
+            ax_refined_domain.spines['top'].set_visible(False)
+            ax_refined_domain.spines['right'].set_visible(False)
+
+            fig.tight_layout()
+            # save plot
+            pdf.savefig(fig)
+            plt.close()
+
+
+if __name__ == '__main__':
+    main_parser = argparse.ArgumentParser(description='Preprocessing module')
+    subparsers = main_parser.add_subparsers(dest='command')
+
+    parser = subparsers.add_parser('visualize_domains', help='plot read coverage of domains as heatmaps')
+    parser.add_argument('--sample-classes', type=str, required=True, help='e.g. {data_dir}/sample_classes.txt')
+    parser.add_argument('--output-dir', type=str, required=True, help='e.g. output/scirep')
+    parser.add_argument('--features', type=str, required=True, help='list of selected features')
+    #parser.add_argument('--count-matrix', type=str, required=True, help='count matrix')
+    parser.add_argument('--output-file', '-o', type=str, required=True, help='output PDF file')
+    parser.add_argument('--flanking', type=int, default=20, help='flanking length for genomic domains')
+    parser.add_argument('--genome-dir', type=str, required=True, help='e.g. genome/hg38')
+    
+    args = main_parser.parse_args()
+    if not args.command:
+        main_parser.print_help()
+        sys.exit(1)
+    logger = logging.getLogger('report.' + args.command)
+
+    command_handlers.get(args.command)(args)
\ No newline at end of file