Switch to unified view

a b/test/dummy_test_one_instance.py
1
"""
2
test one instance of SimDeep
3
"""
4
5
from os.path import abspath
6
from os.path import split
7
8
from os.path import isfile
9
from os.path import isdir
10
11
from os import remove
12
from shutil import rmtree
13
14
15
16
def test_instance():
17
    """
18
    test one instance of SimDeep
19
    """
20
    from simdeep.simdeep_analysis import SimDeep
21
    from simdeep.extract_data import LoadData
22
23
    PATH_DATA = '{0}/../examples/data/'.format(split(abspath(__file__))[0])
24
25
    TRAINING_TSV = {'RNA': 'rna_dummy.tsv', 'METH': 'meth_dummy.tsv', 'MIR': 'mir_dummy.tsv'}
26
    SURVIVAL_TSV = 'survival_dummy.tsv'
27
28
    PROJECT_NAME = 'TestProject'
29
    EPOCHS = 3
30
31
    dataset = LoadData(path_data=PATH_DATA,
32
                   survival_tsv=SURVIVAL_TSV,
33
                   training_tsv=TRAINING_TSV)
34
35
    deep_model_additional_args = {
36
        "epochs":EPOCHS, "seed":4}
37
38
    simdeep = SimDeep(dataset=dataset,
39
                      project_name=PROJECT_NAME,
40
                      path_results=PATH_DATA,
41
                      deep_model_additional_args=deep_model_additional_args,
42
                      )
43
    simdeep.load_training_dataset()
44
    simdeep.fit()
45
    simdeep.predict_labels_on_full_dataset()
46
    simdeep.predict_labels_on_test_fold()
47
48
    simdeep.load_new_test_dataset(
49
        {'RNA': 'rna_test_dummy.tsv'},
50
        'survival_test_dummy.tsv',
51
        'dummy')
52
53
    simdeep.predict_labels_on_test_dataset()
54
55
    from glob import glob
56
57
    for fil in glob('{0}/{1}*'.format(PATH_DATA, PROJECT_NAME)):
58
        if isfile(fil):
59
            remove(fil)
60
        elif isdir(fil):
61
            rmtree(fil)
62
63
64
if __name__ == '__main__':
65
    test_instance()