|
a |
|
b/bpnet/plot/profiles.py |
|
|
1 |
from tqdm import tqdm |
|
|
2 |
import seaborn as sns |
|
|
3 |
import pandas as pd |
|
|
4 |
import os |
|
|
5 |
import matplotlib.pyplot as plt |
|
|
6 |
import numpy as np |
|
|
7 |
from bpnet.plot.utils import simple_yaxis_format, strip_axis, spaced_xticks |
|
|
8 |
from bpnet.modisco.utils import bootstrap_mean, nan_like, ic_scale |
|
|
9 |
from bpnet.plot.utils import show_figure |
|
|
10 |
|
|
|
11 |
|
|
|
12 |
# TODO - make it as a bar-plot with two standard colors: |
|
|
13 |
# #B23F49 (pos), #045CA8 (neg) |
|
|
14 |
def plot_stranded_profile(profile, ax=None, ymax=None, profile_std=None, flip_neg=True, set_ylim=True): |
|
|
15 |
"""Plot the stranded profile |
|
|
16 |
""" |
|
|
17 |
if ax is None: |
|
|
18 |
ax = plt.gca() |
|
|
19 |
|
|
|
20 |
if profile.ndim == 1: |
|
|
21 |
# also compatible with single dim |
|
|
22 |
profile = profile[:, np.newaxis] |
|
|
23 |
assert profile.ndim == 2 |
|
|
24 |
assert profile.shape[1] <= 2 |
|
|
25 |
labels = ['pos', 'neg'] |
|
|
26 |
|
|
|
27 |
# determine ymax if not specified |
|
|
28 |
if ymax is None: |
|
|
29 |
if profile_std is not None: |
|
|
30 |
ymax = (profile.max() - 2 * profile_std).max() |
|
|
31 |
else: |
|
|
32 |
ymax = profile.max() |
|
|
33 |
|
|
|
34 |
if set_ylim: |
|
|
35 |
if flip_neg: |
|
|
36 |
ax.set_ylim([-ymax, ymax]) |
|
|
37 |
else: |
|
|
38 |
ax.set_ylim([0, ymax]) |
|
|
39 |
|
|
|
40 |
ax.axhline(y=0, linewidth=1, linestyle='--', color='black') |
|
|
41 |
# strip_axis(ax) |
|
|
42 |
|
|
|
43 |
xvec = np.arange(1, len(profile) + 1) |
|
|
44 |
|
|
|
45 |
for i in range(profile.shape[1]): |
|
|
46 |
sign = 1 if not flip_neg or i == 0 else -1 |
|
|
47 |
ax.plot(xvec, sign * profile[:, i], label=labels[i]) |
|
|
48 |
|
|
|
49 |
# plot also the ribbons |
|
|
50 |
if profile_std is not None: |
|
|
51 |
ax.fill_between(xvec, |
|
|
52 |
sign * profile[:, i] - 2 * profile_std[:, i], |
|
|
53 |
sign * profile[:, i] + 2 * profile_std[:, i], |
|
|
54 |
alpha=0.1) |
|
|
55 |
# return ax |
|
|
56 |
|
|
|
57 |
|
|
|
58 |
def multiple_plot_stranded_profile(d_profile, figsize_tmpl=(4, 3), normalize=False): |
|
|
59 |
fig, axes = plt.subplots(1, len(d_profile), |
|
|
60 |
figsize=(figsize_tmpl[0] * len(d_profile), figsize_tmpl[1]), |
|
|
61 |
sharey=True) |
|
|
62 |
if len(d_profile)==1: #If only one task, then can't zip axes |
|
|
63 |
ax = axes |
|
|
64 |
task = [*d_profile][0] |
|
|
65 |
arr = d_profile[task].mean(axis=0) |
|
|
66 |
if normalize: |
|
|
67 |
arr = arr / arr.max() |
|
|
68 |
plot_stranded_profile(arr, ax=ax, set_ylim=False) |
|
|
69 |
ax.set_title(task) |
|
|
70 |
ax.set_ylabel("Avg. counts") |
|
|
71 |
ax.set_xlabel("Position") |
|
|
72 |
fig.subplots_adjust(wspace=0) # no space between plots |
|
|
73 |
return fig |
|
|
74 |
else: |
|
|
75 |
for i, (task, ax) in enumerate(zip(d_profile, axes)): |
|
|
76 |
arr = d_profile[task].mean(axis=0) |
|
|
77 |
if normalize: |
|
|
78 |
arr = arr / arr.max() |
|
|
79 |
plot_stranded_profile(arr, ax=ax, set_ylim=False) |
|
|
80 |
ax.set_title(task) |
|
|
81 |
if i == 0: |
|
|
82 |
ax.set_ylabel("Avg. counts") |
|
|
83 |
ax.set_xlabel("Position") |
|
|
84 |
fig.subplots_adjust(wspace=0) # no space between plots |
|
|
85 |
return fig |
|
|
86 |
|
|
|
87 |
|
|
|
88 |
def aggregate_profiles(profile_arr, n_bootstrap=None, only_idx=None): |
|
|
89 |
if only_idx is not None: |
|
|
90 |
return profile_arr[only_idx], None |
|
|
91 |
|
|
|
92 |
if n_bootstrap is not None: |
|
|
93 |
return bootstrap_mean(profile_arr, n=n_bootstrap) |
|
|
94 |
else: |
|
|
95 |
return profile_arr.mean(axis=0), None |
|
|
96 |
|
|
|
97 |
|
|
|
98 |
def extract_signal(x, seqlets, rc_fn=lambda x: x[::-1, ::-1]): |
|
|
99 |
def optional_rc(x, is_rc): |
|
|
100 |
if is_rc: |
|
|
101 |
return rc_fn(x) |
|
|
102 |
else: |
|
|
103 |
return x |
|
|
104 |
return np.stack([optional_rc(x[s['example'], s['start']:s['end']], s['rc']) |
|
|
105 |
for s in seqlets]) |
|
|
106 |
|
|
|
107 |
|
|
|
108 |
def plot_profiles(seqlets_by_pattern, |
|
|
109 |
x, |
|
|
110 |
tracks, |
|
|
111 |
contribution_scores={}, |
|
|
112 |
figsize=(20, 2), |
|
|
113 |
start_vec=None, |
|
|
114 |
width=20, |
|
|
115 |
legend=True, |
|
|
116 |
rotate_y=90, |
|
|
117 |
seq_height=1, |
|
|
118 |
ymax=None, # determine y-max |
|
|
119 |
n_limit=35, |
|
|
120 |
n_bootstrap=None, |
|
|
121 |
flip_neg=False, |
|
|
122 |
patterns=None, |
|
|
123 |
fpath_template=None, |
|
|
124 |
only_idx=None, |
|
|
125 |
mkdir=False, |
|
|
126 |
rc_fn=lambda x: x[::-1, ::-1]): |
|
|
127 |
""" |
|
|
128 |
Plot the sequence profiles |
|
|
129 |
Args: |
|
|
130 |
x: one-hot-encoded sequence |
|
|
131 |
tracks: dictionary of profile tracks |
|
|
132 |
contribution_scores: optional dictionary of contribution scores |
|
|
133 |
|
|
|
134 |
""" |
|
|
135 |
import matplotlib.pyplot as plt |
|
|
136 |
from concise.utils.plot import seqlogo_fig, seqlogo |
|
|
137 |
|
|
|
138 |
# Setup start-vec |
|
|
139 |
if start_vec is not None: |
|
|
140 |
if not isinstance(start_vec, list): |
|
|
141 |
start_vec = [start_vec] * len(patterns) |
|
|
142 |
else: |
|
|
143 |
start_vec = [0] * len(patterns) |
|
|
144 |
width = len(x) |
|
|
145 |
|
|
|
146 |
if patterns is None: |
|
|
147 |
patterns = list(seqlets_by_pattern) |
|
|
148 |
# aggregated profiles |
|
|
149 |
d_signal_patterns = {pattern: |
|
|
150 |
{k: aggregate_profiles( |
|
|
151 |
extract_signal(y, seqlets_by_pattern[pattern])[:, start_vec[ip]:(start_vec[ip] + width)], |
|
|
152 |
n_bootstrap=n_bootstrap, only_idx=only_idx) |
|
|
153 |
for k, y in tracks.items()} |
|
|
154 |
for ip, pattern in enumerate(patterns)} |
|
|
155 |
if ymax is None: |
|
|
156 |
# infer ymax |
|
|
157 |
def take_max(x, dx): |
|
|
158 |
if dx is None: |
|
|
159 |
return x.max() |
|
|
160 |
else: |
|
|
161 |
# HACK - hard-coded 2 |
|
|
162 |
return (x + 2 * dx).max() |
|
|
163 |
|
|
|
164 |
ymax = [max([take_max(*d_signal_patterns[pattern][k]) |
|
|
165 |
for pattern in patterns]) |
|
|
166 |
for k in tracks] # loop through all the tracks |
|
|
167 |
if not isinstance(ymax, list): |
|
|
168 |
ymax = [ymax] * len(tracks) |
|
|
169 |
|
|
|
170 |
figs = [] |
|
|
171 |
for i, pattern in enumerate(tqdm(patterns)): |
|
|
172 |
j = i |
|
|
173 |
# -------------- |
|
|
174 |
# extract signal |
|
|
175 |
seqs = extract_signal(x, seqlets_by_pattern[pattern])[:, start_vec[i]:(start_vec[i] + width)] |
|
|
176 |
ext_contribution_scores = {s: extract_signal(contrib, seqlets_by_pattern[pattern])[:, start_vec[i]:(start_vec[i] + width)] |
|
|
177 |
for s, contrib in contribution_scores.items()} |
|
|
178 |
d_signal = d_signal_patterns[pattern] |
|
|
179 |
# -------------- |
|
|
180 |
if only_idx is None: |
|
|
181 |
sequence = ic_scale(seqs.mean(axis=0)) |
|
|
182 |
else: |
|
|
183 |
sequence = seqs[only_idx] |
|
|
184 |
|
|
|
185 |
n = len(seqs) |
|
|
186 |
if n < n_limit: |
|
|
187 |
continue |
|
|
188 |
fig, ax = plt.subplots(1 + len(contribution_scores) + len(tracks), |
|
|
189 |
1, sharex=True, |
|
|
190 |
figsize=figsize, |
|
|
191 |
gridspec_kw={'height_ratios': [1] * len(tracks) + [seq_height] * (1 + len(contribution_scores))}) |
|
|
192 |
|
|
|
193 |
# signal |
|
|
194 |
ax[0].set_title(f"{pattern} ({n})") |
|
|
195 |
for i, (k, signal) in enumerate(d_signal.items()): |
|
|
196 |
signal_mean, signal_std = d_signal_patterns[pattern][k] |
|
|
197 |
plot_stranded_profile(signal_mean, ax=ax[i], ymax=ymax[i], |
|
|
198 |
profile_std=signal_std, flip_neg=flip_neg) |
|
|
199 |
simple_yaxis_format(ax[i]) |
|
|
200 |
strip_axis(ax[i]) |
|
|
201 |
ax[i].set_ylabel(f"{k}", rotation=rotate_y, ha='right', labelpad=5) |
|
|
202 |
|
|
|
203 |
if legend: |
|
|
204 |
ax[i].legend() |
|
|
205 |
|
|
|
206 |
# ----------- |
|
|
207 |
# contribution scores (seqlogo) |
|
|
208 |
# ----------- |
|
|
209 |
# average the contribution scores |
|
|
210 |
if only_idx is None: |
|
|
211 |
norm_contribution_scores = {k: v.mean(axis=0) |
|
|
212 |
for k, v in ext_contribution_scores.items()} |
|
|
213 |
else: |
|
|
214 |
norm_contribution_scores = {k: v[only_idx] |
|
|
215 |
for k, v in ext_contribution_scores.items()} |
|
|
216 |
|
|
|
217 |
max_scale = max([np.maximum(v, 0).sum(axis=-1).max() for v in norm_contribution_scores.values()]) |
|
|
218 |
min_scale = min([np.minimum(v, 0).sum(axis=-1).min() for v in norm_contribution_scores.values()]) |
|
|
219 |
for k, (contrib_score_name, logo) in enumerate(norm_contribution_scores.items()): |
|
|
220 |
ax_id = len(tracks) + k |
|
|
221 |
|
|
|
222 |
# Trim the pattern if necessary |
|
|
223 |
# plot |
|
|
224 |
ax[ax_id].set_ylim([min_scale, max_scale]) |
|
|
225 |
ax[ax_id].axhline(y=0, linewidth=1, linestyle='--', color='grey') |
|
|
226 |
seqlogo(logo, ax=ax[ax_id]) |
|
|
227 |
|
|
|
228 |
# style |
|
|
229 |
simple_yaxis_format(ax[ax_id]) |
|
|
230 |
strip_axis(ax[ax_id]) |
|
|
231 |
# ax[ax_id].set_ylabel(contrib_score_name) |
|
|
232 |
ax[ax_id].set_ylabel(contrib_score_name, rotation=rotate_y, ha='right', labelpad=5) # va='bottom', |
|
|
233 |
|
|
|
234 |
# ----------- |
|
|
235 |
# information content (seqlogo) |
|
|
236 |
# ----------- |
|
|
237 |
# plot |
|
|
238 |
seqlogo(sequence, ax=ax[-1]) |
|
|
239 |
|
|
|
240 |
# style |
|
|
241 |
simple_yaxis_format(ax[-1]) |
|
|
242 |
strip_axis(ax[-1]) |
|
|
243 |
ax[-1].set_ylabel("Inf. content", rotation=rotate_y, ha='right', labelpad=5) |
|
|
244 |
ax[-1].set_xticks(list(range(0, len(sequence) + 1, 5))) |
|
|
245 |
|
|
|
246 |
figs.append(fig) |
|
|
247 |
# save to file |
|
|
248 |
if fpath_template is not None: |
|
|
249 |
pname = pattern.replace("/", ".") |
|
|
250 |
basepath = fpath_template.format(pname=pname, pattern=pattern) |
|
|
251 |
if mkdir: |
|
|
252 |
os.makedirs(os.path.dirname(basepath), exist_ok=True) |
|
|
253 |
plt.savefig(basepath + '.png', dpi=600) |
|
|
254 |
plt.savefig(basepath + '.pdf', dpi=600) |
|
|
255 |
plt.close(fig) # close the figure |
|
|
256 |
show_figure(fig) |
|
|
257 |
plt.show() |
|
|
258 |
return figs |
|
|
259 |
|
|
|
260 |
|
|
|
261 |
def plot_profiles_single(seqlet, |
|
|
262 |
x, |
|
|
263 |
tracks, |
|
|
264 |
contribution_scores={}, |
|
|
265 |
figsize=(20, 2), |
|
|
266 |
legend=True, |
|
|
267 |
rotate_y=90, |
|
|
268 |
seq_height=1, |
|
|
269 |
flip_neg=False, |
|
|
270 |
rc_fn=lambda x: x[::-1, ::-1]): |
|
|
271 |
""" |
|
|
272 |
Plot the sequence profiles |
|
|
273 |
Args: |
|
|
274 |
x: one-hot-encoded sequence |
|
|
275 |
tracks: dictionary of profile tracks |
|
|
276 |
contribution_scores: optional dictionary of contribution scores |
|
|
277 |
|
|
|
278 |
""" |
|
|
279 |
import matplotlib.pyplot as plt |
|
|
280 |
from concise.utils.plot import seqlogo_fig, seqlogo |
|
|
281 |
|
|
|
282 |
# -------------- |
|
|
283 |
# extract signal |
|
|
284 |
seq = seqlet.extract(x) |
|
|
285 |
ext_contribution_scores = {s: seqlet.extract(contrib) for s, contrib in contribution_scores.items()} |
|
|
286 |
|
|
|
287 |
fig, ax = plt.subplots(1 + len(contribution_scores) + len(tracks), |
|
|
288 |
1, sharex=True, |
|
|
289 |
figsize=figsize, |
|
|
290 |
gridspec_kw={'height_ratios': [1] * len(tracks) + [seq_height] * (1 + len(contribution_scores))}) |
|
|
291 |
|
|
|
292 |
# signal |
|
|
293 |
for i, (k, signal) in enumerate(tracks.items()): |
|
|
294 |
plot_stranded_profile(seqlet.extract(signal), ax=ax[i], |
|
|
295 |
flip_neg=flip_neg) |
|
|
296 |
simple_yaxis_format(ax[i]) |
|
|
297 |
strip_axis(ax[i]) |
|
|
298 |
ax[i].set_ylabel(f"{k}", rotation=rotate_y, ha='right', labelpad=5) |
|
|
299 |
|
|
|
300 |
if legend: |
|
|
301 |
ax[i].legend() |
|
|
302 |
|
|
|
303 |
# ----------- |
|
|
304 |
# contribution scores (seqlogo) |
|
|
305 |
# ----------- |
|
|
306 |
max_scale = max([np.maximum(v, 0).sum(axis=-1).max() for v in ext_contribution_scores.values()]) |
|
|
307 |
min_scale = min([np.minimum(v, 0).sum(axis=-1).min() for v in ext_contribution_scores.values()]) |
|
|
308 |
for k, (contrib_score_name, logo) in enumerate(ext_contribution_scores.items()): |
|
|
309 |
ax_id = len(tracks) + k |
|
|
310 |
# plot |
|
|
311 |
ax[ax_id].set_ylim([min_scale, max_scale]) |
|
|
312 |
ax[ax_id].axhline(y=0, linewidth=1, linestyle='--', color='grey') |
|
|
313 |
seqlogo(logo, ax=ax[ax_id]) |
|
|
314 |
|
|
|
315 |
# style |
|
|
316 |
simple_yaxis_format(ax[ax_id]) |
|
|
317 |
strip_axis(ax[ax_id]) |
|
|
318 |
# ax[ax_id].set_ylabel(contrib_score_name) |
|
|
319 |
ax[ax_id].set_ylabel(contrib_score_name, rotation=rotate_y, ha='right', labelpad=5) # va='bottom', |
|
|
320 |
|
|
|
321 |
# ----------- |
|
|
322 |
# information content (seqlogo) |
|
|
323 |
# ----------- |
|
|
324 |
# plot |
|
|
325 |
seqlogo(seq, ax=ax[-1]) |
|
|
326 |
|
|
|
327 |
# style |
|
|
328 |
simple_yaxis_format(ax[-1]) |
|
|
329 |
strip_axis(ax[-1]) |
|
|
330 |
ax[-1].set_ylabel("Inf. content", rotation=rotate_y, ha='right', labelpad=5) |
|
|
331 |
ax[-1].set_xticks(list(range(0, len(seq) + 1, 5))) |
|
|
332 |
return fig |
|
|
333 |
|
|
|
334 |
|
|
|
335 |
def hist_position(dfp, tasks): |
|
|
336 |
"""Make the positional histogram |
|
|
337 |
|
|
|
338 |
Args: |
|
|
339 |
dfp: pd.DataFrame with columns: peak_id, and center |
|
|
340 |
tasks: list of tasks for which to plot the different peak_id columns |
|
|
341 |
""" |
|
|
342 |
fig, axes = plt.subplots(1, len(tasks), figsize=(5 * len(tasks), 2), |
|
|
343 |
sharey=True, sharex=True) |
|
|
344 |
if len(tasks) == 1: |
|
|
345 |
axes = [axes] |
|
|
346 |
for i, (task, ax) in enumerate(zip(tasks, axes)): |
|
|
347 |
ax.hist(dfp[dfp.peak_id == task].center, bins=100) |
|
|
348 |
ax.set_title(task) |
|
|
349 |
ax.set_xlabel("Position") |
|
|
350 |
ax.set_xlim([0, 1000]) |
|
|
351 |
if i == 0: |
|
|
352 |
ax.set_ylabel("Frequency") |
|
|
353 |
plt.subplots_adjust(wspace=0) |
|
|
354 |
return fig |
|
|
355 |
|
|
|
356 |
|
|
|
357 |
def bar_seqlets_per_example(dfp, tasks): |
|
|
358 |
"""Make the positional histogram |
|
|
359 |
|
|
|
360 |
Args: |
|
|
361 |
dfp: pd.DataFrame with columns: peak_id, and center |
|
|
362 |
tasks: list of tasks for which to plot the different peak_id columns |
|
|
363 |
""" |
|
|
364 |
fig, axes = plt.subplots(1, len(tasks), figsize=(5 * len(tasks), 2), |
|
|
365 |
sharey=True, sharex=True) |
|
|
366 |
if len(tasks) == 1: |
|
|
367 |
axes = [axes] |
|
|
368 |
for i, (task, ax) in enumerate(zip(tasks, axes)): |
|
|
369 |
dfpp = dfp[dfp.peak_id == task] |
|
|
370 |
ax.set_title(task) |
|
|
371 |
ax.set_xlabel("Frequency") |
|
|
372 |
if i == 0: |
|
|
373 |
ax.set_ylabel("# per example") |
|
|
374 |
if not len(dfpp): |
|
|
375 |
continue |
|
|
376 |
dfpp.groupby("example_idx").\ |
|
|
377 |
size().value_counts().plot(kind="barh", ax=ax) |
|
|
378 |
plt.subplots_adjust(wspace=0) |
|
|
379 |
return fig |
|
|
380 |
|
|
|
381 |
|
|
|
382 |
def box_counts(total_counts, pattern_idx): |
|
|
383 |
"""Make a box-plot with total counts in the region |
|
|
384 |
|
|
|
385 |
Args: |
|
|
386 |
total_counts: dict per task |
|
|
387 |
pattern_idx: array with example_idx of the pattern |
|
|
388 |
""" |
|
|
389 |
dfs = pd.concat([total_counts.melt().assign(subset="all peaks"), |
|
|
390 |
total_counts.iloc[pattern_idx].melt().assign(subset="contains pattern")]) |
|
|
391 |
dfs.value = np.log10(1 + dfs.value) |
|
|
392 |
|
|
|
393 |
fig, ax = plt.subplots(figsize=(5, 5)) |
|
|
394 |
sns.boxplot("variable", "value", hue="subset", data=dfs, ax=ax) |
|
|
395 |
ax.set_xlabel("Task") |
|
|
396 |
ax.set_ylabel("log10(1+counts)") |
|
|
397 |
ax.set_title("Total number of counts in the region") |
|
|
398 |
return fig |