|
a |
|
b/R/predict.R |
|
|
1 |
#' Make prediction for nucleotide sequence or entries in fasta/fastq file |
|
|
2 |
#' |
|
|
3 |
#' @description Removes layers (optional) from pretrained model and calculates states of fasta/fastq file or nucleotide sequence. |
|
|
4 |
#' Writes states to h5 or csv file (access content of h5 output with \code{\link{load_prediction}} function). |
|
|
5 |
#' There are several options on how to process an input file: |
|
|
6 |
#' \itemize{ |
|
|
7 |
#' \item If `"one_seq"`, computes prediction for sequence argument or fasta/fastq file. |
|
|
8 |
#' Combines fasta entries in file to one sequence. This means predictor sequences can contain elements from more than one fasta entry. |
|
|
9 |
#' \item If `"by_entry"`, will output a separate file for each fasta/fastq entry. |
|
|
10 |
#' Names of output files are: `output_dir` + "Nr" + i + `filename` + `output_type`, where i is the number of the fasta entry. |
|
|
11 |
#' \item If `"by_entry_one_file"`, will store prediction for all fasta entries in one h5 file. |
|
|
12 |
#' \item If `"one_pred_per_entry"`, will make one prediction for each entry by either picking random sample for long sequences |
|
|
13 |
#' or pad sequence for short sequences. |
|
|
14 |
#' } |
|
|
15 |
#' |
|
|
16 |
#' @inheritParams get_generator |
|
|
17 |
#' @inheritParams train_model |
|
|
18 |
#' @inheritParams get_generator |
|
|
19 |
#' @inheritParams train_model |
|
|
20 |
#' @param layer_name Name of layer to get output from. If `NULL`, will use the last layer. |
|
|
21 |
#' @param path_input Path to fasta file. |
|
|
22 |
#' @param sequence Character string, ignores path_input if argument given. |
|
|
23 |
#' @param round_digits Number of decimal places. |
|
|
24 |
#' @param mode Either `"lm"` for language model or `"label"` for label classification. |
|
|
25 |
#' @param include_seq Whether to include input sequence in h5 file. |
|
|
26 |
#' @param output_format Either `"one_seq"`, `"by_entry"`, `"by_entry_one_file"`, `"one_pred_per_entry"`. |
|
|
27 |
#' @param output_type `"h5"` or `"csv"`. If `output_format`` is `"by_entries_one_file", "one_pred_per_entry"` can only be `"h5"`. |
|
|
28 |
#' @param return_states Return predictions as data frame. Only supported for output_format `"one_seq"`. |
|
|
29 |
#' @param padding Either `"none"`, `"maxlen"`, `"standard"` or `"self"`. |
|
|
30 |
#' \itemize{ |
|
|
31 |
#' \item If `"none"`, apply no padding and skip sequences that are too short. |
|
|
32 |
#' \item If `"maxlen"`, pad with maxlen number of zeros vectors. |
|
|
33 |
#' \item If `"standard"`, pad with zero vectors only if sequence is shorter than maxlen. Pads to minimum size required for one prediction. |
|
|
34 |
#' \item If `"self"`, concatenate sequence with itself until sequence is long enough for one prediction. |
|
|
35 |
#' Example: if sequence is "ACGT" and maxlen is 10, make prediction for "ACGTACGTAC". |
|
|
36 |
#' Only applied if sequence is shorter than maxlen. |
|
|
37 |
#' } |
|
|
38 |
#' @param verbose Boolean. |
|
|
39 |
#' @param filename Filename to store states in. No file output if argument is `NULL`. |
|
|
40 |
#' If `output_format = "by_entry"`, adds "_nr_" + "i" after name, where i is entry number. |
|
|
41 |
#' @param output_dir Directory for file output. |
|
|
42 |
#' @param use_quality Whether to use quality scores. |
|
|
43 |
#' @param quality_string String for encoding with quality scores (as used in fastq format). |
|
|
44 |
#' @param lm_format Either `"target_right"`, `"target_middle_lstm"`, `"target_middle_cnn"` or `"wavenet"`. |
|
|
45 |
#' @param ... Further arguments for sequence encoding with \code{\link{seq_encoding_label}}. |
|
|
46 |
#' @examplesIf reticulate::py_module_available("tensorflow") |
|
|
47 |
#' # make prediction for single sequence and write to h5 file |
|
|
48 |
#' model <- create_model_lstm_cnn(maxlen = 20, layer_lstm = 8, layer_dense = 2, verbose = FALSE) |
|
|
49 |
#' vocabulary <- c("a", "c", "g", "t") |
|
|
50 |
#' sequence <- paste(sample(vocabulary, 200, replace = TRUE), collapse = "") |
|
|
51 |
#' output_file <- tempfile(fileext = ".h5") |
|
|
52 |
#' predict_model(output_format = "one_seq", model = model, step = 10, |
|
|
53 |
#' sequence = sequence, filename = output_file, mode = "label") |
|
|
54 |
#' |
|
|
55 |
#' # make prediction for fasta file with multiple entries, write output to separate h5 files |
|
|
56 |
#' fasta_path <- tempfile(fileext = ".fasta") |
|
|
57 |
#' create_dummy_data(file_path = fasta_path, num_files = 1, |
|
|
58 |
#' num_seq = 5, seq_length = 100, |
|
|
59 |
#' write_to_file_path = TRUE) |
|
|
60 |
#' model <- create_model_lstm_cnn(maxlen = 20, layer_lstm = 8, layer_dense = 2, verbose = FALSE) |
|
|
61 |
#' output_dir <- tempfile() |
|
|
62 |
#' dir.create(output_dir) |
|
|
63 |
#' predict_model(output_format = "by_entry", model = model, step = 10, verbose = FALSE, |
|
|
64 |
#' output_dir = output_dir, mode = "label", path_input = fasta_path) |
|
|
65 |
#' list.files(output_dir) |
|
|
66 |
#' |
|
|
67 |
#' @returns If `return_states = TRUE` returns a list of model predictions and position of corresponding sequences. |
|
|
68 |
#' If additionally `include_seq = TRUE`, list contains sequence strings. |
|
|
69 |
#' If `return_states = FALSE` returns nothing, just writes output to file(s). |
|
|
70 |
#' @export |
|
|
71 |
predict_model <- function(model, output_format = "one_seq", layer_name = NULL, sequence = NULL, path_input = NULL, |
|
|
72 |
round_digits = NULL, filename = "states.h5", step = 1, vocabulary = c("a", "c", "g", "t"), |
|
|
73 |
batch_size = 256, verbose = TRUE, return_states = FALSE, |
|
|
74 |
output_type = "h5", padding = "none", use_quality = FALSE, quality_string = NULL, |
|
|
75 |
mode = "label", lm_format = "target_right", output_dir = NULL, |
|
|
76 |
format = "fasta", include_seq = FALSE, reverse_complement_encoding = FALSE, |
|
|
77 |
ambiguous_nuc = "zero", ...) { |
|
|
78 |
|
|
|
79 |
stopifnot(padding %in% c("standard", "self", "none", "maxlen")) |
|
|
80 |
stopifnot(output_format %in% c("one_seq", "by_entry", "by_entry_one_file", "one_pred_per_entry")) |
|
|
81 |
if (output_format %in% c("by_entry_one_file", "one_pred_per_entry") & output_type == "csv") { |
|
|
82 |
message("by_entry_one_file or one_pred_per_entry only implemented for h5 output. |
|
|
83 |
Setting output_type to h5") |
|
|
84 |
output_type <- "h5" |
|
|
85 |
} |
|
|
86 |
|
|
|
87 |
if (output_format == "one_seq") { |
|
|
88 |
output_list <- predict_model_one_seq(layer_name = layer_name, sequence = sequence, path_input = path_input, |
|
|
89 |
round_digits = round_digits, filename = filename, step = step, vocabulary = vocabulary, |
|
|
90 |
batch_size = batch_size, verbose = verbose, return_states = return_states, |
|
|
91 |
padding = padding, quality_string = quality_string, use_quality = use_quality, |
|
|
92 |
output_type = output_type, model = model, mode = mode, lm_format = lm_format, |
|
|
93 |
format = format, include_seq = include_seq, ambiguous_nuc = ambiguous_nuc, |
|
|
94 |
reverse_complement_encoding = reverse_complement_encoding, ...) |
|
|
95 |
return(output_list) |
|
|
96 |
} |
|
|
97 |
|
|
|
98 |
if (output_format == "by_entry") { |
|
|
99 |
predict_model_by_entry(layer_name = layer_name, path_input = path_input, |
|
|
100 |
round_digits = round_digits, filename = filename, step = step, |
|
|
101 |
# vocabulary = vocabulary, quality_string = quality_string, ambiguous_nuc = ambiguous_nuc, |
|
|
102 |
batch_size = batch_size, verbose = verbose, use_quality = use_quality, |
|
|
103 |
output_type = output_type, model = model, mode = mode, lm_format = lm_format, |
|
|
104 |
output_dir = output_dir, format = format, include_seq = include_seq, |
|
|
105 |
padding = padding, reverse_complement_encoding = reverse_complement_encoding, ...) |
|
|
106 |
} |
|
|
107 |
|
|
|
108 |
if (output_format == "by_entry_one_file") { |
|
|
109 |
predict_model_by_entry_one_file(layer_name = layer_name, path_input = path_input, |
|
|
110 |
round_digits = round_digits, filename = filename, step = step, |
|
|
111 |
# vocabulary = vocabulary, quality_string = quality_string, ambiguous_nuc = ambiguous_nuc, |
|
|
112 |
batch_size = batch_size, verbose = verbose, use_quality = use_quality, |
|
|
113 |
model = model, mode = mode, lm_format = lm_format, format = format, |
|
|
114 |
padding = padding, include_seq = include_seq, |
|
|
115 |
reverse_complement_encoding = reverse_complement_encoding, ...) |
|
|
116 |
} |
|
|
117 |
|
|
|
118 |
if (output_format == "one_pred_per_entry") { |
|
|
119 |
if (mode == "lm") { |
|
|
120 |
stop("one_pred_per_entry only implemented for label classification") |
|
|
121 |
} |
|
|
122 |
predict_model_one_pred_per_entry(layer_name = layer_name, path_input = path_input, |
|
|
123 |
round_digits = round_digits, filename = filename, |
|
|
124 |
batch_size = batch_size, verbose = verbose, model = model, format = format, |
|
|
125 |
# ambiguous_nuc = ambiguous_nuc, use_quality = use_quality, vocabulary = vocabulary, |
|
|
126 |
reverse_complement_encoding = reverse_complement_encoding, ...) |
|
|
127 |
} |
|
|
128 |
|
|
|
129 |
} |
|
|
130 |
|
|
|
131 |
|
|
|
132 |
#' Write output of specific model layer to h5 or csv file. |
|
|
133 |
#' |
|
|
134 |
#' Removes layers (optional) from pretrained model and calculates states of fasta file, writes states to h5/csv file. |
|
|
135 |
#' Function combines fasta entries in file to one sequence. This means predictor sequences can contain elements from more than one fasta entry. |
|
|
136 |
#' h5 file also contains sequence and positions of targets corresponding to states. |
|
|
137 |
#' |
|
|
138 |
#' @inheritParams generator_fasta_lm |
|
|
139 |
#' @param layer_name Name of layer to get output from. If `NULL`, will use the last layer. |
|
|
140 |
#' @param path_input Path to fasta file. |
|
|
141 |
#' @param sequence Character string, ignores path_input if argument given. |
|
|
142 |
#' @param round_digits Number of decimal places. |
|
|
143 |
#' @param batch_size Number of samples to evaluate at once. Does not change output, only relevant for speed and memory. |
|
|
144 |
#' @param step Frequency of sampling steps. |
|
|
145 |
#' @param filename Filename to store states in. No file output if argument is `NULL`. |
|
|
146 |
#' @param vocabulary Vector of allowed characters, character outside vocabulary get encoded as 0-vector. |
|
|
147 |
#' @param return_states Logical scalar, return states matrix. |
|
|
148 |
#' @param ambiguous_nuc `"zero"` or `"equal"`. |
|
|
149 |
#' @param verbose Whether to print model before and after removing layers. |
|
|
150 |
#' @param output_type Either `"h5"` or `"csv"`. |
|
|
151 |
#' @param model A keras model. |
|
|
152 |
#' @param mode Either `"lm"` for language model or `"label"` for label classification. |
|
|
153 |
#' @param format Either `"fasta"` or `"fastq"`. |
|
|
154 |
#' @param include_seq Whether to include input sequence in h5 file. |
|
|
155 |
#' @param ... Further arguments for sequence encoding with \code{\link{seq_encoding_label}}. |
|
|
156 |
#' @noRd |
|
|
157 |
predict_model_one_seq <- function(model, layer_name = NULL, sequence = NULL, path_input = NULL, round_digits = 2, |
|
|
158 |
filename = "states.h5", step = 1, vocabulary = c("a", "c", "g", "t"), batch_size = 256, verbose = TRUE, |
|
|
159 |
return_states = FALSE, target_len = 1, use_quality = FALSE, quality_string = NULL, |
|
|
160 |
output_type = "h5", mode = "lm", lm_format = "target_right", |
|
|
161 |
ambiguous_nuc = "zero", padding = "none", format = "fasta", output_dir = NULL, |
|
|
162 |
include_seq = TRUE, reverse_complement_encoding = FALSE, ...) { |
|
|
163 |
|
|
|
164 |
vocabulary <- stringr::str_to_lower(vocabulary) |
|
|
165 |
stopifnot(mode %in% c("lm", "label")) |
|
|
166 |
stopifnot(output_type %in% c("h5", "csv")) |
|
|
167 |
if (!is.null(quality_string)) use_quality <- TRUE |
|
|
168 |
# if (!is.null(quality_string) & !is.null(sequence)) { |
|
|
169 |
# stopifnot(length(sequence) == length(quality_to_probability(quality_string))) |
|
|
170 |
# } |
|
|
171 |
file_output <- !is.null(filename) |
|
|
172 |
if (!file_output) { |
|
|
173 |
if (!return_states) stop("If filename is NULL, return_states must be TRUE; otherwise function produces no output.") |
|
|
174 |
filename <- tempfile(fileext = paste0(".", output_type)) |
|
|
175 |
} |
|
|
176 |
stopifnot(batch_size > 0) |
|
|
177 |
stopifnot(!file.exists(filename)) |
|
|
178 |
if (reverse_complement_encoding) { |
|
|
179 |
test_len <- length(vocabulary) != 4 |
|
|
180 |
if (test_len || all(sort(stringr::str_to_lower(vocabulary)) != c("a", "c", "g", "t"))) { |
|
|
181 |
stop("reverse_complement_encoding only implemented for A,C,G,T vocabulary") |
|
|
182 |
} |
|
|
183 |
} |
|
|
184 |
|
|
|
185 |
# token for ambiguous nucleotides |
|
|
186 |
for (i in letters) { |
|
|
187 |
if (!(i %in% stringr::str_to_lower(vocabulary))) { |
|
|
188 |
amb_nuc_token <- i |
|
|
189 |
break |
|
|
190 |
} |
|
|
191 |
} |
|
|
192 |
|
|
|
193 |
tokenizer <- keras::fit_text_tokenizer(keras::text_tokenizer(char_level = TRUE, lower = TRUE, oov_token = "0"), c(vocabulary, amb_nuc_token)) |
|
|
194 |
|
|
|
195 |
if (is.null(layer_name)) { |
|
|
196 |
layer_name <- model$output_names |
|
|
197 |
if (verbose) message(paste("layer_name not specified. Using layer", layer_name)) |
|
|
198 |
} |
|
|
199 |
|
|
|
200 |
if (!is.null(sequence) && (!missing(sequence) & sequence != "")) { |
|
|
201 |
nt_seq <- sequence %>% stringr::str_to_lower() |
|
|
202 |
} else { |
|
|
203 |
if (format == "fasta") { |
|
|
204 |
fasta.file <- microseq::readFasta(path_input) |
|
|
205 |
} |
|
|
206 |
if (format == "fastq") { |
|
|
207 |
fasta.file <- microseq::readFastq(path_input) |
|
|
208 |
} |
|
|
209 |
if (nrow(fasta.file) > 1 & verbose) { |
|
|
210 |
text_1 <- paste("Your file has", nrow(fasta.file), "entries. 'one_seq' output_format will concatenate them to a single sequence.\n") |
|
|
211 |
text_2 <- "Use 'by_entry' or 'by_entry_one_file' output_format to evaluate them separately." |
|
|
212 |
message(paste0(text_1, text_2)) |
|
|
213 |
} |
|
|
214 |
nt_seq <- paste(fasta.file$Sequence, collapse = "") %>% stringr::str_to_lower() |
|
|
215 |
} |
|
|
216 |
|
|
|
217 |
# tokenize ambiguous nt |
|
|
218 |
pattern <- paste0("[^", paste0(vocabulary, collapse = ""), "]") |
|
|
219 |
nt_seq <- stringr::str_replace_all(string = nt_seq, pattern = pattern, amb_nuc_token) |
|
|
220 |
|
|
|
221 |
if (use_quality & is.null(quality_string)) { |
|
|
222 |
quality_string <- paste(fasta.file$Quality, collapse = "") |
|
|
223 |
} |
|
|
224 |
|
|
|
225 |
# extract maxlen |
|
|
226 |
target_middle <- ifelse(mode == "lm" && (lm_format %in% c("target_middle_lstm", "target_middle_cnn")), TRUE, FALSE) |
|
|
227 |
if (!target_middle) { |
|
|
228 |
if (reverse_complement_encoding) { |
|
|
229 |
maxlen <- model$input[[1]]$shape[[2]] |
|
|
230 |
} else { |
|
|
231 |
maxlen <- model$input$shape[[2]] |
|
|
232 |
} |
|
|
233 |
} else { |
|
|
234 |
maxlen_1 <- model$input[[1]]$shape[[2]] |
|
|
235 |
maxlen_2 <- model$input[[2]]$shape[[2]] |
|
|
236 |
maxlen <- maxlen_1 + maxlen_2 |
|
|
237 |
} |
|
|
238 |
|
|
|
239 |
total_seq_len <- ifelse(mode == "lm", maxlen + target_len, maxlen) |
|
|
240 |
|
|
|
241 |
# pad sequence |
|
|
242 |
unpadded_seq_len <- nchar(nt_seq) |
|
|
243 |
pad_len <- 0 |
|
|
244 |
if (padding == "maxlen") { |
|
|
245 |
pad_len <- maxlen |
|
|
246 |
} |
|
|
247 |
if (padding == "standard" & (unpadded_seq_len < total_seq_len)) { |
|
|
248 |
pad_len <- total_seq_len - unpadded_seq_len |
|
|
249 |
} |
|
|
250 |
if (padding == "self" & (unpadded_seq_len < total_seq_len)) { |
|
|
251 |
nt_seq <- strrep(nt_seq, ceiling(total_seq_len / unpadded_seq_len)) |
|
|
252 |
nt_seq <- substr(nt_seq, 1, total_seq_len) |
|
|
253 |
if (use_quality) { |
|
|
254 |
quality_string <- strrep(quality_string, ceiling(total_seq_len / unpadded_seq_len)) |
|
|
255 |
quality_string <- substr(quality_string, 1, total_seq_len) |
|
|
256 |
} |
|
|
257 |
} else { |
|
|
258 |
nt_seq <- paste0(strrep("0", pad_len), nt_seq) |
|
|
259 |
if (use_quality) quality_string <- paste0(strrep("0", pad_len), quality_string) |
|
|
260 |
} |
|
|
261 |
|
|
|
262 |
if (nchar(nt_seq) < total_seq_len) { |
|
|
263 |
stop(paste0("Input sequence is shorter than required length (", total_seq_len, "). Use padding argument to pad sequence to bigger size.")) |
|
|
264 |
} |
|
|
265 |
|
|
|
266 |
if (use_quality) { |
|
|
267 |
quality_vector <- quality_string %>% quality_to_probability() |
|
|
268 |
} else { |
|
|
269 |
quality_vector <- NULL |
|
|
270 |
} |
|
|
271 |
|
|
|
272 |
# start of samples |
|
|
273 |
start_indices <- seq(1, nchar(nt_seq) - total_seq_len + 1, by = step) |
|
|
274 |
num_samples <- length(start_indices) |
|
|
275 |
|
|
|
276 |
check_layer_name(model, layer_name) |
|
|
277 |
model <- tensorflow::tf$keras$Model(model$input, model$get_layer(layer_name)$output) |
|
|
278 |
if (verbose) { |
|
|
279 |
cat("Computing output for model at layer", layer_name, "\n") |
|
|
280 |
print(model) |
|
|
281 |
} |
|
|
282 |
|
|
|
283 |
# extract number of neurons in last layer |
|
|
284 |
if (length(model$output$shape$dims) == 3) { |
|
|
285 |
if (!("lstm" %in% stringr::str_to_lower(model$output_names))) { |
|
|
286 |
stop("Output dimension of layer is > 1, format not supported yet") |
|
|
287 |
} |
|
|
288 |
layer.size <- model$output$shape[[3]] |
|
|
289 |
} else { |
|
|
290 |
layer.size <- model$output$shape[[2]] |
|
|
291 |
} |
|
|
292 |
|
|
|
293 |
# tokenize sequence |
|
|
294 |
nt_seq <- stringr::str_to_lower(nt_seq) |
|
|
295 |
tokSeq <- keras::texts_to_sequences(tokenizer, nt_seq)[[1]] - 1 |
|
|
296 |
|
|
|
297 |
# seq end position |
|
|
298 |
pos_arg <- start_indices + total_seq_len - pad_len - 1 |
|
|
299 |
|
|
|
300 |
if (include_seq) { |
|
|
301 |
output_seq <- substr(nt_seq, pad_len + 1, nchar(nt_seq)) |
|
|
302 |
} |
|
|
303 |
|
|
|
304 |
# create h5 file to store states |
|
|
305 |
if (output_type == "h5") { |
|
|
306 |
h5_file <- hdf5r::H5File$new(filename, mode = "w") |
|
|
307 |
h5_file[["multi_entries"]] <- FALSE |
|
|
308 |
h5_file[["sample_end_position"]] <- pos_arg |
|
|
309 |
if (include_seq) h5_file[["sequence"]] <- output_seq |
|
|
310 |
} |
|
|
311 |
|
|
|
312 |
number_batches <- ceiling(length(start_indices)/batch_size) |
|
|
313 |
pred_list <- vector("list", number_batches) |
|
|
314 |
col_names <- c(as.character(1:layer.size), "sample_end_position") |
|
|
315 |
|
|
|
316 |
# subset input for target middle |
|
|
317 |
if (mode == "lm" && lm_format %in% c("target_middle_lstm", "target_middle_cnn")) { |
|
|
318 |
index_x_1 <- 1:ceiling((total_seq_len - target_len)/2) |
|
|
319 |
index_x_2 <- (max(index_x_1) + target_len + 1) : total_seq_len |
|
|
320 |
} |
|
|
321 |
|
|
|
322 |
for (i in 1:number_batches) { |
|
|
323 |
|
|
|
324 |
index_start <- ((i - 1) * batch_size) + 1 |
|
|
325 |
index_end <- min(c(num_samples + 1, index_start + batch_size)) - 1 |
|
|
326 |
index <- index_start : index_end |
|
|
327 |
|
|
|
328 |
x <- seq_encoding_label(sequence = tokSeq, |
|
|
329 |
maxlen = total_seq_len, |
|
|
330 |
vocabulary = vocabulary, |
|
|
331 |
start_ind = start_indices[index], |
|
|
332 |
ambiguous_nuc = ambiguous_nuc, |
|
|
333 |
tokenizer = NULL, |
|
|
334 |
adjust_start_ind = FALSE, |
|
|
335 |
quality_vector = quality_vector, |
|
|
336 |
... |
|
|
337 |
) |
|
|
338 |
|
|
|
339 |
if (mode == "lm" && lm_format == "target_middle_lstm") { |
|
|
340 |
x1 <- x[ , index_x_1, ] |
|
|
341 |
x2 <- x[ , index_x_2, ] |
|
|
342 |
|
|
|
343 |
if (length(index_x_1) == 1 | dim(x)[1] == 1) { |
|
|
344 |
x1 <- array(x1, dim = c(1, dim(x1))) |
|
|
345 |
} |
|
|
346 |
if (length(index_x_2) == 1 | dim(x)[1] == 1) { |
|
|
347 |
x2 <- array(x2, dim = c(1, dim(x2))) |
|
|
348 |
} |
|
|
349 |
|
|
|
350 |
x2 <- x2[ , dim(x2)[2]:1, ] # reverse order |
|
|
351 |
|
|
|
352 |
if (length(dim(x2)) == 2) { |
|
|
353 |
x2 <- array(x2, dim = c(1, dim(x2))) |
|
|
354 |
} |
|
|
355 |
|
|
|
356 |
x <- list(x1, x2) |
|
|
357 |
} |
|
|
358 |
|
|
|
359 |
if (mode == "lm" && lm_format == "target_middle_cnn") { |
|
|
360 |
x <- x[ , c(index_x_1, index_x_2), ] |
|
|
361 |
} |
|
|
362 |
|
|
|
363 |
if (reverse_complement_encoding) x <- list(x, reverse_complement_tensor(x)) |
|
|
364 |
|
|
|
365 |
y <- stats::predict(model, x, verbose = 0) |
|
|
366 |
if (!is.null(round_digits)) y <- round(y, round_digits) |
|
|
367 |
pred_list[[i]] <- y |
|
|
368 |
|
|
|
369 |
} |
|
|
370 |
|
|
|
371 |
states <- do.call(rbind, pred_list) |
|
|
372 |
|
|
|
373 |
if (file_output) { |
|
|
374 |
if (output_type == "h5") { |
|
|
375 |
h5_file[["states"]] <- states |
|
|
376 |
h5_file$close_all() |
|
|
377 |
} else { |
|
|
378 |
col_names <- paste0("N", 1:ncol(states)) |
|
|
379 |
colnames(states) <- col_names |
|
|
380 |
utils::write.csv(x = states, file = filename, row.names = FALSE) |
|
|
381 |
} |
|
|
382 |
} |
|
|
383 |
|
|
|
384 |
if (return_states) { |
|
|
385 |
output_list <- list() |
|
|
386 |
output_list$states <- states |
|
|
387 |
output_list$sample_end_position <- pos_arg |
|
|
388 |
if (include_seq) output_list$sequence <- output_seq |
|
|
389 |
return(output_list) |
|
|
390 |
} else { |
|
|
391 |
return(NULL) |
|
|
392 |
} |
|
|
393 |
} |
|
|
394 |
|
|
|
395 |
#' Write states to h5 file |
|
|
396 |
#' |
|
|
397 |
#' @description Removes layers (optional) from pretrained model and calculates states of fasta file, writes a separate |
|
|
398 |
#' h5 file for every fasta entry in fasta file. h5 files also contain the nucleotide sequence and positions of targets corresponding to states. |
|
|
399 |
#' Names of output files are: file_path + "Nr" + i + filename + output_type, where i is the number of the fasta entry. |
|
|
400 |
#' |
|
|
401 |
#' @param filename Filename to store states, function adds "_nr_" + "i" after name, where i is entry number. |
|
|
402 |
#' @param output_dir Path to folder, where to write output. |
|
|
403 |
#' @noRd |
|
|
404 |
predict_model_by_entry <- function(model, layer_name = NULL, path_input, round_digits = 2, |
|
|
405 |
filename = "states.h5", output_dir = NULL, step = 1, vocabulary = c("a", "c", "g", "t"), |
|
|
406 |
batch_size = 256, output_type = "h5", mode = "lm", |
|
|
407 |
lm_format = "target_right", format = "fasta", use_quality = FALSE, |
|
|
408 |
reverse_complement_encoding = FALSE, padding = "none", |
|
|
409 |
verbose = FALSE, include_seq = FALSE, ambiguous_nuc = "zero", ...) { |
|
|
410 |
|
|
|
411 |
stopifnot(mode %in% c("lm", "label")) |
|
|
412 |
stopifnot(!is.null(filename)) |
|
|
413 |
stopifnot(!is.null(output_dir)) |
|
|
414 |
|
|
|
415 |
if (endsWith(filename, paste0(".", output_type))) { |
|
|
416 |
filename <- stringr::str_remove(filename, paste0(".", output_type, "$")) |
|
|
417 |
filename <- basename(filename) |
|
|
418 |
} |
|
|
419 |
|
|
|
420 |
# extract maxlen |
|
|
421 |
target_middle <- ifelse(mode == "lm" && (lm_format %in% c("target_middle_lstm", "target_middle_cnn")), TRUE, FALSE) |
|
|
422 |
if (!target_middle) { |
|
|
423 |
if (reverse_complement_encoding) { |
|
|
424 |
model$input[[1]]$shape[[2]] |
|
|
425 |
} else { |
|
|
426 |
maxlen <- model$input$shape[[2]] |
|
|
427 |
} |
|
|
428 |
} else { |
|
|
429 |
maxlen_1 <- model$input[[1]]$shape[[2]] |
|
|
430 |
maxlen_2 <- model$input[[2]]$shape[[2]] |
|
|
431 |
maxlen <- maxlen_1 + maxlen_2 |
|
|
432 |
} |
|
|
433 |
|
|
|
434 |
# load fasta file |
|
|
435 |
if (format == "fasta") { |
|
|
436 |
fasta.file <- microseq::readFasta(path_input) |
|
|
437 |
} |
|
|
438 |
if (format == "fastq") { |
|
|
439 |
fasta.file <- microseq::readFastq(path_input) |
|
|
440 |
} |
|
|
441 |
df <- fasta.file[ , c("Sequence", "Header")] |
|
|
442 |
names(df) <- c("seq", "header") |
|
|
443 |
rownames(df) <- NULL |
|
|
444 |
num_skipped_seq <- 0 |
|
|
445 |
|
|
|
446 |
for (i in 1:nrow(df)) { |
|
|
447 |
|
|
|
448 |
# skip entry if too short |
|
|
449 |
if ((nchar(df[i, "seq"]) < maxlen) & padding == "none") { |
|
|
450 |
num_skipped_seq <- num_skipped_seq + 1 |
|
|
451 |
next |
|
|
452 |
} |
|
|
453 |
|
|
|
454 |
if (use_quality) { |
|
|
455 |
quality_string <- fasta.file$Quality[i] |
|
|
456 |
} else { |
|
|
457 |
quality_string <- NULL |
|
|
458 |
} |
|
|
459 |
|
|
|
460 |
current_file <- paste0(output_dir, "/", filename, "_nr_", as.character(i), ".", output_type) |
|
|
461 |
|
|
|
462 |
predict_model_one_seq(layer_name = layer_name, sequence = df[i, "seq"], |
|
|
463 |
round_digits = round_digits, path_input = path_input, |
|
|
464 |
filename = current_file, quality_string = quality_string, |
|
|
465 |
step = step, vocabulary = vocabulary, batch_size = batch_size, |
|
|
466 |
verbose = ifelse(i > 1, FALSE, verbose), |
|
|
467 |
output_type = output_type, mode = mode, |
|
|
468 |
lm_format = lm_format, model = model, include_seq = include_seq, |
|
|
469 |
padding = padding, ambiguous_nuc = "zero", |
|
|
470 |
reverse_complement_encoding = reverse_complement_encoding, ...) |
|
|
471 |
} |
|
|
472 |
|
|
|
473 |
if (verbose & num_skipped_seq > 0) { |
|
|
474 |
message(paste0("Skipped ", num_skipped_seq, |
|
|
475 |
ifelse(num_skipped_seq == 1, " entry", " entries"), |
|
|
476 |
". Use different padding option to evaluate all.")) |
|
|
477 |
} |
|
|
478 |
|
|
|
479 |
} |
|
|
480 |
|
|
|
481 |
#' Write states to h5 file |
|
|
482 |
#' |
|
|
483 |
#' @description Removes layers (optional) from pretrained model and calculates states of fasta file, |
|
|
484 |
#' writes separate states matrix in one .h5 file for every fasta entry. |
|
|
485 |
#' h5 file also contains the nucleotide sequences and positions of targets corresponding to states. |
|
|
486 |
#' @noRd |
|
|
487 |
predict_model_by_entry_one_file <- function(model, path_input, round_digits = 2, filename = "states.h5", |
|
|
488 |
step = 1, vocabulary = c("a", "c", "g", "t"), batch_size = 256, layer_name = NULL, |
|
|
489 |
verbose = TRUE, mode = "lm", use_quality = FALSE, |
|
|
490 |
lm_format = "target_right", padding = "none", |
|
|
491 |
format = "fasta", include_seq = TRUE, reverse_complement_encoding = FALSE, ...) { |
|
|
492 |
|
|
|
493 |
vocabulary <- stringr::str_to_lower(vocabulary) |
|
|
494 |
stopifnot(mode %in% c("lm", "label")) |
|
|
495 |
|
|
|
496 |
target_middle <- ifelse(mode == "lm" && (lm_format %in% c("target_middle_lstm", "target_middle_cnn")), TRUE, FALSE) |
|
|
497 |
# extract maxlen |
|
|
498 |
if (!target_middle) { |
|
|
499 |
if (reverse_complement_encoding) { |
|
|
500 |
maxlen <- model$input[[1]]$shape[[2]] |
|
|
501 |
} else { |
|
|
502 |
maxlen <- model$input$shape[[2]] |
|
|
503 |
} |
|
|
504 |
} else { |
|
|
505 |
maxlen_1 <- model$input[[1]]$shape[[2]] |
|
|
506 |
maxlen_2 <- model$input[[2]]$shape[[2]] |
|
|
507 |
maxlen <- maxlen_1 + maxlen_2 |
|
|
508 |
} |
|
|
509 |
|
|
|
510 |
# extract number of neurons in last layer |
|
|
511 |
if (length(model$output$shape$dims) == 3) { |
|
|
512 |
if (!("lstm" %in% stringr::str_to_lower(model$output_names))) { |
|
|
513 |
stop("Output dimension of layer is > 1, format not supported yet") |
|
|
514 |
} |
|
|
515 |
layer.size <- model$output$shape[[3]] |
|
|
516 |
} else { |
|
|
517 |
layer.size <- model$output$shape[[2]] |
|
|
518 |
} |
|
|
519 |
|
|
|
520 |
# load fasta file |
|
|
521 |
if (format == "fasta") { |
|
|
522 |
fasta.file <- microseq::readFasta(path_input) |
|
|
523 |
} |
|
|
524 |
if (format == "fastq") { |
|
|
525 |
fasta.file <- microseq::readFastq(path_input) |
|
|
526 |
} |
|
|
527 |
df <- fasta.file[ , c("Sequence", "Header")] |
|
|
528 |
names(df) <- c("seq", "header") |
|
|
529 |
rownames(df) <- NULL |
|
|
530 |
|
|
|
531 |
if (verbose) { |
|
|
532 |
# check if names are unique |
|
|
533 |
if (length(df$header) != length(unique(df$header))) { |
|
|
534 |
message("Header names are not unique, adding '_header_x' to names (x being the header number)") |
|
|
535 |
df$header <- paste0(df$header, paste0("_header_", 1:length(df$header))) |
|
|
536 |
} |
|
|
537 |
} |
|
|
538 |
|
|
|
539 |
# create h5 file to store states |
|
|
540 |
|
|
|
541 |
h5_file <- hdf5r::H5File$new(filename, mode = "w") |
|
|
542 |
h5_file[["multi_entries"]] <- TRUE |
|
|
543 |
states.grp <- h5_file$create_group("states") |
|
|
544 |
sample_end_position.grp <- h5_file$create_group("sample_end_position") |
|
|
545 |
if (include_seq) seq.grp <- h5_file$create_group("sequence") |
|
|
546 |
|
|
|
547 |
num_skipped_seq <- 0 |
|
|
548 |
|
|
|
549 |
for (i in 1:nrow(df)) { |
|
|
550 |
|
|
|
551 |
#seq_name <- df$header[i] |
|
|
552 |
seq_name <- paste0("entry_", i) |
|
|
553 |
temp_file <- tempfile(fileext = ".h5") |
|
|
554 |
|
|
|
555 |
# skip entry if too short |
|
|
556 |
if ((nchar(df[i, "seq"]) < maxlen) & padding == "none") { |
|
|
557 |
num_skipped_seq <- num_skipped_seq + 1 |
|
|
558 |
next |
|
|
559 |
} |
|
|
560 |
|
|
|
561 |
if (use_quality) { |
|
|
562 |
quality_string <- fasta.file$Quality[i] |
|
|
563 |
} else { |
|
|
564 |
quality_string <- NULL |
|
|
565 |
} |
|
|
566 |
|
|
|
567 |
output_list <- predict_model_one_seq(layer_name = layer_name, sequence = df$seq[i], path_input = path_input, |
|
|
568 |
round_digits = round_digits, filename = temp_file, step = step, vocabulary = vocabulary, |
|
|
569 |
batch_size = batch_size, return_states = TRUE, quality_string = quality_string, |
|
|
570 |
output_type = "h5", model = model, mode = mode, lm_format = lm_format, |
|
|
571 |
ambiguous_nuc = "zero", verbose = ifelse(i > 1, FALSE, verbose), |
|
|
572 |
padding = padding, format = format, include_seq = include_seq, |
|
|
573 |
reverse_complement_encoding = reverse_complement_encoding, ...) |
|
|
574 |
|
|
|
575 |
states.grp[[seq_name]] <- output_list$states |
|
|
576 |
sample_end_position.grp[[seq_name]] <- output_list$sample_end_position |
|
|
577 |
|
|
|
578 |
if (include_seq) seq.grp[[seq_name]] <- output_list$sequence |
|
|
579 |
} |
|
|
580 |
|
|
|
581 |
if (verbose & num_skipped_seq > 0) { |
|
|
582 |
message(paste0("Skipped ", num_skipped_seq, |
|
|
583 |
ifelse(num_skipped_seq == 1, " entry", " entries"), |
|
|
584 |
". Use different padding option to evaluate all.")) |
|
|
585 |
} |
|
|
586 |
|
|
|
587 |
h5_file$close_all() |
|
|
588 |
} |
|
|
589 |
|
|
|
590 |
#' Get states for label classification model. |
|
|
591 |
#' |
|
|
592 |
#' Computes output at specified model layer. Forces every fasta entry to have length maxlen by either padding sequences shorter than maxlen or taking random subsequence for |
|
|
593 |
#' longer sequences. |
|
|
594 |
#' |
|
|
595 |
#' @inheritParams predict_model_one_seq |
|
|
596 |
#' @noRd |
|
|
597 |
predict_model_one_pred_per_entry <- function(model, layer_name = NULL, path_input, round_digits = 2, format = "fasta", |
|
|
598 |
ambiguous_nuc = "zero", filename = "states.h5", padding = padding, |
|
|
599 |
vocabulary = c("a", "c", "g", "t"), batch_size = 256, verbose = TRUE, |
|
|
600 |
return_states = FALSE, reverse_complement_encoding = FALSE, |
|
|
601 |
include_seq = FALSE, use_quality = FALSE, ...) { |
|
|
602 |
|
|
|
603 |
vocabulary <- stringr::str_to_lower(vocabulary) |
|
|
604 |
file_type <- "h5" |
|
|
605 |
stopifnot(batch_size > 0) |
|
|
606 |
stopifnot(!file.exists(filename)) |
|
|
607 |
# token for ambiguous nucleotides |
|
|
608 |
for (i in letters) { |
|
|
609 |
if (!(i %in% stringr::str_to_lower(vocabulary))) { |
|
|
610 |
amb_nuc_token <- i |
|
|
611 |
break |
|
|
612 |
} |
|
|
613 |
} |
|
|
614 |
tokenizer <- keras::fit_text_tokenizer(keras::text_tokenizer(char_level = TRUE, lower = TRUE, oov_token = "N"), vocabulary) |
|
|
615 |
|
|
|
616 |
if (is.null(layer_name)) { |
|
|
617 |
layer_name <- model$output_names |
|
|
618 |
if (verbose) message(paste("layer_name not specified. Using layer", layer_name)) |
|
|
619 |
} |
|
|
620 |
|
|
|
621 |
# extract maxlen |
|
|
622 |
if (reverse_complement_encoding) { |
|
|
623 |
maxlen <- model$input[[1]]$shape[[2]] |
|
|
624 |
} else { |
|
|
625 |
maxlen <- model$input$shape[[2]] |
|
|
626 |
} |
|
|
627 |
|
|
|
628 |
if (format == "fasta") { |
|
|
629 |
fasta.file <- microseq::readFasta(path_input) |
|
|
630 |
} |
|
|
631 |
if (format == "fastq") { |
|
|
632 |
fasta.file <- microseq::readFastq(path_input) |
|
|
633 |
} |
|
|
634 |
|
|
|
635 |
num_samples <- nrow(fasta.file) |
|
|
636 |
|
|
|
637 |
nucSeq <- as.character(fasta.file$Sequence) |
|
|
638 |
seq_length <- nchar(fasta.file$Sequence) |
|
|
639 |
if (use_quality) { |
|
|
640 |
quality_string <- vector("character", length(nucSeq)) |
|
|
641 |
} else { |
|
|
642 |
quality_string <- NULL |
|
|
643 |
} |
|
|
644 |
|
|
|
645 |
for (i in 1:length(nucSeq)) { |
|
|
646 |
# take random subsequence |
|
|
647 |
if (seq_length[i] > maxlen) { |
|
|
648 |
start <- sample(1 : (seq_length[i] - maxlen + 1) , size = 1) |
|
|
649 |
nucSeq[i] <- substr(nucSeq[i], start = start, stop = start + maxlen - 1) |
|
|
650 |
if (use_quality) { |
|
|
651 |
quality_string[i] <- substr(fasta.file$Quality[i], start = start, stop = start + maxlen - 1) |
|
|
652 |
} |
|
|
653 |
} |
|
|
654 |
# pad sequence |
|
|
655 |
if (seq_length[i] < maxlen) { |
|
|
656 |
nucSeq[i] <- paste0(paste(rep("N", maxlen - seq_length[i]), collapse = ""), nucSeq[i]) |
|
|
657 |
if (use_quality) { |
|
|
658 |
quality_string[i] <- paste0(paste(rep("0", maxlen - seq_length[i]), collapse = ""), fasta.file$Quality[i]) |
|
|
659 |
} |
|
|
660 |
} |
|
|
661 |
} |
|
|
662 |
|
|
|
663 |
model <- tensorflow::tf$keras$Model(model$input, model$get_layer(layer_name)$output) |
|
|
664 |
if (verbose) { |
|
|
665 |
cat("Computing output for model at layer", layer_name, "\n") |
|
|
666 |
print(model) |
|
|
667 |
} |
|
|
668 |
|
|
|
669 |
# extract number of neurons in last layer |
|
|
670 |
if (length(model$output$shape$dims) == 3) { |
|
|
671 |
if (!("lstm" %in% stringr::str_to_lower(model$output_names))) { |
|
|
672 |
stop("Output dimension of layer is > 1, format not supported yet") |
|
|
673 |
} |
|
|
674 |
layer.size <- model$output$shape[[3]] |
|
|
675 |
} else { |
|
|
676 |
layer.size <- model$output$shape[[2]] |
|
|
677 |
} |
|
|
678 |
|
|
|
679 |
if (file_type == "h5") { |
|
|
680 |
# create h5 file to store states |
|
|
681 |
h5_file <- hdf5r::H5File$new(filename, mode = "w") |
|
|
682 |
if (!missing(path_input)) h5_file[["fasta_file"]] <- path_input |
|
|
683 |
|
|
|
684 |
h5_file[["header_names"]] <- fasta.file$Header |
|
|
685 |
if (include_seq) h5_file[["sequences"]] <- nucSeq |
|
|
686 |
h5_file[["states"]] <- array(0, dim = c(0, layer.size)) |
|
|
687 |
h5_file[["multi_entries"]] <- FALSE |
|
|
688 |
writer <- h5_file[["states"]] |
|
|
689 |
} |
|
|
690 |
|
|
|
691 |
rm(fasta.file) |
|
|
692 |
#nucSeq <- paste(nucSeq, collapse = "") %>% stringr::str_to_lower() |
|
|
693 |
number_batches <- ceiling(num_samples/batch_size) |
|
|
694 |
if (verbose) cat("Evaluating", number_batches, ifelse(number_batches > 1, "batches", "batch"), "\n") |
|
|
695 |
row <- 1 |
|
|
696 |
string_start_index <- 1 |
|
|
697 |
ten_percent_steps <- seq(number_batches/10, number_batches, length.out = 10) |
|
|
698 |
percentage_index <- 1 |
|
|
699 |
|
|
|
700 |
if (number_batches > 1) { |
|
|
701 |
for (i in 1:(number_batches - 1)) { |
|
|
702 |
string_end_index <-string_start_index + batch_size - 1 |
|
|
703 |
char_seq <- nucSeq[string_start_index : string_end_index] %>% paste(collapse = "") |
|
|
704 |
|
|
|
705 |
if (use_quality) { |
|
|
706 |
quality_string_subset <- quality_string[string_start_index : string_end_index] %>% paste(collapse = "") |
|
|
707 |
} else { |
|
|
708 |
quality_string_subset <- NULL |
|
|
709 |
} |
|
|
710 |
|
|
|
711 |
if (i == 1) start_ind <- seq(1, nchar(char_seq), maxlen) |
|
|
712 |
one_hot_batch <- seq_encoding_label(sequence = NULL, maxlen = maxlen, vocabulary = vocabulary, |
|
|
713 |
start_ind = start_ind, ambiguous_nuc = ambiguous_nuc, |
|
|
714 |
char_sequence = char_seq, quality_vector = quality_string_subset, |
|
|
715 |
tokenizer = tokenizer, adjust_start_ind = TRUE, ...) |
|
|
716 |
if (reverse_complement_encoding) one_hot_batch <- list(one_hot_batch, reverse_complement_tensor(one_hot_batch)) |
|
|
717 |
activations <- keras::predict_on_batch(model, one_hot_batch) |
|
|
718 |
writer[row : (row + batch_size - 1), ] <- activations |
|
|
719 |
row <- row + batch_size |
|
|
720 |
string_start_index <- string_end_index + 1 |
|
|
721 |
|
|
|
722 |
if (verbose & (i > ten_percent_steps[percentage_index]) & percentage_index < 10) { |
|
|
723 |
cat("Progress: ", percentage_index * 10 ,"% \n") |
|
|
724 |
percentage_index <- percentage_index + 1 |
|
|
725 |
} |
|
|
726 |
|
|
|
727 |
} |
|
|
728 |
} |
|
|
729 |
|
|
|
730 |
# last batch might be shorter |
|
|
731 |
char_seq <- nucSeq[string_start_index : length(nucSeq)] %>% paste(collapse = "") |
|
|
732 |
if (use_quality) { |
|
|
733 |
quality_string_subset <- quality_string[string_start_index : length(nucSeq)] %>% paste(collapse = "") |
|
|
734 |
} else { |
|
|
735 |
quality_string_subset <- NULL |
|
|
736 |
} |
|
|
737 |
one_hot_batch <- seq_encoding_label(sequence = NULL, maxlen = maxlen, vocabulary = vocabulary, |
|
|
738 |
start_ind = seq(1, nchar(char_seq), maxlen), ambiguous_nuc = "zero", nuc_dist = NULL, |
|
|
739 |
quality_vector = quality_string_subset, use_coverage = FALSE, max_cov = NULL, |
|
|
740 |
cov_vector = NULL, n_gram = NULL, n_gram_stride = 1, char_sequence = char_seq, |
|
|
741 |
tokenizer = tokenizer, adjust_start_ind = TRUE, ...) |
|
|
742 |
if (reverse_complement_encoding) one_hot_batch <- list(one_hot_batch, reverse_complement_tensor(one_hot_batch)) |
|
|
743 |
activations <- keras::predict_on_batch(model, one_hot_batch) |
|
|
744 |
writer[row : num_samples, ] <- activations[1 : length(row:num_samples), ] |
|
|
745 |
|
|
|
746 |
if (verbose) cat("Progress: 100 % \n") |
|
|
747 |
|
|
|
748 |
if (return_states & (file_type == "h5")) states <- writer[ , ] |
|
|
749 |
if (file_type == "h5") h5_file$close_all() |
|
|
750 |
if (return_states) return(states) |
|
|
751 |
} |
|
|
752 |
|
|
|
753 |
|
|
|
754 |
#' Read states from h5 file |
|
|
755 |
#' |
|
|
756 |
#' Reads h5 file created by \code{\link{predict_model}} function. |
|
|
757 |
#' |
|
|
758 |
#' @param h5_path Path to h5 file. |
|
|
759 |
#' @param rows Range of rows to read. If `NULL` read everything. |
|
|
760 |
#' @param get_sample_position Return position of sample corresponding to state if `TRUE`. |
|
|
761 |
#' @param get_seq Return nucleotide sequence if `TRUE`. |
|
|
762 |
#' @param verbose Boolean. |
|
|
763 |
#' @examplesIf reticulate::py_module_available("tensorflow") |
|
|
764 |
#' # make prediction for single sequence and write to h5 file |
|
|
765 |
#' model <- create_model_lstm_cnn(maxlen = 20, layer_lstm = 8, layer_dense = 2, verbose = FALSE) |
|
|
766 |
#' vocabulary <- c("a", "c", "g", "t") |
|
|
767 |
#' sequence <- paste(sample(vocabulary, 200, replace = TRUE), collapse = "") |
|
|
768 |
#' output_file <- tempfile(fileext = ".h5") |
|
|
769 |
#' predict_model(output_format = "one_seq", model = model, step = 10, |
|
|
770 |
#' sequence = sequence, filename = output_file, mode = "label") |
|
|
771 |
#' load_prediction(h5_path = output_file) |
|
|
772 |
#' |
|
|
773 |
#' @returns A list of data frames, containing model predictions and sequence positions. |
|
|
774 |
#' @export |
|
|
775 |
load_prediction <- function(h5_path, rows = NULL, verbose = FALSE, |
|
|
776 |
get_sample_position = FALSE, get_seq = FALSE) { |
|
|
777 |
|
|
|
778 |
if (is.null(rows)) complete <- TRUE |
|
|
779 |
h5_file <- hdf5r::H5File$new(h5_path, mode = "r") |
|
|
780 |
|
|
|
781 |
multi_entries <- ifelse(h5_file[["multi_entries"]][], TRUE, FALSE) |
|
|
782 |
if (!multi_entries) { |
|
|
783 |
number_entries <- 1 |
|
|
784 |
} else { |
|
|
785 |
entry_names <- names(h5_file[["states"]]) |
|
|
786 |
number_entries <- length(entry_names) |
|
|
787 |
output_list <- list() |
|
|
788 |
} |
|
|
789 |
|
|
|
790 |
train_mode <- "label" |
|
|
791 |
|
|
|
792 |
if (get_sample_position & !any(c("sample_end_position", "target_position") %in% names(h5_file))) { |
|
|
793 |
get_sample_position <- FALSE |
|
|
794 |
message("File does not contain target positions.") |
|
|
795 |
} |
|
|
796 |
|
|
|
797 |
if (!multi_entries) { |
|
|
798 |
|
|
|
799 |
read_states <- h5_file[["states"]] |
|
|
800 |
|
|
|
801 |
if (get_sample_position) { |
|
|
802 |
read_targetPos <- h5_file[["sample_end_position"]] |
|
|
803 |
} |
|
|
804 |
|
|
|
805 |
if (verbose) { |
|
|
806 |
cat("states matrix has", dim(read_states[ , ])[1], "rows and", dim(read_states[ , ])[2], "columns \n") |
|
|
807 |
} |
|
|
808 |
if (complete) { |
|
|
809 |
states <- read_states[ , ] |
|
|
810 |
if (get_sample_position) { |
|
|
811 |
targetPos <- read_targetPos[ ] |
|
|
812 |
} |
|
|
813 |
} else { |
|
|
814 |
states <- read_states[rows, ] |
|
|
815 |
if (get_sample_position) { |
|
|
816 |
targetPos <- read_targetPos[rows] |
|
|
817 |
} |
|
|
818 |
} |
|
|
819 |
|
|
|
820 |
if (is.null(dim(states))) { |
|
|
821 |
states <- matrix(states, nrow = 1) |
|
|
822 |
} |
|
|
823 |
|
|
|
824 |
contains_seq <- FALSE |
|
|
825 |
if (get_seq) { |
|
|
826 |
if ("sequence" %in% names(h5_file)) { |
|
|
827 |
contains_seq <- TRUE |
|
|
828 |
sequence <- h5_file[["sequence"]][] |
|
|
829 |
} else { |
|
|
830 |
contains_seq <- FALSE |
|
|
831 |
message("File does not contain sequence.") |
|
|
832 |
} |
|
|
833 |
} |
|
|
834 |
|
|
|
835 |
h5_file$close_all() |
|
|
836 |
output_list <- list(states = states) |
|
|
837 |
if (get_sample_position) { |
|
|
838 |
if (train_mode == "label") { |
|
|
839 |
output_list$sample_end_position <- targetPos |
|
|
840 |
} else { |
|
|
841 |
output_list$target_position <- targetPos |
|
|
842 |
} |
|
|
843 |
} |
|
|
844 |
|
|
|
845 |
if (get_seq && contains_seq) { |
|
|
846 |
output_list$sequence <- sequence |
|
|
847 |
} |
|
|
848 |
|
|
|
849 |
return(output_list) |
|
|
850 |
|
|
|
851 |
# multi entries |
|
|
852 |
} else { |
|
|
853 |
|
|
|
854 |
if (verbose) { |
|
|
855 |
cat("file contains", number_entries, "entries \n") |
|
|
856 |
} |
|
|
857 |
|
|
|
858 |
if (get_sample_position) { |
|
|
859 |
target_name <- "sample_end_position" |
|
|
860 |
} |
|
|
861 |
|
|
|
862 |
if (get_seq & !("sequence" %in% names(h5_file))) { |
|
|
863 |
message("File does not contain sequence.") |
|
|
864 |
get_seq <- FALSE |
|
|
865 |
} |
|
|
866 |
|
|
|
867 |
for (i in 1:number_entries) { |
|
|
868 |
|
|
|
869 |
entry_name <- entry_names[i] |
|
|
870 |
states <- h5_file[["states"]][[entry_name]][ , ] |
|
|
871 |
if (is.null(dim(states))) { |
|
|
872 |
states <- matrix(states, nrow = 1) |
|
|
873 |
} |
|
|
874 |
|
|
|
875 |
if (get_seq) { |
|
|
876 |
sequence <- h5_file[["sequence"]][[entry_name]][ ] |
|
|
877 |
} |
|
|
878 |
|
|
|
879 |
if (get_sample_position) { |
|
|
880 |
targetPos <- h5_file[[target_name]][[entry_name]][ ] |
|
|
881 |
} |
|
|
882 |
|
|
|
883 |
if (!complete) { |
|
|
884 |
states <- states[rows, ] |
|
|
885 |
|
|
|
886 |
if (is.null(dim(states))) { |
|
|
887 |
states <- matrix(states, nrow = 1) |
|
|
888 |
} |
|
|
889 |
|
|
|
890 |
if (get_sample_position) { |
|
|
891 |
targetPos <- hdf5r::h5attr(read_states, target_name)[rows] |
|
|
892 |
} |
|
|
893 |
} |
|
|
894 |
|
|
|
895 |
if (get_sample_position) { |
|
|
896 |
l <- list(states = states, sample_end_position = targetPos) |
|
|
897 |
} else { |
|
|
898 |
l <- list(states = states) |
|
|
899 |
} |
|
|
900 |
|
|
|
901 |
if (get_seq) { |
|
|
902 |
l[["sequence"]] <- sequence |
|
|
903 |
} |
|
|
904 |
output_list[[entry_name]] <- l |
|
|
905 |
} |
|
|
906 |
h5_file$close_all() |
|
|
907 |
names(output_list) <- entry_names |
|
|
908 |
return(output_list) |
|
|
909 |
} |
|
|
910 |
} |
|
|
911 |
|
|
|
912 |
#' Create summary of predictions |
|
|
913 |
#' |
|
|
914 |
#' Create summary data frame for confidence predictions over 1 or several state files or a data frame. |
|
|
915 |
#' Columns in file or data frame should be confidence predictions for one class, |
|
|
916 |
#' i.e. each rows should sum to 1 and have nonnegative entries. |
|
|
917 |
#' Output data frame contains average confidence scores, max score and percentage of votes for each class. |
|
|
918 |
#' |
|
|
919 |
#' @param states_path Folder containing state files or a single file with same ending as `file_type`. |
|
|
920 |
#' @param label_names Names of predicted classes. |
|
|
921 |
#' @param file_type `"h5"` or `"csv"`. |
|
|
922 |
#' @param df A states data frame. Ignore `states_dir` argument if not `NULL`. |
|
|
923 |
#' @examples |
|
|
924 |
#' m <- c(0.9, 0.1, 0.2, 0.01, |
|
|
925 |
#' 0.05, 0.7, 0.2, 0, |
|
|
926 |
#' 0.05, 0.2, 0.6, 0.99) %>% matrix(ncol = 3) |
|
|
927 |
#' |
|
|
928 |
#' label_names <- paste0("class_", 1:3) |
|
|
929 |
#' df <- as.data.frame(m) |
|
|
930 |
#' pred_summary <- summarize_states(label_names = label_names, df = df) |
|
|
931 |
#' pred_summary |
|
|
932 |
#' |
|
|
933 |
#' @returns A data frame of predictions summaries. |
|
|
934 |
#' @export |
|
|
935 |
summarize_states <- function(states_path = NULL, label_names = NULL, file_type = "h5", df = NULL) { |
|
|
936 |
|
|
|
937 |
if (!is.null(df)) { |
|
|
938 |
states_path <- NULL |
|
|
939 |
} |
|
|
940 |
|
|
|
941 |
if (is.null(states_path)) { |
|
|
942 |
state_files <- 1 |
|
|
943 |
} else { |
|
|
944 |
if (endsWith(states_path, file_type)) { |
|
|
945 |
state_files <- states_path |
|
|
946 |
} else { |
|
|
947 |
state_files <- list.files(states_path, full.names = TRUE) |
|
|
948 |
} |
|
|
949 |
} |
|
|
950 |
|
|
|
951 |
if (!is.null(label_names)) { |
|
|
952 |
num_labels <- length(label_names) |
|
|
953 |
} |
|
|
954 |
|
|
|
955 |
summary_list <- list() |
|
|
956 |
|
|
|
957 |
for (state_file in state_files) { |
|
|
958 |
|
|
|
959 |
if (is.null(df)) { |
|
|
960 |
if (file_type == "h5") { |
|
|
961 |
df <- load_prediction(h5_path = state_file, get_sample_position = FALSE, verbose = FALSE) |
|
|
962 |
df <- as.data.frame(df$states) |
|
|
963 |
} |
|
|
964 |
if (file_type == "csv") { |
|
|
965 |
df <- utils::read.csv(state_file) |
|
|
966 |
if (ncol(df) != num_labels) { |
|
|
967 |
df <- utils::read.csv2(state_file) |
|
|
968 |
} |
|
|
969 |
} |
|
|
970 |
} |
|
|
971 |
|
|
|
972 |
if (state_file == state_files[1] & is.null(label_names)) { |
|
|
973 |
label_names <- paste0("X_", 1:ncol(df)) |
|
|
974 |
num_labels <- length(label_names) |
|
|
975 |
} |
|
|
976 |
|
|
|
977 |
stopifnot(ncol(df) == num_labels) |
|
|
978 |
|
|
|
979 |
names(df) <- c(label_names) |
|
|
980 |
|
|
|
981 |
mean_df <- data.frame(matrix(0, nrow = 1, ncol = num_labels)) |
|
|
982 |
names(mean_df) <- paste0("mean_conf_", label_names) |
|
|
983 |
max_df <- data.frame(matrix(0, nrow = 1, ncol = num_labels)) |
|
|
984 |
names(max_df) <- paste0("max_conf_", label_names) |
|
|
985 |
|
|
|
986 |
for (label in label_names) { |
|
|
987 |
mean_df[[paste0("mean_conf_", label)]] <- mean(df[[label]]) |
|
|
988 |
max_df[[paste0("max_conf_", label)]] <- max(df[[label]]) |
|
|
989 |
} |
|
|
990 |
|
|
|
991 |
vote_distribution <- apply(df[label_names], 1, which.max) |
|
|
992 |
vote_perc <- table(factor(vote_distribution, levels = 1:length(label_names)))/length(vote_distribution) |
|
|
993 |
votes_df <- data.frame(matrix(vote_perc, nrow = 1, ncol = num_labels)) |
|
|
994 |
names(votes_df) <- paste0("vote_perc_", label_names) |
|
|
995 |
|
|
|
996 |
mean_prediction <- label_names[which.max(unlist(mean_df))] |
|
|
997 |
max_prediction <- label_names[which.max(unlist(max_df))] |
|
|
998 |
vote_prediction <- label_names[which.max(vote_perc)] |
|
|
999 |
|
|
|
1000 |
if (is.null(states_path)) { |
|
|
1001 |
file_name <- NA |
|
|
1002 |
} else { |
|
|
1003 |
file_name <- basename(state_file) |
|
|
1004 |
} |
|
|
1005 |
|
|
|
1006 |
summary_list[[state_file]] <- data.frame(file_name, mean_df, max_df, votes_df, |
|
|
1007 |
mean_prediction, max_prediction, vote_prediction, |
|
|
1008 |
num_prediction = nrow(df)) |
|
|
1009 |
|
|
|
1010 |
} |
|
|
1011 |
|
|
|
1012 |
summary_df <- data.table::rbindlist(summary_list) |
|
|
1013 |
return(summary_df) |
|
|
1014 |
} |