Diff of /bpnet/plot/profiles.py [000000] .. [d45a3a]

Switch to side-by-side view

--- a
+++ b/bpnet/plot/profiles.py
@@ -0,0 +1,398 @@
+from tqdm import tqdm
+import seaborn as sns
+import pandas as pd
+import os
+import matplotlib.pyplot as plt
+import numpy as np
+from bpnet.plot.utils import simple_yaxis_format, strip_axis, spaced_xticks
+from bpnet.modisco.utils import bootstrap_mean, nan_like, ic_scale
+from bpnet.plot.utils import show_figure
+
+
+# TODO - make it as a bar-plot with two standard colors:
+# #B23F49 (pos), #045CA8 (neg)
+def plot_stranded_profile(profile, ax=None, ymax=None, profile_std=None, flip_neg=True, set_ylim=True):
+    """Plot the stranded profile
+    """
+    if ax is None:
+        ax = plt.gca()
+
+    if profile.ndim == 1:
+        # also compatible with single dim
+        profile = profile[:, np.newaxis]
+    assert profile.ndim == 2
+    assert profile.shape[1] <= 2
+    labels = ['pos', 'neg']
+
+    # determine ymax if not specified
+    if ymax is None:
+        if profile_std is not None:
+            ymax = (profile.max() - 2 * profile_std).max()
+        else:
+            ymax = profile.max()
+
+    if set_ylim:
+        if flip_neg:
+            ax.set_ylim([-ymax, ymax])
+        else:
+            ax.set_ylim([0, ymax])
+
+    ax.axhline(y=0, linewidth=1, linestyle='--', color='black')
+    # strip_axis(ax)
+
+    xvec = np.arange(1, len(profile) + 1)
+
+    for i in range(profile.shape[1]):
+        sign = 1 if not flip_neg or i == 0 else -1
+        ax.plot(xvec, sign * profile[:, i], label=labels[i])
+
+        # plot also the ribbons
+        if profile_std is not None:
+            ax.fill_between(xvec,
+                            sign * profile[:, i] - 2 * profile_std[:, i],
+                            sign * profile[:, i] + 2 * profile_std[:, i],
+                            alpha=0.1)
+    # return ax
+
+
+def multiple_plot_stranded_profile(d_profile, figsize_tmpl=(4, 3), normalize=False):
+    fig, axes = plt.subplots(1, len(d_profile),
+                             figsize=(figsize_tmpl[0] * len(d_profile), figsize_tmpl[1]),
+                             sharey=True)
+    if len(d_profile)==1: #If only one task, then can't zip axes
+        ax = axes
+        task = [*d_profile][0]
+        arr = d_profile[task].mean(axis=0)
+        if normalize:
+            arr = arr / arr.max()
+        plot_stranded_profile(arr, ax=ax, set_ylim=False)
+        ax.set_title(task)
+        ax.set_ylabel("Avg. counts")
+        ax.set_xlabel("Position")
+        fig.subplots_adjust(wspace=0)  # no space between plots
+        return fig
+    else:
+        for i, (task, ax) in enumerate(zip(d_profile, axes)):
+            arr = d_profile[task].mean(axis=0)
+            if normalize:
+                arr = arr / arr.max()
+            plot_stranded_profile(arr, ax=ax, set_ylim=False)
+            ax.set_title(task)
+            if i == 0:
+                ax.set_ylabel("Avg. counts")
+                ax.set_xlabel("Position")
+        fig.subplots_adjust(wspace=0)  # no space between plots
+        return fig
+
+
+def aggregate_profiles(profile_arr, n_bootstrap=None, only_idx=None):
+    if only_idx is not None:
+        return profile_arr[only_idx], None
+
+    if n_bootstrap is not None:
+        return bootstrap_mean(profile_arr, n=n_bootstrap)
+    else:
+        return profile_arr.mean(axis=0), None
+
+
+def extract_signal(x, seqlets, rc_fn=lambda x: x[::-1, ::-1]):
+    def optional_rc(x, is_rc):
+        if is_rc:
+            return rc_fn(x)
+        else:
+            return x
+    return np.stack([optional_rc(x[s['example'], s['start']:s['end']], s['rc'])
+                     for s in seqlets])
+
+
+def plot_profiles(seqlets_by_pattern,
+                  x,
+                  tracks,
+                  contribution_scores={},
+                  figsize=(20, 2),
+                  start_vec=None,
+                  width=20,
+                  legend=True,
+                  rotate_y=90,
+                  seq_height=1,
+                  ymax=None,  # determine y-max
+                  n_limit=35,
+                  n_bootstrap=None,
+                  flip_neg=False,
+                  patterns=None,
+                  fpath_template=None,
+                  only_idx=None,
+                  mkdir=False,
+                  rc_fn=lambda x: x[::-1, ::-1]):
+    """
+    Plot the sequence profiles
+    Args:
+      x: one-hot-encoded sequence
+      tracks: dictionary of profile tracks
+      contribution_scores: optional dictionary of contribution scores
+
+    """
+    import matplotlib.pyplot as plt
+    from concise.utils.plot import seqlogo_fig, seqlogo
+
+    # Setup start-vec
+    if start_vec is not None:
+        if not isinstance(start_vec, list):
+            start_vec = [start_vec] * len(patterns)
+    else:
+        start_vec = [0] * len(patterns)
+        width = len(x)
+
+    if patterns is None:
+        patterns = list(seqlets_by_pattern)
+    # aggregated profiles
+    d_signal_patterns = {pattern:
+                         {k: aggregate_profiles(
+                             extract_signal(y, seqlets_by_pattern[pattern])[:, start_vec[ip]:(start_vec[ip] + width)],
+                             n_bootstrap=n_bootstrap, only_idx=only_idx)
+                          for k, y in tracks.items()}
+                         for ip, pattern in enumerate(patterns)}
+    if ymax is None:
+        # infer ymax
+        def take_max(x, dx):
+            if dx is None:
+                return x.max()
+            else:
+                # HACK - hard-coded 2
+                return (x + 2 * dx).max()
+
+        ymax = [max([take_max(*d_signal_patterns[pattern][k])
+                     for pattern in patterns])
+                for k in tracks]  # loop through all the tracks
+    if not isinstance(ymax, list):
+        ymax = [ymax] * len(tracks)
+
+    figs = []
+    for i, pattern in enumerate(tqdm(patterns)):
+        j = i
+        # --------------
+        # extract signal
+        seqs = extract_signal(x, seqlets_by_pattern[pattern])[:, start_vec[i]:(start_vec[i] + width)]
+        ext_contribution_scores = {s: extract_signal(contrib, seqlets_by_pattern[pattern])[:, start_vec[i]:(start_vec[i] + width)]
+                                   for s, contrib in contribution_scores.items()}
+        d_signal = d_signal_patterns[pattern]
+        # --------------
+        if only_idx is None:
+            sequence = ic_scale(seqs.mean(axis=0))
+        else:
+            sequence = seqs[only_idx]
+
+        n = len(seqs)
+        if n < n_limit:
+            continue
+        fig, ax = plt.subplots(1 + len(contribution_scores) + len(tracks),
+                               1, sharex=True,
+                               figsize=figsize,
+                               gridspec_kw={'height_ratios': [1] * len(tracks) + [seq_height] * (1 + len(contribution_scores))})
+
+        # signal
+        ax[0].set_title(f"{pattern} ({n})")
+        for i, (k, signal) in enumerate(d_signal.items()):
+            signal_mean, signal_std = d_signal_patterns[pattern][k]
+            plot_stranded_profile(signal_mean, ax=ax[i], ymax=ymax[i],
+                                  profile_std=signal_std, flip_neg=flip_neg)
+            simple_yaxis_format(ax[i])
+            strip_axis(ax[i])
+            ax[i].set_ylabel(f"{k}", rotation=rotate_y, ha='right', labelpad=5)
+
+            if legend:
+                ax[i].legend()
+
+        # -----------
+        # contribution scores (seqlogo)
+        # -----------
+        # average the contribution scores
+        if only_idx is None:
+            norm_contribution_scores = {k: v.mean(axis=0)
+                                        for k, v in ext_contribution_scores.items()}
+        else:
+            norm_contribution_scores = {k: v[only_idx]
+                                        for k, v in ext_contribution_scores.items()}
+
+        max_scale = max([np.maximum(v, 0).sum(axis=-1).max() for v in norm_contribution_scores.values()])
+        min_scale = min([np.minimum(v, 0).sum(axis=-1).min() for v in norm_contribution_scores.values()])
+        for k, (contrib_score_name, logo) in enumerate(norm_contribution_scores.items()):
+            ax_id = len(tracks) + k
+
+            # Trim the pattern if necessary
+            # plot
+            ax[ax_id].set_ylim([min_scale, max_scale])
+            ax[ax_id].axhline(y=0, linewidth=1, linestyle='--', color='grey')
+            seqlogo(logo, ax=ax[ax_id])
+
+            # style
+            simple_yaxis_format(ax[ax_id])
+            strip_axis(ax[ax_id])
+            # ax[ax_id].set_ylabel(contrib_score_name)
+            ax[ax_id].set_ylabel(contrib_score_name, rotation=rotate_y, ha='right', labelpad=5)  # va='bottom',
+
+        # -----------
+        # information content (seqlogo)
+        # -----------
+        # plot
+        seqlogo(sequence, ax=ax[-1])
+
+        # style
+        simple_yaxis_format(ax[-1])
+        strip_axis(ax[-1])
+        ax[-1].set_ylabel("Inf. content", rotation=rotate_y, ha='right', labelpad=5)
+        ax[-1].set_xticks(list(range(0, len(sequence) + 1, 5)))
+
+        figs.append(fig)
+        # save to file
+        if fpath_template is not None:
+            pname = pattern.replace("/", ".")
+            basepath = fpath_template.format(pname=pname, pattern=pattern)
+            if mkdir:
+                os.makedirs(os.path.dirname(basepath), exist_ok=True)
+            plt.savefig(basepath + '.png', dpi=600)
+            plt.savefig(basepath + '.pdf', dpi=600)
+            plt.close(fig)    # close the figure
+            show_figure(fig)
+            plt.show()
+    return figs
+
+
+def plot_profiles_single(seqlet,
+                         x,
+                         tracks,
+                         contribution_scores={},
+                         figsize=(20, 2),
+                         legend=True,
+                         rotate_y=90,
+                         seq_height=1,
+                         flip_neg=False,
+                         rc_fn=lambda x: x[::-1, ::-1]):
+    """
+    Plot the sequence profiles
+    Args:
+      x: one-hot-encoded sequence
+      tracks: dictionary of profile tracks
+      contribution_scores: optional dictionary of contribution scores
+
+    """
+    import matplotlib.pyplot as plt
+    from concise.utils.plot import seqlogo_fig, seqlogo
+
+    # --------------
+    # extract signal
+    seq = seqlet.extract(x)
+    ext_contribution_scores = {s: seqlet.extract(contrib) for s, contrib in contribution_scores.items()}
+
+    fig, ax = plt.subplots(1 + len(contribution_scores) + len(tracks),
+                           1, sharex=True,
+                           figsize=figsize,
+                           gridspec_kw={'height_ratios': [1] * len(tracks) + [seq_height] * (1 + len(contribution_scores))})
+
+    # signal
+    for i, (k, signal) in enumerate(tracks.items()):
+        plot_stranded_profile(seqlet.extract(signal), ax=ax[i],
+                              flip_neg=flip_neg)
+        simple_yaxis_format(ax[i])
+        strip_axis(ax[i])
+        ax[i].set_ylabel(f"{k}", rotation=rotate_y, ha='right', labelpad=5)
+
+        if legend:
+            ax[i].legend()
+
+    # -----------
+    # contribution scores (seqlogo)
+    # -----------
+    max_scale = max([np.maximum(v, 0).sum(axis=-1).max() for v in ext_contribution_scores.values()])
+    min_scale = min([np.minimum(v, 0).sum(axis=-1).min() for v in ext_contribution_scores.values()])
+    for k, (contrib_score_name, logo) in enumerate(ext_contribution_scores.items()):
+        ax_id = len(tracks) + k
+        # plot
+        ax[ax_id].set_ylim([min_scale, max_scale])
+        ax[ax_id].axhline(y=0, linewidth=1, linestyle='--', color='grey')
+        seqlogo(logo, ax=ax[ax_id])
+
+        # style
+        simple_yaxis_format(ax[ax_id])
+        strip_axis(ax[ax_id])
+        # ax[ax_id].set_ylabel(contrib_score_name)
+        ax[ax_id].set_ylabel(contrib_score_name, rotation=rotate_y, ha='right', labelpad=5)  # va='bottom',
+
+    # -----------
+    # information content (seqlogo)
+    # -----------
+    # plot
+    seqlogo(seq, ax=ax[-1])
+
+    # style
+    simple_yaxis_format(ax[-1])
+    strip_axis(ax[-1])
+    ax[-1].set_ylabel("Inf. content", rotation=rotate_y, ha='right', labelpad=5)
+    ax[-1].set_xticks(list(range(0, len(seq) + 1, 5)))
+    return fig
+
+
+def hist_position(dfp, tasks):
+    """Make the positional histogram
+
+    Args:
+      dfp: pd.DataFrame with columns: peak_id, and center
+      tasks: list of tasks for which to plot the different peak_id columns
+    """
+    fig, axes = plt.subplots(1, len(tasks), figsize=(5 * len(tasks), 2),
+                             sharey=True, sharex=True)
+    if len(tasks) == 1:
+        axes = [axes]
+    for i, (task, ax) in enumerate(zip(tasks, axes)):
+        ax.hist(dfp[dfp.peak_id == task].center, bins=100)
+        ax.set_title(task)
+        ax.set_xlabel("Position")
+        ax.set_xlim([0, 1000])
+        if i == 0:
+            ax.set_ylabel("Frequency")
+    plt.subplots_adjust(wspace=0)
+    return fig
+
+
+def bar_seqlets_per_example(dfp, tasks):
+    """Make the positional histogram
+
+    Args:
+      dfp: pd.DataFrame with columns: peak_id, and center
+      tasks: list of tasks for which to plot the different peak_id columns
+    """
+    fig, axes = plt.subplots(1, len(tasks), figsize=(5 * len(tasks), 2),
+                             sharey=True, sharex=True)
+    if len(tasks) == 1:
+        axes = [axes]
+    for i, (task, ax) in enumerate(zip(tasks, axes)):
+        dfpp = dfp[dfp.peak_id == task]
+        ax.set_title(task)
+        ax.set_xlabel("Frequency")
+        if i == 0:
+            ax.set_ylabel("# per example")
+        if not len(dfpp):
+            continue
+        dfpp.groupby("example_idx").\
+            size().value_counts().plot(kind="barh", ax=ax)
+    plt.subplots_adjust(wspace=0)
+    return fig
+
+
+def box_counts(total_counts, pattern_idx):
+    """Make a box-plot with total counts in the region
+
+    Args:
+      total_counts: dict per task
+      pattern_idx: array with example_idx of the pattern
+    """
+    dfs = pd.concat([total_counts.melt().assign(subset="all peaks"),
+                     total_counts.iloc[pattern_idx].melt().assign(subset="contains pattern")])
+    dfs.value = np.log10(1 + dfs.value)
+
+    fig, ax = plt.subplots(figsize=(5, 5))
+    sns.boxplot("variable", "value", hue="subset", data=dfs, ax=ax)
+    ax.set_xlabel("Task")
+    ax.set_ylabel("log10(1+counts)")
+    ax.set_title("Total number of counts in the region")
+    return fig