Switch to unified view

a b/simdeep/simdeep_boosting.py
1
import warnings
2
from simdeep.simdeep_analysis import SimDeep
3
from simdeep.extract_data import LoadData
4
5
from simdeep.coxph_from_r import coxph
6
from simdeep.coxph_from_r import c_index
7
from simdeep.coxph_from_r import c_index_multiple
8
from simdeep.coxph_from_r import NALogicalType
9
10
from sklearn.model_selection import KFold
11
# from sklearn.preprocessing import OneHotEncoder
12
13
from collections import Counter
14
from collections import defaultdict
15
from itertools import combinations
16
17
import numpy as np
18
19
from scipy.stats import gmean
20
from sklearn.metrics import adjusted_rand_score
21
22
from simdeep.config import PROJECT_NAME
23
from simdeep.config import PATH_RESULTS
24
from simdeep.config import NB_THREADS
25
from simdeep.config import NB_ITER
26
from simdeep.config import NB_FOLDS
27
from simdeep.config import CLASS_SELECTION
28
from simdeep.config import NB_CLUSTERS
29
from simdeep.config import NORMALIZATION
30
from simdeep.config import EPOCHS
31
from simdeep.config import NEW_DIM
32
from simdeep.config import NB_SELECTED_FEATURES
33
from simdeep.config import PVALUE_THRESHOLD
34
from simdeep.config import CLUSTER_METHOD
35
from simdeep.config import CLASSIFICATION_METHOD
36
from simdeep.config import TRAINING_TSV
37
from simdeep.config import SURVIVAL_TSV
38
from simdeep.config import PATH_DATA
39
from simdeep.config import SURVIVAL_FLAG
40
from simdeep.config import NODES_SELECTION
41
from simdeep.config import CINDEX_THRESHOLD
42
from simdeep.config import USE_AUTOENCODERS
43
from simdeep.config import FEATURE_SURV_ANALYSIS
44
from simdeep.config import CLUSTERING_OMICS
45
from simdeep.config import USE_R_PACKAGES_FOR_SURVIVAL
46
47
# Parameter for autoencoder
48
from simdeep.config import LEVEL_DIMS_IN
49
from simdeep.config import LEVEL_DIMS_OUT
50
from simdeep.config import LOSS
51
from simdeep.config import OPTIMIZER
52
from simdeep.config import ACT_REG
53
from simdeep.config import W_REG
54
from simdeep.config import DROPOUT
55
from simdeep.config import ACTIVATION
56
from simdeep.config import PATH_TO_SAVE_MODEL
57
from simdeep.config import DATA_SPLIT
58
from simdeep.config import MODEL_THRES
59
60
from multiprocessing import Pool
61
62
from simdeep.deepmodel_base import DeepBase
63
64
import simplejson
65
66
from distutils.dir_util import mkpath
67
68
from os.path import isdir
69
from os import mkdir
70
71
from glob import glob
72
73
import gc
74
75
from time import time
76
77
from numpy import hstack
78
from numpy import vstack
79
80
import pandas as pd
81
82
from simdeep.survival_utils import \
83
    _process_parallel_feature_importance_per_cluster
84
from simdeep.survival_utils import \
85
    _process_parallel_survival_feature_importance_per_cluster
86
87
88
89
class SimDeepBoosting():
90
    """
91
    Instanciate a new DeepProg Boosting instance.
92
    The default parameters are defined in the config.py file
93
94
    Parameters:
95
            :nb_it: Number of models to construct
96
            :do_KM_plot: Plot Kaplan-Meier (default: True)
97
            :distribute: Distribute DeepProg using ray (default:  False)
98
            :nb_threads: Number of python threads to use to compute parallel Cox-PH
99
            :class_selection: Consensus score to agglomerate DeepProg Instance {'mean', 'max', 'weighted_mean', 'weighted_max'} (default: 'mean')
100
            :model_thres: Cox-PH p-value threshold to reject a model for DeepProg Boosting module
101
            :verbose: Verobosity (Default: True)
102
            :seed: Seed defining the  random split of the training dataset (Default: None).
103
            :project_name: Project name used to save files
104
            :use_autoencoders: Use autoencoder steps to embed the data (default: True)
105
            :feature_surv_analysis: Use individual survival feature detection to filter out features (default: True)
106
            :split_n_fold: For each instance, the original dataset is split in folds and one fold is left
107
            :path_results: Path to create a result folder
108
            :nb_clusters: Number of clusters to use
109
            :epochs: Number of epochs
110
            :normalization: Normalisation procedure to use. See config.py file for details
111
            :nb_selected_features: Number of top features selected for classification
112
            :cluster_method: Clustering method. possible choice: ['mixture', 'kmeans', 'coxPH'] or class instance having fit and fit_proba attributes
113
            :pvalue_thres: Threshold for survival significance to set a node as valid
114
            :classification_method: Possible choice: {'ALL_FEATURES', 'SURVIVAL_FEATURES'} (default: 'ALL_FEATURES')
115
            :new_dim: Size of the new embedding
116
            :training_tsv: Input matrix files
117
            :survival_tsv: Input surival file
118
            :survival_flag: Survival flag to use
119
            :path_data: Path of the input file
120
            :level_dims_in: Autoencoder node layers before the middle layer (default: [])
121
            :level_dims_out: Autoencoder node layers after the middle layer (default: [])
122
            :loss: Loss function to minimize (default: 'binary_crossentropy')
123
            :optimizer: Optimizer (default: adam)
124
            :act_reg: L2 Regularization constant on the node activity (default: False)
125
            :w_reg: L1 Regularization constant on the weight (default: False)
126
            :dropout: Percentage of edges being dropout at each training iteration (None for no dropout) (default: 0.5)
127
            :data_split: Fraction of the dataset to be used as test set when building the autoencoder (default: None)
128
            :node_selection: possible choice: {'Cox-PH', 'C-index'} (default: Cox-PH)
129
            :cindex_thres: Valid if 'c-index' is chosen (default: 0.65)
130
            :activation: Activation function (default: 'tanh')
131
            :clustering_omics: Which omics to use for clustering. If empty, then all the available omics will be used (default [] => all)
132
            :path_to_save_model: path to save the model
133
            :metadata_usage: Meta data usage with survival models (if metadata_tsv provided as argument to the dataset). Possible choice are [None, False, 'labels', 'new-features', 'all', True] (True is the same as all)
134
            :subset_training_with_meta: Use a metadata key-value dict {meta_key:value} to subset the training sets
135
            :alternative_embedding: alternative external embedding to use instead of building autoencoders (default None)
136
            :kwargs_alternative_embedding: parameters for external embedding fitting
137
    """
138
    def __init__(self,
139
                 nb_it=NB_ITER,
140
                 do_KM_plot=True,
141
                 distribute=False,
142
                 nb_threads=NB_THREADS,
143
                 class_selection=CLASS_SELECTION,
144
                 model_thres=MODEL_THRES,
145
                 verbose=True,
146
                 seed=None,
147
                 project_name='{0}_boosting'.format(PROJECT_NAME),
148
                 use_autoencoders=USE_AUTOENCODERS,
149
                 feature_surv_analysis=FEATURE_SURV_ANALYSIS,
150
                 split_n_fold=NB_FOLDS,
151
                 path_results=PATH_RESULTS,
152
                 nb_clusters=NB_CLUSTERS,
153
                 epochs=EPOCHS,
154
                 normalization=NORMALIZATION,
155
                 nb_selected_features=NB_SELECTED_FEATURES,
156
                 cluster_method=CLUSTER_METHOD,
157
                 pvalue_thres=PVALUE_THRESHOLD,
158
                 classification_method=CLASSIFICATION_METHOD,
159
                 new_dim=NEW_DIM,
160
                 training_tsv=TRAINING_TSV,
161
                 metadata_usage=None,
162
                 survival_tsv=SURVIVAL_TSV,
163
                 metadata_tsv=None,
164
                 subset_training_with_meta={},
165
                 survival_flag=SURVIVAL_FLAG,
166
                 path_data=PATH_DATA,
167
                 level_dims_in=LEVEL_DIMS_IN,
168
                 level_dims_out=LEVEL_DIMS_OUT,
169
                 loss=LOSS,
170
                 optimizer=OPTIMIZER,
171
                 act_reg=ACT_REG,
172
                 w_reg=W_REG,
173
                 dropout=DROPOUT,
174
                 data_split=DATA_SPLIT,
175
                 node_selection=NODES_SELECTION,
176
                 cindex_thres=CINDEX_THRESHOLD,
177
                 activation=ACTIVATION,
178
                 clustering_omics=CLUSTERING_OMICS,
179
                 path_to_save_model=PATH_TO_SAVE_MODEL,
180
                 feature_selection_usage='individual',
181
                 use_r_packages=USE_R_PACKAGES_FOR_SURVIVAL,
182
                 alternative_embedding=None,
183
                 kwargs_alternative_embedding={},
184
                 **additional_dataset_args):
185
        """ """
186
        assert(class_selection in ['max', 'mean', 'weighted_mean', 'weighted_max'])
187
        self.class_selection = class_selection
188
189
        self._instance_weights = None
190
        self.distribute = distribute
191
        self.model_thres = model_thres
192
        self.models = []
193
        self.verbose = verbose
194
        self.nb_threads = nb_threads
195
        self.do_KM_plot = do_KM_plot
196
        self.project_name = project_name
197
        self._project_name = project_name
198
        self.path_results = '{0}/{1}'.format(path_results, project_name)
199
        self.training_tsv = training_tsv
200
        self.survival_tsv = survival_tsv
201
        self.survival_flag = survival_flag
202
        self.path_data = path_data
203
        self.dataset = None
204
        self.cindex_thres = cindex_thres
205
        self.node_selection = node_selection
206
        self.clustering_omics = clustering_omics
207
        self.metadata_tsv = metadata_tsv
208
        self.metadata_usage = metadata_usage
209
        self.feature_selection_usage = feature_selection_usage
210
        self.subset_training_with_meta = subset_training_with_meta
211
        self.use_r_packages = use_r_packages
212
213
        self.metadata_mat_full = None
214
215
        self.cluster_method = cluster_method
216
        self.use_autoencoders = use_autoencoders
217
        self.feature_surv_analysis = feature_surv_analysis
218
219
        if self.feature_selection_usage is None:
220
            self.feature_surv_analysis = False
221
222
        self.encoder_for_kde_plot_dict = {}
223
        self.kde_survival_node_ids = {}
224
        self.kde_train_matrices = {}
225
226
        if not isdir(self.path_results):
227
            try:
228
                mkpath(self.path_results)
229
            except Exception:
230
                print('cannot find or create the current result path: {0}' \
231
                      '\n consider changing it as option' \
232
                      .format(self.path_results))
233
234
        self.test_tsv_dict = None
235
        self.test_survival_file = None
236
        self.test_normalization = None
237
238
        self.test_labels = None
239
        self.test_labels_proba = None
240
241
        self.cv_labels = None
242
        self.cv_labels_proba = None
243
        self.full_labels = None
244
        self.full_labels_dicts = None
245
        self.full_labels_proba = None
246
        self.survival_full = None
247
        self.sample_ids_full = None
248
        self.feature_scores_per_cluster = {}
249
        self.survival_feature_scores_per_cluster = {}
250
251
        self._pretrained_fit = False
252
253
        self.log = {}
254
255
        self.alternative_embedding = alternative_embedding
256
        self.kwargs_alternative_embedding = kwargs_alternative_embedding
257
258
        ######## deepprog instance parameters ########
259
        self.nb_clusters = nb_clusters
260
        self.normalization = normalization
261
        self.epochs = epochs
262
        self.new_dim = new_dim
263
        self.nb_selected_features = nb_selected_features
264
        self.pvalue_thres = pvalue_thres
265
        self.cluster_method = cluster_method
266
        self.cindex_test_folds = []
267
        self.classification_method = classification_method
268
        ##############################################
269
270
        self.test_fname_key = ''
271
        self.matrix_with_cv_array = None
272
273
        autoencoder_parameters = {
274
            'epochs': self.epochs,
275
            'new_dim': self.new_dim,
276
            'level_dims_in': level_dims_in,
277
            'level_dims_out': level_dims_out,
278
            'loss': loss,
279
            'optimizer': optimizer,
280
            'act_reg': act_reg,
281
            'w_reg': w_reg,
282
            'dropout': dropout,
283
            'data_split': data_split,
284
            'activation': activation,
285
            'path_to_save_model': path_to_save_model,
286
        }
287
288
        self.datasets = []
289
        self.seed = seed
290
291
        self.log['parameters'] = {}
292
293
        for arg in self.__dict__:
294
            self.log['parameters'][arg] = str(self.__dict__[arg])
295
296
        self.log['seed'] = seed
297
        self.log['parameters'].update(autoencoder_parameters)
298
299
        self.log['nb_it'] = nb_it
300
        self.log['normalization'] = normalization
301
        self.log['nb clusters'] = nb_clusters
302
        self.log['success'] = False
303
        self.log['survival_tsv'] = self.survival_tsv
304
        self.log['metadata_tsv'] = self.metadata_tsv
305
        self.log['subset_training_with_meta'] = self.subset_training_with_meta
306
        self.log['training_tsv'] = self.training_tsv
307
        self.log['path_data'] = self.path_data
308
309
        additional_dataset_args['survival_tsv'] = self.survival_tsv
310
        additional_dataset_args['metadata_tsv'] = self.metadata_tsv
311
        additional_dataset_args['subset_training_with_meta'] = self.subset_training_with_meta
312
        additional_dataset_args['training_tsv'] = self.training_tsv
313
        additional_dataset_args['path_data'] = self.path_data
314
        additional_dataset_args['survival_flag'] = self.survival_flag
315
316
        if 'fill_unkown_feature_with_0' in additional_dataset_args:
317
            self.log['fill_unkown_feature_with_0'] = additional_dataset_args[
318
                'fill_unkown_feature_with_0']
319
320
        self.ray = None
321
322
        self._init_datasets(nb_it, split_n_fold,
323
                            autoencoder_parameters,
324
                            **additional_dataset_args)
325
326
    def _init_datasets(self, nb_it, split_n_fold,
327
                       autoencoder_parameters,
328
                       **additional_dataset_args):
329
        """
330
        """
331
        if self.seed:
332
            np.random.seed(self.seed)
333
        else:
334
            self.seed = np.random.randint(0, 10000000)
335
336
        max_seed = 1000
337
        min_seed = 0
338
339
        if self.seed > max_seed:
340
            min_seed = self.seed - max_seed
341
            max_seed = self.seed
342
343
        np.random.seed(self.seed)
344
        random_states = np.random.randint(min_seed, max_seed, nb_it)
345
346
        self.split_n_fold = split_n_fold
347
348
        for it in range(nb_it):
349
            if self.split_n_fold:
350
                split = KFold(n_splits=split_n_fold,
351
                              shuffle=True, random_state=random_states[it])
352
            else:
353
                split = None
354
355
            autoencoder_parameters['seed'] = random_states[it]
356
357
            dataset = LoadData(cross_validation_instance=split,
358
                               verbose=False,
359
                               normalization=self.normalization,
360
                               _autoencoder_parameters=autoencoder_parameters.copy(),
361
                               **additional_dataset_args)
362
363
            self.datasets.append(dataset)
364
365
    def __del__(self):
366
        """
367
        """
368
        for model in self.models:
369
            del model
370
371
        try:
372
            gc.collect()
373
        except Exception as e:
374
            print('Warning: Exception {0} from garbage collector. continuing... '.format(
375
                e))
376
377
    def _from_models(self, fname, *args, **kwargs):
378
        """
379
        """
380
        if self.distribute:
381
            return self.ray.get([getattr(model, fname).remote(*args, **kwargs)
382
                                 for model in self.models])
383
        else:
384
            return [getattr(model, fname)(*args, **kwargs)
385
                    for model in self.models]
386
387
388
    def _from_model(self, model, fname, *args, **kwargs):
389
        """
390
        """
391
        if self.distribute:
392
            return self.ray.get(getattr(model, fname).remote(
393
                *args, **kwargs))
394
        else:
395
            return getattr(model, fname)(*args, **kwargs)
396
397
    def _from_model_attr(self, model, atname):
398
        """
399
        """
400
        if self.distribute:
401
            return self.ray.get(model._get_attibute.remote(atname))
402
        else:
403
            return model._get_attibute(atname)
404
405
    def _from_models_attr(self, atname):
406
        """
407
        """
408
        if self.distribute:
409
            return self.ray.get([model._get_attibute.remote(atname)
410
                                 for model in self.models])
411
        else:
412
            return [model._get_attibute(atname) for model in self.models]
413
414
    def _from_model_dataset(self, model, atname):
415
        """
416
        """
417
        if self.distribute:
418
            return self.ray.get(model._get_from_dataset.remote(atname))
419
        else:
420
            return model._get_from_dataset(atname)
421
422
423
    def _do_class_selection(self, inputs, **kwargs):
424
        """
425
        """
426
        if self.class_selection == 'max':
427
            return  _highest_proba(inputs)
428
        elif self.class_selection == 'mean':
429
            return _mean_proba(inputs)
430
        elif self.class_selection == 'weighted_mean':
431
            return _weighted_mean(inputs, **kwargs)
432
        elif self.class_selection == 'weighted_max':
433
            return _weighted_max(inputs, **kwargs)
434
435
    def partial_fit(self, debug=False):
436
        """
437
        """
438
        self._fit(debug=debug)
439
440
    def fit_on_pretrained_label_file(
441
            self,
442
            labels_files=[],
443
            labels_files_folder="",
444
            file_name_regex="*.tsv",
445
            verbose=False,
446
            debug=False,
447
    ):
448
        """
449
        fit a deepprog simdeep models without training autoencoders but using isntead  ID->labels files (one for each model instance)
450
        """
451
        assert(isinstance((labels_files), list))
452
453
        if not labels_files and not labels_files_folder:
454
            raise Exception(
455
                '## Error with fit_on_pretrained_label_file: ' \
456
                ' either labels_files or labels_files_folder should be non empty')
457
458
        if not labels_files:
459
            labels_files = glob('{0}/{1}'.format(labels_files_folder,
460
                                                 file_name_regex))
461
462
        if not labels_files:
463
            raise Exception('## Error: labels_files empty')
464
465
        self.fit(
466
            verbose=verbose,
467
            debug=debug,
468
            pretrained_labels_files=labels_files)
469
470
    def fit(self, debug=False, verbose=False, pretrained_labels_files=[]):
471
        """
472
        if pretrained_labels_files, is given, the models are constructed using these labels
473
        """
474
        with warnings.catch_warnings():
475
            warnings.simplefilter("ignore")
476
            if pretrained_labels_files:
477
                self._pretrained_fit = True
478
            else:
479
                self._pretrained_fit = False
480
481
            if self.distribute:
482
                self._fit_distributed(
483
                    pretrained_labels_files=pretrained_labels_files)
484
            else:
485
                self._fit(
486
                    debug=debug,
487
                    verbose=verbose,
488
                    pretrained_labels_files=pretrained_labels_files)
489
490
    def _fit_distributed(self, pretrained_labels_files=[]):
491
        """ """
492
        print('fit models...')
493
        start_time = time()
494
495
        from simdeep.simdeep_distributed import SimDeepDistributed
496
        import ray
497
        assert(ray.is_initialized())
498
        self.ray = ray
499
500
        try:
501
            self.models = [SimDeepDistributed.remote(
502
                nb_clusters=self.nb_clusters,
503
                nb_selected_features=self.nb_selected_features,
504
                pvalue_thres=self.pvalue_thres,
505
                dataset=dataset,
506
                load_existing_models=False,
507
                verbose=dataset.verbose,
508
                _isboosting=True,
509
                do_KM_plot=False,
510
                cluster_method=self.cluster_method,
511
                clustering_omics=self.clustering_omics,
512
                use_autoencoders=self.use_autoencoders,
513
                use_r_packages=self.use_r_packages,
514
                feature_surv_analysis=self.feature_surv_analysis,
515
                path_results=self.path_results,
516
                project_name=self.project_name,
517
                classification_method=self.classification_method,
518
                cindex_thres=self.cindex_thres,
519
                alternative_embedding=self.alternative_embedding,
520
                kwargs_alternative_embedding=self.kwargs_alternative_embedding,
521
                node_selection=self.node_selection,
522
                metadata_usage=self.metadata_usage,
523
                feature_selection_usage=self.feature_selection_usage,
524
                deep_model_additional_args=dataset._autoencoder_parameters)
525
                           for dataset in self.datasets]
526
527
            if pretrained_labels_files:
528
                nb_files = len(pretrained_labels_files)
529
                if nb_files < len(self.models):
530
                    print(
531
                        'Number of pretrained label files' \
532
                        ' inferior to number of instance{0}'.format(
533
                            nb_files))
534
                    self.models = self.models[:nb_files]
535
536
                results = ray.get([
537
                    model._partial_fit_model_with_pretrained_pool.remote(
538
                        labels)
539
                    for model, labels in zip(self.models,
540
                                             pretrained_labels_files)])
541
            else:
542
                results = ray.get([
543
                    model._partial_fit_model_pool.remote()
544
                    for model in self.models])
545
546
            print("Results: {0}".format(results))
547
            self.models = [model for model, is_fitted in zip(self.models, results) if is_fitted]
548
549
            nb_models = len(self.models)
550
551
            print('{0} models fitted'.format(nb_models))
552
            self.log['nb. models fitted'] = nb_models
553
554
            assert(nb_models)
555
556
        except Exception as e:
557
            self.log['failure'] = str(e)
558
            raise e
559
560
        else:
561
            self.log['success'] = True
562
            self.log['fitting time (s)'] = time() - start_time
563
564
            if self.class_selection in ['weighted_mean', 'weighted_max']:
565
                self.collect_cindex_for_test_fold()
566
567
    def _fit(self, debug=False, verbose=False, pretrained_labels_files=[]):
568
        """
569
        if pretrained_labels_files, is given, the models are constructed using these labels
570
        """
571
        print('fit models...')
572
        start_time = time()
573
574
        try:
575
            self.models = [SimDeep(
576
                nb_clusters=self.nb_clusters,
577
                nb_selected_features=self.nb_selected_features,
578
                pvalue_thres=self.pvalue_thres,
579
                dataset=dataset,
580
                load_existing_models=False,
581
                verbose=dataset.verbose,
582
                _isboosting=True,
583
                do_KM_plot=False,
584
                cluster_method=self.cluster_method,
585
                use_autoencoders=self.use_autoencoders,
586
                feature_surv_analysis=self.feature_surv_analysis,
587
                path_results=self.path_results,
588
                project_name=self.project_name,
589
                cindex_thres=self.cindex_thres,
590
                node_selection=self.node_selection,
591
                metadata_usage=self.metadata_usage,
592
                use_r_packages=self.use_r_packages,
593
                feature_selection_usage=self.feature_selection_usage,
594
                alternative_embedding=self.alternative_embedding,
595
                kwargs_alternative_embedding=self.kwargs_alternative_embedding,
596
                classification_method=self.classification_method,
597
                deep_model_additional_args=dataset._autoencoder_parameters)
598
                           for dataset in self.datasets]
599
600
            if pretrained_labels_files:
601
                nb_files = len(pretrained_labels_files)
602
                if nb_files < len(self.models):
603
                    print(
604
                        'Number of pretrained label files' \
605
                        ' inferior to number of instance{0}'.format(
606
                        nb_files))
607
                    self.models = self.models[:nb_files]
608
609
                results = [
610
                    model._partial_fit_model_with_pretrained_pool(labels)
611
                    for model, labels in zip(self.models, pretrained_labels_files)]
612
            else:
613
                results = [model._partial_fit_model_pool() for model in self.models]
614
615
            print("Results: {0}".format(results))
616
            self.models = [model for model, is_fitted in zip(self.models, results) if is_fitted]
617
618
            nb_models = len(self.models)
619
620
            print('{0} models fitted'.format(nb_models))
621
            self.log['nb. models fitted'] = nb_models
622
623
            assert(nb_models)
624
625
        except Exception as e:
626
            self.log['failure'] = str(e)
627
            raise e
628
629
        else:
630
            self.log['success'] = True
631
            self.log['fitting time (s)'] = time() - start_time
632
633
            if self.class_selection in ['weighted_mean', 'weighted_max']:
634
                self.collect_cindex_for_test_fold()
635
636
    def predict_labels_on_test_dataset(self):
637
        """
638
        """
639
        print('predict labels on test datasets...')
640
        test_labels_proba = np.asarray(self._from_models_attr(
641
            'test_labels_proba'))
642
643
        res = self._do_class_selection(
644
            test_labels_proba,
645
            weights=self.cindex_test_folds)
646
        self.test_labels, self.test_labels_proba = res
647
648
        print('#### Report of assigned cluster for TEST dataset {0}:'.format(
649
            self.test_fname_key))
650
        for key, value in sorted(Counter(self.test_labels).items()):
651
            print('class: {0}, number of samples :{1}'.format(key, value))
652
653
        nbdays, isdead = self._from_model_dataset(self.models[0], "survival_test").T.tolist()
654
655
        if np.isnan(nbdays).all():
656
            return np.nan, np.nan
657
658
        pvalue, pvalue_proba, pvalue_cat = self._compute_test_coxph(
659
            'KM_plot_boosting_test',
660
            nbdays, isdead,
661
            self.test_labels, self.test_labels_proba,
662
            self.project_name)
663
664
        self.log['pvalue test {0}'.format(self.test_fname_key)] = pvalue
665
        self.log['pvalue proba test {0}'.format(self.test_fname_key)] = pvalue_proba
666
        self.log['pvalue cat test {0}'.format(self.test_fname_key)] = pvalue_cat
667
668
        sample_id_test = self._from_model_dataset(self.models[0], 'sample_ids_test')
669
670
        self._from_model(self.models[0], '_write_labels',
671
                         sample_id_test,
672
                         self.test_labels,
673
                         '{0}_test_labels'.format(self.project_name),
674
                         labels_proba=self.test_labels_proba.T[0],
675
                         nbdays=nbdays, isdead=isdead)
676
677
        return pvalue, pvalue_proba
678
679
    def compute_pvalue_for_merged_test_fold(self):
680
        """
681
        """
682
        print('predict labels on test fold datasets...')
683
684
        isdead_cv, nbdays_cv, labels_cv = [], [], []
685
686
        if self.metadata_usage in ['all', 'labels'] and \
687
           self.metadata_tsv:
688
            metadata_mat = []
689
        else:
690
            metadata_mat = None
691
692
        for model in self.models:
693
            survival_cv = self._from_model_dataset(model, 'survival_cv')
694
695
            if survival_cv is None:
696
                print('No survival dataset for CV fold returning')
697
                return
698
699
            nbdays, isdead = survival_cv.T.tolist()
700
            nbdays_cv += nbdays
701
            isdead_cv += isdead
702
            labels_cv += self._from_model_attr(model, "cv_labels").tolist()
703
704
            if metadata_mat is not None:
705
                meta2 = self._from_model_dataset(model, 'metadata_mat_cv')
706
707
                if not len(metadata_mat):
708
                    metadata_mat = meta2
709
                else:
710
                    metadata_mat = pd.concat([metadata_mat, meta2])
711
712
                metadata_mat = metadata_mat.fillna(0)
713
714
        pvalue = coxph(
715
            labels_cv, isdead_cv, nbdays_cv,
716
            isfactor=False,
717
            do_KM_plot=self.do_KM_plot,
718
            png_path=self.path_results,
719
            fig_name='cv_analysis', seed=self.seed,
720
            use_r_packages=self.use_r_packages,
721
            metadata_mat=metadata_mat
722
        )
723
724
        print('Pvalue for test fold concatenated: {0}'.format(pvalue))
725
        self.log['pvalue cv test'] = pvalue
726
727
        return pvalue
728
729
    def collect_pvalue_on_test_fold(self):
730
        """
731
        """
732
        print('predict labels on test fold datasets...')
733
        pvalues, pvalues_proba = [], []
734
        with warnings.catch_warnings():
735
            warnings.simplefilter("ignore")
736
737
            for model in self.models:
738
                pvalues.append(self._from_model_attr(model, 'cp_pvalue'))
739
                pvalues_proba.append(self._from_model_attr(model, 'cp_pvalue_proba'))
740
741
            pvalue_gmean, pvalue_proba_gmean = gmean(pvalues), gmean(pvalues_proba)
742
743
            if self.verbose:
744
                print('geo mean pvalues: {0} geo mean pvalues probas: {1}'.format(
745
                    pvalue_gmean, pvalue_proba_gmean))
746
747
            self.log['pvalue geo mean test fold'] = pvalue_gmean
748
            self.log['pvalue proba geo mean test fold'] = pvalue_proba_gmean
749
750
        return pvalues, pvalues_proba
751
752
    def collect_pvalue_on_training_dataset(self):
753
        """
754
        """
755
        print('predict labels on training datasets...')
756
        pvalues, pvalues_proba = [], []
757
        with warnings.catch_warnings():
758
            warnings.simplefilter("ignore")
759
760
            for model in self.models:
761
                pvalues.append(self._from_model_attr(model, 'train_pvalue'))
762
                pvalues_proba.append(self._from_model_attr(model, 'train_pvalue_proba'))
763
764
            pvalue_gmean, pvalue_proba_gmean = gmean(pvalues), gmean(pvalues_proba)
765
766
            if self.verbose:
767
                print('training geo mean pvalues: {0} geo mean pvalues probas: {1}'.format(
768
                    pvalue_gmean, pvalue_proba_gmean))
769
770
            self.log['pvalue geo mean train'] = pvalue_gmean
771
            self.log['pvalue proba geo mean train'] = pvalue_proba_gmean
772
773
        return pvalues, pvalues_proba
774
775
    def collect_pvalue_on_test_dataset(self):
776
        """
777
        """
778
        print('collect pvalues on test datasets...')
779
780
        pvalues, pvalues_proba = [], []
781
782
        for model in self.models:
783
            pvalues.append(self._from_model_attr(model, 'test_pvalue'))
784
            pvalues_proba.append(self._from_model_attr(model, 'test_pvalue_proba'))
785
786
        pvalue_gmean, pvalue_proba_gmean = gmean(pvalues), gmean(pvalues_proba)
787
788
        if self.verbose:
789
            print('test geo mean pvalues: {0} geo mean pvalues probas: {1}'.format(
790
                pvalue_gmean, pvalue_proba_gmean))
791
792
        self.log['pvalue geo mean test {0}'.format(self.test_fname_key)] = pvalue_gmean
793
        self.log['pvalue proba geo mean test {0}'.format(
794
            self.test_fname_key)] = pvalue_proba_gmean
795
796
        return pvalues, pvalues_proba
797
798
    def collect_pvalue_on_full_dataset(self):
799
        """
800
        """
801
        print('collect pvalues on full datasets...')
802
803
        pvalues, pvalues_proba = zip(*self._from_models('_get_pvalues_and_pvalues_proba'))
804
        pvalue_gmean, pvalue_proba_gmean = gmean(pvalues), gmean(pvalues_proba)
805
806
        if self.verbose:
807
            print('full geo mean pvalues: {0} geo mean pvalues probas: {1}'.format(
808
                pvalue_gmean, pvalue_proba_gmean))
809
810
        self.log['pvalue geo mean full'] = pvalue_gmean
811
        self.log['pvalue proba geo mean full'] = pvalue_proba_gmean
812
813
        return pvalues, pvalues_proba
814
815
    def collect_number_of_features_per_omic(self):
816
        """
817
        """
818
        counter = defaultdict(list)
819
        self.log['number of features per omics'] = {}
820
821
        for model in self.models:
822
            valid_node_ids_array = self._from_model_attr(model, 'valid_node_ids_array')
823
            for key in valid_node_ids_array:
824
                counter[key].append(len(valid_node_ids_array[key]))
825
826
        if self.verbose:
827
            for key in counter:
828
                print('key:{0} mean: {1} std: {2}'.format(
829
                    key, np.mean(counter[key]), np.std(counter[key])))
830
831
                self.log['number of features per omics'][key] = float(np.mean(counter[key]))
832
833
        return counter
834
835
836
    def collect_cindex_for_test_fold(self):
837
        """
838
        """
839
        self.cindex_test_folds = []
840
841
        with warnings.catch_warnings():
842
            warnings.simplefilter("ignore")
843
            self._from_models('predict_labels_on_test_fold')
844
        try:
845
            cindexes = self._from_models('compute_c_indexes_for_test_fold_dataset')
846
        except Exception as e:
847
            print('Exception while computing the c-index for test fold: {0}'.format(e))
848
            return np.nan
849
850
        for cindex in cindexes:
851
            if np.isnan(cindex) or isinstance(cindex, NALogicalType):
852
                cindex = np.nan
853
854
            self.cindex_test_folds.append(cindex)
855
856
        if self.verbose:
857
            mean, std = np.nanmean(self.cindex_test_folds), np.nanstd(self.cindex_test_folds)
858
            print('C-index results for test fold: mean {0} std {1}'.format(mean, std))
859
860
        self.log['c-indexes test fold (mean)'] = np.mean(mean)
861
862
        return self.cindex_test_folds
863
864
865
    def collect_cindex_for_full_dataset(self):
866
        """
867
        """
868
        with warnings.catch_warnings():
869
            warnings.simplefilter("ignore")
870
            self._from_models('predict_labels_on_test_fold')
871
        try:
872
            cindexes_list = self._from_models('compute_c_indexes_for_full_dataset')
873
        except Exception as e:
874
            print('Exception while computing the c-index for full dataset: {0}'.format(e))
875
            return np.nan
876
877
        if self.verbose:
878
            print('c-index results for full dataset: mean {0} std {1}'.format(
879
                np.mean(cindexes_list), np.std(cindexes_list)))
880
881
        self.log['c-indexes full (mean)'] = np.mean(cindexes_list)
882
883
        return cindexes_list
884
885
886
    def collect_cindex_for_training_dataset(self):
887
        """
888
        """
889
        try:
890
            cindexes_list = self._from_models('compute_c_indexes_for_training_dataset')
891
        except Exception as e:
892
            print('Exception while computing the c-index for training dataset: {0}'.format(e))
893
            self.log['c-indexes train (mean)'] = np.nan
894
            return np.nan
895
896
        if self.verbose:
897
            print('C-index results for training dataset: mean {0} std {1}'.format(
898
                np.mean(cindexes_list), np.std(cindexes_list)))
899
900
        self.log['c-indexes train (mean)'] = np.mean(cindexes_list)
901
902
        return cindexes_list
903
904
    def collect_cindex_for_test_dataset(self):
905
        """
906
        """
907
        try:
908
            cindexes_list = self._from_models('compute_c_indexes_for_test_dataset')
909
        except Exception as e:
910
            print('Exception while computing the c-index for test dataset: {0}'.format(e))
911
            self.log['C-index test {0}'.format(self.test_fname_key)] = np.nan
912
            return np.nan
913
914
        if self.verbose:
915
            print('C-index results for test: mean {0} std {1}'.format(
916
                np.mean(cindexes_list), np.std(cindexes_list)))
917
918
        self.log['C-index test {0}'.format(self.test_fname_key)] = np.mean(cindexes_list)
919
920
        return cindexes_list
921
922
    def predict_labels_on_full_dataset(self):
923
        """
924
        """
925
        print('predict labels on full datasets...')
926
        self._get_probas_for_full_models()
927
        self._reorder_survival_full_and_metadata()
928
929
        print('#### Report of assigned cluster for the full training dataset:')
930
        for key, value in sorted(Counter(self.full_labels).items()):
931
            print('class: {0}, number of samples :{1}'.format(key, value))
932
933
        nbdays, isdead = self.survival_full.T.tolist()
934
935
936
        pvalue, pvalue_proba, pvalue_cat = self._compute_test_coxph(
937
            'KM_plot_boosting_full',
938
            nbdays, isdead,
939
            self.full_labels, self.full_labels_proba,
940
            self._project_name)
941
942
        self.log['pvalue full'] = pvalue
943
        self.log['pvalue proba full'] = pvalue_proba
944
        self.log['pvalue cat full'] = pvalue_cat
945
946
        self._from_model(self.models[0], '_write_labels',
947
                          self.sample_ids_full,
948
                          self.full_labels,
949
                          '{0}_full_labels'.format(self._project_name),
950
                          labels_proba=self.full_labels_proba.T[0],
951
                          nbdays=nbdays, isdead=isdead)
952
953
        return pvalue, pvalue_proba
954
955
    def compute_clusters_consistency_for_full_labels(self):
956
        """
957
        """
958
        scores = []
959
960
        for model_1, model_2 in combinations(self.models, 2):
961
            full_labels_1_old = self._from_model_attr(model_1, 'full_labels')
962
            full_labels_2_old = self._from_model_attr(model_2, 'full_labels')
963
964
            full_ids_1 = self._from_model_dataset(model_1, 'sample_ids_full')
965
            full_ids_2 = self._from_model_dataset(model_2, 'sample_ids_full')
966
967
            full_labels_1 = _reorder_labels(full_labels_1_old, full_ids_1)
968
            full_labels_2 = _reorder_labels(full_labels_2_old, full_ids_2)
969
970
            scores.append(adjusted_rand_score(full_labels_1,
971
                                              full_labels_2))
972
        print('Adj. Rand scores for full label: mean: {0} std: {1}'.format(
973
            np.mean(scores), np.std(scores)))
974
975
        self.log['Adj. Rand scores'] = np.mean(scores)
976
977
        return scores
978
979
    def compute_clusters_consistency_for_test_labels(self):
980
        """
981
        """
982
        scores = []
983
984
        for model_1, model_2 in combinations(self.models, 2):
985
            scores.append(adjusted_rand_score(
986
                self._from_model_attr(model_1, 'test_labels'),
987
                self._from_model_attr(model_2, 'test_labels'),
988
            ))
989
        print('Adj. Rand scores for test label: mean: {0} std: {1}'.format(
990
            np.mean(scores), np.std(scores)))
991
992
        self.log['Adj. Rand scores test {0}'.format(self.test_fname_key)] = np.mean(scores)
993
994
        return scores
995
996
    def _reorder_survival_full_and_metadata(self):
997
        """
998
        """
999
        survival_old = self._from_model_dataset(self.models[0], 'survival_full')
1000
        sample_ids = self._from_model_dataset(self.models[0], 'sample_ids_full')
1001
1002
        surv_dict = {sample: surv for sample, surv in zip(sample_ids, survival_old)}
1003
1004
        self.survival_full = np.asarray([np.asarray(surv_dict[sample])[0]
1005
                                         for sample in self.sample_ids_full])
1006
1007
        metadata = self._from_model_dataset(self.models[0], 'metadata_mat_full')
1008
1009
        if metadata is not None:
1010
1011
            index_dict = {sample: pos for pos, sample in enumerate(sample_ids)}
1012
            index = np.asarray([index_dict[sample] for sample in self.sample_ids_full])
1013
            self.metadata_mat_full = metadata.T[index].T
1014
1015
    def _reorder_matrix_full(self):
1016
        """
1017
        """
1018
        sample_ids = self._from_model_dataset(self.models[0], 'sample_ids_full')
1019
        index_dict = {sample: ids for ids, sample in enumerate(sample_ids)}
1020
        index = [index_dict[sample] for sample in self.sample_ids_full]
1021
1022
        self.matrix_with_cv_array = self._from_model_dataset(
1023
            self.models[0], 'matrix_array').copy()
1024
        matrix_cv_unormalized_array = self._from_model_dataset(
1025
            self.models[0],
1026
            'matrix_cv_unormalized_array')
1027
1028
        for key in self.matrix_with_cv_array:
1029
            if len(matrix_cv_unormalized_array):
1030
                self.matrix_with_cv_array[key] = vstack(
1031
                    [self.matrix_with_cv_array[key],
1032
                     matrix_cv_unormalized_array[key]])
1033
1034
            self.matrix_with_cv_array[key] = self.matrix_with_cv_array[key][index]
1035
1036
    def _get_probas_for_full_models(self):
1037
        """
1038
        """
1039
        proba_dict = defaultdict(list)
1040
1041
        for sample_proba in self._from_models('_get_probas_for_full_model'):
1042
            sample_set = set()
1043
1044
            for sample, proba in sample_proba:
1045
                if sample in sample_set:
1046
                    continue
1047
1048
                proba_dict[sample].append([np.nan_to_num(proba).tolist()])
1049
                sample_set.add(sample)
1050
1051
        labels, probas = self._do_class_selection(hstack(list(proba_dict.values())),
1052
                                                 weights=self.cindex_test_folds)
1053
1054
        self.full_labels = np.asarray(labels)
1055
        self.full_labels_proba = probas
1056
1057
        self.sample_ids_full = list(proba_dict.keys())
1058
1059
    def _compute_test_coxph(self, fname_base, nbdays,
1060
                            isdead, labels, labels_proba,
1061
                            project_name, metadata_mat=None):
1062
        """ """
1063
        pvalue = coxph(
1064
            labels, isdead, nbdays,
1065
            isfactor=False,
1066
            do_KM_plot=self.do_KM_plot,
1067
            png_path=self.path_results,
1068
            fig_name='{0}_{1}'.format(project_name, fname_base),
1069
            use_r_packages=self.use_r_packages,
1070
            metadata_mat=metadata_mat,
1071
            seed=self.seed)
1072
1073
        if self.verbose:
1074
            print('Cox-PH p-value (Log-Rank) for inferred labels: {0}'.format(pvalue))
1075
1076
        pvalue_proba = coxph(
1077
            labels_proba.T[0],
1078
            isdead, nbdays,
1079
            isfactor=False,
1080
            use_r_packages=self.use_r_packages,
1081
            metadata_mat=metadata_mat,
1082
            seed=self.seed)
1083
1084
        if self.verbose:
1085
            print('Cox-PH proba p-value (Log-Rank) for inferred labels: {0}'.format(pvalue_proba))
1086
1087
        labels_categorical = self._labels_proba_to_labels(labels_proba)
1088
1089
        pvalue_cat = coxph(
1090
            labels_categorical, isdead, nbdays,
1091
            isfactor=False,
1092
            do_KM_plot=self.do_KM_plot,
1093
            png_path=self.path_results,
1094
            use_r_packages=self.use_r_packages,
1095
            fig_name='{0}_proba_{1}'.format(project_name, fname_base),
1096
            metadata_mat=metadata_mat,
1097
            seed=self.seed)
1098
1099
        if self.verbose:
1100
            print('Cox-PH categorical p-value (Log-Rank) for inferred labels: {0}'.format(
1101
                pvalue_cat))
1102
1103
        return pvalue, pvalue_proba, pvalue_cat
1104
1105
    def _labels_proba_to_labels(self, labels_proba):
1106
        """
1107
        """
1108
        probas = labels_proba.T[0]
1109
        labels = np.zeros(len(probas))
1110
        nb_clusters = labels_proba.shape[1]
1111
1112
        for cluster in range(nb_clusters):
1113
            percentile = 100 * (1.0 - 1.0 / (cluster + 1.0))
1114
            value = np.percentile(probas, percentile)
1115
            labels[probas >= value] = nb_clusters - cluster
1116
1117
        return labels
1118
1119
    def compute_c_indexes_for_test_dataset(self):
1120
        """
1121
        return c-index using labels as predicat
1122
        """
1123
        days_full, dead_full = np.asarray(self.survival_full).T
1124
        days_test, dead_test = self._from_model_dataset(self.models[0], 'survival_test').T
1125
1126
        if np.isnan(days_test).all():
1127
            print("Cannot compute C-index for test dataset. Need test survival file")
1128
            return
1129
1130
        labels_test_categorical = self._labels_proba_to_labels(self.test_labels_proba)
1131
1132
        with warnings.catch_warnings():
1133
            warnings.simplefilter("ignore")
1134
            if isinstance(days_test, np.matrix):
1135
                days_test = np.asarray(days_test)[0]
1136
                dead_test = np.asarray(dead_test)[0]
1137
1138
        cindex = c_index(self.full_labels, dead_full, days_full,
1139
                         self.test_labels, dead_test, days_test,
1140
                         use_r_packages=self.use_r_packages,
1141
                         seed=self.seed)
1142
1143
        cindex_cat = c_index(self.full_labels, dead_full, days_full,
1144
                             labels_test_categorical, dead_test, days_test,
1145
                             use_r_packages=self.use_r_packages,
1146
                             seed=self.seed)
1147
1148
        cindex_proba = c_index(self.full_labels_proba.T[0], dead_full, days_full,
1149
                               self.test_labels_proba.T[0], dead_test, days_test,
1150
                               use_r_packages=self.use_r_packages,
1151
                               seed=self.seed)
1152
1153
        if self.verbose:
1154
            print('c-index for boosting test dataset:{0}'.format(cindex))
1155
            print('c-index proba for boosting test dataset:{0}'.format(cindex_proba))
1156
            print('c-index cat for boosting test dataset:{0}'.format(cindex_cat))
1157
1158
        self.log['c-index test boosting {0}'.format(self.test_fname_key)] = cindex
1159
        self.log['c-index proba test boosting {0}'.format(self.test_fname_key)] = cindex_proba
1160
        self.log['c-index cat test boosting {0}'.format(self.test_fname_key)] = cindex_cat
1161
1162
        return cindex
1163
1164
    def compute_c_indexes_for_full_dataset(self):
1165
        """
1166
        return c-index using labels as predicat
1167
        """
1168
        days_full, dead_full = np.asarray(self.survival_full).T
1169
        labels_categorical = self._labels_proba_to_labels(self.full_labels_proba)
1170
1171
        cindex = c_index(self.full_labels, dead_full, days_full,
1172
                         self.full_labels, dead_full, days_full,
1173
                         use_r_packages=self.use_r_packages,
1174
                         seed=self.seed)
1175
1176
        cindex_cat = c_index(labels_categorical, dead_full, days_full,
1177
                             labels_categorical, dead_full, days_full,
1178
                             use_r_packages=self.use_r_packages,
1179
                             seed=self.seed)
1180
1181
        cindex_proba = c_index(self.full_labels_proba.T[0], dead_full, days_full,
1182
                               self.full_labels_proba.T[0], dead_full, days_full,
1183
                               use_r_packages=self.use_r_packages,
1184
                               seed=self.seed)
1185
1186
        if self.verbose:
1187
            print('c-index for boosting full dataset:{0}'.format(cindex))
1188
            print('c-index proba for boosting full dataset:{0}'.format(cindex_proba))
1189
            print('c-index cat for boosting full dataset:{0}'.format(cindex_cat))
1190
1191
        self.log['c-index full boosting {0}'.format(self.test_fname_key)] = cindex
1192
        self.log['c-index proba full boosting {0}'.format(self.test_fname_key)] = cindex_proba
1193
        self.log['c-index cat full boosting {0}'.format(self.test_fname_key)] = cindex_cat
1194
1195
        return cindex
1196
1197
    def compute_c_indexes_multiple_for_test_dataset(self):
1198
        """
1199
        Not Functionnal !
1200
        """
1201
        print('not funtionnal!')
1202
        return
1203
1204
        matrix_array_train = self._from_model_dataset(self.models[0], 'matrix_ref_array')
1205
        matrix_array_test = self._from_model_dataset(self.models[0], 'matrix_test_array')
1206
1207
        nbdays, isdead = self._from_model_dataset(self.models[0],
1208
                                                  'survival').T.tolist()
1209
        nbdays_test, isdead_test = self._from_model_dataset(self.models[0],
1210
                                                            'survival_test').T.tolist()
1211
1212
        activities_train, activities_test = [], []
1213
1214
        for model in self.models:
1215
            activities_train.append(model.predict_nodes_activities(matrix_array_train))
1216
            activities_test.append(model.predict_nodes_activities(matrix_array_test))
1217
1218
        activities_train = hstack(activities_train)
1219
        activities_test = hstack(activities_test)
1220
1221
        cindex = c_index_multiple(
1222
            activities_train, isdead, nbdays,
1223
            activities_test, isdead_test, nbdays_test, seed=self.seed)
1224
1225
        print('total number of survival features: {0}'.format(activities_train.shape[1]))
1226
        print('cindex multiple for test set: {0}:'.format(cindex))
1227
1228
        self.log['c-index multiple test {0}'.format(self.test_fname_key)] = cindex
1229
        self.log['Number of survival features {0}'.format(
1230
            self.test_fname_key)] = activities_train.shape[1]
1231
1232
        return cindex
1233
1234
    def plot_supervised_predicted_labels_for_test_sets(
1235
            self,
1236
            define_as_main_kernel=False,
1237
            use_main_kernel=False):
1238
        """
1239
        """
1240
        print('#### plotting supervised labels....')
1241
1242
        self._from_model(self.models[0], "plot_supervised_kernel_for_test_sets",
1243
                         define_as_main_kernel=define_as_main_kernel,
1244
                         use_main_kernel=use_main_kernel,
1245
                         test_labels_proba=self.test_labels_proba,
1246
                         test_labels=self.test_labels,
1247
                         key='_' + self.test_fname_key)
1248
1249
    def plot_supervised_kernel_for_test_sets(self):
1250
        """
1251
        """
1252
        from simdeep.plot_utils import plot_kernel_plots
1253
1254
        if self.verbose:
1255
            print('plotting survival features using autoencoder...')
1256
1257
        encoder_key = self._create_autoencoder_for_kernel_plot()
1258
        activities, activities_test = self._predict_kde_matrices(
1259
            encoder_key, self.dataset.matrix_test_array)
1260
1261
        html_name = '{0}/{1}_{2}_supervised_kdeplot.html'.format(
1262
            self.path_results,
1263
            self.project_name,
1264
            self.test_fname_key)
1265
1266
        plot_kernel_plots(
1267
            test_labels=self.test_labels,
1268
            test_labels_proba=self.test_labels_proba,
1269
            labels=self.full_labels,
1270
            activities=activities,
1271
            activities_test=activities_test,
1272
            dataset=self.dataset,
1273
            path_html=html_name)
1274
1275
    def _predict_kde_survival_nodes_for_train_matrices(self, encoder_key):
1276
        """
1277
        """
1278
        self.kde_survival_node_ids = {}
1279
        encoder_array = self.encoder_for_kde_plot_dict[encoder_key]
1280
1281
        for key in encoder_array:
1282
            encoder = encoder_array[key]
1283
            matrix_ref = encoder.predict(self.dataset.matrix_ref_array[key])
1284
1285
            survival_node_ids = self._from_model(self.models[0], '_look_for_survival_nodes',
1286
                activities=matrix_ref, survival=self.dataset.survival)
1287
1288
            self.kde_survival_node_ids[key] = survival_node_ids
1289
            self.kde_train_matrices[key] = matrix_ref
1290
1291
    def _predict_kde_matrices(self, encoder_key,
1292
                              matrix_test_array):
1293
        """
1294
        """
1295
        matrix_test_list = []
1296
        matrix_ref_list = []
1297
1298
        encoder_array = self.encoder_for_kde_plot_dict[encoder_key]
1299
1300
        for key in matrix_test_array:
1301
            encoder = encoder_array[key]
1302
            matrix_test = encoder.predict(matrix_test_array[key])
1303
            matrix_ref = self.kde_train_matrices[key]
1304
1305
            survival_node_ids = self.kde_survival_node_ids[key]
1306
1307
            if len(survival_node_ids) > 1:
1308
                matrix_test = matrix_test.T[survival_node_ids].T
1309
                matrix_ref = matrix_ref.T[survival_node_ids].T
1310
            else:
1311
                if self.verbose:
1312
                    print('not enough survival nodes to construct kernel for key: {0}' \
1313
                          'skipping the {0} matrix'.format(key))
1314
                continue
1315
1316
            matrix_ref_list.append(matrix_ref)
1317
            matrix_test_list.append(matrix_test)
1318
1319
        if not matrix_ref_list:
1320
            if self.verbose:
1321
                print('\n<!><!><!><!><!><!><!><!><!><!><!><!><!><!><!><!><!>\n' \
1322
                      ' matrix_ref_list / matrix_test_list empty!' \
1323
                      'take the last OMIC ({0}) matrix as ref \n' \
1324
                      '<!><!><!><!><!><!><!><!><!><!><!><!><!><!><!><!><!><!>\n'.format(key))
1325
            matrix_ref_list.append(matrix_ref)
1326
            matrix_test_list.append(matrix_test)
1327
1328
        return hstack(matrix_ref_list), hstack(matrix_test_list)
1329
1330
    def _create_autoencoder_for_kernel_plot(self):
1331
        """
1332
        """
1333
        key_normalization = {
1334
            key: self.test_normalization[key]
1335
            for key in self.test_normalization
1336
            if self.test_normalization[key]
1337
        }
1338
1339
        encoder_key = str(key_normalization)
1340
        encoder_key = 'omic:{0} normalisation: {1}'.format(
1341
            list(self.test_tsv_dict.keys()),
1342
            encoder_key)
1343
1344
        if encoder_key in self.encoder_for_kde_plot_dict:
1345
            if self.verbose:
1346
                print('loading test data for plotting...')
1347
1348
            self.dataset.load_new_test_dataset(
1349
                tsv_dict=self.test_tsv_dict,
1350
                path_survival_file=self.test_survival_file,
1351
                normalization=self.test_normalization)
1352
1353
            return encoder_key
1354
1355
        self.dataset = LoadData(
1356
            cross_validation_instance=None,
1357
            training_tsv=self.training_tsv,
1358
            survival_tsv=self.survival_tsv,
1359
            metadata_tsv=self.metadata_tsv,
1360
            survival_flag=self.survival_flag,
1361
            path_data=self.path_data,
1362
            verbose=False,
1363
            normalization=self.test_normalization,
1364
            subset_training_with_meta=self.subset_training_with_meta
1365
        )
1366
1367
        if self.verbose:
1368
            print('preparing data for plotting...')
1369
1370
        self.dataset.load_array()
1371
        self.dataset.load_survival()
1372
        self.dataset.load_meta_data()
1373
        self.dataset.subset_training_sets()
1374
        self.dataset.reorder_matrix_array(self.sample_ids_full)
1375
        self.dataset.create_a_cv_split()
1376
        self.dataset.normalize_training_array()
1377
1378
        self.dataset.load_new_test_dataset(
1379
            tsv_dict=self.test_tsv_dict,
1380
            path_survival_file=self.test_survival_file,
1381
            normalization=self.test_normalization)
1382
1383
        if self.verbose:
1384
            print('fitting autoencoder for plotting...')
1385
1386
        autoencoder = DeepBase(dataset=self.dataset,
1387
                               seed=self.seed,
1388
                               verbose=False,
1389
                               dropout=0.1,
1390
                               epochs=50)
1391
1392
        autoencoder.matrix_train_array = self.dataset.matrix_ref_array
1393
1394
        # label_encoded = OneHotEncoder().fit_transform(
1395
        #     self.full_labels.reshape(-1, 1)).todense()
1396
1397
        # autoencoder.construct_supervized_network(label_encoded)
1398
1399
        autoencoder.construct_supervized_network(self.full_labels_proba)
1400
1401
        self.encoder_for_kde_plot_dict[encoder_key] = autoencoder.encoder_array
1402
1403
        if self.verbose:
1404
            print('fitting done!')
1405
1406
        self._predict_kde_survival_nodes_for_train_matrices(encoder_key)
1407
1408
        return encoder_key
1409
1410
    def load_new_test_dataset(self, tsv_dict,
1411
                              fname_key=None,
1412
                              path_survival_file=None,
1413
                              normalization=None,
1414
                              debug=False,
1415
                              verbose=False,
1416
                              survival_flag=None,
1417
                              metadata_file=None
1418
                              ):
1419
        """
1420
        """
1421
        self.test_tsv_dict = tsv_dict
1422
        self.test_survival_file = path_survival_file
1423
1424
        if normalization is None:
1425
            normalization = self.normalization
1426
1427
        self.test_normalization = normalization
1428
1429
        if debug or self.nb_threads < 2:
1430
            pass
1431
        # for model in self.models:
1432
        # model.verbose = True
1433
        # model.dataset.verbose = True
1434
1435
        self.test_fname_key = fname_key
1436
1437
        print("Loading new test dataset {0} ...".format(
1438
            self.test_fname_key))
1439
        t_start = time()
1440
        with warnings.catch_warnings():
1441
            warnings.simplefilter("ignore")
1442
            self._from_models('_predict_new_dataset',
1443
                              tsv_dict=tsv_dict,
1444
                              path_survival_file=path_survival_file,
1445
                              normalization=normalization,
1446
                              survival_flag=survival_flag,
1447
                              metadata_file=metadata_file
1448
            )
1449
1450
        print("Test dataset {1} loaded in {0} s".format(
1451
            time() - t_start, self.test_fname_key))
1452
1453
        if fname_key:
1454
            self.project_name = '{0}_{1}'.format(self._project_name, fname_key)
1455
1456
    def compute_survival_feature_scores_per_cluster(self,
1457
                                                    pval_thres=0.001,
1458
                                                    use_meta=False):
1459
        """
1460
        """
1461
        print('computing survival feature importance per cluster...')
1462
        pool = Pool(self.nb_threads)
1463
        mapf = pool.map
1464
1465
        if (self.metadata_usage in ['all', 'new-features'] or use_meta) and \
1466
           self.metadata_mat_full is not None:
1467
            metadata_mat = self.metadata_mat_full
1468
        else:
1469
            metadata_mat = None
1470
1471
        for label in set(self.full_labels):
1472
            self.survival_feature_scores_per_cluster[label] = []
1473
1474
        feature_dict = self._from_model_dataset(self.models[0], 'feature_array')
1475
1476
        def generator(feature_list, matrix, feature_index):
1477
            for feat in feature_list:
1478
                i = feature_index[feat[0]]
1479
                yield (feat,
1480
                       np.asarray(matrix[i]).reshape(-1),
1481
                       self.survival_full,
1482
                       metadata_mat,
1483
                       pval_thres,
1484
                       self.use_r_packages)
1485
1486
        for key in self.matrix_with_cv_array:
1487
            feature_index = {feat: i for i, feat in enumerate(feature_dict[key])}
1488
1489
            for label in self.feature_scores_per_cluster:
1490
                matrix = self.matrix_with_cv_array[key][:]
1491
1492
                feature_list =  self.feature_scores_per_cluster[label]
1493
                feature_list = [feat for feat in feature_list
1494
                                if feat[0] in feature_index]
1495
1496
                input_list = generator(feature_list, matrix.T, feature_index)
1497
1498
                features_scored = mapf(
1499
                    _process_parallel_survival_feature_importance_per_cluster,
1500
                    input_list)
1501
1502
                for feature, pvalue in features_scored:
1503
                    if feature is not None:
1504
                        self.survival_feature_scores_per_cluster[label].append(
1505
                            (feature, pvalue))
1506
1507
                if label in self.survival_feature_scores_per_cluster:
1508
                    self.survival_feature_scores_per_cluster[label].sort(
1509
                        key=lambda x: x[1])
1510
1511
    def compute_feature_scores_per_cluster(self, pval_thres=0.001):
1512
        """
1513
        """
1514
        print('computing feature importance per cluster...')
1515
        self._reorder_matrix_full()
1516
        mapf = map
1517
1518
        for label in set(self.full_labels):
1519
            self.feature_scores_per_cluster[label] = []
1520
1521
        def generator(labels, feature_list, matrix):
1522
            for i in range(len(feature_list)):
1523
                yield feature_list[i], matrix[i], labels, pval_thres
1524
1525
        feature_dict = self._from_model_dataset(self.models[0], 'feature_array')
1526
1527
        for key in self.matrix_with_cv_array:
1528
            matrix = self.matrix_with_cv_array[key][:]
1529
            labels = self.full_labels[:]
1530
1531
            input_list = generator(labels, feature_dict[key], matrix.T)
1532
1533
            features_scored = mapf(
1534
                _process_parallel_feature_importance_per_cluster, input_list)
1535
            features_scored = [feat for feat_list in features_scored
1536
                               for feat in feat_list]
1537
1538
            for label, feature, median_diff, pvalue in features_scored:
1539
                self.feature_scores_per_cluster[label].append((
1540
                    feature, median_diff, pvalue))
1541
1542
            for label in self.feature_scores_per_cluster:
1543
                self.feature_scores_per_cluster[label].sort(
1544
                    key=lambda x: x[2])
1545
1546
    def write_feature_score_per_cluster(self):
1547
        """
1548
        """
1549
        f_file_name = '{0}/{1}_features_scores_per_clusters.tsv'.format(
1550
            self.path_results, self._project_name)
1551
        f_anti_name = '{0}/{1}_features_anticorrelated_scores_per_clusters.tsv'.format(
1552
            self.path_results, self._project_name)
1553
1554
        f_file = open(f_file_name, 'w')
1555
        f_anti_file = open(f_anti_name, 'w')
1556
        f_file.write('#label\tfeature\tmedian difference\tp-value\n')
1557
        f_anti_file.write('#label\tfeature\tmedian difference\tp-value\n')
1558
1559
        f_file.write('cluster id\tfeature\tmedian diff\tWilcoxon p-value\n')
1560
1561
        for label in self.feature_scores_per_cluster:
1562
            for feature, median_diff, pvalue in self.feature_scores_per_cluster[label]:
1563
                if median_diff > 0:
1564
                    f_to_write = f_file
1565
                else:
1566
                    f_to_write = f_anti_file
1567
1568
                f_to_write.write('{0}\t{1}\t{2}\t{3}\n'.format(
1569
                    label, feature, median_diff, pvalue))
1570
1571
        print('{0} written'.format(f_file_name))
1572
        print('{0} written'.format(f_anti_name))
1573
1574
        if self.survival_feature_scores_per_cluster:
1575
            f_file_name = '{0}/{1}_survival_features_scores_per_clusters.tsv'.format(
1576
                self.path_results, self._project_name)
1577
            f_to_write = open(f_file_name, 'w')
1578
            f_to_write.write(
1579
                '#label\tfeature\tmedian difference\tcluster logrank p-value\tCoxPH Log-rank p-value\n')
1580
1581
            for label in self.survival_feature_scores_per_cluster:
1582
                for feature, pvalue in self.survival_feature_scores_per_cluster[label]:
1583
                    f_to_write.write('{0}\t{1}\t{2}\t{3}\t{4}\n'.format(
1584
                        label, feature[0], feature[1], feature[2], pvalue))
1585
1586
            print('{0} written'.format(f_file_name))
1587
        else:
1588
            print("No survival features detected. File: {0} not writtten".format(f_file_name))
1589
1590
    def evalutate_cluster_performance(self):
1591
        """
1592
        """
1593
        if self._pretrained_fit:
1594
            print('model is fitted on pretrained labels' \
1595
                  ' Cannot evaluate cluster performance')
1596
            return
1597
1598
        bic_scores = np.array([self._from_model_attr(model, 'bic_score') for model in self.models])
1599
1600
        if bic_scores[0] is not None:
1601
            bic = np.nanmean(bic_scores)
1602
            print('bic score: mean: {0} std :{1}'.format(bic_scores.mean(), bic_scores.std()
1603
            ))
1604
            self.log['bic'] = bic
1605
        else:
1606
            bic = np.nan
1607
1608
        silhouette_scores = np.array([self._from_model_attr(model, 'silhouette_score')
1609
                                      for model in self.models])
1610
        silhouette = silhouette_scores.mean()
1611
        print('silhouette score: mean: {0} std :{1}'.format(silhouette,
1612
                                                            silhouette_scores.std()
1613
        ))
1614
        self.log['silhouette'] = silhouette
1615
1616
        calinski_scores = np.array([self._from_model_attr(model, 'calinski_score')
1617
                                    for model in self.models])
1618
1619
        calinski = calinski_scores.mean()
1620
        print('calinski harabasz score: mean: {0} std :{1}'.format(calinski_scores.mean(),
1621
                                                                   calinski_scores.std()
1622
        ))
1623
        self.log['calinski'] = calinski
1624
1625
        return bic, silhouette, calinski
1626
1627
    def save_cv_models_classes(self, path_results=""):
1628
        """
1629
        """
1630
        self.save_models_classes(path_results=path_results,
1631
                                 use_cv_labels=True)
1632
1633
    def save_test_models_classes(self, path_results=""):
1634
        """
1635
        """
1636
        self.save_models_classes(path_results=path_results,
1637
                                 use_test_labels=True)
1638
1639
    def save_models_classes(self, path_results="",
1640
                            use_cv_labels=False,
1641
                            use_test_labels=False):
1642
        """
1643
        """
1644
        if not path_results:
1645
            if use_test_labels:
1646
                path_results = '{0}/saved_models_test_classes'.format(
1647
                    self.path_results)
1648
            elif use_cv_labels:
1649
                path_results = '{0}/saved_models_cv_classes'.format(
1650
                    self.path_results)
1651
            else:
1652
                path_results = '{0}/saved_models_classes'.format(
1653
                    self.path_results)
1654
1655
        if not isdir(path_results):
1656
            mkdir(path_results)
1657
1658
        for i, model in enumerate(self.models):
1659
            if use_test_labels:
1660
                labels = self._from_model_attr(model, 'test_labels')
1661
                labels_proba = self._from_model_attr(model, 'test_labels_proba')
1662
                sample_ids = self._from_model_dataset(model, 'sample_ids_test')
1663
                survival = self._from_model_dataset(model, 'survival_test')
1664
            elif use_cv_labels:
1665
                labels = self._from_model_attr(model, 'cv_labels')
1666
                labels_proba = self._from_model_attr(model, 'cv_labels_proba')
1667
                sample_ids = self._from_model_dataset(model, 'sample_ids_cv')
1668
                survival = self._from_model_dataset(model, 'survival_cv')
1669
            else:
1670
                labels = self._from_model_attr(model, 'labels')
1671
                labels_proba = self._from_model_attr(model, 'labels_proba')
1672
                sample_ids = self._from_model_dataset(model, 'sample_ids')
1673
                survival = self._from_model_dataset(model, 'survival')
1674
1675
            seed = self._from_model_attr(model, 'seed')
1676
1677
            nbdays, isdead = survival.T.tolist()
1678
1679
            if not seed:
1680
                seed = i
1681
1682
            path_file = '{0}/model_instance_{1}.tsv'.format(
1683
                path_results, seed)
1684
1685
            labels_proba = labels_proba.T[0]
1686
1687
            self._from_model(
1688
                model, '_write_labels',
1689
                sample_ids,
1690
                labels,
1691
                path_file=path_file,
1692
                labels_proba=labels_proba,
1693
                nbdays=nbdays, isdead=isdead)
1694
1695
        print('individual model labels saved at: {0}'.format(path_results))
1696
1697
    def _convert_logs(self):
1698
        """
1699
        """
1700
        for key in self.log:
1701
            if isinstance(self.log[key], np.float32):
1702
                self.log[key] = float(self.log[key])
1703
            elif isinstance(self.log[key], NALogicalType):
1704
                self.log[key] = None
1705
            elif pd.isna(self.log[key]):
1706
                self.log[key] = None
1707
            try:
1708
                str(self.log[key])
1709
            except Exception:
1710
                self.log.pop(key)
1711
1712
    def write_logs(self):
1713
        """
1714
        """
1715
        self._convert_logs()
1716
1717
        with open('{0}/{1}.log.json'.format(self.path_results, self._project_name), 'w') as f:
1718
            f.write(simplejson.dumps(self.log, indent=2))
1719
1720
1721
def _highest_proba(proba):
1722
    """
1723
    """
1724
    labels = []
1725
    probas = []
1726
1727
    clusters = range(proba.shape[2])
1728
    samples = range(proba.shape[1])
1729
1730
    for sample in samples:
1731
        proba_vector = [proba.T[cluster][sample].max() for cluster in clusters]
1732
        label = max(enumerate(proba_vector), key=lambda x:x[1])[0]
1733
1734
        labels.append(label)
1735
        probas.append(proba_vector)
1736
1737
    return np.asarray(labels), np.asarray(probas)
1738
1739
def _mean_proba(proba):
1740
    """
1741
    """
1742
    labels = []
1743
    probas = []
1744
1745
    clusters = range(proba.shape[2])
1746
    samples = range(proba.shape[1])
1747
1748
    for sample in samples:
1749
        proba_vector = [proba.T[cluster][sample].mean() for cluster in clusters]
1750
        label = max(enumerate(proba_vector), key=lambda x:x[1])[0]
1751
1752
        labels.append(label)
1753
        probas.append(proba_vector)
1754
1755
    return np.asarray(labels), np.asarray(probas)
1756
1757
def _weighted_mean(proba, weights):
1758
    """
1759
    """
1760
    labels = []
1761
    probas = []
1762
    weights = np.array(weights)
1763
    weights[weights < 0.50] = 0.0
1764
    weights = np.power(weights, 4)
1765
1766
    if weights.sum() == 0:
1767
        weights[:] = 1.0
1768
1769
    clusters = range(proba.shape[2])
1770
    samples = range(proba.shape[1])
1771
1772
    for sample in samples:
1773
        proba_vector = [np.average(proba.T[cluster][sample]) for cluster in clusters]
1774
        label = max(enumerate(proba_vector), key=lambda x:x[1])[0]
1775
1776
        labels.append(label)
1777
        probas.append(proba_vector)
1778
1779
    return np.asarray(labels), np.asarray(probas)
1780
1781
1782
def _weighted_max(proba, weights):
1783
    """
1784
    """
1785
    labels = []
1786
    probas = []
1787
    weights = np.array(weights)
1788
    weights[weights < 0.50] = 0.0
1789
    weights = np.power(weights, 4)
1790
1791
    if weights.sum() == 0:
1792
        weights[:] = 1.0
1793
1794
    clusters = range(proba.shape[2])
1795
    samples = range(proba.shape[1])
1796
1797
    for sample in samples:
1798
        proba_vector = [np.max(proba.T[cluster][sample] * weights) for cluster in clusters]
1799
        label = max(enumerate(proba_vector), key=lambda x:x[1])[0]
1800
1801
        labels.append(label)
1802
        probas.append(proba_vector)
1803
1804
    return np.asarray(labels), np.asarray(probas)
1805
1806
1807
def _reorder_labels(labels, sample_ids):
1808
    """
1809
    """
1810
    sample_dict = {sample: id for id, sample in enumerate(sample_ids)}
1811
    sample_ordered = set(sample_ids)
1812
1813
    index = [sample_dict[sample] for sample in sample_ordered]
1814
1815
    return labels[index]