Diff of /tests/test_gin_parse.py [000000] .. [d45a3a]

Switch to unified view

a b/tests/test_gin_parse.py
1
"""Test gin config string parsing
2
"""
3
4
5
from bpnet.cli.train import gin2dict
6
GIN_STR = """
7
import bpnet
8
import bpnet.datasets
9
import bpnet.heads
10
import bpnet.layers
11
import bpnet.losses
12
import bpnet.metrics
13
import bpnet.models
14
import bpnet.seqmodel
15
import bpnet.trainers
16
17
# Macros:
18
# ==============================================================================
19
augment_interval = True
20
batchnorm = False
21
dataspec = 'dataspec.task1.yml'
22
exclude_chr = ['chr1', 'chr2']
23
filters = 64
24
lambda = 10
25
lr = 0.004
26
n_bias_tracks = 0
27
n_dil_layers = 1
28
seq_width = 200
29
tasks = ['Task1']
30
tconv_kernel_size = 25
31
test_chr = []
32
use_bias = False
33
valid_chr = ['chr2']
34
35
# Parameters for bpnet_data:
36
# ==============================================================================
37
bpnet_data.augment_interval = %augment_interval
38
bpnet_data.dataspec = %dataspec
39
bpnet_data.exclude_chr = %exclude_chr
40
bpnet_data.include_metadata = False
41
bpnet_data.interval_augmentation_shift = 100
42
bpnet_data.intervals_file = None
43
bpnet_data.peak_width = %seq_width
44
bpnet_data.seq_width = %seq_width
45
bpnet_data.shuffle = True
46
"""
47
48
49
def test_gin2dict():
50
    d = gin2dict(GIN_STR)
51
    assert d['bpnet_data.dataspec'] == 'dataspec.task1.yml'
52
    assert d['bpnet_data.peak_width'] == 200
53
    assert d['bpnet_data.seq_width'] == 200
54
    assert d['bpnet_data.exclude_chr'] == ['chr1', 'chr2']