a b/R/inference_results_plot.R
1
# inference_results_plot.R
2
3
require(data.table)
4
require(ggplot2)
5
options(scipen = 3)
6
rsq <- function (x, y) cor(x, y) ^ 2
7
8
# TEMP: Train on CTRP, Test on GDSC using respective omic data (exp) ====
9
ctrp_gnn_exp <- fread("Data/CV_Results//HyperOpt_DRP_ResponseOnly_gnndrug_exp_HyperOpt_DRP_CTRP_ResponseOnly_EncoderTrain_Split_BOTH_NoBottleNeck_NoTCGAPretrain_MergeByLMF_WeightedRMSELoss_GnnDrugs_gnndrug_exp/CTRP_AAC_SMILES_inference_results.csv")
10
gdsc1_gnn_exp <- fread("Data/CV_Results//HyperOpt_DRP_ResponseOnly_gnndrug_exp_HyperOpt_DRP_CTRP_ResponseOnly_EncoderTrain_Split_BOTH_NoBottleNeck_NoTCGAPretrain_MergeByLMF_WeightedRMSELoss_GnnDrugs_gnndrug_exp/GDSC1_AAC_SMILES_inference_results.csv")
11
gdsc2_gnn_exp <- fread("Data/CV_Results//HyperOpt_DRP_ResponseOnly_gnndrug_exp_HyperOpt_DRP_CTRP_ResponseOnly_EncoderTrain_Split_BOTH_NoBottleNeck_NoTCGAPretrain_MergeByLMF_WeightedRMSELoss_GnnDrugs_gnndrug_exp/GDSC2_AAC_SMILES_inference_results.csv")
12
13
rsq(ctrp_gnn_exp$target, ctrp_gnn_exp$predicted)  # 0.833
14
rsq(gdsc1_gnn_exp$target, gdsc1_gnn_exp$predicted)  # 0.07
15
rsq(gdsc2_gnn_exp$target, gdsc2_gnn_exp$predicted)  # 0.119  
16
# Conclusion, DepMap + CTRP is not good at predicting GDSC. Fine-tuning might help
17
18
# TEMP: Train on GDSC2, Test on CTRP using respective omic data (exp) ====
19
ctrp_gnn_exp <- fread("Data/CV_Results/HyperOpt_DRP_ResponseOnly_gnndrug_exp_HyperOpt_DRP_GDSC2_ResponseOnly_EncoderTrain_Split_BOTH_NoBottleNeck_NoTCGAPretrain_MergeByLMF_WeightedRMSELoss_GNNDrugs_gnndrug_exp/CTRP_AAC_SMILES_inference_results.csv")
20
gdsc1_gnn_exp <- fread("Data/CV_Results/HyperOpt_DRP_ResponseOnly_gnndrug_exp_HyperOpt_DRP_GDSC2_ResponseOnly_EncoderTrain_Split_BOTH_NoBottleNeck_NoTCGAPretrain_MergeByLMF_WeightedRMSELoss_GNNDrugs_gnndrug_exp/GDSC1_AAC_SMILES_inference_results.csv")
21
gdsdc2_gnn_exp <- fread("Data/CV_Results/HyperOpt_DRP_ResponseOnly_gnndrug_exp_HyperOpt_DRP_GDSC2_ResponseOnly_EncoderTrain_Split_BOTH_NoBottleNeck_NoTCGAPretrain_MergeByLMF_WeightedRMSELoss_GNNDrugs_gnndrug_exp/GDSC2_AAC_SMILES_inference_results.csv")
22
23
rsq(ctrp_gnn_exp$target, ctrp_gnn_exp$predicted)  # 0.04
24
rsq(gdsc1_gnn_exp$target, gdsc1_gnn_exp$predicted)  # 0.12
25
rsq(gdsc2_gnn_exp$target, gdsc2_gnn_exp$predicted)  # 0.119
26
27
28
# ==== Bimodal Case ====
29
require(data.table)
30
require(ggplot2)
31
options(scipen = 3)
32
# all_csv_results <- list.files("Data/CV_Results/", "CV_results.csv", recursive = T, full.names = T)
33
all_csv_results <- list.files("Data/CV_Results/", "CTRP_.+_inference_results.csv", recursive = T, full.names = T)
34
bimodal_results <- grep(pattern = ".+drug_.{3,5}_HyperOpt.+", x = all_csv_results, value = T)
35
36
37
all_results <- vector(mode = "list", length = length(bimodal_results))
38
for (i in 1:length(bimodal_results)) {
39
  cur_res <- fread(bimodal_results[i])
40
  data_types <- gsub(".+ResponseOnly_\\w*drug_(.+)_HyperOpt.+", "\\1", bimodal_results[i])
41
  data_types <- toupper(data_types)
42
  merge_method <- gsub(".+MergeBy(\\w+)_.*RMSE.+", "\\1", bimodal_results[i])
43
  loss_method <- gsub(".+_(.*)RMSE.+", "\\1RMSE", bimodal_results[i])
44
  drug_type <- gsub(".+ResponseOnly_(\\w*)drug.+_HyperOpt.+", "\\1drug", bimodal_results[i])
45
  drug_type <- toupper(drug_type)
46
  split_method <- gsub(".+Split_(\\w+)_NoBottleNeck.+", "\\1", bimodal_results[i])
47
  # data_types <- strsplit(data_types, "_")[[1]]
48
  # cur_res$epoch <- as.integer(epoch)
49
  cur_res$data_types <- data_types
50
  cur_res$merge_method <- merge_method
51
  cur_res$loss_type <- loss_method
52
  cur_res$drug_type <- drug_type
53
  cur_res$split_method <- split_method
54
  
55
  all_results[[i]] <- cur_res
56
}
57
all_results <- rbindlist(all_results)
58
59
all_results[, loss_by_config := mean(RMSELoss), by = c("data_types", "merge_method", "loss_type", "drug_type", "split_method")]
60
61
# all_results <- all_results[!(V1 %in% c("max_final_epoch", "time_this_iter_s", "num_samples", "avg_cv_untrained_loss"))]
62
63
long_results <- melt(unique(all_results[, c("data_types", "merge_method", "loss_type", "drug_type", "split_method", "loss_by_config")]),
64
                     id.vars = c("data_types", "merge_method", "loss_type", "drug_type", "split_method"))
65
# long_results[V1 == "avg_cv_train_loss"]$V1 <- "Mean CV Training Loss"
66
# long_results[V1 == "avg_cv_valid_loss"]$V1 <- "Mean CV Validation Loss"
67
# long_results <- long_results[split_method == "DRUG"]
68
# long_results <- long_results[merge_method == "Concat"]
69
# long_results <- long_results[merge_method == "Sum"]
70
# long_results <- long_results[loss_type == "RMSE"]
71
# long_results <- long_results[merge_method == "LMF" & loss_type == "WeightedRMSE"]
72
# long_results <- long_results[split_method == "CELL_LINE"]
73
# long_results <- long_results[drug_type == "DRUG"]
74
# All loss comparison ====
75
ggplot(long_results) +
76
  geom_bar(mapping = aes(x = data_types, y = value, fill = split_method), stat = "identity", position='dodge') +
77
  facet_wrap(~merge_method+loss_type+drug_type+split_method, nrow = 2) + 
78
  scale_fill_discrete(name = "Split Type:") +
79
  scale_colour_manual(values=c("#000000", "#E69F00", "#56B4E9", "#009E73",
80
                               "#F0E442", "#0072B2", "#D55E00", "#CC79A7")) +
81
  theme(axis.text.x = element_text(angle = 90, hjust = 1)) +
82
  ggtitle(label = tools::toTitleCase("Comparison of Loss-weighting, fusion method and drug representation in the bi-modal case"),
83
          subtitle = "Training RMSE Loss using strict splitting during hyper-parameter optimization")
84
85
dir.create("Plots/Training_Inference_Results")
86
ggsave(filename = "Plots/Training_Inference_Results/Bimodal_Full_RMSELoss_Comparison.pdf")
87
88
89
# Upper AAC loss comparison ====
90
temp_results <- all_results
91
temp_results <- temp_results[target > 0.7]
92
temp_results[, loss_by_config := mean(RMSELoss), by = c("data_types", "merge_method", "loss_type", "drug_type", "split_method")]
93
temp_results[, rsq_by_config := rsq(target, predicted), by = c("data_types", "merge_method", "loss_type", "drug_type", "split_method")]
94
temp_long_results <- melt(unique(temp_results[, c("data_types", "merge_method", "loss_type", "drug_type", "split_method", "loss_by_config")]),
95
                     id.vars = c("data_types", "merge_method", "loss_type", "drug_type", "split_method"))
96
97
ggplot(temp_long_results) +
98
  geom_bar(mapping = aes(x = data_types, y = value, fill = split_method), stat = "identity", position='dodge') +
99
  facet_wrap(~merge_method+loss_type+drug_type+split_method, nrow = 2) + 
100
  scale_fill_discrete(name = "Split Method:") +
101
  # scale_colour_manual(values=c("#000000", "#E69F00", "#56B4E9", "#009E73",
102
  #                              "#F0E442", "#0072B2", "#D55E00", "#CC79A7")) +
103
  theme(axis.text.x = element_text(angle = 90, hjust = 1)) +
104
  ggtitle(label = tools::toTitleCase("Comparison of Loss-weighting, fusion method and drug representation in the bi-modal case"),
105
          subtitle = "Training RMSE Loss with AAC Targets >= 0.7, using strict splitting during hyper-parameter optimization")
106
107
dir.create("Plots/Training_Inference_Results")
108
ggsave(filename = "Plots/Training_Inference_Results/Bimodal_UpperAAC_RMSELoss_Comparison.pdf")
109
110
111
# temp_long_results <- melt(unique(temp_results[, c("data_types", "merge_method", "loss_type", "drug_type", "split_method", "rsq_by_config")]),
112
#                           id.vars = c("data_types", "merge_method", "loss_type", "drug_type", "split_method"))
113
# 
114
# ggplot(temp_long_results) +
115
#   geom_bar(mapping = aes(x = data_types, y = value, fill = split_method), stat = "identity", position='dodge') +
116
#   facet_wrap(~merge_method+loss_type+drug_type+split_method, nrow = 2) + 
117
#   scale_fill_discrete(name = "Split Method:") +
118
#   # scale_colour_manual(values=c("#000000", "#E69F00", "#56B4E9", "#009E73",
119
#   #                              "#F0E442", "#0072B2", "#D55E00", "#CC79A7")) +
120
#   theme(axis.text.x = element_text(angle = 90, hjust = 1)) +
121
#   ggtitle(label = tools::toTitleCase("Comparison of Loss-weighting, fusion method and drug representation in the bi-modal case"),
122
#           subtitle = "Training RMSE Loss with AAC Targets >= 0.7, using strict splitting during hyper-parameter optimization")
123
# 
124
# dir.create("Plots/Training_Inference_Results")
125
# ggsave(filename = "Plots/Training_Inference_Results/Bimodal_UpperAAC_RMSELoss_Comparison.pdf")
126
127
128
# Lower AAC loss comparison ====
129
temp_results <- all_results
130
temp_results <- temp_results[target < 0.3]
131
temp_results[, loss_by_config := mean(RMSELoss), by = c("data_types", "merge_method", "loss_type", "drug_type", "split_method")]
132
temp_results[, rsq_by_config := rsq(target, predicted), by = c("data_types", "merge_method", "loss_type", "drug_type", "split_method")]
133
temp_long_results <- melt(unique(temp_results[, c("data_types", "merge_method", "loss_type", "drug_type", "split_method", "loss_by_config")]),
134
                          id.vars = c("data_types", "merge_method", "loss_type", "drug_type", "split_method"))
135
136
ggplot(temp_long_results) +
137
  geom_bar(mapping = aes(x = data_types, y = value, fill = split_method), stat = "identity", position='dodge') +
138
  facet_wrap(~merge_method+loss_type+drug_type+split_method, nrow = 2) + 
139
  scale_fill_discrete(name = "Split Method:") +
140
  # scale_colour_manual(values=c("#000000", "#E69F00", "#56B4E9", "#009E73",
141
  #                              "#F0E442", "#0072B2", "#D55E00", "#CC79A7")) +
142
  theme(axis.text.x = element_text(angle = 90, hjust = 1)) +
143
  ggtitle(label = tools::toTitleCase("Comparison of Loss-weighting, fusion method and drug representation in the bi-modal case"),
144
          subtitle = "Training RMSE Loss with AAC Targets <= 0.3, using strict splitting during hyper-parameter optimization")
145
146
dir.create("Plots/Training_Inference_Results")
147
ggsave(filename = "Plots/Training_Inference_Results/Bimodal_LowerAAC_RMSELoss_Comparison.pdf")
148
149
# =========
150
all_gnn_inference_results <- list.files("Data/CV_Results/", "CTRP_AAC_SMILES_inference_results.csv", recursive = T, full.names = T)
151
all_morgan_inference_results <- list.files("Data/CV_Results/", "CTRP_AAC_MORGAN_1024_inference_results.csv", recursive = T, full.names = T)
152
gnn_bimodal_results <- grep(pattern = ".+gnndrug_.{3,5}_HyperOpt.+", x = all_gnn_inference_results, value = T)
153
morgan_bimodal_results <- grep(pattern = ".+_drug_.{3,5}_HyperOpt.+", x = all_morgan_inference_results, value = T)
154
155
all_gnn_results <- vector(mode = "list", length = length(gnn_bimodal_results))
156
for (i in 1:length(gnn_bimodal_results)) {
157
  cur_res <- fread(gnn_bimodal_results[i])
158
  data_types <- gsub(".+ResponseOnly_(.+)_HyperOpt.+", "\\1", gnn_bimodal_results[i])
159
  data_types <- toupper(data_types)
160
  # data_types <- strsplit(data_types, "_")[[1]]
161
  # cur_res$epoch <- as.integer(epoch)
162
  cur_res$data_types <- data_types
163
  all_gnn_results[[i]] <- cur_res
164
}
165
166
all_gnn_results <- rbindlist(all_gnn_results)
167
168
# rsq(all_gnn_results[data_types == "GNNDRUG_PROT"]$target, all_gnn_results[data_types == "GNNDRUG_PROT"]$predicted)
169
# rsq(all_gnn_results[data_types == "GNNDRUG_MUT"]$target, all_gnn_results[data_types == "GNNDRUG_MUT"]$predicted)
170
# mean(all_gnn_results[data_types == "GNNDRUG_PROT"]$RMSELoss)
171
# mean(all_gnn_results[data_types == "GNNDRUG_MUT"]$RMSELoss)
172
173
all_morgan_results <- vector(mode = "list", length = length(morgan_bimodal_results))
174
for (i in 1:length(morgan_bimodal_results)) {
175
  cur_res <- fread(morgan_bimodal_results[i])
176
  data_types <- gsub(".+ResponseOnly_(.+)_HyperOpt.+", "\\1", morgan_bimodal_results[i])
177
  data_types <- toupper(data_types)
178
  # data_types <- strsplit(data_types, "_")[[1]]
179
  # cur_res$epoch <- as.integer(epoch)
180
  cur_res$data_types <- data_types
181
  all_morgan_results[[i]] <- cur_res
182
}
183
184
all_morgan_results <- rbindlist(all_morgan_results)
185
186
# ggplot(data = all_gnn_results[data_types == "GNNDRUG_EXP"], aes(x = predicted, y = target)) +
187
#   geom_point() +
188
#   coord_fixed(ratio = 1) +
189
#   geom_abline(intercept = 0, colour = "red")
190
#   # facet_grid(~data_types,)
191
# 
192
# 
193
# rsq(all_morgan_results[data_types == "DRUG_PROT"]$target, all_morgan_results[data_types == "DRUG_PROT"]$predicted)
194
# mean(all_morgan_results[data_types == "DRUG_PROT"]$RMSELoss)
195
# 
196
# ggplot(data = all_morgan_results[data_types == "DRUG_PROT"], aes(x = predicted, y = target)) +
197
#   geom_point() +
198
#   coord_fixed(ratio = 1) +
199
#   geom_abline(intercept = 0, colour = "red")
200
#   # facet_grid(~data_types,)
201
202
all_data_types <- c("MUT", "EXP", "PROT", "MIRNA", "HIST", "METAB", "RPPA")
203
204
205
for (data_type in all_data_types) {
206
  cur_data <- rbindlist(list(all_morgan_results[data_types == paste0("DRUG_", data_type)], all_gnn_results[data_types == paste0("GNNDRUG_", data_type)]))
207
  
208
  morgan_rsq <- rsq(all_morgan_results[data_types == paste0("DRUG_", data_type)]$target, all_morgan_results[data_types == paste0("DRUG_", data_type)]$predicted)
209
  gnn_rsq <- rsq(all_gnn_results[data_types == paste0("GNNDRUG_", data_type)]$target, all_gnn_results[data_types == paste0("GNNDRUG_", data_type)]$predicted)
210
  
211
  p <- ggplot(data = cur_data, aes(x = predicted, y = target)) +
212
    geom_point() +
213
    coord_fixed(ratio = 1) +
214
    geom_abline(intercept = 0, colour = "red") +
215
    facet_grid(~data_types) +
216
    ggtitle(label = "Performance Comparison on CTRPv2", subtitle = paste0("Model Trained on CTRPv2, R^2 Morgan: ", round(morgan_rsq, 2), ", R^2 GNN Drug: ", round(gnn_rsq, 2)))
217
  ggsave(filename = paste0("Plots/R2_Line/CTRP_Morgan_vs_GNNDrug_", data_type, ".jpg"), plot = p)
218
}
219
220
# data_type <- "EXP"
221
# data_type <- "METAB"
222
cur_data <- all_gnn_results[data_types == paste0("GNNDRUG_", data_type)]
223
gnn_rsq <- rsq(all_gnn_results[data_types == paste0("GNNDRUG_", data_type)]$target, all_gnn_results[data_types == paste0("GNNDRUG_", data_type)]$predicted)
224
225
p <- ggplot(data = cur_data, aes(x = predicted, y = target)) +
226
  geom_point() +
227
  coord_fixed(ratio = 1) +
228
  geom_abline(intercept = 0, colour = "red") +
229
  facet_grid(~data_types) +
230
  ggtitle(label = "Performance Comparison on CTRPv2", subtitle = paste0("Model Trained on CTRPv2, R^2 GNN Drug: ", round(gnn_rsq, 2)))
231
ggsave(filename = paste0("Plots/R2_Line/GNNDrug_", data_type, ".jpg"), plot = p)
232
233
# data_type <- "EXP"
234
# data_type <- "METAB"
235
cur_data <- all_gnn_results[data_types == paste0("GNNDRUG_", data_type)]
236
gnn_rsq <- rsq(all_gnn_results[data_types == paste0("GNNDRUG_", data_type)]$target, all_gnn_results[data_types == paste0("GNNDRUG_", data_type)]$predicted)
237
238
p <- ggplot(data = cur_data, aes(x = predicted, y = target)) +
239
  geom_point() +
240
  coord_fixed(ratio = 1) +
241
  geom_abline(intercept = 0, colour = "red") +
242
  facet_grid(~data_types) +
243
  ggtitle(label = "Performance Comparison on CTRPv2", subtitle = paste0("Model Trained on CTRPv2, R^2 GNN Drug: ", round(gnn_rsq, 2)))
244
ggsave(filename = paste0("Plots/R2_Line/GNNDrug_", data_type, ".jpg"), plot = p)