|
a |
|
b/MultiOmiVAE.py |
|
|
1 |
import torch |
|
|
2 |
import numpy as np |
|
|
3 |
import pandas as pd |
|
|
4 |
from sklearn.model_selection import train_test_split |
|
|
5 |
from torch import nn, optim |
|
|
6 |
from torch.utils.data import Dataset, DataLoader |
|
|
7 |
from torch.nn import functional as F |
|
|
8 |
from torch.utils.tensorboard import SummaryWriter |
|
|
9 |
from earlystoping import Earlystopping |
|
|
10 |
from sklearn import metrics |
|
|
11 |
|
|
|
12 |
|
|
|
13 |
def MultiOmiVAE(input_path, expr_df, methy_chr_df_list, random_seed=42, no_cuda=False, model_parallelism=True, |
|
|
14 |
separate_testing=True, batch_size=32, latent_dim=128, learning_rate=1e-3, p1_epoch_num=50, |
|
|
15 |
p2_epoch_num=100, output_loss_record=True, classifier=True, early_stopping=True): |
|
|
16 |
|
|
|
17 |
torch.manual_seed(random_seed) |
|
|
18 |
torch.cuda.manual_seed_all(random_seed) |
|
|
19 |
|
|
|
20 |
device = torch.device('cuda:0' if not no_cuda and torch.cuda.is_available() else 'cpu') |
|
|
21 |
parallel = torch.cuda.device_count() > 1 and model_parallelism |
|
|
22 |
|
|
|
23 |
# Sample ID and order that has both gene expression and DNA methylation data |
|
|
24 |
sample_id = np.loadtxt(input_path + 'both_samples.tsv', delimiter='\t', dtype='str') |
|
|
25 |
|
|
|
26 |
# Loading label |
|
|
27 |
label = pd.read_csv(input_path + 'both_samples_tumour_type_digit.tsv', sep='\t', header=0, index_col=0) |
|
|
28 |
class_num = len(label.tumour_type.unique()) |
|
|
29 |
label_array = label['tumour_type'].values |
|
|
30 |
|
|
|
31 |
if separate_testing: |
|
|
32 |
# Get testing set index and training set index |
|
|
33 |
# Separate according to different tumour types |
|
|
34 |
testset_ratio = 0.2 |
|
|
35 |
valset_ratio = 0.5 |
|
|
36 |
|
|
|
37 |
train_index, test_index, train_label, test_label = train_test_split(sample_id, label_array, |
|
|
38 |
test_size=testset_ratio, |
|
|
39 |
random_state=random_seed, |
|
|
40 |
stratify=label_array) |
|
|
41 |
val_index, test_index, val_label, test_label = train_test_split(test_index, test_label, test_size=valset_ratio, |
|
|
42 |
random_state=random_seed, stratify=test_label) |
|
|
43 |
|
|
|
44 |
expr_df_test = expr_df[test_index] |
|
|
45 |
expr_df_val = expr_df[val_index] |
|
|
46 |
expr_df_train = expr_df[train_index] |
|
|
47 |
|
|
|
48 |
methy_chr_df_test_list = [] |
|
|
49 |
methy_chr_df_val_list = [] |
|
|
50 |
methy_chr_df_train_list = [] |
|
|
51 |
for chrom_index in range(0, 23): |
|
|
52 |
methy_chr_df_test = methy_chr_df_list[chrom_index][test_index] |
|
|
53 |
methy_chr_df_test_list.append(methy_chr_df_test) |
|
|
54 |
methy_chr_df_val = methy_chr_df_list[chrom_index][val_index] |
|
|
55 |
methy_chr_df_val_list.append(methy_chr_df_val) |
|
|
56 |
methy_chr_df_train = methy_chr_df_list[chrom_index][train_index] |
|
|
57 |
methy_chr_df_train_list.append(methy_chr_df_train) |
|
|
58 |
|
|
|
59 |
# Get multi-omics dataset information |
|
|
60 |
sample_num = len(sample_id) |
|
|
61 |
expr_feature_num = expr_df.shape[0] |
|
|
62 |
methy_feature_num_list = [] |
|
|
63 |
for chrom_index in range(0, 23): |
|
|
64 |
feature_num = methy_chr_df_list[chrom_index].shape[0] |
|
|
65 |
methy_feature_num_list.append(feature_num) |
|
|
66 |
methy_feature_num_array = np.array(methy_feature_num_list) |
|
|
67 |
methy_feature_num = methy_feature_num_array.sum() |
|
|
68 |
print('\nNumber of samples: {}'.format(sample_num)) |
|
|
69 |
print('Number of gene expression features: {}'.format(expr_feature_num)) |
|
|
70 |
print('Number of methylation features: {}'.format(methy_feature_num)) |
|
|
71 |
if classifier: |
|
|
72 |
print('Number of classes: {}'.format(class_num)) |
|
|
73 |
|
|
|
74 |
class MultiOmiDataset(Dataset): |
|
|
75 |
""" |
|
|
76 |
Load multi-omics data |
|
|
77 |
""" |
|
|
78 |
|
|
|
79 |
def __init__(self, expr_df, methy_df_list, labels): |
|
|
80 |
self.expr_df = expr_df |
|
|
81 |
self.methy_df_list = methy_df_list |
|
|
82 |
self.labels = labels |
|
|
83 |
|
|
|
84 |
def __len__(self): |
|
|
85 |
return self.expr_df.shape[1] |
|
|
86 |
|
|
|
87 |
def __getitem__(self, index): |
|
|
88 |
omics_data = [] |
|
|
89 |
# Gene expression tensor index 0 |
|
|
90 |
expr_line = self.expr_df.iloc[:, index].values |
|
|
91 |
expr_line_tensor = torch.Tensor(expr_line) |
|
|
92 |
omics_data.append(expr_line_tensor) |
|
|
93 |
# Methylation tensor index 1-23 |
|
|
94 |
for methy_chrom_index in range(0, 23): |
|
|
95 |
methy_chr_line = self.methy_df_list[methy_chrom_index].iloc[:, index].values |
|
|
96 |
methy_chr_line_tensor = torch.Tensor(methy_chr_line) |
|
|
97 |
omics_data.append(methy_chr_line_tensor) |
|
|
98 |
label = self.labels[index] |
|
|
99 |
return [omics_data, label] |
|
|
100 |
|
|
|
101 |
# DataSets and DataLoaders |
|
|
102 |
if separate_testing: |
|
|
103 |
train_dataset = MultiOmiDataset(expr_df=expr_df_train, methy_df_list=methy_chr_df_train_list, labels=train_label) |
|
|
104 |
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=6) |
|
|
105 |
val_dataset = MultiOmiDataset(expr_df=expr_df_val, methy_df_list=methy_chr_df_val_list, labels=val_label) |
|
|
106 |
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=6) |
|
|
107 |
test_dataset = MultiOmiDataset(expr_df=expr_df_test, methy_df_list=methy_chr_df_test_list, labels=test_label) |
|
|
108 |
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=6) |
|
|
109 |
else: |
|
|
110 |
train_dataset = MultiOmiDataset(expr_df=expr_df, methy_df_list=methy_chr_df_list, labels=label_array) |
|
|
111 |
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=6) |
|
|
112 |
full_dataset = MultiOmiDataset(expr_df=expr_df, methy_df_list=methy_chr_df_list, labels=label_array) |
|
|
113 |
full_loader = DataLoader(full_dataset, batch_size=batch_size, num_workers=6) |
|
|
114 |
|
|
|
115 |
# Setting dimensions |
|
|
116 |
latent_space_dim = latent_dim |
|
|
117 |
input_dim_expr = expr_feature_num |
|
|
118 |
input_dim_methy_array = methy_feature_num_array |
|
|
119 |
level_2_dim_expr = 4096 |
|
|
120 |
level_2_dim_methy = 256 |
|
|
121 |
level_3_dim_expr = 1024 |
|
|
122 |
level_3_dim_methy = 1024 |
|
|
123 |
level_4_dim = 512 |
|
|
124 |
classifier_1_dim = 128 |
|
|
125 |
classifier_2_dim = 64 |
|
|
126 |
classifier_out_dim = class_num |
|
|
127 |
|
|
|
128 |
class VAE(nn.Module): |
|
|
129 |
def __init__(self): |
|
|
130 |
super(VAE, self).__init__() |
|
|
131 |
# ENCODER fc layers |
|
|
132 |
# level 1 |
|
|
133 |
# Methy input for each chromosome |
|
|
134 |
self.e_fc1_methy_1 = self.fc_layer(input_dim_methy_array[0], level_2_dim_methy) |
|
|
135 |
self.e_fc1_methy_2 = self.fc_layer(input_dim_methy_array[1], level_2_dim_methy) |
|
|
136 |
self.e_fc1_methy_3 = self.fc_layer(input_dim_methy_array[2], level_2_dim_methy) |
|
|
137 |
self.e_fc1_methy_4 = self.fc_layer(input_dim_methy_array[3], level_2_dim_methy) |
|
|
138 |
self.e_fc1_methy_5 = self.fc_layer(input_dim_methy_array[4], level_2_dim_methy) |
|
|
139 |
self.e_fc1_methy_6 = self.fc_layer(input_dim_methy_array[5], level_2_dim_methy) |
|
|
140 |
self.e_fc1_methy_7 = self.fc_layer(input_dim_methy_array[6], level_2_dim_methy) |
|
|
141 |
self.e_fc1_methy_8 = self.fc_layer(input_dim_methy_array[7], level_2_dim_methy) |
|
|
142 |
self.e_fc1_methy_9 = self.fc_layer(input_dim_methy_array[8], level_2_dim_methy) |
|
|
143 |
self.e_fc1_methy_10 = self.fc_layer(input_dim_methy_array[9], level_2_dim_methy) |
|
|
144 |
self.e_fc1_methy_11 = self.fc_layer(input_dim_methy_array[10], level_2_dim_methy) |
|
|
145 |
self.e_fc1_methy_12 = self.fc_layer(input_dim_methy_array[11], level_2_dim_methy) |
|
|
146 |
self.e_fc1_methy_13 = self.fc_layer(input_dim_methy_array[12], level_2_dim_methy) |
|
|
147 |
self.e_fc1_methy_14 = self.fc_layer(input_dim_methy_array[13], level_2_dim_methy) |
|
|
148 |
self.e_fc1_methy_15 = self.fc_layer(input_dim_methy_array[14], level_2_dim_methy) |
|
|
149 |
self.e_fc1_methy_16 = self.fc_layer(input_dim_methy_array[15], level_2_dim_methy) |
|
|
150 |
self.e_fc1_methy_17 = self.fc_layer(input_dim_methy_array[16], level_2_dim_methy) |
|
|
151 |
self.e_fc1_methy_18 = self.fc_layer(input_dim_methy_array[17], level_2_dim_methy) |
|
|
152 |
self.e_fc1_methy_19 = self.fc_layer(input_dim_methy_array[18], level_2_dim_methy) |
|
|
153 |
self.e_fc1_methy_20 = self.fc_layer(input_dim_methy_array[19], level_2_dim_methy) |
|
|
154 |
self.e_fc1_methy_21 = self.fc_layer(input_dim_methy_array[20], level_2_dim_methy) |
|
|
155 |
self.e_fc1_methy_22 = self.fc_layer(input_dim_methy_array[21], level_2_dim_methy) |
|
|
156 |
self.e_fc1_methy_X = self.fc_layer(input_dim_methy_array[22], level_2_dim_methy) |
|
|
157 |
|
|
|
158 |
# Expr input |
|
|
159 |
self.e_fc1_expr = self.fc_layer(input_dim_expr, level_2_dim_expr) |
|
|
160 |
|
|
|
161 |
# Level 2 |
|
|
162 |
self.e_fc2_methy = self.fc_layer(level_2_dim_methy*23, level_3_dim_methy) |
|
|
163 |
self.e_fc2_expr = self.fc_layer(level_2_dim_expr, level_3_dim_expr) |
|
|
164 |
# self.e_fc2_methy = self.fc_layer(level_2_dim_methy * 23, level_3_dim_methy, dropout=True) |
|
|
165 |
# self.e_fc2_expr = self.fc_layer(level_2_dim_expr, level_3_dim_expr, dropout=True) |
|
|
166 |
|
|
|
167 |
# Level 3 |
|
|
168 |
self.e_fc3 = self.fc_layer(level_3_dim_methy+level_3_dim_expr, level_4_dim) |
|
|
169 |
# self.e_fc3 = self.fc_layer(level_3_dim_methy+level_3_dim_expr, level_4_dim, dropout=True) |
|
|
170 |
|
|
|
171 |
# Level 4 |
|
|
172 |
self.e_fc4_mean = self.fc_layer(level_4_dim, latent_space_dim, activation=0) |
|
|
173 |
self.e_fc4_log_var = self.fc_layer(level_4_dim, latent_space_dim, activation=0) |
|
|
174 |
|
|
|
175 |
# model parallelism |
|
|
176 |
if parallel: |
|
|
177 |
self.e_fc1_methy_1.to('cuda:0') |
|
|
178 |
self.e_fc1_methy_2.to('cuda:0') |
|
|
179 |
self.e_fc1_methy_3.to('cuda:0') |
|
|
180 |
self.e_fc1_methy_4.to('cuda:0') |
|
|
181 |
self.e_fc1_methy_5.to('cuda:0') |
|
|
182 |
self.e_fc1_methy_6.to('cuda:0') |
|
|
183 |
self.e_fc1_methy_7.to('cuda:0') |
|
|
184 |
self.e_fc1_methy_8.to('cuda:0') |
|
|
185 |
self.e_fc1_methy_9.to('cuda:0') |
|
|
186 |
self.e_fc1_methy_10.to('cuda:0') |
|
|
187 |
self.e_fc1_methy_11.to('cuda:0') |
|
|
188 |
self.e_fc1_methy_12.to('cuda:0') |
|
|
189 |
self.e_fc1_methy_13.to('cuda:0') |
|
|
190 |
self.e_fc1_methy_14.to('cuda:0') |
|
|
191 |
self.e_fc1_methy_15.to('cuda:0') |
|
|
192 |
self.e_fc1_methy_16.to('cuda:0') |
|
|
193 |
self.e_fc1_methy_17.to('cuda:0') |
|
|
194 |
self.e_fc1_methy_18.to('cuda:0') |
|
|
195 |
self.e_fc1_methy_19.to('cuda:0') |
|
|
196 |
self.e_fc1_methy_20.to('cuda:0') |
|
|
197 |
self.e_fc1_methy_21.to('cuda:0') |
|
|
198 |
self.e_fc1_methy_22.to('cuda:0') |
|
|
199 |
self.e_fc1_methy_X.to('cuda:0') |
|
|
200 |
self.e_fc1_expr.to('cuda:0') |
|
|
201 |
self.e_fc2_methy.to('cuda:0') |
|
|
202 |
self.e_fc2_expr.to('cuda:0') |
|
|
203 |
self.e_fc3.to('cuda:0') |
|
|
204 |
self.e_fc4_mean.to('cuda:0') |
|
|
205 |
self.e_fc4_log_var.to('cuda:0') |
|
|
206 |
|
|
|
207 |
# DECODER fc layers |
|
|
208 |
# Level 4 |
|
|
209 |
self.d_fc4 = self.fc_layer(latent_space_dim, level_4_dim) |
|
|
210 |
|
|
|
211 |
# Level 3 |
|
|
212 |
self.d_fc3 = self.fc_layer(level_4_dim, level_3_dim_methy+level_3_dim_expr) |
|
|
213 |
# self.d_fc3 = self.fc_layer(level_4_dim, level_3_dim_methy+level_3_dim_expr, dropout=True) |
|
|
214 |
|
|
|
215 |
# Level 2 |
|
|
216 |
self.d_fc2_methy = self.fc_layer(level_3_dim_methy, level_2_dim_methy*23) |
|
|
217 |
self.d_fc2_expr = self.fc_layer(level_3_dim_expr, level_2_dim_expr) |
|
|
218 |
# self.d_fc2_methy = self.fc_layer(level_3_dim_methy, level_2_dim_methy*23, dropout=True) |
|
|
219 |
# self.d_fc2_expr = self.fc_layer(level_3_dim_expr, level_2_dim_expr, dropout=True) |
|
|
220 |
|
|
|
221 |
# level 1 |
|
|
222 |
# Methy output for each chromosome |
|
|
223 |
self.d_fc1_methy_1 = self.fc_layer(level_2_dim_methy, input_dim_methy_array[0], activation=2) |
|
|
224 |
self.d_fc1_methy_2 = self.fc_layer(level_2_dim_methy, input_dim_methy_array[1], activation=2) |
|
|
225 |
self.d_fc1_methy_3 = self.fc_layer(level_2_dim_methy, input_dim_methy_array[2], activation=2) |
|
|
226 |
self.d_fc1_methy_4 = self.fc_layer(level_2_dim_methy, input_dim_methy_array[3], activation=2) |
|
|
227 |
self.d_fc1_methy_5 = self.fc_layer(level_2_dim_methy, input_dim_methy_array[4], activation=2) |
|
|
228 |
self.d_fc1_methy_6 = self.fc_layer(level_2_dim_methy, input_dim_methy_array[5], activation=2) |
|
|
229 |
self.d_fc1_methy_7 = self.fc_layer(level_2_dim_methy, input_dim_methy_array[6], activation=2) |
|
|
230 |
self.d_fc1_methy_8 = self.fc_layer(level_2_dim_methy, input_dim_methy_array[7], activation=2) |
|
|
231 |
self.d_fc1_methy_9 = self.fc_layer(level_2_dim_methy, input_dim_methy_array[8], activation=2) |
|
|
232 |
self.d_fc1_methy_10 = self.fc_layer(level_2_dim_methy, input_dim_methy_array[9], activation=2) |
|
|
233 |
self.d_fc1_methy_11 = self.fc_layer(level_2_dim_methy, input_dim_methy_array[10], activation=2) |
|
|
234 |
self.d_fc1_methy_12 = self.fc_layer(level_2_dim_methy, input_dim_methy_array[11], activation=2) |
|
|
235 |
self.d_fc1_methy_13 = self.fc_layer(level_2_dim_methy, input_dim_methy_array[12], activation=2) |
|
|
236 |
self.d_fc1_methy_14 = self.fc_layer(level_2_dim_methy, input_dim_methy_array[13], activation=2) |
|
|
237 |
self.d_fc1_methy_15 = self.fc_layer(level_2_dim_methy, input_dim_methy_array[14], activation=2) |
|
|
238 |
self.d_fc1_methy_16 = self.fc_layer(level_2_dim_methy, input_dim_methy_array[15], activation=2) |
|
|
239 |
self.d_fc1_methy_17 = self.fc_layer(level_2_dim_methy, input_dim_methy_array[16], activation=2) |
|
|
240 |
self.d_fc1_methy_18 = self.fc_layer(level_2_dim_methy, input_dim_methy_array[17], activation=2) |
|
|
241 |
self.d_fc1_methy_19 = self.fc_layer(level_2_dim_methy, input_dim_methy_array[18], activation=2) |
|
|
242 |
self.d_fc1_methy_20 = self.fc_layer(level_2_dim_methy, input_dim_methy_array[19], activation=2) |
|
|
243 |
self.d_fc1_methy_21 = self.fc_layer(level_2_dim_methy, input_dim_methy_array[20], activation=2) |
|
|
244 |
self.d_fc1_methy_22 = self.fc_layer(level_2_dim_methy, input_dim_methy_array[21], activation=2) |
|
|
245 |
self.d_fc1_methy_X = self.fc_layer(level_2_dim_methy, input_dim_methy_array[22], activation=2) |
|
|
246 |
# Expr output |
|
|
247 |
self.d_fc1_expr = self.fc_layer(level_2_dim_expr, input_dim_expr, activation=2) |
|
|
248 |
|
|
|
249 |
# model parallelism |
|
|
250 |
if parallel: |
|
|
251 |
self.d_fc4.to('cuda:1') |
|
|
252 |
self.d_fc3.to('cuda:1') |
|
|
253 |
self.d_fc2_methy.to('cuda:1') |
|
|
254 |
self.d_fc2_expr.to('cuda:1') |
|
|
255 |
self.d_fc1_methy_1.to('cuda:1') |
|
|
256 |
self.d_fc1_methy_2.to('cuda:1') |
|
|
257 |
self.d_fc1_methy_3.to('cuda:1') |
|
|
258 |
self.d_fc1_methy_4.to('cuda:1') |
|
|
259 |
self.d_fc1_methy_5.to('cuda:1') |
|
|
260 |
self.d_fc1_methy_6.to('cuda:1') |
|
|
261 |
self.d_fc1_methy_7.to('cuda:1') |
|
|
262 |
self.d_fc1_methy_8.to('cuda:1') |
|
|
263 |
self.d_fc1_methy_9.to('cuda:1') |
|
|
264 |
self.d_fc1_methy_10.to('cuda:1') |
|
|
265 |
self.d_fc1_methy_11.to('cuda:1') |
|
|
266 |
self.d_fc1_methy_12.to('cuda:1') |
|
|
267 |
self.d_fc1_methy_13.to('cuda:1') |
|
|
268 |
self.d_fc1_methy_14.to('cuda:1') |
|
|
269 |
self.d_fc1_methy_15.to('cuda:1') |
|
|
270 |
self.d_fc1_methy_16.to('cuda:1') |
|
|
271 |
self.d_fc1_methy_17.to('cuda:1') |
|
|
272 |
self.d_fc1_methy_18.to('cuda:1') |
|
|
273 |
self.d_fc1_methy_19.to('cuda:1') |
|
|
274 |
self.d_fc1_methy_20.to('cuda:1') |
|
|
275 |
self.d_fc1_methy_21.to('cuda:1') |
|
|
276 |
self.d_fc1_methy_22.to('cuda:1') |
|
|
277 |
self.d_fc1_methy_X.to('cuda:1') |
|
|
278 |
self.d_fc1_expr.to('cuda:1') |
|
|
279 |
|
|
|
280 |
# CLASSIFIER fc layers |
|
|
281 |
self.c_fc1 = self.fc_layer(latent_space_dim, classifier_1_dim) |
|
|
282 |
self.c_fc2 = self.fc_layer(classifier_1_dim, classifier_2_dim) |
|
|
283 |
# self.c_fc2 = self.fc_layer(classifier_1_dim, classifier_2_dim, dropout=True) |
|
|
284 |
self.c_fc3 = self.fc_layer(classifier_2_dim, classifier_out_dim, activation=0) |
|
|
285 |
|
|
|
286 |
# model parallelism |
|
|
287 |
if parallel: |
|
|
288 |
self.c_fc1.to('cuda:1') |
|
|
289 |
self.c_fc2.to('cuda:1') |
|
|
290 |
self.c_fc3.to('cuda:1') |
|
|
291 |
|
|
|
292 |
# Activation - 0: no activation, 1: ReLU, 2: Sigmoid |
|
|
293 |
def fc_layer(self, in_dim, out_dim, activation=1, dropout=False, dropout_p=0.5): |
|
|
294 |
if activation == 0: |
|
|
295 |
layer = nn.Sequential( |
|
|
296 |
nn.Linear(in_dim, out_dim), |
|
|
297 |
nn.BatchNorm1d(out_dim)) |
|
|
298 |
elif activation == 2: |
|
|
299 |
layer = nn.Sequential( |
|
|
300 |
nn.Linear(in_dim, out_dim), |
|
|
301 |
nn.BatchNorm1d(out_dim), |
|
|
302 |
nn.Sigmoid()) |
|
|
303 |
else: |
|
|
304 |
if dropout: |
|
|
305 |
layer = nn.Sequential( |
|
|
306 |
nn.Linear(in_dim, out_dim), |
|
|
307 |
nn.BatchNorm1d(out_dim), |
|
|
308 |
nn.ReLU(), |
|
|
309 |
nn.Dropout(p=dropout_p)) |
|
|
310 |
else: |
|
|
311 |
layer = nn.Sequential( |
|
|
312 |
nn.Linear(in_dim, out_dim), |
|
|
313 |
nn.BatchNorm1d(out_dim), |
|
|
314 |
nn.ReLU()) |
|
|
315 |
return layer |
|
|
316 |
|
|
|
317 |
def encode(self, x): |
|
|
318 |
methy_1_level2_layer = self.e_fc1_methy_1(x[1]) |
|
|
319 |
methy_2_level2_layer = self.e_fc1_methy_2(x[2]) |
|
|
320 |
methy_3_level2_layer = self.e_fc1_methy_3(x[3]) |
|
|
321 |
methy_4_level2_layer = self.e_fc1_methy_4(x[4]) |
|
|
322 |
methy_5_level2_layer = self.e_fc1_methy_5(x[5]) |
|
|
323 |
methy_6_level2_layer = self.e_fc1_methy_6(x[6]) |
|
|
324 |
methy_7_level2_layer = self.e_fc1_methy_7(x[7]) |
|
|
325 |
methy_8_level2_layer = self.e_fc1_methy_8(x[8]) |
|
|
326 |
methy_9_level2_layer = self.e_fc1_methy_9(x[9]) |
|
|
327 |
methy_10_level2_layer = self.e_fc1_methy_10(x[10]) |
|
|
328 |
methy_11_level2_layer = self.e_fc1_methy_11(x[11]) |
|
|
329 |
methy_12_level2_layer = self.e_fc1_methy_12(x[12]) |
|
|
330 |
methy_13_level2_layer = self.e_fc1_methy_13(x[13]) |
|
|
331 |
methy_14_level2_layer = self.e_fc1_methy_14(x[14]) |
|
|
332 |
methy_15_level2_layer = self.e_fc1_methy_15(x[15]) |
|
|
333 |
methy_16_level2_layer = self.e_fc1_methy_16(x[16]) |
|
|
334 |
methy_17_level2_layer = self.e_fc1_methy_17(x[17]) |
|
|
335 |
methy_18_level2_layer = self.e_fc1_methy_18(x[18]) |
|
|
336 |
methy_19_level2_layer = self.e_fc1_methy_19(x[19]) |
|
|
337 |
methy_20_level2_layer = self.e_fc1_methy_20(x[20]) |
|
|
338 |
methy_21_level2_layer = self.e_fc1_methy_21(x[21]) |
|
|
339 |
methy_22_level2_layer = self.e_fc1_methy_22(x[22]) |
|
|
340 |
methy_X_level2_layer = self.e_fc1_methy_X(x[23]) |
|
|
341 |
|
|
|
342 |
# concat methy tensor together |
|
|
343 |
methy_level2_layer = torch.cat((methy_1_level2_layer, methy_2_level2_layer, methy_3_level2_layer, |
|
|
344 |
methy_4_level2_layer, methy_5_level2_layer, methy_6_level2_layer, |
|
|
345 |
methy_7_level2_layer, methy_8_level2_layer, methy_9_level2_layer, |
|
|
346 |
methy_10_level2_layer, methy_11_level2_layer, methy_12_level2_layer, |
|
|
347 |
methy_13_level2_layer, methy_14_level2_layer, methy_15_level2_layer, |
|
|
348 |
methy_16_level2_layer, methy_17_level2_layer, methy_18_level2_layer, |
|
|
349 |
methy_19_level2_layer, methy_20_level2_layer, methy_21_level2_layer, |
|
|
350 |
methy_22_level2_layer, methy_X_level2_layer), 1) |
|
|
351 |
|
|
|
352 |
expr_level2_layer = self.e_fc1_expr(x[0]) |
|
|
353 |
|
|
|
354 |
methy_level3_layer = self.e_fc2_methy(methy_level2_layer) |
|
|
355 |
expr_level3_layer = self.e_fc2_expr(expr_level2_layer) |
|
|
356 |
level_3_layer = torch.cat((methy_level3_layer, expr_level3_layer), 1) |
|
|
357 |
|
|
|
358 |
level_4_layer = self.e_fc3(level_3_layer) |
|
|
359 |
|
|
|
360 |
latent_mean = self.e_fc4_mean(level_4_layer) |
|
|
361 |
latent_log_var = self.e_fc4_log_var(level_4_layer) |
|
|
362 |
|
|
|
363 |
return latent_mean, latent_log_var |
|
|
364 |
|
|
|
365 |
def reparameterize(self, mean, log_var): |
|
|
366 |
sigma = torch.exp(0.5 * log_var) |
|
|
367 |
eps = torch.randn_like(sigma) |
|
|
368 |
return mean + eps * sigma |
|
|
369 |
|
|
|
370 |
def decode(self, z): |
|
|
371 |
level_4_layer = self.d_fc4(z) |
|
|
372 |
|
|
|
373 |
level_3_layer = self.d_fc3(level_4_layer) |
|
|
374 |
methy_level3_layer = level_3_layer.narrow(1, 0, level_3_dim_methy) |
|
|
375 |
expr_level3_layer = level_3_layer.narrow(1, level_3_dim_methy, level_3_dim_expr) |
|
|
376 |
|
|
|
377 |
methy_level2_layer = self.d_fc2_methy(methy_level3_layer) |
|
|
378 |
methy_1_level2_layer = methy_level2_layer.narrow(1, 0, level_2_dim_methy) |
|
|
379 |
methy_2_level2_layer = methy_level2_layer.narrow(1, level_2_dim_methy, level_2_dim_methy) |
|
|
380 |
methy_3_level2_layer = methy_level2_layer.narrow(1, level_2_dim_methy*2, level_2_dim_methy) |
|
|
381 |
methy_4_level2_layer = methy_level2_layer.narrow(1, level_2_dim_methy*3, level_2_dim_methy) |
|
|
382 |
methy_5_level2_layer = methy_level2_layer.narrow(1, level_2_dim_methy*4, level_2_dim_methy) |
|
|
383 |
methy_6_level2_layer = methy_level2_layer.narrow(1, level_2_dim_methy*5, level_2_dim_methy) |
|
|
384 |
methy_7_level2_layer = methy_level2_layer.narrow(1, level_2_dim_methy*6, level_2_dim_methy) |
|
|
385 |
methy_8_level2_layer = methy_level2_layer.narrow(1, level_2_dim_methy*7, level_2_dim_methy) |
|
|
386 |
methy_9_level2_layer = methy_level2_layer.narrow(1, level_2_dim_methy*8, level_2_dim_methy) |
|
|
387 |
methy_10_level2_layer = methy_level2_layer.narrow(1, level_2_dim_methy*9, level_2_dim_methy) |
|
|
388 |
methy_11_level2_layer = methy_level2_layer.narrow(1, level_2_dim_methy*10, level_2_dim_methy) |
|
|
389 |
methy_12_level2_layer = methy_level2_layer.narrow(1, level_2_dim_methy*11, level_2_dim_methy) |
|
|
390 |
methy_13_level2_layer = methy_level2_layer.narrow(1, level_2_dim_methy*12, level_2_dim_methy) |
|
|
391 |
methy_14_level2_layer = methy_level2_layer.narrow(1, level_2_dim_methy*13, level_2_dim_methy) |
|
|
392 |
methy_15_level2_layer = methy_level2_layer.narrow(1, level_2_dim_methy*14, level_2_dim_methy) |
|
|
393 |
methy_16_level2_layer = methy_level2_layer.narrow(1, level_2_dim_methy*15, level_2_dim_methy) |
|
|
394 |
methy_17_level2_layer = methy_level2_layer.narrow(1, level_2_dim_methy*16, level_2_dim_methy) |
|
|
395 |
methy_18_level2_layer = methy_level2_layer.narrow(1, level_2_dim_methy*17, level_2_dim_methy) |
|
|
396 |
methy_19_level2_layer = methy_level2_layer.narrow(1, level_2_dim_methy*18, level_2_dim_methy) |
|
|
397 |
methy_20_level2_layer = methy_level2_layer.narrow(1, level_2_dim_methy*19, level_2_dim_methy) |
|
|
398 |
methy_21_level2_layer = methy_level2_layer.narrow(1, level_2_dim_methy*20, level_2_dim_methy) |
|
|
399 |
methy_22_level2_layer = methy_level2_layer.narrow(1, level_2_dim_methy*21, level_2_dim_methy) |
|
|
400 |
methy_X_level2_layer = methy_level2_layer.narrow(1, level_2_dim_methy*22, level_2_dim_methy) |
|
|
401 |
|
|
|
402 |
expr_level2_layer = self.d_fc2_expr(expr_level3_layer) |
|
|
403 |
|
|
|
404 |
recon_x0 = self.d_fc1_expr(expr_level2_layer) |
|
|
405 |
|
|
|
406 |
recon_x1 = self.d_fc1_methy_1(methy_1_level2_layer) |
|
|
407 |
recon_x2 = self.d_fc1_methy_2(methy_2_level2_layer) |
|
|
408 |
recon_x3 = self.d_fc1_methy_3(methy_3_level2_layer) |
|
|
409 |
recon_x4 = self.d_fc1_methy_4(methy_4_level2_layer) |
|
|
410 |
recon_x5 = self.d_fc1_methy_5(methy_5_level2_layer) |
|
|
411 |
recon_x6 = self.d_fc1_methy_6(methy_6_level2_layer) |
|
|
412 |
recon_x7 = self.d_fc1_methy_7(methy_7_level2_layer) |
|
|
413 |
recon_x8 = self.d_fc1_methy_8(methy_8_level2_layer) |
|
|
414 |
recon_x9 = self.d_fc1_methy_9(methy_9_level2_layer) |
|
|
415 |
recon_x10 = self.d_fc1_methy_10(methy_10_level2_layer) |
|
|
416 |
recon_x11 = self.d_fc1_methy_11(methy_11_level2_layer) |
|
|
417 |
recon_x12 = self.d_fc1_methy_12(methy_12_level2_layer) |
|
|
418 |
recon_x13 = self.d_fc1_methy_13(methy_13_level2_layer) |
|
|
419 |
recon_x14 = self.d_fc1_methy_14(methy_14_level2_layer) |
|
|
420 |
recon_x15 = self.d_fc1_methy_15(methy_15_level2_layer) |
|
|
421 |
recon_x16 = self.d_fc1_methy_16(methy_16_level2_layer) |
|
|
422 |
recon_x17 = self.d_fc1_methy_17(methy_17_level2_layer) |
|
|
423 |
recon_x18 = self.d_fc1_methy_18(methy_18_level2_layer) |
|
|
424 |
recon_x19 = self.d_fc1_methy_19(methy_19_level2_layer) |
|
|
425 |
recon_x20 = self.d_fc1_methy_20(methy_20_level2_layer) |
|
|
426 |
recon_x21 = self.d_fc1_methy_21(methy_21_level2_layer) |
|
|
427 |
recon_x22 = self.d_fc1_methy_22(methy_22_level2_layer) |
|
|
428 |
recon_x23 = self.d_fc1_methy_X(methy_X_level2_layer) |
|
|
429 |
|
|
|
430 |
return [recon_x0, recon_x1, recon_x2, recon_x3, recon_x4, recon_x5, recon_x6, recon_x7, recon_x8, recon_x9, |
|
|
431 |
recon_x10, recon_x11, recon_x12, recon_x13, recon_x14, recon_x15, recon_x16, recon_x17, recon_x18, |
|
|
432 |
recon_x19, recon_x20, recon_x21, recon_x22, recon_x23] |
|
|
433 |
|
|
|
434 |
def classifier(self, mean): |
|
|
435 |
level_1_layer = self.c_fc1(mean) |
|
|
436 |
level_2_layer = self.c_fc2(level_1_layer) |
|
|
437 |
output_layer = self.c_fc3(level_2_layer) |
|
|
438 |
return output_layer |
|
|
439 |
|
|
|
440 |
def forward(self, x): |
|
|
441 |
mean, log_var = self.encode(x) |
|
|
442 |
z = self.reparameterize(mean, log_var) |
|
|
443 |
classifier_x = mean |
|
|
444 |
if parallel: |
|
|
445 |
z = z.to('cuda:1') |
|
|
446 |
classifier_x = classifier_x.to('cuda:1') |
|
|
447 |
recon_x = self.decode(z) |
|
|
448 |
pred_y = self.classifier(classifier_x) |
|
|
449 |
return z, recon_x, mean, log_var, pred_y |
|
|
450 |
|
|
|
451 |
# Instantiate VAE |
|
|
452 |
if parallel: |
|
|
453 |
vae_model = VAE() |
|
|
454 |
else: |
|
|
455 |
vae_model = VAE().to(device) |
|
|
456 |
|
|
|
457 |
# Early Stopping |
|
|
458 |
if early_stopping: |
|
|
459 |
early_stop_ob = Earlystopping() |
|
|
460 |
|
|
|
461 |
# Tensorboard writer |
|
|
462 |
train_writer = SummaryWriter(log_dir='logs/train') |
|
|
463 |
val_writer = SummaryWriter(log_dir='logs/val') |
|
|
464 |
|
|
|
465 |
# print the model information |
|
|
466 |
# print('\nModel information:') |
|
|
467 |
# print(vae_model) |
|
|
468 |
total_params = sum(params.numel() for params in vae_model.parameters()) |
|
|
469 |
print('Number of parameters: {}'.format(total_params)) |
|
|
470 |
|
|
|
471 |
optimizer = optim.Adam(vae_model.parameters(), lr=learning_rate) |
|
|
472 |
|
|
|
473 |
def methy_recon_loss(recon_x, x): |
|
|
474 |
loss = F.binary_cross_entropy(recon_x[1], x[1], reduction='sum') |
|
|
475 |
for i in range(2, 24): |
|
|
476 |
loss += F.binary_cross_entropy(recon_x[i], x[i], reduction='sum') |
|
|
477 |
loss /= 23 |
|
|
478 |
return loss |
|
|
479 |
|
|
|
480 |
def expr_recon_loss(recon_x, x): |
|
|
481 |
loss = F.binary_cross_entropy(recon_x[0], x[0], reduction='sum') |
|
|
482 |
return loss |
|
|
483 |
|
|
|
484 |
def kl_loss(mean, log_var): |
|
|
485 |
loss = -0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp()) |
|
|
486 |
return loss |
|
|
487 |
|
|
|
488 |
def classifier_loss(pred_y, y): |
|
|
489 |
loss = F.cross_entropy(pred_y, y, reduction='sum') |
|
|
490 |
return loss |
|
|
491 |
|
|
|
492 |
# k_methy_recon = 1 |
|
|
493 |
# k_expr_recon = 1 |
|
|
494 |
# k_kl = 1 |
|
|
495 |
# k_class = 1 |
|
|
496 |
|
|
|
497 |
# loss record |
|
|
498 |
loss_array = np.zeros(shape=(11, p1_epoch_num+p2_epoch_num+1)) |
|
|
499 |
# performance metrics |
|
|
500 |
metrics_array = np.zeros(4) |
|
|
501 |
|
|
|
502 |
def train(e_index, e_num, k_methy_recon, k_expr_recon, k_kl, k_c): |
|
|
503 |
vae_model.train() |
|
|
504 |
train_methy_recon = 0 |
|
|
505 |
train_expr_recon = 0 |
|
|
506 |
train_kl = 0 |
|
|
507 |
train_classifier = 0 |
|
|
508 |
train_correct_num = 0 |
|
|
509 |
train_total_loss = 0 |
|
|
510 |
for batch_index, sample in enumerate(train_loader): |
|
|
511 |
data = sample[0] |
|
|
512 |
y = sample[1] |
|
|
513 |
for chr_i in range(24): |
|
|
514 |
data[chr_i] = data[chr_i].to(device) |
|
|
515 |
y = y.to(device) |
|
|
516 |
optimizer.zero_grad() |
|
|
517 |
_, recon_data, mean, log_var, pred_y = vae_model(data) |
|
|
518 |
if parallel: |
|
|
519 |
for chr_i in range(24): |
|
|
520 |
recon_data[chr_i] = recon_data[chr_i].to('cuda:0') |
|
|
521 |
pred_y = pred_y.to('cuda:0') |
|
|
522 |
|
|
|
523 |
methy_recon = methy_recon_loss(recon_data, data) |
|
|
524 |
expr_recon = expr_recon_loss(recon_data, data) |
|
|
525 |
kl = kl_loss(mean, log_var) |
|
|
526 |
class_loss = classifier_loss(pred_y, y) |
|
|
527 |
loss = k_methy_recon * methy_recon + k_expr_recon * expr_recon + k_kl * kl + k_c * class_loss |
|
|
528 |
|
|
|
529 |
loss.backward() |
|
|
530 |
|
|
|
531 |
with torch.no_grad(): |
|
|
532 |
pred_y_softmax = F.softmax(pred_y, dim=1) |
|
|
533 |
_, predicted = torch.max(pred_y_softmax, 1) |
|
|
534 |
correct = (predicted == y).sum().item() |
|
|
535 |
|
|
|
536 |
train_methy_recon += methy_recon.item() |
|
|
537 |
train_expr_recon += expr_recon.item() |
|
|
538 |
train_kl += kl.item() |
|
|
539 |
train_classifier += class_loss.item() |
|
|
540 |
train_correct_num += correct |
|
|
541 |
train_total_loss += loss.item() |
|
|
542 |
|
|
|
543 |
optimizer.step() |
|
|
544 |
|
|
|
545 |
# if batch_index % 15 == 0: |
|
|
546 |
# print('Epoch {:3d}/{:3d} --- [{:5d}/{:5d}] ({:2d}%)\n' |
|
|
547 |
# 'Methy Recon Loss: {:.2f} Expr Recon Loss: {:.2f} KL Loss: {:.2f} ' |
|
|
548 |
# 'Classification Loss: {:.2f}\nACC: {:.2f}%'.format( |
|
|
549 |
# e_index + 1, e_num, batch_index * len(data[0]), len(train_dataset), |
|
|
550 |
# round(100. * batch_index / len(train_loader)), methy_recon.item() / len(data[0]), |
|
|
551 |
# expr_recon.item() / len(data[0]), kl.item() / len(data[0]), class_loss.item() / len(data[0]), |
|
|
552 |
# correct / len(data[0]) * 100)) |
|
|
553 |
|
|
|
554 |
train_methy_recon_ave = train_methy_recon / len(train_dataset) |
|
|
555 |
train_expr_recon_ave = train_expr_recon / len(train_dataset) |
|
|
556 |
train_kl_ave = train_kl / len(train_dataset) |
|
|
557 |
train_classifier_ave = train_classifier / len(train_dataset) |
|
|
558 |
train_accuracy = train_correct_num / len(train_dataset) * 100 |
|
|
559 |
train_total_loss_ave = train_total_loss / len(train_dataset) |
|
|
560 |
|
|
|
561 |
print('Epoch {:3d}/{:3d}\n' |
|
|
562 |
'Training\n' |
|
|
563 |
'Methy Recon Loss: {:.2f} Expr Recon Loss: {:.2f} KL Loss: {:.2f} ' |
|
|
564 |
'Classification Loss: {:.2f}\nACC: {:.2f}%'. |
|
|
565 |
format(e_index + 1, e_num, train_methy_recon_ave, train_expr_recon_ave, train_kl_ave, |
|
|
566 |
train_classifier_ave, train_accuracy)) |
|
|
567 |
loss_array[0, e_index] = train_methy_recon_ave |
|
|
568 |
loss_array[1, e_index] = train_expr_recon_ave |
|
|
569 |
loss_array[2, e_index] = train_kl_ave |
|
|
570 |
loss_array[3, e_index] = train_classifier_ave |
|
|
571 |
loss_array[4, e_index] = train_accuracy |
|
|
572 |
|
|
|
573 |
# TB |
|
|
574 |
train_writer.add_scalar('Total loss', train_total_loss_ave, e_index) |
|
|
575 |
train_writer.add_scalar('Methy recon loss', train_methy_recon_ave, e_index) |
|
|
576 |
train_writer.add_scalar('Expr recon loss', train_expr_recon_ave, e_index) |
|
|
577 |
train_writer.add_scalar('KL loss', train_kl_ave, e_index) |
|
|
578 |
train_writer.add_scalar('Classification loss', train_classifier_ave, e_index) |
|
|
579 |
train_writer.add_scalar('Accuracy', train_accuracy, e_index) |
|
|
580 |
|
|
|
581 |
if separate_testing: |
|
|
582 |
def val(e_index, get_metrics=False): |
|
|
583 |
vae_model.eval() |
|
|
584 |
val_methy_recon = 0 |
|
|
585 |
val_expr_recon = 0 |
|
|
586 |
val_kl = 0 |
|
|
587 |
val_classifier = 0 |
|
|
588 |
val_correct_num = 0 |
|
|
589 |
val_total_loss = 0 |
|
|
590 |
y_store = torch.tensor([0]) |
|
|
591 |
predicted_store = torch.tensor([0]) |
|
|
592 |
|
|
|
593 |
with torch.no_grad(): |
|
|
594 |
for batch_index, sample in enumerate(val_loader): |
|
|
595 |
data = sample[0] |
|
|
596 |
y = sample[1] |
|
|
597 |
for chr_i in range(24): |
|
|
598 |
data[chr_i] = data[chr_i].to(device) |
|
|
599 |
y = y.to(device) |
|
|
600 |
_, recon_data, mean, log_var, pred_y = vae_model(data) |
|
|
601 |
if parallel: |
|
|
602 |
for chr_i in range(24): |
|
|
603 |
recon_data[chr_i] = recon_data[chr_i].to('cuda:0') |
|
|
604 |
pred_y = pred_y.to('cuda:0') |
|
|
605 |
|
|
|
606 |
methy_recon = methy_recon_loss(recon_data, data) |
|
|
607 |
expr_recon = expr_recon_loss(recon_data, data) |
|
|
608 |
kl = kl_loss(mean, log_var) |
|
|
609 |
class_loss = classifier_loss(pred_y, y) |
|
|
610 |
loss = methy_recon + expr_recon + kl + class_loss |
|
|
611 |
|
|
|
612 |
pred_y_softmax = F.softmax(pred_y, dim=1) |
|
|
613 |
_, predicted = torch.max(pred_y_softmax, 1) |
|
|
614 |
correct = (predicted == y).sum().item() |
|
|
615 |
|
|
|
616 |
y_store = torch.cat((y_store, y.cpu())) |
|
|
617 |
predicted_store = torch.cat((predicted_store, predicted.cpu())) |
|
|
618 |
|
|
|
619 |
val_methy_recon += methy_recon.item() |
|
|
620 |
val_expr_recon += expr_recon.item() |
|
|
621 |
val_kl += kl.item() |
|
|
622 |
val_classifier += class_loss.item() |
|
|
623 |
val_correct_num += correct |
|
|
624 |
val_total_loss += loss.item() |
|
|
625 |
|
|
|
626 |
output_y = y_store[1:].numpy() |
|
|
627 |
output_pred_y = predicted_store[1:].numpy() |
|
|
628 |
|
|
|
629 |
if get_metrics: |
|
|
630 |
metrics_array[0] = metrics.accuracy_score(output_y, output_pred_y) |
|
|
631 |
metrics_array[1] = metrics.precision_score(output_y, output_pred_y, average='weighted') |
|
|
632 |
metrics_array[2] = metrics.recall_score(output_y, output_pred_y, average='weighted') |
|
|
633 |
metrics_array[3] = metrics.f1_score(output_y, output_pred_y, average='weighted') |
|
|
634 |
|
|
|
635 |
val_methy_recon_ave = val_methy_recon / len(val_dataset) |
|
|
636 |
val_expr_recon_ave = val_expr_recon / len(val_dataset) |
|
|
637 |
val_kl_ave = val_kl / len(val_dataset) |
|
|
638 |
val_classifier_ave = val_classifier / len(val_dataset) |
|
|
639 |
val_accuracy = val_correct_num / len(val_dataset) * 100 |
|
|
640 |
val_total_loss_ave = val_total_loss / len(val_dataset) |
|
|
641 |
|
|
|
642 |
print('Validation\n' |
|
|
643 |
'Methy Recon Loss: {:.2f} Expr Recon Loss: {:.2f} KL Loss: {:.2f} Classification Loss: {:.2f}' |
|
|
644 |
'\nACC: {:.2f}%\n'. |
|
|
645 |
format(val_methy_recon_ave, val_expr_recon_ave, val_kl_ave, val_classifier_ave, val_accuracy)) |
|
|
646 |
loss_array[5, e_index] = val_methy_recon_ave |
|
|
647 |
loss_array[6, e_index] = val_expr_recon_ave |
|
|
648 |
loss_array[7, e_index] = val_kl_ave |
|
|
649 |
loss_array[8, e_index] = val_classifier_ave |
|
|
650 |
loss_array[9, e_index] = val_accuracy |
|
|
651 |
|
|
|
652 |
# TB |
|
|
653 |
val_writer.add_scalar('Total loss', val_total_loss_ave, e_index) |
|
|
654 |
val_writer.add_scalar('Methy recon loss', val_methy_recon_ave, e_index) |
|
|
655 |
val_writer.add_scalar('Expr recon loss', val_expr_recon_ave, e_index) |
|
|
656 |
val_writer.add_scalar('KL loss', val_kl_ave, e_index) |
|
|
657 |
val_writer.add_scalar('Classification loss', val_classifier_ave, e_index) |
|
|
658 |
val_writer.add_scalar('Accuracy', val_accuracy, e_index) |
|
|
659 |
|
|
|
660 |
return val_accuracy, output_pred_y |
|
|
661 |
|
|
|
662 |
print('\nUNSUPERVISED PHASE\n') |
|
|
663 |
# unsupervised phase |
|
|
664 |
for epoch_index in range(p1_epoch_num): |
|
|
665 |
train(e_index=epoch_index, e_num=p1_epoch_num+p2_epoch_num, k_methy_recon=1, k_expr_recon=1, k_kl=1, k_c=0) |
|
|
666 |
if separate_testing: |
|
|
667 |
_, out_pred_y = val(epoch_index) |
|
|
668 |
|
|
|
669 |
print('\nSUPERVISED PHASE\n') |
|
|
670 |
# supervised phase |
|
|
671 |
epoch_number = p1_epoch_num |
|
|
672 |
for epoch_index in range(p1_epoch_num, p1_epoch_num+p2_epoch_num): |
|
|
673 |
epoch_number += 1 |
|
|
674 |
train(e_index=epoch_index, e_num=p1_epoch_num+p2_epoch_num, k_methy_recon=0, k_expr_recon=0, k_kl=0, k_c=1) |
|
|
675 |
if separate_testing: |
|
|
676 |
if epoch_index == p1_epoch_num+p2_epoch_num-1: |
|
|
677 |
val_classification_acc, out_pred_y = val(epoch_index, get_metrics=True) |
|
|
678 |
else: |
|
|
679 |
val_classification_acc, out_pred_y = val(epoch_index) |
|
|
680 |
if early_stopping: |
|
|
681 |
early_stop_ob(vae_model, val_classification_acc) |
|
|
682 |
if early_stop_ob.stop_now: |
|
|
683 |
print('Early stopping\n') |
|
|
684 |
break |
|
|
685 |
|
|
|
686 |
if early_stopping: |
|
|
687 |
best_epoch = p1_epoch_num + early_stop_ob.best_epoch_num |
|
|
688 |
loss_array[10, 0] = best_epoch |
|
|
689 |
print('Load model of Epoch {:d}'.format(best_epoch)) |
|
|
690 |
vae_model.load_state_dict(torch.load('../ssd/checkpoint.pt')) |
|
|
691 |
_, out_pred_y = val(epoch_number, get_metrics=True) |
|
|
692 |
|
|
|
693 |
# Encode all of the data into the latent space |
|
|
694 |
print('Encoding all the data into latent space...') |
|
|
695 |
vae_model.eval() |
|
|
696 |
with torch.no_grad(): |
|
|
697 |
d_z_store = torch.zeros(1, latent_dim).to(device) |
|
|
698 |
for i, sample in enumerate(full_loader): |
|
|
699 |
d = sample[0] |
|
|
700 |
for chr_i in range(24): |
|
|
701 |
d[chr_i] = d[chr_i].to(device) |
|
|
702 |
_, _, d_z, _, _ = vae_model(d) |
|
|
703 |
d_z_store = torch.cat((d_z_store, d_z), 0) |
|
|
704 |
all_data_z = d_z_store[1:] |
|
|
705 |
all_data_z_np = all_data_z.cpu().numpy() |
|
|
706 |
|
|
|
707 |
# Output file |
|
|
708 |
print('Preparing the output files... ') |
|
|
709 |
input_path_name = input_path.split('/')[-1] |
|
|
710 |
latent_space_path = 'results/' + input_path_name + str(latent_dim) + 'D_latent_space.tsv' |
|
|
711 |
|
|
|
712 |
# Whether variable sample_id exists |
|
|
713 |
all_data_z_df = pd.DataFrame(all_data_z_np, index=sample_id) |
|
|
714 |
all_data_z_df.to_csv(latent_space_path, sep='\t') |
|
|
715 |
|
|
|
716 |
if separate_testing: |
|
|
717 |
pred_y_path = 'results/' + input_path_name + str(latent_dim) + 'D_pred_y.tsv' |
|
|
718 |
np.savetxt(pred_y_path, out_pred_y, delimiter='\t') |
|
|
719 |
|
|
|
720 |
metrics_record_path = 'results/' + input_path_name + str(latent_dim) + 'D_metrics.tsv' |
|
|
721 |
np.savetxt(metrics_record_path, metrics_array, delimiter='\t') |
|
|
722 |
|
|
|
723 |
if output_loss_record: |
|
|
724 |
loss_record_path = 'results/' + input_path_name + str(latent_dim) + 'D_loss_record.tsv' |
|
|
725 |
np.savetxt(loss_record_path, loss_array, delimiter='\t') |
|
|
726 |
|
|
|
727 |
return all_data_z_df |