|
a |
|
b/tests/testthat/test-train.R |
|
|
1 |
context("training") |
|
|
2 |
|
|
|
3 |
test_that("Sucessful training from a dummy model", { |
|
|
4 |
|
|
|
5 |
#testthat::skip_if_not_installed("tensorflow") |
|
|
6 |
testthat::skip_if_not(reticulate::py_module_available("tensorflow")) |
|
|
7 |
|
|
|
8 |
# language model |
|
|
9 |
maxlen <- 30 |
|
|
10 |
batch_size <- 10 |
|
|
11 |
|
|
|
12 |
model <- create_model_lstm_cnn( |
|
|
13 |
maxlen = maxlen, |
|
|
14 |
layer_dense = 4, |
|
|
15 |
layer_lstm = 8, |
|
|
16 |
solver = "adam", |
|
|
17 |
vocabulary_size = 4, |
|
|
18 |
compile = TRUE) |
|
|
19 |
|
|
|
20 |
trainedNetwork <- train_model(reduce_lr_on_plateau = FALSE, |
|
|
21 |
train_type = "lm", |
|
|
22 |
path = "fasta", |
|
|
23 |
path_val = "fasta", |
|
|
24 |
model = model, |
|
|
25 |
train_val_ratio = 0.2, |
|
|
26 |
steps_per_epoch = 3, |
|
|
27 |
batch_size = batch_size, |
|
|
28 |
epochs = 1) |
|
|
29 |
|
|
|
30 |
expect_type(trainedNetwork, "list") |
|
|
31 |
expect_type(trainedNetwork[["metrics"]][["loss"]], "double") |
|
|
32 |
expect_gte(trainedNetwork[["metrics"]][["loss"]], 0) |
|
|
33 |
expect_type(trainedNetwork[["metrics"]][["val_loss"]], "double") |
|
|
34 |
expect_gte(trainedNetwork[["metrics"]][["val_loss"]], 0) |
|
|
35 |
|
|
|
36 |
# label folder |
|
|
37 |
model <- create_model_lstm_cnn( |
|
|
38 |
maxlen = maxlen, |
|
|
39 |
kernel_size = c(4,4), |
|
|
40 |
pool_size = c(2,2), |
|
|
41 |
filters = c(2,4), |
|
|
42 |
layer_dense = 2, |
|
|
43 |
layer_lstm = 8, |
|
|
44 |
solver = "adam", |
|
|
45 |
vocabulary_size = 4, |
|
|
46 |
compile = TRUE) |
|
|
47 |
|
|
|
48 |
trainedNetwork <- train_model(reduce_lr_on_plateau = FALSE, |
|
|
49 |
train_type = "label_folder", |
|
|
50 |
path = rep("fasta", 2), |
|
|
51 |
path_val = rep("fasta", 2), |
|
|
52 |
model = model, |
|
|
53 |
vocabulary_label = c("A", "B"), |
|
|
54 |
train_val_ratio = 0.2, |
|
|
55 |
steps_per_epoch = 3, |
|
|
56 |
batch_size = batch_size, |
|
|
57 |
epochs = 1) |
|
|
58 |
|
|
|
59 |
expect_type(trainedNetwork, "list") |
|
|
60 |
expect_type(trainedNetwork[["metrics"]][["loss"]], "double") |
|
|
61 |
expect_gte(trainedNetwork[["metrics"]][["loss"]], 0) |
|
|
62 |
expect_type(trainedNetwork[["metrics"]][["val_loss"]], "double") |
|
|
63 |
expect_gte(trainedNetwork[["metrics"]][["val_loss"]], 0) |
|
|
64 |
|
|
|
65 |
# train/val split with csv |
|
|
66 |
maxlen <- 3 |
|
|
67 |
model <- create_model_lstm_cnn( |
|
|
68 |
maxlen = maxlen, |
|
|
69 |
layer_dense = 2, |
|
|
70 |
layer_lstm = 6, |
|
|
71 |
solver = "adam", |
|
|
72 |
vocabulary_size = 4, |
|
|
73 |
compile = TRUE) |
|
|
74 |
|
|
|
75 |
files <- as.list(list.files(c("fasta_2", "fasta"), full.names = TRUE)) |
|
|
76 |
|
|
|
77 |
# target csv |
|
|
78 |
A <- sample(c(0,1), length(files), replace = TRUE) |
|
|
79 |
B <- ifelse(A == 0, 1, 0) |
|
|
80 |
target_df <- data.frame(file = basename(unlist(files)), A = A, B = B) |
|
|
81 |
target_from_csv <- tempfile(fileext = ".csv") |
|
|
82 |
write.csv(target_df, target_from_csv, row.names = FALSE) |
|
|
83 |
|
|
|
84 |
# train/val csv |
|
|
85 |
train_val_split_csv <- tempfile(fileext = ".csv") |
|
|
86 |
ttv_df <- data.frame(file = basename(unlist(files)), |
|
|
87 |
type = rep(c("train", "validation"), each = 2)) |
|
|
88 |
train_files <- ttv_df$file[ttv_df$type == "train"] |
|
|
89 |
val_files <- ttv_df$file[ttv_df$type == "validation"] |
|
|
90 |
write.csv(ttv_df, train_val_split_csv, row.names = FALSE) |
|
|
91 |
|
|
|
92 |
path_file_log <- tempfile(fileext = ".csv") |
|
|
93 |
|
|
|
94 |
trainedNetwork <- train_model(reduce_lr_on_plateau = FALSE, |
|
|
95 |
train_type = "label_csv", |
|
|
96 |
path = files, |
|
|
97 |
path_val = files, |
|
|
98 |
model = model, |
|
|
99 |
vocabulary_label = c("A", "B"), |
|
|
100 |
train_val_ratio = 0.2, |
|
|
101 |
steps_per_epoch = 3, |
|
|
102 |
batch_size = 4, |
|
|
103 |
max_samples = 2, |
|
|
104 |
target_from_csv = target_from_csv, |
|
|
105 |
train_val_split_csv = train_val_split_csv, |
|
|
106 |
path_file_log = path_file_log, |
|
|
107 |
epochs = 2) |
|
|
108 |
|
|
|
109 |
file_log <- read.csv(path_file_log) |
|
|
110 |
train_file_log <- unique(basename(file_log[,1])) |
|
|
111 |
expect_true(all(sort(train_file_log) == sort(train_files))) |
|
|
112 |
expect_true(all(sort(train_file_log) != sort(val_files))) |
|
|
113 |
|
|
|
114 |
|
|
|
115 |
}) |