Diff of /tests/cli/test_modisco.py [000000] .. [d45a3a]

Switch to side-by-side view

--- a
+++ b/tests/cli/test_modisco.py
@@ -0,0 +1,80 @@
+"""Test modisco
+"""
+import os
+import pandas as pd
+import numpy as np
+from bpnet.cli.modisco import bpnet_modisco_run, cwm_scan
+from pytest import fixture
+import gin
+
+
+@fixture
+def expected_modisco_files():
+    return [
+        'modisco.h5',
+        'modisco-run.config.gin',
+        'modisco-run.config.gin.json',
+        'modisco-run.input-config.gin',
+        'modisco-run.kwargs.json',
+        'modisco-run.subset-contrib-file.npy',
+        'log'
+    ]
+
+
+def test_modisco_run(tmp_path, contrib_score_grad, modisco_config_gin,
+                     expected_modisco_files):
+    gin.clear_config()
+    bpnet_modisco_run(contrib_file=str(contrib_score_grad),
+                      output_dir=tmp_path,
+                      config=str(modisco_config_gin),
+                      )
+    output_files = os.listdir(tmp_path)
+    for f in expected_modisco_files:
+        assert f in output_files
+    assert np.all(np.load(tmp_path / 'modisco-run.subset-contrib-file.npy') == 1)
+
+
+def test_modisco_run_null(tmp_path, contrib_score_grad, contrib_score_grad_null,
+                          modisco_config_gin, expected_modisco_files):
+    gin.clear_config()
+    bpnet_modisco_run(contrib_file=str(contrib_score_grad),
+                      output_dir=str(tmp_path),
+                      null_contrib_file=str(contrib_score_grad_null),
+                      config=str(modisco_config_gin),
+                      )
+    output_files = os.listdir(tmp_path)
+    for f in expected_modisco_files:
+        assert f in output_files
+    assert np.all(np.load(tmp_path / 'modisco-run.subset-contrib-file.npy') == 1)
+
+
+def test_cwm_scan(tmp_path, modisco_dir, contrib_file):
+    output_file = str(tmp_path / 'instances.csv.gz')
+    cwm_scan(modisco_dir=str(modisco_dir),
+             output_file=output_file,
+             contrib_file=None,
+             add_profile_features=False)
+    df = pd.read_csv(output_file)
+    assert list(df.columns[:7]) == ['example_chrom', 'pattern_start_abs', 'pattern_end_abs',
+                                    'pattern', 'contrib_weighted_p', 'strand', 'match_weighted_p']
+    cm_path = modisco_dir / f'cwm-scan-seqlets.trim-frac=0.08.csv.gz'
+    assert os.path.exists(cm_path)
+
+
+def test_cwm_scan_new_file(tmp_path, modisco_dir, contrib_file):
+    output_file = str(tmp_path / 'instances.csv.gz')
+    trim_frac = 0.05
+    cwm_scan(modisco_dir=str(modisco_dir),
+             output_file=output_file,
+             contrib_file=contrib_file,
+             trim_frac=trim_frac,
+             add_profile_features=True)
+
+    cm_path = modisco_dir / f'cwm-scan-seqlets.trim-frac={trim_frac:.2f}.csv.gz'
+    assert os.path.exists(cm_path)
+
+    # make sure the normalized file exists
+    df = pd.read_csv(output_file)
+
+    assert 'Oct4/profile_counts' in df
+    assert 'Oct4/profile_match_p' in df