--- a
+++ b/tests/testthat/test-predict.R
@@ -0,0 +1,198 @@
+context("predict")
+
+test_that("Sucessful prediction", {
+   
+   #testthat::skip_if_not_installed("tensorflow")
+   testthat::skip_if_not(reticulate::py_module_available("tensorflow"))
+   
+   sequence <- "AAACCNGGGTTT"
+   maxlen <- 8
+   filename <- tempfile(fileext = ".h5")
+   
+   model <- create_model_lstm_cnn(
+      maxlen = maxlen,
+      verbose = FALSE,
+      layer_dense = 4,
+      layer_lstm = 8)
+   
+   # test h5 output
+   pred <- predict_model(layer_name = NULL, sequence = sequence,
+                         filename = filename, step = 1,
+                         batch_size = 1, 
+                         return_states = TRUE,
+                         verbose = FALSE,
+                         output_type = "h5",
+                         model = model,
+                         mode = "label", 
+                         include_seq = TRUE)
+   
+   expect_true(all(pred$states >= 0))
+   expect_true(all(pred$states <= 1))
+   expect_equal(pred$sample_end_position, 8:12)
+   
+   pred_h5 <- load_prediction(filename, get_sample_position = TRUE)
+   expect_equal(pred_h5$states, pred$states)
+   expect_equal(pred_h5$sample_end_position, pred$sample_end_position)
+   
+   # batch size bigger than number of samples
+   pred2 <- predict_model(layer_name = NULL, sequence = sequence,
+                          filename = NULL, step = 1,
+                          batch_size = 100, 
+                          return_states = TRUE,
+                          verbose = FALSE,
+                          output_type = "h5",
+                          model = model,
+                          mode = "label", 
+                          include_seq = TRUE)
+   
+   expect_true(all(abs(pred$states - pred2$states) < 1e-06))
+   expect_equal(pred$sample_end_position, pred2$sample_end_position)
+   
+   # test csv + padding maxlen + ... (nuc_dist)
+   filename <- tempfile(fileext = ".csv")
+   pred <- predict_model(layer_name = NULL, sequence = sequence,
+                         filename = filename, step = 1,
+                         batch_size = 2, 
+                         return_states = TRUE,
+                         padding = "maxlen",
+                         verbose = FALSE,
+                         output_type = "csv", model = model,
+                         mode = "label", ambiguous_nuc = "empirical",
+                         nuc_dist = c(0.1,0.4,0.4,0.1),
+                         include_seq = TRUE)
+   
+   expect_true(all(pred$states >= 0))
+   expect_true(all(pred$states <= 1))
+   expect_equal(pred$sample_end_position, 0:12)
+   
+   pred_csv <- read.csv(filename)
+   expect_equal(as.matrix(pred_csv), pred$states)
+   
+   # padding 
+   pred <- predict_model(layer_name = NULL, sequence = "AAA",
+                         filename = NULL, step = 2,
+                         batch_size = 2, 
+                         return_states = TRUE,
+                         padding = "standard",
+                         verbose = FALSE,
+                         output_type = "csv", model = model,
+                         mode = "label", 
+                         include_seq = TRUE)
+   
+   expect_true(all(pred$states >= 0))
+   expect_true(all(pred$states <= 1))
+   expect_equal(pred$sample_end_position, 3)
+   expect_equal(nrow(pred$states), length(pred$sample_end_position))
+   
+   # step
+   pred <- predict_model(layer_name = NULL, sequence = "AAAAACCCCC",
+                         filename = NULL, step = 2,
+                         batch_size = 2, 
+                         return_states = TRUE,
+                         padding = "standard",
+                         verbose = FALSE,
+                         output_type = "csv", model = model,
+                         mode = "label", 
+                         include_seq = TRUE)
+   
+   expect_equal(pred$sample_end_position, c(8, 10))
+   expect_equal(nrow(pred$states), length(pred$sample_end_position))
+   
+   # fasta file by_entry
+   Sequence <- c("AAAACCCC", "TT", "AAACCCGGGTTT")
+   Header <- letters[1:3]   
+   df <- data.frame(Sequence, Header)
+   fasta_path <- tempfile(fileext = ".fasta")
+   microseq::writeFasta(df, fasta_path)
+   output_path <- tempfile()
+   dir.create(output_path)
+   
+   expect_message(
+      predict_model(layer_name = NULL, 
+                    path_input = fasta_path,
+                    output_format = "by_entry",
+                    output_dir = output_path,
+                    filename = "states.h5",
+                    step = 2,
+                    batch_size = 2, 
+                    padding = "none",
+                    verbose = TRUE,
+                    output_type = "h5",
+                    model = model,
+                    mode = "label", 
+                    include_seq = TRUE)
+   )
+   
+   h5_files <- list.files(output_path, full.names = TRUE)
+   expect_true(basename(h5_files[1]) == "states_nr_1.h5")
+   expect_true(basename(h5_files[2]) == "states_nr_3.h5")
+   
+   output_list_1 <- load_prediction(h5_files[1], get_sample_position = TRUE)
+   expect_equal(output_list_1$sample_end_position, 8)
+   output_list_2 <- load_prediction(h5_files[2], get_sample_position = TRUE)
+   expect_equal(output_list_2$sample_end_position, c(8,10,12))
+   
+   # fasta file, by_entry
+   h5_file <- tempfile(fileext = ".h5")
+   pred <- predict_model(layer_name = NULL, 
+                         path_input = fasta_path,
+                         output_format = "by_entry_one_file",
+                         filename = h5_file,
+                         step = 2,
+                         batch_size = 2, 
+                         padding = "none",
+                         verbose = FALSE,
+                         output_type = "h5",
+                         model = model,
+                         mode = "label")
+   
+   output_list <- load_prediction(h5_file, get_sample_position = TRUE)
+   expect_true(all(output_list[[1]]$states == output_list_1$states))
+   expect_true(all(output_list[[1]]$sample_end_position == output_list_1$sample_end_position))
+   expect_true(all(output_list[[2]]$states == output_list_2$states))
+   expect_true(all(output_list[[2]]$sample_end_position == output_list_2$sample_end_position))
+   
+   # one pred per entry
+   h5_file <- tempfile(fileext = ".h5")
+   pred <- predict_model(layer_name = NULL, 
+                         path_input = fasta_path,
+                         output_format = "one_pred_per_entry",
+                         filename = h5_file,
+                         step = 2,
+                         batch_size = 2, 
+                         verbose = FALSE,
+                         output_type = "h5",
+                         model = model,
+                         mode = "label")
+   
+   output_list <- load_prediction(h5_file)
+   expect_equal(nrow(output_list$states),  nrow(df))
+   
+   # lm, target middle
+   model <- create_model_lstm_cnn_target_middle(
+      maxlen = maxlen,
+      verbose = FALSE,
+      layer_dense = 4,
+      layer_lstm = 8)
+   
+   h5_file <- tempfile(fileext = ".h5")
+   pred <- predict_model(layer_name = NULL, 
+                         path_input = fasta_path,
+                         output_format = "by_entry_one_file",
+                         filename = h5_file,
+                         step = 2,
+                         target_len = 1,
+                         batch_size = 2, 
+                         padding = "standard",
+                         verbose = FALSE,
+                         output_type = "h5",
+                         lm_format = "target_middle_lstm",
+                         model = model,
+                         mode = "lm")
+   
+   output_list <- load_prediction(h5_file, get_sample_position = TRUE)
+   expect_equal(output_list[[1]]$sample_end_position, 8)
+   expect_equal(output_list[[2]]$sample_end_position, 2)
+   expect_equal(output_list[[3]]$sample_end_position[1], 9)
+   
+})