a b/bpnet/plot/vdom.py
1
"""Vdom visualization for modisco
2
"""
3
import pandas as pd
4
from bpnet.plot.heatmaps import multiple_heatmap_stranded_profile, multiple_heatmap_contribution_profile, heatmap_sequence
5
from bpnet.cli.contrib import ContribFile
6
from collections import OrderedDict
7
from bpnet.plot.profiles import extract_signal, multiple_plot_stranded_profile, hist_position, bar_seqlets_per_example, box_counts
8
from bpnet.functions import mean
9
import numpy as np
10
import pandas as pd
11
from vdom.helpers import (h1, p, li, img, div, b, br, ul, img,
12
                          details, summary,
13
                          table, thead, th, tr, tbody, td, ol)
14
import io
15
import base64
16
import urllib
17
import matplotlib.pyplot as plt
18
import os
19
20
21
def fig2vdom(fig, **kwargs):
22
    """Convert a matplotlib figure to an online image
23
    """
24
    buf = io.BytesIO()
25
    fig.savefig(buf, format='png', bbox_inches='tight')
26
    buf.seek(0)
27
    string = base64.b64encode(buf.read())
28
    plt.close()
29
    return img(src='data:image/png;base64,' + urllib.parse.quote(string), **kwargs)
30
31
32
def vdom_pssm(pssm, letter_width=0.2, letter_height=0.8, **kwargs):
33
    """Nicely plot the pssm
34
    """
35
    import matplotlib.pyplot as plt
36
    from concise.utils.plot import seqlogo_fig, seqlogo
37
    fig, ax = plt.subplots(figsize=(letter_width * len(pssm), letter_height))
38
    ax.axison = False
39
    seqlogo(pssm, ax=ax)
40
    plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
41
    return fig2vdom(fig, **kwargs)
42
43
44
def vdom_footprint(arr, r_height=None, text=None,
45
                   fontsize=32, figsize=(3, 1), **kwargs):
46
    """Plot the sparkline for the footprint
47
48
    Args:
49
      arr: np.array of shape (seq_len, 2)
50
      r_height: if not None, add a rectangle with heigth = r_height
51
      text: add additional text to top right corner
52
      fontsize: size of the additional font
53
      figsize: figure size
54
      **kwargs: additional kwargs passed to `fig2vdom`
55
56
    Returns:
57
      VDOM object containing the image
58
    """
59
    import matplotlib.patches as patches
60
61
    fig, ax = plt.subplots(figsize=figsize)
62
    # print(arr.max())
63
    if r_height is not None:
64
        rect = patches.Rectangle((0, 0), len(arr),
65
                                 r_height,  # / arr.max(),
66
                                 linewidth=1,
67
                                 edgecolor=None,
68
                                 alpha=0.3,
69
                                 facecolor='lightgrey')
70
        ax.add_patch(rect)
71
        ax.set_ylim([0, max(r_height, arr.max())])
72
        ax.axhline(r_height, alpha=0.3, color='black', linestyle='dashed')
73
    ax.plot(arr[:, 0])
74
    ax.plot(arr[:, 1])
75
76
    if text is not None:
77
        # Annotate text top-left
78
        pass
79
        ax.text(1, 1, text,
80
                fontsize=fontsize,
81
                transform=ax.transAxes,
82
                verticalalignment='top',
83
                horizontalalignment='right')
84
85
    ax.axison = False
86
    plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
87
    return fig2vdom(fig, **kwargs)
88
89
90
def template_vdom_pattern(name, n_seqlets, trimmed_motif,
91
                          full_motif, figures_url, add_plots={}, metacluster=""):
92
93
    return details(summary(name, f": # seqlets: {n_seqlets}",
94
                           # br(),
95
                           trimmed_motif),  # ", rc: ",  motif_rc),
96
                   details(summary("Aggregated profiles and contribution scores)"),
97
                           img(src=figures_url + "/agg_profile_contribcores.png", width=840),
98
                           ),
99
                   details(summary("Aggregated hypothetical contribution scores)"),
100
                           img(src=figures_url + "/agg_profile_hypcontribscores.png", width=840),
101
                           ),
102
                   details(summary("Sequence"),
103
                           full_motif,
104
                           br(),
105
                           img(src=figures_url + "/heatmap_seq.png", width=840 // 2),
106
                           ),
107
                   details(summary("ChIP-nexus counts"),
108
                           img(src=figures_url + "/profile_aggregated.png", width=840),
109
                           img(src=figures_url + "/profile_heatmap.png", width=840),
110
                           ),
111
                   details(summary("Contribution scores (profile)"),
112
                           img(src=figures_url + "/contrib_profile.png", width=840),
113
                           ),
114
                   details(summary("Contribution scores (counts)"),
115
                           img(src=figures_url + "/contrib_counts.png", width=840),
116
                           ),
117
                   *[details(summary(k), *v) for k, v in add_plots.items()],
118
                   id=metacluster + "/" + name
119
                   )
120
121
122
def vdom_pattern(mr, metacluster, pattern,
123
                 figdir,
124
                 total_counts,
125
                 dfp,
126
                 trim_frac=0.05,
127
                 letter_width=0.2, height=0.8):
128
129
    # get the trimmed motifs
130
    trimmed_motif = vdom_pssm(mr.get_pssm(metacluster + '/' + pattern,
131
                                          rc=False, trim_frac=trim_frac),
132
                              letter_width=letter_width,
133
                              height=height)
134
    full_motif = vdom_pssm(mr.get_pssm(metacluster + '/' + pattern,
135
                                       rc=False, trim_frac=0),
136
                           letter_width=letter_width,
137
                           height=height)
138
139
    # ----------------
140
    # add new plots here
141
    dfpp = dfp[dfp.pattern == (metacluster + "/" + pattern)]
142
    tasks = dfp.peak_id.unique()
143
    pattern_idx = dfpp.example_idx.unique()
144
    add_plots = OrderedDict([
145
        ("Positional distribution",
146
         [fig2vdom(hist_position(dfpp, tasks=tasks)),
147
          fig2vdom(bar_seqlets_per_example(dfpp, tasks=tasks))
148
          ]),
149
        ("Total count distribution",
150
         [p(f"Pattern occurs in {len(pattern_idx)} / {len(total_counts)} regions"
151
            f" ({100*len(pattern_idx)/len(total_counts):.1f}%)"),
152
          fig2vdom(box_counts(total_counts, pattern_idx))]
153
         )
154
    ])
155
    # ----------------
156
157
    return template_vdom_pattern(name=pattern,
158
                                 n_seqlets=mr.n_seqlets(metacluster + "/" + pattern),
159
                                 trimmed_motif=trimmed_motif,
160
                                 full_motif=full_motif,
161
                                 figures_url=os.path.join(figdir, f"{metacluster}/{pattern}"),
162
                                 add_plots=add_plots,
163
                                 metacluster=metacluster,
164
                                 )
165
166
167
def template_vdom_metacluster(name, n_patterns, n_seqlets, important_for, patterns, is_open=False):
168
    return details(summary(b(name), f", # patterns: {n_patterns},"
169
                           f" # seqlets: {n_seqlets}, "
170
                           "important for: ", b(important_for)),
171
                   ul([li(pattern) for pattern in patterns], start=0),
172
                   id=name,
173
                   open=is_open)
174
175
176
def vdom_metacluster(mr, metacluster, figdir, total_counts, dfp=None, is_open=True,
177
                     **kwargs):
178
    patterns = mr.pattern_names(metacluster)
179
    n_seqlets = sum([mr.n_seqlets(metacluster + "/" + pattern)
180
                     for pattern in patterns])
181
    n_patterns = len(patterns)
182
183
    def render_act(task, act):
184
        """Render the activity vector
185
        """
186
        task = task.replace("/weighted", "").replace("/profile", "")  # omit weighted or profile
187
        if act == 0:
188
            return ""
189
        elif act < 0:
190
            return f"-{task}"
191
        else:
192
            return task
193
    activities = mr.metacluster_activity(metacluster)
194
195
    # tasks = mr.tasks()
196
    # tasks = unique_list([task.split("/")[0] for task in tasks])  # HACK. For some
197
    # TODO - one could pretify this here by using Task, and cTask
198
199
    important_for = ",".join([render_act(task, act)
200
                              for task, act in zip(mr.tasks(), activities)
201
                              if act != 0])
202
    pattern_vdoms = [vdom_pattern(mr, metacluster, pattern, figdir, total_counts,
203
                                  dfp, **kwargs)
204
                     for pattern in patterns]
205
    return template_vdom_metacluster(metacluster,
206
                                     n_patterns,
207
                                     n_seqlets,
208
                                     important_for,
209
                                     pattern_vdoms,
210
                                     is_open=is_open
211
                                     )
212
213
214
def vdom_modisco(mr, figdir, total_counts, dfp=None, is_open=True, **kwargs):
215
    return div([vdom_metacluster(mr, metacluster, figdir, total_counts, dfp=dfp,
216
                                 is_open=is_open, **kwargs)
217
                for metacluster in mr.metaclusters()
218
                if len(mr.pattern_names(metacluster)) > 0])
219
220
221
def get_signal(seqlets, d: ContribFile, tasks, resize_width=200):
222
    thr_one_hot = d.get_seq()
223
224
    if resize_width is None:
225
        # width = first seqlets
226
        resize_width = seqlets[0].end - seqlets[0].start
227
228
    # get valid seqlets
229
    start_pad = np.ceil(resize_width / 2)
230
    end_pad = thr_one_hot.shape[1] - start_pad
231
    valid_seqlets = [s.resize(resize_width)
232
                     for s in seqlets
233
                     if (s.center() > start_pad) and (s.center() < end_pad)]
234
235
    # prepare data
236
    ex_signal = {task: extract_signal(d.get_profiles()[task], valid_seqlets)
237
                 for task in tasks}
238
239
    ex_contrib_profile = {task: extract_signal(d.get_contrib()[task],
240
                                               valid_seqlets).sum(axis=-1)
241
                          for task in tasks}
242
243
    if d.contains_contrib_score('count'):
244
        ex_contrib_counts = {task: extract_signal(d.get_contrib("count")[task],
245
                                                  valid_seqlets).sum(axis=-1) for task in tasks}
246
    elif d.contains_contrib_score('counts/pre-act'):
247
        ex_contrib_counts = {task: extract_signal(d.get_contrib("counts/pre-act")[task],
248
                                                  valid_seqlets).sum(axis=-1) for task in tasks}
249
    else:
250
        ex_contrib_counts = None
251
252
    ex_seq = extract_signal(thr_one_hot, valid_seqlets)
253
254
    seq, contrib, hyp_contrib, profile, ranges = d.get_all()
255
256
    total_counts = sum([x.sum(axis=-1).sum(axis=-1) for x in ex_signal.values()])
257
    sort_idx = np.argsort(-total_counts)
258
    return ex_signal, ex_contrib_profile, ex_contrib_counts, ex_seq, sort_idx
259
260
261
def vdm_heatmaps(seqlets, d, included_samples, tasks, pattern, top_n=None, pssm_fig=None, opened=False, resize_width=200):
262
    ex_signal, ex_contrib_profile, ex_contrib_counts, ex_seq, sort_idx = get_signal(seqlets, d, included_samples, tasks,
263
                                                                                    resize_width=resize_width)
264
265
    if top_n is not None:
266
        sort_idx = sort_idx[:top_n]
267
    return div(details(summary("Sequence:"),
268
                       pssm_fig,
269
                       br(),
270
                       fig2vdom(heatmap_sequence(ex_seq, sort_idx=sort_idx, figsize_tmpl=(10, 15), aspect='auto')),
271
                       open=opened
272
                       ),
273
274
               details(summary("ChIP-nexus counts:"),
275
                       fig2vdom(multiple_plot_stranded_profile(ex_signal, figsize_tmpl=(20 / len(ex_signal), 3))),
276
                       # TODO - change
277
                       fig2vdom(multiple_heatmap_stranded_profile(ex_signal, sort_idx=sort_idx, figsize=(20, 20))),
278
                       open=opened
279
                       ),
280
               details(summary("Contribution scores (profile)"),
281
                       fig2vdom(multiple_heatmap_contribution_profile(ex_contrib_profile, sort_idx=sort_idx, figsize=(20, 20))),
282
                       open=opened
283
                       ),
284
               details(summary("Contribution scores (counts)"),
285
                       fig2vdom(multiple_heatmap_contribution_profile(ex_contrib_counts, sort_idx=sort_idx, figsize=(20, 20))),
286
                       open=opened
287
                       )
288
               )
289
290
291
def write_heatmap_pngs(seqlets, d, tasks, pattern, output_dir, resize_width=200):
292
    """Write out histogram png's
293
    """
294
    # get the data
295
    ex_signal, ex_contrib_profile, ex_contrib_counts, ex_seq, sort_idx = get_signal(seqlets, d, tasks,
296
                                                                                    resize_width=resize_width)
297
    # get the plots
298
    figs = dict(
299
        heatmap_seq=heatmap_sequence(ex_seq, sort_idx=sort_idx, figsize_tmpl=(10, 15), aspect='auto'),
300
        profile_aggregated=multiple_plot_stranded_profile(ex_signal, figsize_tmpl=(20 / len(ex_signal), 3)),
301
        profile_heatmap=multiple_heatmap_stranded_profile(ex_signal, sort_idx=sort_idx, figsize=(20, 20)),
302
        contrib_profile=multiple_heatmap_contribution_profile(ex_contrib_profile, sort_idx=sort_idx, figsize=(20, 20)),
303
    )
304
305
    if ex_contrib_counts is not None:
306
        figs['contrib_counts'] = multiple_heatmap_contribution_profile(ex_contrib_counts, sort_idx=sort_idx, figsize=(20, 20))
307
    # write the figures
308
    for k, fig in figs.items():
309
        fig.savefig(os.path.join(output_dir, k + ".png"), bbox_inches='tight')
310
311
312
def df2html(df, uuid='table', style='width:100%'):
313
    import seaborn as sns
314
    cm = sns.light_palette("green", as_cmap=True)
315
    # leverage pandas style to color cells according to values
316
    # https://pandas.pydata.org/pandas-docs/stable/style.html
317
    s = df.style.background_gradient(cmap=cm).set_precision(3).hide_index()
318
    return s.render(uuid=uuid).replace(f'<table id="T_{uuid}"',
319
                                       f'<table id="T_{uuid}" class="compact hover nowrap" style="{style}"')
320
321
322
def df2html_old(df, style='width:100%'):
323
    add_tags = f'id="table_id" style="{style}"'
324
    with pd.option_context('display.max_colwidth', -1):
325
        table = df.to_html(escape=False,
326
                           classes='display nowrap',
327
                           float_format='%.2g',
328
                           index=False).replace(' class="dataframe', f' {add_tags} class="dataframe')
329
    return table
330
331
# <link rel="stylesheet" type="text/css" href="https://cdn.datatables.net/1.10.13/css/jquery.dataTables.css">
332
333
def get_datatable_header():
334
    return '''
335
      <script type="text/javascript" src="https://code.jquery.com/jquery-3.3.1.js"></script>
336
      <script type="text/javascript"  src="https://cdn.datatables.net/1.10.19/js/jquery.dataTables.min.js"></script>
337
      <script type="text/javascript"  src="https://cdn.datatables.net/colreorder/1.5.1/js/dataTables.colReorder.min.js"></script>
338
      <script type="text/javascript"  src="https://cdn.datatables.net/fixedcolumns/3.2.6/js/dataTables.fixedColumns.min.js"></script>
339
      
340
      <link rel="stylesheet" type="text/css" href="https://cdn.datatables.net/1.10.19/css/jquery.dataTables.min.css">     
341
      <link rel="stylesheet" type="text/css" href="https://cdn.datatables.net/colreorder/1.5.1/css/colReorder.dataTables.min.css">
342
      <link rel="stylesheet" type="text/css" href="https://cdn.datatables.net/fixedcolumns/3.2.6/css/fixedColumns.dataTables.min.css">
343
      <link rel="stylesheet" href="https://cdn.jupyter.org/notebook/5.1.0/style/style.min.css">
344
    '''
345
346
347
def style_html_table_datatable(html_str):
348
    from IPython.display import HTML, Javascript
349
350
    header = f'''
351
    <!DOCTYPE html>
352
    <html lang="en">
353
    <head>
354
    {get_datatable_header()}
355
    <head>
356
    </body>
357
        '''
358
    script = '''
359
    <script>
360
    $(document).ready( function () {
361
    var table = $('#T_table').DataTable({
362
         scrollX: true,
363
         scrollY: '80vh',
364
         scrollCollapse: true,
365
         paging: false,
366
         colReorder: true,
367
         columnDefs: [
368
            { orderable: false, targets: 0 },
369
            { orderable: false, targets: 1 }
370
        ],
371
        ordering: [[ 1, 'asc' ]],
372
        colReorder: {
373
            fixedColumnsLeft: 1,
374
            fixedColumnsRight: 0
375
        }
376
    });
377
378
    new $.fn.dataTable.FixedColumns( table, {
379
        leftColumns: 3,
380
        rightColumns: 0
381
    } );
382
383
    // Select rows
384
    $('#T_table tbody').on( 'click', 'tr', function () {
385
        $(this).toggleClass('selected');
386
    } );
387
388
    } );
389
    </script>
390
    </body>
391
    </html>
392
    '''
393
394
    return header + html_str + script
395
396
def write_datatable_html(df, output_file, other=""):
397
    html = style_html_table_datatable(df2html(df) + other)
398
    with open(output_file, "w") as f:
399
        f.write(html)
400
401
402
def render_datatable(df):
403
    from IPython.display import HTML, Javascript, display
404
    display(HTML(get_datatable_header() + df2html(df)))
405
    # display(Javascript(""" $(document).ready( function () {
406
    # $('#T_table').DataTable();
407
    # } );"""))
408
409
410
def footprint_df(footprints, dfl=None, width=120, **kwargs):
411
    """Draw footprints sparklines into a pandas.DataFrame
412
413
    Args:
414
      footprints: footprint dict with `<pattern>/<task>` nested structure
415
        each node contains an array of shape (seq_len, 2)
416
      dfl: optional pandas.DataFrame of labels. Contains columns:
417
        pattern <task>/l
418
      width: width of the final plot
419
      **kwargs: additional kwargs to pass to vdom_footprint
420
    """
421
    from tqdm import tqdm
422
    from bpnet.modisco.utils import shorten_pattern
423
424
    def map_label(l):
425
        """Label -> short-name
426
        """
427
        # TODO - get rid of this function
428
        if l is None:
429
            return "/"
430
        else:
431
            return l[0].upper()
432
    tasks = list(footprints[list(footprints)[0]].keys())
433
    profile_max_median = {task: np.median([np.max(v[task]) for v in footprints.values()]) for task in tasks}
434
    out = []
435
436
    for p, arr_d in tqdm(footprints.items()):
437
        try:
438
            labels = dfl[dfl.pattern == shorten_pattern(p)].iloc[0].to_dict()
439
        except Exception:
440
            labels = {t + "/l": None for t in tasks}
441
        d = {task: vdom_footprint(arr_d[task],
442
                                  r_height=profile_max_median[task],
443
                                  text=map_label(labels[task + "/l"]),
444
                                  **kwargs).to_html().replace("<img",
445
                                                              f"<img width={width}")
446
             for task in tasks}
447
        d['pattern'] = shorten_pattern(p)
448
        out.append(d)
449
    return pd.DataFrame(out)