|
a |
|
b/man/train_model.Rd |
|
|
1 |
% Generated by roxygen2: do not edit by hand |
|
|
2 |
% Please edit documentation in R/train.R |
|
|
3 |
\name{train_model} |
|
|
4 |
\alias{train_model} |
|
|
5 |
\title{Train neural network on genomic data} |
|
|
6 |
\usage{ |
|
|
7 |
train_model( |
|
|
8 |
model = NULL, |
|
|
9 |
dataset = NULL, |
|
|
10 |
dataset_val = NULL, |
|
|
11 |
train_val_ratio = 0.2, |
|
|
12 |
run_name = "run_1", |
|
|
13 |
initial_epoch = 0, |
|
|
14 |
class_weight = NULL, |
|
|
15 |
print_scores = TRUE, |
|
|
16 |
epochs = 10, |
|
|
17 |
max_queue_size = 100, |
|
|
18 |
steps_per_epoch = 1000, |
|
|
19 |
path_checkpoint = NULL, |
|
|
20 |
path_tensorboard = NULL, |
|
|
21 |
path_log = NULL, |
|
|
22 |
save_best_only = NULL, |
|
|
23 |
save_weights_only = FALSE, |
|
|
24 |
tb_images = FALSE, |
|
|
25 |
path_file_log = NULL, |
|
|
26 |
reset_states = FALSE, |
|
|
27 |
early_stopping_time = NULL, |
|
|
28 |
validation_only_after_training = FALSE, |
|
|
29 |
train_val_split_csv = NULL, |
|
|
30 |
reduce_lr_on_plateau = TRUE, |
|
|
31 |
lr_plateau_factor = 0.9, |
|
|
32 |
patience = 20, |
|
|
33 |
cooldown = 1, |
|
|
34 |
model_card = NULL, |
|
|
35 |
callback_list = NULL, |
|
|
36 |
train_type = "label_folder", |
|
|
37 |
path = NULL, |
|
|
38 |
path_val = NULL, |
|
|
39 |
batch_size = 64, |
|
|
40 |
step = NULL, |
|
|
41 |
shuffle_file_order = TRUE, |
|
|
42 |
vocabulary = c("a", "c", "g", "t"), |
|
|
43 |
format = "fasta", |
|
|
44 |
ambiguous_nuc = "zero", |
|
|
45 |
seed = c(1234, 4321), |
|
|
46 |
file_limit = NULL, |
|
|
47 |
use_coverage = NULL, |
|
|
48 |
set_learning = NULL, |
|
|
49 |
proportion_entries = NULL, |
|
|
50 |
sample_by_file_size = FALSE, |
|
|
51 |
n_gram = NULL, |
|
|
52 |
n_gram_stride = 1, |
|
|
53 |
masked_lm = NULL, |
|
|
54 |
random_sampling = FALSE, |
|
|
55 |
add_noise = NULL, |
|
|
56 |
return_int = FALSE, |
|
|
57 |
maxlen = NULL, |
|
|
58 |
reverse_complement = FALSE, |
|
|
59 |
reverse_complement_encoding = FALSE, |
|
|
60 |
output_format = "target_right", |
|
|
61 |
proportion_per_seq = NULL, |
|
|
62 |
read_data = FALSE, |
|
|
63 |
use_quality_score = FALSE, |
|
|
64 |
padding = FALSE, |
|
|
65 |
concat_seq = NULL, |
|
|
66 |
target_len = 1, |
|
|
67 |
skip_amb_nuc = NULL, |
|
|
68 |
max_samples = NULL, |
|
|
69 |
added_label_path = NULL, |
|
|
70 |
add_input_as_seq = NULL, |
|
|
71 |
target_from_csv = NULL, |
|
|
72 |
target_split = NULL, |
|
|
73 |
shuffle_input = TRUE, |
|
|
74 |
vocabulary_label = NULL, |
|
|
75 |
delete_used_files = FALSE, |
|
|
76 |
reshape_xy = NULL, |
|
|
77 |
return_gen = FALSE |
|
|
78 |
) |
|
|
79 |
} |
|
|
80 |
\arguments{ |
|
|
81 |
\item{model}{A keras model.} |
|
|
82 |
|
|
|
83 |
\item{dataset}{List of training data holding training samples in RAM instead of using generator. Should be list with two entries called \code{"X"} and \code{"Y"}.} |
|
|
84 |
|
|
|
85 |
\item{dataset_val}{List of validation data. Should have two entries called \code{"X"} and \code{"Y"}.} |
|
|
86 |
|
|
|
87 |
\item{train_val_ratio}{For generator defines the fraction of batches that will be used for validation (compared to size of training data), i.e. one validation iteration |
|
|
88 |
processes \code{batch_size} \eqn{*} \code{steps_per_epoch} \eqn{*} \code{train_val_ratio} samples. If you use dataset instead of generator and \code{dataset_val} is \code{NULL}, splits \code{dataset} |
|
|
89 |
into train/validation data.} |
|
|
90 |
|
|
|
91 |
\item{run_name}{Name of the run. Name will be used to identify output from callbacks. If \code{NULL}, will use date as run name. |
|
|
92 |
If name already present, will add \code{"_2"} to name or \code{"_{x+1}"} if name ends with \verb{_x}, where \code{x} is some integer.} |
|
|
93 |
|
|
|
94 |
\item{initial_epoch}{Epoch at which to start training. Note that network |
|
|
95 |
will run for (\code{epochs} - \code{initial_epochs}) rounds and not \code{epochs} rounds.} |
|
|
96 |
|
|
|
97 |
\item{class_weight}{List of weights for output. Order should correspond to \code{vocabulary_label}. |
|
|
98 |
You can use \code{\link{get_class_weight}} function to estimate class weights: |
|
|
99 |
|
|
|
100 |
\code{class_weights <- get_class_weights(path = path, train_type = train_type)} |
|
|
101 |
|
|
|
102 |
If \code{train_type = "label_csv"} you need to add path to csv file: |
|
|
103 |
|
|
|
104 |
\code{class_weights <- get_class_weights(path = path, train_type = train_type, csv_path = target_from_csv)}} |
|
|
105 |
|
|
|
106 |
\item{print_scores}{Whether to print train/validation scores during training.} |
|
|
107 |
|
|
|
108 |
\item{epochs}{Number of iterations.} |
|
|
109 |
|
|
|
110 |
\item{max_queue_size}{Maximum size for the generator queue.} |
|
|
111 |
|
|
|
112 |
\item{steps_per_epoch}{Number of training batches per epoch.} |
|
|
113 |
|
|
|
114 |
\item{path_checkpoint}{Path to checkpoints folder or \code{NULL}. If \code{NULL}, checkpoints don't get stored.} |
|
|
115 |
|
|
|
116 |
\item{path_tensorboard}{Path to tensorboard directory or \code{NULL}. If \code{NULL}, training not tracked on tensorboard.} |
|
|
117 |
|
|
|
118 |
\item{path_log}{Path to directory to write training scores. File name is \code{run_name} + \code{".csv"}. No output if \code{NULL}.} |
|
|
119 |
|
|
|
120 |
\item{save_best_only}{Only save model that improved on some score. Not applied if argument is \code{NULL}. Otherwise must be |
|
|
121 |
list with argument \code{monitor} or \code{save_freq} (can only use one option). \code{moniter} specifies what metric to use. |
|
|
122 |
\code{save_freq}, integer specifying how often to store a checkpoint (in epochs).} |
|
|
123 |
|
|
|
124 |
\item{save_weights_only}{Whether to save weights only.} |
|
|
125 |
|
|
|
126 |
\item{tb_images}{Whether to show custom images (confusion matrix) in tensorboard "IMAGES" tab.} |
|
|
127 |
|
|
|
128 |
\item{path_file_log}{Write name of files used for training to csv file if path is specified.} |
|
|
129 |
|
|
|
130 |
\item{reset_states}{Whether to reset hidden states of RNN layer at every new input file and before/after validation.} |
|
|
131 |
|
|
|
132 |
\item{early_stopping_time}{Time in seconds after which to stop training.} |
|
|
133 |
|
|
|
134 |
\item{validation_only_after_training}{Whether to skip validation during training and only do one validation iteration after training.} |
|
|
135 |
|
|
|
136 |
\item{train_val_split_csv}{A csv file specifying train/validation split. csv file should contain one column named \code{"file"} and one column named |
|
|
137 |
\code{"type"}. The \code{"file"} column contains names of fasta/fastq files and \code{"type"} column specifies if file is used for training or validation. |
|
|
138 |
Entries in \code{"type"} must be named \code{"train"} or \code{"val"}, otherwise file will not be used for either. \code{path} and \code{path_val} arguments should be the same. |
|
|
139 |
Not implemented for \code{train_type = "label_folder"}.} |
|
|
140 |
|
|
|
141 |
\item{reduce_lr_on_plateau}{Whether to use learning rate scheduler.} |
|
|
142 |
|
|
|
143 |
\item{lr_plateau_factor}{Factor of decreasing learning rate when plateau is reached.} |
|
|
144 |
|
|
|
145 |
\item{patience}{Number of epochs waiting for decrease in validation loss before reducing learning rate.} |
|
|
146 |
|
|
|
147 |
\item{cooldown}{Number of epochs without changing learning rate.} |
|
|
148 |
|
|
|
149 |
\item{model_card}{List of arguments for training parameters of training run. Must contain at least an entry \code{path_model_card}, i.e. the |
|
|
150 |
directory where parameters are stored. List can contain additional (optional) arguments, for example |
|
|
151 |
\code{model_card = list(path_model_card = "/path/to/logs", description = "transfer learning with BERT model on virus data", ...)}} |
|
|
152 |
|
|
|
153 |
\item{callback_list}{Add additional callbacks to \code{keras::fit} call.} |
|
|
154 |
|
|
|
155 |
\item{train_type}{Either \code{"lm"}, \code{"lm_rds"}, \code{"masked_lm"} for language model; \code{"label_header"}, \code{"label_folder"}, \code{"label_csv"}, \code{"label_rds"} for classification or \code{"dummy_gen"}. |
|
|
156 |
\itemize{ |
|
|
157 |
\item Language model is trained to predict character(s) in a sequence. \cr |
|
|
158 |
\item \code{"label_header"}/\code{"label_folder"}/\code{"label_csv"} are trained to predict a corresponding class given a sequence as input. |
|
|
159 |
\item If \code{"label_header"}, class will be read from fasta headers. |
|
|
160 |
\item If \code{"label_folder"}, class will be read from folder, i.e. all files in one folder must belong to the same class. |
|
|
161 |
\item If \code{"label_csv"}, targets are read from a csv file. This file should have one column named "file". The targets then correspond to entries in that row (except "file" |
|
|
162 |
column). Example: if we are currently working with a file called "a.fasta" and corresponding label is "label_1", there should be a row in our csv file\tabular{lll}{ |
|
|
163 |
file \tab label_1 \tab label_2 \cr |
|
|
164 |
"a.fasta" \tab 1 \tab 0 \cr |
|
|
165 |
} |
|
|
166 |
|
|
|
167 |
|
|
|
168 |
\item If \code{"label_rds"}, generator will iterate over set of .rds files containing each a list of input and target tensors. Not implemented for model |
|
|
169 |
with multiple inputs. |
|
|
170 |
\item If \code{"lm_rds"}, generator will iterate over set of .rds files and will split tensor according to \code{target_len} argument |
|
|
171 |
(targets are last \code{target_len} nucleotides of each sequence). |
|
|
172 |
\item If \code{"dummy_gen"}, generator creates random data once and repeatedly feeds these to model. |
|
|
173 |
\item If \code{"masked_lm"}, generator maskes some parts of the input. See \code{masked_lm} argument for details. |
|
|
174 |
}} |
|
|
175 |
|
|
|
176 |
\item{path}{Path to training data. If \code{train_type} is \code{label_folder}, should be a vector or list |
|
|
177 |
where each entry corresponds to a class (list elements can be directories and/or individual files). If \code{train_type} is not \code{label_folder}, |
|
|
178 |
can be a single directory or file or a list of directories and/or files.} |
|
|
179 |
|
|
|
180 |
\item{path_val}{Path to validation data. See \code{path} argument for details.} |
|
|
181 |
|
|
|
182 |
\item{batch_size}{Number of samples used for one network update.} |
|
|
183 |
|
|
|
184 |
\item{step}{Frequency of sampling steps.} |
|
|
185 |
|
|
|
186 |
\item{shuffle_file_order}{Boolean, whether to go through files sequentially or shuffle beforehand.} |
|
|
187 |
|
|
|
188 |
\item{vocabulary}{Vector of allowed characters. Characters outside vocabulary get encoded as specified in \code{ambiguous_nuc}.} |
|
|
189 |
|
|
|
190 |
\item{format}{File format, \code{"fasta"}, \code{"fastq"}, \code{"rds"} or \code{"fasta.tar.gz"}, \code{"fastq.tar.gz"} for \code{tar.gz} files.} |
|
|
191 |
|
|
|
192 |
\item{ambiguous_nuc}{How to handle nucleotides outside vocabulary, either \code{"zero"}, \code{"discard"}, \code{"empirical"} or \code{"equal"}. |
|
|
193 |
\itemize{ |
|
|
194 |
\item If \code{"zero"}, input gets encoded as zero vector. |
|
|
195 |
\item If \code{"equal"}, input is repetition of \code{1/length(vocabulary)}. |
|
|
196 |
\item If \code{"discard"}, samples containing nucleotides outside vocabulary get discarded. |
|
|
197 |
\item If \code{"empirical"}, use nucleotide distribution of current file. |
|
|
198 |
}} |
|
|
199 |
|
|
|
200 |
\item{seed}{Sets seed for reproducible results.} |
|
|
201 |
|
|
|
202 |
\item{file_limit}{Integer or \code{NULL}. If integer, use only specified number of randomly sampled files for training. Ignored if greater than number of files in \code{path}.} |
|
|
203 |
|
|
|
204 |
\item{use_coverage}{Integer or \code{NULL}. If not \code{NULL}, use coverage as encoding rather than one-hot encoding and normalize. |
|
|
205 |
Coverage information must be contained in fasta header: there must be a string \code{"cov_n"} in the header, where \code{n} is some integer.} |
|
|
206 |
|
|
|
207 |
\item{set_learning}{When you want to assign one label to set of samples. Only implemented for \code{train_type = "label_folder"}. |
|
|
208 |
Input is a list with the following parameters |
|
|
209 |
\itemize{ |
|
|
210 |
\item \code{samples_per_target}: how many samples to use for one target. |
|
|
211 |
\item \code{maxlen}: length of one sample. |
|
|
212 |
\item \code{reshape_mode}: \verb{"time_dist", "multi_input"} or \code{"concat"}. |
|
|
213 |
\itemize{ |
|
|
214 |
\item |
|
|
215 |
If \code{reshape_mode} is \code{"multi_input"}, generator will produce \code{samples_per_target} separate inputs, each of length \code{maxlen} (model should have |
|
|
216 |
\code{samples_per_target} input layers). |
|
|
217 |
\item If reshape_mode is \code{"time_dist"}, generator will produce a 4D input array. The dimensions correspond to |
|
|
218 |
\verb{(batch_size, samples_per_target, maxlen, length(vocabulary))}. |
|
|
219 |
\item If \code{reshape_mode} is \code{"concat"}, generator will concatenate \code{samples_per_target} sequences |
|
|
220 |
of length \code{maxlen} to one long sequence. |
|
|
221 |
} |
|
|
222 |
\item If \code{reshape_mode} is \code{"concat"}, there is an additional \code{buffer_len} |
|
|
223 |
argument. If \code{buffer_len} is an integer, the subsequences are interspaced with \code{buffer_len} rows. The input length is |
|
|
224 |
(\code{maxlen} \eqn{*} \code{samples_per_target}) + \code{buffer_len} \eqn{*} (\code{samples_per_target} - 1). |
|
|
225 |
}} |
|
|
226 |
|
|
|
227 |
\item{proportion_entries}{Proportion of fasta entries to keep. For example, if fasta file has 50 entries and \code{proportion_entries = 0.1}, |
|
|
228 |
will randomly select 5 entries.} |
|
|
229 |
|
|
|
230 |
\item{sample_by_file_size}{Sample new file weighted by file size (bigger files more likely).} |
|
|
231 |
|
|
|
232 |
\item{n_gram}{Integer, encode target not nucleotide wise but combine n nucleotides at once. For example for \verb{n=2, "AA" -> (1, 0,..., 0),} |
|
|
233 |
\verb{"AC" -> (0, 1, 0,..., 0), "TT" -> (0,..., 0, 1)}, where the one-hot vectors have length \code{length(vocabulary)^n}.} |
|
|
234 |
|
|
|
235 |
\item{n_gram_stride}{Step size for n-gram encoding. For AACCGGTT with \code{n_gram = 4} and \code{n_gram_stride = 2}, generator encodes |
|
|
236 |
\verb{(AACC), (CCGG), (GGTT)}; for \code{n_gram_stride = 4} generator encodes \verb{(AACC), (GGTT)}.} |
|
|
237 |
|
|
|
238 |
\item{masked_lm}{If not \code{NULL}, input and target are equal except some parts of the input are masked or random. |
|
|
239 |
Must be list with the following arguments: |
|
|
240 |
\itemize{ |
|
|
241 |
\item \code{mask_rate}: Rate of input to mask (rate of input to replace with mask token). |
|
|
242 |
\item \code{random_rate}: Rate of input to set to random token. |
|
|
243 |
\item \code{identity_rate}: Rate of input where sample weights are applied but input and output are identical. |
|
|
244 |
\item \code{include_sw}: Whether to include sample weights. |
|
|
245 |
\item \code{block_len} (optional): Masked/random/identity regions appear in blocks of size \code{block_len}. |
|
|
246 |
}} |
|
|
247 |
|
|
|
248 |
\item{random_sampling}{Whether samples should be taken from random positions when using \code{max_samples} argument. If \code{FALSE} random |
|
|
249 |
samples are taken from a consecutive subsequence.} |
|
|
250 |
|
|
|
251 |
\item{add_noise}{\code{NULL} or list of arguments. If not \code{NULL}, list must contain the following arguments: \code{noise_type} can be \code{"normal"} or \code{"uniform"}; |
|
|
252 |
optional arguments \code{sd} or \code{mean} if noise_type is \code{"normal"} (default is \code{sd=1} and \code{mean=0}) or \verb{min, max} if \code{noise_type} is \code{"uniform"} |
|
|
253 |
(default is \verb{min=0, max=1}).} |
|
|
254 |
|
|
|
255 |
\item{return_int}{Whether to return integer encoding or one-hot encoding.} |
|
|
256 |
|
|
|
257 |
\item{maxlen}{Length of predictor sequence.} |
|
|
258 |
|
|
|
259 |
\item{reverse_complement}{Boolean, for every new file decide randomly to use original data or its reverse complement.} |
|
|
260 |
|
|
|
261 |
\item{reverse_complement_encoding}{Whether to use both original sequence and reverse complement as two input sequences.} |
|
|
262 |
|
|
|
263 |
\item{output_format}{Determines shape of output tensor for language model. |
|
|
264 |
Either \code{"target_right"}, \code{"target_middle_lstm"}, \code{"target_middle_cnn"} or \code{"wavenet"}. |
|
|
265 |
Assume a sequence \code{"AACCGTA"}. Output correspond as follows |
|
|
266 |
\itemize{ |
|
|
267 |
\item \verb{"target_right": X = "AACCGT", Y = "A"} |
|
|
268 |
\item \verb{"target_middle_lstm": X = (X_1 = "AAC", X_2 = "ATG"), Y = "C"} (note reversed order of X_2) |
|
|
269 |
\item \verb{"target_middle_cnn": X = "AACGTA", Y = "C"} |
|
|
270 |
\item \verb{"wavenet": X = "AACCGT", Y = "ACCGTA"} |
|
|
271 |
}} |
|
|
272 |
|
|
|
273 |
\item{proportion_per_seq}{Numerical value between 0 and 1. Proportion of sequence to take samples from (use random subsequence).} |
|
|
274 |
|
|
|
275 |
\item{read_data}{If \code{TRUE} the first element of output is a list of length 2, each containing one part of paired read. Maxlen should be 2*length of one read.} |
|
|
276 |
|
|
|
277 |
\item{use_quality_score}{Whether to use fastq quality scores. If \code{TRUE} input is not one-hot-encoding but corresponds to probabilities. |
|
|
278 |
For example (0.97, 0.01, 0.01, 0.01) instead of (1, 0, 0, 0).} |
|
|
279 |
|
|
|
280 |
\item{padding}{Whether to pad sequences too short for one sample with zeros.} |
|
|
281 |
|
|
|
282 |
\item{concat_seq}{Character string or \code{NULL}. If not \code{NULL} all entries from file get concatenated to one sequence with \code{concat_seq} string between them. |
|
|
283 |
Example: If 1.entry AACC, 2. entry TTTG and \code{concat_seq = "ZZZ"} this becomes AACCZZZTTTG.} |
|
|
284 |
|
|
|
285 |
\item{target_len}{Number of nucleotides to predict at once for language model.} |
|
|
286 |
|
|
|
287 |
\item{skip_amb_nuc}{Threshold of ambiguous nucleotides to accept in fasta entry. Complete entry will get discarded otherwise.} |
|
|
288 |
|
|
|
289 |
\item{max_samples}{Maximum number of samples to use from one file. If not \code{NULL} and file has more than \code{max_samples} samples, will randomly choose a |
|
|
290 |
subset of \code{max_samples} samples.} |
|
|
291 |
|
|
|
292 |
\item{added_label_path}{Path to file with additional input labels. Should be a csv file with one column named "file". Other columns should correspond to labels.} |
|
|
293 |
|
|
|
294 |
\item{add_input_as_seq}{Boolean vector specifying for each entry in \code{added_label_path} if rows from csv should be encoded as a sequence or used directly. |
|
|
295 |
If a row in your csv file is a sequence this should be \code{TRUE}. For example you may want to add another sequence, say ACCGT. Then this would correspond to 1,2,2,3,4 in |
|
|
296 |
csv file (if vocabulary = c("A", "C", "G", "T")). If \code{add_input_as_seq} is \code{TRUE}, 12234 gets one-hot encoded, so added input is a 3D tensor. If \code{add_input_as_seq} is |
|
|
297 |
\code{FALSE} this will feed network just raw data (a 2D tensor).} |
|
|
298 |
|
|
|
299 |
\item{target_from_csv}{Path to csv file with target mapping. One column should be called "file" and other entries in row are the targets.} |
|
|
300 |
|
|
|
301 |
\item{target_split}{If target gets read from csv file, list of names to divide target tensor into list of tensors. |
|
|
302 |
Example: if csv file has header names \verb{"file", "label_1", "label_2", "label_3"} and \code{target_split = list(c("label_1", "label_2"), "label_3")}, |
|
|
303 |
this will divide target matrix to list of length 2, where the first element contains columns named \code{"label_1"} and \code{"label_2"} and the |
|
|
304 |
second entry contains the column named \code{"label_3"}.} |
|
|
305 |
|
|
|
306 |
\item{shuffle_input}{Whether to shuffle entries in file.} |
|
|
307 |
|
|
|
308 |
\item{vocabulary_label}{Character vector of possible targets. Targets outside \code{vocabulary_label} will get discarded if |
|
|
309 |
\code{train_type = "label_header"}.} |
|
|
310 |
|
|
|
311 |
\item{delete_used_files}{Whether to delete file once used. Only applies for rds files.} |
|
|
312 |
|
|
|
313 |
\item{reshape_xy}{Can be a list of functions to apply to input and/or target. List elements (containing the reshape functions) |
|
|
314 |
must be called x for input or y for target and each have arguments called x and y. For example: |
|
|
315 |
\code{reshape_xy = list(x = function(x, y) {return(x+1)}, y = function(x, y) {return(x+y)})} . |
|
|
316 |
For rds generator needs to have an additional argument called sw.} |
|
|
317 |
|
|
|
318 |
\item{return_gen}{Whether to return the train and validation generators (instead of training).} |
|
|
319 |
} |
|
|
320 |
\value{ |
|
|
321 |
A list of training metrics. |
|
|
322 |
} |
|
|
323 |
\description{ |
|
|
324 |
Train a neural network on genomic data. Data can be fasta/fastq files, rds files or a prepared data set. |
|
|
325 |
If the data is given as collection of fasta, fastq or rds files, function will create a data generator that extracts training and validation batches |
|
|
326 |
from files. Function includes several options to determine the sampling strategy of the generator and preprocessing of the data. |
|
|
327 |
Training progress can be visualized in tensorboard. Model weights can be stored during training using checkpoints. |
|
|
328 |
} |
|
|
329 |
\examples{ |
|
|
330 |
\dontshow{if (reticulate::py_module_available("tensorflow")) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} |
|
|
331 |
# create dummy data |
|
|
332 |
path_train_1 <- tempfile() |
|
|
333 |
path_train_2 <- tempfile() |
|
|
334 |
path_val_1 <- tempfile() |
|
|
335 |
path_val_2 <- tempfile() |
|
|
336 |
|
|
|
337 |
for (current_path in c(path_train_1, path_train_2, |
|
|
338 |
path_val_1, path_val_2)) { |
|
|
339 |
dir.create(current_path) |
|
|
340 |
create_dummy_data(file_path = current_path, |
|
|
341 |
num_files = 3, |
|
|
342 |
seq_length = 10, |
|
|
343 |
num_seq = 5, |
|
|
344 |
vocabulary = c("a", "c", "g", "t")) |
|
|
345 |
} |
|
|
346 |
|
|
|
347 |
# create model |
|
|
348 |
model <- create_model_lstm_cnn(layer_lstm = 8, layer_dense = 2, maxlen = 5) |
|
|
349 |
|
|
|
350 |
# train model |
|
|
351 |
hist <- train_model(train_type = "label_folder", |
|
|
352 |
model = model, |
|
|
353 |
path = c(path_train_1, path_train_2), |
|
|
354 |
path_val = c(path_val_1, path_val_2), |
|
|
355 |
batch_size = 8, |
|
|
356 |
epochs = 3, |
|
|
357 |
steps_per_epoch = 6, |
|
|
358 |
step = 5, |
|
|
359 |
format = "fasta", |
|
|
360 |
vocabulary_label = c("label_1", "label_2")) |
|
|
361 |
|
|
|
362 |
\dontshow{\}) # examplesIf} |
|
|
363 |
} |