--- a +++ b/R/callbacks.R @@ -0,0 +1,1077 @@ +#' Create model card +#' +#' Log information about model, hyperparameters, generator options, training data, scores etc +#' +#' @param model_card_path Directory for model card logs. +#' @param run_name Name of training run. +#' @param argumentList List of training arguments. +#' @examplesIf reticulate::py_module_available("tensorflow") +#' model_card_cb <- function(model_card_path = NULL, run_name, argumentList) +#' mc <- model_card_cb(model_card_path = tempdir(), run_name = 'run_1', +#' argumentList = list(learning_rate = 0.01)) +#' +#' @returns Keras callback writing model cards every epoch. +#' @export +model_card_cb <- function(model_card_path = NULL, run_name, argumentList) { + + model_card_cb_py_class <- reticulate::PyClass("model_card_cb", + inherit = tensorflow::tf$keras$callbacks$Callback, + list( + + `__init__` = function(self, model_card_path, run_name) { + self$model_card_path <- model_card_path + self$start_time <- Sys.time() + self$mc_dir <- file.path(model_card_path, run_name) + self$param_list <- list() + self$argumentList <- argumentList + NULL + }, + + # collect all data + on_train_begin = function(self, logs) { + + if (!dir.exists(self$mc_dir)) { + dir.create(self$mc_dir) + } else { + #stop("Directory already exists. Change run_name") + } + + self$param_list <- self$model$hparam + self$param_list$train_model_args <- argumentList + for (n in names(self$param_list$train_model_args)) { + self$param_list$train_model_args[[n]] <- eval(self$param_list$train_model_args[[n]]) + } + self$param_list$train_model_args[["model"]] <- NULL + self$param_list$model_summary <- summary(self$model) + self$param_list$training_start_time <- format(self$start_time, "%a %b %d %X %Y") + + gpu_info <- tensorflow::tf$config$list_physical_devices('GPU') + self$param_list$gpu_info[["number GPUs"]] <- length(gpu_info) + if (length(gpu_info) > 0) { + for (i in 1:length(gpu_info)) { + self$param_list$gpu_info[[paste0("GPU", i)]] <- + tensorflow::tf$config$experimental$get_device_details( + gpu_info[[i]] + ) + } + } + + saveRDS(self$param_list, paste0(self$mc_dir, "/epoch_0_param_list.rds")) + + }, + + # update training scores + on_epoch_end = function(self, epoch, logs) { + time_passed <- as.double(difftime(Sys.time(), self$start_time, units = "secs")) + self$param_list[["training_time"]] <- time_passed + + if (epoch == 0) { + m <- unlist(logs) + m <- c(m, epoch, time_passed) %>% matrix(nrow = 1) %>% as.data.frame() + names(m) <- c(names(logs), "processing_step", "time") + self$param_list[["logs"]] <- m + } else { + m <- unlist(logs) + m <- c(m, epoch, time_passed) %>% matrix(nrow = 1) %>% as.data.frame() + names(m) <- c(names(logs), "processing_step", "time") + m <- rbind(self$param_list[["logs"]], m) + self$param_list[["logs"]] <- reticulate::r_to_py(m) + } + + saveRDS(self$param_list, paste0(self$mc_dir, "/epoch_", epoch + 1, "_param_list.rds")) + } + + )) + + model_card_cb_py_class(model_card_path = model_card_path, + run_name = run_name) + +} + + + +#' Stop training callback +#' +#' Stop training after specified time. +#' +#' @param stop_time Time in seconds after which to stop training. +#' @examplesIf reticulate::py_module_available("tensorflow") +#' est <- early_stopping_time_cb(stop_time = 60) +#' +#' @returns A Keras callback that stops training after specified time. +#' @export +early_stopping_time_cb <- function(stop_time = NULL) { + + early_stopping_time_cb_py_class <- reticulate::PyClass("early_stopping_time_cb", + inherit = tensorflow::tf$keras$callbacks$Callback, + list( + + `__init__` = function(self, stop_time) { + self$start_time <- Sys.time() + self$stop_time <- stop_time + NULL + }, + + on_batch_end = function(self, epoch, logs) { + time_passed <- as.double(difftime(Sys.time(), self$start_time, units = "secs")) + if (time_passed > self$stop_time) { + self$model$stop_training <- TRUE + } + } + + )) + + early_stopping_time_cb_py_class(stop_time = stop_time) + +} + +#' Early stopping callback +#' +#' @param early_stopping_time Time in seconds after which to stop training. +#' @param early_stopping_patience Stop training if val_loss does not improve for \code{early_stopping_patience}. +#' @param by_time Whether to use time or patience as metric. +#' @returns Keras callback; stop training after specified time. +#' @noRd +early_stopping_cb <- function(early_stopping_patience = 0, early_stopping_time, by_time = TRUE) { + + if (by_time) { + early_stopping_time_cb(stop_time = early_stopping_time) + } else { + keras::callback_early_stopping(patience = early_stopping_patience) + } +} + +#' Log callback +#' +#' @param path_log Path to output directory. +#' @param run_name Name of output file is run_name + ".csv". +#' @returns Keras callback, writes epoch scores to csv file. +#' @noRd +log_cb <- function(path_log, run_name) { + keras::callback_csv_logger( + paste0(path_log, "/", run_name, ".csv"), + separator = ";", + append = TRUE) +} + +#' Learning_rate callback +#' +#' @inheritParams train_model +#' @returns Keras callback, reduces learning rate. +#' @noRd +reduce_lr_cb <- function(patience, + cooldown, + lr_plateau_factor, + monitor = "val_acc") { + keras::callback_reduce_lr_on_plateau( + monitor = monitor, + factor = lr_plateau_factor, + patience = patience, + cooldown = cooldown) +} + +#' Checkpoint callback +#' +#' @inheritParams train_model +#' @returns Keras callback, store model checkpoint. +#' @noRd +checkpoint_cb <- function(filepath_checkpoints, + save_weights_only, + save_best_only, + save_freq, + monitor = "val_loss") { + + if (is.logical(save_best_only)) { + if (save_best_only) { + warning("save_best_only should not be boolean variabel, but list or NULL. Using val_loss as monitor.") + save_best_only <- list(monitor = "val_loss") + } else { + warning("save_best_only should not be boolean variabel, but list or NULL.") + save_best_only <- NULL + } + } + + if (is.null(save_best_only) | !is.null(save_best_only$monitor)) { + + keras::callback_model_checkpoint(filepath = filepath_checkpoints, + save_weights_only = save_weights_only, + save_best_only = !is.null(save_best_only), + verbose = 1, + save_freq = "epoch", + monitor = monitor) + + } else { + + cp_cb <- reticulate::PyClass("cp_cb", + inherit = tensorflow::tf$keras$callbacks$Callback, + list( + + `__init__` = function(self, filepath_checkpoints, save_freq, save_weights_only) { + self$filepath_checkpoints <- filepath_checkpoints + self$save_freq <- save_freq + self$save_weights_only <- save_weights_only + NULL + }, + + on_epoch_end = function(self, epoch, logs) { + if ((epoch + 1) %% self$save_freq == 0) { + + formatted_path <- gsub("\\{epoch:03d\\}", sprintf("%03d", epoch + 1), self$filepath_checkpoints) + formatted_path <- gsub("\\{val_loss:.2f\\}", sprintf("%.2f", logs$val_loss), formatted_path) + formatted_path <- gsub("\\{val_acc:.3f\\}", sprintf("%.3f", logs$val_acc), formatted_path) + print(formatted_path) + if (self$save_weights_only) { + keras::save_model_hdf5(self$model, formatted_path) + } else { + keras::save_model_weights_hdf5(self$model, formatted_path) + } + } + } + + )) + + return(cp_cb(filepath_checkpoints = filepath_checkpoints, + save_freq = save_best_only$save_freq, + save_weights_only = save_weights_only)) + + } + +} + +#' Non model hyperparameter callback +#' +#' Get hyperparameters excluding model parameters. +#' +#' @inheritParams train_model +#' @returns Keras callback, track model hyperparameters. +#' @noRd +hyper_param_model_outside_cb <- function(path_tensorboard, run_name, wavenet_format, cnn_format, model, vocabulary, path, reverse_complement, + vocabulary_label, maxlen, epochs, max_queue_size, lr_plateau_factor, batch_size, + patience, cooldown, steps_per_epoch, step, shuffle_file_order) { + + train_hparams <- list( + run_name = run_name, + vocabulary = paste(vocabulary, collapse = ","), + path = paste(unlist(path), collapse = ", "), + reverse_complement = paste(reverse_complement), + vocabulary_label = paste(vocabulary_label, collapse = ", "), + epochs = epochs, + max_queue_size = max_queue_size, + lr_plateau_factor = lr_plateau_factor, + batch_size = batch_size, + patience = patience, + cooldown = cooldown, + steps_per_epoch = steps_per_epoch, + step = step, + shuffle_file_order = shuffle_file_order + ) + #hparams$update(model$hparam) + model_hparams <- vector("list") + for (i in names(model$hparam)) { + model_hparams[[i]] <- model$hparam[[i]] + } + + hparams_R <- c(train_hparams, model_hparams) + + keep_entry_index <- rep(TRUE, length(hparams_R)) + for (i in 1:length(hparams_R)) { + + if (length(hparams_R[[i]]) == 0) { + keep_entry_index[i] <- FALSE + } + + if (length(hparams_R[[i]]) > 1) { + hparams_R[[i]] <- paste(hparams_R[[i]], collapse = " ") + } + } + hparams_R <- hparams_R[keep_entry_index] + + hparams <- reticulate::dict(hparams_R) + hp <- reticulate::import("tensorboard.plugins.hparams.api") + hp$KerasCallback(file.path(path_tensorboard, run_name), hparams, trial_id = run_name) +} + +#' Model hyperparameter callback +#' +#' Get model hyperparameters. +#' +#' @inheritParams train_model +#' @returns Keras callback, track training hyperparameters. +#' @noRd +hyper_param_with_model_cb <- function(default_arguments, model, path_tensorboard, run_name, train_type, path, train_val_ratio, batch_size, + epochs, max_queue_size, lr_plateau_factor, + patience, cooldown, steps_per_epoch, step, shuffle_file_order, initial_epoch, vocabulary, learning_rate, + shuffle_input, vocabulary_label, solver, file_limit, reverse_complement, wavenet_format, cnn_format) { + + model_hparam <- vector("list") + model_hparam_names <- vector("list") + for (i in 1:length(default_arguments)) { + if (is.null(default_arguments[[i]])) { + model_hparam[i] <- "NULL" + } else { + model_hparam[i] <- default_arguments[i] + } + } + names(model_hparam) <- names(default_arguments) + # hparam from train_model + learning_rate <- keras::k_eval(model$optimizer$lr) + solver <- stringr::str_to_lower(model$optimizer$get_config()["name"]) + + train_hparam_names <- c("train_type", "path", "train_val_ratio", "run_name", "batch_size", "epochs", "max_queue_size", "lr_plateau_factor", + "patience", "cooldown", "steps_per_epoch", "step", "shuffle_file_order", "initial_epoch", "vocabulary", "learning_rate", + "shuffle_input", "vocabulary_label", "solver", "file_limit", "reverse_complement", "wavenet_format", "cnn_format") + train_hparam <- vector("list") + for (i in 1:length(train_hparam_names)) { + if (is.null(eval(parse(text=train_hparam_names[i])))) { + train_hparam[[i]] <- "NULL" + } else if (length(eval(parse(text=train_hparam_names[i])) > 1)) { + train_hparam[[i]] <- toString(eval(parse(text=train_hparam_names[i]))) + if (length(train_hparam[[i]]) > 1) { + train_hparam[[i]] <- paste(train_hparam[[i]], collapse = " ") + } + } else { + train_hparam[[i]] <- eval(parse(text=train_hparam_names[i])) + if (length(train_hparam[[i]]) > 1) { + train_hparam[[i]] <- paste(train_hparam[[i]], collapse = " ") + } + } + } + names(train_hparam) <- train_hparam_names + hparams_R <- c(train_hparam, model_hparam) + hparams <- reticulate::dict(hparams_R) + hp <- reticulate::import("tensorboard.plugins.hparams.api") + return(hp$KerasCallback(file.path(path_tensorboard, run_name), hparams, trial_id = run_name)) +} + +#' Tensorboard callback +#' +#' @inheritParams train_model +#' @returns Keras callback, write tensorboard logs. +#' @noRd +tensorboard_cb <- function(path_tensorboard, run_name) { + keras::callback_tensorboard(file.path(path_tensorboard, run_name), + write_graph = TRUE, + histogram_freq = 1, + write_images = TRUE, + write_grads = TRUE) +} + +#' Function arguments callback +#' +#' Print train_model call in text field of tensorboard. +#' +#' @inheritParams train_model +#' @param argumentList List of function arguments. +#' @returns Keras callback, track arguments of `train_model` function. +#' @noRd +function_args_cb <- function(argumentList, path_tensorboard, run_name) { + + argAsChar <- as.character(argumentList) + argText <- vector("character") + if (length(argumentList$path) > 1) { + + argsInQuotes <- c("path_checkpoint", "run_name", "solver", "format", "output_format", + "path_tensorboard", "path_file_log", "train_type", "ambiguous_nuc", "added_label_path", "added_label_names", + "train_val_split_csv", "target_from_csv") + } else { + argsInQuotes <- c("path", "path_val", "path_checkpoint", "run_name", "solver", "output_format", + "path_tensorboard", "path_file_log", "train_type", "ambiguous_nuc", "format", "added_label_path", "added_label_names", + "train_val_split_csv", "target_from_csv") + } + argText[1] <- "train_model(" + for (i in 2:(length(argumentList) - 1)) { + arg <- argAsChar[[i]] + if (names(argumentList)[i] %in% argsInQuotes) { + if (arg == "NULL") { + argText[i] <- paste0(names(argumentList)[i], " = ", arg, ",") + } else { + argText[i] <- paste0(names(argumentList)[i], " = ", '\"', arg, '\"', ",") + } + } else { + argText[i] <- paste0(names(argumentList)[i], " = ", arg, ",") + } + } + i <- length(argumentList) + if (names(argumentList)[i] %in% argsInQuotes) { + if (arg == "NULL") { + argText[i] <- paste0(names(argumentList)[i], " = ", argAsChar[[i]], ")") + } else { + argText[i] <- paste0(names(argumentList)[i], " = ", '\"', argAsChar[[i]], '\"', ")") + } + } else { + argText[i] <- paste0(names(argumentList)[i], " = ", argAsChar[[i]], ")") + } + + # write function arguments as text in tensorboard + trainArguments <- keras::callback_lambda( + on_train_begin = function(logs) { + file.writer <- tensorflow::tf$summary$create_file_writer(file.path(path_tensorboard, run_name)) + file.writer$set_as_default() + tensorflow::tf$summary$text(name="Arguments", data = argText, step = 0L) + file.writer$flush() + } + ) + trainArguments +} + +#' Tensorboard callback wrapper +#' +#' @inheritParams train_model +#' @returns Keras callback, wrapper for all callbacks involving tensorboard. +#' @noRd +tensorboard_complete_cb <- function(default_arguments, model, path_tensorboard, run_name, train_type, path, train_val_ratio, batch_size, + epochs, max_queue_size, lr_plateau_factor, patience, cooldown, steps_per_epoch, step, shuffle_file_order, initial_epoch, vocabulary, learning_rate, + shuffle_input, vocabulary_label, solver, file_limit, reverse_complement, wavenet_format, cnn_format, create_model_function, vocabulary_size, gen_cb, + argumentList, maxlen, labelGen, labelByFolder, vocabulary_label_size, tb_images = FALSE, stateful, target_middle, num_train_files, path_file_log, + proportion_per_seq, skip_amb_nuc, max_samples, proportion_entries, train_with_gen, count_files = TRUE) { + l <- vector("list") + + l[[1]] <- hyper_param_model_outside_cb(path_tensorboard = path_tensorboard, run_name = run_name, wavenet_format = wavenet_format, cnn_format = cnn_format, model = model, + vocabulary = vocabulary, path = path, reverse_complement = reverse_complement, vocabulary_label = vocabulary_label, + maxlen = maxlen, epochs = epochs, max_queue_size = max_queue_size, lr_plateau_factor = lr_plateau_factor, + batch_size = batch_size, patience = patience, cooldown = cooldown, steps_per_epoch = steps_per_epoch, + step = step, shuffle_file_order = shuffle_file_order) + + l[[2]] <- tensorboard_cb(path_tensorboard = path_tensorboard, run_name = run_name) + l[[3]] <- function_args_cb(argumentList = argumentList, path_tensorboard = path_tensorboard, run_name = run_name) + + if (train_with_gen & count_files) { + + proportion_training_files_cb <- reticulate::PyClass("proportion_training_files_cb", + inherit = tensorflow::tf$keras$callbacks$Callback, + list( + + `__init__` = function(self, num_train_files, path_file_log, path_tensorboard, run_name, vocabulary_label, + path, train_type, start_index, proportion_per_seq, max_samples, step, + proportion_entries) { + self$num_train_files <- num_train_files + self$path_file_log <- path_file_log + self$path_tensorboard <- path_tensorboard + self$run_name <- run_name + self$vocabulary_label <- vocabulary_label + self$path <- path + self$train_type <- train_type + self$proportion_per_seq <- proportion_per_seq + self$max_samples <- max_samples + self$step <- step + self$start_index <- 1 + self$first_epoch <- TRUE + self$description <- "" + self$proportion_entries <- proportion_entries + NULL + }, + + on_epoch_end = function(self, epoch, logs) { + if (is.null(self$proportion_entries)) self$proportion_entries <- 1 + file.writer <- tensorflow::tf$summary$create_file_writer(file.path(self$path_tensorboard, self$run_name)) + file.writer$set_as_default() + files_used <- utils::read.csv(self$path_file_log, stringsAsFactors = FALSE, header = FALSE) + if (self$train_type == "label_folder") { + if (self$first_epoch) { + if (length(self$step) == 1) self$step <- rep(self$step, length(vocabulary_label)) + if (length(self$proportion_per_seq) == 1) { + self$proportion_per_seq <- rep(self$proportion_per_seq, length(self$vocabulary_label)) + } + if (length(max_samples) == 1) self$max_samples <- rep(max_samples, length(vocabulary_label)) + + for (i in 1:length(self$vocabulary_label)) { + if (is.null(self$max_samples)) { + self$description[i] <- paste0("Using step size ", self$step[i], ", proportion_entries ", + self$proportion_entries * 100, "% and ", + ifelse(is.null(self$proportion_per_seq[i]), 1, + self$proportion_per_seq[i]) * 100, "% per sequence") + } else { + self$description[i] <- paste0("Using step size ", self$step[i], ", ", + ifelse(is.null(self$proportion_per_seq[i]), 1, + self$proportion_per_seq[i]) * 100, "% per sequence, maximum of ", + self$max_samples[i], " samples per file and proportion_entries ", + self$proportion_entries * 100, "%") + } + } + self$first_epoch <- FALSE + } + + for (i in 1:length(self$vocabulary_label)) { + files_of_class <- sum(stringr::str_detect( + files_used[ , 1], paste(unlist(self$path[[i]]), collapse = "|") + )) + files_percentage <- 100 * files_of_class/self$num_train_files[i] + tensorflow::tf$summary$scalar(name = paste0("training files seen (%): '", + self$vocabulary_label[i], "'"), data = files_percentage, step = epoch, + description = self$description[i]) + } + } else { + files_percentage <- 100 * nrow(files_used)/self$num_train_files + if (is.null(self$max_samples)) { + description <- paste0("Using step size ", step, + ", proportion_entries ", self$proportion_entries * 100, "% and ", + ifelse(is.null(self$proportion_per_seq), 1, + self$proportion_per_seq) * 100, "% per sequence") + } else { + description <- paste0("Using step size ", step, ", ", + ifelse(is.null(self$proportion_per_seq), 1, + self$proportion_per_seq) * 100, "% per sequence, maximum of ", + self$max_samples, " samples per file and proportion_entries ", + self$proportion_entries * 100, "%") + + } + if (self$train_type == "label_rds") { + description <- paste0("Using step size ", + ifelse(is.null(self$proportion_per_seq), 1, + self$proportion_per_seq) * 100, "% per sequence and maximum of ", + self$max_samples, " samples per file.") + } + tensorflow::tf$summary$scalar(name = paste("training files seen (%)"), data = files_percentage, step = epoch, + description = description) + } + + file.writer$flush() + } + + )) + + + l[[4]] <- proportion_training_files_cb(num_train_files = num_train_files, path_file_log = path_file_log, path_tensorboard = path_tensorboard, run_name = run_name, + vocabulary_label = vocabulary_label, path = path, train_type = train_type, proportion_per_seq = proportion_per_seq, + max_samples = max_samples, step = step, proportion_entries = proportion_entries) + #names(l) <- c("hyper_param_model_outside", "tensorboard", "function_args","proportion_training_files") + } else { + #names(l) <- c("hyper_param_model_outside", "tensorboard", "function_args") + } + return(l) +} + +#' Reset states callback +#' +#' Reset states at start/end of validation and whenever file changes. Can be used for stateful LSTM. +#' +#' @param path_file_log Path to log of training files. +#' @param path_file_logVal Path to log of validation files. +#' @examplesIf reticulate::py_module_available("tensorflow") +#' rs <- reset_states_cb(path_file_log = tempfile(), path_file_logVal = tempfile()) +#' +#' @returns A keras callback that resets states of LSTM layers. +#' @export +reset_states_cb <- function(path_file_log, path_file_logVal) { + + reset_states_cb_py_class <- reticulate::PyClass("reset_states_cb", + inherit = tensorflow::tf$keras$callbacks$Callback, + list( + + `__init__` = function(self, path_file_log, path_file_logVal) { + self$path_file_log <- path_file_log + self$path_file_logVal <- path_file_logVal + self$num_files_old <- 0 + self$num_files_new <- 0 + self$num_files_old_val <- 0 + self$num_files_new_val <- 0 + NULL + }, + + on_test_begin = function(self, epoch, logs) { + self$model$reset_states() + }, + + on_test_end = function(self, epoch, logs) { + self$model$reset_states() + }, + + on_train_batch_begin = function(self, batch, logs) { + files_used <- readLines(self$path_file_log) + self$num_files_new <- length(files_used) + if (self$num_files_new > self$num_files_old) { + self$model$reset_states() + self$num_files_old <- self$num_files_new + } + }, + + on_test_batch_begin = function(self, batch, logs) { + files_used <- readLines(self$path_file_logVal) + self$num_files_new_val <- length(files_used) + if (self$num_files_new_val > self$num_files_old_val) { + self$model$reset_states() + self$num_files_old_val <- self$num_files_new_val + } + } + + )) + + reset_states_cb_py_class(path_file_log = path_file_log, path_file_logVal = path_file_logVal) +} + +#' Validation after training callback +#' +#' Do validation only once at end of training. +#' +#' @param gen.val Validation generator +#' @param validation_steps Number of validation steps. +#' @examplesIf reticulate::py_module_available("tensorflow") +#' maxlen <- 20 +#' model <- create_model_lstm_cnn(layer_lstm = 8, maxlen = maxlen) +#' gen <- get_generator(train_type = 'dummy_gen', model = model, batch_size = 4, maxlen = maxlen) +#' vat <- validation_after_training_cb(gen.val = gen, validation_steps = 10) +#' +#' @returns Keras callback, apply validation only after training. +#' @export +validation_after_training_cb <- function(gen.val, validation_steps) { + + validation_after_training_cb_py_class <- reticulate::PyClass("validation_after_training_cb", + inherit = tensorflow::tf$keras$callbacks$Callback, + list( + + `__init__` = function(self, gen.val, validation_steps) { + self$gen.val <- gen.val + self$validation_steps <- validation_steps + NULL + }, + + + on_train_end = function(self, logs = list()) { + validation_eval <- keras::evaluate_generator( + object = self$model, + generator = gen.val, + steps = self$validation_steps, + max_queue_size = 10, + workers = 1, + callbacks = NULL + ) + self$model$val_loss <- validation_eval[["loss"]] + self$model$val_acc <- validation_eval[["acc"]] + } + + )) + + validation_after_training_cb_py_class(gen.val = gen.val, validation_steps = validation_steps) + +} + +#' Confusion matrix callback. +#' +#' Create a confusion matrix to display under tensorboard images. +#' +#' @inheritParams train_model +#' @param confMatLabels Names of classes. +#' @param cm_dir Directory that contains confusion matrix files. +#' @examplesIf reticulate::py_module_available("tensorflow") +#' cm <- conf_matrix_cb(path_tensorboard = tempfile(), run_name = 'run_1', +#' confMatLabels = c('label_1', 'label_2'), cm_dir = tempfile()) +#' +#' @returns Keras callback, plot confusion matrix in tensorboard. +#' @export +conf_matrix_cb <- function(path_tensorboard, run_name, confMatLabels, cm_dir) { + + conf_matrix_cb_py_class <- reticulate::PyClass("conf_matrix_cb", + inherit = tensorflow::tf$keras$callbacks$Callback, + list( + + `__init__` = function(self, cm_dir, path_tensorboard, run_name, confMatLabels, graphics = "png") { + self$cm_dir <- cm_dir + self$path_tensorboard <- path_tensorboard + self$run_name <- run_name + self$plot_path_train <- tempfile(pattern = "", fileext = paste0(".", graphics)) + self$plot_path_val <- tempfile(pattern = "", fileext = paste0(".", graphics)) + self$confMatLabels <- confMatLabels + self$epoch <- 0 + self$train_images <- NULL + self$val_images <- NULL + self$graphics <- graphics + self$epoch <- 0 + self$text_size <- NULL + self$round_dig <- 3 + if (length(confMatLabels) < 8) { + self$text_size <- (10 - (max(nchar(confMatLabels)) * 0.15)) * (0.95^length(confMatLabels)) + } + self$cm_display_percentage <- TRUE + NULL + }, + + on_epoch_begin = function(self, epoch, logs) { + #suppressMessages(library(yardstick)) + if (epoch > 0) { + + cm_train <- readRDS(file.path(self$cm_dir, paste0("cm_train_", epoch-1, ".rds"))) + cm_val <- readRDS(file.path(self$cm_dir, paste0("cm_val_", epoch-1, ".rds"))) + if (self$cm_display_percentage) { + cm_train <- cm_perc(cm_train, self$round_dig) + cm_val <- cm_perc(cm_val, self$round_dig) + } + cm_train <- create_conf_mat_obj(cm_train, self$confMatLabels) + cm_val <- create_conf_mat_obj(cm_val, self$confMatLabels) + + + suppressMessages( + cm_plot_train <- ggplot2::autoplot(cm_train, type = "heatmap") + + ggplot2::scale_fill_gradient(low="#D6EAF8", high = "#2E86C1") + + ggplot2::theme(axis.text.x = + ggplot2::element_text(angle=90,hjust=1, size = self$text_size)) + + ggplot2::theme(axis.text.y = + ggplot2::element_text(size = self$text_size)) + ) + + suppressMessages( + cm_plot_val <- ggplot2::autoplot(cm_val, type = "heatmap") + + ggplot2::scale_fill_gradient(low="#D6EAF8", high = "#2E86C1") + + ggplot2::theme(axis.text.x = + ggplot2::element_text(angle=90,hjust=1, size = self$text_size)) + + ggplot2::theme(axis.text.y = + ggplot2::element_text(size = self$text_size)) + ) + + if (length(confMatLabels) > 4) { + plot_size <- (length(confMatLabels) * 1.3) + 1 + } else { + plot_size <- length(confMatLabels) * 3 + } + + if (self$graphics == "png") { + + suppressMessages(ggplot2::ggsave(filename = self$plot_path_train, plot = cm_plot_train, device = "png", + width = plot_size, + height = plot_size, + units = "cm")) + p_cm_train <- png::readPNG(self$plot_path_train) + + suppressMessages(ggplot2::ggsave(filename = self$plot_path_val, plot = cm_plot_val, device = "png", + width = plot_size, + height = plot_size, + units = "cm")) + p_cm_val <- png::readPNG(self$plot_path_val) + + } else { + + suppressMessages(ggplot2::ggsave(filename = self$plot_path_train, plot = cm_plot_train, device = "jpg", + width = plot_size, + height = plot_size, + units = "cm")) + p_cm_train <- jpeg::readJPEG(self$plot_path_train) + + suppressMessages(ggplot2::ggsave(filename = self$plot_path_val, plot = cm_plot_val, device = "jpg", + width = plot_size, + height = plot_size, + units = "cm")) + p_cm_train <- jpeg::readJPEG(self$plot_path_val) + } + + p_cm_train <- as.array(p_cm_train) + p_cm_train <- array(p_cm_train, dim = c(1, dim(p_cm_train))) + p_cm_val <- as.array(p_cm_val) + p_cm_val <- array(p_cm_val, dim = c(1, dim(p_cm_val))) + + num_images <- 1 + train_images <- array(0, dim = c(num_images, dim(p_cm_train)[-1])) + train_images[1, , , ] <- p_cm_train + self$train_images <- train_images + + val_images <- array(0, dim = c(num_images, dim(p_cm_val)[-1])) + val_images[1, , , ] <- p_cm_val + self$val_images <- val_images + file.writer <- tensorflow::tf$summary$create_file_writer(file.path(self$path_tensorboard, self$run_name)) + file.writer$set_as_default() + tensorflow::tf$summary$image(name = "confusion matrix train", data = self$train_images, step = as.integer(epoch-1)) + tensorflow::tf$summary$image(name = "confusion matrix validation", data = self$val_images, step = as.integer(epoch-1)) + file.writer$flush() + self$epoch <- epoch + } + }, + + on_train_end = function(self, logs) { + + epoch <- self$epoch + 1 + + # create confusion matrix for last val step manually (storing cm when calling reset_state) + for (i in 1:length(self$model$metrics)) { + if (self$model$metrics[[i]]$name == "balanced_acc") { + self$model$metrics[[i]]$reset_state() + } + } + + cm_train <- readRDS(file.path(self$cm_dir, paste0("cm_train_", epoch-1, ".rds"))) + cm_val <- readRDS(file.path(self$cm_dir, paste0("cm_val_", epoch-1, ".rds"))) + if (self$cm_display_percentage) { + cm_train <- cm_perc(cm_train, self$round_dig) + cm_val <- cm_perc(cm_val, self$round_dig) + } + cm_train <- create_conf_mat_obj(cm_train, self$confMatLabels) + cm_val <- create_conf_mat_obj(cm_val, self$confMatLabels) + + + suppressMessages( + cm_plot_train <- ggplot2::autoplot(cm_train, type = "heatmap") + + ggplot2::scale_fill_gradient(low="#D6EAF8", high = "#2E86C1") + + ggplot2::theme(axis.text.x = + ggplot2::element_text(angle=90,hjust=1, size = self$text_size)) + + ggplot2::theme(axis.text.y = + ggplot2::element_text(size = self$text_size)) + ) + + suppressMessages( + cm_plot_val <- ggplot2::autoplot(cm_val, type = "heatmap") + + ggplot2::scale_fill_gradient(low="#D6EAF8", high = "#2E86C1") + + ggplot2::theme(axis.text.x = + ggplot2::element_text(angle=90,hjust=1, size = self$text_size)) + + ggplot2::theme(axis.text.y = + ggplot2::element_text(size = self$text_size)) + ) + + if (length(confMatLabels) > 4) { + plot_size <- (length(confMatLabels) * 1.3) + 1 + } else { + plot_size <- length(confMatLabels) * 3 + } + + if (self$graphics == "png") { + + suppressMessages(ggplot2::ggsave(filename = self$plot_path_train, plot = cm_plot_train, device = "png", + width = plot_size, + height = plot_size, + units = "cm")) + p_cm_train <- png::readPNG(self$plot_path_train) + + suppressMessages(ggplot2::ggsave(filename = self$plot_path_val, plot = cm_plot_val, device = "png", + width = plot_size, + height = plot_size, + units = "cm")) + p_cm_val <- png::readPNG(self$plot_path_val) + + } else { + + suppressMessages(ggplot2::ggsave(filename = self$plot_path_train, plot = cm_plot_train, device = "jpg", + width = plot_size, + height = plot_size, + units = "cm")) + p_cm_train <- jpeg::readJPEG(self$plot_path_train) + + suppressMessages(ggplot2::ggsave(filename = self$plot_path_val, plot = cm_plot_val, device = "jpg", + width = plot_size, + height = plot_size, + units = "cm")) + p_cm_train <- jpeg::readJPEG(self$plot_path_val) + } + + p_cm_train <- as.array(p_cm_train) + p_cm_train <- array(p_cm_train, dim = c(1, dim(p_cm_train))) + p_cm_val <- as.array(p_cm_val) + p_cm_val <- array(p_cm_val, dim = c(1, dim(p_cm_val))) + + num_images <- 1 + train_images <- array(0, dim = c(num_images, dim(p_cm_train)[-1])) + train_images[1, , , ] <- p_cm_train + self$train_images <- train_images + + val_images <- array(0, dim = c(num_images, dim(p_cm_val)[-1])) + val_images[1, , , ] <- p_cm_val + self$val_images <- val_images + file.writer <- tensorflow::tf$summary$create_file_writer(file.path(self$path_tensorboard, self$run_name)) + file.writer$set_as_default() + tensorflow::tf$summary$image(name = "confusion matrix train", data = self$train_images, step = as.integer(epoch-1)) + tensorflow::tf$summary$image(name = "confusion matrix validation", data = self$val_images, step = as.integer(epoch-1)) + file.writer$flush() + } + )) + conf_matrix_cb_py_class(path_tensorboard = path_tensorboard, + run_name = run_name, + confMatLabels = confMatLabels, + cm_dir = cm_dir) +} + + +get_callbacks <- function(default_arguments, model, path_tensorboard, run_name, train_type, + path, train_val_ratio, batch_size, epochs, format, + max_queue_size, lr_plateau_factor, patience, cooldown, path_checkpoint, + steps_per_epoch, step, shuffle_file_order, initial_epoch, vocabulary, + learning_rate, shuffle_input, vocabulary_label, solver, dataset_val, + file_limit, reverse_complement, wavenet_format, cnn_format, + create_model_function = NULL, vocabulary_size, gen_cb, argumentList, + maxlen, labelGen, labelByFolder, vocabulary_label_size, tb_images, + target_middle, path_file_log, proportion_per_seq, validation_steps, + train_val_split_csv, n_gram, path_file_logVal, model_card, + skip_amb_nuc, max_samples, proportion_entries, path_log, output, + train_with_gen, random_sampling, reduce_lr_on_plateau, + save_weights_only, save_best_only, reset_states, early_stopping_time, + validation_only_after_training, gen.val, target_from_csv) { + + if (output$checkpoints) { + # create folder for checkpoints using run_name + checkpoint_dir <- paste0(path_checkpoint, "/", run_name) + dir.create(checkpoint_dir, showWarnings = FALSE) + if (!is.list(model$output) & !is.null(gen.val)) { + # filename with epoch, validation loss and validation accuracy + filepath_checkpoints <- file.path(checkpoint_dir, "Ep.{epoch:03d}-val_loss{val_loss:.2f}-val_acc{val_acc:.3f}.hdf5") + } else { + + # if (is.null(gen.val)) { + # filepath_checkpoints <- file.path(checkpoint_dir, "Ep.{epoch:03d}-loss{loss:.2f}-acc{acc:.3f}.hdf5") + # } else { + filepath_checkpoints <- file.path(checkpoint_dir, "Ep.{epoch:03d}.hdf5") + if ((is.list(save_best_only) && !is.null(save_best_only$monitor)) & is.null(dataset_val)) { + warning("save_best_only not implemented for multi target or training without validation data. Setting save_best_only to NULL.") + save_best_only <- NULL + } + #} + + } + } + + # Check if path_file_log is unique + if (!is.null(path_file_log) && dir.exists(path_file_log)) { + stop(paste0("path_file_log entry is already present. Please give this file a unique name.")) + } + + count_files <- !random_sampling + callbacks <- list() + callback_names <- NULL + + if (reduce_lr_on_plateau) { + if (is.list(model$outputs)) { + monitor <- "val_loss" + } else { + monitor <- "val_acc" + } + callbacks[[1]] <- reduce_lr_cb(patience = patience, cooldown = cooldown, + lr_plateau_factor = lr_plateau_factor, + monitor = monitor) + callback_names <- c("reduce_lr", callback_names) + } + + if (!is.null(path_log)) { + callbacks <- c(callbacks, log_cb(path_log, run_name)) + callback_names <- c("log", callback_names) + } + + if (!output$tensorboard) tb_images <- FALSE + if (output$tensorboard) { + + # add balanced acc score + model <- manage_metrics(model) + if (train_with_gen) { + num_targets <- ifelse(train_type == "lm", length(vocabulary), length(vocabulary_label)) + } else { + num_targets <- dim(dataset_val$Y)[2] + } + contains_macro_acc_metric <- FALSE + for (i in 1:length(model$metrics)) { + if (model$metrics[[i]]$name == "balanced_acc") contains_macro_acc_metric <- TRUE + } + + metric_names <- vector("character", length(model$metrics)) + for (i in 1:length(model$metrics)) { + metric_names[i] <- model$metrics[[i]]$name + } + loss_index <- stringr::str_detect(metric_names, "loss") + + if (!contains_macro_acc_metric) { + if (tb_images) { + if (!reticulate::py_has_attr(model, "cm_dir")) { + cm_dir <- file.path(tempdir(), paste(sample(letters, 7), collapse = "")) + dir.create(cm_dir) + model$cm_dir <- cm_dir + } + + metrics <- c(model$metrics[!loss_index], balanced_acc_wrapper(num_targets = num_targets, cm_dir = model$cm_dir)) + } + } else { + metrics <- c(model$metrics[!loss_index]) + } + + # count files in path + if (train_type == "label_rds" | train_type == "lm_rds") format <- "rds" + if (train_with_gen) { + num_train_files <- count_files(path = path, format = format, train_type = train_type, + target_from_csv = target_from_csv, + train_val_split_csv = train_val_split_csv) + } else { + num_train_files <- 1 + } + + complete_tb <- tensorboard_complete_cb(default_arguments = default_arguments, model = model, path_tensorboard = path_tensorboard, run_name = run_name, train_type = train_type, + path = path, train_val_ratio = train_val_ratio, batch_size = batch_size, epochs = epochs, + max_queue_size = max_queue_size, lr_plateau_factor = lr_plateau_factor, patience = patience, cooldown = cooldown, + steps_per_epoch = steps_per_epoch, step = step, shuffle_file_order = shuffle_file_order, initial_epoch = initial_epoch, vocabulary = vocabulary, + learning_rate = learning_rate, shuffle_input = shuffle_input, vocabulary_label = vocabulary_label, solver = solver, + file_limit = file_limit, reverse_complement = reverse_complement, wavenet_format = wavenet_format, cnn_format = cnn_format, + create_model_function = NULL, vocabulary_size = vocabulary_size, gen_cb = gen_cb, argumentList = argumentList, + maxlen = maxlen, labelGen = labelGen, labelByFolder = labelByFolder, vocabulary_label_size = vocabulary_label_size, tb_images = FALSE, + target_middle = target_middle, num_train_files = num_train_files, path_file_log = path_file_log, proportion_per_seq = proportion_per_seq, + skip_amb_nuc = skip_amb_nuc, max_samples = max_samples, proportion_entries = proportion_entries, + train_with_gen = train_with_gen, count_files = !random_sampling) + callbacks <- c(callbacks, complete_tb) + callback_names <- c(callback_names, names(complete_tb)) + } + + if (output$checkpoints) { + if (wavenet_format) { + # can only save weights for wavenet + save_weights_only <- TRUE + } + callbacks <- c(callbacks, checkpoint_cb(filepath_checkpoints = filepath_checkpoints, save_weights_only = save_weights_only, + save_best_only = save_best_only)) + callback_names <- c(callback_names, "checkpoint") + } + + if (reset_states) { + callbacks <- c(callbacks, reset_states_cb(path_file_log = path_file_log, path_file_logVal = path_file_logVal)) + callback_names <- c(callback_names, "reset_states") + } + + if (!is.null(early_stopping_time)) { + callbacks <- c(callbacks, early_stopping_cb(early_stopping_time = early_stopping_time)) + callback_names <- c(callback_names, "early_stopping") + } + + if (validation_only_after_training) { + if (!train_with_gen) stop("Validation after training only implemented for generator") + callbacks <- c(callbacks, validation_after_training_cb(gen.val = gen.val, validation_steps = validation_steps)) + callback_names <- c(callback_names, "validation_after_training") + } + + if (!is.null(model_card)) { + callbacks <- c(callbacks, model_card_cb(model_card_path = model_card$path_model_card, + run_name = run_name, argumentList = argumentList)) + } + + if (tb_images) { + if (is.list(model$output)) { + warning("Tensorboard images (confusion matrix) not implemented for model with multiple outputs. + Setting tb_images to FALSE") + tb_images <- FALSE + } + + if (model$loss == "binary_crossentropy") { + warning("Tensorboard images (confusion matrix) not implemented for sigmoid activation in last layer. + Setting tb_images to FALSE") + tb_images <- FALSE + } + } + + if (tb_images) { + + confMatLabels <- vocabulary_label + if (train_with_gen & train_type == "lm") { + if (is.null(n_gram) || n_gram == 1) { + confMatLabels <- vocabulary + } else { + l <- list() + for (i in 1:n_gram) { + l[[i]] <- vocabulary + } + confMatLabels <- expand.grid(l) %>% apply(1, paste0) %>% apply(2, paste, collapse = "") %>% sort() + } + } + + model <- model %>% keras::compile(loss = model$loss, + optimizer = model$optimizer, metrics = metrics) + + if (length(confMatLabels) > 16) { + message("Cannot display confusion matrix with more than 16 labels.") + } else { + + callbacks <- c(callbacks, conf_matrix_cb(path_tensorboard = path_tensorboard, + run_name = run_name, + confMatLabels = confMatLabels, + cm_dir = model$cm_dir)) + callback_names <- c(callback_names, "conf_matrix") + } + } + + return(callbacks) +}