Diff of /R/callbacks.R [000000] .. [409433]

Switch to unified view

a b/R/callbacks.R
1
#' Create model card
2
#' 
3
#' Log information about model, hyperparameters, generator options, training data, scores etc 
4
#'
5
#' @param model_card_path Directory for model card logs.
6
#' @param run_name Name of training run.
7
#' @param argumentList List of training arguments.
8
#' @examplesIf reticulate::py_module_available("tensorflow")
9
#' model_card_cb <- function(model_card_path = NULL, run_name, argumentList)
10
#' mc <- model_card_cb(model_card_path = tempdir(), run_name = 'run_1',
11
#'                     argumentList = list(learning_rate = 0.01)) 
12
#' 
13
#' @returns Keras callback writing model cards every epoch.
14
#' @export
15
model_card_cb <- function(model_card_path = NULL, run_name, argumentList) {
16
  
17
  model_card_cb_py_class <- reticulate::PyClass("model_card_cb",
18
                                                inherit = tensorflow::tf$keras$callbacks$Callback,
19
                                                list(
20
                                                  
21
                                                  `__init__` = function(self, model_card_path, run_name) {
22
                                                    self$model_card_path <- model_card_path
23
                                                    self$start_time <- Sys.time()
24
                                                    self$mc_dir <- file.path(model_card_path, run_name)
25
                                                    self$param_list <- list()
26
                                                    self$argumentList <- argumentList
27
                                                    NULL
28
                                                  },
29
                                                  
30
                                                  # collect all data
31
                                                  on_train_begin = function(self, logs) {
32
                                                    
33
                                                    if (!dir.exists(self$mc_dir)) {
34
                                                      dir.create(self$mc_dir)
35
                                                    } else {
36
                                                      #stop("Directory already exists. Change run_name")
37
                                                    }
38
                                                    
39
                                                    self$param_list <- self$model$hparam
40
                                                    self$param_list$train_model_args <- argumentList
41
                                                    for (n in names(self$param_list$train_model_args)) {
42
                                                      self$param_list$train_model_args[[n]] <- eval(self$param_list$train_model_args[[n]])
43
                                                    }
44
                                                    self$param_list$train_model_args[["model"]] <- NULL 
45
                                                    self$param_list$model_summary <- summary(self$model)
46
                                                    self$param_list$training_start_time <- format(self$start_time, "%a %b %d %X %Y")
47
                                                    
48
                                                    gpu_info <- tensorflow::tf$config$list_physical_devices('GPU')
49
                                                    self$param_list$gpu_info[["number GPUs"]] <- length(gpu_info)
50
                                                    if (length(gpu_info) > 0) {
51
                                                      for (i in 1:length(gpu_info)) {
52
                                                        self$param_list$gpu_info[[paste0("GPU", i)]] <-
53
                                                          tensorflow::tf$config$experimental$get_device_details(
54
                                                            gpu_info[[i]]
55
                                                          )
56
                                                      }
57
                                                    }
58
                                                    
59
                                                    saveRDS(self$param_list, paste0(self$mc_dir, "/epoch_0_param_list.rds"))
60
                                                    
61
                                                  },
62
                                                  
63
                                                  # update training scores
64
                                                  on_epoch_end = function(self, epoch, logs) {
65
                                                    time_passed <- as.double(difftime(Sys.time(), self$start_time, units = "secs"))
66
                                                    self$param_list[["training_time"]] <- time_passed
67
                                                    
68
                                                    if (epoch == 0) {
69
                                                      m <- unlist(logs) 
70
                                                      m <- c(m, epoch, time_passed) %>% matrix(nrow = 1) %>% as.data.frame()
71
                                                      names(m) <- c(names(logs), "processing_step", "time")
72
                                                      self$param_list[["logs"]] <- m
73
                                                    } else {
74
                                                      m <- unlist(logs) 
75
                                                      m <- c(m, epoch, time_passed)  %>% matrix(nrow = 1) %>% as.data.frame()
76
                                                      names(m) <- c(names(logs), "processing_step", "time")
77
                                                      m <- rbind(self$param_list[["logs"]], m)
78
                                                      self$param_list[["logs"]] <- reticulate::r_to_py(m)
79
                                                    }
80
                                                    
81
                                                    saveRDS(self$param_list, paste0(self$mc_dir, "/epoch_", epoch + 1, "_param_list.rds"))
82
                                                  }
83
                                                  
84
                                                ))
85
  
86
  model_card_cb_py_class(model_card_path = model_card_path,
87
                         run_name = run_name)
88
  
89
}
90
91
92
93
#' Stop training callback
94
#' 
95
#' Stop training after specified time.
96
#'
97
#' @param stop_time Time in seconds after which to stop training.
98
#' @examplesIf reticulate::py_module_available("tensorflow")
99
#' est <- early_stopping_time_cb(stop_time = 60)
100
#' 
101
#' @returns A Keras callback that stops training after specified time.
102
#' @export
103
early_stopping_time_cb <- function(stop_time = NULL) {
104
  
105
  early_stopping_time_cb_py_class <- reticulate::PyClass("early_stopping_time_cb",
106
                                                         inherit = tensorflow::tf$keras$callbacks$Callback,
107
                                                         list(
108
                                                           
109
                                                           `__init__` = function(self, stop_time) {
110
                                                             self$start_time <- Sys.time()
111
                                                             self$stop_time <- stop_time
112
                                                             NULL
113
                                                           },
114
                                                           
115
                                                           on_batch_end = function(self, epoch, logs) {
116
                                                             time_passed <- as.double(difftime(Sys.time(), self$start_time, units = "secs"))
117
                                                             if (time_passed > self$stop_time) {
118
                                                               self$model$stop_training <- TRUE
119
                                                             }
120
                                                           }
121
                                                           
122
                                                         ))
123
  
124
  early_stopping_time_cb_py_class(stop_time = stop_time)
125
  
126
}
127
128
#' Early stopping callback
129
#'
130
#' @param early_stopping_time Time in seconds after which to stop training.
131
#' @param early_stopping_patience Stop training if val_loss does not improve for \code{early_stopping_patience}.
132
#' @param by_time Whether to use time or patience as metric.
133
#' @returns Keras callback; stop training after specified time.
134
#' @noRd
135
early_stopping_cb <- function(early_stopping_patience = 0, early_stopping_time, by_time = TRUE) {
136
  
137
  if (by_time) {
138
    early_stopping_time_cb(stop_time = early_stopping_time)
139
  } else {
140
    keras::callback_early_stopping(patience = early_stopping_patience)
141
  }
142
}
143
144
#' Log callback
145
#'
146
#' @param path_log Path to output directory.
147
#' @param run_name Name of output file is run_name + ".csv".
148
#' @returns Keras callback, writes epoch scores to csv file.
149
#' @noRd
150
log_cb <- function(path_log, run_name) {
151
  keras::callback_csv_logger(
152
    paste0(path_log, "/", run_name, ".csv"),
153
    separator = ";",
154
    append = TRUE)
155
}
156
157
#' Learning_rate callback
158
#'
159
#' @inheritParams train_model
160
#' @returns Keras callback, reduces learning rate.
161
#' @noRd
162
reduce_lr_cb <- function(patience,
163
                         cooldown,
164
                         lr_plateau_factor,
165
                         monitor = "val_acc") {
166
  keras::callback_reduce_lr_on_plateau(
167
    monitor = monitor,
168
    factor = lr_plateau_factor,
169
    patience = patience,
170
    cooldown = cooldown)
171
}
172
173
#' Checkpoint callback
174
#'
175
#' @inheritParams train_model
176
#' @returns Keras callback, store model checkpoint.
177
#' @noRd
178
checkpoint_cb <- function(filepath_checkpoints,
179
                          save_weights_only,
180
                          save_best_only,
181
                          save_freq, 
182
                          monitor = "val_loss") {
183
  
184
  if (is.logical(save_best_only)) {
185
    if (save_best_only) {
186
      warning("save_best_only should not be boolean variabel, but list or NULL. Using val_loss as monitor.")
187
      save_best_only <- list(monitor = "val_loss")
188
    } else {
189
      warning("save_best_only should not be boolean variabel, but list or NULL.")
190
      save_best_only <- NULL
191
    }
192
  }
193
  
194
  if (is.null(save_best_only) | !is.null(save_best_only$monitor)) {
195
    
196
    keras::callback_model_checkpoint(filepath = filepath_checkpoints,
197
                                     save_weights_only = save_weights_only,
198
                                     save_best_only = !is.null(save_best_only),
199
                                     verbose = 1,
200
                                     save_freq = "epoch",
201
                                     monitor = monitor)
202
    
203
  } else {
204
    
205
    cp_cb <- reticulate::PyClass("cp_cb",
206
                                 inherit = tensorflow::tf$keras$callbacks$Callback,
207
                                 list(
208
                                   
209
                                   `__init__` = function(self, filepath_checkpoints, save_freq, save_weights_only) {
210
                                     self$filepath_checkpoints <- filepath_checkpoints
211
                                     self$save_freq <- save_freq
212
                                     self$save_weights_only <- save_weights_only
213
                                     NULL
214
                                   },
215
                                   
216
                                   on_epoch_end = function(self, epoch, logs) {
217
                                     if ((epoch + 1) %% self$save_freq == 0) {
218
                                       
219
                                       formatted_path <- gsub("\\{epoch:03d\\}", sprintf("%03d", epoch + 1), self$filepath_checkpoints)
220
                                       formatted_path <- gsub("\\{val_loss:.2f\\}", sprintf("%.2f", logs$val_loss), formatted_path)
221
                                       formatted_path <- gsub("\\{val_acc:.3f\\}", sprintf("%.3f", logs$val_acc), formatted_path)
222
                                       print(formatted_path)
223
                                       if (self$save_weights_only) {
224
                                         keras::save_model_hdf5(self$model, formatted_path)
225
                                       } else {
226
                                         keras::save_model_weights_hdf5(self$model, formatted_path)
227
                                       }
228
                                     }
229
                                   }
230
                                   
231
                                 ))
232
    
233
    return(cp_cb(filepath_checkpoints = filepath_checkpoints,
234
                 save_freq = save_best_only$save_freq,
235
                 save_weights_only = save_weights_only))
236
    
237
  }
238
  
239
}
240
241
#' Non model hyperparameter callback
242
#' 
243
#' Get hyperparameters excluding model parameters.
244
#'
245
#' @inheritParams train_model
246
#' @returns Keras callback, track model hyperparameters.
247
#' @noRd
248
hyper_param_model_outside_cb <- function(path_tensorboard, run_name, wavenet_format, cnn_format, model, vocabulary, path, reverse_complement,
249
                                         vocabulary_label, maxlen, epochs, max_queue_size, lr_plateau_factor, batch_size,
250
                                         patience, cooldown, steps_per_epoch, step, shuffle_file_order) {
251
  
252
  train_hparams <- list(
253
    run_name = run_name,
254
    vocabulary = paste(vocabulary, collapse = ","),
255
    path = paste(unlist(path), collapse = ", "),
256
    reverse_complement = paste(reverse_complement),
257
    vocabulary_label = paste(vocabulary_label, collapse = ", "),
258
    epochs = epochs,
259
    max_queue_size = max_queue_size,
260
    lr_plateau_factor = lr_plateau_factor,
261
    batch_size = batch_size,
262
    patience = patience,
263
    cooldown = cooldown,
264
    steps_per_epoch = steps_per_epoch,
265
    step = step,
266
    shuffle_file_order = shuffle_file_order
267
  )
268
  #hparams$update(model$hparam)
269
  model_hparams <- vector("list")
270
  for (i in names(model$hparam)) {
271
    model_hparams[[i]] <- model$hparam[[i]]
272
  }
273
  
274
  hparams_R <- c(train_hparams, model_hparams)
275
  
276
  keep_entry_index <- rep(TRUE, length(hparams_R))
277
  for (i in 1:length(hparams_R)) {
278
    
279
    if (length(hparams_R[[i]]) == 0) { 
280
      keep_entry_index[i] <- FALSE
281
    }
282
    
283
    if (length(hparams_R[[i]]) > 1) { 
284
      hparams_R[[i]] <- paste(hparams_R[[i]], collapse = " ")
285
    }
286
  }
287
  hparams_R <- hparams_R[keep_entry_index]
288
  
289
  hparams <- reticulate::dict(hparams_R)
290
  hp <- reticulate::import("tensorboard.plugins.hparams.api")
291
  hp$KerasCallback(file.path(path_tensorboard, run_name), hparams, trial_id = run_name)
292
}
293
294
#' Model hyperparameter callback
295
#' 
296
#' Get model hyperparameters.
297
#'
298
#' @inheritParams train_model
299
#' @returns Keras callback, track training hyperparameters.
300
#' @noRd
301
hyper_param_with_model_cb <- function(default_arguments, model, path_tensorboard, run_name, train_type, path, train_val_ratio, batch_size,
302
                                      epochs, max_queue_size, lr_plateau_factor,
303
                                      patience, cooldown, steps_per_epoch, step, shuffle_file_order, initial_epoch, vocabulary, learning_rate,
304
                                      shuffle_input, vocabulary_label, solver, file_limit, reverse_complement, wavenet_format, cnn_format) {
305
  
306
  model_hparam <- vector("list")
307
  model_hparam_names <- vector("list")
308
  for (i in 1:length(default_arguments)) {
309
    if (is.null(default_arguments[[i]])) {
310
      model_hparam[i] <- "NULL"
311
    } else {
312
      model_hparam[i] <- default_arguments[i]
313
    }
314
  }
315
  names(model_hparam) <- names(default_arguments)
316
  # hparam from train_model
317
  learning_rate <- keras::k_eval(model$optimizer$lr)
318
  solver <- stringr::str_to_lower(model$optimizer$get_config()["name"])
319
  
320
  train_hparam_names <- c("train_type", "path", "train_val_ratio", "run_name", "batch_size", "epochs", "max_queue_size", "lr_plateau_factor",
321
                          "patience", "cooldown", "steps_per_epoch", "step", "shuffle_file_order", "initial_epoch", "vocabulary", "learning_rate",
322
                          "shuffle_input", "vocabulary_label", "solver", "file_limit", "reverse_complement", "wavenet_format", "cnn_format")
323
  train_hparam <- vector("list")
324
  for (i in 1:length(train_hparam_names)) {
325
    if (is.null(eval(parse(text=train_hparam_names[i])))) {
326
      train_hparam[[i]] <- "NULL"
327
    } else if (length(eval(parse(text=train_hparam_names[i])) > 1)) {
328
      train_hparam[[i]] <- toString(eval(parse(text=train_hparam_names[i])))
329
      if (length(train_hparam[[i]]) > 1) {
330
        train_hparam[[i]] <- paste(train_hparam[[i]], collapse = " ")
331
      }
332
    } else {
333
      train_hparam[[i]] <- eval(parse(text=train_hparam_names[i]))
334
      if (length(train_hparam[[i]]) > 1) {
335
        train_hparam[[i]] <- paste(train_hparam[[i]], collapse = " ")
336
      }
337
    }
338
  }
339
  names(train_hparam) <- train_hparam_names
340
  hparams_R <- c(train_hparam, model_hparam)
341
  hparams <- reticulate::dict(hparams_R)
342
  hp <- reticulate::import("tensorboard.plugins.hparams.api")
343
  return(hp$KerasCallback(file.path(path_tensorboard, run_name), hparams, trial_id = run_name))
344
}
345
346
#' Tensorboard callback
347
#'
348
#' @inheritParams train_model
349
#' @returns Keras callback, write tensorboard logs.
350
#' @noRd
351
tensorboard_cb <- function(path_tensorboard, run_name) {
352
  keras::callback_tensorboard(file.path(path_tensorboard, run_name),
353
                              write_graph = TRUE,
354
                              histogram_freq = 1,
355
                              write_images = TRUE,
356
                              write_grads = TRUE)
357
}
358
359
#' Function arguments callback
360
#' 
361
#' Print train_model call in text field of tensorboard.
362
#' 
363
#' @inheritParams train_model
364
#' @param argumentList List of function arguments.
365
#' @returns Keras callback, track arguments of `train_model` function.
366
#' @noRd
367
function_args_cb <- function(argumentList, path_tensorboard, run_name) {
368
  
369
  argAsChar <- as.character(argumentList)
370
  argText <- vector("character")
371
  if (length(argumentList$path) > 1) {
372
    
373
    argsInQuotes <- c("path_checkpoint", "run_name", "solver", "format", "output_format",
374
                      "path_tensorboard", "path_file_log", "train_type", "ambiguous_nuc", "added_label_path", "added_label_names",
375
                      "train_val_split_csv", "target_from_csv")
376
  } else {
377
    argsInQuotes <- c("path", "path_val", "path_checkpoint", "run_name", "solver", "output_format",
378
                      "path_tensorboard", "path_file_log", "train_type", "ambiguous_nuc", "format", "added_label_path", "added_label_names",
379
                      "train_val_split_csv", "target_from_csv")
380
  }
381
  argText[1] <- "train_model("
382
  for (i in 2:(length(argumentList) - 1)) {
383
    arg <- argAsChar[[i]]
384
    if (names(argumentList)[i] %in% argsInQuotes) {
385
      if (arg == "NULL") {
386
        argText[i] <- paste0(names(argumentList)[i], " = ", arg, ",")
387
      } else {
388
        argText[i] <- paste0(names(argumentList)[i], " = ", '\"', arg, '\"', ",")
389
      }
390
    } else {
391
      argText[i] <- paste0(names(argumentList)[i], " = ", arg, ",")
392
    }
393
  }
394
  i <- length(argumentList)
395
  if (names(argumentList)[i] %in% argsInQuotes) {
396
    if (arg == "NULL") {
397
      argText[i] <- paste0(names(argumentList)[i], " = ", argAsChar[[i]], ")")
398
    } else {
399
      argText[i] <- paste0(names(argumentList)[i], " = ", '\"', argAsChar[[i]], '\"', ")")
400
    }
401
  } else {
402
    argText[i] <- paste0(names(argumentList)[i], " = ", argAsChar[[i]], ")")
403
  }
404
  
405
  # write function arguments as text in tensorboard
406
  trainArguments <- keras::callback_lambda(
407
    on_train_begin = function(logs) {
408
      file.writer <- tensorflow::tf$summary$create_file_writer(file.path(path_tensorboard, run_name))
409
      file.writer$set_as_default()
410
      tensorflow::tf$summary$text(name="Arguments",  data = argText, step = 0L)
411
      file.writer$flush()
412
    }
413
  )
414
  trainArguments
415
}
416
417
#' Tensorboard callback wrapper
418
#'
419
#' @inheritParams train_model
420
#' @returns Keras callback, wrapper for all callbacks involving tensorboard.
421
#' @noRd
422
tensorboard_complete_cb <- function(default_arguments, model, path_tensorboard, run_name, train_type, path, train_val_ratio, batch_size,
423
                                    epochs, max_queue_size, lr_plateau_factor, patience, cooldown, steps_per_epoch, step, shuffle_file_order, initial_epoch, vocabulary, learning_rate,
424
                                    shuffle_input, vocabulary_label, solver, file_limit, reverse_complement, wavenet_format, cnn_format, create_model_function, vocabulary_size, gen_cb,
425
                                    argumentList, maxlen, labelGen, labelByFolder, vocabulary_label_size, tb_images = FALSE, stateful, target_middle, num_train_files, path_file_log,
426
                                    proportion_per_seq, skip_amb_nuc, max_samples, proportion_entries, train_with_gen, count_files = TRUE) {
427
  l <- vector("list")
428
  
429
  l[[1]] <- hyper_param_model_outside_cb(path_tensorboard = path_tensorboard, run_name = run_name, wavenet_format = wavenet_format, cnn_format = cnn_format, model = model,
430
                                         vocabulary = vocabulary, path = path, reverse_complement = reverse_complement, vocabulary_label = vocabulary_label,
431
                                         maxlen = maxlen, epochs = epochs, max_queue_size = max_queue_size, lr_plateau_factor = lr_plateau_factor,
432
                                         batch_size = batch_size, patience = patience, cooldown = cooldown, steps_per_epoch = steps_per_epoch,
433
                                         step = step, shuffle_file_order = shuffle_file_order)
434
  
435
  l[[2]] <- tensorboard_cb(path_tensorboard = path_tensorboard, run_name = run_name)
436
  l[[3]] <- function_args_cb(argumentList = argumentList, path_tensorboard = path_tensorboard, run_name = run_name)
437
  
438
  if (train_with_gen & count_files) {
439
    
440
    proportion_training_files_cb <- reticulate::PyClass("proportion_training_files_cb",
441
                                                        inherit = tensorflow::tf$keras$callbacks$Callback,
442
                                                        list(
443
                                                          
444
                                                          `__init__` = function(self, num_train_files, path_file_log, path_tensorboard, run_name, vocabulary_label,
445
                                                                                path, train_type, start_index, proportion_per_seq, max_samples, step,
446
                                                                                proportion_entries) {
447
                                                            self$num_train_files <- num_train_files
448
                                                            self$path_file_log <- path_file_log
449
                                                            self$path_tensorboard <- path_tensorboard
450
                                                            self$run_name <- run_name
451
                                                            self$vocabulary_label <- vocabulary_label
452
                                                            self$path <- path
453
                                                            self$train_type <- train_type
454
                                                            self$proportion_per_seq <- proportion_per_seq
455
                                                            self$max_samples <- max_samples
456
                                                            self$step <- step
457
                                                            self$start_index <- 1
458
                                                            self$first_epoch <- TRUE
459
                                                            self$description <- ""
460
                                                            self$proportion_entries <- proportion_entries
461
                                                            NULL
462
                                                          },
463
                                                          
464
                                                          on_epoch_end = function(self, epoch, logs) {
465
                                                            if (is.null(self$proportion_entries)) self$proportion_entries <- 1
466
                                                            file.writer <- tensorflow::tf$summary$create_file_writer(file.path(self$path_tensorboard, self$run_name))
467
                                                            file.writer$set_as_default()
468
                                                            files_used <- utils::read.csv(self$path_file_log, stringsAsFactors = FALSE, header = FALSE)
469
                                                            if (self$train_type == "label_folder") {
470
                                                              if (self$first_epoch) {
471
                                                                if (length(self$step) == 1) self$step <- rep(self$step, length(vocabulary_label))
472
                                                                if (length(self$proportion_per_seq) == 1) {
473
                                                                  self$proportion_per_seq <- rep(self$proportion_per_seq, length(self$vocabulary_label))
474
                                                                }
475
                                                                if (length(max_samples) == 1) self$max_samples <- rep(max_samples, length(vocabulary_label))
476
                                                                
477
                                                                for (i in 1:length(self$vocabulary_label)) {
478
                                                                  if (is.null(self$max_samples)) {
479
                                                                    self$description[i] <- paste0("Using step size ", self$step[i], ", proportion_entries ",
480
                                                                                                  self$proportion_entries * 100, "% and ",
481
                                                                                                  ifelse(is.null(self$proportion_per_seq[i]), 1,
482
                                                                                                         self$proportion_per_seq[i]) * 100, "% per sequence")
483
                                                                  } else {
484
                                                                    self$description[i] <- paste0("Using step size ", self$step[i], ", ",
485
                                                                                                  ifelse(is.null(self$proportion_per_seq[i]), 1,
486
                                                                                                         self$proportion_per_seq[i]) * 100, "% per sequence, maximum of ",
487
                                                                                                  self$max_samples[i], " samples per file and proportion_entries ",
488
                                                                                                  self$proportion_entries * 100, "%")
489
                                                                  }
490
                                                                }
491
                                                                self$first_epoch <- FALSE
492
                                                              }
493
                                                              
494
                                                              for (i in 1:length(self$vocabulary_label)) {
495
                                                                files_of_class <-  sum(stringr::str_detect(
496
                                                                  files_used[ , 1], paste(unlist(self$path[[i]]), collapse = "|")
497
                                                                ))
498
                                                                files_percentage <- 100 * files_of_class/self$num_train_files[i]
499
                                                                tensorflow::tf$summary$scalar(name = paste0("training files seen (%): '",
500
                                                                                                            self$vocabulary_label[i], "'"), data = files_percentage, step = epoch,
501
                                                                                              description = self$description[i])
502
                                                              }
503
                                                            } else {
504
                                                              files_percentage <- 100 * nrow(files_used)/self$num_train_files
505
                                                              if (is.null(self$max_samples)) {
506
                                                                description <- paste0("Using step size ", step,
507
                                                                                      ", proportion_entries ", self$proportion_entries * 100, "% and ",
508
                                                                                      ifelse(is.null(self$proportion_per_seq), 1,
509
                                                                                             self$proportion_per_seq) * 100, "% per sequence")
510
                                                              } else {
511
                                                                description <- paste0("Using step size ", step, ", ",
512
                                                                                      ifelse(is.null(self$proportion_per_seq), 1,
513
                                                                                             self$proportion_per_seq) * 100, "% per sequence, maximum of ",
514
                                                                                      self$max_samples, " samples per file and proportion_entries ",
515
                                                                                      self$proportion_entries * 100, "%")
516
                                                                
517
                                                              }
518
                                                              if (self$train_type == "label_rds") {
519
                                                                description <- paste0("Using step size ",
520
                                                                                      ifelse(is.null(self$proportion_per_seq), 1,
521
                                                                                             self$proportion_per_seq) * 100, "% per sequence and maximum of ",
522
                                                                                      self$max_samples, " samples per file.")
523
                                                              }
524
                                                              tensorflow::tf$summary$scalar(name = paste("training files seen (%)"), data = files_percentage, step = epoch,
525
                                                                                            description = description)
526
                                                            }
527
                                                            
528
                                                            file.writer$flush()
529
                                                          }
530
                                                          
531
                                                        ))
532
    
533
    
534
    l[[4]] <- proportion_training_files_cb(num_train_files = num_train_files, path_file_log = path_file_log, path_tensorboard = path_tensorboard, run_name = run_name,
535
                                           vocabulary_label = vocabulary_label, path = path, train_type = train_type, proportion_per_seq = proportion_per_seq,
536
                                           max_samples = max_samples, step = step, proportion_entries = proportion_entries)
537
    #names(l) <- c("hyper_param_model_outside", "tensorboard", "function_args","proportion_training_files")
538
  } else {
539
    #names(l) <- c("hyper_param_model_outside", "tensorboard", "function_args")
540
  }
541
  return(l)
542
}
543
544
#' Reset states callback
545
#' 
546
#' Reset states at start/end of validation and whenever file changes. Can be used for stateful LSTM.
547
#' 
548
#' @param path_file_log Path to log of training files.
549
#' @param path_file_logVal  Path to log of validation files.
550
#' @examplesIf reticulate::py_module_available("tensorflow")
551
#' rs <- reset_states_cb(path_file_log = tempfile(), path_file_logVal = tempfile())
552
#' 
553
#' @returns A keras callback that resets states of LSTM layers. 
554
#' @export
555
reset_states_cb <- function(path_file_log, path_file_logVal) {
556
  
557
  reset_states_cb_py_class <- reticulate::PyClass("reset_states_cb",
558
                                                  inherit = tensorflow::tf$keras$callbacks$Callback,
559
                                                  list(
560
                                                    
561
                                                    `__init__` = function(self, path_file_log, path_file_logVal) {
562
                                                      self$path_file_log <- path_file_log
563
                                                      self$path_file_logVal <- path_file_logVal
564
                                                      self$num_files_old <- 0
565
                                                      self$num_files_new <- 0
566
                                                      self$num_files_old_val <- 0
567
                                                      self$num_files_new_val <- 0
568
                                                      NULL
569
                                                    },
570
                                                    
571
                                                    on_test_begin = function(self, epoch, logs) {
572
                                                      self$model$reset_states()
573
                                                    },
574
                                                    
575
                                                    on_test_end = function(self, epoch, logs) {
576
                                                      self$model$reset_states()
577
                                                    },
578
                                                    
579
                                                    on_train_batch_begin = function(self, batch, logs) {
580
                                                      files_used <- readLines(self$path_file_log)
581
                                                      self$num_files_new <- length(files_used)
582
                                                      if (self$num_files_new > self$num_files_old) {
583
                                                        self$model$reset_states()
584
                                                        self$num_files_old <- self$num_files_new
585
                                                      }
586
                                                    },
587
                                                    
588
                                                    on_test_batch_begin = function(self, batch, logs) {
589
                                                      files_used <- readLines(self$path_file_logVal)
590
                                                      self$num_files_new_val <- length(files_used)
591
                                                      if (self$num_files_new_val > self$num_files_old_val) {
592
                                                        self$model$reset_states()
593
                                                        self$num_files_old_val <- self$num_files_new_val
594
                                                      }
595
                                                    }
596
                                                    
597
                                                  ))
598
  
599
  reset_states_cb_py_class(path_file_log = path_file_log, path_file_logVal = path_file_logVal)
600
}
601
602
#' Validation after training callback
603
#' 
604
#' Do validation only once at end of training.
605
#' 
606
#' @param gen.val Validation generator
607
#' @param validation_steps Number of validation steps.
608
#' @examplesIf reticulate::py_module_available("tensorflow")
609
#' maxlen <- 20
610
#' model <- create_model_lstm_cnn(layer_lstm = 8, maxlen = maxlen)
611
#' gen <- get_generator(train_type = 'dummy_gen', model = model, batch_size = 4, maxlen = maxlen)
612
#' vat <- validation_after_training_cb(gen.val = gen, validation_steps = 10)
613
#' 
614
#' @returns Keras callback, apply validation only after training.
615
#' @export
616
validation_after_training_cb <- function(gen.val, validation_steps) {
617
  
618
  validation_after_training_cb_py_class <- reticulate::PyClass("validation_after_training_cb",
619
                                                               inherit = tensorflow::tf$keras$callbacks$Callback,
620
                                                               list(
621
                                                                 
622
                                                                 `__init__` = function(self, gen.val, validation_steps) {
623
                                                                   self$gen.val <- gen.val
624
                                                                   self$validation_steps <- validation_steps
625
                                                                   NULL
626
                                                                 },
627
                                                                 
628
                                                                 
629
                                                                 on_train_end = function(self, logs = list()) {
630
                                                                   validation_eval <- keras::evaluate_generator(
631
                                                                     object = self$model,
632
                                                                     generator = gen.val,
633
                                                                     steps = self$validation_steps,
634
                                                                     max_queue_size = 10,
635
                                                                     workers = 1,
636
                                                                     callbacks = NULL
637
                                                                   )
638
                                                                   self$model$val_loss <- validation_eval[["loss"]]
639
                                                                   self$model$val_acc <- validation_eval[["acc"]]
640
                                                                 }
641
                                                                 
642
                                                               ))
643
  
644
  validation_after_training_cb_py_class(gen.val = gen.val, validation_steps = validation_steps)
645
  
646
}
647
648
#' Confusion matrix callback.
649
#' 
650
#' Create a confusion matrix to display under tensorboard images. 
651
#'
652
#' @inheritParams train_model
653
#' @param confMatLabels Names of classes.
654
#' @param cm_dir Directory that contains confusion matrix files.
655
#' @examplesIf reticulate::py_module_available("tensorflow")
656
#' cm <- conf_matrix_cb(path_tensorboard = tempfile(), run_name = 'run_1',
657
#'                      confMatLabels = c('label_1', 'label_2'), cm_dir = tempfile())
658
#' 
659
#' @returns Keras callback, plot confusion matrix in tensorboard.
660
#' @export
661
conf_matrix_cb <- function(path_tensorboard, run_name, confMatLabels, cm_dir) {
662
  
663
  conf_matrix_cb_py_class <- reticulate::PyClass("conf_matrix_cb",
664
                                                 inherit = tensorflow::tf$keras$callbacks$Callback,
665
                                                 list(
666
                                                   
667
                                                   `__init__` = function(self, cm_dir, path_tensorboard, run_name, confMatLabels, graphics = "png") {
668
                                                     self$cm_dir <- cm_dir
669
                                                     self$path_tensorboard <- path_tensorboard
670
                                                     self$run_name <- run_name
671
                                                     self$plot_path_train <- tempfile(pattern = "", fileext = paste0(".", graphics))
672
                                                     self$plot_path_val <- tempfile(pattern = "", fileext = paste0(".", graphics))
673
                                                     self$confMatLabels <- confMatLabels
674
                                                     self$epoch <- 0
675
                                                     self$train_images <- NULL
676
                                                     self$val_images <- NULL
677
                                                     self$graphics <- graphics
678
                                                     self$epoch <- 0
679
                                                     self$text_size <- NULL
680
                                                     self$round_dig <- 3
681
                                                     if (length(confMatLabels) < 8) {
682
                                                       self$text_size <- (10 - (max(nchar(confMatLabels)) * 0.15)) * (0.95^length(confMatLabels))
683
                                                     }
684
                                                     self$cm_display_percentage <- TRUE
685
                                                     NULL
686
                                                   },
687
                                                   
688
                                                   on_epoch_begin = function(self, epoch, logs) {
689
                                                     #suppressMessages(library(yardstick))
690
                                                     if (epoch > 0) {
691
                                                       
692
                                                       cm_train <- readRDS(file.path(self$cm_dir, paste0("cm_train_", epoch-1, ".rds")))
693
                                                       cm_val <- readRDS(file.path(self$cm_dir, paste0("cm_val_", epoch-1, ".rds")))
694
                                                       if (self$cm_display_percentage) {
695
                                                         cm_train <- cm_perc(cm_train, self$round_dig)
696
                                                         cm_val <- cm_perc(cm_val, self$round_dig)
697
                                                       }
698
                                                       cm_train <- create_conf_mat_obj(cm_train, self$confMatLabels)
699
                                                       cm_val <- create_conf_mat_obj(cm_val, self$confMatLabels)
700
                                                       
701
                                                       
702
                                                       suppressMessages(
703
                                                         cm_plot_train <- ggplot2::autoplot(cm_train, type = "heatmap") +
704
                                                           ggplot2::scale_fill_gradient(low="#D6EAF8", high = "#2E86C1")  +
705
                                                           ggplot2::theme(axis.text.x =
706
                                                                            ggplot2::element_text(angle=90,hjust=1, size = self$text_size)) +
707
                                                           ggplot2::theme(axis.text.y =
708
                                                                            ggplot2::element_text(size = self$text_size))
709
                                                       )
710
                                                       
711
                                                       suppressMessages(
712
                                                         cm_plot_val <- ggplot2::autoplot(cm_val, type = "heatmap") +
713
                                                           ggplot2::scale_fill_gradient(low="#D6EAF8", high = "#2E86C1")  +
714
                                                           ggplot2::theme(axis.text.x =
715
                                                                            ggplot2::element_text(angle=90,hjust=1, size = self$text_size)) +
716
                                                           ggplot2::theme(axis.text.y =
717
                                                                            ggplot2::element_text(size = self$text_size))
718
                                                       )
719
                                                       
720
                                                       if (length(confMatLabels) > 4) {
721
                                                         plot_size <- (length(confMatLabels) * 1.3) + 1
722
                                                       } else {
723
                                                         plot_size <- length(confMatLabels) * 3
724
                                                       }
725
                                                       
726
                                                       if (self$graphics == "png") {
727
                                                         
728
                                                         suppressMessages(ggplot2::ggsave(filename = self$plot_path_train, plot = cm_plot_train, device = "png",
729
                                                                                          width = plot_size,
730
                                                                                          height = plot_size,
731
                                                                                          units = "cm"))
732
                                                         p_cm_train <- png::readPNG(self$plot_path_train)
733
                                                         
734
                                                         suppressMessages(ggplot2::ggsave(filename = self$plot_path_val, plot = cm_plot_val, device = "png",
735
                                                                                          width = plot_size,
736
                                                                                          height = plot_size,
737
                                                                                          units = "cm"))
738
                                                         p_cm_val <- png::readPNG(self$plot_path_val)
739
                                                         
740
                                                       } else {
741
                                                         
742
                                                         suppressMessages(ggplot2::ggsave(filename = self$plot_path_train, plot = cm_plot_train, device = "jpg",
743
                                                                                          width = plot_size,
744
                                                                                          height = plot_size,
745
                                                                                          units = "cm"))
746
                                                         p_cm_train <- jpeg::readJPEG(self$plot_path_train)
747
                                                         
748
                                                         suppressMessages(ggplot2::ggsave(filename = self$plot_path_val, plot = cm_plot_val, device = "jpg",
749
                                                                                          width = plot_size,
750
                                                                                          height = plot_size,
751
                                                                                          units = "cm"))
752
                                                         p_cm_train <- jpeg::readJPEG(self$plot_path_val)
753
                                                       }
754
                                                       
755
                                                       p_cm_train <- as.array(p_cm_train)
756
                                                       p_cm_train <- array(p_cm_train, dim = c(1, dim(p_cm_train)))
757
                                                       p_cm_val <- as.array(p_cm_val)
758
                                                       p_cm_val <- array(p_cm_val, dim = c(1, dim(p_cm_val)))
759
                                                       
760
                                                       num_images <- 1
761
                                                       train_images <- array(0, dim = c(num_images, dim(p_cm_train)[-1]))
762
                                                       train_images[1, , , ] <- p_cm_train
763
                                                       self$train_images <- train_images
764
                                                       
765
                                                       val_images <- array(0, dim = c(num_images, dim(p_cm_val)[-1]))
766
                                                       val_images[1, , , ] <- p_cm_val
767
                                                       self$val_images <- val_images
768
                                                       file.writer <- tensorflow::tf$summary$create_file_writer(file.path(self$path_tensorboard, self$run_name))
769
                                                       file.writer$set_as_default()
770
                                                       tensorflow::tf$summary$image(name = "confusion matrix train", data = self$train_images, step = as.integer(epoch-1))
771
                                                       tensorflow::tf$summary$image(name = "confusion matrix validation", data = self$val_images, step = as.integer(epoch-1))
772
                                                       file.writer$flush()
773
                                                       self$epoch <- epoch
774
                                                     }
775
                                                   },
776
                                                   
777
                                                   on_train_end = function(self, logs) {
778
                                                     
779
                                                     epoch <- self$epoch + 1
780
                                                     
781
                                                     # create confusion matrix for last val step manually (storing cm when calling reset_state)
782
                                                     for (i in 1:length(self$model$metrics)) {
783
                                                       if (self$model$metrics[[i]]$name == "balanced_acc") {
784
                                                         self$model$metrics[[i]]$reset_state()
785
                                                       }
786
                                                     }
787
                                                     
788
                                                     cm_train <- readRDS(file.path(self$cm_dir, paste0("cm_train_", epoch-1, ".rds")))
789
                                                     cm_val <- readRDS(file.path(self$cm_dir, paste0("cm_val_", epoch-1, ".rds")))
790
                                                     if (self$cm_display_percentage) {
791
                                                       cm_train <- cm_perc(cm_train, self$round_dig)
792
                                                       cm_val <- cm_perc(cm_val, self$round_dig)
793
                                                     }
794
                                                     cm_train <- create_conf_mat_obj(cm_train, self$confMatLabels)
795
                                                     cm_val <- create_conf_mat_obj(cm_val, self$confMatLabels)
796
                                                     
797
                                                     
798
                                                     suppressMessages(
799
                                                       cm_plot_train <- ggplot2::autoplot(cm_train, type = "heatmap") +
800
                                                         ggplot2::scale_fill_gradient(low="#D6EAF8", high = "#2E86C1")  +
801
                                                         ggplot2::theme(axis.text.x =
802
                                                                          ggplot2::element_text(angle=90,hjust=1, size = self$text_size)) +
803
                                                         ggplot2::theme(axis.text.y =
804
                                                                          ggplot2::element_text(size = self$text_size))
805
                                                     )
806
                                                     
807
                                                     suppressMessages(
808
                                                       cm_plot_val <- ggplot2::autoplot(cm_val, type = "heatmap") +
809
                                                         ggplot2::scale_fill_gradient(low="#D6EAF8", high = "#2E86C1")  +
810
                                                         ggplot2::theme(axis.text.x =
811
                                                                          ggplot2::element_text(angle=90,hjust=1, size = self$text_size)) +
812
                                                         ggplot2::theme(axis.text.y =
813
                                                                          ggplot2::element_text(size = self$text_size))
814
                                                     )
815
                                                     
816
                                                     if (length(confMatLabels) > 4) {
817
                                                       plot_size <- (length(confMatLabels) * 1.3) + 1
818
                                                     } else {
819
                                                       plot_size <- length(confMatLabels) * 3
820
                                                     }
821
                                                     
822
                                                     if (self$graphics == "png") {
823
                                                       
824
                                                       suppressMessages(ggplot2::ggsave(filename = self$plot_path_train, plot = cm_plot_train, device = "png",
825
                                                                                        width = plot_size,
826
                                                                                        height = plot_size,
827
                                                                                        units = "cm"))
828
                                                       p_cm_train <- png::readPNG(self$plot_path_train)
829
                                                       
830
                                                       suppressMessages(ggplot2::ggsave(filename = self$plot_path_val, plot = cm_plot_val, device = "png",
831
                                                                                        width = plot_size,
832
                                                                                        height = plot_size,
833
                                                                                        units = "cm"))
834
                                                       p_cm_val <- png::readPNG(self$plot_path_val)
835
                                                       
836
                                                     } else {
837
                                                       
838
                                                       suppressMessages(ggplot2::ggsave(filename = self$plot_path_train, plot = cm_plot_train, device = "jpg",
839
                                                                                        width = plot_size,
840
                                                                                        height = plot_size,
841
                                                                                        units = "cm"))
842
                                                       p_cm_train <- jpeg::readJPEG(self$plot_path_train)
843
                                                       
844
                                                       suppressMessages(ggplot2::ggsave(filename = self$plot_path_val, plot = cm_plot_val, device = "jpg",
845
                                                                                        width = plot_size,
846
                                                                                        height = plot_size,
847
                                                                                        units = "cm"))
848
                                                       p_cm_train <- jpeg::readJPEG(self$plot_path_val)
849
                                                     }
850
                                                     
851
                                                     p_cm_train <- as.array(p_cm_train)
852
                                                     p_cm_train <- array(p_cm_train, dim = c(1, dim(p_cm_train)))
853
                                                     p_cm_val <- as.array(p_cm_val)
854
                                                     p_cm_val <- array(p_cm_val, dim = c(1, dim(p_cm_val)))
855
                                                     
856
                                                     num_images <- 1
857
                                                     train_images <- array(0, dim = c(num_images, dim(p_cm_train)[-1]))
858
                                                     train_images[1, , , ] <- p_cm_train
859
                                                     self$train_images <- train_images
860
                                                     
861
                                                     val_images <- array(0, dim = c(num_images, dim(p_cm_val)[-1]))
862
                                                     val_images[1, , , ] <- p_cm_val
863
                                                     self$val_images <- val_images
864
                                                     file.writer <- tensorflow::tf$summary$create_file_writer(file.path(self$path_tensorboard, self$run_name))
865
                                                     file.writer$set_as_default()
866
                                                     tensorflow::tf$summary$image(name = "confusion matrix train", data = self$train_images, step = as.integer(epoch-1))
867
                                                     tensorflow::tf$summary$image(name = "confusion matrix validation", data = self$val_images, step = as.integer(epoch-1))
868
                                                     file.writer$flush()
869
                                                   }
870
                                                 ))
871
  conf_matrix_cb_py_class(path_tensorboard = path_tensorboard,
872
                          run_name = run_name,
873
                          confMatLabels = confMatLabels,
874
                          cm_dir = cm_dir)
875
}
876
877
878
get_callbacks <- function(default_arguments, model, path_tensorboard, run_name, train_type,
879
                          path, train_val_ratio, batch_size, epochs, format,
880
                          max_queue_size, lr_plateau_factor, patience, cooldown, path_checkpoint,
881
                          steps_per_epoch, step, shuffle_file_order, initial_epoch, vocabulary,
882
                          learning_rate, shuffle_input, vocabulary_label, solver, dataset_val,
883
                          file_limit, reverse_complement, wavenet_format, cnn_format,
884
                          create_model_function = NULL, vocabulary_size, gen_cb, argumentList,
885
                          maxlen, labelGen, labelByFolder, vocabulary_label_size, tb_images,
886
                          target_middle, path_file_log, proportion_per_seq, validation_steps,
887
                          train_val_split_csv, n_gram, path_file_logVal, model_card,
888
                          skip_amb_nuc, max_samples, proportion_entries, path_log, output,
889
                          train_with_gen, random_sampling, reduce_lr_on_plateau,
890
                          save_weights_only, save_best_only, reset_states, early_stopping_time,
891
                          validation_only_after_training, gen.val, target_from_csv) {
892
  
893
  if (output$checkpoints) {
894
    # create folder for checkpoints using run_name
895
    checkpoint_dir <- paste0(path_checkpoint, "/", run_name)
896
    dir.create(checkpoint_dir, showWarnings = FALSE)
897
    if (!is.list(model$output) & !is.null(gen.val)) {
898
      # filename with epoch, validation loss and validation accuracy
899
      filepath_checkpoints <- file.path(checkpoint_dir, "Ep.{epoch:03d}-val_loss{val_loss:.2f}-val_acc{val_acc:.3f}.hdf5")
900
    } else {
901
      
902
      # if (is.null(gen.val)) {
903
      #   filepath_checkpoints <- file.path(checkpoint_dir, "Ep.{epoch:03d}-loss{loss:.2f}-acc{acc:.3f}.hdf5")
904
      # } else {
905
      filepath_checkpoints <- file.path(checkpoint_dir, "Ep.{epoch:03d}.hdf5")
906
      if ((is.list(save_best_only) && !is.null(save_best_only$monitor)) & is.null(dataset_val)) {
907
        warning("save_best_only not implemented for multi target or training without validation data. Setting save_best_only to NULL.")
908
        save_best_only <- NULL
909
      }
910
      #}
911
      
912
    }
913
  }
914
  
915
  # Check if path_file_log is unique
916
  if (!is.null(path_file_log) && dir.exists(path_file_log)) {
917
    stop(paste0("path_file_log entry is already present. Please give this file a unique name."))
918
  }
919
  
920
  count_files <- !random_sampling
921
  callbacks <- list()
922
  callback_names <- NULL
923
  
924
  if (reduce_lr_on_plateau) {
925
    if (is.list(model$outputs)) {
926
      monitor <- "val_loss"
927
    } else {
928
      monitor <- "val_acc"
929
    }
930
    callbacks[[1]] <- reduce_lr_cb(patience = patience, cooldown = cooldown,
931
                                   lr_plateau_factor = lr_plateau_factor,
932
                                   monitor = monitor)
933
    callback_names <- c("reduce_lr", callback_names)
934
  }
935
  
936
  if (!is.null(path_log)) {
937
    callbacks <- c(callbacks, log_cb(path_log, run_name))
938
    callback_names <- c("log", callback_names)
939
  }
940
  
941
  if (!output$tensorboard) tb_images <- FALSE
942
  if (output$tensorboard) {
943
    
944
    # add balanced acc score
945
    model <- manage_metrics(model)
946
    if (train_with_gen) {
947
      num_targets <- ifelse(train_type == "lm", length(vocabulary), length(vocabulary_label))
948
    } else {
949
      num_targets <- dim(dataset_val$Y)[2]
950
    }
951
    contains_macro_acc_metric <- FALSE
952
    for (i in 1:length(model$metrics)) {
953
      if (model$metrics[[i]]$name == "balanced_acc") contains_macro_acc_metric <- TRUE
954
    }
955
    
956
    metric_names <- vector("character", length(model$metrics))
957
    for (i in 1:length(model$metrics)) {
958
      metric_names[i] <-  model$metrics[[i]]$name
959
    }
960
    loss_index <- stringr::str_detect(metric_names, "loss")
961
    
962
    if (!contains_macro_acc_metric) {
963
      if (tb_images) {
964
        if (!reticulate::py_has_attr(model, "cm_dir")) {
965
          cm_dir <- file.path(tempdir(), paste(sample(letters, 7), collapse = ""))
966
          dir.create(cm_dir)
967
          model$cm_dir <- cm_dir
968
        }
969
        
970
        metrics <- c(model$metrics[!loss_index], balanced_acc_wrapper(num_targets = num_targets, cm_dir = model$cm_dir))
971
      }
972
    } else {
973
      metrics <- c(model$metrics[!loss_index])
974
    }
975
    
976
    # count files in path
977
    if (train_type == "label_rds" | train_type == "lm_rds") format <- "rds"
978
    if (train_with_gen) {
979
      num_train_files <- count_files(path = path, format = format, train_type = train_type, 
980
                                     target_from_csv = target_from_csv, 
981
                                     train_val_split_csv = train_val_split_csv)
982
    } else {
983
      num_train_files <- 1
984
    }
985
    
986
    complete_tb <- tensorboard_complete_cb(default_arguments = default_arguments, model = model, path_tensorboard = path_tensorboard, run_name = run_name, train_type = train_type,
987
                                           path = path, train_val_ratio = train_val_ratio, batch_size = batch_size, epochs = epochs,
988
                                           max_queue_size = max_queue_size, lr_plateau_factor = lr_plateau_factor, patience = patience, cooldown = cooldown,
989
                                           steps_per_epoch = steps_per_epoch, step = step, shuffle_file_order = shuffle_file_order, initial_epoch = initial_epoch, vocabulary = vocabulary,
990
                                           learning_rate = learning_rate, shuffle_input = shuffle_input, vocabulary_label = vocabulary_label, solver = solver,
991
                                           file_limit = file_limit, reverse_complement = reverse_complement, wavenet_format = wavenet_format,  cnn_format = cnn_format,
992
                                           create_model_function = NULL, vocabulary_size = vocabulary_size, gen_cb = gen_cb, argumentList = argumentList,
993
                                           maxlen = maxlen, labelGen = labelGen, labelByFolder = labelByFolder, vocabulary_label_size = vocabulary_label_size, tb_images = FALSE,
994
                                           target_middle = target_middle, num_train_files = num_train_files, path_file_log = path_file_log, proportion_per_seq = proportion_per_seq,
995
                                           skip_amb_nuc = skip_amb_nuc, max_samples = max_samples, proportion_entries = proportion_entries,
996
                                           train_with_gen = train_with_gen, count_files = !random_sampling)
997
    callbacks <- c(callbacks, complete_tb)
998
    callback_names <- c(callback_names, names(complete_tb))
999
  }
1000
  
1001
  if (output$checkpoints) {
1002
    if (wavenet_format) {
1003
      # can only save weights for wavenet
1004
      save_weights_only <- TRUE
1005
    }
1006
    callbacks <- c(callbacks, checkpoint_cb(filepath_checkpoints = filepath_checkpoints, save_weights_only = save_weights_only,
1007
                                            save_best_only = save_best_only))
1008
    callback_names <- c(callback_names, "checkpoint")
1009
  }
1010
  
1011
  if (reset_states) {
1012
    callbacks <- c(callbacks, reset_states_cb(path_file_log = path_file_log, path_file_logVal = path_file_logVal))
1013
    callback_names <- c(callback_names, "reset_states")
1014
  }
1015
  
1016
  if (!is.null(early_stopping_time)) {
1017
    callbacks <- c(callbacks, early_stopping_cb(early_stopping_time = early_stopping_time))
1018
    callback_names <- c(callback_names, "early_stopping")
1019
  }
1020
  
1021
  if (validation_only_after_training) {
1022
    if (!train_with_gen) stop("Validation after training only implemented for generator")
1023
    callbacks <- c(callbacks, validation_after_training_cb(gen.val = gen.val, validation_steps = validation_steps))
1024
    callback_names <- c(callback_names, "validation_after_training")
1025
  }
1026
  
1027
  if (!is.null(model_card)) {
1028
    callbacks <- c(callbacks, model_card_cb(model_card_path = model_card$path_model_card,
1029
                                            run_name = run_name, argumentList = argumentList))
1030
  }
1031
  
1032
  if (tb_images) {
1033
    if (is.list(model$output)) {
1034
      warning("Tensorboard images (confusion matrix) not implemented for model with multiple outputs.
1035
                 Setting tb_images to FALSE")
1036
      tb_images <- FALSE
1037
    }
1038
    
1039
    if (model$loss == "binary_crossentropy") {
1040
      warning("Tensorboard images (confusion matrix) not implemented for sigmoid activation in last layer.
1041
                 Setting tb_images to FALSE")
1042
      tb_images <- FALSE
1043
    }
1044
  }
1045
  
1046
  if (tb_images) {
1047
    
1048
    confMatLabels <- vocabulary_label
1049
    if (train_with_gen & train_type == "lm") {
1050
      if (is.null(n_gram) || n_gram == 1) {
1051
        confMatLabels <- vocabulary
1052
      } else {
1053
        l <- list()
1054
        for (i in 1:n_gram) {
1055
          l[[i]] <- vocabulary
1056
        }
1057
        confMatLabels <- expand.grid(l) %>% apply(1, paste0) %>% apply(2, paste, collapse = "") %>% sort()
1058
      }
1059
    }
1060
    
1061
    model <- model %>% keras::compile(loss = model$loss,
1062
                                      optimizer = model$optimizer, metrics = metrics)
1063
    
1064
    if (length(confMatLabels) > 16) {
1065
      message("Cannot display confusion matrix with more than 16 labels.")
1066
    } else {
1067
      
1068
      callbacks <- c(callbacks, conf_matrix_cb(path_tensorboard = path_tensorboard,
1069
                                               run_name = run_name,
1070
                                               confMatLabels = confMatLabels,
1071
                                               cm_dir = model$cm_dir))
1072
      callback_names <- c(callback_names, "conf_matrix")
1073
    }
1074
  }
1075
  
1076
  return(callbacks)
1077
}