|
a |
|
b/R/loss_functions_plot.R |
|
|
1 |
# loss_functions_plot.R |
|
|
2 |
require(ggplot2) |
|
|
3 |
require(cowplot) |
|
|
4 |
|
|
|
5 |
plot_activation_function <- function(f, title, range, target){ |
|
|
6 |
ggplot(data.frame(x=range), mapping=aes(x=x)) + |
|
|
7 |
geom_hline(yintercept=0, colour='red', alpha=1/4) + |
|
|
8 |
geom_vline(xintercept=0, colour='red', alpha=1/4) + |
|
|
9 |
stat_function(fun=f, colour = "dodgerblue3",) + |
|
|
10 |
ggtitle(title) + |
|
|
11 |
theme(text = element_text(size = 16), plot.title = element_text(hjust = 0.5), axis.text.x = element_text(angle=45, hjust = 1)) + |
|
|
12 |
xlab("Prediction") + ylab("Loss") + |
|
|
13 |
scale_x_continuous(breaks=c(0, target, 0.25, 0.5, 0.75, 1)) |
|
|
14 |
} |
|
|
15 |
|
|
|
16 |
plot_all_activation_functions <- function(f1, f2, f3, title, range, target){ |
|
|
17 |
ggplot(data.frame(x=range), mapping=aes(x=x)) + |
|
|
18 |
geom_hline(yintercept=0, color='red', alpha=1/4) + |
|
|
19 |
geom_vline(xintercept=0, color='red', alpha=1/4) + |
|
|
20 |
stat_function(fun=f1, aes(colour = "MSE")) + |
|
|
21 |
stat_function(fun=f2, aes(colour = "MAE")) + |
|
|
22 |
stat_function(fun=f3, aes(colour = "RMSE")) + |
|
|
23 |
ggtitle(title) + |
|
|
24 |
theme(text = element_text(size = 18, face = "bold"), |
|
|
25 |
plot.title = element_text(hjust = 0.5), |
|
|
26 |
axis.text.x = element_text(angle=45, hjust = 1), |
|
|
27 |
legend.position = "top") + |
|
|
28 |
xlab("Prediction") + ylab("Loss") + |
|
|
29 |
scale_x_continuous(breaks=c(0, target, 0.25, 0.5, 0.75, 1)) + |
|
|
30 |
scale_colour_manual("", |
|
|
31 |
breaks = c("MSE", "MAE", "RMSE"), |
|
|
32 |
values = c("MSE"="dodgerblue3", "MAE"="red3", "RMSE"="orange1")) |
|
|
33 |
|
|
|
34 |
} |
|
|
35 |
|
|
|
36 |
set.seed(42) |
|
|
37 |
yhats <- runif(32, min=0, max=1) |
|
|
38 |
mse <- function(yhats, y=0) {((y - yhats) ** 2)/2} |
|
|
39 |
p_mse <- plot_activation_function(mse, 'MSE', c(0,1), 0) |
|
|
40 |
|
|
|
41 |
mae <- function(yhats, y=0) {abs(y - yhats)/2} |
|
|
42 |
p_mae <- plot_activation_function(mae, 'MAE', c(0,1), 0) |
|
|
43 |
|
|
|
44 |
rmse <- function(yhats, y=0) {sqrt(((y - yhats) ** 2)/2)} |
|
|
45 |
p_rmse <- plot_activation_function(rmse, 'RMSE', c(0,1), 0) |
|
|
46 |
|
|
|
47 |
plot_all_activation_functions(mse, mae, rmse, '', c(0,2), 0) |
|
|
48 |
ggsave("Plots/All_Loss_Functions_Plot.pdf") |
|
|
49 |
|
|
|
50 |
p_grid <- cowplot::plot_grid(p_mse, p_mae, p_rmse, ncol = 3) |
|
|
51 |
# theme(plot.margin = unit(c(1,31,3,3), "lines")) |
|
|
52 |
ggsave("Plots/Loss_Function_Plots.pdf", plot = p_grid) |