--- a +++ b/bpnet/plot/vdom.py @@ -0,0 +1,449 @@ +"""Vdom visualization for modisco +""" +import pandas as pd +from bpnet.plot.heatmaps import multiple_heatmap_stranded_profile, multiple_heatmap_contribution_profile, heatmap_sequence +from bpnet.cli.contrib import ContribFile +from collections import OrderedDict +from bpnet.plot.profiles import extract_signal, multiple_plot_stranded_profile, hist_position, bar_seqlets_per_example, box_counts +from bpnet.functions import mean +import numpy as np +import pandas as pd +from vdom.helpers import (h1, p, li, img, div, b, br, ul, img, + details, summary, + table, thead, th, tr, tbody, td, ol) +import io +import base64 +import urllib +import matplotlib.pyplot as plt +import os + + +def fig2vdom(fig, **kwargs): + """Convert a matplotlib figure to an online image + """ + buf = io.BytesIO() + fig.savefig(buf, format='png', bbox_inches='tight') + buf.seek(0) + string = base64.b64encode(buf.read()) + plt.close() + return img(src='data:image/png;base64,' + urllib.parse.quote(string), **kwargs) + + +def vdom_pssm(pssm, letter_width=0.2, letter_height=0.8, **kwargs): + """Nicely plot the pssm + """ + import matplotlib.pyplot as plt + from concise.utils.plot import seqlogo_fig, seqlogo + fig, ax = plt.subplots(figsize=(letter_width * len(pssm), letter_height)) + ax.axison = False + seqlogo(pssm, ax=ax) + plt.subplots_adjust(left=0, right=1, top=1, bottom=0) + return fig2vdom(fig, **kwargs) + + +def vdom_footprint(arr, r_height=None, text=None, + fontsize=32, figsize=(3, 1), **kwargs): + """Plot the sparkline for the footprint + + Args: + arr: np.array of shape (seq_len, 2) + r_height: if not None, add a rectangle with heigth = r_height + text: add additional text to top right corner + fontsize: size of the additional font + figsize: figure size + **kwargs: additional kwargs passed to `fig2vdom` + + Returns: + VDOM object containing the image + """ + import matplotlib.patches as patches + + fig, ax = plt.subplots(figsize=figsize) + # print(arr.max()) + if r_height is not None: + rect = patches.Rectangle((0, 0), len(arr), + r_height, # / arr.max(), + linewidth=1, + edgecolor=None, + alpha=0.3, + facecolor='lightgrey') + ax.add_patch(rect) + ax.set_ylim([0, max(r_height, arr.max())]) + ax.axhline(r_height, alpha=0.3, color='black', linestyle='dashed') + ax.plot(arr[:, 0]) + ax.plot(arr[:, 1]) + + if text is not None: + # Annotate text top-left + pass + ax.text(1, 1, text, + fontsize=fontsize, + transform=ax.transAxes, + verticalalignment='top', + horizontalalignment='right') + + ax.axison = False + plt.subplots_adjust(left=0, right=1, top=1, bottom=0) + return fig2vdom(fig, **kwargs) + + +def template_vdom_pattern(name, n_seqlets, trimmed_motif, + full_motif, figures_url, add_plots={}, metacluster=""): + + return details(summary(name, f": # seqlets: {n_seqlets}", + # br(), + trimmed_motif), # ", rc: ", motif_rc), + details(summary("Aggregated profiles and contribution scores)"), + img(src=figures_url + "/agg_profile_contribcores.png", width=840), + ), + details(summary("Aggregated hypothetical contribution scores)"), + img(src=figures_url + "/agg_profile_hypcontribscores.png", width=840), + ), + details(summary("Sequence"), + full_motif, + br(), + img(src=figures_url + "/heatmap_seq.png", width=840 // 2), + ), + details(summary("ChIP-nexus counts"), + img(src=figures_url + "/profile_aggregated.png", width=840), + img(src=figures_url + "/profile_heatmap.png", width=840), + ), + details(summary("Contribution scores (profile)"), + img(src=figures_url + "/contrib_profile.png", width=840), + ), + details(summary("Contribution scores (counts)"), + img(src=figures_url + "/contrib_counts.png", width=840), + ), + *[details(summary(k), *v) for k, v in add_plots.items()], + id=metacluster + "/" + name + ) + + +def vdom_pattern(mr, metacluster, pattern, + figdir, + total_counts, + dfp, + trim_frac=0.05, + letter_width=0.2, height=0.8): + + # get the trimmed motifs + trimmed_motif = vdom_pssm(mr.get_pssm(metacluster + '/' + pattern, + rc=False, trim_frac=trim_frac), + letter_width=letter_width, + height=height) + full_motif = vdom_pssm(mr.get_pssm(metacluster + '/' + pattern, + rc=False, trim_frac=0), + letter_width=letter_width, + height=height) + + # ---------------- + # add new plots here + dfpp = dfp[dfp.pattern == (metacluster + "/" + pattern)] + tasks = dfp.peak_id.unique() + pattern_idx = dfpp.example_idx.unique() + add_plots = OrderedDict([ + ("Positional distribution", + [fig2vdom(hist_position(dfpp, tasks=tasks)), + fig2vdom(bar_seqlets_per_example(dfpp, tasks=tasks)) + ]), + ("Total count distribution", + [p(f"Pattern occurs in {len(pattern_idx)} / {len(total_counts)} regions" + f" ({100*len(pattern_idx)/len(total_counts):.1f}%)"), + fig2vdom(box_counts(total_counts, pattern_idx))] + ) + ]) + # ---------------- + + return template_vdom_pattern(name=pattern, + n_seqlets=mr.n_seqlets(metacluster + "/" + pattern), + trimmed_motif=trimmed_motif, + full_motif=full_motif, + figures_url=os.path.join(figdir, f"{metacluster}/{pattern}"), + add_plots=add_plots, + metacluster=metacluster, + ) + + +def template_vdom_metacluster(name, n_patterns, n_seqlets, important_for, patterns, is_open=False): + return details(summary(b(name), f", # patterns: {n_patterns}," + f" # seqlets: {n_seqlets}, " + "important for: ", b(important_for)), + ul([li(pattern) for pattern in patterns], start=0), + id=name, + open=is_open) + + +def vdom_metacluster(mr, metacluster, figdir, total_counts, dfp=None, is_open=True, + **kwargs): + patterns = mr.pattern_names(metacluster) + n_seqlets = sum([mr.n_seqlets(metacluster + "/" + pattern) + for pattern in patterns]) + n_patterns = len(patterns) + + def render_act(task, act): + """Render the activity vector + """ + task = task.replace("/weighted", "").replace("/profile", "") # omit weighted or profile + if act == 0: + return "" + elif act < 0: + return f"-{task}" + else: + return task + activities = mr.metacluster_activity(metacluster) + + # tasks = mr.tasks() + # tasks = unique_list([task.split("/")[0] for task in tasks]) # HACK. For some + # TODO - one could pretify this here by using Task, and cTask + + important_for = ",".join([render_act(task, act) + for task, act in zip(mr.tasks(), activities) + if act != 0]) + pattern_vdoms = [vdom_pattern(mr, metacluster, pattern, figdir, total_counts, + dfp, **kwargs) + for pattern in patterns] + return template_vdom_metacluster(metacluster, + n_patterns, + n_seqlets, + important_for, + pattern_vdoms, + is_open=is_open + ) + + +def vdom_modisco(mr, figdir, total_counts, dfp=None, is_open=True, **kwargs): + return div([vdom_metacluster(mr, metacluster, figdir, total_counts, dfp=dfp, + is_open=is_open, **kwargs) + for metacluster in mr.metaclusters() + if len(mr.pattern_names(metacluster)) > 0]) + + +def get_signal(seqlets, d: ContribFile, tasks, resize_width=200): + thr_one_hot = d.get_seq() + + if resize_width is None: + # width = first seqlets + resize_width = seqlets[0].end - seqlets[0].start + + # get valid seqlets + start_pad = np.ceil(resize_width / 2) + end_pad = thr_one_hot.shape[1] - start_pad + valid_seqlets = [s.resize(resize_width) + for s in seqlets + if (s.center() > start_pad) and (s.center() < end_pad)] + + # prepare data + ex_signal = {task: extract_signal(d.get_profiles()[task], valid_seqlets) + for task in tasks} + + ex_contrib_profile = {task: extract_signal(d.get_contrib()[task], + valid_seqlets).sum(axis=-1) + for task in tasks} + + if d.contains_contrib_score('count'): + ex_contrib_counts = {task: extract_signal(d.get_contrib("count")[task], + valid_seqlets).sum(axis=-1) for task in tasks} + elif d.contains_contrib_score('counts/pre-act'): + ex_contrib_counts = {task: extract_signal(d.get_contrib("counts/pre-act")[task], + valid_seqlets).sum(axis=-1) for task in tasks} + else: + ex_contrib_counts = None + + ex_seq = extract_signal(thr_one_hot, valid_seqlets) + + seq, contrib, hyp_contrib, profile, ranges = d.get_all() + + total_counts = sum([x.sum(axis=-1).sum(axis=-1) for x in ex_signal.values()]) + sort_idx = np.argsort(-total_counts) + return ex_signal, ex_contrib_profile, ex_contrib_counts, ex_seq, sort_idx + + +def vdm_heatmaps(seqlets, d, included_samples, tasks, pattern, top_n=None, pssm_fig=None, opened=False, resize_width=200): + ex_signal, ex_contrib_profile, ex_contrib_counts, ex_seq, sort_idx = get_signal(seqlets, d, included_samples, tasks, + resize_width=resize_width) + + if top_n is not None: + sort_idx = sort_idx[:top_n] + return div(details(summary("Sequence:"), + pssm_fig, + br(), + fig2vdom(heatmap_sequence(ex_seq, sort_idx=sort_idx, figsize_tmpl=(10, 15), aspect='auto')), + open=opened + ), + + details(summary("ChIP-nexus counts:"), + fig2vdom(multiple_plot_stranded_profile(ex_signal, figsize_tmpl=(20 / len(ex_signal), 3))), + # TODO - change + fig2vdom(multiple_heatmap_stranded_profile(ex_signal, sort_idx=sort_idx, figsize=(20, 20))), + open=opened + ), + details(summary("Contribution scores (profile)"), + fig2vdom(multiple_heatmap_contribution_profile(ex_contrib_profile, sort_idx=sort_idx, figsize=(20, 20))), + open=opened + ), + details(summary("Contribution scores (counts)"), + fig2vdom(multiple_heatmap_contribution_profile(ex_contrib_counts, sort_idx=sort_idx, figsize=(20, 20))), + open=opened + ) + ) + + +def write_heatmap_pngs(seqlets, d, tasks, pattern, output_dir, resize_width=200): + """Write out histogram png's + """ + # get the data + ex_signal, ex_contrib_profile, ex_contrib_counts, ex_seq, sort_idx = get_signal(seqlets, d, tasks, + resize_width=resize_width) + # get the plots + figs = dict( + heatmap_seq=heatmap_sequence(ex_seq, sort_idx=sort_idx, figsize_tmpl=(10, 15), aspect='auto'), + profile_aggregated=multiple_plot_stranded_profile(ex_signal, figsize_tmpl=(20 / len(ex_signal), 3)), + profile_heatmap=multiple_heatmap_stranded_profile(ex_signal, sort_idx=sort_idx, figsize=(20, 20)), + contrib_profile=multiple_heatmap_contribution_profile(ex_contrib_profile, sort_idx=sort_idx, figsize=(20, 20)), + ) + + if ex_contrib_counts is not None: + figs['contrib_counts'] = multiple_heatmap_contribution_profile(ex_contrib_counts, sort_idx=sort_idx, figsize=(20, 20)) + # write the figures + for k, fig in figs.items(): + fig.savefig(os.path.join(output_dir, k + ".png"), bbox_inches='tight') + + +def df2html(df, uuid='table', style='width:100%'): + import seaborn as sns + cm = sns.light_palette("green", as_cmap=True) + # leverage pandas style to color cells according to values + # https://pandas.pydata.org/pandas-docs/stable/style.html + s = df.style.background_gradient(cmap=cm).set_precision(3).hide_index() + return s.render(uuid=uuid).replace(f'<table id="T_{uuid}"', + f'<table id="T_{uuid}" class="compact hover nowrap" style="{style}"') + + +def df2html_old(df, style='width:100%'): + add_tags = f'id="table_id" style="{style}"' + with pd.option_context('display.max_colwidth', -1): + table = df.to_html(escape=False, + classes='display nowrap', + float_format='%.2g', + index=False).replace(' class="dataframe', f' {add_tags} class="dataframe') + return table + +# <link rel="stylesheet" type="text/css" href="https://cdn.datatables.net/1.10.13/css/jquery.dataTables.css"> + +def get_datatable_header(): + return ''' + <script type="text/javascript" src="https://code.jquery.com/jquery-3.3.1.js"></script> + <script type="text/javascript" src="https://cdn.datatables.net/1.10.19/js/jquery.dataTables.min.js"></script> + <script type="text/javascript" src="https://cdn.datatables.net/colreorder/1.5.1/js/dataTables.colReorder.min.js"></script> + <script type="text/javascript" src="https://cdn.datatables.net/fixedcolumns/3.2.6/js/dataTables.fixedColumns.min.js"></script> + + <link rel="stylesheet" type="text/css" href="https://cdn.datatables.net/1.10.19/css/jquery.dataTables.min.css"> + <link rel="stylesheet" type="text/css" href="https://cdn.datatables.net/colreorder/1.5.1/css/colReorder.dataTables.min.css"> + <link rel="stylesheet" type="text/css" href="https://cdn.datatables.net/fixedcolumns/3.2.6/css/fixedColumns.dataTables.min.css"> + <link rel="stylesheet" href="https://cdn.jupyter.org/notebook/5.1.0/style/style.min.css"> + ''' + + +def style_html_table_datatable(html_str): + from IPython.display import HTML, Javascript + + header = f''' + <!DOCTYPE html> + <html lang="en"> + <head> + {get_datatable_header()} + <head> + </body> + ''' + script = ''' + <script> + $(document).ready( function () { + var table = $('#T_table').DataTable({ + scrollX: true, + scrollY: '80vh', + scrollCollapse: true, + paging: false, + colReorder: true, + columnDefs: [ + { orderable: false, targets: 0 }, + { orderable: false, targets: 1 } + ], + ordering: [[ 1, 'asc' ]], + colReorder: { + fixedColumnsLeft: 1, + fixedColumnsRight: 0 + } + }); + + new $.fn.dataTable.FixedColumns( table, { + leftColumns: 3, + rightColumns: 0 + } ); + + // Select rows + $('#T_table tbody').on( 'click', 'tr', function () { + $(this).toggleClass('selected'); + } ); + + } ); + </script> + </body> + </html> + ''' + + return header + html_str + script + +def write_datatable_html(df, output_file, other=""): + html = style_html_table_datatable(df2html(df) + other) + with open(output_file, "w") as f: + f.write(html) + + +def render_datatable(df): + from IPython.display import HTML, Javascript, display + display(HTML(get_datatable_header() + df2html(df))) + # display(Javascript(""" $(document).ready( function () { + # $('#T_table').DataTable(); + # } );""")) + + +def footprint_df(footprints, dfl=None, width=120, **kwargs): + """Draw footprints sparklines into a pandas.DataFrame + + Args: + footprints: footprint dict with `<pattern>/<task>` nested structure + each node contains an array of shape (seq_len, 2) + dfl: optional pandas.DataFrame of labels. Contains columns: + pattern <task>/l + width: width of the final plot + **kwargs: additional kwargs to pass to vdom_footprint + """ + from tqdm import tqdm + from bpnet.modisco.utils import shorten_pattern + + def map_label(l): + """Label -> short-name + """ + # TODO - get rid of this function + if l is None: + return "/" + else: + return l[0].upper() + tasks = list(footprints[list(footprints)[0]].keys()) + profile_max_median = {task: np.median([np.max(v[task]) for v in footprints.values()]) for task in tasks} + out = [] + + for p, arr_d in tqdm(footprints.items()): + try: + labels = dfl[dfl.pattern == shorten_pattern(p)].iloc[0].to_dict() + except Exception: + labels = {t + "/l": None for t in tasks} + d = {task: vdom_footprint(arr_d[task], + r_height=profile_max_median[task], + text=map_label(labels[task + "/l"]), + **kwargs).to_html().replace("<img", + f"<img width={width}") + for task in tasks} + d['pattern'] = shorten_pattern(p) + out.append(d) + return pd.DataFrame(out)