Diff of /simdeep/survival_utils.py [000000] .. [53737a]

Switch to unified view

a b/simdeep/survival_utils.py
1
"""
2
"""
3
import re
4
5
import pandas as pd
6
7
from simdeep.config import PATH_DATA
8
from simdeep.config import SURVIVAL_FLAG
9
from simdeep.config import SEPARATOR
10
from simdeep.config import ENTREZ_TO_ENSG_FILE
11
from simdeep.config import USE_INPUT_TRANSPOSE
12
from simdeep.config import DEFAULTSEP
13
from simdeep.config import CLASSIFIER
14
15
import  numpy as np
16
17
from os.path import isfile
18
19
from scipy.stats import rankdata
20
21
from numpy import hstack
22
23
from sklearn.metrics import pairwise_distances
24
25
from sklearn.preprocessing import LabelBinarizer
26
from sklearn.preprocessing import RobustScaler
27
28
from collections import defaultdict
29
30
from simdeep.coxph_from_r import coxph
31
from simdeep.coxph_from_r import c_index
32
33
from scipy.stats import kruskal
34
from scipy.stats import ranksums
35
36
from os.path import isdir
37
from os import mkdir
38
39
40
################ DEBUG ################
41
# supposed to be None for normal usage
42
MAX_FEATURE = None
43
#######################################
44
45
46
class MadScaler():
47
    def __init__(self):
48
        """
49
        """
50
        pass
51
    def fit_transform(self, X):
52
        """ """
53
        X = np.asarray(X)
54
55
        for i in range(len(X)):
56
            med = np.median(X[i])
57
            mad = np.median(np.abs(X[i] - med))
58
            X[i] = (X[i] - med) / mad
59
60
        return np.nan_to_num(np.matrix(X))
61
62
class RankNorm():
63
    """
64
    """
65
    def __init__(self):
66
        """
67
        """
68
        pass
69
70
    def fit_transform(self, X):
71
        """ """
72
        X = np.asarray(X)
73
        shape = list(map(float, X.shape))
74
75
        for i in range(len(X)):
76
            X[i] = rankdata(X[i]) / shape[1]
77
78
        return np.matrix(X)
79
80
class SampleReducer():
81
    """
82
    """
83
    def __init__(self, perc_sample_to_keep=0.90):
84
        """
85
        """
86
        assert(isinstance(perc_sample_to_keep, float))
87
        assert(0.0 < perc_sample_to_keep < 1.0)
88
        self.perc_sample_to_keep = perc_sample_to_keep
89
90
    def sample_to_keep(self, datasets, index=None):
91
        """
92
        """
93
        nb_samples = len(datasets.values()[0][index])
94
        scores = np.zeros(nb_samples)
95
        threshold = int(nb_samples * self.perc_sample_to_keep)
96
97
        for key in datasets:
98
            scores_array = np.array([vector.sum() for vector in datasets[key][index]])
99
            scores = scores + scores_array
100
101
        scores = [(pos, score) for pos, score in enumerate(scores)]
102
103
        scores.sort(key=lambda x:x[1], reverse=True)
104
        to_keep = [pos for pos, score in scores[:threshold]]
105
        to_remove = [pos for pos, score in scores[threshold:]]
106
107
        return to_keep, to_remove
108
109
class VarianceReducer():
110
    """
111
    """
112
    def __init__(self, nb_features=200):
113
        """
114
        """
115
        self.nb_features = nb_features
116
        self.index_to_keep = []
117
118
    def fit(self, dataset):
119
        """
120
        """
121
        if self.nb_features > dataset.shape[1]:
122
            self.nb_features = dataset.shape[1]
123
124
        variances = [np.var(array) for array in dataset.T]
125
        threshold = sorted(enumerate(variances),
126
                           reverse=True,
127
                           key=lambda x:x[1],
128
        )
129
        self.index_to_keep = [pos for pos, var in threshold[:self.nb_features]]
130
131
    def transform(self, dataset):
132
        """
133
        """
134
        return dataset.T[self.index_to_keep].T
135
136
    def fit_transform(self, dataset):
137
        """
138
        """
139
        self.fit(dataset)
140
        return self.transform(dataset)
141
142
class CorrelationReducer():
143
    """
144
    """
145
    def __init__(self, distance='correlation', threshold=None):
146
        """
147
        """
148
        self.distance = distance
149
        self.dataset = None
150
        self.threshold = threshold
151
152
    def fit(self, dataset):
153
        """ """
154
        self.dataset = dataset
155
156
        if self.threshold:
157
            self.dataset[self.dataset < self.threshold] = 0
158
159
    def transform(self, dataset):
160
        """ """
161
        if self.threshold:
162
            dataset[dataset < self.threshold] = 0
163
164
        return 1.0 - pairwise_distances(dataset,
165
                                        self.dataset,
166
                                        self.distance)
167
168
    def fit_transform(self, dataset):
169
        """ """
170
        self.fit(dataset)
171
        return self.transform(dataset)
172
173
174
class RankCorrNorm():
175
    """
176
    """
177
    def __init__(self, dataset):
178
        """
179
        """
180
        self.dataset = dataset
181
182
183
def load_survival_file(f_name,
184
                       path_data=PATH_DATA,
185
                       sep=DEFAULTSEP,
186
                       survival_flag=SURVIVAL_FLAG):
187
    """ """
188
    if f_name in SEPARATOR:
189
        sep = SEPARATOR[f_name]
190
191
    survival = {}
192
193
    filename = '{0}/{1}'.format(path_data, f_name)
194
195
    if not isfile(filename):
196
        raise Exception('## Error wirh unexisting file: {0}'.format(filename))
197
198
    with open(filename, 'r') as f_surv:
199
        first_line = f_surv.readline().strip(' \n\r\t').split(sep)
200
        for field in survival_flag.values():
201
            try:
202
                assert(field in first_line)
203
            except Exception:
204
                raise Exception("""#### Exception with survival file {fil} header (first line): "{header}". Header does not contain the field "{field}". Please define new survival flags  using the `survival_flag` variable. Current non-valid survival_flag: {sflag}. Needs variables in header: `{header}` defined as values for the key: "patient_id" => ID of the patient, "survival" => time of the study, "event" => event after survival time """.format(
205
                    header=first_line,
206
                    field=field,
207
                    fil=filename,
208
                    sflag=survival_flag,
209
                    ))
210
211
        patient_id = first_line.index(survival_flag['patient_id'])
212
        surv_id = first_line.index(survival_flag['survival'])
213
        event_id = first_line.index(survival_flag['event'])
214
215
        for line in f_surv:
216
            line = line.strip('\n').split(sep)
217
            ids  = line[patient_id].strip('"')
218
            ndays = line[surv_id].strip('"')
219
            isdead = line[event_id].strip('"')
220
221
            survival[ids] = (float(ndays), float(isdead))
222
223
    return survival
224
225
226
def translate_index(original_ids, new_ids):
227
    """ """
228
    index1d = {ids: pos for pos, ids in enumerate(original_ids)}
229
230
    return np.asarray([index1d[sample] for sample in new_ids])
231
232
233
def return_intersection_indexes(ids_1, ids_2):
234
    """ """
235
    index1d = {ids: pos for pos, ids in enumerate(ids_1)}
236
    index2d = {ids: pos for pos, ids in enumerate(ids_2)}
237
238
    inter = set(ids_1).intersection(ids_2)
239
240
    if len(inter) == 0:
241
        raise(Exception("Error! No common sample index between: {0}... and {1}...".format(
242
            ids_1[:2], ids_2[:2])))
243
244
    index1 = np.asarray([index1d[sample] for sample in inter])
245
    index2 = np.asarray([index2d[sample] for sample in inter])
246
247
    return index1, index2, list(inter)
248
249
250
def convert_metadata_frame_to_matrix(frame):
251
    """ """
252
    lbl = LabelBinarizer()
253
254
    normed_matrix = np.zeros((frame.shape[0], 0))
255
    keys = []
256
257
    for key in frame.keys():
258
        if str(frame[key].dtype) == 'object' or str(frame[key].dtype) == 'string':
259
            matrix = lbl.fit_transform(frame[key].astype('string'))
260
            if lbl.y_type_ == "binary":
261
                keys += list(["{0}_{1}".format(key, lbl.classes_[lbl.pos_label])])
262
            else:
263
                keys += ["{0}_{1}".format(key, k) for k in lbl.classes_]
264
        else:
265
            rbs = RobustScaler()
266
            matrix = np.asarray(frame[key]).reshape((frame.shape[0], 1))
267
            matrix = rbs.fit_transform(matrix)
268
            keys.append(key)
269
270
        normed_matrix = hstack([normed_matrix, matrix])
271
272
    return pd.DataFrame(normed_matrix, columns=keys)
273
274
275
def load_data_from_tsv(use_transpose=USE_INPUT_TRANSPOSE, **kwargs):
276
    """
277
    """
278
    if use_transpose:
279
        return _load_data_from_tsv_transposee(**kwargs)
280
    else:
281
        return _load_data_from_tsv(**kwargs)
282
283
284
def _load_data_from_tsv(
285
        f_name,
286
        key,
287
        path_data=PATH_DATA,
288
        f_type=float,
289
        sep=DEFAULTSEP,
290
        nan_to_num=True):
291
    """ """
292
    f_short = key
293
294
    if f_name in SEPARATOR:
295
        sep = SEPARATOR[f_name]
296
297
    f_tsv = open("{0}/{1}".format(path_data, f_name))
298
    header = f_tsv.readline().strip(sep + '\n').split(sep)
299
300
    feature_ids = ['{0}_{1}'.format(f_short, head)
301
                   for head in header][:MAX_FEATURE]
302
    sample_ids = []
303
    f_matrix = []
304
305
    for line in f_tsv:
306
        line = line.strip(sep + '\n').split(sep)
307
        sample_ids.append(line[0])
308
309
        if nan_to_num:
310
            line[1:] = [0 if (l.isalpha() or not l) else l
311
                        for l in line[1:MAX_FEATURE]]
312
313
        f_matrix.append(list(map(f_type, line[1:MAX_FEATURE])))
314
315
    f_matrix = np.array(f_matrix)
316
317
    if f_matrix.shape[1] == len(feature_ids) - 1:
318
        feature_ids = feature_ids[1:]
319
320
    assert(f_matrix.shape[1] == len(feature_ids))
321
    assert(f_matrix.shape[0] == len(sample_ids))
322
323
    f_tsv.close()
324
325
    return sample_ids, feature_ids, f_matrix
326
327
def _format_sample_name(sample_ids):
328
    """
329
    """
330
    regex = re.compile('_1_[A-Z][A-Z]')
331
332
    sample_ids = [regex.sub('', sample.strip('"')) for sample in sample_ids]
333
    return sample_ids
334
335
def _load_data_from_tsv_transposee(
336
        f_name,
337
        key,
338
        path_data=PATH_DATA,
339
        f_type=float,
340
        sep=DEFAULTSEP,
341
        nan_to_num=True):
342
    """ """
343
    if f_name in SEPARATOR:
344
        sep = SEPARATOR[f_name]
345
346
    f_tsv = open(path_data + f_name)
347
    header = f_tsv.readline().strip(sep + '\n').split(sep)
348
349
    sample_ids = header[1:]
350
351
    sample_ids = _format_sample_name(sample_ids)
352
353
    feature_ids = []
354
    f_matrix = []
355
356
    if f_name.lower().count('entrez'):
357
        ensg_dict = load_entrezID_to_ensg()
358
        use_ensg = True
359
    else:
360
        use_ensg = False
361
362
    for line in f_tsv:
363
        line = line.strip(sep + '\n').split(sep)
364
        feature = line[0].strip('"')
365
366
        if nan_to_num:
367
            line[1:] = [0 if (l.isalpha() or not l) else l
368
                        for l in line[1:]]
369
370
        if use_ensg and feature in ensg_dict:
371
            features = ensg_dict[feature]
372
        else:
373
            features = [feature]
374
375
        for feature in features:
376
            feature_ids.append('{0}_{1}'.format(key, feature))
377
            f_matrix.append(list(map(f_type, line[1:])))
378
379
380
    f_matrix = np.array(f_matrix).T
381
382
    assert(f_matrix.shape[1] == len(feature_ids))
383
    assert(f_matrix.shape[0] == len(sample_ids))
384
385
    f_tsv.close()
386
387
    return sample_ids, feature_ids, f_matrix
388
389
def select_best_classif_params(clf):
390
    """
391
    select best classifier parameters based uniquely
392
    on test errors
393
    """
394
    arr = []
395
396
    for fold in range(clf.cv):
397
        arr.append(clf.cv_results_[
398
            'split{0}_test_score'.format(fold)])
399
400
    score = [ar.max() for ar in np.array(arr).T]
401
    index = score.index(max(score))
402
403
    params = clf.cv_results_['params'][index]
404
405
    clf = CLASSIFIER(**params)
406
407
    return clf, params
408
409
def load_entrezID_to_ensg():
410
    """
411
    """
412
    entrez_dict = {}
413
414
    for line in open(ENTREZ_TO_ENSG_FILE):
415
        line = line.split()
416
        entrez_dict[line[0]] = line[1:]
417
418
    return entrez_dict
419
420
def _process_parallel_coxph(inp):
421
    """
422
    """
423
    node_id, activity, isdead, nbdays, seed, metadata_mat, use_r_packages = inp
424
    pvalue = coxph(activity,
425
                   isdead,
426
                   nbdays,
427
                   seed=seed,
428
                   metadata_mat=metadata_mat,
429
                   use_r_packages=use_r_packages)
430
431
    return node_id, pvalue
432
433
def _process_parallel_cindex(inp):
434
    """
435
    """
436
    (node_id,
437
     act_ref, isdead_ref, nbdays_ref,
438
     act_test, isdead_test, nbdays_test, use_r_packages) = inp
439
440
    score = c_index(act_ref, isdead_ref, nbdays_ref,
441
                    act_test, isdead_test, nbdays_test,
442
                    use_r_packages=use_r_packages
443
                    )
444
445
    return node_id, score
446
447
def _process_parallel_feature_importance(inp):
448
    """
449
    """
450
    arrays = defaultdict(list)
451
    feature, array, labels = inp
452
453
    for label, value in zip(labels, np.array(array).reshape(-1)):
454
        arrays[label].append(value)
455
    try:
456
        score, pvalue = kruskal(*arrays.values())
457
    except Exception:
458
        return feature, 1.0
459
460
    return feature, pvalue
461
462
def _process_parallel_feature_importance_per_cluster(inp):
463
    """
464
    """
465
    arrays = defaultdict(list)
466
    results = []
467
468
    feature, array, labels, pval_thres = inp
469
470
    for label, value in zip(labels, np.array(array).reshape(-1)):
471
        arrays[label].append(value)
472
473
    for cluster in arrays:
474
        array = np.array(arrays[cluster])
475
        array_comp = np.array([a for comp in arrays for a in arrays[comp]
476
                      if comp != cluster])
477
478
        score, pvalue = ranksums(array, array_comp)
479
        median_diff = np.median(array) - np.median(array_comp)
480
481
        if pvalue < pval_thres:
482
            results.append((cluster, feature, median_diff, pvalue))
483
484
    return results
485
486
487
def _process_parallel_survival_feature_importance_per_cluster(inp):
488
    """
489
    """
490
491
    feature, array, survival, metadata_mat, pval_thres, use_r_packages = inp
492
    nbdays, isdead = survival.T.tolist()
493
494
    pvalue = coxph(
495
        array,
496
        isdead,
497
        nbdays,
498
        metadata_mat=metadata_mat,
499
        use_r_packages=use_r_packages
500
    )
501
502
    if not np.isnan(pvalue) and pvalue < pval_thres:
503
        return (feature, pvalue)
504
    return None, None
505
506
def save_matrix(matrix, feature_array, sample_array,
507
                path_folder, project_name, key='', sep='\t'):
508
    """
509
    """
510
    if not isdir(path_folder):
511
        mkdir(path_folder)
512
513
    if key:
514
        key = '_' + key
515
516
    f_csv = open('{0}/{1}{2}.tsv'.format(path_folder, project_name, key), 'w')
517
518
    f_csv.write(sep + sep.join(map(lambda x:x.split('_', 1)[-1], feature_array)) + '\n')
519
520
    for sample, vector in zip(sample_array, matrix):
521
        vector = np.asarray(vector).reshape(-1)
522
        f_csv.write('{0}{1}'.format(sample, sep) + sep.join(map(str, vector)) + '\n')
523
524
    print('{0}/{1}{2}.tsv saved'.format(path_folder, project_name, key))