Switch to unified view

a b/vignettes/integrated_gradient.Rmd
1
---
2
title: "Integrated Gradient"
3
output: rmarkdown::html_vignette
4
vignette: >
5
  %\VignetteIndexEntry{Integrated Gradient}
6
  %\VignetteEngine{knitr::rmarkdown}
7
  %\VignetteEncoding{UTF-8}
8
---
9
10
```{r, echo=FALSE, warning=FALSE, message=FALSE}
11
12
if (!reticulate::py_module_available("tensorflow")) {
13
  knitr::opts_chunk$set(eval = FALSE)
14
} else {
15
  knitr::opts_chunk$set(eval = TRUE)
16
}
17
```
18
  
19
```{r, message=FALSE}
20
library(deepG)
21
library(keras)
22
library(magrittr)
23
library(ggplot2)
24
library(reticulate)
25
```
26
27
28
```{r, echo=FALSE, warning=FALSE, message=FALSE}
29
options(rmarkdown.html_vignette.check_title = FALSE)
30
```
31
32
```{css, echo=FALSE}
33
mark.in {
34
background-color: CornflowerBlue;
35
}
36
37
mark.out {
38
background-color: IndianRed;
39
}
40
41
```
42
43
## Introduction 
44
45
The  <a href="https://arxiv.org/abs/1703.01365">Integrated Gradient</a> (IG) method can be used to determine what parts of an input sequence are important for the models decision.
46
We start with training a model that can differentiate sequences based on the GC content 
47
(as described in the <a href="getting_started.html">Getting started tutorial</a>). 
48
49
50
## Model Training
51
52
We create two simple dummy training and validation data sets. Both consist of random <tt>ACGT</tt> sequences but the first category has 
53
a probability of 40% each for drawing <tt>G</tt> or <tt>C</tt> and the second has equal probability for each nucleotide (first category has around 80% <tt>GC</tt> content and second one around 50%).   
54
55
```{r warning = FALSE}
56
set.seed(123)
57
58
# Create data 
59
vocabulary <- c("A", "C", "G", "T")
60
data_type <- c("train_1", "train_2", "val_1", "val_2")
61
62
for (i in 1:length(data_type)) {
63
  
64
  temp_file <- tempfile()
65
  assign(paste0(data_type[i], "_dir"), temp_file)
66
  dir.create(temp_file)
67
  
68
  if (i %% 2 == 1) {
69
    header <- "label_1"
70
    prob <- c(0.1, 0.4, 0.4, 0.1)
71
  } else {
72
    header <- "label_2"
73
    prob <- rep(0.25, 4)
74
  }
75
  fasta_name_start <- paste0(header, "_", data_type[i], "file")
76
  
77
  create_dummy_data(file_path = temp_file,
78
                    num_files = 1,
79
                    seq_length = 20000, 
80
                    num_seq = 1,
81
                    header = header,
82
                    prob = prob,
83
                    fasta_name_start = fasta_name_start,
84
                    vocabulary = vocabulary)
85
  
86
}
87
88
# Create model
89
maxlen <- 50
90
model <- create_model_lstm_cnn(maxlen = maxlen,
91
                               filters = c(8, 16),
92
                               kernel_size = c(8, 8),
93
                               pool_size = c(3, 3),
94
                               layer_lstm = 8,
95
                               layer_dense = c(4, 2),
96
                               model_seed = 3)
97
98
# Train model
99
hist <- train_model(model,
100
                    train_type = "label_folder",
101
                    run_name = "gc_model_1",
102
                    path = c(train_1_dir, train_2_dir),
103
                    path_val = c(val_1_dir, val_2_dir),
104
                    epochs = 6, 
105
                    batch_size = 64,
106
                    steps_per_epoch = 50, 
107
                    step = 50, 
108
                    vocabulary_label = c("high_gc", "equal_dist"))
109
110
plot(hist)
111
```
112
113
114
## Integrated Gradient
115
116
We can try to visualize what parts of an input sequence is important for the models decision, using Integrated Gradient.
117
Let's create a sequence with a high GC content. We use same number of Cs as Gs and of As as Ts.
118
119
```{r warning = FALSE}
120
set.seed(321)
121
g_count <- 17
122
stopifnot(g_count < 25)
123
a_count <- (50 - (2*g_count))/2  
124
high_gc_seq <- c(rep("G", g_count), rep("C", g_count), rep("A", a_count), rep("T", a_count))
125
high_gc_seq <- high_gc_seq[sample(maxlen)] %>% paste(collapse = "") # shuffle nt order
126
high_gc_seq
127
```
128
129
We need to one-hot encode the sequence before applying Integrated Gradient.
130
131
```{r warning = FALSE}
132
high_gc_seq_one_hot <- seq_encoding_label(char_sequence = high_gc_seq,
133
                                          maxlen = 50,
134
                                          start_ind = 1,
135
                                          vocabulary = vocabulary)
136
head(high_gc_seq_one_hot[1,,])
137
```
138
139
Our model should be confident, this sequences belongs to the first class
140
141
```{r warning = FALSE}
142
pred <- predict(model, high_gc_seq_one_hot, verbose = 0)
143
colnames(pred) <- c("high_gc", "equal_dist")
144
pred
145
```
146
147
We can visualize what parts where important for the prediction.
148
149
```{r warning = FALSE}
150
ig <- integrated_gradients(
151
  input_seq = high_gc_seq_one_hot,
152
  target_class_idx = 1,
153
  model = model)
154
155
if (requireNamespace("ComplexHeatmap", quietly = TRUE)) {
156
  heatmaps_integrated_grad(integrated_grads = ig,
157
                           input_seq = high_gc_seq_one_hot)
158
} else {
159
  message("Skipping ComplexHeatmap-related code because the package is not installed.")
160
}
161
162
```
163
164
We may test how our models prediction changes if we exchange certain nucleotides in the input sequence.
165
First, we look for the positions with the smallest IG score.
166
167
```{r warning = FALSE}
168
ig <- as.array(ig)
169
smallest_index <- which(ig == min(ig), arr.ind = TRUE)
170
smallest_index
171
```
172
173
We may change the nucleotide with the lowest score and observe the change in prediction confidence
174
175
```{r warning = FALSE}
176
# copy original sequence
177
high_gc_seq_one_hot_changed <- high_gc_seq_one_hot 
178
179
# prediction for original sequence
180
predict(model, high_gc_seq_one_hot, verbose = 0)
181
182
# change nt
183
smallest_index <- which(ig == min(ig), arr.ind = TRUE)
184
smallest_index
185
row_index <- smallest_index[ , "row"]
186
col_index <- smallest_index[ , "col"]               
187
new_row <- rep(0, 4)
188
nt_index_old <- col_index
189
nt_index_new <- which.max(ig[row_index, ])
190
new_row[nt_index_new] <- 1
191
high_gc_seq_one_hot_changed[1, row_index, ] <- new_row
192
cat("At position", row_index, "changing", vocabulary[nt_index_old], "to", vocabulary[nt_index_new], "\n")
193
194
pred <- predict(model, high_gc_seq_one_hot_changed, verbose = 0)
195
print(pred)
196
```
197
198
Let's repeatedly apply the previous step and change the sequence after each iteration.
199
200
```{r warning = FALSE}
201
# copy original sequence
202
high_gc_seq_one_hot_changed <- high_gc_seq_one_hot 
203
204
pred_list <- list()
205
pred_list[[1]] <- pred <- predict(model, high_gc_seq_one_hot, verbose = 0)
206
207
# change nts
208
for (i in 1:20) {
209
  
210
  # update ig scores for changed input
211
  ig <- integrated_gradients(
212
    input_seq = high_gc_seq_one_hot_changed,
213
    target_class_idx = 1,
214
    model = model) %>% as.array()
215
  
216
  smallest_index <- which(ig == min(ig), arr.ind = TRUE)
217
  smallest_index
218
  row_index <- smallest_index[ , "row"]
219
  col_index <- smallest_index[ , "col"]               
220
  new_row <- rep(0, 4)
221
  nt_index_old <- col_index
222
  nt_index_new <- which.max(ig[row_index, ])
223
  new_row[nt_index_new] <- 1
224
  high_gc_seq_one_hot_changed[1, row_index, ] <- new_row
225
  cat("At position", row_index, "changing", vocabulary[nt_index_old],
226
      "to", vocabulary[nt_index_new], "\n")
227
  pred <- predict(model, high_gc_seq_one_hot_changed, verbose = 0)
228
  pred_list[[i + 1]] <- pred 
229
  
230
}
231
232
pred_df <- do.call(rbind, pred_list)
233
pred_df <- data.frame(pred_df, iteration = 0:(nrow(pred_df) - 1))
234
names(pred_df) <- c("high_gc", "equal_dist", "iteration")
235
ggplot(pred_df, aes(x = iteration, y = high_gc)) + geom_line() + ylab("high GC confidence")
236
237
```
238
239
We can try the same in the opposite direction, i.e. replace big IG scores.
240
241
```{r warning = FALSE}
242
# copy original sequence
243
high_gc_seq_one_hot_changed <- high_gc_seq_one_hot 
244
245
pred_list <- list()
246
pred <- predict(model, high_gc_seq_one_hot, verbose = 0)
247
pred_list[[1]] <- pred
248
249
# change nts
250
for (i in 1:20) {
251
  
252
  # update ig scores for changed input
253
  ig <- integrated_gradients(
254
    input_seq = high_gc_seq_one_hot_changed,
255
    target_class_idx = 1,
256
    model = model) %>% as.array()
257
  
258
  biggest_index <- which(ig == max(ig), arr.ind = TRUE)
259
  biggest_index
260
  row_index <- biggest_index[ , "row"]
261
  row_index <- row_index[1]
262
  col_index <- biggest_index[ , "col"]               
263
  new_row <- rep(0, 4)
264
  nt_index_old <- col_index
265
  nt_index_new <- which.min(ig[row_index, ])
266
  new_row[nt_index_new] <- 1
267
  high_gc_seq_one_hot_changed[1, row_index, ] <- new_row
268
  cat("At position", row_index, "changing", vocabulary[nt_index_old], "to", vocabulary[nt_index_new], "\n")
269
  
270
  pred <- predict(model, high_gc_seq_one_hot_changed, verbose = 0)
271
  pred_list[[i + 1]] <- pred 
272
  
273
}
274
275
pred_df <- do.call(rbind, pred_list)
276
pred_df <- data.frame(pred_df, iteration = 0:(nrow(pred_df) - 1))
277
names(pred_df) <- c("high_gc", "equal_dist", "iteration")
278
ggplot(pred_df, aes(x = iteration, y = high_gc)) + geom_line() + ylab("high GC confidence")
279
```