Diff of /R/train.R [000000] .. [409433]

Switch to unified view

a b/R/train.R
1
#' @title Train neural network on genomic data
2
#'
3
#' @description
4
#' Train a neural network on genomic data. Data can be fasta/fastq files, rds files or a prepared data set.
5
#' If the data is given as collection of fasta, fastq or rds files, function will create a data generator that extracts training and validation batches
6
#' from files. Function includes several options to determine the sampling strategy of the generator and preprocessing of the data.  
7
#' Training progress can be visualized in tensorboard. Model weights can be stored during training using checkpoints.        
8
#' 
9
#' @inheritParams generator_fasta_lm
10
#' @inheritParams generator_fasta_label_folder
11
#' @inheritParams generator_fasta_label_header_csv
12
#' @inheritParams get_generator 
13
#' @param train_type Either `"lm"`, `"lm_rds"`, `"masked_lm"` for language model; `"label_header"`, `"label_folder"`, `"label_csv"`, `"label_rds"` for classification or `"dummy_gen"`.
14
#' \itemize{
15
#' \item Language model is trained to predict character(s) in a sequence. \cr
16
#' \item `"label_header"`/`"label_folder"`/`"label_csv"` are trained to predict a corresponding class given a sequence as input.
17
#' \item If `"label_header"`, class will be read from fasta headers.
18
#' \item If `"label_folder"`, class will be read from folder, i.e. all files in one folder must belong to the same class. 
19
#' \item If `"label_csv"`, targets are read from a csv file. This file should have one column named "file". The targets then correspond to entries in that row (except "file"
20
#' column). Example: if we are currently working with a file called "a.fasta" and corresponding label is "label_1", there should be a row in our csv file  
21
#' 
22
#'  |  file       | label_1 | label_2 | 
23
#'  |   ---       |   ---   |  ---    |   
24
#'  | "a.fasta"   |    1    |    0    |
25
#'
26
#' \item If `"label_rds"`, generator will iterate over set of .rds files containing each a list of input and target tensors. Not implemented for model
27
#' with multiple inputs. 
28
#' \item If `"lm_rds"`, generator will iterate over set of .rds files and will split tensor according to `target_len` argument
29
#' (targets are last `target_len` nucleotides of each sequence). 
30
#' \item  If `"dummy_gen"`, generator creates random data once and repeatedly feeds these to model.
31
#' \item  If `"masked_lm"`, generator maskes some parts of the input. See `masked_lm` argument for details.
32
#' }
33
#' @param model A keras model.
34
#' @param path Path to training data. If \code{train_type} is \code{label_folder}, should be a vector or list
35
#' where each entry corresponds to a class (list elements can be directories and/or individual files). If \code{train_type} is not \code{label_folder}, 
36
#' can be a single directory or file or a list of directories and/or files.
37
#' @param path_val Path to validation data. See `path` argument for details.
38
#' @param dataset List of training data holding training samples in RAM instead of using generator. Should be list with two entries called `"X"` and `"Y"`.
39
#' @param dataset_val List of validation data. Should have two entries called `"X"` and `"Y"`.
40
#' @param path_checkpoint Path to checkpoints folder or `NULL`. If `NULL`, checkpoints don't get stored.
41
#' @param path_log Path to directory to write training scores. File name is `run_name` + `".csv"`. No output if `NULL`.
42
#' @param train_val_ratio For generator defines the fraction of batches that will be used for validation (compared to size of training data), i.e. one validation iteration
43
#' processes \code{batch_size} \eqn{*} \code{steps_per_epoch} \eqn{*} \code{train_val_ratio} samples. If you use dataset instead of generator and \code{dataset_val} is `NULL`, splits \code{dataset}
44
#' into train/validation data.
45
#' @param run_name Name of the run. Name will be used to identify output from callbacks. If `NULL`, will use date as run name. 
46
#' If name already present, will add `"_2"` to name or `"_{x+1}"` if name ends with `_x`, where `x` is some integer. 
47
#' @param batch_size Number of samples used for one network update.
48
#' @param epochs Number of iterations.
49
#' @param max_queue_size Maximum size for the generator queue.
50
#' @param reduce_lr_on_plateau Whether to use learning rate scheduler.
51
#' @param lr_plateau_factor Factor of decreasing learning rate when plateau is reached.
52
#' @param patience Number of epochs waiting for decrease in validation loss before reducing learning rate.
53
#' @param cooldown Number of epochs without changing learning rate.
54
#' @param steps_per_epoch Number of training batches per epoch.
55
#' @param step Frequency of sampling steps.
56
#' @param shuffle_file_order Boolean, whether to go through files sequentially or shuffle beforehand.
57
#' @param vocabulary Vector of allowed characters. Characters outside vocabulary get encoded as specified in \code{ambiguous_nuc}.
58
#' @param initial_epoch Epoch at which to start training. Note that network
59
#' will run for (\code{epochs} - \code{initial_epochs}) rounds and not \code{epochs} rounds.
60
#' @param path_tensorboard Path to tensorboard directory or `NULL`. If `NULL`, training not tracked on tensorboard.
61
#' @param save_best_only Only save model that improved on some score. Not applied if argument is `NULL`. Otherwise must be 
62
#' list with argument `monitor` or `save_freq` (can only use one option). `moniter` specifies what metric to use. 
63
#' `save_freq`, integer specifying how often to store a checkpoint (in epochs).
64
#' @param save_weights_only Whether to save weights only.
65
#' @param seed Sets seed for reproducible results.
66
#' @param shuffle_input Whether to shuffle entries in file.
67
#' @param tb_images Whether to show custom images (confusion matrix) in tensorboard "IMAGES" tab.
68
#' @param format File format, `"fasta"`, `"fastq"`, `"rds"` or `"fasta.tar.gz"`, `"fastq.tar.gz"` for `tar.gz` files. 
69
#' @param path_file_log Write name of files used for training to csv file if path is specified.
70
#' @param vocabulary_label Character vector of possible targets. Targets outside \code{vocabulary_label} will get discarded if
71
#' \code{train_type = "label_header"}.
72
#' @param file_limit Integer or `NULL`. If integer, use only specified number of randomly sampled files for training. Ignored if greater than number of files in \code{path}.
73
#' @param reverse_complement_encoding Whether to use both original sequence and reverse complement as two input sequences.
74
#' @param output_format Determines shape of output tensor for language model.
75
#' Either `"target_right"`, `"target_middle_lstm"`, `"target_middle_cnn"` or `"wavenet"`.
76
#' Assume a sequence `"AACCGTA"`. Output correspond as follows
77
#' \itemize{
78
#' \item `"target_right": X = "AACCGT", Y = "A"`
79
#' \item `"target_middle_lstm": X = (X_1 = "AAC", X_2 = "ATG"), Y = "C"` (note reversed order of X_2)
80
#' \item `"target_middle_cnn": X = "AACGTA", Y = "C"` 
81
#' \item `"wavenet": X = "AACCGT", Y = "ACCGTA"`
82
#' }
83
#' @param reset_states Whether to reset hidden states of RNN layer at every new input file and before/after validation.
84
#' @param use_quality_score Whether to use fastq quality scores. If `TRUE` input is not one-hot-encoding but corresponds to probabilities.
85
#' For example (0.97, 0.01, 0.01, 0.01) instead of (1, 0, 0, 0).
86
#' @param padding Whether to pad sequences too short for one sample with zeros.
87
#' @param early_stopping_time Time in seconds after which to stop training.
88
#' @param validation_only_after_training Whether to skip validation during training and only do one validation iteration after training.
89
#' @param skip_amb_nuc Threshold of ambiguous nucleotides to accept in fasta entry. Complete entry will get discarded otherwise.
90
#' @param class_weight List of weights for output. Order should correspond to \code{vocabulary_label}.
91
#' You can use \code{\link{get_class_weight}} function to estimate class weights:
92
#' 
93
#' \code{class_weights <- get_class_weights(path = path, train_type = train_type)}
94
#' 
95
#' If \code{train_type = "label_csv"} you need to add path to csv file:
96
#' 
97
#' \code{class_weights <- get_class_weights(path = path, train_type = train_type, csv_path = target_from_csv)}
98
#' @param print_scores Whether to print train/validation scores during training.
99
#' @param train_val_split_csv A csv file specifying train/validation split. csv file should contain one column named `"file"` and one column named
100
#' `"type"`. The `"file"` column contains names of fasta/fastq files and `"type"` column specifies if file is used for training or validation.
101
#' Entries in `"type"` must be named `"train"` or `"val"`, otherwise file will not be used for either. `path` and `path_val` arguments should be the same.
102
#' Not implemented for `train_type = "label_folder"`.
103
#' @param set_learning When you want to assign one label to set of samples. Only implemented for `train_type = "label_folder"`.
104
#' Input is a list with the following parameters 
105
#' \itemize{
106
#' \item `samples_per_target`: how many samples to use for one target.
107
#' \item `maxlen`: length of one sample.
108
#' \item `reshape_mode`: `"time_dist", "multi_input"` or `"concat"`. 
109
#' \itemize{
110
#' \item
111
#'  If `reshape_mode` is `"multi_input"`, generator will produce `samples_per_target` separate inputs, each of length `maxlen` (model should have
112
#' `samples_per_target` input layers).
113
#' \item If reshape_mode is `"time_dist"`, generator will produce a 4D input array. The dimensions correspond to
114
#' `(batch_size, samples_per_target, maxlen, length(vocabulary))`.
115
#' \item If `reshape_mode` is `"concat"`, generator will concatenate `samples_per_target` sequences
116
#' of length `maxlen` to one long sequence.
117
#' }
118
#' \item If `reshape_mode` is `"concat"`, there is an additional `buffer_len`
119
#' argument. If `buffer_len` is an integer, the subsequences are interspaced with `buffer_len` rows. The input length is
120
#' (`maxlen` \eqn{*} `samples_per_target`) + `buffer_len` \eqn{*} (`samples_per_target` - 1).
121
#' }
122
#' @param random_sampling Whether samples should be taken from random positions when using `max_samples` argument. If `FALSE` random 
123
#' samples are taken from a consecutive subsequence.
124
#' @param n_gram_stride Step size for n-gram encoding. For AACCGGTT with `n_gram = 4` and `n_gram_stride = 2`, generator encodes
125
#' `(AACC), (CCGG), (GGTT)`; for `n_gram_stride = 4` generator encodes `(AACC), (GGTT)`.
126
#' @param callback_list Add additional callbacks to `keras::fit` call.  
127
#' @param model_card List of arguments for training parameters of training run. Must contain at least an entry `path_model_card`, i.e. the 
128
#' directory where parameters are stored. List can contain additional (optional) arguments, for example 
129
#' `model_card = list(path_model_card = "/path/to/logs", description = "transfer learning with BERT model on virus data", ...)`  
130
#' @param return_gen Whether to return the train and validation generators (instead of training).
131
#' @examplesIf reticulate::py_module_available("tensorflow")
132
#' # create dummy data
133
#' path_train_1 <- tempfile()
134
#' path_train_2 <- tempfile()
135
#' path_val_1 <- tempfile()
136
#' path_val_2 <- tempfile()
137
#' 
138
#' for (current_path in c(path_train_1, path_train_2,
139
#'                        path_val_1, path_val_2)) {
140
#'   dir.create(current_path)
141
#'   create_dummy_data(file_path = current_path,
142
#'                     num_files = 3,
143
#'                     seq_length = 10,
144
#'                     num_seq = 5,
145
#'                     vocabulary = c("a", "c", "g", "t"))
146
#' }
147
#' 
148
#' # create model
149
#' model <- create_model_lstm_cnn(layer_lstm = 8, layer_dense = 2, maxlen = 5)
150
#' 
151
#' # train model
152
#' hist <- train_model(train_type = "label_folder",
153
#'                     model = model,
154
#'                     path = c(path_train_1, path_train_2),
155
#'                     path_val = c(path_val_1, path_val_2),
156
#'                     batch_size = 8,
157
#'                     epochs = 3,
158
#'                     steps_per_epoch = 6,
159
#'                     step = 5,
160
#'                     format = "fasta",
161
#'                     vocabulary_label = c("label_1", "label_2"))
162
#'  
163
#' @returns A list of training metrics.  
164
#' @export
165
train_model <- function(model = NULL,
166
                        dataset = NULL,
167
                        dataset_val = NULL,
168
                        # training args
169
                        train_val_ratio = 0.2,
170
                        run_name = "run_1",
171
                        initial_epoch = 0,
172
                        class_weight = NULL,
173
                        print_scores = TRUE,
174
                        epochs = 10,
175
                        max_queue_size = 100,
176
                        steps_per_epoch = 1000,
177
                        # callbacks
178
                        path_checkpoint = NULL,
179
                        path_tensorboard = NULL,
180
                        path_log = NULL,
181
                        save_best_only = NULL, 
182
                        save_weights_only = FALSE,
183
                        tb_images = FALSE,
184
                        path_file_log = NULL,
185
                        reset_states = FALSE,
186
                        early_stopping_time = NULL,
187
                        validation_only_after_training = FALSE,
188
                        train_val_split_csv = NULL,
189
                        reduce_lr_on_plateau = TRUE,
190
                        lr_plateau_factor = 0.9,
191
                        patience = 20,
192
                        cooldown = 1,
193
                        model_card = NULL,
194
                        callback_list = NULL,
195
                        # generator args
196
                        train_type = "label_folder",
197
                        path = NULL,
198
                        path_val = NULL,
199
                        batch_size = 64,
200
                        step = NULL,
201
                        shuffle_file_order = TRUE,
202
                        vocabulary = c("a", "c", "g", "t"),
203
                        format = "fasta",
204
                        ambiguous_nuc = "zero",
205
                        seed = c(1234, 4321),
206
                        file_limit = NULL,
207
                        use_coverage = NULL,
208
                        set_learning = NULL,
209
                        proportion_entries = NULL,
210
                        sample_by_file_size = FALSE,
211
                        n_gram = NULL,
212
                        n_gram_stride = 1,
213
                        masked_lm = NULL,
214
                        random_sampling = FALSE,
215
                        add_noise = NULL,
216
                        return_int = FALSE,
217
                        maxlen = NULL,
218
                        reverse_complement = FALSE,
219
                        reverse_complement_encoding = FALSE,
220
                        output_format = "target_right",
221
                        proportion_per_seq = NULL,
222
                        read_data = FALSE,
223
                        use_quality_score = FALSE,
224
                        padding = FALSE,
225
                        concat_seq = NULL,
226
                        target_len = 1,
227
                        skip_amb_nuc = NULL,
228
                        max_samples = NULL,
229
                        added_label_path = NULL,
230
                        add_input_as_seq = NULL,
231
                        target_from_csv = NULL,
232
                        target_split = NULL,
233
                        shuffle_input = TRUE,
234
                        vocabulary_label = NULL,
235
                        delete_used_files = FALSE,
236
                        reshape_xy = NULL,
237
                        return_gen = FALSE) {
238
  
239
  if (!is.null(model_card)) {
240
    if (!is.list(model_card)) {
241
      stop("model_card must be a list and contain at least an entry called 'path_model_card'")
242
    }
243
  }
244
  
245
  # initialize metrics, temporary fix
246
  model <- manage_metrics(model)
247
  
248
  run_name <- get_run_name(run_name, path_tensorboard, path_checkpoint, path_log,
249
                           path_model_card = model_card$path_model_card,
250
                           auto_extend = TRUE)
251
  train_with_gen <- is.null(dataset)
252
  output <- list(tensorboard = FALSE, checkpoints = FALSE)
253
  if (!is.null(path_tensorboard)) output$tensorboard <- TRUE
254
  if (!is.null(path_checkpoint)) output$checkpoints <- TRUE
255
  wavenet_format <- FALSE ; target_middle <- FALSE ; cnn_format <- FALSE
256
  if (train_type != "label_csv") target_from_csv <- NULL
257
  
258
  if (train_with_gen) {
259
    stopifnot(train_type %in% c("lm", "label_header", "label_folder", "label_csv", "label_rds", "lm_rds", "dummy_gen", "masked_lm"))
260
    stopifnot(ambiguous_nuc %in% c("zero", "equal", "discard", "empirical"))
261
    stopifnot(length(vocabulary) == length(unique(vocabulary)))
262
    stopifnot(length(vocabulary_label) == length(unique(vocabulary_label)))
263
    labelByFolder <- FALSE
264
    labelGen <- ifelse(train_type == "lm", FALSE, TRUE)
265
    
266
    if (train_type == "label_header") target_from_csv <- NULL
267
    if (train_type == "label_csv") {
268
      #train_type <- "label_header"
269
      if (is.null(target_from_csv)) {
270
        stop('You need to add a path to csv file for target_from_csv when using train_type = "label_csv"')
271
      }
272
      if (!is.null(vocabulary_label)) {
273
        message("Reading vocabulary_label from csv header")
274
        if (!is.data.frame(target_from_csv)) {
275
          output_label_csv <- utils::read.csv2(target_from_csv, header = TRUE, stringsAsFactors = FALSE)
276
          if (dim(output_label_csv)[2] == 1) {
277
            output_label_csv <- utils::read.csv(target_from_csv, header = TRUE, stringsAsFactors = FALSE)
278
          }
279
        } else {
280
          output_label_csv <- target_from_csv
281
        }
282
        vocabulary_label <- names(output_label_csv)
283
        vocabulary_label <- vocabulary_label[vocabulary_label != "file"]
284
      }
285
    }
286
    
287
    if (!is.null(skip_amb_nuc)) {
288
      if((skip_amb_nuc > 1) | (skip_amb_nuc <0)) {
289
        stop("skip_amb_nuc should be between 0 and 1 or NULL")
290
      }
291
    }
292
    
293
    if (!is.null(proportion_per_seq)) {
294
      if(any(proportion_per_seq > 1) | any(proportion_per_seq  < 0)) {
295
        stop("proportion_per_seq should be between 0 and 1 or NULL")
296
      }
297
    }
298
    
299
    # TODO: adjust for multi output model
300
    # if (!is.null(class_weight) && (length(class_weight) != length(vocabulary_label))) {
301
    #   stop("class_weight and vocabulary_label must have same length")
302
    # }
303
    
304
    if (!is.null(concat_seq)) {
305
      if (!is.null(use_coverage)) stop("Coverage encoding not implemented for concat_seq")
306
    }
307
    
308
    # train train_val_ratio via csv file
309
    if (!is.null(train_val_split_csv)) {
310
      
311
      train_val_file <- utils::read.csv2(train_val_split_csv, header = TRUE, stringsAsFactors = FALSE)
312
      
313
      if (is.null(path)) {
314
        path <- train_val_file %>% dplyr::filter(type %in% c("train", "val", "validation")) %>% 
315
          dplyr::select(file) %>% as.list()
316
      }
317
      
318
      if (train_type == "label_folder") {
319
        stop('train_val_split_csv not implemented for train_type = "label_folder"')
320
      }
321
      if (is.null(path_val)) {
322
        path_val <- path
323
      } else {
324
        if (!all(unlist(path_val) %in% unlist(path))) {
325
          warning("Train/validation split done via file in train_val_split_csv. Only using files from path argument.")
326
        }
327
        path_val <- path
328
      }
329
      
330
      if (dim(train_val_file)[2] == 1) {
331
        train_val_file <- utils::read.csv(train_val_split_csv, header = TRUE, stringsAsFactors = FALSE)
332
      }
333
      train_val_file <- dplyr::distinct(train_val_file)
334
      
335
      if (!all(c("file", "type") %in% names(train_val_file))) {
336
        stop("Column names of train_val_split_csv file must be 'file' and 'type'")
337
      }
338
      
339
      if (length(train_val_file$file) != length(unique(train_val_file$file))) {
340
        stop("In train_val_split_csv all entires in 'file' column must be unique")
341
      }
342
      
343
      train_files <- train_val_file %>% dplyr::filter(type == "train")
344
      train_files <- as.character(train_files$file)
345
      val_files <- train_val_file %>% dplyr::filter(type == "val" | type == "validation")
346
      val_files <- as.character(val_files$file)
347
    } else {
348
      train_files <- NULL
349
      val_files <- NULL
350
    }
351
    
352
    if (train_type == "lm") {
353
      stopifnot(output_format %in% c("target_right", "target_middle_lstm", "target_middle_cnn", "wavenet"))
354
      if (output_format == "target_middle_lstm") target_middle <- TRUE
355
      if (output_format == "target_middle_cnn") cnn_format <- TRUE
356
      if (output_format == "wavenet") wavenet_format <- TRUE
357
    }
358
    
359
    if (train_type == "label_header" & is.null(target_from_csv)) {
360
      stopifnot(!is.null(vocabulary_label))
361
    }
362
    
363
    if (train_type == "label_folder") {
364
      labelByFolder <- TRUE
365
      stopifnot(!is.null(vocabulary_label))
366
      stopifnot(length(path) == length(vocabulary_label))
367
    }
368
    
369
  }
370
  
371
  model_weights <- model$get_weights()
372
  
373
  # function arguments
374
  argumentList <- as.list(match.call(expand.dots=FALSE))
375
  #argumentList <- c(as.list(environment()), list(...)) log default args too
376
  argumentList <- argumentList[names(argumentList) != ""]
377
  argumentList <- lapply(argumentList, eval, envir = parent.frame())
378
  
379
  # extract maxlen from model
380
  if (is.null(maxlen)) {
381
    maxlen <- get_maxlen(model, set_learning, target_middle, read_data)
382
  }
383
  
384
  if (is.null(step)) step <- maxlen
385
  vocabulary_label_size <- length(vocabulary_label)
386
  vocabulary_size <- length(vocabulary)
387
  
388
  if (is.null(dataset) && labelByFolder) {
389
    if (length(path) == 1) warning("Training with just one label")
390
  }
391
  
392
  # add empty hparam dict if non exists
393
  if (!reticulate::py_has_attr(model, "hparam")) {
394
    model$hparam <- reticulate::dict()
395
  }
396
  
397
  # tempory file to log training data
398
  removeLog <- FALSE
399
  if (is.null(path_file_log)) {
400
    removeLog <- TRUE
401
    path_file_log <- tempfile(pattern = "", fileext = ".csv")
402
  } else {
403
    if (!endsWith(path_file_log, ".csv")) path_file_log <- paste0(path_file_log, ".csv")
404
    #path_file_logVal <- tempfile(pattern = "", fileext = ".csv")
405
  }
406
  if (reset_states) {
407
    path_file_logVal <- tempfile(pattern = "", fileext = ".csv")
408
  } else {
409
    path_file_logVal <- NULL
410
  }
411
  
412
  # if no dataset is supplied, external fasta generator will generate batches
413
  if (train_with_gen) {
414
    #message("Starting fasta generator...")
415
    
416
    gen <- get_generator(path = path, batch_size = batch_size, model = model,
417
                         maxlen = maxlen, step = step, shuffle_file_order = shuffle_file_order,
418
                         vocabulary = vocabulary, seed = seed[1], proportion_entries = proportion_entries,
419
                         shuffle_input = shuffle_input, format = format, reshape_xy = reshape_xy,
420
                         path_file_log = path_file_log, reverse_complement = reverse_complement, n_gram_stride = n_gram_stride,
421
                         output_format = output_format, ambiguous_nuc = ambiguous_nuc,
422
                         proportion_per_seq = proportion_per_seq, skip_amb_nuc = skip_amb_nuc,
423
                         use_quality_score = use_quality_score, padding = padding, n_gram = n_gram,
424
                         added_label_path = added_label_path, add_input_as_seq = add_input_as_seq,
425
                         max_samples = max_samples, concat_seq = concat_seq, target_len = target_len,
426
                         file_filter = train_files, use_coverage = use_coverage, random_sampling = random_sampling,
427
                         train_type = train_type, set_learning = set_learning, file_limit = file_limit,
428
                         reverse_complement_encoding = reverse_complement_encoding, read_data = read_data,
429
                         sample_by_file_size = sample_by_file_size, add_noise = add_noise, target_split = target_split,
430
                         target_from_csv = target_from_csv, masked_lm = masked_lm, return_int = return_int,
431
                         path_file_logVal = path_file_logVal, delete_used_files = delete_used_files,
432
                         vocabulary_label = vocabulary_label, val = FALSE)
433
    
434
    if (!is.null(path_val)) {
435
      
436
      gen.val <- get_generator(path = path_val, batch_size = batch_size, model = model,
437
                               maxlen = maxlen, step = step, shuffle_file_order = shuffle_file_order,
438
                               vocabulary = vocabulary, seed = seed[2], proportion_entries = proportion_entries,
439
                               shuffle_input = shuffle_input, format = format, delete_used_files = FALSE,
440
                               path_file_log = path_file_logVal, reverse_complement = reverse_complement, n_gram_stride = n_gram_stride,
441
                               output_format = output_format, ambiguous_nuc = ambiguous_nuc, reshape_xy = reshape_xy,
442
                               proportion_per_seq = proportion_per_seq, skip_amb_nuc = skip_amb_nuc,
443
                               use_quality_score = use_quality_score, padding = padding, n_gram = n_gram,
444
                               added_label_path = added_label_path, add_input_as_seq = add_input_as_seq,
445
                               max_samples = max_samples, concat_seq = concat_seq, target_len = target_len,
446
                               file_filter = val_files, use_coverage = use_coverage, random_sampling = random_sampling,
447
                               train_type = train_type, set_learning = set_learning, file_limit = file_limit,
448
                               reverse_complement_encoding = reverse_complement_encoding, read_data = read_data,
449
                               sample_by_file_size = sample_by_file_size, add_noise = add_noise, target_split = target_split,
450
                               target_from_csv = target_from_csv, masked_lm = masked_lm, return_int = return_int,
451
                               path_file_logVal = path_file_logVal, vocabulary_label = vocabulary_label,
452
                               val = TRUE)
453
    } else {
454
      gen.val <- NULL
455
    }
456
    
457
  }
458
  
459
  # skip validation callback
460
  if (validation_only_after_training | is.null(train_val_ratio) || train_val_ratio == 0) {
461
    validation_data <- NULL
462
  } else {
463
    if (train_with_gen) {
464
      if (is.null(path_val)) {
465
        validation_data <- NULL
466
      } else {
467
        validation_data <- gen.val
468
      } 
469
    } else {
470
      validation_data <- dataset_val
471
    }
472
  }
473
  
474
  if (is.null(validation_data)) {
475
    validation_steps <- NULL
476
  } else {
477
    validation_steps <- ceiling(steps_per_epoch * train_val_ratio)
478
  }
479
  
480
  callbacks <- get_callbacks(default_arguments = NULL, model = model, path_tensorboard = path_tensorboard, run_name = run_name, train_type = train_type,
481
                             path = path, train_val_ratio = train_val_ratio, batch_size = batch_size, epochs = epochs,
482
                             max_queue_size = max_queue_size, lr_plateau_factor = lr_plateau_factor, patience = patience, cooldown = cooldown, format = format,
483
                             steps_per_epoch = steps_per_epoch, step = step, shuffle_file_order = shuffle_file_order, initial_epoch = initial_epoch, vocabulary = vocabulary,
484
                             learning_rate =  model$optimizer$learning_rate$numpy(), solver = stringr::str_to_lower(model$optimizer$get_config()["name"]),
485
                             shuffle_input = shuffle_input, vocabulary_label = vocabulary_label, 
486
                             file_limit = file_limit, reverse_complement = reverse_complement, wavenet_format = wavenet_format,  cnn_format = cnn_format,
487
                             train_val_split_csv = train_val_split_csv, n_gram = n_gram, path_file_logVal = path_file_logVal, validation_steps = validation_steps,
488
                             create_model_function = NULL, vocabulary_size = vocabulary_size, gen_cb = NULL, argumentList = argumentList, output = output,
489
                             maxlen = maxlen, labelGen = labelGen, labelByFolder = labelByFolder, vocabulary_label_size = vocabulary_label_size, tb_images = tb_images,
490
                             target_middle = target_middle, path_file_log = path_file_log, proportion_per_seq = proportion_per_seq,
491
                             skip_amb_nuc = skip_amb_nuc, max_samples = max_samples, proportion_entries = proportion_entries, path_log = path_log,
492
                             train_with_gen = train_with_gen, random_sampling = random_sampling, reduce_lr_on_plateau = reduce_lr_on_plateau,
493
                             save_weights_only = save_weights_only, path_checkpoint = path_checkpoint, save_best_only = save_best_only, gen.val = gen.val,
494
                             target_from_csv = target_from_csv, reset_states = reset_states, early_stopping_time = early_stopping_time,
495
                             validation_only_after_training = validation_only_after_training, model_card = model_card, dataset_val = dataset_val)
496
  
497
  # training
498
  if (train_with_gen) {
499
    
500
    if (!is.null(dataset_val)) {
501
      validation_data <- dataset_val
502
      validation_steps <- NULL
503
    }
504
    
505
    if (return_gen) {
506
      return(list(gen = gen, gen.val = gen.val))
507
    }
508
    
509
    model <- keras::set_weights(model, model_weights)
510
    history <-
511
      model %>% keras::fit(
512
        x = gen,
513
        validation_data = validation_data,
514
        validation_steps = validation_steps,
515
        steps_per_epoch = steps_per_epoch,
516
        max_queue_size = max_queue_size,
517
        epochs = epochs,
518
        initial_epoch = initial_epoch,
519
        callbacks = c(callbacks, callback_list),
520
        class_weight = class_weight,
521
        batch_size = batch_size,
522
        verbose = print_scores)
523
    
524
    if (validation_only_after_training) {
525
      history$val_loss <- model$val_loss
526
      history$val_acc <- model$val_acc
527
      model$val_loss <- NULL
528
      model$val_acc <- NULL
529
    }
530
    
531
  } else {
532
    
533
    model <- keras::set_weights(model, model_weights)
534
    if (!is.null(dataset_val)) {
535
      validation_data <- list(dataset_val[[1]], dataset_val[[2]])
536
    } else {
537
      validation_data <- NULL
538
    }
539
    
540
    history <- keras::fit(
541
      object = model,
542
      x = dataset[[1]],
543
      y = dataset[[2]],
544
      batch_size = batch_size,
545
      validation_split = train_val_ratio,
546
      validation_data = validation_data,
547
      callbacks = c(callbacks, callback_list),
548
      epochs = epochs,
549
      class_weight = class_weight,
550
      verbose = print_scores)
551
  }
552
  
553
  if (removeLog & file.exists(path_file_log)) {
554
    file.remove(path_file_log)
555
  }
556
  
557
  message("Training done.")
558
  
559
  return(history)
560
}
561
562
#' Generate run_name if none is given or is already present.
563
#' 
564
#' If no run name is given, will use date as run name. If run name is already present will add _2 to name or 
565
#' _x+1 if name ends with _x and x is integer. 
566
#'
567
#' @param auto_extend If run_name is already present, add "_2" to name. If name already ends with "_x" replace x with x+1.
568
#' @noRd
569
get_run_name <- function(run_name = NULL, path_tensorboard, path_checkpoint, path_log, path_model_card, auto_extend = FALSE) {
570
  
571
  if (is.null(run_name)) {
572
    run_name_new <- as.character(Sys.time()) %>% stringr::str_replace_all(" ", "_")
573
  }
574
  
575
  tb_names <- ""
576
  cp_names <- ""
577
  log_names <- ""
578
  mc_names <- ""
579
  name_present_tb <- FALSE
580
  name_present_cp <- FALSE
581
  name_present_log <- FALSE
582
  name_present_mc <- FALSE
583
  
584
  if (!is.null(path_tensorboard)) {
585
    tb_names <- list.files(path_tensorboard)
586
    name_present_tb <- (run_name %in% tb_names) # & any(stringr::str_detect(tb_names, run_name))
587
  }
588
  if (!is.null(path_checkpoint)) {
589
    cp_names <- list.files(path_checkpoint)
590
    name_present_cp <- (run_name %in% cp_names) # & any(stringr::str_detect(cp_names, run_name))  
591
  }
592
  if (!is.null(path_log)) {
593
    log_names <- list.files(path_log)
594
    name_present_log <- (run_name %in% log_names) # & any(stringr::str_detect(log_names, run_name)) 
595
  }
596
  if (!is.null(path_model_card)) {
597
    mc_names <- list.files(path_model_card)
598
    name_present_mc <- (run_name %in% mc_names) # & any(stringr::str_detect(log_names, run_name)) 
599
  }
600
  
601
  name_present <- name_present_tb | name_present_cp | name_present_log | name_present_mc
602
  
603
  if (name_present & auto_extend) {
604
    
605
    ends_with_int <- stringr::str_detect(run_name, "_\\d+$")
606
    if (ends_with_int) {
607
      int_ending <- stringr::str_extract(run_name, "\\d+$") %>% as.integer()
608
      run_name_new <- paste0(stringr::str_remove(run_name, "\\d+$"), int_ending + 1)
609
    } else {
610
      run_name_new <- paste0(run_name, "_2")
611
    }
612
    
613
    int_ending <- stringr::str_subset(c(tb_names, cp_names, log_names, mc_names),
614
                                      paste0("^", stringr::str_remove(run_name, "_\\d+$"))) %>% unique()
615
    int_ending <- stringr::str_subset(int_ending, "_\\d+$")
616
    if (length(int_ending) > 0) {
617
      max_int_ending <- stringr::str_extract(int_ending, "_\\d+$") %>% stringr::str_remove("_") %>% as.integer() %>% max()
618
      if (!ends_with_int) {
619
        run_name_new <- paste0(run_name, "_", max_int_ending + 1)
620
      } else {
621
        run_name_new <- paste0(stringr::str_remove(run_name, "\\d+$"), max_int_ending + 1)
622
      }
623
    }
624
    
625
    if (length(int_ending) > 0) {
626
      name_order <- stringr::str_extract(int_ending, "\\d+$") %>% as.integer() %>% order()
627
      prev_names <- unique(c(run_name, int_ending[name_order]))
628
      if (ends_with_int) {
629
        name_order <- stringr::str_extract(prev_names, "\\d+$") %>% as.integer() %>% order()
630
        prev_names <- prev_names[name_order]
631
      }
632
      
633
      if (length(prev_names) > 8) {
634
        old_names_start <- paste(prev_names[1:2], collapse = ", ")
635
        old_names_end <- paste(prev_names[(length(prev_names)-1) : length(prev_names)], collapse = ", ")
636
        #old_names <- paste(old_names_start, ",...,", old_names_end) # outputs range of previously used names
637
        old_names <- run_name
638
      } else {
639
        old_names <- paste(prev_names, collapse = ", ")
640
      }
641
      message(paste("run_name", old_names, "already present, setting run_name to", run_name_new))
642
    } else {
643
      message(paste("run_name", run_name, "already present, setting run_name to", run_name_new))
644
    }
645
  }
646
  
647
  if (name_present & !auto_extend) {
648
    stop("run_name already present, please give your run a unique name")
649
  }
650
  
651
  if (!name_present) {
652
    return(run_name)
653
  }
654
  
655
  return(run_name_new)
656
}
657
658
#' Continue training from model card
659
#' 
660
#' Use information from model card to resume from the corresponding checkpoint using the same training arguments.
661
#' 
662
#' @param path_model_card Path to model card to resume training from.
663
#' @param seed Seed for reproducible results. If `NULL`, set random seed.
664
#' @param epoch Epoch to resume from. If `NULL`, use last epoch.
665
#' @param new_run_name New run name. If `NULL`, new run name is old run name + '_cont'.
666
#' @param new_args Named list of arguments to overwrite. Will use previous arguments from model card otherwise.
667
#' For example, if you want to change the batch size and padding option:
668
#' `new_args = list(batch_size = 6, padding = TRUE)`.
669
#' @param new_compile List of arguments to compile the model again. If `NULL`, use compiled model from checkpoint.
670
#' Example: `new_compile = list(loss = 'binary_crossentropy', metrics = 'acc', optimizer = keras::optimizer_adam())`
671
#' @param use_mirrored_strategy Whether to use distributed mirrored strategy. 
672
#' If NULL, will use distributed mirrored strategy only if >1 GPU available.   
673
#' @param unfreeze If `TRUE`, set trainable attribute of model to `TRUE` (unfreeze weights). 
674
#' @param verbose Whether to print all training arguments. 
675
#' @examples
676
#' \donttest{
677
#' library(keras)
678
#' # create dummy data and temp directories
679
#' path_train_1 <- tempfile()
680
#' path_train_2 <- tempfile()
681
#' path_val_1 <- tempfile()
682
#' path_val_2 <- tempfile()
683
#' path_checkpoint <- tempfile()
684
#' dir.create(path_checkpoint)
685
#' path_model_card <- tempfile()
686
#' dir.create(path_model_card)
687
#' 
688
#' for (current_path in c(path_train_1, path_train_2,
689
#'                        path_val_1, path_val_2)) {
690
#'   dir.create(current_path)
691
#'   create_dummy_data(file_path = current_path,
692
#'                     num_files = 3,
693
#'                     seq_length = 10,
694
#'                     num_seq = 5,
695
#'                     vocabulary = c("a", "c", "g", "t"))
696
#' }
697
#' 
698
#' # create model
699
#' model <- create_model_lstm_cnn(layer_lstm = 8, layer_dense = 2, maxlen = 5)
700
#' 
701
#' # train model
702
#' run_name <- 'test_run_1'
703
#' hist <- train_model(train_type = "label_folder",
704
#'                     run_name = run_name,
705
#'                     path_checkpoint = path_checkpoint,
706
#'                     model_card = list(path_model_card = path_model_card, description = 'test run'),
707
#'                     model = model,
708
#'                     path = c(path_train_1, path_train_2),
709
#'                     path_val = c(path_val_1, path_val_2),
710
#'                     batch_size = 8,
711
#'                     epochs = 3,
712
#'                     steps_per_epoch = 6,
713
#'                     vocabulary_label = c("label_1", "label_2"))
714
#' 
715
#' # resume training
716
#' resume_training_from_model_card(path_model_card = file.path(path_model_card, run_name))
717
#' }
718
#' @returns A list of training metrics.  
719
#' @export
720
resume_training_from_model_card <- function(path_model_card,
721
                                            seed = NULL,
722
                                            epoch = NULL,
723
                                            new_run_name = NULL,
724
                                            new_args = NULL,
725
                                            new_compile = NULL,
726
                                            use_mirrored_strategy = NULL,
727
                                            unfreeze = FALSE, 
728
                                            verbose = FALSE) {
729
  
730
  if (is.null(use_mirrored_strategy)) use_mirrored_strategy <- ifelse(count_gpu() > 1, TRUE, FALSE)
731
  
732
  info <- file.info(path_model_card)
733
  is_directory <- info$isdir
734
  
735
  if (is.na(is_directory)) {
736
    stop("model_card path does not exist.\n")
737
  } else if (is_directory) {
738
    mc <- get_mc(path_model_card = path_model_card, epoch = epoch)
739
  } else {
740
    mc <- path_model_card
741
  }
742
  
743
  mc_args <- readRDS(mc)
744
  train_args_mc <- mc_args$train_model_args
745
  new_train_args <- train_args_mc
746
  
747
  if (is.null(new_run_name)) {
748
    new_train_args$run_name <- set_new_run_name(train_args_mc$run_name)
749
  } else {
750
    new_train_args$run_name <- new_run_name
751
  }
752
  
753
  # overwrite args
754
  if (is.null(seed)) seed <- get_seed()
755
  new_train_args$seed <- seed
756
  
757
  # load checkpoint to resume from
758
  if (is.null(train_args_mc$path_checkpoint)) {
759
    stop('Did not save checkpoints in the run from model card')
760
  }
761
  
762
  if (unfreeze) {
763
    model$trainable <- TRUE
764
  }
765
  
766
  if (use_mirrored_strategy) {
767
    mirrored_strategy <- tensorflow::tf$distribute$MirroredStrategy()
768
    with(mirrored_strategy$scope(), {
769
      model <- load_model(cp_path = file.path(train_args_mc$path_checkpoint, train_args_mc$run_name),
770
                          ep_index = epoch,
771
                          new_compile = new_compile)
772
    })
773
  } else {
774
    model <- load_model(cp_path = file.path(train_args_mc$path_checkpoint, train_args_mc$run_name),
775
                        ep_index = epoch,
776
                        new_compile = new_compile)
777
  }
778
  
779
  new_train_args$model <- model
780
  
781
  if (!is.null(new_args)) {
782
    stopifnot(is.list(new_args))
783
    for (n in names(new_args))
784
      new_train_args[[n]] <- new_args[[n]]
785
  }
786
  
787
  new_train_args$model_card[['cont_train_info']] <- paste0('run continues training from run ',
788
                                                           train_args_mc, ' and epoch ',
789
                                                           max(mc_args$logs$processing_step))
790
  
791
  if (verbose) {
792
    print(new_train_args)
793
  }
794
  
795
  do.call(train_model, new_train_args)
796
  
797
}
798
799
get_mc <- function(path_model_card, epoch = NULL) {
800
  
801
  all_cards <- list.files(path_model_card, full.names = TRUE)
802
  all_epochs <- vector("integer", length(all_cards))
803
  for (i in seq_along(all_cards)) {
804
    split_string <- all_cards[i] %>% basename() %>% stringr::str_split("_")  
805
    all_epochs[i] <- split_string[[1]][2] %>% as.integer()
806
  }
807
  
808
  if (is.null(epoch)) epoch <- max(all_epochs)
809
  
810
  index <- all_epochs == epoch
811
  if (sum(index) == 0) {
812
    error_message <- paste('epoch not found in model card directory, possible values:',
813
                           paste(all_epochs, collapse = ", "))
814
    stop(error_message)
815
  } 
816
  
817
  mc <- all_cards[index]
818
  return(mc)
819
  
820
}
821
822
load_model <- function(cp_path,
823
                       ep_index,
824
                       new_compile) {
825
  
826
  model <- load_cp(cp_path,
827
                   ep_index = ep_index,
828
                   mirrored_strategy = FALSE,
829
                   compile = ifelse(is.null(new_compile), TRUE, FALSE))
830
  
831
  if (!is.null(new_compile)) {
832
    model <- keras::compile(model,
833
                            optimizer = new_compile$optimizer,
834
                            loss = new_compile$loss,
835
                            metrics = new_compile$metrics)
836
  }
837
  
838
  return(model)
839
  
840
}
841
842
get_seed <- function() {
843
  
844
  current_time <- Sys.time()
845
  current_time <- as.numeric(current_time) * 1e2
846
  seed_value <- (current_time %% 10^5) %>% as.integer()
847
  set.seed(seed_value)
848
  return(sample(1:10^6, 2))
849
  
850
}
851
852
set_new_run_name <- function(run_name_old) {
853
  
854
  run_name_new <- paste0(run_name_old, '_cont')
855
  return(run_name_new)
856
  
857
}