|
a |
|
b/tests/test_train.py |
|
|
1 |
"""test run gin -> TODO rename to bpnet |
|
|
2 |
""" |
|
|
3 |
from kipoi.data import Dataset |
|
|
4 |
import numpy as np |
|
|
5 |
import os |
|
|
6 |
# from gin_train.cli.gin_train import gin_train |
|
|
7 |
# import gin |
|
|
8 |
|
|
|
9 |
|
|
|
10 |
# @gin.configurable |
|
|
11 |
# class Dummy(Dataset): |
|
|
12 |
# def __init__(self, n, |
|
|
13 |
# incl_chromosomes=None, |
|
|
14 |
# excl_chromosomes=None): |
|
|
15 |
# self.n = n |
|
|
16 |
|
|
|
17 |
# def __len__(self): |
|
|
18 |
# return self.n |
|
|
19 |
|
|
|
20 |
# def __getitem__(self, idx): |
|
|
21 |
# return {"inputs": np.array([idx, idx + 1]), |
|
|
22 |
# "targets": idx // 2 |
|
|
23 |
# } |
|
|
24 |
|
|
|
25 |
|
|
|
26 |
# @gin.configurable |
|
|
27 |
# def dummy_model(n_hidden, lr=0.04): |
|
|
28 |
# import keras.layers as kl |
|
|
29 |
# from keras.models import Model |
|
|
30 |
# inp = kl.Input((2,)) |
|
|
31 |
# x = kl.Dense(n_hidden)(inp) |
|
|
32 |
# x = kl.Dense(1)(x) |
|
|
33 |
# model = Model([inp], x) |
|
|
34 |
# model.compile('Adam', loss="mse") |
|
|
35 |
# return model |
|
|
36 |
|
|
|
37 |
|
|
|
38 |
# @gin.configurable |
|
|
39 |
# def train_valid_dataset(dataset_cls): |
|
|
40 |
# return dataset_cls(), dataset_cls() |
|
|
41 |
|
|
|
42 |
|
|
|
43 |
# def test_gin_train(tmpdir): |
|
|
44 |
# run_id = 'test' |
|
|
45 |
# gin_train("tests/data/example.gin", str(tmpdir), run_id=run_id, force_overwrite=True) |
|
|
46 |
|
|
|
47 |
# output_dir = os.path.join(str(tmpdir), run_id) |
|
|
48 |
# # produced files |
|
|
49 |
# # assert os.path.exists(os.path.join(str(tmpdir), "log/stdout.log")) |
|
|
50 |
# assert os.path.exists(os.path.join(output_dir, "config.gin")) |
|
|
51 |
# assert os.path.exists(os.path.join(output_dir, "model.h5")) |
|
|
52 |
# assert os.path.exists(os.path.join(output_dir, "history.csv")) |