--- a +++ b/tests/testthat/test-predict.R @@ -0,0 +1,198 @@ +context("predict") + +test_that("Sucessful prediction", { + + #testthat::skip_if_not_installed("tensorflow") + testthat::skip_if_not(reticulate::py_module_available("tensorflow")) + + sequence <- "AAACCNGGGTTT" + maxlen <- 8 + filename <- tempfile(fileext = ".h5") + + model <- create_model_lstm_cnn( + maxlen = maxlen, + verbose = FALSE, + layer_dense = 4, + layer_lstm = 8) + + # test h5 output + pred <- predict_model(layer_name = NULL, sequence = sequence, + filename = filename, step = 1, + batch_size = 1, + return_states = TRUE, + verbose = FALSE, + output_type = "h5", + model = model, + mode = "label", + include_seq = TRUE) + + expect_true(all(pred$states >= 0)) + expect_true(all(pred$states <= 1)) + expect_equal(pred$sample_end_position, 8:12) + + pred_h5 <- load_prediction(filename, get_sample_position = TRUE) + expect_equal(pred_h5$states, pred$states) + expect_equal(pred_h5$sample_end_position, pred$sample_end_position) + + # batch size bigger than number of samples + pred2 <- predict_model(layer_name = NULL, sequence = sequence, + filename = NULL, step = 1, + batch_size = 100, + return_states = TRUE, + verbose = FALSE, + output_type = "h5", + model = model, + mode = "label", + include_seq = TRUE) + + expect_true(all(abs(pred$states - pred2$states) < 1e-06)) + expect_equal(pred$sample_end_position, pred2$sample_end_position) + + # test csv + padding maxlen + ... (nuc_dist) + filename <- tempfile(fileext = ".csv") + pred <- predict_model(layer_name = NULL, sequence = sequence, + filename = filename, step = 1, + batch_size = 2, + return_states = TRUE, + padding = "maxlen", + verbose = FALSE, + output_type = "csv", model = model, + mode = "label", ambiguous_nuc = "empirical", + nuc_dist = c(0.1,0.4,0.4,0.1), + include_seq = TRUE) + + expect_true(all(pred$states >= 0)) + expect_true(all(pred$states <= 1)) + expect_equal(pred$sample_end_position, 0:12) + + pred_csv <- read.csv(filename) + expect_equal(as.matrix(pred_csv), pred$states) + + # padding + pred <- predict_model(layer_name = NULL, sequence = "AAA", + filename = NULL, step = 2, + batch_size = 2, + return_states = TRUE, + padding = "standard", + verbose = FALSE, + output_type = "csv", model = model, + mode = "label", + include_seq = TRUE) + + expect_true(all(pred$states >= 0)) + expect_true(all(pred$states <= 1)) + expect_equal(pred$sample_end_position, 3) + expect_equal(nrow(pred$states), length(pred$sample_end_position)) + + # step + pred <- predict_model(layer_name = NULL, sequence = "AAAAACCCCC", + filename = NULL, step = 2, + batch_size = 2, + return_states = TRUE, + padding = "standard", + verbose = FALSE, + output_type = "csv", model = model, + mode = "label", + include_seq = TRUE) + + expect_equal(pred$sample_end_position, c(8, 10)) + expect_equal(nrow(pred$states), length(pred$sample_end_position)) + + # fasta file by_entry + Sequence <- c("AAAACCCC", "TT", "AAACCCGGGTTT") + Header <- letters[1:3] + df <- data.frame(Sequence, Header) + fasta_path <- tempfile(fileext = ".fasta") + microseq::writeFasta(df, fasta_path) + output_path <- tempfile() + dir.create(output_path) + + expect_message( + predict_model(layer_name = NULL, + path_input = fasta_path, + output_format = "by_entry", + output_dir = output_path, + filename = "states.h5", + step = 2, + batch_size = 2, + padding = "none", + verbose = TRUE, + output_type = "h5", + model = model, + mode = "label", + include_seq = TRUE) + ) + + h5_files <- list.files(output_path, full.names = TRUE) + expect_true(basename(h5_files[1]) == "states_nr_1.h5") + expect_true(basename(h5_files[2]) == "states_nr_3.h5") + + output_list_1 <- load_prediction(h5_files[1], get_sample_position = TRUE) + expect_equal(output_list_1$sample_end_position, 8) + output_list_2 <- load_prediction(h5_files[2], get_sample_position = TRUE) + expect_equal(output_list_2$sample_end_position, c(8,10,12)) + + # fasta file, by_entry + h5_file <- tempfile(fileext = ".h5") + pred <- predict_model(layer_name = NULL, + path_input = fasta_path, + output_format = "by_entry_one_file", + filename = h5_file, + step = 2, + batch_size = 2, + padding = "none", + verbose = FALSE, + output_type = "h5", + model = model, + mode = "label") + + output_list <- load_prediction(h5_file, get_sample_position = TRUE) + expect_true(all(output_list[[1]]$states == output_list_1$states)) + expect_true(all(output_list[[1]]$sample_end_position == output_list_1$sample_end_position)) + expect_true(all(output_list[[2]]$states == output_list_2$states)) + expect_true(all(output_list[[2]]$sample_end_position == output_list_2$sample_end_position)) + + # one pred per entry + h5_file <- tempfile(fileext = ".h5") + pred <- predict_model(layer_name = NULL, + path_input = fasta_path, + output_format = "one_pred_per_entry", + filename = h5_file, + step = 2, + batch_size = 2, + verbose = FALSE, + output_type = "h5", + model = model, + mode = "label") + + output_list <- load_prediction(h5_file) + expect_equal(nrow(output_list$states), nrow(df)) + + # lm, target middle + model <- create_model_lstm_cnn_target_middle( + maxlen = maxlen, + verbose = FALSE, + layer_dense = 4, + layer_lstm = 8) + + h5_file <- tempfile(fileext = ".h5") + pred <- predict_model(layer_name = NULL, + path_input = fasta_path, + output_format = "by_entry_one_file", + filename = h5_file, + step = 2, + target_len = 1, + batch_size = 2, + padding = "standard", + verbose = FALSE, + output_type = "h5", + lm_format = "target_middle_lstm", + model = model, + mode = "lm") + + output_list <- load_prediction(h5_file, get_sample_position = TRUE) + expect_equal(output_list[[1]]$sample_end_position, 8) + expect_equal(output_list[[2]]$sample_end_position, 2) + expect_equal(output_list[[3]]$sample_end_position[1], 9) + +})