Switch to unified view

a b/simdeep/simdeep_analysis.py
1
"""
2
DeepProg class for one instance model
3
"""
4
5
from sklearn.cluster import KMeans
6
from sklearn.mixture import GaussianMixture
7
from sklearn.model_selection import cross_val_score
8
9
from simdeep.deepmodel_base import DeepBase
10
11
from simdeep.survival_model_utils import ClusterWithSurvival
12
13
from simdeep.config import NB_CLUSTERS
14
from simdeep.config import CLUSTER_ARRAY
15
from simdeep.config import PVALUE_THRESHOLD
16
from simdeep.config import CINDEX_THRESHOLD
17
from simdeep.config import CLASSIFIER_TYPE
18
from simdeep.config import USE_AUTOENCODERS
19
from simdeep.config import FEATURE_SURV_ANALYSIS
20
from simdeep.config import SEED
21
22
from simdeep.config import MIXTURE_PARAMS
23
from simdeep.config import PATH_RESULTS
24
from simdeep.config import PROJECT_NAME
25
from simdeep.config import CLASSIFICATION_METHOD
26
27
from simdeep.config import CLUSTER_EVAL_METHOD
28
from simdeep.config import CLUSTER_METHOD
29
from simdeep.config import NB_THREADS_COXPH
30
from simdeep.config import NB_SELECTED_FEATURES
31
from simdeep.config import LOAD_EXISTING_MODELS
32
from simdeep.config import NODES_SELECTION
33
from simdeep.config import CLASSIFIER
34
from simdeep.config import HYPER_PARAMETERS
35
from simdeep.config import PATH_TO_SAVE_MODEL
36
from simdeep.config import CLUSTERING_OMICS
37
from simdeep.config import USE_R_PACKAGES_FOR_SURVIVAL
38
39
from simdeep.survival_utils import _process_parallel_coxph
40
from simdeep.survival_utils import _process_parallel_cindex
41
from simdeep.survival_utils import _process_parallel_feature_importance
42
from simdeep.survival_utils import _process_parallel_feature_importance_per_cluster
43
from simdeep.survival_utils import select_best_classif_params
44
45
from simdeep.simdeep_utils import metadata_usage_type
46
from simdeep.simdeep_utils import feature_selection_usage_type
47
48
from simdeep.simdeep_utils import load_labels_file
49
50
from simdeep.coxph_from_r import coxph
51
from simdeep.coxph_from_r import c_index
52
from simdeep.coxph_from_r import c_index_multiple
53
54
from simdeep.coxph_from_r import surv_median
55
56
from collections import Counter
57
58
from sklearn.metrics import silhouette_score
59
60
try:
61
    from sklearn.metrics import calinski_harabasz_score \
62
        as calinski_harabaz_score
63
except Exception:
64
    from sklearn.metrics import calinski_harabaz_score
65
66
from sklearn.model_selection import GridSearchCV
67
68
import numpy as np
69
from numpy import hstack
70
71
from collections import defaultdict
72
73
import warnings
74
75
from multiprocessing import Pool
76
77
from os.path import isdir
78
from os import mkdir
79
80
81
################ VARIABLE ############################################
82
_CLASSIFICATION_METHOD_LIST = ['ALL_FEATURES', 'SURVIVAL_FEATURES']
83
MODEL_THRES = 0.05
84
######################################################################
85
86
87
class SimDeep(DeepBase):
88
    """
89
    Instanciate a new DeepProg instance.
90
    The default parameters are defined in the config.py file
91
92
    Parameters:
93
             :dataset: ExtractData instance. Default None (create a new dataset using the config variable)
94
             :nb_clusters: Number of clusters to search (default NB_CLUSTERS)
95
             :pvalue_thres: Pvalue threshold to include a feature  (default PVALUE_THRESHOLD)
96
             :clustering_omics: Which omics to use for clustering. If empty, then all the available omics will be used
97
             :cindex_thres: C-index threshold to include a feature. This parameter is used only if `node_selection` is set to "C-index" (default CINDEX_THRESHOLD)
98
             :cluster_method: Cluster method to use. possible choice ['mixture', 'kmeans']. (default CLUSTER_METHOD)
99
             :cluster_eval_method: Cluster evaluation method to use in case the `cluster_array` parameter is a list of possible K. Possible choice ['bic', 'silhouette', 'calinski'] (default CLUSTER_EVAL_METHOD)
100
             :classifier_type: Type of classifier to use. Possible choice ['svm', 'clustering']. If 'clustering' is selected, The predict method of the clustering algoritm is used  (default CLASSIFIER_TYPE)
101
             :project_name: Name of the project. This name will be used to save the output files and create the output folder (default PROJECT_NAME)
102
             :path_results: Result folder path used to save the output files (default PATH_RESULTS)
103
             :cluster_array: Array of possible number of clusters to try. If set, `nb_clusters` is ignored (default CLUSTER_ARRAY)
104
             :nb_selected_features: Number of selected features to construct classifiers (default NB_SELECTED_FEATURES)
105
             :mixture_params: Dictionary of parameters used to instanciate the Gaussian mixture algorithm (default MIXTURE_PARAMS)
106
             :node_selection: Mehtod to select new features. possible choice ['Cox-PH', 'C-index']. (default NODES_SELECTION)
107
             :nb_threads_coxph: Number of python processes to use to compute individual survival models in parallel (default NB_THREADS_COXPH)
108
             :classification_method: Possible choice  ['ALL_FEATURES', 'SURVIVAL_FEATURES']. If 'SURVIVAL_FEATURES' is selected, the classifiers are built using survival features  (default CLASSIFICATION_METHOD)
109
             :load_existing_models: (default LOAD_EXISTING_MODELS)
110
             :path_to_save_model: (default PATH_TO_SAVE_MODEL)
111
             :metadata_usage: Meta data usage with survival models (if metadata_tsv provided as argument to the dataset). Possible choice are [None, False, 'labels', 'new-features', 'all', True] (True is the same as all)
112
             :feature_selection_usage: selection method for survival features ('individual' or 'lasso')
113
             :alternative_embedding: alternative external embedding to use instead of builfing autoencoders (default None)
114
             :kwargs_alternative_embedding: parameters for external embedding fitting
115
    """
116
    def __init__(self,
117
                 nb_clusters=NB_CLUSTERS,
118
                 pvalue_thres=PVALUE_THRESHOLD,
119
                 cindex_thres=CINDEX_THRESHOLD,
120
                 use_autoencoders=USE_AUTOENCODERS,
121
                 feature_surv_analysis=FEATURE_SURV_ANALYSIS,
122
                 cluster_method=CLUSTER_METHOD,
123
                 cluster_eval_method=CLUSTER_EVAL_METHOD,
124
                 classifier_type=CLASSIFIER_TYPE,
125
                 project_name=PROJECT_NAME,
126
                 path_results=PATH_RESULTS,
127
                 cluster_array=CLUSTER_ARRAY,
128
                 nb_selected_features=NB_SELECTED_FEATURES,
129
                 mixture_params=MIXTURE_PARAMS,
130
                 node_selection=NODES_SELECTION,
131
                 nb_threads_coxph=NB_THREADS_COXPH,
132
                 classification_method=CLASSIFICATION_METHOD,
133
                 load_existing_models=LOAD_EXISTING_MODELS,
134
                 path_to_save_model=PATH_TO_SAVE_MODEL,
135
                 clustering_omics=CLUSTERING_OMICS,
136
                 metadata_usage=None,
137
                 feature_selection_usage='individual',
138
                 use_r_packages=USE_R_PACKAGES_FOR_SURVIVAL,
139
                 seed=SEED,
140
                 alternative_embedding=None,
141
                 do_KM_plot=True,
142
                 verbose=True,
143
                 _isboosting=False,
144
                 dataset=None,
145
                 kwargs_alternative_embedding={},
146
                 deep_model_additional_args={}):
147
        """
148
        """
149
        self.seed = seed
150
        self.nb_clusters = nb_clusters
151
        self.pvalue_thres = pvalue_thres
152
        self.cindex_thres = cindex_thres
153
        self.use_autoencoders = use_autoencoders
154
        self.classifier_grid = GridSearchCV(CLASSIFIER(), HYPER_PARAMETERS, cv=5)
155
        self.cluster_array = cluster_array
156
        self.path_results = path_results
157
        self.clustering_omics = clustering_omics
158
        self.use_r_packages = use_r_packages
159
        self.metadata_usage = metadata_usage_type(metadata_usage)
160
        self.feature_selection_usage = feature_selection_usage_type(
161
            feature_selection_usage)
162
163
        self.feature_surv_analysis = feature_surv_analysis
164
165
        if self.feature_selection_usage is None:
166
            self.feature_surv_analysis = False
167
168
        self.alternative_embedding = alternative_embedding
169
        self.kwargs_alternative_embedding = kwargs_alternative_embedding
170
171
        if self.path_results and not isdir(self.path_results):
172
            mkdir(self.path_results)
173
174
        self.mixture_params = mixture_params
175
176
        self.project_name = project_name
177
        self._project_name = project_name
178
        self.do_KM_plot = do_KM_plot
179
        self.nb_threads_coxph = nb_threads_coxph
180
        self.classification_method = classification_method
181
        self.nb_selected_features = nb_selected_features
182
        self.node_selection = node_selection
183
184
        self.train_pvalue = None
185
        self.train_pvalue_proba = None
186
        self.full_pvalue = None
187
        self.full_pvalue_proba = None
188
        self.cv_pvalue = None
189
        self.cv_pvalue_proba = None
190
        self.test_pvalue = None
191
        self.test_pvalue_proba = None
192
193
        self.classifier = None
194
        self.classifier_test = None
195
        self.clustering = None
196
197
        self.classifier_dict = {}
198
199
        self.encoder_for_kde_plot_dict = {}
200
        self._main_kernel = {}
201
202
        self.classifier_type = classifier_type
203
204
        self.used_normalization = None
205
        self.test_normalization = None
206
207
        self.used_features_for_classif = None
208
209
        self._isboosting = _isboosting
210
        self._pretrained_model = False
211
        self._is_fitted = False
212
213
        self.valid_node_ids_array = {}
214
        self.activities_array = {}
215
        self.activities_pred_array = {}
216
        self.pred_node_ids_array = {}
217
218
        self.activities_train = None
219
        self.activities_test = None
220
        self.activities_cv = None
221
222
        self.activities_for_pred_train = None
223
        self.activities_for_pred_test = None
224
        self.activities_for_pred_cv = None
225
226
        self.test_labels = None
227
        self.test_labels_proba = None
228
        self.cv_labels = None
229
        self.cv_labels_proba = None
230
        self.full_labels = None
231
        self.full_labels_proba = None
232
233
        self.labels = None
234
        self.labels_proba = None
235
236
        self.training_omic_list = []
237
        self.test_omic_list = []
238
239
        self.feature_scores = defaultdict(list)
240
        self.feature_scores_per_cluster = {}
241
242
        self._label_ordered_dict = {}
243
244
        self.clustering_performance = None
245
        self.bic_score = None
246
        self.silhouette_score = None
247
        self.calinski_score = None
248
249
        self.cluster_method = cluster_method
250
        self.cluster_eval_method = cluster_eval_method
251
        self.verbose = verbose
252
        self._load_existing_models = load_existing_models
253
        self._features_scores_changed = False
254
255
        self.path_to_save_model = path_to_save_model
256
257
        deep_model_additional_args['path_to_save_model'] = self.path_to_save_model
258
259
        DeepBase.__init__(self,
260
                          verbose=self.verbose,
261
                          dataset=dataset,
262
                          alternative_embedding=self.alternative_embedding,
263
                          kwargs_alternative_embedding=self.kwargs_alternative_embedding,
264
                          **deep_model_additional_args)
265
266
    def _look_for_nodes(self, key):
267
        """
268
        """
269
        assert(self.node_selection in ['Cox-PH', 'C-index'])
270
271
        if self.metadata_usage in ['all', 'new-features'] and \
272
           self.dataset.metadata_mat is not None:
273
            metadata_mat = self.dataset.metadata_mat
274
        else:
275
            metadata_mat = None
276
277
        if self.node_selection == 'Cox-PH':
278
            return self._look_for_survival_nodes(
279
                key, metadata_mat=metadata_mat)
280
281
        if self.node_selection == 'C-index':
282
            return self._look_for_prediction_nodes(key)
283
284
    def load_new_test_dataset(self, tsv_dict,
285
                              fname_key=None,
286
                              path_survival_file=None,
287
                              normalization=None,
288
                              survival_flag=None,
289
                              metadata_file=None):
290
        """
291
        """
292
        self.dataset.load_new_test_dataset(
293
            tsv_dict,
294
            path_survival_file,
295
            normalization=normalization,
296
            survival_flag=survival_flag,
297
            metadata_file=metadata_file
298
        )
299
300
        if normalization is not None:
301
            self.test_normalization = {
302
                key: normalization[key]
303
                for key in normalization
304
                if normalization[key]}
305
306
        else:
307
            self.test_normalization = {
308
                key: self.dataset.normalization[key]
309
                for key in self.dataset.normalization
310
                if self.dataset.normalization[key]}
311
312
        if self.used_normalization != self.test_normalization:
313
            if self.verbose:
314
                print('recombuting feature scores...')
315
316
            self.feature_scores = {}
317
            self.compute_feature_scores(use_ref=True)
318
            self._features_scores_changed = True
319
320
        if fname_key:
321
            self.project_name = '{0}_{1}'.format(self._project_name, fname_key)
322
323
    def fit_on_pretrained_label_file(self, label_file):
324
        """
325
        fit a deepprog simdeep model without training autoencoder but just using a ID->labels file to train a classifier
326
        """
327
        self._pretrained_model = True
328
        self.use_autoencoders = False
329
        self.feature_surv_analysis = False
330
331
        self.dataset.load_array()
332
        self.dataset.load_survival()
333
        self.dataset.load_meta_data()
334
        self.dataset.subset_training_sets()
335
336
        labels_dict = load_labels_file(label_file)
337
338
        train, test, labels, labels_proba = [], [], [], []
339
340
        for index, sample in enumerate(self.dataset.sample_ids):
341
342
            if sample in labels_dict:
343
                train.append(index)
344
                label, label_proba = labels_dict[sample]
345
346
                labels.append(label)
347
                labels_proba.append(label_proba)
348
349
            else:
350
                test.append(index)
351
352
        if test:
353
            self.dataset.cross_validation_instance = (train, test)
354
        else:
355
            self.dataset.cross_validation_instance = None
356
357
        self.dataset.create_a_cv_split()
358
        self.dataset.normalize_training_array()
359
360
        self.matrix_train_array = self.dataset.matrix_train_array
361
362
        for key in self.matrix_train_array:
363
            self.matrix_train_array[key] = self.matrix_train_array[key].astype('float32')
364
365
        self.training_omic_list = self.dataset.training_tsv.keys()
366
367
        self.predict_labels_using_external_labels(labels, labels_proba)
368
369
        self.used_normalization = {key: self.dataset.normalization[key]
370
                                   for key in self.dataset.normalization
371
                                   if self.dataset.normalization[key]}
372
373
        self.used_features_for_classif = self.dataset.feature_train_array
374
        self.look_for_survival_nodes()
375
        self.fit_classification_model()
376
377
    def predict_labels_using_external_labels(self, labels, labels_proba):
378
        """
379
        """
380
        self.labels = labels
381
        nb_clusters = len(set(self.labels))
382
        self.labels_proba = np.array([labels_proba for _ in range(nb_clusters)]).T
383
384
        nbdays, isdead = self.dataset.survival.T.tolist()
385
386
        pvalue = coxph(self.labels, isdead, nbdays,
387
                       isfactor=False,
388
                       do_KM_plot=self.do_KM_plot,
389
                       png_path=self.path_results,
390
                       seed=self.seed,
391
                       use_r_packages=self.use_r_packages,
392
                       fig_name='{0}_KM_plot_training_dataset'.format(self.project_name))
393
394
        pvalue_proba = coxph(self.labels_proba.T[0], isdead, nbdays,
395
                             seed=self.seed,
396
                             use_r_packages=self.use_r_packages,
397
                             isfactor=False)
398
399
        if not self._isboosting:
400
            self._write_labels(self.dataset.sample_ids, self.labels,
401
                               labels_proba=self.labels_proba.T[0],
402
                               fname='{0}_training_set_labels'.format(self.project_name))
403
404
        if self.verbose:
405
            print('Cox-PH p-value (Log-Rank) for the cluster labels: {0}'.format(pvalue))
406
407
        self.train_pvalue = pvalue
408
        self.train_pvalue_proba = pvalue_proba
409
410
    def fit(self):
411
        """
412
        main function
413
        I) construct an autoencoder or fit alternative embedding
414
        II) predict nodes linked with survival (if active)
415
        and III) do clustering
416
        """
417
        if self._load_existing_models:
418
            self.load_encoders()
419
420
        if not self.is_model_loaded:
421
            if self.alternative_embedding is not None:
422
                self.fit_alternative_embedding()
423
            else:
424
                self.construct_autoencoders()
425
426
        self.look_for_survival_nodes()
427
428
        self.training_omic_list = list(self.encoder_array.keys())
429
        self.predict_labels()
430
431
        self.used_normalization = {key: self.dataset.normalization[key]
432
                                   for key in self.dataset.normalization
433
                                   if self.dataset.normalization[key]}
434
435
        self.used_features_for_classif = self.dataset.feature_train_array
436
        self.fit_classification_model()
437
438
    def predict_labels_on_test_fold(self):
439
        """
440
        """
441
        if not self.dataset.cross_validation_instance:
442
            return
443
444
        self.dataset.load_matrix_test_fold()
445
446
        nbdays, isdead = self.dataset.survival_cv.T.tolist()
447
        self.activities_cv = self._predict_survival_nodes(
448
            self.dataset.matrix_cv_array)
449
450
        self.cv_labels, self.cv_labels_proba = self._predict_labels(
451
            self.activities_cv, self.dataset.matrix_cv_array)
452
453
        if self.verbose:
454
            print('#### report of test fold cluster:):')
455
            for key, value in Counter(self.cv_labels).items():
456
                print('class: {0}, number of samples :{1}'.format(key, value))
457
458
        if self.metadata_usage in ['all', 'labels'] and \
459
           self.dataset.metadata_mat_cv is not None:
460
            metadata_mat = self.dataset.metadata_mat_cv
461
        else:
462
            metadata_mat = None
463
464
        pvalue, pvalue_proba = self._compute_test_coxph('KM_plot_test_fold',
465
                                                        nbdays, isdead,
466
                                                        self.cv_labels,
467
                                                        self.cv_labels_proba,
468
                                                        metadata_mat=metadata_mat)
469
        self.cv_pvalue = pvalue
470
        self.cv_pvalue_proba = pvalue_proba
471
472
        if not self._isboosting:
473
            self._write_labels(self.dataset.sample_ids_cv, self.cv_labels,
474
                               labels_proba=self.cv_labels_proba.T[0],
475
                               fname='{0}_test_fold_labels'.format(self.project_name))
476
477
        return self.cv_labels, pvalue, pvalue_proba
478
479
    def predict_labels_on_full_dataset(self):
480
        """
481
        """
482
        self.dataset.load_matrix_full()
483
484
        nbdays, isdead = self.dataset.survival_full.T.tolist()
485
486
        self.activities_full = self._predict_survival_nodes(
487
            self.dataset.matrix_full_array)
488
489
        self.full_labels, self.full_labels_proba = self._predict_labels(
490
            self.activities_full, self.dataset.matrix_full_array)
491
492
        if self.verbose:
493
            print('#### report of assigned cluster for full dataset:')
494
            for key, value in Counter(self.full_labels).items():
495
                print('class: {0}, number of samples :{1}'.format(key, value))
496
497
        if self.metadata_usage in ['all', 'labels'] and \
498
           self.dataset.metadata_mat_full is not None:
499
            metadata_mat = self.dataset.metadata_mat_full
500
        else:
501
            metadata_mat = None
502
503
        pvalue, pvalue_proba = self._compute_test_coxph('KM_plot_full',
504
                                                        nbdays, isdead,
505
                                                        self.full_labels,
506
                                                        self.full_labels_proba,
507
                                                        metadata_mat=metadata_mat)
508
        self.full_pvalue = pvalue
509
        self.full_pvalue_proba = pvalue_proba
510
511
        if not self._isboosting:
512
            self._write_labels(self.dataset.sample_ids_full, self.full_labels,
513
                               labels_proba=self.full_labels_proba.T[0],
514
                               fname='{0}_full_labels'.format(self.project_name))
515
516
        return self.full_labels, pvalue, pvalue_proba
517
518
    def predict_labels_on_test_dataset(self):
519
        """
520
        """
521
        if self.dataset.survival_test is not None:
522
            nbdays, isdead = self.dataset.survival_test.T.tolist()
523
524
        self.test_omic_list = list(self.dataset.matrix_test_array.keys())
525
        self.test_omic_list = list(set(self.test_omic_list).intersection(
526
            self.training_omic_list))
527
528
        try:
529
            assert(len(self.test_omic_list) > 0)
530
        except AssertionError:
531
            raise Exception('in predict_labels_on_test_dataset: test_omic_list is empty!'\
532
                            '\n either no common omic with trining_omic_list or error!')
533
534
        self.fit_classification_test_model()
535
536
        self.activities_test = self._predict_survival_nodes(
537
            self.dataset.matrix_test_array)
538
        self._predict_test_labels(self.activities_test,
539
                                  self.dataset.matrix_test_array)
540
541
        if self.verbose:
542
            print('#### report of assigned cluster:')
543
            for key, value in Counter(self.test_labels).items():
544
                print('class: {0}, number of samples :{1}'.format(key, value))
545
546
        if self.metadata_usage in ['all', 'test-labels'] and \
547
           self.dataset.metadata_mat_test is not None:
548
            metadata_mat = self.dataset.metadata_mat_test
549
        else:
550
            metadata_mat = None
551
552
        pvalue, pvalue_proba = self._compute_test_coxph('KM_plot_test',
553
                                                        nbdays, isdead,
554
                                                        self.test_labels,
555
                                                        self.test_labels_proba,
556
                                                        metadata_mat=metadata_mat)
557
        self.test_pvalue = pvalue
558
        self.test_pvalue_proba = pvalue_proba
559
560
        if self.dataset.survival_test is not None:
561
            if np.isnan(nbdays).all():
562
                pvalue, pvalue_proba = self._compute_test_coxph(
563
                    'KM_plot_test',
564
                    nbdays, isdead,
565
                    self.test_labels, self.test_labels_proba)
566
567
                self.test_pvalue = pvalue
568
                self.test_pvalue_proba = pvalue_proba
569
570
        if not self._isboosting:
571
            self._write_labels(self.dataset.sample_ids_test, self.test_labels,
572
                               labels_proba=self.test_labels_proba.T[0],
573
                               fname='{0}_test_labels'.format(self.project_name))
574
575
        return self.test_labels, pvalue, pvalue_proba
576
577
    def _compute_test_coxph(self,
578
                            fname_base,
579
                            nbdays,
580
                            isdead,
581
                            labels,
582
                            labels_proba,
583
                            metadata_mat=None):
584
        """ """
585
        pvalue = coxph(
586
            labels, isdead, nbdays,
587
            isfactor=False,
588
            do_KM_plot=self.do_KM_plot,
589
            png_path=self.path_results,
590
            seed=self.seed,
591
            use_r_packages=self.use_r_packages,
592
            metadata_mat=metadata_mat,
593
            fig_name='{0}_{1}'.format(self.project_name, fname_base))
594
595
        if self.verbose:
596
            print('Cox-PH p-value (Log-Rank) for inferred labels: {0}'.format(pvalue))
597
598
        pvalue_proba = coxph(
599
            labels_proba.T[0],
600
            isdead, nbdays,
601
            isfactor=False,
602
            do_KM_plot=False,
603
            png_path=self.path_results,
604
            seed=self.seed,
605
            use_r_packages=self.use_r_packages,
606
            metadata_mat=metadata_mat,
607
            fig_name='{0}_{1}_proba'.format(self.project_name, fname_base))
608
609
        if self.verbose:
610
            print('Cox-PH proba p-value (Log-Rank) for inferred labels: {0}'.format(pvalue_proba))
611
612
        return pvalue, pvalue_proba
613
614
    def compute_feature_scores(self, use_ref=False):
615
        """
616
        """
617
        if self.feature_scores:
618
            return
619
620
        pool = None
621
622
        if not self._isboosting:
623
            pool = Pool(self.nb_threads_coxph)
624
            mapf = pool.map
625
            mapf = map
626
        else:
627
            mapf = map
628
629
        def generator(labels, feature_list, matrix):
630
            for i in range(len(feature_list)):
631
                yield feature_list[i], matrix[i], labels
632
633
        if use_ref:
634
            key_array = list(self.dataset.matrix_ref_array.keys())
635
        else:
636
            key_array = list(self.dataset.matrix_train_array.keys())
637
638
        for key in key_array:
639
            if use_ref:
640
                feature_list = self.dataset.feature_ref_array[key][:]
641
                matrix = self.dataset.matrix_ref_array[key][:]
642
            else:
643
                feature_list = self.dataset.feature_train_array[key][:]
644
                matrix = self.dataset.matrix_train_array[key][:]
645
646
            labels = self.labels[:]
647
648
            input_list = generator(labels, feature_list, matrix.T)
649
650
            features_scored = list(mapf(_process_parallel_feature_importance, input_list))
651
            features_scored.sort(key=lambda x:x[1])
652
653
            self.feature_scores[key] = features_scored
654
655
        if pool is not None:
656
            pool.close()
657
            pool.join()
658
659
    def compute_feature_scores_per_cluster(self, use_ref=False,
660
                                           pval_thres=0.01):
661
        """
662
        """
663
        print('computing feature importance per cluster...')
664
665
        mapf = map
666
667
        for label in set(self.labels):
668
            self.feature_scores_per_cluster[label] = []
669
670
        def generator(labels, feature_list, matrix):
671
            for i in range(len(feature_list)):
672
                yield feature_list[i], matrix[i], labels, pval_thres
673
674
        if use_ref:
675
            key_array = list(self.dataset.matrix_ref_array.keys())
676
        else:
677
            key_array = list(self.dataset.matrix_train_array.keys())
678
679
        for key in key_array:
680
            if use_ref:
681
                feature_list = self.dataset.feature_ref_array[key][:]
682
                matrix = self.dataset.matrix_ref_array[key][:]
683
            else:
684
                feature_list = self.dataset.feature_train_array[key][:]
685
                matrix = self.dataset.matrix_train_array[key][:]
686
687
            labels = self.labels[:]
688
689
            input_list = generator(labels, feature_list, matrix.T)
690
691
            features_scored = mapf(_process_parallel_feature_importance_per_cluster, input_list)
692
            features_scored = [feat for feat_list in features_scored for feat in feat_list]
693
694
            for label, feature, median_diff, pvalue in features_scored:
695
                self.feature_scores_per_cluster[label].append((feature, median_diff, pvalue))
696
697
            for label in self.feature_scores_per_cluster:
698
                self.feature_scores_per_cluster[label].sort(key=lambda x:x[1])
699
700
    def write_feature_score_per_cluster(self):
701
        """
702
        """
703
        f_file_name = '{0}/{1}_features_scores_per_clusters.tsv'.format(
704
            self.path_results, self._project_name)
705
        f_anti_name = '{0}/{1}_features_anticorrelated_scores_per_clusters.tsv'.format(
706
            self.path_results, self._project_name)
707
708
        f_file = open(f_file_name, 'w')
709
        f_anti_file = open(f_anti_name, 'w')
710
711
        f_file.write('cluster id;feature;median diff;p-value\n')
712
713
        for label in self.feature_scores_per_cluster:
714
            for feature, median_diff, pvalue in self.feature_scores_per_cluster[label]:
715
                if median_diff > 0:
716
                    f_to_write = f_file
717
                else:
718
                    f_to_write = f_anti_file
719
720
                f_to_write.write('{0};{1};{2};{3}\n'.format(label, feature, median_diff, pvalue))
721
722
        print('{0} written'.format(f_file_name))
723
        print('{0} written'.format(f_anti_name))
724
725
    def write_feature_scores(self):
726
        """
727
        """
728
        with open('{0}/{1}_features_scores.tsv'.format(
729
                self.path_results, self.project_name), 'w') as f_file:
730
731
            for key in self.feature_scores:
732
                f_file.write('#### {0} ####\n'.format(key))
733
734
                for feature, score in self.feature_scores[key]:
735
                    f_file.write('{0};{1}\n'.format(feature, score))
736
737
            print('{0}/{1}_features_scores.tsv written'.format(
738
                self.path_results, self.project_name))
739
740
    def _return_train_matrix_for_classification(self):
741
        """
742
        """
743
        assert (self.classification_method in _CLASSIFICATION_METHOD_LIST)
744
745
        if self.verbose:
746
            print('classification method: {0}'.format(
747
                self.classification_method))
748
749
        if self.classification_method == 'SURVIVAL_FEATURES':
750
            assert(self.classifier_type != 'clustering')
751
            matrix = self._predict_survival_nodes(
752
                self.dataset.matrix_ref_array)
753
        elif self.classification_method == 'ALL_FEATURES':
754
            matrix = self._reduce_and_stack_matrices(
755
                self.dataset.matrix_ref_array)
756
757
        if self.verbose:
758
            print('number of features for the classifier: {0}'.format(
759
                matrix.shape[1]))
760
761
        return np.nan_to_num(matrix)
762
763
    def _reduce_and_stack_matrices(self, matrices):
764
        """
765
        """
766
        if not self.nb_selected_features:
767
            return hstack(matrices.values())
768
        else:
769
            self.compute_feature_scores()
770
            matrix = []
771
772
            for key in matrices:
773
                index = [self.dataset.feature_ref_index[key][feature]
774
                         for feature, pvalue in
775
                         self.feature_scores[key][:self.nb_selected_features]
776
                         if feature in self.dataset.feature_ref_index[key]
777
                ]
778
779
                matrix.append(matrices[key].T[index].T)
780
781
            return hstack(matrix)
782
783
    def fit_classification_model(self):
784
        """ """
785
        train_matrix = self._return_train_matrix_for_classification()
786
        labels = self.labels
787
788
        if self.classifier_type == 'clustering':
789
            if self.verbose:
790
                print('clustering model defined as the classifier')
791
792
            self.classifier = self.clustering
793
            return
794
795
        if self.verbose:
796
            print('classification analysis...')
797
798
        if isinstance(self.seed, int):
799
            np.random.seed(self.seed)
800
801
        with warnings.catch_warnings():
802
            warnings.simplefilter("ignore")
803
            self.classifier_grid.fit(train_matrix, labels)
804
805
        self.classifier, params = select_best_classif_params(
806
            self.classifier_grid)
807
808
        self.classifier.set_params(probability=True)
809
        self.classifier.fit(train_matrix, labels)
810
811
        self.classifier_dict[str(self.used_normalization)] = self.classifier
812
813
        if self.verbose:
814
            cvs = cross_val_score(self.classifier, train_matrix, labels, cv=5)
815
            print('best params:', params)
816
            print('cross val score: {0}'.format(np.mean(cvs)))
817
            print('classification score:', self.classifier.score(
818
                train_matrix, labels))
819
820
    def fit_classification_test_model(self):
821
        """ """
822
        is_same_features = self.used_features_for_classif == self.dataset.feature_ref_array
823
        is_same_normalization = self.used_normalization == self.test_normalization
824
        is_filled_with_zero = self.dataset.fill_unkown_feature_with_0
825
826
        if (is_same_features and is_same_normalization and is_filled_with_zero)\
827
           or self.classifier_type == 'clustering':
828
            if self.verbose:
829
                print('Not rebuilding the test classifier')
830
831
            if self.classifier_test is None:
832
                self.classifier_test = self.classifier
833
            return
834
835
        if self.verbose:
836
            print('classification for test set analysis...')
837
838
        self.used_normalization = self.dataset.normalization_test
839
        self.used_features_for_classif = self.dataset.feature_ref_array
840
841
        train_matrix = self._return_train_matrix_for_classification()
842
        labels = self.labels
843
844
        with warnings.catch_warnings():
845
            warnings.simplefilter("ignore")
846
            self.classifier_grid.fit(train_matrix, labels)
847
848
        self.classifier_test, params = select_best_classif_params(self.classifier_grid)
849
850
        self.classifier_test.set_params(probability=True)
851
        self.classifier_test.fit(train_matrix, labels)
852
853
        if self.verbose:
854
            cvs = cross_val_score(self.classifier_test, train_matrix, labels, cv=5)
855
            print('best params:', params)
856
            print('cross val score: {0}'.format(np.mean(cvs)))
857
            print('classification score:', self.classifier_test.score(train_matrix, labels))
858
859
    def predict_labels(self):
860
        """
861
        predict labels from training set
862
        using K-Means algorithm on the node activities,
863
        using only nodes linked to survival
864
        """
865
        if self.verbose:
866
            print('performing clustering on the omic model with the following key:{0}'.format(
867
                self.training_omic_list))
868
869
        if hasattr(self.cluster_method, 'fit_predict'):
870
            self.clustering = self.cluster_method(n_clusters=self.nb_clusters)
871
            self.cluster_method == 'custom'
872
873
        elif self.cluster_method == 'kmeans':
874
            self.clustering = KMeans(n_clusters=self.nb_clusters, n_init=100)
875
876
        elif self.cluster_method == 'mixture':
877
            self.clustering = GaussianMixture(
878
                n_components=self.nb_clusters,
879
                **self.mixture_params
880
            )
881
882
        elif self.cluster_method == "coxPH":
883
            nbdays, isdead = self.dataset.survival.T.tolist()
884
885
            self.clustering = ClusterWithSurvival(
886
                n_clusters=self.nb_clusters,
887
                isdead=isdead,
888
                nbdays=nbdays)
889
890
        elif self.cluster_method == "coxPHMixture":
891
            nbdays, isdead = self.dataset.survival.T.tolist()
892
893
            self.clustering = ClusterWithSurvival(
894
                n_clusters=self.nb_clusters,
895
                use_gaussian_to_dichotomize=True,
896
                isdead=isdead,
897
                nbdays=nbdays)
898
899
        else:
900
            raise(Exception("No method fit and predict found for: {0}".format(
901
                self.cluster_method)))
902
903
        if not self.activities_train.any():
904
            raise Exception('No components linked to survival!'\
905
                            ' cannot perform clustering')
906
907
        if self.cluster_array and len(self.cluster_array) > 1:
908
            self._predict_best_k_for_cluster()
909
910
        if hasattr(self.clustering, 'predict'):
911
            self.clustering.fit(self.activities_train)
912
            labels = self.clustering.predict(self.activities_train)
913
        else:
914
            labels = self.clustering.fit_predict(self.activities_train)
915
916
        labels = self._order_labels_according_to_survival(labels)
917
918
        self.labels = labels
919
920
        if hasattr(self.clustering, 'predict_proba'):
921
            self.labels_proba = self.clustering.predict_proba(self.activities_train)
922
        else:
923
            self.labels_proba = np.array([self.labels, self.labels]).T
924
925
        if len(self.labels_proba.shape) == 1:
926
            self.labels_proba = self.labels_proba.reshape((
927
                self.labels_proba.shape[0], 1))
928
929
        if self.labels_proba.shape[1] < self.nb_clusters:
930
            missing_columns = self.nb_clusters - self.labels_proba.shape[1]
931
932
            for i in range(missing_columns):
933
                self.labels_proba = hstack([
934
                    self.labels_proba, np.zeros(
935
                        shape=(self.labels_proba.shape[0], 1))])
936
937
        if self.verbose:
938
            print("clustering done, labels ordered according to survival:")
939
            for key, value in Counter(labels).items():
940
                print('cluster label: {0}\t number of samples:{1}'.format(key, value))
941
            print('\n')
942
943
        nbdays, isdead = self.dataset.survival.T.tolist()
944
945
        if self.metadata_usage in ['all', 'labels'] and \
946
           self.dataset.metadata_mat is not None:
947
            metadata_mat = self.dataset.metadata_mat
948
        else:
949
            metadata_mat = None
950
951
        pvalue = coxph(self.labels, isdead, nbdays,
952
                       isfactor=False,
953
                       do_KM_plot=self.do_KM_plot,
954
                       png_path=self.path_results,
955
                       seed=self.seed,
956
                       use_r_packages=self.use_r_packages,
957
                       metadata_mat=metadata_mat,
958
                       fig_name='{0}_KM_plot_training_dataset'.format(self.project_name))
959
960
        pvalue_proba = coxph(self.labels_proba.T[0],
961
                             isdead, nbdays,
962
                             seed=self.seed,
963
                             use_r_packages=self.use_r_packages,
964
                             metadata_mat=metadata_mat,
965
                             isfactor=False)
966
967
        if not self._isboosting:
968
            self._write_labels(self.dataset.sample_ids, self.labels,
969
                               labels_proba=self.labels_proba.T[0],
970
                               fname='{0}_training_set_labels'.format(self.project_name))
971
972
        if self.verbose:
973
            print('Cox-PH p-value (Log-Rank) for the cluster labels: {0}'.format(pvalue))
974
975
        self.train_pvalue = pvalue
976
        self.train_pvalue_proba = pvalue_proba
977
978
    def evalutate_cluster_performance(self):
979
        """
980
        """
981
        if not self.clustering:
982
            print('clustering attribute is defined as None. ' \
983
                   ' Cannot evaluate cluster performance')
984
            return
985
986
        if self.cluster_method == 'mixture':
987
            self.bic_score = self.clustering.bic(self.activities_train)
988
989
        self.silhouette_score = silhouette_score(self.activities_train, self.labels)
990
        self.calinski_score = calinski_harabaz_score(self.activities_train, self.labels)
991
992
        if self.verbose:
993
            print('silhouette score: {0}'.format(self.silhouette_score))
994
            print('calinski-harabaz score: {0}'.format(self.calinski_score))
995
            print('bic score: {0}'.format(self.bic_score))
996
997
    def _write_labels(self, sample_ids, labels, fname="",
998
                      labels_proba=None,
999
                      nbdays=None,
1000
                      isdead=None,
1001
                      path_file=None):
1002
        """ """
1003
        assert(fname or path_file)
1004
1005
        if not path_file:
1006
            path_file = '{0}/{1}.tsv'.format(self.path_results, fname)
1007
1008
        with open(path_file, 'w') as f_file:
1009
            for ids, (sample, label) in enumerate(zip(sample_ids, labels)):
1010
                suppl = ''
1011
1012
                if labels_proba is not None:
1013
                    suppl += '\t{0}'.format(labels_proba[ids])
1014
                if nbdays is not None:
1015
                    suppl += '\t{0}'.format(nbdays[ids])
1016
                if isdead is not None:
1017
                    suppl += '\t{0}'.format(isdead[ids])
1018
1019
                f_file.write('{0}\t{1}{2}\n'.format(sample, label, suppl))
1020
1021
        print('file written: {0}'.format(path_file))
1022
1023
    def _predict_survival_nodes(self, matrix_array, keys=None):
1024
        """
1025
        """
1026
        activities_array = {}
1027
1028
        if keys is None:
1029
            keys = list(matrix_array.keys())
1030
1031
        for key in keys:
1032
            matrix = matrix_array[key]
1033
            if not self._pretrained_model:
1034
                if self.alternative_embedding is  None and \
1035
                   self.encoder_input_shape(key)[1] != matrix.shape[1]:
1036
                    if self.verbose:
1037
                        print('matrix doesnt have the input dimension of the encoder'\
1038
                              ' returning None')
1039
                    return None
1040
1041
            if self.alternative_embedding is not None:
1042
                activities = self.embedding_predict(key, matrix)
1043
            elif self.use_autoencoders:
1044
                activities = self.encoder_predict(key, matrix)
1045
            else:
1046
                activities = np.asarray(matrix)
1047
1048
            activities_array[key] = activities.T[self.valid_node_ids_array[key]].T
1049
1050
        return hstack([activities_array[key]
1051
                       for key in keys])
1052
1053
    def look_for_survival_nodes(self, keys=None):
1054
        """
1055
        detect nodes from the autoencoder significantly
1056
        linked with survival through coxph regression
1057
        """
1058
        if not keys:
1059
            keys = list(self.encoder_array.keys())
1060
1061
            if not keys:
1062
                keys = self.matrix_train_array.keys()
1063
1064
        for key in keys:
1065
            matrix_train = self.matrix_train_array[key]
1066
1067
            if self.alternative_embedding is not None:
1068
                activities = self.embedding_predict(key, matrix_train)
1069
            elif self.use_autoencoders:
1070
                activities = self.encoder_predict(key, matrix_train)
1071
            else:
1072
                activities = np.asarray(matrix_train)
1073
1074
            if self.feature_surv_analysis:
1075
                valid_node_ids = self._look_for_nodes(key)
1076
            else:
1077
                valid_node_ids = np.arange(matrix_train.shape[1])
1078
1079
            self.valid_node_ids_array[key] = valid_node_ids
1080
            self.activities_array[key] = activities.T[valid_node_ids].T
1081
1082
        if self.clustering_omics:
1083
            keys = self.clustering_omics
1084
1085
        self.activities_train = hstack([self.activities_array[key]
1086
                                        for key in keys])
1087
1088
    def look_for_prediction_nodes(self, keys=None):
1089
        """
1090
        detect nodes from the autoencoder that predict a
1091
        high c-index scores using label from the retained test fold
1092
        """
1093
        if not keys:
1094
            keys = list(self.encoder_array.keys())
1095
1096
        for key in keys:
1097
            matrix_train = self.matrix_train_array[key]
1098
1099
            if self.alternative_embedding is not None:
1100
                activities = self.embedding_predict(key, matrix_train)
1101
            elif self.use_autoencoders:
1102
                activities = self.encoder_predict(key, matrix_train)
1103
            else:
1104
                activities = np.asarray(matrix_train)
1105
1106
            if self.feature_surv_analysis:
1107
                valid_node_ids = self._look_for_prediction_nodes(key)
1108
            else:
1109
                valid_node_ids = np.arange(matrix_train.shape[1])
1110
1111
            self.pred_node_ids_array[key] = valid_node_ids
1112
1113
            self.activities_pred_array[key] = activities.T[valid_node_ids].T
1114
1115
        self.activities_for_pred_train = hstack([self.activities_pred_array[key]
1116
                                                 for key in keys])
1117
1118
    def compute_c_indexes_multiple_for_test_dataset(self):
1119
        """
1120
        return c-index using labels as predicat
1121
        """
1122
        days, dead = np.asarray(self.dataset.survival).T
1123
        days_test, dead_test = np.asarray(self.dataset.survival_test).T
1124
1125
        activities_test = {}
1126
1127
        for key in self.dataset.matrix_test_array:
1128
            node_ids = self.pred_node_ids_array[key]
1129
1130
            matrix = self.dataset.matrix_test_array[key]
1131
1132
            if self.alternative_embedding is not None:
1133
                activities_test[key] = self.embedding_predict(
1134
                    key, matrix).T[node_ids].T
1135
1136
            elif self.use_autoencoders:
1137
                activities_test[key] = self.encoder_predict(
1138
                    key, matrix).T[node_ids].T
1139
1140
            else:
1141
                activities_test[key] = self.dataset.matrix_test_array[key]
1142
1143
        activities_test = hstack(activities_test.values())
1144
        activities_train = hstack([self.activities_pred_array[key]
1145
                                   for key in self.dataset.matrix_ref_array])
1146
1147
        with warnings.catch_warnings():
1148
            warnings.simplefilter("ignore")
1149
            cindex = c_index_multiple(activities_train, dead, days,
1150
                                      activities_test, dead_test, days_test,
1151
                                      seed=self.seed,)
1152
1153
        if self.verbose:
1154
            print('c-index multiple for test dataset:{0}'.format(cindex))
1155
1156
        return cindex
1157
1158
    def compute_c_indexes_multiple_for_test_fold_dataset(self):
1159
        """
1160
        return c-index using test-fold labels as predicat
1161
        """
1162
        days, dead = np.asarray(self.dataset.survival).T
1163
        days_cv, dead_cv = np.asarray(self.dataset.survival_cv).T
1164
1165
        activities_cv = {}
1166
1167
        for key in self.dataset.matrix_cv_array:
1168
            node_ids = self.pred_node_ids_array[key]
1169
1170
            if self.alternative_embedding is not None:
1171
                activities_cv[key] = self.embedding_predict(
1172
                    key, self.dataset.matrix_cv_array[key]).T[node_ids].T
1173
1174
            elif self.use_autoencoders:
1175
                activities_cv[key] = self.encoder_predict(
1176
                    key, self.dataset.matrix_cv_array[key]).T[node_ids].T
1177
1178
            else:
1179
                activities_cv[key] = self.dataset.matrix_cv_array[key]
1180
1181
        activities_cv = hstack(activities_cv.values())
1182
1183
        with warnings.catch_warnings():
1184
            warnings.simplefilter("ignore")
1185
            cindex = c_index_multiple(self.activities_for_pred_train, dead, days,
1186
                                      activities_cv, dead_cv, days_cv,
1187
                                      seed=self.seed,)
1188
1189
        if self.verbose:
1190
            print('c-index multiple for test fold dataset:{0}'.format(cindex))
1191
1192
        return cindex
1193
1194
    def _return_test_matrix_for_classification(self, activities, matrix_array):
1195
        """
1196
        """
1197
        if self.classification_method == 'SURVIVAL_FEATURES':
1198
            return activities
1199
        elif self.classification_method == 'ALL_FEATURES':
1200
            matrix = self._reduce_and_stack_matrices(matrix_array)
1201
            return matrix
1202
1203
    def _predict_test_labels(self, activities, matrix_array):
1204
        """ """
1205
        matrix_test = self._return_test_matrix_for_classification(
1206
            activities, matrix_array)
1207
1208
        self.test_labels = self.classifier_test.predict(matrix_test)
1209
        self.test_labels_proba = self.classifier_test.predict_proba(matrix_test)
1210
1211
        if self.test_labels_proba.shape[1] < self.nb_clusters:
1212
            missing_columns = self.nb_clusters - self.test_labels_proba.shape[1]
1213
1214
            for i in range(missing_columns):
1215
                self.test_labels_proba = hstack([
1216
                    self.test_labels_proba, np.zeros(
1217
                        shape=(self.test_labels_proba, 1))])
1218
1219
    def _predict_labels(self, activities, matrix_array):
1220
        """ """
1221
        matrix_test = self._return_test_matrix_for_classification(
1222
            activities, matrix_array)
1223
1224
        labels = self.classifier.predict(matrix_test)
1225
        labels_proba = self.classifier.predict_proba(matrix_test)
1226
1227
        if labels_proba.shape[1] < self.nb_clusters:
1228
            missing_columns = self.nb_clusters - labels_proba.shape[1]
1229
1230
            for i in range(missing_columns):
1231
                labels_proba = hstack([
1232
                    labels_proba, np.zeros(
1233
                        shape=(labels_proba.shape[0], 1))])
1234
1235
        return labels, labels_proba
1236
1237
    def _predict_best_k_for_cluster(self):
1238
        """ """
1239
        criterion = None
1240
        best_k = None
1241
1242
        for k_cluster in self.cluster_array:
1243
            if self.cluster_method == 'mixture':
1244
                self.clustering.set_params(n_components=k_cluster)
1245
            else:
1246
                self.clustering.set_params(n_clusters=k_cluster)
1247
1248
            labels = self.clustering.fit_predict(self.activities_train)
1249
1250
            if self.cluster_eval_method == 'bic':
1251
                score = self.clustering.bic(self.activities_train)
1252
            elif self.cluster_eval_method == 'calinski':
1253
                score = calinski_harabaz_score(
1254
                    self.activities_train,
1255
                    labels
1256
                )
1257
            elif self.cluster_eval_method == 'silhouette':
1258
                score = silhouette_score(
1259
                    self.activities_train,
1260
                    labels
1261
                )
1262
1263
            if self.verbose:
1264
                print('obtained {2}: {0} for k = {1}'.format(score, k_cluster,
1265
                                                             self.cluster_eval_method))
1266
1267
            if criterion == None or score < criterion:
1268
                criterion, best_k = score, k_cluster
1269
1270
                self.clustering_performance = criterion
1271
1272
        if self.verbose:
1273
            print('best k: {0}'.format(best_k))
1274
1275
        if self.cluster_method == 'mixture':
1276
            self.clustering.set_params(n_components=best_k)
1277
        else:
1278
            self.clustering.set_params(n_clusters=best_k)
1279
1280
    def _order_labels_according_to_survival(self, labels):
1281
        """
1282
        Order cluster labels according to survival
1283
        """
1284
        labels_old = labels.copy()
1285
1286
        days, dead = np.asarray(self.dataset.survival).T
1287
1288
        self._label_ordered_dict = {}
1289
1290
        for label in set(labels_old):
1291
            mean = surv_median(dead[labels_old == label],
1292
                             days[labels_old == label])
1293
            self._label_ordered_dict[label] = mean
1294
1295
        label_ordered = [label for label, _ in
1296
                         sorted(self._label_ordered_dict.items(), key=lambda x:x[1])]
1297
1298
        self._label_ordered_dict = {old_label: new_label
1299
                      for new_label, old_label in enumerate(label_ordered)}
1300
1301
        for old_label in self._label_ordered_dict:
1302
            labels[labels_old == old_label] = self._label_ordered_dict[old_label]
1303
1304
        return labels
1305
1306
    def _look_for_survival_nodes(self, key=None,
1307
                                 activities=None,
1308
                                 survival=None,
1309
                                 metadata_mat=None):
1310
        """
1311
        """
1312
        if key is not None:
1313
            matrix_train = self.matrix_train_array[key]
1314
1315
            if self.alternative_embedding is not None:
1316
                activities = np.nan_to_num(self.embedding_predict(
1317
                    key, matrix_train))
1318
1319
            elif self.use_autoencoders:
1320
                activities = np.nan_to_num(self.encoder_predict(
1321
                    key, matrix_train))
1322
1323
            else:
1324
                activities = np.asarray(matrix_train)
1325
        else:
1326
            assert(activities is not None)
1327
1328
        if survival is not None:
1329
            nbdays, isdead = survival.T.tolist()
1330
        else:
1331
            nbdays, isdead = self.dataset.survival.T.tolist()
1332
1333
        if self.feature_selection_usage == 'lasso':
1334
            cws = ClusterWithSurvival(
1335
                isdead=isdead,
1336
                nbdays=nbdays,
1337
                metadata_mat=metadata_mat)
1338
1339
            return cws.get_nonzero_features(activities)
1340
1341
        else:
1342
            return self._get_survival_features_parallel(
1343
                isdead, nbdays, metadata_mat, activities, key)
1344
1345
    def _get_survival_features_parallel(
1346
            self, isdead, nbdays, metadata_mat, activities, key):
1347
        """ """
1348
        pool = None
1349
1350
        if not self._isboosting:
1351
            pool = Pool(self.nb_threads_coxph)
1352
            mapf = pool.map
1353
        else:
1354
            mapf = map
1355
1356
        input_list = iter((node_id,
1357
                           activity,
1358
                           isdead,
1359
                           nbdays,
1360
                           self.seed,
1361
                           metadata_mat, self.use_r_packages)
1362
1363
                          for node_id, activity in enumerate(activities.T))
1364
1365
        pvalue_list = mapf(_process_parallel_coxph, input_list)
1366
1367
        pvalue_list = list(filter(lambda x: not np.isnan(x[1]), pvalue_list))
1368
        pvalue_list.sort(key=lambda x: x[1], reverse=True)
1369
1370
        valid_node_ids = [node_id for node_id, pvalue in pvalue_list
1371
                          if pvalue < self.pvalue_thres]
1372
1373
        if self.verbose:
1374
            print('number of components linked to survival found:{0} for key {1}'.format(
1375
                len(valid_node_ids), key))
1376
1377
        if pool is not None:
1378
            pool.close()
1379
            pool.join()
1380
1381
        return valid_node_ids
1382
1383
    def _look_for_prediction_nodes(self, key):
1384
        """
1385
        """
1386
        nbdays, isdead = self.dataset.survival.T.tolist()
1387
        nbdays_cv, isdead_cv = self.dataset.survival_cv.T.tolist()
1388
1389
        matrix_train = self.matrix_train_array[key]
1390
        matrix_cv = self.dataset.matrix_cv_array[key]
1391
1392
        if self.alternative_embedding is not None:
1393
            activities_train = self.embedding_predict(key, matrix_train)
1394
            activities_cv = self.embedding_predict(key, matrix_cv)
1395
1396
        elif self.use_autoencoders:
1397
            activities_train = self.encoder_predict(key, matrix_train)
1398
            activities_cv = self.encoder_predict(key, matrix_cv)
1399
        else:
1400
            activities_train = np.asarray( matrix_train)
1401
            activities_cv = np.asarray( matrix_cv)
1402
1403
        input_list = iter((node_id,
1404
                           activities_train.T[node_id], isdead, nbdays,
1405
                           activities_cv.T[node_id], isdead_cv, nbdays_cv, self.use_r_packages)
1406
                           for node_id in range(activities_train.shape[1]))
1407
1408
        score_list = map(_process_parallel_cindex, input_list)
1409
1410
        score_list = filter(lambda x: not np.isnan(x[1]), score_list)
1411
        score_list.sort(key=lambda x:x[1], reverse=True)
1412
1413
        valid_node_ids = [node_id for node_id, cindex in score_list
1414
                               if cindex > self.cindex_thres]
1415
1416
        scores = [score for node_id, score in score_list
1417
                  if score > self.cindex_thres]
1418
1419
        if self.verbose:
1420
            print('number of components with a high prediction score:{0} for key {1}'\
1421
                  ' \n\t mean: {2} std: {3}'.format(
1422
                      len(valid_node_ids), key, np.mean(scores), np.std(scores)))
1423
1424
        return valid_node_ids
1425
1426
    def compute_c_indexes_for_full_dataset(self):
1427
        """
1428
        return c-index using labels as predicat
1429
        """
1430
        days, dead = np.asarray(self.dataset.survival).T
1431
        days_full, dead_full = np.asarray(self.dataset.survival_full).T
1432
1433
        try:
1434
            with warnings.catch_warnings():
1435
                warnings.simplefilter("ignore")
1436
                cindex = c_index(self.labels, dead, days,
1437
                                 self.full_labels, dead_full, days_full,
1438
                                 use_r_packages=self.use_r_packages,
1439
                                 seed=self.seed,)
1440
        except Exception as e:
1441
            print('Exception while computing the c-index: {0}'.format(e))
1442
            cindex = np.nan
1443
1444
        if self.verbose:
1445
            print('c-index for full dataset:{0}'.format(cindex))
1446
1447
        return cindex
1448
1449
    def compute_c_indexes_for_training_dataset(self):
1450
        """
1451
        return c-index using labels as predicat
1452
        """
1453
        days, dead = np.asarray(self.dataset.survival).T
1454
1455
        try:
1456
            with warnings.catch_warnings():
1457
                warnings.simplefilter("ignore")
1458
                cindex = c_index(self.labels, dead, days,
1459
                                 self.labels, dead, days,
1460
                                 use_r_packages=self.use_r_packages,
1461
                                 seed=self.seed,)
1462
        except Exception as e:
1463
            print('Exception while computing the c-index: {0}'.format(e))
1464
            cindex = np.nan
1465
1466
        if self.verbose:
1467
            print('c-index for training dataset:{0}'.format(cindex))
1468
1469
        return cindex
1470
1471
    def compute_c_indexes_for_test_dataset(self):
1472
        """
1473
        return c-index using labels as predicat
1474
        """
1475
        days, dead = np.asarray(self.dataset.survival).T
1476
        days_test, dead_test = np.asarray(self.dataset.survival_test).T
1477
1478
        try:
1479
            with warnings.catch_warnings():
1480
                warnings.simplefilter("ignore")
1481
                cindex = c_index(self.labels, dead, days,
1482
                                 self.test_labels, dead_test, days_test,
1483
                                 use_r_packages=self.use_r_packages,
1484
                                 seed=self.seed,)
1485
        except Exception as e:
1486
            print('Exception while computing the c-index: {0}'.format(e))
1487
            cindex = np.nan
1488
1489
        if self.verbose:
1490
            print('c-index for test dataset:{0}'.format(cindex))
1491
1492
        return cindex
1493
1494
    def compute_c_indexes_for_test_fold_dataset(self):
1495
        """
1496
        return c-index using labels as predicat
1497
        """
1498
        with warnings.catch_warnings():
1499
            warnings.simplefilter("ignore")
1500
            days, dead = np.asarray(self.dataset.survival).T
1501
            days_cv, dead_cv= np.asarray(self.dataset.survival_cv).T
1502
1503
            try:
1504
                cindex =  c_index(self.labels, dead, days,
1505
                                  self.cv_labels, dead_cv, days_cv,
1506
                                  use_r_packages=self.use_r_packages,
1507
                                  seed=self.seed,)
1508
            except Exception as e:
1509
                print('Exception while computing the c-index: {0}'.format(e))
1510
                cindex = np.nan
1511
1512
            if self.verbose:
1513
                print('c-index for test fold dataset:{0}'.format(cindex))
1514
1515
        return cindex
1516
1517
    def predict_nodes_activities(self, matrix_array):
1518
        """
1519
        """
1520
        activities = []
1521
1522
        for key in matrix_array:
1523
            if key not in self.pred_node_ids_array:
1524
                continue
1525
1526
            node_ids = self.pred_node_ids_array[key]
1527
1528
            if self.alternative_embedding is not None:
1529
                activities.append(
1530
                    self.embedding_predict(
1531
                        key, matrix_array[key]).T[node_ids].T)
1532
            else:
1533
                activities.append(
1534
                    self.encoder_predict(
1535
                        key, matrix_array[key]).T[node_ids].T)
1536
1537
        return hstack(activities)
1538
1539
    def plot_kernel_for_test_sets(self,
1540
                                  dataset=None,
1541
                                  labels=None,
1542
                                  labels_proba=None,
1543
                                  test_labels=None,
1544
                                  test_labels_proba=None,
1545
                                  define_as_main_kernel=False,
1546
                                  use_main_kernel=False,
1547
                                  activities=None,
1548
                                  activities_test=None,
1549
                                  key=''):
1550
        """
1551
        """
1552
        from simdeep.plot_utils import plot_kernel_plots
1553
1554
        if dataset is None:
1555
            dataset = self.dataset
1556
1557
        if labels is None:
1558
            labels = self.labels
1559
1560
        if labels_proba is None:
1561
            labels_proba = self.labels_proba
1562
1563
        if test_labels_proba is None:
1564
            test_labels_proba = self.test_labels_proba
1565
1566
        if test_labels is None:
1567
            test_labels = self.test_labels
1568
1569
        if test_labels_proba is None:
1570
            test_labels_proba = self.test_labels_proba
1571
1572
        test_norm = self.test_normalization
1573
        train_norm = self.dataset.normalization
1574
        train_norm = {key: train_norm[key] for key in train_norm if train_norm[key]}
1575
1576
        is_same_normalization = train_norm == test_norm
1577
        is_filled_with_zero = self.dataset.fill_unkown_feature_with_0
1578
1579
        if activities is None or activities_test is None:
1580
            if not (is_same_normalization and is_filled_with_zero):
1581
                print('\n<><><><> Cannot plot survival KDE plot' \
1582
                      ' Different normalisation used for test set <><><><>\n')
1583
                return
1584
1585
            activities = hstack([self.activities_array[omic]
1586
                                 for omic in self.test_omic_list])
1587
            activities_test = self.activities_test
1588
1589
        if define_as_main_kernel:
1590
            self._main_kernel = {'activities': activities_test.copy(),
1591
                                 'labels': test_labels.copy()}
1592
1593
        if use_main_kernel:
1594
            activities = self._main_kernel['activities']
1595
            labels = self._main_kernel['labels']
1596
1597
        html_name = '{0}/{1}{2}_test_kdeplot.html'.format(
1598
            self.path_results,
1599
            self.project_name,
1600
            key)
1601
1602
        plot_kernel_plots(
1603
            test_labels=test_labels,
1604
            test_labels_proba=test_labels_proba,
1605
            labels=labels,
1606
            activities=activities,
1607
            activities_test=activities_test,
1608
            dataset=self.dataset,
1609
            path_html=html_name)
1610
1611
    def plot_supervised_kernel_for_test_sets(
1612
            self,
1613
            labels=None,
1614
            labels_proba=None,
1615
            dataset=None,
1616
            key='',
1617
            use_main_kernel=False,
1618
            test_labels=None,
1619
            test_labels_proba=None,
1620
            define_as_main_kernel=False,
1621
    ):
1622
        """
1623
        """
1624
        if labels is None:
1625
            labels = self.labels
1626
1627
        if labels_proba is None:
1628
            labels_proba = self.labels_proba
1629
1630
        if dataset is None:
1631
            dataset = self.dataset
1632
1633
        activities, activities_test = self._predict_kde_matrix(
1634
            labels_proba, dataset)
1635
1636
        key += '_supervised'
1637
1638
        self.plot_kernel_for_test_sets(labels=labels,
1639
                                       labels_proba=labels_proba,
1640
                                       dataset=dataset,
1641
                                       activities=activities,
1642
                                       activities_test=activities_test,
1643
                                       key=key,
1644
                                       use_main_kernel=use_main_kernel,
1645
                                       test_labels=test_labels,
1646
                                       test_labels_proba=test_labels_proba,
1647
                                       define_as_main_kernel=define_as_main_kernel,
1648
        )
1649
1650
    def _create_autoencoder_for_kernel_plot(self, labels_proba, dataset, key):
1651
        """
1652
        """
1653
        autoencoder = DeepBase(dataset=dataset,
1654
                               seed=self.seed,
1655
                               verbose=False,
1656
                               dropout=0.1,
1657
                               epochs=50)
1658
1659
        autoencoder.matrix_train_array = dataset.matrix_ref_array
1660
        autoencoder.construct_supervized_network(labels_proba)
1661
1662
        self.encoder_for_kde_plot_dict[key] = autoencoder.encoder_array
1663
1664
    def _predict_kde_matrix(self, labels_proba, dataset):
1665
        """
1666
        """
1667
        matrix_ref_list = []
1668
        matrix_test_list = []
1669
1670
        encoder_key = str(self.test_normalization)
1671
        encoder_key = 'omic:{0} normalisation: {1}'.format(
1672
            self.test_omic_list,
1673
            encoder_key)
1674
1675
        if encoder_key not in self.encoder_for_kde_plot_dict or \
1676
           not dataset.fill_unkown_feature_with_0:
1677
            self._create_autoencoder_for_kernel_plot(
1678
                labels_proba, dataset, encoder_key)
1679
1680
        encoder_array = self.encoder_for_kde_plot_dict[encoder_key]
1681
1682
        if self.metadata_usage in ['all', 'new-features'] and \
1683
           dataset.metadata_mat is not None:
1684
            metadata_mat = dataset.metadata_mat
1685
        else:
1686
            metadata_mat = None
1687
1688
        for key in encoder_array:
1689
            matrix_ref = encoder_array[key].predict(
1690
                dataset.matrix_ref_array[key])
1691
            matrix_test = encoder_array[key].predict(
1692
                dataset.matrix_test_array[key])
1693
1694
            survival_node_ids = self._look_for_survival_nodes(
1695
                activities=matrix_ref, survival=dataset.survival,
1696
                metadata_mat=metadata_mat)
1697
1698
            if len(survival_node_ids) > 1:
1699
                matrix_ref = matrix_ref.T[survival_node_ids].T
1700
                matrix_test = matrix_test.T[survival_node_ids].T
1701
            else:
1702
                print('not enough survival nodes to construct kernel for key: {0}' \
1703
                      'skipping the {0} matrix'.format(key))
1704
                continue
1705
1706
            matrix_ref_list.append(matrix_ref)
1707
            matrix_test_list.append(matrix_test)
1708
1709
        if not matrix_ref_list:
1710
            print('matrix_ref_list / matrix_test_list empty!' \
1711
                  'take the last OMIC ({0}) matrix as ref'.format(key))
1712
            matrix_ref_list.append(matrix_ref)
1713
            matrix_test_list.append(matrix_test)
1714
1715
        return hstack(matrix_ref_list), hstack(matrix_test_list)
1716
1717
1718
    def _get_probas_for_full_model(self):
1719
        """
1720
        return sample and proba
1721
        """
1722
        return list(zip(self.dataset.sample_ids_full, self.full_labels_proba))
1723
1724
1725
    def _get_pvalues_and_pvalues_proba(self):
1726
        """
1727
        """
1728
        return self.full_pvalue, self.full_pvalue_proba
1729
1730
    def _get_from_dataset(self, attr):
1731
        """
1732
        """
1733
        return getattr(self.dataset, attr)
1734
1735
    def _get_attibute(self, attr):
1736
        """
1737
        """
1738
        return getattr(self, attr)
1739
1740
1741
    def _partial_fit_model_pool(self):
1742
        """
1743
        """
1744
        try:
1745
            self.load_training_dataset()
1746
            self.fit()
1747
1748
            if len(set(self.labels)) < 1:
1749
                raise Exception('only one class!')
1750
1751
            if self.train_pvalue > MODEL_THRES:
1752
                raise Exception('pvalue: {0} not significant!'.format(self.train_pvalue))
1753
1754
        except Exception as e:
1755
            print('model with random state:{1} didn\'t converge:{0}'.format(str(e), self.seed))
1756
            return False
1757
1758
        else:
1759
            print('model with random state:{0} fitted'.format(self.seed))
1760
            self._is_fitted = True
1761
1762
        self.predict_labels_on_test_fold()
1763
        self.predict_labels_on_full_dataset()
1764
        self.evalutate_cluster_performance()
1765
1766
        return self._is_fitted
1767
1768
    def _partial_fit_model_with_pretrained_pool(self, labels_file):
1769
        """
1770
        """
1771
        self.fit_on_pretrained_label_file(labels_file)
1772
1773
        self.predict_labels_on_test_fold()
1774
        self.predict_labels_on_full_dataset()
1775
        self.evalutate_cluster_performance()
1776
1777
        self._is_fitted = True
1778
1779
        return self._is_fitted
1780
1781
    def _predict_new_dataset(self,
1782
                             tsv_dict,
1783
                             path_survival_file,
1784
                             normalization,
1785
                             survival_flag=None,
1786
                             metadata_file=None):
1787
        """
1788
        """
1789
        self.load_new_test_dataset(
1790
            tsv_dict=tsv_dict,
1791
            path_survival_file=path_survival_file,
1792
            normalization=normalization,
1793
            survival_flag=survival_flag,
1794
            metadata_file=metadata_file
1795
        )
1796
1797
        self.predict_labels_on_test_dataset()