Switch to unified view

a b/bpnet/modisco/pattern_instances.py
1
"""Code for working with the pattern instances table
2
produced by `bpnet.cli.modisco.modisco_score2`
3
which calls `pattern.get_instances`
4
"""
5
from bpnet.stats import quantile_norm
6
from collections import OrderedDict
7
from tqdm import tqdm
8
import pandas as pd
9
from bpnet.modisco.utils import longer_pattern, shorten_pattern
10
from bpnet.cli.modisco import get_nonredundant_example_idx
11
import numpy as np
12
from bpnet.plot.profiles import extract_signal
13
from bpnet.modisco.core import resize_seqlets, Seqlet
14
from bpnet.modisco.utils import trim_pssm_idx
15
16
17
def get_motif_pairs(motifs):
18
    """Generate motif pairs
19
    """
20
    pairs = []
21
    for i in range(len(motifs)):
22
        for j in range(i, len(motifs)):
23
            pairs.append([list(motifs)[i], list(motifs)[j], ])
24
    return pairs
25
26
27
comp_strand_compbination = {
28
    "++": "--",
29
    "--": "++",
30
    "-+": "+-",
31
    "+-": "-+"
32
}
33
34
strand_combinations = ["++", "--", "+-", "-+"]
35
36
37
# TODO - allow these to be of also other type?
38
def load_instances(parq_file, motifs=None, dedup=True, verbose=True):
39
    """Load pattern instances from the parquet file
40
41
    Args:
42
      parq_file: parquet file of motif instances
43
      motifs: dictionary of motifs of interest.
44
        key=custom motif name, value=short pattern name (e.g. {'Nanog': 'm0_p3'})
45
46
    """
47
    if motifs is not None:
48
        incl_motifs = {longer_pattern(m) for m in motifs.values()}
49
    else:
50
        incl_motifs = None
51
52
    if isinstance(parq_file, pd.DataFrame):
53
        dfi = parq_file
54
    else:
55
        if motifs is not None:
56
            from fastparquet import ParquetFile
57
58
            # Selectively load only the relevant patterns
59
            pf = ParquetFile(str(parq_file))
60
            patterns = [shorten_pattern(pn) for pn in incl_motifs]
61
            dfi = pf.to_pandas(filters=[("pattern_short", "in", patterns)])
62
        else:
63
            dfi = pd.read_parquet(str(parq_file), engine='fastparquet')
64
            if 'pattern' not in dfi:
65
                # assumes a hive-stored file
66
                dfi['pattern'] = dfi['dir0'].str.replace("pattern=", "").astype(str) + "/" + dfi['dir1'].astype(str)
67
68
    # filter
69
    if motifs is not None:
70
        dfi = dfi[dfi.pattern.isin(incl_motifs)]  # NOTE this should already be removed
71
        if 'pattern_short' not in dfi:
72
            dfi['pattern_short'] = dfi['pattern'].map({k: shorten_pattern(k) for k in incl_motifs})
73
        dfi['pattern_name'] = dfi['pattern_short'].map({v: k for k, v in motifs.items()})
74
    else:
75
        dfi['pattern_short'] = dfi['pattern'].map({k: shorten_pattern(k)
76
                                                   for k in dfi.pattern.unique()})
77
78
    # add some columns if they don't yet exist
79
    if 'pattern_start_abs' not in dfi:
80
        dfi['pattern_start_abs'] = dfi['example_start'] + dfi['pattern_start']
81
    if 'pattern_end_abs' not in dfi:
82
        dfi['pattern_end_abs'] = dfi['example_start'] + dfi['pattern_end']
83
84
    if dedup:
85
        # deduplicate
86
        dfi_dedup = dfi.drop_duplicates(['pattern',
87
                                         'example_chrom',
88
                                         'pattern_start_abs',
89
                                         'pattern_end_abs',
90
                                         'strand'])
91
92
        # number of removed duplicates
93
        d = len(dfi) - len(dfi_dedup)
94
        if verbose:
95
            print("number of de-duplicated instances:", d, f"({d / len(dfi) * 100}%)")
96
97
        # use de-duplicated instances from now on
98
        dfi = dfi_dedup
99
    return dfi
100
101
102
def multiple_load_instances(paths, motifs):
103
    """
104
    Args:
105
      paths: dictionary <tf> -> instances.parq
106
      motifs: dictinoary with <motif_name> -> pattern name of
107
        the form `<TF>/m0_p1`
108
    """
109
    from bpnet.utils import pd_col_prepend
110
    # load all the patterns
111
112
    dfi = pd.concat([load_instances(path,
113
                                    motifs=OrderedDict([(motif, pn.split("/", 1)[1])
114
                                                        for motif, pn in motifs.items()
115
                                                        if pn.split("/", 1)[0] == tf]),
116
                                    dedup=False).assign(tf=tf).pipe(pd_col_prepend, ['pattern', 'pattern_short'], prefix=tf + "/")
117
                     for tf, path in tqdm(paths.items())
118
                     ])
119
    return dfi
120
121
122
def dfi_add_ranges(dfi, ranges, dedup=False):
123
    # Add absolute locations
124
    dfi = dfi.merge(ranges, on="example_idx", how='left')
125
    dfi['pattern_start_abs'] = dfi['example_start'] + dfi['pattern_start']
126
    dfi['pattern_end_abs'] = dfi['example_start'] + dfi['pattern_end']
127
128
    if dedup:
129
        # deduplicate
130
        dfi_dedup = dfi.drop_duplicates(['pattern',
131
                                         'example_chrom',
132
                                         'pattern_start_abs',
133
                                         'pattern_end_abs',
134
                                         'strand'])
135
136
        # number of removed duplicates
137
        d = len(dfi) - len(dfi_dedup)
138
        print("number of de-duplicated instances:", d, f"({d / len(dfi) * 100}%)")
139
140
        # use de-duplicated instances from now on
141
        dfi = dfi_dedup
142
    return dfi
143
144
145
def dfi2pyranges(dfi):
146
    """Convert dfi to pyranges
147
148
    Args:
149
      dfi: pd.DataFrame returned by `load_instances`
150
    """
151
    import pyranges as pr
152
    dfi = dfi.copy()
153
    dfi['Chromosome'] = dfi['example_chrom']
154
    dfi['Start'] = dfi['pattern_start_abs']
155
    dfi['End'] = dfi['pattern_end_abs']
156
    dfi['Name'] = dfi['pattern']
157
    dfi['Score'] = dfi['contrib_weighted_p']
158
    dfi['Strand'] = dfi['strand']
159
    return pr.PyRanges(dfi)
160
161
162
def align_instance_center(dfi, original_patterns, aligned_patterns, trim_frac=0.08):
163
    """Align the center of the seqlets using aligned patterns
164
165
    Args:
166
      dfi: pd.DataFrame returned by `load_instances`
167
      original_patterns: un-trimmed patterns that were trimmed using
168
        trim_frac before scanning
169
      aligned_patterns: patterns that are all lined-up and that contain
170
        'align': {"use_rc", "offset" } information in the attrs
171
      trim_frac: original trim_frac used to trim the motifs
172
173
    Returns:
174
      dfi with 2 new columns: `pattern_center_aln` and `pattern_strand_aln`
175
    """
176
    # NOTE - it would be nice to be able to give trimmed patterns instead of
177
    # `original_patterns` + `trim_frac` and just extract the trim stats from the pattern
178
    # TODO - shall we split this function into two -> one for dealling with
179
    #        pattern trimming and one for dealing with aligning patterns?
180
    trim_shift_pos = {p.name: p._trim_center_shift(trim_frac=trim_frac)[0]
181
                      for p in original_patterns}
182
    trim_shift_neg = {p.name: p._trim_center_shift(trim_frac=trim_frac)[1]
183
                      for p in original_patterns}
184
    shift = {p.name: (p.attrs['align']['use_rc'] * 2 - 1) * p.attrs['align']['offset']
185
             for p in aligned_patterns}
186
    strand_shift = {p.name: p.attrs['align']['use_rc'] for p in aligned_patterns}
187
188
    strand_vec = dfi.strand.map({"+": 1, "-": -1})
189
    dfi['pattern_center_aln'] = (dfi.pattern_center -
190
                                 # - trim_shift since we are going from trimmed to non-trimmed
191
                                 np.where(dfi.strand == '-',
192
                                          dfi.pattern.map(trim_shift_neg),
193
                                          dfi.pattern.map(trim_shift_pos)) +
194
                                 # NOTE: `strand` should better be called `pattern_strand`
195
                                 dfi.pattern.map(shift) * strand_vec)
196
197
    def flip_strand(x):
198
        return x.map({"+": "-", "-": "+"})
199
200
    # flip the strand
201
    dfi['pattern_strand_aln'] = np.where(dfi.pattern.map(strand_shift),
202
                                         # if True, then we are on the other strand
203
                                         flip_strand(dfi.strand),
204
                                         dfi.strand)
205
    return dfi
206
207
208
def extract_ranges(dfi):
209
    """Extract example ranges
210
    """
211
    ranges = dfi[['example_chrom', 'example_start',
212
                  'example_end', 'example_idx']].drop_duplicates()
213
    ranges.columns = ['chrom', 'start', 'end', 'example_idx']
214
    return ranges
215
216
217
def filter_nonoverlapping_intervals(dfi):
218
    ranges = extract_ranges(dfi)
219
    keep_idx = get_nonredundant_example_idx(ranges, 200)
220
    return dfi[dfi.example_idx.isin(keep_idx)]
221
222
223
def plot_coocurence_matrix(dfi, total_examples, signif_threshold=1e-5, ax=None):
224
    """Test for motif co-occurence in example regions
225
226
    Args:
227
      dfi: pattern instance DataFrame observer by load_instances
228
      total_examples: total number of examples
229
    """
230
    import matplotlib.pyplot as plt
231
    if ax is None:
232
        ax = plt.gca()
233
    from sklearn.metrics import matthews_corrcoef
234
    from scipy.stats import fisher_exact
235
    import statsmodels as sm
236
    import seaborn as sns
237
    import matplotlib.pyplot as plt
238
239
    counts = pd.pivot_table(dfi, 'pattern_len', "example_idx",
240
                            "pattern_name", aggfunc=len, fill_value=0)
241
    ndxs = list(counts)
242
    c = counts > 0
243
244
    o = np.zeros((len(ndxs), len(ndxs)))
245
    op = np.zeros((len(ndxs), len(ndxs)))
246
    # fo = np.zeros((len(ndxs), len(ndxs)))
247
    # fp = np.zeros((len(ndxs), len(ndxs)))
248
249
    for i, xn in enumerate(ndxs):
250
        for j, yn in enumerate(ndxs):
251
            if xn == yn:
252
                continue
253
            ct = pd.crosstab(c[xn], c[yn])
254
            # add not-counted 0 values:
255
            ct.iloc[0, 0] += total_examples - len(c)
256
            t22 = sm.stats.contingency_tables.Table2x2(ct)
257
            o[i, j] = np.log2(t22.oddsratio)
258
            op[i, j] = t22.oddsratio_pvalue()
259
    signif = op < signif_threshold
260
    a = np.zeros_like(signif).astype(str)
261
    a[signif] = "*"
262
    a[~signif] = ""
263
    np.fill_diagonal(a, '')
264
265
    sns.heatmap(pd.DataFrame(o, columns=ndxs, index=ndxs),
266
                annot=a, fmt="", vmin=-4, vmax=4,
267
                cmap='RdBu_r', ax=ax)
268
    ax.set_title(f"Log2 odds-ratio. (*: p<{signif_threshold})")
269
270
271
def construct_motif_pairs(dfi, motif_pair,
272
                          features=['match_weighted_p',
273
                                    'contrib_weighted_p',
274
                                    'contrib_weighted']):
275
    """Construct motifs pair table
276
    """
277
    dfi_filtered = dfi.set_index('example_idx', drop=False)
278
    counts = pd.pivot_table(dfi_filtered,
279
                            'pattern_center', "example_idx", "pattern_name",
280
                            aggfunc=len, fill_value=0)
281
282
    if motif_pair[0] != motif_pair[1]:
283
        relevant_examples_idx = counts.index[np.all(counts[motif_pair] == 1, 1)]
284
    else:
285
        relevant_examples_idx = counts.index[np.all(counts[motif_pair] == 2, 1)]
286
287
    dft = dfi_filtered.loc[relevant_examples_idx]
288
    dft = dft[dft.pattern_name.isin(motif_pair)]
289
290
    dft = dft.sort_values(['example_idx', 'pattern_center'])
291
    dft['pattern_order'] = dft.index.duplicated().astype(int)
292
    if motif_pair[0] == motif_pair[1]:
293
        dft['pattern_name'] = dft['pattern_name'] + dft['pattern_order'].astype(str)
294
        motif_pair = [motif_pair[0] + '0', motif_pair[1] + '1']
295
296
    dftw = dft.set_index(['pattern_name'], append=True)[['pattern_center',
297
                                                         'strand'] + features].unstack()
298
299
    dftw['center_diff'] = dftw['pattern_center'][motif_pair].diff(axis=1).iloc[:, 1]
300
301
    dftw_filt = dftw[np.abs(dftw.center_diff) > 10]
302
303
    dftw_filt['distance'] = np.abs(dftw_filt['center_diff'])
304
    dftw_filt['strand_combination'] = dftw_filt['strand'][motif_pair].sum(1)
305
    return dftw_filt
306
307
308
def dfi_row2seqlet(row, short_name=False):
309
    return Seqlet(row.example_idx,
310
                  row.pattern_start,
311
                  row.pattern_end,
312
                  name=shorten_pattern(row.pattern) if short_name else row.pattern,
313
                  strand=row.strand)
314
315
316
def dfi2seqlets(dfi, short_name=False):
317
    """Convert the data-frame produced by pattern.get_instances()
318
    to a list of Seqlets
319
320
    Args:
321
      dfi: pd.DataFrame returned by pattern.get_instances()
322
      short_name: if True, short pattern name will be used for the seqlet name
323
324
    Returns:
325
      Seqlet list
326
    """
327
    return [dfi_row2seqlet(row, short_name=short_name)
328
            for i, row in dfi.iterrows()]
329
330
331
def profile_features(seqlets, ref_seqlets, profile, profile_width=70):
332
    from bpnet.simulate import profile_sim_metrics
333
    # resize
334
    seqlets = resize_seqlets(seqlets, profile_width, seqlen=profile.shape[1])
335
    seqlets_ref = resize_seqlets(ref_seqlets, profile_width, seqlen=profile.shape[1])
336
#     import pdb
337
#     pdb.set_trace()
338
339
    # extract the profile
340
    seqlet_profile = extract_signal(profile, seqlets)
341
    seqlet_profile_ref = extract_signal(profile, seqlets_ref)
342
343
    # compute the average profile
344
    avg_profile = seqlet_profile_ref.mean(axis=0)
345
346
    metrics = pd.DataFrame([profile_sim_metrics(avg_profile + 1e-6, cp + 1e-6)
347
                            for cp in seqlet_profile])
348
    metrics_ref = pd.DataFrame([profile_sim_metrics(avg_profile + 1e-6, cp + 1e-6)
349
                                for cp in seqlet_profile_ref])
350
351
    assert len(metrics) == len(seqlets)  # needs to be the same length
352
353
    if metrics.simmetric_kl.min() == np.inf or \
354
            metrics_ref.simmetric_kl.min() == np.inf:
355
        profile_match_p = None
356
    else:
357
        profile_match_p = quantile_norm(metrics.simmetric_kl, metrics_ref.simmetric_kl)
358
    return pd.DataFrame(OrderedDict([
359
        ("profile_match", metrics.simmetric_kl),
360
        ("profile_match_p", profile_match_p),
361
        ("profile_counts", metrics['counts']),
362
        ("profile_counts_p", quantile_norm(metrics['counts'], metrics_ref['counts'])),
363
        ("profile_max", metrics['max']),
364
        ("profile_max_p", quantile_norm(metrics['max'], metrics_ref['max'])),
365
        ("profile_counts_max_ref", metrics['counts_max_ref']),
366
        ("profile_counts_max_ref_p", quantile_norm(metrics['counts_max_ref'],
367
                                                   metrics_ref['counts_max_ref'])),
368
    ]))
369
370
371
def dfi_filter_valid(df, profile_width, seqlen):
372
    return df[(df.pattern_center.round() - profile_width > 0)
373
              & ((df.pattern_center + profile_width < seqlen))]
374
375
376
def annotate_profile_single(dfi, pattern_name, mr, profiles, profile_width=70, trim_frac=0.08):
377
    seqlen = profiles[list(profiles)[0]].shape[1]
378
379
    dfi = dfi_filter_valid(dfi.copy(), profile_width, seqlen)
380
    dfi['id'] = np.arange(len(dfi))
381
    assert np.all(dfi.pattern == pattern_name)
382
383
    dfp_pattern_list = []
384
    dfi_subset = dfi
385
    ref_seqlets = mr._get_seqlets(pattern_name, trim_frac=trim_frac)
386
    dfi_seqlets = dfi2seqlets(dfi_subset)
387
    for task in profiles:
388
        dfp = profile_features(dfi_seqlets,
389
                               ref_seqlets=ref_seqlets,
390
                               profile=profiles[task],
391
                               profile_width=profile_width)
392
        assert len(dfi_subset) == len(dfp)
393
        dfp.columns = [f'{task}/{c}' for c in dfp.columns]  # prepend task
394
        dfp_pattern_list.append(dfp)
395
396
    dfp_pattern = pd.concat(dfp_pattern_list, axis=1)
397
    dfp_pattern['id'] = dfi_subset['id'].values
398
    assert len(dfp_pattern) == len(dfi)
399
    return pd.merge(dfi, dfp_pattern, on='id')
400
401
402
def annotate_profile(dfi, mr, profiles, profile_width=70, trim_frac=0.08, pattern_map=None):
403
    """Append profile match columns to dfi
404
405
    Args:
406
      dfi[pd.DataFrame]: motif instances
407
      mr[ModiscoFile]
408
      profiles: dictionary of profiles with shape: (n_examples, seqlen, strand)
409
      profile_width: width of the profile to extract
410
      trim_frac: what trim fraction to use then computing the values for modisco
411
        seqlets.
412
      pattern_map[dict]: mapping from the pattern name in `dfi` to the corresponding
413
        pattern in `mr`. Used when dfi was for example not derived from modisco.
414
    """
415
    seqlen = profiles[list(profiles)[0]].shape[1]
416
417
    dfi = dfi_filter_valid(dfi.copy(), profile_width, seqlen)
418
    dfi['id'] = np.arange(len(dfi))
419
    # TODO - remove in-valid variables
420
    dfp_list = []
421
    for pattern in tqdm(dfi.pattern.unique()):
422
        dfp_pattern_list = []
423
        dfi_subset = dfi[dfi.pattern == pattern]
424
        for task in profiles:
425
            if pattern_map is not None:
426
                modisco_pattern = pattern_map[pattern]
427
            else:
428
                modisco_pattern = pattern
429
            dfp = profile_features(dfi2seqlets(dfi_subset),
430
                                   ref_seqlets=mr._get_seqlets(modisco_pattern,
431
                                                               trim_frac=trim_frac),
432
                                   profile=profiles[task],
433
                                   profile_width=profile_width)
434
            assert len(dfi_subset) == len(dfp)
435
            dfp.columns = [f'{task}/{c}' for c in dfp.columns]  # prepend task
436
            dfp_pattern_list.append(dfp)
437
438
        dfp_pattern = pd.concat(dfp_pattern_list, axis=1)
439
        dfp_pattern['id'] = dfi_subset['id'].values
440
        dfp_list.append(dfp_pattern)
441
    out = pd.concat(dfp_list, axis=0)
442
    assert len(out) == len(dfi)
443
    return pd.merge(dfi, out, on='id')
444
445
446
def get_motif_pairs(motifs):
447
    """Generate motif pairs
448
    """
449
    pairs = []
450
    for i in range(len(motifs)):
451
        for j in range(i, len(motifs)):
452
            pairs.append([list(motifs)[i], list(motifs)[j], ])
453
    return pairs
454
455
456
def motif_pair_dfi(dfi_filtered, motif_pair):
457
    """Construct the matrix of motif pairs
458
459
    Args:
460
      dfi_filtered: dfi filtered to the desired property
461
      motif_pair: tuple of two pattern_name's
462
    Returns:
463
      pd.DataFrame with columns from dfi_filtered with _x and _y suffix
464
    """
465
    dfa = dfi_filtered[dfi_filtered.pattern_name == motif_pair[0]]
466
    dfb = dfi_filtered[dfi_filtered.pattern_name == motif_pair[1]]
467
468
    dfab = pd.merge(dfa, dfb, on='example_idx', how='outer')
469
    dfab = dfab[~dfab[['pattern_center_x', 'pattern_center_y']].isnull().any(1)]
470
471
    dfab['center_diff'] = dfab.pattern_center_y - dfab.pattern_center_x
472
    if "pattern_center_aln_x" in dfab:
473
        dfab['center_diff_aln'] = dfab.pattern_center_aln_y - dfab.pattern_center_aln_x
474
    dfab['strand_combination'] = dfab.strand_x + dfab.strand_y
475
    # assure the right strand combination
476
    dfab.loc[dfab.center_diff < 0, 'strand_combination'] = dfab[dfab.center_diff < 0]['strand_combination'].map(comp_strand_compbination).values
477
478
    if motif_pair[0] == motif_pair[1]:
479
        dfab.loc[dfab['strand_combination'] == "--", 'strand_combination'] = "++"
480
        dfab = dfab[dfab.center_diff > 0]
481
    else:
482
        dfab.center_diff = np.abs(dfab.center_diff)
483
        if "center_diff_aln" in dfab:
484
            dfab.center_diff_aln = np.abs(dfab.center_diff_aln)
485
    if "center_diff_aln" in dfab:
486
        dfab = dfab[dfab.center_diff_aln != 0]  # exclude perfect matches
487
    return dfab
488
489
490
def remove_edge_instances(dfab, profile_width=70, total_width=1000):
491
    half = profile_width // 2 + profile_width % 2
492
    return dfab[(dfab.pattern_center_x - half > 0) & (dfab.pattern_center_x + half < total_width) &
493
                (dfab.pattern_center_y - half > 0) & (dfab.pattern_center_y + half < total_width)]