Diff of /R/generators.R [000000] .. [409433]

Switch to unified view

a b/R/generators.R
1
#' Wrapper for generator functions
2
#' 
3
#' For a detailed description see the data generator [tutorial](https://deepg.de/articles/data_generator.html).
4
#' Will choose one of the generators from \code{\link{generator_fasta_lm}}, 
5
#' \code{\link{generator_fasta_label_folder}}, \code{\link{generator_fasta_label_header_csv}}, 
6
#' \code{\link{generator_rds}}, \code{\link{generator_random}}, \code{\link{generator_dummy}} or 
7
#' \code{\link{generator_fasta_lm}} according to the \code{train_type} and \code{random_sampling}
8
#' arguments.
9
#'
10
#' @inheritParams train_model
11
#' @inheritParams generator_fasta_lm
12
#' @inheritParams generator_fasta_label_folder
13
#' @inheritParams generator_fasta_label_header_csv
14
#' @inheritParams generator_rds
15
#' @inheritParams generator_random
16
#' @inheritParams generator_initialize
17
#' @param path_file_logVal Path to csv file logging used validation files.
18
#' @examplesIf reticulate::py_module_available("tensorflow")
19
#' # create dummy fasta files
20
#' fasta_path <- tempfile()
21
#' dir.create(fasta_path)
22
#' create_dummy_data(file_path = fasta_path,
23
#'                   num_files = 3,
24
#'                   seq_length = 10,
25
#'                   num_seq = 5,
26
#'                   vocabulary = c("a", "c", "g", "t"))
27
#' 
28
#' gen <- get_generator(path = fasta_path,
29
#'                      maxlen = 5, train_type = "lm",
30
#'                      output_format = "target_right",
31
#'                      step = 3, batch_size = 7)
32
#' z <- gen()
33
#' x <- z[[1]]
34
#' y <- z[[2]]
35
#' dim(x)
36
#' dim(y)
37
#' 
38
#' @returns A generator function.
39
#' @export
40
get_generator <- function(path = NULL,
41
                          train_type,
42
                          batch_size,
43
                          maxlen,
44
                          step = NULL,
45
                          shuffle_file_order = FALSE,
46
                          vocabulary = c("A", "C", "G", "T"),
47
                          seed = 1,
48
                          proportion_entries = NULL,
49
                          shuffle_input = FALSE,
50
                          format = "fasta",
51
                          path_file_log = NULL,
52
                          reverse_complement = FALSE,
53
                          n_gram = NULL,
54
                          n_gram_stride = NULL,
55
                          output_format = "target_right",
56
                          ambiguous_nuc = "zero",
57
                          proportion_per_seq = NULL,
58
                          skip_amb_nuc = NULL,
59
                          use_quality_score = FALSE,
60
                          padding = FALSE,
61
                          added_label_path = NULL,
62
                          target_from_csv = NULL,
63
                          add_input_as_seq = NULL,
64
                          max_samples = NULL,
65
                          concat_seq = NULL,
66
                          target_len = 1,
67
                          file_filter = NULL,
68
                          use_coverage = NULL,
69
                          sample_by_file_size = FALSE,
70
                          add_noise = NULL,
71
                          random_sampling = FALSE,
72
                          set_learning = NULL,
73
                          file_limit = NULL,
74
                          reverse_complement_encoding = FALSE,
75
                          read_data = FALSE,
76
                          target_split = NULL,
77
                          path_file_logVal = NULL,
78
                          model = NULL,
79
                          vocabulary_label = NULL,
80
                          masked_lm = NULL,
81
                          val = FALSE,
82
                          return_int = FALSE,
83
                          verbose = TRUE,
84
                          delete_used_files = FALSE,
85
                          reshape_xy = NULL) {
86
  
87
  if (random_sampling) {
88
    if (use_quality_score) stop("use_quality_score not implemented for random sampling")
89
    if (read_data) stop("read_data not implemented for random sampling")
90
    if (!is.null(use_coverage)) stop("use_coverage not implemented for random sampling")
91
    if (!is.null(add_noise)) stop("add_noise not implemented for random sampling")
92
  }
93
  
94
  if (train_type %in% c("label_rds", "lm_rds") & format != "rds") {
95
    warning(paste("train_type is", train_type, "but format is not 'rds'"))
96
  }
97
  
98
  # adjust batch size
99
  if ((length(batch_size) == 1) && (batch_size %% length(path) != 0) & train_type == "label_folder") {
100
    batch_size <- ceiling(batch_size/length(path)) * length(path)
101
    if (!val) {
102
      message(paste("Batch size needs to be multiple of number of targets. Setting batch_size to", batch_size))
103
    }
104
  }
105
  
106
  if (is.null(step)) step <- maxlen
107
  
108
  if (train_type == "dummy_gen") {
109
    #gen <- generator_dummy(model, ifelse(is.null(set_learning), batch_size, new_batch_size))
110
    gen <- generator_dummy(model, batch_size)
111
    removeLog <- FALSE
112
  }
113
  
114
  if (!is.null(added_label_path) & is.null(add_input_as_seq)) {
115
    add_input_as_seq <- rep(FALSE, length(added_label_path))
116
  }
117
  
118
  # language model
119
  if (train_type == "lm" & random_sampling) {
120
    
121
    gen <- generator_random(
122
      train_type = "lm",
123
      output_format = output_format,
124
      seed = seed[1],
125
      format = format,
126
      reverse_complement = reverse_complement,
127
      reverse_complement_encoding = reverse_complement_encoding,
128
      path = path,
129
      batch_size = batch_size,
130
      maxlen = maxlen,
131
      ambiguous_nuc = ambiguous_nuc,
132
      padding = padding,
133
      vocabulary = vocabulary,
134
      number_target_nt = target_len,
135
      target_split = target_split,
136
      target_from_csv = target_from_csv,
137
      n_gram = n_gram,
138
      n_gram_stride = n_gram_stride,
139
      sample_by_file_size = sample_by_file_size,
140
      max_samples = max_samples,
141
      skip_amb_nuc = skip_amb_nuc,
142
      vocabulary_label = vocabulary_label,
143
      shuffle_input = shuffle_input,
144
      proportion_entries = proportion_entries,
145
      return_int = return_int,
146
      concat_seq = concat_seq,
147
      reshape_xy = reshape_xy)
148
  } 
149
  
150
  if (train_type == "lm" & !random_sampling) {
151
    
152
    gen <- generator_fasta_lm(path_corpus = path, batch_size = batch_size,
153
                              maxlen = maxlen, step = step, shuffle_file_order = shuffle_file_order,
154
                              vocabulary = vocabulary, seed = seed[1], proportion_entries = proportion_entries,
155
                              shuffle_input = shuffle_input, format = format, n_gram_stride = n_gram_stride,
156
                              path_file_log = path_file_log, reverse_complement = reverse_complement, 
157
                              output_format = output_format, ambiguous_nuc = ambiguous_nuc,
158
                              proportion_per_seq = proportion_per_seq, skip_amb_nuc = skip_amb_nuc,
159
                              use_quality_score = use_quality_score, padding = padding, n_gram = n_gram,
160
                              added_label_path = added_label_path, add_input_as_seq = add_input_as_seq,
161
                              max_samples = max_samples, concat_seq = concat_seq, target_len = target_len,
162
                              file_filter = file_filter, use_coverage = use_coverage, return_int = return_int,
163
                              sample_by_file_size = sample_by_file_size, add_noise = add_noise,
164
                              reshape_xy = reshape_xy)
165
  }
166
  
167
  # label by folder
168
  if (train_type %in% c("label_folder", "masked_lm") & random_sampling) {
169
    
170
    gen <- generator_random(
171
      train_type = train_type,
172
      seed = seed[1],
173
      format = format,
174
      reverse_complement = reverse_complement,
175
      path = path,
176
      batch_size = batch_size,
177
      maxlen = maxlen,
178
      ambiguous_nuc = ambiguous_nuc,
179
      padding = padding,
180
      vocabulary = vocabulary,
181
      number_target_nt = NULL,
182
      n_gram = n_gram,
183
      n_gram_stride = n_gram_stride,
184
      sample_by_file_size = sample_by_file_size,
185
      max_samples = max_samples,
186
      skip_amb_nuc = skip_amb_nuc,
187
      shuffle_input = shuffle_input,
188
      set_learning = set_learning,
189
      reverse_complement_encoding = reverse_complement_encoding,
190
      vocabulary_label = vocabulary_label,
191
      proportion_entries = proportion_entries,
192
      masked_lm = masked_lm,
193
      return_int = return_int,
194
      concat_seq = concat_seq,
195
      reshape_xy = reshape_xy)
196
  } 
197
  
198
  if (train_type == "label_folder" & !random_sampling) {
199
    
200
    gen_list <- generator_initialize(directories = path, format = format, batch_size = batch_size, maxlen = maxlen, vocabulary = vocabulary,
201
                                     verbose = verbose, shuffle_file_order = shuffle_file_order, step = step, seed = seed[1],
202
                                     shuffle_input = shuffle_input, file_limit = file_limit, skip_amb_nuc = skip_amb_nuc,
203
                                     path_file_log = path_file_log, reverse_complement = reverse_complement,
204
                                     reverse_complement_encoding = reverse_complement_encoding, return_int = return_int,
205
                                     ambiguous_nuc = ambiguous_nuc, proportion_per_seq = proportion_per_seq,
206
                                     read_data = read_data, use_quality_score = use_quality_score, val = val,
207
                                     padding = padding, max_samples = max_samples, concat_seq = concat_seq,
208
                                     added_label_path = added_label_path, add_input_as_seq = add_input_as_seq, use_coverage = use_coverage,
209
                                     set_learning = set_learning, proportion_entries = proportion_entries,
210
                                     sample_by_file_size = sample_by_file_size, n_gram = n_gram, n_gram_stride = n_gram_stride,
211
                                     add_noise = add_noise, reshape_xy = reshape_xy)
212
213
    gen <- generator_fasta_label_folder_wrapper(val = val, path = path, 
214
                                                batch_size = batch_size, voc_len = length(vocabulary),
215
                                                gen_list = gen_list,
216
                                                maxlen = maxlen, set_learning = set_learning)
217
    
218
  }
219
  
220
  if (train_type == "masked_lm" & !random_sampling) {
221
    
222
    stopifnot(!is.null(masked_lm))
223
    
224
    gen <- generator_fasta_label_folder(path_corpus = unlist(path),
225
                                        format = format,
226
                                        batch_size = batch_size,
227
                                        maxlen = maxlen,
228
                                        vocabulary = vocabulary,
229
                                        shuffle_file_order = shuffle_file_order,
230
                                        step = step,
231
                                        seed = seed,
232
                                        shuffle_input = shuffle_input,
233
                                        file_limit = file_limit,
234
                                        path_file_log = path_file_log,
235
                                        reverse_complement = reverse_complement,
236
                                        reverse_complement_encoding = reverse_complement_encoding,
237
                                        num_targets = 1,
238
                                        ones_column = 1,
239
                                        ambiguous_nuc = ambiguous_nuc,
240
                                        proportion_per_seq = proportion_per_seq,
241
                                        read_data = read_data,
242
                                        use_quality_score = use_quality_score,
243
                                        padding = padding,
244
                                        added_label_path = added_label_path,
245
                                        add_input_as_seq = add_input_as_seq,
246
                                        skip_amb_nuc = skip_amb_nuc,
247
                                        max_samples = max_samples,
248
                                        concat_seq = concat_seq,
249
                                        file_filter = NULL,
250
                                        return_int = return_int,
251
                                        use_coverage = use_coverage,
252
                                        proportion_entries = proportion_entries,
253
                                        sample_by_file_size = sample_by_file_size,
254
                                        n_gram = n_gram,
255
                                        n_gram_stride = n_gram_stride,
256
                                        masked_lm = masked_lm,
257
                                        add_noise = add_noise,
258
                                        reshape_xy = reshape_xy) 
259
  }
260
  
261
  
262
  if ((train_type == "label_csv" | train_type == "label_header") & !random_sampling) {
263
    
264
    gen <- generator_fasta_label_header_csv(path_corpus = path, format = format, batch_size = batch_size, maxlen = maxlen,
265
                                            vocabulary = vocabulary, verbose = verbose, shuffle_file_order = shuffle_file_order, step = step,
266
                                            seed = seed[1], shuffle_input = shuffle_input, return_int = return_int,
267
                                            path_file_log = path_file_log, vocabulary_label = vocabulary_label, reverse_complement = reverse_complement,
268
                                            ambiguous_nuc = ambiguous_nuc, proportion_per_seq = proportion_per_seq,
269
                                            read_data = read_data, use_quality_score = use_quality_score, padding = padding,
270
                                            added_label_path = added_label_path, add_input_as_seq = add_input_as_seq,
271
                                            skip_amb_nuc = skip_amb_nuc, max_samples = max_samples, concat_seq = concat_seq,
272
                                            target_from_csv = target_from_csv, target_split = target_split, file_filter = file_filter,
273
                                            use_coverage = use_coverage, proportion_entries = proportion_entries,
274
                                            sample_by_file_size = sample_by_file_size, n_gram = n_gram, n_gram_stride = n_gram_stride,
275
                                            add_noise = add_noise, reverse_complement_encoding = reverse_complement_encoding,
276
                                            reshape_xy = reshape_xy)
277
  }
278
  
279
  if ((train_type == "label_csv" | train_type == "label_header") & random_sampling) {
280
    
281
    gen <- generator_random(
282
      train_type = train_type, 
283
      output_format = output_format,
284
      seed = seed[1],
285
      format = format,
286
      reverse_complement = reverse_complement,
287
      reverse_complement_encoding = reverse_complement_encoding,
288
      path = path,
289
      batch_size = batch_size,
290
      maxlen = maxlen,
291
      ambiguous_nuc = ambiguous_nuc,
292
      padding = padding,
293
      vocabulary = vocabulary,
294
      number_target_nt = NULL,
295
      n_gram = n_gram,
296
      n_gram_stride = n_gram_stride,
297
      sample_by_file_size = sample_by_file_size,
298
      max_samples = max_samples,
299
      skip_amb_nuc = skip_amb_nuc,
300
      vocabulary_label = vocabulary_label,
301
      target_from_csv = target_from_csv,
302
      target_split = target_split,
303
      verbose = verbose,
304
      shuffle_input = shuffle_input,
305
      proportion_entries = proportion_entries,
306
      return_int = return_int,
307
      concat_seq = concat_seq,
308
      reshape_xy = reshape_xy)
309
  }
310
  
311
  if (train_type %in% c("label_rds", "lm_rds")) {
312
    reverse_complement <- FALSE
313
    step <- 1
314
    if (train_type == "label_rds") target_len <- NULL
315
    gen <- generator_rds(rds_folder = path, batch_size = batch_size, path_file_log = path_file_log,
316
                         max_samples = max_samples, proportion_per_seq = proportion_per_seq,
317
                         sample_by_file_size = sample_by_file_size, add_noise = add_noise,
318
                         reverse_complement_encoding = reverse_complement_encoding, seed = seed[1],
319
                         target_len = target_len, n_gram = n_gram, n_gram_stride = n_gram_stride,
320
                         delete_used_files = delete_used_files, reshape_xy = reshape_xy)
321
    
322
  }
323
  
324
  return(gen)
325
  
326
}