|
a |
|
b/R/train_cpc.R |
|
|
1 |
#' @title Train CPC inspired model |
|
|
2 |
#' |
|
|
3 |
#' @description |
|
|
4 |
#' Train a CPC (Oord et al.) inspired neural network on genomic data. |
|
|
5 |
#' |
|
|
6 |
#' @inheritParams generator_fasta_lm |
|
|
7 |
#' @inheritParams generator_fasta_label_folder |
|
|
8 |
#' @inheritParams generator_fasta_label_header_csv |
|
|
9 |
#' @inheritParams train_model |
|
|
10 |
#' @param train_type Either `"cpc"`, `"Self-GenomeNet"`. |
|
|
11 |
#' @param encoder A keras encoder for the cpc function. |
|
|
12 |
#' @param context A keras context model for the cpc function. |
|
|
13 |
#' @param path Path to training data. If \code{train_type} is \code{label_folder}, should be a vector or list |
|
|
14 |
#' where each entry corresponds to a class (list elements can be directories and/or individual files). If \code{train_type} is not \code{label_folder}, |
|
|
15 |
#' can be a single directory or file or a list of directories and/or files. |
|
|
16 |
#' @param path_val Path to validation data. See `path` argument for details. |
|
|
17 |
#' @param path_checkpoint Path to checkpoints folder or `NULL`. If `NULL`, checkpoints don't get stored. |
|
|
18 |
#' @param path_tensorboard Path to tensorboard directory or `NULL`. If `NULL`, training not tracked on tensorboard. |
|
|
19 |
#' @param train_val_ratio For generator defines the fraction of batches that will be used for validation (compared to size of training data), i.e. one validation iteration |
|
|
20 |
#' processes \code{batch_size} \eqn{*} \code{steps_per_epoch} \eqn{*} \code{train_val_ratio} samples. If you use dataset instead of generator and \code{dataset_val} is `NULL`, splits \code{dataset} |
|
|
21 |
#' into train/validation data. |
|
|
22 |
#' @param run_name Name of the run. Name will be used to identify output from callbacks. |
|
|
23 |
#' @param batch_size Number of samples used for one network update. |
|
|
24 |
#' @param epochs Number of iterations. |
|
|
25 |
#' @param steps_per_epoch Number of training batches per epoch. |
|
|
26 |
#' @param shuffle_file_order Boolean, whether to go through files sequentially or shuffle beforehand. |
|
|
27 |
#' @param initial_epoch Epoch at which to start training. Note that network |
|
|
28 |
#' will run for (\code{epochs} - \code{initial_epochs}) rounds and not \code{epochs} rounds. |
|
|
29 |
#' @param seed Sets seed for reproducible results. |
|
|
30 |
#' @param file_limit Integer or `NULL`. If integer, use only specified number of randomly sampled files for training. Ignored if greater than number of files in \code{path}. |
|
|
31 |
#' @param patchlen The length of a patch when splitting the input sequence. |
|
|
32 |
#' @param nopatches The number of patches when splitting the input sequence. |
|
|
33 |
#' @param step Frequency of sampling steps. |
|
|
34 |
#' @param stride The overlap between two patches when splitting the input sequence. |
|
|
35 |
#' @param pretrained_model A pretrained keras model, for which training will be continued |
|
|
36 |
#' @param learningrate A Tensor, floating point value. If a schedule is defines, this value gives the initial learning rate. Defaults to 0.001. |
|
|
37 |
#' @param learningrate_schedule A schedule for a non-constant learning rate over the training. Either "cosine_annealing", "step_decay", or "exp_decay". |
|
|
38 |
#' @param k Value of k for sparse top k categorical accuracy. Defaults to 5. |
|
|
39 |
#' @param stepsmin In CPC, a patch is predicted given another patch. stepsmin defines how many patches between these two should be ignored during prediction. |
|
|
40 |
#' @param stepsmax The maximum distance between the predicted patch and the given patch. |
|
|
41 |
#' @param emb_scale Scales the impact of a patches context. |
|
|
42 |
#' @examplesIf reticulate::py_module_available("tensorflow") |
|
|
43 |
#' |
|
|
44 |
#' #create dummy data |
|
|
45 |
#' path_train_1 <- tempfile() |
|
|
46 |
#' path_train_2 <- tempfile() |
|
|
47 |
#' path_val_1 <- tempfile() |
|
|
48 |
#' path_val_2 <- tempfile() |
|
|
49 |
#' |
|
|
50 |
#' for (current_path in c(path_train_1, path_train_2, |
|
|
51 |
#' path_val_1, path_val_2)) { |
|
|
52 |
#' dir.create(current_path) |
|
|
53 |
#' deepG::create_dummy_data(file_path = current_path, |
|
|
54 |
#' num_files = 3, |
|
|
55 |
#' seq_length = 10, |
|
|
56 |
#' num_seq = 5, |
|
|
57 |
#' vocabulary = c("a", "c", "g", "t")) |
|
|
58 |
#' } |
|
|
59 |
#' |
|
|
60 |
#' # create model |
|
|
61 |
#' encoder <- function(maxlen = NULL, |
|
|
62 |
#' patchlen = NULL, |
|
|
63 |
#' nopatches = NULL, |
|
|
64 |
#' eval = FALSE) { |
|
|
65 |
#' if (is.null(nopatches)) { |
|
|
66 |
#' nopatches <- nopatchescalc(patchlen, maxlen, patchlen * 0.4) |
|
|
67 |
#' } |
|
|
68 |
#' inp <- keras::layer_input(shape = c(maxlen, 4)) |
|
|
69 |
#' stridelen <- as.integer(0.4 * patchlen) |
|
|
70 |
#' createpatches <- inp %>% |
|
|
71 |
#' keras::layer_reshape(list(maxlen, 4L, 1L), name = "prep_reshape1", dtype = "float32") %>% |
|
|
72 |
#' tensorflow::tf$image$extract_patches( |
|
|
73 |
#' sizes = list(1L, patchlen, 4L, 1L), |
|
|
74 |
#' strides = list(1L, stridelen, 4L, 1L), |
|
|
75 |
#' rates = list(1L, 1L, 1L, 1L), |
|
|
76 |
#' padding = "VALID", |
|
|
77 |
#' name = "prep_patches" |
|
|
78 |
#' ) %>% |
|
|
79 |
#' keras::layer_reshape(list(nopatches, patchlen, 4L), |
|
|
80 |
#' name = "prep_reshape2") %>% |
|
|
81 |
#' tensorflow::tf$reshape(list(-1L, patchlen, 4L), |
|
|
82 |
#' name = "prep_reshape3") |
|
|
83 |
#' |
|
|
84 |
#' danQ <- createpatches %>% |
|
|
85 |
#' keras::layer_conv_1d( |
|
|
86 |
#' input_shape = c(maxlen, 4L), |
|
|
87 |
#' filters = 320L, |
|
|
88 |
#' kernel_size = 26L, |
|
|
89 |
#' activation = "relu" |
|
|
90 |
#' ) %>% |
|
|
91 |
#' keras::layer_max_pooling_1d(pool_size = 13L, strides = 13L) %>% |
|
|
92 |
#' keras::layer_dropout(0.2) %>% |
|
|
93 |
#' keras::layer_lstm(units = 320, return_sequences = TRUE) %>% |
|
|
94 |
#' keras::layer_dropout(0.5) %>% |
|
|
95 |
#' keras::layer_flatten() %>% |
|
|
96 |
#' keras::layer_dense(925, activation = "relu") |
|
|
97 |
#' patchesback <- danQ %>% |
|
|
98 |
#' tensorflow::tf$reshape(list(-1L, tensorflow::tf$cast(nopatches, tensorflow::tf$int16), 925L)) |
|
|
99 |
#' keras::keras_model(inp, patchesback) |
|
|
100 |
#' } |
|
|
101 |
#' |
|
|
102 |
#' context <- function(latents) { |
|
|
103 |
#' cres <- latents |
|
|
104 |
#' cres_dim = cres$shape |
|
|
105 |
#' predictions <- |
|
|
106 |
#' cres %>% |
|
|
107 |
#' keras::layer_lstm( |
|
|
108 |
#' return_sequences = TRUE, |
|
|
109 |
#' units = 256, # WAS: 2048, |
|
|
110 |
#' name = paste("context_LSTM_1", |
|
|
111 |
#' sep = ""), |
|
|
112 |
#' activation = "relu" |
|
|
113 |
#' ) |
|
|
114 |
#' return(predictions) |
|
|
115 |
#' } |
|
|
116 |
#' |
|
|
117 |
#' # train model |
|
|
118 |
#' temp_dir <- tempdir() |
|
|
119 |
#' hist <- train_model_cpc(train_type = "CPC", |
|
|
120 |
#' ### cpc functions ### |
|
|
121 |
#' encoder = encoder, |
|
|
122 |
#' context = context, |
|
|
123 |
#' #### Generator settings #### |
|
|
124 |
#' path_checkpoint = temp_dir, |
|
|
125 |
#' path = c(path_train_1, path_train_2), |
|
|
126 |
#' path_val = c(path_val_1, path_val_2), |
|
|
127 |
#' run_name = "TEST", |
|
|
128 |
#' batch_size = 8, |
|
|
129 |
#' epochs = 3, |
|
|
130 |
#' steps_per_epoch = 6, |
|
|
131 |
#' patchlen = 100, |
|
|
132 |
#' nopatches = 8) |
|
|
133 |
#' |
|
|
134 |
#' |
|
|
135 |
#' @returns A list of training metrics. |
|
|
136 |
#' @export |
|
|
137 |
train_model_cpc <- |
|
|
138 |
function(train_type = "CPC", |
|
|
139 |
### cpc functions ### |
|
|
140 |
encoder = NULL, |
|
|
141 |
context = NULL, |
|
|
142 |
#### Generator settings #### |
|
|
143 |
path, |
|
|
144 |
path_val = NULL, |
|
|
145 |
path_checkpoint = NULL, |
|
|
146 |
path_tensorboard = NULL, |
|
|
147 |
train_val_ratio = 0.2, |
|
|
148 |
run_name, |
|
|
149 |
|
|
|
150 |
batch_size = 32, |
|
|
151 |
epochs = 100, |
|
|
152 |
steps_per_epoch = 2000, |
|
|
153 |
shuffle_file_order = FALSE, |
|
|
154 |
initial_epoch = 1, |
|
|
155 |
seed = 1234, |
|
|
156 |
|
|
|
157 |
path_file_log = TRUE, |
|
|
158 |
train_val_split_csv = NULL, |
|
|
159 |
file_limit = NULL, |
|
|
160 |
proportion_per_seq = NULL, |
|
|
161 |
max_samples = NULL, |
|
|
162 |
maxlen = NULL, |
|
|
163 |
|
|
|
164 |
patchlen = NULL, |
|
|
165 |
nopatches = NULL, |
|
|
166 |
step = NULL, |
|
|
167 |
file_filter = NULL, |
|
|
168 |
stride = 0.4, |
|
|
169 |
pretrained_model = NULL, |
|
|
170 |
learningrate = 0.001, |
|
|
171 |
learningrate_schedule = NULL, |
|
|
172 |
k = 5, |
|
|
173 |
stepsmin = 2, |
|
|
174 |
stepsmax = 3, |
|
|
175 |
emb_scale = 0.1) { |
|
|
176 |
|
|
|
177 |
# Stride is default 0.4 x patchlen FOR NOW |
|
|
178 |
stride <- 0.4 |
|
|
179 |
|
|
|
180 |
patchlen <- as.integer(patchlen) |
|
|
181 |
|
|
|
182 |
######################################################################################################## |
|
|
183 |
############################### Warning messages if wrong initialization ############################### |
|
|
184 |
######################################################################################################## |
|
|
185 |
|
|
|
186 |
####~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Model specification ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#### |
|
|
187 |
## Three options: |
|
|
188 |
## 1. Define Maxlen and Patchlen |
|
|
189 |
## 2. Define Number of patches and Patchlen |
|
|
190 |
## ---> in both cases the respectively missing value will be calculated |
|
|
191 |
## 3. Pretrained model is giving specs |
|
|
192 |
## error if none of those is fulfilled |
|
|
193 |
|
|
|
194 |
if (is.null(pretrained_model)) { |
|
|
195 |
## If no pretrained model, patchlen has to be defined |
|
|
196 |
if (is.null(patchlen)) { |
|
|
197 |
stop("Please define patchlen") |
|
|
198 |
} |
|
|
199 |
## Either maxlen or number of patches is needed |
|
|
200 |
if (is.null(maxlen) & is.null(nopatches)) { |
|
|
201 |
stop("Please define either maxlen or nopatches") |
|
|
202 |
## the respectively missing value will be calculated |
|
|
203 |
} else if (is.null(maxlen) & !is.null(nopatches)) { |
|
|
204 |
maxlen <- (nopatches - 1) * (stride * patchlen) + patchlen |
|
|
205 |
} else if (!is.null(maxlen) & is.null(nopatches)) { |
|
|
206 |
nopatches <- |
|
|
207 |
as.integer((maxlen - patchlen) / (stride * patchlen) + 1) |
|
|
208 |
} |
|
|
209 |
## if step is not defined, we do not use overlapping sequences |
|
|
210 |
if (is.null(step)) { |
|
|
211 |
step = maxlen |
|
|
212 |
} |
|
|
213 |
} else if (!is.null(pretrained_model)) { |
|
|
214 |
specs <- |
|
|
215 |
readRDS(paste( |
|
|
216 |
sub("/[^/]+$", "", pretrained_model), |
|
|
217 |
"modelspecs.rds", |
|
|
218 |
sep = "/" |
|
|
219 |
)) |
|
|
220 |
patchlen <- specs$patchlen |
|
|
221 |
maxlen <- specs$maxlen |
|
|
222 |
nopatches <- specs$nopatches |
|
|
223 |
stride <- specs$stride |
|
|
224 |
step <- specs$step |
|
|
225 |
k <- specs$k |
|
|
226 |
emb_scale <- specs$emb_scale |
|
|
227 |
} |
|
|
228 |
|
|
|
229 |
|
|
|
230 |
####~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Learning rate schedule ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#### |
|
|
231 |
## If learning_rate schedule is wanted, all necessary parameters must be given |
|
|
232 |
LRstop(learningrate_schedule) |
|
|
233 |
######################################################################################################## |
|
|
234 |
#################################### Preparation: Data, paths metrics ################################## |
|
|
235 |
######################################################################################################## |
|
|
236 |
|
|
|
237 |
####~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Path definition ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#### |
|
|
238 |
runname <- |
|
|
239 |
paste0(run_name , format(Sys.time(), "_%y%m%d_%H%M%S")) |
|
|
240 |
|
|
|
241 |
## Create folder for model |
|
|
242 |
if (!is.null(path_checkpoint)) { |
|
|
243 |
dir.create(paste(path_checkpoint, runname, sep = "/")) |
|
|
244 |
dir <- paste(path_checkpoint, runname, sep = "/") |
|
|
245 |
## Create folder for filelog |
|
|
246 |
path_file_log <- |
|
|
247 |
paste(path_checkpoint, runname, "filelog.csv", sep = "/") |
|
|
248 |
} else { |
|
|
249 |
path_file_log <- NULL |
|
|
250 |
} |
|
|
251 |
|
|
|
252 |
GenConfig <- |
|
|
253 |
GenParams(maxlen, batch_size, step, proportion_per_seq, max_samples) |
|
|
254 |
GenTConfig <- |
|
|
255 |
GenTParams(path, shuffle_file_order, path_file_log, seed) |
|
|
256 |
GenVConfig <- GenVParams(path_val, shuffle_file_order) |
|
|
257 |
|
|
|
258 |
# train train_val_ratio via csv file |
|
|
259 |
if (!is.null(train_val_split_csv)) { |
|
|
260 |
if (is.null(path_val)) { |
|
|
261 |
path_val <- path |
|
|
262 |
} else { |
|
|
263 |
if (!all(unlist(path_val) %in% unlist(path))) { |
|
|
264 |
warning("Train/validation split done via file in train_val_split_csv. Only using files from path argument.") |
|
|
265 |
} |
|
|
266 |
path_val <- path |
|
|
267 |
} |
|
|
268 |
|
|
|
269 |
train_val_file <- utils::read.csv2(train_val_split_csv, header = TRUE, stringsAsFactors = FALSE) |
|
|
270 |
if (dim(train_val_file)[2] == 1) { |
|
|
271 |
train_val_file <- utils::read.csv(train_val_split_csv, header = TRUE, stringsAsFactors = FALSE) |
|
|
272 |
} |
|
|
273 |
train_val_file <- dplyr::distinct(train_val_file) |
|
|
274 |
|
|
|
275 |
if (!all(c("file", "type") %in% names(train_val_file))) { |
|
|
276 |
stop("Column names of train_val_split_csv file must be 'file' and 'type'") |
|
|
277 |
} |
|
|
278 |
|
|
|
279 |
if (length(train_val_file$file) != length(unique(train_val_file$file))) { |
|
|
280 |
stop("In train_val_split_csv all entires in 'file' column must be unique") |
|
|
281 |
} |
|
|
282 |
|
|
|
283 |
file_filter <- list() |
|
|
284 |
file_filter[[1]] <- train_val_file %>% dplyr::filter(type == "train") |
|
|
285 |
file_filter[[1]] <- as.character(file_filter[[1]]$file) |
|
|
286 |
file_filter[[2]] <- train_val_file %>% dplyr::filter(type == "val" | type == "validation") |
|
|
287 |
file_filter[[2]] <- as.character(file_filter[[2]]$file) |
|
|
288 |
} |
|
|
289 |
|
|
|
290 |
|
|
|
291 |
####~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ File count ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#### |
|
|
292 |
if (is.null(file_filter) && is.null(train_val_split_csv)) { |
|
|
293 |
if (is.null(file_limit)) { |
|
|
294 |
if (is.list(path)) { |
|
|
295 |
num_files <- 0 |
|
|
296 |
for (i in seq_along(path)) { |
|
|
297 |
num_files <- num_files + length(list.files(path[[i]])) |
|
|
298 |
} |
|
|
299 |
} else { |
|
|
300 |
num_files <- length(list.files(path)) |
|
|
301 |
} |
|
|
302 |
} else { |
|
|
303 |
num_files <- file_limit * length(path) |
|
|
304 |
} |
|
|
305 |
} else { |
|
|
306 |
num_files <- length(file_filter[1]) |
|
|
307 |
} |
|
|
308 |
|
|
|
309 |
####~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Creation of generators ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#### |
|
|
310 |
message(format(Sys.time(), "%F %R"), ": Preparing the data\n") |
|
|
311 |
####~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Training Generator ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#### |
|
|
312 |
fastrain <- |
|
|
313 |
do.call(generator_fasta_lm, |
|
|
314 |
c(GenConfig, GenTConfig, file_filter = file_filter[1])) |
|
|
315 |
|
|
|
316 |
|
|
|
317 |
####~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Validation Generator ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#### |
|
|
318 |
fasval <- |
|
|
319 |
do.call( |
|
|
320 |
generator_fasta_lm, |
|
|
321 |
c( |
|
|
322 |
GenConfig, |
|
|
323 |
GenVConfig, |
|
|
324 |
seed = seed, |
|
|
325 |
file_filter = file_filter[2] |
|
|
326 |
) |
|
|
327 |
) |
|
|
328 |
####~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Creation of metrics ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#### |
|
|
329 |
message(format(Sys.time(), "%F %R"), ": Preparing the metrics\n") |
|
|
330 |
train_loss <- tensorflow::tf$keras$metrics$Mean(name = 'train_loss') |
|
|
331 |
val_loss <- tensorflow::tf$keras$metrics$Mean(name = 'val_loss') |
|
|
332 |
train_acc <- tensorflow::tf$keras$metrics$Mean(name = 'train_acc') |
|
|
333 |
val_acc <- tensorflow::tf$keras$metrics$Mean(name = 'val_acc') |
|
|
334 |
|
|
|
335 |
######################################################################################################## |
|
|
336 |
###################################### History object preparation ###################################### |
|
|
337 |
######################################################################################################## |
|
|
338 |
|
|
|
339 |
history <- list( |
|
|
340 |
params = list( |
|
|
341 |
batch_size = batch_size, |
|
|
342 |
epochs = 0, |
|
|
343 |
steps = steps_per_epoch, |
|
|
344 |
samples = steps_per_epoch * batch_size, |
|
|
345 |
verbose = 1, |
|
|
346 |
do_validation = TRUE, |
|
|
347 |
metrics = c("loss", "accuracy", "val_loss", "val_accuracy") |
|
|
348 |
), |
|
|
349 |
metrics = list( |
|
|
350 |
loss = c(), |
|
|
351 |
accuracy = c(), |
|
|
352 |
val_loss = c(), |
|
|
353 |
val_accuracy = c() |
|
|
354 |
) |
|
|
355 |
) |
|
|
356 |
|
|
|
357 |
eploss <- list() |
|
|
358 |
epacc <- list() |
|
|
359 |
|
|
|
360 |
####~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Reformat to S3 object ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#### |
|
|
361 |
class(history) <- "keras_training_history" |
|
|
362 |
|
|
|
363 |
######################################################################################################## |
|
|
364 |
############################################ Model creation ############################################ |
|
|
365 |
######################################################################################################## |
|
|
366 |
if (is.null(pretrained_model)) { |
|
|
367 |
####~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Unsupervised Build from scratch ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#### |
|
|
368 |
message(format(Sys.time(), "%F %R"), ": Creating the model\n") |
|
|
369 |
## Build encoder |
|
|
370 |
enc <- |
|
|
371 |
encoder(maxlen = maxlen, |
|
|
372 |
patchlen = patchlen, |
|
|
373 |
nopatches = nopatches) |
|
|
374 |
|
|
|
375 |
## Build model |
|
|
376 |
model <- |
|
|
377 |
keras::keras_model( |
|
|
378 |
enc$input, |
|
|
379 |
cpcloss( |
|
|
380 |
enc$output, |
|
|
381 |
context, |
|
|
382 |
batch_size = batch_size, |
|
|
383 |
steps_to_ignore = stepsmin, |
|
|
384 |
steps_to_predict = stepsmax, |
|
|
385 |
train_type = train_type, |
|
|
386 |
k = k, |
|
|
387 |
emb_scale = emb_scale |
|
|
388 |
) |
|
|
389 |
) |
|
|
390 |
|
|
|
391 |
## Build optimizer |
|
|
392 |
optimizer <- # keras::optimizer_adam( |
|
|
393 |
tensorflow::tf$keras$optimizers$legacy$Adam( |
|
|
394 |
learning_rate = learningrate, |
|
|
395 |
beta_1 = 0.8, |
|
|
396 |
epsilon = 10 ^ -8, |
|
|
397 |
decay = 0.999, |
|
|
398 |
clipnorm = 0.01 |
|
|
399 |
) |
|
|
400 |
####~~~~~~~~~~~~~~~~~~~~~~~~~~ Unsupervised Read if pretrained model given ~~~~~~~~~~~~~~~~~~~~~~~~~#### |
|
|
401 |
|
|
|
402 |
} else { |
|
|
403 |
message(format(Sys.time(), "%F %R"), ": Loading the trained model.\n") |
|
|
404 |
## Read model |
|
|
405 |
model <- keras::load_model_hdf5(pretrained_model, compile = FALSE) |
|
|
406 |
optimizer <- ReadOpt(pretrained_model) |
|
|
407 |
optimizer$learning_rate$assign(learningrate) |
|
|
408 |
} |
|
|
409 |
|
|
|
410 |
####~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Saving necessary model objects ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#### |
|
|
411 |
## optimizer configuration |
|
|
412 |
|
|
|
413 |
if (!is.null(path_checkpoint)) { |
|
|
414 |
saveRDS(optimizer$get_config(), |
|
|
415 |
paste(dir, "optconfig.rds", sep = "/")) |
|
|
416 |
## model parameters |
|
|
417 |
saveRDS( |
|
|
418 |
list( |
|
|
419 |
maxlen = maxlen, |
|
|
420 |
patchlen = patchlen, |
|
|
421 |
stride = stride, |
|
|
422 |
nopatches = nopatches, |
|
|
423 |
step = step, |
|
|
424 |
batch_size = batch_size, |
|
|
425 |
epochs = epochs, |
|
|
426 |
steps_per_epoch = steps_per_epoch, |
|
|
427 |
train_val_ratio = train_val_ratio, |
|
|
428 |
max_samples = max_samples, |
|
|
429 |
k = k, |
|
|
430 |
emb_scale = emb_scale, |
|
|
431 |
learningrate = learningrate |
|
|
432 |
), |
|
|
433 |
paste(dir, "modelspecs.rds", sep = "/") |
|
|
434 |
) |
|
|
435 |
} |
|
|
436 |
######################################################################################################## |
|
|
437 |
######################################## Tensorboard connection ######################################## |
|
|
438 |
######################################################################################################## |
|
|
439 |
|
|
|
440 |
if (!is.null(path_tensorboard)) { |
|
|
441 |
####~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Initialize Tensorboard writers ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#### |
|
|
442 |
logdir <- path_tensorboard |
|
|
443 |
writertrain <- |
|
|
444 |
tensorflow::tf$summary$create_file_writer(file.path(logdir, runname, "/train")) |
|
|
445 |
writerval <- |
|
|
446 |
tensorflow::tf$summary$create_file_writer(file.path(logdir, runname, "/validation")) |
|
|
447 |
|
|
|
448 |
####~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Write parameters to Tensorboard ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#### |
|
|
449 |
tftext <- |
|
|
450 |
lapply(as.list(match.call())[-1][-c(1, 2)], function(x) |
|
|
451 |
ifelse(all(nchar(deparse( |
|
|
452 |
eval(x) |
|
|
453 |
)) < 20) && !is.null(eval(x)), eval(x), deparse(x))) |
|
|
454 |
|
|
|
455 |
with(writertrain$as_default(), { |
|
|
456 |
tensorflow::tf$summary$text("Specification", |
|
|
457 |
paste( |
|
|
458 |
names(tftext), |
|
|
459 |
tftext, |
|
|
460 |
sep = " = ", |
|
|
461 |
collapse = " \n" |
|
|
462 |
), |
|
|
463 |
step = 0L) |
|
|
464 |
}) |
|
|
465 |
} |
|
|
466 |
|
|
|
467 |
######################################################################################################## |
|
|
468 |
######################################## Training loop function ######################################## |
|
|
469 |
######################################################################################################## |
|
|
470 |
|
|
|
471 |
train_val_loop <- |
|
|
472 |
function(batches = steps_per_epoch, epoch, train_val_ratio) { |
|
|
473 |
####~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Start of loop ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#### |
|
|
474 |
for (i in c("train", "val")) { |
|
|
475 |
if (i == "val") { |
|
|
476 |
## Calculate steps for validation |
|
|
477 |
batches <- ceiling(batches * train_val_ratio) |
|
|
478 |
} |
|
|
479 |
|
|
|
480 |
for (b in seq(batches)) { |
|
|
481 |
if (i == "train") { |
|
|
482 |
####~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Training step ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#### |
|
|
483 |
## If Learning rate schedule specified, calculate learning_rate for current epoch |
|
|
484 |
if (!is.null(learningrate_schedule)) { |
|
|
485 |
optimizer$learning_rate$assign(getEpochLR(learningrate_schedule, epoch)) |
|
|
486 |
} |
|
|
487 |
####~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Optimization step ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#### |
|
|
488 |
|
|
|
489 |
#with(tensorflow::tf$GradientTape() %as% tape, { |
|
|
490 |
with(reticulate::`%as%`(tensorflow::tf$GradientTape(), tape), { |
|
|
491 |
|
|
|
492 |
out <- |
|
|
493 |
modelstep(fastrain(), |
|
|
494 |
model, |
|
|
495 |
train_type, |
|
|
496 |
TRUE) |
|
|
497 |
l <- out[1] |
|
|
498 |
acc <- out[2] |
|
|
499 |
}) |
|
|
500 |
|
|
|
501 |
gradients <- |
|
|
502 |
tape$gradient(l, model$trainable_variables) |
|
|
503 |
optimizer$apply_gradients(purrr::transpose(list( |
|
|
504 |
gradients, model$trainable_variables |
|
|
505 |
))) |
|
|
506 |
train_loss(l) |
|
|
507 |
train_acc(acc) |
|
|
508 |
|
|
|
509 |
} else { |
|
|
510 |
####~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Validation step ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#### |
|
|
511 |
out <- |
|
|
512 |
modelstep(fasval(), |
|
|
513 |
model, |
|
|
514 |
train_type, |
|
|
515 |
FALSE) |
|
|
516 |
|
|
|
517 |
l <- out[1] |
|
|
518 |
acc <- out[2] |
|
|
519 |
val_loss(l) |
|
|
520 |
val_acc(acc) |
|
|
521 |
|
|
|
522 |
} |
|
|
523 |
|
|
|
524 |
## Print status of epoch |
|
|
525 |
if (b %in% seq(0, batches, by = batches / 10)) { |
|
|
526 |
message("-") |
|
|
527 |
} |
|
|
528 |
} |
|
|
529 |
|
|
|
530 |
####~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ End of Epoch ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~#### |
|
|
531 |
if (i == "train") { |
|
|
532 |
## Training step |
|
|
533 |
# Write epoch result metrics value to tensorboard |
|
|
534 |
if (!is.null(path_tensorboard)) { |
|
|
535 |
TB_loss_acc(writertrain, train_loss, train_acc, epoch) |
|
|
536 |
with(writertrain$as_default(), { |
|
|
537 |
tensorflow::tf$summary$scalar('epoch_lr', |
|
|
538 |
optimizer$learning_rate, |
|
|
539 |
step = tensorflow::tf$cast(epoch, "int64")) |
|
|
540 |
tensorflow::tf$summary$scalar( |
|
|
541 |
'training files seen', |
|
|
542 |
nrow( |
|
|
543 |
readr::read_csv( |
|
|
544 |
path_file_log, |
|
|
545 |
col_names = FALSE, |
|
|
546 |
col_types = readr::cols() |
|
|
547 |
) |
|
|
548 |
) / num_files, |
|
|
549 |
step = tensorflow::tf$cast(epoch, "int64") |
|
|
550 |
) |
|
|
551 |
}) |
|
|
552 |
} |
|
|
553 |
# Print epoch result metric values to console |
|
|
554 |
tensorflow::tf$print(" Train Loss", |
|
|
555 |
train_loss$result(), |
|
|
556 |
", Train Acc", |
|
|
557 |
train_acc$result()) |
|
|
558 |
|
|
|
559 |
# Save epoch result metric values to history object |
|
|
560 |
history$params$epochs <- epoch |
|
|
561 |
history$metrics$loss[epoch] <- |
|
|
562 |
as.double(train_loss$result()) |
|
|
563 |
history$metrics$accuracy[epoch] <- |
|
|
564 |
as.double(train_acc$result()) |
|
|
565 |
|
|
|
566 |
# Reset states |
|
|
567 |
train_loss$reset_states() |
|
|
568 |
train_acc$reset_states() |
|
|
569 |
|
|
|
570 |
} else { |
|
|
571 |
## Validation step |
|
|
572 |
# Write epoch result metrics value to tensorboard |
|
|
573 |
if (!is.null(path_tensorboard)) { |
|
|
574 |
TB_loss_acc(writerval, val_loss, val_acc, epoch) |
|
|
575 |
} |
|
|
576 |
|
|
|
577 |
# Print epoch result metric values to console |
|
|
578 |
tensorflow::tf$print(" Validation Loss", |
|
|
579 |
val_loss$result(), |
|
|
580 |
", Validation Acc", |
|
|
581 |
val_acc$result()) |
|
|
582 |
|
|
|
583 |
# save results for best model saving condition |
|
|
584 |
if (b == max(seq(batches))) { |
|
|
585 |
eploss[[epoch]] <- as.double(val_loss$result()) |
|
|
586 |
epacc[[epoch]] <- |
|
|
587 |
as.double(val_acc$result()) |
|
|
588 |
} |
|
|
589 |
|
|
|
590 |
# Save epoch result metric values to history object |
|
|
591 |
history$metrics$val_loss[epoch] <- |
|
|
592 |
as.double(val_loss$result()) |
|
|
593 |
history$metrics$val_accuracy[epoch] <- |
|
|
594 |
as.double(val_acc$result()) |
|
|
595 |
|
|
|
596 |
# Reset states |
|
|
597 |
val_loss$reset_states() |
|
|
598 |
val_acc$reset_states() |
|
|
599 |
} |
|
|
600 |
} |
|
|
601 |
return(list(history,eploss,epacc)) |
|
|
602 |
} |
|
|
603 |
|
|
|
604 |
######################################################################################################## |
|
|
605 |
############################################# Training run ############################################# |
|
|
606 |
######################################################################################################## |
|
|
607 |
|
|
|
608 |
|
|
|
609 |
message(format(Sys.time(), "%F %R"), ": Starting Training\n") |
|
|
610 |
|
|
|
611 |
## Training loop |
|
|
612 |
for (i in seq(initial_epoch, (epochs + initial_epoch - 1))) { |
|
|
613 |
message(format(Sys.time(), "%F %R"), ": EPOCH ", i, " \n") |
|
|
614 |
|
|
|
615 |
## Epoch loop |
|
|
616 |
out <- train_val_loop(epoch = i, train_val_ratio = train_val_ratio) |
|
|
617 |
history <- out[[1]] |
|
|
618 |
eploss <- out[[2]] |
|
|
619 |
epacc <- out[[3]] |
|
|
620 |
## Save checkpoints |
|
|
621 |
# best model (smallest loss) |
|
|
622 |
if (eploss[[i]] == min(unlist(eploss))) { |
|
|
623 |
savechecks("best", runname, model, optimizer, history, path_checkpoint) |
|
|
624 |
} |
|
|
625 |
# backup model every 10 epochs |
|
|
626 |
if (i %% 2 == 0) { |
|
|
627 |
savechecks("backup", runname, model, optimizer, history, path_checkpoint) |
|
|
628 |
} |
|
|
629 |
} |
|
|
630 |
|
|
|
631 |
######################################################################################################## |
|
|
632 |
############################################# Final saves ############################################## |
|
|
633 |
######################################################################################################## |
|
|
634 |
|
|
|
635 |
savechecks(cp = "FINAL", runname, model, optimizer, history, path_checkpoint) |
|
|
636 |
if (!is.null(path_tensorboard)) { |
|
|
637 |
writegraph <- |
|
|
638 |
tensorflow::tf$keras$callbacks$TensorBoard(file.path(logdir, runname)) |
|
|
639 |
writegraph$set_model(model) |
|
|
640 |
} |
|
|
641 |
} |