|
a |
|
b/bpnet/modisco/pattern_instances.py |
|
|
1 |
"""Code for working with the pattern instances table |
|
|
2 |
produced by `bpnet.cli.modisco.modisco_score2` |
|
|
3 |
which calls `pattern.get_instances` |
|
|
4 |
""" |
|
|
5 |
from bpnet.stats import quantile_norm |
|
|
6 |
from collections import OrderedDict |
|
|
7 |
from tqdm import tqdm |
|
|
8 |
import pandas as pd |
|
|
9 |
from bpnet.modisco.utils import longer_pattern, shorten_pattern |
|
|
10 |
from bpnet.cli.modisco import get_nonredundant_example_idx |
|
|
11 |
import numpy as np |
|
|
12 |
from bpnet.plot.profiles import extract_signal |
|
|
13 |
from bpnet.modisco.core import resize_seqlets, Seqlet |
|
|
14 |
from bpnet.modisco.utils import trim_pssm_idx |
|
|
15 |
|
|
|
16 |
|
|
|
17 |
def get_motif_pairs(motifs): |
|
|
18 |
"""Generate motif pairs |
|
|
19 |
""" |
|
|
20 |
pairs = [] |
|
|
21 |
for i in range(len(motifs)): |
|
|
22 |
for j in range(i, len(motifs)): |
|
|
23 |
pairs.append([list(motifs)[i], list(motifs)[j], ]) |
|
|
24 |
return pairs |
|
|
25 |
|
|
|
26 |
|
|
|
27 |
comp_strand_compbination = { |
|
|
28 |
"++": "--", |
|
|
29 |
"--": "++", |
|
|
30 |
"-+": "+-", |
|
|
31 |
"+-": "-+" |
|
|
32 |
} |
|
|
33 |
|
|
|
34 |
strand_combinations = ["++", "--", "+-", "-+"] |
|
|
35 |
|
|
|
36 |
|
|
|
37 |
# TODO - allow these to be of also other type? |
|
|
38 |
def load_instances(parq_file, motifs=None, dedup=True, verbose=True): |
|
|
39 |
"""Load pattern instances from the parquet file |
|
|
40 |
|
|
|
41 |
Args: |
|
|
42 |
parq_file: parquet file of motif instances |
|
|
43 |
motifs: dictionary of motifs of interest. |
|
|
44 |
key=custom motif name, value=short pattern name (e.g. {'Nanog': 'm0_p3'}) |
|
|
45 |
|
|
|
46 |
""" |
|
|
47 |
if motifs is not None: |
|
|
48 |
incl_motifs = {longer_pattern(m) for m in motifs.values()} |
|
|
49 |
else: |
|
|
50 |
incl_motifs = None |
|
|
51 |
|
|
|
52 |
if isinstance(parq_file, pd.DataFrame): |
|
|
53 |
dfi = parq_file |
|
|
54 |
else: |
|
|
55 |
if motifs is not None: |
|
|
56 |
from fastparquet import ParquetFile |
|
|
57 |
|
|
|
58 |
# Selectively load only the relevant patterns |
|
|
59 |
pf = ParquetFile(str(parq_file)) |
|
|
60 |
patterns = [shorten_pattern(pn) for pn in incl_motifs] |
|
|
61 |
dfi = pf.to_pandas(filters=[("pattern_short", "in", patterns)]) |
|
|
62 |
else: |
|
|
63 |
dfi = pd.read_parquet(str(parq_file), engine='fastparquet') |
|
|
64 |
if 'pattern' not in dfi: |
|
|
65 |
# assumes a hive-stored file |
|
|
66 |
dfi['pattern'] = dfi['dir0'].str.replace("pattern=", "").astype(str) + "/" + dfi['dir1'].astype(str) |
|
|
67 |
|
|
|
68 |
# filter |
|
|
69 |
if motifs is not None: |
|
|
70 |
dfi = dfi[dfi.pattern.isin(incl_motifs)] # NOTE this should already be removed |
|
|
71 |
if 'pattern_short' not in dfi: |
|
|
72 |
dfi['pattern_short'] = dfi['pattern'].map({k: shorten_pattern(k) for k in incl_motifs}) |
|
|
73 |
dfi['pattern_name'] = dfi['pattern_short'].map({v: k for k, v in motifs.items()}) |
|
|
74 |
else: |
|
|
75 |
dfi['pattern_short'] = dfi['pattern'].map({k: shorten_pattern(k) |
|
|
76 |
for k in dfi.pattern.unique()}) |
|
|
77 |
|
|
|
78 |
# add some columns if they don't yet exist |
|
|
79 |
if 'pattern_start_abs' not in dfi: |
|
|
80 |
dfi['pattern_start_abs'] = dfi['example_start'] + dfi['pattern_start'] |
|
|
81 |
if 'pattern_end_abs' not in dfi: |
|
|
82 |
dfi['pattern_end_abs'] = dfi['example_start'] + dfi['pattern_end'] |
|
|
83 |
|
|
|
84 |
if dedup: |
|
|
85 |
# deduplicate |
|
|
86 |
dfi_dedup = dfi.drop_duplicates(['pattern', |
|
|
87 |
'example_chrom', |
|
|
88 |
'pattern_start_abs', |
|
|
89 |
'pattern_end_abs', |
|
|
90 |
'strand']) |
|
|
91 |
|
|
|
92 |
# number of removed duplicates |
|
|
93 |
d = len(dfi) - len(dfi_dedup) |
|
|
94 |
if verbose: |
|
|
95 |
print("number of de-duplicated instances:", d, f"({d / len(dfi) * 100}%)") |
|
|
96 |
|
|
|
97 |
# use de-duplicated instances from now on |
|
|
98 |
dfi = dfi_dedup |
|
|
99 |
return dfi |
|
|
100 |
|
|
|
101 |
|
|
|
102 |
def multiple_load_instances(paths, motifs): |
|
|
103 |
""" |
|
|
104 |
Args: |
|
|
105 |
paths: dictionary <tf> -> instances.parq |
|
|
106 |
motifs: dictinoary with <motif_name> -> pattern name of |
|
|
107 |
the form `<TF>/m0_p1` |
|
|
108 |
""" |
|
|
109 |
from bpnet.utils import pd_col_prepend |
|
|
110 |
# load all the patterns |
|
|
111 |
|
|
|
112 |
dfi = pd.concat([load_instances(path, |
|
|
113 |
motifs=OrderedDict([(motif, pn.split("/", 1)[1]) |
|
|
114 |
for motif, pn in motifs.items() |
|
|
115 |
if pn.split("/", 1)[0] == tf]), |
|
|
116 |
dedup=False).assign(tf=tf).pipe(pd_col_prepend, ['pattern', 'pattern_short'], prefix=tf + "/") |
|
|
117 |
for tf, path in tqdm(paths.items()) |
|
|
118 |
]) |
|
|
119 |
return dfi |
|
|
120 |
|
|
|
121 |
|
|
|
122 |
def dfi_add_ranges(dfi, ranges, dedup=False): |
|
|
123 |
# Add absolute locations |
|
|
124 |
dfi = dfi.merge(ranges, on="example_idx", how='left') |
|
|
125 |
dfi['pattern_start_abs'] = dfi['example_start'] + dfi['pattern_start'] |
|
|
126 |
dfi['pattern_end_abs'] = dfi['example_start'] + dfi['pattern_end'] |
|
|
127 |
|
|
|
128 |
if dedup: |
|
|
129 |
# deduplicate |
|
|
130 |
dfi_dedup = dfi.drop_duplicates(['pattern', |
|
|
131 |
'example_chrom', |
|
|
132 |
'pattern_start_abs', |
|
|
133 |
'pattern_end_abs', |
|
|
134 |
'strand']) |
|
|
135 |
|
|
|
136 |
# number of removed duplicates |
|
|
137 |
d = len(dfi) - len(dfi_dedup) |
|
|
138 |
print("number of de-duplicated instances:", d, f"({d / len(dfi) * 100}%)") |
|
|
139 |
|
|
|
140 |
# use de-duplicated instances from now on |
|
|
141 |
dfi = dfi_dedup |
|
|
142 |
return dfi |
|
|
143 |
|
|
|
144 |
|
|
|
145 |
def dfi2pyranges(dfi): |
|
|
146 |
"""Convert dfi to pyranges |
|
|
147 |
|
|
|
148 |
Args: |
|
|
149 |
dfi: pd.DataFrame returned by `load_instances` |
|
|
150 |
""" |
|
|
151 |
import pyranges as pr |
|
|
152 |
dfi = dfi.copy() |
|
|
153 |
dfi['Chromosome'] = dfi['example_chrom'] |
|
|
154 |
dfi['Start'] = dfi['pattern_start_abs'] |
|
|
155 |
dfi['End'] = dfi['pattern_end_abs'] |
|
|
156 |
dfi['Name'] = dfi['pattern'] |
|
|
157 |
dfi['Score'] = dfi['contrib_weighted_p'] |
|
|
158 |
dfi['Strand'] = dfi['strand'] |
|
|
159 |
return pr.PyRanges(dfi) |
|
|
160 |
|
|
|
161 |
|
|
|
162 |
def align_instance_center(dfi, original_patterns, aligned_patterns, trim_frac=0.08): |
|
|
163 |
"""Align the center of the seqlets using aligned patterns |
|
|
164 |
|
|
|
165 |
Args: |
|
|
166 |
dfi: pd.DataFrame returned by `load_instances` |
|
|
167 |
original_patterns: un-trimmed patterns that were trimmed using |
|
|
168 |
trim_frac before scanning |
|
|
169 |
aligned_patterns: patterns that are all lined-up and that contain |
|
|
170 |
'align': {"use_rc", "offset" } information in the attrs |
|
|
171 |
trim_frac: original trim_frac used to trim the motifs |
|
|
172 |
|
|
|
173 |
Returns: |
|
|
174 |
dfi with 2 new columns: `pattern_center_aln` and `pattern_strand_aln` |
|
|
175 |
""" |
|
|
176 |
# NOTE - it would be nice to be able to give trimmed patterns instead of |
|
|
177 |
# `original_patterns` + `trim_frac` and just extract the trim stats from the pattern |
|
|
178 |
# TODO - shall we split this function into two -> one for dealling with |
|
|
179 |
# pattern trimming and one for dealing with aligning patterns? |
|
|
180 |
trim_shift_pos = {p.name: p._trim_center_shift(trim_frac=trim_frac)[0] |
|
|
181 |
for p in original_patterns} |
|
|
182 |
trim_shift_neg = {p.name: p._trim_center_shift(trim_frac=trim_frac)[1] |
|
|
183 |
for p in original_patterns} |
|
|
184 |
shift = {p.name: (p.attrs['align']['use_rc'] * 2 - 1) * p.attrs['align']['offset'] |
|
|
185 |
for p in aligned_patterns} |
|
|
186 |
strand_shift = {p.name: p.attrs['align']['use_rc'] for p in aligned_patterns} |
|
|
187 |
|
|
|
188 |
strand_vec = dfi.strand.map({"+": 1, "-": -1}) |
|
|
189 |
dfi['pattern_center_aln'] = (dfi.pattern_center - |
|
|
190 |
# - trim_shift since we are going from trimmed to non-trimmed |
|
|
191 |
np.where(dfi.strand == '-', |
|
|
192 |
dfi.pattern.map(trim_shift_neg), |
|
|
193 |
dfi.pattern.map(trim_shift_pos)) + |
|
|
194 |
# NOTE: `strand` should better be called `pattern_strand` |
|
|
195 |
dfi.pattern.map(shift) * strand_vec) |
|
|
196 |
|
|
|
197 |
def flip_strand(x): |
|
|
198 |
return x.map({"+": "-", "-": "+"}) |
|
|
199 |
|
|
|
200 |
# flip the strand |
|
|
201 |
dfi['pattern_strand_aln'] = np.where(dfi.pattern.map(strand_shift), |
|
|
202 |
# if True, then we are on the other strand |
|
|
203 |
flip_strand(dfi.strand), |
|
|
204 |
dfi.strand) |
|
|
205 |
return dfi |
|
|
206 |
|
|
|
207 |
|
|
|
208 |
def extract_ranges(dfi): |
|
|
209 |
"""Extract example ranges |
|
|
210 |
""" |
|
|
211 |
ranges = dfi[['example_chrom', 'example_start', |
|
|
212 |
'example_end', 'example_idx']].drop_duplicates() |
|
|
213 |
ranges.columns = ['chrom', 'start', 'end', 'example_idx'] |
|
|
214 |
return ranges |
|
|
215 |
|
|
|
216 |
|
|
|
217 |
def filter_nonoverlapping_intervals(dfi): |
|
|
218 |
ranges = extract_ranges(dfi) |
|
|
219 |
keep_idx = get_nonredundant_example_idx(ranges, 200) |
|
|
220 |
return dfi[dfi.example_idx.isin(keep_idx)] |
|
|
221 |
|
|
|
222 |
|
|
|
223 |
def plot_coocurence_matrix(dfi, total_examples, signif_threshold=1e-5, ax=None): |
|
|
224 |
"""Test for motif co-occurence in example regions |
|
|
225 |
|
|
|
226 |
Args: |
|
|
227 |
dfi: pattern instance DataFrame observer by load_instances |
|
|
228 |
total_examples: total number of examples |
|
|
229 |
""" |
|
|
230 |
import matplotlib.pyplot as plt |
|
|
231 |
if ax is None: |
|
|
232 |
ax = plt.gca() |
|
|
233 |
from sklearn.metrics import matthews_corrcoef |
|
|
234 |
from scipy.stats import fisher_exact |
|
|
235 |
import statsmodels as sm |
|
|
236 |
import seaborn as sns |
|
|
237 |
import matplotlib.pyplot as plt |
|
|
238 |
|
|
|
239 |
counts = pd.pivot_table(dfi, 'pattern_len', "example_idx", |
|
|
240 |
"pattern_name", aggfunc=len, fill_value=0) |
|
|
241 |
ndxs = list(counts) |
|
|
242 |
c = counts > 0 |
|
|
243 |
|
|
|
244 |
o = np.zeros((len(ndxs), len(ndxs))) |
|
|
245 |
op = np.zeros((len(ndxs), len(ndxs))) |
|
|
246 |
# fo = np.zeros((len(ndxs), len(ndxs))) |
|
|
247 |
# fp = np.zeros((len(ndxs), len(ndxs))) |
|
|
248 |
|
|
|
249 |
for i, xn in enumerate(ndxs): |
|
|
250 |
for j, yn in enumerate(ndxs): |
|
|
251 |
if xn == yn: |
|
|
252 |
continue |
|
|
253 |
ct = pd.crosstab(c[xn], c[yn]) |
|
|
254 |
# add not-counted 0 values: |
|
|
255 |
ct.iloc[0, 0] += total_examples - len(c) |
|
|
256 |
t22 = sm.stats.contingency_tables.Table2x2(ct) |
|
|
257 |
o[i, j] = np.log2(t22.oddsratio) |
|
|
258 |
op[i, j] = t22.oddsratio_pvalue() |
|
|
259 |
signif = op < signif_threshold |
|
|
260 |
a = np.zeros_like(signif).astype(str) |
|
|
261 |
a[signif] = "*" |
|
|
262 |
a[~signif] = "" |
|
|
263 |
np.fill_diagonal(a, '') |
|
|
264 |
|
|
|
265 |
sns.heatmap(pd.DataFrame(o, columns=ndxs, index=ndxs), |
|
|
266 |
annot=a, fmt="", vmin=-4, vmax=4, |
|
|
267 |
cmap='RdBu_r', ax=ax) |
|
|
268 |
ax.set_title(f"Log2 odds-ratio. (*: p<{signif_threshold})") |
|
|
269 |
|
|
|
270 |
|
|
|
271 |
def construct_motif_pairs(dfi, motif_pair, |
|
|
272 |
features=['match_weighted_p', |
|
|
273 |
'contrib_weighted_p', |
|
|
274 |
'contrib_weighted']): |
|
|
275 |
"""Construct motifs pair table |
|
|
276 |
""" |
|
|
277 |
dfi_filtered = dfi.set_index('example_idx', drop=False) |
|
|
278 |
counts = pd.pivot_table(dfi_filtered, |
|
|
279 |
'pattern_center', "example_idx", "pattern_name", |
|
|
280 |
aggfunc=len, fill_value=0) |
|
|
281 |
|
|
|
282 |
if motif_pair[0] != motif_pair[1]: |
|
|
283 |
relevant_examples_idx = counts.index[np.all(counts[motif_pair] == 1, 1)] |
|
|
284 |
else: |
|
|
285 |
relevant_examples_idx = counts.index[np.all(counts[motif_pair] == 2, 1)] |
|
|
286 |
|
|
|
287 |
dft = dfi_filtered.loc[relevant_examples_idx] |
|
|
288 |
dft = dft[dft.pattern_name.isin(motif_pair)] |
|
|
289 |
|
|
|
290 |
dft = dft.sort_values(['example_idx', 'pattern_center']) |
|
|
291 |
dft['pattern_order'] = dft.index.duplicated().astype(int) |
|
|
292 |
if motif_pair[0] == motif_pair[1]: |
|
|
293 |
dft['pattern_name'] = dft['pattern_name'] + dft['pattern_order'].astype(str) |
|
|
294 |
motif_pair = [motif_pair[0] + '0', motif_pair[1] + '1'] |
|
|
295 |
|
|
|
296 |
dftw = dft.set_index(['pattern_name'], append=True)[['pattern_center', |
|
|
297 |
'strand'] + features].unstack() |
|
|
298 |
|
|
|
299 |
dftw['center_diff'] = dftw['pattern_center'][motif_pair].diff(axis=1).iloc[:, 1] |
|
|
300 |
|
|
|
301 |
dftw_filt = dftw[np.abs(dftw.center_diff) > 10] |
|
|
302 |
|
|
|
303 |
dftw_filt['distance'] = np.abs(dftw_filt['center_diff']) |
|
|
304 |
dftw_filt['strand_combination'] = dftw_filt['strand'][motif_pair].sum(1) |
|
|
305 |
return dftw_filt |
|
|
306 |
|
|
|
307 |
|
|
|
308 |
def dfi_row2seqlet(row, short_name=False): |
|
|
309 |
return Seqlet(row.example_idx, |
|
|
310 |
row.pattern_start, |
|
|
311 |
row.pattern_end, |
|
|
312 |
name=shorten_pattern(row.pattern) if short_name else row.pattern, |
|
|
313 |
strand=row.strand) |
|
|
314 |
|
|
|
315 |
|
|
|
316 |
def dfi2seqlets(dfi, short_name=False): |
|
|
317 |
"""Convert the data-frame produced by pattern.get_instances() |
|
|
318 |
to a list of Seqlets |
|
|
319 |
|
|
|
320 |
Args: |
|
|
321 |
dfi: pd.DataFrame returned by pattern.get_instances() |
|
|
322 |
short_name: if True, short pattern name will be used for the seqlet name |
|
|
323 |
|
|
|
324 |
Returns: |
|
|
325 |
Seqlet list |
|
|
326 |
""" |
|
|
327 |
return [dfi_row2seqlet(row, short_name=short_name) |
|
|
328 |
for i, row in dfi.iterrows()] |
|
|
329 |
|
|
|
330 |
|
|
|
331 |
def profile_features(seqlets, ref_seqlets, profile, profile_width=70): |
|
|
332 |
from bpnet.simulate import profile_sim_metrics |
|
|
333 |
# resize |
|
|
334 |
seqlets = resize_seqlets(seqlets, profile_width, seqlen=profile.shape[1]) |
|
|
335 |
seqlets_ref = resize_seqlets(ref_seqlets, profile_width, seqlen=profile.shape[1]) |
|
|
336 |
# import pdb |
|
|
337 |
# pdb.set_trace() |
|
|
338 |
|
|
|
339 |
# extract the profile |
|
|
340 |
seqlet_profile = extract_signal(profile, seqlets) |
|
|
341 |
seqlet_profile_ref = extract_signal(profile, seqlets_ref) |
|
|
342 |
|
|
|
343 |
# compute the average profile |
|
|
344 |
avg_profile = seqlet_profile_ref.mean(axis=0) |
|
|
345 |
|
|
|
346 |
metrics = pd.DataFrame([profile_sim_metrics(avg_profile + 1e-6, cp + 1e-6) |
|
|
347 |
for cp in seqlet_profile]) |
|
|
348 |
metrics_ref = pd.DataFrame([profile_sim_metrics(avg_profile + 1e-6, cp + 1e-6) |
|
|
349 |
for cp in seqlet_profile_ref]) |
|
|
350 |
|
|
|
351 |
assert len(metrics) == len(seqlets) # needs to be the same length |
|
|
352 |
|
|
|
353 |
if metrics.simmetric_kl.min() == np.inf or \ |
|
|
354 |
metrics_ref.simmetric_kl.min() == np.inf: |
|
|
355 |
profile_match_p = None |
|
|
356 |
else: |
|
|
357 |
profile_match_p = quantile_norm(metrics.simmetric_kl, metrics_ref.simmetric_kl) |
|
|
358 |
return pd.DataFrame(OrderedDict([ |
|
|
359 |
("profile_match", metrics.simmetric_kl), |
|
|
360 |
("profile_match_p", profile_match_p), |
|
|
361 |
("profile_counts", metrics['counts']), |
|
|
362 |
("profile_counts_p", quantile_norm(metrics['counts'], metrics_ref['counts'])), |
|
|
363 |
("profile_max", metrics['max']), |
|
|
364 |
("profile_max_p", quantile_norm(metrics['max'], metrics_ref['max'])), |
|
|
365 |
("profile_counts_max_ref", metrics['counts_max_ref']), |
|
|
366 |
("profile_counts_max_ref_p", quantile_norm(metrics['counts_max_ref'], |
|
|
367 |
metrics_ref['counts_max_ref'])), |
|
|
368 |
])) |
|
|
369 |
|
|
|
370 |
|
|
|
371 |
def dfi_filter_valid(df, profile_width, seqlen): |
|
|
372 |
return df[(df.pattern_center.round() - profile_width > 0) |
|
|
373 |
& ((df.pattern_center + profile_width < seqlen))] |
|
|
374 |
|
|
|
375 |
|
|
|
376 |
def annotate_profile_single(dfi, pattern_name, mr, profiles, profile_width=70, trim_frac=0.08): |
|
|
377 |
seqlen = profiles[list(profiles)[0]].shape[1] |
|
|
378 |
|
|
|
379 |
dfi = dfi_filter_valid(dfi.copy(), profile_width, seqlen) |
|
|
380 |
dfi['id'] = np.arange(len(dfi)) |
|
|
381 |
assert np.all(dfi.pattern == pattern_name) |
|
|
382 |
|
|
|
383 |
dfp_pattern_list = [] |
|
|
384 |
dfi_subset = dfi |
|
|
385 |
ref_seqlets = mr._get_seqlets(pattern_name, trim_frac=trim_frac) |
|
|
386 |
dfi_seqlets = dfi2seqlets(dfi_subset) |
|
|
387 |
for task in profiles: |
|
|
388 |
dfp = profile_features(dfi_seqlets, |
|
|
389 |
ref_seqlets=ref_seqlets, |
|
|
390 |
profile=profiles[task], |
|
|
391 |
profile_width=profile_width) |
|
|
392 |
assert len(dfi_subset) == len(dfp) |
|
|
393 |
dfp.columns = [f'{task}/{c}' for c in dfp.columns] # prepend task |
|
|
394 |
dfp_pattern_list.append(dfp) |
|
|
395 |
|
|
|
396 |
dfp_pattern = pd.concat(dfp_pattern_list, axis=1) |
|
|
397 |
dfp_pattern['id'] = dfi_subset['id'].values |
|
|
398 |
assert len(dfp_pattern) == len(dfi) |
|
|
399 |
return pd.merge(dfi, dfp_pattern, on='id') |
|
|
400 |
|
|
|
401 |
|
|
|
402 |
def annotate_profile(dfi, mr, profiles, profile_width=70, trim_frac=0.08, pattern_map=None): |
|
|
403 |
"""Append profile match columns to dfi |
|
|
404 |
|
|
|
405 |
Args: |
|
|
406 |
dfi[pd.DataFrame]: motif instances |
|
|
407 |
mr[ModiscoFile] |
|
|
408 |
profiles: dictionary of profiles with shape: (n_examples, seqlen, strand) |
|
|
409 |
profile_width: width of the profile to extract |
|
|
410 |
trim_frac: what trim fraction to use then computing the values for modisco |
|
|
411 |
seqlets. |
|
|
412 |
pattern_map[dict]: mapping from the pattern name in `dfi` to the corresponding |
|
|
413 |
pattern in `mr`. Used when dfi was for example not derived from modisco. |
|
|
414 |
""" |
|
|
415 |
seqlen = profiles[list(profiles)[0]].shape[1] |
|
|
416 |
|
|
|
417 |
dfi = dfi_filter_valid(dfi.copy(), profile_width, seqlen) |
|
|
418 |
dfi['id'] = np.arange(len(dfi)) |
|
|
419 |
# TODO - remove in-valid variables |
|
|
420 |
dfp_list = [] |
|
|
421 |
for pattern in tqdm(dfi.pattern.unique()): |
|
|
422 |
dfp_pattern_list = [] |
|
|
423 |
dfi_subset = dfi[dfi.pattern == pattern] |
|
|
424 |
for task in profiles: |
|
|
425 |
if pattern_map is not None: |
|
|
426 |
modisco_pattern = pattern_map[pattern] |
|
|
427 |
else: |
|
|
428 |
modisco_pattern = pattern |
|
|
429 |
dfp = profile_features(dfi2seqlets(dfi_subset), |
|
|
430 |
ref_seqlets=mr._get_seqlets(modisco_pattern, |
|
|
431 |
trim_frac=trim_frac), |
|
|
432 |
profile=profiles[task], |
|
|
433 |
profile_width=profile_width) |
|
|
434 |
assert len(dfi_subset) == len(dfp) |
|
|
435 |
dfp.columns = [f'{task}/{c}' for c in dfp.columns] # prepend task |
|
|
436 |
dfp_pattern_list.append(dfp) |
|
|
437 |
|
|
|
438 |
dfp_pattern = pd.concat(dfp_pattern_list, axis=1) |
|
|
439 |
dfp_pattern['id'] = dfi_subset['id'].values |
|
|
440 |
dfp_list.append(dfp_pattern) |
|
|
441 |
out = pd.concat(dfp_list, axis=0) |
|
|
442 |
assert len(out) == len(dfi) |
|
|
443 |
return pd.merge(dfi, out, on='id') |
|
|
444 |
|
|
|
445 |
|
|
|
446 |
def get_motif_pairs(motifs): |
|
|
447 |
"""Generate motif pairs |
|
|
448 |
""" |
|
|
449 |
pairs = [] |
|
|
450 |
for i in range(len(motifs)): |
|
|
451 |
for j in range(i, len(motifs)): |
|
|
452 |
pairs.append([list(motifs)[i], list(motifs)[j], ]) |
|
|
453 |
return pairs |
|
|
454 |
|
|
|
455 |
|
|
|
456 |
def motif_pair_dfi(dfi_filtered, motif_pair): |
|
|
457 |
"""Construct the matrix of motif pairs |
|
|
458 |
|
|
|
459 |
Args: |
|
|
460 |
dfi_filtered: dfi filtered to the desired property |
|
|
461 |
motif_pair: tuple of two pattern_name's |
|
|
462 |
Returns: |
|
|
463 |
pd.DataFrame with columns from dfi_filtered with _x and _y suffix |
|
|
464 |
""" |
|
|
465 |
dfa = dfi_filtered[dfi_filtered.pattern_name == motif_pair[0]] |
|
|
466 |
dfb = dfi_filtered[dfi_filtered.pattern_name == motif_pair[1]] |
|
|
467 |
|
|
|
468 |
dfab = pd.merge(dfa, dfb, on='example_idx', how='outer') |
|
|
469 |
dfab = dfab[~dfab[['pattern_center_x', 'pattern_center_y']].isnull().any(1)] |
|
|
470 |
|
|
|
471 |
dfab['center_diff'] = dfab.pattern_center_y - dfab.pattern_center_x |
|
|
472 |
if "pattern_center_aln_x" in dfab: |
|
|
473 |
dfab['center_diff_aln'] = dfab.pattern_center_aln_y - dfab.pattern_center_aln_x |
|
|
474 |
dfab['strand_combination'] = dfab.strand_x + dfab.strand_y |
|
|
475 |
# assure the right strand combination |
|
|
476 |
dfab.loc[dfab.center_diff < 0, 'strand_combination'] = dfab[dfab.center_diff < 0]['strand_combination'].map(comp_strand_compbination).values |
|
|
477 |
|
|
|
478 |
if motif_pair[0] == motif_pair[1]: |
|
|
479 |
dfab.loc[dfab['strand_combination'] == "--", 'strand_combination'] = "++" |
|
|
480 |
dfab = dfab[dfab.center_diff > 0] |
|
|
481 |
else: |
|
|
482 |
dfab.center_diff = np.abs(dfab.center_diff) |
|
|
483 |
if "center_diff_aln" in dfab: |
|
|
484 |
dfab.center_diff_aln = np.abs(dfab.center_diff_aln) |
|
|
485 |
if "center_diff_aln" in dfab: |
|
|
486 |
dfab = dfab[dfab.center_diff_aln != 0] # exclude perfect matches |
|
|
487 |
return dfab |
|
|
488 |
|
|
|
489 |
|
|
|
490 |
def remove_edge_instances(dfab, profile_width=70, total_width=1000): |
|
|
491 |
half = profile_width // 2 + profile_width % 2 |
|
|
492 |
return dfab[(dfab.pattern_center_x - half > 0) & (dfab.pattern_center_x + half < total_width) & |
|
|
493 |
(dfab.pattern_center_y - half > 0) & (dfab.pattern_center_y + half < total_width)] |