--- a
+++ b/R/visualization.R
@@ -0,0 +1,444 @@
+#' Interpolation between baseline and prediction
+#'
+#' @param baseline_type Baseline sequence, either "zero" for all zeros or "shuffle" for random permutation of input_seq.
+#' @param m_steps Number of steps between baseline and original input.
+#' @param input_seq Input tensor.
+#' @noRd
+interpolate_seq <- function(m_steps = 50,
+                            baseline_type = "shuffle",
+                            input_seq) {
+  
+  stopifnot(baseline_type %in% c("zero", "shuffle", "unif"))
+  if (is.list(input_seq)) {
+    baseline <- list()
+    for (i in 1:length(input_seq)) {
+      input_dim <- dim(input_seq[[i]])
+      if (baseline_type == "zero") {
+        baseline[[i]] <- array(rep(0, prod(input_dim)), dim = input_dim)
+      } 
+      if (baseline_type == "shuffle") {
+        input_dim <- dim(input_seq[[i]])
+        baseline[[i]] <- array(input_seq[[i]][ , sample(input_dim[2]), ], dim = input_dim)
+      }
+      if (baseline_type == "unif") {
+        baseline[[i]] <- array(stats::runif(prod(input_dim)), dim = input_dim)
+      } 
+    }
+  } else {
+    if (baseline_type == "zero") {
+      baseline <- array(rep(0, prod(dim(input_seq))), dim = dim(input_seq))
+    } 
+    if (baseline_type == "shuffle") {
+      baseline <- array(input_seq[ , sample(dim(input_seq)[2]), ], dim = dim(input_seq))
+    }
+    if (baseline_type == "unif") {
+      baseline <- array(stats::runif(prod(dim(input_seq))), dim = dim(input_seq))
+    } 
+  }
+  
+  m_steps <- as.integer(m_steps)
+  alphas <- tensorflow::tf$linspace(start = 0.0, stop = 1.0, num = m_steps + 1L) # Generate m_steps intervals for integral_approximation() below.
+  alphas_x <- alphas[ , tensorflow::tf$newaxis, tensorflow::tf$newaxis]
+  if (is.list(baseline)) {
+    delta <- list()
+    sequences <- list()
+    for (i in 1:length(baseline)) {
+      delta[[i]] <- input_seq[[i]] - baseline[[i]]
+      sequences[[i]] <- baseline[[i]] +  alphas_x * delta[[i]]
+    }
+  } else {
+    delta <- input_seq - baseline
+    sequences <- baseline +  alphas_x * delta
+  }
+  return(sequences)
+}
+
+#' Compute gradients
+#'
+#' @param input_idx  Input layer to monitor for > 1 input.
+#' @param target_class_idx Index of class to compute gradient for.
+#' @param model Model to compute gradient for.
+#' @param pred_stepwise Whether to do predictions with batch_size 1 rather than all at once. Can be used if
+#' input is too big to handle at once.
+#' @noRd
+compute_gradients <- function(input_tensor, target_class_idx, model, input_idx = NULL, pred_stepwise = FALSE) {
+  
+  # if (is.list(input_tensor)) {
+  #   stop("Stepwise predictions only supported for single input layer yet")
+  # }
+
+  reticulate::py_run_string("import tensorflow as tf")
+  py$input_tensor <- input_tensor
+  py$input_idx <- as.integer(input_idx - 1)
+  py$target_class_idx <- as.integer(target_class_idx - 1)
+  py$model <- model
+  
+  if (!is.null(input_idx)) {
+    reticulate::py_run_string(
+      "with tf.GradientTape() as tape:
+             tape.watch(input_tensor[input_idx])
+             probs = model(input_tensor)[:, target_class_idx]
+    ")
+  } else {
+    reticulate::py_run_string(
+      "with tf.GradientTape() as tape:
+             tape.watch(input_tensor)
+             probs = model(input_tensor)[:, target_class_idx]
+    ")
+  }
+  
+  grad <- py$tape$gradient(py$probs, py$input_tensor)
+  if (!is.null(input_idx)) {
+    return(grad[[input_idx]])
+  } else {
+    return(grad)
+  }
+}
+
+integral_approximation <- function(gradients) {
+  reticulate::py_run_string("import tensorflow as tf")
+  py$gradients <- gradients
+  # riemann_trapezoidal
+  reticulate::py_run_string("grads = (gradients[:-1] + gradients[1:]) / tf.constant(2.0)")
+  reticulate::py_run_string("integrated_gradients = tf.math.reduce_mean(grads, axis=0)")
+  return(py$integrated_gradients)
+}
+
+#' Compute integrated gradients 
+#' 
+#' Computes integrated gradients scores for model and an input sequence.
+#' This can be used to visualize what part of the input is import for the models decision.
+#' Code is R implementation of python code from [here](https://www.tensorflow.org/tutorials/interpretability/integrated_gradients).
+#' Tensorflow implementation is based on this [paper](https://arxiv.org/abs/1703.01365).
+#' 
+#' @param baseline_type Baseline sequence, either `"zero"` for all zeros or `"shuffle"` for random permutation of `input_seq`.
+#' @param m_steps Number of steps between baseline and original input.
+#' @param input_seq Input tensor.
+#' @param target_class_idx Index of class to compute gradient for
+#' @param model Model to compute gradient for.
+#' @param pred_stepwise Whether to do predictions with batch size 1 rather than all at once. Can be used if
+#' input is too big to handle at once. Only supported for single input layer.
+#' @param num_baseline_repeats Number of different baseline estimations if baseline_type is `"shuffle"` (estimate integrated
+#' gradient repeatedly for different shuffles). Final result is average of \code{num_baseline} single calculations.
+#' @examplesIf reticulate::py_module_available("tensorflow")
+#' library(reticulate)
+#' model <- create_model_lstm_cnn(layer_lstm = 8, layer_dense = 3, maxlen = 20, verbose = FALSE)
+#' random_seq <- sample(0:3, 20, replace = TRUE)
+#' input_seq <- array(keras::to_categorical(random_seq), dim = c(1, 20, 4))
+#' integrated_gradients(
+#'   input_seq = input_seq,
+#'   target_class_idx = 3,
+#'   model = model)
+#'   
+#' @returns A tensorflow tensor.
+#' @export
+integrated_gradients <- function(m_steps = 50,
+                                 baseline_type = "zero",
+                                 input_seq,
+                                 target_class_idx,
+                                 model,
+                                 pred_stepwise = FALSE,
+                                 num_baseline_repeats = 1) {
+
+  #library(reticulate)
+  reticulate::py_run_string("import tensorflow as tf")
+  input_idx <- NULL
+  if (num_baseline_repeats > 1 & baseline_type == "zero") {
+    warning('Ignoring num_baseline_repeats if baseline is of type "zero". Did you mean to use baseline_type = "shuffle"?')
+  }
+  
+  if (num_baseline_repeats == 1 | baseline_type == "zero") {
+    
+    baseline_seq <- interpolate_seq(m_steps = m_steps,
+                                    baseline_type = baseline_type,
+                                    input_seq = input_seq)
+    
+    if (is.list(baseline_seq)) {
+      for (i in 1:length(baseline_seq)) {
+        baseline_seq[[i]] <- tensorflow::tf$cast(baseline_seq[[i]], dtype = "float32")
+      }
+    } else {
+      baseline_seq <- tensorflow::tf$cast(baseline_seq, dtype = "float32")
+    }
+    
+    if (is.list(input_seq)) {
+      path_gradients <- list()
+      avg_grads <- list()
+      ig <- list()
+      
+      if (pred_stepwise) {
+        path_gradients <- gradients_stepwise(
+          model = model,
+          baseline_seq = baseline_seq,
+          target_class_idx = target_class_idx)
+      } else {
+        
+        path_gradients <- compute_gradients(
+          model = model,
+          input_tensor = baseline_seq,
+          target_class_idx = target_class_idx,
+          input_idx = NULL,
+          pred_stepwise = pred_stepwise)
+      }
+      
+      for (i in 1:length(input_seq)) {
+        avg_grads[[i]] <- integral_approximation(gradients = path_gradients[[i]])
+        ig[[i]] <- ((input_seq[[i]] - baseline_seq[[i]][1, , ]) * avg_grads[[i]])[1, , ]
+      }
+    } else {
+      
+      if (pred_stepwise) {
+        path_gradients <- gradients_stepwise(model = model,
+                                             baseline_seq = baseline_seq,
+                                             target_class_idx = target_class_idx,
+                                             input_idx = NULL)
+      } else {
+        path_gradients <- compute_gradients(
+          model = model,
+          input_tensor = baseline_seq,
+          target_class_idx = target_class_idx,
+          input_idx = NULL,
+          pred_stepwise = pred_stepwise)
+      }
+      
+      avg_grads <- integral_approximation(gradients = path_gradients)
+      ig <- ((input_seq - baseline_seq[1, , ]) * avg_grads)[1, , ]
+    }
+  } else {
+    ig_list <- list()
+    for (i in 1:num_baseline_repeats) {
+      ig_list[[i]] <- integrated_gradients(m_steps = m_steps,
+                                           baseline_type = "shuffle",
+                                           input_seq = input_seq,
+                                           target_class_idx = target_class_idx,
+                                           model = model,
+                                           pred_stepwise = pred_stepwise,
+                                           num_baseline_repeats = 1)
+    }
+    ig_stacked <- tensorflow::tf$stack(ig_list, axis = 0L)
+    ig <- tensorflow::tf$reduce_mean(ig_stacked, axis = 0L)
+  }
+  
+  return(ig)
+}
+
+#' Compute gradients stepwise (one batch at a time)
+#'
+#' @noRd
+gradients_stepwise <- function(model = model, baseline_seq, target_class_idx,
+                               input_idx = NULL) {
+  
+  if (is.list(baseline_seq)) {
+    first_dim <- dim(baseline_seq[[1]])[1]
+    num_input_layers <- length(baseline_seq)
+    
+    l <- list()
+    for (j in 1:first_dim) {
+      input_list <- list()
+      for (k in 1:length(baseline_seq)) {
+        input <- as.array(baseline_seq[[k]][j, , ])
+        input <- array(input, dim = c(1, dim(baseline_seq[[k]])[-1]))
+        input <- tensorflow::tf$cast(input, baseline_seq[[k]]$dtype)
+        input_list[[k]] <- input
+      }
+      output <- compute_gradients(
+        model = model,
+        input_tensor = input_list,
+        target_class_idx = target_class_idx,
+        input_idx = NULL)
+      for (m in 1:length(output)) {
+        output[[m]] <- tensorflow::tf$squeeze(output[[m]])
+      }
+      l[[j]] <- output
+    }
+    
+    path_gradients <- vector("list", num_input_layers)
+    for (n in 1:num_input_layers) {
+      temp_list <- vector("list", first_dim)
+      for (p in 1:first_dim){
+        temp_list[[p]] <- l[[p]][[n]]
+      }
+      path_gradients[[n]] <- tensorflow::tf$stack(temp_list)
+    }
+    
+  } else {
+    l <- list()
+    for (j in 1:dim(baseline_seq)[1]) {
+      input <- as.array(baseline_seq[j, , ])
+      input <- array(input, dim = c(1, dim(baseline_seq)[-1]))
+      input <- tensorflow::tf$cast(input, baseline_seq$dtype)
+      output <- compute_gradients(
+        model = model,
+        input_tensor = input,
+        target_class_idx = target_class_idx,
+        input_idx = NULL)
+      output <- tensorflow::tf$squeeze(output)
+      l[[j]] <- output
+    }
+    path_gradients <- tensorflow::tf$stack(l)
+  }
+  return(path_gradients)
+}
+
+
+#' Heatmap of integrated gradient scores
+#' 
+#' Creates a heatmap from output of \code{\link{integrated_gradients}} function. The first row contains 
+#' the column-wise absolute sums of IG scores and the second row the sums. Rows 3 to 6 contain the IG scores for each 
+#' position and each nucleotide. The last row contains nucleotide information.
+#'
+#' @param integrated_grads Matrix of integrated gradient scores (output of \code{\link{integrated_gradients}} function).
+#' @param input_seq Input sequence for model. Should be the same as \code{input_seq} input for corresponding
+#' \code{\link{integrated_gradients}} call that computed input for \code{integrated_grads} argument.
+#' @examplesIf reticulate::py_module_available("tensorflow")  && requireNamespace("ComplexHeatmap", quietly = TRUE)
+#' library(reticulate)
+#' model <- create_model_lstm_cnn(layer_lstm = 8, layer_dense = 3, maxlen = 20, verbose = FALSE)
+#' random_seq <- sample(0:3, 20, replace = TRUE)
+#' input_seq <- array(keras::to_categorical(random_seq), dim = c(1, 20, 4))
+#' ig <- integrated_gradients(
+#'   input_seq = input_seq,
+#'   target_class_idx = 3,
+#'   model = model)
+#' heatmaps_integrated_grad(integrated_grads = ig,
+#'                          input_seq = input_seq)
+#'  
+#' @returns A list of heatmaps.                          
+#' @export
+heatmaps_integrated_grad <- function(integrated_grads,
+                                     input_seq) {
+  
+  if (is.list(input_seq)) {
+    for (i in 1:length(input_seq)) {
+      input_seq[[i]] <- tensorflow::tf$cast(input_seq[[i]], dtype = "float32")
+    }
+    
+    for (i in 1:length(integrated_grads)) {
+      integrated_grads[[i]] <- tensorflow::tf$cast(integrated_grads[[i]], dtype = "float32")
+    }
+  } else {
+    input_seq <- tensorflow::tf$cast(input_seq, dtype = "float32")
+    integrated_grads <- tensorflow::tf$cast(integrated_grads, dtype = "float32")
+  }
+  
+  
+  if (is.list(input_seq)) {
+    num_input <- length(input_seq)
+    attribution_mask <- list()
+    nuc_matrix <- list()
+    nuc_seq <- list()
+    sum_nuc <- list()
+    for (i in 1:length(integrated_grads)) {
+      py$integrated_grads <- integrated_grads[[i]]
+      reticulate::py_run_string("attribution_mask = tf.reduce_sum(tf.math.abs(integrated_grads), axis=-1)")
+      reticulate::py_run_string("sum_nuc = tf.reduce_sum(integrated_grads, axis=-1)")
+      attribution_mask[[i]] <- py$attribution_mask
+      attribution_mask[[i]] <- as.matrix(attribution_mask[[i]], nrow = 1) %>% as.data.frame()
+      colnames(attribution_mask[[i]]) <- "abs_sum"
+      sum_nuc[[i]] <- py$sum_nuc
+      sum_nuc[[i]] <- as.matrix(sum_nuc[[i]], nrow = 1) %>% as.data.frame()
+      colnames(sum_nuc[[i]]) <- "sum"
+      
+      if (length(dim(integrated_grads[[i]])) == 3) {
+        nuc_matrix[[i]] <- as.matrix(integrated_grads[[i]][1, , ])
+      }
+      if (length(dim(integrated_grads[[i]])) == 2) {
+        nuc_matrix[[i]] <- as.matrix(integrated_grads[[i]])
+      }
+      amb_nuc <- (apply(input_seq[[i]][1, ,], 1, max) %>% as.character()) != "1"
+      nuc_seq[[i]] <- apply(input_seq[[i]][1, ,], 1, which.max) %>% as.character()
+      nuc_seq[[i]] <- nuc_seq[[i]] %>% stringr::str_replace_all("1", "A") %>%
+        stringr::str_replace_all("2", "C") %>%
+        stringr::str_replace_all("3", "G") %>%
+        stringr::str_replace_all("4", "T")
+      nuc_seq[[i]][amb_nuc] <- "0"
+      rownames(nuc_matrix[[i]]) <- nuc_seq[[i]]
+      colnames(nuc_matrix[[i]]) <- c("A", "C", "G", "T")
+    }
+    
+  } else {
+    num_input <- 1
+    py$integrated_grads <- integrated_grads
+    reticulate::py_run_string("attribution_mask = tf.reduce_sum(tf.math.abs(integrated_grads), axis=-1)")
+    reticulate::py_run_string("sum_nuc = tf.reduce_sum(integrated_grads, axis=-1)")
+    #py_run_string("mean_nuc = tf.reduce_mean(integrated_grads, axis=-1)")
+    
+    attribution_mask <- py$attribution_mask
+    attribution_mask <- as.matrix(attribution_mask, nrow = 1) %>% as.data.frame()
+    colnames(attribution_mask) <- "abs_sum"
+    
+    sum_nuc <- py$sum_nuc
+    sum_nuc <- as.matrix(sum_nuc, nrow = 1) %>% as.data.frame()
+    colnames(sum_nuc) <- "sum"
+    
+    if (length(dim(integrated_grads)) == 3) {
+      nuc_matrix <- as.matrix(integrated_grads[1, , ])
+    }
+    if (length(dim(integrated_grads)) == 2) {
+      nuc_matrix <- as.matrix(integrated_grads)
+    }
+    amb_nuc <- (apply(input_seq[1, ,], 1, max) %>% as.character()) != "1"
+    nuc_seq <- apply(input_seq[1, ,], 1, which.max) %>% as.character()
+    nuc_seq <- nuc_seq %>% stringr::str_replace_all("1", "A") %>%
+      stringr::str_replace_all("2", "C") %>%
+      stringr::str_replace_all("3", "G") %>%
+      stringr::str_replace_all("4", "T")
+    nuc_seq[amb_nuc] <- "0"
+    rownames(nuc_matrix) <- nuc_seq
+    colnames(nuc_matrix) <- c("A", "C", "G", "T")
+  }
+  
+  if (num_input == 1) {
+    ig_min <- keras::k_min(integrated_grads)$numpy()
+    ig_max <- keras::k_max(integrated_grads)$numpy()
+    col_fun <- circlize::colorRamp2(c(ig_min, mean(c(ig_max, ig_min)) , ig_max), c("blue", "white", "red"))
+  } else {
+    col_fun <- list()
+    for (i in 1:num_input) {
+      ig_min <- keras::k_min(integrated_grads[[i]])$numpy()
+      ig_max <- keras::k_max(integrated_grads[[i]])$numpy()
+      col_fun[[i]] <- circlize::colorRamp2(c(ig_min, mean(c(ig_max, ig_min)) , ig_max), c("blue", "white", "red"))
+    }
+  }
+  
+  hm_list <- list()
+  if (num_input == 1) {
+    row_ha = ComplexHeatmap::columnAnnotation(abs_sum = attribution_mask[,1], sum = sum_nuc[,1]) # mean = mean_nuc[,1]
+    if (length(unique(row.names(nuc_matrix))) == 4) {
+      nuc_col <- c("A" = "red", "C" = "green", "G" = "blue", "T" = "yellow")
+    }
+    if (length(unique(row.names(nuc_matrix))) == 5) {
+      nuc_col <- c("A" = "red", "C" = "green", "G" = "blue", "T" = "yellow", "0" = "white")
+    }
+    ha <- ComplexHeatmap::HeatmapAnnotation(nuc = row.names(nuc_matrix), col = list(nuc = nuc_col))
+    hm_list[[1]] <- ComplexHeatmap::Heatmap(matrix = t(nuc_matrix),
+                                            name = "hm",
+                                            top_annotation = row_ha,
+                                            bottom_annotation = ha,
+                                            col = col_fun,
+                                            cluster_rows = FALSE,
+                                            cluster_columns = FALSE,
+                                            column_names_rot = 0
+    )
+  } else {
+    for (i in 1:num_input) {
+      row_ha <- ComplexHeatmap::columnAnnotation(abs_sum = attribution_mask[[i]][,1], sum = sum_nuc[[i]][,1])
+      if (length(unique(row.names(nuc_matrix[[i]]))) == 4) {
+        nuc_col <- c("A" = "red", "C" = "green", "G" = "blue", "T" = "yellow")
+      }
+      if (length(unique(row.names(nuc_matrix[[i]]))) == 5) {
+        nuc_col <- c("A" = "red", "C" = "green", "G" = "blue", "T" = "yellow", "0" = "white")
+      }
+      ha <- ComplexHeatmap::HeatmapAnnotation(nuc = row.names(nuc_matrix[[i]]), col = list(nuc = nuc_col))
+      hm_list[[i]] <- ComplexHeatmap::Heatmap(matrix = t(nuc_matrix[[i]]),
+                                              name = paste0("hm_", i),
+                                              top_annotation = row_ha,
+                                              bottom_annotation = ha,
+                                              col = col_fun[[i]],
+                                              cluster_rows = FALSE,
+                                              cluster_columns = FALSE,
+                                              column_names_rot = 0
+      )
+    }
+  }
+  hm_list
+}