--- a +++ b/tests/modisco/test_pattern_instances.py @@ -0,0 +1,70 @@ +"""Test pattern_instances.py + +""" +import pytest +from pytest import fixture +import pandas as pd +import numpy as np +from bpnet.modisco.utils import longer_pattern +import bpnet +from bpnet.modisco.core import Pattern +from concise.preprocessing import encodeDNA + + +def scan_seq(p, seq): + scanned = p.scan_seq(seqs_one_hot, n_jobs=1) + strands = scanned.argmax(-1) + positions = scanned.max(axis=-1).argmax(-1) + strands = strands[np.arange(len(positions)), positions] + + return pd.DataFrame({"pattern": p.name, + "strand": strands, + "center": positions, + "seq_idx": np.arange(len(seqs_one_hot))}) + +# 'TTTACAATTT' # seq1 +# 'TTTACAATT' # seq2 +# ' AACAAA ' # m1 +# ' AAACAA ' # m1 +# ' ACAAT ' # m2 + + +seqs = ['TTTACAATTT', + 'TTTACAATT'] +seqs_one_hot = encodeDNA(seqs) + +motif_seqs_1 = ['TTTGTT', + 'AAACAA', + 'TTGTTT', + 'ACAATT', + 'TATTGT'] + +motif_seqs_2 = ['AACAAA', + 'AAACAA', + 'TTGTTT', + 'ACAATT', + 'TATTGT'] + + +def create_patterns(motif_seqs): + patterns = [Pattern(seq=encodeDNA([s])[0], + contrib=dict(a=encodeDNA([s])[0]), + hyp_contrib=dict(a=encodeDNA([s])[0]), name=str(i)) for i, s in enumerate(motif_seqs)] + + aligned_patterns = [p.align(patterns[0], pad_value=np.array([0.25] * 4)) for p in patterns] + return patterns, aligned_patterns + + +@pytest.mark.parametrize("motif_seqs", [motif_seqs_1, motif_seqs_2]) +def test_pattern_shift(motif_seqs): + patterns, aligned_patterns = create_patterns(motif_seqs) + dfi = pd.concat([scan_seq(p, seqs_one_hot) for p in patterns]) + dfi_aligned = pd.concat([scan_seq(p, seqs_one_hot) for p in aligned_patterns]) + np.all(dfi_aligned.center == 5) + + shift = {p.name: (p.attrs['align']['use_rc'] * 2 - 1) * p.attrs['align']['offset'] for p in aligned_patterns} + # This works + strand_shift = {p.name: p.attrs['align']['use_rc'] for p in aligned_patterns} + + assert np.all(dfi_aligned.center == dfi.center - (dfi.strand * 2 - 1) * dfi.pattern.map(shift)) + assert np.all(dfi_aligned.strand == np.where(dfi.pattern.map(strand_shift), 1 - dfi.strand, dfi.strand))