--- a +++ b/R/activation_functions_plot.R @@ -0,0 +1,49 @@ +# activation_functions_plot.R +require(ggplot2) +require(cowplot) + +plot_activation_function <- function(f, title, range){ + ggplot(data.frame(x=range), mapping=aes(x=x)) + + geom_hline(yintercept=0, color='red', alpha=1/4) + + geom_vline(xintercept=0, color='red', alpha=1/4) + + stat_function(fun=f, colour = "dodgerblue3") + + ggtitle(title) + + scale_x_continuous(name='x') + + scale_y_continuous(name='f(x)') + + theme(plot.title = element_text(hjust = 0.5)) +} + + + +relu <- function(x) {ifelse(x < 0 , 0, x )} +p_relu <- plot_activation_function(relu, 'ReLU', c(-8,8)) + +lrelu <- function(x){ ifelse(x < 0 , 0.01 *x , x )} +p_lrelu <- plot_activation_function(lrelu, 'LReLU', c(-8,8)) + +softplus <- function(x){ log(1 + exp(x))} +p_softplus <- plot_activation_function(softplus, 'SoftPlus', c(-8,8)) + +prelu <- function(x, a=0.5) {ifelse(x < 0 , a * x , x )} +p_prelu <- plot_activation_function(prelu, 'PReLU (a=0.5)', c(-8,8)) + +elu <- function(x, a=0.7) {ifelse(x < 0 , a*(exp(x)-1) , x )} +p_elu <- plot_activation_function(elu, 'ELU (a=0.5)', c(-8,8)) + +selu <- function(x, a=1.6732632423543772848170429916717, l=1.0507009873554804934193349852946) {ifelse(x < 0 , l*(a*(exp(x)-1)) , x )} +p_selu <- plot_activation_function(selu, 'SELU (a=1.67, l=1.05)', c(-8,8)) + +sigmoid = function(x) { + 1 / (1 + exp(-x)) +} + +swish <- function(x) {x*sigmoid(x)} +p_swish <- plot_activation_function(swish, 'Swish/SiLU', c(-8,8)) + + +gelu <- function(x) {x*pnorm(x)} +p_gelu <- plot_activation_function(gelu, 'GELU', c(-8,8)) + +p_grid <- cowplot::plot_grid(p_relu, p_lrelu, p_prelu, p_elu, p_selu, p_softplus, p_swish, p_gelu, ncol = 4) + + theme(plot.margin = unit(c(3,3,3,3), "lines")) +ggsave("Plots/ReLU_Variant_Plots.pdf", plot = p_grid)