Diff of /bpnet/plot/profiles.py [000000] .. [d45a3a]

Switch to unified view

a b/bpnet/plot/profiles.py
1
from tqdm import tqdm
2
import seaborn as sns
3
import pandas as pd
4
import os
5
import matplotlib.pyplot as plt
6
import numpy as np
7
from bpnet.plot.utils import simple_yaxis_format, strip_axis, spaced_xticks
8
from bpnet.modisco.utils import bootstrap_mean, nan_like, ic_scale
9
from bpnet.plot.utils import show_figure
10
11
12
# TODO - make it as a bar-plot with two standard colors:
13
# #B23F49 (pos), #045CA8 (neg)
14
def plot_stranded_profile(profile, ax=None, ymax=None, profile_std=None, flip_neg=True, set_ylim=True):
15
    """Plot the stranded profile
16
    """
17
    if ax is None:
18
        ax = plt.gca()
19
20
    if profile.ndim == 1:
21
        # also compatible with single dim
22
        profile = profile[:, np.newaxis]
23
    assert profile.ndim == 2
24
    assert profile.shape[1] <= 2
25
    labels = ['pos', 'neg']
26
27
    # determine ymax if not specified
28
    if ymax is None:
29
        if profile_std is not None:
30
            ymax = (profile.max() - 2 * profile_std).max()
31
        else:
32
            ymax = profile.max()
33
34
    if set_ylim:
35
        if flip_neg:
36
            ax.set_ylim([-ymax, ymax])
37
        else:
38
            ax.set_ylim([0, ymax])
39
40
    ax.axhline(y=0, linewidth=1, linestyle='--', color='black')
41
    # strip_axis(ax)
42
43
    xvec = np.arange(1, len(profile) + 1)
44
45
    for i in range(profile.shape[1]):
46
        sign = 1 if not flip_neg or i == 0 else -1
47
        ax.plot(xvec, sign * profile[:, i], label=labels[i])
48
49
        # plot also the ribbons
50
        if profile_std is not None:
51
            ax.fill_between(xvec,
52
                            sign * profile[:, i] - 2 * profile_std[:, i],
53
                            sign * profile[:, i] + 2 * profile_std[:, i],
54
                            alpha=0.1)
55
    # return ax
56
57
58
def multiple_plot_stranded_profile(d_profile, figsize_tmpl=(4, 3), normalize=False):
59
    fig, axes = plt.subplots(1, len(d_profile),
60
                             figsize=(figsize_tmpl[0] * len(d_profile), figsize_tmpl[1]),
61
                             sharey=True)
62
    if len(d_profile)==1: #If only one task, then can't zip axes
63
        ax = axes
64
        task = [*d_profile][0]
65
        arr = d_profile[task].mean(axis=0)
66
        if normalize:
67
            arr = arr / arr.max()
68
        plot_stranded_profile(arr, ax=ax, set_ylim=False)
69
        ax.set_title(task)
70
        ax.set_ylabel("Avg. counts")
71
        ax.set_xlabel("Position")
72
        fig.subplots_adjust(wspace=0)  # no space between plots
73
        return fig
74
    else:
75
        for i, (task, ax) in enumerate(zip(d_profile, axes)):
76
            arr = d_profile[task].mean(axis=0)
77
            if normalize:
78
                arr = arr / arr.max()
79
            plot_stranded_profile(arr, ax=ax, set_ylim=False)
80
            ax.set_title(task)
81
            if i == 0:
82
                ax.set_ylabel("Avg. counts")
83
                ax.set_xlabel("Position")
84
        fig.subplots_adjust(wspace=0)  # no space between plots
85
        return fig
86
87
88
def aggregate_profiles(profile_arr, n_bootstrap=None, only_idx=None):
89
    if only_idx is not None:
90
        return profile_arr[only_idx], None
91
92
    if n_bootstrap is not None:
93
        return bootstrap_mean(profile_arr, n=n_bootstrap)
94
    else:
95
        return profile_arr.mean(axis=0), None
96
97
98
def extract_signal(x, seqlets, rc_fn=lambda x: x[::-1, ::-1]):
99
    def optional_rc(x, is_rc):
100
        if is_rc:
101
            return rc_fn(x)
102
        else:
103
            return x
104
    return np.stack([optional_rc(x[s['example'], s['start']:s['end']], s['rc'])
105
                     for s in seqlets])
106
107
108
def plot_profiles(seqlets_by_pattern,
109
                  x,
110
                  tracks,
111
                  contribution_scores={},
112
                  figsize=(20, 2),
113
                  start_vec=None,
114
                  width=20,
115
                  legend=True,
116
                  rotate_y=90,
117
                  seq_height=1,
118
                  ymax=None,  # determine y-max
119
                  n_limit=35,
120
                  n_bootstrap=None,
121
                  flip_neg=False,
122
                  patterns=None,
123
                  fpath_template=None,
124
                  only_idx=None,
125
                  mkdir=False,
126
                  rc_fn=lambda x: x[::-1, ::-1]):
127
    """
128
    Plot the sequence profiles
129
    Args:
130
      x: one-hot-encoded sequence
131
      tracks: dictionary of profile tracks
132
      contribution_scores: optional dictionary of contribution scores
133
134
    """
135
    import matplotlib.pyplot as plt
136
    from concise.utils.plot import seqlogo_fig, seqlogo
137
138
    # Setup start-vec
139
    if start_vec is not None:
140
        if not isinstance(start_vec, list):
141
            start_vec = [start_vec] * len(patterns)
142
    else:
143
        start_vec = [0] * len(patterns)
144
        width = len(x)
145
146
    if patterns is None:
147
        patterns = list(seqlets_by_pattern)
148
    # aggregated profiles
149
    d_signal_patterns = {pattern:
150
                         {k: aggregate_profiles(
151
                             extract_signal(y, seqlets_by_pattern[pattern])[:, start_vec[ip]:(start_vec[ip] + width)],
152
                             n_bootstrap=n_bootstrap, only_idx=only_idx)
153
                          for k, y in tracks.items()}
154
                         for ip, pattern in enumerate(patterns)}
155
    if ymax is None:
156
        # infer ymax
157
        def take_max(x, dx):
158
            if dx is None:
159
                return x.max()
160
            else:
161
                # HACK - hard-coded 2
162
                return (x + 2 * dx).max()
163
164
        ymax = [max([take_max(*d_signal_patterns[pattern][k])
165
                     for pattern in patterns])
166
                for k in tracks]  # loop through all the tracks
167
    if not isinstance(ymax, list):
168
        ymax = [ymax] * len(tracks)
169
170
    figs = []
171
    for i, pattern in enumerate(tqdm(patterns)):
172
        j = i
173
        # --------------
174
        # extract signal
175
        seqs = extract_signal(x, seqlets_by_pattern[pattern])[:, start_vec[i]:(start_vec[i] + width)]
176
        ext_contribution_scores = {s: extract_signal(contrib, seqlets_by_pattern[pattern])[:, start_vec[i]:(start_vec[i] + width)]
177
                                   for s, contrib in contribution_scores.items()}
178
        d_signal = d_signal_patterns[pattern]
179
        # --------------
180
        if only_idx is None:
181
            sequence = ic_scale(seqs.mean(axis=0))
182
        else:
183
            sequence = seqs[only_idx]
184
185
        n = len(seqs)
186
        if n < n_limit:
187
            continue
188
        fig, ax = plt.subplots(1 + len(contribution_scores) + len(tracks),
189
                               1, sharex=True,
190
                               figsize=figsize,
191
                               gridspec_kw={'height_ratios': [1] * len(tracks) + [seq_height] * (1 + len(contribution_scores))})
192
193
        # signal
194
        ax[0].set_title(f"{pattern} ({n})")
195
        for i, (k, signal) in enumerate(d_signal.items()):
196
            signal_mean, signal_std = d_signal_patterns[pattern][k]
197
            plot_stranded_profile(signal_mean, ax=ax[i], ymax=ymax[i],
198
                                  profile_std=signal_std, flip_neg=flip_neg)
199
            simple_yaxis_format(ax[i])
200
            strip_axis(ax[i])
201
            ax[i].set_ylabel(f"{k}", rotation=rotate_y, ha='right', labelpad=5)
202
203
            if legend:
204
                ax[i].legend()
205
206
        # -----------
207
        # contribution scores (seqlogo)
208
        # -----------
209
        # average the contribution scores
210
        if only_idx is None:
211
            norm_contribution_scores = {k: v.mean(axis=0)
212
                                        for k, v in ext_contribution_scores.items()}
213
        else:
214
            norm_contribution_scores = {k: v[only_idx]
215
                                        for k, v in ext_contribution_scores.items()}
216
217
        max_scale = max([np.maximum(v, 0).sum(axis=-1).max() for v in norm_contribution_scores.values()])
218
        min_scale = min([np.minimum(v, 0).sum(axis=-1).min() for v in norm_contribution_scores.values()])
219
        for k, (contrib_score_name, logo) in enumerate(norm_contribution_scores.items()):
220
            ax_id = len(tracks) + k
221
222
            # Trim the pattern if necessary
223
            # plot
224
            ax[ax_id].set_ylim([min_scale, max_scale])
225
            ax[ax_id].axhline(y=0, linewidth=1, linestyle='--', color='grey')
226
            seqlogo(logo, ax=ax[ax_id])
227
228
            # style
229
            simple_yaxis_format(ax[ax_id])
230
            strip_axis(ax[ax_id])
231
            # ax[ax_id].set_ylabel(contrib_score_name)
232
            ax[ax_id].set_ylabel(contrib_score_name, rotation=rotate_y, ha='right', labelpad=5)  # va='bottom',
233
234
        # -----------
235
        # information content (seqlogo)
236
        # -----------
237
        # plot
238
        seqlogo(sequence, ax=ax[-1])
239
240
        # style
241
        simple_yaxis_format(ax[-1])
242
        strip_axis(ax[-1])
243
        ax[-1].set_ylabel("Inf. content", rotation=rotate_y, ha='right', labelpad=5)
244
        ax[-1].set_xticks(list(range(0, len(sequence) + 1, 5)))
245
246
        figs.append(fig)
247
        # save to file
248
        if fpath_template is not None:
249
            pname = pattern.replace("/", ".")
250
            basepath = fpath_template.format(pname=pname, pattern=pattern)
251
            if mkdir:
252
                os.makedirs(os.path.dirname(basepath), exist_ok=True)
253
            plt.savefig(basepath + '.png', dpi=600)
254
            plt.savefig(basepath + '.pdf', dpi=600)
255
            plt.close(fig)    # close the figure
256
            show_figure(fig)
257
            plt.show()
258
    return figs
259
260
261
def plot_profiles_single(seqlet,
262
                         x,
263
                         tracks,
264
                         contribution_scores={},
265
                         figsize=(20, 2),
266
                         legend=True,
267
                         rotate_y=90,
268
                         seq_height=1,
269
                         flip_neg=False,
270
                         rc_fn=lambda x: x[::-1, ::-1]):
271
    """
272
    Plot the sequence profiles
273
    Args:
274
      x: one-hot-encoded sequence
275
      tracks: dictionary of profile tracks
276
      contribution_scores: optional dictionary of contribution scores
277
278
    """
279
    import matplotlib.pyplot as plt
280
    from concise.utils.plot import seqlogo_fig, seqlogo
281
282
    # --------------
283
    # extract signal
284
    seq = seqlet.extract(x)
285
    ext_contribution_scores = {s: seqlet.extract(contrib) for s, contrib in contribution_scores.items()}
286
287
    fig, ax = plt.subplots(1 + len(contribution_scores) + len(tracks),
288
                           1, sharex=True,
289
                           figsize=figsize,
290
                           gridspec_kw={'height_ratios': [1] * len(tracks) + [seq_height] * (1 + len(contribution_scores))})
291
292
    # signal
293
    for i, (k, signal) in enumerate(tracks.items()):
294
        plot_stranded_profile(seqlet.extract(signal), ax=ax[i],
295
                              flip_neg=flip_neg)
296
        simple_yaxis_format(ax[i])
297
        strip_axis(ax[i])
298
        ax[i].set_ylabel(f"{k}", rotation=rotate_y, ha='right', labelpad=5)
299
300
        if legend:
301
            ax[i].legend()
302
303
    # -----------
304
    # contribution scores (seqlogo)
305
    # -----------
306
    max_scale = max([np.maximum(v, 0).sum(axis=-1).max() for v in ext_contribution_scores.values()])
307
    min_scale = min([np.minimum(v, 0).sum(axis=-1).min() for v in ext_contribution_scores.values()])
308
    for k, (contrib_score_name, logo) in enumerate(ext_contribution_scores.items()):
309
        ax_id = len(tracks) + k
310
        # plot
311
        ax[ax_id].set_ylim([min_scale, max_scale])
312
        ax[ax_id].axhline(y=0, linewidth=1, linestyle='--', color='grey')
313
        seqlogo(logo, ax=ax[ax_id])
314
315
        # style
316
        simple_yaxis_format(ax[ax_id])
317
        strip_axis(ax[ax_id])
318
        # ax[ax_id].set_ylabel(contrib_score_name)
319
        ax[ax_id].set_ylabel(contrib_score_name, rotation=rotate_y, ha='right', labelpad=5)  # va='bottom',
320
321
    # -----------
322
    # information content (seqlogo)
323
    # -----------
324
    # plot
325
    seqlogo(seq, ax=ax[-1])
326
327
    # style
328
    simple_yaxis_format(ax[-1])
329
    strip_axis(ax[-1])
330
    ax[-1].set_ylabel("Inf. content", rotation=rotate_y, ha='right', labelpad=5)
331
    ax[-1].set_xticks(list(range(0, len(seq) + 1, 5)))
332
    return fig
333
334
335
def hist_position(dfp, tasks):
336
    """Make the positional histogram
337
338
    Args:
339
      dfp: pd.DataFrame with columns: peak_id, and center
340
      tasks: list of tasks for which to plot the different peak_id columns
341
    """
342
    fig, axes = plt.subplots(1, len(tasks), figsize=(5 * len(tasks), 2),
343
                             sharey=True, sharex=True)
344
    if len(tasks) == 1:
345
        axes = [axes]
346
    for i, (task, ax) in enumerate(zip(tasks, axes)):
347
        ax.hist(dfp[dfp.peak_id == task].center, bins=100)
348
        ax.set_title(task)
349
        ax.set_xlabel("Position")
350
        ax.set_xlim([0, 1000])
351
        if i == 0:
352
            ax.set_ylabel("Frequency")
353
    plt.subplots_adjust(wspace=0)
354
    return fig
355
356
357
def bar_seqlets_per_example(dfp, tasks):
358
    """Make the positional histogram
359
360
    Args:
361
      dfp: pd.DataFrame with columns: peak_id, and center
362
      tasks: list of tasks for which to plot the different peak_id columns
363
    """
364
    fig, axes = plt.subplots(1, len(tasks), figsize=(5 * len(tasks), 2),
365
                             sharey=True, sharex=True)
366
    if len(tasks) == 1:
367
        axes = [axes]
368
    for i, (task, ax) in enumerate(zip(tasks, axes)):
369
        dfpp = dfp[dfp.peak_id == task]
370
        ax.set_title(task)
371
        ax.set_xlabel("Frequency")
372
        if i == 0:
373
            ax.set_ylabel("# per example")
374
        if not len(dfpp):
375
            continue
376
        dfpp.groupby("example_idx").\
377
            size().value_counts().plot(kind="barh", ax=ax)
378
    plt.subplots_adjust(wspace=0)
379
    return fig
380
381
382
def box_counts(total_counts, pattern_idx):
383
    """Make a box-plot with total counts in the region
384
385
    Args:
386
      total_counts: dict per task
387
      pattern_idx: array with example_idx of the pattern
388
    """
389
    dfs = pd.concat([total_counts.melt().assign(subset="all peaks"),
390
                     total_counts.iloc[pattern_idx].melt().assign(subset="contains pattern")])
391
    dfs.value = np.log10(1 + dfs.value)
392
393
    fig, ax = plt.subplots(figsize=(5, 5))
394
    sns.boxplot("variable", "value", hue="subset", data=dfs, ax=ax)
395
    ax.set_xlabel("Task")
396
    ax.set_ylabel("log10(1+counts)")
397
    ax.set_title("Total number of counts in the region")
398
    return fig