|
a |
|
b/R/loss_by_lineage.R |
|
|
1 |
# loss_by_lineage.R |
|
|
2 |
require(data.table) |
|
|
3 |
require(ggplot2) |
|
|
4 |
|
|
|
5 |
|
|
|
6 |
plot_loss_by_lineage <- function(path, |
|
|
7 |
plot_path, cell_line_data, title, subtitle, plot_filename, display_plot = FALSE) { |
|
|
8 |
|
|
|
9 |
cv_results <- fread(paste0(path, "CV_results.csv")) |
|
|
10 |
cv_valid_loss <- cv_results[V1 == "avg_cv_valid_loss"][,2] |
|
|
11 |
cv_valid_loss <- format(round(cv_valid_loss, 4), nsmall = 4) |
|
|
12 |
ctrp_data <- fread(paste0(path, "CTRP_AAC_MORGAN_1024_inference_results.csv")) |
|
|
13 |
ctrp_data <- merge(ctrp_data, cell_line_data[, c("stripped_cell_line_name", "lineage")], by.x = "cell_name", by.y = "stripped_cell_line_name") |
|
|
14 |
# gdsc1_data <- fread(paste0(path, "GDSC1_AAC_MORGAN_1024_inference_results.csv")) |
|
|
15 |
# gdsc1_data <- merge(gdsc1_data, cell_line_data[, c("stripped_cell_line_name", "lineage")], by.x = "cell_name", by.y = "stripped_cell_line_name") |
|
|
16 |
# gdsc2_data <- fread(paste0(path, "GDSC2_AAC_MORGAN_1024_inference_results.csv")) |
|
|
17 |
# gdsc2_data <- merge(gdsc2_data, cell_line_data[, c("stripped_cell_line_name", "lineage")], by.x = "cell_name", by.y = "stripped_cell_line_name") |
|
|
18 |
|
|
|
19 |
# ctrp_data[, abs_loss := sqrt(MSE_loss)] |
|
|
20 |
ctrp_data[, lineage_loss_avg := mean(RMSE_loss), by = "lineage"] |
|
|
21 |
ctrp_data[, lineage_loss_sd := sd(RMSE_loss), by = "lineage"] |
|
|
22 |
ctrp_data[, sample_by_lineage_count := .N, by = "lineage"] |
|
|
23 |
ctrp_avg_abs_by_lineage <- unique(ctrp_data[, c("lineage", "lineage_loss_avg", "lineage_loss_sd")]) |
|
|
24 |
ctrp_avg_abs_by_lineage$Dataset <- "CTRPv2" |
|
|
25 |
|
|
|
26 |
# gdsc1_data[, lineage_loss_avg := mean(RMSE_loss), by = "lineage"] |
|
|
27 |
# gdsc1_data[, lineage_loss_sd := sd(RMSE_loss), by = "lineage"] |
|
|
28 |
# gdsc1_data[, sample_by_lineage_count := .N, by = "lineage"] |
|
|
29 |
# gdsc1_avg_abs_by_lineage <- unique(gdsc1_data[, c("lineage", "lineage_loss_avg", "lineage_loss_sd")]) |
|
|
30 |
# gdsc1_avg_abs_by_lineage$Dataset <- "GDSC1" |
|
|
31 |
# |
|
|
32 |
# gdsc2_data[, lineage_loss_avg := mean(RMSE_loss), by = "lineage"] |
|
|
33 |
# gdsc2_data[, lineage_loss_sd := sd(RMSE_loss), by = "lineage"] |
|
|
34 |
# gdsc2_data[, sample_by_lineage_count := .N, by = "lineage"] |
|
|
35 |
# gdsc2_avg_abs_by_lineage <- unique(gdsc2_data[, c("lineage", "lineage_loss_avg", "lineage_loss_sd")]) |
|
|
36 |
# gdsc2_avg_abs_by_lineage$Dataset <- "GDSC2" |
|
|
37 |
|
|
|
38 |
# all_avg_abs_by_lineage <- rbindlist(list(ctrp_avg_abs_by_lineage, gdsc1_avg_abs_by_lineage, gdsc2_avg_abs_by_lineage)) |
|
|
39 |
all_avg_abs_by_lineage <- ctrp_avg_abs_by_lineage |
|
|
40 |
all_avg_abs_by_lineage <- merge(all_avg_abs_by_lineage, unique(ctrp_data[, c("lineage", "sample_by_lineage_count")])) |
|
|
41 |
all_avg_abs_by_lineage$lineage <- paste0(all_avg_abs_by_lineage$lineage, ", n = ", all_avg_abs_by_lineage$sample_by_lineage_count) |
|
|
42 |
|
|
|
43 |
g <- ggplot(data = all_avg_abs_by_lineage, mapping = aes(x = reorder(lineage, -lineage_loss_avg), y = lineage_loss_avg, fill = Dataset)) + |
|
|
44 |
geom_bar(stat = "identity", position = position_dodge()) + |
|
|
45 |
# geom_errorbar(aes(ymin = lineage_loss_avg - lineage_loss_sd, ymax = lineage_loss_avg + lineage_loss_sd), width = 0.2, position = position_dodge(0.9)) + |
|
|
46 |
theme(axis.text.x = element_text(angle = 45, hjust = 1)) + |
|
|
47 |
geom_hline(yintercept = mean(ctrp_data$lineage_loss_avg), linetype="dashed", color = "red") + |
|
|
48 |
# geom_text(aes(10, mean(ctrp_data$abs_loss),label = mean(ctrp_data$abs_loss), vjust = -1)) + |
|
|
49 |
# geom_hline(yintercept = mean(gdsc1_data$lineage_loss_avg), linetype="dashed", color = "green") + |
|
|
50 |
# geom_hline(yintercept = mean(gdsc2_data$lineage_loss_avg), linetype="dashed", color = "blue") + |
|
|
51 |
xlab("Cell Line Lineage + # testing datapoints") + ylab("RMSE Loss") + |
|
|
52 |
# scale_y_discrete(limits = c("0.001", "0.002")) + |
|
|
53 |
scale_y_continuous(breaks = sort(c(seq(0, 0.25, length.out=10), |
|
|
54 |
c(mean(ctrp_data$lineage_loss_avg) |
|
|
55 |
# mean(gdsc1_data$lineage_loss_avg), |
|
|
56 |
# mean(gdsc2_data$lineage_loss_avg) |
|
|
57 |
) |
|
|
58 |
))) + |
|
|
59 |
# ggtitle(label = "Full DRP Mean Absolute Loss by Cell Line Lineage", subtitle = "Data: Drug + Proteomics | Trained on CTRPv2 | Tested on All 3") |
|
|
60 |
ggtitle(label = title, subtitle = paste0(subtitle, "\nAverage Cross-Validation RMSE Loss:", as.character(cv_valid_loss))) |
|
|
61 |
if (display_plot == TRUE) { |
|
|
62 |
print(g) |
|
|
63 |
} |
|
|
64 |
# ggsave(filename = paste0(plot_path, "drug_prot_train_CTRPv2_test_All_avg_Abs_by_lineage.pdf"), device = "pdf") |
|
|
65 |
ggsave(plot = g, filename = paste0(plot_path, plot_filename), device = "pdf") |
|
|
66 |
|
|
|
67 |
} |
|
|
68 |
|
|
|
69 |
|
|
|
70 |
model_types <- c("FullModel", "ResponseOnly") |
|
|
71 |
data_types <- c("mut", "exp", "prot", "mirna", "metab", "rppa", "hist") |
|
|
72 |
data_types <- paste0("_", data_types) |
|
|
73 |
data_types <- c("", data_types) |
|
|
74 |
# splits <- c("CELL_LINE", "DRUG", "BOTH") |
|
|
75 |
splits <- c("DRUG") |
|
|
76 |
# bottlenecking <- c("WithBottleNeck", "NoBottleNeck") |
|
|
77 |
bottlenecking <- c("NoBottleNeck") |
|
|
78 |
drug_types <- c("OneHotDrugs") |
|
|
79 |
grid <- expand.grid(model_types, data_types, splits, bottlenecking, drug_types) |
|
|
80 |
|
|
|
81 |
for (i in 1:nrow(grid)) { |
|
|
82 |
plot_grid_mono(model_type = grid[i, 1], data_type = grid[i, 2], split = grid[i, 3], bottleneck = grid[i, 4], drug_type = grid[i, 5]) |
|
|
83 |
} |