Diff of /R/generators.R [000000] .. [409433]

Switch to side-by-side view

--- a
+++ b/R/generators.R
@@ -0,0 +1,326 @@
+#' Wrapper for generator functions
+#' 
+#' For a detailed description see the data generator [tutorial](https://deepg.de/articles/data_generator.html).
+#' Will choose one of the generators from \code{\link{generator_fasta_lm}}, 
+#' \code{\link{generator_fasta_label_folder}}, \code{\link{generator_fasta_label_header_csv}}, 
+#' \code{\link{generator_rds}}, \code{\link{generator_random}}, \code{\link{generator_dummy}} or 
+#' \code{\link{generator_fasta_lm}} according to the \code{train_type} and \code{random_sampling}
+#' arguments.
+#'
+#' @inheritParams train_model
+#' @inheritParams generator_fasta_lm
+#' @inheritParams generator_fasta_label_folder
+#' @inheritParams generator_fasta_label_header_csv
+#' @inheritParams generator_rds
+#' @inheritParams generator_random
+#' @inheritParams generator_initialize
+#' @param path_file_logVal Path to csv file logging used validation files.
+#' @examplesIf reticulate::py_module_available("tensorflow")
+#' # create dummy fasta files
+#' fasta_path <- tempfile()
+#' dir.create(fasta_path)
+#' create_dummy_data(file_path = fasta_path,
+#'                   num_files = 3,
+#'                   seq_length = 10,
+#'                   num_seq = 5,
+#'                   vocabulary = c("a", "c", "g", "t"))
+#' 
+#' gen <- get_generator(path = fasta_path,
+#'                      maxlen = 5, train_type = "lm",
+#'                      output_format = "target_right",
+#'                      step = 3, batch_size = 7)
+#' z <- gen()
+#' x <- z[[1]]
+#' y <- z[[2]]
+#' dim(x)
+#' dim(y)
+#' 
+#' @returns A generator function.
+#' @export
+get_generator <- function(path = NULL,
+                          train_type,
+                          batch_size,
+                          maxlen,
+                          step = NULL,
+                          shuffle_file_order = FALSE,
+                          vocabulary = c("A", "C", "G", "T"),
+                          seed = 1,
+                          proportion_entries = NULL,
+                          shuffle_input = FALSE,
+                          format = "fasta",
+                          path_file_log = NULL,
+                          reverse_complement = FALSE,
+                          n_gram = NULL,
+                          n_gram_stride = NULL,
+                          output_format = "target_right",
+                          ambiguous_nuc = "zero",
+                          proportion_per_seq = NULL,
+                          skip_amb_nuc = NULL,
+                          use_quality_score = FALSE,
+                          padding = FALSE,
+                          added_label_path = NULL,
+                          target_from_csv = NULL,
+                          add_input_as_seq = NULL,
+                          max_samples = NULL,
+                          concat_seq = NULL,
+                          target_len = 1,
+                          file_filter = NULL,
+                          use_coverage = NULL,
+                          sample_by_file_size = FALSE,
+                          add_noise = NULL,
+                          random_sampling = FALSE,
+                          set_learning = NULL,
+                          file_limit = NULL,
+                          reverse_complement_encoding = FALSE,
+                          read_data = FALSE,
+                          target_split = NULL,
+                          path_file_logVal = NULL,
+                          model = NULL,
+                          vocabulary_label = NULL,
+                          masked_lm = NULL,
+                          val = FALSE,
+                          return_int = FALSE,
+                          verbose = TRUE,
+                          delete_used_files = FALSE,
+                          reshape_xy = NULL) {
+  
+  if (random_sampling) {
+    if (use_quality_score) stop("use_quality_score not implemented for random sampling")
+    if (read_data) stop("read_data not implemented for random sampling")
+    if (!is.null(use_coverage)) stop("use_coverage not implemented for random sampling")
+    if (!is.null(add_noise)) stop("add_noise not implemented for random sampling")
+  }
+  
+  if (train_type %in% c("label_rds", "lm_rds") & format != "rds") {
+    warning(paste("train_type is", train_type, "but format is not 'rds'"))
+  }
+  
+  # adjust batch size
+  if ((length(batch_size) == 1) && (batch_size %% length(path) != 0) & train_type == "label_folder") {
+    batch_size <- ceiling(batch_size/length(path)) * length(path)
+    if (!val) {
+      message(paste("Batch size needs to be multiple of number of targets. Setting batch_size to", batch_size))
+    }
+  }
+  
+  if (is.null(step)) step <- maxlen
+  
+  if (train_type == "dummy_gen") {
+    #gen <- generator_dummy(model, ifelse(is.null(set_learning), batch_size, new_batch_size))
+    gen <- generator_dummy(model, batch_size)
+    removeLog <- FALSE
+  }
+  
+  if (!is.null(added_label_path) & is.null(add_input_as_seq)) {
+    add_input_as_seq <- rep(FALSE, length(added_label_path))
+  }
+  
+  # language model
+  if (train_type == "lm" & random_sampling) {
+    
+    gen <- generator_random(
+      train_type = "lm",
+      output_format = output_format,
+      seed = seed[1],
+      format = format,
+      reverse_complement = reverse_complement,
+      reverse_complement_encoding = reverse_complement_encoding,
+      path = path,
+      batch_size = batch_size,
+      maxlen = maxlen,
+      ambiguous_nuc = ambiguous_nuc,
+      padding = padding,
+      vocabulary = vocabulary,
+      number_target_nt = target_len,
+      target_split = target_split,
+      target_from_csv = target_from_csv,
+      n_gram = n_gram,
+      n_gram_stride = n_gram_stride,
+      sample_by_file_size = sample_by_file_size,
+      max_samples = max_samples,
+      skip_amb_nuc = skip_amb_nuc,
+      vocabulary_label = vocabulary_label,
+      shuffle_input = shuffle_input,
+      proportion_entries = proportion_entries,
+      return_int = return_int,
+      concat_seq = concat_seq,
+      reshape_xy = reshape_xy)
+  } 
+  
+  if (train_type == "lm" & !random_sampling) {
+    
+    gen <- generator_fasta_lm(path_corpus = path, batch_size = batch_size,
+                              maxlen = maxlen, step = step, shuffle_file_order = shuffle_file_order,
+                              vocabulary = vocabulary, seed = seed[1], proportion_entries = proportion_entries,
+                              shuffle_input = shuffle_input, format = format, n_gram_stride = n_gram_stride,
+                              path_file_log = path_file_log, reverse_complement = reverse_complement, 
+                              output_format = output_format, ambiguous_nuc = ambiguous_nuc,
+                              proportion_per_seq = proportion_per_seq, skip_amb_nuc = skip_amb_nuc,
+                              use_quality_score = use_quality_score, padding = padding, n_gram = n_gram,
+                              added_label_path = added_label_path, add_input_as_seq = add_input_as_seq,
+                              max_samples = max_samples, concat_seq = concat_seq, target_len = target_len,
+                              file_filter = file_filter, use_coverage = use_coverage, return_int = return_int,
+                              sample_by_file_size = sample_by_file_size, add_noise = add_noise,
+                              reshape_xy = reshape_xy)
+  }
+  
+  # label by folder
+  if (train_type %in% c("label_folder", "masked_lm") & random_sampling) {
+    
+    gen <- generator_random(
+      train_type = train_type,
+      seed = seed[1],
+      format = format,
+      reverse_complement = reverse_complement,
+      path = path,
+      batch_size = batch_size,
+      maxlen = maxlen,
+      ambiguous_nuc = ambiguous_nuc,
+      padding = padding,
+      vocabulary = vocabulary,
+      number_target_nt = NULL,
+      n_gram = n_gram,
+      n_gram_stride = n_gram_stride,
+      sample_by_file_size = sample_by_file_size,
+      max_samples = max_samples,
+      skip_amb_nuc = skip_amb_nuc,
+      shuffle_input = shuffle_input,
+      set_learning = set_learning,
+      reverse_complement_encoding = reverse_complement_encoding,
+      vocabulary_label = vocabulary_label,
+      proportion_entries = proportion_entries,
+      masked_lm = masked_lm,
+      return_int = return_int,
+      concat_seq = concat_seq,
+      reshape_xy = reshape_xy)
+  } 
+  
+  if (train_type == "label_folder" & !random_sampling) {
+    
+    gen_list <- generator_initialize(directories = path, format = format, batch_size = batch_size, maxlen = maxlen, vocabulary = vocabulary,
+                                     verbose = verbose, shuffle_file_order = shuffle_file_order, step = step, seed = seed[1],
+                                     shuffle_input = shuffle_input, file_limit = file_limit, skip_amb_nuc = skip_amb_nuc,
+                                     path_file_log = path_file_log, reverse_complement = reverse_complement,
+                                     reverse_complement_encoding = reverse_complement_encoding, return_int = return_int,
+                                     ambiguous_nuc = ambiguous_nuc, proportion_per_seq = proportion_per_seq,
+                                     read_data = read_data, use_quality_score = use_quality_score, val = val,
+                                     padding = padding, max_samples = max_samples, concat_seq = concat_seq,
+                                     added_label_path = added_label_path, add_input_as_seq = add_input_as_seq, use_coverage = use_coverage,
+                                     set_learning = set_learning, proportion_entries = proportion_entries,
+                                     sample_by_file_size = sample_by_file_size, n_gram = n_gram, n_gram_stride = n_gram_stride,
+                                     add_noise = add_noise, reshape_xy = reshape_xy)
+
+    gen <- generator_fasta_label_folder_wrapper(val = val, path = path, 
+                                                batch_size = batch_size, voc_len = length(vocabulary),
+                                                gen_list = gen_list,
+                                                maxlen = maxlen, set_learning = set_learning)
+    
+  }
+  
+  if (train_type == "masked_lm" & !random_sampling) {
+    
+    stopifnot(!is.null(masked_lm))
+    
+    gen <- generator_fasta_label_folder(path_corpus = unlist(path),
+                                        format = format,
+                                        batch_size = batch_size,
+                                        maxlen = maxlen,
+                                        vocabulary = vocabulary,
+                                        shuffle_file_order = shuffle_file_order,
+                                        step = step,
+                                        seed = seed,
+                                        shuffle_input = shuffle_input,
+                                        file_limit = file_limit,
+                                        path_file_log = path_file_log,
+                                        reverse_complement = reverse_complement,
+                                        reverse_complement_encoding = reverse_complement_encoding,
+                                        num_targets = 1,
+                                        ones_column = 1,
+                                        ambiguous_nuc = ambiguous_nuc,
+                                        proportion_per_seq = proportion_per_seq,
+                                        read_data = read_data,
+                                        use_quality_score = use_quality_score,
+                                        padding = padding,
+                                        added_label_path = added_label_path,
+                                        add_input_as_seq = add_input_as_seq,
+                                        skip_amb_nuc = skip_amb_nuc,
+                                        max_samples = max_samples,
+                                        concat_seq = concat_seq,
+                                        file_filter = NULL,
+                                        return_int = return_int,
+                                        use_coverage = use_coverage,
+                                        proportion_entries = proportion_entries,
+                                        sample_by_file_size = sample_by_file_size,
+                                        n_gram = n_gram,
+                                        n_gram_stride = n_gram_stride,
+                                        masked_lm = masked_lm,
+                                        add_noise = add_noise,
+                                        reshape_xy = reshape_xy) 
+  }
+  
+  
+  if ((train_type == "label_csv" | train_type == "label_header") & !random_sampling) {
+    
+    gen <- generator_fasta_label_header_csv(path_corpus = path, format = format, batch_size = batch_size, maxlen = maxlen,
+                                            vocabulary = vocabulary, verbose = verbose, shuffle_file_order = shuffle_file_order, step = step,
+                                            seed = seed[1], shuffle_input = shuffle_input, return_int = return_int,
+                                            path_file_log = path_file_log, vocabulary_label = vocabulary_label, reverse_complement = reverse_complement,
+                                            ambiguous_nuc = ambiguous_nuc, proportion_per_seq = proportion_per_seq,
+                                            read_data = read_data, use_quality_score = use_quality_score, padding = padding,
+                                            added_label_path = added_label_path, add_input_as_seq = add_input_as_seq,
+                                            skip_amb_nuc = skip_amb_nuc, max_samples = max_samples, concat_seq = concat_seq,
+                                            target_from_csv = target_from_csv, target_split = target_split, file_filter = file_filter,
+                                            use_coverage = use_coverage, proportion_entries = proportion_entries,
+                                            sample_by_file_size = sample_by_file_size, n_gram = n_gram, n_gram_stride = n_gram_stride,
+                                            add_noise = add_noise, reverse_complement_encoding = reverse_complement_encoding,
+                                            reshape_xy = reshape_xy)
+  }
+  
+  if ((train_type == "label_csv" | train_type == "label_header") & random_sampling) {
+    
+    gen <- generator_random(
+      train_type = train_type, 
+      output_format = output_format,
+      seed = seed[1],
+      format = format,
+      reverse_complement = reverse_complement,
+      reverse_complement_encoding = reverse_complement_encoding,
+      path = path,
+      batch_size = batch_size,
+      maxlen = maxlen,
+      ambiguous_nuc = ambiguous_nuc,
+      padding = padding,
+      vocabulary = vocabulary,
+      number_target_nt = NULL,
+      n_gram = n_gram,
+      n_gram_stride = n_gram_stride,
+      sample_by_file_size = sample_by_file_size,
+      max_samples = max_samples,
+      skip_amb_nuc = skip_amb_nuc,
+      vocabulary_label = vocabulary_label,
+      target_from_csv = target_from_csv,
+      target_split = target_split,
+      verbose = verbose,
+      shuffle_input = shuffle_input,
+      proportion_entries = proportion_entries,
+      return_int = return_int,
+      concat_seq = concat_seq,
+      reshape_xy = reshape_xy)
+  }
+  
+  if (train_type %in% c("label_rds", "lm_rds")) {
+    reverse_complement <- FALSE
+    step <- 1
+    if (train_type == "label_rds") target_len <- NULL
+    gen <- generator_rds(rds_folder = path, batch_size = batch_size, path_file_log = path_file_log,
+                         max_samples = max_samples, proportion_per_seq = proportion_per_seq,
+                         sample_by_file_size = sample_by_file_size, add_noise = add_noise,
+                         reverse_complement_encoding = reverse_complement_encoding, seed = seed[1],
+                         target_len = target_len, n_gram = n_gram, n_gram_stride = n_gram_stride,
+                         delete_used_files = delete_used_files, reshape_xy = reshape_xy)
+    
+  }
+  
+  return(gen)
+  
+}