Switch to side-by-side view

--- a
+++ b/tests/testthat/test-train.R
@@ -0,0 +1,115 @@
+context("training")
+
+test_that("Sucessful training from a dummy model", {
+   
+   #testthat::skip_if_not_installed("tensorflow")
+   testthat::skip_if_not(reticulate::py_module_available("tensorflow"))
+   
+   # language model
+   maxlen <- 30
+   batch_size <- 10
+
+   model <- create_model_lstm_cnn(
+      maxlen = maxlen,
+      layer_dense = 4,
+      layer_lstm = 8,
+      solver = "adam",
+      vocabulary_size = 4,
+      compile = TRUE)
+
+   trainedNetwork <- train_model(reduce_lr_on_plateau = FALSE,
+                                 train_type = "lm",
+                                 path = "fasta",
+                                 path_val = "fasta",
+                                 model = model,
+                                 train_val_ratio = 0.2,
+                                 steps_per_epoch = 3,
+                                 batch_size = batch_size,
+                                 epochs = 1)
+
+   expect_type(trainedNetwork, "list")
+   expect_type(trainedNetwork[["metrics"]][["loss"]], "double")
+   expect_gte(trainedNetwork[["metrics"]][["loss"]], 0)
+   expect_type(trainedNetwork[["metrics"]][["val_loss"]], "double")
+   expect_gte(trainedNetwork[["metrics"]][["val_loss"]], 0)
+
+   # label folder
+   model <- create_model_lstm_cnn(
+      maxlen = maxlen,
+      kernel_size = c(4,4),
+      pool_size = c(2,2),
+      filters = c(2,4),
+      layer_dense = 2,
+      layer_lstm = 8,
+      solver = "adam",
+      vocabulary_size = 4,
+      compile = TRUE)
+
+   trainedNetwork <- train_model(reduce_lr_on_plateau = FALSE,
+                                 train_type = "label_folder",
+                                 path = rep("fasta", 2),
+                                 path_val = rep("fasta", 2),
+                                 model = model,
+                                 vocabulary_label = c("A", "B"),
+                                 train_val_ratio = 0.2,
+                                 steps_per_epoch = 3,
+                                 batch_size = batch_size,
+                                 epochs = 1)
+
+   expect_type(trainedNetwork, "list")
+   expect_type(trainedNetwork[["metrics"]][["loss"]], "double")
+   expect_gte(trainedNetwork[["metrics"]][["loss"]], 0)
+   expect_type(trainedNetwork[["metrics"]][["val_loss"]], "double")
+   expect_gte(trainedNetwork[["metrics"]][["val_loss"]], 0)
+   
+   # train/val split with csv
+   maxlen <- 3
+   model <- create_model_lstm_cnn(
+      maxlen = maxlen,
+      layer_dense = 2,
+      layer_lstm = 6,
+      solver = "adam",
+      vocabulary_size = 4,
+      compile = TRUE)
+   
+   files <- as.list(list.files(c("fasta_2", "fasta"), full.names = TRUE))
+   
+   # target csv
+   A <- sample(c(0,1), length(files), replace = TRUE)
+   B <- ifelse(A == 0, 1, 0)
+   target_df <- data.frame(file = basename(unlist(files)), A = A, B = B)
+   target_from_csv <- tempfile(fileext = ".csv")
+   write.csv(target_df, target_from_csv, row.names = FALSE)
+   
+   # train/val csv
+   train_val_split_csv <- tempfile(fileext = ".csv")
+   ttv_df <- data.frame(file = basename(unlist(files)), 
+                        type = rep(c("train", "validation"), each = 2))
+   train_files <- ttv_df$file[ttv_df$type == "train"]
+   val_files <- ttv_df$file[ttv_df$type == "validation"]
+   write.csv(ttv_df, train_val_split_csv, row.names = FALSE)
+   
+   path_file_log <- tempfile(fileext = ".csv")
+   
+   trainedNetwork <- train_model(reduce_lr_on_plateau = FALSE,
+                                 train_type = "label_csv",
+                                 path = files,
+                                 path_val = files,
+                                 model = model,
+                                 vocabulary_label = c("A", "B"),
+                                 train_val_ratio = 0.2,
+                                 steps_per_epoch = 3,
+                                 batch_size = 4,
+                                 max_samples = 2,
+                                 target_from_csv = target_from_csv,
+                                 train_val_split_csv = train_val_split_csv,
+                                 path_file_log = path_file_log,
+                                 epochs = 2)
+   
+   file_log <- read.csv(path_file_log)
+   train_file_log <- unique(basename(file_log[,1]))
+   expect_true(all(sort(train_file_log) == sort(train_files)))
+   expect_true(all(sort(train_file_log) != sort(val_files)))
+   
+
+})