|
a |
|
b/R/train_fit_by_epoch.R |
|
|
1 |
# train_fit_by_epoch.R |
|
|
2 |
require(data.table) |
|
|
3 |
require(ggplot2) |
|
|
4 |
require(gganimate) |
|
|
5 |
# devtools::install_github("thomasp85/transformr") |
|
|
6 |
require(transformr) |
|
|
7 |
|
|
|
8 |
all_epochs <- list.files(path = "Data/EpochResults/{'drp_first_layer_size': 998, 'drp_last_layer_size': 21, 'gnn_out_channels': 243}/", full.names = T) |
|
|
9 |
|
|
|
10 |
all_results <- vector(mode = "list", length = length(all_epochs)) |
|
|
11 |
for (i in 1:length(all_epochs)) { |
|
|
12 |
cur_res <- fread(all_epochs[i]) |
|
|
13 |
epoch <- gsub(".+Epoch_(\\d+)_.+", "\\1", all_epochs[i]) |
|
|
14 |
cur_res$epoch <- as.integer(epoch) |
|
|
15 |
all_results[[i]] <- cur_res |
|
|
16 |
} |
|
|
17 |
all_results <- rbindlist(all_results) |
|
|
18 |
|
|
|
19 |
theme_set(theme_bw()) |
|
|
20 |
|
|
|
21 |
p <- ggplot(data = all_results) + |
|
|
22 |
geom_freqpoly(alpha = 0.7, aes(x = predicted), binwidth = 0.01, colour = "red") + |
|
|
23 |
geom_freqpoly(alpha = 0.7, aes(x = target), binwidth = 0.01, colour = "black") + |
|
|
24 |
xlab("Area Above Curve") + ylab("Frequency") |
|
|
25 |
p |
|
|
26 |
|
|
|
27 |
p + transition_time(epoch) + |
|
|
28 |
labs(title = "Epoch: {frame_time}") |
|
|
29 |
|
|
|
30 |
dir.create("Plots/Train_Fit_Animations/") |
|
|
31 |
anim_save("Plots/Train_Fit_Animations/gnndrug_exp_amsgrad_silu_standardization_nobatchnorm.gif") |
|
|
32 |
|
|
|
33 |
rm(epoch) |
|
|
34 |
p <- ggplot(data = all_results, aes(x = predicted, y = target)) + |
|
|
35 |
geom_point() + |
|
|
36 |
geom_abline(intercept = 0, |
|
|
37 |
slope = 1, |
|
|
38 |
color = "red", |
|
|
39 |
size = 2) + |
|
|
40 |
# geom_freqpoly(alpha = 0.7, aes(x = predicted), binwidth = 0.01, colour = "red") + |
|
|
41 |
# geom_freqpoly(alpha = 0.7, aes(x = target), binwidth = 0.01, colour = "black") + |
|
|
42 |
xlab("Predicted") + ylab("Observed") + |
|
|
43 |
transition_time(epoch) + |
|
|
44 |
labs(title = "Epoch: {frame_time}") |
|
|
45 |
|
|
|
46 |
dir.create("Plots/R2_Line/") |
|
|
47 |
anim_save("Plots/R2_Line/gnndrug_exp_amsgrad_silu_standardization_nobatchnorm.gif") |
|
|
48 |
|