a b/R/loss_by_lineage.R
1
# loss_by_lineage.R
2
require(data.table)
3
require(ggplot2)
4
5
6
plot_loss_by_lineage <- function(path,
7
                                 plot_path, cell_line_data, title, subtitle, plot_filename, display_plot = FALSE) {
8
  
9
  cv_results <- fread(paste0(path, "CV_results.csv"))
10
  cv_valid_loss <- cv_results[V1 == "avg_cv_valid_loss"][,2]
11
  cv_valid_loss <- format(round(cv_valid_loss, 4), nsmall = 4)
12
  ctrp_data <- fread(paste0(path, "CTRP_AAC_MORGAN_1024_inference_results.csv"))
13
  ctrp_data <- merge(ctrp_data, cell_line_data[, c("stripped_cell_line_name", "lineage")], by.x = "cell_name", by.y = "stripped_cell_line_name")
14
  # gdsc1_data <- fread(paste0(path, "GDSC1_AAC_MORGAN_1024_inference_results.csv"))
15
  # gdsc1_data <- merge(gdsc1_data, cell_line_data[, c("stripped_cell_line_name", "lineage")], by.x = "cell_name", by.y = "stripped_cell_line_name")
16
  # gdsc2_data <- fread(paste0(path, "GDSC2_AAC_MORGAN_1024_inference_results.csv"))
17
  # gdsc2_data <- merge(gdsc2_data, cell_line_data[, c("stripped_cell_line_name", "lineage")], by.x = "cell_name", by.y = "stripped_cell_line_name")
18
  
19
  # ctrp_data[, abs_loss := sqrt(MSE_loss)]
20
  ctrp_data[, lineage_loss_avg := mean(RMSE_loss), by = "lineage"]
21
  ctrp_data[, lineage_loss_sd := sd(RMSE_loss), by = "lineage"]
22
  ctrp_data[, sample_by_lineage_count := .N, by = "lineage"]
23
  ctrp_avg_abs_by_lineage <- unique(ctrp_data[, c("lineage", "lineage_loss_avg", "lineage_loss_sd")])
24
  ctrp_avg_abs_by_lineage$Dataset <- "CTRPv2"
25
  
26
  # gdsc1_data[, lineage_loss_avg := mean(RMSE_loss), by = "lineage"]
27
  # gdsc1_data[, lineage_loss_sd := sd(RMSE_loss), by = "lineage"]
28
  # gdsc1_data[, sample_by_lineage_count := .N, by = "lineage"]
29
  # gdsc1_avg_abs_by_lineage <- unique(gdsc1_data[, c("lineage", "lineage_loss_avg", "lineage_loss_sd")])
30
  # gdsc1_avg_abs_by_lineage$Dataset <- "GDSC1"
31
  # 
32
  # gdsc2_data[, lineage_loss_avg := mean(RMSE_loss), by = "lineage"]
33
  # gdsc2_data[, lineage_loss_sd := sd(RMSE_loss), by = "lineage"]
34
  # gdsc2_data[, sample_by_lineage_count := .N, by = "lineage"]
35
  # gdsc2_avg_abs_by_lineage <- unique(gdsc2_data[, c("lineage", "lineage_loss_avg", "lineage_loss_sd")])
36
  # gdsc2_avg_abs_by_lineage$Dataset <- "GDSC2"
37
  
38
  # all_avg_abs_by_lineage <- rbindlist(list(ctrp_avg_abs_by_lineage, gdsc1_avg_abs_by_lineage, gdsc2_avg_abs_by_lineage))
39
  all_avg_abs_by_lineage <- ctrp_avg_abs_by_lineage
40
  all_avg_abs_by_lineage <- merge(all_avg_abs_by_lineage, unique(ctrp_data[, c("lineage", "sample_by_lineage_count")]))
41
  all_avg_abs_by_lineage$lineage <- paste0(all_avg_abs_by_lineage$lineage, ", n = ", all_avg_abs_by_lineage$sample_by_lineage_count)
42
  
43
  g <- ggplot(data = all_avg_abs_by_lineage, mapping = aes(x = reorder(lineage, -lineage_loss_avg), y = lineage_loss_avg, fill = Dataset)) +
44
    geom_bar(stat = "identity", position = position_dodge()) +
45
    # geom_errorbar(aes(ymin = lineage_loss_avg - lineage_loss_sd, ymax = lineage_loss_avg + lineage_loss_sd), width = 0.2, position = position_dodge(0.9)) +
46
    theme(axis.text.x = element_text(angle = 45, hjust = 1)) + 
47
    geom_hline(yintercept = mean(ctrp_data$lineage_loss_avg), linetype="dashed", color = "red") +
48
    # geom_text(aes(10, mean(ctrp_data$abs_loss),label = mean(ctrp_data$abs_loss), vjust = -1)) +
49
    # geom_hline(yintercept = mean(gdsc1_data$lineage_loss_avg), linetype="dashed", color = "green") +
50
    # geom_hline(yintercept = mean(gdsc2_data$lineage_loss_avg), linetype="dashed", color = "blue") +
51
    xlab("Cell Line Lineage + # testing datapoints") + ylab("RMSE Loss") + 
52
    # scale_y_discrete(limits = c("0.001", "0.002")) +
53
    scale_y_continuous(breaks = sort(c(seq(0, 0.25, length.out=10),
54
                                       c(mean(ctrp_data$lineage_loss_avg)
55
                                         # mean(gdsc1_data$lineage_loss_avg),
56
                                         # mean(gdsc2_data$lineage_loss_avg)
57
                                       )
58
    ))) +
59
    # ggtitle(label = "Full DRP Mean Absolute Loss by Cell Line Lineage", subtitle = "Data: Drug + Proteomics | Trained on CTRPv2 | Tested on All 3")
60
    ggtitle(label = title, subtitle = paste0(subtitle, "\nAverage Cross-Validation RMSE Loss:", as.character(cv_valid_loss)))
61
  if (display_plot == TRUE) {
62
    print(g)
63
  }
64
  # ggsave(filename = paste0(plot_path, "drug_prot_train_CTRPv2_test_All_avg_Abs_by_lineage.pdf"), device = "pdf")
65
  ggsave(plot = g, filename = paste0(plot_path, plot_filename), device = "pdf")
66
  
67
}
68
69
70
model_types <- c("FullModel", "ResponseOnly")
71
data_types <- c("mut", "exp", "prot", "mirna", "metab", "rppa", "hist")
72
data_types <- paste0("_", data_types)
73
data_types <- c("", data_types)
74
# splits <- c("CELL_LINE", "DRUG", "BOTH")
75
splits <- c("DRUG")
76
# bottlenecking <- c("WithBottleNeck", "NoBottleNeck")
77
bottlenecking <- c("NoBottleNeck")
78
drug_types <- c("OneHotDrugs")
79
grid <- expand.grid(model_types, data_types, splits, bottlenecking, drug_types)
80
81
for (i in 1:nrow(grid)) {
82
  plot_grid_mono(model_type = grid[i, 1], data_type = grid[i, 2], split = grid[i, 3], bottleneck = grid[i, 4], drug_type = grid[i, 5])
83
}