Diff of /R/visualization.R [000000] .. [409433]

Switch to unified view

a b/R/visualization.R
1
#' Interpolation between baseline and prediction
2
#'
3
#' @param baseline_type Baseline sequence, either "zero" for all zeros or "shuffle" for random permutation of input_seq.
4
#' @param m_steps Number of steps between baseline and original input.
5
#' @param input_seq Input tensor.
6
#' @noRd
7
interpolate_seq <- function(m_steps = 50,
8
                            baseline_type = "shuffle",
9
                            input_seq) {
10
  
11
  stopifnot(baseline_type %in% c("zero", "shuffle", "unif"))
12
  if (is.list(input_seq)) {
13
    baseline <- list()
14
    for (i in 1:length(input_seq)) {
15
      input_dim <- dim(input_seq[[i]])
16
      if (baseline_type == "zero") {
17
        baseline[[i]] <- array(rep(0, prod(input_dim)), dim = input_dim)
18
      } 
19
      if (baseline_type == "shuffle") {
20
        input_dim <- dim(input_seq[[i]])
21
        baseline[[i]] <- array(input_seq[[i]][ , sample(input_dim[2]), ], dim = input_dim)
22
      }
23
      if (baseline_type == "unif") {
24
        baseline[[i]] <- array(stats::runif(prod(input_dim)), dim = input_dim)
25
      } 
26
    }
27
  } else {
28
    if (baseline_type == "zero") {
29
      baseline <- array(rep(0, prod(dim(input_seq))), dim = dim(input_seq))
30
    } 
31
    if (baseline_type == "shuffle") {
32
      baseline <- array(input_seq[ , sample(dim(input_seq)[2]), ], dim = dim(input_seq))
33
    }
34
    if (baseline_type == "unif") {
35
      baseline <- array(stats::runif(prod(dim(input_seq))), dim = dim(input_seq))
36
    } 
37
  }
38
  
39
  m_steps <- as.integer(m_steps)
40
  alphas <- tensorflow::tf$linspace(start = 0.0, stop = 1.0, num = m_steps + 1L) # Generate m_steps intervals for integral_approximation() below.
41
  alphas_x <- alphas[ , tensorflow::tf$newaxis, tensorflow::tf$newaxis]
42
  if (is.list(baseline)) {
43
    delta <- list()
44
    sequences <- list()
45
    for (i in 1:length(baseline)) {
46
      delta[[i]] <- input_seq[[i]] - baseline[[i]]
47
      sequences[[i]] <- baseline[[i]] +  alphas_x * delta[[i]]
48
    }
49
  } else {
50
    delta <- input_seq - baseline
51
    sequences <- baseline +  alphas_x * delta
52
  }
53
  return(sequences)
54
}
55
56
#' Compute gradients
57
#'
58
#' @param input_idx  Input layer to monitor for > 1 input.
59
#' @param target_class_idx Index of class to compute gradient for.
60
#' @param model Model to compute gradient for.
61
#' @param pred_stepwise Whether to do predictions with batch_size 1 rather than all at once. Can be used if
62
#' input is too big to handle at once.
63
#' @noRd
64
compute_gradients <- function(input_tensor, target_class_idx, model, input_idx = NULL, pred_stepwise = FALSE) {
65
  
66
  # if (is.list(input_tensor)) {
67
  #   stop("Stepwise predictions only supported for single input layer yet")
68
  # }
69
70
  reticulate::py_run_string("import tensorflow as tf")
71
  py$input_tensor <- input_tensor
72
  py$input_idx <- as.integer(input_idx - 1)
73
  py$target_class_idx <- as.integer(target_class_idx - 1)
74
  py$model <- model
75
  
76
  if (!is.null(input_idx)) {
77
    reticulate::py_run_string(
78
      "with tf.GradientTape() as tape:
79
             tape.watch(input_tensor[input_idx])
80
             probs = model(input_tensor)[:, target_class_idx]
81
    ")
82
  } else {
83
    reticulate::py_run_string(
84
      "with tf.GradientTape() as tape:
85
             tape.watch(input_tensor)
86
             probs = model(input_tensor)[:, target_class_idx]
87
    ")
88
  }
89
  
90
  grad <- py$tape$gradient(py$probs, py$input_tensor)
91
  if (!is.null(input_idx)) {
92
    return(grad[[input_idx]])
93
  } else {
94
    return(grad)
95
  }
96
}
97
98
integral_approximation <- function(gradients) {
99
  reticulate::py_run_string("import tensorflow as tf")
100
  py$gradients <- gradients
101
  # riemann_trapezoidal
102
  reticulate::py_run_string("grads = (gradients[:-1] + gradients[1:]) / tf.constant(2.0)")
103
  reticulate::py_run_string("integrated_gradients = tf.math.reduce_mean(grads, axis=0)")
104
  return(py$integrated_gradients)
105
}
106
107
#' Compute integrated gradients 
108
#' 
109
#' Computes integrated gradients scores for model and an input sequence.
110
#' This can be used to visualize what part of the input is import for the models decision.
111
#' Code is R implementation of python code from [here](https://www.tensorflow.org/tutorials/interpretability/integrated_gradients).
112
#' Tensorflow implementation is based on this [paper](https://arxiv.org/abs/1703.01365).
113
#' 
114
#' @param baseline_type Baseline sequence, either `"zero"` for all zeros or `"shuffle"` for random permutation of `input_seq`.
115
#' @param m_steps Number of steps between baseline and original input.
116
#' @param input_seq Input tensor.
117
#' @param target_class_idx Index of class to compute gradient for
118
#' @param model Model to compute gradient for.
119
#' @param pred_stepwise Whether to do predictions with batch size 1 rather than all at once. Can be used if
120
#' input is too big to handle at once. Only supported for single input layer.
121
#' @param num_baseline_repeats Number of different baseline estimations if baseline_type is `"shuffle"` (estimate integrated
122
#' gradient repeatedly for different shuffles). Final result is average of \code{num_baseline} single calculations.
123
#' @examplesIf reticulate::py_module_available("tensorflow")
124
#' library(reticulate)
125
#' model <- create_model_lstm_cnn(layer_lstm = 8, layer_dense = 3, maxlen = 20, verbose = FALSE)
126
#' random_seq <- sample(0:3, 20, replace = TRUE)
127
#' input_seq <- array(keras::to_categorical(random_seq), dim = c(1, 20, 4))
128
#' integrated_gradients(
129
#'   input_seq = input_seq,
130
#'   target_class_idx = 3,
131
#'   model = model)
132
#'   
133
#' @returns A tensorflow tensor.
134
#' @export
135
integrated_gradients <- function(m_steps = 50,
136
                                 baseline_type = "zero",
137
                                 input_seq,
138
                                 target_class_idx,
139
                                 model,
140
                                 pred_stepwise = FALSE,
141
                                 num_baseline_repeats = 1) {
142
143
  #library(reticulate)
144
  reticulate::py_run_string("import tensorflow as tf")
145
  input_idx <- NULL
146
  if (num_baseline_repeats > 1 & baseline_type == "zero") {
147
    warning('Ignoring num_baseline_repeats if baseline is of type "zero". Did you mean to use baseline_type = "shuffle"?')
148
  }
149
  
150
  if (num_baseline_repeats == 1 | baseline_type == "zero") {
151
    
152
    baseline_seq <- interpolate_seq(m_steps = m_steps,
153
                                    baseline_type = baseline_type,
154
                                    input_seq = input_seq)
155
    
156
    if (is.list(baseline_seq)) {
157
      for (i in 1:length(baseline_seq)) {
158
        baseline_seq[[i]] <- tensorflow::tf$cast(baseline_seq[[i]], dtype = "float32")
159
      }
160
    } else {
161
      baseline_seq <- tensorflow::tf$cast(baseline_seq, dtype = "float32")
162
    }
163
    
164
    if (is.list(input_seq)) {
165
      path_gradients <- list()
166
      avg_grads <- list()
167
      ig <- list()
168
      
169
      if (pred_stepwise) {
170
        path_gradients <- gradients_stepwise(
171
          model = model,
172
          baseline_seq = baseline_seq,
173
          target_class_idx = target_class_idx)
174
      } else {
175
        
176
        path_gradients <- compute_gradients(
177
          model = model,
178
          input_tensor = baseline_seq,
179
          target_class_idx = target_class_idx,
180
          input_idx = NULL,
181
          pred_stepwise = pred_stepwise)
182
      }
183
      
184
      for (i in 1:length(input_seq)) {
185
        avg_grads[[i]] <- integral_approximation(gradients = path_gradients[[i]])
186
        ig[[i]] <- ((input_seq[[i]] - baseline_seq[[i]][1, , ]) * avg_grads[[i]])[1, , ]
187
      }
188
    } else {
189
      
190
      if (pred_stepwise) {
191
        path_gradients <- gradients_stepwise(model = model,
192
                                             baseline_seq = baseline_seq,
193
                                             target_class_idx = target_class_idx,
194
                                             input_idx = NULL)
195
      } else {
196
        path_gradients <- compute_gradients(
197
          model = model,
198
          input_tensor = baseline_seq,
199
          target_class_idx = target_class_idx,
200
          input_idx = NULL,
201
          pred_stepwise = pred_stepwise)
202
      }
203
      
204
      avg_grads <- integral_approximation(gradients = path_gradients)
205
      ig <- ((input_seq - baseline_seq[1, , ]) * avg_grads)[1, , ]
206
    }
207
  } else {
208
    ig_list <- list()
209
    for (i in 1:num_baseline_repeats) {
210
      ig_list[[i]] <- integrated_gradients(m_steps = m_steps,
211
                                           baseline_type = "shuffle",
212
                                           input_seq = input_seq,
213
                                           target_class_idx = target_class_idx,
214
                                           model = model,
215
                                           pred_stepwise = pred_stepwise,
216
                                           num_baseline_repeats = 1)
217
    }
218
    ig_stacked <- tensorflow::tf$stack(ig_list, axis = 0L)
219
    ig <- tensorflow::tf$reduce_mean(ig_stacked, axis = 0L)
220
  }
221
  
222
  return(ig)
223
}
224
225
#' Compute gradients stepwise (one batch at a time)
226
#'
227
#' @noRd
228
gradients_stepwise <- function(model = model, baseline_seq, target_class_idx,
229
                               input_idx = NULL) {
230
  
231
  if (is.list(baseline_seq)) {
232
    first_dim <- dim(baseline_seq[[1]])[1]
233
    num_input_layers <- length(baseline_seq)
234
    
235
    l <- list()
236
    for (j in 1:first_dim) {
237
      input_list <- list()
238
      for (k in 1:length(baseline_seq)) {
239
        input <- as.array(baseline_seq[[k]][j, , ])
240
        input <- array(input, dim = c(1, dim(baseline_seq[[k]])[-1]))
241
        input <- tensorflow::tf$cast(input, baseline_seq[[k]]$dtype)
242
        input_list[[k]] <- input
243
      }
244
      output <- compute_gradients(
245
        model = model,
246
        input_tensor = input_list,
247
        target_class_idx = target_class_idx,
248
        input_idx = NULL)
249
      for (m in 1:length(output)) {
250
        output[[m]] <- tensorflow::tf$squeeze(output[[m]])
251
      }
252
      l[[j]] <- output
253
    }
254
    
255
    path_gradients <- vector("list", num_input_layers)
256
    for (n in 1:num_input_layers) {
257
      temp_list <- vector("list", first_dim)
258
      for (p in 1:first_dim){
259
        temp_list[[p]] <- l[[p]][[n]]
260
      }
261
      path_gradients[[n]] <- tensorflow::tf$stack(temp_list)
262
    }
263
    
264
  } else {
265
    l <- list()
266
    for (j in 1:dim(baseline_seq)[1]) {
267
      input <- as.array(baseline_seq[j, , ])
268
      input <- array(input, dim = c(1, dim(baseline_seq)[-1]))
269
      input <- tensorflow::tf$cast(input, baseline_seq$dtype)
270
      output <- compute_gradients(
271
        model = model,
272
        input_tensor = input,
273
        target_class_idx = target_class_idx,
274
        input_idx = NULL)
275
      output <- tensorflow::tf$squeeze(output)
276
      l[[j]] <- output
277
    }
278
    path_gradients <- tensorflow::tf$stack(l)
279
  }
280
  return(path_gradients)
281
}
282
283
284
#' Heatmap of integrated gradient scores
285
#' 
286
#' Creates a heatmap from output of \code{\link{integrated_gradients}} function. The first row contains 
287
#' the column-wise absolute sums of IG scores and the second row the sums. Rows 3 to 6 contain the IG scores for each 
288
#' position and each nucleotide. The last row contains nucleotide information.
289
#'
290
#' @param integrated_grads Matrix of integrated gradient scores (output of \code{\link{integrated_gradients}} function).
291
#' @param input_seq Input sequence for model. Should be the same as \code{input_seq} input for corresponding
292
#' \code{\link{integrated_gradients}} call that computed input for \code{integrated_grads} argument.
293
#' @examplesIf reticulate::py_module_available("tensorflow")  && requireNamespace("ComplexHeatmap", quietly = TRUE)
294
#' library(reticulate)
295
#' model <- create_model_lstm_cnn(layer_lstm = 8, layer_dense = 3, maxlen = 20, verbose = FALSE)
296
#' random_seq <- sample(0:3, 20, replace = TRUE)
297
#' input_seq <- array(keras::to_categorical(random_seq), dim = c(1, 20, 4))
298
#' ig <- integrated_gradients(
299
#'   input_seq = input_seq,
300
#'   target_class_idx = 3,
301
#'   model = model)
302
#' heatmaps_integrated_grad(integrated_grads = ig,
303
#'                          input_seq = input_seq)
304
#'  
305
#' @returns A list of heatmaps.                          
306
#' @export
307
heatmaps_integrated_grad <- function(integrated_grads,
308
                                     input_seq) {
309
  
310
  if (is.list(input_seq)) {
311
    for (i in 1:length(input_seq)) {
312
      input_seq[[i]] <- tensorflow::tf$cast(input_seq[[i]], dtype = "float32")
313
    }
314
    
315
    for (i in 1:length(integrated_grads)) {
316
      integrated_grads[[i]] <- tensorflow::tf$cast(integrated_grads[[i]], dtype = "float32")
317
    }
318
  } else {
319
    input_seq <- tensorflow::tf$cast(input_seq, dtype = "float32")
320
    integrated_grads <- tensorflow::tf$cast(integrated_grads, dtype = "float32")
321
  }
322
  
323
  
324
  if (is.list(input_seq)) {
325
    num_input <- length(input_seq)
326
    attribution_mask <- list()
327
    nuc_matrix <- list()
328
    nuc_seq <- list()
329
    sum_nuc <- list()
330
    for (i in 1:length(integrated_grads)) {
331
      py$integrated_grads <- integrated_grads[[i]]
332
      reticulate::py_run_string("attribution_mask = tf.reduce_sum(tf.math.abs(integrated_grads), axis=-1)")
333
      reticulate::py_run_string("sum_nuc = tf.reduce_sum(integrated_grads, axis=-1)")
334
      attribution_mask[[i]] <- py$attribution_mask
335
      attribution_mask[[i]] <- as.matrix(attribution_mask[[i]], nrow = 1) %>% as.data.frame()
336
      colnames(attribution_mask[[i]]) <- "abs_sum"
337
      sum_nuc[[i]] <- py$sum_nuc
338
      sum_nuc[[i]] <- as.matrix(sum_nuc[[i]], nrow = 1) %>% as.data.frame()
339
      colnames(sum_nuc[[i]]) <- "sum"
340
      
341
      if (length(dim(integrated_grads[[i]])) == 3) {
342
        nuc_matrix[[i]] <- as.matrix(integrated_grads[[i]][1, , ])
343
      }
344
      if (length(dim(integrated_grads[[i]])) == 2) {
345
        nuc_matrix[[i]] <- as.matrix(integrated_grads[[i]])
346
      }
347
      amb_nuc <- (apply(input_seq[[i]][1, ,], 1, max) %>% as.character()) != "1"
348
      nuc_seq[[i]] <- apply(input_seq[[i]][1, ,], 1, which.max) %>% as.character()
349
      nuc_seq[[i]] <- nuc_seq[[i]] %>% stringr::str_replace_all("1", "A") %>%
350
        stringr::str_replace_all("2", "C") %>%
351
        stringr::str_replace_all("3", "G") %>%
352
        stringr::str_replace_all("4", "T")
353
      nuc_seq[[i]][amb_nuc] <- "0"
354
      rownames(nuc_matrix[[i]]) <- nuc_seq[[i]]
355
      colnames(nuc_matrix[[i]]) <- c("A", "C", "G", "T")
356
    }
357
    
358
  } else {
359
    num_input <- 1
360
    py$integrated_grads <- integrated_grads
361
    reticulate::py_run_string("attribution_mask = tf.reduce_sum(tf.math.abs(integrated_grads), axis=-1)")
362
    reticulate::py_run_string("sum_nuc = tf.reduce_sum(integrated_grads, axis=-1)")
363
    #py_run_string("mean_nuc = tf.reduce_mean(integrated_grads, axis=-1)")
364
    
365
    attribution_mask <- py$attribution_mask
366
    attribution_mask <- as.matrix(attribution_mask, nrow = 1) %>% as.data.frame()
367
    colnames(attribution_mask) <- "abs_sum"
368
    
369
    sum_nuc <- py$sum_nuc
370
    sum_nuc <- as.matrix(sum_nuc, nrow = 1) %>% as.data.frame()
371
    colnames(sum_nuc) <- "sum"
372
    
373
    if (length(dim(integrated_grads)) == 3) {
374
      nuc_matrix <- as.matrix(integrated_grads[1, , ])
375
    }
376
    if (length(dim(integrated_grads)) == 2) {
377
      nuc_matrix <- as.matrix(integrated_grads)
378
    }
379
    amb_nuc <- (apply(input_seq[1, ,], 1, max) %>% as.character()) != "1"
380
    nuc_seq <- apply(input_seq[1, ,], 1, which.max) %>% as.character()
381
    nuc_seq <- nuc_seq %>% stringr::str_replace_all("1", "A") %>%
382
      stringr::str_replace_all("2", "C") %>%
383
      stringr::str_replace_all("3", "G") %>%
384
      stringr::str_replace_all("4", "T")
385
    nuc_seq[amb_nuc] <- "0"
386
    rownames(nuc_matrix) <- nuc_seq
387
    colnames(nuc_matrix) <- c("A", "C", "G", "T")
388
  }
389
  
390
  if (num_input == 1) {
391
    ig_min <- keras::k_min(integrated_grads)$numpy()
392
    ig_max <- keras::k_max(integrated_grads)$numpy()
393
    col_fun <- circlize::colorRamp2(c(ig_min, mean(c(ig_max, ig_min)) , ig_max), c("blue", "white", "red"))
394
  } else {
395
    col_fun <- list()
396
    for (i in 1:num_input) {
397
      ig_min <- keras::k_min(integrated_grads[[i]])$numpy()
398
      ig_max <- keras::k_max(integrated_grads[[i]])$numpy()
399
      col_fun[[i]] <- circlize::colorRamp2(c(ig_min, mean(c(ig_max, ig_min)) , ig_max), c("blue", "white", "red"))
400
    }
401
  }
402
  
403
  hm_list <- list()
404
  if (num_input == 1) {
405
    row_ha = ComplexHeatmap::columnAnnotation(abs_sum = attribution_mask[,1], sum = sum_nuc[,1]) # mean = mean_nuc[,1]
406
    if (length(unique(row.names(nuc_matrix))) == 4) {
407
      nuc_col <- c("A" = "red", "C" = "green", "G" = "blue", "T" = "yellow")
408
    }
409
    if (length(unique(row.names(nuc_matrix))) == 5) {
410
      nuc_col <- c("A" = "red", "C" = "green", "G" = "blue", "T" = "yellow", "0" = "white")
411
    }
412
    ha <- ComplexHeatmap::HeatmapAnnotation(nuc = row.names(nuc_matrix), col = list(nuc = nuc_col))
413
    hm_list[[1]] <- ComplexHeatmap::Heatmap(matrix = t(nuc_matrix),
414
                                            name = "hm",
415
                                            top_annotation = row_ha,
416
                                            bottom_annotation = ha,
417
                                            col = col_fun,
418
                                            cluster_rows = FALSE,
419
                                            cluster_columns = FALSE,
420
                                            column_names_rot = 0
421
    )
422
  } else {
423
    for (i in 1:num_input) {
424
      row_ha <- ComplexHeatmap::columnAnnotation(abs_sum = attribution_mask[[i]][,1], sum = sum_nuc[[i]][,1])
425
      if (length(unique(row.names(nuc_matrix[[i]]))) == 4) {
426
        nuc_col <- c("A" = "red", "C" = "green", "G" = "blue", "T" = "yellow")
427
      }
428
      if (length(unique(row.names(nuc_matrix[[i]]))) == 5) {
429
        nuc_col <- c("A" = "red", "C" = "green", "G" = "blue", "T" = "yellow", "0" = "white")
430
      }
431
      ha <- ComplexHeatmap::HeatmapAnnotation(nuc = row.names(nuc_matrix[[i]]), col = list(nuc = nuc_col))
432
      hm_list[[i]] <- ComplexHeatmap::Heatmap(matrix = t(nuc_matrix[[i]]),
433
                                              name = paste0("hm_", i),
434
                                              top_annotation = row_ha,
435
                                              bottom_annotation = ha,
436
                                              col = col_fun[[i]],
437
                                              cluster_rows = FALSE,
438
                                              cluster_columns = FALSE,
439
                                              column_names_rot = 0
440
      )
441
    }
442
  }
443
  hm_list
444
}