Skip to contents
# devtools::install_github("GenomeNet/deepG")
library(deepG)
library(magrittr)
library(ggplot2)

Introduction

The Integrated Gradient (IG) method can be used to determine what parts of an input sequence are important for the models decision. We start with training a model that can differentiate sequences based on the GC content (as described in the Getting started tutorial).

Model Training

We create two simple dummy training and validation data sets. Both consist of random ACGT sequences but the first category has a probability of 40% each for drawing G or C and the second has equal probability for each nucleotide (first category has around 80% GC content and second one around 50%).

set.seed(123)

# Create data 
vocabulary <- c("A", "C", "G", "T")
data_type <- c("train_1", "train_2", "val_1", "val_2")

for (i in 1:length(data_type)) {
  
  temp_file <- tempfile()
  assign(paste0(data_type[i], "_dir"), temp_file)
  dir.create(temp_file)
  
  if (i %% 2 == 1) {
    header <- "label_1"
    prob <- c(0.1, 0.4, 0.4, 0.1)
  } else {
    header <- "label_2"
    prob <- rep(0.25, 4)
  }
  fasta_name_start <- paste0(header, "_", data_type[i], "file")
  
  create_dummy_data(file_path = temp_file,
                    num_files = 1,
                    seq_length = 20000, 
                    num_seq = 1,
                    header = header,
                    prob = prob,
                    fasta_name_start = fasta_name_start,
                    vocabulary = vocabulary)
  
}

# Create model
maxlen <- 50
model <- create_model_lstm_cnn(maxlen = maxlen,
                               filters = c(8, 16),
                               kernel_size = c(8, 8),
                               pool_size = c(3, 3),
                               layer_lstm = 8,
                               layer_dense = c(4, 2),
                               model_seed = 3)
## Model: "model"
## _________________________________________________________________
##  Layer (type)                Output Shape              Param #   
## =================================================================
##  input_1 (InputLayer)        [(None, 50, 4)]           0         
##                                                                  
##  conv1d (Conv1D)             (None, 50, 8)             264       
##                                                                  
##  max_pooling1d (MaxPooling1  (None, 16, 8)             0         
##  D)                                                              
##                                                                  
##  batch_normalization (Batch  (None, 16, 8)             32        
##  Normalization)                                                  
##                                                                  
##  conv1d_1 (Conv1D)           (None, 16, 16)            1040      
##                                                                  
##  batch_normalization_1 (Bat  (None, 16, 16)            64        
##  chNormalization)                                                
##                                                                  
##  max_pooling1d_1 (MaxPoolin  (None, 5, 16)             0         
##  g1D)                                                            
##                                                                  
##  lstm (LSTM)                 (None, 8)                 800       
##                                                                  
##  dense (Dense)               (None, 4)                 36        
##                                                                  
##  dense_1 (Dense)             (None, 2)                 10        
##                                                                  
## =================================================================
## Total params: 2246 (8.77 KB)
## Trainable params: 2198 (8.59 KB)
## Non-trainable params: 48 (192.00 Byte)
## _________________________________________________________________
# Train model
hist <- train_model(model,
                    train_type = "label_folder",
                    run_name = "gc_model_1",
                    path = c(train_1_dir, train_2_dir),
                    path_val = c(val_1_dir, val_2_dir),
                    epochs = 6, 
                    batch_size = 64,
                    steps_per_epoch = 50, 
                    step = 50, 
                    vocabulary_label = c("high_gc", "equal_dist"))
## Epoch 1/6
##  1/50 [..............................] - ETA: 1:00 - loss: 0.7005 - acc: 0.3906 6/50 [==>...........................] - ETA: 0s - loss: 0.6881 - acc: 0.5417  10/50 [=====>........................] - ETA: 0s - loss: 0.6825 - acc: 0.578115/50 [========>.....................] - ETA: 0s - loss: 0.6681 - acc: 0.676021/50 [===========>..................] - ETA: 0s - loss: 0.6466 - acc: 0.746325/50 [==============>...............] - ETA: 0s - loss: 0.6319 - acc: 0.777529/50 [================>.............] - ETA: 0s - loss: 0.6159 - acc: 0.805034/50 [===================>..........] - ETA: 0s - loss: 0.5983 - acc: 0.826738/50 [=====================>........] - ETA: 0s - loss: 0.5831 - acc: 0.842543/50 [========================>.....] - ETA: 0s - loss: 0.5637 - acc: 0.858347/50 [===========================>..] - ETA: 0s - loss: 0.5501 - acc: 0.867450/50 [==============================] - ETA: 0s - loss: 0.5391 - acc: 0.874750/50 [==============================] - 2s 20ms/step - loss: 0.5391 - acc: 0.8747 - val_loss: 0.5296 - val_acc: 0.9578 - lr: 0.0010
## Epoch 2/6
##  1/50 [..............................] - ETA: 0s - loss: 0.3424 - acc: 0.9844 6/50 [==>...........................] - ETA: 0s - loss: 0.3381 - acc: 0.981811/50 [=====>........................] - ETA: 0s - loss: 0.3287 - acc: 0.984417/50 [=========>....................] - ETA: 0s - loss: 0.3093 - acc: 0.985320/50 [===========>..................] - ETA: 0s - loss: 0.3038 - acc: 0.985225/50 [==============>...............] - ETA: 0s - loss: 0.2914 - acc: 0.986231/50 [=================>............] - ETA: 0s - loss: 0.2775 - acc: 0.985936/50 [====================>.........] - ETA: 0s - loss: 0.2672 - acc: 0.987039/50 [======================>.......] - ETA: 0s - loss: 0.2604 - acc: 0.987242/50 [========================>.....] - ETA: 0s - loss: 0.2541 - acc: 0.987746/50 [==========================>...] - ETA: 0s - loss: 0.2471 - acc: 0.987849/50 [============================>.] - ETA: 0s - loss: 0.2413 - acc: 0.988250/50 [==============================] - 1s 16ms/step - loss: 0.2392 - acc: 0.9884 - val_loss: 0.3314 - val_acc: 0.9406 - lr: 0.0010
## Epoch 3/6
##  1/50 [..............................] - ETA: 0s - loss: 0.1552 - acc: 0.9844 6/50 [==>...........................] - ETA: 0s - loss: 0.1424 - acc: 0.992211/50 [=====>........................] - ETA: 0s - loss: 0.1349 - acc: 0.994316/50 [========>.....................] - ETA: 0s - loss: 0.1275 - acc: 0.995122/50 [============>.................] - ETA: 0s - loss: 0.1222 - acc: 0.995026/50 [==============>...............] - ETA: 0s - loss: 0.1183 - acc: 0.995232/50 [==================>...........] - ETA: 0s - loss: 0.1130 - acc: 0.995136/50 [====================>.........] - ETA: 0s - loss: 0.1090 - acc: 0.995741/50 [=======================>......] - ETA: 0s - loss: 0.1045 - acc: 0.995846/50 [==========================>...] - ETA: 0s - loss: 0.1011 - acc: 0.995650/50 [==============================] - 1s 14ms/step - loss: 0.0977 - acc: 0.9959 - val_loss: 0.1857 - val_acc: 0.9594 - lr: 0.0010
## Epoch 4/6
##  1/50 [..............................] - ETA: 0s - loss: 0.0734 - acc: 0.9844 7/50 [===>..........................] - ETA: 0s - loss: 0.0624 - acc: 0.995511/50 [=====>........................] - ETA: 0s - loss: 0.0580 - acc: 0.997217/50 [=========>....................] - ETA: 0s - loss: 0.0545 - acc: 0.997223/50 [============>.................] - ETA: 0s - loss: 0.0528 - acc: 0.997329/50 [================>.............] - ETA: 0s - loss: 0.0503 - acc: 0.997335/50 [====================>.........] - ETA: 0s - loss: 0.0488 - acc: 0.997340/50 [=======================>......] - ETA: 0s - loss: 0.0468 - acc: 0.997746/50 [==========================>...] - ETA: 0s - loss: 0.0454 - acc: 0.997650/50 [==============================] - 1s 14ms/step - loss: 0.0441 - acc: 0.9978 - val_loss: 0.1155 - val_acc: 0.9656 - lr: 0.0010
## Epoch 5/6
##  1/50 [..............................] - ETA: 0s - loss: 0.0293 - acc: 1.0000 7/50 [===>..........................] - ETA: 0s - loss: 0.0305 - acc: 0.997811/50 [=====>........................] - ETA: 0s - loss: 0.0288 - acc: 0.998616/50 [========>.....................] - ETA: 0s - loss: 0.0272 - acc: 0.999021/50 [===========>..................] - ETA: 0s - loss: 0.0266 - acc: 0.999325/50 [==============>...............] - ETA: 0s - loss: 0.0258 - acc: 0.999431/50 [=================>............] - ETA: 0s - loss: 0.0249 - acc: 0.999536/50 [====================>.........] - ETA: 0s - loss: 0.0242 - acc: 0.999641/50 [=======================>......] - ETA: 0s - loss: 0.0235 - acc: 0.999646/50 [==========================>...] - ETA: 0s - loss: 0.0229 - acc: 0.999750/50 [==============================] - ETA: 0s - loss: 0.0224 - acc: 0.999750/50 [==============================] - 1s 16ms/step - loss: 0.0224 - acc: 0.9997 - val_loss: 0.0869 - val_acc: 0.9750 - lr: 0.0010
## Epoch 6/6
##  1/50 [..............................] - ETA: 0s - loss: 0.0164 - acc: 1.0000 7/50 [===>..........................] - ETA: 0s - loss: 0.0161 - acc: 1.000011/50 [=====>........................] - ETA: 0s - loss: 0.0159 - acc: 1.000016/50 [========>.....................] - ETA: 0s - loss: 0.0154 - acc: 1.000022/50 [============>.................] - ETA: 0s - loss: 0.0151 - acc: 1.000026/50 [==============>...............] - ETA: 0s - loss: 0.0148 - acc: 1.000032/50 [==================>...........] - ETA: 0s - loss: 0.0145 - acc: 1.000037/50 [=====================>........] - ETA: 0s - loss: 0.0142 - acc: 1.000043/50 [========================>.....] - ETA: 0s - loss: 0.0138 - acc: 1.000048/50 [===========================>..] - ETA: 0s - loss: 0.0136 - acc: 1.000050/50 [==============================] - 1s 13ms/step - loss: 0.0135 - acc: 1.0000 - val_loss: 0.0858 - val_acc: 0.9766 - lr: 0.0010
## Training done.
plot(hist)

Integrated Gradient

We can try to visualize what parts of an input sequence is important for the models decision, using Integrated Gradient. Let’s create a sequence with a high GC content. We use same number of Cs as Gs and of As as Ts.

set.seed(321)
g_count <- 17
stopifnot(g_count < 25)
a_count <- (50 - (2*g_count))/2  
high_gc_seq <- c(rep("G", g_count), rep("C", g_count), rep("A", a_count), rep("T", a_count))
high_gc_seq <- high_gc_seq[sample(maxlen)] %>% paste(collapse = "") # shuffle nt order
high_gc_seq
## [1] "TGCGCGAGCCCAGCTAAGCGGCCTCCTTAGGCTGCCGGCGGGATCAGCTA"

We need to one-hot encode the sequence before applying Integrated Gradient.

high_gc_seq_one_hot <- seq_encoding_label(char_sequence = high_gc_seq,
                                          maxlen = 50,
                                          start_ind = 1,
                                          vocabulary = vocabulary)
head(high_gc_seq_one_hot[1,,])
##      [,1] [,2] [,3] [,4]
## [1,]    0    0    0    1
## [2,]    0    0    1    0
## [3,]    0    1    0    0
## [4,]    0    0    1    0
## [5,]    0    1    0    0
## [6,]    0    0    1    0

Our model should be confident, this sequences belongs to the first class

pred <- predict(model, high_gc_seq_one_hot, verbose = 0)
colnames(pred) <- c("high_gc", "equal_dist")
pred
##        high_gc equal_dist
## [1,] 0.9657075  0.0342925

We can visualize what parts where important for the prediction.

ig <- integrated_gradients(
  input_seq = high_gc_seq_one_hot,
  target_class_idx = 1,
  model = model)

if (requireNamespace("ComplexHeatmap", quietly = TRUE)) {
  heatmaps_integrated_grad(integrated_grads = ig,
                           input_seq = high_gc_seq_one_hot)
} else {
  message("Skipping ComplexHeatmap-related code because the package is not installed.")
}
## [[1]]

We may test how our models prediction changes if we exchange certain nucleotides in the input sequence. First, we look for the positions with the smallest IG score.

ig <- as.array(ig)
smallest_index <- which(ig == min(ig), arr.ind = TRUE)
smallest_index
##      row col
## [1,]  33   4

We may change the nucleotide with the lowest score and observe the change in prediction confidence

# copy original sequence
high_gc_seq_one_hot_changed <- high_gc_seq_one_hot 

# prediction for original sequence
predict(model, high_gc_seq_one_hot, verbose = 0)
##           [,1]      [,2]
## [1,] 0.9657075 0.0342925
# change nt
smallest_index <- which(ig == min(ig), arr.ind = TRUE)
smallest_index
##      row col
## [1,]  33   4
row_index <- smallest_index[ , "row"]
col_index <- smallest_index[ , "col"]               
new_row <- rep(0, 4)
nt_index_old <- col_index
nt_index_new <- which.max(ig[row_index, ])
new_row[nt_index_new] <- 1
high_gc_seq_one_hot_changed[1, row_index, ] <- new_row
cat("At position", row_index, "changing", vocabulary[nt_index_old], "to", vocabulary[nt_index_new], "\n")
## At position 33 changing T to A
pred <- predict(model, high_gc_seq_one_hot_changed, verbose = 0)
print(pred)
##           [,1]       [,2]
## [1,] 0.9255649 0.07443508

Let’s repeatedly apply the previous step and change the sequence after each iteration.

# copy original sequence
high_gc_seq_one_hot_changed <- high_gc_seq_one_hot 

pred_list <- list()
pred_list[[1]] <- pred <- predict(model, high_gc_seq_one_hot, verbose = 0)

# change nts
for (i in 1:20) {
  
  # update ig scores for changed input
  ig <- integrated_gradients(
    input_seq = high_gc_seq_one_hot_changed,
    target_class_idx = 1,
    model = model) %>% as.array()
  
  smallest_index <- which(ig == min(ig), arr.ind = TRUE)
  smallest_index
  row_index <- smallest_index[ , "row"]
  col_index <- smallest_index[ , "col"]               
  new_row <- rep(0, 4)
  nt_index_old <- col_index
  nt_index_new <- which.max(ig[row_index, ])
  new_row[nt_index_new] <- 1
  high_gc_seq_one_hot_changed[1, row_index, ] <- new_row
  cat("At position", row_index, "changing", vocabulary[nt_index_old],
      "to", vocabulary[nt_index_new], "\n")
  pred <- predict(model, high_gc_seq_one_hot_changed, verbose = 0)
  pred_list[[i + 1]] <- pred 
  
}
## At position 33 changing T to A 
## At position 15 changing T to A 
## At position 46 changing A to C 
## At position 11 changing C to A 
## At position 19 changing C to A 
## At position 19 changing A to C 
## At position 19 changing C to A 
## At position 19 changing A to C 
## At position 19 changing C to A 
## At position 19 changing A to C 
## At position 19 changing C to A 
## At position 19 changing A to C 
## At position 19 changing C to A 
## At position 19 changing A to C 
## At position 19 changing C to A 
## At position 19 changing A to C 
## At position 19 changing C to A 
## At position 19 changing A to C 
## At position 19 changing C to A 
## At position 19 changing A to C
pred_df <- do.call(rbind, pred_list)
pred_df <- data.frame(pred_df, iteration = 0:(nrow(pred_df) - 1))
names(pred_df) <- c("high_gc", "equal_dist", "iteration")
ggplot(pred_df, aes(x = iteration, y = high_gc)) + geom_line() + ylab("high GC confidence")

We can try the same in the opposite direction, i.e. replace big IG scores.

# copy original sequence
high_gc_seq_one_hot_changed <- high_gc_seq_one_hot 

pred_list <- list()
pred <- predict(model, high_gc_seq_one_hot, verbose = 0)
pred_list[[1]] <- pred

# change nts
for (i in 1:20) {
  
  # update ig scores for changed input
  ig <- integrated_gradients(
    input_seq = high_gc_seq_one_hot_changed,
    target_class_idx = 1,
    model = model) %>% as.array()
  
  biggest_index <- which(ig == max(ig), arr.ind = TRUE)
  biggest_index
  row_index <- biggest_index[ , "row"]
  col_index <- biggest_index[ , "col"]               
  new_row <- rep(0, 4)
  nt_index_old <- col_index
  nt_index_new <- which.min(ig[row_index, ])
  new_row[nt_index_new] <- 1
  high_gc_seq_one_hot_changed[1, row_index, ] <- new_row
  cat("At position", row_index, "changing", vocabulary[nt_index_old], "to", vocabulary[nt_index_new], "\n")
  
  pred <- predict(model, high_gc_seq_one_hot_changed, verbose = 0)
  pred_list[[i + 1]] <- pred 
  
}
## At position 30 changing G to A 
## At position 20 changing G to A 
## At position 34 changing G to A 
## At position 38 changing G to A 
## At position 32 changing C to A 
## At position 18 changing G to A 
## At position 19 changing C to A 
## At position 23 changing C to A 
## At position 25 changing C to A 
## At position 48 changing C to A 
## At position 41 changing G to A 
## At position 10 changing C to A 
## At position 40 changing G to A 
## At position 37 changing G to A 
## At position 42 changing G to A 
## At position 35 changing C to A 
## At position 6 changing G to A 
## At position 36 changing C to A 
## At position 45 changing C to A 
## At position 13 changing G to A
pred_df <- do.call(rbind, pred_list)
pred_df <- data.frame(pred_df, iteration = 0:(nrow(pred_df) - 1))
names(pred_df) <- c("high_gc", "equal_dist", "iteration")
ggplot(pred_df, aes(x = iteration, y = high_gc)) + geom_line() + ylab("high GC confidence")