Switch to side-by-side view

--- a
+++ b/vignettes/integrated_gradient.Rmd
@@ -0,0 +1,279 @@
+---
+title: "Integrated Gradient"
+output: rmarkdown::html_vignette
+vignette: >
+  %\VignetteIndexEntry{Integrated Gradient}
+  %\VignetteEngine{knitr::rmarkdown}
+  %\VignetteEncoding{UTF-8}
+---
+
+```{r, echo=FALSE, warning=FALSE, message=FALSE}
+
+if (!reticulate::py_module_available("tensorflow")) {
+  knitr::opts_chunk$set(eval = FALSE)
+} else {
+  knitr::opts_chunk$set(eval = TRUE)
+}
+```
+  
+```{r, message=FALSE}
+library(deepG)
+library(keras)
+library(magrittr)
+library(ggplot2)
+library(reticulate)
+```
+
+
+```{r, echo=FALSE, warning=FALSE, message=FALSE}
+options(rmarkdown.html_vignette.check_title = FALSE)
+```
+
+```{css, echo=FALSE}
+mark.in {
+background-color: CornflowerBlue;
+}
+
+mark.out {
+background-color: IndianRed;
+}
+
+```
+
+## Introduction 
+
+The  <a href="https://arxiv.org/abs/1703.01365">Integrated Gradient</a> (IG) method can be used to determine what parts of an input sequence are important for the models decision.
+We start with training a model that can differentiate sequences based on the GC content 
+(as described in the <a href="getting_started.html">Getting started tutorial</a>). 
+
+
+## Model Training
+
+We create two simple dummy training and validation data sets. Both consist of random <tt>ACGT</tt> sequences but the first category has 
+a probability of 40% each for drawing <tt>G</tt> or <tt>C</tt> and the second has equal probability for each nucleotide (first category has around 80% <tt>GC</tt> content and second one around 50%).   
+
+```{r warning = FALSE}
+set.seed(123)
+
+# Create data 
+vocabulary <- c("A", "C", "G", "T")
+data_type <- c("train_1", "train_2", "val_1", "val_2")
+
+for (i in 1:length(data_type)) {
+  
+  temp_file <- tempfile()
+  assign(paste0(data_type[i], "_dir"), temp_file)
+  dir.create(temp_file)
+  
+  if (i %% 2 == 1) {
+    header <- "label_1"
+    prob <- c(0.1, 0.4, 0.4, 0.1)
+  } else {
+    header <- "label_2"
+    prob <- rep(0.25, 4)
+  }
+  fasta_name_start <- paste0(header, "_", data_type[i], "file")
+  
+  create_dummy_data(file_path = temp_file,
+                    num_files = 1,
+                    seq_length = 20000, 
+                    num_seq = 1,
+                    header = header,
+                    prob = prob,
+                    fasta_name_start = fasta_name_start,
+                    vocabulary = vocabulary)
+  
+}
+
+# Create model
+maxlen <- 50
+model <- create_model_lstm_cnn(maxlen = maxlen,
+                               filters = c(8, 16),
+                               kernel_size = c(8, 8),
+                               pool_size = c(3, 3),
+                               layer_lstm = 8,
+                               layer_dense = c(4, 2),
+                               model_seed = 3)
+
+# Train model
+hist <- train_model(model,
+                    train_type = "label_folder",
+                    run_name = "gc_model_1",
+                    path = c(train_1_dir, train_2_dir),
+                    path_val = c(val_1_dir, val_2_dir),
+                    epochs = 6, 
+                    batch_size = 64,
+                    steps_per_epoch = 50, 
+                    step = 50, 
+                    vocabulary_label = c("high_gc", "equal_dist"))
+
+plot(hist)
+```
+
+
+## Integrated Gradient
+
+We can try to visualize what parts of an input sequence is important for the models decision, using Integrated Gradient.
+Let's create a sequence with a high GC content. We use same number of Cs as Gs and of As as Ts.
+
+```{r warning = FALSE}
+set.seed(321)
+g_count <- 17
+stopifnot(g_count < 25)
+a_count <- (50 - (2*g_count))/2  
+high_gc_seq <- c(rep("G", g_count), rep("C", g_count), rep("A", a_count), rep("T", a_count))
+high_gc_seq <- high_gc_seq[sample(maxlen)] %>% paste(collapse = "") # shuffle nt order
+high_gc_seq
+```
+
+We need to one-hot encode the sequence before applying Integrated Gradient.
+
+```{r warning = FALSE}
+high_gc_seq_one_hot <- seq_encoding_label(char_sequence = high_gc_seq,
+                                          maxlen = 50,
+                                          start_ind = 1,
+                                          vocabulary = vocabulary)
+head(high_gc_seq_one_hot[1,,])
+```
+
+Our model should be confident, this sequences belongs to the first class
+
+```{r warning = FALSE}
+pred <- predict(model, high_gc_seq_one_hot, verbose = 0)
+colnames(pred) <- c("high_gc", "equal_dist")
+pred
+```
+
+We can visualize what parts where important for the prediction.
+
+```{r warning = FALSE}
+ig <- integrated_gradients(
+  input_seq = high_gc_seq_one_hot,
+  target_class_idx = 1,
+  model = model)
+
+if (requireNamespace("ComplexHeatmap", quietly = TRUE)) {
+  heatmaps_integrated_grad(integrated_grads = ig,
+                           input_seq = high_gc_seq_one_hot)
+} else {
+  message("Skipping ComplexHeatmap-related code because the package is not installed.")
+}
+
+```
+
+We may test how our models prediction changes if we exchange certain nucleotides in the input sequence.
+First, we look for the positions with the smallest IG score.
+
+```{r warning = FALSE}
+ig <- as.array(ig)
+smallest_index <- which(ig == min(ig), arr.ind = TRUE)
+smallest_index
+```
+
+We may change the nucleotide with the lowest score and observe the change in prediction confidence
+
+```{r warning = FALSE}
+# copy original sequence
+high_gc_seq_one_hot_changed <- high_gc_seq_one_hot 
+
+# prediction for original sequence
+predict(model, high_gc_seq_one_hot, verbose = 0)
+
+# change nt
+smallest_index <- which(ig == min(ig), arr.ind = TRUE)
+smallest_index
+row_index <- smallest_index[ , "row"]
+col_index <- smallest_index[ , "col"]               
+new_row <- rep(0, 4)
+nt_index_old <- col_index
+nt_index_new <- which.max(ig[row_index, ])
+new_row[nt_index_new] <- 1
+high_gc_seq_one_hot_changed[1, row_index, ] <- new_row
+cat("At position", row_index, "changing", vocabulary[nt_index_old], "to", vocabulary[nt_index_new], "\n")
+
+pred <- predict(model, high_gc_seq_one_hot_changed, verbose = 0)
+print(pred)
+```
+
+Let's repeatedly apply the previous step and change the sequence after each iteration.
+
+```{r warning = FALSE}
+# copy original sequence
+high_gc_seq_one_hot_changed <- high_gc_seq_one_hot 
+
+pred_list <- list()
+pred_list[[1]] <- pred <- predict(model, high_gc_seq_one_hot, verbose = 0)
+
+# change nts
+for (i in 1:20) {
+  
+  # update ig scores for changed input
+  ig <- integrated_gradients(
+    input_seq = high_gc_seq_one_hot_changed,
+    target_class_idx = 1,
+    model = model) %>% as.array()
+  
+  smallest_index <- which(ig == min(ig), arr.ind = TRUE)
+  smallest_index
+  row_index <- smallest_index[ , "row"]
+  col_index <- smallest_index[ , "col"]               
+  new_row <- rep(0, 4)
+  nt_index_old <- col_index
+  nt_index_new <- which.max(ig[row_index, ])
+  new_row[nt_index_new] <- 1
+  high_gc_seq_one_hot_changed[1, row_index, ] <- new_row
+  cat("At position", row_index, "changing", vocabulary[nt_index_old],
+      "to", vocabulary[nt_index_new], "\n")
+  pred <- predict(model, high_gc_seq_one_hot_changed, verbose = 0)
+  pred_list[[i + 1]] <- pred 
+  
+}
+
+pred_df <- do.call(rbind, pred_list)
+pred_df <- data.frame(pred_df, iteration = 0:(nrow(pred_df) - 1))
+names(pred_df) <- c("high_gc", "equal_dist", "iteration")
+ggplot(pred_df, aes(x = iteration, y = high_gc)) + geom_line() + ylab("high GC confidence")
+
+```
+
+We can try the same in the opposite direction, i.e. replace big IG scores.
+
+```{r warning = FALSE}
+# copy original sequence
+high_gc_seq_one_hot_changed <- high_gc_seq_one_hot 
+
+pred_list <- list()
+pred <- predict(model, high_gc_seq_one_hot, verbose = 0)
+pred_list[[1]] <- pred
+
+# change nts
+for (i in 1:20) {
+  
+  # update ig scores for changed input
+  ig <- integrated_gradients(
+    input_seq = high_gc_seq_one_hot_changed,
+    target_class_idx = 1,
+    model = model) %>% as.array()
+  
+  biggest_index <- which(ig == max(ig), arr.ind = TRUE)
+  biggest_index
+  row_index <- biggest_index[ , "row"]
+  row_index <- row_index[1]
+  col_index <- biggest_index[ , "col"]               
+  new_row <- rep(0, 4)
+  nt_index_old <- col_index
+  nt_index_new <- which.min(ig[row_index, ])
+  new_row[nt_index_new] <- 1
+  high_gc_seq_one_hot_changed[1, row_index, ] <- new_row
+  cat("At position", row_index, "changing", vocabulary[nt_index_old], "to", vocabulary[nt_index_new], "\n")
+  
+  pred <- predict(model, high_gc_seq_one_hot_changed, verbose = 0)
+  pred_list[[i + 1]] <- pred 
+  
+}
+
+pred_df <- do.call(rbind, pred_list)
+pred_df <- data.frame(pred_df, iteration = 0:(nrow(pred_df) - 1))
+names(pred_df) <- c("high_gc", "equal_dist", "iteration")
+ggplot(pred_df, aes(x = iteration, y = high_gc)) + geom_line() + ylab("high GC confidence")
+```