Diff of /bpnet/plot/vdom.py [000000] .. [d45a3a]

Switch to side-by-side view

--- a
+++ b/bpnet/plot/vdom.py
@@ -0,0 +1,449 @@
+"""Vdom visualization for modisco
+"""
+import pandas as pd
+from bpnet.plot.heatmaps import multiple_heatmap_stranded_profile, multiple_heatmap_contribution_profile, heatmap_sequence
+from bpnet.cli.contrib import ContribFile
+from collections import OrderedDict
+from bpnet.plot.profiles import extract_signal, multiple_plot_stranded_profile, hist_position, bar_seqlets_per_example, box_counts
+from bpnet.functions import mean
+import numpy as np
+import pandas as pd
+from vdom.helpers import (h1, p, li, img, div, b, br, ul, img,
+                          details, summary,
+                          table, thead, th, tr, tbody, td, ol)
+import io
+import base64
+import urllib
+import matplotlib.pyplot as plt
+import os
+
+
+def fig2vdom(fig, **kwargs):
+    """Convert a matplotlib figure to an online image
+    """
+    buf = io.BytesIO()
+    fig.savefig(buf, format='png', bbox_inches='tight')
+    buf.seek(0)
+    string = base64.b64encode(buf.read())
+    plt.close()
+    return img(src='data:image/png;base64,' + urllib.parse.quote(string), **kwargs)
+
+
+def vdom_pssm(pssm, letter_width=0.2, letter_height=0.8, **kwargs):
+    """Nicely plot the pssm
+    """
+    import matplotlib.pyplot as plt
+    from concise.utils.plot import seqlogo_fig, seqlogo
+    fig, ax = plt.subplots(figsize=(letter_width * len(pssm), letter_height))
+    ax.axison = False
+    seqlogo(pssm, ax=ax)
+    plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
+    return fig2vdom(fig, **kwargs)
+
+
+def vdom_footprint(arr, r_height=None, text=None,
+                   fontsize=32, figsize=(3, 1), **kwargs):
+    """Plot the sparkline for the footprint
+
+    Args:
+      arr: np.array of shape (seq_len, 2)
+      r_height: if not None, add a rectangle with heigth = r_height
+      text: add additional text to top right corner
+      fontsize: size of the additional font
+      figsize: figure size
+      **kwargs: additional kwargs passed to `fig2vdom`
+
+    Returns:
+      VDOM object containing the image
+    """
+    import matplotlib.patches as patches
+
+    fig, ax = plt.subplots(figsize=figsize)
+    # print(arr.max())
+    if r_height is not None:
+        rect = patches.Rectangle((0, 0), len(arr),
+                                 r_height,  # / arr.max(),
+                                 linewidth=1,
+                                 edgecolor=None,
+                                 alpha=0.3,
+                                 facecolor='lightgrey')
+        ax.add_patch(rect)
+        ax.set_ylim([0, max(r_height, arr.max())])
+        ax.axhline(r_height, alpha=0.3, color='black', linestyle='dashed')
+    ax.plot(arr[:, 0])
+    ax.plot(arr[:, 1])
+
+    if text is not None:
+        # Annotate text top-left
+        pass
+        ax.text(1, 1, text,
+                fontsize=fontsize,
+                transform=ax.transAxes,
+                verticalalignment='top',
+                horizontalalignment='right')
+
+    ax.axison = False
+    plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
+    return fig2vdom(fig, **kwargs)
+
+
+def template_vdom_pattern(name, n_seqlets, trimmed_motif,
+                          full_motif, figures_url, add_plots={}, metacluster=""):
+
+    return details(summary(name, f": # seqlets: {n_seqlets}",
+                           # br(),
+                           trimmed_motif),  # ", rc: ",  motif_rc),
+                   details(summary("Aggregated profiles and contribution scores)"),
+                           img(src=figures_url + "/agg_profile_contribcores.png", width=840),
+                           ),
+                   details(summary("Aggregated hypothetical contribution scores)"),
+                           img(src=figures_url + "/agg_profile_hypcontribscores.png", width=840),
+                           ),
+                   details(summary("Sequence"),
+                           full_motif,
+                           br(),
+                           img(src=figures_url + "/heatmap_seq.png", width=840 // 2),
+                           ),
+                   details(summary("ChIP-nexus counts"),
+                           img(src=figures_url + "/profile_aggregated.png", width=840),
+                           img(src=figures_url + "/profile_heatmap.png", width=840),
+                           ),
+                   details(summary("Contribution scores (profile)"),
+                           img(src=figures_url + "/contrib_profile.png", width=840),
+                           ),
+                   details(summary("Contribution scores (counts)"),
+                           img(src=figures_url + "/contrib_counts.png", width=840),
+                           ),
+                   *[details(summary(k), *v) for k, v in add_plots.items()],
+                   id=metacluster + "/" + name
+                   )
+
+
+def vdom_pattern(mr, metacluster, pattern,
+                 figdir,
+                 total_counts,
+                 dfp,
+                 trim_frac=0.05,
+                 letter_width=0.2, height=0.8):
+
+    # get the trimmed motifs
+    trimmed_motif = vdom_pssm(mr.get_pssm(metacluster + '/' + pattern,
+                                          rc=False, trim_frac=trim_frac),
+                              letter_width=letter_width,
+                              height=height)
+    full_motif = vdom_pssm(mr.get_pssm(metacluster + '/' + pattern,
+                                       rc=False, trim_frac=0),
+                           letter_width=letter_width,
+                           height=height)
+
+    # ----------------
+    # add new plots here
+    dfpp = dfp[dfp.pattern == (metacluster + "/" + pattern)]
+    tasks = dfp.peak_id.unique()
+    pattern_idx = dfpp.example_idx.unique()
+    add_plots = OrderedDict([
+        ("Positional distribution",
+         [fig2vdom(hist_position(dfpp, tasks=tasks)),
+          fig2vdom(bar_seqlets_per_example(dfpp, tasks=tasks))
+          ]),
+        ("Total count distribution",
+         [p(f"Pattern occurs in {len(pattern_idx)} / {len(total_counts)} regions"
+            f" ({100*len(pattern_idx)/len(total_counts):.1f}%)"),
+          fig2vdom(box_counts(total_counts, pattern_idx))]
+         )
+    ])
+    # ----------------
+
+    return template_vdom_pattern(name=pattern,
+                                 n_seqlets=mr.n_seqlets(metacluster + "/" + pattern),
+                                 trimmed_motif=trimmed_motif,
+                                 full_motif=full_motif,
+                                 figures_url=os.path.join(figdir, f"{metacluster}/{pattern}"),
+                                 add_plots=add_plots,
+                                 metacluster=metacluster,
+                                 )
+
+
+def template_vdom_metacluster(name, n_patterns, n_seqlets, important_for, patterns, is_open=False):
+    return details(summary(b(name), f", # patterns: {n_patterns},"
+                           f" # seqlets: {n_seqlets}, "
+                           "important for: ", b(important_for)),
+                   ul([li(pattern) for pattern in patterns], start=0),
+                   id=name,
+                   open=is_open)
+
+
+def vdom_metacluster(mr, metacluster, figdir, total_counts, dfp=None, is_open=True,
+                     **kwargs):
+    patterns = mr.pattern_names(metacluster)
+    n_seqlets = sum([mr.n_seqlets(metacluster + "/" + pattern)
+                     for pattern in patterns])
+    n_patterns = len(patterns)
+
+    def render_act(task, act):
+        """Render the activity vector
+        """
+        task = task.replace("/weighted", "").replace("/profile", "")  # omit weighted or profile
+        if act == 0:
+            return ""
+        elif act < 0:
+            return f"-{task}"
+        else:
+            return task
+    activities = mr.metacluster_activity(metacluster)
+
+    # tasks = mr.tasks()
+    # tasks = unique_list([task.split("/")[0] for task in tasks])  # HACK. For some
+    # TODO - one could pretify this here by using Task, and cTask
+
+    important_for = ",".join([render_act(task, act)
+                              for task, act in zip(mr.tasks(), activities)
+                              if act != 0])
+    pattern_vdoms = [vdom_pattern(mr, metacluster, pattern, figdir, total_counts,
+                                  dfp, **kwargs)
+                     for pattern in patterns]
+    return template_vdom_metacluster(metacluster,
+                                     n_patterns,
+                                     n_seqlets,
+                                     important_for,
+                                     pattern_vdoms,
+                                     is_open=is_open
+                                     )
+
+
+def vdom_modisco(mr, figdir, total_counts, dfp=None, is_open=True, **kwargs):
+    return div([vdom_metacluster(mr, metacluster, figdir, total_counts, dfp=dfp,
+                                 is_open=is_open, **kwargs)
+                for metacluster in mr.metaclusters()
+                if len(mr.pattern_names(metacluster)) > 0])
+
+
+def get_signal(seqlets, d: ContribFile, tasks, resize_width=200):
+    thr_one_hot = d.get_seq()
+
+    if resize_width is None:
+        # width = first seqlets
+        resize_width = seqlets[0].end - seqlets[0].start
+
+    # get valid seqlets
+    start_pad = np.ceil(resize_width / 2)
+    end_pad = thr_one_hot.shape[1] - start_pad
+    valid_seqlets = [s.resize(resize_width)
+                     for s in seqlets
+                     if (s.center() > start_pad) and (s.center() < end_pad)]
+
+    # prepare data
+    ex_signal = {task: extract_signal(d.get_profiles()[task], valid_seqlets)
+                 for task in tasks}
+
+    ex_contrib_profile = {task: extract_signal(d.get_contrib()[task],
+                                               valid_seqlets).sum(axis=-1)
+                          for task in tasks}
+
+    if d.contains_contrib_score('count'):
+        ex_contrib_counts = {task: extract_signal(d.get_contrib("count")[task],
+                                                  valid_seqlets).sum(axis=-1) for task in tasks}
+    elif d.contains_contrib_score('counts/pre-act'):
+        ex_contrib_counts = {task: extract_signal(d.get_contrib("counts/pre-act")[task],
+                                                  valid_seqlets).sum(axis=-1) for task in tasks}
+    else:
+        ex_contrib_counts = None
+
+    ex_seq = extract_signal(thr_one_hot, valid_seqlets)
+
+    seq, contrib, hyp_contrib, profile, ranges = d.get_all()
+
+    total_counts = sum([x.sum(axis=-1).sum(axis=-1) for x in ex_signal.values()])
+    sort_idx = np.argsort(-total_counts)
+    return ex_signal, ex_contrib_profile, ex_contrib_counts, ex_seq, sort_idx
+
+
+def vdm_heatmaps(seqlets, d, included_samples, tasks, pattern, top_n=None, pssm_fig=None, opened=False, resize_width=200):
+    ex_signal, ex_contrib_profile, ex_contrib_counts, ex_seq, sort_idx = get_signal(seqlets, d, included_samples, tasks,
+                                                                                    resize_width=resize_width)
+
+    if top_n is not None:
+        sort_idx = sort_idx[:top_n]
+    return div(details(summary("Sequence:"),
+                       pssm_fig,
+                       br(),
+                       fig2vdom(heatmap_sequence(ex_seq, sort_idx=sort_idx, figsize_tmpl=(10, 15), aspect='auto')),
+                       open=opened
+                       ),
+
+               details(summary("ChIP-nexus counts:"),
+                       fig2vdom(multiple_plot_stranded_profile(ex_signal, figsize_tmpl=(20 / len(ex_signal), 3))),
+                       # TODO - change
+                       fig2vdom(multiple_heatmap_stranded_profile(ex_signal, sort_idx=sort_idx, figsize=(20, 20))),
+                       open=opened
+                       ),
+               details(summary("Contribution scores (profile)"),
+                       fig2vdom(multiple_heatmap_contribution_profile(ex_contrib_profile, sort_idx=sort_idx, figsize=(20, 20))),
+                       open=opened
+                       ),
+               details(summary("Contribution scores (counts)"),
+                       fig2vdom(multiple_heatmap_contribution_profile(ex_contrib_counts, sort_idx=sort_idx, figsize=(20, 20))),
+                       open=opened
+                       )
+               )
+
+
+def write_heatmap_pngs(seqlets, d, tasks, pattern, output_dir, resize_width=200):
+    """Write out histogram png's
+    """
+    # get the data
+    ex_signal, ex_contrib_profile, ex_contrib_counts, ex_seq, sort_idx = get_signal(seqlets, d, tasks,
+                                                                                    resize_width=resize_width)
+    # get the plots
+    figs = dict(
+        heatmap_seq=heatmap_sequence(ex_seq, sort_idx=sort_idx, figsize_tmpl=(10, 15), aspect='auto'),
+        profile_aggregated=multiple_plot_stranded_profile(ex_signal, figsize_tmpl=(20 / len(ex_signal), 3)),
+        profile_heatmap=multiple_heatmap_stranded_profile(ex_signal, sort_idx=sort_idx, figsize=(20, 20)),
+        contrib_profile=multiple_heatmap_contribution_profile(ex_contrib_profile, sort_idx=sort_idx, figsize=(20, 20)),
+    )
+
+    if ex_contrib_counts is not None:
+        figs['contrib_counts'] = multiple_heatmap_contribution_profile(ex_contrib_counts, sort_idx=sort_idx, figsize=(20, 20))
+    # write the figures
+    for k, fig in figs.items():
+        fig.savefig(os.path.join(output_dir, k + ".png"), bbox_inches='tight')
+
+
+def df2html(df, uuid='table', style='width:100%'):
+    import seaborn as sns
+    cm = sns.light_palette("green", as_cmap=True)
+    # leverage pandas style to color cells according to values
+    # https://pandas.pydata.org/pandas-docs/stable/style.html
+    s = df.style.background_gradient(cmap=cm).set_precision(3).hide_index()
+    return s.render(uuid=uuid).replace(f'<table id="T_{uuid}"',
+                                       f'<table id="T_{uuid}" class="compact hover nowrap" style="{style}"')
+
+
+def df2html_old(df, style='width:100%'):
+    add_tags = f'id="table_id" style="{style}"'
+    with pd.option_context('display.max_colwidth', -1):
+        table = df.to_html(escape=False,
+                           classes='display nowrap',
+                           float_format='%.2g',
+                           index=False).replace(' class="dataframe', f' {add_tags} class="dataframe')
+    return table
+
+# <link rel="stylesheet" type="text/css" href="https://cdn.datatables.net/1.10.13/css/jquery.dataTables.css">
+
+def get_datatable_header():
+    return '''
+      <script type="text/javascript" src="https://code.jquery.com/jquery-3.3.1.js"></script>
+      <script type="text/javascript"  src="https://cdn.datatables.net/1.10.19/js/jquery.dataTables.min.js"></script>
+      <script type="text/javascript"  src="https://cdn.datatables.net/colreorder/1.5.1/js/dataTables.colReorder.min.js"></script>
+      <script type="text/javascript"  src="https://cdn.datatables.net/fixedcolumns/3.2.6/js/dataTables.fixedColumns.min.js"></script>
+      
+      <link rel="stylesheet" type="text/css" href="https://cdn.datatables.net/1.10.19/css/jquery.dataTables.min.css">     
+      <link rel="stylesheet" type="text/css" href="https://cdn.datatables.net/colreorder/1.5.1/css/colReorder.dataTables.min.css">
+      <link rel="stylesheet" type="text/css" href="https://cdn.datatables.net/fixedcolumns/3.2.6/css/fixedColumns.dataTables.min.css">
+      <link rel="stylesheet" href="https://cdn.jupyter.org/notebook/5.1.0/style/style.min.css">
+    '''
+
+
+def style_html_table_datatable(html_str):
+    from IPython.display import HTML, Javascript
+
+    header = f'''
+    <!DOCTYPE html>
+    <html lang="en">
+    <head>
+    {get_datatable_header()}
+    <head>
+    </body>
+        '''
+    script = '''
+    <script>
+    $(document).ready( function () {
+    var table = $('#T_table').DataTable({
+         scrollX: true,
+         scrollY: '80vh',
+         scrollCollapse: true,
+         paging: false,
+         colReorder: true,
+         columnDefs: [
+            { orderable: false, targets: 0 },
+            { orderable: false, targets: 1 }
+        ],
+        ordering: [[ 1, 'asc' ]],
+        colReorder: {
+            fixedColumnsLeft: 1,
+            fixedColumnsRight: 0
+        }
+    });
+
+    new $.fn.dataTable.FixedColumns( table, {
+        leftColumns: 3,
+        rightColumns: 0
+    } );
+
+    // Select rows
+    $('#T_table tbody').on( 'click', 'tr', function () {
+        $(this).toggleClass('selected');
+    } );
+
+    } );
+    </script>
+    </body>
+    </html>
+    '''
+
+    return header + html_str + script
+
+def write_datatable_html(df, output_file, other=""):
+    html = style_html_table_datatable(df2html(df) + other)
+    with open(output_file, "w") as f:
+        f.write(html)
+
+
+def render_datatable(df):
+    from IPython.display import HTML, Javascript, display
+    display(HTML(get_datatable_header() + df2html(df)))
+    # display(Javascript(""" $(document).ready( function () {
+    # $('#T_table').DataTable();
+    # } );"""))
+
+
+def footprint_df(footprints, dfl=None, width=120, **kwargs):
+    """Draw footprints sparklines into a pandas.DataFrame
+
+    Args:
+      footprints: footprint dict with `<pattern>/<task>` nested structure
+        each node contains an array of shape (seq_len, 2)
+      dfl: optional pandas.DataFrame of labels. Contains columns:
+        pattern <task>/l
+      width: width of the final plot
+      **kwargs: additional kwargs to pass to vdom_footprint
+    """
+    from tqdm import tqdm
+    from bpnet.modisco.utils import shorten_pattern
+
+    def map_label(l):
+        """Label -> short-name
+        """
+        # TODO - get rid of this function
+        if l is None:
+            return "/"
+        else:
+            return l[0].upper()
+    tasks = list(footprints[list(footprints)[0]].keys())
+    profile_max_median = {task: np.median([np.max(v[task]) for v in footprints.values()]) for task in tasks}
+    out = []
+
+    for p, arr_d in tqdm(footprints.items()):
+        try:
+            labels = dfl[dfl.pattern == shorten_pattern(p)].iloc[0].to_dict()
+        except Exception:
+            labels = {t + "/l": None for t in tasks}
+        d = {task: vdom_footprint(arr_d[task],
+                                  r_height=profile_max_median[task],
+                                  text=map_label(labels[task + "/l"]),
+                                  **kwargs).to_html().replace("<img",
+                                                              f"<img width={width}")
+             for task in tasks}
+        d['pattern'] = shorten_pattern(p)
+        out.append(d)
+    return pd.DataFrame(out)