|
a |
|
b/R/preprocess.R |
|
|
1 |
#' Encodes integer sequence for language model |
|
|
2 |
#' |
|
|
3 |
#' Helper function for \code{\link{generator_fasta_lm}}. |
|
|
4 |
#' Encodes integer sequence to input/target list according to \code{output_format} argument. |
|
|
5 |
#' |
|
|
6 |
#' @inheritParams generator_fasta_lm |
|
|
7 |
#' @param sequence Sequence of integers. |
|
|
8 |
#' @param start_ind Start positions of samples in \code{sequence}. |
|
|
9 |
#' @param ambiguous_nuc How to handle nucleotides outside vocabulary, either `"zero"`, `"empirical"` or `"equal"`. |
|
|
10 |
#' See \code{\link{train_model}}. Note that `"discard"` option is not available for this function. |
|
|
11 |
#' @param nuc_dist Nucleotide distribution. |
|
|
12 |
#' @param max_cov Biggest coverage value. Only applies if `use_coverage = TRUE`. |
|
|
13 |
#' @param cov_vector Vector of coverage values associated to the input. |
|
|
14 |
#' @param adjust_start_ind Whether to shift values in \code{start_ind} to start at 1: for example (5,11,25) becomes (1,7,21). |
|
|
15 |
#' @param quality_vector Vector of quality probabilities. |
|
|
16 |
#' @param tokenizer A keras tokenizer. |
|
|
17 |
#' @param char_sequence A character string. |
|
|
18 |
#' @examplesIf reticulate::py_module_available("tensorflow") |
|
|
19 |
#' # use integer sequence as input |
|
|
20 |
#' |
|
|
21 |
#' z <- seq_encoding_lm(sequence = c(1,0,5,1,3,4,3,1,4,1,2), |
|
|
22 |
#' maxlen = 5, |
|
|
23 |
#' vocabulary = c("a", "c", "g", "t"), |
|
|
24 |
#' start_ind = c(1,3), |
|
|
25 |
#' ambiguous_nuc = "equal", |
|
|
26 |
#' target_len = 1, |
|
|
27 |
#' output_format = "target_right") |
|
|
28 |
#' |
|
|
29 |
#' x <- z[[1]] |
|
|
30 |
#' y <- z[[2]] |
|
|
31 |
#' |
|
|
32 |
#' x[1,,] # 1,0,5,1,3 |
|
|
33 |
#' y[1,] # 4 |
|
|
34 |
#' |
|
|
35 |
#' x[2,,] # 5,1,3,4, |
|
|
36 |
#' y[2,] # 1 |
|
|
37 |
#' |
|
|
38 |
#' # use character string as input |
|
|
39 |
#' z <- seq_encoding_lm(sequence = NULL, |
|
|
40 |
#' maxlen = 5, |
|
|
41 |
#' vocabulary = c("a", "c", "g", "t"), |
|
|
42 |
#' start_ind = c(1,3), |
|
|
43 |
#' ambiguous_nuc = "zero", |
|
|
44 |
#' target_len = 1, |
|
|
45 |
#' output_format = "target_right", |
|
|
46 |
#' char_sequence = "ACTaaTNTNaZ") |
|
|
47 |
#' |
|
|
48 |
#' |
|
|
49 |
#' x <- z[[1]] |
|
|
50 |
#' y <- z[[2]] |
|
|
51 |
#' |
|
|
52 |
#' x[1,,] # actaa |
|
|
53 |
#' y[1,] # t |
|
|
54 |
#' |
|
|
55 |
#' x[2,,] # taatn |
|
|
56 |
#' y[2,] # t |
|
|
57 |
#' |
|
|
58 |
#' @returns A list of 2 tensors. |
|
|
59 |
#' @export |
|
|
60 |
seq_encoding_lm <- function(sequence = NULL, maxlen, vocabulary, start_ind, ambiguous_nuc = "zero", |
|
|
61 |
nuc_dist = NULL, quality_vector = NULL, return_int = FALSE, |
|
|
62 |
target_len = 1, use_coverage = FALSE, max_cov = NULL, cov_vector = NULL, |
|
|
63 |
n_gram = NULL, n_gram_stride = 1, output_format = "target_right", |
|
|
64 |
char_sequence = NULL, adjust_start_ind = FALSE, |
|
|
65 |
tokenizer = NULL) { |
|
|
66 |
|
|
|
67 |
use_quality <- ifelse(is.null(quality_vector), FALSE, TRUE) |
|
|
68 |
discard_amb_nt <- FALSE |
|
|
69 |
## TODO: add discard_amb_nt |
|
|
70 |
if (!is.null(char_sequence)) { |
|
|
71 |
|
|
|
72 |
vocabulary <- stringr::str_to_lower(vocabulary) |
|
|
73 |
pattern <- paste0("[^", paste0(vocabulary, collapse = ""), "]") |
|
|
74 |
|
|
|
75 |
|
|
|
76 |
# token for ambiguous nucleotides |
|
|
77 |
for (i in letters) { |
|
|
78 |
if (!(i %in% stringr::str_to_lower(vocabulary))) { |
|
|
79 |
amb_nuc_token <- i |
|
|
80 |
break |
|
|
81 |
} |
|
|
82 |
} |
|
|
83 |
|
|
|
84 |
if (is.null(tokenizer)) { |
|
|
85 |
tokenizer <- keras::fit_text_tokenizer(keras::text_tokenizer(char_level = TRUE, lower = TRUE, oov_token = "0"), c(vocabulary, amb_nuc_token)) |
|
|
86 |
} |
|
|
87 |
|
|
|
88 |
sequence <- stringr::str_to_lower(char_sequence) |
|
|
89 |
sequence <- stringr::str_replace_all(string = sequence, pattern = pattern, amb_nuc_token) |
|
|
90 |
sequence <- keras::texts_to_sequences(tokenizer, sequence)[[1]] - 1 |
|
|
91 |
} |
|
|
92 |
|
|
|
93 |
voc_len <- length(vocabulary) |
|
|
94 |
if (target_len == 1) { |
|
|
95 |
n_gram <- NULL |
|
|
96 |
} |
|
|
97 |
if (!is.null(n_gram)) { |
|
|
98 |
if (target_len < n_gram) stop("target_len needs to be at least as big as n_gram") |
|
|
99 |
} |
|
|
100 |
|
|
|
101 |
if (adjust_start_ind) start_ind <- start_ind - start_ind[1] + 1 |
|
|
102 |
numberOfSamples <- length(start_ind) |
|
|
103 |
|
|
|
104 |
# every row in z one-hot encodes one character in sequence, oov is zero-vector |
|
|
105 |
num_classes <- voc_len + 2 |
|
|
106 |
z <- keras::to_categorical(sequence, num_classes = num_classes)[ , -c(1, num_classes)] |
|
|
107 |
|
|
|
108 |
if (use_quality) { |
|
|
109 |
ones_pos <- apply(z, 1, which.max) |
|
|
110 |
is_zero_row <- apply(z == 0, 1, all) |
|
|
111 |
z <- purrr::map(1:length(quality_vector), ~create_quality_vector(pos = ones_pos[.x], prob = quality_vector[.x], |
|
|
112 |
voc_length = length(vocabulary))) %>% unlist() %>% |
|
|
113 |
matrix(ncol = length(vocabulary), byrow = TRUE) |
|
|
114 |
z[is_zero_row, ] <- 0 |
|
|
115 |
} |
|
|
116 |
|
|
|
117 |
if (ambiguous_nuc == "equal") { |
|
|
118 |
amb_nuc_pos <- which(sequence == (voc_len + 1)) |
|
|
119 |
z[amb_nuc_pos, ] <- matrix(rep(1/voc_len, ncol(z) * length(amb_nuc_pos)), ncol = ncol(z)) |
|
|
120 |
} |
|
|
121 |
|
|
|
122 |
if (ambiguous_nuc == "empirical") { |
|
|
123 |
if (!is.null(n_gram)) stop("Can only use equal, zero or discard option for ambiguous_nuc when using n_gram encoding") |
|
|
124 |
amb_nuc_pos <- which(sequence == (voc_len + 1)) |
|
|
125 |
z[amb_nuc_pos, ] <- matrix(rep(nuc_dist, length(amb_nuc_pos)), nrow = length(amb_nuc_pos), byrow = TRUE) |
|
|
126 |
} |
|
|
127 |
|
|
|
128 |
if (use_coverage) { |
|
|
129 |
z <- z * (cov_vector/max_cov) |
|
|
130 |
} |
|
|
131 |
|
|
|
132 |
if (target_len == 1) { |
|
|
133 |
|
|
|
134 |
if (output_format == "target_right") { |
|
|
135 |
x <- array(0, dim = c(numberOfSamples, maxlen, voc_len)) |
|
|
136 |
for (i in 1:numberOfSamples) { |
|
|
137 |
start <- start_ind[i] |
|
|
138 |
x[i, , ] <- z[start : (start + maxlen - 1), ] |
|
|
139 |
} |
|
|
140 |
y <- z[start_ind + maxlen, ] |
|
|
141 |
} |
|
|
142 |
|
|
|
143 |
if (output_format == "wavenet") { |
|
|
144 |
if (!is.null(n_gram)) stop("Wavenet format not implemented for n_gram.") |
|
|
145 |
x <- array(0, dim = c(numberOfSamples, maxlen, voc_len)) |
|
|
146 |
y <- array(0, dim = c(numberOfSamples, maxlen, voc_len)) |
|
|
147 |
for (i in 1:numberOfSamples) { |
|
|
148 |
start <- start_ind[i] |
|
|
149 |
x[i, , ] <- z[start : (start + maxlen - 1), ] |
|
|
150 |
y[i, , ] <- z[(start + 1) : (start + maxlen), ] |
|
|
151 |
} |
|
|
152 |
} |
|
|
153 |
|
|
|
154 |
if (output_format == "target_middle_cnn") { |
|
|
155 |
x <- array(0, dim = c(numberOfSamples, maxlen + 1, voc_len)) |
|
|
156 |
for (i in 1:numberOfSamples) { |
|
|
157 |
start <- start_ind[i] |
|
|
158 |
x[i, , ] <- z[start : (start + maxlen), ] |
|
|
159 |
} |
|
|
160 |
missing_val <- ceiling(maxlen/2) |
|
|
161 |
y <- z[start_ind + missing_val, ] |
|
|
162 |
x <- x[ , -(missing_val + 1), ] |
|
|
163 |
} |
|
|
164 |
|
|
|
165 |
if (output_format == "target_middle_lstm") { |
|
|
166 |
len_input_1 <- ceiling(maxlen/2) |
|
|
167 |
len_input_2 <- floor(maxlen/2) |
|
|
168 |
input_tensor_1 <- array(0, dim = c(numberOfSamples, len_input_1, voc_len)) |
|
|
169 |
input_tensor_2 <- array(0, dim = c(numberOfSamples, len_input_2, voc_len)) |
|
|
170 |
for (i in 1:numberOfSamples) { |
|
|
171 |
start <- start_ind[i] |
|
|
172 |
input_tensor_1[i, , ] <- z[start : (start + len_input_1 - 1), ] |
|
|
173 |
input_tensor_2[i, , ] <- z[(start + maxlen) : (start + len_input_1 + 1), ] |
|
|
174 |
} |
|
|
175 |
if (!is.null(n_gram)) { |
|
|
176 |
input_tensor_1 <- input_tensor_1[ , 1:(dim(input_tensor_1) - n_gram + 1), ] |
|
|
177 |
input_tensor_2 <- input_tensor_2[ , 1:(dim(input_tensor_2) - n_gram + 1), ] |
|
|
178 |
} |
|
|
179 |
x <- list(input_tensor_1, input_tensor_2) |
|
|
180 |
y <- z[start_ind + len_input_1, ] |
|
|
181 |
} |
|
|
182 |
|
|
|
183 |
} |
|
|
184 |
|
|
|
185 |
if (target_len > 1) { |
|
|
186 |
|
|
|
187 |
if (output_format == "target_right") { |
|
|
188 |
x <- array(0, dim = c(numberOfSamples, maxlen - target_len + 1, voc_len)) |
|
|
189 |
for (i in 1:numberOfSamples) { |
|
|
190 |
start <- start_ind[i] |
|
|
191 |
x[i, , ] <- z[start : (start + maxlen - target_len), ] |
|
|
192 |
} |
|
|
193 |
y <- list() |
|
|
194 |
for (i in 1:target_len) { |
|
|
195 |
y[[i]] <- z[start_ind + maxlen - target_len + i, ] |
|
|
196 |
} |
|
|
197 |
} |
|
|
198 |
|
|
|
199 |
if (output_format == "target_middle_cnn") { |
|
|
200 |
x <- array(0, dim = c(numberOfSamples, maxlen + 1, voc_len)) |
|
|
201 |
for (i in 1:numberOfSamples) { |
|
|
202 |
start <- start_ind[i] |
|
|
203 |
x[i, , ] <- z[start : (start + maxlen), ] |
|
|
204 |
} |
|
|
205 |
missing_val <- ceiling((maxlen - target_len)/2) |
|
|
206 |
y <- list() |
|
|
207 |
for (i in 1:target_len) { |
|
|
208 |
y[[i]] <- z[start_ind + missing_val + i - 1, ] |
|
|
209 |
} |
|
|
210 |
x <- x[ , -((missing_val + 1):(missing_val + target_len)), ] |
|
|
211 |
} |
|
|
212 |
|
|
|
213 |
if (output_format == "target_middle_lstm") { |
|
|
214 |
len_input_1 <- ceiling((maxlen - target_len + 1)/2) |
|
|
215 |
len_input_2 <- maxlen + 1 - len_input_1 - target_len |
|
|
216 |
input_tensor_1 <- array(0, dim = c(numberOfSamples, len_input_1, voc_len)) |
|
|
217 |
input_tensor_2 <- array(0, dim = c(numberOfSamples, len_input_2, voc_len)) |
|
|
218 |
for (i in 1:numberOfSamples) { |
|
|
219 |
start <- start_ind[i] |
|
|
220 |
input_tensor_1[i, , ] <- z[start : (start + len_input_1 - 1), ] |
|
|
221 |
input_tensor_2[i, , ] <- z[(start + maxlen) : (start + maxlen - len_input_2 + 1), ] |
|
|
222 |
} |
|
|
223 |
|
|
|
224 |
x <- list(input_tensor_1, input_tensor_2) |
|
|
225 |
y <- list() |
|
|
226 |
for (i in 1:target_len) { |
|
|
227 |
y[[i]] <- z[start_ind + len_input_1 - 1 + i, ] |
|
|
228 |
} |
|
|
229 |
} |
|
|
230 |
|
|
|
231 |
if (output_format == "wavenet") { |
|
|
232 |
stop("Multi target not implemented for wavenet format.") |
|
|
233 |
} |
|
|
234 |
} |
|
|
235 |
|
|
|
236 |
if (is.matrix(x)) { |
|
|
237 |
x <- array(x, dim = c(1, dim(x))) |
|
|
238 |
} |
|
|
239 |
|
|
|
240 |
if (!is.null(n_gram)) { |
|
|
241 |
if (is.list(y)) y <- do.call(rbind, y) |
|
|
242 |
y_list <- list() |
|
|
243 |
for (i in 1:numberOfSamples) { |
|
|
244 |
index <- (i-1) + (1 + (0:(target_len-1))*numberOfSamples) |
|
|
245 |
input_matrix <- y[index, ] |
|
|
246 |
if (length(index) == 1) input_matrix <- matrix(input_matrix, nrow = 1) |
|
|
247 |
n_gram_matrix <- n_gram_of_matrix(input_matrix = input_matrix, n = n_gram) |
|
|
248 |
y_list[[i]] <- n_gram_matrix # tensorflow::tf$expand_dims(n_gram_matrix, axis = 0L) |
|
|
249 |
} |
|
|
250 |
y_tensor <- keras::k_stack(y_list, axis = 1L) %>% keras::k_eval() |
|
|
251 |
y <- vector("list", dim(y_tensor)[2]) |
|
|
252 |
|
|
|
253 |
for (i in 1:dim(y_tensor)[2]) { |
|
|
254 |
y_subset <- y_tensor[ , i, ] |
|
|
255 |
if (numberOfSamples == 1) y_subset <- matrix(y_subset, nrow = 1) |
|
|
256 |
y[[i]] <- y_subset |
|
|
257 |
} |
|
|
258 |
|
|
|
259 |
if (is.list(y) & length(y) == 1) { |
|
|
260 |
y <- y[[1]] |
|
|
261 |
} |
|
|
262 |
|
|
|
263 |
if (n_gram_stride > 1 & is.list(y)) { |
|
|
264 |
stride_index <- 0:(length(y)-1) %% n_gram_stride == 0 |
|
|
265 |
y <- y[stride_index] |
|
|
266 |
} |
|
|
267 |
} |
|
|
268 |
|
|
|
269 |
return(list(x, y)) |
|
|
270 |
} |
|
|
271 |
|
|
|
272 |
#' Encodes integer sequence for label classification. |
|
|
273 |
#' |
|
|
274 |
#' Returns encoding for integer or character sequence. |
|
|
275 |
#' |
|
|
276 |
#' @inheritParams seq_encoding_lm |
|
|
277 |
#' @inheritParams generator_fasta_lm |
|
|
278 |
#' @inheritParams train_model |
|
|
279 |
#' @param return_int Whether to return integer encoding or one-hot encoding. |
|
|
280 |
#' @examplesIf reticulate::py_module_available("tensorflow") |
|
|
281 |
#' # use integer sequence as input |
|
|
282 |
#' x <- seq_encoding_label(sequence = c(1,0,5,1,3,4,3,1,4,1,2), |
|
|
283 |
#' maxlen = 5, |
|
|
284 |
#' vocabulary = c("a", "c", "g", "t"), |
|
|
285 |
#' start_ind = c(1,3), |
|
|
286 |
#' ambiguous_nuc = "equal") |
|
|
287 |
#' |
|
|
288 |
#' x[1,,] # 1,0,5,1,3 |
|
|
289 |
#' |
|
|
290 |
#' x[2,,] # 5,1,3,4, |
|
|
291 |
#' |
|
|
292 |
#' # use character string as input |
|
|
293 |
#' x <- seq_encoding_label(maxlen = 5, |
|
|
294 |
#' vocabulary = c("a", "c", "g", "t"), |
|
|
295 |
#' start_ind = c(1,3), |
|
|
296 |
#' ambiguous_nuc = "equal", |
|
|
297 |
#' char_sequence = "ACTaaTNTNaZ") |
|
|
298 |
#' |
|
|
299 |
#' x[1,,] # actaa |
|
|
300 |
#' |
|
|
301 |
#' x[2,,] # taatn |
|
|
302 |
#' |
|
|
303 |
#' @returns A list of 2 tensors. |
|
|
304 |
#' @export |
|
|
305 |
seq_encoding_label <- function(sequence = NULL, maxlen, vocabulary, start_ind, ambiguous_nuc = "zero", nuc_dist = NULL, |
|
|
306 |
quality_vector = NULL, use_coverage = FALSE, max_cov = NULL, |
|
|
307 |
cov_vector = NULL, n_gram = NULL, n_gram_stride = 1, masked_lm = NULL, |
|
|
308 |
char_sequence = NULL, tokenizer = NULL, adjust_start_ind = FALSE, |
|
|
309 |
return_int = FALSE) { |
|
|
310 |
|
|
|
311 |
## TODO: add discard_amb_nt, add conditions for return_int |
|
|
312 |
use_quality <- ifelse(is.null(quality_vector), FALSE, TRUE) |
|
|
313 |
discard_amb_nt <- FALSE |
|
|
314 |
maxlen_original <- maxlen |
|
|
315 |
if (return_int) ambiguous_nuc <- "zero" |
|
|
316 |
|
|
|
317 |
if (!is.null(char_sequence)) { |
|
|
318 |
|
|
|
319 |
vocabulary <- stringr::str_to_lower(vocabulary) |
|
|
320 |
pattern <- paste0("[^", paste0(vocabulary, collapse = ""), "]") |
|
|
321 |
|
|
|
322 |
# token for ambiguous nucleotides |
|
|
323 |
for (i in letters) { |
|
|
324 |
if (!(i %in% stringr::str_to_lower(vocabulary))) { |
|
|
325 |
amb_nuc_token <- i |
|
|
326 |
break |
|
|
327 |
} |
|
|
328 |
} |
|
|
329 |
|
|
|
330 |
if (is.null(tokenizer)) { |
|
|
331 |
tokenizer <- keras::fit_text_tokenizer(keras::text_tokenizer(char_level = TRUE, lower = TRUE, oov_token = "0"), c(vocabulary, amb_nuc_token)) |
|
|
332 |
} |
|
|
333 |
|
|
|
334 |
sequence <- stringr::str_to_lower(char_sequence) |
|
|
335 |
sequence <- stringr::str_replace_all(string = sequence, pattern = pattern, amb_nuc_token) |
|
|
336 |
sequence <- keras::texts_to_sequences(tokenizer, sequence)[[1]] - 1 |
|
|
337 |
} |
|
|
338 |
|
|
|
339 |
if (adjust_start_ind) start_ind <- start_ind - start_ind[1] + 1 |
|
|
340 |
numberOfSamples <- length(start_ind) |
|
|
341 |
|
|
|
342 |
if (is.null(n_gram_stride)) n_gram_stride <- 1 |
|
|
343 |
voc_len <- length(vocabulary) |
|
|
344 |
if (!is.null(n_gram)) { |
|
|
345 |
sequence <- int_to_n_gram(int_seq = sequence, n = n_gram, voc_size = length(vocabulary)) |
|
|
346 |
maxlen <- ceiling((maxlen - n_gram + 1)/n_gram_stride) |
|
|
347 |
voc_len <- length(vocabulary)^n_gram |
|
|
348 |
} |
|
|
349 |
|
|
|
350 |
if (!is.null(masked_lm)) { |
|
|
351 |
l <- mask_seq(int_seq = sequence, |
|
|
352 |
mask_rate = masked_lm$mask_rate, |
|
|
353 |
random_rate = masked_lm$random_rate, |
|
|
354 |
identity_rate = masked_lm$identity_rate, |
|
|
355 |
start_ind = start_ind, |
|
|
356 |
block_len = masked_lm$block_len, |
|
|
357 |
voc_len = voc_len) |
|
|
358 |
masked_seq <- l$masked_seq |
|
|
359 |
sample_weight_seq <- l$sample_weight_seq |
|
|
360 |
} |
|
|
361 |
|
|
|
362 |
if (!return_int) { |
|
|
363 |
if (!is.null(masked_lm)) { |
|
|
364 |
# every row in z one-hot encodes one character in sequence, oov is zero-vector |
|
|
365 |
z_masked <- keras::to_categorical(masked_seq, num_classes = voc_len + 2)[ , -c(1)] |
|
|
366 |
z_masked <- matrix(z_masked, ncol = voc_len + 1) |
|
|
367 |
z <- keras::to_categorical(sequence, num_classes = voc_len + 2)[ , -c(1)] |
|
|
368 |
z <- matrix(z, ncol = voc_len + 1) |
|
|
369 |
} else { |
|
|
370 |
# every row in z one-hot encodes one character in sequence, oov is zero-vector |
|
|
371 |
z <- keras::to_categorical(sequence, num_classes = voc_len + 2)[ , -c(1, voc_len + 2)] |
|
|
372 |
z <- matrix(z, ncol = voc_len) |
|
|
373 |
} |
|
|
374 |
} |
|
|
375 |
|
|
|
376 |
if (use_quality) { |
|
|
377 |
ones_pos <- apply(z, 1, which.max) |
|
|
378 |
is_zero_row <- apply(z == 0, 1, all) |
|
|
379 |
z <- purrr::map(1:length(quality_vector), ~create_quality_vector(pos = ones_pos[.x], prob = quality_vector[.x], |
|
|
380 |
voc_length = voc_len)) %>% unlist() %>% matrix(ncol = voc_len, byrow = TRUE) |
|
|
381 |
z[is_zero_row, ] <- 0 |
|
|
382 |
} |
|
|
383 |
|
|
|
384 |
if (ambiguous_nuc == "equal") { |
|
|
385 |
amb_nuc_pos <- which(sequence == (voc_len + 1)) |
|
|
386 |
z[amb_nuc_pos, ] <- matrix(rep(1/voc_len, ncol(z) * length(amb_nuc_pos)), ncol = ncol(z)) |
|
|
387 |
} |
|
|
388 |
|
|
|
389 |
if (ambiguous_nuc == "empirical") { |
|
|
390 |
amb_nuc_pos <- which(sequence == (voc_len + 1)) |
|
|
391 |
z[amb_nuc_pos, ] <- matrix(rep(nuc_dist, length(amb_nuc_pos)), nrow = length(amb_nuc_pos), byrow = TRUE) |
|
|
392 |
} |
|
|
393 |
|
|
|
394 |
if (use_coverage) { |
|
|
395 |
z <- z * (cov_vector/max_cov) |
|
|
396 |
} |
|
|
397 |
|
|
|
398 |
remove_end_of_seq <- ifelse(is.null(n_gram), 1, n_gram) |
|
|
399 |
|
|
|
400 |
if (!return_int) { |
|
|
401 |
if (is.null(masked_lm)) { |
|
|
402 |
|
|
|
403 |
x <- array(0, dim = c(numberOfSamples, maxlen, voc_len)) |
|
|
404 |
for (i in 1:numberOfSamples) { |
|
|
405 |
start <- start_ind[i] |
|
|
406 |
subset_index <- seq(start, (start + maxlen_original - remove_end_of_seq), by = n_gram_stride) |
|
|
407 |
x[i, , ] <- z[subset_index, ] |
|
|
408 |
} |
|
|
409 |
return(x) |
|
|
410 |
|
|
|
411 |
} else { |
|
|
412 |
|
|
|
413 |
x <- array(0, dim = c(numberOfSamples, maxlen, voc_len + 1)) |
|
|
414 |
y <- array(0, dim = c(numberOfSamples, maxlen, voc_len + 1)) |
|
|
415 |
sw <- array(0, dim = c(numberOfSamples, maxlen)) |
|
|
416 |
|
|
|
417 |
for (i in 1:numberOfSamples) { |
|
|
418 |
start <- start_ind[i] |
|
|
419 |
subset_index <- seq(start, (start + maxlen - remove_end_of_seq), by = n_gram_stride) |
|
|
420 |
x[i, , ] <- z_masked[subset_index, ] |
|
|
421 |
y[i, , ] <- z[subset_index, ] |
|
|
422 |
sw[i, ] <- sample_weight_seq[subset_index] |
|
|
423 |
} |
|
|
424 |
return(list(x=x, y=y, sample_weight=sw)) |
|
|
425 |
|
|
|
426 |
} |
|
|
427 |
} |
|
|
428 |
|
|
|
429 |
if (return_int) { |
|
|
430 |
if (is.null(masked_lm)) { |
|
|
431 |
|
|
|
432 |
x <- array(0, dim = c(numberOfSamples, maxlen)) |
|
|
433 |
for (i in 1:numberOfSamples) { |
|
|
434 |
start <- start_ind[i] |
|
|
435 |
subset_index <- seq(start, (start + maxlen_original - remove_end_of_seq), by = n_gram_stride) |
|
|
436 |
x[i, ] <- sequence[subset_index] |
|
|
437 |
} |
|
|
438 |
return(x) |
|
|
439 |
|
|
|
440 |
} else { |
|
|
441 |
x <- array(0, dim = c(numberOfSamples, maxlen)) |
|
|
442 |
y <- array(0, dim = c(numberOfSamples, maxlen)) |
|
|
443 |
sw <- array(0, dim = c(numberOfSamples, maxlen)) |
|
|
444 |
for (i in 1:numberOfSamples) { |
|
|
445 |
start <- start_ind[i] |
|
|
446 |
subset_index <- seq(start, (start + maxlen_original - remove_end_of_seq), by = n_gram_stride) |
|
|
447 |
x[i, ] <- masked_seq[subset_index] |
|
|
448 |
y[i, ] <- sequence[subset_index] |
|
|
449 |
sw[i, ] <- sample_weight_seq[subset_index] |
|
|
450 |
} |
|
|
451 |
return(list(x=x, y=y, sample_weight=sw)) |
|
|
452 |
|
|
|
453 |
} |
|
|
454 |
} |
|
|
455 |
|
|
|
456 |
} |
|
|
457 |
|
|
|
458 |
#' Computes start position of samples |
|
|
459 |
#' |
|
|
460 |
#' Helper function for data generators. |
|
|
461 |
#' Computes start positions in sequence where samples can be extracted, given maxlen, step size and ambiguous nucleotide constraints. |
|
|
462 |
#' |
|
|
463 |
#' @inheritParams train_model |
|
|
464 |
#' @param seq_vector Vector of character sequences. |
|
|
465 |
#' @param length_vector Length of sequences in \code{seq_vector}. |
|
|
466 |
#' @param maxlen Length of one predictor sequence. |
|
|
467 |
#' @param step Distance between samples from one entry in \code{seq_vector}. |
|
|
468 |
#' @param train_mode Either `"lm"` for language model or `"label"` for label classification. |
|
|
469 |
#' @param discard_amb_nuc Whether to discard all samples that contain characters outside vocabulary. |
|
|
470 |
#' @examples |
|
|
471 |
#' seq_vector <- c("AAACCCNNNGGGTTT") |
|
|
472 |
#' get_start_ind( |
|
|
473 |
#' seq_vector = seq_vector, |
|
|
474 |
#' length_vector = nchar(seq_vector), |
|
|
475 |
#' maxlen = 4, |
|
|
476 |
#' step = 2, |
|
|
477 |
#' train_mode = "label", |
|
|
478 |
#' discard_amb_nuc = TRUE, |
|
|
479 |
#' vocabulary = c("A", "C", "G", "T")) |
|
|
480 |
#' |
|
|
481 |
#' @returns A numeric vector. |
|
|
482 |
#' @export |
|
|
483 |
get_start_ind <- function(seq_vector, length_vector, maxlen, |
|
|
484 |
step, train_mode = "label", |
|
|
485 |
discard_amb_nuc = FALSE, |
|
|
486 |
vocabulary = c("A", "C", "G", "T")) { |
|
|
487 |
|
|
|
488 |
stopifnot(train_mode == "lm" | train_mode == "label") |
|
|
489 |
if (!discard_amb_nuc) { |
|
|
490 |
if (length(length_vector) > 1) { |
|
|
491 |
startNewEntry <- cumsum(c(1, length_vector[-length(length_vector)])) |
|
|
492 |
if (train_mode == "label") { |
|
|
493 |
indexVector <- purrr::map(1:(length(length_vector) - 1), ~seq(startNewEntry[.x], startNewEntry[.x + 1] - maxlen, by = step)) |
|
|
494 |
} else { |
|
|
495 |
indexVector <- purrr::map(1:(length(length_vector) - 1), ~seq(startNewEntry[.x], startNewEntry[.x + 1] - maxlen - 1, by = step)) |
|
|
496 |
} |
|
|
497 |
indexVector <- unlist(indexVector) |
|
|
498 |
last_seq <- length(seq_vector) |
|
|
499 |
if (!(startNewEntry[last_seq] > (sum(length_vector) - maxlen + 1))) { |
|
|
500 |
if (train_mode == "label") { |
|
|
501 |
indexVector <- c(indexVector, seq(startNewEntry[last_seq], sum(length_vector) - maxlen + 1, by = step)) |
|
|
502 |
} else { |
|
|
503 |
indexVector <- c(indexVector, seq(startNewEntry[last_seq], sum(length_vector) - maxlen, by = step)) |
|
|
504 |
} |
|
|
505 |
} |
|
|
506 |
return(indexVector) |
|
|
507 |
} else { |
|
|
508 |
if (train_mode == "label") { |
|
|
509 |
indexVector <- seq(1, length_vector - maxlen + 1, by = step) |
|
|
510 |
} else { |
|
|
511 |
indexVector <- seq(1, length_vector - maxlen, by = step) |
|
|
512 |
} |
|
|
513 |
} |
|
|
514 |
} else { |
|
|
515 |
indexVector <- start_ind_ignore_amb(seq_vector = seq_vector, length_vector = length_vector, |
|
|
516 |
maxlen = maxlen, step = step, vocabulary = c(vocabulary, "0"), train_mode = train_mode) |
|
|
517 |
} |
|
|
518 |
return(indexVector) |
|
|
519 |
} |
|
|
520 |
|
|
|
521 |
|
|
|
522 |
#' Helper function for get_start_ind, extracts the start positions of all potential samples (considering step size and vocabulary) |
|
|
523 |
#' |
|
|
524 |
#' @param seq Sequences. |
|
|
525 |
#' @param maxlen Length of one sample. |
|
|
526 |
#' @param step How often to take a sample. |
|
|
527 |
#' @param vocabulary Vector of allowed characters in samples. |
|
|
528 |
#' @param train_mode "lm" or "label". |
|
|
529 |
#' @noRd |
|
|
530 |
start_ind_ignore_amb_single_seq <- function(seq, maxlen, step, vocabulary, train_mode = "lm") { |
|
|
531 |
|
|
|
532 |
vocabulary <- stringr::str_to_lower(vocabulary) |
|
|
533 |
vocabulary <- c(vocabulary, "0") |
|
|
534 |
seq <- stringr::str_to_lower(seq) |
|
|
535 |
len_seq <- nchar(seq) |
|
|
536 |
if (train_mode != "label") maxlen <- maxlen + 1 |
|
|
537 |
stopifnot(len_seq >= maxlen) |
|
|
538 |
# regular expressions for allowed characters |
|
|
539 |
voc_pattern <- paste0("[^", paste0(vocabulary, collapse = ""), "]") |
|
|
540 |
|
|
|
541 |
pos_of_amb_nucleotides <- stringr::str_locate_all(seq, pattern = voc_pattern)[[1]][ , 1] |
|
|
542 |
non_start_index <- pos_of_amb_nucleotides - maxlen + 1 |
|
|
543 |
|
|
|
544 |
# define range of unallowed start indices |
|
|
545 |
non_start_index <- purrr::map(non_start_index, ~(.x:(.x + maxlen - 1))) %>% |
|
|
546 |
unlist() %>% union((len_seq - maxlen + 2):len_seq) %>% unique() |
|
|
547 |
# drop non-positive values |
|
|
548 |
if (length(non_start_index[non_start_index < 1])) { |
|
|
549 |
non_start_index <- unique(c(1, non_start_index[non_start_index >= 1])) |
|
|
550 |
} |
|
|
551 |
|
|
|
552 |
non_start_index <- non_start_index %>% sort() |
|
|
553 |
allowed_start <- setdiff(1:len_seq, non_start_index) |
|
|
554 |
len_start_vector <- length(allowed_start) |
|
|
555 |
|
|
|
556 |
|
|
|
557 |
if (len_start_vector < 1) { |
|
|
558 |
# message("Can not extract a single sampling point with current settings.") |
|
|
559 |
return(NULL) |
|
|
560 |
} |
|
|
561 |
|
|
|
562 |
# only keep indices with sufficient distance, as defined by step |
|
|
563 |
start_indices <- vector("integer") |
|
|
564 |
index <- allowed_start[1] |
|
|
565 |
start_indices[1] <- index |
|
|
566 |
count <- 1 |
|
|
567 |
if (length(allowed_start) > 1) { |
|
|
568 |
for (j in 1:(length(allowed_start) - 1)) { |
|
|
569 |
if (allowed_start[j + 1] - index >= step) { |
|
|
570 |
count <- count + 1 |
|
|
571 |
start_indices[count] <- allowed_start[j + 1] |
|
|
572 |
index <- allowed_start[j + 1] |
|
|
573 |
} |
|
|
574 |
} |
|
|
575 |
} |
|
|
576 |
|
|
|
577 |
start_indices |
|
|
578 |
} |
|
|
579 |
|
|
|
580 |
|
|
|
581 |
#' Helper function for get_start_ind, extracts the start positions of all potential samples (considering step size and vocabulary) |
|
|
582 |
#' |
|
|
583 |
#' @param seq_vector Vector of character sequences. |
|
|
584 |
#' @param maxlen Length of one sample. |
|
|
585 |
#' @param step How often to take a sample. |
|
|
586 |
#' @param vocabulary Vector of allowed characters in samples. |
|
|
587 |
#' @param train_mode "lm" or "label". |
|
|
588 |
#' @noRd |
|
|
589 |
start_ind_ignore_amb <- function(seq_vector, length_vector, maxlen, step, vocabulary, train_mode = "lm") { |
|
|
590 |
start_ind <- purrr::map(1:length(seq_vector), ~start_ind_ignore_amb_single_seq(seq = seq_vector[.x], |
|
|
591 |
maxlen = maxlen, |
|
|
592 |
step = step, |
|
|
593 |
vocabulary = vocabulary, |
|
|
594 |
train_mode = train_mode)) |
|
|
595 |
|
|
|
596 |
cum_sum_length <- cumsum(length_vector) |
|
|
597 |
if (length(start_ind) > 1) { |
|
|
598 |
for (i in 2:length(start_ind)) { |
|
|
599 |
start_ind[[i]] <- start_ind[[i]] + cum_sum_length[i - 1] |
|
|
600 |
} |
|
|
601 |
} |
|
|
602 |
start_ind <- unlist(start_ind) |
|
|
603 |
start_ind |
|
|
604 |
} |
|
|
605 |
|
|
|
606 |
quality_to_probability <- function(quality_vector) { |
|
|
607 |
Q <- utf8ToInt(quality_vector) - 33 |
|
|
608 |
1 - 10^(-Q/10) |
|
|
609 |
} |
|
|
610 |
|
|
|
611 |
create_quality_vector <- function(pos, prob, voc_length = 4) { |
|
|
612 |
vec <- rep(0, voc_length) |
|
|
613 |
vec[pos] <- prob |
|
|
614 |
vec[-pos] <- (1 - prob)/(voc_length - 1) |
|
|
615 |
vec |
|
|
616 |
} |
|
|
617 |
|
|
|
618 |
remove_amb_nuc_entries <- function(fasta.file, skip_amb_nuc, pattern) { |
|
|
619 |
chars_per_row <- nchar(fasta.file$Sequence) |
|
|
620 |
amb_per_row <- stringr::str_count(stringr::str_to_lower(fasta.file$Sequence), pattern) |
|
|
621 |
threshold_index <- (amb_per_row/chars_per_row) > skip_amb_nuc |
|
|
622 |
fasta.file <- fasta.file[!threshold_index, ] |
|
|
623 |
fasta.file |
|
|
624 |
} |
|
|
625 |
|
|
|
626 |
#' Estimate frequency of different classes |
|
|
627 |
#' |
|
|
628 |
#' Count number of nucleotides for each class and use as estimation for relation of class distribution. |
|
|
629 |
#' Outputs list of class relations. Can be used as input for \code{class_weigth} in \code{\link{train_model}} function. |
|
|
630 |
#' |
|
|
631 |
#' @inheritParams generator_fasta_lm |
|
|
632 |
#' @inheritParams generator_fasta_label_header_csv |
|
|
633 |
#' @inheritParams train_model |
|
|
634 |
#' @param file_proportion Proportion of files to randomly sample for estimating class distributions. |
|
|
635 |
#' @param csv_path If `train_type = "label_csv"`, path to csv file containing labels. |
|
|
636 |
#' @param named_list Whether to give class weight list names `"0", "1", ...` or not. |
|
|
637 |
#' @examples |
|
|
638 |
#' |
|
|
639 |
#' # create dummy data |
|
|
640 |
#' path_1 <- tempfile() |
|
|
641 |
#' path_2 <- tempfile() |
|
|
642 |
#' |
|
|
643 |
#' for (current_path in c(path_1, path_2)) { |
|
|
644 |
#' |
|
|
645 |
#' dir.create(current_path) |
|
|
646 |
#' # create twice as much data for first class |
|
|
647 |
#' num_files <- ifelse(current_path == path_1, 6, 3) |
|
|
648 |
#' create_dummy_data(file_path = current_path, |
|
|
649 |
#' num_files = num_files, |
|
|
650 |
#' seq_length = 10, |
|
|
651 |
#' num_seq = 5, |
|
|
652 |
#' vocabulary = c("a", "c", "g", "t")) |
|
|
653 |
#' } |
|
|
654 |
#' |
|
|
655 |
#' |
|
|
656 |
#' class_weight <- get_class_weight( |
|
|
657 |
#' path = c(path_1, path_2), |
|
|
658 |
#' vocabulary_label = c("A", "B"), |
|
|
659 |
#' format = "fasta", |
|
|
660 |
#' file_proportion = 1, |
|
|
661 |
#' train_type = "label_folder", |
|
|
662 |
#' csv_path = NULL) |
|
|
663 |
#' |
|
|
664 |
#' class_weight |
|
|
665 |
#' |
|
|
666 |
#' @returns A list of numeric values (class weights). |
|
|
667 |
#' @export |
|
|
668 |
get_class_weight <- function(path, |
|
|
669 |
vocabulary_label = NULL, |
|
|
670 |
format = "fasta", |
|
|
671 |
file_proportion = 1, |
|
|
672 |
train_type = "label_folder", |
|
|
673 |
named_list = FALSE, |
|
|
674 |
csv_path = NULL) { |
|
|
675 |
|
|
|
676 |
classes <- count_nuc(path = path, |
|
|
677 |
vocabulary_label = vocabulary_label, |
|
|
678 |
format = format, |
|
|
679 |
file_proportion = file_proportion, |
|
|
680 |
train_type = train_type, |
|
|
681 |
csv_path = csv_path) |
|
|
682 |
|
|
|
683 |
zero_entry <- classes == 0 |
|
|
684 |
if (sum(zero_entry) > 0) { |
|
|
685 |
warning_message <- paste("The following classes have no samples:", paste(vocabulary_label[zero_entry]), |
|
|
686 |
"\n Try bigger file_proportion size or check vocabulary_label.") |
|
|
687 |
warning(warning_message) |
|
|
688 |
} |
|
|
689 |
|
|
|
690 |
if (!is.list(classes)) { |
|
|
691 |
num_classes <- length(classes) |
|
|
692 |
total <- sum(classes) |
|
|
693 |
weight_list <- list() |
|
|
694 |
for (i in 1:(length(classes))) { |
|
|
695 |
weight_list[[as.character(i-1)]] <- total/(classes[i] * num_classes) |
|
|
696 |
} |
|
|
697 |
if (!named_list) names(classes) <- NULL # no list names in tf version > 2.8 |
|
|
698 |
classes <- weight_list |
|
|
699 |
} else { |
|
|
700 |
weight_collection <- list() |
|
|
701 |
for (j in 1:length(classes)) { |
|
|
702 |
num_classes <- length(classes[[j]]) |
|
|
703 |
total <- sum(classes[[j]]) |
|
|
704 |
weight_list <- list() |
|
|
705 |
for (i in 1:(length(classes[[j]]))) { |
|
|
706 |
weight_list[[as.character(i-1)]] <- total/(classes[[j]][i] * num_classes) |
|
|
707 |
} |
|
|
708 |
if (!named_list) names(classes) <- NULL |
|
|
709 |
weigth_collection[[j]] <- weight_list |
|
|
710 |
} |
|
|
711 |
classes <- weight_collection |
|
|
712 |
} |
|
|
713 |
|
|
|
714 |
classes |
|
|
715 |
} |
|
|
716 |
|
|
|
717 |
#' Count nucleotides per class |
|
|
718 |
#' |
|
|
719 |
#' @inheritParams get_class_weight |
|
|
720 |
#' @noRd |
|
|
721 |
count_nuc <- function(path, |
|
|
722 |
vocabulary_label = NULL, |
|
|
723 |
format = "fasta", |
|
|
724 |
# estimate class distribution from subset |
|
|
725 |
file_proportion = 1, |
|
|
726 |
train_type = "label_folder", |
|
|
727 |
csv_path = NULL) { |
|
|
728 |
|
|
|
729 |
classes <- rep(0, length(vocabulary_label)) |
|
|
730 |
names(classes) <- vocabulary_label |
|
|
731 |
|
|
|
732 |
# label by folder |
|
|
733 |
if (train_type == "label_folder") { |
|
|
734 |
for (j in 1:length(path)) { |
|
|
735 |
files <- list.files(path[[j]], full.names = TRUE) |
|
|
736 |
if (file_proportion < 1) { |
|
|
737 |
files <- sample(files, floor(file_proportion * length(files))) |
|
|
738 |
} |
|
|
739 |
for (i in files) { |
|
|
740 |
if (format == "fasta") { |
|
|
741 |
fasta.file <- microseq::readFasta(i) |
|
|
742 |
} |
|
|
743 |
if (format == "fastq") { |
|
|
744 |
fasta.file <- microseq::readFastq(i) |
|
|
745 |
} |
|
|
746 |
freq <- sum(nchar(fasta.file$Sequence)) |
|
|
747 |
classes[j] <- classes[j] + freq |
|
|
748 |
} |
|
|
749 |
} |
|
|
750 |
} |
|
|
751 |
|
|
|
752 |
# label header |
|
|
753 |
if (train_type == "label_header") { |
|
|
754 |
files <- list.files(unlist(path), full.names = TRUE) |
|
|
755 |
if (file_proportion < 1) { |
|
|
756 |
files <- sample(files, floor(file_proportion * length(files))) |
|
|
757 |
} |
|
|
758 |
for (i in files) { |
|
|
759 |
if (format == "fasta") { |
|
|
760 |
fasta.file <- microseq::readFasta(i) |
|
|
761 |
} |
|
|
762 |
if (format == "fastq") { |
|
|
763 |
fasta.file <- microseq::readFastq(i) |
|
|
764 |
} |
|
|
765 |
df <- data.frame(Header = fasta.file$Header, freq = nchar(fasta.file$Sequence)) |
|
|
766 |
df <- stats::aggregate(df$freq, by = list(Category = df$Header), FUN = sum) |
|
|
767 |
freq <- df$x |
|
|
768 |
names(freq) <- df$Category |
|
|
769 |
for (k in names(freq)) { |
|
|
770 |
classes[k] <- classes[k] + freq[k] |
|
|
771 |
} |
|
|
772 |
} |
|
|
773 |
} |
|
|
774 |
|
|
|
775 |
# label csv |
|
|
776 |
if (train_type == "label_csv") { |
|
|
777 |
|
|
|
778 |
label_csv <- utils::read.csv2(csv_path, header = TRUE, stringsAsFactors = FALSE) |
|
|
779 |
if (dim(label_csv)[2] == 1) { |
|
|
780 |
label_csv <- utils::read.csv(csv_path, header = TRUE, stringsAsFactors = FALSE) |
|
|
781 |
} |
|
|
782 |
if (!("file" %in% names(label_csv))) { |
|
|
783 |
stop('csv file needs one column named "file"') |
|
|
784 |
} |
|
|
785 |
|
|
|
786 |
row_sums <- label_csv %>% dplyr::select(-file) %>% rowSums() |
|
|
787 |
if (!(all(row_sums == 1))) { |
|
|
788 |
stop("Can only estimate class weights if labels are mutually exclusive.") |
|
|
789 |
} |
|
|
790 |
|
|
|
791 |
if (is.null(vocabulary_label) || missing(vocabulary_label)) { |
|
|
792 |
vocabulary_label <- names(label_csv)[!names(label_csv) == "file"] |
|
|
793 |
} else { |
|
|
794 |
label_csv <- label_csv %>% dplyr::select(c(dplyr::all_of(vocabulary_label), "file")) |
|
|
795 |
} |
|
|
796 |
|
|
|
797 |
classes <- rep(0, length(vocabulary_label)) |
|
|
798 |
names(classes) <- vocabulary_label |
|
|
799 |
|
|
|
800 |
path <- unlist(path) |
|
|
801 |
single_file_index <- stringr::str_detect(path, "fasta$|fastq$") |
|
|
802 |
files <- c(list.files(path[!single_file_index], full.names = TRUE), path[single_file_index]) |
|
|
803 |
if (file_proportion < 1) { |
|
|
804 |
files <- sample(files, floor(file_proportion * length(files))) |
|
|
805 |
} |
|
|
806 |
for (i in files) { |
|
|
807 |
if (format == "fasta") { |
|
|
808 |
fasta.file <- microseq::readFasta(i) |
|
|
809 |
} |
|
|
810 |
if (format == "fastq") { |
|
|
811 |
fasta.file <- microseq::readFastq(i) |
|
|
812 |
} |
|
|
813 |
count_nuc <- sum(nchar(fasta.file$Sequence)) |
|
|
814 |
df <- label_csv %>% dplyr::filter(file == basename(i)) |
|
|
815 |
if (nrow(df) == 0) next |
|
|
816 |
index <- df[1, ] == 1 |
|
|
817 |
current_label <- names(df)[index] |
|
|
818 |
classes[current_label] <- classes[current_label] + count_nuc |
|
|
819 |
} |
|
|
820 |
} |
|
|
821 |
return(classes) |
|
|
822 |
} |
|
|
823 |
|
|
|
824 |
read_fasta_fastq <- function(format, skip_amb_nuc, file_index, pattern, shuffle_input, |
|
|
825 |
reverse_complement, fasta.files, use_coverage = FALSE, proportion_entries = NULL, |
|
|
826 |
vocabulary_label = NULL, filter_header = FALSE, target_from_csv = NULL) { |
|
|
827 |
|
|
|
828 |
if (stringr::str_detect(format, "fasta")) { |
|
|
829 |
if (is.null(skip_amb_nuc)) { |
|
|
830 |
fasta.file <- microseq::readFasta(fasta.files[file_index]) |
|
|
831 |
} else { |
|
|
832 |
fasta.file <- remove_amb_nuc_entries(microseq::readFasta(fasta.files[file_index]), skip_amb_nuc = skip_amb_nuc, |
|
|
833 |
pattern = pattern) |
|
|
834 |
} |
|
|
835 |
|
|
|
836 |
if (filter_header & is.null(target_from_csv)) { |
|
|
837 |
label_vector <- trimws(stringr::str_to_lower(fasta.file$Header)) |
|
|
838 |
label_filter <- label_vector %in% vocabulary_label |
|
|
839 |
fasta.file <- fasta.file[label_filter, ] |
|
|
840 |
} |
|
|
841 |
|
|
|
842 |
if (!is.null(proportion_entries) && proportion_entries < 1) { |
|
|
843 |
index <- sample(nrow(fasta.file), max(1, floor(nrow(fasta.file) * proportion_entries))) |
|
|
844 |
fasta.file <- fasta.file[index, ] |
|
|
845 |
} |
|
|
846 |
|
|
|
847 |
if (shuffle_input) { |
|
|
848 |
fasta.file <- fasta.file[sample(nrow(fasta.file)), ] |
|
|
849 |
} |
|
|
850 |
|
|
|
851 |
if (reverse_complement) { |
|
|
852 |
index <- sample(c(TRUE, FALSE), nrow(fasta.file), replace = TRUE) |
|
|
853 |
fasta.file$Sequence[index] <- microseq::reverseComplement(fasta.file$Sequence[index]) |
|
|
854 |
} |
|
|
855 |
|
|
|
856 |
} |
|
|
857 |
|
|
|
858 |
if (stringr::str_detect(format, "fastq")) { |
|
|
859 |
if (is.null(skip_amb_nuc)) { |
|
|
860 |
fasta.file <- microseq::readFastq(fasta.files[file_index]) |
|
|
861 |
} else { |
|
|
862 |
fasta.file <- remove_amb_nuc_entries(microseq::readFastq(fasta.files[file_index]), skip_amb_nuc = skip_amb_nuc, |
|
|
863 |
pattern = pattern) |
|
|
864 |
} |
|
|
865 |
|
|
|
866 |
if (filter_header & is.null(target_from_csv)) { |
|
|
867 |
label_vector <- trimws(stringr::str_to_lower(fasta.file$Header)) |
|
|
868 |
label_filter <- label_vector %in% vocabulary_label |
|
|
869 |
fasta.file <- fasta.file[label_filter, ] |
|
|
870 |
} |
|
|
871 |
|
|
|
872 |
if (!is.null(proportion_entries) && proportion_entries < 1) { |
|
|
873 |
index <- sample(nrow(fasta.file), max(1, floor(nrow(fasta.file) * proportion_entries))) |
|
|
874 |
fasta.file <- fasta.file[index, ] |
|
|
875 |
} |
|
|
876 |
|
|
|
877 |
if (shuffle_input) { |
|
|
878 |
fasta.file <- fasta.file[sample(nrow(fasta.file)), ] |
|
|
879 |
} |
|
|
880 |
|
|
|
881 |
if (reverse_complement & sample(c(TRUE, FALSE), 1)) { |
|
|
882 |
fasta.file$Sequence <- microseq::reverseComplement(fasta.file$Sequence) |
|
|
883 |
} |
|
|
884 |
} |
|
|
885 |
return(fasta.file) |
|
|
886 |
} |
|
|
887 |
|
|
|
888 |
input_from_csv <- function(added_label_path) { |
|
|
889 |
.datatable.aware = TRUE |
|
|
890 |
label_csv <- utils::read.csv2(added_label_path, header = TRUE, stringsAsFactors = FALSE) |
|
|
891 |
if (dim(label_csv)[2] == 1) { |
|
|
892 |
label_csv <- utils::read.csv(added_label_path, header = TRUE, stringsAsFactors = FALSE) |
|
|
893 |
} |
|
|
894 |
label_csv <- data.table::as.data.table(label_csv) |
|
|
895 |
label_csv$file <- stringr::str_to_lower(as.character(label_csv$file)) |
|
|
896 |
data.table::setkey(label_csv, file) |
|
|
897 |
added_label_by_header <- FALSE |
|
|
898 |
|
|
|
899 |
if (!("file" %in% names(label_csv))) { |
|
|
900 |
stop('names in added_label_path should contain one column named "file" ') |
|
|
901 |
} |
|
|
902 |
col_name <- ifelse(added_label_by_header, "header", "file") |
|
|
903 |
return(list(label_csv = label_csv, col_name = col_name)) |
|
|
904 |
} |
|
|
905 |
|
|
|
906 |
#' @rawNamespace import(data.table, except = c(first, last, between)) |
|
|
907 |
#' @noRd |
|
|
908 |
csv_to_tensor <- function(label_csv, added_label_vector, added_label_by_header, batch_size, |
|
|
909 |
start_index_list) { |
|
|
910 |
.datatable.aware = TRUE |
|
|
911 |
label_tensor <- matrix(0, ncol = ncol(label_csv) - 1, nrow = batch_size, byrow = TRUE) |
|
|
912 |
|
|
|
913 |
if (added_label_by_header) { |
|
|
914 |
header_unique <- unique(added_label_vector) |
|
|
915 |
for (i in header_unique) { |
|
|
916 |
label_from_csv <- label_csv[ .(i), -"header"] |
|
|
917 |
index_label_vector <- added_label_vector == i |
|
|
918 |
if (nrow(label_from_csv) > 0) { |
|
|
919 |
label_tensor[index_label_vector, ] <- matrix(as.matrix(label_from_csv[1, ]), |
|
|
920 |
nrow = sum(index_label_vector), ncol = ncol(label_tensor), byrow = TRUE) |
|
|
921 |
} |
|
|
922 |
} |
|
|
923 |
} else { |
|
|
924 |
row_index <- 1 |
|
|
925 |
for (i in 1:length(added_label_vector)) { |
|
|
926 |
row_filter <- added_label_vector[i] |
|
|
927 |
label_from_csv <- label_csv[data.table(row_filter), -"file"] |
|
|
928 |
samples_per_file <- length(start_index_list[[i]]) |
|
|
929 |
assign_rows <- row_index:(row_index + samples_per_file - 1) |
|
|
930 |
|
|
|
931 |
if (nrow(stats::na.omit(label_from_csv)) > 0) { |
|
|
932 |
label_tensor[assign_rows, ] <- matrix(as.matrix(label_from_csv[1, ]), |
|
|
933 |
nrow = samples_per_file, ncol = ncol(label_tensor), byrow = TRUE) |
|
|
934 |
} |
|
|
935 |
row_index <- row_index + samples_per_file |
|
|
936 |
} |
|
|
937 |
} |
|
|
938 |
return(label_tensor) |
|
|
939 |
} |
|
|
940 |
|
|
|
941 |
#' Divide tensor to list of subsets |
|
|
942 |
#' |
|
|
943 |
#' @noRd |
|
|
944 |
slice_tensor <- function(tensor, target_split) { |
|
|
945 |
|
|
|
946 |
num_row <- nrow(tensor) |
|
|
947 |
l <- vector("list", length = length(target_split)) |
|
|
948 |
for (i in 1:length(target_split)) { |
|
|
949 |
if (length(target_split[[i]]) == 1 | num_row == 1) { |
|
|
950 |
l[[i]] <- matrix(tensor[ , target_split[[i]]], ncol = length(target_split[[i]])) |
|
|
951 |
} else { |
|
|
952 |
l[[i]] <- tensor[ , target_split[[i]]] |
|
|
953 |
} |
|
|
954 |
} |
|
|
955 |
return(l) |
|
|
956 |
} |
|
|
957 |
|
|
|
958 |
check_header_names <- function(target_split, vocabulary_label) { |
|
|
959 |
target_split <- unlist(target_split) |
|
|
960 |
if (!all(target_split %in% vocabulary_label)) { |
|
|
961 |
stop_text <- paste("Your csv file has no columns named", |
|
|
962 |
paste(target_split[!(target_split %in% vocabulary_label)], collapse = " ")) |
|
|
963 |
stop(stop_text) |
|
|
964 |
} |
|
|
965 |
if (!all(vocabulary_label %in% target_split)) { |
|
|
966 |
warning_text <- paste("target_split does not cover the following columns:", |
|
|
967 |
paste(vocabulary_label[!(vocabulary_label %in% target_split)], collapse = " ")) |
|
|
968 |
warning(warning_text) |
|
|
969 |
} |
|
|
970 |
} |
|
|
971 |
|
|
|
972 |
count_files <- function(path, format = "fasta", train_type, |
|
|
973 |
target_from_csv = NULL, train_val_split_csv = NULL) { |
|
|
974 |
|
|
|
975 |
num_files <- rep(0, length(path)) |
|
|
976 |
if (!is.null(target_from_csv) & train_type == "label_csv") { |
|
|
977 |
target_files <- utils::read.csv(target_from_csv) |
|
|
978 |
if (ncol(target_files) == 1) target_files <- utils::read.csv2(target_from_csv) |
|
|
979 |
target_files <- target_files$file |
|
|
980 |
# are files given with absolute path |
|
|
981 |
full.names <- ifelse(dirname(target_files[1]) == ".", FALSE, TRUE) |
|
|
982 |
} |
|
|
983 |
if (!is.null(train_val_split_csv)) { |
|
|
984 |
tvt_files <- utils::read.csv(train_val_split_csv) |
|
|
985 |
if (ncol(tvt_files) == 1) tvt_files <- utils::read.csv2(train_val_split_csv) |
|
|
986 |
train_index <- tvt_files$type == "train" |
|
|
987 |
tvt_files <- tvt_files$file |
|
|
988 |
target_files <- intersect(tvt_files[train_index], target_files) |
|
|
989 |
} |
|
|
990 |
|
|
|
991 |
for (i in 1:length(path)) { |
|
|
992 |
for (k in 1:length(path[[i]])) { |
|
|
993 |
current_path <- path[[i]][[k]] |
|
|
994 |
|
|
|
995 |
if (!is.null(train_val_split_csv)) { |
|
|
996 |
if (!(current_path %in% target_files)) next |
|
|
997 |
} |
|
|
998 |
|
|
|
999 |
if (endsWith(current_path, paste0(".", format))) { |
|
|
1000 |
# remove files not in csv file |
|
|
1001 |
if (!is.null(target_from_csv)) { |
|
|
1002 |
current_files <- length(intersect(basename(target_files), basename(current_path))) |
|
|
1003 |
} else { |
|
|
1004 |
current_files <- 1 |
|
|
1005 |
} |
|
|
1006 |
} else { |
|
|
1007 |
# remove files not in csv file |
|
|
1008 |
if (!is.null(target_from_csv)) { |
|
|
1009 |
current_files <- list.files(current_path, pattern = paste0(".", format, "$"), full.names = full.names) %>% |
|
|
1010 |
intersect(target_files) %>% length() |
|
|
1011 |
} else { |
|
|
1012 |
current_files <- list.files(current_path, pattern = paste0(".", format, "$")) %>% length() |
|
|
1013 |
} |
|
|
1014 |
} |
|
|
1015 |
num_files[i] <- num_files[i] + current_files |
|
|
1016 |
|
|
|
1017 |
if (current_files == 0) { |
|
|
1018 |
stop(paste0(path[[i]][[k]], " is empty or no files with .", format, " ending in this directory")) |
|
|
1019 |
} |
|
|
1020 |
} |
|
|
1021 |
} |
|
|
1022 |
|
|
|
1023 |
# return number of files per class for "label_folder" |
|
|
1024 |
if (train_type == "label_folder") { |
|
|
1025 |
return(num_files) |
|
|
1026 |
} else { |
|
|
1027 |
return(sum(num_files)) |
|
|
1028 |
} |
|
|
1029 |
} |
|
|
1030 |
|
|
|
1031 |
list_fasta_files <- function(path_corpus, format, file_filter) { |
|
|
1032 |
|
|
|
1033 |
fasta.files <- list() |
|
|
1034 |
path_corpus <- unlist(path_corpus) |
|
|
1035 |
|
|
|
1036 |
for (i in 1:length(path_corpus)) { |
|
|
1037 |
|
|
|
1038 |
if (endsWith(path_corpus[[i]], paste0(".", format))) { |
|
|
1039 |
fasta.files[[i]] <- path_corpus[[i]] |
|
|
1040 |
|
|
|
1041 |
} else { |
|
|
1042 |
|
|
|
1043 |
fasta.files[[i]] <- list.files( |
|
|
1044 |
path = path_corpus[[i]], |
|
|
1045 |
pattern = paste0("\\.", format, "$"), |
|
|
1046 |
full.names = TRUE) |
|
|
1047 |
} |
|
|
1048 |
} |
|
|
1049 |
fasta.files <- unlist(fasta.files) |
|
|
1050 |
num_files <- length(fasta.files) |
|
|
1051 |
|
|
|
1052 |
if (!is.null(file_filter)) { |
|
|
1053 |
|
|
|
1054 |
# file filter files given with/without absolute path |
|
|
1055 |
if (all(basename(file_filter) == file_filter)) { |
|
|
1056 |
fasta.files <- fasta.files[basename(fasta.files) %in% file_filter] |
|
|
1057 |
} else { |
|
|
1058 |
fasta.files <- fasta.files[fasta.files %in% file_filter] |
|
|
1059 |
} |
|
|
1060 |
|
|
|
1061 |
if (length(fasta.files) < 1) { |
|
|
1062 |
stop_text <- paste0("None of the files from ", unlist(path_corpus), |
|
|
1063 |
" are present in train_val_split_csv table for either train or validation. \n") |
|
|
1064 |
stop(stop_text) |
|
|
1065 |
} |
|
|
1066 |
} |
|
|
1067 |
|
|
|
1068 |
fasta.files <- gsub(pattern="/+", replacement="/", x = fasta.files) |
|
|
1069 |
fasta.files <- gsub(pattern="/$", replacement="", x = fasta.files) |
|
|
1070 |
return(fasta.files) |
|
|
1071 |
} |
|
|
1072 |
|
|
|
1073 |
get_coverage <- function(fasta.file) { |
|
|
1074 |
header <- fasta.file$Header |
|
|
1075 |
cov <- stringr::str_extract(header, "cov_\\d+") %>% |
|
|
1076 |
stringr::str_extract("\\d+") %>% as.integer() |
|
|
1077 |
cov[is.na(cov)] <- 1 |
|
|
1078 |
return(cov) |
|
|
1079 |
} |
|
|
1080 |
|
|
|
1081 |
get_coverage_concat <- function(fasta.file, concat_seq) { |
|
|
1082 |
header <- fasta.file$Header |
|
|
1083 |
cov <- stringr::str_extract(header, "cov_\\d+") %>% |
|
|
1084 |
stringr::str_extract("\\d+") %>% as.integer() |
|
|
1085 |
cov[is.na(cov)] <- 1 |
|
|
1086 |
len_vec <- nchar(fasta.file$Sequence) |
|
|
1087 |
cov <- purrr::map(1:nrow(fasta.file), ~rep(cov[.x], times = len_vec[.x])) |
|
|
1088 |
cov <- lapply(cov, append, rep(1, nchar(concat_seq))) |
|
|
1089 |
cov <- unlist(cov) |
|
|
1090 |
cov <- cov[-((length(cov) - nchar(concat_seq)) : length(cov))] |
|
|
1091 |
return(cov) |
|
|
1092 |
} |
|
|
1093 |
|
|
|
1094 |
#' Reshape tensors for set learning |
|
|
1095 |
#' |
|
|
1096 |
#' Reshape input x and target y. Aggregates multiple samples from x and y into single input/target batches. |
|
|
1097 |
#' |
|
|
1098 |
#' @param x 3D input tensor. |
|
|
1099 |
#' @param y 2D target tensor. |
|
|
1100 |
#' @param samples_per_target How many samples to use for one target |
|
|
1101 |
#' @param reshape_mode `"time_dist", "multi_input"` or `"concat"` |
|
|
1102 |
#' \itemize{ |
|
|
1103 |
#' \item If `"multi_input"`, will produce `samples_per_target` separate inputs, each of length `maxlen`. |
|
|
1104 |
#' \item If `"time_dist"`, will produce a 4D input array. The dimensions correspond to |
|
|
1105 |
#' `(new_batch_size, samples_per_target, maxlen, length(vocabulary))`. |
|
|
1106 |
#' \item If `"concat"`, will concatenate `samples_per_target` sequences of length `maxlen` to one long sequence |
|
|
1107 |
#' } |
|
|
1108 |
#' @param buffer_len Only applies if `reshape_mode = "concat"`. If `buffer_len` is an integer, the subsequences are interspaced with `buffer_len` rows. The reshaped x has |
|
|
1109 |
#' new maxlen: (`maxlen` \eqn{*} `samples_per_target`) + `buffer_len` \eqn{*} (`samples_per_target` - 1). |
|
|
1110 |
#' @param new_batch_size Size of first axis of input/targets after reshaping. |
|
|
1111 |
#' @param check_y Check if entries in `y` are consistent with reshape strategy (same label when aggregating). |
|
|
1112 |
#' @examplesIf reticulate::py_module_available("tensorflow") |
|
|
1113 |
#' # create dummy data |
|
|
1114 |
#' batch_size <- 8 |
|
|
1115 |
#' maxlen <- 11 |
|
|
1116 |
#' voc_len <- 4 |
|
|
1117 |
#' x <- sample(0:(voc_len-1), maxlen*batch_size, replace = TRUE) |
|
|
1118 |
#' x <- keras::to_categorical(x, num_classes = voc_len) |
|
|
1119 |
#' x <- array(x, dim = c(batch_size, maxlen, voc_len)) |
|
|
1120 |
#' y <- rep(0:1, each = batch_size/2) |
|
|
1121 |
#' y <- keras::to_categorical(y, num_classes = 2) |
|
|
1122 |
#' y |
|
|
1123 |
#' |
|
|
1124 |
#' # reshape data for multi input model |
|
|
1125 |
#' reshaped_data <- reshape_tensor( |
|
|
1126 |
#' x = x, |
|
|
1127 |
#' y = y, |
|
|
1128 |
#' new_batch_size = 2, |
|
|
1129 |
#' samples_per_target = 4, |
|
|
1130 |
#' reshape_mode = "multi_input") |
|
|
1131 |
#' |
|
|
1132 |
#' length(reshaped_data[[1]]) |
|
|
1133 |
#' dim(reshaped_data[[1]][[1]]) |
|
|
1134 |
#' reshaped_data[[2]] |
|
|
1135 |
#' |
|
|
1136 |
#' @returns A list of 2 tensors. |
|
|
1137 |
#' @export |
|
|
1138 |
reshape_tensor <- function(x, y, new_batch_size, |
|
|
1139 |
samples_per_target, |
|
|
1140 |
buffer_len = NULL, |
|
|
1141 |
reshape_mode = "time_dist", |
|
|
1142 |
check_y = FALSE) { |
|
|
1143 |
|
|
|
1144 |
batch_size <- dim(x)[1] |
|
|
1145 |
maxlen <- dim(x)[2] |
|
|
1146 |
voc_len <- dim(x)[3] |
|
|
1147 |
num_classes <- dim(y)[2] |
|
|
1148 |
|
|
|
1149 |
if (check_y) { |
|
|
1150 |
targets <- apply(y, 1, which.max) |
|
|
1151 |
test_y_dist <- all(targets == rep(1:num_classes, each = batch_size/num_classes)) |
|
|
1152 |
if (!test_y_dist) { |
|
|
1153 |
stop("y must have same number of samples for each class") |
|
|
1154 |
} |
|
|
1155 |
} |
|
|
1156 |
|
|
|
1157 |
if (reshape_mode == "time_dist") { |
|
|
1158 |
|
|
|
1159 |
x_new <- array(0, dim = c(new_batch_size, samples_per_target, maxlen, voc_len)) |
|
|
1160 |
y_new <- array(0, dim = c(new_batch_size, num_classes)) |
|
|
1161 |
for (i in 1:new_batch_size) { |
|
|
1162 |
index <- (1:samples_per_target) + (i-1)*samples_per_target |
|
|
1163 |
x_new[i, , , ] <- x[index, , ] |
|
|
1164 |
y_new[i, ] <- y[index[1], ] |
|
|
1165 |
} |
|
|
1166 |
|
|
|
1167 |
return(list(x = x_new, y = y_new)) |
|
|
1168 |
} |
|
|
1169 |
|
|
|
1170 |
if (reshape_mode == "multi_input") { |
|
|
1171 |
|
|
|
1172 |
x_list <- vector("list", samples_per_target) |
|
|
1173 |
for (i in 1:samples_per_target) { |
|
|
1174 |
x_index <- base::seq(i, batch_size, samples_per_target) |
|
|
1175 |
x_list[[i]] <- x[x_index, , ] |
|
|
1176 |
} |
|
|
1177 |
y <- y[base::seq(1, batch_size, samples_per_target), ] |
|
|
1178 |
return(list(x = x_list, y = y)) |
|
|
1179 |
} |
|
|
1180 |
|
|
|
1181 |
if (reshape_mode == "concat") { |
|
|
1182 |
|
|
|
1183 |
use_buffer <- !is.null(buffer_len) && buffer_len > 0 |
|
|
1184 |
if (use_buffer) { |
|
|
1185 |
buffer_tensor <- array(0, dim = c(buffer_len, voc_len)) |
|
|
1186 |
buffer_tensor[ , voc_len] <- 1 |
|
|
1187 |
concat_maxlen <- (maxlen * samples_per_target) + (buffer_len * (samples_per_target - 1)) |
|
|
1188 |
} else { |
|
|
1189 |
concat_maxlen <- maxlen * samples_per_target |
|
|
1190 |
} |
|
|
1191 |
|
|
|
1192 |
x_new <- array(0, dim = c(new_batch_size, concat_maxlen, voc_len)) |
|
|
1193 |
y_new <- array(0, dim = c(new_batch_size, num_classes)) |
|
|
1194 |
|
|
|
1195 |
|
|
|
1196 |
for (i in 1:new_batch_size) { |
|
|
1197 |
index <- (1:samples_per_target) + (i-1)*samples_per_target |
|
|
1198 |
if (!use_buffer) { |
|
|
1199 |
x_temp <- x[index, , ] |
|
|
1200 |
x_temp <- reticulate::array_reshape(x_temp, dim = c(1, dim(x_temp)[1] * dim(x_temp)[2], voc_len)) |
|
|
1201 |
} else { |
|
|
1202 |
# create list of subsequences interspaced with buffer tensor |
|
|
1203 |
x_list <- vector("list", (2*samples_per_target) - 1) |
|
|
1204 |
x_list[seq(2, length(x_list), by = 2)] <- list(buffer_tensor) |
|
|
1205 |
for (k in 1:length(index)) { |
|
|
1206 |
x_list[[(2*k) - 1]] <- x[index[k], , ] |
|
|
1207 |
} |
|
|
1208 |
x_temp <- do.call(rbind, x_list) |
|
|
1209 |
} |
|
|
1210 |
|
|
|
1211 |
x_new[i, , ] <- x_temp |
|
|
1212 |
y_new[i, ] <- y[index[1], ] |
|
|
1213 |
} |
|
|
1214 |
return(list(x = x_new, y = y_new)) |
|
|
1215 |
} |
|
|
1216 |
} |
|
|
1217 |
|
|
|
1218 |
#' Transform confusion matrix with total numbers to matrix with percentages. |
|
|
1219 |
#' |
|
|
1220 |
#' @noRd |
|
|
1221 |
cm_perc <- function(cm, round_dig = 2) { |
|
|
1222 |
col_sums <- colSums(cm) |
|
|
1223 |
for (i in 1:ncol(cm)) { |
|
|
1224 |
if (col_sums[i] == 0) { |
|
|
1225 |
cm[ , i] <- 0 |
|
|
1226 |
} else { |
|
|
1227 |
cm[ , i] <- cm[ , i]/col_sums[i] |
|
|
1228 |
} |
|
|
1229 |
} |
|
|
1230 |
cm <- round(cm, round_dig) |
|
|
1231 |
cm |
|
|
1232 |
} |
|
|
1233 |
|
|
|
1234 |
create_conf_mat_obj <- function(m, confMatLabels) { |
|
|
1235 |
dimnames(m) <- list(Prediction = confMatLabels, Truth = confMatLabels) |
|
|
1236 |
l <- list() |
|
|
1237 |
m <- as.table(m) |
|
|
1238 |
l[["table"]] <- m |
|
|
1239 |
l[["dots"]] <- list() |
|
|
1240 |
class(l) <- "conf_mat" |
|
|
1241 |
return(l) |
|
|
1242 |
} |
|
|
1243 |
|
|
|
1244 |
#' Encode sequence of integers to sequence of n-gram |
|
|
1245 |
#' |
|
|
1246 |
#' Input is sequence of integers from vocabulary of size \code{voc_size}. |
|
|
1247 |
#' Returns vector of integers corresponding to n-gram encoding. |
|
|
1248 |
#' Integers greater than `voc_size` get encoded as `voc_size^n + 1`. |
|
|
1249 |
#' |
|
|
1250 |
#' @param int_seq Integer sequence |
|
|
1251 |
#' @param n Length of n-gram aggregation |
|
|
1252 |
#' @param voc_size Size of vocabulary. |
|
|
1253 |
#' @examples |
|
|
1254 |
#' int_to_n_gram(int_seq = c(1,1,2,4,4), n = 2, voc_size = 4) |
|
|
1255 |
#' |
|
|
1256 |
#' @returns A numeric vector. |
|
|
1257 |
#' @export |
|
|
1258 |
int_to_n_gram <- function(int_seq, n, voc_size = 4) { |
|
|
1259 |
|
|
|
1260 |
encoding_len <- length(int_seq) - n + 1 |
|
|
1261 |
n_gram_encoding <- vector("numeric", encoding_len) |
|
|
1262 |
oov_token <- voc_size^n + 1 |
|
|
1263 |
padding_token <- 0 |
|
|
1264 |
|
|
|
1265 |
for (i in 1:encoding_len) { |
|
|
1266 |
int_seq_subset <- int_seq[i:(i + n - 1)] |
|
|
1267 |
|
|
|
1268 |
if (prod(int_seq_subset) == 0) { |
|
|
1269 |
n_gram_encoding[i] <- padding_token |
|
|
1270 |
} else { |
|
|
1271 |
# encoding for amb nuc |
|
|
1272 |
if (any(int_seq_subset > voc_size)) { |
|
|
1273 |
n_gram_encoding[i] <- oov_token |
|
|
1274 |
} else { |
|
|
1275 |
int_seq_subset <- int_seq_subset - 1 |
|
|
1276 |
n_gram_encoding[i] <- 1 + sum(voc_size^((n-1):0) * (int_seq_subset)) |
|
|
1277 |
} |
|
|
1278 |
} |
|
|
1279 |
} |
|
|
1280 |
n_gram_encoding |
|
|
1281 |
} |
|
|
1282 |
|
|
|
1283 |
#' One-hot encoding matrix to n-gram encoding matrix |
|
|
1284 |
#' |
|
|
1285 |
#' @param input_matrix Matrix with one 1 per row and zeros otherwise. |
|
|
1286 |
#' @param n Length of one n-gram. |
|
|
1287 |
#' @examplesIf reticulate::py_module_available("tensorflow") |
|
|
1288 |
#' x <- c(0,0,1,3,3) |
|
|
1289 |
#' input_matrix <- keras::to_categorical(x, 4) |
|
|
1290 |
#' n_gram_of_matrix(input_matrix, n = 2) |
|
|
1291 |
#' |
|
|
1292 |
#' @returns Matrix of one-hot encodings. |
|
|
1293 |
#' @export |
|
|
1294 |
n_gram_of_matrix <- function(input_matrix, n = 3) { |
|
|
1295 |
voc_len <- ncol(input_matrix)^n |
|
|
1296 |
oov_index <- apply(input_matrix, 1, max) != 1 |
|
|
1297 |
max_index <- apply(input_matrix, 1, which.max) |
|
|
1298 |
max_index[oov_index] <- voc_len + 1 |
|
|
1299 |
int_enc <- int_to_n_gram(int_seq = max_index, n = n, voc_size = ncol(input_matrix)) |
|
|
1300 |
if (length(int_enc) == 1) { |
|
|
1301 |
n_gram_matrix <- matrix(keras::to_categorical(int_enc, num_classes = voc_len + 2), nrow = 1)[ , -c(1, voc_len + 2)] |
|
|
1302 |
} else { |
|
|
1303 |
n_gram_matrix <- keras::to_categorical(int_enc, num_classes = voc_len + 2)[ , -c(1, voc_len + 2)] |
|
|
1304 |
} |
|
|
1305 |
n_gram_matrix <- matrix(n_gram_matrix, ncol = voc_len) |
|
|
1306 |
return(n_gram_matrix) |
|
|
1307 |
} |
|
|
1308 |
|
|
|
1309 |
n_gram_of_3d_tensor <- function(tensor_3d, n) { |
|
|
1310 |
new_dim <- dim(tensor_3d) |
|
|
1311 |
new_dim[2] <- new_dim[2] - n + 1 |
|
|
1312 |
new_dim[3] <- new_dim[3]^n |
|
|
1313 |
new_tensor <- array(0, dim = new_dim) |
|
|
1314 |
for (i in 1:dim(tensor_3d)[1]) { |
|
|
1315 |
new_tensor[i, , ] <- n_gram_of_matrix(tensor_3d[i, , ], n = n) |
|
|
1316 |
} |
|
|
1317 |
new_tensor |
|
|
1318 |
} |
|
|
1319 |
|
|
|
1320 |
n_gram_vocabulary <- function(n_gram = 3, vocabulary = c("A", "C", "G", "T")) { |
|
|
1321 |
l <- list() |
|
|
1322 |
for (i in 1:n_gram) { |
|
|
1323 |
l[[i]] <- vocabulary |
|
|
1324 |
} |
|
|
1325 |
df <- expand.grid(l) |
|
|
1326 |
df <- df[ , ncol(df) : 1] |
|
|
1327 |
n_gram_nuc <- apply(df, 1, paste, collapse = "") |
|
|
1328 |
n_gram_nuc |
|
|
1329 |
} |
|
|
1330 |
|
|
|
1331 |
|
|
|
1332 |
#' Split fasta file into smaller files. |
|
|
1333 |
#' |
|
|
1334 |
#' Returns smaller files with same file name and "_x" (where x is an integer). For example, |
|
|
1335 |
#' assume we have input file called "abc.fasta" with 100 entries and `split_n = 50`. Function will |
|
|
1336 |
#' create two files called "abc_1.fasta" and "abc_2.fasta" in `target_path`. |
|
|
1337 |
#' |
|
|
1338 |
#' @param path_input Fasta file to split into smaller files |
|
|
1339 |
#' @param split_n Maximum number of entries to use in smaller file. |
|
|
1340 |
#' @param target_folder Directory for output. |
|
|
1341 |
#' @param shuffle_entries Whether to shuffle fasta entries before split. |
|
|
1342 |
#' @param delete_input Whether to delete the original file. |
|
|
1343 |
#' @examples |
|
|
1344 |
#' path_input <- tempfile(fileext = '.fasta') |
|
|
1345 |
#' create_dummy_data(file_path = path_input, |
|
|
1346 |
#' num_files = 1, |
|
|
1347 |
#' write_to_file_path = TRUE, |
|
|
1348 |
#' seq_length = 7, |
|
|
1349 |
#' num_seq = 25, |
|
|
1350 |
#' vocabulary = c("a", "c", "g", "t")) |
|
|
1351 |
#' target_folder <- tempfile() |
|
|
1352 |
#' dir.create(target_folder) |
|
|
1353 |
#' |
|
|
1354 |
#' # split 25 entries into 5 files |
|
|
1355 |
#' split_fasta(path_input = path_input, |
|
|
1356 |
#' target_folder = target_folder, |
|
|
1357 |
#' split_n = 5) |
|
|
1358 |
#' length(list.files(target_folder)) |
|
|
1359 |
#' |
|
|
1360 |
#' @returns None. Writes files to output. |
|
|
1361 |
#' @export |
|
|
1362 |
split_fasta <- function(path_input, |
|
|
1363 |
target_folder, |
|
|
1364 |
split_n = 500, |
|
|
1365 |
shuffle_entries = TRUE, |
|
|
1366 |
delete_input = FALSE) { |
|
|
1367 |
|
|
|
1368 |
fasta_file <- microseq::readFasta(path_input) |
|
|
1369 |
|
|
|
1370 |
base_name <- basename(stringr::str_remove(path_input, ".fasta")) |
|
|
1371 |
new_path <- paste0(target_folder, "/", base_name) |
|
|
1372 |
count <- 1 |
|
|
1373 |
start_index <- 1 |
|
|
1374 |
end_index <- 1 |
|
|
1375 |
|
|
|
1376 |
if (nrow(fasta_file) == 1) { |
|
|
1377 |
fasta_name <- paste0(new_path, "_", count, ".fasta") |
|
|
1378 |
microseq::writeFasta(fasta_file, fasta_name) |
|
|
1379 |
if (delete_input) { |
|
|
1380 |
file.remove(path_input) |
|
|
1381 |
} |
|
|
1382 |
return(NULL) |
|
|
1383 |
} |
|
|
1384 |
|
|
|
1385 |
if (shuffle_entries) { |
|
|
1386 |
fasta_file <- fasta_file[sample(nrow(fasta_file)), ] |
|
|
1387 |
} |
|
|
1388 |
|
|
|
1389 |
while (end_index < nrow(fasta_file)) { |
|
|
1390 |
end_index <- min(start_index + split_n - 1, nrow(fasta_file)) |
|
|
1391 |
index <- start_index : end_index |
|
|
1392 |
sub_df <- fasta_file[index, ] |
|
|
1393 |
fasta_name <- paste0(new_path, "_", count, ".fasta") |
|
|
1394 |
microseq::writeFasta(sub_df, fasta_name) |
|
|
1395 |
start_index <- start_index + split_n |
|
|
1396 |
count <- count + 1 |
|
|
1397 |
} |
|
|
1398 |
|
|
|
1399 |
if (delete_input) { |
|
|
1400 |
file.remove(path_input) |
|
|
1401 |
} |
|
|
1402 |
} |
|
|
1403 |
|
|
|
1404 |
#' Add noise to tensor |
|
|
1405 |
#' |
|
|
1406 |
#' @param noise_type "normal" or "uniform". |
|
|
1407 |
#' @param ... additional arguments for rnorm or runif call. |
|
|
1408 |
#' @noRd |
|
|
1409 |
add_noise_tensor <- function(x, noise_type, ...) { |
|
|
1410 |
|
|
|
1411 |
stopifnot(noise_type %in% c("normal", "uniform")) |
|
|
1412 |
random_fn <- ifelse(noise_type == "normal", "rnorm", "runif") |
|
|
1413 |
|
|
|
1414 |
if (is.list(x)) { |
|
|
1415 |
for (i in 1:length(x)) { |
|
|
1416 |
x_dim <- dim(x[[i]]) |
|
|
1417 |
noise_tensor <- do.call(random_fn, list(n = prod(x_dim[-1]), ...)) |
|
|
1418 |
noise_tensor <- array(noise_tensor, dim = x_dim) |
|
|
1419 |
x[[i]] <- x[[i]] + noise_tensor |
|
|
1420 |
} |
|
|
1421 |
} else { |
|
|
1422 |
x_dim <- dim(x) |
|
|
1423 |
stopifnot(noise_type %in% c("normal", "uniform")) |
|
|
1424 |
random_fn <- ifelse(noise_type == "normal", "rnorm", "runif") |
|
|
1425 |
noise_tensor <- do.call(random_fn, list(n = prod(x_dim[-1]), ...)) |
|
|
1426 |
noise_tensor <- array(noise_tensor, dim = x_dim) |
|
|
1427 |
x <- x + noise_tensor |
|
|
1428 |
} |
|
|
1429 |
|
|
|
1430 |
return(x) |
|
|
1431 |
} |
|
|
1432 |
|
|
|
1433 |
reverse_complement_tensor <- function(x) { |
|
|
1434 |
stopifnot(dim(x)[3] == 4) |
|
|
1435 |
x_rev_comp <- x[ , dim(x)[2]:1, 4:1] |
|
|
1436 |
x_rev_comp <- array(x_rev_comp, dim = dim(x)) |
|
|
1437 |
x_rev_comp |
|
|
1438 |
} |
|
|
1439 |
|
|
|
1440 |
|
|
|
1441 |
get_pos_enc <- function(pos, i, d_model, n = 10000) { |
|
|
1442 |
|
|
|
1443 |
pw <- (2 * floor(i/2)) / d_model |
|
|
1444 |
angle_rates <- 1 / (n ^ pw) |
|
|
1445 |
angle <- pos * angle_rates |
|
|
1446 |
pos_enc <- ifelse(i %% 2 == 0, sin(angle), cos(angle)) |
|
|
1447 |
return(pos_enc) |
|
|
1448 |
} |
|
|
1449 |
|
|
|
1450 |
positional_encoding <- function(seq_len, d_model, n=10000) { |
|
|
1451 |
|
|
|
1452 |
P = matrix(0, nrow = seq_len, ncol = d_model) |
|
|
1453 |
|
|
|
1454 |
for (pos in 0:(seq_len - 1)) { |
|
|
1455 |
for (i in 0:(d_model - 1)) { |
|
|
1456 |
P[pos + 1, i + 1] <- get_pos_enc(pos, i, d_model, n) |
|
|
1457 |
} |
|
|
1458 |
} |
|
|
1459 |
|
|
|
1460 |
return(P) |
|
|
1461 |
} |
|
|
1462 |
|
|
|
1463 |
|
|
|
1464 |
subset_tensor_list <- function(tensor_list, dim_list, subset_index, dim_n_list) { |
|
|
1465 |
|
|
|
1466 |
for (i in 1:length(tensor_list)) { |
|
|
1467 |
tensor_list[[i]] <- subset_tensor(tensor = tensor_list[[i]], |
|
|
1468 |
subset_index = subset_index, |
|
|
1469 |
dim_n = dim_n_list[[i]]) |
|
|
1470 |
} |
|
|
1471 |
|
|
|
1472 |
} |
|
|
1473 |
|
|
|
1474 |
subset_tensor <- function(tensor, subset_index, dim_n) { |
|
|
1475 |
|
|
|
1476 |
if (dim_n == 1) { |
|
|
1477 |
subset_tensor <- tensor[subset_index] |
|
|
1478 |
} |
|
|
1479 |
|
|
|
1480 |
if (dim_n == 2) { |
|
|
1481 |
subset_tensor <- tensor[subset_index, ] |
|
|
1482 |
} |
|
|
1483 |
|
|
|
1484 |
if (dim_n == 3) { |
|
|
1485 |
subset_tensor <- tensor[subset_index, , ] |
|
|
1486 |
} |
|
|
1487 |
|
|
|
1488 |
if (dim_n == 4) { |
|
|
1489 |
subset_tensor <- tensor[subset_index, , , ] |
|
|
1490 |
} |
|
|
1491 |
|
|
|
1492 |
if (length(subset_index) == 1 & dim_n > 1) { |
|
|
1493 |
subset_tensor <- tensorflow::tf$expand_dims(subset_tensor, axis = 0L) |
|
|
1494 |
} |
|
|
1495 |
} |
|
|
1496 |
|
|
|
1497 |
|
|
|
1498 |
mask_seq <- function(int_seq, |
|
|
1499 |
mask_rate = NULL, |
|
|
1500 |
random_rate = NULL, |
|
|
1501 |
identity_rate = NULL, |
|
|
1502 |
block_len = NULL, |
|
|
1503 |
start_ind = NULL, |
|
|
1504 |
voc_len) { |
|
|
1505 |
|
|
|
1506 |
mask_token <- voc_len + 1 |
|
|
1507 |
if (is.null(mask_rate)) mask_rate <- 0 |
|
|
1508 |
if (is.null(random_rate)) random_rate <- 0 |
|
|
1509 |
if (is.null(identity_rate)) identity_rate <- 0 |
|
|
1510 |
mask_perc <- mask_rate + random_rate + identity_rate |
|
|
1511 |
if (mask_perc > 1) { |
|
|
1512 |
stop("Sum of mask_rate, random_rate, identity_rate bigger than 1") |
|
|
1513 |
} |
|
|
1514 |
# don't mask padding or oov positions |
|
|
1515 |
valid_pos <- which(int_seq != 0 & int_seq != mask_token) |
|
|
1516 |
|
|
|
1517 |
# randomly decide whether to round up or down |
|
|
1518 |
ceiling_floor <- sample(c(TRUE, FALSE), 3, replace = TRUE) |
|
|
1519 |
# adjust for block len |
|
|
1520 |
block_len_adjust <- ifelse(is.null(block_len), 1, block_len) |
|
|
1521 |
|
|
|
1522 |
num_mask_pos <- (mask_rate * length(valid_pos))/block_len_adjust |
|
|
1523 |
num_mask_pos <- ifelse(ceiling_floor[1], floor(num_mask_pos), ceiling(num_mask_pos)) |
|
|
1524 |
num_random_pos <- (random_rate * length(valid_pos))/block_len_adjust |
|
|
1525 |
num_random_pos <- ifelse(ceiling_floor[2], floor(num_random_pos), ceiling(num_random_pos)) |
|
|
1526 |
num_identity_pos <- (identity_rate * length(valid_pos))/block_len_adjust |
|
|
1527 |
num_identity_pos <- ifelse(ceiling_floor[3], floor(num_identity_pos), ceiling(num_identity_pos)) |
|
|
1528 |
num_all_pos <- num_mask_pos + num_random_pos + num_identity_pos |
|
|
1529 |
if (is.null(block_len)) { |
|
|
1530 |
all_pos <- sample(valid_pos, num_all_pos) |
|
|
1531 |
} else { |
|
|
1532 |
valid_pos_block_len <- seq(from = sample(1:(block_len - 1), 1), to = length(valid_pos), by = block_len) |
|
|
1533 |
valid_pos <- intersect(valid_pos_block_len, valid_pos) |
|
|
1534 |
all_pos <- sample(valid_pos, min(num_all_pos, length(valid_pos))) |
|
|
1535 |
} |
|
|
1536 |
|
|
|
1537 |
sample_weight_seq <- rep(0, length(int_seq)) |
|
|
1538 |
if (is.null(block_len)) { |
|
|
1539 |
sample_weight_seq[all_pos] <- 1 |
|
|
1540 |
} else { |
|
|
1541 |
all_pos_blocks <- purrr::map(all_pos, ~seq(.x, .x + block_len - 1, by = 1)) |
|
|
1542 |
sample_weight_seq[unlist(all_pos_blocks)] <- 1 |
|
|
1543 |
} |
|
|
1544 |
|
|
|
1545 |
if (num_mask_pos > 0) { |
|
|
1546 |
mask_index <- sample(all_pos, num_mask_pos) |
|
|
1547 |
all_pos <- setdiff(all_pos, mask_index) |
|
|
1548 |
if (!is.null(block_len)) { |
|
|
1549 |
mask_index <- purrr::map(mask_index, ~seq(.x, .x + block_len - 1, by = 1)) %>% |
|
|
1550 |
unlist() |
|
|
1551 |
} |
|
|
1552 |
int_seq[mask_index] <- mask_token |
|
|
1553 |
} |
|
|
1554 |
|
|
|
1555 |
if (num_random_pos > 0) { |
|
|
1556 |
random_index <- sample(all_pos, num_random_pos) |
|
|
1557 |
all_pos <- setdiff(all_pos, random_index) |
|
|
1558 |
if (!is.null(block_len)) { |
|
|
1559 |
random_index <- purrr::map(random_index, ~seq(.x, .x + block_len - 1, by = 1)) %>% |
|
|
1560 |
unlist() |
|
|
1561 |
} |
|
|
1562 |
int_seq[random_index] <- sample(1:voc_len, length(random_index), replace = TRUE) |
|
|
1563 |
} |
|
|
1564 |
|
|
|
1565 |
# mask oov tokens |
|
|
1566 |
sample_weight_seq[int_seq == mask_token] <- 1 |
|
|
1567 |
|
|
|
1568 |
return(list(masked_seq = int_seq, sample_weight_seq = sample_weight_seq)) |
|
|
1569 |
|
|
|
1570 |
} |
|
|
1571 |
|
|
|
1572 |
#' Char sequence corresponding to one-hot matrix. |
|
|
1573 |
#' |
|
|
1574 |
#' Return character sequence corresponding to one-hot elements in matrix or tensor. |
|
|
1575 |
#' |
|
|
1576 |
#' @inheritParams generator_fasta_lm |
|
|
1577 |
#' @param m One-hot encoding matrix or 3d array where each element of first axis is one-hot matrix. |
|
|
1578 |
#' @param amb_enc Either `"zero"` or `"equal"`. How oov tokens where treated for one-hot encoding. |
|
|
1579 |
#' @param amb_char Char to use for oov positions. |
|
|
1580 |
#' @param paste_chars Whether to return vector or single sequence. |
|
|
1581 |
#' @examples |
|
|
1582 |
#' m <- matrix(c(1,0,0,0,0,1,0,0), 2) |
|
|
1583 |
#' one_hot_to_seq(m) |
|
|
1584 |
#' |
|
|
1585 |
#' @returns A string. |
|
|
1586 |
#' @export |
|
|
1587 |
one_hot_to_seq <- function(m, vocabulary = c("A", "C", "G", "T"), amb_enc = "zero", |
|
|
1588 |
amb_char = "N", paste_chars = TRUE) { |
|
|
1589 |
|
|
|
1590 |
if (length(dim(m)) == 3) { |
|
|
1591 |
seq_list <- list() |
|
|
1592 |
for (i in 1:dim(m)[1]) { |
|
|
1593 |
seq_list[[i]] <- one_hot_to_seq(m = m[i, , ], vocabulary = vocabulary, amb_enc = amb_enc, |
|
|
1594 |
amb_char = amb_char, paste_chars = paste_chars) |
|
|
1595 |
} |
|
|
1596 |
return(seq_list) |
|
|
1597 |
} |
|
|
1598 |
|
|
|
1599 |
if (amb_enc == "zero") { |
|
|
1600 |
amb_row <- which(rowSums(m) == 0) |
|
|
1601 |
} |
|
|
1602 |
|
|
|
1603 |
if (amb_enc == "equal") { |
|
|
1604 |
amb_row <- which(rowSums[ , 1] == 1/length(vocabulary)) |
|
|
1605 |
} |
|
|
1606 |
|
|
|
1607 |
nt_seq <- vocabulary[apply(m, 1, which.max)] |
|
|
1608 |
nt_seq[amb_row] <- amb_char |
|
|
1609 |
|
|
|
1610 |
if (paste_chars) { |
|
|
1611 |
nt_seq <- paste(nt_seq, collapse = "") |
|
|
1612 |
} |
|
|
1613 |
|
|
|
1614 |
return(nt_seq) |
|
|
1615 |
|
|
|
1616 |
} |