Diff of /R/help_functions.R [000000] .. [409433]

Switch to unified view

a b/R/help_functions.R
1
#' Stride length calculation
2
#' 
3
#' Compute the optimal length for Stride.
4
#'
5
#' @param maxlen Length of the input sequence.
6
#' @param plen Length of a patch.
7
#' @returns Numerical value.
8
#' @noRd
9
stridecalc <- function(maxlen, plen) {
10
  vec <- c()
11
  for (i in ceiling(plen / 3):(floor(plen / 2) - 1)) {
12
    if ((maxlen - plen) %% i == 0) {
13
      vec <- c(vec, i)
14
    }
15
  }
16
  return(vec)
17
}
18
19
20
#' Number of Patches calculation
21
#' 
22
#' Compute the Number of Patches.
23
#'
24
#' @param plen Length of a patch.
25
#' @param maxlen Length of the input sequence.
26
#' @param stride Stride.
27
#' @returns Numerical value.
28
#' @noRd
29
nopatchescalc <- function(plen, maxlen, stride) {
30
  ((maxlen - plen)/stride) + 1
31
}
32
33
maxlencalc <- function(plen, nopatches, stride) {
34
  (nopatches - 1) * stride + plen
35
}
36
37
38
#' Checkpoints saving function
39
#'
40
#' @param cp Type of the checkpoint.
41
#' @param runname Name of the run. Name will be used to identify output from callbacks.
42
#' @param model A keras model.
43
#' @param optimizer A keras optimizer.
44
#' @param history A keras history object.
45
#' @returns None. Saves object to file.
46
#' @noRd
47
savechecks <- function(cp, runname, model, optimizer, history, path_checkpoint) {
48
  
49
  np = reticulate::import("numpy", convert = FALSE)
50
  ## define path for saved objects
51
  modpath <- file.path(path_checkpoint, runname, cp)
52
  ## save model object
53
  model %>% keras::save_model_hdf5(paste0(modpath, "mod_temp.h5"))
54
  file.rename(paste0(modpath, "mod_temp.h5"),
55
              paste0(modpath, "mod.h5"))
56
  ## save optimizer object
57
  np$save(
58
    paste0(modpath, "opt.npy"),
59
    np$array(keras::backend(FALSE)$batch_get_value(optimizer$weights),
60
             dtype = "object"),
61
    allow_pickle = TRUE
62
  )
63
  ## save history object
64
  saveRDS(history, paste0(modpath, "history_temp.rds"))
65
  file.rename(paste0(modpath, "history_temp.rds"),
66
              paste0(modpath, "history.rds"))
67
  ## print when finished
68
  message(paste0("---------- New ", cp, " model saved\n"))
69
}
70
71
#' Tensorboard Writer
72
#' 
73
#' Writes the loss and the accuracy for a given epoch to the tensorboard.
74
#'
75
#' @param writer Name of the tensorboard writer function.
76
#' @param loss Computed loss for a given epoch.
77
#' @param acc Computed accracy for a given epoch.
78
#' @param epoch Epoch, for which the values shall be written to the tensorboard.
79
#' @returns None. Saves object to file.
80
#' @noRd
81
TB_loss_acc <- function(writer, loss, acc, epoch) {
82
  with(writer$as_default(), {
83
    tensorflow::tf$summary$scalar('epoch_loss',
84
                                  loss$result(),
85
                                  step = tensorflow::tf$cast(epoch, "int64"))
86
    tensorflow::tf$summary$scalar('epoch_accuracy',
87
                                  acc$result(),
88
                                  step = tensorflow::tf$cast(epoch, "int64"))
89
  })
90
}
91
92
93
#' Step function 
94
#'
95
#' @param trainvaldat A data generator.
96
#' @param model A keras model.
97
#' @param train_type Either `"cpc"`, `"Self-GenomeNet"`.
98
#' @param training Boolean. Whether this step is a training step.
99
#' @noRd
100
modelstep <-
101
  function(trainvaldat,
102
           model,
103
           train_type = "cpc",
104
           training = FALSE) {
105
    ## get batch
106
    a <- trainvaldat$x %>% tensorflow::tf$convert_to_tensor()
107
    if (train_type == "Self-GenomeNet") {
108
      ## get complement 
109
      a_complement <-
110
        tensorflow::tf$convert_to_tensor(array(as.array(a)[, (dim(a)[2]):1, 4:1], dim = c(dim(a)[1], dim(a)[2], dim(a)[3])))
111
      a <- tensorflow::tf$concat(list(a, a_complement), axis = 0L)
112
    }
113
    ## insert data in model
114
    model(a, training = training)
115
  }
116
117
118
#' Reading Pretrained Model function
119
#'
120
#' @param pretrained_model The path to a saved keras model.
121
#' @noRd
122
ReadOpt <- function(pretrained_model) {
123
  ## Read configuration
124
  optconf <-
125
    readRDS(paste(sub("/[^/]+$", "", pretrained_model),
126
                  "optconfig.rds",
127
                  sep = "/"))
128
  ## Read optimizer
129
  optimizer <- tensorflow::tf$optimizers$Adam$from_config(optconf)
130
  # Initialize optimizer
131
  with(
132
    keras::backend()$name_scope(optimizer$`_name`),
133
    with(tensorflow::tf$python$framework$ops$init_scope(), {
134
      optimizer$iterations
135
      optimizer$`_create_hypers`()
136
      optimizer$`_create_slots`(model$trainable_weights)
137
    })
138
  )
139
  # Read optimizer weights
140
  wts2 <-
141
    np$load(paste(
142
      sub("/[^/]+$", "", pretrained_model),
143
      "/",
144
      utils::tail(stringr::str_remove(
145
        strsplit(pretrained_model, "/")[[1]], "mod.h5"
146
      ), 1),
147
      "opt.npy",
148
      sep = ""
149
    ), allow_pickle = TRUE)
150
  
151
  # Set optimizer weights
152
  optimizer$set_weights(wts2)
153
  return(optimizer)
154
}
155
156
#' Learning Rate Schedule - Parameter Check
157
#' 
158
#' Checks, whether all necessary parameters for a defined learning rate schedule are given.
159
#'
160
#' @param lr_schedule The name of a learning rate schedule.
161
#' @noRd
162
LRstop <- function(lr_schedule) {
163
  # cosine annealing
164
  if ("cosine_annealing" %in% lr_schedule) {
165
    if (!isTRUE(all.equal(sort(names(lr_schedule)), sort(
166
      c("schedule", "lrmin", "lrmax", "restart", "mult")
167
    )))) {
168
      stop(
169
        "Please define lrmin, lrmax, restart, and mult within the list to use cosine annealing"
170
      )
171
    }
172
    # step decay
173
  } else if ("step_decay" %in% lr_schedule) {
174
    if (!isTRUE(all.equal(sort(names(lr_schedule)), sort(
175
      c("schedule", "lrmax", "newstep", "mult")
176
    )))) {
177
      stop("Please define lrmax, newstep, and mult within the list to use step decay")
178
    }
179
    # exponential decay
180
  } else if ("exp_decay" %in% lr_schedule) {
181
    if (!isTRUE(all.equal(sort(names(lr_schedule)), sort(c(
182
      "schedule", "lrmax", "mult"
183
    ))))) {
184
      stop("Please define lrmax, and mult within the list to use exponential decay")
185
    }
186
  }
187
}
188
189
#' Learning Rate Calculator
190
#' 
191
#' Computes the learning rate for a given epoch.
192
#'
193
#' @param lr_schedule The name of a learning rate schedule.
194
#' @param epoch Epoch, for which the learning rate shall be calculated.
195
#' @noRd
196
getEpochLR <- function(lr_schedule, epoch) {
197
  if (lr_schedule$schedule == "cosine_annealing") {
198
    # cosine annealing
199
    sgdr(
200
      lrmin = lr_schedule$lrmin,
201
      restart = lr_schedule$restart,
202
      lrmax = lr_schedule$lrmax,
203
      mult = lr_schedule$mult,
204
      epoch = epoch
205
    )
206
  } else if (lr_schedule$schedule == "step_decay") {
207
    # step decay
208
    stepdecay(
209
      newstep = lr_schedule$newstep,
210
      lrmax = lr_schedule$lrmax,
211
      mult = lr_schedule$mult,
212
      epoch = epoch
213
    )
214
    
215
  } else if (lr_schedule$schedule == "exp_decay") {
216
    # exponential decay
217
    exp_decay(
218
      lrmax = lr_schedule$lrmax,
219
      mult = lr_schedule$mult,
220
      epoch = epoch
221
    )
222
  }
223
}
224
225
226
########################################################################################################
227
########################################### Parameter Lists ############################################
228
########################################################################################################
229
#
230
GenParams <- function(maxlen,
231
                      batch_size,
232
                      step,
233
                      proportion_per_seq,
234
                      max_samples) {
235
  checkmate::assertInt(maxlen, lower = 1)
236
  checkmate::assertInt(batch_size, lower = 1)
237
  checkmate::assertInt(step, lower = 1)
238
  checkmate::assertInt(max_samples, lower = 1, null.ok = TRUE)
239
  checkmate::assertNumber(
240
    proportion_per_seq,
241
    lower = 0,
242
    upper = 1,
243
    null.ok = TRUE
244
  )
245
  
246
  structure(
247
    list(
248
      maxlen = maxlen,
249
      batch_size = batch_size,
250
      step = step,
251
      proportion_per_seq = proportion_per_seq,
252
      max_samples = max_samples
253
    ),
254
    class = "Params"
255
  )
256
}
257
258
259
GenTParams <- function(path,
260
                       shuffle_file_orderTrain,
261
                       path_file_log,
262
                       seed) {
263
  checkmate::assertLogical(shuffle_file_orderTrain)
264
  checkmate::assertInt(seed)
265
  
266
  structure(
267
    list(
268
      path_corpus = path,
269
      shuffle_file_order = shuffle_file_orderTrain,
270
      path_file_log = path_file_log,
271
      seed = seed
272
    ),
273
    class = "Params"
274
  )
275
}
276
277
GenVParams <- function(path_val,
278
                       shuffle_file_orderVal) {
279
  checkmate::assertLogical(shuffle_file_orderVal)
280
  
281
  structure(list(path_corpus = path_val[[1]],
282
                 shuffle_file_order = shuffle_file_orderVal),
283
            class = "Params")
284
}
285
286
# add list of hyperparameters to model
287
add_hparam_list <- function(model, argg) {
288
  
289
  argg["model_metrics"] <- NULL
290
  argg["model"] <- NULL
291
  argg["i"] <- NULL
292
  argg["optimizer"] <- NULL
293
  argg["layer_lstm"] <- paste(as.character(argg$layer_lstm), collapse = " ")
294
  argg["filters"] <- paste(as.character(argg$filters), collapse = " ")
295
  argg["kernel_size"] <- paste(as.character(argg$kernel_size), collapse = " ")
296
  argg["pool_size"] <- paste(as.character(argg$pool_size), collapse = " ")
297
  argg["strides"] <- paste(as.character(argg$strides), collapse = " ")
298
  argg["residual_block"] <- paste(as.character(argg$residual_block), collapse = " ")
299
  argg["residual_block_length"] <- paste(as.character(argg$residual_block_length), collapse = " ")
300
  argg["size_reduction_1Dconv"] <- paste(as.character(argg$size_reduction_1Dconv), collapse = " ")
301
  argg["layer_dense"] <- paste(as.character(argg$layer_dense), collapse = " ")
302
  argg["padding"] <- paste(as.character(argg$padding), collapse = " ")
303
  argg["use_bias"] <- paste(as.character(argg$use_bias), collapse = " ")
304
  argg["input_label_list"] <- paste(as.character(argg$layer_dense), collapse = " ")
305
  argg["num_heads"] <- paste(as.character(argg$num_heads), collapse = " ")
306
  argg["head_size"] <- paste(as.character(argg$head_size), collapse = " ")
307
  argg["dropout"] <- paste(as.character(argg$dropout), collapse = " ")
308
  argg["input_tensor"] <- NULL
309
  argg["label_inputs"] <- NULL
310
  argg["f1"] <- NULL
311
  argg["multi_acc"] <- NULL
312
  argg[["number_model_params"]] <- model$count_params()
313
  for (i in 1:length(argg$label_input)) {
314
    argg[paste0("input_tensor_", i)] <- NULL
315
    argg[paste0("label_input_layer_", i)] <- NULL
316
  }
317
  argg["output_tensor"] <- NULL
318
  argg["output_list"] <- NULL
319
  argg["residual_layer"] <- NULL
320
  argg["label_noise_matrix"] <- NULL
321
  argg["smooth_loss"] <- NULL
322
  argg["noisy_loss"] <- NULL
323
  argg["col_sums"] <- NULL
324
  argg["auc"] <- NULL
325
  argg["multi_label"] <- NULL
326
  argg["macro_average_cb"] <- NULL
327
  argg["verbose"] <- NULL
328
  argg["embedded_indices"] <- NULL
329
  argg["position_indices"] <- NULL
330
  
331
  argg["optimizer"] <- NULL
332
  argg["residual_blocks"] <- paste(as.character(argg$residual_blocks), collapse = " ")
333
  
334
  argg["model_metrics"] <- NULL
335
  argg["i"] <- NULL
336
  argg["optimizer"] <- NULL
337
  argg["model"] <- NULL
338
  argg["input_tensor_1"] <- NULL
339
  argg["input_tensor_2"] <- NULL
340
  argg["input_label_list"] <- NULL
341
  for (i in 1:length(argg$label_input)) {
342
    argg[paste0("input_tensor_", i)] <- NULL
343
    argg[paste0("label_input_layer_", i)] <- NULL
344
  }
345
  argg[["number_model_params"]] <- model$count_params()
346
  argg["label_input"] <- NULL
347
  argg["label_inputs"] <- NULL
348
  argg["maxlen_1"] <- NULL
349
  argg["maxlen_2"] <- NULL
350
  argg["f1"] <- NULL
351
  argg["output_tensor"] <- NULL
352
  argg["output_tensor_1"] <- NULL
353
  argg["output_tensor_2"] <- NULL
354
  argg[["attn_block"]] <- NULL
355
  argg["feature_ext_model"] <- NULL
356
  argg["ff_dim"] <- NULL
357
  argg["pos_enc_layer"] <- NULL
358
  argg["number_of_cnn_layers"] <- paste(as.character(argg$number_of_cnn_layers), collapse = " ")
359
  argg["feature_ext_model"] <- NULL
360
  argg["pe_matrix"] <- NULL
361
  argg["position_embedding_layer"] <- NULL 
362
  argg["layer_add_td"] <- NULL 
363
  
364
  model$hparam <- argg
365
  model
366
}
367
368
369
get_maxlen <- function(model, set_learning, target_middle, read_data, return_int = FALSE,
370
                       n_gram = NULL) {
371
  if (is.null(set_learning)) {
372
    num_in_layers <- length(model$inputs)
373
    if (num_in_layers == 1) {
374
      maxlen <- model$input$shape[[2]]
375
    } else {
376
      if (!target_middle & !read_data) {
377
        maxlen <- model$input[[num_in_layers]]$shape[[2]]
378
      } else {
379
        maxlen <- model$inputs[[num_in_layers - 1]]$shape[[2]] + model$inputs[[num_in_layers]]$shape[[2]]
380
      }
381
    }
382
    
383
    if (!is.null(n_gram)) {
384
      maxlen <- maxlen + n_gram - 1
385
    }
386
    
387
  } else {
388
    maxlen <- set_learning$maxlen
389
  }
390
  return(maxlen)
391
}
392
393
# combine lists containing x, y and sample weight subsets
394
reorder_masked_lm_lists <- function(array_lists, include_sw = NULL) {
395
  
396
  if (is.null(include_sw)) include_sw <- FALSE
397
  x <- list()
398
  y <- list()
399
  sw <- list()
400
  for (i in 1:length(array_lists)) {
401
    x[[i]] <- array_lists[[i]]$x
402
    y[[i]] <- array_lists[[i]]$y
403
    if (include_sw) sw[[i]] <- array_lists[[i]]$sample_weight
404
  }
405
  x <- abind::abind(x, along = 1)
406
  y <- abind::abind(y, along = 1)
407
  if (include_sw) sw <- abind::abind(sw, along = 1)
408
  if (include_sw) {
409
    return(list(x=x, y=y, sw=sw))
410
  } else {
411
    return(list(x=x, y=y))
412
  }
413
  
414
}
415
416
# stack 
417
create_x_y_tensors_lm <- function(sequence_list, nuc_dist_list, target_middle,
418
                                  maxlen, vocabulary, ambiguous_nuc,
419
                                  start_index_list, quality_list, target_len,
420
                                  coverage_list, use_coverage, max_cov,n_gram,
421
                                  n_gram_stride, output_format, wavenet_format) {
422
  
423
  if (!wavenet_format) {
424
    
425
    array_list <- purrr::map(1:length(sequence_list),
426
                             ~seq_encoding_lm(sequence_list[[.x]], nuc_dist = nuc_dist_list[[.x]], adjust_start_ind = TRUE,
427
                                              maxlen = maxlen, vocabulary = vocabulary, ambiguous_nuc = ambiguous_nuc,
428
                                              start_ind =  start_index_list[[.x]], 
429
                                              quality_vector = quality_list[[.x]], target_len = target_len,
430
                                              cov_vector = coverage_list[[.x]], use_coverage = use_coverage, max_cov = max_cov,
431
                                              n_gram = n_gram, n_gram_stride = n_gram_stride, output_format = output_format)
432
    )
433
    
434
    if (!is.list(array_list[[1]][[2]])) {
435
      if (!target_middle) {
436
        x <- array_list[[1]][[1]]
437
        y <- array_list[[1]][[2]]
438
        if (length(array_list) > 1) {
439
          for (i in 2:length(array_list)) {
440
            x <- abind::abind(x, array_list[[i]][[1]], along = 1)
441
            y <- rbind(y, array_list[[i]][[2]])
442
          }
443
        }
444
        
445
        # coerce y type to matrix
446
        if (dim(x)[1] == 1) {
447
          if (is.null(n_gram)) {
448
            dim(y) <-  c(1, length(vocabulary))
449
          } else {
450
            dim(y) <-  c(1, length(vocabulary)^n_gram)
451
          }
452
        }
453
      } else {
454
        x_1 <- array_list[[1]][[1]][[1]]
455
        x_2 <- array_list[[1]][[1]][[2]]
456
        y <- array_list[[1]][[2]]
457
        if (length(array_list) > 1) {
458
          for (i in 2:length(array_list)) {
459
            x_1 <- abind::abind(x_1, array_list[[i]][[1]][[1]], along = 1)
460
            x_2 <- abind::abind(x_2, array_list[[i]][[1]][[2]], along = 1)
461
            y <- rbind(y, array_list[[i]][[2]])
462
          }
463
        }
464
        x <- list(x_1, x_2)
465
        
466
        # coerce y type to matrix
467
        if (dim(x_1)[1] == 1) {
468
          if (is.null(n_gram)) {
469
            dim(y) <-  c(1, length(vocabulary))
470
          } else {
471
            dim(y) <-  c(1, length(vocabulary)^n_gram)
472
          }
473
        }
474
      }
475
    } else {
476
      if (!target_middle) {
477
        x <- array_list[[1]][[1]]
478
        y <- array_list[[1]][[2]]
479
        if (length(array_list) > 1) {
480
          for (i in 2:length(array_list)) {
481
            x <- abind::abind(x, array_list[[i]][[1]], along = 1)
482
            for (j in 1:length(y)) {
483
              y[[j]] <- rbind(y[[j]], array_list[[i]][[2]][[j]] )
484
            }
485
          }
486
        }
487
        
488
        # coerce y type to matrix
489
        if (dim(x)[1] == 1) {
490
          for (i in 1:length(y)) {
491
            if (is.null(n_gram)) {
492
              dim(y[[i]]) <-  c(1, length(vocabulary))
493
            } else {
494
              dim(y[[i]]) <-  c(1, length(vocabulary)^n_gram)
495
            }
496
          }
497
        }
498
      } else {
499
        x_1 <- array_list[[1]][[1]][[1]]
500
        x_2 <- array_list[[1]][[1]][[2]]
501
        y <- array_list[[1]][[2]]
502
        if (length(array_list) > 1) {
503
          for (i in 2:length(array_list)) {
504
            x_1 <- abind::abind(x_1, array_list[[i]][[1]][[1]], along = 1)
505
            x_2 <- abind::abind(x_2, array_list[[i]][[1]][[2]], along = 1)
506
            for (j in 1:length(y)) {
507
              y[[j]] <- rbind(y[[j]], array_list[[i]][[2]][[j]] )
508
            }
509
          }
510
        }
511
        x <- list(x_1, x_2)
512
        
513
        # coerce y type to matrix
514
        if (dim(x_1)[1] == 1) {
515
          for (i in 1:length(y)) {
516
            if (is.null(n_gram)) {
517
              dim(y[[i]]) <-  c(1, length(vocabulary))
518
            } else {
519
              dim(y[[i]]) <-  c(1, length(vocabulary)^n_gram)
520
            }
521
          }
522
        }
523
      }
524
    }
525
    
526
    # wavenet format
527
  } else {
528
    
529
    if (target_len > 1) {
530
      stop("target_len must be 1 when using wavenet_format")
531
    }
532
    
533
    # one hot encode strings collected in sequence_list and connect arrays
534
    array_list <- purrr::map(1:length(sequence_list),
535
                             ~seq_encoding_lm(sequence_list[[.x]], ambiguous_nuc = ambiguous_nuc, adjust_start_ind = TRUE,
536
                                              maxlen = maxlen, vocabulary = vocabulary, nuc_dist = nuc_dist_list[[.x]],
537
                                              start_ind =  start_index_list[[.x]], 
538
                                              quality_vector = quality_list[[.x]], n_gram = n_gram,
539
                                              cov_vector = coverage_list[[.x]], use_coverage = use_coverage, max_cov = max_cov,
540
                                              output_format = output_format)
541
    )
542
    
543
    x <- array_list[[1]][[1]]
544
    y <- array_list[[1]][[2]]
545
    if (length(array_list) > 1) {
546
      for (i in 2:length(array_list)) {
547
        x <- abind::abind(x, array_list[[i]][[1]], along = 1)
548
        y <- abind::abind(y, array_list[[i]][[2]], along = 1)
549
      }
550
    }
551
  }
552
  return(list(x, y))
553
  
554
}
555
556
# 
557
slice_tensor_lm <- function(xy, output_format, target_len, n_gram,
558
                            n_gram_stride, 
559
                            # maxlen_n_gram = NULL,
560
                            # target_len_n_gram = NULL, 
561
                            total_seq_len, return_int) {
562
  
563
  xy_dim <- dim(xy)
564
  
565
  if (!is.null(n_gram)) {
566
    target_len <- floor(target_len/n_gram)
567
  }
568
  
569
  if (output_format == "target_right") {
570
    x_index <- get_x_index(xy_dim, output_format, target_len)
571
    if (return_int) {
572
      x <- xy[ , x_index, drop=FALSE]
573
      y <- xy[ , -x_index]
574
    } else {
575
      x <- xy[ , x_index, , drop=FALSE]
576
      y <- xy[ , -x_index, ]
577
    }
578
  }
579
  
580
  if (output_format == "wavenet") {
581
    
582
    if (target_len != 1) {
583
      stop("Target length must be 1 for wavenet model")
584
    }
585
    x_index <- 1:(xy_dim[2] - target_len)
586
    y_index <- 2:dim(xy)[2]
587
    if (return_int) {
588
      x <- xy[ , x_index, drop=FALSE]
589
      y <- xy[ , y_index, drop=FALSE]
590
    } else {
591
      x <- xy[ , x_index, , drop=FALSE]
592
      y <- xy[ , y_index, , drop=FALSE]
593
    }
594
    
595
  }
596
  
597
  if (output_format == "target_middle_cnn") {
598
    
599
    seq_middle <- ceiling(xy_dim[2]/2)
600
    y_index <- (1:target_len) + (seq_middle - ceiling(target_len/2))
601
    if (return_int) {
602
      x <- xy[ , -y_index, drop=FALSE]
603
      y <- xy[ , y_index]
604
    } else {
605
      x <- xy[ , -y_index, , drop=FALSE]
606
      y <- xy[ , y_index, ]
607
    }
608
    
609
  }
610
  
611
  if (output_format == "target_middle_lstm") {
612
    
613
    seq_middle <- ceiling(xy_dim[2]/2)
614
    y_index <- (1:target_len) + (seq_middle - ceiling(target_len/2))
615
    
616
    if (return_int) {
617
      x1 <- xy[ , 1:(min(y_index) - 1), drop=FALSE]
618
      # reverse order of x2
619
      x2 <- xy[ , xy_dim[2] : (max(y_index) + 1), drop=FALSE]
620
      y <- xy[ , y_index]
621
    } else {
622
      x1 <- xy[ , 1:(min(y_index) - 1), , drop=FALSE]
623
      # reverse order of x2
624
      x2 <- xy[ ,  xy_dim[2] : (max(y_index) + 1), , drop=FALSE]
625
      y <- xy[ , y_index, ]
626
    }
627
    
628
    x <- list(x1, x2)
629
    
630
  }
631
  
632
  if (target_len == 1 & xy_dim[1] == 1 & output_format != "wavenet") {
633
    y <- matrix(y, nrow = 1)
634
  }
635
  
636
  if (target_len > 1 & xy_dim[1] == 1) {
637
    y <- array(y, dim = c(1, dim(y)))
638
  }
639
  
640
  return(list(x=x, y=y))
641
  
642
}
643
644
get_x_index <- function(xy_dim, output_format, target_len) {
645
  
646
  if (output_format == "target_right") {
647
    x_index <- 1:(xy_dim[2] - target_len)
648
  }
649
  
650
  #TODO: subset for other formats with n_gram/stride
651
  
652
  return(x_index)
653
}
654
655
656
add_dim <- function(x) {
657
  
658
  if (is.null(dim(x))) {
659
    return(matrix(x, nrow = 1))
660
  } else {
661
    return(array(x, dim = c(1, dim(x))))
662
  }
663
  
664
}
665
666
shuffle_batches <- function(x, shuffle_index) {
667
  
668
  if (!is.list(x)) {
669
    dim_len <- length(dim(x))
670
    x <- shuffle_sample(x, dim_len, shuffle_index)
671
  } else {
672
    dim_len <- length(dim(x[[1]]))
673
    for (i in 1:length(x)) {
674
      x[[i]] <- shuffle_sample(x[[i]], dim_len, shuffle_index)
675
    }
676
  }
677
  
678
}  
679
680
shuffle_sample <- function(x, dim_len, shuffle_index) {
681
  
682
  if (is.null(dim_len) | dim_len == 1) {
683
    x <- x[shuffle_index]
684
  }
685
  
686
  if (dim_len == 2) {
687
    x <- x[shuffle_index, ]
688
  }
689
  
690
  if (dim_len == 3) {
691
    x <- x[shuffle_index, , ]
692
  }
693
  
694
  return(x)
695
}
696
697
count_gpu <- function() {
698
  
699
  pd <- tensorflow::tf$config$list_physical_devices()
700
  count <- 0
701
  for (i in 1:length(pd)) {
702
    if (pd[[i]]$device_type == "GPU") count <- count + 1
703
  }
704
  return(count)
705
}
706
707
to_time_dist <- function(x, samples_per_target) {
708
  x_dim <- dim(x)
709
  x_dim_td <- c(x_dim[1], samples_per_target, x_dim[2]/samples_per_target, x_dim[3])
710
  x_td <- keras::k_reshape(x, shape = x_dim_td)
711
  keras::k_eval(x_td)
712
}
713
714
715
#' Plot confusion matrix
716
#' 
717
#' Plot confusion matrix, either with absolute numbers or percentages per column (true labels).
718
#' 
719
#' @param cm A confusion matrix
720
#' @param perc Whether to use absolute numbers or percentages.
721
#' @param cm_labels Labels corresponding to confusion matrix entries.
722
#' @param round_dig How to round numbers.
723
#' @param text_size Size of text annotations.
724
#' @param highlight_diag Whether to highlight entries in diagonal.
725
#' @examplesIf reticulate::py_module_available("tensorflow")
726
#' cm <- matrix(c(90, 1, 0, 2, 7, 1, 8, 3, 1), nrow = 3, byrow = TRUE)
727
#' plot_cm(cm, perc = TRUE, cm_labels = paste0('label_', 1:3), text_size = 8)
728
#' 
729
#' @returns A ggplot of a confusion matrix.
730
#' @export
731
plot_cm <- function(cm, perc = FALSE, cm_labels, round_dig = 2, text_size = 1, highlight_diag = TRUE) {
732
  
733
  if (perc) cm <- cm_perc(cm, round_dig)
734
  cm <- create_conf_mat_obj(cm, cm_labels)
735
  
736
  cm_plot <- ggplot2::autoplot(cm, type = "heatmap") +
737
    ggplot2::scale_fill_gradient(low="#D6EAF8", high = "#2E86C1")  +
738
    ggplot2::theme(axis.text.x =
739
                     ggplot2::element_text(angle=90, hjust=1, size = text_size)) +
740
    ggplot2::theme(axis.text.y =
741
                     ggplot2::element_text(size = text_size))
742
  
743
  if (highlight_diag) {
744
    #diagonal_data <- data.frame(x = levels(y_true), y = levels(y_pred))
745
    diagonal_data <- data.frame(x = cm_labels, y = cm_labels)
746
    cm_plot <- cm_plot + ggplot2::geom_tile(data = diagonal_data, ggplot2::aes(x = x, y = y),
747
                                            fill = "red", colour = "white", size = 1,
748
                                            alpha = 0.000001) 
749
  }
750
  
751
  
752
  # TODO: add conf mat with ComplexHeatmap for bigger sizes
753
  cm_plot
754
  
755
}
756
757
758
759
#' Remove checkpoints
760
#' 
761
#' Remove all but n 'best' checkpoints, based on some condition. Condition can be 
762
#' accuracy, loss or epoch number.
763
#' 
764
#' @param cp_dir Directory containing checkpoints.
765
#' @param metric Either `"acc"`, `"loss"` or `"last_ep"`. Condition which checkpoints to keep.
766
#' @param best_n Number of checkpoints to keep.
767
#' @param ask_before_remove Whether to show files to keep before deleting rest.
768
#' @examplesIf reticulate::py_module_available("tensorflow")
769
#' model <- create_model_lstm_cnn(layer_lstm = 8)
770
#' checkpoint_folder <- tempfile()
771
#' dir.create(checkpoint_folder)
772
#' keras::save_model_hdf5(model, file.path(checkpoint_folder, 'Ep.007-val_loss11.07-val_acc0.6.hdf5'))
773
#' keras::save_model_hdf5(model, file.path(checkpoint_folder, 'Ep.019-val_loss8.74-val_acc0.7.hdf5'))
774
#' keras::save_model_hdf5(model, file.path(checkpoint_folder, 'Ep.025-val_loss0.03-val_acc0.8.hdf5'))
775
#' remove_checkpoints(cp_dir = checkpoint_folder, metric = "acc", best_n = 2,
776
#'                    ask_before_remove = FALSE)
777
#' list.files(checkpoint_folder)
778
#'  
779
#' @returns None. Deletes certain files.
780
#' @export 
781
remove_checkpoints <- function(cp_dir, metric = "acc", best_n = 1, ask_before_remove = TRUE) {
782
  
783
  stopifnot(metric %in% c("acc", "loss", "last_ep"))
784
  stopifnot(dir.exists(cp_dir))
785
  stopifnot(best_n >= 1)
786
  
787
  files <- list.files(cp_dir, full.names = TRUE)
788
  if (length(files) == 0) {
789
    stop("Directory is empty")
790
  }
791
  files_basename <- basename(files)
792
  num_cp <- length(files)
793
  
794
  if (metric == "acc") {
795
    if (!all(stringr::str_detect(files_basename, "acc"))) {
796
      stop("No accuracy information in checkpoint names ('acc' string), use other metric.")
797
    }
798
    acc_scores <- files_basename %>% stringr::str_extract("acc\\d++\\.\\d++") %>% 
799
      stringr::str_remove("acc") %>% as.numeric()
800
    rank_order <- rank(acc_scores, ties.method = "last")
801
    index <- rank_order > (num_cp - best_n)
802
  }
803
  
804
  if (metric == "loss") {
805
    if (!all(stringr::str_detect(files_basename, "loss"))) {
806
      stop("No loss information in checkpoint names ('loss' string), use other metric.")
807
    }
808
    loss_scores <- files_basename %>% stringr::str_extract("loss\\d++\\.\\d++") %>% 
809
      stringr::str_remove("loss") %>% as.numeric()
810
    rank_order <- rank(loss_scores, ties.method = "last")
811
    index <- rank_order <= best_n
812
  }
813
  
814
  if (metric == "last_ep") {
815
    ep_scores <- files_basename %>% stringr::str_extract("Ep\\.\\d++") %>% 
816
      stringr::str_remove("Ep\\.") %>% as.numeric()
817
    rank_order <- rank(ep_scores)
818
    index <- rank_order > (num_cp - best_n)
819
  }
820
  
821
  if (ask_before_remove) {
822
    message("Deleting", sum(!index), paste0(ifelse(sum(!index) == 1, "file", "files") , "."),
823
            "Only keep \n", paste0(basename(files[index]), collapse = ",\n "), "\n")
824
    remove_cps <- utils::askYesNo("")
825
  } else {
826
    remove_cps <- TRUE
827
  }
828
  
829
  if (is.na(remove_cps)) return(NULL)
830
  
831
  if (remove_cps) {
832
    invisible(file.remove(files[!index]))
833
  }
834
  
835
}
836
837
pooling_flatten <- function(global_pooling = NULL, output_tensor) {
838
  
839
  if (!is.null(global_pooling)) {
840
    stopifnot(global_pooling %in% c("max_ch_first", "max_ch_last", "average_ch_first",
841
                                    "average_ch_last", "both_ch_first", "both_ch_last", "all", "none", "flatten"))
842
  } else {
843
    out <- output_tensor %>% keras::layer_flatten()
844
    return(out)
845
  }
846
  
847
  #if (!is.null(global_pooling) & global_pooling != "flatten") {
848
  if (stringr::str_detect(global_pooling, "_ch_")) {
849
    
850
    if (global_pooling == "max_ch_first") {
851
      out <- output_tensor %>% keras::layer_global_max_pooling_1d(data_format="channels_first")
852
    } 
853
    if (global_pooling == "max_ch_last") {
854
      out <- output_tensor %>% keras::layer_global_max_pooling_1d(data_format="channels_last")
855
    } 
856
    if (global_pooling ==  "average_ch_first") {
857
      out <- output_tensor %>% keras::layer_global_average_pooling_1d(data_format="channels_first")
858
    } 
859
    if (global_pooling ==  "average_ch_last") { 
860
      out <- output_tensor %>% keras::layer_global_average_pooling_1d(data_format="channels_last")
861
    } 
862
    if (global_pooling ==  "both_ch_last") { 
863
      out1 <- output_tensor %>% keras::layer_global_average_pooling_1d(data_format="channels_last")
864
      out2 <- output_tensor %>% keras::layer_global_max_pooling_1d(data_format="channels_last")
865
      out <- keras::layer_concatenate(list(out1, out2))
866
    } 
867
    if (global_pooling ==  "both_ch_first") {
868
      out1 <- output_tensor %>% keras::layer_global_average_pooling_1d(data_format="channels_first")
869
      out2 <- output_tensor %>% keras::layer_global_max_pooling_1d(data_format="channels_first")
870
      out <- keras::layer_concatenate(list(out1, out2))
871
    } 
872
    if (global_pooling == "all") {
873
      out1 <- output_tensor %>% keras::layer_global_average_pooling_1d(data_format="channels_first")
874
      out2 <- output_tensor %>% keras::layer_global_max_pooling_1d(data_format="channels_first")
875
      out3 <- output_tensor %>% keras::layer_global_average_pooling_1d(data_format="channels_last")
876
      out4 <- output_tensor %>% keras::layer_global_max_pooling_1d(data_format="channels_last")
877
      out <- keras::layer_concatenate(list(out1, out2, out3, out4))
878
    }       
879
  }
880
  
881
  if (global_pooling == "flatten") {
882
    out <- output_tensor %>% keras::layer_flatten()
883
  }
884
  
885
  if (global_pooling == "none") {
886
    return(output_tensor)
887
  }
888
  
889
  return(out)
890
}
891
892
893
pooling_flatten_time_dist <- function(global_pooling = NULL, output_tensor) {
894
  
895
  if (!is.null(global_pooling)) {
896
    stopifnot(global_pooling %in% c("max_ch_first", "max_ch_last", "average_ch_first",
897
                                    "average_ch_last", "both_ch_first", "both_ch_last", "all", "none", "flatten"))
898
  } else {
899
    out <- output_tensor %>% keras::layer_flatten()
900
    return(out)
901
  }
902
  
903
  #if (!is.null(global_pooling) & global_pooling != "flatten") {
904
  if (stringr::str_detect(global_pooling, "_ch_")) {
905
    
906
    if (global_pooling == "max_ch_first") {
907
      out <- output_tensor %>% keras::time_distributed(keras::layer_global_max_pooling_1d(data_format="channels_first"))
908
    } 
909
    if (global_pooling == "max_ch_last") {
910
      out <- output_tensor %>% keras::time_distributed(keras::layer_global_max_pooling_1d(data_format="channels_last"))
911
    } 
912
    if (global_pooling ==  "average_ch_first") {
913
      out <- output_tensor %>% keras::time_distributed(keras::layer_global_average_pooling_1d(data_format="channels_first"))
914
    } 
915
    if (global_pooling ==  "average_ch_last") { 
916
      out <- output_tensor %>% keras::time_distributed(keras::layer_global_average_pooling_1d(data_format="channels_last"))
917
    } 
918
    if (global_pooling ==  "both_ch_last") { 
919
      out1 <- output_tensor %>% keras::time_distributed(keras::layer_global_average_pooling_1d(data_format="channels_last"))
920
      out2 <- output_tensor %>% keras::time_distributed(keras::layer_global_max_pooling_1d(data_format="channels_last"))
921
      out <- keras::layer_concatenate(list(out1, out2))
922
    } 
923
    if (global_pooling ==  "both_ch_first") {
924
      out1 <- output_tensor %>% keras::time_distributed(keras::layer_global_average_pooling_1d(data_format="channels_first"))
925
      out2 <- output_tensor %>% keras::time_distributed(keras::layer_global_max_pooling_1d(data_format="channels_first"))
926
      out <- keras::layer_concatenate(list(out1, out2))
927
    } 
928
    if (global_pooling == "all") {
929
      out1 <- output_tensor %>% keras::time_distributed(keras::layer_global_average_pooling_1d(data_format="channels_first"))
930
      out2 <- output_tensor %>% keras::time_distributed(keras::layer_global_max_pooling_1d(data_format="channels_first"))
931
      out3 <- output_tensor %>% keras::time_distributed(keras::layer_global_average_pooling_1d(data_format="channels_last"))
932
      out4 <- output_tensor %>% keras::time_distributed(keras::layer_global_max_pooling_1d(data_format="channels_last"))
933
      out <- keras::layer_concatenate(list(out1, out2, out3, out4))
934
    }       
935
  }
936
  
937
  if (global_pooling == "flatten") {
938
    out <- output_tensor %>% keras::time_distributed(keras::layer_flatten())
939
  }
940
  
941
  if (global_pooling == "none") {
942
    return(output_tensor)
943
  }
944
  
945
  return(out)
946
}
947
948
949
950
get_pooling_flatten_layer <- function(global_pooling = NULL) {
951
  
952
  if (!is.null(global_pooling)) {
953
    stopifnot(global_pooling %in% c("max_ch_first", "max_ch_last", "average_ch_first",
954
                                    "average_ch_last", "both_ch_first", "both_ch_last", "all", "none", "flatten"))
955
  }
956
  
957
  if (stringr::str_detect(global_pooling, "_ch_")) {
958
    
959
    if (global_pooling == "max_ch_first") {
960
      out <- keras::layer_global_max_pooling_1d(data_format="channels_first")
961
    } 
962
    if (global_pooling == "max_ch_last") {
963
      out <- keras::layer_global_max_pooling_1d(data_format="channels_last")
964
    } 
965
    if (global_pooling ==  "average_ch_first") {
966
      out <- keras::layer_global_average_pooling_1d(data_format="channels_first")
967
    } 
968
    if (global_pooling ==  "average_ch_last") { 
969
      out <- keras::layer_global_average_pooling_1d(data_format="channels_last")
970
    } 
971
    if (global_pooling ==  "both_ch_last") { 
972
      out1 <- keras::layer_global_average_pooling_1d(data_format="channels_last")
973
      out2 <- keras::layer_global_max_pooling_1d(data_format="channels_last")
974
      out <- keras::layer_concatenate(list(out1, out2))
975
    } 
976
    if (global_pooling == "both_ch_first") {
977
      out1 <- keras::layer_global_average_pooling_1d(data_format="channels_first")
978
      out2 <- keras::layer_global_max_pooling_1d(data_format="channels_first")
979
      out <- keras::layer_concatenate(list(out1, out2))
980
    } 
981
    if (global_pooling == "all") {
982
      out1 <- keras::layer_global_average_pooling_1d(data_format="channels_first")
983
      out2 <- keras::layer_global_max_pooling_1d(data_format="channels_first")
984
      out3 <- keras::layer_global_average_pooling_1d(data_format="channels_last")
985
      out4 <- keras::layer_global_max_pooling_1d(data_format="channels_last")
986
      out <- keras::layer_concatenate(list(out1, out2, out3, out4))
987
    }     
988
    return(out)
989
  }
990
  
991
  if (is.null(global_pooling) | global_pooling == "flatten") {
992
    out <- keras::layer_flatten()
993
    return(out)
994
  }
995
  
996
  return(keras::layer_flatten())
997
}
998
999
1000
f_reshape <- function(x, y, reshape_xy, reshape_x_bool, reshape_y_bool, reshape_sw_bool = FALSE, sw = NULL) {
1001
  
1002
  if (is.null(reshape_xy)) {
1003
    return(list(X = x, Y = y, SW = sw))
1004
  }
1005
  
1006
  if (reshape_sw_bool) {
1007
    
1008
    if (reshape_x_bool) {
1009
      x_new <- reshape_xy$x(x = x, y = y, sw = sw)
1010
    } else {
1011
      x_new <- x
1012
    }
1013
    
1014
    if (reshape_y_bool) {
1015
      y_new <- reshape_xy$y(x = x, y = y, sw = sw)
1016
    } else {
1017
      y_new <- y
1018
    }
1019
    
1020
    if (reshape_sw_bool) {
1021
      sw_new <- reshape_xy$sw(x = x, y = y, sw = sw)
1022
    } else {
1023
      sw_new <- sw
1024
    }
1025
    
1026
    return(list(X = x_new, Y = y_new, SW = sw_new))
1027
    
1028
  } else {
1029
    
1030
    
1031
    if (reshape_x_bool) {
1032
      x_new <- reshape_xy$x(x = x, y = y)
1033
    } else {
1034
      x_new <- x
1035
    }
1036
    
1037
    if (reshape_y_bool) {
1038
      y_new <- reshape_xy$y(x = x, y = y)
1039
    } else {
1040
      y_new <- y
1041
    }
1042
    #browser()
1043
    return(list(X = x_new, Y = y_new))
1044
    
1045
  }
1046
  
1047
}
1048
1049
bal_acc_from_cm <- function(cm, verbose = TRUE) {
1050
  
1051
  class_acc <- list()
1052
  names_list <- list()
1053
  count <- 1
1054
  for (i in 1:ncol(cm)) {
1055
    v_col <- cm[,i]
1056
    if (sum(v_col) == 0) next
1057
    class_acc[count] <- v_col[i]/sum(v_col)
1058
    names_list[count] <- colnames(cm)[i]
1059
    count <- count + 1
1060
  }
1061
  class_acc <- unlist(class_acc)
1062
  names(class_acc) <- unlist(names_list)
1063
  if (verbose) print(class_acc)
1064
  return(mean(class_acc))
1065
}