[409433]: / R / generator_lm.R

Download this file

216 lines (204 with data), 12.4 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
#' Language model generator for fasta/fastq files
#'
#' @description Iterates over folder containing fasta/fastq files and produces encoding of predictor sequences
#' and target variables. Will take a sequence of fixed size and use some part of sequence as input and other part as target.
#'
#' @inheritParams train_model
#' @param path_corpus Input directory where fasta files are located or path to single file ending with fasta or fastq
#' (as specified in format argument). Can also be a list of directories and/or files.
#' @param format File format, either `"fasta"` or `"fastq"`.
#' @param batch_size Number of samples in one batch.
#' @param maxlen Length of predictor sequence.
#' @param max_iter Stop after `max_iter` number of iterations failed to produce a new batch.
#' @param shuffle_file_order Logical, whether to go through files randomly or sequentially.
#' @param step How often to take a sample.
#' @param seed Sets seed for `set.seed` function for reproducible results.
#' @param shuffle_input Whether to shuffle entries in every fasta/fastq file before extracting samples.
#' @param verbose Whether to show messages.
#' @param path_file_log Write name of files to csv file if path is specified.
#' @param reverse_complement Boolean, for every new file decide randomly to use original data or its reverse complement.
#' @param ambiguous_nuc How to handle nucleotides outside vocabulary, either `"zero"`, `"discard"`, `"empirical"` or `"equal"`.
#' \itemize{
#' \item If `"zero"`, input gets encoded as zero vector.
#' \item If `"equal"`, input is repetition of `1/length(vocabulary)`.
#' \item If `"discard"`, samples containing nucleotides outside vocabulary get discarded.
#' \item If `"empirical"`, use nucleotide distribution of current file.
#' }
#' @param proportion_per_seq Numerical value between 0 and 1. Proportion of sequence to take samples from (use random subsequence).
#' @param use_quality_score Whether to use fastq quality scores. If TRUE input is not one-hot-encoding but corresponds to probabilities.
#' For example (0.97, 0.01, 0.01, 0.01) instead of (1, 0, 0, 0).
#' @param padding Whether to pad sequences too short for one sample with zeros.
#' @param added_label_path Path to file with additional input labels. Should be a csv file with one column named "file". Other columns should correspond to labels.
#' @param add_input_as_seq Boolean vector specifying for each entry in \code{added_label_path} if rows from csv should be encoded as a sequence or used directly.
#' If a row in your csv file is a sequence this should be `TRUE`. For example you may want to add another sequence, say ACCGT. Then this would correspond to 1,2,2,3,4 in
#' csv file (if vocabulary = c("A", "C", "G", "T")). If \code{add_input_as_seq} is `TRUE`, 12234 gets one-hot encoded, so added input is a 3D tensor. If \code{add_input_as_seq} is
#' `FALSE` this will feed network just raw data (a 2D tensor).
#' @param skip_amb_nuc Threshold of ambiguous nucleotides to accept in fasta entry. Complete entry will get discarded otherwise.
#' @param max_samples Maximum number of samples to use from one file. If not `NULL` and file has more than \code{max_samples} samples, will randomly choose a
#' subset of \code{max_samples} samples.
#' @param concat_seq Character string or `NULL`. If not `NULL` all entries from file get concatenated to one sequence with `concat_seq` string between them.
#' Example: If 1.entry AACC, 2. entry TTTG and `concat_seq = "ZZZ"` this becomes AACCZZZTTTG.
#' @param target_len Number of nucleotides to predict at once for language model.
#' @param file_filter Vector of file names to use from path_corpus.
#' @param use_coverage Integer or `NULL`. If not `NULL`, use coverage as encoding rather than one-hot encoding and normalize.
#' Coverage information must be contained in fasta header: there must be a string `"cov_n"` in the header, where `n` is some integer.
#' @param proportion_entries Proportion of fasta entries to keep. For example, if fasta file has 50 entries and `proportion_entries = 0.1`,
#' will randomly select 5 entries.
#' @param sample_by_file_size Sample new file weighted by file size (bigger files more likely).
#' @param n_gram Integer, encode target not nucleotide wise but combine n nucleotides at once. For example for `n=2, "AA" -> (1, 0,..., 0),`
#' `"AC" -> (0, 1, 0,..., 0), "TT" -> (0,..., 0, 1)`, where the one-hot vectors have length `length(vocabulary)^n`.
#' @param add_noise `NULL` or list of arguments. If not `NULL`, list must contain the following arguments: \code{noise_type} can be `"normal"` or `"uniform"`;
#' optional arguments `sd` or `mean` if noise_type is `"normal"` (default is `sd=1` and `mean=0`) or `min, max` if `noise_type` is `"uniform"`
#' (default is `min=0, max=1`).
#' @param return_int Whether to return integer encoding or one-hot encoding.
#' @param reshape_xy Can be a list of functions to apply to input and/or target. List elements (containing the reshape functions)
#' must be called x for input or y for target and each have arguments called x and y. For example:
#' `reshape_xy = list(x = function(x, y) {return(x+1)}, y = function(x, y) {return(x+y)})` .
#' For rds generator needs to have an additional argument called sw.
#' @rawNamespace import(data.table, except = c(first, last, between))
#' @importFrom magrittr %>%
#' @examplesIf reticulate::py_module_available("tensorflow")
#' # create dummy fasta files
#' path_input_1 <- tempfile()
#' dir.create(path_input_1)
#' create_dummy_data(file_path = path_input_1,
#' num_files = 2,
#' seq_length = 8,
#' num_seq = 1,
#' vocabulary = c("a", "c", "g", "t"))
#'
#' gen <- generator_fasta_lm(path_corpus = path_input_1, batch_size = 2,
#' maxlen = 7)
#' z <- gen()
#' dim(z[[1]])
#' z[[2]]
#'
#' @returns A generator function.
#' @export
generator_fasta_lm <- function(path_corpus,
format = "fasta",
batch_size = 256,
maxlen = 250,
max_iter = 10000,
vocabulary = c("a", "c", "g", "t"),
verbose = FALSE,
shuffle_file_order = FALSE,
step = 1,
seed = 1234,
shuffle_input = FALSE,
file_limit = NULL,
path_file_log = NULL,
reverse_complement = FALSE,
output_format = "target_right",
ambiguous_nuc = "zeros",
use_quality_score = FALSE,
proportion_per_seq = NULL,
padding = TRUE,
added_label_path = NULL,
add_input_as_seq = NULL,
skip_amb_nuc = NULL,
max_samples = NULL,
concat_seq = NULL,
target_len = 1,
file_filter = NULL,
use_coverage = NULL,
proportion_entries = NULL,
sample_by_file_size = FALSE,
n_gram = NULL,
n_gram_stride = 1,
add_noise = NULL,
return_int = FALSE,
reshape_xy = NULL) {
##TODO: add check for n-gram and option for stride
# if (!is.null(n_gram) & !(any(n_gram_stride == c(n_gram, 1)))) {
# stop("When using language model with n_gram encoding, n_gram_stride must be 1 or equal to n_gram")
# }
if (!is.null(n_gram)) {
# maxlen_n_gram <- ceiling((maxlen - n_gram + 1)/n_gram_stride)
# target_len_n_gram <- ceiling((target_len - n_gram + 1)/n_gram_stride)
if (!n_gram_stride == n_gram) {
stop("When using train_type='lm' with n_gram encoding, n_gram_stride must be equal to n_gram.")
}
} # else {
# maxlen_n_gram <- maxlen
# target_len_n_gram <- target_len
# }
if (!is.null(reshape_xy)) {
reshape_xy_bool <- TRUE
reshape_x_bool <- ifelse(is.null(reshape_xy$x), FALSE, TRUE)
if (reshape_x_bool && !all(c('x', 'y') %in% formals(reshape_xy$x))) {
stop("function reshape_xy$x needs to have arguments named x and y")
}
reshape_y_bool <- ifelse(is.null(reshape_xy$y), FALSE, TRUE)
if (reshape_y_bool && !all(c('x', 'y') %in% formals(reshape_xy$y))) {
stop("function reshape_xy$y needs to have arguments named x and y")
}
} else {
reshape_xy_bool <- FALSE
}
total_seq_len <- maxlen + target_len
gen <- generator_fasta_label_folder(path_corpus = path_corpus,
format = format,
batch_size = batch_size,
maxlen = total_seq_len,
max_iter = max_iter,
vocabulary = vocabulary,
shuffle_file_order = shuffle_file_order,
step = step,
seed = seed,
shuffle_input = shuffle_input,
file_limit = file_limit,
path_file_log = path_file_log,
reverse_complement = reverse_complement,
reverse_complement_encoding = FALSE,
num_targets = 1,
ones_column = 1,
ambiguous_nuc = ambiguous_nuc,
proportion_per_seq = proportion_per_seq,
read_data = FALSE,
use_quality_score = use_quality_score,
padding = padding,
added_label_path = added_label_path,
add_input_as_seq = add_input_as_seq,
skip_amb_nuc = skip_amb_nuc,
max_samples = max_samples,
concat_seq = concat_seq,
file_filter = file_filter,
use_coverage = use_coverage,
proportion_entries = proportion_entries,
sample_by_file_size = sample_by_file_size,
n_gram = n_gram,
n_gram_stride = n_gram_stride,
masked_lm = NULL,
add_noise = add_noise,
return_int = return_int)
function() {
if (is.null(added_label_path)) {
xy <- gen()[[1]]
} else {
z <- gen()[[1]]
added_input <- z[1:(length(z)-1)]
xy <- z[length(z)][[1]]
}
xy_list <- slice_tensor_lm(xy = xy,
output_format = output_format,
target_len = target_len,
n_gram = n_gram,
# maxlen_n_gram = maxlen_n_gram,
# target_len_n_gram = target_len_n_gram,
n_gram_stride = n_gram_stride,
total_seq_len = total_seq_len,
return_int = return_int)
if (!is.null(added_label_path)) {
xy_list <- (list(append(added_input, list(xy_list$x)), xy_list$y))
}
if (reshape_xy_bool) {
xy_list <- f_reshape(x = xy_list$x, y = xy_list$y,
reshape_xy = reshape_xy,
reshape_x_bool = reshape_x_bool,
reshape_y_bool = reshape_y_bool,
reshape_sw_bool = FALSE, sw = NULL)
}
return(xy_list)
}
}