Switch to unified view

a b/main_cross_prediction_rna_protein.py
1
2
3
from moETM.train import Trainer_moETM_for_cross_prediction, Train_moETM_for_cross_prediction
4
from dataloader import load_nips_dataset_rna_protein_dataset, prepare_nips_dataset, data_process_moETM_cross_prediction
5
from moETM.build_model import build_moETM
6
import pandas as pd
7
import gc
8
import os
9
import numpy as np
10
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
11
12
import warnings
13
warnings.filterwarnings('ignore')
14
15
# Load dataset
16
mod_file_path = "./data/GSE194122_openproblems_neurips2021_cite_BMMC_processed.h5ad"
17
gene_encoding = pd.read_csv('./useful_file/gene_coding_nips_rna_protein.csv')
18
protein_encoding = pd.read_csv('./useful_file/protein_coding_nips_rna_protein.csv')
19
20
adata_mod1, adata_mod2 = load_nips_dataset_rna_protein_dataset(mod_file_path, gene_encoding, protein_encoding)
21
gc.collect()
22
23
# Prepare dataset
24
adata_mod1, adata_mod2 = prepare_nips_dataset(adata_mod1, adata_mod2)
25
26
n_total_sample = adata_mod1.shape[0]
27
28
X_mod1_train_T, X_mod2_train_T, batch_index_train_T, X_mod1_test_T, X_mod2_test_T, batch_index_test_T, test_adata_mod1, train_adata_mod1, test_mod1_sum, test_mod2_sum= data_process_moETM_cross_prediction(adata_mod1, adata_mod2, n_sample=np.int(np.floor(n_total_sample*0.8)))
29
30
num_batch = len(batch_index_train_T.unique())
31
input_dim_mod1 = X_mod1_train_T.shape[1]
32
input_dim_mod2 = X_mod2_train_T.shape[1]
33
train_num = X_mod1_train_T.shape[0]
34
35
num_topic = 200
36
emd_dim = 400
37
encoder_mod1, encoder_mod2, decoder, optimizer = build_moETM(input_dim_mod1, input_dim_mod2, num_batch, num_topic=num_topic, emd_dim=emd_dim)
38
39
direction = 'another_to_rna'   # Or another_to_rna
40
trainer = Trainer_moETM_for_cross_prediction(encoder_mod1, encoder_mod2, decoder, optimizer, direction)
41
42
Total_epoch = 500
43
batch_size = 2000
44
Train_set = [X_mod1_train_T, X_mod2_train_T, batch_index_train_T]
45
Test_set = [X_mod1_test_T, X_mod2_test_T, batch_index_test_T, test_adata_mod1, test_mod1_sum, test_mod2_sum]
46
Train_moETM_for_cross_prediction(trainer, Total_epoch, train_num, batch_size, Train_set, Test_set)
47