|
a |
|
b/util/visualizer.py |
|
|
1 |
import os |
|
|
2 |
import time |
|
|
3 |
import numpy as np |
|
|
4 |
import pandas as pd |
|
|
5 |
import sklearn as sk |
|
|
6 |
from sklearn.preprocessing import label_binarize |
|
|
7 |
from util import util |
|
|
8 |
from util import metrics |
|
|
9 |
from torch.utils.tensorboard import SummaryWriter |
|
|
10 |
|
|
|
11 |
|
|
|
12 |
class Visualizer: |
|
|
13 |
""" |
|
|
14 |
This class print/save logging information |
|
|
15 |
""" |
|
|
16 |
|
|
|
17 |
def __init__(self, param): |
|
|
18 |
""" |
|
|
19 |
Initialize the Visualizer class |
|
|
20 |
""" |
|
|
21 |
self.param = param |
|
|
22 |
self.output_path = os.path.join(param.checkpoints_dir, param.experiment_name) |
|
|
23 |
tb_dir = os.path.join(self.output_path, 'tb_log') |
|
|
24 |
util.mkdir(tb_dir) |
|
|
25 |
|
|
|
26 |
if param.isTrain: |
|
|
27 |
# Create a logging file to store training losses |
|
|
28 |
self.train_log_filename = os.path.join(self.output_path, 'train_log.txt') |
|
|
29 |
with open(self.train_log_filename, 'a') as log_file: |
|
|
30 |
now = time.strftime('%c') |
|
|
31 |
log_file.write('----------------------- Training Log ({:s}) -----------------------\n'.format(now)) |
|
|
32 |
|
|
|
33 |
self.train_summary_filename = os.path.join(self.output_path, 'train_summary.txt') |
|
|
34 |
with open(self.train_summary_filename, 'a') as log_file: |
|
|
35 |
now = time.strftime('%c') |
|
|
36 |
log_file.write('----------------------- Training Summary ({:s}) -----------------------\n'.format(now)) |
|
|
37 |
|
|
|
38 |
# Create log folder for TensorBoard |
|
|
39 |
tb_train_dir = os.path.join(self.output_path, 'tb_log', 'train') |
|
|
40 |
util.mkdir(tb_train_dir) |
|
|
41 |
util.clear_dir(tb_train_dir) |
|
|
42 |
|
|
|
43 |
# Create TensorBoard writer |
|
|
44 |
self.train_writer = SummaryWriter(log_dir=tb_train_dir) |
|
|
45 |
|
|
|
46 |
if param.isTest: |
|
|
47 |
# Create a logging file to store testing metrics |
|
|
48 |
self.test_log_filename = os.path.join(self.output_path, 'test_log.txt') |
|
|
49 |
with open(self.test_log_filename, 'a') as log_file: |
|
|
50 |
now = time.strftime('%c') |
|
|
51 |
log_file.write('----------------------- Testing Log ({:s}) -----------------------\n'.format(now)) |
|
|
52 |
|
|
|
53 |
self.test_summary_filename = os.path.join(self.output_path, 'test_summary.txt') |
|
|
54 |
with open(self.test_summary_filename, 'a') as log_file: |
|
|
55 |
now = time.strftime('%c') |
|
|
56 |
log_file.write('----------------------- Testing Summary ({:s}) -----------------------\n'.format(now)) |
|
|
57 |
|
|
|
58 |
# Create log folder for TensorBoard |
|
|
59 |
tb_test_dir = os.path.join(self.output_path, 'tb_log', 'test') |
|
|
60 |
util.mkdir(tb_test_dir) |
|
|
61 |
util.clear_dir(tb_test_dir) |
|
|
62 |
|
|
|
63 |
# Create TensorBoard writer |
|
|
64 |
self.test_writer = SummaryWriter(log_dir=tb_test_dir) |
|
|
65 |
|
|
|
66 |
def print_train_log(self, epoch, iteration, losses_dict, metrics_dict, load_time, comp_time, batch_size, dataset_size, with_time=True): |
|
|
67 |
""" |
|
|
68 |
print train log on console and save the message to the disk |
|
|
69 |
|
|
|
70 |
Parameters: |
|
|
71 |
epoch (int) -- current epoch |
|
|
72 |
iteration (int) -- current training iteration during this epoch |
|
|
73 |
losses_dict (OrderedDict) -- training losses stored in the ordered dict |
|
|
74 |
metrics_dict (OrderedDict) -- metrics stored in the ordered dict |
|
|
75 |
load_time (float) -- data loading time per data point (normalized by batch_size) |
|
|
76 |
comp_time (float) -- computational time per data point (normalized by batch_size) |
|
|
77 |
batch_size (int) -- batch size of training |
|
|
78 |
dataset_size (int) -- size of the training dataset |
|
|
79 |
with_time (bool) -- print the running time or not |
|
|
80 |
""" |
|
|
81 |
data_point_covered = min((iteration + 1) * batch_size, dataset_size) |
|
|
82 |
if with_time: |
|
|
83 |
message = '[TRAIN] [Epoch: {:3d} Iter: {:4d} Load_t: {:.3f} Comp_t: {:.3f}] '.format(epoch, data_point_covered, load_time, comp_time) |
|
|
84 |
else: |
|
|
85 |
message = '[TRAIN] [Epoch: {:3d} Iter: {:4d}]\n'.format(epoch, data_point_covered) |
|
|
86 |
for name, loss in losses_dict.items(): |
|
|
87 |
message += '{:s}: {:.3f} '.format(name, loss[-1]) |
|
|
88 |
for name, metric in metrics_dict.items(): |
|
|
89 |
message += '{:s}: {:.3f} '.format(name, metric) |
|
|
90 |
|
|
|
91 |
print(message) # print the message |
|
|
92 |
|
|
|
93 |
with open(self.train_log_filename, 'a') as log_file: |
|
|
94 |
log_file.write(message + '\n') # save the message |
|
|
95 |
|
|
|
96 |
def print_train_summary(self, epoch, losses_dict, output_dict, train_time, current_lr): |
|
|
97 |
""" |
|
|
98 |
print the summary of this training epoch |
|
|
99 |
|
|
|
100 |
Parameters: |
|
|
101 |
epoch (int) -- epoch number of this training model |
|
|
102 |
losses_dict (OrderedDict) -- the losses dictionary |
|
|
103 |
output_dict (OrderedDict) -- the downstream output dictionary |
|
|
104 |
train_time (float) -- time used for training this epoch |
|
|
105 |
current_lr (float) -- the learning rate of this epoch |
|
|
106 |
""" |
|
|
107 |
write_message = '{:s}\t'.format(str(epoch)) |
|
|
108 |
print_message = '[TRAIN] [Epoch: {:3d}]\n'.format(int(epoch)) |
|
|
109 |
|
|
|
110 |
for name, loss in losses_dict.items(): |
|
|
111 |
write_message += '{:.6f}\t'.format(np.mean(loss)) |
|
|
112 |
print_message += name + ': {:.3f} '.format(np.mean(loss)) |
|
|
113 |
self.train_writer.add_scalar('loss_'+name, np.mean(loss), epoch) |
|
|
114 |
|
|
|
115 |
metrics_dict = self.get_epoch_metrics(output_dict) |
|
|
116 |
for name, metric in metrics_dict.items(): |
|
|
117 |
write_message += '{:.6f}\t'.format(metric) |
|
|
118 |
print_message += name + ': {:.3f} '.format(metric) |
|
|
119 |
self.train_writer.add_scalar('metric_'+name, metric, epoch) |
|
|
120 |
|
|
|
121 |
train_time_msg = 'Training time used: {:.3f}s'.format(train_time) |
|
|
122 |
print_message += '\n' + train_time_msg |
|
|
123 |
with open(self.train_log_filename, 'a') as log_file: |
|
|
124 |
log_file.write(train_time_msg + '\n') |
|
|
125 |
|
|
|
126 |
current_lr_msg = 'Learning rate for this epoch: {:.7f}'.format(current_lr) |
|
|
127 |
print_message += '\n' + current_lr_msg |
|
|
128 |
self.train_writer.add_scalar('lr', current_lr, epoch) |
|
|
129 |
|
|
|
130 |
with open(self.train_summary_filename, 'a') as log_file: |
|
|
131 |
log_file.write(write_message + '\n') |
|
|
132 |
|
|
|
133 |
print(print_message) |
|
|
134 |
|
|
|
135 |
def print_test_log(self, epoch, iteration, losses_dict, metrics_dict, batch_size, dataset_size): |
|
|
136 |
""" |
|
|
137 |
print performance metrics of this iteration on console and save the message to the disk |
|
|
138 |
|
|
|
139 |
Parameters: |
|
|
140 |
epoch (int) -- epoch number of this testing model |
|
|
141 |
iteration (int) -- current testing iteration during this epoch |
|
|
142 |
losses_dict (OrderedDict) -- training losses stored in the ordered dict |
|
|
143 |
metrics_dict (OrderedDict) -- metrics stored in the ordered dict |
|
|
144 |
batch_size (int) -- batch size of testing |
|
|
145 |
dataset_size (int) -- size of the testing dataset |
|
|
146 |
""" |
|
|
147 |
data_point_covered = min((iteration + 1) * batch_size, dataset_size) |
|
|
148 |
message = '[TEST] [Epoch: {:3d} Iter: {:4d}] '.format(int(epoch), data_point_covered) |
|
|
149 |
for name, loss in losses_dict.items(): |
|
|
150 |
message += '{:s}: {:.3f} '.format(name, loss[-1]) |
|
|
151 |
for name, metric in metrics_dict.items(): |
|
|
152 |
message += '{:s}: {:.3f} '.format(name, metric) |
|
|
153 |
|
|
|
154 |
print(message) |
|
|
155 |
|
|
|
156 |
with open(self.test_log_filename, 'a') as log_file: |
|
|
157 |
log_file.write(message + '\n') |
|
|
158 |
|
|
|
159 |
def print_test_summary(self, epoch, losses_dict, output_dict, test_time): |
|
|
160 |
""" |
|
|
161 |
print the summary of this testing epoch |
|
|
162 |
|
|
|
163 |
Parameters: |
|
|
164 |
epoch (int) -- epoch number of this testing model |
|
|
165 |
losses_dict (OrderedDict) -- the losses dictionary |
|
|
166 |
output_dict (OrderedDict) -- the downstream output dictionary |
|
|
167 |
test_time (float) -- time used for testing this epoch |
|
|
168 |
""" |
|
|
169 |
write_message = '{:s}\t'.format(str(epoch)) |
|
|
170 |
print_message = '[TEST] [Epoch: {:3d}] '.format(int(epoch)) |
|
|
171 |
|
|
|
172 |
for name, loss in losses_dict.items(): |
|
|
173 |
# write_message += '{:.6f}\t'.format(np.mean(loss)) |
|
|
174 |
print_message += name + ': {:.3f} '.format(np.mean(loss)) |
|
|
175 |
self.test_writer.add_scalar('loss_'+name, np.mean(loss), epoch) |
|
|
176 |
|
|
|
177 |
metrics_dict = self.get_epoch_metrics(output_dict) |
|
|
178 |
|
|
|
179 |
for name, metric in metrics_dict.items(): |
|
|
180 |
write_message += '{:.6f}\t'.format(metric) |
|
|
181 |
print_message += name + ': {:.3f} '.format(metric) |
|
|
182 |
self.test_writer.add_scalar('metric_' + name, metric, epoch) |
|
|
183 |
|
|
|
184 |
with open(self.test_summary_filename, 'a') as log_file: |
|
|
185 |
log_file.write(write_message + '\n') |
|
|
186 |
|
|
|
187 |
test_time_msg = 'Testing time used: {:.3f}s'.format(test_time) |
|
|
188 |
print_message += '\n' + test_time_msg |
|
|
189 |
print(print_message) |
|
|
190 |
with open(self.test_log_filename, 'a') as log_file: |
|
|
191 |
log_file.write(test_time_msg + '\n') |
|
|
192 |
|
|
|
193 |
def get_epoch_metrics(self, output_dict): |
|
|
194 |
""" |
|
|
195 |
Get the downstream task metrics for whole epoch |
|
|
196 |
|
|
|
197 |
Parameters: |
|
|
198 |
output_dict (OrderedDict) -- the output dictionary used to compute the downstream task metrics |
|
|
199 |
""" |
|
|
200 |
if self.param.downstream_task == 'classification': |
|
|
201 |
y_true = output_dict['y_true'].cpu().numpy() |
|
|
202 |
y_true_binary = label_binarize(y_true, classes=range(self.param.class_num)) |
|
|
203 |
y_pred = output_dict['y_pred'].cpu().numpy() |
|
|
204 |
y_prob = output_dict['y_prob'].cpu().numpy() |
|
|
205 |
if self.param.class_num == 2: |
|
|
206 |
y_prob = y_prob[:, 1] |
|
|
207 |
|
|
|
208 |
accuracy = sk.metrics.accuracy_score(y_true, y_pred) |
|
|
209 |
precision = sk.metrics.precision_score(y_true, y_pred, average='macro', zero_division=0) |
|
|
210 |
recall = sk.metrics.recall_score(y_true, y_pred, average='macro', zero_division=0) |
|
|
211 |
f1 = sk.metrics.f1_score(y_true, y_pred, average='macro', zero_division=0) |
|
|
212 |
try: |
|
|
213 |
auc = sk.metrics.roc_auc_score(y_true_binary, y_prob, multi_class='ovo', average='macro') |
|
|
214 |
except ValueError: |
|
|
215 |
auc = -1 |
|
|
216 |
print('ValueError: ROC AUC score is not defined in this case.') |
|
|
217 |
|
|
|
218 |
return {'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1, 'auc': auc} |
|
|
219 |
|
|
|
220 |
elif self.param.downstream_task == 'regression': |
|
|
221 |
y_true = output_dict['y_true'].cpu().numpy() |
|
|
222 |
y_pred = output_dict['y_pred'].cpu().detach().numpy() |
|
|
223 |
|
|
|
224 |
mse = sk.metrics.mean_squared_error(y_true, y_pred) |
|
|
225 |
rmse = sk.metrics.mean_squared_error(y_true, y_pred, squared=False) |
|
|
226 |
mae = sk.metrics.mean_absolute_error(y_true, y_pred) |
|
|
227 |
medae = sk.metrics.median_absolute_error(y_true, y_pred) |
|
|
228 |
r2 = sk.metrics.r2_score(y_true, y_pred) |
|
|
229 |
|
|
|
230 |
return {'mse': mse, 'rmse': rmse, 'mae': mae, 'medae': medae, 'r2': r2} |
|
|
231 |
|
|
|
232 |
elif self.param.downstream_task == 'survival': |
|
|
233 |
metrics_start_time = time.time() |
|
|
234 |
|
|
|
235 |
y_true_E = output_dict['y_true_E'].cpu().numpy() |
|
|
236 |
y_true_T = output_dict['y_true_T'].cpu().numpy() |
|
|
237 |
y_pred_risk = output_dict['risk'].cpu().numpy() |
|
|
238 |
y_pred_survival = output_dict['survival'].cpu().numpy() |
|
|
239 |
|
|
|
240 |
time_points = util.get_time_points(self.param.survival_T_max, self.param.time_num) |
|
|
241 |
|
|
|
242 |
try: |
|
|
243 |
c_index = metrics.c_index(y_true_T, y_true_E, y_pred_risk) |
|
|
244 |
except ValueError: |
|
|
245 |
c_index = -1 |
|
|
246 |
print('ValueError: NaNs detected in input when calculating c-index.') |
|
|
247 |
|
|
|
248 |
try: |
|
|
249 |
ibs = metrics.ibs(y_true_T, y_true_E, y_pred_survival, time_points) |
|
|
250 |
except ValueError: |
|
|
251 |
ibs = -1 |
|
|
252 |
print('ValueError: NaNs detected in input when calculating integrated brier score.') |
|
|
253 |
|
|
|
254 |
metrics_time = time.time() - metrics_start_time |
|
|
255 |
print('Metrics computing time: {:.3f}s'.format(metrics_time)) |
|
|
256 |
|
|
|
257 |
return {'c-index': c_index, 'ibs': ibs} |
|
|
258 |
|
|
|
259 |
elif self.param.downstream_task == 'multitask': |
|
|
260 |
metrics_start_time = time.time() |
|
|
261 |
|
|
|
262 |
# Survival |
|
|
263 |
y_true_E = output_dict['y_true_E'].cpu().numpy() |
|
|
264 |
y_true_T = output_dict['y_true_T'].cpu().numpy() |
|
|
265 |
y_pred_risk = output_dict['risk'].cpu().numpy() |
|
|
266 |
y_pred_survival = output_dict['survival'].cpu().numpy() |
|
|
267 |
time_points = util.get_time_points(self.param.survival_T_max, self.param.time_num) |
|
|
268 |
try: |
|
|
269 |
c_index = metrics.c_index(y_true_T, y_true_E, y_pred_risk) |
|
|
270 |
except ValueError: |
|
|
271 |
c_index = -1 |
|
|
272 |
print('ValueError: NaNs detected in input when calculating c-index.') |
|
|
273 |
try: |
|
|
274 |
ibs = metrics.ibs(y_true_T, y_true_E, y_pred_survival, time_points) |
|
|
275 |
except ValueError: |
|
|
276 |
ibs = -1 |
|
|
277 |
print('ValueError: NaNs detected in input when calculating integrated brier score.') |
|
|
278 |
|
|
|
279 |
# Classification |
|
|
280 |
y_true_cla = output_dict['y_true_cla'].cpu().numpy() |
|
|
281 |
y_true_cla_binary = label_binarize(y_true_cla, classes=range(self.param.class_num)) |
|
|
282 |
y_pred_cla = output_dict['y_pred_cla'].cpu().numpy() |
|
|
283 |
y_prob_cla = output_dict['y_prob_cla'].cpu().numpy() |
|
|
284 |
if self.param.class_num == 2: |
|
|
285 |
y_prob_cla = y_prob_cla[:, 1] |
|
|
286 |
accuracy = sk.metrics.accuracy_score(y_true_cla, y_pred_cla) |
|
|
287 |
precision = sk.metrics.precision_score(y_true_cla, y_pred_cla, average='macro', zero_division=0) |
|
|
288 |
recall = sk.metrics.recall_score(y_true_cla, y_pred_cla, average='macro', zero_division=0) |
|
|
289 |
f1 = sk.metrics.f1_score(y_true_cla, y_pred_cla, average='macro', zero_division=0) |
|
|
290 |
''' |
|
|
291 |
try: |
|
|
292 |
auc = sk.metrics.roc_auc_score(y_true_cla_binary, y_prob_cla, multi_class='ovo', average='macro') |
|
|
293 |
except ValueError: |
|
|
294 |
auc = -1 |
|
|
295 |
print('ValueError: ROC AUC score is not defined in this case.') |
|
|
296 |
''' |
|
|
297 |
|
|
|
298 |
# Regression |
|
|
299 |
y_true_reg = output_dict['y_true_reg'].cpu().numpy() |
|
|
300 |
y_pred_reg = output_dict['y_pred_reg'].cpu().detach().numpy() |
|
|
301 |
# mse = sk.metrics.mean_squared_error(y_true_reg, y_pred_reg) |
|
|
302 |
rmse = sk.metrics.mean_squared_error(y_true_reg, y_pred_reg, squared=False) |
|
|
303 |
mae = sk.metrics.mean_absolute_error(y_true_reg, y_pred_reg) |
|
|
304 |
medae = sk.metrics.median_absolute_error(y_true_reg, y_pred_reg) |
|
|
305 |
r2 = sk.metrics.r2_score(y_true_reg, y_pred_reg) |
|
|
306 |
|
|
|
307 |
metrics_time = time.time() - metrics_start_time |
|
|
308 |
print('Metrics computing time: {:.3f}s'.format(metrics_time)) |
|
|
309 |
|
|
|
310 |
return {'c-index': c_index, 'ibs': ibs, 'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1, 'rmse': rmse, 'mae': mae, 'medae': medae, 'r2': r2} |
|
|
311 |
|
|
|
312 |
elif self.param.downstream_task == 'alltask': |
|
|
313 |
metrics_start_time = time.time() |
|
|
314 |
|
|
|
315 |
# Survival |
|
|
316 |
y_true_E = output_dict['y_true_E'].cpu().numpy() |
|
|
317 |
y_true_T = output_dict['y_true_T'].cpu().numpy() |
|
|
318 |
y_pred_risk = output_dict['risk'].cpu().numpy() |
|
|
319 |
y_pred_survival = output_dict['survival'].cpu().numpy() |
|
|
320 |
time_points = util.get_time_points(self.param.survival_T_max, self.param.time_num) |
|
|
321 |
try: |
|
|
322 |
c_index = metrics.c_index(y_true_T, y_true_E, y_pred_risk) |
|
|
323 |
except ValueError: |
|
|
324 |
c_index = -1 |
|
|
325 |
print('ValueError: NaNs detected in input when calculating c-index.') |
|
|
326 |
try: |
|
|
327 |
ibs = metrics.ibs(y_true_T, y_true_E, y_pred_survival, time_points) |
|
|
328 |
except ValueError: |
|
|
329 |
ibs = -1 |
|
|
330 |
print('ValueError: NaNs detected in input when calculating integrated brier score.') |
|
|
331 |
|
|
|
332 |
# Classification |
|
|
333 |
accuracy = [] |
|
|
334 |
f1 = [] |
|
|
335 |
auc = [] |
|
|
336 |
for i in range(self.param.task_num - 2): |
|
|
337 |
y_true_cla = output_dict['y_true_cla'][i].cpu().numpy() |
|
|
338 |
y_true_cla_binary = label_binarize(y_true_cla, classes=range(self.param.class_num[i])) |
|
|
339 |
y_pred_cla = output_dict['y_pred_cla'][i].cpu().numpy() |
|
|
340 |
y_prob_cla = output_dict['y_prob_cla'][i].cpu().numpy() |
|
|
341 |
if self.param.class_num[i] == 2: |
|
|
342 |
y_prob_cla = y_prob_cla[:, 1] |
|
|
343 |
accuracy.append(sk.metrics.accuracy_score(y_true_cla, y_pred_cla)) |
|
|
344 |
f1.append(sk.metrics.f1_score(y_true_cla, y_pred_cla, average='macro', zero_division=0)) |
|
|
345 |
try: |
|
|
346 |
auc.append(sk.metrics.roc_auc_score(y_true_cla_binary, y_prob_cla, multi_class='ovo', average='macro')) |
|
|
347 |
except ValueError: |
|
|
348 |
auc.append(-1) |
|
|
349 |
print('ValueError: ROC AUC score is not defined in this case.') |
|
|
350 |
|
|
|
351 |
# Regression |
|
|
352 |
y_true_reg = output_dict['y_true_reg'].cpu().numpy() |
|
|
353 |
y_pred_reg = output_dict['y_pred_reg'].cpu().detach().numpy() |
|
|
354 |
# mse = sk.metrics.mean_squared_error(y_true_reg, y_pred_reg) |
|
|
355 |
rmse = sk.metrics.mean_squared_error(y_true_reg, y_pred_reg, squared=False) |
|
|
356 |
# mae = sk.metrics.mean_absolute_error(y_true_reg, y_pred_reg) |
|
|
357 |
# medae = sk.metrics.median_absolute_error(y_true_reg, y_pred_reg) |
|
|
358 |
r2 = sk.metrics.r2_score(y_true_reg, y_pred_reg) |
|
|
359 |
|
|
|
360 |
metrics_time = time.time() - metrics_start_time |
|
|
361 |
print('Metrics computing time: {:.3f}s'.format(metrics_time)) |
|
|
362 |
|
|
|
363 |
return {'c-index': c_index, 'ibs': ibs, 'accuracy_1': accuracy[0], 'f1_1': f1[0], 'auc_1': auc[0], 'accuracy_2': accuracy[1], 'f1_2': f1[1], 'auc_2': auc[1], 'accuracy_3': accuracy[2], 'f1_3': f1[2], 'auc_3': auc[2], 'accuracy_4': accuracy[3], 'f1_4': f1[3], 'auc_4': auc[3], 'accuracy_5': accuracy[4], 'f1_5': f1[4], 'auc_5': auc[4], 'rmse': rmse, 'r2': r2} |
|
|
364 |
|
|
|
365 |
def save_output_dict(self, output_dict): |
|
|
366 |
""" |
|
|
367 |
Save the downstream task output to disk |
|
|
368 |
|
|
|
369 |
Parameters: |
|
|
370 |
output_dict (OrderedDict) -- the downstream task output dictionary to be saved |
|
|
371 |
""" |
|
|
372 |
down_path = os.path.join(self.output_path, 'down_output') |
|
|
373 |
util.mkdir(down_path) |
|
|
374 |
if self.param.downstream_task == 'classification': |
|
|
375 |
# Prepare files |
|
|
376 |
index = output_dict['index'].numpy() |
|
|
377 |
y_true = output_dict['y_true'].cpu().numpy() |
|
|
378 |
y_pred = output_dict['y_pred'].cpu().numpy() |
|
|
379 |
y_prob = output_dict['y_prob'].cpu().numpy() |
|
|
380 |
|
|
|
381 |
sample_list = self.param.sample_list[index] |
|
|
382 |
|
|
|
383 |
# Output files |
|
|
384 |
y_df = pd.DataFrame({'sample': sample_list, 'y_true': y_true, 'y_pred': y_pred}, index=index) |
|
|
385 |
y_df_path = os.path.join(down_path, 'y_df.tsv') |
|
|
386 |
y_df.to_csv(y_df_path, sep='\t') |
|
|
387 |
|
|
|
388 |
prob_df = pd.DataFrame(y_prob, columns=range(self.param.class_num), index=sample_list) |
|
|
389 |
y_prob_path = os.path.join(down_path, 'y_prob.tsv') |
|
|
390 |
prob_df.to_csv(y_prob_path, sep='\t') |
|
|
391 |
|
|
|
392 |
elif self.param.downstream_task == 'regression': |
|
|
393 |
# Prepare files |
|
|
394 |
index = output_dict['index'].numpy() |
|
|
395 |
y_true = output_dict['y_true'].cpu().numpy() |
|
|
396 |
y_pred = np.squeeze(output_dict['y_pred'].cpu().detach().numpy()) |
|
|
397 |
|
|
|
398 |
sample_list = self.param.sample_list[index] |
|
|
399 |
|
|
|
400 |
# Output files |
|
|
401 |
y_df = pd.DataFrame({'sample': sample_list, 'y_true': y_true, 'y_pred': y_pred}, index=index) |
|
|
402 |
y_df_path = os.path.join(down_path, 'y_df.tsv') |
|
|
403 |
y_df.to_csv(y_df_path, sep='\t') |
|
|
404 |
|
|
|
405 |
elif self.param.downstream_task == 'survival': |
|
|
406 |
# Prepare files |
|
|
407 |
index = output_dict['index'].numpy() |
|
|
408 |
y_true_E = output_dict['y_true_E'].cpu().numpy() |
|
|
409 |
y_true_T = output_dict['y_true_T'].cpu().numpy() |
|
|
410 |
y_pred_risk = output_dict['risk'].cpu().numpy() |
|
|
411 |
survival_function = output_dict['survival'].cpu().numpy() |
|
|
412 |
y_out = output_dict['y_out'].cpu().numpy() |
|
|
413 |
|
|
|
414 |
sample_list = self.param.sample_list[index] |
|
|
415 |
time_points = util.get_time_points(self.param.survival_T_max, self.param.time_num) |
|
|
416 |
|
|
|
417 |
# Output files |
|
|
418 |
y_df = pd.DataFrame({'sample': sample_list, 'true_T': y_true_T, 'true_E': y_true_E, 'pred_risk': y_pred_risk}, index=index) |
|
|
419 |
y_df_path = os.path.join(down_path, 'y_df.tsv') |
|
|
420 |
y_df.to_csv(y_df_path, sep='\t') |
|
|
421 |
|
|
|
422 |
survival_function_df = pd.DataFrame(survival_function, columns=time_points, index=sample_list) |
|
|
423 |
survival_function_path = os.path.join(down_path, 'survival_function.tsv') |
|
|
424 |
survival_function_df.to_csv(survival_function_path, sep='\t') |
|
|
425 |
|
|
|
426 |
y_out_df = pd.DataFrame(y_out, index=sample_list) |
|
|
427 |
y_out_path = os.path.join(down_path, 'y_out.tsv') |
|
|
428 |
y_out_df.to_csv(y_out_path, sep='\t') |
|
|
429 |
|
|
|
430 |
elif self.param.downstream_task == 'multitask': |
|
|
431 |
# Survival |
|
|
432 |
index = output_dict['index'].numpy() |
|
|
433 |
y_true_E = output_dict['y_true_E'].cpu().numpy() |
|
|
434 |
y_true_T = output_dict['y_true_T'].cpu().numpy() |
|
|
435 |
y_pred_risk = output_dict['risk'].cpu().numpy() |
|
|
436 |
survival_function = output_dict['survival'].cpu().numpy() |
|
|
437 |
y_out_sur = output_dict['y_out_sur'].cpu().numpy() |
|
|
438 |
sample_list = self.param.sample_list[index] |
|
|
439 |
time_points = util.get_time_points(self.param.survival_T_max, self.param.time_num) |
|
|
440 |
y_df_sur = pd.DataFrame( |
|
|
441 |
{'sample': sample_list, 'true_T': y_true_T, 'true_E': y_true_E, 'pred_risk': y_pred_risk}, index=index) |
|
|
442 |
y_df_sur_path = os.path.join(down_path, 'y_df_survival.tsv') |
|
|
443 |
y_df_sur.to_csv(y_df_sur_path, sep='\t') |
|
|
444 |
survival_function_df = pd.DataFrame(survival_function, columns=time_points, index=sample_list) |
|
|
445 |
survival_function_path = os.path.join(down_path, 'survival_function.tsv') |
|
|
446 |
survival_function_df.to_csv(survival_function_path, sep='\t') |
|
|
447 |
y_out_sur_df = pd.DataFrame(y_out_sur, index=sample_list) |
|
|
448 |
y_out_sur_path = os.path.join(down_path, 'y_out_survival.tsv') |
|
|
449 |
y_out_sur_df.to_csv(y_out_sur_path, sep='\t') |
|
|
450 |
|
|
|
451 |
# Classification |
|
|
452 |
y_true_cla = output_dict['y_true_cla'].cpu().numpy() |
|
|
453 |
y_pred_cla = output_dict['y_pred_cla'].cpu().numpy() |
|
|
454 |
y_prob_cla = output_dict['y_prob_cla'].cpu().numpy() |
|
|
455 |
y_df_cla = pd.DataFrame({'sample': sample_list, 'y_true': y_true_cla, 'y_pred': y_pred_cla}, index=index) |
|
|
456 |
y_df_cla_path = os.path.join(down_path, 'y_df_classification.tsv') |
|
|
457 |
y_df_cla.to_csv(y_df_cla_path, sep='\t') |
|
|
458 |
prob_cla_df = pd.DataFrame(y_prob_cla, columns=range(self.param.class_num), index=sample_list) |
|
|
459 |
y_prob_cla_path = os.path.join(down_path, 'y_prob_classification.tsv') |
|
|
460 |
prob_cla_df.to_csv(y_prob_cla_path, sep='\t') |
|
|
461 |
|
|
|
462 |
# Regression |
|
|
463 |
y_true_reg = output_dict['y_true_reg'].cpu().numpy() |
|
|
464 |
y_pred_reg = np.squeeze(output_dict['y_pred_reg'].cpu().detach().numpy()) |
|
|
465 |
y_df_reg = pd.DataFrame({'sample': sample_list, 'y_true': y_true_reg, 'y_pred': y_pred_reg}, index=index) |
|
|
466 |
y_df_reg_path = os.path.join(down_path, 'y_df_regression.tsv') |
|
|
467 |
y_df_reg.to_csv(y_df_reg_path, sep='\t') |
|
|
468 |
|
|
|
469 |
elif self.param.downstream_task == 'alltask': |
|
|
470 |
# Survival |
|
|
471 |
index = output_dict['index'].numpy() |
|
|
472 |
y_true_E = output_dict['y_true_E'].cpu().numpy() |
|
|
473 |
y_true_T = output_dict['y_true_T'].cpu().numpy() |
|
|
474 |
y_pred_risk = output_dict['risk'].cpu().numpy() |
|
|
475 |
survival_function = output_dict['survival'].cpu().numpy() |
|
|
476 |
y_out_sur = output_dict['y_out_sur'].cpu().numpy() |
|
|
477 |
sample_list = self.param.sample_list[index] |
|
|
478 |
time_points = util.get_time_points(self.param.survival_T_max, self.param.time_num) |
|
|
479 |
y_df_sur = pd.DataFrame( |
|
|
480 |
{'sample': sample_list, 'true_T': y_true_T, 'true_E': y_true_E, 'pred_risk': y_pred_risk}, index=index) |
|
|
481 |
y_df_sur_path = os.path.join(down_path, 'y_df_survival.tsv') |
|
|
482 |
y_df_sur.to_csv(y_df_sur_path, sep='\t') |
|
|
483 |
survival_function_df = pd.DataFrame(survival_function, columns=time_points, index=sample_list) |
|
|
484 |
survival_function_path = os.path.join(down_path, 'survival_function.tsv') |
|
|
485 |
survival_function_df.to_csv(survival_function_path, sep='\t') |
|
|
486 |
y_out_sur_df = pd.DataFrame(y_out_sur, index=sample_list) |
|
|
487 |
y_out_sur_path = os.path.join(down_path, 'y_out_survival.tsv') |
|
|
488 |
y_out_sur_df.to_csv(y_out_sur_path, sep='\t') |
|
|
489 |
|
|
|
490 |
# Classification |
|
|
491 |
for i in range(self.param.task_num - 2): |
|
|
492 |
y_true_cla = output_dict['y_true_cla'][i].cpu().numpy() |
|
|
493 |
y_pred_cla = output_dict['y_pred_cla'][i].cpu().numpy() |
|
|
494 |
y_prob_cla = output_dict['y_prob_cla'][i].cpu().numpy() |
|
|
495 |
y_df_cla = pd.DataFrame({'sample': sample_list, 'y_true': y_true_cla, 'y_pred': y_pred_cla}, index=index) |
|
|
496 |
y_df_cla_path = os.path.join(down_path, 'y_df_classification_'+str(i+1)+'.tsv') |
|
|
497 |
y_df_cla.to_csv(y_df_cla_path, sep='\t') |
|
|
498 |
prob_cla_df = pd.DataFrame(y_prob_cla, columns=range(self.param.class_num[i]), index=sample_list) |
|
|
499 |
y_prob_cla_path = os.path.join(down_path, 'y_prob_classification_'+str(i+1)+'.tsv') |
|
|
500 |
prob_cla_df.to_csv(y_prob_cla_path, sep='\t') |
|
|
501 |
|
|
|
502 |
# Regression |
|
|
503 |
y_true_reg = output_dict['y_true_reg'].cpu().numpy() |
|
|
504 |
y_pred_reg = np.squeeze(output_dict['y_pred_reg'].cpu().detach().numpy()) |
|
|
505 |
y_df_reg = pd.DataFrame({'sample': sample_list, 'y_true': y_true_reg, 'y_pred': y_pred_reg}, index=index) |
|
|
506 |
y_df_reg_path = os.path.join(down_path, 'y_df_regression.tsv') |
|
|
507 |
y_df_reg.to_csv(y_df_reg_path, sep='\t') |
|
|
508 |
|
|
|
509 |
|
|
|
510 |
def save_latent_space(self, latent_dict, sample_list): |
|
|
511 |
""" |
|
|
512 |
save the latent space matrix to disc |
|
|
513 |
|
|
|
514 |
Parameters: |
|
|
515 |
latent_dict (OrderedDict) -- the latent space dictionary |
|
|
516 |
sample_list (ndarray) -- the sample list for the latent matrix |
|
|
517 |
""" |
|
|
518 |
reordered_sample_list = sample_list[latent_dict['index'].astype(int)] |
|
|
519 |
latent_df = pd.DataFrame(latent_dict['latent'], index=reordered_sample_list) |
|
|
520 |
output_path = os.path.join(self.param.checkpoints_dir, self.param.experiment_name, 'latent_space.tsv') |
|
|
521 |
print('Saving the latent space matrix...') |
|
|
522 |
latent_df.to_csv(output_path, sep='\t') |
|
|
523 |
|
|
|
524 |
|
|
|
525 |
@staticmethod |
|
|
526 |
def print_phase(phase): |
|
|
527 |
""" |
|
|
528 |
print the phase information |
|
|
529 |
|
|
|
530 |
Parameters: |
|
|
531 |
phase (int) -- the phase of the training process |
|
|
532 |
""" |
|
|
533 |
if phase == 'p1': |
|
|
534 |
print('PHASE 1: Unsupervised Phase') |
|
|
535 |
elif phase == 'p2': |
|
|
536 |
print('PHASE 2: Supervised Phase') |
|
|
537 |
elif phase == 'p3': |
|
|
538 |
print('PHASE 3: Supervised Phase') |