--- a +++ b/bpnet/plot/profiles.py @@ -0,0 +1,398 @@ +from tqdm import tqdm +import seaborn as sns +import pandas as pd +import os +import matplotlib.pyplot as plt +import numpy as np +from bpnet.plot.utils import simple_yaxis_format, strip_axis, spaced_xticks +from bpnet.modisco.utils import bootstrap_mean, nan_like, ic_scale +from bpnet.plot.utils import show_figure + + +# TODO - make it as a bar-plot with two standard colors: +# #B23F49 (pos), #045CA8 (neg) +def plot_stranded_profile(profile, ax=None, ymax=None, profile_std=None, flip_neg=True, set_ylim=True): + """Plot the stranded profile + """ + if ax is None: + ax = plt.gca() + + if profile.ndim == 1: + # also compatible with single dim + profile = profile[:, np.newaxis] + assert profile.ndim == 2 + assert profile.shape[1] <= 2 + labels = ['pos', 'neg'] + + # determine ymax if not specified + if ymax is None: + if profile_std is not None: + ymax = (profile.max() - 2 * profile_std).max() + else: + ymax = profile.max() + + if set_ylim: + if flip_neg: + ax.set_ylim([-ymax, ymax]) + else: + ax.set_ylim([0, ymax]) + + ax.axhline(y=0, linewidth=1, linestyle='--', color='black') + # strip_axis(ax) + + xvec = np.arange(1, len(profile) + 1) + + for i in range(profile.shape[1]): + sign = 1 if not flip_neg or i == 0 else -1 + ax.plot(xvec, sign * profile[:, i], label=labels[i]) + + # plot also the ribbons + if profile_std is not None: + ax.fill_between(xvec, + sign * profile[:, i] - 2 * profile_std[:, i], + sign * profile[:, i] + 2 * profile_std[:, i], + alpha=0.1) + # return ax + + +def multiple_plot_stranded_profile(d_profile, figsize_tmpl=(4, 3), normalize=False): + fig, axes = plt.subplots(1, len(d_profile), + figsize=(figsize_tmpl[0] * len(d_profile), figsize_tmpl[1]), + sharey=True) + if len(d_profile)==1: #If only one task, then can't zip axes + ax = axes + task = [*d_profile][0] + arr = d_profile[task].mean(axis=0) + if normalize: + arr = arr / arr.max() + plot_stranded_profile(arr, ax=ax, set_ylim=False) + ax.set_title(task) + ax.set_ylabel("Avg. counts") + ax.set_xlabel("Position") + fig.subplots_adjust(wspace=0) # no space between plots + return fig + else: + for i, (task, ax) in enumerate(zip(d_profile, axes)): + arr = d_profile[task].mean(axis=0) + if normalize: + arr = arr / arr.max() + plot_stranded_profile(arr, ax=ax, set_ylim=False) + ax.set_title(task) + if i == 0: + ax.set_ylabel("Avg. counts") + ax.set_xlabel("Position") + fig.subplots_adjust(wspace=0) # no space between plots + return fig + + +def aggregate_profiles(profile_arr, n_bootstrap=None, only_idx=None): + if only_idx is not None: + return profile_arr[only_idx], None + + if n_bootstrap is not None: + return bootstrap_mean(profile_arr, n=n_bootstrap) + else: + return profile_arr.mean(axis=0), None + + +def extract_signal(x, seqlets, rc_fn=lambda x: x[::-1, ::-1]): + def optional_rc(x, is_rc): + if is_rc: + return rc_fn(x) + else: + return x + return np.stack([optional_rc(x[s['example'], s['start']:s['end']], s['rc']) + for s in seqlets]) + + +def plot_profiles(seqlets_by_pattern, + x, + tracks, + contribution_scores={}, + figsize=(20, 2), + start_vec=None, + width=20, + legend=True, + rotate_y=90, + seq_height=1, + ymax=None, # determine y-max + n_limit=35, + n_bootstrap=None, + flip_neg=False, + patterns=None, + fpath_template=None, + only_idx=None, + mkdir=False, + rc_fn=lambda x: x[::-1, ::-1]): + """ + Plot the sequence profiles + Args: + x: one-hot-encoded sequence + tracks: dictionary of profile tracks + contribution_scores: optional dictionary of contribution scores + + """ + import matplotlib.pyplot as plt + from concise.utils.plot import seqlogo_fig, seqlogo + + # Setup start-vec + if start_vec is not None: + if not isinstance(start_vec, list): + start_vec = [start_vec] * len(patterns) + else: + start_vec = [0] * len(patterns) + width = len(x) + + if patterns is None: + patterns = list(seqlets_by_pattern) + # aggregated profiles + d_signal_patterns = {pattern: + {k: aggregate_profiles( + extract_signal(y, seqlets_by_pattern[pattern])[:, start_vec[ip]:(start_vec[ip] + width)], + n_bootstrap=n_bootstrap, only_idx=only_idx) + for k, y in tracks.items()} + for ip, pattern in enumerate(patterns)} + if ymax is None: + # infer ymax + def take_max(x, dx): + if dx is None: + return x.max() + else: + # HACK - hard-coded 2 + return (x + 2 * dx).max() + + ymax = [max([take_max(*d_signal_patterns[pattern][k]) + for pattern in patterns]) + for k in tracks] # loop through all the tracks + if not isinstance(ymax, list): + ymax = [ymax] * len(tracks) + + figs = [] + for i, pattern in enumerate(tqdm(patterns)): + j = i + # -------------- + # extract signal + seqs = extract_signal(x, seqlets_by_pattern[pattern])[:, start_vec[i]:(start_vec[i] + width)] + ext_contribution_scores = {s: extract_signal(contrib, seqlets_by_pattern[pattern])[:, start_vec[i]:(start_vec[i] + width)] + for s, contrib in contribution_scores.items()} + d_signal = d_signal_patterns[pattern] + # -------------- + if only_idx is None: + sequence = ic_scale(seqs.mean(axis=0)) + else: + sequence = seqs[only_idx] + + n = len(seqs) + if n < n_limit: + continue + fig, ax = plt.subplots(1 + len(contribution_scores) + len(tracks), + 1, sharex=True, + figsize=figsize, + gridspec_kw={'height_ratios': [1] * len(tracks) + [seq_height] * (1 + len(contribution_scores))}) + + # signal + ax[0].set_title(f"{pattern} ({n})") + for i, (k, signal) in enumerate(d_signal.items()): + signal_mean, signal_std = d_signal_patterns[pattern][k] + plot_stranded_profile(signal_mean, ax=ax[i], ymax=ymax[i], + profile_std=signal_std, flip_neg=flip_neg) + simple_yaxis_format(ax[i]) + strip_axis(ax[i]) + ax[i].set_ylabel(f"{k}", rotation=rotate_y, ha='right', labelpad=5) + + if legend: + ax[i].legend() + + # ----------- + # contribution scores (seqlogo) + # ----------- + # average the contribution scores + if only_idx is None: + norm_contribution_scores = {k: v.mean(axis=0) + for k, v in ext_contribution_scores.items()} + else: + norm_contribution_scores = {k: v[only_idx] + for k, v in ext_contribution_scores.items()} + + max_scale = max([np.maximum(v, 0).sum(axis=-1).max() for v in norm_contribution_scores.values()]) + min_scale = min([np.minimum(v, 0).sum(axis=-1).min() for v in norm_contribution_scores.values()]) + for k, (contrib_score_name, logo) in enumerate(norm_contribution_scores.items()): + ax_id = len(tracks) + k + + # Trim the pattern if necessary + # plot + ax[ax_id].set_ylim([min_scale, max_scale]) + ax[ax_id].axhline(y=0, linewidth=1, linestyle='--', color='grey') + seqlogo(logo, ax=ax[ax_id]) + + # style + simple_yaxis_format(ax[ax_id]) + strip_axis(ax[ax_id]) + # ax[ax_id].set_ylabel(contrib_score_name) + ax[ax_id].set_ylabel(contrib_score_name, rotation=rotate_y, ha='right', labelpad=5) # va='bottom', + + # ----------- + # information content (seqlogo) + # ----------- + # plot + seqlogo(sequence, ax=ax[-1]) + + # style + simple_yaxis_format(ax[-1]) + strip_axis(ax[-1]) + ax[-1].set_ylabel("Inf. content", rotation=rotate_y, ha='right', labelpad=5) + ax[-1].set_xticks(list(range(0, len(sequence) + 1, 5))) + + figs.append(fig) + # save to file + if fpath_template is not None: + pname = pattern.replace("/", ".") + basepath = fpath_template.format(pname=pname, pattern=pattern) + if mkdir: + os.makedirs(os.path.dirname(basepath), exist_ok=True) + plt.savefig(basepath + '.png', dpi=600) + plt.savefig(basepath + '.pdf', dpi=600) + plt.close(fig) # close the figure + show_figure(fig) + plt.show() + return figs + + +def plot_profiles_single(seqlet, + x, + tracks, + contribution_scores={}, + figsize=(20, 2), + legend=True, + rotate_y=90, + seq_height=1, + flip_neg=False, + rc_fn=lambda x: x[::-1, ::-1]): + """ + Plot the sequence profiles + Args: + x: one-hot-encoded sequence + tracks: dictionary of profile tracks + contribution_scores: optional dictionary of contribution scores + + """ + import matplotlib.pyplot as plt + from concise.utils.plot import seqlogo_fig, seqlogo + + # -------------- + # extract signal + seq = seqlet.extract(x) + ext_contribution_scores = {s: seqlet.extract(contrib) for s, contrib in contribution_scores.items()} + + fig, ax = plt.subplots(1 + len(contribution_scores) + len(tracks), + 1, sharex=True, + figsize=figsize, + gridspec_kw={'height_ratios': [1] * len(tracks) + [seq_height] * (1 + len(contribution_scores))}) + + # signal + for i, (k, signal) in enumerate(tracks.items()): + plot_stranded_profile(seqlet.extract(signal), ax=ax[i], + flip_neg=flip_neg) + simple_yaxis_format(ax[i]) + strip_axis(ax[i]) + ax[i].set_ylabel(f"{k}", rotation=rotate_y, ha='right', labelpad=5) + + if legend: + ax[i].legend() + + # ----------- + # contribution scores (seqlogo) + # ----------- + max_scale = max([np.maximum(v, 0).sum(axis=-1).max() for v in ext_contribution_scores.values()]) + min_scale = min([np.minimum(v, 0).sum(axis=-1).min() for v in ext_contribution_scores.values()]) + for k, (contrib_score_name, logo) in enumerate(ext_contribution_scores.items()): + ax_id = len(tracks) + k + # plot + ax[ax_id].set_ylim([min_scale, max_scale]) + ax[ax_id].axhline(y=0, linewidth=1, linestyle='--', color='grey') + seqlogo(logo, ax=ax[ax_id]) + + # style + simple_yaxis_format(ax[ax_id]) + strip_axis(ax[ax_id]) + # ax[ax_id].set_ylabel(contrib_score_name) + ax[ax_id].set_ylabel(contrib_score_name, rotation=rotate_y, ha='right', labelpad=5) # va='bottom', + + # ----------- + # information content (seqlogo) + # ----------- + # plot + seqlogo(seq, ax=ax[-1]) + + # style + simple_yaxis_format(ax[-1]) + strip_axis(ax[-1]) + ax[-1].set_ylabel("Inf. content", rotation=rotate_y, ha='right', labelpad=5) + ax[-1].set_xticks(list(range(0, len(seq) + 1, 5))) + return fig + + +def hist_position(dfp, tasks): + """Make the positional histogram + + Args: + dfp: pd.DataFrame with columns: peak_id, and center + tasks: list of tasks for which to plot the different peak_id columns + """ + fig, axes = plt.subplots(1, len(tasks), figsize=(5 * len(tasks), 2), + sharey=True, sharex=True) + if len(tasks) == 1: + axes = [axes] + for i, (task, ax) in enumerate(zip(tasks, axes)): + ax.hist(dfp[dfp.peak_id == task].center, bins=100) + ax.set_title(task) + ax.set_xlabel("Position") + ax.set_xlim([0, 1000]) + if i == 0: + ax.set_ylabel("Frequency") + plt.subplots_adjust(wspace=0) + return fig + + +def bar_seqlets_per_example(dfp, tasks): + """Make the positional histogram + + Args: + dfp: pd.DataFrame with columns: peak_id, and center + tasks: list of tasks for which to plot the different peak_id columns + """ + fig, axes = plt.subplots(1, len(tasks), figsize=(5 * len(tasks), 2), + sharey=True, sharex=True) + if len(tasks) == 1: + axes = [axes] + for i, (task, ax) in enumerate(zip(tasks, axes)): + dfpp = dfp[dfp.peak_id == task] + ax.set_title(task) + ax.set_xlabel("Frequency") + if i == 0: + ax.set_ylabel("# per example") + if not len(dfpp): + continue + dfpp.groupby("example_idx").\ + size().value_counts().plot(kind="barh", ax=ax) + plt.subplots_adjust(wspace=0) + return fig + + +def box_counts(total_counts, pattern_idx): + """Make a box-plot with total counts in the region + + Args: + total_counts: dict per task + pattern_idx: array with example_idx of the pattern + """ + dfs = pd.concat([total_counts.melt().assign(subset="all peaks"), + total_counts.iloc[pattern_idx].melt().assign(subset="contains pattern")]) + dfs.value = np.log10(1 + dfs.value) + + fig, ax = plt.subplots(figsize=(5, 5)) + sns.boxplot("variable", "value", hue="subset", data=dfs, ax=ax) + ax.set_xlabel("Task") + ax.set_ylabel("log10(1+counts)") + ax.set_title("Total number of counts in the region") + return fig