a | b/tests/test_create_sybilnet.py | ||
---|---|---|---|
1 | import argparse |
||
2 | import datetime |
||
3 | import os |
||
4 | |||
5 | from sybil import Serie, Sybil |
||
6 | |||
7 | def test_create_sybilnet(): |
||
8 | from sybil.models.sybil import SybilNet |
||
9 | |||
10 | fake_args = argparse.Namespace( |
||
11 | dropout=0.1, |
||
12 | max_followup=5, |
||
13 | ) |
||
14 | |||
15 | sybil_net = SybilNet(fake_args) |
||
16 | |||
17 | assert sybil_net.hidden_dim == 512 |
||
18 | assert sybil_net.prob_of_failure_layer is not None |