Switch to unified view

a b/tests/testthat/test-predict.R
1
context("predict")
2
3
test_that("Sucessful prediction", {
4
   
5
   #testthat::skip_if_not_installed("tensorflow")
6
   testthat::skip_if_not(reticulate::py_module_available("tensorflow"))
7
   
8
   sequence <- "AAACCNGGGTTT"
9
   maxlen <- 8
10
   filename <- tempfile(fileext = ".h5")
11
   
12
   model <- create_model_lstm_cnn(
13
      maxlen = maxlen,
14
      verbose = FALSE,
15
      layer_dense = 4,
16
      layer_lstm = 8)
17
   
18
   # test h5 output
19
   pred <- predict_model(layer_name = NULL, sequence = sequence,
20
                         filename = filename, step = 1,
21
                         batch_size = 1, 
22
                         return_states = TRUE,
23
                         verbose = FALSE,
24
                         output_type = "h5",
25
                         model = model,
26
                         mode = "label", 
27
                         include_seq = TRUE)
28
   
29
   expect_true(all(pred$states >= 0))
30
   expect_true(all(pred$states <= 1))
31
   expect_equal(pred$sample_end_position, 8:12)
32
   
33
   pred_h5 <- load_prediction(filename, get_sample_position = TRUE)
34
   expect_equal(pred_h5$states, pred$states)
35
   expect_equal(pred_h5$sample_end_position, pred$sample_end_position)
36
   
37
   # batch size bigger than number of samples
38
   pred2 <- predict_model(layer_name = NULL, sequence = sequence,
39
                          filename = NULL, step = 1,
40
                          batch_size = 100, 
41
                          return_states = TRUE,
42
                          verbose = FALSE,
43
                          output_type = "h5",
44
                          model = model,
45
                          mode = "label", 
46
                          include_seq = TRUE)
47
   
48
   expect_true(all(abs(pred$states - pred2$states) < 1e-06))
49
   expect_equal(pred$sample_end_position, pred2$sample_end_position)
50
   
51
   # test csv + padding maxlen + ... (nuc_dist)
52
   filename <- tempfile(fileext = ".csv")
53
   pred <- predict_model(layer_name = NULL, sequence = sequence,
54
                         filename = filename, step = 1,
55
                         batch_size = 2, 
56
                         return_states = TRUE,
57
                         padding = "maxlen",
58
                         verbose = FALSE,
59
                         output_type = "csv", model = model,
60
                         mode = "label", ambiguous_nuc = "empirical",
61
                         nuc_dist = c(0.1,0.4,0.4,0.1),
62
                         include_seq = TRUE)
63
   
64
   expect_true(all(pred$states >= 0))
65
   expect_true(all(pred$states <= 1))
66
   expect_equal(pred$sample_end_position, 0:12)
67
   
68
   pred_csv <- read.csv(filename)
69
   expect_equal(as.matrix(pred_csv), pred$states)
70
   
71
   # padding 
72
   pred <- predict_model(layer_name = NULL, sequence = "AAA",
73
                         filename = NULL, step = 2,
74
                         batch_size = 2, 
75
                         return_states = TRUE,
76
                         padding = "standard",
77
                         verbose = FALSE,
78
                         output_type = "csv", model = model,
79
                         mode = "label", 
80
                         include_seq = TRUE)
81
   
82
   expect_true(all(pred$states >= 0))
83
   expect_true(all(pred$states <= 1))
84
   expect_equal(pred$sample_end_position, 3)
85
   expect_equal(nrow(pred$states), length(pred$sample_end_position))
86
   
87
   # step
88
   pred <- predict_model(layer_name = NULL, sequence = "AAAAACCCCC",
89
                         filename = NULL, step = 2,
90
                         batch_size = 2, 
91
                         return_states = TRUE,
92
                         padding = "standard",
93
                         verbose = FALSE,
94
                         output_type = "csv", model = model,
95
                         mode = "label", 
96
                         include_seq = TRUE)
97
   
98
   expect_equal(pred$sample_end_position, c(8, 10))
99
   expect_equal(nrow(pred$states), length(pred$sample_end_position))
100
   
101
   # fasta file by_entry
102
   Sequence <- c("AAAACCCC", "TT", "AAACCCGGGTTT")
103
   Header <- letters[1:3]   
104
   df <- data.frame(Sequence, Header)
105
   fasta_path <- tempfile(fileext = ".fasta")
106
   microseq::writeFasta(df, fasta_path)
107
   output_path <- tempfile()
108
   dir.create(output_path)
109
   
110
   expect_message(
111
      predict_model(layer_name = NULL, 
112
                    path_input = fasta_path,
113
                    output_format = "by_entry",
114
                    output_dir = output_path,
115
                    filename = "states.h5",
116
                    step = 2,
117
                    batch_size = 2, 
118
                    padding = "none",
119
                    verbose = TRUE,
120
                    output_type = "h5",
121
                    model = model,
122
                    mode = "label", 
123
                    include_seq = TRUE)
124
   )
125
   
126
   h5_files <- list.files(output_path, full.names = TRUE)
127
   expect_true(basename(h5_files[1]) == "states_nr_1.h5")
128
   expect_true(basename(h5_files[2]) == "states_nr_3.h5")
129
   
130
   output_list_1 <- load_prediction(h5_files[1], get_sample_position = TRUE)
131
   expect_equal(output_list_1$sample_end_position, 8)
132
   output_list_2 <- load_prediction(h5_files[2], get_sample_position = TRUE)
133
   expect_equal(output_list_2$sample_end_position, c(8,10,12))
134
   
135
   # fasta file, by_entry
136
   h5_file <- tempfile(fileext = ".h5")
137
   pred <- predict_model(layer_name = NULL, 
138
                         path_input = fasta_path,
139
                         output_format = "by_entry_one_file",
140
                         filename = h5_file,
141
                         step = 2,
142
                         batch_size = 2, 
143
                         padding = "none",
144
                         verbose = FALSE,
145
                         output_type = "h5",
146
                         model = model,
147
                         mode = "label")
148
   
149
   output_list <- load_prediction(h5_file, get_sample_position = TRUE)
150
   expect_true(all(output_list[[1]]$states == output_list_1$states))
151
   expect_true(all(output_list[[1]]$sample_end_position == output_list_1$sample_end_position))
152
   expect_true(all(output_list[[2]]$states == output_list_2$states))
153
   expect_true(all(output_list[[2]]$sample_end_position == output_list_2$sample_end_position))
154
   
155
   # one pred per entry
156
   h5_file <- tempfile(fileext = ".h5")
157
   pred <- predict_model(layer_name = NULL, 
158
                         path_input = fasta_path,
159
                         output_format = "one_pred_per_entry",
160
                         filename = h5_file,
161
                         step = 2,
162
                         batch_size = 2, 
163
                         verbose = FALSE,
164
                         output_type = "h5",
165
                         model = model,
166
                         mode = "label")
167
   
168
   output_list <- load_prediction(h5_file)
169
   expect_equal(nrow(output_list$states),  nrow(df))
170
   
171
   # lm, target middle
172
   model <- create_model_lstm_cnn_target_middle(
173
      maxlen = maxlen,
174
      verbose = FALSE,
175
      layer_dense = 4,
176
      layer_lstm = 8)
177
   
178
   h5_file <- tempfile(fileext = ".h5")
179
   pred <- predict_model(layer_name = NULL, 
180
                         path_input = fasta_path,
181
                         output_format = "by_entry_one_file",
182
                         filename = h5_file,
183
                         step = 2,
184
                         target_len = 1,
185
                         batch_size = 2, 
186
                         padding = "standard",
187
                         verbose = FALSE,
188
                         output_type = "h5",
189
                         lm_format = "target_middle_lstm",
190
                         model = model,
191
                         mode = "lm")
192
   
193
   output_list <- load_prediction(h5_file, get_sample_position = TRUE)
194
   expect_equal(output_list[[1]]$sample_end_position, 8)
195
   expect_equal(output_list[[2]]$sample_end_position, 2)
196
   expect_equal(output_list[[3]]$sample_end_position[1], 9)
197
   
198
})