a b/tests/testthat/test-train.R
1
context("training")
2
3
test_that("Sucessful training from a dummy model", {
4
   
5
   #testthat::skip_if_not_installed("tensorflow")
6
   testthat::skip_if_not(reticulate::py_module_available("tensorflow"))
7
   
8
   # language model
9
   maxlen <- 30
10
   batch_size <- 10
11
12
   model <- create_model_lstm_cnn(
13
      maxlen = maxlen,
14
      layer_dense = 4,
15
      layer_lstm = 8,
16
      solver = "adam",
17
      vocabulary_size = 4,
18
      compile = TRUE)
19
20
   trainedNetwork <- train_model(reduce_lr_on_plateau = FALSE,
21
                                 train_type = "lm",
22
                                 path = "fasta",
23
                                 path_val = "fasta",
24
                                 model = model,
25
                                 train_val_ratio = 0.2,
26
                                 steps_per_epoch = 3,
27
                                 batch_size = batch_size,
28
                                 epochs = 1)
29
30
   expect_type(trainedNetwork, "list")
31
   expect_type(trainedNetwork[["metrics"]][["loss"]], "double")
32
   expect_gte(trainedNetwork[["metrics"]][["loss"]], 0)
33
   expect_type(trainedNetwork[["metrics"]][["val_loss"]], "double")
34
   expect_gte(trainedNetwork[["metrics"]][["val_loss"]], 0)
35
36
   # label folder
37
   model <- create_model_lstm_cnn(
38
      maxlen = maxlen,
39
      kernel_size = c(4,4),
40
      pool_size = c(2,2),
41
      filters = c(2,4),
42
      layer_dense = 2,
43
      layer_lstm = 8,
44
      solver = "adam",
45
      vocabulary_size = 4,
46
      compile = TRUE)
47
48
   trainedNetwork <- train_model(reduce_lr_on_plateau = FALSE,
49
                                 train_type = "label_folder",
50
                                 path = rep("fasta", 2),
51
                                 path_val = rep("fasta", 2),
52
                                 model = model,
53
                                 vocabulary_label = c("A", "B"),
54
                                 train_val_ratio = 0.2,
55
                                 steps_per_epoch = 3,
56
                                 batch_size = batch_size,
57
                                 epochs = 1)
58
59
   expect_type(trainedNetwork, "list")
60
   expect_type(trainedNetwork[["metrics"]][["loss"]], "double")
61
   expect_gte(trainedNetwork[["metrics"]][["loss"]], 0)
62
   expect_type(trainedNetwork[["metrics"]][["val_loss"]], "double")
63
   expect_gte(trainedNetwork[["metrics"]][["val_loss"]], 0)
64
   
65
   # train/val split with csv
66
   maxlen <- 3
67
   model <- create_model_lstm_cnn(
68
      maxlen = maxlen,
69
      layer_dense = 2,
70
      layer_lstm = 6,
71
      solver = "adam",
72
      vocabulary_size = 4,
73
      compile = TRUE)
74
   
75
   files <- as.list(list.files(c("fasta_2", "fasta"), full.names = TRUE))
76
   
77
   # target csv
78
   A <- sample(c(0,1), length(files), replace = TRUE)
79
   B <- ifelse(A == 0, 1, 0)
80
   target_df <- data.frame(file = basename(unlist(files)), A = A, B = B)
81
   target_from_csv <- tempfile(fileext = ".csv")
82
   write.csv(target_df, target_from_csv, row.names = FALSE)
83
   
84
   # train/val csv
85
   train_val_split_csv <- tempfile(fileext = ".csv")
86
   ttv_df <- data.frame(file = basename(unlist(files)), 
87
                        type = rep(c("train", "validation"), each = 2))
88
   train_files <- ttv_df$file[ttv_df$type == "train"]
89
   val_files <- ttv_df$file[ttv_df$type == "validation"]
90
   write.csv(ttv_df, train_val_split_csv, row.names = FALSE)
91
   
92
   path_file_log <- tempfile(fileext = ".csv")
93
   
94
   trainedNetwork <- train_model(reduce_lr_on_plateau = FALSE,
95
                                 train_type = "label_csv",
96
                                 path = files,
97
                                 path_val = files,
98
                                 model = model,
99
                                 vocabulary_label = c("A", "B"),
100
                                 train_val_ratio = 0.2,
101
                                 steps_per_epoch = 3,
102
                                 batch_size = 4,
103
                                 max_samples = 2,
104
                                 target_from_csv = target_from_csv,
105
                                 train_val_split_csv = train_val_split_csv,
106
                                 path_file_log = path_file_log,
107
                                 epochs = 2)
108
   
109
   file_log <- read.csv(path_file_log)
110
   train_file_log <- unique(basename(file_log[,1]))
111
   expect_true(all(sort(train_file_log) == sort(train_files)))
112
   expect_true(all(sort(train_file_log) != sort(val_files)))
113
   
114
115
})