Switch to unified view

a b/R/generator_folder_collect.R
1
#' Initializes generators defined by \code{generator_fasta_label_folder} function
2
#'
3
#' Initializes generators defined by \code{\link{generator_fasta_label_folder}} function. Targets get encoded in order of directories.
4
#' Number of classes is given by length of \code{directories}.
5
#'
6
#' @inheritParams generator_fasta_lm
7
#' @inheritParams generator_fasta_label_header_csv
8
#' @inheritParams generator_fasta_label_folder
9
#' @inheritParams train_model
10
#' @param directories Vector of paths to folder containing fasta files. Files in one folder should belong to one class.
11
#' @param val Logical, call initialized generator "genY" or "genValY" where Y is an integer between 1 and length of directories.
12
#' @param target_middle Split input sequence into two sequences while removing nucleotide in middle. If input is x_1,..., x_(n+1), input gets split into
13
#' input_1 = x_1,..., x_m and input_2 = x_(n+1),..., x_(m+2) where m = ceiling((n+1)/2) and n = maxlen. Note that x_(m+1) is not used.
14
#' @param read_data If true the first element of output is a list of length 2, each containing one part of paired read.
15
#' @examplesIf reticulate::py_module_available("tensorflow")
16
#' # create two folders with dummy fasta files
17
#' path_input_1 <- tempfile()
18
#' dir.create(path_input_1)
19
#' create_dummy_data(file_path = path_input_1, num_files = 2, seq_length = 5,
20
#'                   num_seq = 2, vocabulary = c("a", "c", "g", "t"))
21
#' path_input_2 <- tempfile()
22
#' dir.create(path_input_2)
23
#' create_dummy_data(file_path = path_input_2, num_files = 3, seq_length = 7,
24
#'                   num_seq = 5, vocabulary = c("a", "c", "g", "t"))
25
#' 
26
#' gen_list <- generator_initialize(directories = c(path_input_1, path_input_1),
27
#'                                         batch_size = 4, maxlen = 5)
28
#' z1 <- gen_list[[1]]()
29
#' z1[[1]]
30
#' z1[[2]]
31
#' 
32
#' @returns List of generator function. 
33
#' @export
34
generator_initialize <- function(directories,
35
                                 format = "fasta",
36
                                 batch_size = 256,
37
                                 maxlen = 250,
38
                                 max_iter = 10000,
39
                                 vocabulary = c("a", "c", "g", "t"),
40
                                 verbose = FALSE,
41
                                 shuffle_file_order = FALSE,
42
                                 step = 1,
43
                                 seed = 1234,
44
                                 shuffle_input = FALSE,
45
                                 file_limit = NULL,
46
                                 path_file_log = NULL,
47
                                 reverse_complement = FALSE,
48
                                 reverse_complement_encoding = FALSE,
49
                                 val = FALSE,
50
                                 ambiguous_nuc = "zero",
51
                                 proportion_per_seq = NULL,
52
                                 target_middle = FALSE,
53
                                 read_data = FALSE,
54
                                 use_quality_score = FALSE,
55
                                 padding = TRUE,
56
                                 added_label_path = NULL,
57
                                 add_input_as_seq = NULL,
58
                                 skip_amb_nuc = NULL,
59
                                 max_samples = NULL,
60
                                 file_filter = NULL,
61
                                 concat_seq = NULL,
62
                                 use_coverage = NULL,
63
                                 set_learning = NULL,
64
                                 proportion_entries = NULL,
65
                                 sample_by_file_size = FALSE,
66
                                 n_gram = NULL,
67
                                 n_gram_stride = 1,
68
                                 add_noise = NULL,
69
                                 return_int = FALSE,
70
                                 reshape_xy = NULL) {
71
  
72
  num_class <- length(directories)
73
  
74
  if (!is.null(reshape_xy)) {
75
    reshape_xy_bool <- TRUE
76
    reshape_x_bool <- ifelse(is.null(reshape_xy$x), FALSE, TRUE)
77
    if (reshape_x_bool && !all(c('x', 'y') %in% names(formals(reshape_xy$x)))) {
78
      stop("function reshape_xy$x needs to have arguments named x and y")
79
    }
80
    reshape_y_bool <- ifelse(is.null(reshape_xy$y), FALSE, TRUE)
81
    if (reshape_y_bool && !all(c('x', 'y') %in% names(formals(reshape_xy$y)))) {
82
      stop("function reshape_xy$y needs to have arguments named x and y")
83
    }
84
  } else {
85
    reshape_xy_bool <- FALSE
86
  }
87
  
88
  # adjust batch_size
89
  if (is.null(set_learning) && (length(batch_size) == 1) && (batch_size %% length(directories) != 0)) {
90
    batch_size <- ceiling(batch_size/length(directories)) * length(directories)
91
    if (!val) {
92
      message(paste("Batch size needs to be multiple of number of targets. Setting batch_size to", batch_size))
93
    }
94
  }
95
  
96
  num_targets <- length(directories)
97
  
98
  if (!is.null(set_learning)) {
99
    reshape_mode <- set_learning$reshape_mode
100
    samples_per_target <- set_learning$samples_per_target
101
    buffer_len <- set_learning$buffer_len
102
    maxlen <- set_learning$maxlen
103
    concat_maxlen <- NULL
104
    
105
    if (reshape_mode == "concat") {
106
      
107
      if (sum(batch_size) %% length(directories) != 0) {
108
        stop_text <- paste("batch_size is", batch_size, "but needs to be multiple of number of classes (",
109
                           length(directories), ") for set learning with 'concat'")
110
        stop(stop_text)
111
      }
112
      buffer_len <- ifelse(is.null(set_learning$buffer_len), 0, set_learning$buffer_len)
113
      concat_maxlen <- (maxlen * samples_per_target) + (buffer_len * (samples_per_target - 1))
114
      if (any(c("z", "Z") %in% vocabulary) & !is.null(set_learning$buffer_len)) {
115
        stop("'Z' is used as token for separating sequences and can not be in vocabulary.")
116
      }
117
      if (!is.null(set_learning$buffer_len)) {
118
        vocabulary <- c(vocabulary, "Z")
119
      }
120
    }
121
    
122
    if (any(batch_size[1] != batch_size)) {
123
      stop("Set learning only implemented for uniform batch_size for all classes.")
124
    }
125
    new_batch_size <- batch_size
126
    batch_size <- samples_per_target * batch_size
127
  }
128
  
129
  
130
  if (length(batch_size) == 1) {
131
    batch_size <- rep(batch_size/num_targets, num_targets)
132
  }
133
  
134
  argg <- c(as.list(environment()))
135
  # variables with just one entry
136
  argg["directories"] <- NULL
137
  argg["file_filter"] <- NULL
138
  argg["val"] <- NULL
139
  argg["vocabulary"] <- NULL
140
  argg["num_targets"] <- NULL
141
  argg["verbose"] <- NULL
142
  argg["maxlen"] <- NULL
143
  argg["max_iter"] <- NULL
144
  argg["read_data"] <- NULL
145
  argg["use_quality_score"] <- NULL
146
  argg["added_label_path"] <- NULL
147
  argg["add_input_as_seq"] <- NULL
148
  argg["skip_amb_nuc"] <- NULL
149
  argg["concat_seq"] <- NULL
150
  argg["reverse_complement_encoding"] <- NULL
151
  argg["use_coverage"] <- NULL
152
  argg["set_learning"] <- NULL
153
  argg["proportion_entries"] <- NULL
154
  argg["sample_by_file_size"] <- NULL
155
  argg["n_gram"] <- NULL
156
  argg["n_gram_stride"] <- NULL
157
  argg[["add_noise"]] <- NULL
158
  argg[["return_int"]] <- NULL
159
  argg[["num_class"]] <- NULL
160
  argg[["reshape_xy"]] <- NULL
161
  
162
  for (i in 1:length(argg)) {
163
    if (length(argg[[i]]) == 1) {
164
      assign(names(argg)[i], rep(argg[[i]], num_targets))
165
    }
166
    if ((length(argg[[i]]) != 1) & (length(argg[[i]]) != num_targets) & !(is.null(argg[[i]]))) {
167
      stop_message <- paste("Incorrect argument length,", names(argg[i]), "argument vector must have length 1 or", num_targets)
168
      stop(stop_message)
169
    }
170
  }
171
  
172
  if (!val) {
173
    gen_train_list <- list()
174
    # create generator for every folder
175
    for (i in 1:num_class) {
176
      numberedGen <- paste0("gen", i)
177
      gen_train_list[[numberedGen]] <- generator_fasta_label_folder(path_corpus = directories[[i]],
178
                                                                    format = format[i],
179
                                                                    batch_size = batch_size[i],
180
                                                                    maxlen = maxlen,
181
                                                                    max_iter = max_iter,
182
                                                                    vocabulary = vocabulary,
183
                                                                    verbose = verbose,
184
                                                                    shuffle_file_order = shuffle_file_order[i],
185
                                                                    step = step[i],
186
                                                                    seed = seed[i],
187
                                                                    shuffle_input = shuffle_input[i],
188
                                                                    file_limit = file_limit[i],
189
                                                                    path_file_log = path_file_log[i],
190
                                                                    reverse_complement = reverse_complement[i],
191
                                                                    reverse_complement_encoding = reverse_complement_encoding,
192
                                                                    num_targets = num_targets,
193
                                                                    ones_column = i,
194
                                                                    ambiguous_nuc = ambiguous_nuc[i],
195
                                                                    proportion_per_seq = proportion_per_seq[i],
196
                                                                    read_data = read_data,
197
                                                                    use_quality_score = use_quality_score,
198
                                                                    padding = padding[i],
199
                                                                    file_filter = file_filter,
200
                                                                    added_label_path = added_label_path,
201
                                                                    add_input_as_seq = add_input_as_seq,
202
                                                                    skip_amb_nuc = skip_amb_nuc,
203
                                                                    max_samples = max_samples[i],
204
                                                                    concat_seq = concat_seq,
205
                                                                    use_coverage = use_coverage,
206
                                                                    proportion_entries = proportion_entries,
207
                                                                    sample_by_file_size = sample_by_file_size,
208
                                                                    n_gram = n_gram,
209
                                                                    masked_lm = NULL,
210
                                                                    n_gram_stride = n_gram_stride,
211
                                                                    add_noise = add_noise,
212
                                                                    return_int = return_int,
213
                                                                    reshape_xy = reshape_xy)
214
      
215
    }
216
  } else {
217
    gen_val_list <- list()
218
    # create generator for every folder
219
    for (i in 1:num_class) {
220
      # different names for validation generators
221
      numberedGenVal <- paste0("genVal", as.character(i))
222
      gen_val_list[[numberedGenVal]] <- generator_fasta_label_folder(path_corpus = directories[[i]],
223
                                                                     format = format[i],
224
                                                                     batch_size = batch_size[i],
225
                                                                     maxlen = maxlen,
226
                                                                     max_iter = max_iter,
227
                                                                     vocabulary = vocabulary,
228
                                                                     verbose = verbose,
229
                                                                     shuffle_file_order = shuffle_file_order[i],
230
                                                                     step = step[i],
231
                                                                     seed = seed[i],
232
                                                                     shuffle_input = shuffle_input[i],
233
                                                                     file_limit = file_limit[i],
234
                                                                     path_file_log = path_file_log[i],
235
                                                                     reverse_complement = reverse_complement[i],
236
                                                                     reverse_complement_encoding = reverse_complement_encoding,
237
                                                                     num_targets = num_targets,
238
                                                                     ones_column = i,
239
                                                                     ambiguous_nuc = ambiguous_nuc[i],
240
                                                                     proportion_per_seq = proportion_per_seq[i],
241
                                                                     read_data = read_data,
242
                                                                     use_quality_score = use_quality_score,
243
                                                                     padding = padding[i],
244
                                                                     file_filter = file_filter,
245
                                                                     added_label_path = added_label_path,
246
                                                                     add_input_as_seq = add_input_as_seq,
247
                                                                     skip_amb_nuc = skip_amb_nuc,
248
                                                                     max_samples = max_samples[i],
249
                                                                     concat_seq = concat_seq,
250
                                                                     use_coverage = use_coverage,
251
                                                                     proportion_entries = proportion_entries,
252
                                                                     sample_by_file_size = sample_by_file_size,
253
                                                                     masked_lm = NULL,
254
                                                                     n_gram = n_gram,
255
                                                                     n_gram_stride = n_gram_stride,
256
                                                                     add_noise = add_noise,
257
                                                                     return_int = return_int,
258
                                                                     reshape_xy = reshape_xy)
259
    }
260
  }
261
  
262
  if (!val) {
263
    gen_train_list
264
  } else {
265
    gen_val_list
266
  }
267
  
268
}
269
270
#' Generator wrapper
271
#'
272
#' Combines generators created by \code{\link{generator_initialize}} into a single generator.
273
#'
274
#' @inheritParams generator_fasta_lm
275
#' @inheritParams generator_fasta_label_folder
276
#' @inheritParams train_model
277
#' @inheritParams reshape_tensor
278
#' @param val Train or validation generator.
279
#' @param path Path to input files.
280
#' @param voc_len Length of vocabulary.
281
#' @param gen_list List of generator functions.
282
#' @examplesIf reticulate::py_module_available("tensorflow")
283
#' # create two folders with dummy fasta files
284
#' path_input_1 <- tempfile()
285
#' dir.create(path_input_1)
286
#' create_dummy_data(file_path = path_input_1, num_files = 2, seq_length = 5,
287
#'                   num_seq = 2, vocabulary = c("a", "c", "g", "t"))
288
#' path_input_2 <- tempfile()
289
#' dir.create(path_input_2)
290
#' create_dummy_data(file_path = path_input_2, num_files = 3, seq_length = 7,
291
#'                   num_seq = 5, vocabulary = c("a", "c", "g", "t"))
292
#' 
293
#' maxlen <- 5
294
#' p <- c(path_input_1, path_input_1)
295
#' gen_list <- generator_initialize(directories = p,
296
#'                                  batch_size = 4, maxlen = maxlen)
297
#' gen <- generator_fasta_label_folder_wrapper(val = FALSE, batch_size = 8,
298
#'                                             path = p, voc_len = 4, 
299
#'                                             maxlen = maxlen,
300
#'                                             gen_list = gen_list)
301
#' z <- gen()
302
#' dim(z[[1]])
303
#' z[[2]]
304
#' 
305
#' @returns A generator function.  
306
#' @export
307
generator_fasta_label_folder_wrapper <- function(val, 
308
                                                 batch_size = NULL,
309
                                                 path = NULL, voc_len = NULL, 
310
                                                 maxlen = NULL,
311
                                                 gen_list = NULL, 
312
                                                 set_learning = NULL) {
313
  
314
  if (is.null(set_learning)) {
315
    samples_per_target <- NULL
316
    new_batch_size <- NULL
317
    reshape_mode <- NULL
318
    buffer_len <- NULL
319
  } else {
320
    reshape_mode <- set_learning$reshape_mode
321
    samples_per_target <- set_learning$samples_per_target
322
    buffer_len <- set_learning$buffer_len
323
    maxlen <- set_learning$maxlen
324
    concat_maxlen <- NULL
325
    
326
    if (reshape_mode == "concat") {
327
      
328
      if (sum(batch_size) %% length(path) != 0) {
329
        stop_text <- paste("batch_size is", batch_size, "but needs to be multiple of number of classes (",
330
                           length(path), ") for set learning with 'concat'")
331
        stop(stop_text)
332
      }
333
      buffer_len <- ifelse(is.null(set_learning$buffer_len), 0, set_learning$buffer_len)
334
      concat_maxlen <- (maxlen * samples_per_target) + (buffer_len * (samples_per_target - 1))
335
      if (!is.null(set_learning$buffer_len)) {
336
        vocabulary <- c(vocabulary, "Z")
337
      }
338
      
339
    }
340
    
341
    if (any(batch_size[1] != batch_size)) {
342
      stop("Set learning only implemented for uniform batch_size for all classes.")
343
    }
344
    new_batch_size <- batch_size
345
    batch_size <- samples_per_target * batch_size
346
  }
347
  
348
  if (!val) {
349
    
350
    function() {
351
      directories <- path
352
      # combine generators to create one batch
353
      subBatchTrain <- gen_list[["gen1"]]()
354
      if (is.list(subBatchTrain[[1]])) {
355
        num_inputs <- length(subBatchTrain[[1]])
356
      } else {
357
        num_inputs <- 1
358
      }
359
      xTrain <- subBatchTrain[[1]]
360
      yTrain <- subBatchTrain[[2]]
361
      if (num_inputs > 1) {
362
        x_train_list <- list()
363
        for (i in 1:num_inputs) {
364
          x_train_list[[i]] <- xTrain[[i]]
365
        }
366
      }
367
      
368
      if (length(directories) > 1) {
369
        for (i in 2:length(directories)) {
370
          subBatchTrain <- gen_list[[paste0("gen", i)]]()
371
          yTrain <- abind::abind(yTrain, subBatchTrain[[2]], along = 1)
372
          
373
          if (num_inputs == 1) {
374
            xTrain <- abind::abind(xTrain, subBatchTrain[[1]], along = 1)
375
          } else {
376
            for (j in 1:num_inputs) {
377
              x_train_list[[j]] <- abind::abind(x_train_list[[j]], subBatchTrain[[1]][[j]], along = 1)
378
            }
379
          }
380
        }
381
      }
382
      if (num_inputs > 1) {
383
        xTrain <- x_train_list
384
      }
385
      
386
      if (!is.null(samples_per_target)) {
387
        l <- reshape_tensor(x = xTrain, y = yTrain,
388
                            new_batch_size = new_batch_size,
389
                            samples_per_target = samples_per_target,
390
                            buffer_len = buffer_len,
391
                            reshape_mode = reshape_mode)
392
        return(l)
393
      } else {
394
        return(list(X = xTrain, Y = yTrain))
395
      }
396
      
397
    }
398
  } else {
399
    function() {
400
      directories <- path
401
      # combine generators to create one batch
402
      subBatchVal <- gen_list[["genVal1"]]()
403
      if (is.list(subBatchVal[[1]])) {
404
        num_inputs <- length(subBatchVal[[1]])
405
      } else {
406
        num_inputs <- 1
407
      }
408
      xVal <- subBatchVal[[1]]
409
      yVal <- subBatchVal[[2]]
410
      if (num_inputs > 1) {
411
        x_val_list <- list()
412
        for (i in 1:num_inputs) {
413
          x_val_list[[i]] <- xVal[[i]]
414
        }
415
      }
416
      
417
      if (length(directories) > 1) {
418
        for (i in 2:length(directories)) {
419
          subBatchVal <- gen_list[[paste0("genVal", i)]]()
420
          yVal <- abind::abind(yVal, subBatchVal[[2]], along = 1)
421
          
422
          if (num_inputs == 1) {
423
            xVal <- abind::abind(xVal, subBatchVal[[1]], along = 1)
424
          } else {
425
            for (j in 1:num_inputs) {
426
              x_val_list[[j]] <- abind::abind(x_val_list[[j]], subBatchVal[[1]][[j]], along = 1)
427
            }
428
          }
429
        }
430
      }
431
      if (num_inputs > 1) {
432
        xVal <- x_val_list
433
      }
434
      if (!is.null(samples_per_target)) {
435
        l <- reshape_tensor(x = xVal, y = yVal,
436
                            new_batch_size = new_batch_size,
437
                            samples_per_target = samples_per_target,
438
                            buffer_len = buffer_len,
439
                            reshape_mode = reshape_mode)
440
        return(l)
441
      } else {
442
        return(list(X = xVal, Y = yVal))
443
      }
444
    }
445
  }
446
  
447
}