|
a |
|
b/R/train.R |
|
|
1 |
#' @title Train neural network on genomic data |
|
|
2 |
#' |
|
|
3 |
#' @description |
|
|
4 |
#' Train a neural network on genomic data. Data can be fasta/fastq files, rds files or a prepared data set. |
|
|
5 |
#' If the data is given as collection of fasta, fastq or rds files, function will create a data generator that extracts training and validation batches |
|
|
6 |
#' from files. Function includes several options to determine the sampling strategy of the generator and preprocessing of the data. |
|
|
7 |
#' Training progress can be visualized in tensorboard. Model weights can be stored during training using checkpoints. |
|
|
8 |
#' |
|
|
9 |
#' @inheritParams generator_fasta_lm |
|
|
10 |
#' @inheritParams generator_fasta_label_folder |
|
|
11 |
#' @inheritParams generator_fasta_label_header_csv |
|
|
12 |
#' @inheritParams get_generator |
|
|
13 |
#' @param train_type Either `"lm"`, `"lm_rds"`, `"masked_lm"` for language model; `"label_header"`, `"label_folder"`, `"label_csv"`, `"label_rds"` for classification or `"dummy_gen"`. |
|
|
14 |
#' \itemize{ |
|
|
15 |
#' \item Language model is trained to predict character(s) in a sequence. \cr |
|
|
16 |
#' \item `"label_header"`/`"label_folder"`/`"label_csv"` are trained to predict a corresponding class given a sequence as input. |
|
|
17 |
#' \item If `"label_header"`, class will be read from fasta headers. |
|
|
18 |
#' \item If `"label_folder"`, class will be read from folder, i.e. all files in one folder must belong to the same class. |
|
|
19 |
#' \item If `"label_csv"`, targets are read from a csv file. This file should have one column named "file". The targets then correspond to entries in that row (except "file" |
|
|
20 |
#' column). Example: if we are currently working with a file called "a.fasta" and corresponding label is "label_1", there should be a row in our csv file |
|
|
21 |
#' |
|
|
22 |
#' | file | label_1 | label_2 | |
|
|
23 |
#' | --- | --- | --- | |
|
|
24 |
#' | "a.fasta" | 1 | 0 | |
|
|
25 |
#' |
|
|
26 |
#' \item If `"label_rds"`, generator will iterate over set of .rds files containing each a list of input and target tensors. Not implemented for model |
|
|
27 |
#' with multiple inputs. |
|
|
28 |
#' \item If `"lm_rds"`, generator will iterate over set of .rds files and will split tensor according to `target_len` argument |
|
|
29 |
#' (targets are last `target_len` nucleotides of each sequence). |
|
|
30 |
#' \item If `"dummy_gen"`, generator creates random data once and repeatedly feeds these to model. |
|
|
31 |
#' \item If `"masked_lm"`, generator maskes some parts of the input. See `masked_lm` argument for details. |
|
|
32 |
#' } |
|
|
33 |
#' @param model A keras model. |
|
|
34 |
#' @param path Path to training data. If \code{train_type} is \code{label_folder}, should be a vector or list |
|
|
35 |
#' where each entry corresponds to a class (list elements can be directories and/or individual files). If \code{train_type} is not \code{label_folder}, |
|
|
36 |
#' can be a single directory or file or a list of directories and/or files. |
|
|
37 |
#' @param path_val Path to validation data. See `path` argument for details. |
|
|
38 |
#' @param dataset List of training data holding training samples in RAM instead of using generator. Should be list with two entries called `"X"` and `"Y"`. |
|
|
39 |
#' @param dataset_val List of validation data. Should have two entries called `"X"` and `"Y"`. |
|
|
40 |
#' @param path_checkpoint Path to checkpoints folder or `NULL`. If `NULL`, checkpoints don't get stored. |
|
|
41 |
#' @param path_log Path to directory to write training scores. File name is `run_name` + `".csv"`. No output if `NULL`. |
|
|
42 |
#' @param train_val_ratio For generator defines the fraction of batches that will be used for validation (compared to size of training data), i.e. one validation iteration |
|
|
43 |
#' processes \code{batch_size} \eqn{*} \code{steps_per_epoch} \eqn{*} \code{train_val_ratio} samples. If you use dataset instead of generator and \code{dataset_val} is `NULL`, splits \code{dataset} |
|
|
44 |
#' into train/validation data. |
|
|
45 |
#' @param run_name Name of the run. Name will be used to identify output from callbacks. If `NULL`, will use date as run name. |
|
|
46 |
#' If name already present, will add `"_2"` to name or `"_{x+1}"` if name ends with `_x`, where `x` is some integer. |
|
|
47 |
#' @param batch_size Number of samples used for one network update. |
|
|
48 |
#' @param epochs Number of iterations. |
|
|
49 |
#' @param max_queue_size Maximum size for the generator queue. |
|
|
50 |
#' @param reduce_lr_on_plateau Whether to use learning rate scheduler. |
|
|
51 |
#' @param lr_plateau_factor Factor of decreasing learning rate when plateau is reached. |
|
|
52 |
#' @param patience Number of epochs waiting for decrease in validation loss before reducing learning rate. |
|
|
53 |
#' @param cooldown Number of epochs without changing learning rate. |
|
|
54 |
#' @param steps_per_epoch Number of training batches per epoch. |
|
|
55 |
#' @param step Frequency of sampling steps. |
|
|
56 |
#' @param shuffle_file_order Boolean, whether to go through files sequentially or shuffle beforehand. |
|
|
57 |
#' @param vocabulary Vector of allowed characters. Characters outside vocabulary get encoded as specified in \code{ambiguous_nuc}. |
|
|
58 |
#' @param initial_epoch Epoch at which to start training. Note that network |
|
|
59 |
#' will run for (\code{epochs} - \code{initial_epochs}) rounds and not \code{epochs} rounds. |
|
|
60 |
#' @param path_tensorboard Path to tensorboard directory or `NULL`. If `NULL`, training not tracked on tensorboard. |
|
|
61 |
#' @param save_best_only Only save model that improved on some score. Not applied if argument is `NULL`. Otherwise must be |
|
|
62 |
#' list with argument `monitor` or `save_freq` (can only use one option). `moniter` specifies what metric to use. |
|
|
63 |
#' `save_freq`, integer specifying how often to store a checkpoint (in epochs). |
|
|
64 |
#' @param save_weights_only Whether to save weights only. |
|
|
65 |
#' @param seed Sets seed for reproducible results. |
|
|
66 |
#' @param shuffle_input Whether to shuffle entries in file. |
|
|
67 |
#' @param tb_images Whether to show custom images (confusion matrix) in tensorboard "IMAGES" tab. |
|
|
68 |
#' @param format File format, `"fasta"`, `"fastq"`, `"rds"` or `"fasta.tar.gz"`, `"fastq.tar.gz"` for `tar.gz` files. |
|
|
69 |
#' @param path_file_log Write name of files used for training to csv file if path is specified. |
|
|
70 |
#' @param vocabulary_label Character vector of possible targets. Targets outside \code{vocabulary_label} will get discarded if |
|
|
71 |
#' \code{train_type = "label_header"}. |
|
|
72 |
#' @param file_limit Integer or `NULL`. If integer, use only specified number of randomly sampled files for training. Ignored if greater than number of files in \code{path}. |
|
|
73 |
#' @param reverse_complement_encoding Whether to use both original sequence and reverse complement as two input sequences. |
|
|
74 |
#' @param output_format Determines shape of output tensor for language model. |
|
|
75 |
#' Either `"target_right"`, `"target_middle_lstm"`, `"target_middle_cnn"` or `"wavenet"`. |
|
|
76 |
#' Assume a sequence `"AACCGTA"`. Output correspond as follows |
|
|
77 |
#' \itemize{ |
|
|
78 |
#' \item `"target_right": X = "AACCGT", Y = "A"` |
|
|
79 |
#' \item `"target_middle_lstm": X = (X_1 = "AAC", X_2 = "ATG"), Y = "C"` (note reversed order of X_2) |
|
|
80 |
#' \item `"target_middle_cnn": X = "AACGTA", Y = "C"` |
|
|
81 |
#' \item `"wavenet": X = "AACCGT", Y = "ACCGTA"` |
|
|
82 |
#' } |
|
|
83 |
#' @param reset_states Whether to reset hidden states of RNN layer at every new input file and before/after validation. |
|
|
84 |
#' @param use_quality_score Whether to use fastq quality scores. If `TRUE` input is not one-hot-encoding but corresponds to probabilities. |
|
|
85 |
#' For example (0.97, 0.01, 0.01, 0.01) instead of (1, 0, 0, 0). |
|
|
86 |
#' @param padding Whether to pad sequences too short for one sample with zeros. |
|
|
87 |
#' @param early_stopping_time Time in seconds after which to stop training. |
|
|
88 |
#' @param validation_only_after_training Whether to skip validation during training and only do one validation iteration after training. |
|
|
89 |
#' @param skip_amb_nuc Threshold of ambiguous nucleotides to accept in fasta entry. Complete entry will get discarded otherwise. |
|
|
90 |
#' @param class_weight List of weights for output. Order should correspond to \code{vocabulary_label}. |
|
|
91 |
#' You can use \code{\link{get_class_weight}} function to estimate class weights: |
|
|
92 |
#' |
|
|
93 |
#' \code{class_weights <- get_class_weights(path = path, train_type = train_type)} |
|
|
94 |
#' |
|
|
95 |
#' If \code{train_type = "label_csv"} you need to add path to csv file: |
|
|
96 |
#' |
|
|
97 |
#' \code{class_weights <- get_class_weights(path = path, train_type = train_type, csv_path = target_from_csv)} |
|
|
98 |
#' @param print_scores Whether to print train/validation scores during training. |
|
|
99 |
#' @param train_val_split_csv A csv file specifying train/validation split. csv file should contain one column named `"file"` and one column named |
|
|
100 |
#' `"type"`. The `"file"` column contains names of fasta/fastq files and `"type"` column specifies if file is used for training or validation. |
|
|
101 |
#' Entries in `"type"` must be named `"train"` or `"val"`, otherwise file will not be used for either. `path` and `path_val` arguments should be the same. |
|
|
102 |
#' Not implemented for `train_type = "label_folder"`. |
|
|
103 |
#' @param set_learning When you want to assign one label to set of samples. Only implemented for `train_type = "label_folder"`. |
|
|
104 |
#' Input is a list with the following parameters |
|
|
105 |
#' \itemize{ |
|
|
106 |
#' \item `samples_per_target`: how many samples to use for one target. |
|
|
107 |
#' \item `maxlen`: length of one sample. |
|
|
108 |
#' \item `reshape_mode`: `"time_dist", "multi_input"` or `"concat"`. |
|
|
109 |
#' \itemize{ |
|
|
110 |
#' \item |
|
|
111 |
#' If `reshape_mode` is `"multi_input"`, generator will produce `samples_per_target` separate inputs, each of length `maxlen` (model should have |
|
|
112 |
#' `samples_per_target` input layers). |
|
|
113 |
#' \item If reshape_mode is `"time_dist"`, generator will produce a 4D input array. The dimensions correspond to |
|
|
114 |
#' `(batch_size, samples_per_target, maxlen, length(vocabulary))`. |
|
|
115 |
#' \item If `reshape_mode` is `"concat"`, generator will concatenate `samples_per_target` sequences |
|
|
116 |
#' of length `maxlen` to one long sequence. |
|
|
117 |
#' } |
|
|
118 |
#' \item If `reshape_mode` is `"concat"`, there is an additional `buffer_len` |
|
|
119 |
#' argument. If `buffer_len` is an integer, the subsequences are interspaced with `buffer_len` rows. The input length is |
|
|
120 |
#' (`maxlen` \eqn{*} `samples_per_target`) + `buffer_len` \eqn{*} (`samples_per_target` - 1). |
|
|
121 |
#' } |
|
|
122 |
#' @param random_sampling Whether samples should be taken from random positions when using `max_samples` argument. If `FALSE` random |
|
|
123 |
#' samples are taken from a consecutive subsequence. |
|
|
124 |
#' @param n_gram_stride Step size for n-gram encoding. For AACCGGTT with `n_gram = 4` and `n_gram_stride = 2`, generator encodes |
|
|
125 |
#' `(AACC), (CCGG), (GGTT)`; for `n_gram_stride = 4` generator encodes `(AACC), (GGTT)`. |
|
|
126 |
#' @param callback_list Add additional callbacks to `keras::fit` call. |
|
|
127 |
#' @param model_card List of arguments for training parameters of training run. Must contain at least an entry `path_model_card`, i.e. the |
|
|
128 |
#' directory where parameters are stored. List can contain additional (optional) arguments, for example |
|
|
129 |
#' `model_card = list(path_model_card = "/path/to/logs", description = "transfer learning with BERT model on virus data", ...)` |
|
|
130 |
#' @param return_gen Whether to return the train and validation generators (instead of training). |
|
|
131 |
#' @examplesIf reticulate::py_module_available("tensorflow") |
|
|
132 |
#' # create dummy data |
|
|
133 |
#' path_train_1 <- tempfile() |
|
|
134 |
#' path_train_2 <- tempfile() |
|
|
135 |
#' path_val_1 <- tempfile() |
|
|
136 |
#' path_val_2 <- tempfile() |
|
|
137 |
#' |
|
|
138 |
#' for (current_path in c(path_train_1, path_train_2, |
|
|
139 |
#' path_val_1, path_val_2)) { |
|
|
140 |
#' dir.create(current_path) |
|
|
141 |
#' create_dummy_data(file_path = current_path, |
|
|
142 |
#' num_files = 3, |
|
|
143 |
#' seq_length = 10, |
|
|
144 |
#' num_seq = 5, |
|
|
145 |
#' vocabulary = c("a", "c", "g", "t")) |
|
|
146 |
#' } |
|
|
147 |
#' |
|
|
148 |
#' # create model |
|
|
149 |
#' model <- create_model_lstm_cnn(layer_lstm = 8, layer_dense = 2, maxlen = 5) |
|
|
150 |
#' |
|
|
151 |
#' # train model |
|
|
152 |
#' hist <- train_model(train_type = "label_folder", |
|
|
153 |
#' model = model, |
|
|
154 |
#' path = c(path_train_1, path_train_2), |
|
|
155 |
#' path_val = c(path_val_1, path_val_2), |
|
|
156 |
#' batch_size = 8, |
|
|
157 |
#' epochs = 3, |
|
|
158 |
#' steps_per_epoch = 6, |
|
|
159 |
#' step = 5, |
|
|
160 |
#' format = "fasta", |
|
|
161 |
#' vocabulary_label = c("label_1", "label_2")) |
|
|
162 |
#' |
|
|
163 |
#' @returns A list of training metrics. |
|
|
164 |
#' @export |
|
|
165 |
train_model <- function(model = NULL, |
|
|
166 |
dataset = NULL, |
|
|
167 |
dataset_val = NULL, |
|
|
168 |
# training args |
|
|
169 |
train_val_ratio = 0.2, |
|
|
170 |
run_name = "run_1", |
|
|
171 |
initial_epoch = 0, |
|
|
172 |
class_weight = NULL, |
|
|
173 |
print_scores = TRUE, |
|
|
174 |
epochs = 10, |
|
|
175 |
max_queue_size = 100, |
|
|
176 |
steps_per_epoch = 1000, |
|
|
177 |
# callbacks |
|
|
178 |
path_checkpoint = NULL, |
|
|
179 |
path_tensorboard = NULL, |
|
|
180 |
path_log = NULL, |
|
|
181 |
save_best_only = NULL, |
|
|
182 |
save_weights_only = FALSE, |
|
|
183 |
tb_images = FALSE, |
|
|
184 |
path_file_log = NULL, |
|
|
185 |
reset_states = FALSE, |
|
|
186 |
early_stopping_time = NULL, |
|
|
187 |
validation_only_after_training = FALSE, |
|
|
188 |
train_val_split_csv = NULL, |
|
|
189 |
reduce_lr_on_plateau = TRUE, |
|
|
190 |
lr_plateau_factor = 0.9, |
|
|
191 |
patience = 20, |
|
|
192 |
cooldown = 1, |
|
|
193 |
model_card = NULL, |
|
|
194 |
callback_list = NULL, |
|
|
195 |
# generator args |
|
|
196 |
train_type = "label_folder", |
|
|
197 |
path = NULL, |
|
|
198 |
path_val = NULL, |
|
|
199 |
batch_size = 64, |
|
|
200 |
step = NULL, |
|
|
201 |
shuffle_file_order = TRUE, |
|
|
202 |
vocabulary = c("a", "c", "g", "t"), |
|
|
203 |
format = "fasta", |
|
|
204 |
ambiguous_nuc = "zero", |
|
|
205 |
seed = c(1234, 4321), |
|
|
206 |
file_limit = NULL, |
|
|
207 |
use_coverage = NULL, |
|
|
208 |
set_learning = NULL, |
|
|
209 |
proportion_entries = NULL, |
|
|
210 |
sample_by_file_size = FALSE, |
|
|
211 |
n_gram = NULL, |
|
|
212 |
n_gram_stride = 1, |
|
|
213 |
masked_lm = NULL, |
|
|
214 |
random_sampling = FALSE, |
|
|
215 |
add_noise = NULL, |
|
|
216 |
return_int = FALSE, |
|
|
217 |
maxlen = NULL, |
|
|
218 |
reverse_complement = FALSE, |
|
|
219 |
reverse_complement_encoding = FALSE, |
|
|
220 |
output_format = "target_right", |
|
|
221 |
proportion_per_seq = NULL, |
|
|
222 |
read_data = FALSE, |
|
|
223 |
use_quality_score = FALSE, |
|
|
224 |
padding = FALSE, |
|
|
225 |
concat_seq = NULL, |
|
|
226 |
target_len = 1, |
|
|
227 |
skip_amb_nuc = NULL, |
|
|
228 |
max_samples = NULL, |
|
|
229 |
added_label_path = NULL, |
|
|
230 |
add_input_as_seq = NULL, |
|
|
231 |
target_from_csv = NULL, |
|
|
232 |
target_split = NULL, |
|
|
233 |
shuffle_input = TRUE, |
|
|
234 |
vocabulary_label = NULL, |
|
|
235 |
delete_used_files = FALSE, |
|
|
236 |
reshape_xy = NULL, |
|
|
237 |
return_gen = FALSE) { |
|
|
238 |
|
|
|
239 |
if (!is.null(model_card)) { |
|
|
240 |
if (!is.list(model_card)) { |
|
|
241 |
stop("model_card must be a list and contain at least an entry called 'path_model_card'") |
|
|
242 |
} |
|
|
243 |
} |
|
|
244 |
|
|
|
245 |
# initialize metrics, temporary fix |
|
|
246 |
model <- manage_metrics(model) |
|
|
247 |
|
|
|
248 |
run_name <- get_run_name(run_name, path_tensorboard, path_checkpoint, path_log, |
|
|
249 |
path_model_card = model_card$path_model_card, |
|
|
250 |
auto_extend = TRUE) |
|
|
251 |
train_with_gen <- is.null(dataset) |
|
|
252 |
output <- list(tensorboard = FALSE, checkpoints = FALSE) |
|
|
253 |
if (!is.null(path_tensorboard)) output$tensorboard <- TRUE |
|
|
254 |
if (!is.null(path_checkpoint)) output$checkpoints <- TRUE |
|
|
255 |
wavenet_format <- FALSE ; target_middle <- FALSE ; cnn_format <- FALSE |
|
|
256 |
if (train_type != "label_csv") target_from_csv <- NULL |
|
|
257 |
|
|
|
258 |
if (train_with_gen) { |
|
|
259 |
stopifnot(train_type %in% c("lm", "label_header", "label_folder", "label_csv", "label_rds", "lm_rds", "dummy_gen", "masked_lm")) |
|
|
260 |
stopifnot(ambiguous_nuc %in% c("zero", "equal", "discard", "empirical")) |
|
|
261 |
stopifnot(length(vocabulary) == length(unique(vocabulary))) |
|
|
262 |
stopifnot(length(vocabulary_label) == length(unique(vocabulary_label))) |
|
|
263 |
labelByFolder <- FALSE |
|
|
264 |
labelGen <- ifelse(train_type == "lm", FALSE, TRUE) |
|
|
265 |
|
|
|
266 |
if (train_type == "label_header") target_from_csv <- NULL |
|
|
267 |
if (train_type == "label_csv") { |
|
|
268 |
#train_type <- "label_header" |
|
|
269 |
if (is.null(target_from_csv)) { |
|
|
270 |
stop('You need to add a path to csv file for target_from_csv when using train_type = "label_csv"') |
|
|
271 |
} |
|
|
272 |
if (!is.null(vocabulary_label)) { |
|
|
273 |
message("Reading vocabulary_label from csv header") |
|
|
274 |
if (!is.data.frame(target_from_csv)) { |
|
|
275 |
output_label_csv <- utils::read.csv2(target_from_csv, header = TRUE, stringsAsFactors = FALSE) |
|
|
276 |
if (dim(output_label_csv)[2] == 1) { |
|
|
277 |
output_label_csv <- utils::read.csv(target_from_csv, header = TRUE, stringsAsFactors = FALSE) |
|
|
278 |
} |
|
|
279 |
} else { |
|
|
280 |
output_label_csv <- target_from_csv |
|
|
281 |
} |
|
|
282 |
vocabulary_label <- names(output_label_csv) |
|
|
283 |
vocabulary_label <- vocabulary_label[vocabulary_label != "file"] |
|
|
284 |
} |
|
|
285 |
} |
|
|
286 |
|
|
|
287 |
if (!is.null(skip_amb_nuc)) { |
|
|
288 |
if((skip_amb_nuc > 1) | (skip_amb_nuc <0)) { |
|
|
289 |
stop("skip_amb_nuc should be between 0 and 1 or NULL") |
|
|
290 |
} |
|
|
291 |
} |
|
|
292 |
|
|
|
293 |
if (!is.null(proportion_per_seq)) { |
|
|
294 |
if(any(proportion_per_seq > 1) | any(proportion_per_seq < 0)) { |
|
|
295 |
stop("proportion_per_seq should be between 0 and 1 or NULL") |
|
|
296 |
} |
|
|
297 |
} |
|
|
298 |
|
|
|
299 |
# TODO: adjust for multi output model |
|
|
300 |
# if (!is.null(class_weight) && (length(class_weight) != length(vocabulary_label))) { |
|
|
301 |
# stop("class_weight and vocabulary_label must have same length") |
|
|
302 |
# } |
|
|
303 |
|
|
|
304 |
if (!is.null(concat_seq)) { |
|
|
305 |
if (!is.null(use_coverage)) stop("Coverage encoding not implemented for concat_seq") |
|
|
306 |
} |
|
|
307 |
|
|
|
308 |
# train train_val_ratio via csv file |
|
|
309 |
if (!is.null(train_val_split_csv)) { |
|
|
310 |
|
|
|
311 |
train_val_file <- utils::read.csv2(train_val_split_csv, header = TRUE, stringsAsFactors = FALSE) |
|
|
312 |
|
|
|
313 |
if (is.null(path)) { |
|
|
314 |
path <- train_val_file %>% dplyr::filter(type %in% c("train", "val", "validation")) %>% |
|
|
315 |
dplyr::select(file) %>% as.list() |
|
|
316 |
} |
|
|
317 |
|
|
|
318 |
if (train_type == "label_folder") { |
|
|
319 |
stop('train_val_split_csv not implemented for train_type = "label_folder"') |
|
|
320 |
} |
|
|
321 |
if (is.null(path_val)) { |
|
|
322 |
path_val <- path |
|
|
323 |
} else { |
|
|
324 |
if (!all(unlist(path_val) %in% unlist(path))) { |
|
|
325 |
warning("Train/validation split done via file in train_val_split_csv. Only using files from path argument.") |
|
|
326 |
} |
|
|
327 |
path_val <- path |
|
|
328 |
} |
|
|
329 |
|
|
|
330 |
if (dim(train_val_file)[2] == 1) { |
|
|
331 |
train_val_file <- utils::read.csv(train_val_split_csv, header = TRUE, stringsAsFactors = FALSE) |
|
|
332 |
} |
|
|
333 |
train_val_file <- dplyr::distinct(train_val_file) |
|
|
334 |
|
|
|
335 |
if (!all(c("file", "type") %in% names(train_val_file))) { |
|
|
336 |
stop("Column names of train_val_split_csv file must be 'file' and 'type'") |
|
|
337 |
} |
|
|
338 |
|
|
|
339 |
if (length(train_val_file$file) != length(unique(train_val_file$file))) { |
|
|
340 |
stop("In train_val_split_csv all entires in 'file' column must be unique") |
|
|
341 |
} |
|
|
342 |
|
|
|
343 |
train_files <- train_val_file %>% dplyr::filter(type == "train") |
|
|
344 |
train_files <- as.character(train_files$file) |
|
|
345 |
val_files <- train_val_file %>% dplyr::filter(type == "val" | type == "validation") |
|
|
346 |
val_files <- as.character(val_files$file) |
|
|
347 |
} else { |
|
|
348 |
train_files <- NULL |
|
|
349 |
val_files <- NULL |
|
|
350 |
} |
|
|
351 |
|
|
|
352 |
if (train_type == "lm") { |
|
|
353 |
stopifnot(output_format %in% c("target_right", "target_middle_lstm", "target_middle_cnn", "wavenet")) |
|
|
354 |
if (output_format == "target_middle_lstm") target_middle <- TRUE |
|
|
355 |
if (output_format == "target_middle_cnn") cnn_format <- TRUE |
|
|
356 |
if (output_format == "wavenet") wavenet_format <- TRUE |
|
|
357 |
} |
|
|
358 |
|
|
|
359 |
if (train_type == "label_header" & is.null(target_from_csv)) { |
|
|
360 |
stopifnot(!is.null(vocabulary_label)) |
|
|
361 |
} |
|
|
362 |
|
|
|
363 |
if (train_type == "label_folder") { |
|
|
364 |
labelByFolder <- TRUE |
|
|
365 |
stopifnot(!is.null(vocabulary_label)) |
|
|
366 |
stopifnot(length(path) == length(vocabulary_label)) |
|
|
367 |
} |
|
|
368 |
|
|
|
369 |
} |
|
|
370 |
|
|
|
371 |
model_weights <- model$get_weights() |
|
|
372 |
|
|
|
373 |
# function arguments |
|
|
374 |
argumentList <- as.list(match.call(expand.dots=FALSE)) |
|
|
375 |
#argumentList <- c(as.list(environment()), list(...)) log default args too |
|
|
376 |
argumentList <- argumentList[names(argumentList) != ""] |
|
|
377 |
argumentList <- lapply(argumentList, eval, envir = parent.frame()) |
|
|
378 |
|
|
|
379 |
# extract maxlen from model |
|
|
380 |
if (is.null(maxlen)) { |
|
|
381 |
maxlen <- get_maxlen(model, set_learning, target_middle, read_data) |
|
|
382 |
} |
|
|
383 |
|
|
|
384 |
if (is.null(step)) step <- maxlen |
|
|
385 |
vocabulary_label_size <- length(vocabulary_label) |
|
|
386 |
vocabulary_size <- length(vocabulary) |
|
|
387 |
|
|
|
388 |
if (is.null(dataset) && labelByFolder) { |
|
|
389 |
if (length(path) == 1) warning("Training with just one label") |
|
|
390 |
} |
|
|
391 |
|
|
|
392 |
# add empty hparam dict if non exists |
|
|
393 |
if (!reticulate::py_has_attr(model, "hparam")) { |
|
|
394 |
model$hparam <- reticulate::dict() |
|
|
395 |
} |
|
|
396 |
|
|
|
397 |
# tempory file to log training data |
|
|
398 |
removeLog <- FALSE |
|
|
399 |
if (is.null(path_file_log)) { |
|
|
400 |
removeLog <- TRUE |
|
|
401 |
path_file_log <- tempfile(pattern = "", fileext = ".csv") |
|
|
402 |
} else { |
|
|
403 |
if (!endsWith(path_file_log, ".csv")) path_file_log <- paste0(path_file_log, ".csv") |
|
|
404 |
#path_file_logVal <- tempfile(pattern = "", fileext = ".csv") |
|
|
405 |
} |
|
|
406 |
if (reset_states) { |
|
|
407 |
path_file_logVal <- tempfile(pattern = "", fileext = ".csv") |
|
|
408 |
} else { |
|
|
409 |
path_file_logVal <- NULL |
|
|
410 |
} |
|
|
411 |
|
|
|
412 |
# if no dataset is supplied, external fasta generator will generate batches |
|
|
413 |
if (train_with_gen) { |
|
|
414 |
#message("Starting fasta generator...") |
|
|
415 |
|
|
|
416 |
gen <- get_generator(path = path, batch_size = batch_size, model = model, |
|
|
417 |
maxlen = maxlen, step = step, shuffle_file_order = shuffle_file_order, |
|
|
418 |
vocabulary = vocabulary, seed = seed[1], proportion_entries = proportion_entries, |
|
|
419 |
shuffle_input = shuffle_input, format = format, reshape_xy = reshape_xy, |
|
|
420 |
path_file_log = path_file_log, reverse_complement = reverse_complement, n_gram_stride = n_gram_stride, |
|
|
421 |
output_format = output_format, ambiguous_nuc = ambiguous_nuc, |
|
|
422 |
proportion_per_seq = proportion_per_seq, skip_amb_nuc = skip_amb_nuc, |
|
|
423 |
use_quality_score = use_quality_score, padding = padding, n_gram = n_gram, |
|
|
424 |
added_label_path = added_label_path, add_input_as_seq = add_input_as_seq, |
|
|
425 |
max_samples = max_samples, concat_seq = concat_seq, target_len = target_len, |
|
|
426 |
file_filter = train_files, use_coverage = use_coverage, random_sampling = random_sampling, |
|
|
427 |
train_type = train_type, set_learning = set_learning, file_limit = file_limit, |
|
|
428 |
reverse_complement_encoding = reverse_complement_encoding, read_data = read_data, |
|
|
429 |
sample_by_file_size = sample_by_file_size, add_noise = add_noise, target_split = target_split, |
|
|
430 |
target_from_csv = target_from_csv, masked_lm = masked_lm, return_int = return_int, |
|
|
431 |
path_file_logVal = path_file_logVal, delete_used_files = delete_used_files, |
|
|
432 |
vocabulary_label = vocabulary_label, val = FALSE) |
|
|
433 |
|
|
|
434 |
if (!is.null(path_val)) { |
|
|
435 |
|
|
|
436 |
gen.val <- get_generator(path = path_val, batch_size = batch_size, model = model, |
|
|
437 |
maxlen = maxlen, step = step, shuffle_file_order = shuffle_file_order, |
|
|
438 |
vocabulary = vocabulary, seed = seed[2], proportion_entries = proportion_entries, |
|
|
439 |
shuffle_input = shuffle_input, format = format, delete_used_files = FALSE, |
|
|
440 |
path_file_log = path_file_logVal, reverse_complement = reverse_complement, n_gram_stride = n_gram_stride, |
|
|
441 |
output_format = output_format, ambiguous_nuc = ambiguous_nuc, reshape_xy = reshape_xy, |
|
|
442 |
proportion_per_seq = proportion_per_seq, skip_amb_nuc = skip_amb_nuc, |
|
|
443 |
use_quality_score = use_quality_score, padding = padding, n_gram = n_gram, |
|
|
444 |
added_label_path = added_label_path, add_input_as_seq = add_input_as_seq, |
|
|
445 |
max_samples = max_samples, concat_seq = concat_seq, target_len = target_len, |
|
|
446 |
file_filter = val_files, use_coverage = use_coverage, random_sampling = random_sampling, |
|
|
447 |
train_type = train_type, set_learning = set_learning, file_limit = file_limit, |
|
|
448 |
reverse_complement_encoding = reverse_complement_encoding, read_data = read_data, |
|
|
449 |
sample_by_file_size = sample_by_file_size, add_noise = add_noise, target_split = target_split, |
|
|
450 |
target_from_csv = target_from_csv, masked_lm = masked_lm, return_int = return_int, |
|
|
451 |
path_file_logVal = path_file_logVal, vocabulary_label = vocabulary_label, |
|
|
452 |
val = TRUE) |
|
|
453 |
} else { |
|
|
454 |
gen.val <- NULL |
|
|
455 |
} |
|
|
456 |
|
|
|
457 |
} |
|
|
458 |
|
|
|
459 |
# skip validation callback |
|
|
460 |
if (validation_only_after_training | is.null(train_val_ratio) || train_val_ratio == 0) { |
|
|
461 |
validation_data <- NULL |
|
|
462 |
} else { |
|
|
463 |
if (train_with_gen) { |
|
|
464 |
if (is.null(path_val)) { |
|
|
465 |
validation_data <- NULL |
|
|
466 |
} else { |
|
|
467 |
validation_data <- gen.val |
|
|
468 |
} |
|
|
469 |
} else { |
|
|
470 |
validation_data <- dataset_val |
|
|
471 |
} |
|
|
472 |
} |
|
|
473 |
|
|
|
474 |
if (is.null(validation_data)) { |
|
|
475 |
validation_steps <- NULL |
|
|
476 |
} else { |
|
|
477 |
validation_steps <- ceiling(steps_per_epoch * train_val_ratio) |
|
|
478 |
} |
|
|
479 |
|
|
|
480 |
callbacks <- get_callbacks(default_arguments = NULL, model = model, path_tensorboard = path_tensorboard, run_name = run_name, train_type = train_type, |
|
|
481 |
path = path, train_val_ratio = train_val_ratio, batch_size = batch_size, epochs = epochs, |
|
|
482 |
max_queue_size = max_queue_size, lr_plateau_factor = lr_plateau_factor, patience = patience, cooldown = cooldown, format = format, |
|
|
483 |
steps_per_epoch = steps_per_epoch, step = step, shuffle_file_order = shuffle_file_order, initial_epoch = initial_epoch, vocabulary = vocabulary, |
|
|
484 |
learning_rate = model$optimizer$learning_rate$numpy(), solver = stringr::str_to_lower(model$optimizer$get_config()["name"]), |
|
|
485 |
shuffle_input = shuffle_input, vocabulary_label = vocabulary_label, |
|
|
486 |
file_limit = file_limit, reverse_complement = reverse_complement, wavenet_format = wavenet_format, cnn_format = cnn_format, |
|
|
487 |
train_val_split_csv = train_val_split_csv, n_gram = n_gram, path_file_logVal = path_file_logVal, validation_steps = validation_steps, |
|
|
488 |
create_model_function = NULL, vocabulary_size = vocabulary_size, gen_cb = NULL, argumentList = argumentList, output = output, |
|
|
489 |
maxlen = maxlen, labelGen = labelGen, labelByFolder = labelByFolder, vocabulary_label_size = vocabulary_label_size, tb_images = tb_images, |
|
|
490 |
target_middle = target_middle, path_file_log = path_file_log, proportion_per_seq = proportion_per_seq, |
|
|
491 |
skip_amb_nuc = skip_amb_nuc, max_samples = max_samples, proportion_entries = proportion_entries, path_log = path_log, |
|
|
492 |
train_with_gen = train_with_gen, random_sampling = random_sampling, reduce_lr_on_plateau = reduce_lr_on_plateau, |
|
|
493 |
save_weights_only = save_weights_only, path_checkpoint = path_checkpoint, save_best_only = save_best_only, gen.val = gen.val, |
|
|
494 |
target_from_csv = target_from_csv, reset_states = reset_states, early_stopping_time = early_stopping_time, |
|
|
495 |
validation_only_after_training = validation_only_after_training, model_card = model_card, dataset_val = dataset_val) |
|
|
496 |
|
|
|
497 |
# training |
|
|
498 |
if (train_with_gen) { |
|
|
499 |
|
|
|
500 |
if (!is.null(dataset_val)) { |
|
|
501 |
validation_data <- dataset_val |
|
|
502 |
validation_steps <- NULL |
|
|
503 |
} |
|
|
504 |
|
|
|
505 |
if (return_gen) { |
|
|
506 |
return(list(gen = gen, gen.val = gen.val)) |
|
|
507 |
} |
|
|
508 |
|
|
|
509 |
model <- keras::set_weights(model, model_weights) |
|
|
510 |
history <- |
|
|
511 |
model %>% keras::fit( |
|
|
512 |
x = gen, |
|
|
513 |
validation_data = validation_data, |
|
|
514 |
validation_steps = validation_steps, |
|
|
515 |
steps_per_epoch = steps_per_epoch, |
|
|
516 |
max_queue_size = max_queue_size, |
|
|
517 |
epochs = epochs, |
|
|
518 |
initial_epoch = initial_epoch, |
|
|
519 |
callbacks = c(callbacks, callback_list), |
|
|
520 |
class_weight = class_weight, |
|
|
521 |
batch_size = batch_size, |
|
|
522 |
verbose = print_scores) |
|
|
523 |
|
|
|
524 |
if (validation_only_after_training) { |
|
|
525 |
history$val_loss <- model$val_loss |
|
|
526 |
history$val_acc <- model$val_acc |
|
|
527 |
model$val_loss <- NULL |
|
|
528 |
model$val_acc <- NULL |
|
|
529 |
} |
|
|
530 |
|
|
|
531 |
} else { |
|
|
532 |
|
|
|
533 |
model <- keras::set_weights(model, model_weights) |
|
|
534 |
if (!is.null(dataset_val)) { |
|
|
535 |
validation_data <- list(dataset_val[[1]], dataset_val[[2]]) |
|
|
536 |
} else { |
|
|
537 |
validation_data <- NULL |
|
|
538 |
} |
|
|
539 |
|
|
|
540 |
history <- keras::fit( |
|
|
541 |
object = model, |
|
|
542 |
x = dataset[[1]], |
|
|
543 |
y = dataset[[2]], |
|
|
544 |
batch_size = batch_size, |
|
|
545 |
validation_split = train_val_ratio, |
|
|
546 |
validation_data = validation_data, |
|
|
547 |
callbacks = c(callbacks, callback_list), |
|
|
548 |
epochs = epochs, |
|
|
549 |
class_weight = class_weight, |
|
|
550 |
verbose = print_scores) |
|
|
551 |
} |
|
|
552 |
|
|
|
553 |
if (removeLog & file.exists(path_file_log)) { |
|
|
554 |
file.remove(path_file_log) |
|
|
555 |
} |
|
|
556 |
|
|
|
557 |
message("Training done.") |
|
|
558 |
|
|
|
559 |
return(history) |
|
|
560 |
} |
|
|
561 |
|
|
|
562 |
#' Generate run_name if none is given or is already present. |
|
|
563 |
#' |
|
|
564 |
#' If no run name is given, will use date as run name. If run name is already present will add _2 to name or |
|
|
565 |
#' _x+1 if name ends with _x and x is integer. |
|
|
566 |
#' |
|
|
567 |
#' @param auto_extend If run_name is already present, add "_2" to name. If name already ends with "_x" replace x with x+1. |
|
|
568 |
#' @noRd |
|
|
569 |
get_run_name <- function(run_name = NULL, path_tensorboard, path_checkpoint, path_log, path_model_card, auto_extend = FALSE) { |
|
|
570 |
|
|
|
571 |
if (is.null(run_name)) { |
|
|
572 |
run_name_new <- as.character(Sys.time()) %>% stringr::str_replace_all(" ", "_") |
|
|
573 |
} |
|
|
574 |
|
|
|
575 |
tb_names <- "" |
|
|
576 |
cp_names <- "" |
|
|
577 |
log_names <- "" |
|
|
578 |
mc_names <- "" |
|
|
579 |
name_present_tb <- FALSE |
|
|
580 |
name_present_cp <- FALSE |
|
|
581 |
name_present_log <- FALSE |
|
|
582 |
name_present_mc <- FALSE |
|
|
583 |
|
|
|
584 |
if (!is.null(path_tensorboard)) { |
|
|
585 |
tb_names <- list.files(path_tensorboard) |
|
|
586 |
name_present_tb <- (run_name %in% tb_names) # & any(stringr::str_detect(tb_names, run_name)) |
|
|
587 |
} |
|
|
588 |
if (!is.null(path_checkpoint)) { |
|
|
589 |
cp_names <- list.files(path_checkpoint) |
|
|
590 |
name_present_cp <- (run_name %in% cp_names) # & any(stringr::str_detect(cp_names, run_name)) |
|
|
591 |
} |
|
|
592 |
if (!is.null(path_log)) { |
|
|
593 |
log_names <- list.files(path_log) |
|
|
594 |
name_present_log <- (run_name %in% log_names) # & any(stringr::str_detect(log_names, run_name)) |
|
|
595 |
} |
|
|
596 |
if (!is.null(path_model_card)) { |
|
|
597 |
mc_names <- list.files(path_model_card) |
|
|
598 |
name_present_mc <- (run_name %in% mc_names) # & any(stringr::str_detect(log_names, run_name)) |
|
|
599 |
} |
|
|
600 |
|
|
|
601 |
name_present <- name_present_tb | name_present_cp | name_present_log | name_present_mc |
|
|
602 |
|
|
|
603 |
if (name_present & auto_extend) { |
|
|
604 |
|
|
|
605 |
ends_with_int <- stringr::str_detect(run_name, "_\\d+$") |
|
|
606 |
if (ends_with_int) { |
|
|
607 |
int_ending <- stringr::str_extract(run_name, "\\d+$") %>% as.integer() |
|
|
608 |
run_name_new <- paste0(stringr::str_remove(run_name, "\\d+$"), int_ending + 1) |
|
|
609 |
} else { |
|
|
610 |
run_name_new <- paste0(run_name, "_2") |
|
|
611 |
} |
|
|
612 |
|
|
|
613 |
int_ending <- stringr::str_subset(c(tb_names, cp_names, log_names, mc_names), |
|
|
614 |
paste0("^", stringr::str_remove(run_name, "_\\d+$"))) %>% unique() |
|
|
615 |
int_ending <- stringr::str_subset(int_ending, "_\\d+$") |
|
|
616 |
if (length(int_ending) > 0) { |
|
|
617 |
max_int_ending <- stringr::str_extract(int_ending, "_\\d+$") %>% stringr::str_remove("_") %>% as.integer() %>% max() |
|
|
618 |
if (!ends_with_int) { |
|
|
619 |
run_name_new <- paste0(run_name, "_", max_int_ending + 1) |
|
|
620 |
} else { |
|
|
621 |
run_name_new <- paste0(stringr::str_remove(run_name, "\\d+$"), max_int_ending + 1) |
|
|
622 |
} |
|
|
623 |
} |
|
|
624 |
|
|
|
625 |
if (length(int_ending) > 0) { |
|
|
626 |
name_order <- stringr::str_extract(int_ending, "\\d+$") %>% as.integer() %>% order() |
|
|
627 |
prev_names <- unique(c(run_name, int_ending[name_order])) |
|
|
628 |
if (ends_with_int) { |
|
|
629 |
name_order <- stringr::str_extract(prev_names, "\\d+$") %>% as.integer() %>% order() |
|
|
630 |
prev_names <- prev_names[name_order] |
|
|
631 |
} |
|
|
632 |
|
|
|
633 |
if (length(prev_names) > 8) { |
|
|
634 |
old_names_start <- paste(prev_names[1:2], collapse = ", ") |
|
|
635 |
old_names_end <- paste(prev_names[(length(prev_names)-1) : length(prev_names)], collapse = ", ") |
|
|
636 |
#old_names <- paste(old_names_start, ",...,", old_names_end) # outputs range of previously used names |
|
|
637 |
old_names <- run_name |
|
|
638 |
} else { |
|
|
639 |
old_names <- paste(prev_names, collapse = ", ") |
|
|
640 |
} |
|
|
641 |
message(paste("run_name", old_names, "already present, setting run_name to", run_name_new)) |
|
|
642 |
} else { |
|
|
643 |
message(paste("run_name", run_name, "already present, setting run_name to", run_name_new)) |
|
|
644 |
} |
|
|
645 |
} |
|
|
646 |
|
|
|
647 |
if (name_present & !auto_extend) { |
|
|
648 |
stop("run_name already present, please give your run a unique name") |
|
|
649 |
} |
|
|
650 |
|
|
|
651 |
if (!name_present) { |
|
|
652 |
return(run_name) |
|
|
653 |
} |
|
|
654 |
|
|
|
655 |
return(run_name_new) |
|
|
656 |
} |
|
|
657 |
|
|
|
658 |
#' Continue training from model card |
|
|
659 |
#' |
|
|
660 |
#' Use information from model card to resume from the corresponding checkpoint using the same training arguments. |
|
|
661 |
#' |
|
|
662 |
#' @param path_model_card Path to model card to resume training from. |
|
|
663 |
#' @param seed Seed for reproducible results. If `NULL`, set random seed. |
|
|
664 |
#' @param epoch Epoch to resume from. If `NULL`, use last epoch. |
|
|
665 |
#' @param new_run_name New run name. If `NULL`, new run name is old run name + '_cont'. |
|
|
666 |
#' @param new_args Named list of arguments to overwrite. Will use previous arguments from model card otherwise. |
|
|
667 |
#' For example, if you want to change the batch size and padding option: |
|
|
668 |
#' `new_args = list(batch_size = 6, padding = TRUE)`. |
|
|
669 |
#' @param new_compile List of arguments to compile the model again. If `NULL`, use compiled model from checkpoint. |
|
|
670 |
#' Example: `new_compile = list(loss = 'binary_crossentropy', metrics = 'acc', optimizer = keras::optimizer_adam())` |
|
|
671 |
#' @param use_mirrored_strategy Whether to use distributed mirrored strategy. |
|
|
672 |
#' If NULL, will use distributed mirrored strategy only if >1 GPU available. |
|
|
673 |
#' @param unfreeze If `TRUE`, set trainable attribute of model to `TRUE` (unfreeze weights). |
|
|
674 |
#' @param verbose Whether to print all training arguments. |
|
|
675 |
#' @examples |
|
|
676 |
#' \donttest{ |
|
|
677 |
#' library(keras) |
|
|
678 |
#' # create dummy data and temp directories |
|
|
679 |
#' path_train_1 <- tempfile() |
|
|
680 |
#' path_train_2 <- tempfile() |
|
|
681 |
#' path_val_1 <- tempfile() |
|
|
682 |
#' path_val_2 <- tempfile() |
|
|
683 |
#' path_checkpoint <- tempfile() |
|
|
684 |
#' dir.create(path_checkpoint) |
|
|
685 |
#' path_model_card <- tempfile() |
|
|
686 |
#' dir.create(path_model_card) |
|
|
687 |
#' |
|
|
688 |
#' for (current_path in c(path_train_1, path_train_2, |
|
|
689 |
#' path_val_1, path_val_2)) { |
|
|
690 |
#' dir.create(current_path) |
|
|
691 |
#' create_dummy_data(file_path = current_path, |
|
|
692 |
#' num_files = 3, |
|
|
693 |
#' seq_length = 10, |
|
|
694 |
#' num_seq = 5, |
|
|
695 |
#' vocabulary = c("a", "c", "g", "t")) |
|
|
696 |
#' } |
|
|
697 |
#' |
|
|
698 |
#' # create model |
|
|
699 |
#' model <- create_model_lstm_cnn(layer_lstm = 8, layer_dense = 2, maxlen = 5) |
|
|
700 |
#' |
|
|
701 |
#' # train model |
|
|
702 |
#' run_name <- 'test_run_1' |
|
|
703 |
#' hist <- train_model(train_type = "label_folder", |
|
|
704 |
#' run_name = run_name, |
|
|
705 |
#' path_checkpoint = path_checkpoint, |
|
|
706 |
#' model_card = list(path_model_card = path_model_card, description = 'test run'), |
|
|
707 |
#' model = model, |
|
|
708 |
#' path = c(path_train_1, path_train_2), |
|
|
709 |
#' path_val = c(path_val_1, path_val_2), |
|
|
710 |
#' batch_size = 8, |
|
|
711 |
#' epochs = 3, |
|
|
712 |
#' steps_per_epoch = 6, |
|
|
713 |
#' vocabulary_label = c("label_1", "label_2")) |
|
|
714 |
#' |
|
|
715 |
#' # resume training |
|
|
716 |
#' resume_training_from_model_card(path_model_card = file.path(path_model_card, run_name)) |
|
|
717 |
#' } |
|
|
718 |
#' @returns A list of training metrics. |
|
|
719 |
#' @export |
|
|
720 |
resume_training_from_model_card <- function(path_model_card, |
|
|
721 |
seed = NULL, |
|
|
722 |
epoch = NULL, |
|
|
723 |
new_run_name = NULL, |
|
|
724 |
new_args = NULL, |
|
|
725 |
new_compile = NULL, |
|
|
726 |
use_mirrored_strategy = NULL, |
|
|
727 |
unfreeze = FALSE, |
|
|
728 |
verbose = FALSE) { |
|
|
729 |
|
|
|
730 |
if (is.null(use_mirrored_strategy)) use_mirrored_strategy <- ifelse(count_gpu() > 1, TRUE, FALSE) |
|
|
731 |
|
|
|
732 |
info <- file.info(path_model_card) |
|
|
733 |
is_directory <- info$isdir |
|
|
734 |
|
|
|
735 |
if (is.na(is_directory)) { |
|
|
736 |
stop("model_card path does not exist.\n") |
|
|
737 |
} else if (is_directory) { |
|
|
738 |
mc <- get_mc(path_model_card = path_model_card, epoch = epoch) |
|
|
739 |
} else { |
|
|
740 |
mc <- path_model_card |
|
|
741 |
} |
|
|
742 |
|
|
|
743 |
mc_args <- readRDS(mc) |
|
|
744 |
train_args_mc <- mc_args$train_model_args |
|
|
745 |
new_train_args <- train_args_mc |
|
|
746 |
|
|
|
747 |
if (is.null(new_run_name)) { |
|
|
748 |
new_train_args$run_name <- set_new_run_name(train_args_mc$run_name) |
|
|
749 |
} else { |
|
|
750 |
new_train_args$run_name <- new_run_name |
|
|
751 |
} |
|
|
752 |
|
|
|
753 |
# overwrite args |
|
|
754 |
if (is.null(seed)) seed <- get_seed() |
|
|
755 |
new_train_args$seed <- seed |
|
|
756 |
|
|
|
757 |
# load checkpoint to resume from |
|
|
758 |
if (is.null(train_args_mc$path_checkpoint)) { |
|
|
759 |
stop('Did not save checkpoints in the run from model card') |
|
|
760 |
} |
|
|
761 |
|
|
|
762 |
if (unfreeze) { |
|
|
763 |
model$trainable <- TRUE |
|
|
764 |
} |
|
|
765 |
|
|
|
766 |
if (use_mirrored_strategy) { |
|
|
767 |
mirrored_strategy <- tensorflow::tf$distribute$MirroredStrategy() |
|
|
768 |
with(mirrored_strategy$scope(), { |
|
|
769 |
model <- load_model(cp_path = file.path(train_args_mc$path_checkpoint, train_args_mc$run_name), |
|
|
770 |
ep_index = epoch, |
|
|
771 |
new_compile = new_compile) |
|
|
772 |
}) |
|
|
773 |
} else { |
|
|
774 |
model <- load_model(cp_path = file.path(train_args_mc$path_checkpoint, train_args_mc$run_name), |
|
|
775 |
ep_index = epoch, |
|
|
776 |
new_compile = new_compile) |
|
|
777 |
} |
|
|
778 |
|
|
|
779 |
new_train_args$model <- model |
|
|
780 |
|
|
|
781 |
if (!is.null(new_args)) { |
|
|
782 |
stopifnot(is.list(new_args)) |
|
|
783 |
for (n in names(new_args)) |
|
|
784 |
new_train_args[[n]] <- new_args[[n]] |
|
|
785 |
} |
|
|
786 |
|
|
|
787 |
new_train_args$model_card[['cont_train_info']] <- paste0('run continues training from run ', |
|
|
788 |
train_args_mc, ' and epoch ', |
|
|
789 |
max(mc_args$logs$processing_step)) |
|
|
790 |
|
|
|
791 |
if (verbose) { |
|
|
792 |
print(new_train_args) |
|
|
793 |
} |
|
|
794 |
|
|
|
795 |
do.call(train_model, new_train_args) |
|
|
796 |
|
|
|
797 |
} |
|
|
798 |
|
|
|
799 |
get_mc <- function(path_model_card, epoch = NULL) { |
|
|
800 |
|
|
|
801 |
all_cards <- list.files(path_model_card, full.names = TRUE) |
|
|
802 |
all_epochs <- vector("integer", length(all_cards)) |
|
|
803 |
for (i in seq_along(all_cards)) { |
|
|
804 |
split_string <- all_cards[i] %>% basename() %>% stringr::str_split("_") |
|
|
805 |
all_epochs[i] <- split_string[[1]][2] %>% as.integer() |
|
|
806 |
} |
|
|
807 |
|
|
|
808 |
if (is.null(epoch)) epoch <- max(all_epochs) |
|
|
809 |
|
|
|
810 |
index <- all_epochs == epoch |
|
|
811 |
if (sum(index) == 0) { |
|
|
812 |
error_message <- paste('epoch not found in model card directory, possible values:', |
|
|
813 |
paste(all_epochs, collapse = ", ")) |
|
|
814 |
stop(error_message) |
|
|
815 |
} |
|
|
816 |
|
|
|
817 |
mc <- all_cards[index] |
|
|
818 |
return(mc) |
|
|
819 |
|
|
|
820 |
} |
|
|
821 |
|
|
|
822 |
load_model <- function(cp_path, |
|
|
823 |
ep_index, |
|
|
824 |
new_compile) { |
|
|
825 |
|
|
|
826 |
model <- load_cp(cp_path, |
|
|
827 |
ep_index = ep_index, |
|
|
828 |
mirrored_strategy = FALSE, |
|
|
829 |
compile = ifelse(is.null(new_compile), TRUE, FALSE)) |
|
|
830 |
|
|
|
831 |
if (!is.null(new_compile)) { |
|
|
832 |
model <- keras::compile(model, |
|
|
833 |
optimizer = new_compile$optimizer, |
|
|
834 |
loss = new_compile$loss, |
|
|
835 |
metrics = new_compile$metrics) |
|
|
836 |
} |
|
|
837 |
|
|
|
838 |
return(model) |
|
|
839 |
|
|
|
840 |
} |
|
|
841 |
|
|
|
842 |
get_seed <- function() { |
|
|
843 |
|
|
|
844 |
current_time <- Sys.time() |
|
|
845 |
current_time <- as.numeric(current_time) * 1e2 |
|
|
846 |
seed_value <- (current_time %% 10^5) %>% as.integer() |
|
|
847 |
set.seed(seed_value) |
|
|
848 |
return(sample(1:10^6, 2)) |
|
|
849 |
|
|
|
850 |
} |
|
|
851 |
|
|
|
852 |
set_new_run_name <- function(run_name_old) { |
|
|
853 |
|
|
|
854 |
run_name_new <- paste0(run_name_old, '_cont') |
|
|
855 |
return(run_name_new) |
|
|
856 |
|
|
|
857 |
} |