#! /usr/bin/env python
from __future__ import print_function
import argparse, sys, os, errno
import logging
logging.basicConfig(level=logging.INFO, format='[%(asctime)s] [%(levelname)s] %(name)s: %(message)s')
from scipy import interp
from sklearn.metrics import roc_curve, roc_auc_score
from sklearn.preprocessing import RobustScaler
import numpy as np
import pandas as pd
command_handlers = {}
def command_handler(f):
command_handlers[f.__name__] = f
return f
def _compare_feature_selection_params(input_dir):
from tqdm import tqdm
import pandas as pd
import matplotlib.pyplot as plt
records = []
pbar = tqdm(unit='directory')
for compare_group in os.listdir(input_dir):
for path in os.listdir(os.path.join(input_dir, compare_group)):
classifier, n_features, selector, resample_method = path.split('.')
record = {
'compare_group': compare_group,
'classifier': classifier,
'n_features': n_features,
'selector': selector,
'resample_method': resample_method
}
metrics = pd.read_table(os.path.join(input_dir, compare_group, path, 'metrics.{}.txt'.format(resample_method)))
record['test_roc_auc_mean'] = metrics['test_roc_auc'].mean()
if resample_method == 'leave_one_out':
record['test_roc_auc_std'] = 0
elif resample_method == 'stratified_shuffle_split':
record['test_roc_auc_std'] = metrics['test_roc_auc'].std()
pbar.update(1)
records.append(record)
pbar.close()
records = pd.DataFrame.from_records(records)
records.loc[:, 'n_features'] = records.loc[:, 'n_features'].astype('int')
compare_groups = records.loc[:, 'compare_group'].unique()
figsize = 3.5
# Compare resample methods
fig, axes = plt.subplots(1, len(compare_groups),
figsize=(figsize*len(compare_groups), figsize),
sharey=True, sharex=False)
for i, compare_group in enumerate(compare_groups):
if len(compare_groups) > 1:
ax = axes[i]
else:
ax = axes
sub_df = records.query('compare_group == "{}"'.format(compare_group))
pivot = sub_df.pivot_table(index=['classifier', 'n_features', 'selector'],
columns=['resample_method'],
values='test_roc_auc_mean')
ax.scatter(pivot.loc[:, 'leave_one_out'], pivot.loc[:, 'stratified_shuffle_split'], s=12)
ax.set_xlabel('AUROC (leave_one_out)')
ax.set_ylabel('AUROC (stratified_shuffle_split)')
ax.set_xlim(0.5, 1)
ax.set_ylim(0.5, 1)
ax.plot([0.5, 1], [0.5, 1], linestyle='dashed', color='gray', linewidth=0.8)
ax.set_title(compare_group)
# Compare classifiers
fig, axes = plt.subplots(1, len(compare_groups),
figsize=(figsize*len(compare_groups), figsize),
sharey=True, sharex=False)
for i, compare_group in enumerate(compare_groups):
if len(compare_groups) > 1:
ax = axes[i]
else:
ax = axes
sub_df = records.query('compare_group == "{}"'.format(compare_group))
pivot = sub_df.pivot_table(index=['resample_method', 'n_features', 'selector'],
columns=['classifier'],
values='test_roc_auc_mean')
ax.scatter(pivot.loc[:, 'logistic_regression'], pivot.loc[:, 'random_forest'], s=12)
ax.set_xlabel('AUROC (logistic_regression)')
ax.set_ylabel('AUROC (random_forest)')
ax.set_xlim(0.5, 1)
ax.set_ylim(0.5, 1)
ax.plot([0.5, 1], [0.5, 1], linestyle='dashed', color='gray', linewidth=0.8)
ax.set_title(compare_group)
# Compare number of features
fig, axes = plt.subplots(1, len(compare_groups),
figsize=(figsize*len(compare_groups), figsize),
sharey=False, sharex=False)
for i, compare_group in enumerate(compare_groups):
if len(compare_groups) > 1:
ax = axes[i]
else:
ax = axes
sub_df = records.query('compare_group == "{}"'.format(compare_group))
pivot = sub_df.pivot_table(index=['classifier', 'selector', 'resample_method'],
columns=['n_features'],
values='test_roc_auc_mean')
ax.plot(np.repeat(pivot.columns.values.reshape((-1, 1)), pivot.shape[0], axis=1),
pivot.values.T)
ax.set_ylim(0.5, 1)
ax.set_xlabel('Number of features')
ax.set_ylabel('AUROC')
ax.set_title(compare_group)
# Compare feature selection methods
fig, axes = plt.subplots(1, len(compare_groups),
figsize=(figsize*len(compare_groups), figsize),
sharey=True, sharex=False)
for i, compare_group in enumerate(compare_groups):
if len(compare_groups) > 1:
ax = axes[i]
else:
ax = axes
sub_df = records.query('compare_group == "{}"'.format(compare_group))
pivot = sub_df.pivot_table(index=['classifier', 'n_features', 'resample_method'],
columns=['selector'],
values='test_roc_auc_mean')
ax.plot(np.repeat(pivot.columns.values.reshape((-1, 1)), pivot.shape[0], axis=1),
pivot.values.T)
ax.set_ylim(0.5, 1)
ax.set_xlabel('Feature selection method')
ax.set_ylabel('AUROC')
ax.set_title(compare_group)
return records
@command_handler
def compare_feature_selection_params(args):
_compare_feature_selection_params(args.input_dir)
logger.info('save plot: ' + args.output_file)
plt.savefig(args.output_file)
def _compare_features(input_dir, datasets):
import pandas as pd
from tqdm import tqdm
import seaborn as sns
pbar = tqdm(unit='directory')
records = []
feature_matrices = {}
#feature_support_matrices = {}
feature_indicator_matrices = {}
for dataset in datasets:
cpm = pd.read_table('output/cpm_matrix/{}.txt'.format(dataset), index_col=0)
for compare_group in os.listdir(os.path.join(input_dir, dataset)):
feature_lists = {}
#feature_supports = {}
for path in os.listdir(os.path.join(input_dir, dataset, compare_group)):
classifier, n_features, selector, resample_method = path.split('.')
if int(n_features) > 10:
continue
if (classifier != 'random_forest') or (selector != 'robust'):
continue
if resample_method != 'stratified_shuffle_split':
continue
record = {
'compare_group': compare_group,
'classifier': classifier,
'n_features': n_features,
'selector': selector,
'resample_method': resample_method,
'dataset': dataset
}
# feature importance
feature_lists[n_features] = pd.read_table(os.path.join(input_dir, dataset, compare_group,
path, 'feature_importances.txt'), header=None, index_col=0).iloc[:, 0]
feature_lists[n_features].index = feature_lists[n_features].index.astype('str')
# feature support
#with h5py.File(os.path.join(input_dir, dataset, compare_group,
# path, 'evaluation.{}.h5'.format(resample_method)), 'r') as f:
# feature_support = np.mean(f['feature_selection'][:], axis=0)
# feature_support = pd.Series(feature_support, index=cpm.index.values)
# feature_support = feature_support[feature_lists[n_features].index.values]
# feature_supports[n_features] = feature_support
# read metrics
metrics = pd.read_table(os.path.join(input_dir, dataset, compare_group,
path, 'metrics.{}.txt'.format(resample_method)))
record['test_roc_auc_mean'] = metrics['test_roc_auc'].mean()
if resample_method == 'leave_one_out':
record['test_roc_auc_std'] = 0
elif resample_method == 'stratified_shuffle_split':
record['test_roc_auc_std'] = metrics['test_roc_auc'].std()
pbar.update(1)
records.append(record)
# feature union set
feature_set = reduce(np.union1d, [a.index.values for a in feature_lists.values()])
# build feature importance matrix
feature_matrix = pd.DataFrame(np.zeros((len(feature_set), len(feature_lists))),
index=feature_set, columns=list(feature_lists.keys()))
for n_features, feature_importance in feature_lists.items():
feature_matrix.loc[feature_importance.index.values, n_features] = feature_importance.values
feature_matrix.columns = feature_matrix.columns.astype('int')
feature_matrix.index = feature_matrix.index.astype('str')
feature_matrix = feature_matrix.loc[:, feature_matrix.columns.sort_values().values]
feature_matrices[(dataset, compare_group)] = feature_matrix
# build feature indicator matrix
feature_indicator_matrix = pd.DataFrame(np.zeros((len(feature_set), len(feature_lists))),
index=feature_set, columns=list(feature_lists.keys()))
for n_features, feature_importance in feature_lists.items():
feature_indicator_matrix.loc[feature_importance.index.values, n_features] = 1
feature_indicator_matrix.columns = feature_indicator_matrix.columns.astype('int')
feature_indicator_matrix = feature_indicator_matrix.loc[:, feature_indicator_matrix.columns.sort_values().values]
feature_indicator_matrices[(dataset, compare_group)] = feature_indicator_matrix
if dataset in feature_fields:
feature_meta = feature_matrix.index.to_series().str.split('|', expand=True)
feature_meta.columns = feature_fields[dataset]
if 'transcript_id' in feature_fields[dataset]:
feature_matrix.insert(
0, 'gene_type',
transcript_table_by_transcript_id.loc[feature_meta['transcript_id'].values, 'gene_type'].values)
feature_matrix.insert(
0, 'gene_name',
transcript_table_by_transcript_id.loc[feature_meta['transcript_id'].values, 'gene_name'].values)
elif 'gene_id' in feature_fields[dataset]:
feature_matrix.insert(
0, 'gene_type',
transcript_table_by_gene_id.loc[feature_meta['gene_id'].values, 'gene_type'].values)
feature_matrix.insert(
0, 'gene_name',
transcript_table_by_gene_id.loc[feature_meta['gene_id'].values, 'gene_name'].values)
elif 'transcript_name' in feature_fields[dataset]:
feature_matrix.insert(
0, 'gene_type',
transcript_table_by_transcript_name.loc[feature_meta['transcript_name'].values, 'gene_type'].values)
feature_matrix.insert(
0, 'gene_name',
transcript_table_by_transcript_name.loc[feature_meta['transcript_name'].values, 'gene_name'].values)
feature_indicator_matrix.index = feature_matrix.loc[:, 'gene_name'].values + '|' + feature_matrix.loc[:, 'gene_type'].values
# build feature support matrix
#feature_support_matrix = pd.DataFrame(np.zeros((len(feature_set), len(feature_lists))),
# index=feature_set, columns=list(feature_lists.keys()))
#for n_features, feature_support in feature_supports.items():
# feature_support_matrix.loc[feature_support.index.values, n_features] = feature_support.values
#feature_support_matrix.columns = feature_support_matrix.columns.astype('int')
#feature_support_matrix = feature_matrix.loc[:, feature_support_matrix.columns.sort_values().values]
#feature_support_matrices[(dataset, compare_group)] = feature_support_matrix
fig, ax = plt.subplots(figsize=(6, 8))
sns.heatmap(feature_indicator_matrix,
cmap=sns.light_palette('green', as_cmap=True), cbar=False, ax=ax, linewidth=1)
ax.set_xlabel('Number of features')
ax.set_ylabel('Fetures')
ax.set_title('{}, {}'.format(dataset, compare_group))
display(feature_matrix.style\
.background_gradient(cmap=sns.light_palette('green', as_cmap=True))\
.set_precision(2)\
.set_caption('{}, {}'.format(dataset, compare_group)))
pbar.close()
metrics = pd.DataFrame.from_records(records)
return metrics, feature_matrices, feature_indicator_matrices
def plot_roc_curve_ci(y, is_train, predicted_scores, ax, title=None):
# ROC curve
n_splits = is_train.shape[0]
all_fprs = np.linspace(0, 1, 100)
roc_curves = np.zeros((n_splits, len(all_fprs), 2))
roc_aucs = np.zeros(n_splits)
for i in range(n_splits):
fpr, tpr, thresholds = roc_curve(y[~is_train[i]], predicted_scores[i, ~is_train[i]])
roc_aucs[i] = roc_auc_score(y[~is_train[i]], predicted_scores[i, ~is_train[i]])
roc_curves[i, :, 0] = all_fprs
roc_curves[i, :, 1] = interp(all_fprs, fpr, tpr)
roc_curves = pd.DataFrame(roc_curves.reshape((-1, 2)), columns=['fpr', 'tpr'])
sns.lineplot(x='fpr', y='tpr', data=roc_curves, ci='sd', ax=ax,
label='Average ROAUC = {:.4f}'.format(roc_aucs.mean()))
#ax.plot(fpr, tpr, label='ROAUC = {:.4f}'.format(roc_auc_score(y_test, y_score[:, 1])))
#ax.plot([0, 1], [0, 1], linestyle='dashed')
ax.set_xlabel('False positive rate')
ax.set_ylabel('True positive rate')
ax.plot([0, 1], [0, 1], linestyle='dashed', color='gray')
if title:
ax.set_title(title)
ax.legend()
def _plot_10_features(input_dir, datasets, use_log=False, scale=False, title=None):
pbar = tqdm_notebook(unit='directory')
for dataset in datasets:
sample_classes = pd.read_table('metadata/sample_classes.{}.txt'.format(groups[dataset]),
header=None, index_col=0).iloc[:, 0]
cpm = pd.read_table('output/cpm_matrix/{}.txt'.format(dataset), index_col=0)
if use_log:
cpm = np.log2(cpm + 0.001)
if scale:
X = RobustScaler().fit_transform(cpm.T.values).T
X = cpm.values
X = pd.DataFrame(X, index=cpm.index.values, columns=cpm.columns.values)
for compare_group in os.listdir(os.path.join(input_dir, dataset)):
for path in os.listdir(os.path.join(input_dir, dataset, compare_group)):
classifier, n_features, selector, resample_method = path.split('.')
if int(n_features) != 10:
continue
if (classifier != 'random_forest') or (selector != 'robust'):
continue
if resample_method != 'stratified_shuffle_split':
continue
record = {
'compare_group': compare_group,
'classifier': classifier,
'n_features': n_features,
'selector': selector,
'resample_method': resample_method,
'dataset': dataset
}
result_dir = os.path.join(input_dir, dataset, compare_group, path)
with h5py.File(os.path.join(result_dir, 'evaluation.{}.h5'.format(resample_method))) as f:
train_index = f['train_index'][:]
predicted_scores = f['predictions'][:]
labels = f['labels'][:]
fig, ax = plt.subplots(figsize=(8, 8))
plot_roc_curve_ci(labels, train_index, predicted_scores, ax,
title='{}, {}'.format(dataset, compare_group))
features = pd.read_table(os.path.join(result_dir, 'features.txt'), header=None).iloc[:, 0].values
pbar.update(1)
pbar.close()
def _evaluate_preprocess_methods(input_dirs, preprocess_methods, title=None):
records = []
pbar = tqdm_notebook(unit='directory')
for preprocess_method, input_dir in zip(preprocess_methods, input_dirs):
for compare_group in os.listdir(input_dir):
for path in os.listdir(os.path.join(input_dir, compare_group)):
classifier, n_features, selector, resample_method = path.split('.')
if int(n_features) > 50:
continue
if (classifier != 'random_forest') or (selector != 'robust'):
continue
if resample_method != 'stratified_shuffle_split':
continue
record = {
'compare_group': compare_group,
'classifier': classifier,
'n_features': n_features,
'selector': selector,
'resample_method': resample_method,
'preprocess_method': preprocess_method
}
metrics = pd.read_table(os.path.join(input_dir, compare_group, path, 'metrics.{}.txt'.format(resample_method)))
record['test_roc_auc_mean'] = metrics['test_roc_auc'].mean()
if resample_method == 'leave_one_out':
record['test_roc_auc_std'] = 0
elif resample_method == 'stratified_shuffle_split':
record['test_roc_auc_std'] = metrics['test_roc_auc'].std()
pbar.update(1)
records.append(record)
pbar.close()
records = pd.DataFrame.from_records(records)
records['n_features'] = records.loc[:, 'n_features'].astype(np.int32)
for compare_group, sub_df in records.groupby('compare_group'):
pivot = sub_df.pivot_table(
index='preprocess_method', columns='n_features', values='test_roc_auc_mean')
#print(pivot.iloc[:, 0])
#print(np.argsort(np.argsort(pivot.values, axis=0), axis=0)[:, 0])
mean_ranks = np.mean(pivot.shape[0] - np.argsort(np.argsort(pivot.values, axis=0), axis=0), axis=1)
mean_ranks = pd.Series(mean_ranks, index=pivot.index.values)
mean_ranks = mean_ranks.sort_values()
rename_index = ['{} (rank = {:.1f})'.format(name, value) for name, value in zip(mean_ranks.index, mean_ranks.values)]
rename_index = pd.Series(rename_index, index=mean_ranks.index.values)
sub_df = sub_df.copy()
sub_df['preprocess_method'] = rename_index[sub_df['preprocess_method'].values].values
sub_df['n_features'] = sub_df['n_features'].astype('int')
sub_df = sub_df.sort_values(['preprocess_method', 'n_features'], ascending=True)
sub_df['n_features'] = sub_df['n_features'].astype('str')
fig, ax = plt.subplots(figsize=(8, 8))
#sns.lineplot('n_features', 'test_roc_auc_mean', hue='preprocess_method', data=sub_df,
# ci=None, ax=ax, markers='o', hue_order=rename_index.values, sort=False)
for preprocess_method in rename_index.values:
tmp_df = sub_df[sub_df['preprocess_method'] == preprocess_method]
ax.plot(np.arange(tmp_df.shape[0]) + 1, tmp_df['test_roc_auc_mean'], label=preprocess_method)
ax.set_xticks(np.arange(tmp_df.shape[0]) + 1)
ax.set_xticklabels(tmp_df['n_features'])
ax.set_xlabel('Number of features')
ax.set_ylabel('Average AUROC')
if len(preprocess_methods) > 1:
ax.legend(title='Preprocess method', bbox_to_anchor=(1.04,0.5),
loc="center left", borderaxespad=0)
ax.set_ylim(0.5, 1)
if title:
ax.set_title(title + ', ' + compare_group)
@command_handler
def evaluate_preprocessing_methods(args):
_evaluate_preprocess_methods(args.input_dirs, args.precessing_methods)
def bigwig_fetch(filename, chrom, start, end, dtype='float'):
import subprocess
p = subprocess.Popen(['bigWigToBedGraph', filename, 'stdout',
'-chrom={}'.format(chrom), '-start={}'.format(start), '-end={}'.format(end)],
stdout=subprocess.PIPE)
data = np.zeros(end - start, dtype=dtype)
for line in p.stdout:
line = str(line, encoding='ascii')
c = line.strip().split('\t')
data[(int(c[1]) - start):(int(c[2]) - start)] = float(c[3])
return data
def extract_feature_sequence(feature, genome_dir):
from pyfaidx import Fasta
from Bio.Seq import Seq
feature = line.split('\t')[0]
gene_id, gene_type, gene_name, domain_id, transcript_id, start, end = feature.split('|')
start = int(start)
end = int(end)
if gene_type == 'genomic':
gene_type = 'genome'
fasta = Fasta(os.path.join(args.genome_dir, 'fasta', gene_type + '.fa'))
if gene_type == 'genome':
chrom, gstart, gend, strand = gene_id.split('_')
gstart = int(gstart)
gend = int(gend)
seq = fasta[chrom][gstart:gend].seq
if strand == '-':
seq = str(Seq(seq).reverse_complement())
else:
seq = fasta[transcript_id][start:end].seq
seq = seq.upper()
@command_handler
def visualize_domains(args):
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
from matplotlib.gridspec import GridSpec
import seaborn as sns
sns.set_style('white')
import pandas as pd
plt.rcParams['figure.dpi'] = 96
from tqdm import tqdm
from pykent import BigWigFile
from scipy.cluster.hierarchy import linkage, dendrogram
from pyfaidx import Fasta
from Bio.Seq import Seq
from call_peak import call_peaks
# read sample ids
#logger.info('read sample ids: ' + args.sample_ids)
#sample_ids = open(args.sample_ids_file, 'r').read().split()
logger.info('reads sample classes: ' + args.sample_classes)
sample_classes = pd.read_table(args.sample_classes, sep='\t', index_col=0).iloc[:, 0]
sample_classes = sample_classes.sort_values()
sample_ids = sample_classes.index.values
# read features
features = pd.read_table(args.features, header=None).iloc[:, 0]
feature_info = features.str.split('|', expand=True)
feature_info.columns = ['gene_id', 'gene_type', 'gene_name', 'domain_id', 'transcript_id', 'start', 'end']
feature_info.index = features.values
# read count matrix to get read depth
#counts = pd.read_table(args.count_matrix, index_col=0)
#read_depth = counts.sum(axis=0)
#del counts
# read chrom sizes
#chrom_sizes = pd.read_table(args.chrom_sizes, sep='\t', index_col=0, header=None).iloc[:, 0]
with PdfPages(args.output_file) as pdf:
for feature_name, feature in tqdm(feature_info.iterrows(), unit='feature'):
#logger.info('plot feature: {}'.format(feature_name))
if feature['gene_type'] == 'genomic':
chrom, start, end, strand = feature['gene_id'].split('_')
start = int(start)
end = int(end)
bigwig_file = os.path.join(args.output_dir, 'bigwig', '{{0}}.genome.{0}.bigWig'.format(strand))
elif feature['gene_type'] in ('piRNA', 'miRNA'):
continue
else:
start = int(feature['start'])
end = int(feature['end'])
chrom = feature.transcript_id
bigwig_file = os.path.join(args.output_dir, 'tbigwig_normalized', '{0}.transcriptome.bigWig')
# read coverage from BigWig files
coverage = None
for i, sample_id in enumerate(sample_ids):
bwf = BigWigFile(bigwig_file.format(sample_id))
if coverage is None:
# interval to display coverage
chrom_size = bwf.get_chrom_size(feature['transcript_id'])
if chrom_size == 0:
raise ValueError('cannot find transcript id {} in bigwig'.format(feature['transcript_id']))
view_start = max(start - args.flanking, 0)
view_end = min(end + args.flanking, chrom_size)
coverage = np.zeros((len(sample_ids), view_end - view_start), dtype='float')
logger.info('create_coverage_matrix: ({}, {})'.format(*coverage.shape))
#logger.info('bigWigQuery: {}:{}-{}'.format(chrom, view_start, view_end))
values = bwf.query(chrom, view_start, view_end, fillna=0)
del bwf
if values is not None:
coverage[i] = values
#coverage[i] = bigwig_fetch(bigwig_file.format(sample_id), chrom, view_start, view_end, dtype='int')
# normalize coverage by read depth
#coverage[i] *= 1e6/read_depth[sample_id]
# log2 transformation
coverage[i] = np.log2(coverage[i] + 1)
# get sequence
gene_type = feature['gene_type']
if gene_type == 'genomic':
gene_type = 'genome'
fasta = Fasta(os.path.join(args.genome_dir, 'fasta', gene_type + '.fa'))
seq = fasta[feature['transcript_id']][view_start:view_end].seq
if (gene_type == 'genome') and (strand == '-'):
seq = str(Seq(seq).reverse_complement())
seq = seq.upper()
# draw heatmap
'''
plot_data = pd.DataFrame(coverage)
cmap = sns.light_palette('blue', as_cmap=True, n_colors=6)
g = sns.clustermap(plot_data, figsize=(20, 8), col_cluster=False, row_colors=None, cmap='Blues')
g.ax_heatmap.set_yticklabels([])
g.ax_heatmap.set_yticks([])
xticks = np.arange(0, coverage.shape[1], 10)
g.ax_heatmap.set_xticks(xticks)
g.ax_heatmap.set_xticklabels(xticks, rotation=0)
g.ax_heatmap.vlines(x=domain_start, ymin=0, ymax=g.ax_heatmap.get_ylim()[0], linestyle='dashed', linewidth=1.0)
g.ax_heatmap.vlines(x=domain_end, ymin=0, ymax=g.ax_heatmap.get_ylim()[0], linestyle='dashed', linewidth=1.0)
g.ax_heatmap.set_title(feature_name)
'''
# hierarchical clustering
order = np.arange(coverage.shape[0], dtype='int')
for label in np.unique(sample_classes):
mask = (sample_classes == label)
Z = linkage(coverage[mask], 'single')
R = dendrogram(Z, no_plot=True, labels=order[mask])
order[mask] = R['ivl']
sample_classes = sample_classes.iloc[order]
coverage = coverage[order]
plt.rcParams['xtick.minor.visible'] = True
plt.rcParams['xtick.minor.size'] = 4
plt.rcParams['xtick.bottom'] = True
plt.rcParams['xtick.labelsize'] = 8
fig = plt.figure(figsize=(20, 6))
gs = GridSpec(4, 3, figure=fig, width_ratios=[0.95, 0.03, 0.02], height_ratios=[0.6, 0.15, 0.1, 0.1], hspace=0.2, wspace=0.15)
#fig, axes = plt.subplots(1, 2, figsize=(20, 3), sharey=True,
# gridspec_kw={'width_ratios': [0.98, 0.02], 'hspace': 0})
ax_heatmap = plt.subplot(gs[0, 0])
ax_colorbar = plt.subplot(gs[0, 1])
ax_label = plt.subplot(gs[0, 2])
ax_line = plt.subplot(gs[1, 0])
ax_domain = plt.subplot(gs[2, 0])
ax_refined_domain = plt.subplot(gs[3, 0])
p = ax_heatmap.pcolormesh(coverage, cmap='Blues')
ax_heatmap.set_xticks(np.arange(coverage.shape[1]) + 0.5, minor=True)
ax_heatmap.set_xlim(0, coverage.shape[1])
xticks = ax_heatmap.get_xticks()
ax_heatmap.set_xticks(xticks + 0.5)
ax_heatmap.set_xticklabels(xticks.astype('int'))
ax_heatmap.set_title(feature_name)
fig.colorbar(p, cax=ax_colorbar, use_gridspec=False, orientation='vertical')
for label in sample_classes.unique():
ax_label.barh(y=np.arange(coverage.shape[0]), width=(sample_classes == label).astype('int'), height=1,
edgecolor='none', label=label)
ax_label.set_xlim(0, 1)
ax_label.set_ylim(0, coverage.shape[0])
ax_label.tick_params(labelbottom=False, bottom=False)
ax_label.set_xticks([])
ax_label.set_yticks([])
ax_label.legend(title='Class', bbox_to_anchor=(1.1, 0.5), loc="center left", borderaxespad=0)
ax_line.fill_between(np.arange(coverage.shape[1]), coverage.mean(axis=0), step='pre', alpha=0.9)
ax_line.set_xlim(0, coverage.shape[1])
#ax_line.set_xticks(np.arange(coverage.shape[1]) + 0.5, minor=True)
#xticks = ax_line.get_xticks()
#ax_line.set_xticks(xticks + 0.5)
#ax_line.set_xticklabels(xticks.astype('int'))
ax_line.set_xticks(np.arange(coverage.shape[1]) + 0.5)
ax_line.set_xticks([], minor=True)
ax_line.set_xticklabels(list(seq))
ax_line.set_ylim(0, ax_line.get_ylim()[1])
#ax_line.vlines(x=start - view_start + 0.5, ymin=0, ymax=ax_line.get_ylim()[1], linestyle='dashed', linewidth=1.0)
#ax_line.vlines(x=end - view_start + 0.5, ymin=0, ymax=ax_line.get_ylim()[1], linestyle='dashed', linewidth=1.0)
ax_domain.hlines(y=0.5, xmin=start - view_start, xmax=end - view_start, linewidth=5, color='C0')
ax_domain.set_ylim(0, 1)
ax_domain.set_ylabel('Domain')
ax_domain.set_yticks([])
ax_domain.set_xticks([])
ax_domain.set_xticks(np.arange(coverage.shape[1]) + 0.5, minor=True)
ax_domain.set_xlim(0, coverage.shape[1])
ax_domain.spines['top'].set_visible(False)
ax_domain.spines['right'].set_visible(False)
coverage_mean = coverage.mean(axis=0)
for _, peak_start, peak_end in call_peaks([coverage_mean], min_length=10):
ax_refined_domain.hlines(y=0.5, xmin=peak_start, xmax=peak_end, linewidth=5, color='C0')
ax_refined_domain.set_ylim(0, 1)
ax_refined_domain.set_ylabel('Refined')
ax_refined_domain.set_yticks([])
ax_refined_domain.set_xticks([])
ax_refined_domain.set_xticks(np.arange(coverage.shape[1]) + 0.5, minor=True)
ax_refined_domain.set_xlim(0, coverage.shape[1])
ax_refined_domain.spines['top'].set_visible(False)
ax_refined_domain.spines['right'].set_visible(False)
fig.tight_layout()
# save plot
pdf.savefig(fig)
plt.close()
if __name__ == '__main__':
main_parser = argparse.ArgumentParser(description='Preprocessing module')
subparsers = main_parser.add_subparsers(dest='command')
parser = subparsers.add_parser('visualize_domains', help='plot read coverage of domains as heatmaps')
parser.add_argument('--sample-classes', type=str, required=True, help='e.g. {data_dir}/sample_classes.txt')
parser.add_argument('--output-dir', type=str, required=True, help='e.g. output/scirep')
parser.add_argument('--features', type=str, required=True, help='list of selected features')
#parser.add_argument('--count-matrix', type=str, required=True, help='count matrix')
parser.add_argument('--output-file', '-o', type=str, required=True, help='output PDF file')
parser.add_argument('--flanking', type=int, default=20, help='flanking length for genomic domains')
parser.add_argument('--genome-dir', type=str, required=True, help='e.g. genome/hg38')
args = main_parser.parse_args()
if not args.command:
main_parser.print_help()
sys.exit(1)
logger = logging.getLogger('report.' + args.command)
command_handlers.get(args.command)(args)