Diff of /R/predict.R [000000] .. [409433]

Switch to unified view

a b/R/predict.R
1
#' Make prediction for nucleotide sequence or entries in fasta/fastq file
2
#'
3
#' @description Removes layers (optional) from pretrained model and calculates states of fasta/fastq file or nucleotide sequence.
4
#' Writes states to h5 or csv file (access content of h5 output with \code{\link{load_prediction}} function).
5
#' There are several options on how to process an input file:
6
#' \itemize{
7
#' \item If `"one_seq"`, computes prediction for sequence argument or fasta/fastq file.
8
#' Combines fasta entries in file to one sequence. This means predictor sequences can contain elements from more than one fasta entry.
9
#' \item If `"by_entry"`, will output a separate file for each fasta/fastq entry.
10
#' Names of output files are: `output_dir` + "Nr" + i + `filename` + `output_type`, where i is the number of the fasta entry.
11
#' \item If `"by_entry_one_file"`, will store prediction for all fasta entries in one h5 file.
12
#' \item If `"one_pred_per_entry"`, will make one prediction for each entry by either picking random sample for long sequences
13
#' or pad sequence for short sequences.
14
#' }
15
#' 
16
#' @inheritParams get_generator
17
#' @inheritParams train_model
18
#' @inheritParams get_generator 
19
#' @inheritParams train_model
20
#' @param layer_name Name of layer to get output from. If `NULL`, will use the last layer.
21
#' @param path_input Path to fasta file.
22
#' @param sequence Character string, ignores path_input if argument given.
23
#' @param round_digits Number of decimal places.
24
#' @param mode Either `"lm"` for language model or `"label"` for label classification.
25
#' @param include_seq Whether to include input sequence in h5 file.
26
#' @param output_format Either `"one_seq"`, `"by_entry"`, `"by_entry_one_file"`, `"one_pred_per_entry"`.
27
#' @param output_type `"h5"` or `"csv"`. If `output_format`` is `"by_entries_one_file", "one_pred_per_entry"` can only be `"h5"`.
28
#' @param return_states Return predictions as data frame. Only supported for output_format `"one_seq"`.
29
#' @param padding Either `"none"`, `"maxlen"`, `"standard"` or `"self"`.
30
#' \itemize{
31
#' \item If `"none"`, apply no padding and skip sequences that are too short.
32
#' \item If `"maxlen"`, pad with maxlen number of zeros vectors.
33
#' \item If `"standard"`, pad with zero vectors only if sequence is shorter than maxlen. Pads to minimum size required for one prediction.
34
#' \item If `"self"`, concatenate sequence with itself until sequence is long enough for one prediction.
35
#' Example: if sequence is "ACGT" and maxlen is 10, make prediction for "ACGTACGTAC". 
36
#' Only applied if sequence is shorter than maxlen.
37
#' }
38
#' @param verbose Boolean.
39
#' @param filename Filename to store states in. No file output if argument is `NULL`.
40
#' If `output_format = "by_entry"`, adds "_nr_" + "i" after name, where i is entry number.
41
#' @param output_dir Directory for file output.
42
#' @param use_quality Whether to use quality scores.
43
#' @param quality_string String for encoding with quality scores (as used in fastq format).
44
#' @param lm_format Either `"target_right"`, `"target_middle_lstm"`, `"target_middle_cnn"` or `"wavenet"`.
45
#' @param ... Further arguments for sequence encoding with \code{\link{seq_encoding_label}}.
46
#' @examplesIf reticulate::py_module_available("tensorflow")
47
#' # make prediction for single sequence and write to h5 file
48
#' model <- create_model_lstm_cnn(maxlen = 20, layer_lstm = 8, layer_dense = 2, verbose = FALSE)
49
#' vocabulary <- c("a", "c", "g", "t")
50
#' sequence <- paste(sample(vocabulary, 200, replace = TRUE), collapse = "")
51
#' output_file <- tempfile(fileext = ".h5")
52
#' predict_model(output_format = "one_seq", model = model, step = 10,
53
#'              sequence = sequence, filename = output_file, mode = "label")
54
#' 
55
#' # make prediction for fasta file with multiple entries, write output to separate h5 files
56
#' fasta_path <- tempfile(fileext = ".fasta")
57
#' create_dummy_data(file_path = fasta_path, num_files = 1,
58
#'                  num_seq = 5, seq_length = 100,
59
#'                  write_to_file_path = TRUE)
60
#' model <- create_model_lstm_cnn(maxlen = 20, layer_lstm = 8, layer_dense = 2, verbose = FALSE)
61
#' output_dir <- tempfile()
62
#' dir.create(output_dir)
63
#' predict_model(output_format = "by_entry", model = model, step = 10, verbose = FALSE,
64
#'                output_dir = output_dir, mode = "label", path_input = fasta_path)
65
#' list.files(output_dir)
66
#' 
67
#' @returns If `return_states = TRUE` returns a list of model predictions and position of corresponding sequences.
68
#' If additionally `include_seq = TRUE`, list contains sequence strings.
69
#' If `return_states = FALSE` returns nothing, just writes output to file(s).  
70
#' @export
71
predict_model <- function(model, output_format = "one_seq", layer_name = NULL, sequence = NULL, path_input = NULL,
72
                          round_digits = NULL, filename = "states.h5", step = 1, vocabulary = c("a", "c", "g", "t"),
73
                          batch_size = 256, verbose = TRUE, return_states = FALSE, 
74
                          output_type = "h5", padding = "none", use_quality = FALSE, quality_string = NULL,
75
                          mode = "label", lm_format = "target_right", output_dir = NULL,
76
                          format = "fasta", include_seq = FALSE, reverse_complement_encoding = FALSE,
77
                          ambiguous_nuc = "zero", ...) {
78
  
79
  stopifnot(padding %in% c("standard", "self", "none", "maxlen"))
80
  stopifnot(output_format %in% c("one_seq", "by_entry", "by_entry_one_file", "one_pred_per_entry"))
81
  if (output_format %in% c("by_entry_one_file", "one_pred_per_entry") & output_type == "csv") {
82
    message("by_entry_one_file or one_pred_per_entry only implemented for h5 output.
83
            Setting output_type to h5")
84
    output_type <- "h5"
85
  }
86
  
87
  if (output_format == "one_seq") {
88
    output_list <- predict_model_one_seq(layer_name = layer_name, sequence = sequence, path_input = path_input,
89
                                         round_digits = round_digits, filename = filename, step = step, vocabulary = vocabulary,
90
                                         batch_size = batch_size, verbose = verbose, return_states = return_states, 
91
                                         padding = padding, quality_string = quality_string, use_quality = use_quality,
92
                                         output_type = output_type, model = model, mode = mode, lm_format = lm_format,
93
                                         format = format, include_seq = include_seq, ambiguous_nuc = ambiguous_nuc,
94
                                         reverse_complement_encoding = reverse_complement_encoding, ...)
95
    return(output_list)
96
  }
97
  
98
  if (output_format == "by_entry") {
99
    predict_model_by_entry(layer_name = layer_name, path_input = path_input,
100
                           round_digits = round_digits, filename = filename, step = step, 
101
                           # vocabulary = vocabulary, quality_string = quality_string, ambiguous_nuc = ambiguous_nuc,
102
                           batch_size = batch_size, verbose = verbose, use_quality = use_quality,
103
                           output_type = output_type, model = model, mode = mode, lm_format = lm_format,
104
                           output_dir = output_dir, format = format, include_seq = include_seq,
105
                           padding = padding, reverse_complement_encoding = reverse_complement_encoding, ...)
106
  }
107
  
108
  if (output_format == "by_entry_one_file") {
109
    predict_model_by_entry_one_file(layer_name = layer_name, path_input = path_input,
110
                                    round_digits = round_digits, filename = filename, step = step, 
111
                                    # vocabulary = vocabulary, quality_string = quality_string, ambiguous_nuc = ambiguous_nuc, 
112
                                    batch_size = batch_size, verbose = verbose, use_quality = use_quality,
113
                                    model = model, mode = mode, lm_format = lm_format, format = format,
114
                                    padding = padding, include_seq = include_seq, 
115
                                    reverse_complement_encoding = reverse_complement_encoding, ...)
116
  }
117
  
118
  if (output_format == "one_pred_per_entry") {
119
    if (mode == "lm") {
120
      stop("one_pred_per_entry only implemented for label classification")
121
    }
122
    predict_model_one_pred_per_entry(layer_name = layer_name, path_input = path_input,
123
                                     round_digits = round_digits, filename = filename, 
124
                                     batch_size = batch_size, verbose = verbose, model = model, format = format,
125
                                     # ambiguous_nuc = ambiguous_nuc, use_quality = use_quality, vocabulary = vocabulary,
126
                                     reverse_complement_encoding = reverse_complement_encoding, ...)
127
  }
128
  
129
}
130
131
132
#' Write output of specific model layer to h5 or csv file.
133
#'
134
#' Removes layers (optional) from pretrained model and calculates states of fasta file, writes states to h5/csv file.
135
#' Function combines fasta entries in file to one sequence. This means predictor sequences can contain elements from more than one fasta entry.
136
#' h5 file also contains sequence and positions of targets corresponding to states.
137
#'
138
#' @inheritParams generator_fasta_lm
139
#' @param layer_name Name of layer to get output from. If `NULL`, will use the last layer.
140
#' @param path_input Path to fasta file.
141
#' @param sequence Character string, ignores path_input if argument given.
142
#' @param round_digits Number of decimal places.
143
#' @param batch_size Number of samples to evaluate at once. Does not change output, only relevant for speed and memory.
144
#' @param step Frequency of sampling steps.
145
#' @param filename Filename to store states in. No file output if argument is `NULL`.
146
#' @param vocabulary Vector of allowed characters, character outside vocabulary get encoded as 0-vector.
147
#' @param return_states Logical scalar, return states matrix.
148
#' @param ambiguous_nuc `"zero"` or `"equal"`. 
149
#' @param verbose Whether to print model before and after removing layers.
150
#' @param output_type Either `"h5"` or `"csv"`.
151
#' @param model A keras model. 
152
#' @param mode Either `"lm"` for language model or `"label"` for label classification.
153
#' @param format Either `"fasta"` or `"fastq"`.
154
#' @param include_seq Whether to include input sequence in h5 file.
155
#' @param ... Further arguments for sequence encoding with \code{\link{seq_encoding_label}}.
156
#' @noRd
157
predict_model_one_seq <- function(model, layer_name = NULL, sequence = NULL, path_input = NULL, round_digits = 2,
158
                                  filename = "states.h5", step = 1, vocabulary = c("a", "c", "g", "t"), batch_size = 256, verbose = TRUE,
159
                                  return_states = FALSE, target_len = 1, use_quality = FALSE, quality_string = NULL,
160
                                  output_type = "h5", mode = "lm", lm_format = "target_right",
161
                                  ambiguous_nuc = "zero", padding = "none", format = "fasta", output_dir = NULL,
162
                                  include_seq = TRUE, reverse_complement_encoding = FALSE, ...) {
163
  
164
  vocabulary <- stringr::str_to_lower(vocabulary)
165
  stopifnot(mode %in% c("lm", "label"))
166
  stopifnot(output_type %in% c("h5", "csv"))
167
  if (!is.null(quality_string)) use_quality <- TRUE
168
  # if (!is.null(quality_string) & !is.null(sequence)) {
169
  #   stopifnot(length(sequence) == length(quality_to_probability(quality_string)))
170
  # }
171
  file_output <- !is.null(filename)
172
  if (!file_output) {
173
    if (!return_states) stop("If filename is NULL, return_states must be TRUE; otherwise function produces no output.")
174
    filename <- tempfile(fileext = paste0(".", output_type))
175
  }
176
  stopifnot(batch_size > 0)
177
  stopifnot(!file.exists(filename))
178
  if (reverse_complement_encoding) {
179
    test_len <- length(vocabulary) != 4
180
    if (test_len || all(sort(stringr::str_to_lower(vocabulary)) != c("a", "c", "g", "t"))) {
181
      stop("reverse_complement_encoding only implemented for A,C,G,T vocabulary")
182
    }
183
  }
184
  
185
  # token for ambiguous nucleotides
186
  for (i in letters) {
187
    if (!(i %in% stringr::str_to_lower(vocabulary))) {
188
      amb_nuc_token <- i
189
      break
190
    }
191
  }
192
  
193
  tokenizer <- keras::fit_text_tokenizer(keras::text_tokenizer(char_level = TRUE, lower = TRUE, oov_token = "0"), c(vocabulary, amb_nuc_token))
194
  
195
  if (is.null(layer_name)) {
196
    layer_name <- model$output_names
197
    if (verbose) message(paste("layer_name not specified. Using layer", layer_name))
198
  }
199
  
200
  if (!is.null(sequence) && (!missing(sequence) & sequence != "")) {
201
    nt_seq <- sequence %>% stringr::str_to_lower()
202
  } else {
203
    if (format == "fasta") {
204
      fasta.file <- microseq::readFasta(path_input)
205
    }
206
    if (format == "fastq") {
207
      fasta.file <- microseq::readFastq(path_input)
208
    }
209
    if (nrow(fasta.file) > 1 & verbose) {
210
      text_1 <- paste("Your file has", nrow(fasta.file), "entries. 'one_seq'  output_format will concatenate them to a single sequence.\n")
211
      text_2 <- "Use 'by_entry' or 'by_entry_one_file' output_format to evaluate them separately."
212
      message(paste0(text_1, text_2))
213
    }
214
    nt_seq <- paste(fasta.file$Sequence, collapse = "") %>% stringr::str_to_lower()
215
  }
216
  
217
  # tokenize ambiguous nt
218
  pattern <- paste0("[^", paste0(vocabulary, collapse = ""), "]")
219
  nt_seq <- stringr::str_replace_all(string = nt_seq, pattern = pattern, amb_nuc_token)
220
  
221
  if (use_quality & is.null(quality_string)) {
222
    quality_string <- paste(fasta.file$Quality, collapse = "")
223
  }
224
  
225
  # extract maxlen
226
  target_middle <- ifelse(mode == "lm" && (lm_format %in% c("target_middle_lstm", "target_middle_cnn")), TRUE, FALSE)
227
  if (!target_middle) {
228
    if (reverse_complement_encoding) {
229
      maxlen <- model$input[[1]]$shape[[2]]
230
    } else {
231
      maxlen <- model$input$shape[[2]]
232
    }
233
  } else {
234
    maxlen_1 <- model$input[[1]]$shape[[2]]
235
    maxlen_2 <- model$input[[2]]$shape[[2]]
236
    maxlen <- maxlen_1 + maxlen_2
237
  }
238
  
239
  total_seq_len <- ifelse(mode == "lm", maxlen + target_len, maxlen)
240
  
241
  # pad sequence
242
  unpadded_seq_len <- nchar(nt_seq)
243
  pad_len <- 0
244
  if (padding == "maxlen") {
245
    pad_len <- maxlen
246
  }
247
  if (padding == "standard" & (unpadded_seq_len < total_seq_len)) {
248
    pad_len <- total_seq_len - unpadded_seq_len
249
  }
250
  if (padding == "self"  & (unpadded_seq_len < total_seq_len)) {
251
    nt_seq <- strrep(nt_seq, ceiling(total_seq_len / unpadded_seq_len))
252
    nt_seq <- substr(nt_seq, 1, total_seq_len)
253
    if (use_quality) {
254
      quality_string <- strrep(quality_string, ceiling(total_seq_len / unpadded_seq_len))
255
      quality_string <- substr(quality_string, 1, total_seq_len) 
256
    }
257
  } else {
258
    nt_seq <- paste0(strrep("0", pad_len), nt_seq)
259
    if (use_quality) quality_string <- paste0(strrep("0", pad_len), quality_string)
260
  }
261
  
262
  if (nchar(nt_seq) < total_seq_len) {
263
    stop(paste0("Input sequence is shorter than required length (", total_seq_len, "). Use padding argument to pad sequence to bigger size."))
264
  }
265
  
266
  if (use_quality) {
267
    quality_vector <- quality_string %>% quality_to_probability()
268
  } else {
269
    quality_vector <- NULL
270
  }
271
  
272
  # start of samples
273
  start_indices <- seq(1, nchar(nt_seq) - total_seq_len + 1, by = step)
274
  num_samples <- length(start_indices)
275
  
276
  check_layer_name(model, layer_name)
277
  model <- tensorflow::tf$keras$Model(model$input, model$get_layer(layer_name)$output)
278
  if (verbose) {
279
    cat("Computing output for model at layer", layer_name,  "\n")
280
    print(model)
281
  }
282
  
283
  # extract number of neurons in last layer
284
  if (length(model$output$shape$dims) == 3) {
285
    if (!("lstm" %in% stringr::str_to_lower(model$output_names))) {
286
      stop("Output dimension of layer is > 1, format not supported yet")
287
    }
288
    layer.size <- model$output$shape[[3]]
289
  } else {
290
    layer.size <- model$output$shape[[2]]
291
  }
292
  
293
  # tokenize sequence
294
  nt_seq <- stringr::str_to_lower(nt_seq)
295
  tokSeq <- keras::texts_to_sequences(tokenizer, nt_seq)[[1]] - 1
296
  
297
  # seq end position
298
  pos_arg <- start_indices + total_seq_len - pad_len - 1
299
  
300
  if (include_seq) {
301
    output_seq <- substr(nt_seq, pad_len + 1, nchar(nt_seq))
302
  }
303
  
304
  # create h5 file to store states
305
  if (output_type == "h5") {
306
    h5_file <- hdf5r::H5File$new(filename, mode = "w")
307
    h5_file[["multi_entries"]] <- FALSE
308
    h5_file[["sample_end_position"]] <- pos_arg
309
    if (include_seq) h5_file[["sequence"]] <- output_seq
310
  }
311
  
312
  number_batches <- ceiling(length(start_indices)/batch_size)
313
  pred_list <- vector("list", number_batches)
314
  col_names <- c(as.character(1:layer.size), "sample_end_position")
315
  
316
  # subset input for target middle
317
  if (mode == "lm" && lm_format %in% c("target_middle_lstm", "target_middle_cnn")) {
318
    index_x_1 <- 1:ceiling((total_seq_len - target_len)/2)
319
    index_x_2 <- (max(index_x_1) + target_len + 1) : total_seq_len
320
  } 
321
  
322
  for (i in 1:number_batches) {
323
    
324
    index_start <- ((i - 1) * batch_size) + 1
325
    index_end <- min(c(num_samples + 1, index_start + batch_size)) - 1
326
    index <- index_start : index_end 
327
    
328
    x <- seq_encoding_label(sequence = tokSeq, 
329
                            maxlen = total_seq_len,
330
                            vocabulary = vocabulary,
331
                            start_ind = start_indices[index],
332
                            ambiguous_nuc = ambiguous_nuc,
333
                            tokenizer = NULL,
334
                            adjust_start_ind = FALSE,
335
                            quality_vector = quality_vector,
336
                            ...
337
    )
338
    
339
    if (mode == "lm" && lm_format == "target_middle_lstm") {
340
      x1 <- x[ , index_x_1, ]
341
      x2 <- x[ , index_x_2, ]
342
      
343
      if (length(index_x_1) == 1 | dim(x)[1] == 1) {
344
        x1 <- array(x1, dim = c(1, dim(x1)))
345
      }
346
      if (length(index_x_2) == 1 | dim(x)[1] == 1) {
347
        x2 <- array(x2, dim = c(1, dim(x2)))
348
      }
349
      
350
      x2 <- x2[ , dim(x2)[2]:1, ] # reverse order
351
      
352
      if (length(dim(x2)) == 2) {
353
        x2 <- array(x2, dim = c(1, dim(x2)))
354
      }
355
      
356
      x <- list(x1, x2)
357
    } 
358
    
359
    if (mode == "lm" && lm_format == "target_middle_cnn") {
360
      x <- x[ , c(index_x_1, index_x_2), ]
361
    } 
362
    
363
    if (reverse_complement_encoding) x <- list(x, reverse_complement_tensor(x))
364
    
365
    y <- stats::predict(model, x, verbose = 0)
366
    if (!is.null(round_digits)) y <- round(y, round_digits)
367
    pred_list[[i]] <- y
368
    
369
  }
370
  
371
  states <- do.call(rbind, pred_list)
372
  
373
  if (file_output) {
374
    if (output_type == "h5") {
375
      h5_file[["states"]] <- states
376
      h5_file$close_all()
377
    } else {
378
      col_names <- paste0("N", 1:ncol(states))
379
      colnames(states) <- col_names 
380
      utils::write.csv(x = states, file = filename, row.names = FALSE)
381
    }
382
  }
383
  
384
  if (return_states) {
385
    output_list <- list()
386
    output_list$states <- states
387
    output_list$sample_end_position <- pos_arg
388
    if (include_seq) output_list$sequence <- output_seq
389
    return(output_list)
390
  } else {
391
    return(NULL)
392
  }
393
}
394
395
#' Write states to h5 file
396
#'
397
#' @description Removes layers (optional) from pretrained model and calculates states of fasta file, writes a separate
398
#' h5 file for every fasta entry in fasta file. h5 files also contain the nucleotide sequence and positions of targets corresponding to states.
399
#' Names of output files are: file_path + "Nr" + i + filename + output_type, where i is the number of the fasta entry.
400
#'
401
#' @param filename Filename to store states, function adds "_nr_" + "i" after name, where i is entry number.
402
#' @param output_dir Path to folder, where to write output.
403
#' @noRd
404
predict_model_by_entry <- function(model, layer_name = NULL, path_input, round_digits = 2,
405
                                   filename = "states.h5", output_dir = NULL, step = 1, vocabulary = c("a", "c", "g", "t"),
406
                                   batch_size = 256, output_type = "h5", mode = "lm",
407
                                   lm_format = "target_right", format = "fasta", use_quality = FALSE,
408
                                   reverse_complement_encoding = FALSE, padding = "none", 
409
                                   verbose = FALSE, include_seq = FALSE, ambiguous_nuc = "zero", ...) {
410
  
411
  stopifnot(mode %in% c("lm", "label"))
412
  stopifnot(!is.null(filename))
413
  stopifnot(!is.null(output_dir))
414
  
415
  if (endsWith(filename, paste0(".", output_type))) {
416
    filename <- stringr::str_remove(filename, paste0(".", output_type, "$"))
417
    filename <- basename(filename)
418
  }
419
  
420
  # extract maxlen
421
  target_middle <- ifelse(mode == "lm" && (lm_format %in% c("target_middle_lstm", "target_middle_cnn")), TRUE, FALSE)
422
  if (!target_middle) {
423
    if (reverse_complement_encoding) {
424
      model$input[[1]]$shape[[2]]
425
    } else {
426
      maxlen <- model$input$shape[[2]]
427
    }
428
  } else {
429
    maxlen_1 <- model$input[[1]]$shape[[2]]
430
    maxlen_2 <- model$input[[2]]$shape[[2]]
431
    maxlen <- maxlen_1 + maxlen_2
432
  }
433
  
434
  # load fasta file
435
  if (format == "fasta") {
436
    fasta.file <- microseq::readFasta(path_input)
437
  }
438
  if (format == "fastq") {
439
    fasta.file <- microseq::readFastq(path_input)
440
  }
441
  df <- fasta.file[ , c("Sequence", "Header")]
442
  names(df) <- c("seq", "header")
443
  rownames(df) <- NULL
444
  num_skipped_seq <- 0
445
  
446
  for (i in 1:nrow(df)) {
447
    
448
    # skip entry if too short
449
    if ((nchar(df[i, "seq"]) < maxlen) & padding == "none") {
450
      num_skipped_seq <- num_skipped_seq + 1
451
      next
452
    } 
453
    
454
    if (use_quality) {
455
      quality_string <- fasta.file$Quality[i]
456
    } else {
457
      quality_string <- NULL
458
    }
459
    
460
    current_file <- paste0(output_dir, "/", filename, "_nr_", as.character(i), ".", output_type)
461
    
462
    predict_model_one_seq(layer_name = layer_name, sequence = df[i, "seq"],
463
                          round_digits = round_digits, path_input = path_input,
464
                          filename = current_file, quality_string = quality_string,
465
                          step = step, vocabulary = vocabulary, batch_size = batch_size,
466
                          verbose = ifelse(i > 1, FALSE, verbose), 
467
                          output_type = output_type, mode = mode,
468
                          lm_format = lm_format, model = model, include_seq = include_seq,
469
                          padding = padding, ambiguous_nuc = "zero", 
470
                          reverse_complement_encoding = reverse_complement_encoding, ...)
471
  }
472
  
473
  if (verbose & num_skipped_seq > 0) {
474
    message(paste0("Skipped ", num_skipped_seq,
475
                   ifelse(num_skipped_seq == 1, " entry", " entries"),
476
                   ". Use different padding option to evaluate all."))
477
  }  
478
  
479
}
480
481
#' Write states to h5 file
482
#'
483
#' @description Removes layers (optional) from pretrained model and calculates states of fasta file,
484
#' writes separate states matrix in one .h5 file for every fasta entry.
485
#' h5 file also contains the nucleotide sequences and positions of targets corresponding to states.
486
#' @noRd
487
predict_model_by_entry_one_file <- function(model, path_input, round_digits = 2, filename = "states.h5",
488
                                            step = 1,  vocabulary = c("a", "c", "g", "t"), batch_size = 256, layer_name = NULL,
489
                                            verbose = TRUE, mode = "lm", use_quality = FALSE,
490
                                            lm_format = "target_right", padding = "none",
491
                                            format = "fasta", include_seq = TRUE, reverse_complement_encoding = FALSE, ...) {
492
  
493
  vocabulary <- stringr::str_to_lower(vocabulary)
494
  stopifnot(mode %in% c("lm", "label"))
495
  
496
  target_middle <- ifelse(mode == "lm" && (lm_format %in% c("target_middle_lstm", "target_middle_cnn")), TRUE, FALSE)
497
  # extract maxlen
498
  if (!target_middle) {
499
    if (reverse_complement_encoding) {
500
      maxlen <- model$input[[1]]$shape[[2]]
501
    } else {
502
      maxlen <- model$input$shape[[2]]
503
    }
504
  } else {
505
    maxlen_1 <- model$input[[1]]$shape[[2]]
506
    maxlen_2 <- model$input[[2]]$shape[[2]]
507
    maxlen <- maxlen_1 + maxlen_2
508
  }
509
  
510
  # extract number of neurons in last layer
511
  if (length(model$output$shape$dims) == 3) {
512
    if (!("lstm" %in% stringr::str_to_lower(model$output_names))) {
513
      stop("Output dimension of layer is > 1, format not supported yet")
514
    }
515
    layer.size <- model$output$shape[[3]]
516
  } else {
517
    layer.size <- model$output$shape[[2]]
518
  }
519
  
520
  # load fasta file
521
  if (format == "fasta") {
522
    fasta.file <- microseq::readFasta(path_input)
523
  }
524
  if (format == "fastq") {
525
    fasta.file <- microseq::readFastq(path_input)
526
  }
527
  df <- fasta.file[ , c("Sequence", "Header")]
528
  names(df) <- c("seq", "header")
529
  rownames(df) <- NULL
530
  
531
  if (verbose) {
532
    # check if names are unique
533
    if (length(df$header) != length(unique(df$header))) {
534
      message("Header names are not unique, adding '_header_x' to names (x being the header number)")
535
      df$header <- paste0(df$header, paste0("_header_", 1:length(df$header)))
536
    }
537
  }
538
  
539
  # create h5 file to store states
540
  
541
  h5_file <- hdf5r::H5File$new(filename, mode = "w")
542
  h5_file[["multi_entries"]] <- TRUE
543
  states.grp <- h5_file$create_group("states")
544
  sample_end_position.grp <- h5_file$create_group("sample_end_position")
545
  if (include_seq) seq.grp <- h5_file$create_group("sequence")
546
  
547
  num_skipped_seq <- 0
548
  
549
  for (i in 1:nrow(df)) {
550
    
551
    #seq_name <- df$header[i]
552
    seq_name <- paste0("entry_", i)
553
    temp_file <- tempfile(fileext = ".h5")
554
    
555
    # skip entry if too short
556
    if ((nchar(df[i, "seq"]) < maxlen) & padding == "none") {
557
      num_skipped_seq <- num_skipped_seq + 1
558
      next
559
    } 
560
    
561
    if (use_quality) {
562
      quality_string <- fasta.file$Quality[i]
563
    } else {
564
      quality_string <- NULL
565
    }
566
    
567
    output_list <- predict_model_one_seq(layer_name = layer_name, sequence = df$seq[i], path_input = path_input,
568
                                         round_digits = round_digits, filename = temp_file, step = step, vocabulary = vocabulary,
569
                                         batch_size = batch_size, return_states = TRUE, quality_string = quality_string,
570
                                         output_type = "h5", model = model, mode = mode, lm_format = lm_format,
571
                                         ambiguous_nuc = "zero", verbose = ifelse(i > 1, FALSE, verbose), 
572
                                         padding = padding, format = format, include_seq = include_seq,
573
                                         reverse_complement_encoding = reverse_complement_encoding, ...)
574
    
575
    states.grp[[seq_name]] <- output_list$states
576
    sample_end_position.grp[[seq_name]] <- output_list$sample_end_position
577
    
578
    if (include_seq) seq.grp[[seq_name]] <- output_list$sequence
579
  }
580
  
581
  if (verbose & num_skipped_seq > 0) {
582
    message(paste0("Skipped ", num_skipped_seq,
583
                   ifelse(num_skipped_seq == 1, " entry", " entries"),
584
                   ". Use different padding option to evaluate all."))
585
  }  
586
  
587
  h5_file$close_all()
588
}
589
590
#' Get states for label classification model.
591
#'
592
#' Computes output at specified model layer. Forces every fasta entry to have length maxlen by either padding sequences shorter than maxlen or taking random subsequence for
593
#' longer sequences.
594
#'
595
#' @inheritParams predict_model_one_seq
596
#' @noRd
597
predict_model_one_pred_per_entry <- function(model, layer_name = NULL, path_input, round_digits = 2, format = "fasta",
598
                                             ambiguous_nuc = "zero", filename = "states.h5", padding = padding,
599
                                             vocabulary = c("a", "c", "g", "t"), batch_size = 256, verbose = TRUE,
600
                                             return_states = FALSE, reverse_complement_encoding = FALSE, 
601
                                             include_seq = FALSE, use_quality = FALSE, ...) {
602
  
603
  vocabulary <- stringr::str_to_lower(vocabulary)
604
  file_type <- "h5"
605
  stopifnot(batch_size > 0)
606
  stopifnot(!file.exists(filename))
607
  # token for ambiguous nucleotides
608
  for (i in letters) {
609
    if (!(i %in% stringr::str_to_lower(vocabulary))) {
610
      amb_nuc_token <- i
611
      break
612
    }
613
  }
614
  tokenizer <- keras::fit_text_tokenizer(keras::text_tokenizer(char_level = TRUE, lower = TRUE, oov_token = "N"), vocabulary)
615
  
616
  if (is.null(layer_name)) {
617
    layer_name <- model$output_names
618
    if (verbose) message(paste("layer_name not specified. Using layer", layer_name))
619
  }
620
621
  # extract maxlen
622
  if (reverse_complement_encoding) {
623
    maxlen <- model$input[[1]]$shape[[2]]
624
  } else {
625
    maxlen <- model$input$shape[[2]]
626
  }
627
  
628
  if (format == "fasta") {
629
    fasta.file <- microseq::readFasta(path_input)
630
  }
631
  if (format == "fastq") {
632
    fasta.file <- microseq::readFastq(path_input)
633
  }
634
  
635
  num_samples <- nrow(fasta.file)
636
  
637
  nucSeq <- as.character(fasta.file$Sequence)
638
  seq_length <- nchar(fasta.file$Sequence)
639
  if (use_quality) {
640
    quality_string <- vector("character", length(nucSeq)) 
641
  } else {
642
    quality_string <- NULL
643
  }
644
  
645
  for (i in 1:length(nucSeq)) {
646
    # take random subsequence
647
    if (seq_length[i] > maxlen) {
648
      start <- sample(1 : (seq_length[i] - maxlen + 1) , size = 1)
649
      nucSeq[i] <- substr(nucSeq[i], start = start, stop = start + maxlen - 1)
650
      if (use_quality) {
651
        quality_string[i] <- substr(fasta.file$Quality[i], start = start, stop = start + maxlen - 1)
652
      }  
653
    }
654
    # pad sequence
655
    if (seq_length[i] < maxlen) {
656
      nucSeq[i] <- paste0(paste(rep("N", maxlen - seq_length[i]), collapse = ""), nucSeq[i])
657
      if (use_quality) {
658
        quality_string[i] <- paste0(paste(rep("0", maxlen - seq_length[i]), collapse = ""), fasta.file$Quality[i])
659
      }  
660
    }
661
  }
662
  
663
  model <- tensorflow::tf$keras$Model(model$input, model$get_layer(layer_name)$output)
664
  if (verbose) {
665
    cat("Computing output for model at layer", layer_name,  "\n")
666
    print(model)
667
  } 
668
  
669
  # extract number of neurons in last layer
670
  if (length(model$output$shape$dims) == 3) {
671
    if (!("lstm" %in% stringr::str_to_lower(model$output_names))) {
672
      stop("Output dimension of layer is > 1, format not supported yet")
673
    }
674
    layer.size <- model$output$shape[[3]]
675
  } else {
676
    layer.size <- model$output$shape[[2]]
677
  }
678
  
679
  if (file_type == "h5") {
680
    # create h5 file to store states
681
    h5_file <- hdf5r::H5File$new(filename, mode = "w")
682
    if (!missing(path_input)) h5_file[["fasta_file"]] <- path_input
683
    
684
    h5_file[["header_names"]] <- fasta.file$Header
685
    if (include_seq) h5_file[["sequences"]] <- nucSeq
686
    h5_file[["states"]] <- array(0, dim = c(0, layer.size))
687
    h5_file[["multi_entries"]] <- FALSE
688
    writer <- h5_file[["states"]]
689
  }
690
  
691
  rm(fasta.file)
692
  #nucSeq <- paste(nucSeq, collapse = "") %>% stringr::str_to_lower()
693
  number_batches <- ceiling(num_samples/batch_size)
694
  if (verbose) cat("Evaluating", number_batches, ifelse(number_batches > 1, "batches", "batch"), "\n")
695
  row <- 1
696
  string_start_index <- 1 
697
  ten_percent_steps <- seq(number_batches/10, number_batches, length.out = 10)
698
  percentage_index <- 1
699
  
700
  if (number_batches > 1) {
701
    for (i in 1:(number_batches - 1)) {
702
      string_end_index <-string_start_index + batch_size - 1 
703
      char_seq <- nucSeq[string_start_index : string_end_index] %>% paste(collapse = "") 
704
      
705
      if (use_quality) {
706
        quality_string_subset <- quality_string[string_start_index : string_end_index] %>% paste(collapse = "") 
707
      } else {
708
        quality_string_subset <- NULL
709
      }
710
      
711
      if (i == 1) start_ind <- seq(1, nchar(char_seq), maxlen)
712
      one_hot_batch <- seq_encoding_label(sequence = NULL, maxlen = maxlen, vocabulary = vocabulary,
713
                                          start_ind = start_ind, ambiguous_nuc = ambiguous_nuc, 
714
                                          char_sequence = char_seq, quality_vector = quality_string_subset,
715
                                          tokenizer = tokenizer, adjust_start_ind = TRUE, ...) 
716
      if (reverse_complement_encoding) one_hot_batch <- list(one_hot_batch, reverse_complement_tensor(one_hot_batch))
717
      activations <- keras::predict_on_batch(model, one_hot_batch)
718
      writer[row : (row + batch_size - 1), ] <- activations
719
      row <- row + batch_size
720
      string_start_index <- string_end_index + 1
721
      
722
      if (verbose & (i > ten_percent_steps[percentage_index]) & percentage_index < 10) {
723
        cat("Progress: ", percentage_index * 10 ,"% \n")
724
        percentage_index <- percentage_index + 1
725
      }
726
      
727
    }
728
  }
729
  
730
  # last batch might be shorter
731
  char_seq <- nucSeq[string_start_index : length(nucSeq)] %>% paste(collapse = "") 
732
  if (use_quality) {
733
    quality_string_subset <- quality_string[string_start_index : length(nucSeq)] %>% paste(collapse = "") 
734
  } else {
735
    quality_string_subset <- NULL
736
  }
737
  one_hot_batch <- seq_encoding_label(sequence = NULL, maxlen = maxlen, vocabulary = vocabulary,
738
                                      start_ind = seq(1, nchar(char_seq), maxlen), ambiguous_nuc = "zero", nuc_dist = NULL,
739
                                      quality_vector = quality_string_subset, use_coverage = FALSE, max_cov = NULL,
740
                                      cov_vector = NULL, n_gram = NULL, n_gram_stride = 1, char_sequence = char_seq,
741
                                      tokenizer = tokenizer, adjust_start_ind = TRUE, ...) 
742
  if (reverse_complement_encoding) one_hot_batch <- list(one_hot_batch, reverse_complement_tensor(one_hot_batch))
743
  activations <- keras::predict_on_batch(model, one_hot_batch)
744
  writer[row : num_samples, ] <- activations[1 : length(row:num_samples), ]
745
  
746
  if (verbose) cat("Progress: 100 % \n")
747
  
748
  if (return_states & (file_type == "h5")) states <- writer[ , ]
749
  if (file_type == "h5") h5_file$close_all()
750
  if (return_states) return(states)
751
}
752
753
754
#' Read states from h5 file
755
#'
756
#' Reads h5 file created by  \code{\link{predict_model}} function.
757
#'
758
#' @param h5_path Path to h5 file.
759
#' @param rows Range of rows to read. If `NULL` read everything.
760
#' @param get_sample_position Return position of sample corresponding to state if `TRUE`.
761
#' @param get_seq Return nucleotide sequence if `TRUE`.
762
#' @param verbose Boolean.
763
#' @examplesIf reticulate::py_module_available("tensorflow")
764
#' # make prediction for single sequence and write to h5 file
765
#' model <- create_model_lstm_cnn(maxlen = 20, layer_lstm = 8, layer_dense = 2, verbose = FALSE)
766
#' vocabulary <- c("a", "c", "g", "t")
767
#' sequence <- paste(sample(vocabulary, 200, replace = TRUE), collapse = "")
768
#' output_file <- tempfile(fileext = ".h5")
769
#' predict_model(output_format = "one_seq", model = model, step = 10,
770
#'               sequence = sequence, filename = output_file, mode = "label")
771
#' load_prediction(h5_path = output_file)
772
#' 
773
#' @returns A list of data frames, containing model predictions and sequence positions.
774
#' @export
775
load_prediction <- function(h5_path, rows = NULL, verbose = FALSE,
776
                            get_sample_position = FALSE, get_seq = FALSE) {
777
  
778
  if (is.null(rows)) complete <- TRUE
779
  h5_file <- hdf5r::H5File$new(h5_path, mode = "r")
780
  
781
  multi_entries <- ifelse(h5_file[["multi_entries"]][], TRUE, FALSE)
782
  if (!multi_entries) {
783
    number_entries <- 1
784
  } else {
785
    entry_names <- names(h5_file[["states"]])
786
    number_entries <- length(entry_names)
787
    output_list <- list()
788
  }
789
  
790
  train_mode <- "label"
791
  
792
  if (get_sample_position & !any(c("sample_end_position", "target_position") %in% names(h5_file))) {
793
    get_sample_position <- FALSE
794
    message("File does not contain target positions.")
795
  }
796
  
797
  if (!multi_entries) {
798
    
799
    read_states <- h5_file[["states"]]
800
    
801
    if (get_sample_position) {
802
      read_targetPos <- h5_file[["sample_end_position"]]
803
    }
804
    
805
    if (verbose) {
806
      cat("states matrix has", dim(read_states[ , ])[1], "rows and",  dim(read_states[ , ])[2], "columns \n")
807
    }
808
    if (complete) {
809
      states <- read_states[ , ]
810
      if (get_sample_position) {
811
        targetPos <- read_targetPos[ ]
812
      }
813
    } else {
814
      states <- read_states[rows, ]
815
      if (get_sample_position) {
816
        targetPos <- read_targetPos[rows]
817
      }
818
    }
819
    
820
    if (is.null(dim(states))) {
821
      states <- matrix(states, nrow = 1)
822
    }
823
    
824
    contains_seq <- FALSE
825
    if (get_seq) {
826
      if ("sequence" %in% names(h5_file)) {
827
        contains_seq <- TRUE
828
        sequence <- h5_file[["sequence"]][]
829
      } else {
830
        contains_seq <- FALSE
831
        message("File does not contain sequence.")
832
      }
833
    }
834
    
835
    h5_file$close_all()
836
    output_list <- list(states = states)
837
    if (get_sample_position) {
838
      if (train_mode == "label") {
839
        output_list$sample_end_position <- targetPos
840
      } else {
841
        output_list$target_position <- targetPos
842
      }
843
    }
844
    
845
    if (get_seq && contains_seq) {
846
      output_list$sequence <- sequence
847
    }
848
    
849
    return(output_list)
850
    
851
    # multi entries
852
  } else {
853
    
854
    if (verbose) {
855
      cat("file contains", number_entries, "entries \n")
856
    }
857
    
858
    if (get_sample_position) {
859
      target_name <- "sample_end_position"
860
    }
861
    
862
    if (get_seq & !("sequence" %in% names(h5_file))) {
863
      message("File does not contain sequence.")
864
      get_seq <- FALSE
865
    }
866
    
867
    for (i in 1:number_entries) {
868
      
869
      entry_name <- entry_names[i]
870
      states <- h5_file[["states"]][[entry_name]][ , ]
871
      if (is.null(dim(states))) {
872
        states <- matrix(states, nrow = 1)
873
      }
874
      
875
      if (get_seq) {
876
        sequence <- h5_file[["sequence"]][[entry_name]][ ]
877
      }
878
      
879
      if (get_sample_position) {
880
        targetPos <- h5_file[[target_name]][[entry_name]][ ]
881
      }
882
      
883
      if (!complete) {
884
        states <- states[rows, ]
885
        
886
        if (is.null(dim(states))) {
887
          states <- matrix(states, nrow = 1)
888
        }
889
        
890
        if (get_sample_position) {
891
          targetPos <- hdf5r::h5attr(read_states, target_name)[rows]
892
        }
893
      }
894
      
895
      if (get_sample_position) {
896
        l <- list(states = states, sample_end_position = targetPos)
897
      } else {
898
        l <- list(states = states)
899
      }
900
      
901
      if (get_seq) {
902
        l[["sequence"]] <- sequence
903
      }
904
      output_list[[entry_name]] <- l
905
    }
906
    h5_file$close_all()
907
    names(output_list) <- entry_names
908
    return(output_list)
909
  }
910
}
911
912
#' Create summary of predictions 
913
#'
914
#' Create summary data frame for confidence predictions over 1 or several state files or a data frame.
915
#' Columns in file or data frame should be confidence predictions for one class,
916
#' i.e. each rows should sum to 1 and have nonnegative entries. 
917
#' Output data frame contains average confidence scores, max score and percentage of votes for each class.
918
#'
919
#' @param states_path Folder containing state files or a single file with same ending as `file_type`.
920
#' @param label_names Names of predicted classes.
921
#' @param file_type `"h5"` or `"csv"`.
922
#' @param df A states data frame. Ignore `states_dir` argument if not `NULL`.
923
#' @examples 
924
#' m <- c(0.9,  0.1, 0.2, 0.01,
925
#'        0.05, 0.7, 0.2, 0,
926
#'        0.05, 0.2, 0.6, 0.99) %>% matrix(ncol = 3)
927
#' 
928
#' label_names <- paste0("class_", 1:3)
929
#' df <- as.data.frame(m)
930
#' pred_summary <- summarize_states(label_names = label_names, df = df)
931
#' pred_summary
932
#' 
933
#' @returns A data frame of predictions summaries.
934
#' @export
935
summarize_states <- function(states_path = NULL, label_names = NULL, file_type = "h5", df = NULL) {
936
  
937
  if (!is.null(df)) {
938
    states_path <- NULL
939
  }
940
  
941
  if (is.null(states_path)) {
942
    state_files <- 1
943
  } else {
944
    if (endsWith(states_path, file_type)) {
945
      state_files <- states_path
946
    } else {
947
      state_files <- list.files(states_path, full.names = TRUE)
948
    }
949
  }
950
  
951
  if (!is.null(label_names)) {
952
    num_labels <- length(label_names)
953
  }
954
  
955
  summary_list <- list()
956
  
957
  for (state_file in state_files) {
958
    
959
    if (is.null(df)) {
960
      if (file_type == "h5") {
961
        df <- load_prediction(h5_path = state_file, get_sample_position = FALSE, verbose = FALSE)
962
        df <- as.data.frame(df$states)
963
      }
964
      if (file_type == "csv") {
965
        df <- utils::read.csv(state_file)
966
        if (ncol(df) != num_labels) {
967
          df <- utils::read.csv2(state_file)
968
        }
969
      }
970
    }
971
    
972
    if (state_file == state_files[1] & is.null(label_names)) {
973
      label_names <- paste0("X_", 1:ncol(df))
974
      num_labels <- length(label_names)
975
    }
976
    
977
    stopifnot(ncol(df) == num_labels)
978
    
979
    names(df) <- c(label_names)
980
    
981
    mean_df <- data.frame(matrix(0, nrow = 1, ncol = num_labels))
982
    names(mean_df) <- paste0("mean_conf_", label_names)
983
    max_df <- data.frame(matrix(0, nrow = 1, ncol = num_labels))
984
    names(max_df) <- paste0("max_conf_", label_names)
985
    
986
    for (label in label_names) {
987
      mean_df[[paste0("mean_conf_", label)]] <-  mean(df[[label]])
988
      max_df[[paste0("max_conf_", label)]] <-  max(df[[label]])
989
    }
990
    
991
    vote_distribution <- apply(df[label_names], 1, which.max)
992
    vote_perc <- table(factor(vote_distribution, levels = 1:length(label_names)))/length(vote_distribution)
993
    votes_df <- data.frame(matrix(vote_perc, nrow = 1, ncol = num_labels))
994
    names(votes_df) <- paste0("vote_perc_", label_names)
995
    
996
    mean_prediction <- label_names[which.max(unlist(mean_df))]
997
    max_prediction <- label_names[which.max(unlist(max_df))]
998
    vote_prediction <- label_names[which.max(vote_perc)]
999
    
1000
    if (is.null(states_path)) {
1001
      file_name <- NA
1002
    } else {
1003
      file_name <- basename(state_file)
1004
    }
1005
    
1006
    summary_list[[state_file]] <- data.frame(file_name, mean_df, max_df, votes_df,
1007
                                             mean_prediction, max_prediction, vote_prediction,
1008
                                             num_prediction = nrow(df))
1009
    
1010
  }
1011
  
1012
  summary_df <- data.table::rbindlist(summary_list)
1013
  return(summary_df)
1014
}