Diff of /simdeep/extract_data.py [000000] .. [53737a]

Switch to unified view

a b/simdeep/extract_data.py
1
""" """
2
from sklearn.preprocessing import Normalizer
3
from sklearn.preprocessing import RobustScaler
4
from sklearn.preprocessing import MinMaxScaler
5
from sklearn.preprocessing import quantile_transform
6
7
from simdeep.config import TRAINING_TSV
8
from simdeep.config import SURVIVAL_TSV
9
10
from simdeep.config import TEST_TSV
11
from simdeep.config import SURVIVAL_TSV_TEST
12
13
from simdeep.config import PATH_DATA
14
from simdeep.config import STACK_MULTI_OMIC
15
16
from simdeep.config import NORMALIZATION
17
18
from simdeep.config import FILL_UNKOWN_FEATURE_WITH_0
19
20
from simdeep.config import CROSS_VALIDATION_INSTANCE
21
from simdeep.config import TEST_FOLD
22
23
from simdeep.config import SURVIVAL_FLAG
24
25
from simdeep.survival_utils import load_data_from_tsv
26
from simdeep.survival_utils import load_survival_file
27
from simdeep.survival_utils import return_intersection_indexes
28
from simdeep.survival_utils import translate_index
29
from simdeep.survival_utils import MadScaler
30
from simdeep.survival_utils import RankNorm
31
from simdeep.survival_utils import CorrelationReducer
32
from simdeep.survival_utils import VarianceReducer
33
from simdeep.survival_utils import SampleReducer
34
from simdeep.survival_utils import convert_metadata_frame_to_matrix
35
36
from simdeep.survival_utils import save_matrix
37
38
from collections import defaultdict
39
40
from os.path import isfile
41
42
from time import time
43
44
import numpy as np
45
46
import pandas as pd
47
48
from numpy import hstack
49
from numpy import vstack
50
51
52
######################## VARIABLE ############################
53
QUANTILE_OPTION = {'n_quantiles': 100,
54
                   'output_distribution':'normal'}
55
###############################################################
56
57
58
class LoadData():
59
    """
60
    """
61
62
    def __init__(
63
            self,
64
            path_data=PATH_DATA,
65
            training_tsv=TRAINING_TSV,
66
            survival_tsv=SURVIVAL_TSV,
67
            metadata_tsv=None,
68
            metadata_test_tsv=None,
69
            test_tsv=TEST_TSV,
70
            survival_tsv_test=SURVIVAL_TSV_TEST,
71
            cross_validation_instance=CROSS_VALIDATION_INSTANCE,
72
            test_fold=TEST_FOLD,
73
            stack_multi_omic=STACK_MULTI_OMIC,
74
            fill_unkown_feature_with_0=FILL_UNKOWN_FEATURE_WITH_0,
75
            normalization=NORMALIZATION,
76
            survival_flag=SURVIVAL_FLAG,
77
            subset_training_with_meta={},
78
            _autoencoder_parameters={},
79
            verbose=True,
80
    ):
81
        """
82
        class to extract data
83
        :training_matrices: dict(matrice_type, path to the tsv file)
84
85
        :path_data: str    path to the folder containing the data
86
        :training_tsv: dict    dict('data type', 'name of the tsv file')
87
        :survival_tsv: str    name of the tsv file containing the survival data
88
                              of the training set
89
        :survival_tsv_test: str    name of the tsv file containing the survival data
90
                                   of the test set
91
        :metadata_tsv: str         name of the file containing metadata
92
        :metadata_test_tsv: str         name of the file containing metadata of the test set
93
        :tsv_test: str    name of the file containing the test dataset
94
        :data_type_test: str    name of the data type of the test set
95
                                must match a key existing in training_tsv
96
        """
97
98
        self.verbose = verbose
99
        self.do_stack_multi_omic = stack_multi_omic
100
        self.path_data = path_data
101
        self.survival_tsv = survival_tsv
102
        self.metadata_tsv = metadata_tsv
103
        self.training_tsv = training_tsv
104
        self.fill_unkown_feature_with_0 = fill_unkown_feature_with_0
105
        self.survival_flag = survival_flag
106
        self.feature_array = {}
107
        self.matrix_array = {}
108
        self.subset_training_with_meta = subset_training_with_meta
109
110
        self.test_tsv = test_tsv
111
        self.matrix_train_array = {}
112
113
        self.sample_ids = []
114
        self.data_type = list(training_tsv.keys())
115
116
        self.survival = None
117
        self.survival_tsv_test = survival_tsv_test
118
        self.metadata_test_tsv = metadata_test_tsv
119
120
        self.matrix_full_array = {}
121
        self.sample_ids_full = []
122
        self.survival_full = None
123
124
        self.feature_test_array = {}
125
        self.matrix_test_array = {}
126
127
        self.sample_ids_cv = []
128
        self.matrix_cv_array = {}
129
        self.matrix_cv_unormalized_array = {}
130
        self.survival_cv = None
131
132
        self._cv_loaded = False
133
        self._full_loaded = False
134
135
        self.matrix_ref_array = {}
136
        self.feature_ref_array = {}
137
        self.feature_ref_index = {}
138
        self.feature_train_array = {}
139
        self.feature_train_index = {}
140
141
        self.metadata_frame_full = None
142
        self.metadata_frame_cv = None
143
        self.metadata_frame_test = None
144
        self.metadata_frame = None
145
146
        self.metadata_mat_full = None
147
        self.metadata_mat_cv = None
148
        self.metadata_mat_test = None
149
        self.metadata_mat = None
150
151
        self.survival_test = None
152
        self.sample_ids_test = None
153
154
        self.cross_validation_instance = cross_validation_instance
155
        self.test_fold = test_fold
156
157
        self.do_feature_reduction = None
158
159
        self.normalizer = Normalizer()
160
        self.mad_scaler = MadScaler()
161
        self.robust_scaler = RobustScaler()
162
        self.min_max_scaler = MinMaxScaler()
163
        self.dim_reducer = CorrelationReducer()
164
        self.variance_reducer = VarianceReducer()
165
166
        self._autoencoder_parameters = _autoencoder_parameters
167
        self.normalization = defaultdict(bool, normalization)
168
        self.normalization_test = None
169
170
    def __del__(self):
171
        """
172
        """
173
        try:
174
            import gc
175
            gc.collect()
176
        except Exception:
177
               pass
178
179
    def _stack_multiomics(self, arrays=None, features=None):
180
        """
181
        """
182
        if not self.do_stack_multi_omic:
183
            return
184
185
        if arrays is not None:
186
            arrays['STACKED'] = hstack(
187
                tuple(arrays.values()))
188
189
            for key in list(arrays.keys()):
190
                arrays.pop(key) if key != 'STACKED' else True
191
192
        if not features:
193
            return
194
195
        features['STACKED'] = [feat for key in features
196
                               for feat in features[key]]
197
        for key in list(features.keys()):
198
            features.pop(key) if key != 'STACKED' else True
199
200
        self.feature_ref_index['STACKED'] = {feature: pos for pos, feature
201
                                             in enumerate(features['STACKED'])}
202
203
    def load_matrix_test_fold(self):
204
        """ """
205
        if not self.cross_validation_instance or self._cv_loaded:
206
            return
207
208
        for key in self.matrix_array:
209
210
            matrix_test = self.matrix_cv_array[key].copy()
211
            matrix_ref = self.matrix_array[key].copy()
212
213
            matrix_ref, matrix_test = self.transform_matrices(
214
                matrix_ref, matrix_test, key,
215
            )
216
217
            self.matrix_cv_unormalized_array[key] = \
218
                self.matrix_cv_array[key].copy()
219
            self.matrix_cv_array[key] = matrix_test
220
221
        self._stack_multiomics(self.matrix_cv_array)
222
        self._cv_loaded = True
223
224
    def load_matrix_test(self, normalization=None):
225
        """ """
226
        if normalization is not None:
227
            self.normalization_test = normalization
228
        else:
229
            self.normalization_test = self.normalization
230
231
        for key in self.test_tsv:
232
            sample_ids, feature_ids, matrix = load_data_from_tsv(
233
                f_name=self.test_tsv[key],
234
                key=key,
235
                path_data=self.path_data)
236
237
            feature_ids_ref = self.feature_array[key]
238
            matrix_ref = self.matrix_array[key].copy()
239
240
            common_features = set(feature_ids).intersection(feature_ids_ref)
241
242
            if self.verbose:
243
                print('nb common features for the test set:{0}'.format(len(common_features)))
244
245
            feature_ids_dict = {feat: i for i,feat in enumerate(feature_ids)}
246
            feature_ids_ref_dict = {feat: i for i,feat in enumerate(feature_ids_ref)}
247
248
            if len(common_features) < len(feature_ids_ref) and self.fill_unkown_feature_with_0:
249
                missing_features = set(feature_ids_ref).difference(common_features)
250
251
                if self.verbose:
252
                    print('filling {0} with 0 for {1} additional features'.format(
253
                        key, len(missing_features)))
254
255
                matrix = hstack([matrix, np.zeros((len(sample_ids), len(missing_features)))])
256
257
                for i, feat in enumerate(missing_features):
258
                    feature_ids_dict[feat] = i + len(feature_ids)
259
260
                common_features = feature_ids_ref
261
262
            feature_index = [feature_ids_dict[feature] for feature in common_features]
263
            feature_ref_index = [feature_ids_ref_dict[feature] for feature in common_features]
264
265
            matrix_test = np.nan_to_num(matrix.T[feature_index].T)
266
            matrix_ref = np.nan_to_num(matrix_ref.T[feature_ref_index].T)
267
268
            self.feature_test_array[key] = list(common_features)
269
270
            if not isinstance(self.sample_ids_test, type(None)):
271
                try:
272
                    assert(self.sample_ids_test == sample_ids)
273
                except Exception:
274
                    raise Exception('Assertion error when loading test sample ids!')
275
            else:
276
                self.sample_ids_test = sample_ids
277
278
            matrix_ref, matrix_test = self.transform_matrices(
279
                matrix_ref, matrix_test, key, normalization=normalization)
280
281
            self._define_test_features(key, normalization)
282
283
            self.matrix_test_array[key] = matrix_test
284
            self.matrix_ref_array[key] = matrix_ref
285
            self.feature_ref_array[key] = self.feature_test_array[key]
286
            self.feature_ref_index[key] = {feat: pos for pos, feat in enumerate(common_features)}
287
288
            self._define_ref_features(key, normalization)
289
290
        self._stack_multiomics(self.matrix_test_array,
291
                               self.feature_test_array)
292
        self._stack_multiomics(self.matrix_ref_array,
293
                               self.feature_ref_array)
294
295
    def load_meta_data_test(self, metadata_file="", sep="\t"):
296
        """
297
        """
298
        if metadata_file:
299
            self.metadata_test_tsv = metadata_file
300
301
        if isfile("{0}/{1}".format(self.path_data, self.metadata_test_tsv)):
302
            self.metadata_test_tsv = "{0}/{1}".format(
303
                self.path_data, self.metadata_test_tsv)
304
305
        if not self.metadata_test_tsv:
306
            return
307
308
        frame = pd.read_csv(self.metadata_test_tsv, sep=sep, index_col=0)
309
310
        diff = set(self.sample_ids_test).difference(frame.index)
311
312
        if diff:
313
            raise(Exception(
314
                "Error! samples from the tes dataset not present in metadata: {0}".format(
315
                    list(diff)[:5])))
316
317
        self.metadata_frame_test = frame.T[self.sample_ids_test].T
318
        self.metadata_mat_test = convert_metadata_frame_to_matrix(
319
            self.metadata_frame_test)
320
321
    def load_meta_data(self, sep="\t"):
322
        """
323
        """
324
325
        if isfile("{0}/{1}".format(self.path_data, self.metadata_tsv)):
326
            self.metadata_tsv = "{0}/{1}".format(
327
                self.path_data, self.metadata_tsv)
328
329
        if not self.metadata_tsv:
330
            return
331
332
        frame = pd.read_csv(self.metadata_tsv, sep=sep, index_col=0)
333
334
        ## FULL ##
335
        if self.sample_ids_full:
336
            diff = set(self.sample_ids_full).difference(frame.index)
337
338
            if diff:
339
                raise(Exception("Error! sample not present in metadata: {0}".format(
340
                    list(diff)[:5])))
341
342
            self.metadata_frame_full = frame.T[self.sample_ids_full].T
343
344
            self.metadata_mat_full = convert_metadata_frame_to_matrix(
345
                self.metadata_frame_full)
346
347
        ## CV ##
348
        if len(self.sample_ids_cv):
349
            diff = set(self.sample_ids_cv).difference(frame.index)
350
351
            if diff:
352
                raise(Exception("Error! sample not present in metadata: {0}".format(
353
                    list(diff)[:5])))
354
355
            self.metadata_frame_cv = frame.T[self.sample_ids_cv].T
356
            self.metadata_mat_cv = convert_metadata_frame_to_matrix(
357
                self.metadata_frame_cv)
358
359
        ## ALL ##
360
        diff = set(self.sample_ids).difference(frame.index)
361
362
        if diff:
363
            raise(Exception("Error! sample not present in metadata: {0}".format(
364
                list(diff)[:5])))
365
366
        self.metadata_frame = frame.T[self.sample_ids].T
367
        self.metadata_mat = convert_metadata_frame_to_matrix(
368
            self.metadata_frame)
369
370
    def subset_training_sets(self, change_cv=False):
371
        """ """
372
        if not self.subset_training_with_meta:
373
            print("Not subsetting training dataset.")
374
            return
375
376
        if self.metadata_frame is None:
377
            print("No metadata parsed. Not subsetting training sets")
378
            return
379
380
        samples_subset = set()
381
        samples_subset_cv = set()
382
383
        for key, values in self.subset_training_with_meta.items():
384
            if not isinstance(values, list):
385
                values = [values]
386
387
            for value in values:
388
                if key not in self.metadata_frame:
389
                    raise(Exception("Subbseting keys does'nt not exists in the metadata {0}".format(
390
                        key)))
391
392
                index = self.metadata_frame[self.metadata_frame[key] == value].index
393
394
                if self.metadata_frame_cv is not None:
395
                    index_cv = self.metadata_frame_cv[self.metadata_frame_cv[key] == value].index
396
                    samples_subset_cv.update(index_cv)
397
398
                samples_subset.update(index)
399
400
        new_index = translate_index(self.sample_ids, samples_subset)
401
402
        for key in self.matrix_train_array:
403
            self.matrix_train_array[key] = self.matrix_train_array[key][new_index]
404
        for key in self.matrix_ref_array:
405
            self.matrix_ref_array[key] = self.matrix_ref_array[key][new_index]
406
        for key in self.matrix_array:
407
            self.matrix_array[key] = self.matrix_array[key][new_index]
408
409
        self.survival = self.survival[new_index]
410
411
        self.metadata_frame = self.metadata_frame.T[list(samples_subset)].T
412
        self.metadata_mat = convert_metadata_frame_to_matrix(
413
            self.metadata_frame)
414
415
        self.sample_ids = list(samples_subset)
416
417
        if self.survival_cv is not None:
418
            new_index_cv = translate_index(self.sample_ids_cv,
419
                                           samples_subset_cv)
420
            for key in self.matrix_cv_array:
421
                self.matrix_cv_array[key] = self.matrix_cv_array[key][new_index_cv]
422
423
                if key in self.matrix_cv_unormalized_array:
424
                    self.matrix_cv_unormalized_array[key] = self.matrix_cv_unormalized_array[
425
                        key][new_index_cv]
426
427
            self.metadata_frame_cv = self.metadata_frame_cv.T[
428
                list(samples_subset_cv)].T
429
            self.metadata_mat_cv = convert_metadata_frame_to_matrix(
430
                self.metadata_frame_cv)
431
432
            self.sample_ids_cv = list(samples_subset_cv)
433
            self.survival_cv = self.survival_cv[new_index_cv]
434
435
    def load_new_test_dataset(self, tsv_dict,
436
                              path_survival_file=None,
437
                              survival_flag=None,
438
                              normalization=None,
439
                              metadata_file=None):
440
        """
441
        """
442
        if normalization is not None:
443
            normalization = defaultdict(bool, normalization)
444
        else:
445
            normalization = self.normalization.copy()
446
447
        self.test_tsv = tsv_dict.copy()
448
449
        for key in tsv_dict:
450
            if key not in self.training_tsv:
451
                self.test_tsv.pop(key)
452
453
        self.survival_test = None
454
        self.sample_ids_test = None
455
456
        self.metadata_frame_test = None
457
        self.metadata_mat_test = None
458
459
        self.survival_tsv_test = path_survival_file
460
461
        self.matrix_test_array = {}
462
        self.matrix_ref_array = {}
463
        self.feature_test_array = {}
464
        self.feature_ref_array = {}
465
        self.feature_ref_index = {}
466
467
        self.load_matrix_test(normalization)
468
        self.load_survival_test(survival_flag)
469
        self.load_meta_data_test(metadata_file=metadata_file)
470
471
    def _create_ref_matrix(self, key):
472
        """ """
473
        features_test = self.feature_test_array[key]
474
475
        features_train = self.feature_train_array[key]
476
        matrix_train = self.matrix_ref_array[key]
477
478
        test_dict = {feat: pos for pos, feat in enumerate(features_test)}
479
        train_dict = {feat: pos for pos, feat in enumerate(features_train)}
480
481
        index = [train_dict[feat] for feat in features_test]
482
483
        self.feature_ref_array[key] = self.feature_test_array[key]
484
        self.matrix_ref_array[key] = np.nan_to_num(matrix_train.T[index].T)
485
486
        self.feature_ref_index[key] = test_dict
487
488
    def load_array(self):
489
        """ """
490
        if self.verbose:
491
            print('loading data...')
492
493
        t = time()
494
495
        self.feature_array = {}
496
        self.matrix_array = {}
497
498
        data = list(self.data_type)[0]
499
        f_name = self.training_tsv[data]
500
501
        self.sample_ids, feature_ids, matrix = load_data_from_tsv(
502
            f_name=f_name,
503
            key=data,
504
            path_data=self.path_data)
505
506
        if self.verbose:
507
            print('{0} loaded of dim:{1}'.format(f_name, matrix.shape))
508
509
        self.feature_array[data] = feature_ids
510
        self.matrix_array[data] = matrix
511
512
        for data in self.data_type[1:]:
513
            f_name = self.training_tsv[data]
514
            sample_ids, feature_ids, matrix = load_data_from_tsv(
515
                f_name=f_name,
516
                key=data,
517
                path_data=self.path_data)
518
519
            if self.sample_ids != sample_ids:
520
                print('#### Different patient ID for {0} matrix ####'.format(data))
521
522
                index1, index2, sample_ids = return_intersection_indexes(
523
                    self.sample_ids, sample_ids)
524
525
                self.sample_ids = sample_ids
526
                matrix = matrix[index2]
527
528
                for data2 in self.matrix_array:
529
                    self.matrix_array[data2] = self.matrix_array[data2][index1]
530
531
            self.feature_array[data] = feature_ids
532
            self.matrix_array[data] = matrix
533
534
            if self.verbose:
535
                print('{0} loaded of dim:{1}'.format(f_name, matrix.shape))
536
537
        self._discard_training_samples()
538
539
        if self.verbose:
540
            print('data loaded in {0} s'.format(time() - t))
541
542
    def _discard_training_samples(self):
543
        """
544
        """
545
        if self.normalization['DISCARD_TRAINING_SAMPLES']:
546
            sample_reducer = SampleReducer(1.0 - self.normalization['DISCARD_TRAINING_SAMPLES'])
547
            index = range(len(self.sample_ids))
548
            to_keep, to_remove = sample_reducer.sample_to_keep(self.matrix_array, index)
549
550
            self.sample_ids = np.asarray(self.sample_ids)[to_keep].tolist()
551
552
            for key in self.matrix_array:
553
                self.matrix_array[key] = self.matrix_array[key][to_keep]
554
555
            if self.verbose:
556
                print('{0} training samples discarded'.format(len(to_remove)))
557
558
    def reorder_matrix_array(self, new_sample_ids):
559
        """
560
        """
561
        assert(set(new_sample_ids) == set(self.sample_ids))
562
        index_dict = {sample: pos for pos, sample in enumerate(self.sample_ids)}
563
        index = [index_dict[sample] for sample in new_sample_ids]
564
565
        self.sample_ids = np.asarray(self.sample_ids)[index].tolist()
566
567
        for key in self.matrix_array:
568
            self.matrix_array[key] = self.matrix_array[key][index]
569
570
        self.survival = self.survival[index]
571
572
    def create_a_cv_split(self):
573
        """ """
574
        if not self.cross_validation_instance:
575
            return
576
577
        cv = self.cross_validation_instance
578
579
        if isinstance(self.cross_validation_instance, tuple):
580
            train, test = self.cross_validation_instance
581
        else:
582
            train, test = [(tn, tt)
583
                           for tn, tt in
584
                           cv.split(self.sample_ids)][self.test_fold]
585
586
        if self.normalization['PERC_SAMPLE_TO_KEEP']:
587
            sample_reducer = SampleReducer(self.normalization['PERC_SAMPLE_TO_KEEP'])
588
            to_keep, to_remove = sample_reducer.sample_to_keep(self.matrix_array, train)
589
590
            test = list(train[to_remove]) + list(test)
591
            train = train[to_keep]
592
593
        for key in self.matrix_array:
594
            self.matrix_cv_array[key] = self.matrix_array[key][test]
595
            self.matrix_array[key] = self.matrix_array[key][train]
596
597
        self.survival_cv = self.survival.copy()[test]
598
        self.survival = self.survival[train]
599
600
        if self.metadata_frame is not None:
601
            # cv
602
            self.metadata_frame_cv = self.metadata_frame.T[
603
                list(np.asarray(self.sample_ids)[test])].T
604
            self.metadata_mat_cv = self.metadata_mat.T[test].T
605
            self.metadata_mat_cv.index = range(len(test))
606
            # train
607
            self.metadata_frame = self.metadata_frame.T[
608
                list(np.asarray(self.sample_ids)[train])].T
609
            self.metadata_mat = self.metadata_mat.T[train].T
610
            self.metadata_mat.index = range(len(train))
611
612
        self.sample_ids_cv = np.asarray(self.sample_ids)[test].tolist()
613
        self.sample_ids = np.asarray(self.sample_ids)[train].tolist()
614
615
    def load_matrix_full(self):
616
        """
617
        """
618
        if self._full_loaded:
619
            return
620
621
        if not self.cross_validation_instance:
622
            self.matrix_full_array = self.matrix_train_array
623
            self.sample_ids_full = self.sample_ids
624
            self.survival_full = self.survival
625
            self.metadata_frame_full = self.metadata_frame
626
            self.metadata_mat_full = self.metadata_mat
627
            return
628
629
        if not self._cv_loaded:
630
            self.load_matrix_test_fold()
631
632
        for key in self.matrix_train_array:
633
            self.matrix_full_array[key] = vstack([self.matrix_train_array[key],
634
                                                  self.matrix_cv_array[key]])
635
636
        self.sample_ids_full = self.sample_ids[:] + self.sample_ids_cv[:]
637
        self.survival_full = vstack([self.survival, self.survival_cv])
638
639
        if self.metadata_frame is not None:
640
            self.metadata_frame_full = pd.concat([self.metadata_frame,
641
                                                  self.metadata_frame_cv])
642
            self.metadata_mat_full = pd.concat([self.metadata_mat,
643
                                                  self.metadata_mat_cv])
644
            self.metadata_mat_full.index = range(len(self.sample_ids_full))
645
646
        self._full_loaded = True
647
648
    def load_survival(self):
649
        """ """
650
        survival = load_survival_file(self.survival_tsv, path_data=self.path_data,
651
                                      survival_flag=self.survival_flag)
652
        matrix = []
653
654
        retained_samples = []
655
        sample_removed = 0
656
657
        for ids, sample in enumerate(self.sample_ids):
658
            if sample not in survival:
659
                sample_removed += 1
660
                continue
661
662
            retained_samples.append(ids)
663
            matrix.append(survival[sample])
664
665
        self.survival = np.asmatrix(matrix)
666
667
        if sample_removed:
668
            for key in self.matrix_array:
669
                self.matrix_array[key] = self.matrix_array[key][retained_samples]
670
671
            self.sample_ids = np.asarray(self.sample_ids)[retained_samples]
672
673
            if self.verbose:
674
                print('{0} samples without survival removed'.format(sample_removed))
675
676
    def load_survival_test(self, survival_flag=None):
677
        """ """
678
        if self.survival_tsv_test is None:
679
            self.survival_test = np.empty(
680
                shape=(len(self.sample_ids_test), 2))
681
682
            self.survival_test[:] = np.nan
683
684
            return
685
686
        if survival_flag is None:
687
            survival_flag = self.survival_flag
688
689
        survival = load_survival_file(self.survival_tsv_test,
690
                                      path_data=self.path_data,
691
                                      survival_flag=survival_flag)
692
        matrix = []
693
694
        retained_samples = []
695
        sample_removed = 0
696
697
        for ids, sample in enumerate(self.sample_ids_test):
698
            if sample not in survival:
699
                sample_removed += 1
700
                continue
701
702
            retained_samples.append(ids)
703
            matrix.append(survival[sample])
704
705
        self.survival_test = np.asmatrix(matrix)
706
707
        if sample_removed:
708
            for key in self.matrix_test_array:
709
                self.matrix_test_array[key] = self.matrix_test_array[key][retained_samples]
710
711
            self.sample_ids_test = np.asarray(self.sample_ids_test)[retained_samples]
712
713
            if self.verbose:
714
                print('{0} samples without survival removed'.format(sample_removed))
715
716
    def _define_train_features(self, key):
717
        """ """
718
        self.feature_train_array[key] = self.feature_array[key][:]
719
720
        if self.normalization['TRAIN_CORR_REDUCTION']:
721
            self.feature_train_array[key] = ['{0}_{1}'.format(key, sample)
722
                                             for sample in self.sample_ids]
723
        elif self.normalization['NB_FEATURES_TO_KEEP']:
724
            self.feature_train_array[key] = np.array(self.feature_train_array[key])[
725
                self.variance_reducer.index_to_keep].tolist()
726
727
        self.feature_ref_array[key] = self.feature_train_array[key]
728
729
        self.feature_train_index[key] = {key: id for id, key in enumerate(
730
            self.feature_train_array[key])}
731
        self.feature_ref_index[key] = self.feature_train_index[key]
732
733
    def _define_test_features(self, key, normalization=None):
734
        """ """
735
        if normalization is None:
736
            normalization = self.normalization
737
738
        if normalization['TRAIN_CORR_REDUCTION']:
739
            self.feature_test_array[key] = ['{0}_{1}'.format(key, sample)
740
                                             for sample in self.sample_ids]
741
742
        elif normalization['NB_FEATURES_TO_KEEP']:
743
            self.feature_test_array[key] = np.array(self.feature_test_array[key])[
744
                self.variance_reducer.index_to_keep].tolist()
745
746
    def _define_ref_features(self, key, normalization=None):
747
        """ """
748
        if normalization is None:
749
            normalization = self.normalization
750
751
        if normalization['TRAIN_CORR_REDUCTION']:
752
            self.feature_ref_array[key] = ['{0}_{1}'.format(key, sample)
753
                                           for sample in self.sample_ids]
754
755
            self.feature_ref_index[key] = {feat:pos for pos, feat in
756
                                           enumerate(self.feature_ref_array[key])}
757
758
        elif normalization['NB_FEATURES_TO_KEEP']:
759
            self.feature_ref_index[key] = {feat: pos for pos, feat in
760
                                           enumerate(self.feature_ref_array[key])}
761
762
    def normalize_training_array(self):
763
        """ """
764
        for key in self.matrix_array:
765
            matrix = self.matrix_array[key].copy()
766
            matrix = self._normalize(matrix, key)
767
768
            self.matrix_train_array[key] = matrix
769
            self.matrix_ref_array[key] = self.matrix_train_array[key]
770
            self._define_train_features(key)
771
772
        self._stack_multiomics(self.matrix_train_array, self.feature_train_array)
773
        self._stack_multiomics(self.matrix_ref_array, self.feature_ref_array)
774
        self._stack_index()
775
776
    def _stack_index(self):
777
        """
778
        """
779
        if not self.do_stack_multi_omic:
780
            return
781
782
        index = {'STACKED':{}}
783
        count = 0
784
785
        for key in self.feature_train_index:
786
            for feature in self.feature_train_index[key]:
787
                index['STACKED'][feature] = count + self.feature_train_index[key][feature]
788
789
            count += len(self.feature_train_index[key])
790
791
        self.feature_train_index = index
792
        self.feature_ref_index = self.feature_train_index
793
794
    def _normalize(self, matrix, key):
795
        """ """
796
        if self.verbose:
797
            print('normalizing for {0}...'.format(key))
798
799
        if self.normalization['NB_FEATURES_TO_KEEP']:
800
            self.variance_reducer.nb_features = self.normalization[
801
                'NB_FEATURES_TO_KEEP']
802
            matrix = self.variance_reducer.fit_transform(matrix)
803
804
        if self.normalization['CUSTOM']:
805
            custom_norm = self.normalization['CUSTOM']()
806
            assert(hasattr(custom_norm, 'fit') and hasattr(
807
                custom_norm, 'fit_transform'))
808
            matrix = custom_norm.fit_transform(matrix)
809
810
        if self.normalization['TRAIN_MIN_MAX']:
811
            matrix = MinMaxScaler().fit_transform(matrix.T).T
812
813
        if self.normalization['TRAIN_MAD_SCALE']:
814
            matrix = self.mad_scaler.fit_transform(matrix.T).T
815
816
        if self.normalization['TRAIN_ROBUST_SCALE'] or\
817
           self.normalization['TRAIN_ROBUST_SCALE_TWO_WAY']:
818
            matrix = self.robust_scaler.fit_transform(matrix)
819
820
        if self.normalization['TRAIN_NORM_SCALE']:
821
            matrix = self.normalizer.fit_transform(matrix)
822
823
        if self.normalization['TRAIN_QUANTILE_TRANSFORM']:
824
            matrix = quantile_transform(matrix, **QUANTILE_OPTION)
825
826
        if self.normalization['TRAIN_RANK_NORM']:
827
            matrix = RankNorm().fit_transform(
828
                matrix)
829
830
        if self.normalization['TRAIN_CORR_REDUCTION']:
831
            args = self.normalization['TRAIN_CORR_REDUCTION']
832
            if args == True:
833
                args = {}
834
835
            if self.verbose:
836
                print('dim reduction for {0}...'.format(key))
837
838
            reducer = CorrelationReducer(**args)
839
            matrix = reducer.fit_transform(
840
                matrix)
841
842
            if self.normalization['TRAIN_CORR_RANK_NORM']:
843
                matrix = RankNorm().fit_transform(
844
                    matrix)
845
846
            if self.normalization['TRAIN_CORR_QUANTILE_NORM']:
847
                matrix = quantile_transform(matrix, **QUANTILE_OPTION)
848
849
            if self.normalization['TRAIN_CORR_NORM_SCALE']:
850
                matrix = self.normalizer.fit_transform(matrix)
851
852
        return np.nan_to_num(matrix)
853
854
    def transform_matrices(self, matrix_ref, matrix, key, normalization=None):
855
        """ """
856
        if normalization is None:
857
            normalization = self.normalization
858
859
        if self.verbose:
860
            print('Scaling/Normalising dataset...')
861
862
        if normalization['LOG_REF_MATRIX']:
863
            matrix_ref = np.log2(1.0 + matrix_ref)
864
865
        if normalization['LOG_TEST_MATRIX']:
866
            matrix = np.log2(1.0 +  matrix)
867
868
        if self.normalization['CUSTOM']:
869
            custom_norm = self.normalization['CUSTOM']()
870
            assert(hasattr(custom_norm, 'fit') and hasattr(
871
                custom_norm, 'fit_transform'))
872
            matrix_ref = custom_norm.fit_transform(matrix_ref)
873
            matrix = custom_norm.transform(matrix)
874
875
        if normalization['NB_FEATURES_TO_KEEP']:
876
            self.variance_reducer.nb_features = normalization[
877
                'NB_FEATURES_TO_KEEP']
878
            matrix_ref = self.variance_reducer.fit_transform(matrix_ref)
879
            matrix = self.variance_reducer.transform(matrix)
880
881
        if normalization['TRAIN_MIN_MAX']:
882
            matrix_ref = self.min_max_scaler.fit_transform(matrix_ref.T).T
883
            matrix = self.min_max_scaler.fit_transform(matrix.T).T
884
885
        if normalization['TRAIN_MAD_SCALE']:
886
            matrix_ref = self.mad_scaler.fit_transform(matrix_ref.T).T
887
            matrix = self.mad_scaler.fit_transform(matrix.T).T
888
889
        if normalization['TRAIN_ROBUST_SCALE']:
890
            matrix_ref = self.robust_scaler.fit_transform(matrix_ref)
891
            matrix = self.robust_scaler.transform(matrix)
892
893
        if normalization['TRAIN_ROBUST_SCALE_TWO_WAY']:
894
            matrix_ref = self.robust_scaler.fit_transform(matrix_ref)
895
            matrix = self.robust_scaler.transform(matrix)
896
897
        if normalization['TRAIN_NORM_SCALE']:
898
            matrix_ref = self.normalizer.fit_transform(matrix_ref)
899
            matrix = self.normalizer.transform(matrix)
900
901
        if self.normalization['TRAIN_QUANTILE_TRANSFORM']:
902
            matrix_ref = quantile_transform(matrix_ref, **QUANTILE_OPTION)
903
            matrix = quantile_transform(matrix, **QUANTILE_OPTION)
904
905
        if normalization['TRAIN_RANK_NORM']:
906
            matrix_ref = RankNorm().fit_transform(matrix_ref)
907
            matrix = RankNorm().fit_transform(matrix)
908
909
        if normalization['TRAIN_CORR_REDUCTION']:
910
            args = normalization['TRAIN_CORR_REDUCTION']
911
912
            if args == True:
913
                args = {}
914
915
            reducer = CorrelationReducer(**args)
916
            matrix_ref = reducer.fit_transform(matrix_ref)
917
            matrix = reducer.transform(matrix)
918
919
            if normalization['TRAIN_CORR_RANK_NORM']:
920
                matrix_ref = RankNorm().fit_transform(matrix_ref)
921
                matrix = RankNorm().fit_transform(matrix)
922
923
            if self.normalization['TRAIN_CORR_QUANTILE_TRANSFORM']:
924
                matrix_ref = quantile_transform(matrix_ref, **QUANTILE_OPTION)
925
                matrix = quantile_transform(matrix, **QUANTILE_OPTION)
926
927
            if self.normalization['TRAIN_CORR_NORM_SCALE']:
928
                matrix_ref = self.normalizer.fit_transform(matrix_ref)
929
                matrix = self.normalizer.fit_transform(matrix)
930
931
        return np.nan_to_num(matrix_ref), np.nan_to_num(matrix)
932
933
    def save_ref_matrix(self, path_folder, project_name):
934
        """
935
        """
936
        for key in self.matrix_ref_array:
937
            save_matrix(
938
                matrix=self.matrix_ref_array[key],
939
                feature_array=self.feature_ref_array[key],
940
                sample_array=self.sample_ids,
941
                path_folder=path_folder,
942
                project_name=project_name,
943
                key=key
944
            )