Diff of /test/test_simdeep.py [000000] .. [53737a]

Switch to unified view

a b/test/test_simdeep.py
1
import unittest
2
import warnings
3
4
import numpy as np
5
6
from simdeep.config import ACTIVATION
7
from simdeep.config import OPTIMIZER
8
from simdeep.config import LOSS
9
10
from os.path import abspath
11
from os.path import split
12
13
from os.path import isfile
14
from os.path import isdir
15
16
from os import remove
17
from shutil import rmtree
18
19
20
class TestPackage(unittest.TestCase):
21
    """ """
22
    def test_1_coxph_function(self):
23
        """test if the coxph function works """
24
        from simdeep.coxph_from_r import coxph
25
26
        isdead = [0, 1, 1, 1, 0, 1, 0, 0, 1, 0]
27
        nbdays = [24, 10, 25, 50, 14, 10 ,100, 10, 50, 10]
28
        values = [0, 1, 1, 0 , 1, 2, 0, 1, 0, 0]
29
30
31
        pvalue = coxph(values, isdead, nbdays, isfactor=True)
32
33
        self.assertTrue(isinstance(pvalue, float))
34
        self.assertTrue(pvalue < 0.05)
35
36
    def test_4_keras_model_instantiation(self):
37
        """
38
        test if keras can be loaded and if that a model
39
        can be instanciated
40
        """
41
        from keras.models import Sequential
42
        from keras.layers import Dense
43
44
        dummy_model = Sequential()
45
        dummy_model.add(Dense(10, input_dim=20,
46
                                   activation=ACTIVATION))
47
48
        dummy_model.compile(
49
            optimizer=OPTIMIZER, loss=LOSS)
50
51
        Xmat = np.random.random((50,20))
52
        Ymat = np.random.random((50,10))
53
54
        dummy_model.fit(
55
            x=Xmat,
56
            y=Ymat)
57
58
    def test_5_one_simdeep_instance(self):
59
        """
60
        test one simdeep instance
61
        """
62
        from simdeep.simdeep_analysis import SimDeep
63
        from simdeep.extract_data import LoadData
64
65
        PATH_DATA = '{0}/../examples/data/'.format(split(abspath(__file__))[0])
66
67
        TRAINING_TSV = {'RNA': 'rna_dummy.tsv', 'METH': 'meth_dummy.tsv', 'MIR': 'mir_dummy.tsv'}
68
        SURVIVAL_TSV = 'survival_dummy.tsv'
69
70
        PROJECT_NAME = 'TestProject'
71
        EPOCHS = 3
72
73
        deep_model_additional_args = {
74
        "epochs":EPOCHS, "seed":4}
75
76
        dataset = LoadData(path_data=PATH_DATA,
77
                       survival_tsv=SURVIVAL_TSV,
78
                       training_tsv=TRAINING_TSV)
79
80
        simdeep = SimDeep(dataset=dataset,
81
                          project_name=PROJECT_NAME,
82
                          path_results="{0}/{1}".format(PATH_DATA, PROJECT_NAME),
83
                          deep_model_additional_args=deep_model_additional_args,
84
        )
85
        simdeep.load_training_dataset()
86
        simdeep.fit()
87
        simdeep.predict_labels_on_full_dataset()
88
        simdeep.predict_labels_on_test_fold()
89
90
        simdeep.load_new_test_dataset(
91
            tsv_dict={'RNA': 'rna_test_dummy.tsv'},
92
            fname_key='dummy',
93
            path_survival_file='survival_test_dummy.tsv')
94
95
        simdeep.predict_labels_on_test_dataset()
96
97
        path_fig = '{0}/{1}/{1}_KM_plot_training_dataset.pdf'.format(PATH_DATA, PROJECT_NAME)
98
99
        print('#### asserting file: {0} exists'.format(path_fig))
100
        self.assertTrue(isfile(path_fig))
101
102
        from glob import glob
103
104
        for fil in glob('{0}/{1}*'.format(PATH_DATA, PROJECT_NAME)):
105
            if isfile(fil):
106
                remove(fil)
107
            elif isdir(fil):
108
                rmtree(fil)
109
110
    def test_6_simdeep_boosting(self):
111
        """
112
        test simdeep boosting
113
        """
114
        from simdeep.simdeep_boosting import SimDeepBoosting
115
116
        PATH_DATA = '{0}/../examples/data/'.format(split(abspath(__file__))[0])
117
118
        TRAINING_TSV = {'RNA': 'rna_dummy.tsv', 'METH': 'meth_dummy.tsv', 'MIR': 'mir_dummy.tsv'}
119
        SURVIVAL_TSV = 'survival_dummy.tsv'
120
121
        PROJECT_NAME = 'TestProject'
122
        EPOCHS = 3
123
        SEED = 3
124
        nb_it = 3
125
        nb_threads = 2
126
127
        boosting = SimDeepBoosting(
128
            nb_threads=nb_threads,
129
            nb_it=nb_it,
130
            survival_tsv=SURVIVAL_TSV,
131
            training_tsv=TRAINING_TSV,
132
            path_data=PATH_DATA,
133
            project_name=PROJECT_NAME,
134
            path_results=PATH_DATA,
135
            epochs=EPOCHS,
136
            normalization={'TRAIN_CORR_REDUCTION':True},
137
            seed=SEED)
138
139
        boosting.partial_fit()
140
        boosting.predict_labels_on_full_dataset()
141
        boosting.compute_clusters_consistency_for_full_labels()
142
        boosting.evalutate_cluster_performance()
143
        boosting.collect_cindex_for_test_fold()
144
        boosting.collect_cindex_for_full_dataset()
145
146
        boosting.load_new_test_dataset(
147
            tsv_dict={'RNA': 'rna_test_dummy.tsv'},
148
            fname_key='dummy',
149
            path_survival_file='survival_test_dummy.tsv',
150
            normalization={'TRAIN_NORM_SCALE':True},
151
        )
152
153
        boosting.predict_labels_on_test_dataset()
154
        boosting.predict_labels_on_test_dataset()
155
        boosting.compute_c_indexes_for_test_dataset()
156
        boosting.compute_clusters_consistency_for_test_labels()
157
158
        from glob import glob
159
160
        for fil in glob('{0}/{1}*'.format(PATH_DATA, PROJECT_NAME)):
161
            if isfile(fil):
162
                remove(fil)
163
            elif isdir(fil):
164
                rmtree(fil)
165
166
167
if __name__ == "__main__":
168
    unittest.main()