Diff of /R/preprocess.R [000000] .. [409433]

Switch to unified view

a b/R/preprocess.R
1
#' Encodes integer sequence for language model
2
#'
3
#' Helper function for \code{\link{generator_fasta_lm}}. 
4
#' Encodes integer sequence to input/target list according to \code{output_format} argument. 
5
#'
6
#' @inheritParams generator_fasta_lm
7
#' @param sequence Sequence of integers.
8
#' @param start_ind Start positions of samples in \code{sequence}.
9
#' @param ambiguous_nuc How to handle nucleotides outside vocabulary, either `"zero"`, `"empirical"` or `"equal"`.
10
#' See \code{\link{train_model}}. Note that `"discard"` option is not available for this function.
11
#' @param nuc_dist Nucleotide distribution.
12
#' @param max_cov Biggest coverage value. Only applies if `use_coverage = TRUE`.
13
#' @param cov_vector Vector of coverage values associated to the input. 
14
#' @param adjust_start_ind Whether to shift values in \code{start_ind} to start at 1: for example (5,11,25) becomes (1,7,21).
15
#' @param quality_vector Vector of quality probabilities.
16
#' @param tokenizer A keras tokenizer.
17
#' @param char_sequence A character string.
18
#' @examplesIf reticulate::py_module_available("tensorflow")
19
#' # use integer sequence as input 
20
#' 
21
#' z <- seq_encoding_lm(sequence = c(1,0,5,1,3,4,3,1,4,1,2),
22
#' maxlen = 5,
23
#' vocabulary = c("a", "c", "g", "t"),
24
#' start_ind = c(1,3),
25
#' ambiguous_nuc = "equal",
26
#' target_len = 1,
27
#' output_format = "target_right")
28
#' 
29
#' x <- z[[1]]
30
#' y <- z[[2]]
31
#' 
32
#' x[1,,] # 1,0,5,1,3
33
#' y[1,] # 4
34
#' 
35
#' x[2,,] # 5,1,3,4,
36
#' y[2,] # 1
37
#' 
38
#' # use character string as input
39
#' z <- seq_encoding_lm(sequence = NULL,
40
#' maxlen = 5,
41
#' vocabulary = c("a", "c", "g", "t"),
42
#' start_ind = c(1,3),
43
#' ambiguous_nuc = "zero",
44
#' target_len = 1,
45
#' output_format = "target_right",
46
#' char_sequence = "ACTaaTNTNaZ")
47
#' 
48
#' 
49
#' x <- z[[1]]
50
#' y <- z[[2]]
51
#' 
52
#' x[1,,] # actaa
53
#' y[1,] # t
54
#' 
55
#' x[2,,] # taatn
56
#' y[2,] # t
57
#' 
58
#' @returns A list of 2 tensors.
59
#' @export
60
seq_encoding_lm <- function(sequence = NULL, maxlen, vocabulary, start_ind, ambiguous_nuc = "zero",
61
                            nuc_dist = NULL, quality_vector = NULL, return_int = FALSE,
62
                            target_len = 1, use_coverage = FALSE, max_cov = NULL, cov_vector = NULL,
63
                            n_gram = NULL, n_gram_stride = 1, output_format = "target_right",
64
                            char_sequence = NULL, adjust_start_ind = FALSE,
65
                            tokenizer = NULL) {
66
  
67
  use_quality <- ifelse(is.null(quality_vector), FALSE, TRUE)
68
  discard_amb_nt <- FALSE
69
  ## TODO: add discard_amb_nt
70
  if (!is.null(char_sequence)) {
71
    
72
    vocabulary <- stringr::str_to_lower(vocabulary)
73
    pattern <- paste0("[^", paste0(vocabulary, collapse = ""), "]")
74
    
75
    
76
    # token for ambiguous nucleotides
77
    for (i in letters) {
78
      if (!(i %in% stringr::str_to_lower(vocabulary))) {
79
        amb_nuc_token <- i
80
        break
81
      }
82
    }
83
    
84
    if (is.null(tokenizer)) {
85
      tokenizer <- keras::fit_text_tokenizer(keras::text_tokenizer(char_level = TRUE, lower = TRUE, oov_token = "0"), c(vocabulary, amb_nuc_token))
86
    }
87
    
88
    sequence <- stringr::str_to_lower(char_sequence)
89
    sequence <- stringr::str_replace_all(string = sequence, pattern = pattern, amb_nuc_token)
90
    sequence <- keras::texts_to_sequences(tokenizer, sequence)[[1]] - 1
91
  }
92
  
93
  voc_len <- length(vocabulary)
94
  if (target_len == 1) {
95
    n_gram <- NULL
96
  }
97
  if (!is.null(n_gram)) {
98
    if (target_len < n_gram) stop("target_len needs to be at least as big as n_gram")
99
  }
100
  
101
  if (adjust_start_ind) start_ind <- start_ind - start_ind[1] + 1
102
  numberOfSamples <- length(start_ind)
103
  
104
  # every row in z one-hot encodes one character in sequence, oov is zero-vector
105
  num_classes <- voc_len + 2
106
  z  <- keras::to_categorical(sequence, num_classes = num_classes)[ , -c(1, num_classes)]
107
  
108
  if (use_quality) {
109
    ones_pos <- apply(z, 1, which.max)
110
    is_zero_row <- apply(z == 0, 1, all)
111
    z <- purrr::map(1:length(quality_vector), ~create_quality_vector(pos = ones_pos[.x], prob = quality_vector[.x],
112
                                                                     voc_length = length(vocabulary))) %>% unlist() %>%
113
      matrix(ncol = length(vocabulary), byrow = TRUE)
114
    z[is_zero_row, ] <- 0
115
  }
116
  
117
  if (ambiguous_nuc == "equal") {
118
    amb_nuc_pos <- which(sequence == (voc_len + 1))
119
    z[amb_nuc_pos, ] <- matrix(rep(1/voc_len, ncol(z) * length(amb_nuc_pos)), ncol = ncol(z))
120
  }
121
  
122
  if (ambiguous_nuc == "empirical") {
123
    if (!is.null(n_gram)) stop("Can only use equal, zero or discard option for ambiguous_nuc when using n_gram encoding")
124
    amb_nuc_pos <- which(sequence == (voc_len + 1))
125
    z[amb_nuc_pos, ] <- matrix(rep(nuc_dist, length(amb_nuc_pos)), nrow = length(amb_nuc_pos), byrow = TRUE)
126
  }
127
  
128
  if (use_coverage) {
129
    z <- z * (cov_vector/max_cov)
130
  }
131
  
132
  if (target_len == 1) {
133
    
134
    if (output_format == "target_right") {
135
      x <- array(0, dim = c(numberOfSamples, maxlen, voc_len))
136
      for (i in 1:numberOfSamples) {
137
        start <- start_ind[i]
138
        x[i, , ] <- z[start : (start + maxlen - 1), ]
139
      }
140
      y <- z[start_ind + maxlen, ]
141
    }
142
    
143
    if (output_format == "wavenet") {
144
      if (!is.null(n_gram)) stop("Wavenet format not implemented for n_gram.")
145
      x <- array(0, dim = c(numberOfSamples, maxlen, voc_len))
146
      y <- array(0, dim = c(numberOfSamples, maxlen, voc_len))
147
      for (i in 1:numberOfSamples) {
148
        start <- start_ind[i]
149
        x[i, , ] <- z[start : (start + maxlen - 1), ]
150
        y[i, , ] <- z[(start + 1) : (start + maxlen), ]
151
      }
152
    }
153
    
154
    if (output_format == "target_middle_cnn") {
155
      x <- array(0, dim = c(numberOfSamples, maxlen + 1, voc_len))
156
      for (i in 1:numberOfSamples) {
157
        start <- start_ind[i]
158
        x[i, , ] <- z[start : (start + maxlen), ]
159
      }
160
      missing_val <- ceiling(maxlen/2)
161
      y <- z[start_ind + missing_val, ]
162
      x <- x[ , -(missing_val + 1), ]
163
    }
164
    
165
    if (output_format == "target_middle_lstm") {
166
      len_input_1 <- ceiling(maxlen/2)
167
      len_input_2 <- floor(maxlen/2)
168
      input_tensor_1 <- array(0, dim = c(numberOfSamples, len_input_1, voc_len))
169
      input_tensor_2 <- array(0, dim = c(numberOfSamples, len_input_2, voc_len))
170
      for (i in 1:numberOfSamples) {
171
        start <- start_ind[i]
172
        input_tensor_1[i, , ] <- z[start : (start + len_input_1 - 1), ]
173
        input_tensor_2[i, , ] <- z[(start + maxlen) : (start + len_input_1 + 1), ]
174
      }
175
      if (!is.null(n_gram)) {
176
        input_tensor_1 <- input_tensor_1[ , 1:(dim(input_tensor_1) - n_gram + 1), ]
177
        input_tensor_2 <- input_tensor_2[ , 1:(dim(input_tensor_2) - n_gram + 1), ]
178
      }
179
      x <- list(input_tensor_1, input_tensor_2)
180
      y <- z[start_ind + len_input_1, ]
181
    }
182
    
183
  }
184
  
185
  if (target_len > 1) {
186
    
187
    if (output_format == "target_right") {
188
      x <- array(0, dim = c(numberOfSamples, maxlen - target_len + 1, voc_len))
189
      for (i in 1:numberOfSamples) {
190
        start <- start_ind[i]
191
        x[i, , ] <- z[start : (start + maxlen - target_len), ]
192
      }
193
      y <- list()
194
      for (i in 1:target_len) {
195
        y[[i]] <- z[start_ind + maxlen - target_len + i, ]
196
      }
197
    }
198
    
199
    if (output_format == "target_middle_cnn") {
200
      x <- array(0, dim = c(numberOfSamples, maxlen + 1, voc_len))
201
      for (i in 1:numberOfSamples) {
202
        start <- start_ind[i]
203
        x[i, , ] <- z[start : (start + maxlen), ]
204
      }
205
      missing_val <- ceiling((maxlen - target_len)/2)
206
      y <- list()
207
      for (i in 1:target_len) {
208
        y[[i]] <- z[start_ind + missing_val + i - 1, ]
209
      }
210
      x <- x[ , -((missing_val + 1):(missing_val + target_len)), ]
211
    }
212
    
213
    if (output_format == "target_middle_lstm") {
214
      len_input_1 <- ceiling((maxlen - target_len + 1)/2)
215
      len_input_2 <- maxlen + 1 - len_input_1 - target_len
216
      input_tensor_1 <- array(0, dim = c(numberOfSamples, len_input_1, voc_len))
217
      input_tensor_2 <- array(0, dim = c(numberOfSamples, len_input_2, voc_len))
218
      for (i in 1:numberOfSamples) {
219
        start <- start_ind[i]
220
        input_tensor_1[i, , ] <- z[start : (start + len_input_1 - 1), ]
221
        input_tensor_2[i, , ] <- z[(start + maxlen) : (start + maxlen - len_input_2 + 1), ]
222
      }
223
      
224
      x <- list(input_tensor_1, input_tensor_2)
225
      y <- list()
226
      for (i in 1:target_len) {
227
        y[[i]] <- z[start_ind + len_input_1 - 1 + i, ]
228
      }
229
    }
230
    
231
    if (output_format == "wavenet") {
232
      stop("Multi target not implemented for wavenet format.")
233
    }
234
  }
235
  
236
  if (is.matrix(x)) {
237
    x <- array(x, dim = c(1, dim(x)))
238
  }
239
  
240
  if (!is.null(n_gram)) {
241
    if (is.list(y)) y <- do.call(rbind, y)
242
    y_list <- list()
243
    for (i in 1:numberOfSamples) {
244
      index <- (i-1)  + (1 + (0:(target_len-1))*numberOfSamples)
245
      input_matrix <- y[index, ]
246
      if (length(index) == 1) input_matrix <- matrix(input_matrix, nrow = 1)
247
      n_gram_matrix <- n_gram_of_matrix(input_matrix = input_matrix, n = n_gram)
248
      y_list[[i]] <- n_gram_matrix # tensorflow::tf$expand_dims(n_gram_matrix, axis = 0L)
249
    }
250
    y_tensor <- keras::k_stack(y_list, axis = 1L) %>% keras::k_eval()
251
    y <- vector("list", dim(y_tensor)[2])
252
    
253
    for (i in 1:dim(y_tensor)[2]) {
254
      y_subset <- y_tensor[ , i, ]
255
      if (numberOfSamples == 1) y_subset <- matrix(y_subset, nrow = 1)
256
      y[[i]] <- y_subset
257
    }
258
    
259
    if (is.list(y) & length(y) == 1) {
260
      y <- y[[1]]
261
    }
262
    
263
    if (n_gram_stride > 1 & is.list(y)) {
264
      stride_index <- 0:(length(y)-1) %% n_gram_stride == 0
265
      y <- y[stride_index]
266
    }
267
  }
268
  
269
  return(list(x, y))
270
}
271
272
#' Encodes integer sequence for label classification.
273
#'
274
#' Returns encoding for integer or character sequence.
275
#'
276
#' @inheritParams seq_encoding_lm
277
#' @inheritParams generator_fasta_lm
278
#' @inheritParams train_model
279
#' @param return_int Whether to return integer encoding or one-hot encoding.
280
#' @examplesIf reticulate::py_module_available("tensorflow")
281
#' # use integer sequence as input
282
#' x <- seq_encoding_label(sequence = c(1,0,5,1,3,4,3,1,4,1,2),
283
#'                         maxlen = 5,
284
#'                         vocabulary = c("a", "c", "g", "t"),
285
#'                         start_ind = c(1,3),
286
#'                         ambiguous_nuc = "equal")
287
#' 
288
#' x[1,,] # 1,0,5,1,3
289
#' 
290
#' x[2,,] # 5,1,3,4,
291
#' 
292
#' # use character string as input
293
#' x <- seq_encoding_label(maxlen = 5,
294
#'                         vocabulary = c("a", "c", "g", "t"),
295
#'                         start_ind = c(1,3),
296
#'                         ambiguous_nuc = "equal",
297
#'                         char_sequence = "ACTaaTNTNaZ")
298
#' 
299
#' x[1,,] # actaa
300
#' 
301
#' x[2,,] # taatn
302
#' 
303
#' @returns A list of 2 tensors.
304
#' @export
305
seq_encoding_label <- function(sequence = NULL, maxlen, vocabulary, start_ind, ambiguous_nuc = "zero", nuc_dist = NULL,
306
                               quality_vector = NULL, use_coverage = FALSE, max_cov = NULL,
307
                               cov_vector = NULL, n_gram = NULL, n_gram_stride = 1, masked_lm = NULL,
308
                               char_sequence = NULL, tokenizer = NULL, adjust_start_ind = FALSE,
309
                               return_int = FALSE) {
310
  
311
  ## TODO: add discard_amb_nt, add conditions for return_int
312
  use_quality <- ifelse(is.null(quality_vector), FALSE, TRUE)
313
  discard_amb_nt <- FALSE
314
  maxlen_original <- maxlen
315
  if (return_int) ambiguous_nuc <- "zero"
316
  
317
  if (!is.null(char_sequence)) {
318
    
319
    vocabulary <- stringr::str_to_lower(vocabulary)
320
    pattern <- paste0("[^", paste0(vocabulary, collapse = ""), "]")
321
    
322
    # token for ambiguous nucleotides
323
    for (i in letters) {
324
      if (!(i %in% stringr::str_to_lower(vocabulary))) {
325
        amb_nuc_token <- i
326
        break
327
      }
328
    }
329
    
330
    if (is.null(tokenizer)) {
331
      tokenizer <- keras::fit_text_tokenizer(keras::text_tokenizer(char_level = TRUE, lower = TRUE, oov_token = "0"), c(vocabulary, amb_nuc_token))
332
    }
333
    
334
    sequence <- stringr::str_to_lower(char_sequence)
335
    sequence <- stringr::str_replace_all(string = sequence, pattern = pattern, amb_nuc_token)
336
    sequence <- keras::texts_to_sequences(tokenizer, sequence)[[1]] - 1
337
  }
338
  
339
  if (adjust_start_ind) start_ind <- start_ind - start_ind[1] + 1
340
  numberOfSamples <- length(start_ind)
341
  
342
  if (is.null(n_gram_stride)) n_gram_stride <- 1
343
  voc_len <- length(vocabulary)
344
  if (!is.null(n_gram)) {
345
    sequence <- int_to_n_gram(int_seq = sequence, n = n_gram, voc_size = length(vocabulary))
346
    maxlen <- ceiling((maxlen - n_gram + 1)/n_gram_stride)
347
    voc_len <- length(vocabulary)^n_gram
348
  }
349
  
350
  if (!is.null(masked_lm)) {
351
    l <- mask_seq(int_seq = sequence,
352
                  mask_rate = masked_lm$mask_rate,
353
                  random_rate = masked_lm$random_rate,
354
                  identity_rate = masked_lm$identity_rate,
355
                  start_ind = start_ind,
356
                  block_len = masked_lm$block_len,
357
                  voc_len = voc_len)
358
    masked_seq <- l$masked_seq
359
    sample_weight_seq <- l$sample_weight_seq
360
  }
361
  
362
  if (!return_int) {
363
    if (!is.null(masked_lm)) {
364
      # every row in z one-hot encodes one character in sequence, oov is zero-vector
365
      z_masked <- keras::to_categorical(masked_seq, num_classes = voc_len + 2)[ , -c(1)]
366
      z_masked <- matrix(z_masked, ncol = voc_len + 1)
367
      z <- keras::to_categorical(sequence, num_classes = voc_len + 2)[ , -c(1)]
368
      z <- matrix(z, ncol = voc_len + 1)
369
    } else {
370
      # every row in z one-hot encodes one character in sequence, oov is zero-vector
371
      z  <- keras::to_categorical(sequence, num_classes = voc_len + 2)[ , -c(1, voc_len + 2)]
372
      z <- matrix(z, ncol = voc_len)
373
    }
374
  }
375
  
376
  if (use_quality) {
377
    ones_pos <- apply(z, 1, which.max)
378
    is_zero_row <- apply(z == 0, 1, all)
379
    z <- purrr::map(1:length(quality_vector), ~create_quality_vector(pos = ones_pos[.x], prob = quality_vector[.x],
380
                                                                     voc_length = voc_len)) %>% unlist() %>% matrix(ncol = voc_len, byrow = TRUE)
381
    z[is_zero_row, ] <- 0
382
  }
383
  
384
  if (ambiguous_nuc == "equal") {
385
    amb_nuc_pos <- which(sequence == (voc_len + 1))
386
    z[amb_nuc_pos, ] <- matrix(rep(1/voc_len, ncol(z) * length(amb_nuc_pos)), ncol = ncol(z))
387
  }
388
  
389
  if (ambiguous_nuc == "empirical") {
390
    amb_nuc_pos <- which(sequence == (voc_len + 1))
391
    z[amb_nuc_pos, ] <- matrix(rep(nuc_dist, length(amb_nuc_pos)), nrow = length(amb_nuc_pos), byrow = TRUE)
392
  }
393
  
394
  if (use_coverage) {
395
    z <- z * (cov_vector/max_cov)
396
  }
397
  
398
  remove_end_of_seq <- ifelse(is.null(n_gram), 1, n_gram) 
399
  
400
  if (!return_int) {
401
    if (is.null(masked_lm)) {
402
      
403
      x <- array(0, dim = c(numberOfSamples, maxlen, voc_len))
404
      for (i in 1:numberOfSamples) {
405
        start <- start_ind[i]
406
        subset_index <- seq(start, (start + maxlen_original - remove_end_of_seq), by = n_gram_stride)
407
        x[i, , ] <- z[subset_index, ]
408
      }
409
      return(x)
410
      
411
    } else {
412
      
413
      x <- array(0, dim = c(numberOfSamples, maxlen, voc_len + 1))
414
      y <- array(0, dim = c(numberOfSamples, maxlen, voc_len + 1))
415
      sw <- array(0, dim = c(numberOfSamples, maxlen))
416
      
417
      for (i in 1:numberOfSamples) {
418
        start <- start_ind[i]
419
        subset_index <- seq(start, (start + maxlen - remove_end_of_seq), by = n_gram_stride)
420
        x[i, , ] <- z_masked[subset_index, ]
421
        y[i, , ] <- z[subset_index, ]
422
        sw[i, ] <- sample_weight_seq[subset_index]
423
      }
424
      return(list(x=x, y=y, sample_weight=sw))
425
      
426
    }
427
  }
428
  
429
  if (return_int) {
430
    if (is.null(masked_lm)) {
431
      
432
      x <- array(0, dim = c(numberOfSamples, maxlen))
433
      for (i in 1:numberOfSamples) {
434
        start <- start_ind[i]
435
        subset_index <- seq(start, (start + maxlen_original - remove_end_of_seq), by = n_gram_stride)
436
        x[i, ] <- sequence[subset_index]
437
      }
438
      return(x)
439
      
440
    } else {
441
      x <- array(0, dim = c(numberOfSamples, maxlen))
442
      y <- array(0, dim = c(numberOfSamples, maxlen))
443
      sw <- array(0, dim = c(numberOfSamples, maxlen))
444
      for (i in 1:numberOfSamples) {
445
        start <- start_ind[i]
446
        subset_index <- seq(start, (start + maxlen_original - remove_end_of_seq), by = n_gram_stride)
447
        x[i, ] <- masked_seq[subset_index]
448
        y[i, ] <- sequence[subset_index]
449
        sw[i, ] <- sample_weight_seq[subset_index]
450
      }
451
      return(list(x=x, y=y, sample_weight=sw))
452
      
453
    }
454
  }
455
  
456
}
457
458
#' Computes start position of samples
459
#'
460
#' Helper function for data generators. 
461
#' Computes start positions in sequence where samples can be extracted, given maxlen, step size and ambiguous nucleotide constraints.
462
#'
463
#' @inheritParams train_model
464
#' @param seq_vector Vector of character sequences.
465
#' @param length_vector Length of sequences in \code{seq_vector}.
466
#' @param maxlen Length of one predictor sequence.
467
#' @param step Distance between samples from one entry in \code{seq_vector}.
468
#' @param train_mode Either `"lm"` for language model or `"label"` for label classification. 
469
#' @param discard_amb_nuc Whether to discard all samples that contain characters outside vocabulary.
470
#' @examples
471
#' seq_vector <- c("AAACCCNNNGGGTTT")
472
#' get_start_ind(
473
#'   seq_vector = seq_vector,
474
#'   length_vector = nchar(seq_vector),
475
#'   maxlen = 4,
476
#'   step = 2,
477
#'   train_mode = "label",
478
#'   discard_amb_nuc = TRUE,
479
#'   vocabulary = c("A", "C", "G", "T"))
480
#'   
481
#' @returns A numeric vector.   
482
#' @export
483
get_start_ind <- function(seq_vector, length_vector, maxlen,
484
                          step, train_mode = "label", 
485
                          discard_amb_nuc = FALSE,
486
                          vocabulary = c("A", "C", "G", "T")) {
487
  
488
  stopifnot(train_mode == "lm" | train_mode == "label")
489
  if (!discard_amb_nuc) {
490
    if (length(length_vector) > 1) {
491
      startNewEntry <- cumsum(c(1, length_vector[-length(length_vector)]))
492
      if (train_mode == "label") {
493
        indexVector <- purrr::map(1:(length(length_vector) - 1), ~seq(startNewEntry[.x], startNewEntry[.x + 1] - maxlen, by = step))
494
      } else {
495
        indexVector <- purrr::map(1:(length(length_vector) - 1), ~seq(startNewEntry[.x], startNewEntry[.x + 1] - maxlen - 1, by = step))
496
      }
497
      indexVector <- unlist(indexVector)
498
      last_seq <- length(seq_vector)
499
      if (!(startNewEntry[last_seq] > (sum(length_vector) - maxlen + 1))) {
500
        if (train_mode == "label") {
501
          indexVector <- c(indexVector, seq(startNewEntry[last_seq], sum(length_vector) - maxlen + 1, by = step))
502
        } else {
503
          indexVector <- c(indexVector, seq(startNewEntry[last_seq], sum(length_vector) - maxlen, by = step))
504
        }
505
      }
506
      return(indexVector)
507
    } else {
508
      if (train_mode == "label") {
509
        indexVector <- seq(1, length_vector - maxlen + 1, by = step)
510
      } else {
511
        indexVector <- seq(1, length_vector - maxlen, by = step)
512
      }
513
    }
514
  } else {
515
    indexVector <- start_ind_ignore_amb(seq_vector = seq_vector, length_vector = length_vector,
516
                                        maxlen = maxlen, step = step, vocabulary = c(vocabulary, "0"), train_mode = train_mode)
517
  }
518
  return(indexVector)
519
}
520
521
522
#' Helper function for get_start_ind, extracts the start positions of all potential samples (considering step size and vocabulary)
523
#'
524
#' @param seq Sequences.
525
#' @param maxlen Length of one sample.
526
#' @param step How often to take a sample.
527
#' @param vocabulary Vector of allowed characters in samples.
528
#' @param train_mode "lm" or "label".
529
#' @noRd
530
start_ind_ignore_amb_single_seq <- function(seq, maxlen, step, vocabulary, train_mode = "lm") {
531
  
532
  vocabulary <- stringr::str_to_lower(vocabulary)
533
  vocabulary <- c(vocabulary, "0")
534
  seq <- stringr::str_to_lower(seq)
535
  len_seq <- nchar(seq)
536
  if (train_mode != "label") maxlen <- maxlen + 1
537
  stopifnot(len_seq >= maxlen)
538
  # regular expressions for allowed characters
539
  voc_pattern <- paste0("[^", paste0(vocabulary, collapse = ""), "]")
540
  
541
  pos_of_amb_nucleotides <- stringr::str_locate_all(seq, pattern = voc_pattern)[[1]][ , 1]
542
  non_start_index <-  pos_of_amb_nucleotides - maxlen + 1
543
  
544
  # define range of unallowed start indices
545
  non_start_index <- purrr::map(non_start_index, ~(.x:(.x + maxlen - 1))) %>%
546
    unlist() %>% union((len_seq - maxlen + 2):len_seq) %>% unique()
547
  # drop non-positive values
548
  if (length(non_start_index[non_start_index < 1])) {
549
    non_start_index <- unique(c(1, non_start_index[non_start_index >= 1]))
550
  }
551
  
552
  non_start_index <- non_start_index %>% sort()
553
  allowed_start <- setdiff(1:len_seq, non_start_index)
554
  len_start_vector <- length(allowed_start)
555
  
556
  
557
  if (len_start_vector < 1) {
558
    # message("Can not extract a single sampling point with current settings.")
559
    return(NULL)
560
  }
561
  
562
  # only keep indices with sufficient distance, as defined by step
563
  start_indices <- vector("integer")
564
  index <- allowed_start[1]
565
  start_indices[1] <- index
566
  count <- 1
567
  if (length(allowed_start) > 1) {
568
    for (j in 1:(length(allowed_start) - 1)) {
569
      if (allowed_start[j + 1] - index >= step) {
570
        count <- count + 1
571
        start_indices[count] <- allowed_start[j + 1]
572
        index <- allowed_start[j + 1]
573
      }
574
    }
575
  }
576
  
577
  start_indices
578
}
579
580
581
#' Helper function for get_start_ind, extracts the start positions of all potential samples (considering step size and vocabulary)
582
#'
583
#' @param seq_vector Vector of character sequences.
584
#' @param maxlen Length of one sample.
585
#' @param step How often to take a sample.
586
#' @param vocabulary Vector of allowed characters in samples.
587
#' @param train_mode "lm" or "label".
588
#' @noRd
589
start_ind_ignore_amb <- function(seq_vector, length_vector, maxlen, step, vocabulary, train_mode = "lm") {
590
  start_ind <- purrr::map(1:length(seq_vector), ~start_ind_ignore_amb_single_seq(seq = seq_vector[.x],
591
                                                                                 maxlen = maxlen,
592
                                                                                 step = step,
593
                                                                                 vocabulary = vocabulary,
594
                                                                                 train_mode = train_mode))
595
  
596
  cum_sum_length <- cumsum(length_vector)
597
  if (length(start_ind) > 1) {
598
    for (i in 2:length(start_ind)) {
599
      start_ind[[i]] <- start_ind[[i]] + cum_sum_length[i - 1]
600
    }
601
  }
602
  start_ind <- unlist(start_ind)
603
  start_ind
604
}
605
606
quality_to_probability <- function(quality_vector) {
607
  Q <- utf8ToInt(quality_vector) - 33
608
  1 - 10^(-Q/10)
609
}
610
611
create_quality_vector <- function(pos, prob, voc_length = 4) {
612
  vec <- rep(0, voc_length)
613
  vec[pos] <- prob
614
  vec[-pos] <- (1 - prob)/(voc_length - 1)
615
  vec
616
}
617
618
remove_amb_nuc_entries <- function(fasta.file, skip_amb_nuc, pattern) {
619
  chars_per_row <- nchar(fasta.file$Sequence)
620
  amb_per_row <- stringr::str_count(stringr::str_to_lower(fasta.file$Sequence), pattern)
621
  threshold_index <- (amb_per_row/chars_per_row) > skip_amb_nuc
622
  fasta.file <- fasta.file[!threshold_index, ]
623
  fasta.file
624
}
625
626
#' Estimate frequency of different classes
627
#' 
628
#' Count number of nucleotides for each class and use as estimation for relation of class distribution.
629
#' Outputs list of class relations. Can be used as input for \code{class_weigth} in \code{\link{train_model}} function.   
630
#'
631
#' @inheritParams generator_fasta_lm
632
#' @inheritParams generator_fasta_label_header_csv
633
#' @inheritParams train_model
634
#' @param file_proportion Proportion of files to randomly sample for estimating class distributions.
635
#' @param csv_path If `train_type = "label_csv"`, path to csv file containing labels.
636
#' @param named_list Whether to give class weight list names `"0", "1", ...` or not.
637
#' @examples 
638
#' 
639
#' # create dummy data
640
#' path_1 <- tempfile()
641
#' path_2 <- tempfile()
642
#' 
643
#' for (current_path in c(path_1, path_2)) {
644
#'   
645
#'   dir.create(current_path)
646
#'   # create twice as much data for first class
647
#'   num_files <- ifelse(current_path == path_1, 6, 3)
648
#'   create_dummy_data(file_path = current_path,
649
#'                     num_files = num_files,
650
#'                     seq_length = 10,
651
#'                     num_seq = 5,
652
#'                     vocabulary = c("a", "c", "g", "t"))
653
#' }
654
#' 
655
#' 
656
#' class_weight <- get_class_weight(
657
#'   path = c(path_1, path_2),
658
#'   vocabulary_label = c("A", "B"),
659
#'   format = "fasta",
660
#'   file_proportion = 1,
661
#'   train_type = "label_folder",
662
#'   csv_path = NULL)
663
#' 
664
#' class_weight
665
#' 
666
#' @returns A list of numeric values (class weights).
667
#' @export
668
get_class_weight <- function(path,
669
                             vocabulary_label = NULL,
670
                             format = "fasta",
671
                             file_proportion = 1, 
672
                             train_type = "label_folder",
673
                             named_list = FALSE,
674
                             csv_path = NULL) {
675
  
676
  classes <- count_nuc(path = path,
677
                       vocabulary_label = vocabulary_label,
678
                       format = format,
679
                       file_proportion = file_proportion,
680
                       train_type = train_type,
681
                       csv_path = csv_path)
682
  
683
  zero_entry <- classes == 0
684
  if (sum(zero_entry) > 0) {
685
    warning_message <- paste("The following classes have no samples:", paste(vocabulary_label[zero_entry]),
686
                             "\n Try bigger file_proportion size or check vocabulary_label.")
687
    warning(warning_message)
688
  }
689
  
690
  if (!is.list(classes)) {
691
    num_classes <- length(classes)
692
    total <- sum(classes)
693
    weight_list <- list()
694
    for (i in 1:(length(classes))) {
695
      weight_list[[as.character(i-1)]] <- total/(classes[i] * num_classes)
696
    }
697
    if (!named_list) names(classes) <- NULL # no list names in tf version > 2.8
698
    classes <- weight_list
699
  } else {
700
    weight_collection <- list()
701
    for (j in 1:length(classes)) {
702
      num_classes <- length(classes[[j]])
703
      total <- sum(classes[[j]])
704
      weight_list <- list()
705
      for (i in 1:(length(classes[[j]]))) {
706
        weight_list[[as.character(i-1)]] <- total/(classes[[j]][i] * num_classes)
707
      }
708
      if (!named_list) names(classes) <- NULL
709
      weigth_collection[[j]] <- weight_list
710
    }
711
    classes <- weight_collection
712
  }
713
  
714
  classes
715
}
716
717
#' Count nucleotides per class
718
#'
719
#' @inheritParams get_class_weight
720
#' @noRd
721
count_nuc <- function(path,
722
                      vocabulary_label = NULL,
723
                      format = "fasta",
724
                      # estimate class distribution from subset
725
                      file_proportion = 1,
726
                      train_type = "label_folder",
727
                      csv_path = NULL) {
728
  
729
  classes <- rep(0, length(vocabulary_label))
730
  names(classes) <- vocabulary_label
731
  
732
  # label by folder
733
  if (train_type == "label_folder") {
734
    for (j in 1:length(path)) {
735
      files <- list.files(path[[j]], full.names = TRUE)
736
      if (file_proportion < 1) {
737
        files <- sample(files, floor(file_proportion * length(files)))
738
      }
739
      for (i in files) {
740
        if (format == "fasta") {
741
          fasta.file <- microseq::readFasta(i)
742
        }
743
        if (format == "fastq") {
744
          fasta.file <- microseq::readFastq(i)
745
        }
746
        freq <- sum(nchar(fasta.file$Sequence))
747
        classes[j] <- classes[j] + freq
748
      }
749
    }
750
  }
751
  
752
  # label header
753
  if (train_type == "label_header") {
754
    files <- list.files(unlist(path), full.names = TRUE)
755
    if (file_proportion < 1) {
756
      files <- sample(files, floor(file_proportion * length(files)))
757
    }
758
    for (i in files) {
759
      if (format == "fasta") {
760
        fasta.file <- microseq::readFasta(i)
761
      }
762
      if (format == "fastq") {
763
        fasta.file <- microseq::readFastq(i)
764
      }
765
      df <- data.frame(Header = fasta.file$Header, freq = nchar(fasta.file$Sequence))
766
      df <- stats::aggregate(df$freq, by = list(Category = df$Header), FUN = sum)
767
      freq <- df$x
768
      names(freq) <- df$Category
769
      for (k in names(freq)) {
770
        classes[k] <- classes[k] + freq[k]
771
      }
772
    }
773
  }
774
  
775
  # label csv
776
  if (train_type == "label_csv") {
777
    
778
    label_csv <- utils::read.csv2(csv_path, header = TRUE, stringsAsFactors = FALSE)
779
    if (dim(label_csv)[2] == 1) {
780
      label_csv <- utils::read.csv(csv_path, header = TRUE, stringsAsFactors = FALSE)
781
    }
782
    if (!("file" %in% names(label_csv))) {
783
      stop('csv file needs one column named "file"')
784
    }
785
    
786
    row_sums <- label_csv %>% dplyr::select(-file) %>% rowSums()
787
    if (!(all(row_sums == 1))) {
788
      stop("Can only estimate class weights if labels are mutually exclusive.")
789
    }
790
    
791
    if (is.null(vocabulary_label) || missing(vocabulary_label)) {
792
      vocabulary_label <-  names(label_csv)[!names(label_csv) == "file"]
793
    } else {
794
      label_csv <- label_csv %>% dplyr::select(c(dplyr::all_of(vocabulary_label), "file"))
795
    }
796
    
797
    classes <- rep(0, length(vocabulary_label))
798
    names(classes) <- vocabulary_label
799
    
800
    path <- unlist(path)
801
    single_file_index <- stringr::str_detect(path, "fasta$|fastq$")
802
    files <- c(list.files(path[!single_file_index], full.names = TRUE), path[single_file_index])
803
    if (file_proportion < 1) {
804
      files <- sample(files, floor(file_proportion * length(files)))
805
    }
806
    for (i in files) {
807
      if (format == "fasta") {
808
        fasta.file <- microseq::readFasta(i)
809
      }
810
      if (format == "fastq") {
811
        fasta.file <- microseq::readFastq(i)
812
      }
813
      count_nuc <- sum(nchar(fasta.file$Sequence))
814
      df <- label_csv %>% dplyr::filter(file == basename(i))
815
      if (nrow(df) == 0) next
816
      index <- df[1, ] == 1
817
      current_label <- names(df)[index]
818
      classes[current_label] <- classes[current_label] + count_nuc
819
    }
820
  }
821
  return(classes)
822
}
823
824
read_fasta_fastq <- function(format, skip_amb_nuc, file_index, pattern, shuffle_input,
825
                             reverse_complement, fasta.files, use_coverage = FALSE, proportion_entries = NULL,
826
                             vocabulary_label = NULL, filter_header = FALSE, target_from_csv = NULL) {
827
  
828
  if (stringr::str_detect(format, "fasta")) {
829
    if (is.null(skip_amb_nuc)) {
830
      fasta.file <- microseq::readFasta(fasta.files[file_index])
831
    } else {
832
      fasta.file <- remove_amb_nuc_entries(microseq::readFasta(fasta.files[file_index]), skip_amb_nuc = skip_amb_nuc,
833
                                           pattern = pattern)
834
    }
835
    
836
    if (filter_header & is.null(target_from_csv)) {
837
      label_vector <- trimws(stringr::str_to_lower(fasta.file$Header))
838
      label_filter <- label_vector %in% vocabulary_label
839
      fasta.file <- fasta.file[label_filter, ]
840
    }
841
    
842
    if (!is.null(proportion_entries) && proportion_entries < 1) {
843
      index <- sample(nrow(fasta.file), max(1, floor(nrow(fasta.file) * proportion_entries)))
844
      fasta.file <- fasta.file[index, ]
845
    }
846
    
847
    if (shuffle_input) {
848
      fasta.file <- fasta.file[sample(nrow(fasta.file)), ]
849
    }
850
    
851
    if (reverse_complement) {
852
      index <- sample(c(TRUE, FALSE), nrow(fasta.file), replace = TRUE)
853
      fasta.file$Sequence[index] <- microseq::reverseComplement(fasta.file$Sequence[index])
854
    }
855
    
856
  }
857
  
858
  if (stringr::str_detect(format, "fastq")) {
859
    if (is.null(skip_amb_nuc)) {
860
      fasta.file <- microseq::readFastq(fasta.files[file_index])
861
    } else {
862
      fasta.file <- remove_amb_nuc_entries(microseq::readFastq(fasta.files[file_index]), skip_amb_nuc = skip_amb_nuc,
863
                                           pattern = pattern)
864
    }
865
    
866
    if (filter_header & is.null(target_from_csv)) {
867
      label_vector <- trimws(stringr::str_to_lower(fasta.file$Header))
868
      label_filter <- label_vector %in% vocabulary_label
869
      fasta.file <- fasta.file[label_filter, ]
870
    }
871
    
872
    if (!is.null(proportion_entries) && proportion_entries < 1) {
873
      index <- sample(nrow(fasta.file), max(1, floor(nrow(fasta.file) * proportion_entries)))
874
      fasta.file <- fasta.file[index, ]
875
    }
876
    
877
    if (shuffle_input) {
878
      fasta.file <- fasta.file[sample(nrow(fasta.file)), ]
879
    }
880
    
881
    if (reverse_complement & sample(c(TRUE, FALSE), 1)) {
882
      fasta.file$Sequence <- microseq::reverseComplement(fasta.file$Sequence)
883
    }
884
  }
885
  return(fasta.file)
886
}
887
888
input_from_csv <- function(added_label_path) {
889
  .datatable.aware = TRUE
890
  label_csv <- utils::read.csv2(added_label_path, header = TRUE, stringsAsFactors = FALSE)
891
  if (dim(label_csv)[2] == 1) {
892
    label_csv <- utils::read.csv(added_label_path, header = TRUE, stringsAsFactors = FALSE)
893
  }
894
  label_csv <- data.table::as.data.table(label_csv)
895
  label_csv$file <- stringr::str_to_lower(as.character(label_csv$file))
896
  data.table::setkey(label_csv, file)
897
  added_label_by_header <- FALSE
898
  
899
  if (!("file" %in% names(label_csv))) {
900
    stop('names in added_label_path should contain one column named "file" ')
901
  }
902
  col_name <- ifelse(added_label_by_header, "header", "file")
903
  return(list(label_csv = label_csv, col_name = col_name))
904
}
905
906
#' @rawNamespace import(data.table, except = c(first, last, between))
907
#' @noRd
908
csv_to_tensor <- function(label_csv, added_label_vector, added_label_by_header, batch_size,
909
                          start_index_list) {
910
  .datatable.aware = TRUE
911
  label_tensor <- matrix(0, ncol = ncol(label_csv) - 1, nrow = batch_size, byrow = TRUE)
912
  
913
  if (added_label_by_header) {
914
    header_unique <- unique(added_label_vector)
915
    for (i in header_unique) {
916
      label_from_csv <- label_csv[ .(i), -"header"]
917
      index_label_vector <- added_label_vector == i
918
      if (nrow(label_from_csv) > 0) {
919
        label_tensor[index_label_vector, ] <- matrix(as.matrix(label_from_csv[1, ]),
920
                                                     nrow = sum(index_label_vector), ncol = ncol(label_tensor), byrow = TRUE)
921
      }
922
    }
923
  } else {
924
    row_index <- 1
925
    for (i in 1:length(added_label_vector)) {
926
      row_filter <- added_label_vector[i]
927
      label_from_csv <- label_csv[data.table(row_filter), -"file"]
928
      samples_per_file <- length(start_index_list[[i]])
929
      assign_rows <-  row_index:(row_index + samples_per_file - 1)
930
      
931
      if (nrow(stats::na.omit(label_from_csv)) > 0) {
932
        label_tensor[assign_rows, ] <- matrix(as.matrix(label_from_csv[1, ]),
933
                                              nrow = samples_per_file, ncol = ncol(label_tensor), byrow = TRUE)
934
      }
935
      row_index <- row_index + samples_per_file
936
    }
937
  }
938
  return(label_tensor)
939
}
940
941
#' Divide tensor to list of subsets
942
#'
943
#' @noRd
944
slice_tensor <- function(tensor, target_split) {
945
  
946
  num_row <- nrow(tensor)
947
  l <- vector("list", length = length(target_split))
948
  for (i in 1:length(target_split)) {
949
    if (length(target_split[[i]]) == 1 | num_row == 1) {
950
      l[[i]] <- matrix(tensor[ , target_split[[i]]], ncol = length(target_split[[i]]))
951
    } else {
952
      l[[i]] <- tensor[ , target_split[[i]]]
953
    }
954
  }
955
  return(l)
956
}
957
958
check_header_names <- function(target_split, vocabulary_label) {
959
  target_split <- unlist(target_split)
960
  if (!all(target_split %in% vocabulary_label)) {
961
    stop_text <- paste("Your csv file has no columns named",
962
                       paste(target_split[!(target_split %in% vocabulary_label)], collapse = " "))
963
    stop(stop_text)
964
  }
965
  if (!all(vocabulary_label %in% target_split)) {
966
    warning_text <- paste("target_split does not cover the following columns:",
967
                          paste(vocabulary_label[!(vocabulary_label %in% target_split)], collapse = " "))
968
    warning(warning_text)
969
  }
970
}
971
972
count_files <- function(path, format = "fasta", train_type,
973
                        target_from_csv = NULL, train_val_split_csv = NULL) {
974
  
975
  num_files <- rep(0, length(path))
976
  if (!is.null(target_from_csv) & train_type == "label_csv") {
977
    target_files <- utils::read.csv(target_from_csv)
978
    if (ncol(target_files) == 1) target_files <- utils::read.csv2(target_from_csv)
979
    target_files <- target_files$file
980
    # are files given with absolute path
981
    full.names <- ifelse(dirname(target_files[1]) == ".", FALSE, TRUE) 
982
  }  
983
  if (!is.null(train_val_split_csv)) {
984
    tvt_files <- utils::read.csv(train_val_split_csv)
985
    if (ncol(tvt_files) == 1) tvt_files <- utils::read.csv2(train_val_split_csv)
986
    train_index <- tvt_files$type == "train"
987
    tvt_files <- tvt_files$file
988
    target_files <- intersect(tvt_files[train_index], target_files)
989
  }  
990
  
991
  for (i in 1:length(path)) {
992
    for (k in 1:length(path[[i]])) {
993
      current_path <- path[[i]][[k]]
994
      
995
      if (!is.null(train_val_split_csv)) {
996
        if (!(current_path %in% target_files)) next
997
      }
998
      
999
      if (endsWith(current_path, paste0(".", format))) {
1000
        # remove files not in csv file 
1001
        if (!is.null(target_from_csv)) {
1002
          current_files <- length(intersect(basename(target_files), basename(current_path)))
1003
        } else {
1004
          current_files <- 1
1005
        }
1006
      } else {
1007
        # remove files not in csv file 
1008
        if (!is.null(target_from_csv)) {
1009
          current_files <- list.files(current_path, pattern = paste0(".", format, "$"), full.names = full.names) %>%
1010
            intersect(target_files) %>% length()
1011
        } else {
1012
          current_files <- list.files(current_path, pattern = paste0(".", format, "$")) %>% length()
1013
        }
1014
      }
1015
      num_files[i] <- num_files[i] + current_files
1016
      
1017
      if (current_files == 0) {
1018
        stop(paste0(path[[i]][[k]], " is empty or no files with .", format, " ending in this directory"))
1019
      }
1020
    }
1021
  }
1022
  
1023
  # return number of files per class for "label_folder"
1024
  if (train_type == "label_folder") {
1025
    return(num_files)
1026
  } else {
1027
    return(sum(num_files))
1028
  }
1029
}
1030
1031
list_fasta_files <- function(path_corpus, format, file_filter) {
1032
  
1033
  fasta.files <- list()
1034
  path_corpus <- unlist(path_corpus)
1035
  
1036
  for (i in 1:length(path_corpus)) {
1037
    
1038
    if (endsWith(path_corpus[[i]], paste0(".", format))) {
1039
      fasta.files[[i]] <- path_corpus[[i]]
1040
      
1041
    } else {
1042
      
1043
      fasta.files[[i]] <- list.files(
1044
        path = path_corpus[[i]],
1045
        pattern = paste0("\\.", format, "$"),
1046
        full.names = TRUE)
1047
    }
1048
  }
1049
  fasta.files <- unlist(fasta.files)
1050
  num_files <- length(fasta.files)
1051
  
1052
  if (!is.null(file_filter)) {
1053
    
1054
    # file filter files given with/without absolute path
1055
    if (all(basename(file_filter) == file_filter)) {
1056
      fasta.files <- fasta.files[basename(fasta.files) %in% file_filter]
1057
    } else {
1058
      fasta.files <- fasta.files[fasta.files %in% file_filter]
1059
    }
1060
    
1061
    if (length(fasta.files) < 1) {
1062
      stop_text <- paste0("None of the files from ", unlist(path_corpus),
1063
                          " are present in train_val_split_csv table for either train or validation. \n")
1064
      stop(stop_text)
1065
    }
1066
  }
1067
  
1068
  fasta.files <- gsub(pattern="/+", replacement="/", x = fasta.files)
1069
  fasta.files <- gsub(pattern="/$", replacement="", x = fasta.files)
1070
  return(fasta.files)
1071
}
1072
1073
get_coverage <- function(fasta.file) {
1074
  header <- fasta.file$Header
1075
  cov <- stringr::str_extract(header, "cov_\\d+") %>%
1076
    stringr::str_extract("\\d+") %>% as.integer()
1077
  cov[is.na(cov)] <- 1
1078
  return(cov)
1079
}
1080
1081
get_coverage_concat <- function(fasta.file, concat_seq) {
1082
  header <- fasta.file$Header
1083
  cov <- stringr::str_extract(header, "cov_\\d+") %>%
1084
    stringr::str_extract("\\d+") %>% as.integer()
1085
  cov[is.na(cov)] <- 1
1086
  len_vec <- nchar(fasta.file$Sequence)
1087
  cov <- purrr::map(1:nrow(fasta.file), ~rep(cov[.x], times = len_vec[.x]))
1088
  cov <- lapply(cov, append, rep(1, nchar(concat_seq)))
1089
  cov <- unlist(cov)
1090
  cov <- cov[-((length(cov) - nchar(concat_seq)) : length(cov))]
1091
  return(cov)
1092
}
1093
1094
#' Reshape tensors for set learning
1095
#' 
1096
#' Reshape input x and target y. Aggregates multiple samples from x and y into single input/target batches.  
1097
#' 
1098
#' @param x 3D input tensor.
1099
#' @param y 2D target tensor.
1100
#' @param samples_per_target How many samples to use for one target
1101
#' @param reshape_mode `"time_dist", "multi_input"` or `"concat"` 
1102
#' \itemize{
1103
#' \item If `"multi_input"`, will produce `samples_per_target` separate inputs, each of length `maxlen`.
1104
#' \item If `"time_dist"`, will produce a 4D input array. The dimensions correspond to
1105
#' `(new_batch_size, samples_per_target, maxlen, length(vocabulary))`.
1106
#' \item If `"concat"`, will concatenate `samples_per_target` sequences of length `maxlen` to one long sequence
1107
#' }
1108
#' @param buffer_len Only applies if `reshape_mode = "concat"`. If `buffer_len` is an integer, the subsequences are interspaced with `buffer_len` rows. The reshaped x has
1109
#' new maxlen: (`maxlen` \eqn{*} `samples_per_target`) + `buffer_len` \eqn{*} (`samples_per_target` - 1).
1110
#' @param new_batch_size Size of first axis of input/targets after reshaping.
1111
#' @param check_y Check if entries in `y` are consistent with reshape strategy (same label when aggregating).   
1112
#' @examplesIf reticulate::py_module_available("tensorflow")
1113
#' # create dummy data
1114
#' batch_size <- 8
1115
#' maxlen <- 11
1116
#' voc_len <- 4 
1117
#' x <- sample(0:(voc_len-1), maxlen*batch_size, replace = TRUE)
1118
#' x <- keras::to_categorical(x, num_classes = voc_len)
1119
#' x <- array(x, dim = c(batch_size, maxlen, voc_len))
1120
#' y <- rep(0:1, each = batch_size/2)
1121
#' y <- keras::to_categorical(y, num_classes = 2)
1122
#' y
1123
#' 
1124
#' # reshape data for multi input model
1125
#' reshaped_data <- reshape_tensor(
1126
#'   x = x,
1127
#'   y = y,
1128
#'   new_batch_size = 2,
1129
#'   samples_per_target = 4,
1130
#'   reshape_mode = "multi_input")
1131
#' 
1132
#' length(reshaped_data[[1]])
1133
#' dim(reshaped_data[[1]][[1]])
1134
#' reshaped_data[[2]]
1135
#' 
1136
#' @returns A list of 2 tensors.
1137
#' @export
1138
reshape_tensor <- function(x, y, new_batch_size,
1139
                           samples_per_target,
1140
                           buffer_len = NULL,
1141
                           reshape_mode = "time_dist",
1142
                           check_y = FALSE) {
1143
  
1144
  batch_size <- dim(x)[1]
1145
  maxlen <- dim(x)[2]
1146
  voc_len <- dim(x)[3]
1147
  num_classes <- dim(y)[2]
1148
  
1149
  if (check_y) {
1150
    targets <- apply(y, 1, which.max)
1151
    test_y_dist <- all(targets == rep(1:num_classes, each = batch_size/num_classes))
1152
    if (!test_y_dist) {
1153
      stop("y must have same number of samples for each class")
1154
    }
1155
  }
1156
  
1157
  if (reshape_mode == "time_dist") {
1158
    
1159
    x_new <- array(0, dim = c(new_batch_size, samples_per_target, maxlen, voc_len))
1160
    y_new <- array(0, dim = c(new_batch_size, num_classes))
1161
    for (i in 1:new_batch_size) {
1162
      index <- (1:samples_per_target) + (i-1)*samples_per_target
1163
      x_new[i, , , ] <- x[index, , ]
1164
      y_new[i, ]  <- y[index[1], ]
1165
    }
1166
    
1167
    return(list(x = x_new, y = y_new))
1168
  }
1169
  
1170
  if (reshape_mode == "multi_input") {
1171
    
1172
    x_list <- vector("list", samples_per_target)
1173
    for (i in 1:samples_per_target) {
1174
      x_index <- base::seq(i, batch_size, samples_per_target)
1175
      x_list[[i]] <- x[x_index, , ]
1176
    }
1177
    y <- y[base::seq(1, batch_size, samples_per_target), ]
1178
    return(list(x = x_list, y = y))
1179
  }
1180
  
1181
  if (reshape_mode == "concat") {
1182
    
1183
    use_buffer <- !is.null(buffer_len) && buffer_len > 0
1184
    if (use_buffer) {
1185
      buffer_tensor <- array(0, dim = c(buffer_len, voc_len))
1186
      buffer_tensor[ , voc_len] <- 1
1187
      concat_maxlen <- (maxlen * samples_per_target) + (buffer_len * (samples_per_target - 1))
1188
    } else {
1189
      concat_maxlen <- maxlen * samples_per_target 
1190
    }
1191
    
1192
    x_new <- array(0, dim = c(new_batch_size, concat_maxlen, voc_len))
1193
    y_new <- array(0, dim = c(new_batch_size, num_classes))
1194
    
1195
    
1196
    for (i in 1:new_batch_size) {
1197
      index <- (1:samples_per_target) + (i-1)*samples_per_target
1198
      if (!use_buffer) {
1199
        x_temp <- x[index, , ]
1200
        x_temp <- reticulate::array_reshape(x_temp, dim = c(1, dim(x_temp)[1] * dim(x_temp)[2], voc_len))
1201
      } else {
1202
        # create list of subsequences interspaced with buffer tensor
1203
        x_list <- vector("list", (2*samples_per_target) - 1)
1204
        x_list[seq(2, length(x_list), by = 2)] <- list(buffer_tensor)
1205
        for (k in 1:length(index)) {
1206
          x_list[[(2*k) - 1]] <- x[index[k], , ]
1207
        }
1208
        x_temp <- do.call(rbind, x_list)
1209
      }
1210
      
1211
      x_new[i, , ] <- x_temp
1212
      y_new[i, ]  <- y[index[1], ]
1213
    }
1214
    return(list(x = x_new, y = y_new))
1215
  }
1216
}
1217
1218
#' Transform confusion matrix with total numbers to matrix with percentages.
1219
#'
1220
#' @noRd
1221
cm_perc <- function(cm, round_dig = 2) {
1222
  col_sums <- colSums(cm)
1223
  for (i in 1:ncol(cm)) {
1224
    if (col_sums[i] == 0) {
1225
      cm[ , i] <- 0
1226
    } else {
1227
      cm[ , i] <- cm[ , i]/col_sums[i]
1228
    }
1229
  }
1230
  cm <- round(cm, round_dig)
1231
  cm
1232
}
1233
1234
create_conf_mat_obj <- function(m, confMatLabels) {
1235
  dimnames(m) <- list(Prediction = confMatLabels, Truth = confMatLabels)
1236
  l <- list()
1237
  m <- as.table(m)
1238
  l[["table"]] <- m
1239
  l[["dots"]] <- list()
1240
  class(l) <- "conf_mat"
1241
  return(l)
1242
}
1243
1244
#' Encode sequence of integers to sequence of n-gram 
1245
#' 
1246
#' Input is sequence of integers from vocabulary of size \code{voc_size}. 
1247
#' Returns vector of integers corresponding to n-gram encoding.
1248
#' Integers greater than `voc_size` get encoded as `voc_size^n + 1`.
1249
#' 
1250
#' @param int_seq Integer sequence
1251
#' @param n Length of n-gram aggregation
1252
#' @param voc_size Size of vocabulary.
1253
#' @examples
1254
#' int_to_n_gram(int_seq = c(1,1,2,4,4), n = 2, voc_size = 4)
1255
#' 
1256
#' @returns A numeric vector.
1257
#' @export
1258
int_to_n_gram <- function(int_seq, n, voc_size = 4) {
1259
  
1260
  encoding_len <- length(int_seq) - n + 1
1261
  n_gram_encoding <- vector("numeric", encoding_len)
1262
  oov_token <- voc_size^n + 1
1263
  padding_token <- 0
1264
  
1265
  for (i in 1:encoding_len) {
1266
    int_seq_subset <- int_seq[i:(i + n - 1)]
1267
    
1268
    if (prod(int_seq_subset) == 0) {
1269
      n_gram_encoding[i] <- padding_token
1270
    } else {
1271
      # encoding for amb nuc
1272
      if (any(int_seq_subset > voc_size)) {
1273
        n_gram_encoding[i] <- oov_token
1274
      } else {
1275
        int_seq_subset <- int_seq_subset - 1
1276
        n_gram_encoding[i] <- 1 + sum(voc_size^((n-1):0) * (int_seq_subset))
1277
      }
1278
    }
1279
  }
1280
  n_gram_encoding
1281
}
1282
1283
#' One-hot encoding matrix to n-gram encoding matrix
1284
#' 
1285
#' @param input_matrix Matrix with one 1 per row and zeros otherwise.
1286
#' @param n Length of one n-gram.   
1287
#' @examplesIf reticulate::py_module_available("tensorflow")
1288
#' x <- c(0,0,1,3,3) 
1289
#' input_matrix <- keras::to_categorical(x, 4)
1290
#' n_gram_of_matrix(input_matrix, n = 2) 
1291
#' 
1292
#' @returns Matrix of one-hot encodings. 
1293
#' @export
1294
n_gram_of_matrix <- function(input_matrix, n = 3) {
1295
  voc_len <- ncol(input_matrix)^n
1296
  oov_index <- apply(input_matrix, 1, max) != 1
1297
  max_index <- apply(input_matrix, 1, which.max)
1298
  max_index[oov_index] <- voc_len + 1
1299
  int_enc <- int_to_n_gram(int_seq = max_index, n = n, voc_size = ncol(input_matrix))
1300
  if (length(int_enc) == 1) {
1301
    n_gram_matrix <- matrix(keras::to_categorical(int_enc, num_classes = voc_len + 2), nrow = 1)[ , -c(1, voc_len + 2)]
1302
  } else {
1303
    n_gram_matrix <- keras::to_categorical(int_enc, num_classes = voc_len + 2)[ , -c(1, voc_len + 2)]
1304
  }
1305
  n_gram_matrix <- matrix(n_gram_matrix, ncol = voc_len)
1306
  return(n_gram_matrix)
1307
}
1308
1309
n_gram_of_3d_tensor <- function(tensor_3d, n) {
1310
  new_dim <- dim(tensor_3d)
1311
  new_dim[2] <- new_dim[2] - n + 1
1312
  new_dim[3] <- new_dim[3]^n
1313
  new_tensor <- array(0, dim = new_dim)
1314
  for (i in 1:dim(tensor_3d)[1]) {
1315
    new_tensor[i, , ] <- n_gram_of_matrix(tensor_3d[i, , ], n = n)
1316
  }
1317
  new_tensor
1318
}
1319
1320
n_gram_vocabulary <- function(n_gram = 3, vocabulary = c("A", "C", "G", "T")) {
1321
  l <- list()
1322
  for (i in 1:n_gram) {
1323
    l[[i]] <- vocabulary
1324
  }
1325
  df <- expand.grid(l)
1326
  df <- df[ , ncol(df) : 1]
1327
  n_gram_nuc <- apply(df, 1, paste, collapse = "") 
1328
  n_gram_nuc
1329
}
1330
1331
1332
#' Split fasta file into smaller files.
1333
#'
1334
#' Returns smaller files with same file name and "_x" (where x is an integer). For example,
1335
#' assume we have input file called "abc.fasta" with 100 entries and `split_n = 50`. Function will
1336
#' create two files called "abc_1.fasta" and "abc_2.fasta" in `target_path`.
1337
#'
1338
#' @param path_input Fasta file to split into smaller files
1339
#' @param split_n Maximum number of entries to use in smaller file.
1340
#' @param target_folder Directory for output.
1341
#' @param shuffle_entries Whether to shuffle fasta entries before split.
1342
#' @param delete_input Whether to delete the original file.
1343
#' @examples
1344
#' path_input <- tempfile(fileext = '.fasta')
1345
#' create_dummy_data(file_path = path_input,
1346
#'                   num_files = 1,
1347
#'                   write_to_file_path = TRUE,
1348
#'                   seq_length = 7,
1349
#'                   num_seq = 25,
1350
#'                   vocabulary = c("a", "c", "g", "t"))
1351
#' target_folder <- tempfile()
1352
#' dir.create(target_folder)
1353
#' 
1354
#' # split 25 entries into 5 files
1355
#' split_fasta(path_input = path_input,
1356
#'             target_folder = target_folder,
1357
#'             split_n = 5)
1358
#' length(list.files(target_folder)) 
1359
#' 
1360
#' @returns None. Writes files to output.
1361
#' @export
1362
split_fasta <- function(path_input,
1363
                        target_folder,
1364
                        split_n = 500,
1365
                        shuffle_entries = TRUE,
1366
                        delete_input = FALSE) {
1367
  
1368
  fasta_file <- microseq::readFasta(path_input)
1369
  
1370
  base_name <- basename(stringr::str_remove(path_input, ".fasta"))
1371
  new_path <- paste0(target_folder, "/", base_name)
1372
  count <- 1
1373
  start_index <- 1
1374
  end_index <- 1
1375
  
1376
  if (nrow(fasta_file) == 1) {
1377
    fasta_name <- paste0(new_path, "_", count, ".fasta")
1378
    microseq::writeFasta(fasta_file, fasta_name)
1379
    if (delete_input) {
1380
      file.remove(path_input)
1381
    }
1382
    return(NULL)
1383
  }
1384
  
1385
  if (shuffle_entries) {
1386
    fasta_file <- fasta_file[sample(nrow(fasta_file)), ]
1387
  }
1388
  
1389
  while (end_index < nrow(fasta_file)) {
1390
    end_index <- min(start_index + split_n - 1, nrow(fasta_file))
1391
    index <- start_index : end_index
1392
    sub_df <- fasta_file[index, ]
1393
    fasta_name <- paste0(new_path, "_", count, ".fasta")
1394
    microseq::writeFasta(sub_df, fasta_name)
1395
    start_index <- start_index + split_n
1396
    count <- count + 1
1397
  }
1398
  
1399
  if (delete_input) {
1400
    file.remove(path_input)
1401
  }
1402
}
1403
1404
#' Add noise to tensor
1405
#'
1406
#' @param noise_type "normal" or "uniform".
1407
#' @param ... additional arguments for rnorm or runif call.
1408
#' @noRd
1409
add_noise_tensor <- function(x, noise_type, ...) {
1410
  
1411
  stopifnot(noise_type %in% c("normal", "uniform"))
1412
  random_fn <- ifelse(noise_type == "normal", "rnorm", "runif")
1413
  
1414
  if (is.list(x)) {
1415
    for (i in 1:length(x)) {
1416
      x_dim <- dim(x[[i]])
1417
      noise_tensor <- do.call(random_fn, list(n = prod(x_dim[-1]), ...))
1418
      noise_tensor <- array(noise_tensor, dim = x_dim)
1419
      x[[i]] <- x[[i]] + noise_tensor
1420
    }
1421
  } else {
1422
    x_dim <- dim(x)
1423
    stopifnot(noise_type %in% c("normal", "uniform"))
1424
    random_fn <- ifelse(noise_type == "normal", "rnorm", "runif")
1425
    noise_tensor <- do.call(random_fn, list(n = prod(x_dim[-1]), ...))
1426
    noise_tensor <- array(noise_tensor, dim = x_dim)
1427
    x <- x + noise_tensor
1428
  }
1429
  
1430
  return(x)
1431
}
1432
1433
reverse_complement_tensor <- function(x) {
1434
  stopifnot(dim(x)[3] == 4)
1435
  x_rev_comp <- x[ ,  dim(x)[2]:1, 4:1]
1436
  x_rev_comp <- array(x_rev_comp, dim = dim(x))
1437
  x_rev_comp
1438
}
1439
1440
1441
get_pos_enc <- function(pos, i, d_model, n = 10000) {
1442
  
1443
  pw <- (2 * floor(i/2)) / d_model
1444
  angle_rates <- 1 / (n ^ pw)
1445
  angle <- pos * angle_rates
1446
  pos_enc <- ifelse(i %% 2 == 0, sin(angle), cos(angle))
1447
  return(pos_enc)
1448
}  
1449
1450
positional_encoding <- function(seq_len, d_model, n=10000) {
1451
  
1452
  P = matrix(0, nrow = seq_len, ncol = d_model)
1453
  
1454
  for (pos in 0:(seq_len - 1)) {
1455
    for (i in 0:(d_model - 1)) {
1456
      P[pos + 1, i + 1] <- get_pos_enc(pos, i, d_model, n)
1457
    }
1458
  }
1459
  
1460
  return(P)
1461
}
1462
1463
1464
subset_tensor_list <- function(tensor_list, dim_list, subset_index, dim_n_list) {
1465
  
1466
  for (i in 1:length(tensor_list)) {
1467
    tensor_list[[i]] <- subset_tensor(tensor = tensor_list[[i]],
1468
                                      subset_index = subset_index,
1469
                                      dim_n = dim_n_list[[i]])
1470
  }
1471
  
1472
}
1473
1474
subset_tensor <- function(tensor, subset_index, dim_n) {
1475
  
1476
  if (dim_n == 1) {
1477
    subset_tensor <- tensor[subset_index]
1478
  }
1479
  
1480
  if (dim_n == 2) {
1481
    subset_tensor <- tensor[subset_index, ]
1482
  }
1483
  
1484
  if (dim_n == 3) {
1485
    subset_tensor <- tensor[subset_index, , ]
1486
  }
1487
  
1488
  if (dim_n == 4) {
1489
    subset_tensor <- tensor[subset_index, , , ]
1490
  }
1491
  
1492
  if (length(subset_index) == 1 & dim_n > 1) {
1493
    subset_tensor <- tensorflow::tf$expand_dims(subset_tensor, axis = 0L)
1494
  }
1495
}
1496
1497
1498
mask_seq <- function(int_seq,
1499
                     mask_rate = NULL,
1500
                     random_rate = NULL,
1501
                     identity_rate = NULL,
1502
                     block_len = NULL,
1503
                     start_ind = NULL,
1504
                     voc_len) {
1505
  
1506
  mask_token <- voc_len + 1
1507
  if (is.null(mask_rate)) mask_rate <- 0
1508
  if (is.null(random_rate)) random_rate <- 0
1509
  if (is.null(identity_rate)) identity_rate <- 0
1510
  mask_perc <- mask_rate + random_rate + identity_rate
1511
  if (mask_perc > 1) {
1512
    stop("Sum of mask_rate, random_rate, identity_rate bigger than 1")
1513
  } 
1514
  # don't mask padding or oov positions 
1515
  valid_pos <- which(int_seq != 0 & int_seq != mask_token) 
1516
  
1517
  # randomly decide whether to round up or down
1518
  ceiling_floor <- sample(c(TRUE, FALSE), 3, replace = TRUE)
1519
  # adjust for block len
1520
  block_len_adjust <- ifelse(is.null(block_len), 1, block_len) 
1521
  
1522
  num_mask_pos <- (mask_rate * length(valid_pos))/block_len_adjust
1523
  num_mask_pos <- ifelse(ceiling_floor[1], floor(num_mask_pos), ceiling(num_mask_pos))
1524
  num_random_pos <- (random_rate * length(valid_pos))/block_len_adjust
1525
  num_random_pos <- ifelse(ceiling_floor[2], floor(num_random_pos), ceiling(num_random_pos))
1526
  num_identity_pos <- (identity_rate * length(valid_pos))/block_len_adjust
1527
  num_identity_pos <- ifelse(ceiling_floor[3], floor(num_identity_pos), ceiling(num_identity_pos))
1528
  num_all_pos <- num_mask_pos + num_random_pos + num_identity_pos
1529
  if (is.null(block_len)) {
1530
    all_pos <- sample(valid_pos, num_all_pos)
1531
  } else {
1532
    valid_pos_block_len <- seq(from = sample(1:(block_len - 1), 1), to = length(valid_pos), by = block_len)
1533
    valid_pos <- intersect(valid_pos_block_len, valid_pos)
1534
    all_pos <- sample(valid_pos, min(num_all_pos, length(valid_pos)))
1535
  }
1536
  
1537
  sample_weight_seq <- rep(0, length(int_seq))
1538
  if (is.null(block_len)) {
1539
    sample_weight_seq[all_pos] <- 1
1540
  } else {
1541
    all_pos_blocks <- purrr::map(all_pos, ~seq(.x, .x + block_len - 1, by = 1))
1542
    sample_weight_seq[unlist(all_pos_blocks)] <- 1
1543
  }
1544
  
1545
  if (num_mask_pos > 0) {
1546
    mask_index <- sample(all_pos, num_mask_pos)
1547
    all_pos <- setdiff(all_pos, mask_index)
1548
    if (!is.null(block_len)) {
1549
      mask_index <- purrr::map(mask_index, ~seq(.x, .x + block_len - 1, by = 1)) %>% 
1550
        unlist()
1551
    }
1552
    int_seq[mask_index] <- mask_token
1553
  }
1554
  
1555
  if (num_random_pos > 0) {
1556
    random_index <- sample(all_pos, num_random_pos)
1557
    all_pos <- setdiff(all_pos, random_index)
1558
    if (!is.null(block_len)) {
1559
      random_index <- purrr::map(random_index, ~seq(.x, .x + block_len - 1, by = 1)) %>% 
1560
        unlist()
1561
    }
1562
    int_seq[random_index] <- sample(1:voc_len, length(random_index), replace = TRUE)
1563
  }
1564
  
1565
  # mask oov tokens
1566
  sample_weight_seq[int_seq == mask_token] <- 1
1567
  
1568
  return(list(masked_seq = int_seq, sample_weight_seq = sample_weight_seq))
1569
  
1570
}
1571
1572
#' Char sequence corresponding to one-hot matrix.
1573
#'
1574
#' Return character sequence corresponding to one-hot elements in matrix or tensor.
1575
#'
1576
#' @inheritParams generator_fasta_lm
1577
#' @param m One-hot encoding matrix or 3d array where each element of first axis is one-hot matrix.
1578
#' @param amb_enc Either `"zero"` or `"equal"`. How oov tokens where treated for one-hot encoding. 
1579
#' @param amb_char Char to use for oov positions.
1580
#' @param paste_chars Whether to return vector or single sequence.
1581
#' @examples 
1582
#' m <- matrix(c(1,0,0,0,0,1,0,0), 2)
1583
#' one_hot_to_seq(m)
1584
#' 
1585
#' @returns A string.
1586
#' @export
1587
one_hot_to_seq <- function(m, vocabulary = c("A", "C", "G", "T"), amb_enc = "zero",
1588
                           amb_char = "N", paste_chars = TRUE) {
1589
  
1590
  if (length(dim(m)) == 3) {
1591
    seq_list <- list()
1592
    for (i in 1:dim(m)[1]) {
1593
      seq_list[[i]] <- one_hot_to_seq(m = m[i, , ], vocabulary = vocabulary, amb_enc = amb_enc,
1594
                                      amb_char = amb_char, paste_chars = paste_chars)
1595
    }
1596
    return(seq_list)
1597
  }
1598
  
1599
  if (amb_enc == "zero") {
1600
    amb_row <- which(rowSums(m) == 0)
1601
  }
1602
  
1603
  if (amb_enc == "equal") {
1604
    amb_row <- which(rowSums[ , 1] == 1/length(vocabulary))
1605
  }
1606
  
1607
  nt_seq <- vocabulary[apply(m, 1, which.max)]
1608
  nt_seq[amb_row] <- amb_char
1609
  
1610
  if (paste_chars) {
1611
    nt_seq <- paste(nt_seq, collapse = "")
1612
  } 
1613
  
1614
  return(nt_seq)
1615
  
1616
}