Diff of /R/train_cpc.R [000000] .. [409433]

Switch to unified view

a b/R/train_cpc.R
1
#' @title Train CPC inspired model
2
#'   
3
#' @description
4
#' Train a CPC (Oord et al.) inspired neural network on genomic data.
5
#' 
6
#' @inheritParams generator_fasta_lm
7
#' @inheritParams generator_fasta_label_folder
8
#' @inheritParams generator_fasta_label_header_csv
9
#' @inheritParams train_model
10
#' @param train_type Either `"cpc"`, `"Self-GenomeNet"`. 
11
#' @param encoder A keras encoder for the cpc function. 
12
#' @param context A keras context model for the cpc function.
13
#' @param path Path to training data. If \code{train_type} is \code{label_folder}, should be a vector or list
14
#' where each entry corresponds to a class (list elements can be directories and/or individual files). If \code{train_type} is not \code{label_folder}, 
15
#' can be a single directory or file or a list of directories and/or files.
16
#' @param path_val Path to validation data. See `path` argument for details.
17
#' @param path_checkpoint Path to checkpoints folder or `NULL`. If `NULL`, checkpoints don't get stored.
18
#' @param path_tensorboard Path to tensorboard directory or `NULL`. If `NULL`, training not tracked on tensorboard.
19
#' @param train_val_ratio For generator defines the fraction of batches that will be used for validation (compared to size of training data), i.e. one validation iteration
20
#' processes \code{batch_size} \eqn{*} \code{steps_per_epoch} \eqn{*} \code{train_val_ratio} samples. If you use dataset instead of generator and \code{dataset_val} is `NULL`, splits \code{dataset}
21
#' into train/validation data.
22
#' @param run_name Name of the run. Name will be used to identify output from callbacks.
23
#' @param batch_size Number of samples used for one network update.
24
#' @param epochs Number of iterations.
25
#' @param steps_per_epoch Number of training batches per epoch.
26
#' @param shuffle_file_order Boolean, whether to go through files sequentially or shuffle beforehand.
27
#' @param initial_epoch Epoch at which to start training. Note that network
28
#' will run for (\code{epochs} - \code{initial_epochs}) rounds and not \code{epochs} rounds.
29
#' @param seed Sets seed for reproducible results.
30
#' @param file_limit Integer or `NULL`. If integer, use only specified number of randomly sampled files for training. Ignored if greater than number of files in \code{path}.
31
#' @param patchlen The length of a patch when splitting the input sequence.
32
#' @param nopatches The number of patches when splitting the input sequence. 
33
#' @param step Frequency of sampling steps.
34
#' @param stride The overlap between two patches when splitting the input sequence.
35
#' @param pretrained_model A pretrained keras model, for which training will be continued
36
#' @param learningrate A Tensor, floating point value. If a schedule is defines, this value gives the initial learning rate. Defaults to 0.001.
37
#' @param learningrate_schedule A schedule for a non-constant learning rate over the training. Either "cosine_annealing", "step_decay", or "exp_decay".
38
#' @param k Value of k for sparse top k categorical accuracy. Defaults to 5.
39
#' @param stepsmin In CPC, a patch is predicted given another patch. stepsmin defines how many patches between these two should be ignored during prediction.
40
#' @param stepsmax The maximum distance between the predicted patch and the given patch.
41
#' @param emb_scale Scales the impact of a patches context.
42
#' @examplesIf reticulate::py_module_available("tensorflow")
43
#' 
44
#' #create dummy data
45
#' path_train_1 <- tempfile()
46
#' path_train_2 <- tempfile()
47
#' path_val_1 <- tempfile()
48
#' path_val_2 <- tempfile()
49
#' 
50
#' for (current_path in c(path_train_1, path_train_2,
51
#'                        path_val_1, path_val_2)) {
52
#'   dir.create(current_path)
53
#'   deepG::create_dummy_data(file_path = current_path,
54
#'                            num_files = 3,
55
#'                            seq_length = 10,
56
#'                            num_seq = 5,
57
#'                            vocabulary = c("a", "c", "g", "t"))
58
#' }
59
#' 
60
#' # create model
61
#' encoder <- function(maxlen = NULL,
62
#'                     patchlen = NULL,
63
#'                     nopatches = NULL,
64
#'                     eval = FALSE) {
65
#'   if (is.null(nopatches)) {
66
#'     nopatches <- nopatchescalc(patchlen, maxlen, patchlen * 0.4)
67
#'   }
68
#'   inp <- keras::layer_input(shape = c(maxlen, 4))
69
#'   stridelen <- as.integer(0.4 * patchlen)
70
#'   createpatches <- inp %>%
71
#'     keras::layer_reshape(list(maxlen, 4L, 1L), name = "prep_reshape1", dtype = "float32") %>%
72
#'     tensorflow::tf$image$extract_patches(
73
#'       sizes = list(1L, patchlen, 4L, 1L),
74
#'       strides = list(1L, stridelen, 4L, 1L),
75
#'       rates = list(1L, 1L, 1L, 1L),
76
#'       padding = "VALID",
77
#'       name = "prep_patches"
78
#'     ) %>%
79
#'     keras::layer_reshape(list(nopatches, patchlen, 4L),
80
#'                          name = "prep_reshape2") %>%
81
#'     tensorflow::tf$reshape(list(-1L, patchlen, 4L),
82
#'                            name = "prep_reshape3")
83
#' 
84
#'   danQ <- createpatches %>%
85
#'     keras::layer_conv_1d(
86
#'       input_shape = c(maxlen, 4L),
87
#'       filters = 320L,
88
#'       kernel_size = 26L,
89
#'       activation = "relu"
90
#'     ) %>%
91
#'     keras::layer_max_pooling_1d(pool_size = 13L, strides = 13L) %>%
92
#'     keras::layer_dropout(0.2) %>%
93
#'     keras::layer_lstm(units = 320, return_sequences = TRUE) %>%
94
#'     keras::layer_dropout(0.5) %>%
95
#'     keras::layer_flatten() %>%
96
#'     keras::layer_dense(925, activation = "relu")
97
#'   patchesback <- danQ %>%
98
#'     tensorflow::tf$reshape(list(-1L, tensorflow::tf$cast(nopatches, tensorflow::tf$int16), 925L))
99
#'   keras::keras_model(inp, patchesback)
100
#' }
101
#' 
102
#' context <- function(latents) {
103
#'   cres <- latents
104
#'   cres_dim = cres$shape
105
#'   predictions <-
106
#'     cres %>%
107
#'     keras::layer_lstm(
108
#'       return_sequences = TRUE,
109
#'       units = 256,  # WAS: 2048,
110
#'       name = paste("context_LSTM_1",
111
#'                    sep = ""),
112
#'       activation = "relu"
113
#'     )
114
#'   return(predictions)
115
#' }
116
#' 
117
#' # train model
118
#' temp_dir <- tempdir()
119
#' hist <- train_model_cpc(train_type = "CPC",
120
#'                         ### cpc functions ###
121
#'                         encoder = encoder,
122
#'                         context = context,
123
#'                         #### Generator settings ####
124
#'                         path_checkpoint = temp_dir,
125
#'                         path = c(path_train_1, path_train_2),
126
#'                         path_val = c(path_val_1, path_val_2),
127
#'                         run_name = "TEST",
128
#'                         batch_size = 8,
129
#'                         epochs = 3,
130
#'                         steps_per_epoch = 6,
131
#'                         patchlen = 100,
132
#'                         nopatches = 8)
133
#'                 
134
#'  
135
#' @returns A list of training metrics.  
136
#' @export
137
train_model_cpc <-
138
  function(train_type = "CPC",
139
           ### cpc functions ###
140
           encoder = NULL,
141
           context = NULL,
142
           #### Generator settings ####
143
           path,
144
           path_val = NULL,
145
           path_checkpoint = NULL,
146
           path_tensorboard = NULL,
147
           train_val_ratio = 0.2,
148
           run_name,
149
           
150
           batch_size = 32,
151
           epochs = 100,
152
           steps_per_epoch = 2000,
153
           shuffle_file_order = FALSE,
154
           initial_epoch = 1,
155
           seed = 1234,
156
           
157
           path_file_log = TRUE,
158
           train_val_split_csv = NULL,
159
           file_limit = NULL,
160
           proportion_per_seq = NULL,
161
           max_samples = NULL,
162
           maxlen = NULL,
163
           
164
           patchlen = NULL,
165
           nopatches = NULL,
166
           step = NULL,
167
           file_filter = NULL,
168
           stride = 0.4,
169
           pretrained_model = NULL,
170
           learningrate = 0.001,
171
           learningrate_schedule = NULL,
172
           k = 5,
173
           stepsmin = 2,
174
           stepsmax = 3,
175
           emb_scale = 0.1) {
176
    
177
    # Stride is default 0.4 x patchlen FOR NOW
178
    stride <- 0.4
179
    
180
    patchlen <- as.integer(patchlen)
181
    
182
    ########################################################################################################
183
    ############################### Warning messages if wrong initialization ###############################
184
    ########################################################################################################
185
    
186
    ####~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Model specification ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~####
187
    ## Three options:
188
    ## 1. Define Maxlen and Patchlen
189
    ## 2. Define Number of patches and Patchlen
190
    ## ---> in both cases the respectively missing value will be calculated
191
    ## 3. Pretrained model is giving specs
192
    ## error if none of those is fulfilled
193
    
194
    if (is.null(pretrained_model)) {
195
      ## If no pretrained model, patchlen has to be defined
196
      if (is.null(patchlen)) {
197
        stop("Please define patchlen")
198
      }
199
      ## Either maxlen or number of patches is needed
200
      if (is.null(maxlen) & is.null(nopatches)) {
201
        stop("Please define either maxlen or nopatches")
202
        ## the respectively missing value will be calculated
203
      } else if (is.null(maxlen) & !is.null(nopatches)) {
204
        maxlen <- (nopatches - 1) * (stride * patchlen) + patchlen
205
      } else if (!is.null(maxlen) & is.null(nopatches)) {
206
        nopatches <-
207
          as.integer((maxlen - patchlen) / (stride * patchlen) + 1)
208
      }
209
      ## if step is not defined, we do not use overlapping sequences
210
      if (is.null(step)) {
211
        step = maxlen
212
      }
213
    } else if (!is.null(pretrained_model)) {
214
      specs <-
215
        readRDS(paste(
216
          sub("/[^/]+$", "", pretrained_model),
217
          "modelspecs.rds",
218
          sep = "/"
219
        ))
220
      patchlen          <- specs$patchlen
221
      maxlen            <- specs$maxlen
222
      nopatches         <- specs$nopatches
223
      stride            <- specs$stride
224
      step              <- specs$step
225
      k                 <- specs$k
226
      emb_scale         <- specs$emb_scale
227
    }
228
    
229
    
230
    ####~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Learning rate schedule ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~####
231
    ## If learning_rate schedule is wanted, all necessary parameters must be given
232
    LRstop(learningrate_schedule)
233
    ########################################################################################################
234
    #################################### Preparation: Data, paths metrics ##################################
235
    ########################################################################################################
236
    
237
    ####~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Path definition ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~####
238
    runname <-
239
      paste0(run_name , format(Sys.time(), "_%y%m%d_%H%M%S"))
240
    
241
    ## Create folder for model
242
    if (!is.null(path_checkpoint)) {
243
      dir.create(paste(path_checkpoint, runname, sep = "/"))
244
      dir <- paste(path_checkpoint, runname, sep = "/")
245
      ## Create folder for filelog
246
      path_file_log <-
247
        paste(path_checkpoint, runname, "filelog.csv", sep = "/")
248
    } else {
249
      path_file_log <- NULL
250
    }
251
    
252
    GenConfig <-
253
      GenParams(maxlen, batch_size, step, proportion_per_seq, max_samples)
254
    GenTConfig <-
255
      GenTParams(path, shuffle_file_order, path_file_log, seed)
256
    GenVConfig <- GenVParams(path_val, shuffle_file_order)
257
    
258
    # train train_val_ratio via csv file
259
    if (!is.null(train_val_split_csv)) {
260
      if (is.null(path_val)) {
261
        path_val <- path
262
      } else {
263
        if (!all(unlist(path_val) %in% unlist(path))) {
264
          warning("Train/validation split done via file in train_val_split_csv. Only using files from path argument.")
265
        }
266
        path_val <- path
267
      }
268
      
269
      train_val_file <- utils::read.csv2(train_val_split_csv, header = TRUE, stringsAsFactors = FALSE)
270
      if (dim(train_val_file)[2] == 1) {
271
        train_val_file <- utils::read.csv(train_val_split_csv, header = TRUE, stringsAsFactors = FALSE)
272
      }
273
      train_val_file <- dplyr::distinct(train_val_file)
274
      
275
      if (!all(c("file", "type") %in% names(train_val_file))) {
276
        stop("Column names of train_val_split_csv file must be 'file' and 'type'")
277
      }
278
      
279
      if (length(train_val_file$file) != length(unique(train_val_file$file))) {
280
        stop("In train_val_split_csv all entires in 'file' column must be unique")
281
      }
282
      
283
      file_filter <- list()
284
      file_filter[[1]] <- train_val_file %>% dplyr::filter(type == "train")
285
      file_filter[[1]] <- as.character(file_filter[[1]]$file)
286
      file_filter[[2]] <- train_val_file %>% dplyr::filter(type == "val" | type == "validation")
287
      file_filter[[2]] <- as.character(file_filter[[2]]$file)
288
    }
289
    
290
    
291
    ####~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ File count ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~####
292
    if (is.null(file_filter) && is.null(train_val_split_csv)) {
293
      if (is.null(file_limit)) {
294
        if (is.list(path)) {
295
          num_files <- 0
296
          for (i in seq_along(path)) {
297
            num_files <- num_files + length(list.files(path[[i]]))
298
          }
299
        } else {
300
          num_files <- length(list.files(path))
301
        }
302
      } else {
303
        num_files <- file_limit * length(path)
304
      }
305
    } else {
306
      num_files <- length(file_filter[1])
307
    }
308
    
309
    ####~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Creation of generators ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~####
310
    message(format(Sys.time(), "%F %R"), ": Preparing the data\n")
311
    ####~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Training Generator ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~####
312
    fastrain <-
313
      do.call(generator_fasta_lm,
314
              c(GenConfig, GenTConfig, file_filter = file_filter[1]))
315
    
316
    
317
    ####~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Validation Generator ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~####
318
    fasval <-
319
      do.call(
320
        generator_fasta_lm,
321
        c(
322
          GenConfig,
323
          GenVConfig,
324
          seed = seed,
325
          file_filter = file_filter[2]
326
        )
327
      )
328
    ####~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Creation of metrics ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~####
329
    message(format(Sys.time(), "%F %R"), ": Preparing the metrics\n")
330
    train_loss <- tensorflow::tf$keras$metrics$Mean(name = 'train_loss')
331
    val_loss <- tensorflow::tf$keras$metrics$Mean(name = 'val_loss')
332
    train_acc <- tensorflow::tf$keras$metrics$Mean(name = 'train_acc')
333
    val_acc <- tensorflow::tf$keras$metrics$Mean(name = 'val_acc')
334
    
335
    ########################################################################################################
336
    ###################################### History object preparation ######################################
337
    ########################################################################################################
338
    
339
    history <- list(
340
      params = list(
341
        batch_size = batch_size,
342
        epochs = 0,
343
        steps = steps_per_epoch,
344
        samples = steps_per_epoch * batch_size,
345
        verbose = 1,
346
        do_validation = TRUE,
347
        metrics = c("loss", "accuracy", "val_loss", "val_accuracy")
348
      ),
349
      metrics = list(
350
        loss = c(),
351
        accuracy = c(),
352
        val_loss = c(),
353
        val_accuracy = c()
354
      )
355
    )
356
    
357
    eploss <- list()
358
    epacc <- list()
359
    
360
    ####~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Reformat to S3 object ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~####
361
    class(history) <- "keras_training_history"
362
    
363
    ########################################################################################################
364
    ############################################ Model creation ############################################
365
    ########################################################################################################
366
    if (is.null(pretrained_model)) {
367
      ####~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Unsupervised Build from scratch ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~####
368
      message(format(Sys.time(), "%F %R"), ": Creating the model\n")
369
      ## Build encoder
370
      enc <-
371
        encoder(maxlen = maxlen,
372
                patchlen = patchlen,
373
                nopatches = nopatches)
374
      
375
      ## Build model
376
      model <-
377
        keras::keras_model(
378
          enc$input,
379
          cpcloss(
380
            enc$output,
381
            context,
382
            batch_size = batch_size,
383
            steps_to_ignore = stepsmin,
384
            steps_to_predict = stepsmax,
385
            train_type = train_type,
386
            k = k,
387
            emb_scale = emb_scale
388
          )
389
        )
390
      
391
      ## Build optimizer
392
      optimizer <- # keras::optimizer_adam(
393
        tensorflow::tf$keras$optimizers$legacy$Adam(
394
          learning_rate = learningrate,
395
          beta_1 = 0.8,
396
          epsilon = 10 ^ -8,
397
          decay = 0.999,
398
          clipnorm = 0.01
399
        )
400
      ####~~~~~~~~~~~~~~~~~~~~~~~~~~ Unsupervised Read if pretrained model given ~~~~~~~~~~~~~~~~~~~~~~~~~####
401
      
402
    } else {
403
      message(format(Sys.time(), "%F %R"), ": Loading the trained model.\n")
404
      ## Read model
405
      model <- keras::load_model_hdf5(pretrained_model, compile = FALSE)
406
      optimizer <- ReadOpt(pretrained_model)
407
      optimizer$learning_rate$assign(learningrate)
408
    }
409
    
410
    ####~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Saving necessary model objects ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~####
411
    ## optimizer configuration
412
    
413
    if (!is.null(path_checkpoint)) {
414
      saveRDS(optimizer$get_config(),
415
              paste(dir, "optconfig.rds", sep = "/"))
416
      ## model parameters
417
      saveRDS(
418
        list(
419
          maxlen = maxlen,
420
          patchlen = patchlen,
421
          stride = stride,
422
          nopatches = nopatches,
423
          step = step,
424
          batch_size = batch_size,
425
          epochs = epochs,
426
          steps_per_epoch = steps_per_epoch,
427
          train_val_ratio = train_val_ratio,
428
          max_samples = max_samples,
429
          k = k,
430
          emb_scale = emb_scale,
431
          learningrate = learningrate
432
        ),
433
        paste(dir, "modelspecs.rds", sep = "/")
434
      )
435
    }
436
    ########################################################################################################
437
    ######################################## Tensorboard connection ########################################
438
    ########################################################################################################
439
    
440
    if (!is.null(path_tensorboard)) {
441
      ####~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Initialize Tensorboard writers ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~####
442
      logdir <- path_tensorboard
443
      writertrain <-
444
        tensorflow::tf$summary$create_file_writer(file.path(logdir, runname, "/train"))
445
      writerval <-
446
        tensorflow::tf$summary$create_file_writer(file.path(logdir, runname, "/validation"))
447
      
448
      ####~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Write parameters to Tensorboard ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~####
449
      tftext <-
450
        lapply(as.list(match.call())[-1][-c(1, 2)], function(x)
451
          ifelse(all(nchar(deparse(
452
            eval(x)
453
          )) < 20) && !is.null(eval(x)), eval(x), deparse(x)))
454
      
455
      with(writertrain$as_default(), {
456
        tensorflow::tf$summary$text("Specification",
457
                                    paste(
458
                                      names(tftext),
459
                                      tftext,
460
                                      sep = " = ",
461
                                      collapse = "  \n"
462
                                    ),
463
                                    step = 0L)
464
      })
465
    }
466
    
467
    ########################################################################################################
468
    ######################################## Training loop function ########################################
469
    ########################################################################################################
470
    
471
    train_val_loop <-
472
      function(batches = steps_per_epoch, epoch, train_val_ratio) {
473
        ####~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Start of loop ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~####
474
        for (i in c("train", "val")) {
475
          if (i == "val") {
476
            ## Calculate steps for validation
477
            batches <- ceiling(batches * train_val_ratio)
478
          }
479
          
480
          for (b in seq(batches)) {
481
            if (i == "train") {
482
              ####~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Training step ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~####
483
              ## If Learning rate schedule specified, calculate learning_rate for current epoch
484
              if (!is.null(learningrate_schedule)) {
485
                optimizer$learning_rate$assign(getEpochLR(learningrate_schedule, epoch))
486
              }
487
              ####~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Optimization step ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~####
488
              
489
              #with(tensorflow::tf$GradientTape() %as% tape, {
490
              with(reticulate::`%as%`(tensorflow::tf$GradientTape(), tape), {
491
                
492
                out <-
493
                  modelstep(fastrain(),
494
                            model,
495
                            train_type,
496
                            TRUE)
497
                l <- out[1]
498
                acc <- out[2]
499
              })
500
              
501
              gradients <-
502
                tape$gradient(l, model$trainable_variables)
503
              optimizer$apply_gradients(purrr::transpose(list(
504
                gradients, model$trainable_variables
505
              )))
506
              train_loss(l)
507
              train_acc(acc)
508
              
509
            } else {
510
              ####~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Validation step ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~####
511
              out <-
512
                modelstep(fasval(),
513
                          model,
514
                          train_type,
515
                          FALSE)
516
              
517
              l <- out[1]
518
              acc <- out[2]
519
              val_loss(l)
520
              val_acc(acc)
521
              
522
            }
523
            
524
            ## Print status of epoch
525
            if (b %in% seq(0, batches, by = batches / 10)) {
526
              message("-")
527
            }
528
          }
529
          
530
          ####~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ End of Epoch ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~####
531
          if (i == "train") {
532
            ## Training step
533
            # Write epoch result metrics value to tensorboard
534
            if (!is.null(path_tensorboard)) {
535
              TB_loss_acc(writertrain, train_loss, train_acc, epoch)
536
              with(writertrain$as_default(), {
537
                tensorflow::tf$summary$scalar('epoch_lr',
538
                                              optimizer$learning_rate,
539
                                              step = tensorflow::tf$cast(epoch, "int64"))
540
                tensorflow::tf$summary$scalar(
541
                  'training files seen',
542
                  nrow(
543
                    readr::read_csv(
544
                      path_file_log,
545
                      col_names = FALSE,
546
                      col_types = readr::cols()
547
                    )
548
                  ) / num_files,
549
                  step = tensorflow::tf$cast(epoch, "int64")
550
                )
551
              })
552
            }
553
            # Print epoch result metric values to console
554
            tensorflow::tf$print(" Train Loss",
555
                                 train_loss$result(),
556
                                 ", Train Acc",
557
                                 train_acc$result())
558
            
559
            # Save epoch result metric values to history object
560
            history$params$epochs <- epoch
561
            history$metrics$loss[epoch] <-
562
              as.double(train_loss$result())
563
            history$metrics$accuracy[epoch]  <-
564
              as.double(train_acc$result())
565
            
566
            # Reset states
567
            train_loss$reset_states()
568
            train_acc$reset_states()
569
            
570
          } else {
571
            ## Validation step
572
            # Write epoch result metrics value to tensorboard
573
            if (!is.null(path_tensorboard)) {
574
              TB_loss_acc(writerval, val_loss, val_acc, epoch)
575
            }
576
            
577
            # Print epoch result metric values to console
578
            tensorflow::tf$print(" Validation Loss",
579
                                 val_loss$result(),
580
                                 ", Validation Acc",
581
                                 val_acc$result())
582
            
583
            # save results for best model saving condition
584
            if (b == max(seq(batches))) {
585
              eploss[[epoch]] <- as.double(val_loss$result())
586
              epacc[[epoch]] <-
587
                as.double(val_acc$result())
588
            }
589
            
590
            # Save epoch result metric values to history object
591
            history$metrics$val_loss[epoch] <-
592
              as.double(val_loss$result())
593
            history$metrics$val_accuracy[epoch]  <-
594
              as.double(val_acc$result())
595
            
596
            # Reset states
597
            val_loss$reset_states()
598
            val_acc$reset_states()
599
          }
600
        }
601
        return(list(history,eploss,epacc))
602
      }
603
    
604
    ########################################################################################################
605
    ############################################# Training run #############################################
606
    ########################################################################################################
607
    
608
    
609
    message(format(Sys.time(), "%F %R"), ": Starting Training\n")
610
    
611
    ## Training loop
612
    for (i in seq(initial_epoch, (epochs + initial_epoch - 1))) {
613
      message(format(Sys.time(), "%F %R"), ": EPOCH ", i, " \n")
614
      
615
      ## Epoch loop
616
      out <- train_val_loop(epoch = i, train_val_ratio = train_val_ratio)
617
      history <- out[[1]]
618
      eploss <- out[[2]]
619
      epacc <- out[[3]]
620
      ## Save checkpoints
621
      # best model (smallest loss)
622
      if (eploss[[i]] == min(unlist(eploss))) {
623
        savechecks("best", runname, model, optimizer, history, path_checkpoint)
624
      }
625
      # backup model every 10 epochs
626
      if (i %% 2 == 0) {
627
        savechecks("backup", runname, model, optimizer, history, path_checkpoint)
628
      }
629
    }
630
    
631
    ########################################################################################################
632
    ############################################# Final saves ##############################################
633
    ########################################################################################################
634
    
635
    savechecks(cp = "FINAL", runname, model, optimizer, history, path_checkpoint)
636
    if (!is.null(path_tensorboard)) {
637
      writegraph <-
638
        tensorflow::tf$keras$callbacks$TensorBoard(file.path(logdir, runname))
639
      writegraph$set_model(model)
640
    }
641
  }