Diff of /tests/test_train.py [000000] .. [d45a3a]

Switch to unified view

a b/tests/test_train.py
1
"""test run gin -> TODO rename to bpnet
2
"""
3
from kipoi.data import Dataset
4
import numpy as np
5
import os
6
# from gin_train.cli.gin_train import gin_train
7
# import gin
8
9
10
# @gin.configurable
11
# class Dummy(Dataset):
12
#     def __init__(self, n,
13
#                  incl_chromosomes=None,
14
#                  excl_chromosomes=None):
15
#         self.n = n
16
17
#     def __len__(self):
18
#         return self.n
19
20
#     def __getitem__(self, idx):
21
#         return {"inputs": np.array([idx, idx + 1]),
22
#                 "targets": idx // 2
23
#                 }
24
25
26
# @gin.configurable
27
# def dummy_model(n_hidden, lr=0.04):
28
#     import keras.layers as kl
29
#     from keras.models import Model
30
#     inp = kl.Input((2,))
31
#     x = kl.Dense(n_hidden)(inp)
32
#     x = kl.Dense(1)(x)
33
#     model = Model([inp], x)
34
#     model.compile('Adam', loss="mse")
35
#     return model
36
37
38
# @gin.configurable
39
# def train_valid_dataset(dataset_cls):
40
#     return dataset_cls(), dataset_cls()
41
42
43
# def test_gin_train(tmpdir):
44
#     run_id = 'test'
45
#     gin_train("tests/data/example.gin", str(tmpdir), run_id=run_id, force_overwrite=True)
46
47
#     output_dir = os.path.join(str(tmpdir), run_id)
48
#     # produced files
49
#     # assert os.path.exists(os.path.join(str(tmpdir), "log/stdout.log"))
50
#     assert os.path.exists(os.path.join(output_dir, "config.gin"))
51
#     assert os.path.exists(os.path.join(output_dir, "model.h5"))
52
#     assert os.path.exists(os.path.join(output_dir, "history.csv"))