--- a +++ b/vignettes/integrated_gradient.Rmd @@ -0,0 +1,279 @@ +--- +title: "Integrated Gradient" +output: rmarkdown::html_vignette +vignette: > + %\VignetteIndexEntry{Integrated Gradient} + %\VignetteEngine{knitr::rmarkdown} + %\VignetteEncoding{UTF-8} +--- + +```{r, echo=FALSE, warning=FALSE, message=FALSE} + +if (!reticulate::py_module_available("tensorflow")) { + knitr::opts_chunk$set(eval = FALSE) +} else { + knitr::opts_chunk$set(eval = TRUE) +} +``` + +```{r, message=FALSE} +library(deepG) +library(keras) +library(magrittr) +library(ggplot2) +library(reticulate) +``` + + +```{r, echo=FALSE, warning=FALSE, message=FALSE} +options(rmarkdown.html_vignette.check_title = FALSE) +``` + +```{css, echo=FALSE} +mark.in { +background-color: CornflowerBlue; +} + +mark.out { +background-color: IndianRed; +} + +``` + +## Introduction + +The <a href="https://arxiv.org/abs/1703.01365">Integrated Gradient</a> (IG) method can be used to determine what parts of an input sequence are important for the models decision. +We start with training a model that can differentiate sequences based on the GC content +(as described in the <a href="getting_started.html">Getting started tutorial</a>). + + +## Model Training + +We create two simple dummy training and validation data sets. Both consist of random <tt>ACGT</tt> sequences but the first category has +a probability of 40% each for drawing <tt>G</tt> or <tt>C</tt> and the second has equal probability for each nucleotide (first category has around 80% <tt>GC</tt> content and second one around 50%). + +```{r warning = FALSE} +set.seed(123) + +# Create data +vocabulary <- c("A", "C", "G", "T") +data_type <- c("train_1", "train_2", "val_1", "val_2") + +for (i in 1:length(data_type)) { + + temp_file <- tempfile() + assign(paste0(data_type[i], "_dir"), temp_file) + dir.create(temp_file) + + if (i %% 2 == 1) { + header <- "label_1" + prob <- c(0.1, 0.4, 0.4, 0.1) + } else { + header <- "label_2" + prob <- rep(0.25, 4) + } + fasta_name_start <- paste0(header, "_", data_type[i], "file") + + create_dummy_data(file_path = temp_file, + num_files = 1, + seq_length = 20000, + num_seq = 1, + header = header, + prob = prob, + fasta_name_start = fasta_name_start, + vocabulary = vocabulary) + +} + +# Create model +maxlen <- 50 +model <- create_model_lstm_cnn(maxlen = maxlen, + filters = c(8, 16), + kernel_size = c(8, 8), + pool_size = c(3, 3), + layer_lstm = 8, + layer_dense = c(4, 2), + model_seed = 3) + +# Train model +hist <- train_model(model, + train_type = "label_folder", + run_name = "gc_model_1", + path = c(train_1_dir, train_2_dir), + path_val = c(val_1_dir, val_2_dir), + epochs = 6, + batch_size = 64, + steps_per_epoch = 50, + step = 50, + vocabulary_label = c("high_gc", "equal_dist")) + +plot(hist) +``` + + +## Integrated Gradient + +We can try to visualize what parts of an input sequence is important for the models decision, using Integrated Gradient. +Let's create a sequence with a high GC content. We use same number of Cs as Gs and of As as Ts. + +```{r warning = FALSE} +set.seed(321) +g_count <- 17 +stopifnot(g_count < 25) +a_count <- (50 - (2*g_count))/2 +high_gc_seq <- c(rep("G", g_count), rep("C", g_count), rep("A", a_count), rep("T", a_count)) +high_gc_seq <- high_gc_seq[sample(maxlen)] %>% paste(collapse = "") # shuffle nt order +high_gc_seq +``` + +We need to one-hot encode the sequence before applying Integrated Gradient. + +```{r warning = FALSE} +high_gc_seq_one_hot <- seq_encoding_label(char_sequence = high_gc_seq, + maxlen = 50, + start_ind = 1, + vocabulary = vocabulary) +head(high_gc_seq_one_hot[1,,]) +``` + +Our model should be confident, this sequences belongs to the first class + +```{r warning = FALSE} +pred <- predict(model, high_gc_seq_one_hot, verbose = 0) +colnames(pred) <- c("high_gc", "equal_dist") +pred +``` + +We can visualize what parts where important for the prediction. + +```{r warning = FALSE} +ig <- integrated_gradients( + input_seq = high_gc_seq_one_hot, + target_class_idx = 1, + model = model) + +if (requireNamespace("ComplexHeatmap", quietly = TRUE)) { + heatmaps_integrated_grad(integrated_grads = ig, + input_seq = high_gc_seq_one_hot) +} else { + message("Skipping ComplexHeatmap-related code because the package is not installed.") +} + +``` + +We may test how our models prediction changes if we exchange certain nucleotides in the input sequence. +First, we look for the positions with the smallest IG score. + +```{r warning = FALSE} +ig <- as.array(ig) +smallest_index <- which(ig == min(ig), arr.ind = TRUE) +smallest_index +``` + +We may change the nucleotide with the lowest score and observe the change in prediction confidence + +```{r warning = FALSE} +# copy original sequence +high_gc_seq_one_hot_changed <- high_gc_seq_one_hot + +# prediction for original sequence +predict(model, high_gc_seq_one_hot, verbose = 0) + +# change nt +smallest_index <- which(ig == min(ig), arr.ind = TRUE) +smallest_index +row_index <- smallest_index[ , "row"] +col_index <- smallest_index[ , "col"] +new_row <- rep(0, 4) +nt_index_old <- col_index +nt_index_new <- which.max(ig[row_index, ]) +new_row[nt_index_new] <- 1 +high_gc_seq_one_hot_changed[1, row_index, ] <- new_row +cat("At position", row_index, "changing", vocabulary[nt_index_old], "to", vocabulary[nt_index_new], "\n") + +pred <- predict(model, high_gc_seq_one_hot_changed, verbose = 0) +print(pred) +``` + +Let's repeatedly apply the previous step and change the sequence after each iteration. + +```{r warning = FALSE} +# copy original sequence +high_gc_seq_one_hot_changed <- high_gc_seq_one_hot + +pred_list <- list() +pred_list[[1]] <- pred <- predict(model, high_gc_seq_one_hot, verbose = 0) + +# change nts +for (i in 1:20) { + + # update ig scores for changed input + ig <- integrated_gradients( + input_seq = high_gc_seq_one_hot_changed, + target_class_idx = 1, + model = model) %>% as.array() + + smallest_index <- which(ig == min(ig), arr.ind = TRUE) + smallest_index + row_index <- smallest_index[ , "row"] + col_index <- smallest_index[ , "col"] + new_row <- rep(0, 4) + nt_index_old <- col_index + nt_index_new <- which.max(ig[row_index, ]) + new_row[nt_index_new] <- 1 + high_gc_seq_one_hot_changed[1, row_index, ] <- new_row + cat("At position", row_index, "changing", vocabulary[nt_index_old], + "to", vocabulary[nt_index_new], "\n") + pred <- predict(model, high_gc_seq_one_hot_changed, verbose = 0) + pred_list[[i + 1]] <- pred + +} + +pred_df <- do.call(rbind, pred_list) +pred_df <- data.frame(pred_df, iteration = 0:(nrow(pred_df) - 1)) +names(pred_df) <- c("high_gc", "equal_dist", "iteration") +ggplot(pred_df, aes(x = iteration, y = high_gc)) + geom_line() + ylab("high GC confidence") + +``` + +We can try the same in the opposite direction, i.e. replace big IG scores. + +```{r warning = FALSE} +# copy original sequence +high_gc_seq_one_hot_changed <- high_gc_seq_one_hot + +pred_list <- list() +pred <- predict(model, high_gc_seq_one_hot, verbose = 0) +pred_list[[1]] <- pred + +# change nts +for (i in 1:20) { + + # update ig scores for changed input + ig <- integrated_gradients( + input_seq = high_gc_seq_one_hot_changed, + target_class_idx = 1, + model = model) %>% as.array() + + biggest_index <- which(ig == max(ig), arr.ind = TRUE) + biggest_index + row_index <- biggest_index[ , "row"] + row_index <- row_index[1] + col_index <- biggest_index[ , "col"] + new_row <- rep(0, 4) + nt_index_old <- col_index + nt_index_new <- which.min(ig[row_index, ]) + new_row[nt_index_new] <- 1 + high_gc_seq_one_hot_changed[1, row_index, ] <- new_row + cat("At position", row_index, "changing", vocabulary[nt_index_old], "to", vocabulary[nt_index_new], "\n") + + pred <- predict(model, high_gc_seq_one_hot_changed, verbose = 0) + pred_list[[i + 1]] <- pred + +} + +pred_df <- do.call(rbind, pred_list) +pred_df <- data.frame(pred_df, iteration = 0:(nrow(pred_df) - 1)) +names(pred_df) <- c("high_gc", "equal_dist", "iteration") +ggplot(pred_df, aes(x = iteration, y = high_gc)) + geom_line() + ylab("high GC confidence") +```