--- a +++ b/simdeep/simdeep_analysis.py @@ -0,0 +1,1797 @@ +""" +DeepProg class for one instance model +""" + +from sklearn.cluster import KMeans +from sklearn.mixture import GaussianMixture +from sklearn.model_selection import cross_val_score + +from simdeep.deepmodel_base import DeepBase + +from simdeep.survival_model_utils import ClusterWithSurvival + +from simdeep.config import NB_CLUSTERS +from simdeep.config import CLUSTER_ARRAY +from simdeep.config import PVALUE_THRESHOLD +from simdeep.config import CINDEX_THRESHOLD +from simdeep.config import CLASSIFIER_TYPE +from simdeep.config import USE_AUTOENCODERS +from simdeep.config import FEATURE_SURV_ANALYSIS +from simdeep.config import SEED + +from simdeep.config import MIXTURE_PARAMS +from simdeep.config import PATH_RESULTS +from simdeep.config import PROJECT_NAME +from simdeep.config import CLASSIFICATION_METHOD + +from simdeep.config import CLUSTER_EVAL_METHOD +from simdeep.config import CLUSTER_METHOD +from simdeep.config import NB_THREADS_COXPH +from simdeep.config import NB_SELECTED_FEATURES +from simdeep.config import LOAD_EXISTING_MODELS +from simdeep.config import NODES_SELECTION +from simdeep.config import CLASSIFIER +from simdeep.config import HYPER_PARAMETERS +from simdeep.config import PATH_TO_SAVE_MODEL +from simdeep.config import CLUSTERING_OMICS +from simdeep.config import USE_R_PACKAGES_FOR_SURVIVAL + +from simdeep.survival_utils import _process_parallel_coxph +from simdeep.survival_utils import _process_parallel_cindex +from simdeep.survival_utils import _process_parallel_feature_importance +from simdeep.survival_utils import _process_parallel_feature_importance_per_cluster +from simdeep.survival_utils import select_best_classif_params + +from simdeep.simdeep_utils import metadata_usage_type +from simdeep.simdeep_utils import feature_selection_usage_type + +from simdeep.simdeep_utils import load_labels_file + +from simdeep.coxph_from_r import coxph +from simdeep.coxph_from_r import c_index +from simdeep.coxph_from_r import c_index_multiple + +from simdeep.coxph_from_r import surv_median + +from collections import Counter + +from sklearn.metrics import silhouette_score + +try: + from sklearn.metrics import calinski_harabasz_score \ + as calinski_harabaz_score +except Exception: + from sklearn.metrics import calinski_harabaz_score + +from sklearn.model_selection import GridSearchCV + +import numpy as np +from numpy import hstack + +from collections import defaultdict + +import warnings + +from multiprocessing import Pool + +from os.path import isdir +from os import mkdir + + +################ VARIABLE ############################################ +_CLASSIFICATION_METHOD_LIST = ['ALL_FEATURES', 'SURVIVAL_FEATURES'] +MODEL_THRES = 0.05 +###################################################################### + + +class SimDeep(DeepBase): + """ + Instanciate a new DeepProg instance. + The default parameters are defined in the config.py file + + Parameters: + :dataset: ExtractData instance. Default None (create a new dataset using the config variable) + :nb_clusters: Number of clusters to search (default NB_CLUSTERS) + :pvalue_thres: Pvalue threshold to include a feature (default PVALUE_THRESHOLD) + :clustering_omics: Which omics to use for clustering. If empty, then all the available omics will be used + :cindex_thres: C-index threshold to include a feature. This parameter is used only if `node_selection` is set to "C-index" (default CINDEX_THRESHOLD) + :cluster_method: Cluster method to use. possible choice ['mixture', 'kmeans']. (default CLUSTER_METHOD) + :cluster_eval_method: Cluster evaluation method to use in case the `cluster_array` parameter is a list of possible K. Possible choice ['bic', 'silhouette', 'calinski'] (default CLUSTER_EVAL_METHOD) + :classifier_type: Type of classifier to use. Possible choice ['svm', 'clustering']. If 'clustering' is selected, The predict method of the clustering algoritm is used (default CLASSIFIER_TYPE) + :project_name: Name of the project. This name will be used to save the output files and create the output folder (default PROJECT_NAME) + :path_results: Result folder path used to save the output files (default PATH_RESULTS) + :cluster_array: Array of possible number of clusters to try. If set, `nb_clusters` is ignored (default CLUSTER_ARRAY) + :nb_selected_features: Number of selected features to construct classifiers (default NB_SELECTED_FEATURES) + :mixture_params: Dictionary of parameters used to instanciate the Gaussian mixture algorithm (default MIXTURE_PARAMS) + :node_selection: Mehtod to select new features. possible choice ['Cox-PH', 'C-index']. (default NODES_SELECTION) + :nb_threads_coxph: Number of python processes to use to compute individual survival models in parallel (default NB_THREADS_COXPH) + :classification_method: Possible choice ['ALL_FEATURES', 'SURVIVAL_FEATURES']. If 'SURVIVAL_FEATURES' is selected, the classifiers are built using survival features (default CLASSIFICATION_METHOD) + :load_existing_models: (default LOAD_EXISTING_MODELS) + :path_to_save_model: (default PATH_TO_SAVE_MODEL) + :metadata_usage: Meta data usage with survival models (if metadata_tsv provided as argument to the dataset). Possible choice are [None, False, 'labels', 'new-features', 'all', True] (True is the same as all) + :feature_selection_usage: selection method for survival features ('individual' or 'lasso') + :alternative_embedding: alternative external embedding to use instead of builfing autoencoders (default None) + :kwargs_alternative_embedding: parameters for external embedding fitting + """ + def __init__(self, + nb_clusters=NB_CLUSTERS, + pvalue_thres=PVALUE_THRESHOLD, + cindex_thres=CINDEX_THRESHOLD, + use_autoencoders=USE_AUTOENCODERS, + feature_surv_analysis=FEATURE_SURV_ANALYSIS, + cluster_method=CLUSTER_METHOD, + cluster_eval_method=CLUSTER_EVAL_METHOD, + classifier_type=CLASSIFIER_TYPE, + project_name=PROJECT_NAME, + path_results=PATH_RESULTS, + cluster_array=CLUSTER_ARRAY, + nb_selected_features=NB_SELECTED_FEATURES, + mixture_params=MIXTURE_PARAMS, + node_selection=NODES_SELECTION, + nb_threads_coxph=NB_THREADS_COXPH, + classification_method=CLASSIFICATION_METHOD, + load_existing_models=LOAD_EXISTING_MODELS, + path_to_save_model=PATH_TO_SAVE_MODEL, + clustering_omics=CLUSTERING_OMICS, + metadata_usage=None, + feature_selection_usage='individual', + use_r_packages=USE_R_PACKAGES_FOR_SURVIVAL, + seed=SEED, + alternative_embedding=None, + do_KM_plot=True, + verbose=True, + _isboosting=False, + dataset=None, + kwargs_alternative_embedding={}, + deep_model_additional_args={}): + """ + """ + self.seed = seed + self.nb_clusters = nb_clusters + self.pvalue_thres = pvalue_thres + self.cindex_thres = cindex_thres + self.use_autoencoders = use_autoencoders + self.classifier_grid = GridSearchCV(CLASSIFIER(), HYPER_PARAMETERS, cv=5) + self.cluster_array = cluster_array + self.path_results = path_results + self.clustering_omics = clustering_omics + self.use_r_packages = use_r_packages + self.metadata_usage = metadata_usage_type(metadata_usage) + self.feature_selection_usage = feature_selection_usage_type( + feature_selection_usage) + + self.feature_surv_analysis = feature_surv_analysis + + if self.feature_selection_usage is None: + self.feature_surv_analysis = False + + self.alternative_embedding = alternative_embedding + self.kwargs_alternative_embedding = kwargs_alternative_embedding + + if self.path_results and not isdir(self.path_results): + mkdir(self.path_results) + + self.mixture_params = mixture_params + + self.project_name = project_name + self._project_name = project_name + self.do_KM_plot = do_KM_plot + self.nb_threads_coxph = nb_threads_coxph + self.classification_method = classification_method + self.nb_selected_features = nb_selected_features + self.node_selection = node_selection + + self.train_pvalue = None + self.train_pvalue_proba = None + self.full_pvalue = None + self.full_pvalue_proba = None + self.cv_pvalue = None + self.cv_pvalue_proba = None + self.test_pvalue = None + self.test_pvalue_proba = None + + self.classifier = None + self.classifier_test = None + self.clustering = None + + self.classifier_dict = {} + + self.encoder_for_kde_plot_dict = {} + self._main_kernel = {} + + self.classifier_type = classifier_type + + self.used_normalization = None + self.test_normalization = None + + self.used_features_for_classif = None + + self._isboosting = _isboosting + self._pretrained_model = False + self._is_fitted = False + + self.valid_node_ids_array = {} + self.activities_array = {} + self.activities_pred_array = {} + self.pred_node_ids_array = {} + + self.activities_train = None + self.activities_test = None + self.activities_cv = None + + self.activities_for_pred_train = None + self.activities_for_pred_test = None + self.activities_for_pred_cv = None + + self.test_labels = None + self.test_labels_proba = None + self.cv_labels = None + self.cv_labels_proba = None + self.full_labels = None + self.full_labels_proba = None + + self.labels = None + self.labels_proba = None + + self.training_omic_list = [] + self.test_omic_list = [] + + self.feature_scores = defaultdict(list) + self.feature_scores_per_cluster = {} + + self._label_ordered_dict = {} + + self.clustering_performance = None + self.bic_score = None + self.silhouette_score = None + self.calinski_score = None + + self.cluster_method = cluster_method + self.cluster_eval_method = cluster_eval_method + self.verbose = verbose + self._load_existing_models = load_existing_models + self._features_scores_changed = False + + self.path_to_save_model = path_to_save_model + + deep_model_additional_args['path_to_save_model'] = self.path_to_save_model + + DeepBase.__init__(self, + verbose=self.verbose, + dataset=dataset, + alternative_embedding=self.alternative_embedding, + kwargs_alternative_embedding=self.kwargs_alternative_embedding, + **deep_model_additional_args) + + def _look_for_nodes(self, key): + """ + """ + assert(self.node_selection in ['Cox-PH', 'C-index']) + + if self.metadata_usage in ['all', 'new-features'] and \ + self.dataset.metadata_mat is not None: + metadata_mat = self.dataset.metadata_mat + else: + metadata_mat = None + + if self.node_selection == 'Cox-PH': + return self._look_for_survival_nodes( + key, metadata_mat=metadata_mat) + + if self.node_selection == 'C-index': + return self._look_for_prediction_nodes(key) + + def load_new_test_dataset(self, tsv_dict, + fname_key=None, + path_survival_file=None, + normalization=None, + survival_flag=None, + metadata_file=None): + """ + """ + self.dataset.load_new_test_dataset( + tsv_dict, + path_survival_file, + normalization=normalization, + survival_flag=survival_flag, + metadata_file=metadata_file + ) + + if normalization is not None: + self.test_normalization = { + key: normalization[key] + for key in normalization + if normalization[key]} + + else: + self.test_normalization = { + key: self.dataset.normalization[key] + for key in self.dataset.normalization + if self.dataset.normalization[key]} + + if self.used_normalization != self.test_normalization: + if self.verbose: + print('recombuting feature scores...') + + self.feature_scores = {} + self.compute_feature_scores(use_ref=True) + self._features_scores_changed = True + + if fname_key: + self.project_name = '{0}_{1}'.format(self._project_name, fname_key) + + def fit_on_pretrained_label_file(self, label_file): + """ + fit a deepprog simdeep model without training autoencoder but just using a ID->labels file to train a classifier + """ + self._pretrained_model = True + self.use_autoencoders = False + self.feature_surv_analysis = False + + self.dataset.load_array() + self.dataset.load_survival() + self.dataset.load_meta_data() + self.dataset.subset_training_sets() + + labels_dict = load_labels_file(label_file) + + train, test, labels, labels_proba = [], [], [], [] + + for index, sample in enumerate(self.dataset.sample_ids): + + if sample in labels_dict: + train.append(index) + label, label_proba = labels_dict[sample] + + labels.append(label) + labels_proba.append(label_proba) + + else: + test.append(index) + + if test: + self.dataset.cross_validation_instance = (train, test) + else: + self.dataset.cross_validation_instance = None + + self.dataset.create_a_cv_split() + self.dataset.normalize_training_array() + + self.matrix_train_array = self.dataset.matrix_train_array + + for key in self.matrix_train_array: + self.matrix_train_array[key] = self.matrix_train_array[key].astype('float32') + + self.training_omic_list = self.dataset.training_tsv.keys() + + self.predict_labels_using_external_labels(labels, labels_proba) + + self.used_normalization = {key: self.dataset.normalization[key] + for key in self.dataset.normalization + if self.dataset.normalization[key]} + + self.used_features_for_classif = self.dataset.feature_train_array + self.look_for_survival_nodes() + self.fit_classification_model() + + def predict_labels_using_external_labels(self, labels, labels_proba): + """ + """ + self.labels = labels + nb_clusters = len(set(self.labels)) + self.labels_proba = np.array([labels_proba for _ in range(nb_clusters)]).T + + nbdays, isdead = self.dataset.survival.T.tolist() + + pvalue = coxph(self.labels, isdead, nbdays, + isfactor=False, + do_KM_plot=self.do_KM_plot, + png_path=self.path_results, + seed=self.seed, + use_r_packages=self.use_r_packages, + fig_name='{0}_KM_plot_training_dataset'.format(self.project_name)) + + pvalue_proba = coxph(self.labels_proba.T[0], isdead, nbdays, + seed=self.seed, + use_r_packages=self.use_r_packages, + isfactor=False) + + if not self._isboosting: + self._write_labels(self.dataset.sample_ids, self.labels, + labels_proba=self.labels_proba.T[0], + fname='{0}_training_set_labels'.format(self.project_name)) + + if self.verbose: + print('Cox-PH p-value (Log-Rank) for the cluster labels: {0}'.format(pvalue)) + + self.train_pvalue = pvalue + self.train_pvalue_proba = pvalue_proba + + def fit(self): + """ + main function + I) construct an autoencoder or fit alternative embedding + II) predict nodes linked with survival (if active) + and III) do clustering + """ + if self._load_existing_models: + self.load_encoders() + + if not self.is_model_loaded: + if self.alternative_embedding is not None: + self.fit_alternative_embedding() + else: + self.construct_autoencoders() + + self.look_for_survival_nodes() + + self.training_omic_list = list(self.encoder_array.keys()) + self.predict_labels() + + self.used_normalization = {key: self.dataset.normalization[key] + for key in self.dataset.normalization + if self.dataset.normalization[key]} + + self.used_features_for_classif = self.dataset.feature_train_array + self.fit_classification_model() + + def predict_labels_on_test_fold(self): + """ + """ + if not self.dataset.cross_validation_instance: + return + + self.dataset.load_matrix_test_fold() + + nbdays, isdead = self.dataset.survival_cv.T.tolist() + self.activities_cv = self._predict_survival_nodes( + self.dataset.matrix_cv_array) + + self.cv_labels, self.cv_labels_proba = self._predict_labels( + self.activities_cv, self.dataset.matrix_cv_array) + + if self.verbose: + print('#### report of test fold cluster:):') + for key, value in Counter(self.cv_labels).items(): + print('class: {0}, number of samples :{1}'.format(key, value)) + + if self.metadata_usage in ['all', 'labels'] and \ + self.dataset.metadata_mat_cv is not None: + metadata_mat = self.dataset.metadata_mat_cv + else: + metadata_mat = None + + pvalue, pvalue_proba = self._compute_test_coxph('KM_plot_test_fold', + nbdays, isdead, + self.cv_labels, + self.cv_labels_proba, + metadata_mat=metadata_mat) + self.cv_pvalue = pvalue + self.cv_pvalue_proba = pvalue_proba + + if not self._isboosting: + self._write_labels(self.dataset.sample_ids_cv, self.cv_labels, + labels_proba=self.cv_labels_proba.T[0], + fname='{0}_test_fold_labels'.format(self.project_name)) + + return self.cv_labels, pvalue, pvalue_proba + + def predict_labels_on_full_dataset(self): + """ + """ + self.dataset.load_matrix_full() + + nbdays, isdead = self.dataset.survival_full.T.tolist() + + self.activities_full = self._predict_survival_nodes( + self.dataset.matrix_full_array) + + self.full_labels, self.full_labels_proba = self._predict_labels( + self.activities_full, self.dataset.matrix_full_array) + + if self.verbose: + print('#### report of assigned cluster for full dataset:') + for key, value in Counter(self.full_labels).items(): + print('class: {0}, number of samples :{1}'.format(key, value)) + + if self.metadata_usage in ['all', 'labels'] and \ + self.dataset.metadata_mat_full is not None: + metadata_mat = self.dataset.metadata_mat_full + else: + metadata_mat = None + + pvalue, pvalue_proba = self._compute_test_coxph('KM_plot_full', + nbdays, isdead, + self.full_labels, + self.full_labels_proba, + metadata_mat=metadata_mat) + self.full_pvalue = pvalue + self.full_pvalue_proba = pvalue_proba + + if not self._isboosting: + self._write_labels(self.dataset.sample_ids_full, self.full_labels, + labels_proba=self.full_labels_proba.T[0], + fname='{0}_full_labels'.format(self.project_name)) + + return self.full_labels, pvalue, pvalue_proba + + def predict_labels_on_test_dataset(self): + """ + """ + if self.dataset.survival_test is not None: + nbdays, isdead = self.dataset.survival_test.T.tolist() + + self.test_omic_list = list(self.dataset.matrix_test_array.keys()) + self.test_omic_list = list(set(self.test_omic_list).intersection( + self.training_omic_list)) + + try: + assert(len(self.test_omic_list) > 0) + except AssertionError: + raise Exception('in predict_labels_on_test_dataset: test_omic_list is empty!'\ + '\n either no common omic with trining_omic_list or error!') + + self.fit_classification_test_model() + + self.activities_test = self._predict_survival_nodes( + self.dataset.matrix_test_array) + self._predict_test_labels(self.activities_test, + self.dataset.matrix_test_array) + + if self.verbose: + print('#### report of assigned cluster:') + for key, value in Counter(self.test_labels).items(): + print('class: {0}, number of samples :{1}'.format(key, value)) + + if self.metadata_usage in ['all', 'test-labels'] and \ + self.dataset.metadata_mat_test is not None: + metadata_mat = self.dataset.metadata_mat_test + else: + metadata_mat = None + + pvalue, pvalue_proba = self._compute_test_coxph('KM_plot_test', + nbdays, isdead, + self.test_labels, + self.test_labels_proba, + metadata_mat=metadata_mat) + self.test_pvalue = pvalue + self.test_pvalue_proba = pvalue_proba + + if self.dataset.survival_test is not None: + if np.isnan(nbdays).all(): + pvalue, pvalue_proba = self._compute_test_coxph( + 'KM_plot_test', + nbdays, isdead, + self.test_labels, self.test_labels_proba) + + self.test_pvalue = pvalue + self.test_pvalue_proba = pvalue_proba + + if not self._isboosting: + self._write_labels(self.dataset.sample_ids_test, self.test_labels, + labels_proba=self.test_labels_proba.T[0], + fname='{0}_test_labels'.format(self.project_name)) + + return self.test_labels, pvalue, pvalue_proba + + def _compute_test_coxph(self, + fname_base, + nbdays, + isdead, + labels, + labels_proba, + metadata_mat=None): + """ """ + pvalue = coxph( + labels, isdead, nbdays, + isfactor=False, + do_KM_plot=self.do_KM_plot, + png_path=self.path_results, + seed=self.seed, + use_r_packages=self.use_r_packages, + metadata_mat=metadata_mat, + fig_name='{0}_{1}'.format(self.project_name, fname_base)) + + if self.verbose: + print('Cox-PH p-value (Log-Rank) for inferred labels: {0}'.format(pvalue)) + + pvalue_proba = coxph( + labels_proba.T[0], + isdead, nbdays, + isfactor=False, + do_KM_plot=False, + png_path=self.path_results, + seed=self.seed, + use_r_packages=self.use_r_packages, + metadata_mat=metadata_mat, + fig_name='{0}_{1}_proba'.format(self.project_name, fname_base)) + + if self.verbose: + print('Cox-PH proba p-value (Log-Rank) for inferred labels: {0}'.format(pvalue_proba)) + + return pvalue, pvalue_proba + + def compute_feature_scores(self, use_ref=False): + """ + """ + if self.feature_scores: + return + + pool = None + + if not self._isboosting: + pool = Pool(self.nb_threads_coxph) + mapf = pool.map + mapf = map + else: + mapf = map + + def generator(labels, feature_list, matrix): + for i in range(len(feature_list)): + yield feature_list[i], matrix[i], labels + + if use_ref: + key_array = list(self.dataset.matrix_ref_array.keys()) + else: + key_array = list(self.dataset.matrix_train_array.keys()) + + for key in key_array: + if use_ref: + feature_list = self.dataset.feature_ref_array[key][:] + matrix = self.dataset.matrix_ref_array[key][:] + else: + feature_list = self.dataset.feature_train_array[key][:] + matrix = self.dataset.matrix_train_array[key][:] + + labels = self.labels[:] + + input_list = generator(labels, feature_list, matrix.T) + + features_scored = list(mapf(_process_parallel_feature_importance, input_list)) + features_scored.sort(key=lambda x:x[1]) + + self.feature_scores[key] = features_scored + + if pool is not None: + pool.close() + pool.join() + + def compute_feature_scores_per_cluster(self, use_ref=False, + pval_thres=0.01): + """ + """ + print('computing feature importance per cluster...') + + mapf = map + + for label in set(self.labels): + self.feature_scores_per_cluster[label] = [] + + def generator(labels, feature_list, matrix): + for i in range(len(feature_list)): + yield feature_list[i], matrix[i], labels, pval_thres + + if use_ref: + key_array = list(self.dataset.matrix_ref_array.keys()) + else: + key_array = list(self.dataset.matrix_train_array.keys()) + + for key in key_array: + if use_ref: + feature_list = self.dataset.feature_ref_array[key][:] + matrix = self.dataset.matrix_ref_array[key][:] + else: + feature_list = self.dataset.feature_train_array[key][:] + matrix = self.dataset.matrix_train_array[key][:] + + labels = self.labels[:] + + input_list = generator(labels, feature_list, matrix.T) + + features_scored = mapf(_process_parallel_feature_importance_per_cluster, input_list) + features_scored = [feat for feat_list in features_scored for feat in feat_list] + + for label, feature, median_diff, pvalue in features_scored: + self.feature_scores_per_cluster[label].append((feature, median_diff, pvalue)) + + for label in self.feature_scores_per_cluster: + self.feature_scores_per_cluster[label].sort(key=lambda x:x[1]) + + def write_feature_score_per_cluster(self): + """ + """ + f_file_name = '{0}/{1}_features_scores_per_clusters.tsv'.format( + self.path_results, self._project_name) + f_anti_name = '{0}/{1}_features_anticorrelated_scores_per_clusters.tsv'.format( + self.path_results, self._project_name) + + f_file = open(f_file_name, 'w') + f_anti_file = open(f_anti_name, 'w') + + f_file.write('cluster id;feature;median diff;p-value\n') + + for label in self.feature_scores_per_cluster: + for feature, median_diff, pvalue in self.feature_scores_per_cluster[label]: + if median_diff > 0: + f_to_write = f_file + else: + f_to_write = f_anti_file + + f_to_write.write('{0};{1};{2};{3}\n'.format(label, feature, median_diff, pvalue)) + + print('{0} written'.format(f_file_name)) + print('{0} written'.format(f_anti_name)) + + def write_feature_scores(self): + """ + """ + with open('{0}/{1}_features_scores.tsv'.format( + self.path_results, self.project_name), 'w') as f_file: + + for key in self.feature_scores: + f_file.write('#### {0} ####\n'.format(key)) + + for feature, score in self.feature_scores[key]: + f_file.write('{0};{1}\n'.format(feature, score)) + + print('{0}/{1}_features_scores.tsv written'.format( + self.path_results, self.project_name)) + + def _return_train_matrix_for_classification(self): + """ + """ + assert (self.classification_method in _CLASSIFICATION_METHOD_LIST) + + if self.verbose: + print('classification method: {0}'.format( + self.classification_method)) + + if self.classification_method == 'SURVIVAL_FEATURES': + assert(self.classifier_type != 'clustering') + matrix = self._predict_survival_nodes( + self.dataset.matrix_ref_array) + elif self.classification_method == 'ALL_FEATURES': + matrix = self._reduce_and_stack_matrices( + self.dataset.matrix_ref_array) + + if self.verbose: + print('number of features for the classifier: {0}'.format( + matrix.shape[1])) + + return np.nan_to_num(matrix) + + def _reduce_and_stack_matrices(self, matrices): + """ + """ + if not self.nb_selected_features: + return hstack(matrices.values()) + else: + self.compute_feature_scores() + matrix = [] + + for key in matrices: + index = [self.dataset.feature_ref_index[key][feature] + for feature, pvalue in + self.feature_scores[key][:self.nb_selected_features] + if feature in self.dataset.feature_ref_index[key] + ] + + matrix.append(matrices[key].T[index].T) + + return hstack(matrix) + + def fit_classification_model(self): + """ """ + train_matrix = self._return_train_matrix_for_classification() + labels = self.labels + + if self.classifier_type == 'clustering': + if self.verbose: + print('clustering model defined as the classifier') + + self.classifier = self.clustering + return + + if self.verbose: + print('classification analysis...') + + if isinstance(self.seed, int): + np.random.seed(self.seed) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + self.classifier_grid.fit(train_matrix, labels) + + self.classifier, params = select_best_classif_params( + self.classifier_grid) + + self.classifier.set_params(probability=True) + self.classifier.fit(train_matrix, labels) + + self.classifier_dict[str(self.used_normalization)] = self.classifier + + if self.verbose: + cvs = cross_val_score(self.classifier, train_matrix, labels, cv=5) + print('best params:', params) + print('cross val score: {0}'.format(np.mean(cvs))) + print('classification score:', self.classifier.score( + train_matrix, labels)) + + def fit_classification_test_model(self): + """ """ + is_same_features = self.used_features_for_classif == self.dataset.feature_ref_array + is_same_normalization = self.used_normalization == self.test_normalization + is_filled_with_zero = self.dataset.fill_unkown_feature_with_0 + + if (is_same_features and is_same_normalization and is_filled_with_zero)\ + or self.classifier_type == 'clustering': + if self.verbose: + print('Not rebuilding the test classifier') + + if self.classifier_test is None: + self.classifier_test = self.classifier + return + + if self.verbose: + print('classification for test set analysis...') + + self.used_normalization = self.dataset.normalization_test + self.used_features_for_classif = self.dataset.feature_ref_array + + train_matrix = self._return_train_matrix_for_classification() + labels = self.labels + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + self.classifier_grid.fit(train_matrix, labels) + + self.classifier_test, params = select_best_classif_params(self.classifier_grid) + + self.classifier_test.set_params(probability=True) + self.classifier_test.fit(train_matrix, labels) + + if self.verbose: + cvs = cross_val_score(self.classifier_test, train_matrix, labels, cv=5) + print('best params:', params) + print('cross val score: {0}'.format(np.mean(cvs))) + print('classification score:', self.classifier_test.score(train_matrix, labels)) + + def predict_labels(self): + """ + predict labels from training set + using K-Means algorithm on the node activities, + using only nodes linked to survival + """ + if self.verbose: + print('performing clustering on the omic model with the following key:{0}'.format( + self.training_omic_list)) + + if hasattr(self.cluster_method, 'fit_predict'): + self.clustering = self.cluster_method(n_clusters=self.nb_clusters) + self.cluster_method == 'custom' + + elif self.cluster_method == 'kmeans': + self.clustering = KMeans(n_clusters=self.nb_clusters, n_init=100) + + elif self.cluster_method == 'mixture': + self.clustering = GaussianMixture( + n_components=self.nb_clusters, + **self.mixture_params + ) + + elif self.cluster_method == "coxPH": + nbdays, isdead = self.dataset.survival.T.tolist() + + self.clustering = ClusterWithSurvival( + n_clusters=self.nb_clusters, + isdead=isdead, + nbdays=nbdays) + + elif self.cluster_method == "coxPHMixture": + nbdays, isdead = self.dataset.survival.T.tolist() + + self.clustering = ClusterWithSurvival( + n_clusters=self.nb_clusters, + use_gaussian_to_dichotomize=True, + isdead=isdead, + nbdays=nbdays) + + else: + raise(Exception("No method fit and predict found for: {0}".format( + self.cluster_method))) + + if not self.activities_train.any(): + raise Exception('No components linked to survival!'\ + ' cannot perform clustering') + + if self.cluster_array and len(self.cluster_array) > 1: + self._predict_best_k_for_cluster() + + if hasattr(self.clustering, 'predict'): + self.clustering.fit(self.activities_train) + labels = self.clustering.predict(self.activities_train) + else: + labels = self.clustering.fit_predict(self.activities_train) + + labels = self._order_labels_according_to_survival(labels) + + self.labels = labels + + if hasattr(self.clustering, 'predict_proba'): + self.labels_proba = self.clustering.predict_proba(self.activities_train) + else: + self.labels_proba = np.array([self.labels, self.labels]).T + + if len(self.labels_proba.shape) == 1: + self.labels_proba = self.labels_proba.reshape(( + self.labels_proba.shape[0], 1)) + + if self.labels_proba.shape[1] < self.nb_clusters: + missing_columns = self.nb_clusters - self.labels_proba.shape[1] + + for i in range(missing_columns): + self.labels_proba = hstack([ + self.labels_proba, np.zeros( + shape=(self.labels_proba.shape[0], 1))]) + + if self.verbose: + print("clustering done, labels ordered according to survival:") + for key, value in Counter(labels).items(): + print('cluster label: {0}\t number of samples:{1}'.format(key, value)) + print('\n') + + nbdays, isdead = self.dataset.survival.T.tolist() + + if self.metadata_usage in ['all', 'labels'] and \ + self.dataset.metadata_mat is not None: + metadata_mat = self.dataset.metadata_mat + else: + metadata_mat = None + + pvalue = coxph(self.labels, isdead, nbdays, + isfactor=False, + do_KM_plot=self.do_KM_plot, + png_path=self.path_results, + seed=self.seed, + use_r_packages=self.use_r_packages, + metadata_mat=metadata_mat, + fig_name='{0}_KM_plot_training_dataset'.format(self.project_name)) + + pvalue_proba = coxph(self.labels_proba.T[0], + isdead, nbdays, + seed=self.seed, + use_r_packages=self.use_r_packages, + metadata_mat=metadata_mat, + isfactor=False) + + if not self._isboosting: + self._write_labels(self.dataset.sample_ids, self.labels, + labels_proba=self.labels_proba.T[0], + fname='{0}_training_set_labels'.format(self.project_name)) + + if self.verbose: + print('Cox-PH p-value (Log-Rank) for the cluster labels: {0}'.format(pvalue)) + + self.train_pvalue = pvalue + self.train_pvalue_proba = pvalue_proba + + def evalutate_cluster_performance(self): + """ + """ + if not self.clustering: + print('clustering attribute is defined as None. ' \ + ' Cannot evaluate cluster performance') + return + + if self.cluster_method == 'mixture': + self.bic_score = self.clustering.bic(self.activities_train) + + self.silhouette_score = silhouette_score(self.activities_train, self.labels) + self.calinski_score = calinski_harabaz_score(self.activities_train, self.labels) + + if self.verbose: + print('silhouette score: {0}'.format(self.silhouette_score)) + print('calinski-harabaz score: {0}'.format(self.calinski_score)) + print('bic score: {0}'.format(self.bic_score)) + + def _write_labels(self, sample_ids, labels, fname="", + labels_proba=None, + nbdays=None, + isdead=None, + path_file=None): + """ """ + assert(fname or path_file) + + if not path_file: + path_file = '{0}/{1}.tsv'.format(self.path_results, fname) + + with open(path_file, 'w') as f_file: + for ids, (sample, label) in enumerate(zip(sample_ids, labels)): + suppl = '' + + if labels_proba is not None: + suppl += '\t{0}'.format(labels_proba[ids]) + if nbdays is not None: + suppl += '\t{0}'.format(nbdays[ids]) + if isdead is not None: + suppl += '\t{0}'.format(isdead[ids]) + + f_file.write('{0}\t{1}{2}\n'.format(sample, label, suppl)) + + print('file written: {0}'.format(path_file)) + + def _predict_survival_nodes(self, matrix_array, keys=None): + """ + """ + activities_array = {} + + if keys is None: + keys = list(matrix_array.keys()) + + for key in keys: + matrix = matrix_array[key] + if not self._pretrained_model: + if self.alternative_embedding is None and \ + self.encoder_input_shape(key)[1] != matrix.shape[1]: + if self.verbose: + print('matrix doesnt have the input dimension of the encoder'\ + ' returning None') + return None + + if self.alternative_embedding is not None: + activities = self.embedding_predict(key, matrix) + elif self.use_autoencoders: + activities = self.encoder_predict(key, matrix) + else: + activities = np.asarray(matrix) + + activities_array[key] = activities.T[self.valid_node_ids_array[key]].T + + return hstack([activities_array[key] + for key in keys]) + + def look_for_survival_nodes(self, keys=None): + """ + detect nodes from the autoencoder significantly + linked with survival through coxph regression + """ + if not keys: + keys = list(self.encoder_array.keys()) + + if not keys: + keys = self.matrix_train_array.keys() + + for key in keys: + matrix_train = self.matrix_train_array[key] + + if self.alternative_embedding is not None: + activities = self.embedding_predict(key, matrix_train) + elif self.use_autoencoders: + activities = self.encoder_predict(key, matrix_train) + else: + activities = np.asarray(matrix_train) + + if self.feature_surv_analysis: + valid_node_ids = self._look_for_nodes(key) + else: + valid_node_ids = np.arange(matrix_train.shape[1]) + + self.valid_node_ids_array[key] = valid_node_ids + self.activities_array[key] = activities.T[valid_node_ids].T + + if self.clustering_omics: + keys = self.clustering_omics + + self.activities_train = hstack([self.activities_array[key] + for key in keys]) + + def look_for_prediction_nodes(self, keys=None): + """ + detect nodes from the autoencoder that predict a + high c-index scores using label from the retained test fold + """ + if not keys: + keys = list(self.encoder_array.keys()) + + for key in keys: + matrix_train = self.matrix_train_array[key] + + if self.alternative_embedding is not None: + activities = self.embedding_predict(key, matrix_train) + elif self.use_autoencoders: + activities = self.encoder_predict(key, matrix_train) + else: + activities = np.asarray(matrix_train) + + if self.feature_surv_analysis: + valid_node_ids = self._look_for_prediction_nodes(key) + else: + valid_node_ids = np.arange(matrix_train.shape[1]) + + self.pred_node_ids_array[key] = valid_node_ids + + self.activities_pred_array[key] = activities.T[valid_node_ids].T + + self.activities_for_pred_train = hstack([self.activities_pred_array[key] + for key in keys]) + + def compute_c_indexes_multiple_for_test_dataset(self): + """ + return c-index using labels as predicat + """ + days, dead = np.asarray(self.dataset.survival).T + days_test, dead_test = np.asarray(self.dataset.survival_test).T + + activities_test = {} + + for key in self.dataset.matrix_test_array: + node_ids = self.pred_node_ids_array[key] + + matrix = self.dataset.matrix_test_array[key] + + if self.alternative_embedding is not None: + activities_test[key] = self.embedding_predict( + key, matrix).T[node_ids].T + + elif self.use_autoencoders: + activities_test[key] = self.encoder_predict( + key, matrix).T[node_ids].T + + else: + activities_test[key] = self.dataset.matrix_test_array[key] + + activities_test = hstack(activities_test.values()) + activities_train = hstack([self.activities_pred_array[key] + for key in self.dataset.matrix_ref_array]) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + cindex = c_index_multiple(activities_train, dead, days, + activities_test, dead_test, days_test, + seed=self.seed,) + + if self.verbose: + print('c-index multiple for test dataset:{0}'.format(cindex)) + + return cindex + + def compute_c_indexes_multiple_for_test_fold_dataset(self): + """ + return c-index using test-fold labels as predicat + """ + days, dead = np.asarray(self.dataset.survival).T + days_cv, dead_cv = np.asarray(self.dataset.survival_cv).T + + activities_cv = {} + + for key in self.dataset.matrix_cv_array: + node_ids = self.pred_node_ids_array[key] + + if self.alternative_embedding is not None: + activities_cv[key] = self.embedding_predict( + key, self.dataset.matrix_cv_array[key]).T[node_ids].T + + elif self.use_autoencoders: + activities_cv[key] = self.encoder_predict( + key, self.dataset.matrix_cv_array[key]).T[node_ids].T + + else: + activities_cv[key] = self.dataset.matrix_cv_array[key] + + activities_cv = hstack(activities_cv.values()) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + cindex = c_index_multiple(self.activities_for_pred_train, dead, days, + activities_cv, dead_cv, days_cv, + seed=self.seed,) + + if self.verbose: + print('c-index multiple for test fold dataset:{0}'.format(cindex)) + + return cindex + + def _return_test_matrix_for_classification(self, activities, matrix_array): + """ + """ + if self.classification_method == 'SURVIVAL_FEATURES': + return activities + elif self.classification_method == 'ALL_FEATURES': + matrix = self._reduce_and_stack_matrices(matrix_array) + return matrix + + def _predict_test_labels(self, activities, matrix_array): + """ """ + matrix_test = self._return_test_matrix_for_classification( + activities, matrix_array) + + self.test_labels = self.classifier_test.predict(matrix_test) + self.test_labels_proba = self.classifier_test.predict_proba(matrix_test) + + if self.test_labels_proba.shape[1] < self.nb_clusters: + missing_columns = self.nb_clusters - self.test_labels_proba.shape[1] + + for i in range(missing_columns): + self.test_labels_proba = hstack([ + self.test_labels_proba, np.zeros( + shape=(self.test_labels_proba, 1))]) + + def _predict_labels(self, activities, matrix_array): + """ """ + matrix_test = self._return_test_matrix_for_classification( + activities, matrix_array) + + labels = self.classifier.predict(matrix_test) + labels_proba = self.classifier.predict_proba(matrix_test) + + if labels_proba.shape[1] < self.nb_clusters: + missing_columns = self.nb_clusters - labels_proba.shape[1] + + for i in range(missing_columns): + labels_proba = hstack([ + labels_proba, np.zeros( + shape=(labels_proba.shape[0], 1))]) + + return labels, labels_proba + + def _predict_best_k_for_cluster(self): + """ """ + criterion = None + best_k = None + + for k_cluster in self.cluster_array: + if self.cluster_method == 'mixture': + self.clustering.set_params(n_components=k_cluster) + else: + self.clustering.set_params(n_clusters=k_cluster) + + labels = self.clustering.fit_predict(self.activities_train) + + if self.cluster_eval_method == 'bic': + score = self.clustering.bic(self.activities_train) + elif self.cluster_eval_method == 'calinski': + score = calinski_harabaz_score( + self.activities_train, + labels + ) + elif self.cluster_eval_method == 'silhouette': + score = silhouette_score( + self.activities_train, + labels + ) + + if self.verbose: + print('obtained {2}: {0} for k = {1}'.format(score, k_cluster, + self.cluster_eval_method)) + + if criterion == None or score < criterion: + criterion, best_k = score, k_cluster + + self.clustering_performance = criterion + + if self.verbose: + print('best k: {0}'.format(best_k)) + + if self.cluster_method == 'mixture': + self.clustering.set_params(n_components=best_k) + else: + self.clustering.set_params(n_clusters=best_k) + + def _order_labels_according_to_survival(self, labels): + """ + Order cluster labels according to survival + """ + labels_old = labels.copy() + + days, dead = np.asarray(self.dataset.survival).T + + self._label_ordered_dict = {} + + for label in set(labels_old): + mean = surv_median(dead[labels_old == label], + days[labels_old == label]) + self._label_ordered_dict[label] = mean + + label_ordered = [label for label, _ in + sorted(self._label_ordered_dict.items(), key=lambda x:x[1])] + + self._label_ordered_dict = {old_label: new_label + for new_label, old_label in enumerate(label_ordered)} + + for old_label in self._label_ordered_dict: + labels[labels_old == old_label] = self._label_ordered_dict[old_label] + + return labels + + def _look_for_survival_nodes(self, key=None, + activities=None, + survival=None, + metadata_mat=None): + """ + """ + if key is not None: + matrix_train = self.matrix_train_array[key] + + if self.alternative_embedding is not None: + activities = np.nan_to_num(self.embedding_predict( + key, matrix_train)) + + elif self.use_autoencoders: + activities = np.nan_to_num(self.encoder_predict( + key, matrix_train)) + + else: + activities = np.asarray(matrix_train) + else: + assert(activities is not None) + + if survival is not None: + nbdays, isdead = survival.T.tolist() + else: + nbdays, isdead = self.dataset.survival.T.tolist() + + if self.feature_selection_usage == 'lasso': + cws = ClusterWithSurvival( + isdead=isdead, + nbdays=nbdays, + metadata_mat=metadata_mat) + + return cws.get_nonzero_features(activities) + + else: + return self._get_survival_features_parallel( + isdead, nbdays, metadata_mat, activities, key) + + def _get_survival_features_parallel( + self, isdead, nbdays, metadata_mat, activities, key): + """ """ + pool = None + + if not self._isboosting: + pool = Pool(self.nb_threads_coxph) + mapf = pool.map + else: + mapf = map + + input_list = iter((node_id, + activity, + isdead, + nbdays, + self.seed, + metadata_mat, self.use_r_packages) + + for node_id, activity in enumerate(activities.T)) + + pvalue_list = mapf(_process_parallel_coxph, input_list) + + pvalue_list = list(filter(lambda x: not np.isnan(x[1]), pvalue_list)) + pvalue_list.sort(key=lambda x: x[1], reverse=True) + + valid_node_ids = [node_id for node_id, pvalue in pvalue_list + if pvalue < self.pvalue_thres] + + if self.verbose: + print('number of components linked to survival found:{0} for key {1}'.format( + len(valid_node_ids), key)) + + if pool is not None: + pool.close() + pool.join() + + return valid_node_ids + + def _look_for_prediction_nodes(self, key): + """ + """ + nbdays, isdead = self.dataset.survival.T.tolist() + nbdays_cv, isdead_cv = self.dataset.survival_cv.T.tolist() + + matrix_train = self.matrix_train_array[key] + matrix_cv = self.dataset.matrix_cv_array[key] + + if self.alternative_embedding is not None: + activities_train = self.embedding_predict(key, matrix_train) + activities_cv = self.embedding_predict(key, matrix_cv) + + elif self.use_autoencoders: + activities_train = self.encoder_predict(key, matrix_train) + activities_cv = self.encoder_predict(key, matrix_cv) + else: + activities_train = np.asarray( matrix_train) + activities_cv = np.asarray( matrix_cv) + + input_list = iter((node_id, + activities_train.T[node_id], isdead, nbdays, + activities_cv.T[node_id], isdead_cv, nbdays_cv, self.use_r_packages) + for node_id in range(activities_train.shape[1])) + + score_list = map(_process_parallel_cindex, input_list) + + score_list = filter(lambda x: not np.isnan(x[1]), score_list) + score_list.sort(key=lambda x:x[1], reverse=True) + + valid_node_ids = [node_id for node_id, cindex in score_list + if cindex > self.cindex_thres] + + scores = [score for node_id, score in score_list + if score > self.cindex_thres] + + if self.verbose: + print('number of components with a high prediction score:{0} for key {1}'\ + ' \n\t mean: {2} std: {3}'.format( + len(valid_node_ids), key, np.mean(scores), np.std(scores))) + + return valid_node_ids + + def compute_c_indexes_for_full_dataset(self): + """ + return c-index using labels as predicat + """ + days, dead = np.asarray(self.dataset.survival).T + days_full, dead_full = np.asarray(self.dataset.survival_full).T + + try: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + cindex = c_index(self.labels, dead, days, + self.full_labels, dead_full, days_full, + use_r_packages=self.use_r_packages, + seed=self.seed,) + except Exception as e: + print('Exception while computing the c-index: {0}'.format(e)) + cindex = np.nan + + if self.verbose: + print('c-index for full dataset:{0}'.format(cindex)) + + return cindex + + def compute_c_indexes_for_training_dataset(self): + """ + return c-index using labels as predicat + """ + days, dead = np.asarray(self.dataset.survival).T + + try: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + cindex = c_index(self.labels, dead, days, + self.labels, dead, days, + use_r_packages=self.use_r_packages, + seed=self.seed,) + except Exception as e: + print('Exception while computing the c-index: {0}'.format(e)) + cindex = np.nan + + if self.verbose: + print('c-index for training dataset:{0}'.format(cindex)) + + return cindex + + def compute_c_indexes_for_test_dataset(self): + """ + return c-index using labels as predicat + """ + days, dead = np.asarray(self.dataset.survival).T + days_test, dead_test = np.asarray(self.dataset.survival_test).T + + try: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + cindex = c_index(self.labels, dead, days, + self.test_labels, dead_test, days_test, + use_r_packages=self.use_r_packages, + seed=self.seed,) + except Exception as e: + print('Exception while computing the c-index: {0}'.format(e)) + cindex = np.nan + + if self.verbose: + print('c-index for test dataset:{0}'.format(cindex)) + + return cindex + + def compute_c_indexes_for_test_fold_dataset(self): + """ + return c-index using labels as predicat + """ + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + days, dead = np.asarray(self.dataset.survival).T + days_cv, dead_cv= np.asarray(self.dataset.survival_cv).T + + try: + cindex = c_index(self.labels, dead, days, + self.cv_labels, dead_cv, days_cv, + use_r_packages=self.use_r_packages, + seed=self.seed,) + except Exception as e: + print('Exception while computing the c-index: {0}'.format(e)) + cindex = np.nan + + if self.verbose: + print('c-index for test fold dataset:{0}'.format(cindex)) + + return cindex + + def predict_nodes_activities(self, matrix_array): + """ + """ + activities = [] + + for key in matrix_array: + if key not in self.pred_node_ids_array: + continue + + node_ids = self.pred_node_ids_array[key] + + if self.alternative_embedding is not None: + activities.append( + self.embedding_predict( + key, matrix_array[key]).T[node_ids].T) + else: + activities.append( + self.encoder_predict( + key, matrix_array[key]).T[node_ids].T) + + return hstack(activities) + + def plot_kernel_for_test_sets(self, + dataset=None, + labels=None, + labels_proba=None, + test_labels=None, + test_labels_proba=None, + define_as_main_kernel=False, + use_main_kernel=False, + activities=None, + activities_test=None, + key=''): + """ + """ + from simdeep.plot_utils import plot_kernel_plots + + if dataset is None: + dataset = self.dataset + + if labels is None: + labels = self.labels + + if labels_proba is None: + labels_proba = self.labels_proba + + if test_labels_proba is None: + test_labels_proba = self.test_labels_proba + + if test_labels is None: + test_labels = self.test_labels + + if test_labels_proba is None: + test_labels_proba = self.test_labels_proba + + test_norm = self.test_normalization + train_norm = self.dataset.normalization + train_norm = {key: train_norm[key] for key in train_norm if train_norm[key]} + + is_same_normalization = train_norm == test_norm + is_filled_with_zero = self.dataset.fill_unkown_feature_with_0 + + if activities is None or activities_test is None: + if not (is_same_normalization and is_filled_with_zero): + print('\n<><><><> Cannot plot survival KDE plot' \ + ' Different normalisation used for test set <><><><>\n') + return + + activities = hstack([self.activities_array[omic] + for omic in self.test_omic_list]) + activities_test = self.activities_test + + if define_as_main_kernel: + self._main_kernel = {'activities': activities_test.copy(), + 'labels': test_labels.copy()} + + if use_main_kernel: + activities = self._main_kernel['activities'] + labels = self._main_kernel['labels'] + + html_name = '{0}/{1}{2}_test_kdeplot.html'.format( + self.path_results, + self.project_name, + key) + + plot_kernel_plots( + test_labels=test_labels, + test_labels_proba=test_labels_proba, + labels=labels, + activities=activities, + activities_test=activities_test, + dataset=self.dataset, + path_html=html_name) + + def plot_supervised_kernel_for_test_sets( + self, + labels=None, + labels_proba=None, + dataset=None, + key='', + use_main_kernel=False, + test_labels=None, + test_labels_proba=None, + define_as_main_kernel=False, + ): + """ + """ + if labels is None: + labels = self.labels + + if labels_proba is None: + labels_proba = self.labels_proba + + if dataset is None: + dataset = self.dataset + + activities, activities_test = self._predict_kde_matrix( + labels_proba, dataset) + + key += '_supervised' + + self.plot_kernel_for_test_sets(labels=labels, + labels_proba=labels_proba, + dataset=dataset, + activities=activities, + activities_test=activities_test, + key=key, + use_main_kernel=use_main_kernel, + test_labels=test_labels, + test_labels_proba=test_labels_proba, + define_as_main_kernel=define_as_main_kernel, + ) + + def _create_autoencoder_for_kernel_plot(self, labels_proba, dataset, key): + """ + """ + autoencoder = DeepBase(dataset=dataset, + seed=self.seed, + verbose=False, + dropout=0.1, + epochs=50) + + autoencoder.matrix_train_array = dataset.matrix_ref_array + autoencoder.construct_supervized_network(labels_proba) + + self.encoder_for_kde_plot_dict[key] = autoencoder.encoder_array + + def _predict_kde_matrix(self, labels_proba, dataset): + """ + """ + matrix_ref_list = [] + matrix_test_list = [] + + encoder_key = str(self.test_normalization) + encoder_key = 'omic:{0} normalisation: {1}'.format( + self.test_omic_list, + encoder_key) + + if encoder_key not in self.encoder_for_kde_plot_dict or \ + not dataset.fill_unkown_feature_with_0: + self._create_autoencoder_for_kernel_plot( + labels_proba, dataset, encoder_key) + + encoder_array = self.encoder_for_kde_plot_dict[encoder_key] + + if self.metadata_usage in ['all', 'new-features'] and \ + dataset.metadata_mat is not None: + metadata_mat = dataset.metadata_mat + else: + metadata_mat = None + + for key in encoder_array: + matrix_ref = encoder_array[key].predict( + dataset.matrix_ref_array[key]) + matrix_test = encoder_array[key].predict( + dataset.matrix_test_array[key]) + + survival_node_ids = self._look_for_survival_nodes( + activities=matrix_ref, survival=dataset.survival, + metadata_mat=metadata_mat) + + if len(survival_node_ids) > 1: + matrix_ref = matrix_ref.T[survival_node_ids].T + matrix_test = matrix_test.T[survival_node_ids].T + else: + print('not enough survival nodes to construct kernel for key: {0}' \ + 'skipping the {0} matrix'.format(key)) + continue + + matrix_ref_list.append(matrix_ref) + matrix_test_list.append(matrix_test) + + if not matrix_ref_list: + print('matrix_ref_list / matrix_test_list empty!' \ + 'take the last OMIC ({0}) matrix as ref'.format(key)) + matrix_ref_list.append(matrix_ref) + matrix_test_list.append(matrix_test) + + return hstack(matrix_ref_list), hstack(matrix_test_list) + + + def _get_probas_for_full_model(self): + """ + return sample and proba + """ + return list(zip(self.dataset.sample_ids_full, self.full_labels_proba)) + + + def _get_pvalues_and_pvalues_proba(self): + """ + """ + return self.full_pvalue, self.full_pvalue_proba + + def _get_from_dataset(self, attr): + """ + """ + return getattr(self.dataset, attr) + + def _get_attibute(self, attr): + """ + """ + return getattr(self, attr) + + + def _partial_fit_model_pool(self): + """ + """ + try: + self.load_training_dataset() + self.fit() + + if len(set(self.labels)) < 1: + raise Exception('only one class!') + + if self.train_pvalue > MODEL_THRES: + raise Exception('pvalue: {0} not significant!'.format(self.train_pvalue)) + + except Exception as e: + print('model with random state:{1} didn\'t converge:{0}'.format(str(e), self.seed)) + return False + + else: + print('model with random state:{0} fitted'.format(self.seed)) + self._is_fitted = True + + self.predict_labels_on_test_fold() + self.predict_labels_on_full_dataset() + self.evalutate_cluster_performance() + + return self._is_fitted + + def _partial_fit_model_with_pretrained_pool(self, labels_file): + """ + """ + self.fit_on_pretrained_label_file(labels_file) + + self.predict_labels_on_test_fold() + self.predict_labels_on_full_dataset() + self.evalutate_cluster_performance() + + self._is_fitted = True + + return self._is_fitted + + def _predict_new_dataset(self, + tsv_dict, + path_survival_file, + normalization, + survival_flag=None, + metadata_file=None): + """ + """ + self.load_new_test_dataset( + tsv_dict=tsv_dict, + path_survival_file=path_survival_file, + normalization=normalization, + survival_flag=survival_flag, + metadata_file=metadata_file + ) + + self.predict_labels_on_test_dataset()