[0ce940]: / tests / modisco / test_pattern_instances.py

Download this file

71 lines (54 with data), 2.3 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
"""Test pattern_instances.py
"""
import pytest
from pytest import fixture
import pandas as pd
import numpy as np
from bpnet.modisco.utils import longer_pattern
import bpnet
from bpnet.modisco.core import Pattern
from concise.preprocessing import encodeDNA
def scan_seq(p, seq):
scanned = p.scan_seq(seqs_one_hot, n_jobs=1)
strands = scanned.argmax(-1)
positions = scanned.max(axis=-1).argmax(-1)
strands = strands[np.arange(len(positions)), positions]
return pd.DataFrame({"pattern": p.name,
"strand": strands,
"center": positions,
"seq_idx": np.arange(len(seqs_one_hot))})
# 'TTTACAATTT' # seq1
# 'TTTACAATT' # seq2
# ' AACAAA ' # m1
# ' AAACAA ' # m1
# ' ACAAT ' # m2
seqs = ['TTTACAATTT',
'TTTACAATT']
seqs_one_hot = encodeDNA(seqs)
motif_seqs_1 = ['TTTGTT',
'AAACAA',
'TTGTTT',
'ACAATT',
'TATTGT']
motif_seqs_2 = ['AACAAA',
'AAACAA',
'TTGTTT',
'ACAATT',
'TATTGT']
def create_patterns(motif_seqs):
patterns = [Pattern(seq=encodeDNA([s])[0],
contrib=dict(a=encodeDNA([s])[0]),
hyp_contrib=dict(a=encodeDNA([s])[0]), name=str(i)) for i, s in enumerate(motif_seqs)]
aligned_patterns = [p.align(patterns[0], pad_value=np.array([0.25] * 4)) for p in patterns]
return patterns, aligned_patterns
@pytest.mark.parametrize("motif_seqs", [motif_seqs_1, motif_seqs_2])
def test_pattern_shift(motif_seqs):
patterns, aligned_patterns = create_patterns(motif_seqs)
dfi = pd.concat([scan_seq(p, seqs_one_hot) for p in patterns])
dfi_aligned = pd.concat([scan_seq(p, seqs_one_hot) for p in aligned_patterns])
np.all(dfi_aligned.center == 5)
shift = {p.name: (p.attrs['align']['use_rc'] * 2 - 1) * p.attrs['align']['offset'] for p in aligned_patterns}
# This works
strand_shift = {p.name: p.attrs['align']['use_rc'] for p in aligned_patterns}
assert np.all(dfi_aligned.center == dfi.center - (dfi.strand * 2 - 1) * dfi.pattern.map(shift))
assert np.all(dfi_aligned.strand == np.where(dfi.pattern.map(strand_shift), 1 - dfi.strand, dfi.strand))