a b/R/train_fit_by_epoch.R
1
# train_fit_by_epoch.R
2
require(data.table)
3
require(ggplot2)
4
require(gganimate)
5
# devtools::install_github("thomasp85/transformr")
6
require(transformr)
7
8
all_epochs <- list.files(path = "Data/EpochResults/{'drp_first_layer_size': 998, 'drp_last_layer_size': 21, 'gnn_out_channels': 243}/", full.names = T)
9
10
all_results <- vector(mode = "list", length = length(all_epochs))
11
for (i in 1:length(all_epochs)) {
12
  cur_res <- fread(all_epochs[i])
13
  epoch <- gsub(".+Epoch_(\\d+)_.+", "\\1", all_epochs[i])
14
  cur_res$epoch <- as.integer(epoch)
15
  all_results[[i]] <- cur_res
16
}
17
all_results <- rbindlist(all_results)
18
19
theme_set(theme_bw())
20
21
p <- ggplot(data = all_results) +
22
  geom_freqpoly(alpha = 0.7, aes(x = predicted), binwidth = 0.01, colour = "red") +
23
  geom_freqpoly(alpha = 0.7, aes(x = target), binwidth = 0.01, colour = "black") + 
24
  xlab("Area Above Curve") + ylab("Frequency")
25
p
26
27
p + transition_time(epoch) +
28
  labs(title = "Epoch: {frame_time}")
29
30
dir.create("Plots/Train_Fit_Animations/")
31
anim_save("Plots/Train_Fit_Animations/gnndrug_exp_amsgrad_silu_standardization_nobatchnorm.gif")
32
33
rm(epoch)
34
p <- ggplot(data = all_results, aes(x = predicted, y = target)) +
35
  geom_point() +
36
  geom_abline(intercept = 0,
37
              slope = 1,
38
              color = "red",
39
              size = 2) +
40
  # geom_freqpoly(alpha = 0.7, aes(x = predicted), binwidth = 0.01, colour = "red") +
41
  # geom_freqpoly(alpha = 0.7, aes(x = target), binwidth = 0.01, colour = "black") + 
42
  xlab("Predicted") + ylab("Observed") +
43
  transition_time(epoch) +
44
  labs(title = "Epoch: {frame_time}")
45
46
dir.create("Plots/R2_Line/")
47
anim_save("Plots/R2_Line/gnndrug_exp_amsgrad_silu_standardization_nobatchnorm.gif")
48