--- a +++ b/dataloader.py @@ -0,0 +1,413 @@ + +import numpy as np +import torch +import anndata as ad +import scanpy as sc + + +import gc + + +def load_nips_rna_atac_dataset(mod_file_path, gene_encoding): + adata = ad.read_h5ad(mod_file_path) + + feature_gex_index = np.array(adata.var.feature_types) == 'GEX' + feature_adt_index = np.array(adata.var.feature_types) == 'ATAC' + + gex = adata[:, feature_gex_index].copy() + atac = adata[:, feature_adt_index].copy() + del adata + + gc.collect() + + index = [] + for i in range(gex.shape[1]): + if gex.var['gene_id'][i] != gene_encoding['gene_id'][i]: + print('Warning') + else: + A = bool(gene_encoding['is_gene_coding'][i]) + index.append(A) + + gex = gex[:, index].copy() + + # gex.var.to_csv('./gex_name.csv') + # atac.var.to_csv('./atac_name.csv') + + adata_mod1 = gex.copy() + adata_mod1.X = adata_mod1.layers['counts'] + del gex + + adata_mod2 = atac.copy() + adata_mod2.X = adata_mod2.layers['counts'] + del atac + + gc.collect() + + # obs = adata.obs + # adata_mod1 = ad.AnnData(X=adata.layers['counts'][:, feature_gex_index], obs=obs) + # adata_mod2 = ad.AnnData(X=adata.layers['counts'][:, feature_adt_index], obs=obs) + + adata_mod1_original = ad.AnnData.copy(adata_mod1) + adata_mod2_original = ad.AnnData.copy(adata_mod2) + + sc.pp.normalize_total(adata_mod1, target_sum=1e4) + sc.pp.log1p(adata_mod1) + sc.pp.highly_variable_genes(adata_mod1) + index = adata_mod1.var['highly_variable'].values + + adata_mod1 = ad.AnnData.copy(adata_mod1_original) + adata_mod1 = adata_mod1[:, index].copy() + + del adata_mod1_original + gc.collect() + + sc.pp.normalize_total(adata_mod2, target_sum=1e4) + sc.pp.log1p(adata_mod2) + sc.pp.highly_variable_genes(adata_mod2) + index = adata_mod2.var['highly_variable'].values + + adata_mod2 = ad.AnnData.copy(adata_mod2_original) + del adata_mod2_original + gc.collect() + + adata_mod2 = adata_mod2[:, index].copy() + + return adata_mod1, adata_mod2 + +def prepare_nips_dataset(adata_gex, adata_mod2, + batch_col = 'batch', + ): + + batch_index = np.array(adata_gex.obs[batch_col].values) + unique_batch = list(np.unique(batch_index)) + batch_index = np.array([unique_batch.index(xs) for xs in batch_index]) + + obs = adata_gex.obs + obs.insert(obs.shape[1], 'batch_indices', batch_index) + adata_gex = ad.AnnData(X=adata_gex.X, obs=obs) + + obs = adata_mod2.obs + obs.insert(obs.shape[1], 'batch_indices', batch_index) + + X = adata_mod2.X + adata_mod2 = ad.AnnData(X=X, obs=obs) + + Index = np.array(X.sum(1)>0).squeeze() + + adata_gex = adata_gex[Index] + obs = adata_gex.obs + adata_gex = ad.AnnData(X=adata_gex.X, obs=obs) + + adata_mod2 = adata_mod2[Index] + obs = adata_mod2.obs + adata_mod2 = ad.AnnData(X=adata_mod2.X, obs=obs) + + return adata_gex, adata_mod2 + +def data_process_moETM(adata_mod1, adata_mod2): + # train/test on the whole + train_adata_mod1 = adata_mod1 + train_adata_mod2 = adata_mod2 + + ######################################################## + # Training dataset + X_mod1 = np.array(train_adata_mod1.X.todense()) + X_mod2 = np.array(train_adata_mod2.X.todense()) + batch_index = np.array(train_adata_mod1.obs['batch_indices']) + + X_mod1 = X_mod1 / X_mod1.sum(1)[:, np.newaxis] + X_mod2 = X_mod2 / X_mod2.sum(1)[:, np.newaxis] + + X_mod1_train_T = torch.from_numpy(X_mod1).float() + X_mod2_train_T = torch.from_numpy(X_mod2).float() + batch_index_train_T = torch.from_numpy(batch_index).to(torch.int64) + + del X_mod1, X_mod2, batch_index + + return X_mod1_train_T, X_mod2_train_T, batch_index_train_T, train_adata_mod1 + +def data_process_moETM_split(adata_mod1, adata_mod2, n_sample, test_ratio=0.1): + ###### random split for training and testing + from sklearn.utils import resample + Index = np.arange(0, n_sample) + train_index = resample(Index, n_samples=int(n_sample*(1-test_ratio)), replace=False) + test_index = np.array(list(set(range(n_sample)).difference(train_index))) + + train_adata_mod1 = adata_mod1[train_index] + obs = train_adata_mod1.obs + X = train_adata_mod1.X + train_adata_mod1 = ad.AnnData(X=X, obs=obs) + + train_adata_mod2 = adata_mod2[train_index] + obs = train_adata_mod2.obs + X = train_adata_mod2.X + train_adata_mod2 = ad.AnnData(X=X, obs=obs) + + test_adata_mod1 = adata_mod1[test_index] + obs = test_adata_mod1.obs + X = test_adata_mod1.X + test_adata_mod1 = ad.AnnData(X=X, obs=obs) + + test_adata_mod2 = adata_mod2[test_index] + obs = test_adata_mod2.obs + X = test_adata_mod2.X + test_adata_mod2 = ad.AnnData(X=X, obs=obs) + + ######################################################## + # Training dataset + X_mod1 = np.array(train_adata_mod1.X.todense()) + X_mod2 = np.array(train_adata_mod2.X.todense()) + batch_index = np.array(train_adata_mod1.obs['batch_indices']) + + X_mod1 = X_mod1 / X_mod1.sum(1)[:, np.newaxis] + X_mod2 = X_mod2 / X_mod2.sum(1)[:, np.newaxis] + + X_mod1_train_T = torch.from_numpy(X_mod1).float() + X_mod2_train_T = torch.from_numpy(X_mod2).float() + batch_index_train_T = torch.from_numpy(batch_index).to(torch.int64).cuda() + + # Testing dataset + X_mod1 = np.array(test_adata_mod1.X.todense()) + X_mod2 = np.array(test_adata_mod2.X.todense()) + batch_index = np.array(test_adata_mod1.obs['batch_indices']) + + X_mod1 = X_mod1 / X_mod1.sum(1)[:, np.newaxis] + X_mod2 = X_mod2 / X_mod2.sum(1)[:, np.newaxis] + + X_mod1_test_T = torch.from_numpy(X_mod1).float() + X_mod2_test_T = torch.from_numpy(X_mod2).float() + batch_index_test_T = torch.from_numpy(batch_index).to(torch.int64) + + del X_mod1, X_mod2, batch_index + + return X_mod1_train_T, X_mod2_train_T, batch_index_train_T, X_mod1_test_T, X_mod2_test_T, batch_index_test_T, test_adata_mod1 + +def data_process_moETM_leave_one_batch(adata_mod1, adata_mod2, batch_index_as_test): + #leave one batch for testing + train_index = (adata_mod1.obs['batch_indices'] != batch_index_as_test) + test_index = (adata_mod1.obs['batch_indices'] == batch_index_as_test) + + train_adata_mod1 = adata_mod1[train_index] + obs = train_adata_mod1.obs + X = train_adata_mod1.X + train_adata_mod1 = ad.AnnData(X=X, obs=obs) + + train_adata_mod2 = adata_mod2[train_index] + obs = train_adata_mod2.obs + X = train_adata_mod2.X + train_adata_mod2 = ad.AnnData(X=X, obs=obs) + + test_adata_mod1 = adata_mod1[test_index] + obs = test_adata_mod1.obs + X = test_adata_mod1.X + test_adata_mod1 = ad.AnnData(X=X, obs=obs) + + test_adata_mod2 = adata_mod2[test_index] + obs = test_adata_mod2.obs + X = test_adata_mod2.X + test_adata_mod2 = ad.AnnData(X=X, obs=obs) + + ######################################################## + # Training dataset + X_mod1 = np.array(train_adata_mod1.X.todense()) + X_mod2 = np.array(train_adata_mod2.X.todense()) + batch_index = np.array(train_adata_mod1.obs['batch_indices']) + + ##convert batch index + batch_mapping = {batch: i for i, batch in enumerate(set(batch_index))} + mapped_index = np.array([batch_mapping[batch] for batch in batch_index]) + batch_index = mapped_index + + X_mod1 = X_mod1 / X_mod1.sum(1)[:, np.newaxis] + X_mod2 = X_mod2 / X_mod2.sum(1)[:, np.newaxis] + + X_mod1_train_T = torch.from_numpy(X_mod1).float() + X_mod2_train_T = torch.from_numpy(X_mod2).float() + batch_index_train_T = torch.from_numpy(batch_index).to(torch.int64).cuda() + + # Testing dataset + X_mod1 = np.array(test_adata_mod1.X.todense()) + X_mod2 = np.array(test_adata_mod2.X.todense()) + batch_index = np.array(test_adata_mod1.obs['batch_indices']) + + ##convert batch index + batch_mapping = {batch: i for i, batch in enumerate(set(batch_index))} + mapped_index = np.array([batch_mapping[batch] for batch in batch_index]) + batch_index = mapped_index + + X_mod1 = X_mod1 / X_mod1.sum(1)[:, np.newaxis] + X_mod2 = X_mod2 / X_mod2.sum(1)[:, np.newaxis] + + X_mod1_test_T = torch.from_numpy(X_mod1).float() + X_mod2_test_T = torch.from_numpy(X_mod2).float() + batch_index_test_T = torch.from_numpy(batch_index).to(torch.int64) + + del X_mod1, X_mod2, batch_index + + return X_mod1_train_T, X_mod2_train_T, batch_index_train_T, X_mod1_test_T, X_mod2_test_T, batch_index_test_T, test_adata_mod1, train_adata_mod1 + + +def data_process_moETM_cross_prediction(adata_mod1, adata_mod2, n_sample): + from sklearn.utils import resample + + Index = np.arange(0, n_sample) + train_index = resample(Index, n_samples=n_sample) + test_index = np.array(list(set(range(n_sample)).difference(train_index))) + + train_adata_mod1 = adata_mod1[train_index] + obs = train_adata_mod1.obs + X = train_adata_mod1.X + train_adata_mod1 = ad.AnnData(X=X, obs=obs) + + train_adata_mod2 = adata_mod2[train_index] + obs = train_adata_mod2.obs + X = train_adata_mod2.X + train_adata_mod2 = ad.AnnData(X=X, obs=obs) + + test_adata_mod1 = adata_mod1[test_index] + obs = test_adata_mod1.obs + X = test_adata_mod1.X + test_adata_mod1 = ad.AnnData(X=X, obs=obs) + + test_adata_mod2 = adata_mod2[test_index] + obs = test_adata_mod2.obs + X = test_adata_mod2.X + test_adata_mod2 = ad.AnnData(X=X, obs=obs) + + ######################################################## + # Training dataset + X_mod1 = np.array(train_adata_mod1.X.todense()) + X_mod2 = np.array(train_adata_mod2.X.todense()) + batch_index = np.array(train_adata_mod1.obs['batch_indices']) + + X_mod1 = X_mod1 / X_mod1.sum(1)[:, np.newaxis] + X_mod2 = X_mod2 / X_mod2.sum(1)[:, np.newaxis] + + X_mod1_train_T = torch.from_numpy(X_mod1).float() + X_mod2_train_T = torch.from_numpy(X_mod2).float() + batch_index_train_T = torch.from_numpy(batch_index).to(torch.int64).cuda() + + # Testing dataset + X_mod1 = np.array(test_adata_mod1.X.todense()) + X_mod2 = np.array(test_adata_mod2.X.todense()) + batch_index = np.array(test_adata_mod1.obs['batch_indices']) + + sum1 = X_mod1.sum(1) + sum2 = X_mod2.sum(1) + + X_mod1 = X_mod1 / X_mod1.sum(1)[:, np.newaxis] + X_mod2 = X_mod2 / X_mod2.sum(1)[:, np.newaxis] + + X_mod1_test_T = torch.from_numpy(X_mod1).float() + X_mod2_test_T = torch.from_numpy(X_mod2).float() + batch_index_test_T = torch.from_numpy(batch_index).to(torch.int64) + + + del X_mod1, X_mod2, batch_index + + return X_mod1_train_T, X_mod2_train_T, batch_index_train_T, X_mod1_test_T, X_mod2_test_T, batch_index_test_T, test_adata_mod1, train_adata_mod1, sum1, sum2 + + +def load_nips_dataset_rna_protein_dataset(mod_file_path, gene_encoding, protein_encoding): + + adata = ad.read_h5ad(mod_file_path) + + feature_gex_index = np.array(adata.var.feature_types) == 'GEX' + feature_adt_index = np.array(adata.var.feature_types) == 'ADT' + + adata_mod1 = adata[:, feature_gex_index].copy() + adata_mod2 = adata[:, feature_adt_index].copy() + + adata_mod1.X = adata_mod1.layers['counts'] + adata_mod2.X = adata_mod2.layers['counts'] + + index = [] + for i in range(adata_mod1.shape[1]): + if adata_mod1.var.index[i] != gene_encoding['X'][i]: + print('Warning') + else: + index.append(bool(gene_encoding['is_gene_coding'][i])) + + adata_mod1_original = adata_mod1[:, index].copy() + adata_mod1 = adata_mod1[:, index].copy() + + sc.pp.normalize_total(adata_mod1, target_sum=1e4) + sc.pp.log1p(adata_mod1) + sc.pp.highly_variable_genes(adata_mod1) # n_top_genes + index = adata_mod1.var['highly_variable'].values + + adata_mod1_original = adata_mod1_original[:, index].copy() + + index = [] + for i in range(adata_mod2.shape[1]): + if adata_mod2.var.index[i] != protein_encoding['X'][i]: + print('Warning') + else: + index.append(bool(protein_encoding['is_protein_coding'][i])) + + adata_mod2 = adata_mod2[:, index].copy() + + return adata_mod1_original, adata_mod2 + +def load_nips_rna_atac_dataset_with_pathway(mod_file_path, gene_encoding, gene_pathway): + adata = ad.read_h5ad(mod_file_path) + + feature_gex_index = np.array(adata.var.feature_types) == 'GEX' + feature_adt_index = np.array(adata.var.feature_types) == 'ATAC' + + gex = adata[:, feature_gex_index].copy() + atac = adata[:, feature_adt_index].copy() + del adata + + gc.collect() + + gene_pathway_sum = gene_pathway.sum(0) + index = [] + for i in range(gex.shape[1]): + if gex.var['gene_id'][i] != gene_encoding['gene_id'][i]: + print('Warning') + else: + A = bool(gene_encoding['is_gene_coding'][i]) + B = bool(gene_pathway_sum[i]) + index.append(A & B) + + gex = gex[:, index].copy() + gene_pathway = gene_pathway[:, index].copy() + + adata_mod1 = gex.copy() + adata_mod1.X = adata_mod1.layers['counts'] + del gex + + adata_mod2 = atac.copy() + adata_mod2.X = adata_mod2.layers['counts'] + del atac + + gc.collect() + + adata_mod1_original = ad.AnnData.copy(adata_mod1) + adata_mod2_original = ad.AnnData.copy(adata_mod2) + + sc.pp.normalize_total(adata_mod1, target_sum=1e4) + sc.pp.log1p(adata_mod1) + sc.pp.highly_variable_genes(adata_mod1) + index = adata_mod1.var['highly_variable'].values + + adata_mod1 = ad.AnnData.copy(adata_mod1_original) + adata_mod1 = adata_mod1[:, index].copy() + gene_pathway = gene_pathway[:, index].copy() + + del adata_mod1_original + gc.collect() + + sc.pp.normalize_total(adata_mod2, target_sum=1e4) + sc.pp.log1p(adata_mod2) + sc.pp.highly_variable_genes(adata_mod2) + index = adata_mod2.var['highly_variable'].values + + adata_mod2 = ad.AnnData.copy(adata_mod2_original) + del adata_mod2_original + gc.collect() + + adata_mod2 = adata_mod2[:, index].copy() + + return adata_mod1, adata_mod2, gene_pathway