|
a |
|
b/simdeep/simdeep_analysis.py |
|
|
1 |
""" |
|
|
2 |
DeepProg class for one instance model |
|
|
3 |
""" |
|
|
4 |
|
|
|
5 |
from sklearn.cluster import KMeans |
|
|
6 |
from sklearn.mixture import GaussianMixture |
|
|
7 |
from sklearn.model_selection import cross_val_score |
|
|
8 |
|
|
|
9 |
from simdeep.deepmodel_base import DeepBase |
|
|
10 |
|
|
|
11 |
from simdeep.survival_model_utils import ClusterWithSurvival |
|
|
12 |
|
|
|
13 |
from simdeep.config import NB_CLUSTERS |
|
|
14 |
from simdeep.config import CLUSTER_ARRAY |
|
|
15 |
from simdeep.config import PVALUE_THRESHOLD |
|
|
16 |
from simdeep.config import CINDEX_THRESHOLD |
|
|
17 |
from simdeep.config import CLASSIFIER_TYPE |
|
|
18 |
from simdeep.config import USE_AUTOENCODERS |
|
|
19 |
from simdeep.config import FEATURE_SURV_ANALYSIS |
|
|
20 |
from simdeep.config import SEED |
|
|
21 |
|
|
|
22 |
from simdeep.config import MIXTURE_PARAMS |
|
|
23 |
from simdeep.config import PATH_RESULTS |
|
|
24 |
from simdeep.config import PROJECT_NAME |
|
|
25 |
from simdeep.config import CLASSIFICATION_METHOD |
|
|
26 |
|
|
|
27 |
from simdeep.config import CLUSTER_EVAL_METHOD |
|
|
28 |
from simdeep.config import CLUSTER_METHOD |
|
|
29 |
from simdeep.config import NB_THREADS_COXPH |
|
|
30 |
from simdeep.config import NB_SELECTED_FEATURES |
|
|
31 |
from simdeep.config import LOAD_EXISTING_MODELS |
|
|
32 |
from simdeep.config import NODES_SELECTION |
|
|
33 |
from simdeep.config import CLASSIFIER |
|
|
34 |
from simdeep.config import HYPER_PARAMETERS |
|
|
35 |
from simdeep.config import PATH_TO_SAVE_MODEL |
|
|
36 |
from simdeep.config import CLUSTERING_OMICS |
|
|
37 |
from simdeep.config import USE_R_PACKAGES_FOR_SURVIVAL |
|
|
38 |
|
|
|
39 |
from simdeep.survival_utils import _process_parallel_coxph |
|
|
40 |
from simdeep.survival_utils import _process_parallel_cindex |
|
|
41 |
from simdeep.survival_utils import _process_parallel_feature_importance |
|
|
42 |
from simdeep.survival_utils import _process_parallel_feature_importance_per_cluster |
|
|
43 |
from simdeep.survival_utils import select_best_classif_params |
|
|
44 |
|
|
|
45 |
from simdeep.simdeep_utils import metadata_usage_type |
|
|
46 |
from simdeep.simdeep_utils import feature_selection_usage_type |
|
|
47 |
|
|
|
48 |
from simdeep.simdeep_utils import load_labels_file |
|
|
49 |
|
|
|
50 |
from simdeep.coxph_from_r import coxph |
|
|
51 |
from simdeep.coxph_from_r import c_index |
|
|
52 |
from simdeep.coxph_from_r import c_index_multiple |
|
|
53 |
|
|
|
54 |
from simdeep.coxph_from_r import surv_median |
|
|
55 |
|
|
|
56 |
from collections import Counter |
|
|
57 |
|
|
|
58 |
from sklearn.metrics import silhouette_score |
|
|
59 |
|
|
|
60 |
try: |
|
|
61 |
from sklearn.metrics import calinski_harabasz_score \ |
|
|
62 |
as calinski_harabaz_score |
|
|
63 |
except Exception: |
|
|
64 |
from sklearn.metrics import calinski_harabaz_score |
|
|
65 |
|
|
|
66 |
from sklearn.model_selection import GridSearchCV |
|
|
67 |
|
|
|
68 |
import numpy as np |
|
|
69 |
from numpy import hstack |
|
|
70 |
|
|
|
71 |
from collections import defaultdict |
|
|
72 |
|
|
|
73 |
import warnings |
|
|
74 |
|
|
|
75 |
from multiprocessing import Pool |
|
|
76 |
|
|
|
77 |
from os.path import isdir |
|
|
78 |
from os import mkdir |
|
|
79 |
|
|
|
80 |
|
|
|
81 |
################ VARIABLE ############################################ |
|
|
82 |
_CLASSIFICATION_METHOD_LIST = ['ALL_FEATURES', 'SURVIVAL_FEATURES'] |
|
|
83 |
MODEL_THRES = 0.05 |
|
|
84 |
###################################################################### |
|
|
85 |
|
|
|
86 |
|
|
|
87 |
class SimDeep(DeepBase): |
|
|
88 |
""" |
|
|
89 |
Instanciate a new DeepProg instance. |
|
|
90 |
The default parameters are defined in the config.py file |
|
|
91 |
|
|
|
92 |
Parameters: |
|
|
93 |
:dataset: ExtractData instance. Default None (create a new dataset using the config variable) |
|
|
94 |
:nb_clusters: Number of clusters to search (default NB_CLUSTERS) |
|
|
95 |
:pvalue_thres: Pvalue threshold to include a feature (default PVALUE_THRESHOLD) |
|
|
96 |
:clustering_omics: Which omics to use for clustering. If empty, then all the available omics will be used |
|
|
97 |
:cindex_thres: C-index threshold to include a feature. This parameter is used only if `node_selection` is set to "C-index" (default CINDEX_THRESHOLD) |
|
|
98 |
:cluster_method: Cluster method to use. possible choice ['mixture', 'kmeans']. (default CLUSTER_METHOD) |
|
|
99 |
:cluster_eval_method: Cluster evaluation method to use in case the `cluster_array` parameter is a list of possible K. Possible choice ['bic', 'silhouette', 'calinski'] (default CLUSTER_EVAL_METHOD) |
|
|
100 |
:classifier_type: Type of classifier to use. Possible choice ['svm', 'clustering']. If 'clustering' is selected, The predict method of the clustering algoritm is used (default CLASSIFIER_TYPE) |
|
|
101 |
:project_name: Name of the project. This name will be used to save the output files and create the output folder (default PROJECT_NAME) |
|
|
102 |
:path_results: Result folder path used to save the output files (default PATH_RESULTS) |
|
|
103 |
:cluster_array: Array of possible number of clusters to try. If set, `nb_clusters` is ignored (default CLUSTER_ARRAY) |
|
|
104 |
:nb_selected_features: Number of selected features to construct classifiers (default NB_SELECTED_FEATURES) |
|
|
105 |
:mixture_params: Dictionary of parameters used to instanciate the Gaussian mixture algorithm (default MIXTURE_PARAMS) |
|
|
106 |
:node_selection: Mehtod to select new features. possible choice ['Cox-PH', 'C-index']. (default NODES_SELECTION) |
|
|
107 |
:nb_threads_coxph: Number of python processes to use to compute individual survival models in parallel (default NB_THREADS_COXPH) |
|
|
108 |
:classification_method: Possible choice ['ALL_FEATURES', 'SURVIVAL_FEATURES']. If 'SURVIVAL_FEATURES' is selected, the classifiers are built using survival features (default CLASSIFICATION_METHOD) |
|
|
109 |
:load_existing_models: (default LOAD_EXISTING_MODELS) |
|
|
110 |
:path_to_save_model: (default PATH_TO_SAVE_MODEL) |
|
|
111 |
:metadata_usage: Meta data usage with survival models (if metadata_tsv provided as argument to the dataset). Possible choice are [None, False, 'labels', 'new-features', 'all', True] (True is the same as all) |
|
|
112 |
:feature_selection_usage: selection method for survival features ('individual' or 'lasso') |
|
|
113 |
:alternative_embedding: alternative external embedding to use instead of builfing autoencoders (default None) |
|
|
114 |
:kwargs_alternative_embedding: parameters for external embedding fitting |
|
|
115 |
""" |
|
|
116 |
def __init__(self, |
|
|
117 |
nb_clusters=NB_CLUSTERS, |
|
|
118 |
pvalue_thres=PVALUE_THRESHOLD, |
|
|
119 |
cindex_thres=CINDEX_THRESHOLD, |
|
|
120 |
use_autoencoders=USE_AUTOENCODERS, |
|
|
121 |
feature_surv_analysis=FEATURE_SURV_ANALYSIS, |
|
|
122 |
cluster_method=CLUSTER_METHOD, |
|
|
123 |
cluster_eval_method=CLUSTER_EVAL_METHOD, |
|
|
124 |
classifier_type=CLASSIFIER_TYPE, |
|
|
125 |
project_name=PROJECT_NAME, |
|
|
126 |
path_results=PATH_RESULTS, |
|
|
127 |
cluster_array=CLUSTER_ARRAY, |
|
|
128 |
nb_selected_features=NB_SELECTED_FEATURES, |
|
|
129 |
mixture_params=MIXTURE_PARAMS, |
|
|
130 |
node_selection=NODES_SELECTION, |
|
|
131 |
nb_threads_coxph=NB_THREADS_COXPH, |
|
|
132 |
classification_method=CLASSIFICATION_METHOD, |
|
|
133 |
load_existing_models=LOAD_EXISTING_MODELS, |
|
|
134 |
path_to_save_model=PATH_TO_SAVE_MODEL, |
|
|
135 |
clustering_omics=CLUSTERING_OMICS, |
|
|
136 |
metadata_usage=None, |
|
|
137 |
feature_selection_usage='individual', |
|
|
138 |
use_r_packages=USE_R_PACKAGES_FOR_SURVIVAL, |
|
|
139 |
seed=SEED, |
|
|
140 |
alternative_embedding=None, |
|
|
141 |
do_KM_plot=True, |
|
|
142 |
verbose=True, |
|
|
143 |
_isboosting=False, |
|
|
144 |
dataset=None, |
|
|
145 |
kwargs_alternative_embedding={}, |
|
|
146 |
deep_model_additional_args={}): |
|
|
147 |
""" |
|
|
148 |
""" |
|
|
149 |
self.seed = seed |
|
|
150 |
self.nb_clusters = nb_clusters |
|
|
151 |
self.pvalue_thres = pvalue_thres |
|
|
152 |
self.cindex_thres = cindex_thres |
|
|
153 |
self.use_autoencoders = use_autoencoders |
|
|
154 |
self.classifier_grid = GridSearchCV(CLASSIFIER(), HYPER_PARAMETERS, cv=5) |
|
|
155 |
self.cluster_array = cluster_array |
|
|
156 |
self.path_results = path_results |
|
|
157 |
self.clustering_omics = clustering_omics |
|
|
158 |
self.use_r_packages = use_r_packages |
|
|
159 |
self.metadata_usage = metadata_usage_type(metadata_usage) |
|
|
160 |
self.feature_selection_usage = feature_selection_usage_type( |
|
|
161 |
feature_selection_usage) |
|
|
162 |
|
|
|
163 |
self.feature_surv_analysis = feature_surv_analysis |
|
|
164 |
|
|
|
165 |
if self.feature_selection_usage is None: |
|
|
166 |
self.feature_surv_analysis = False |
|
|
167 |
|
|
|
168 |
self.alternative_embedding = alternative_embedding |
|
|
169 |
self.kwargs_alternative_embedding = kwargs_alternative_embedding |
|
|
170 |
|
|
|
171 |
if self.path_results and not isdir(self.path_results): |
|
|
172 |
mkdir(self.path_results) |
|
|
173 |
|
|
|
174 |
self.mixture_params = mixture_params |
|
|
175 |
|
|
|
176 |
self.project_name = project_name |
|
|
177 |
self._project_name = project_name |
|
|
178 |
self.do_KM_plot = do_KM_plot |
|
|
179 |
self.nb_threads_coxph = nb_threads_coxph |
|
|
180 |
self.classification_method = classification_method |
|
|
181 |
self.nb_selected_features = nb_selected_features |
|
|
182 |
self.node_selection = node_selection |
|
|
183 |
|
|
|
184 |
self.train_pvalue = None |
|
|
185 |
self.train_pvalue_proba = None |
|
|
186 |
self.full_pvalue = None |
|
|
187 |
self.full_pvalue_proba = None |
|
|
188 |
self.cv_pvalue = None |
|
|
189 |
self.cv_pvalue_proba = None |
|
|
190 |
self.test_pvalue = None |
|
|
191 |
self.test_pvalue_proba = None |
|
|
192 |
|
|
|
193 |
self.classifier = None |
|
|
194 |
self.classifier_test = None |
|
|
195 |
self.clustering = None |
|
|
196 |
|
|
|
197 |
self.classifier_dict = {} |
|
|
198 |
|
|
|
199 |
self.encoder_for_kde_plot_dict = {} |
|
|
200 |
self._main_kernel = {} |
|
|
201 |
|
|
|
202 |
self.classifier_type = classifier_type |
|
|
203 |
|
|
|
204 |
self.used_normalization = None |
|
|
205 |
self.test_normalization = None |
|
|
206 |
|
|
|
207 |
self.used_features_for_classif = None |
|
|
208 |
|
|
|
209 |
self._isboosting = _isboosting |
|
|
210 |
self._pretrained_model = False |
|
|
211 |
self._is_fitted = False |
|
|
212 |
|
|
|
213 |
self.valid_node_ids_array = {} |
|
|
214 |
self.activities_array = {} |
|
|
215 |
self.activities_pred_array = {} |
|
|
216 |
self.pred_node_ids_array = {} |
|
|
217 |
|
|
|
218 |
self.activities_train = None |
|
|
219 |
self.activities_test = None |
|
|
220 |
self.activities_cv = None |
|
|
221 |
|
|
|
222 |
self.activities_for_pred_train = None |
|
|
223 |
self.activities_for_pred_test = None |
|
|
224 |
self.activities_for_pred_cv = None |
|
|
225 |
|
|
|
226 |
self.test_labels = None |
|
|
227 |
self.test_labels_proba = None |
|
|
228 |
self.cv_labels = None |
|
|
229 |
self.cv_labels_proba = None |
|
|
230 |
self.full_labels = None |
|
|
231 |
self.full_labels_proba = None |
|
|
232 |
|
|
|
233 |
self.labels = None |
|
|
234 |
self.labels_proba = None |
|
|
235 |
|
|
|
236 |
self.training_omic_list = [] |
|
|
237 |
self.test_omic_list = [] |
|
|
238 |
|
|
|
239 |
self.feature_scores = defaultdict(list) |
|
|
240 |
self.feature_scores_per_cluster = {} |
|
|
241 |
|
|
|
242 |
self._label_ordered_dict = {} |
|
|
243 |
|
|
|
244 |
self.clustering_performance = None |
|
|
245 |
self.bic_score = None |
|
|
246 |
self.silhouette_score = None |
|
|
247 |
self.calinski_score = None |
|
|
248 |
|
|
|
249 |
self.cluster_method = cluster_method |
|
|
250 |
self.cluster_eval_method = cluster_eval_method |
|
|
251 |
self.verbose = verbose |
|
|
252 |
self._load_existing_models = load_existing_models |
|
|
253 |
self._features_scores_changed = False |
|
|
254 |
|
|
|
255 |
self.path_to_save_model = path_to_save_model |
|
|
256 |
|
|
|
257 |
deep_model_additional_args['path_to_save_model'] = self.path_to_save_model |
|
|
258 |
|
|
|
259 |
DeepBase.__init__(self, |
|
|
260 |
verbose=self.verbose, |
|
|
261 |
dataset=dataset, |
|
|
262 |
alternative_embedding=self.alternative_embedding, |
|
|
263 |
kwargs_alternative_embedding=self.kwargs_alternative_embedding, |
|
|
264 |
**deep_model_additional_args) |
|
|
265 |
|
|
|
266 |
def _look_for_nodes(self, key): |
|
|
267 |
""" |
|
|
268 |
""" |
|
|
269 |
assert(self.node_selection in ['Cox-PH', 'C-index']) |
|
|
270 |
|
|
|
271 |
if self.metadata_usage in ['all', 'new-features'] and \ |
|
|
272 |
self.dataset.metadata_mat is not None: |
|
|
273 |
metadata_mat = self.dataset.metadata_mat |
|
|
274 |
else: |
|
|
275 |
metadata_mat = None |
|
|
276 |
|
|
|
277 |
if self.node_selection == 'Cox-PH': |
|
|
278 |
return self._look_for_survival_nodes( |
|
|
279 |
key, metadata_mat=metadata_mat) |
|
|
280 |
|
|
|
281 |
if self.node_selection == 'C-index': |
|
|
282 |
return self._look_for_prediction_nodes(key) |
|
|
283 |
|
|
|
284 |
def load_new_test_dataset(self, tsv_dict, |
|
|
285 |
fname_key=None, |
|
|
286 |
path_survival_file=None, |
|
|
287 |
normalization=None, |
|
|
288 |
survival_flag=None, |
|
|
289 |
metadata_file=None): |
|
|
290 |
""" |
|
|
291 |
""" |
|
|
292 |
self.dataset.load_new_test_dataset( |
|
|
293 |
tsv_dict, |
|
|
294 |
path_survival_file, |
|
|
295 |
normalization=normalization, |
|
|
296 |
survival_flag=survival_flag, |
|
|
297 |
metadata_file=metadata_file |
|
|
298 |
) |
|
|
299 |
|
|
|
300 |
if normalization is not None: |
|
|
301 |
self.test_normalization = { |
|
|
302 |
key: normalization[key] |
|
|
303 |
for key in normalization |
|
|
304 |
if normalization[key]} |
|
|
305 |
|
|
|
306 |
else: |
|
|
307 |
self.test_normalization = { |
|
|
308 |
key: self.dataset.normalization[key] |
|
|
309 |
for key in self.dataset.normalization |
|
|
310 |
if self.dataset.normalization[key]} |
|
|
311 |
|
|
|
312 |
if self.used_normalization != self.test_normalization: |
|
|
313 |
if self.verbose: |
|
|
314 |
print('recombuting feature scores...') |
|
|
315 |
|
|
|
316 |
self.feature_scores = {} |
|
|
317 |
self.compute_feature_scores(use_ref=True) |
|
|
318 |
self._features_scores_changed = True |
|
|
319 |
|
|
|
320 |
if fname_key: |
|
|
321 |
self.project_name = '{0}_{1}'.format(self._project_name, fname_key) |
|
|
322 |
|
|
|
323 |
def fit_on_pretrained_label_file(self, label_file): |
|
|
324 |
""" |
|
|
325 |
fit a deepprog simdeep model without training autoencoder but just using a ID->labels file to train a classifier |
|
|
326 |
""" |
|
|
327 |
self._pretrained_model = True |
|
|
328 |
self.use_autoencoders = False |
|
|
329 |
self.feature_surv_analysis = False |
|
|
330 |
|
|
|
331 |
self.dataset.load_array() |
|
|
332 |
self.dataset.load_survival() |
|
|
333 |
self.dataset.load_meta_data() |
|
|
334 |
self.dataset.subset_training_sets() |
|
|
335 |
|
|
|
336 |
labels_dict = load_labels_file(label_file) |
|
|
337 |
|
|
|
338 |
train, test, labels, labels_proba = [], [], [], [] |
|
|
339 |
|
|
|
340 |
for index, sample in enumerate(self.dataset.sample_ids): |
|
|
341 |
|
|
|
342 |
if sample in labels_dict: |
|
|
343 |
train.append(index) |
|
|
344 |
label, label_proba = labels_dict[sample] |
|
|
345 |
|
|
|
346 |
labels.append(label) |
|
|
347 |
labels_proba.append(label_proba) |
|
|
348 |
|
|
|
349 |
else: |
|
|
350 |
test.append(index) |
|
|
351 |
|
|
|
352 |
if test: |
|
|
353 |
self.dataset.cross_validation_instance = (train, test) |
|
|
354 |
else: |
|
|
355 |
self.dataset.cross_validation_instance = None |
|
|
356 |
|
|
|
357 |
self.dataset.create_a_cv_split() |
|
|
358 |
self.dataset.normalize_training_array() |
|
|
359 |
|
|
|
360 |
self.matrix_train_array = self.dataset.matrix_train_array |
|
|
361 |
|
|
|
362 |
for key in self.matrix_train_array: |
|
|
363 |
self.matrix_train_array[key] = self.matrix_train_array[key].astype('float32') |
|
|
364 |
|
|
|
365 |
self.training_omic_list = self.dataset.training_tsv.keys() |
|
|
366 |
|
|
|
367 |
self.predict_labels_using_external_labels(labels, labels_proba) |
|
|
368 |
|
|
|
369 |
self.used_normalization = {key: self.dataset.normalization[key] |
|
|
370 |
for key in self.dataset.normalization |
|
|
371 |
if self.dataset.normalization[key]} |
|
|
372 |
|
|
|
373 |
self.used_features_for_classif = self.dataset.feature_train_array |
|
|
374 |
self.look_for_survival_nodes() |
|
|
375 |
self.fit_classification_model() |
|
|
376 |
|
|
|
377 |
def predict_labels_using_external_labels(self, labels, labels_proba): |
|
|
378 |
""" |
|
|
379 |
""" |
|
|
380 |
self.labels = labels |
|
|
381 |
nb_clusters = len(set(self.labels)) |
|
|
382 |
self.labels_proba = np.array([labels_proba for _ in range(nb_clusters)]).T |
|
|
383 |
|
|
|
384 |
nbdays, isdead = self.dataset.survival.T.tolist() |
|
|
385 |
|
|
|
386 |
pvalue = coxph(self.labels, isdead, nbdays, |
|
|
387 |
isfactor=False, |
|
|
388 |
do_KM_plot=self.do_KM_plot, |
|
|
389 |
png_path=self.path_results, |
|
|
390 |
seed=self.seed, |
|
|
391 |
use_r_packages=self.use_r_packages, |
|
|
392 |
fig_name='{0}_KM_plot_training_dataset'.format(self.project_name)) |
|
|
393 |
|
|
|
394 |
pvalue_proba = coxph(self.labels_proba.T[0], isdead, nbdays, |
|
|
395 |
seed=self.seed, |
|
|
396 |
use_r_packages=self.use_r_packages, |
|
|
397 |
isfactor=False) |
|
|
398 |
|
|
|
399 |
if not self._isboosting: |
|
|
400 |
self._write_labels(self.dataset.sample_ids, self.labels, |
|
|
401 |
labels_proba=self.labels_proba.T[0], |
|
|
402 |
fname='{0}_training_set_labels'.format(self.project_name)) |
|
|
403 |
|
|
|
404 |
if self.verbose: |
|
|
405 |
print('Cox-PH p-value (Log-Rank) for the cluster labels: {0}'.format(pvalue)) |
|
|
406 |
|
|
|
407 |
self.train_pvalue = pvalue |
|
|
408 |
self.train_pvalue_proba = pvalue_proba |
|
|
409 |
|
|
|
410 |
def fit(self): |
|
|
411 |
""" |
|
|
412 |
main function |
|
|
413 |
I) construct an autoencoder or fit alternative embedding |
|
|
414 |
II) predict nodes linked with survival (if active) |
|
|
415 |
and III) do clustering |
|
|
416 |
""" |
|
|
417 |
if self._load_existing_models: |
|
|
418 |
self.load_encoders() |
|
|
419 |
|
|
|
420 |
if not self.is_model_loaded: |
|
|
421 |
if self.alternative_embedding is not None: |
|
|
422 |
self.fit_alternative_embedding() |
|
|
423 |
else: |
|
|
424 |
self.construct_autoencoders() |
|
|
425 |
|
|
|
426 |
self.look_for_survival_nodes() |
|
|
427 |
|
|
|
428 |
self.training_omic_list = list(self.encoder_array.keys()) |
|
|
429 |
self.predict_labels() |
|
|
430 |
|
|
|
431 |
self.used_normalization = {key: self.dataset.normalization[key] |
|
|
432 |
for key in self.dataset.normalization |
|
|
433 |
if self.dataset.normalization[key]} |
|
|
434 |
|
|
|
435 |
self.used_features_for_classif = self.dataset.feature_train_array |
|
|
436 |
self.fit_classification_model() |
|
|
437 |
|
|
|
438 |
def predict_labels_on_test_fold(self): |
|
|
439 |
""" |
|
|
440 |
""" |
|
|
441 |
if not self.dataset.cross_validation_instance: |
|
|
442 |
return |
|
|
443 |
|
|
|
444 |
self.dataset.load_matrix_test_fold() |
|
|
445 |
|
|
|
446 |
nbdays, isdead = self.dataset.survival_cv.T.tolist() |
|
|
447 |
self.activities_cv = self._predict_survival_nodes( |
|
|
448 |
self.dataset.matrix_cv_array) |
|
|
449 |
|
|
|
450 |
self.cv_labels, self.cv_labels_proba = self._predict_labels( |
|
|
451 |
self.activities_cv, self.dataset.matrix_cv_array) |
|
|
452 |
|
|
|
453 |
if self.verbose: |
|
|
454 |
print('#### report of test fold cluster:):') |
|
|
455 |
for key, value in Counter(self.cv_labels).items(): |
|
|
456 |
print('class: {0}, number of samples :{1}'.format(key, value)) |
|
|
457 |
|
|
|
458 |
if self.metadata_usage in ['all', 'labels'] and \ |
|
|
459 |
self.dataset.metadata_mat_cv is not None: |
|
|
460 |
metadata_mat = self.dataset.metadata_mat_cv |
|
|
461 |
else: |
|
|
462 |
metadata_mat = None |
|
|
463 |
|
|
|
464 |
pvalue, pvalue_proba = self._compute_test_coxph('KM_plot_test_fold', |
|
|
465 |
nbdays, isdead, |
|
|
466 |
self.cv_labels, |
|
|
467 |
self.cv_labels_proba, |
|
|
468 |
metadata_mat=metadata_mat) |
|
|
469 |
self.cv_pvalue = pvalue |
|
|
470 |
self.cv_pvalue_proba = pvalue_proba |
|
|
471 |
|
|
|
472 |
if not self._isboosting: |
|
|
473 |
self._write_labels(self.dataset.sample_ids_cv, self.cv_labels, |
|
|
474 |
labels_proba=self.cv_labels_proba.T[0], |
|
|
475 |
fname='{0}_test_fold_labels'.format(self.project_name)) |
|
|
476 |
|
|
|
477 |
return self.cv_labels, pvalue, pvalue_proba |
|
|
478 |
|
|
|
479 |
def predict_labels_on_full_dataset(self): |
|
|
480 |
""" |
|
|
481 |
""" |
|
|
482 |
self.dataset.load_matrix_full() |
|
|
483 |
|
|
|
484 |
nbdays, isdead = self.dataset.survival_full.T.tolist() |
|
|
485 |
|
|
|
486 |
self.activities_full = self._predict_survival_nodes( |
|
|
487 |
self.dataset.matrix_full_array) |
|
|
488 |
|
|
|
489 |
self.full_labels, self.full_labels_proba = self._predict_labels( |
|
|
490 |
self.activities_full, self.dataset.matrix_full_array) |
|
|
491 |
|
|
|
492 |
if self.verbose: |
|
|
493 |
print('#### report of assigned cluster for full dataset:') |
|
|
494 |
for key, value in Counter(self.full_labels).items(): |
|
|
495 |
print('class: {0}, number of samples :{1}'.format(key, value)) |
|
|
496 |
|
|
|
497 |
if self.metadata_usage in ['all', 'labels'] and \ |
|
|
498 |
self.dataset.metadata_mat_full is not None: |
|
|
499 |
metadata_mat = self.dataset.metadata_mat_full |
|
|
500 |
else: |
|
|
501 |
metadata_mat = None |
|
|
502 |
|
|
|
503 |
pvalue, pvalue_proba = self._compute_test_coxph('KM_plot_full', |
|
|
504 |
nbdays, isdead, |
|
|
505 |
self.full_labels, |
|
|
506 |
self.full_labels_proba, |
|
|
507 |
metadata_mat=metadata_mat) |
|
|
508 |
self.full_pvalue = pvalue |
|
|
509 |
self.full_pvalue_proba = pvalue_proba |
|
|
510 |
|
|
|
511 |
if not self._isboosting: |
|
|
512 |
self._write_labels(self.dataset.sample_ids_full, self.full_labels, |
|
|
513 |
labels_proba=self.full_labels_proba.T[0], |
|
|
514 |
fname='{0}_full_labels'.format(self.project_name)) |
|
|
515 |
|
|
|
516 |
return self.full_labels, pvalue, pvalue_proba |
|
|
517 |
|
|
|
518 |
def predict_labels_on_test_dataset(self): |
|
|
519 |
""" |
|
|
520 |
""" |
|
|
521 |
if self.dataset.survival_test is not None: |
|
|
522 |
nbdays, isdead = self.dataset.survival_test.T.tolist() |
|
|
523 |
|
|
|
524 |
self.test_omic_list = list(self.dataset.matrix_test_array.keys()) |
|
|
525 |
self.test_omic_list = list(set(self.test_omic_list).intersection( |
|
|
526 |
self.training_omic_list)) |
|
|
527 |
|
|
|
528 |
try: |
|
|
529 |
assert(len(self.test_omic_list) > 0) |
|
|
530 |
except AssertionError: |
|
|
531 |
raise Exception('in predict_labels_on_test_dataset: test_omic_list is empty!'\ |
|
|
532 |
'\n either no common omic with trining_omic_list or error!') |
|
|
533 |
|
|
|
534 |
self.fit_classification_test_model() |
|
|
535 |
|
|
|
536 |
self.activities_test = self._predict_survival_nodes( |
|
|
537 |
self.dataset.matrix_test_array) |
|
|
538 |
self._predict_test_labels(self.activities_test, |
|
|
539 |
self.dataset.matrix_test_array) |
|
|
540 |
|
|
|
541 |
if self.verbose: |
|
|
542 |
print('#### report of assigned cluster:') |
|
|
543 |
for key, value in Counter(self.test_labels).items(): |
|
|
544 |
print('class: {0}, number of samples :{1}'.format(key, value)) |
|
|
545 |
|
|
|
546 |
if self.metadata_usage in ['all', 'test-labels'] and \ |
|
|
547 |
self.dataset.metadata_mat_test is not None: |
|
|
548 |
metadata_mat = self.dataset.metadata_mat_test |
|
|
549 |
else: |
|
|
550 |
metadata_mat = None |
|
|
551 |
|
|
|
552 |
pvalue, pvalue_proba = self._compute_test_coxph('KM_plot_test', |
|
|
553 |
nbdays, isdead, |
|
|
554 |
self.test_labels, |
|
|
555 |
self.test_labels_proba, |
|
|
556 |
metadata_mat=metadata_mat) |
|
|
557 |
self.test_pvalue = pvalue |
|
|
558 |
self.test_pvalue_proba = pvalue_proba |
|
|
559 |
|
|
|
560 |
if self.dataset.survival_test is not None: |
|
|
561 |
if np.isnan(nbdays).all(): |
|
|
562 |
pvalue, pvalue_proba = self._compute_test_coxph( |
|
|
563 |
'KM_plot_test', |
|
|
564 |
nbdays, isdead, |
|
|
565 |
self.test_labels, self.test_labels_proba) |
|
|
566 |
|
|
|
567 |
self.test_pvalue = pvalue |
|
|
568 |
self.test_pvalue_proba = pvalue_proba |
|
|
569 |
|
|
|
570 |
if not self._isboosting: |
|
|
571 |
self._write_labels(self.dataset.sample_ids_test, self.test_labels, |
|
|
572 |
labels_proba=self.test_labels_proba.T[0], |
|
|
573 |
fname='{0}_test_labels'.format(self.project_name)) |
|
|
574 |
|
|
|
575 |
return self.test_labels, pvalue, pvalue_proba |
|
|
576 |
|
|
|
577 |
def _compute_test_coxph(self, |
|
|
578 |
fname_base, |
|
|
579 |
nbdays, |
|
|
580 |
isdead, |
|
|
581 |
labels, |
|
|
582 |
labels_proba, |
|
|
583 |
metadata_mat=None): |
|
|
584 |
""" """ |
|
|
585 |
pvalue = coxph( |
|
|
586 |
labels, isdead, nbdays, |
|
|
587 |
isfactor=False, |
|
|
588 |
do_KM_plot=self.do_KM_plot, |
|
|
589 |
png_path=self.path_results, |
|
|
590 |
seed=self.seed, |
|
|
591 |
use_r_packages=self.use_r_packages, |
|
|
592 |
metadata_mat=metadata_mat, |
|
|
593 |
fig_name='{0}_{1}'.format(self.project_name, fname_base)) |
|
|
594 |
|
|
|
595 |
if self.verbose: |
|
|
596 |
print('Cox-PH p-value (Log-Rank) for inferred labels: {0}'.format(pvalue)) |
|
|
597 |
|
|
|
598 |
pvalue_proba = coxph( |
|
|
599 |
labels_proba.T[0], |
|
|
600 |
isdead, nbdays, |
|
|
601 |
isfactor=False, |
|
|
602 |
do_KM_plot=False, |
|
|
603 |
png_path=self.path_results, |
|
|
604 |
seed=self.seed, |
|
|
605 |
use_r_packages=self.use_r_packages, |
|
|
606 |
metadata_mat=metadata_mat, |
|
|
607 |
fig_name='{0}_{1}_proba'.format(self.project_name, fname_base)) |
|
|
608 |
|
|
|
609 |
if self.verbose: |
|
|
610 |
print('Cox-PH proba p-value (Log-Rank) for inferred labels: {0}'.format(pvalue_proba)) |
|
|
611 |
|
|
|
612 |
return pvalue, pvalue_proba |
|
|
613 |
|
|
|
614 |
def compute_feature_scores(self, use_ref=False): |
|
|
615 |
""" |
|
|
616 |
""" |
|
|
617 |
if self.feature_scores: |
|
|
618 |
return |
|
|
619 |
|
|
|
620 |
pool = None |
|
|
621 |
|
|
|
622 |
if not self._isboosting: |
|
|
623 |
pool = Pool(self.nb_threads_coxph) |
|
|
624 |
mapf = pool.map |
|
|
625 |
mapf = map |
|
|
626 |
else: |
|
|
627 |
mapf = map |
|
|
628 |
|
|
|
629 |
def generator(labels, feature_list, matrix): |
|
|
630 |
for i in range(len(feature_list)): |
|
|
631 |
yield feature_list[i], matrix[i], labels |
|
|
632 |
|
|
|
633 |
if use_ref: |
|
|
634 |
key_array = list(self.dataset.matrix_ref_array.keys()) |
|
|
635 |
else: |
|
|
636 |
key_array = list(self.dataset.matrix_train_array.keys()) |
|
|
637 |
|
|
|
638 |
for key in key_array: |
|
|
639 |
if use_ref: |
|
|
640 |
feature_list = self.dataset.feature_ref_array[key][:] |
|
|
641 |
matrix = self.dataset.matrix_ref_array[key][:] |
|
|
642 |
else: |
|
|
643 |
feature_list = self.dataset.feature_train_array[key][:] |
|
|
644 |
matrix = self.dataset.matrix_train_array[key][:] |
|
|
645 |
|
|
|
646 |
labels = self.labels[:] |
|
|
647 |
|
|
|
648 |
input_list = generator(labels, feature_list, matrix.T) |
|
|
649 |
|
|
|
650 |
features_scored = list(mapf(_process_parallel_feature_importance, input_list)) |
|
|
651 |
features_scored.sort(key=lambda x:x[1]) |
|
|
652 |
|
|
|
653 |
self.feature_scores[key] = features_scored |
|
|
654 |
|
|
|
655 |
if pool is not None: |
|
|
656 |
pool.close() |
|
|
657 |
pool.join() |
|
|
658 |
|
|
|
659 |
def compute_feature_scores_per_cluster(self, use_ref=False, |
|
|
660 |
pval_thres=0.01): |
|
|
661 |
""" |
|
|
662 |
""" |
|
|
663 |
print('computing feature importance per cluster...') |
|
|
664 |
|
|
|
665 |
mapf = map |
|
|
666 |
|
|
|
667 |
for label in set(self.labels): |
|
|
668 |
self.feature_scores_per_cluster[label] = [] |
|
|
669 |
|
|
|
670 |
def generator(labels, feature_list, matrix): |
|
|
671 |
for i in range(len(feature_list)): |
|
|
672 |
yield feature_list[i], matrix[i], labels, pval_thres |
|
|
673 |
|
|
|
674 |
if use_ref: |
|
|
675 |
key_array = list(self.dataset.matrix_ref_array.keys()) |
|
|
676 |
else: |
|
|
677 |
key_array = list(self.dataset.matrix_train_array.keys()) |
|
|
678 |
|
|
|
679 |
for key in key_array: |
|
|
680 |
if use_ref: |
|
|
681 |
feature_list = self.dataset.feature_ref_array[key][:] |
|
|
682 |
matrix = self.dataset.matrix_ref_array[key][:] |
|
|
683 |
else: |
|
|
684 |
feature_list = self.dataset.feature_train_array[key][:] |
|
|
685 |
matrix = self.dataset.matrix_train_array[key][:] |
|
|
686 |
|
|
|
687 |
labels = self.labels[:] |
|
|
688 |
|
|
|
689 |
input_list = generator(labels, feature_list, matrix.T) |
|
|
690 |
|
|
|
691 |
features_scored = mapf(_process_parallel_feature_importance_per_cluster, input_list) |
|
|
692 |
features_scored = [feat for feat_list in features_scored for feat in feat_list] |
|
|
693 |
|
|
|
694 |
for label, feature, median_diff, pvalue in features_scored: |
|
|
695 |
self.feature_scores_per_cluster[label].append((feature, median_diff, pvalue)) |
|
|
696 |
|
|
|
697 |
for label in self.feature_scores_per_cluster: |
|
|
698 |
self.feature_scores_per_cluster[label].sort(key=lambda x:x[1]) |
|
|
699 |
|
|
|
700 |
def write_feature_score_per_cluster(self): |
|
|
701 |
""" |
|
|
702 |
""" |
|
|
703 |
f_file_name = '{0}/{1}_features_scores_per_clusters.tsv'.format( |
|
|
704 |
self.path_results, self._project_name) |
|
|
705 |
f_anti_name = '{0}/{1}_features_anticorrelated_scores_per_clusters.tsv'.format( |
|
|
706 |
self.path_results, self._project_name) |
|
|
707 |
|
|
|
708 |
f_file = open(f_file_name, 'w') |
|
|
709 |
f_anti_file = open(f_anti_name, 'w') |
|
|
710 |
|
|
|
711 |
f_file.write('cluster id;feature;median diff;p-value\n') |
|
|
712 |
|
|
|
713 |
for label in self.feature_scores_per_cluster: |
|
|
714 |
for feature, median_diff, pvalue in self.feature_scores_per_cluster[label]: |
|
|
715 |
if median_diff > 0: |
|
|
716 |
f_to_write = f_file |
|
|
717 |
else: |
|
|
718 |
f_to_write = f_anti_file |
|
|
719 |
|
|
|
720 |
f_to_write.write('{0};{1};{2};{3}\n'.format(label, feature, median_diff, pvalue)) |
|
|
721 |
|
|
|
722 |
print('{0} written'.format(f_file_name)) |
|
|
723 |
print('{0} written'.format(f_anti_name)) |
|
|
724 |
|
|
|
725 |
def write_feature_scores(self): |
|
|
726 |
""" |
|
|
727 |
""" |
|
|
728 |
with open('{0}/{1}_features_scores.tsv'.format( |
|
|
729 |
self.path_results, self.project_name), 'w') as f_file: |
|
|
730 |
|
|
|
731 |
for key in self.feature_scores: |
|
|
732 |
f_file.write('#### {0} ####\n'.format(key)) |
|
|
733 |
|
|
|
734 |
for feature, score in self.feature_scores[key]: |
|
|
735 |
f_file.write('{0};{1}\n'.format(feature, score)) |
|
|
736 |
|
|
|
737 |
print('{0}/{1}_features_scores.tsv written'.format( |
|
|
738 |
self.path_results, self.project_name)) |
|
|
739 |
|
|
|
740 |
def _return_train_matrix_for_classification(self): |
|
|
741 |
""" |
|
|
742 |
""" |
|
|
743 |
assert (self.classification_method in _CLASSIFICATION_METHOD_LIST) |
|
|
744 |
|
|
|
745 |
if self.verbose: |
|
|
746 |
print('classification method: {0}'.format( |
|
|
747 |
self.classification_method)) |
|
|
748 |
|
|
|
749 |
if self.classification_method == 'SURVIVAL_FEATURES': |
|
|
750 |
assert(self.classifier_type != 'clustering') |
|
|
751 |
matrix = self._predict_survival_nodes( |
|
|
752 |
self.dataset.matrix_ref_array) |
|
|
753 |
elif self.classification_method == 'ALL_FEATURES': |
|
|
754 |
matrix = self._reduce_and_stack_matrices( |
|
|
755 |
self.dataset.matrix_ref_array) |
|
|
756 |
|
|
|
757 |
if self.verbose: |
|
|
758 |
print('number of features for the classifier: {0}'.format( |
|
|
759 |
matrix.shape[1])) |
|
|
760 |
|
|
|
761 |
return np.nan_to_num(matrix) |
|
|
762 |
|
|
|
763 |
def _reduce_and_stack_matrices(self, matrices): |
|
|
764 |
""" |
|
|
765 |
""" |
|
|
766 |
if not self.nb_selected_features: |
|
|
767 |
return hstack(matrices.values()) |
|
|
768 |
else: |
|
|
769 |
self.compute_feature_scores() |
|
|
770 |
matrix = [] |
|
|
771 |
|
|
|
772 |
for key in matrices: |
|
|
773 |
index = [self.dataset.feature_ref_index[key][feature] |
|
|
774 |
for feature, pvalue in |
|
|
775 |
self.feature_scores[key][:self.nb_selected_features] |
|
|
776 |
if feature in self.dataset.feature_ref_index[key] |
|
|
777 |
] |
|
|
778 |
|
|
|
779 |
matrix.append(matrices[key].T[index].T) |
|
|
780 |
|
|
|
781 |
return hstack(matrix) |
|
|
782 |
|
|
|
783 |
def fit_classification_model(self): |
|
|
784 |
""" """ |
|
|
785 |
train_matrix = self._return_train_matrix_for_classification() |
|
|
786 |
labels = self.labels |
|
|
787 |
|
|
|
788 |
if self.classifier_type == 'clustering': |
|
|
789 |
if self.verbose: |
|
|
790 |
print('clustering model defined as the classifier') |
|
|
791 |
|
|
|
792 |
self.classifier = self.clustering |
|
|
793 |
return |
|
|
794 |
|
|
|
795 |
if self.verbose: |
|
|
796 |
print('classification analysis...') |
|
|
797 |
|
|
|
798 |
if isinstance(self.seed, int): |
|
|
799 |
np.random.seed(self.seed) |
|
|
800 |
|
|
|
801 |
with warnings.catch_warnings(): |
|
|
802 |
warnings.simplefilter("ignore") |
|
|
803 |
self.classifier_grid.fit(train_matrix, labels) |
|
|
804 |
|
|
|
805 |
self.classifier, params = select_best_classif_params( |
|
|
806 |
self.classifier_grid) |
|
|
807 |
|
|
|
808 |
self.classifier.set_params(probability=True) |
|
|
809 |
self.classifier.fit(train_matrix, labels) |
|
|
810 |
|
|
|
811 |
self.classifier_dict[str(self.used_normalization)] = self.classifier |
|
|
812 |
|
|
|
813 |
if self.verbose: |
|
|
814 |
cvs = cross_val_score(self.classifier, train_matrix, labels, cv=5) |
|
|
815 |
print('best params:', params) |
|
|
816 |
print('cross val score: {0}'.format(np.mean(cvs))) |
|
|
817 |
print('classification score:', self.classifier.score( |
|
|
818 |
train_matrix, labels)) |
|
|
819 |
|
|
|
820 |
def fit_classification_test_model(self): |
|
|
821 |
""" """ |
|
|
822 |
is_same_features = self.used_features_for_classif == self.dataset.feature_ref_array |
|
|
823 |
is_same_normalization = self.used_normalization == self.test_normalization |
|
|
824 |
is_filled_with_zero = self.dataset.fill_unkown_feature_with_0 |
|
|
825 |
|
|
|
826 |
if (is_same_features and is_same_normalization and is_filled_with_zero)\ |
|
|
827 |
or self.classifier_type == 'clustering': |
|
|
828 |
if self.verbose: |
|
|
829 |
print('Not rebuilding the test classifier') |
|
|
830 |
|
|
|
831 |
if self.classifier_test is None: |
|
|
832 |
self.classifier_test = self.classifier |
|
|
833 |
return |
|
|
834 |
|
|
|
835 |
if self.verbose: |
|
|
836 |
print('classification for test set analysis...') |
|
|
837 |
|
|
|
838 |
self.used_normalization = self.dataset.normalization_test |
|
|
839 |
self.used_features_for_classif = self.dataset.feature_ref_array |
|
|
840 |
|
|
|
841 |
train_matrix = self._return_train_matrix_for_classification() |
|
|
842 |
labels = self.labels |
|
|
843 |
|
|
|
844 |
with warnings.catch_warnings(): |
|
|
845 |
warnings.simplefilter("ignore") |
|
|
846 |
self.classifier_grid.fit(train_matrix, labels) |
|
|
847 |
|
|
|
848 |
self.classifier_test, params = select_best_classif_params(self.classifier_grid) |
|
|
849 |
|
|
|
850 |
self.classifier_test.set_params(probability=True) |
|
|
851 |
self.classifier_test.fit(train_matrix, labels) |
|
|
852 |
|
|
|
853 |
if self.verbose: |
|
|
854 |
cvs = cross_val_score(self.classifier_test, train_matrix, labels, cv=5) |
|
|
855 |
print('best params:', params) |
|
|
856 |
print('cross val score: {0}'.format(np.mean(cvs))) |
|
|
857 |
print('classification score:', self.classifier_test.score(train_matrix, labels)) |
|
|
858 |
|
|
|
859 |
def predict_labels(self): |
|
|
860 |
""" |
|
|
861 |
predict labels from training set |
|
|
862 |
using K-Means algorithm on the node activities, |
|
|
863 |
using only nodes linked to survival |
|
|
864 |
""" |
|
|
865 |
if self.verbose: |
|
|
866 |
print('performing clustering on the omic model with the following key:{0}'.format( |
|
|
867 |
self.training_omic_list)) |
|
|
868 |
|
|
|
869 |
if hasattr(self.cluster_method, 'fit_predict'): |
|
|
870 |
self.clustering = self.cluster_method(n_clusters=self.nb_clusters) |
|
|
871 |
self.cluster_method == 'custom' |
|
|
872 |
|
|
|
873 |
elif self.cluster_method == 'kmeans': |
|
|
874 |
self.clustering = KMeans(n_clusters=self.nb_clusters, n_init=100) |
|
|
875 |
|
|
|
876 |
elif self.cluster_method == 'mixture': |
|
|
877 |
self.clustering = GaussianMixture( |
|
|
878 |
n_components=self.nb_clusters, |
|
|
879 |
**self.mixture_params |
|
|
880 |
) |
|
|
881 |
|
|
|
882 |
elif self.cluster_method == "coxPH": |
|
|
883 |
nbdays, isdead = self.dataset.survival.T.tolist() |
|
|
884 |
|
|
|
885 |
self.clustering = ClusterWithSurvival( |
|
|
886 |
n_clusters=self.nb_clusters, |
|
|
887 |
isdead=isdead, |
|
|
888 |
nbdays=nbdays) |
|
|
889 |
|
|
|
890 |
elif self.cluster_method == "coxPHMixture": |
|
|
891 |
nbdays, isdead = self.dataset.survival.T.tolist() |
|
|
892 |
|
|
|
893 |
self.clustering = ClusterWithSurvival( |
|
|
894 |
n_clusters=self.nb_clusters, |
|
|
895 |
use_gaussian_to_dichotomize=True, |
|
|
896 |
isdead=isdead, |
|
|
897 |
nbdays=nbdays) |
|
|
898 |
|
|
|
899 |
else: |
|
|
900 |
raise(Exception("No method fit and predict found for: {0}".format( |
|
|
901 |
self.cluster_method))) |
|
|
902 |
|
|
|
903 |
if not self.activities_train.any(): |
|
|
904 |
raise Exception('No components linked to survival!'\ |
|
|
905 |
' cannot perform clustering') |
|
|
906 |
|
|
|
907 |
if self.cluster_array and len(self.cluster_array) > 1: |
|
|
908 |
self._predict_best_k_for_cluster() |
|
|
909 |
|
|
|
910 |
if hasattr(self.clustering, 'predict'): |
|
|
911 |
self.clustering.fit(self.activities_train) |
|
|
912 |
labels = self.clustering.predict(self.activities_train) |
|
|
913 |
else: |
|
|
914 |
labels = self.clustering.fit_predict(self.activities_train) |
|
|
915 |
|
|
|
916 |
labels = self._order_labels_according_to_survival(labels) |
|
|
917 |
|
|
|
918 |
self.labels = labels |
|
|
919 |
|
|
|
920 |
if hasattr(self.clustering, 'predict_proba'): |
|
|
921 |
self.labels_proba = self.clustering.predict_proba(self.activities_train) |
|
|
922 |
else: |
|
|
923 |
self.labels_proba = np.array([self.labels, self.labels]).T |
|
|
924 |
|
|
|
925 |
if len(self.labels_proba.shape) == 1: |
|
|
926 |
self.labels_proba = self.labels_proba.reshape(( |
|
|
927 |
self.labels_proba.shape[0], 1)) |
|
|
928 |
|
|
|
929 |
if self.labels_proba.shape[1] < self.nb_clusters: |
|
|
930 |
missing_columns = self.nb_clusters - self.labels_proba.shape[1] |
|
|
931 |
|
|
|
932 |
for i in range(missing_columns): |
|
|
933 |
self.labels_proba = hstack([ |
|
|
934 |
self.labels_proba, np.zeros( |
|
|
935 |
shape=(self.labels_proba.shape[0], 1))]) |
|
|
936 |
|
|
|
937 |
if self.verbose: |
|
|
938 |
print("clustering done, labels ordered according to survival:") |
|
|
939 |
for key, value in Counter(labels).items(): |
|
|
940 |
print('cluster label: {0}\t number of samples:{1}'.format(key, value)) |
|
|
941 |
print('\n') |
|
|
942 |
|
|
|
943 |
nbdays, isdead = self.dataset.survival.T.tolist() |
|
|
944 |
|
|
|
945 |
if self.metadata_usage in ['all', 'labels'] and \ |
|
|
946 |
self.dataset.metadata_mat is not None: |
|
|
947 |
metadata_mat = self.dataset.metadata_mat |
|
|
948 |
else: |
|
|
949 |
metadata_mat = None |
|
|
950 |
|
|
|
951 |
pvalue = coxph(self.labels, isdead, nbdays, |
|
|
952 |
isfactor=False, |
|
|
953 |
do_KM_plot=self.do_KM_plot, |
|
|
954 |
png_path=self.path_results, |
|
|
955 |
seed=self.seed, |
|
|
956 |
use_r_packages=self.use_r_packages, |
|
|
957 |
metadata_mat=metadata_mat, |
|
|
958 |
fig_name='{0}_KM_plot_training_dataset'.format(self.project_name)) |
|
|
959 |
|
|
|
960 |
pvalue_proba = coxph(self.labels_proba.T[0], |
|
|
961 |
isdead, nbdays, |
|
|
962 |
seed=self.seed, |
|
|
963 |
use_r_packages=self.use_r_packages, |
|
|
964 |
metadata_mat=metadata_mat, |
|
|
965 |
isfactor=False) |
|
|
966 |
|
|
|
967 |
if not self._isboosting: |
|
|
968 |
self._write_labels(self.dataset.sample_ids, self.labels, |
|
|
969 |
labels_proba=self.labels_proba.T[0], |
|
|
970 |
fname='{0}_training_set_labels'.format(self.project_name)) |
|
|
971 |
|
|
|
972 |
if self.verbose: |
|
|
973 |
print('Cox-PH p-value (Log-Rank) for the cluster labels: {0}'.format(pvalue)) |
|
|
974 |
|
|
|
975 |
self.train_pvalue = pvalue |
|
|
976 |
self.train_pvalue_proba = pvalue_proba |
|
|
977 |
|
|
|
978 |
def evalutate_cluster_performance(self): |
|
|
979 |
""" |
|
|
980 |
""" |
|
|
981 |
if not self.clustering: |
|
|
982 |
print('clustering attribute is defined as None. ' \ |
|
|
983 |
' Cannot evaluate cluster performance') |
|
|
984 |
return |
|
|
985 |
|
|
|
986 |
if self.cluster_method == 'mixture': |
|
|
987 |
self.bic_score = self.clustering.bic(self.activities_train) |
|
|
988 |
|
|
|
989 |
self.silhouette_score = silhouette_score(self.activities_train, self.labels) |
|
|
990 |
self.calinski_score = calinski_harabaz_score(self.activities_train, self.labels) |
|
|
991 |
|
|
|
992 |
if self.verbose: |
|
|
993 |
print('silhouette score: {0}'.format(self.silhouette_score)) |
|
|
994 |
print('calinski-harabaz score: {0}'.format(self.calinski_score)) |
|
|
995 |
print('bic score: {0}'.format(self.bic_score)) |
|
|
996 |
|
|
|
997 |
def _write_labels(self, sample_ids, labels, fname="", |
|
|
998 |
labels_proba=None, |
|
|
999 |
nbdays=None, |
|
|
1000 |
isdead=None, |
|
|
1001 |
path_file=None): |
|
|
1002 |
""" """ |
|
|
1003 |
assert(fname or path_file) |
|
|
1004 |
|
|
|
1005 |
if not path_file: |
|
|
1006 |
path_file = '{0}/{1}.tsv'.format(self.path_results, fname) |
|
|
1007 |
|
|
|
1008 |
with open(path_file, 'w') as f_file: |
|
|
1009 |
for ids, (sample, label) in enumerate(zip(sample_ids, labels)): |
|
|
1010 |
suppl = '' |
|
|
1011 |
|
|
|
1012 |
if labels_proba is not None: |
|
|
1013 |
suppl += '\t{0}'.format(labels_proba[ids]) |
|
|
1014 |
if nbdays is not None: |
|
|
1015 |
suppl += '\t{0}'.format(nbdays[ids]) |
|
|
1016 |
if isdead is not None: |
|
|
1017 |
suppl += '\t{0}'.format(isdead[ids]) |
|
|
1018 |
|
|
|
1019 |
f_file.write('{0}\t{1}{2}\n'.format(sample, label, suppl)) |
|
|
1020 |
|
|
|
1021 |
print('file written: {0}'.format(path_file)) |
|
|
1022 |
|
|
|
1023 |
def _predict_survival_nodes(self, matrix_array, keys=None): |
|
|
1024 |
""" |
|
|
1025 |
""" |
|
|
1026 |
activities_array = {} |
|
|
1027 |
|
|
|
1028 |
if keys is None: |
|
|
1029 |
keys = list(matrix_array.keys()) |
|
|
1030 |
|
|
|
1031 |
for key in keys: |
|
|
1032 |
matrix = matrix_array[key] |
|
|
1033 |
if not self._pretrained_model: |
|
|
1034 |
if self.alternative_embedding is None and \ |
|
|
1035 |
self.encoder_input_shape(key)[1] != matrix.shape[1]: |
|
|
1036 |
if self.verbose: |
|
|
1037 |
print('matrix doesnt have the input dimension of the encoder'\ |
|
|
1038 |
' returning None') |
|
|
1039 |
return None |
|
|
1040 |
|
|
|
1041 |
if self.alternative_embedding is not None: |
|
|
1042 |
activities = self.embedding_predict(key, matrix) |
|
|
1043 |
elif self.use_autoencoders: |
|
|
1044 |
activities = self.encoder_predict(key, matrix) |
|
|
1045 |
else: |
|
|
1046 |
activities = np.asarray(matrix) |
|
|
1047 |
|
|
|
1048 |
activities_array[key] = activities.T[self.valid_node_ids_array[key]].T |
|
|
1049 |
|
|
|
1050 |
return hstack([activities_array[key] |
|
|
1051 |
for key in keys]) |
|
|
1052 |
|
|
|
1053 |
def look_for_survival_nodes(self, keys=None): |
|
|
1054 |
""" |
|
|
1055 |
detect nodes from the autoencoder significantly |
|
|
1056 |
linked with survival through coxph regression |
|
|
1057 |
""" |
|
|
1058 |
if not keys: |
|
|
1059 |
keys = list(self.encoder_array.keys()) |
|
|
1060 |
|
|
|
1061 |
if not keys: |
|
|
1062 |
keys = self.matrix_train_array.keys() |
|
|
1063 |
|
|
|
1064 |
for key in keys: |
|
|
1065 |
matrix_train = self.matrix_train_array[key] |
|
|
1066 |
|
|
|
1067 |
if self.alternative_embedding is not None: |
|
|
1068 |
activities = self.embedding_predict(key, matrix_train) |
|
|
1069 |
elif self.use_autoencoders: |
|
|
1070 |
activities = self.encoder_predict(key, matrix_train) |
|
|
1071 |
else: |
|
|
1072 |
activities = np.asarray(matrix_train) |
|
|
1073 |
|
|
|
1074 |
if self.feature_surv_analysis: |
|
|
1075 |
valid_node_ids = self._look_for_nodes(key) |
|
|
1076 |
else: |
|
|
1077 |
valid_node_ids = np.arange(matrix_train.shape[1]) |
|
|
1078 |
|
|
|
1079 |
self.valid_node_ids_array[key] = valid_node_ids |
|
|
1080 |
self.activities_array[key] = activities.T[valid_node_ids].T |
|
|
1081 |
|
|
|
1082 |
if self.clustering_omics: |
|
|
1083 |
keys = self.clustering_omics |
|
|
1084 |
|
|
|
1085 |
self.activities_train = hstack([self.activities_array[key] |
|
|
1086 |
for key in keys]) |
|
|
1087 |
|
|
|
1088 |
def look_for_prediction_nodes(self, keys=None): |
|
|
1089 |
""" |
|
|
1090 |
detect nodes from the autoencoder that predict a |
|
|
1091 |
high c-index scores using label from the retained test fold |
|
|
1092 |
""" |
|
|
1093 |
if not keys: |
|
|
1094 |
keys = list(self.encoder_array.keys()) |
|
|
1095 |
|
|
|
1096 |
for key in keys: |
|
|
1097 |
matrix_train = self.matrix_train_array[key] |
|
|
1098 |
|
|
|
1099 |
if self.alternative_embedding is not None: |
|
|
1100 |
activities = self.embedding_predict(key, matrix_train) |
|
|
1101 |
elif self.use_autoencoders: |
|
|
1102 |
activities = self.encoder_predict(key, matrix_train) |
|
|
1103 |
else: |
|
|
1104 |
activities = np.asarray(matrix_train) |
|
|
1105 |
|
|
|
1106 |
if self.feature_surv_analysis: |
|
|
1107 |
valid_node_ids = self._look_for_prediction_nodes(key) |
|
|
1108 |
else: |
|
|
1109 |
valid_node_ids = np.arange(matrix_train.shape[1]) |
|
|
1110 |
|
|
|
1111 |
self.pred_node_ids_array[key] = valid_node_ids |
|
|
1112 |
|
|
|
1113 |
self.activities_pred_array[key] = activities.T[valid_node_ids].T |
|
|
1114 |
|
|
|
1115 |
self.activities_for_pred_train = hstack([self.activities_pred_array[key] |
|
|
1116 |
for key in keys]) |
|
|
1117 |
|
|
|
1118 |
def compute_c_indexes_multiple_for_test_dataset(self): |
|
|
1119 |
""" |
|
|
1120 |
return c-index using labels as predicat |
|
|
1121 |
""" |
|
|
1122 |
days, dead = np.asarray(self.dataset.survival).T |
|
|
1123 |
days_test, dead_test = np.asarray(self.dataset.survival_test).T |
|
|
1124 |
|
|
|
1125 |
activities_test = {} |
|
|
1126 |
|
|
|
1127 |
for key in self.dataset.matrix_test_array: |
|
|
1128 |
node_ids = self.pred_node_ids_array[key] |
|
|
1129 |
|
|
|
1130 |
matrix = self.dataset.matrix_test_array[key] |
|
|
1131 |
|
|
|
1132 |
if self.alternative_embedding is not None: |
|
|
1133 |
activities_test[key] = self.embedding_predict( |
|
|
1134 |
key, matrix).T[node_ids].T |
|
|
1135 |
|
|
|
1136 |
elif self.use_autoencoders: |
|
|
1137 |
activities_test[key] = self.encoder_predict( |
|
|
1138 |
key, matrix).T[node_ids].T |
|
|
1139 |
|
|
|
1140 |
else: |
|
|
1141 |
activities_test[key] = self.dataset.matrix_test_array[key] |
|
|
1142 |
|
|
|
1143 |
activities_test = hstack(activities_test.values()) |
|
|
1144 |
activities_train = hstack([self.activities_pred_array[key] |
|
|
1145 |
for key in self.dataset.matrix_ref_array]) |
|
|
1146 |
|
|
|
1147 |
with warnings.catch_warnings(): |
|
|
1148 |
warnings.simplefilter("ignore") |
|
|
1149 |
cindex = c_index_multiple(activities_train, dead, days, |
|
|
1150 |
activities_test, dead_test, days_test, |
|
|
1151 |
seed=self.seed,) |
|
|
1152 |
|
|
|
1153 |
if self.verbose: |
|
|
1154 |
print('c-index multiple for test dataset:{0}'.format(cindex)) |
|
|
1155 |
|
|
|
1156 |
return cindex |
|
|
1157 |
|
|
|
1158 |
def compute_c_indexes_multiple_for_test_fold_dataset(self): |
|
|
1159 |
""" |
|
|
1160 |
return c-index using test-fold labels as predicat |
|
|
1161 |
""" |
|
|
1162 |
days, dead = np.asarray(self.dataset.survival).T |
|
|
1163 |
days_cv, dead_cv = np.asarray(self.dataset.survival_cv).T |
|
|
1164 |
|
|
|
1165 |
activities_cv = {} |
|
|
1166 |
|
|
|
1167 |
for key in self.dataset.matrix_cv_array: |
|
|
1168 |
node_ids = self.pred_node_ids_array[key] |
|
|
1169 |
|
|
|
1170 |
if self.alternative_embedding is not None: |
|
|
1171 |
activities_cv[key] = self.embedding_predict( |
|
|
1172 |
key, self.dataset.matrix_cv_array[key]).T[node_ids].T |
|
|
1173 |
|
|
|
1174 |
elif self.use_autoencoders: |
|
|
1175 |
activities_cv[key] = self.encoder_predict( |
|
|
1176 |
key, self.dataset.matrix_cv_array[key]).T[node_ids].T |
|
|
1177 |
|
|
|
1178 |
else: |
|
|
1179 |
activities_cv[key] = self.dataset.matrix_cv_array[key] |
|
|
1180 |
|
|
|
1181 |
activities_cv = hstack(activities_cv.values()) |
|
|
1182 |
|
|
|
1183 |
with warnings.catch_warnings(): |
|
|
1184 |
warnings.simplefilter("ignore") |
|
|
1185 |
cindex = c_index_multiple(self.activities_for_pred_train, dead, days, |
|
|
1186 |
activities_cv, dead_cv, days_cv, |
|
|
1187 |
seed=self.seed,) |
|
|
1188 |
|
|
|
1189 |
if self.verbose: |
|
|
1190 |
print('c-index multiple for test fold dataset:{0}'.format(cindex)) |
|
|
1191 |
|
|
|
1192 |
return cindex |
|
|
1193 |
|
|
|
1194 |
def _return_test_matrix_for_classification(self, activities, matrix_array): |
|
|
1195 |
""" |
|
|
1196 |
""" |
|
|
1197 |
if self.classification_method == 'SURVIVAL_FEATURES': |
|
|
1198 |
return activities |
|
|
1199 |
elif self.classification_method == 'ALL_FEATURES': |
|
|
1200 |
matrix = self._reduce_and_stack_matrices(matrix_array) |
|
|
1201 |
return matrix |
|
|
1202 |
|
|
|
1203 |
def _predict_test_labels(self, activities, matrix_array): |
|
|
1204 |
""" """ |
|
|
1205 |
matrix_test = self._return_test_matrix_for_classification( |
|
|
1206 |
activities, matrix_array) |
|
|
1207 |
|
|
|
1208 |
self.test_labels = self.classifier_test.predict(matrix_test) |
|
|
1209 |
self.test_labels_proba = self.classifier_test.predict_proba(matrix_test) |
|
|
1210 |
|
|
|
1211 |
if self.test_labels_proba.shape[1] < self.nb_clusters: |
|
|
1212 |
missing_columns = self.nb_clusters - self.test_labels_proba.shape[1] |
|
|
1213 |
|
|
|
1214 |
for i in range(missing_columns): |
|
|
1215 |
self.test_labels_proba = hstack([ |
|
|
1216 |
self.test_labels_proba, np.zeros( |
|
|
1217 |
shape=(self.test_labels_proba, 1))]) |
|
|
1218 |
|
|
|
1219 |
def _predict_labels(self, activities, matrix_array): |
|
|
1220 |
""" """ |
|
|
1221 |
matrix_test = self._return_test_matrix_for_classification( |
|
|
1222 |
activities, matrix_array) |
|
|
1223 |
|
|
|
1224 |
labels = self.classifier.predict(matrix_test) |
|
|
1225 |
labels_proba = self.classifier.predict_proba(matrix_test) |
|
|
1226 |
|
|
|
1227 |
if labels_proba.shape[1] < self.nb_clusters: |
|
|
1228 |
missing_columns = self.nb_clusters - labels_proba.shape[1] |
|
|
1229 |
|
|
|
1230 |
for i in range(missing_columns): |
|
|
1231 |
labels_proba = hstack([ |
|
|
1232 |
labels_proba, np.zeros( |
|
|
1233 |
shape=(labels_proba.shape[0], 1))]) |
|
|
1234 |
|
|
|
1235 |
return labels, labels_proba |
|
|
1236 |
|
|
|
1237 |
def _predict_best_k_for_cluster(self): |
|
|
1238 |
""" """ |
|
|
1239 |
criterion = None |
|
|
1240 |
best_k = None |
|
|
1241 |
|
|
|
1242 |
for k_cluster in self.cluster_array: |
|
|
1243 |
if self.cluster_method == 'mixture': |
|
|
1244 |
self.clustering.set_params(n_components=k_cluster) |
|
|
1245 |
else: |
|
|
1246 |
self.clustering.set_params(n_clusters=k_cluster) |
|
|
1247 |
|
|
|
1248 |
labels = self.clustering.fit_predict(self.activities_train) |
|
|
1249 |
|
|
|
1250 |
if self.cluster_eval_method == 'bic': |
|
|
1251 |
score = self.clustering.bic(self.activities_train) |
|
|
1252 |
elif self.cluster_eval_method == 'calinski': |
|
|
1253 |
score = calinski_harabaz_score( |
|
|
1254 |
self.activities_train, |
|
|
1255 |
labels |
|
|
1256 |
) |
|
|
1257 |
elif self.cluster_eval_method == 'silhouette': |
|
|
1258 |
score = silhouette_score( |
|
|
1259 |
self.activities_train, |
|
|
1260 |
labels |
|
|
1261 |
) |
|
|
1262 |
|
|
|
1263 |
if self.verbose: |
|
|
1264 |
print('obtained {2}: {0} for k = {1}'.format(score, k_cluster, |
|
|
1265 |
self.cluster_eval_method)) |
|
|
1266 |
|
|
|
1267 |
if criterion == None or score < criterion: |
|
|
1268 |
criterion, best_k = score, k_cluster |
|
|
1269 |
|
|
|
1270 |
self.clustering_performance = criterion |
|
|
1271 |
|
|
|
1272 |
if self.verbose: |
|
|
1273 |
print('best k: {0}'.format(best_k)) |
|
|
1274 |
|
|
|
1275 |
if self.cluster_method == 'mixture': |
|
|
1276 |
self.clustering.set_params(n_components=best_k) |
|
|
1277 |
else: |
|
|
1278 |
self.clustering.set_params(n_clusters=best_k) |
|
|
1279 |
|
|
|
1280 |
def _order_labels_according_to_survival(self, labels): |
|
|
1281 |
""" |
|
|
1282 |
Order cluster labels according to survival |
|
|
1283 |
""" |
|
|
1284 |
labels_old = labels.copy() |
|
|
1285 |
|
|
|
1286 |
days, dead = np.asarray(self.dataset.survival).T |
|
|
1287 |
|
|
|
1288 |
self._label_ordered_dict = {} |
|
|
1289 |
|
|
|
1290 |
for label in set(labels_old): |
|
|
1291 |
mean = surv_median(dead[labels_old == label], |
|
|
1292 |
days[labels_old == label]) |
|
|
1293 |
self._label_ordered_dict[label] = mean |
|
|
1294 |
|
|
|
1295 |
label_ordered = [label for label, _ in |
|
|
1296 |
sorted(self._label_ordered_dict.items(), key=lambda x:x[1])] |
|
|
1297 |
|
|
|
1298 |
self._label_ordered_dict = {old_label: new_label |
|
|
1299 |
for new_label, old_label in enumerate(label_ordered)} |
|
|
1300 |
|
|
|
1301 |
for old_label in self._label_ordered_dict: |
|
|
1302 |
labels[labels_old == old_label] = self._label_ordered_dict[old_label] |
|
|
1303 |
|
|
|
1304 |
return labels |
|
|
1305 |
|
|
|
1306 |
def _look_for_survival_nodes(self, key=None, |
|
|
1307 |
activities=None, |
|
|
1308 |
survival=None, |
|
|
1309 |
metadata_mat=None): |
|
|
1310 |
""" |
|
|
1311 |
""" |
|
|
1312 |
if key is not None: |
|
|
1313 |
matrix_train = self.matrix_train_array[key] |
|
|
1314 |
|
|
|
1315 |
if self.alternative_embedding is not None: |
|
|
1316 |
activities = np.nan_to_num(self.embedding_predict( |
|
|
1317 |
key, matrix_train)) |
|
|
1318 |
|
|
|
1319 |
elif self.use_autoencoders: |
|
|
1320 |
activities = np.nan_to_num(self.encoder_predict( |
|
|
1321 |
key, matrix_train)) |
|
|
1322 |
|
|
|
1323 |
else: |
|
|
1324 |
activities = np.asarray(matrix_train) |
|
|
1325 |
else: |
|
|
1326 |
assert(activities is not None) |
|
|
1327 |
|
|
|
1328 |
if survival is not None: |
|
|
1329 |
nbdays, isdead = survival.T.tolist() |
|
|
1330 |
else: |
|
|
1331 |
nbdays, isdead = self.dataset.survival.T.tolist() |
|
|
1332 |
|
|
|
1333 |
if self.feature_selection_usage == 'lasso': |
|
|
1334 |
cws = ClusterWithSurvival( |
|
|
1335 |
isdead=isdead, |
|
|
1336 |
nbdays=nbdays, |
|
|
1337 |
metadata_mat=metadata_mat) |
|
|
1338 |
|
|
|
1339 |
return cws.get_nonzero_features(activities) |
|
|
1340 |
|
|
|
1341 |
else: |
|
|
1342 |
return self._get_survival_features_parallel( |
|
|
1343 |
isdead, nbdays, metadata_mat, activities, key) |
|
|
1344 |
|
|
|
1345 |
def _get_survival_features_parallel( |
|
|
1346 |
self, isdead, nbdays, metadata_mat, activities, key): |
|
|
1347 |
""" """ |
|
|
1348 |
pool = None |
|
|
1349 |
|
|
|
1350 |
if not self._isboosting: |
|
|
1351 |
pool = Pool(self.nb_threads_coxph) |
|
|
1352 |
mapf = pool.map |
|
|
1353 |
else: |
|
|
1354 |
mapf = map |
|
|
1355 |
|
|
|
1356 |
input_list = iter((node_id, |
|
|
1357 |
activity, |
|
|
1358 |
isdead, |
|
|
1359 |
nbdays, |
|
|
1360 |
self.seed, |
|
|
1361 |
metadata_mat, self.use_r_packages) |
|
|
1362 |
|
|
|
1363 |
for node_id, activity in enumerate(activities.T)) |
|
|
1364 |
|
|
|
1365 |
pvalue_list = mapf(_process_parallel_coxph, input_list) |
|
|
1366 |
|
|
|
1367 |
pvalue_list = list(filter(lambda x: not np.isnan(x[1]), pvalue_list)) |
|
|
1368 |
pvalue_list.sort(key=lambda x: x[1], reverse=True) |
|
|
1369 |
|
|
|
1370 |
valid_node_ids = [node_id for node_id, pvalue in pvalue_list |
|
|
1371 |
if pvalue < self.pvalue_thres] |
|
|
1372 |
|
|
|
1373 |
if self.verbose: |
|
|
1374 |
print('number of components linked to survival found:{0} for key {1}'.format( |
|
|
1375 |
len(valid_node_ids), key)) |
|
|
1376 |
|
|
|
1377 |
if pool is not None: |
|
|
1378 |
pool.close() |
|
|
1379 |
pool.join() |
|
|
1380 |
|
|
|
1381 |
return valid_node_ids |
|
|
1382 |
|
|
|
1383 |
def _look_for_prediction_nodes(self, key): |
|
|
1384 |
""" |
|
|
1385 |
""" |
|
|
1386 |
nbdays, isdead = self.dataset.survival.T.tolist() |
|
|
1387 |
nbdays_cv, isdead_cv = self.dataset.survival_cv.T.tolist() |
|
|
1388 |
|
|
|
1389 |
matrix_train = self.matrix_train_array[key] |
|
|
1390 |
matrix_cv = self.dataset.matrix_cv_array[key] |
|
|
1391 |
|
|
|
1392 |
if self.alternative_embedding is not None: |
|
|
1393 |
activities_train = self.embedding_predict(key, matrix_train) |
|
|
1394 |
activities_cv = self.embedding_predict(key, matrix_cv) |
|
|
1395 |
|
|
|
1396 |
elif self.use_autoencoders: |
|
|
1397 |
activities_train = self.encoder_predict(key, matrix_train) |
|
|
1398 |
activities_cv = self.encoder_predict(key, matrix_cv) |
|
|
1399 |
else: |
|
|
1400 |
activities_train = np.asarray( matrix_train) |
|
|
1401 |
activities_cv = np.asarray( matrix_cv) |
|
|
1402 |
|
|
|
1403 |
input_list = iter((node_id, |
|
|
1404 |
activities_train.T[node_id], isdead, nbdays, |
|
|
1405 |
activities_cv.T[node_id], isdead_cv, nbdays_cv, self.use_r_packages) |
|
|
1406 |
for node_id in range(activities_train.shape[1])) |
|
|
1407 |
|
|
|
1408 |
score_list = map(_process_parallel_cindex, input_list) |
|
|
1409 |
|
|
|
1410 |
score_list = filter(lambda x: not np.isnan(x[1]), score_list) |
|
|
1411 |
score_list.sort(key=lambda x:x[1], reverse=True) |
|
|
1412 |
|
|
|
1413 |
valid_node_ids = [node_id for node_id, cindex in score_list |
|
|
1414 |
if cindex > self.cindex_thres] |
|
|
1415 |
|
|
|
1416 |
scores = [score for node_id, score in score_list |
|
|
1417 |
if score > self.cindex_thres] |
|
|
1418 |
|
|
|
1419 |
if self.verbose: |
|
|
1420 |
print('number of components with a high prediction score:{0} for key {1}'\ |
|
|
1421 |
' \n\t mean: {2} std: {3}'.format( |
|
|
1422 |
len(valid_node_ids), key, np.mean(scores), np.std(scores))) |
|
|
1423 |
|
|
|
1424 |
return valid_node_ids |
|
|
1425 |
|
|
|
1426 |
def compute_c_indexes_for_full_dataset(self): |
|
|
1427 |
""" |
|
|
1428 |
return c-index using labels as predicat |
|
|
1429 |
""" |
|
|
1430 |
days, dead = np.asarray(self.dataset.survival).T |
|
|
1431 |
days_full, dead_full = np.asarray(self.dataset.survival_full).T |
|
|
1432 |
|
|
|
1433 |
try: |
|
|
1434 |
with warnings.catch_warnings(): |
|
|
1435 |
warnings.simplefilter("ignore") |
|
|
1436 |
cindex = c_index(self.labels, dead, days, |
|
|
1437 |
self.full_labels, dead_full, days_full, |
|
|
1438 |
use_r_packages=self.use_r_packages, |
|
|
1439 |
seed=self.seed,) |
|
|
1440 |
except Exception as e: |
|
|
1441 |
print('Exception while computing the c-index: {0}'.format(e)) |
|
|
1442 |
cindex = np.nan |
|
|
1443 |
|
|
|
1444 |
if self.verbose: |
|
|
1445 |
print('c-index for full dataset:{0}'.format(cindex)) |
|
|
1446 |
|
|
|
1447 |
return cindex |
|
|
1448 |
|
|
|
1449 |
def compute_c_indexes_for_training_dataset(self): |
|
|
1450 |
""" |
|
|
1451 |
return c-index using labels as predicat |
|
|
1452 |
""" |
|
|
1453 |
days, dead = np.asarray(self.dataset.survival).T |
|
|
1454 |
|
|
|
1455 |
try: |
|
|
1456 |
with warnings.catch_warnings(): |
|
|
1457 |
warnings.simplefilter("ignore") |
|
|
1458 |
cindex = c_index(self.labels, dead, days, |
|
|
1459 |
self.labels, dead, days, |
|
|
1460 |
use_r_packages=self.use_r_packages, |
|
|
1461 |
seed=self.seed,) |
|
|
1462 |
except Exception as e: |
|
|
1463 |
print('Exception while computing the c-index: {0}'.format(e)) |
|
|
1464 |
cindex = np.nan |
|
|
1465 |
|
|
|
1466 |
if self.verbose: |
|
|
1467 |
print('c-index for training dataset:{0}'.format(cindex)) |
|
|
1468 |
|
|
|
1469 |
return cindex |
|
|
1470 |
|
|
|
1471 |
def compute_c_indexes_for_test_dataset(self): |
|
|
1472 |
""" |
|
|
1473 |
return c-index using labels as predicat |
|
|
1474 |
""" |
|
|
1475 |
days, dead = np.asarray(self.dataset.survival).T |
|
|
1476 |
days_test, dead_test = np.asarray(self.dataset.survival_test).T |
|
|
1477 |
|
|
|
1478 |
try: |
|
|
1479 |
with warnings.catch_warnings(): |
|
|
1480 |
warnings.simplefilter("ignore") |
|
|
1481 |
cindex = c_index(self.labels, dead, days, |
|
|
1482 |
self.test_labels, dead_test, days_test, |
|
|
1483 |
use_r_packages=self.use_r_packages, |
|
|
1484 |
seed=self.seed,) |
|
|
1485 |
except Exception as e: |
|
|
1486 |
print('Exception while computing the c-index: {0}'.format(e)) |
|
|
1487 |
cindex = np.nan |
|
|
1488 |
|
|
|
1489 |
if self.verbose: |
|
|
1490 |
print('c-index for test dataset:{0}'.format(cindex)) |
|
|
1491 |
|
|
|
1492 |
return cindex |
|
|
1493 |
|
|
|
1494 |
def compute_c_indexes_for_test_fold_dataset(self): |
|
|
1495 |
""" |
|
|
1496 |
return c-index using labels as predicat |
|
|
1497 |
""" |
|
|
1498 |
with warnings.catch_warnings(): |
|
|
1499 |
warnings.simplefilter("ignore") |
|
|
1500 |
days, dead = np.asarray(self.dataset.survival).T |
|
|
1501 |
days_cv, dead_cv= np.asarray(self.dataset.survival_cv).T |
|
|
1502 |
|
|
|
1503 |
try: |
|
|
1504 |
cindex = c_index(self.labels, dead, days, |
|
|
1505 |
self.cv_labels, dead_cv, days_cv, |
|
|
1506 |
use_r_packages=self.use_r_packages, |
|
|
1507 |
seed=self.seed,) |
|
|
1508 |
except Exception as e: |
|
|
1509 |
print('Exception while computing the c-index: {0}'.format(e)) |
|
|
1510 |
cindex = np.nan |
|
|
1511 |
|
|
|
1512 |
if self.verbose: |
|
|
1513 |
print('c-index for test fold dataset:{0}'.format(cindex)) |
|
|
1514 |
|
|
|
1515 |
return cindex |
|
|
1516 |
|
|
|
1517 |
def predict_nodes_activities(self, matrix_array): |
|
|
1518 |
""" |
|
|
1519 |
""" |
|
|
1520 |
activities = [] |
|
|
1521 |
|
|
|
1522 |
for key in matrix_array: |
|
|
1523 |
if key not in self.pred_node_ids_array: |
|
|
1524 |
continue |
|
|
1525 |
|
|
|
1526 |
node_ids = self.pred_node_ids_array[key] |
|
|
1527 |
|
|
|
1528 |
if self.alternative_embedding is not None: |
|
|
1529 |
activities.append( |
|
|
1530 |
self.embedding_predict( |
|
|
1531 |
key, matrix_array[key]).T[node_ids].T) |
|
|
1532 |
else: |
|
|
1533 |
activities.append( |
|
|
1534 |
self.encoder_predict( |
|
|
1535 |
key, matrix_array[key]).T[node_ids].T) |
|
|
1536 |
|
|
|
1537 |
return hstack(activities) |
|
|
1538 |
|
|
|
1539 |
def plot_kernel_for_test_sets(self, |
|
|
1540 |
dataset=None, |
|
|
1541 |
labels=None, |
|
|
1542 |
labels_proba=None, |
|
|
1543 |
test_labels=None, |
|
|
1544 |
test_labels_proba=None, |
|
|
1545 |
define_as_main_kernel=False, |
|
|
1546 |
use_main_kernel=False, |
|
|
1547 |
activities=None, |
|
|
1548 |
activities_test=None, |
|
|
1549 |
key=''): |
|
|
1550 |
""" |
|
|
1551 |
""" |
|
|
1552 |
from simdeep.plot_utils import plot_kernel_plots |
|
|
1553 |
|
|
|
1554 |
if dataset is None: |
|
|
1555 |
dataset = self.dataset |
|
|
1556 |
|
|
|
1557 |
if labels is None: |
|
|
1558 |
labels = self.labels |
|
|
1559 |
|
|
|
1560 |
if labels_proba is None: |
|
|
1561 |
labels_proba = self.labels_proba |
|
|
1562 |
|
|
|
1563 |
if test_labels_proba is None: |
|
|
1564 |
test_labels_proba = self.test_labels_proba |
|
|
1565 |
|
|
|
1566 |
if test_labels is None: |
|
|
1567 |
test_labels = self.test_labels |
|
|
1568 |
|
|
|
1569 |
if test_labels_proba is None: |
|
|
1570 |
test_labels_proba = self.test_labels_proba |
|
|
1571 |
|
|
|
1572 |
test_norm = self.test_normalization |
|
|
1573 |
train_norm = self.dataset.normalization |
|
|
1574 |
train_norm = {key: train_norm[key] for key in train_norm if train_norm[key]} |
|
|
1575 |
|
|
|
1576 |
is_same_normalization = train_norm == test_norm |
|
|
1577 |
is_filled_with_zero = self.dataset.fill_unkown_feature_with_0 |
|
|
1578 |
|
|
|
1579 |
if activities is None or activities_test is None: |
|
|
1580 |
if not (is_same_normalization and is_filled_with_zero): |
|
|
1581 |
print('\n<><><><> Cannot plot survival KDE plot' \ |
|
|
1582 |
' Different normalisation used for test set <><><><>\n') |
|
|
1583 |
return |
|
|
1584 |
|
|
|
1585 |
activities = hstack([self.activities_array[omic] |
|
|
1586 |
for omic in self.test_omic_list]) |
|
|
1587 |
activities_test = self.activities_test |
|
|
1588 |
|
|
|
1589 |
if define_as_main_kernel: |
|
|
1590 |
self._main_kernel = {'activities': activities_test.copy(), |
|
|
1591 |
'labels': test_labels.copy()} |
|
|
1592 |
|
|
|
1593 |
if use_main_kernel: |
|
|
1594 |
activities = self._main_kernel['activities'] |
|
|
1595 |
labels = self._main_kernel['labels'] |
|
|
1596 |
|
|
|
1597 |
html_name = '{0}/{1}{2}_test_kdeplot.html'.format( |
|
|
1598 |
self.path_results, |
|
|
1599 |
self.project_name, |
|
|
1600 |
key) |
|
|
1601 |
|
|
|
1602 |
plot_kernel_plots( |
|
|
1603 |
test_labels=test_labels, |
|
|
1604 |
test_labels_proba=test_labels_proba, |
|
|
1605 |
labels=labels, |
|
|
1606 |
activities=activities, |
|
|
1607 |
activities_test=activities_test, |
|
|
1608 |
dataset=self.dataset, |
|
|
1609 |
path_html=html_name) |
|
|
1610 |
|
|
|
1611 |
def plot_supervised_kernel_for_test_sets( |
|
|
1612 |
self, |
|
|
1613 |
labels=None, |
|
|
1614 |
labels_proba=None, |
|
|
1615 |
dataset=None, |
|
|
1616 |
key='', |
|
|
1617 |
use_main_kernel=False, |
|
|
1618 |
test_labels=None, |
|
|
1619 |
test_labels_proba=None, |
|
|
1620 |
define_as_main_kernel=False, |
|
|
1621 |
): |
|
|
1622 |
""" |
|
|
1623 |
""" |
|
|
1624 |
if labels is None: |
|
|
1625 |
labels = self.labels |
|
|
1626 |
|
|
|
1627 |
if labels_proba is None: |
|
|
1628 |
labels_proba = self.labels_proba |
|
|
1629 |
|
|
|
1630 |
if dataset is None: |
|
|
1631 |
dataset = self.dataset |
|
|
1632 |
|
|
|
1633 |
activities, activities_test = self._predict_kde_matrix( |
|
|
1634 |
labels_proba, dataset) |
|
|
1635 |
|
|
|
1636 |
key += '_supervised' |
|
|
1637 |
|
|
|
1638 |
self.plot_kernel_for_test_sets(labels=labels, |
|
|
1639 |
labels_proba=labels_proba, |
|
|
1640 |
dataset=dataset, |
|
|
1641 |
activities=activities, |
|
|
1642 |
activities_test=activities_test, |
|
|
1643 |
key=key, |
|
|
1644 |
use_main_kernel=use_main_kernel, |
|
|
1645 |
test_labels=test_labels, |
|
|
1646 |
test_labels_proba=test_labels_proba, |
|
|
1647 |
define_as_main_kernel=define_as_main_kernel, |
|
|
1648 |
) |
|
|
1649 |
|
|
|
1650 |
def _create_autoencoder_for_kernel_plot(self, labels_proba, dataset, key): |
|
|
1651 |
""" |
|
|
1652 |
""" |
|
|
1653 |
autoencoder = DeepBase(dataset=dataset, |
|
|
1654 |
seed=self.seed, |
|
|
1655 |
verbose=False, |
|
|
1656 |
dropout=0.1, |
|
|
1657 |
epochs=50) |
|
|
1658 |
|
|
|
1659 |
autoencoder.matrix_train_array = dataset.matrix_ref_array |
|
|
1660 |
autoencoder.construct_supervized_network(labels_proba) |
|
|
1661 |
|
|
|
1662 |
self.encoder_for_kde_plot_dict[key] = autoencoder.encoder_array |
|
|
1663 |
|
|
|
1664 |
def _predict_kde_matrix(self, labels_proba, dataset): |
|
|
1665 |
""" |
|
|
1666 |
""" |
|
|
1667 |
matrix_ref_list = [] |
|
|
1668 |
matrix_test_list = [] |
|
|
1669 |
|
|
|
1670 |
encoder_key = str(self.test_normalization) |
|
|
1671 |
encoder_key = 'omic:{0} normalisation: {1}'.format( |
|
|
1672 |
self.test_omic_list, |
|
|
1673 |
encoder_key) |
|
|
1674 |
|
|
|
1675 |
if encoder_key not in self.encoder_for_kde_plot_dict or \ |
|
|
1676 |
not dataset.fill_unkown_feature_with_0: |
|
|
1677 |
self._create_autoencoder_for_kernel_plot( |
|
|
1678 |
labels_proba, dataset, encoder_key) |
|
|
1679 |
|
|
|
1680 |
encoder_array = self.encoder_for_kde_plot_dict[encoder_key] |
|
|
1681 |
|
|
|
1682 |
if self.metadata_usage in ['all', 'new-features'] and \ |
|
|
1683 |
dataset.metadata_mat is not None: |
|
|
1684 |
metadata_mat = dataset.metadata_mat |
|
|
1685 |
else: |
|
|
1686 |
metadata_mat = None |
|
|
1687 |
|
|
|
1688 |
for key in encoder_array: |
|
|
1689 |
matrix_ref = encoder_array[key].predict( |
|
|
1690 |
dataset.matrix_ref_array[key]) |
|
|
1691 |
matrix_test = encoder_array[key].predict( |
|
|
1692 |
dataset.matrix_test_array[key]) |
|
|
1693 |
|
|
|
1694 |
survival_node_ids = self._look_for_survival_nodes( |
|
|
1695 |
activities=matrix_ref, survival=dataset.survival, |
|
|
1696 |
metadata_mat=metadata_mat) |
|
|
1697 |
|
|
|
1698 |
if len(survival_node_ids) > 1: |
|
|
1699 |
matrix_ref = matrix_ref.T[survival_node_ids].T |
|
|
1700 |
matrix_test = matrix_test.T[survival_node_ids].T |
|
|
1701 |
else: |
|
|
1702 |
print('not enough survival nodes to construct kernel for key: {0}' \ |
|
|
1703 |
'skipping the {0} matrix'.format(key)) |
|
|
1704 |
continue |
|
|
1705 |
|
|
|
1706 |
matrix_ref_list.append(matrix_ref) |
|
|
1707 |
matrix_test_list.append(matrix_test) |
|
|
1708 |
|
|
|
1709 |
if not matrix_ref_list: |
|
|
1710 |
print('matrix_ref_list / matrix_test_list empty!' \ |
|
|
1711 |
'take the last OMIC ({0}) matrix as ref'.format(key)) |
|
|
1712 |
matrix_ref_list.append(matrix_ref) |
|
|
1713 |
matrix_test_list.append(matrix_test) |
|
|
1714 |
|
|
|
1715 |
return hstack(matrix_ref_list), hstack(matrix_test_list) |
|
|
1716 |
|
|
|
1717 |
|
|
|
1718 |
def _get_probas_for_full_model(self): |
|
|
1719 |
""" |
|
|
1720 |
return sample and proba |
|
|
1721 |
""" |
|
|
1722 |
return list(zip(self.dataset.sample_ids_full, self.full_labels_proba)) |
|
|
1723 |
|
|
|
1724 |
|
|
|
1725 |
def _get_pvalues_and_pvalues_proba(self): |
|
|
1726 |
""" |
|
|
1727 |
""" |
|
|
1728 |
return self.full_pvalue, self.full_pvalue_proba |
|
|
1729 |
|
|
|
1730 |
def _get_from_dataset(self, attr): |
|
|
1731 |
""" |
|
|
1732 |
""" |
|
|
1733 |
return getattr(self.dataset, attr) |
|
|
1734 |
|
|
|
1735 |
def _get_attibute(self, attr): |
|
|
1736 |
""" |
|
|
1737 |
""" |
|
|
1738 |
return getattr(self, attr) |
|
|
1739 |
|
|
|
1740 |
|
|
|
1741 |
def _partial_fit_model_pool(self): |
|
|
1742 |
""" |
|
|
1743 |
""" |
|
|
1744 |
try: |
|
|
1745 |
self.load_training_dataset() |
|
|
1746 |
self.fit() |
|
|
1747 |
|
|
|
1748 |
if len(set(self.labels)) < 1: |
|
|
1749 |
raise Exception('only one class!') |
|
|
1750 |
|
|
|
1751 |
if self.train_pvalue > MODEL_THRES: |
|
|
1752 |
raise Exception('pvalue: {0} not significant!'.format(self.train_pvalue)) |
|
|
1753 |
|
|
|
1754 |
except Exception as e: |
|
|
1755 |
print('model with random state:{1} didn\'t converge:{0}'.format(str(e), self.seed)) |
|
|
1756 |
return False |
|
|
1757 |
|
|
|
1758 |
else: |
|
|
1759 |
print('model with random state:{0} fitted'.format(self.seed)) |
|
|
1760 |
self._is_fitted = True |
|
|
1761 |
|
|
|
1762 |
self.predict_labels_on_test_fold() |
|
|
1763 |
self.predict_labels_on_full_dataset() |
|
|
1764 |
self.evalutate_cluster_performance() |
|
|
1765 |
|
|
|
1766 |
return self._is_fitted |
|
|
1767 |
|
|
|
1768 |
def _partial_fit_model_with_pretrained_pool(self, labels_file): |
|
|
1769 |
""" |
|
|
1770 |
""" |
|
|
1771 |
self.fit_on_pretrained_label_file(labels_file) |
|
|
1772 |
|
|
|
1773 |
self.predict_labels_on_test_fold() |
|
|
1774 |
self.predict_labels_on_full_dataset() |
|
|
1775 |
self.evalutate_cluster_performance() |
|
|
1776 |
|
|
|
1777 |
self._is_fitted = True |
|
|
1778 |
|
|
|
1779 |
return self._is_fitted |
|
|
1780 |
|
|
|
1781 |
def _predict_new_dataset(self, |
|
|
1782 |
tsv_dict, |
|
|
1783 |
path_survival_file, |
|
|
1784 |
normalization, |
|
|
1785 |
survival_flag=None, |
|
|
1786 |
metadata_file=None): |
|
|
1787 |
""" |
|
|
1788 |
""" |
|
|
1789 |
self.load_new_test_dataset( |
|
|
1790 |
tsv_dict=tsv_dict, |
|
|
1791 |
path_survival_file=path_survival_file, |
|
|
1792 |
normalization=normalization, |
|
|
1793 |
survival_flag=survival_flag, |
|
|
1794 |
metadata_file=metadata_file |
|
|
1795 |
) |
|
|
1796 |
|
|
|
1797 |
self.predict_labels_on_test_dataset() |