a b/R/loss_functions_plot.R
1
# loss_functions_plot.R
2
require(ggplot2)
3
require(cowplot)
4
5
plot_activation_function <- function(f, title, range, target){
6
  ggplot(data.frame(x=range), mapping=aes(x=x)) + 
7
    geom_hline(yintercept=0, colour='red', alpha=1/4) +
8
    geom_vline(xintercept=0, colour='red', alpha=1/4) +
9
    stat_function(fun=f, colour = "dodgerblue3",) +
10
    ggtitle(title) +
11
    theme(text = element_text(size = 16), plot.title = element_text(hjust = 0.5), axis.text.x = element_text(angle=45, hjust = 1)) +
12
    xlab("Prediction") + ylab("Loss") +
13
    scale_x_continuous(breaks=c(0, target, 0.25, 0.5, 0.75, 1))
14
}
15
16
plot_all_activation_functions <- function(f1, f2, f3, title, range, target){
17
  ggplot(data.frame(x=range), mapping=aes(x=x)) + 
18
    geom_hline(yintercept=0, color='red', alpha=1/4) +
19
    geom_vline(xintercept=0, color='red', alpha=1/4) +
20
    stat_function(fun=f1, aes(colour = "MSE")) +
21
    stat_function(fun=f2, aes(colour = "MAE")) +
22
    stat_function(fun=f3, aes(colour = "RMSE")) +
23
    ggtitle(title) +
24
    theme(text = element_text(size = 18, face = "bold"),
25
          plot.title = element_text(hjust = 0.5),
26
          axis.text.x = element_text(angle=45, hjust = 1),
27
          legend.position = "top") +
28
    xlab("Prediction") + ylab("Loss") +
29
    scale_x_continuous(breaks=c(0, target, 0.25, 0.5, 0.75, 1)) +
30
    scale_colour_manual("", 
31
                        breaks = c("MSE", "MAE", "RMSE"),
32
                        values = c("MSE"="dodgerblue3", "MAE"="red3", "RMSE"="orange1"))
33
    
34
}
35
36
set.seed(42)
37
yhats <- runif(32, min=0, max=1)
38
mse <- function(yhats, y=0) {((y - yhats) ** 2)/2}
39
p_mse <- plot_activation_function(mse, 'MSE', c(0,1), 0)
40
  
41
mae <- function(yhats, y=0) {abs(y - yhats)/2}
42
p_mae <- plot_activation_function(mae, 'MAE', c(0,1), 0)
43
44
rmse <- function(yhats, y=0) {sqrt(((y - yhats) ** 2)/2)}
45
p_rmse <- plot_activation_function(rmse, 'RMSE', c(0,1), 0)
46
47
plot_all_activation_functions(mse, mae, rmse, '', c(0,2), 0)
48
ggsave("Plots/All_Loss_Functions_Plot.pdf")
49
50
p_grid <- cowplot::plot_grid(p_mse, p_mae, p_rmse, ncol = 3)
51
  # theme(plot.margin = unit(c(1,31,3,3), "lines"))
52
ggsave("Plots/Loss_Function_Plots.pdf", plot = p_grid)