Diff of /R/evaluation.R [000000] .. [409433]

Switch to unified view

a b/R/evaluation.R
1
#' Evaluates a trained model on fasta, fastq or rds files
2
#'
3
#' Returns evaluation metric like confusion matrix, loss, AUC, AUPRC, MAE, MSE (depending on output layer).
4
#'
5
#' @inheritParams generator_fasta_lm
6
#' @inheritParams generator_fasta_label_folder
7
#' @inheritParams generator_fasta_label_header_csv
8
#' @param path_input Input directory where fasta, fastq or rds files are located.
9
#' @param model A keras model.
10
#' @param batch_size Number of samples per batch.
11
#' @param step How often to take a sample.
12
#' @param vocabulary Vector of allowed characters. Character outside vocabulary get encoded as specified in ambiguous_nuc.
13
#' @param vocabulary_label List of labels for targets of each output layer.
14
#' @param number_batches How many batches to evaluate.
15
#' @param format File format, `"fasta"`, `"fastq"` or `"rds"`.
16
#' @param mode Either `"lm"` for language model or `"label_header"`, `"label_csv"` or `"label_folder"` for label classification.
17
#' @param verbose Boolean.
18
#' @param target_middle Whether model is language model with separate input layers. 
19
#' @param evaluate_all_files Boolean, if `TRUE` will iterate over all files in \code{path_input} once. \code{number_batches} will be overwritten.
20
#' @param auc Whether to include AUC metric. If output layer activation is `"softmax"`, only possible for 2 targets. Computes the average if output layer has sigmoid
21
#' activation and multiple targets.
22
#' @param auprc Whether to include AUPRC metric. If output layer activation is `"softmax"`, only possible for 2 targets. Computes the average if output layer has sigmoid
23
#' activation and multiple targets.
24
#' @param path_pred_list Path to store list of predictions (output of output layers) and corresponding true labels as rds file. 
25
#' @param exact_num_samples Exact number of samples to evaluate. If you want to evaluate a number of samples not divisible by batch_size. Useful if you want
26
#' to evaluate a data set exactly ones and know the number of samples already. Should be a vector if `mode = "label_folder"` (with same length as `vocabulary_label`)
27
#' and else an integer.
28
#' @param activations List containing output formats for output layers (`softmax, sigmoid` or `linear`). If `NULL`, will be estimated from model.   
29
#' @param include_seq Whether to store input. Only applies if `path_pred_list` is not `NULL`.
30
#' @param ... Further generator options. See \code{\link{get_generator}}.
31
#' @examplesIf reticulate::py_module_available("tensorflow")
32
#' # create dummy data
33
#' path_input <- tempfile()
34
#' dir.create(path_input)
35
#' create_dummy_data(file_path = path_input,
36
#'                   num_files = 3,
37
#'                   seq_length = 11, 
38
#'                   num_seq = 5,
39
#'                   vocabulary = c("a", "c", "g", "t"))
40
#' # create model
41
#' model <- create_model_lstm_cnn(layer_lstm = 8, layer_dense = 4, maxlen = 10, verbose = FALSE)
42
#' # evaluate
43
#' evaluate_model(path_input = path_input,
44
#'   model = model,
45
#'   step = 11,
46
#'   vocabulary = c("a", "c", "g", "t"),
47
#'   vocabulary_label = list(c("a", "c", "g", "t")),
48
#'   mode = "lm",
49
#'   output_format = "target_right",
50
#'   evaluate_all_files = TRUE,
51
#'   verbose = FALSE)
52
#'   
53
#' @returns A list of evaluation results. Each list element corresponds to an output layer of the model.   
54
#' @export
55
evaluate_model <- function(path_input,
56
                           model = NULL,
57
                           batch_size = 100,
58
                           step = 1,
59
                           padding = FALSE,
60
                           vocabulary = c("a", "c", "g", "t"),
61
                           vocabulary_label = list(c("a", "c", "g", "t")),
62
                           number_batches = 10,
63
                           format = "fasta",
64
                           target_middle = FALSE,
65
                           mode = "lm",
66
                           output_format = "target_right",
67
                           ambiguous_nuc = "zero",
68
                           evaluate_all_files = FALSE,
69
                           verbose = TRUE,
70
                           max_iter = 20000,
71
                           target_from_csv = NULL,
72
                           max_samples = NULL,
73
                           proportion_per_seq = NULL,
74
                           concat_seq = NULL,
75
                           seed = 1234,
76
                           auc = FALSE,
77
                           auprc = FALSE,
78
                           path_pred_list = NULL,
79
                           exact_num_samples = NULL,
80
                           activations = NULL,
81
                           shuffle_file_order = FALSE,
82
                           include_seq = FALSE,
83
                           ...) {
84
  
85
  set.seed(seed)
86
  path_model <- NULL
87
  stopifnot(mode %in% c("lm", "label_header", "label_folder", "label_csv", "lm_rds", "label_rds"))
88
  stopifnot(format %in% c("fasta", "fastq", "rds"))
89
  stopifnot(is.null(proportion_per_seq) || proportion_per_seq <= 1)
90
  if (!is.null(exact_num_samples) & evaluate_all_files) {
91
    warning(paste("Will evaluate number of samples as specified in exact_num_samples argument. Setting evaluate_all_files to FALSE."))
92
    evaluate_all_files <- FALSE
93
  }
94
  eval_exact_num_samples <- !is.null(exact_num_samples) | evaluate_all_files
95
  if (is.null(activations)) activations <- get_output_activations(model)
96
  if (is.null(path_pred_list) & include_seq) {
97
    stop("Can only store input, if path_pred_list is specified.")
98
  }
99
  if (is.null(vocabulary_label)) vocabulary_label <- list(vocabulary)
100
  if (!is.list(vocabulary_label)) vocabulary_label <- list(vocabulary_label)
101
  if (mode == "label_folder") {
102
    number_batches <- rep(ceiling(number_batches/length(path_input)), length(path_input))
103
  }
104
  num_classes <- ifelse(mode == "label_folder", length(path_input), 1)
105
  num_out_layers <- length(activations)
106
  
107
  # extract maxlen from model
108
  num_in_layers <- length(model$inputs)
109
  if (num_in_layers == 1) {
110
    maxlen <- model$input$shape[[2]]
111
  } else {
112
    if (!target_middle) {
113
      maxlen <- model$input[[num_in_layers]]$shape[[2]]
114
    } else {
115
      maxlen <- model$input[[num_in_layers - 1]]$shape[[2]] + model$input[[num_in_layers]]$shape[[2]]
116
    }
117
  }
118
  
119
  if (evaluate_all_files & (format %in% c("fasta", "fastq"))) {
120
    
121
    number_batches <- NULL
122
    num_samples <- rep(0, length(path_input))
123
    
124
    for (i in 1:num_classes) {
125
      if (mode == "label_folder") {
126
        files <- list_fasta_files(path_input[[i]], format = format, file_filter = NULL)
127
      } else {
128
        files <- list_fasta_files(path_input, format = format, file_filter = NULL)
129
      }
130
      
131
      # remove files not in csv table
132
      if (mode == "label_csv") {
133
        csv_file <- utils::read.csv2(target_from_csv, header = TRUE, stringsAsFactors = FALSE)
134
        if (dim(csv_file)[2] == 1) {
135
          csv_file <- utils::read.csv(target_from_csv, header = TRUE, stringsAsFactors = FALSE)
136
        }
137
        index <- basename(files) %in% csv_file$file
138
        files <- files[index]
139
        if (length(files) == 0) {
140
          stop("No files from path_input have label in target_from_csv file.")
141
        }
142
      }
143
      
144
      for (file in files) {
145
        if (format == "fasta") {
146
          fasta_file <- microseq::readFasta(file)
147
        } else {
148
          fasta_file <- microseq::readFastq(file)
149
        }
150
        
151
        # remove entries with wrong header
152
        if (mode == "label_header") {
153
          index <- fasta_file$Header %in% vocabulary_label
154
          fasta_file <- fasta_file[index, ]
155
        }
156
        
157
        seq_vector <- fasta_file$Sequence
158
        
159
        if (!is.null(concat_seq)) {
160
          seq_vector <- paste(seq_vector, collapse = concat_seq)
161
        }
162
        
163
        if (!is.null(proportion_per_seq)) {
164
          fasta_width <- nchar(seq_vector)
165
          sample_range <- floor(fasta_width - (proportion_per_seq * fasta_width))
166
          start <- mapply(sample_range, FUN = sample, size = 1)
167
          perc_length <- floor(fasta_width * proportion_per_seq)
168
          stop <- start + perc_length
169
          seq_vector <- mapply(seq_vector, FUN = substr, start = start, stop = stop)
170
        }
171
        
172
        if (mode == "lm") {
173
          if (!padding) {
174
            seq_vector <- seq_vector[nchar(seq_vector) >= (maxlen + 1)]
175
          } else {
176
            length_vector <- nchar(seq_vector)
177
            short_seq_index <- which(length_vector < (maxlen + 1))
178
            for (ssi in short_seq_index) {
179
              seq_vector[ssi] <- paste0(paste(rep("0", (maxlen + 1) - length_vector[ssi]), collapse = ""), seq_vector[ssi])
180
            }
181
          }
182
        } else {
183
          if (!padding) {
184
            seq_vector <- seq_vector[nchar(seq_vector) >= (maxlen)]
185
          } else {
186
            length_vector <- nchar(seq_vector)
187
            short_seq_index <- which(length_vector < (maxlen))
188
            for (ssi in short_seq_index) {
189
              seq_vector[ssi] <- paste0(paste(rep("0", (maxlen) - length_vector[ssi]), collapse = ""), seq_vector[ssi])
190
            }
191
          }
192
        }
193
        
194
        if (length(seq_vector) == 0) next
195
        new_samples <- get_start_ind(seq_vector = seq_vector,
196
                                     length_vector = nchar(seq_vector),
197
                                     maxlen = maxlen,
198
                                     step = step,
199
                                     train_mode = ifelse(mode == "lm", "lm", "label"),
200
                                     discard_amb_nuc = ifelse(ambiguous_nuc == "discard", TRUE, FALSE),
201
                                     vocabulary = vocabulary
202
        ) %>% length()
203
        
204
        if (is.null(max_samples)) {
205
          num_samples[i] <- num_samples[i] + new_samples
206
        } else {
207
          num_samples[i] <- num_samples[i] + min(new_samples, max_samples)
208
        }
209
      }
210
      number_batches[i] <- ceiling(num_samples[i]/batch_size)
211
      
212
    }
213
    if (mode == "label_folder") {
214
      message_string <- paste0("Evaluate ", num_samples, " samples for class ", vocabulary_label[[1]], ".\n")
215
    } else {
216
      message_string <- paste0("Evaluate ", sum(num_samples), " samples.")
217
    }
218
    message(message_string)
219
  }
220
  
221
  if (evaluate_all_files & format == "rds") {
222
    rds_files <- list_fasta_files(path_corpus = path_input,
223
                                  format = "rds",
224
                                  file_filter = NULL)
225
    num_samples <- 0
226
    for (file in rds_files) {
227
      rds_file <- readRDS(file)
228
      x <- rds_file[[1]]
229
      while (is.list(x)) {
230
        x <- x[[1]]
231
      }
232
      num_samples <- dim(x)[1] + num_samples
233
    }
234
    number_batches <- ceiling(num_samples/batch_size)
235
    message_string <- paste0("Evaluate ", num_samples, " samples.")
236
    message(message_string)
237
  }
238
  
239
  if (!is.null(exact_num_samples)) {
240
    num_samples <- exact_num_samples
241
    number_batches <- ceiling(num_samples/batch_size)
242
  }
243
  
244
  overall_num_batches <- sum(number_batches)
245
  
246
  if (mode == "lm") {
247
    gen <- generator_fasta_lm(path_corpus = path_input,
248
                              format = format,
249
                              batch_size = batch_size,
250
                              maxlen = maxlen,
251
                              max_iter = max_iter,
252
                              vocabulary = vocabulary,
253
                              verbose = FALSE,
254
                              shuffle_file_order = shuffle_file_order,
255
                              step = step,
256
                              concat_seq = concat_seq,
257
                              padding = padding,
258
                              shuffle_input = FALSE,
259
                              reverse_complement = FALSE,
260
                              output_format = output_format,
261
                              ambiguous_nuc = ambiguous_nuc,
262
                              proportion_per_seq = proportion_per_seq,
263
                              max_samples = max_samples,
264
                              seed = seed,
265
                              ...)
266
  }
267
  
268
  if (mode == "label_header" | mode == "label_csv") {
269
    gen <- generator_fasta_label_header_csv(path_corpus = path_input,
270
                                            format = format,
271
                                            batch_size = batch_size,
272
                                            maxlen = maxlen,
273
                                            max_iter = max_iter,
274
                                            vocabulary = vocabulary,
275
                                            verbose = FALSE,
276
                                            shuffle_file_order = shuffle_file_order,
277
                                            step = step,
278
                                            padding = padding,
279
                                            shuffle_input = FALSE,
280
                                            concat_seq = concat_seq,
281
                                            vocabulary_label = vocabulary_label[[1]],
282
                                            reverse_complement = FALSE,
283
                                            ambiguous_nuc = ambiguous_nuc,
284
                                            target_from_csv = target_from_csv,
285
                                            proportion_per_seq = proportion_per_seq,
286
                                            max_samples = max_samples,
287
                                            seed = seed, ...)
288
  }
289
  
290
  if (mode == "label_rds" | mode == "lm_rds") {
291
    gen <- generator_rds(rds_folder = path_input, batch_size = batch_size, path_file_log = NULL, ...)
292
  }
293
  
294
  batch_index <- 1
295
  start_time <- Sys.time()
296
  ten_percent_steps <- seq(overall_num_batches/10, overall_num_batches, length.out = 10)
297
  percentage_index <- 1
298
  count <- 1
299
  y_conf_list <- vector("list", overall_num_batches)
300
  y_list <- vector("list", overall_num_batches)
301
  if (include_seq) {
302
    x_list <- vector("list", overall_num_batches)
303
  }
304
  
305
  for (k in 1:num_classes) {
306
    
307
    index <- NULL
308
    if (mode == "label_folder") {
309
      gen <- generator_fasta_label_folder(path_corpus = path_input[[k]],
310
                                          format = format,
311
                                          batch_size = batch_size,
312
                                          maxlen = maxlen,
313
                                          max_iter = max_iter,
314
                                          vocabulary = vocabulary,
315
                                          step = step,
316
                                          padding = padding,
317
                                          concat_seq = concat_seq,
318
                                          reverse_complement = FALSE,
319
                                          num_targets = length(path_input),
320
                                          ones_column = k,
321
                                          ambiguous_nuc = ambiguous_nuc,
322
                                          proportion_per_seq = proportion_per_seq,
323
                                          max_samples = max_samples,
324
                                          seed = seed, ...)
325
    }
326
    
327
    for (i in 1:number_batches[k]) {
328
      z <- gen()
329
      x <- z[[1]]
330
      y <- z[[2]]
331
      
332
      y_conf <- model(x)
333
      batch_index <- batch_index + 1
334
      
335
      # remove double predictions
336
      if (eval_exact_num_samples & (i == number_batches[k])) {
337
        double_index <- (i * batch_size) - num_samples[k]
338
        
339
        if (double_index > 0) {
340
          index <- 1:(nrow(y_conf) - double_index)
341
          
342
          if (is.list(y_conf)) {
343
            for (m in 1:length(y_conf)) {
344
              y_conf[[m]] <- y_conf[[m]][index, ]
345
              y[[m]] <- y[[m]][index, ]
346
            }
347
          } else {
348
            y_conf <- y_conf[index, ]
349
            y <- y[index, ]
350
          }
351
          
352
          # vector to matrix
353
          if (length(index) == 1) {
354
            if (is.list(y_conf)) {
355
              for (m in 1:length(y_conf)) {
356
                y_conf[[m]] <- array(as.array(y_conf[[m]]), dim = c(1, length(y_conf[[m]])))
357
                y[[m]] <- matrix(y[[m]], ncol = length(y[[m]]))
358
              }
359
            } else {
360
              y_conf <- array(as.array(y_conf), dim = c(1, length(y_conf)))
361
              y <- matrix(y, ncol = length(y))
362
            }
363
          }
364
          
365
        }
366
      }
367
      
368
      if (include_seq) {
369
        x_list[[count]] <- x
370
      }
371
      y_conf_list[[count]] <- y_conf
372
      if (batch_size == 1 | (!is.null(index) && length(index == 1))) {
373
        col_num <- ncol(y_conf)
374
        if (is.na(col_num)) col_num <- length(y_conf)
375
        y_list[[count]] <- matrix(y, ncol = col_num)
376
      } else {
377
        y_list[[count]] <- y
378
      }
379
      count <- count + 1
380
      
381
      if (verbose & (batch_index == 10)) {
382
        time_passed <- as.double(difftime(Sys.time(), start_time, units = "hours"))
383
        time_estimation <- (overall_num_batches/10) * time_passed
384
        cat("Evaluation will take approximately", round(time_estimation, 3), "hours. Starting time:", format(Sys.time(), "%F %R."), " \n")
385
        
386
      }
387
      
388
      if (verbose & (batch_index > ten_percent_steps[percentage_index]) & percentage_index < 10) {
389
        cat("Progress: ", percentage_index * 10 ,"% \n")
390
        time_passed <- as.double(difftime(Sys.time(), start_time, units = "hours"))
391
        cat("Time passed: ", round(time_passed, 3), "hours \n")
392
        percentage_index <- percentage_index + 1
393
      }
394
      
395
    }
396
  }
397
  
398
  if (verbose) {
399
    cat("Progress: 100 % \n")
400
    time_passed <- as.double(difftime(Sys.time(), start_time, units = "hours"))
401
    cat("Time passed: ", round(time_passed, 3), "hours \n")
402
  }
403
  
404
  y_conf_list <- reshape_y_list(y_conf_list, num_out_layers = num_out_layers, tf_format = TRUE)
405
  y_list <- reshape_y_list(y_list, num_out_layers = num_out_layers, tf_format = FALSE)
406
  
407
  if (!is.null(path_pred_list)) {
408
    if (include_seq) {
409
      if (is.list(x_list[[1]])) {
410
        num_layers <- length(x_list[[1]])
411
      } else {
412
        num_layers <- 1 
413
      }
414
      x_list <- reshape_y_list(x_list, num_out_layers = num_layers, tf_format = FALSE)
415
      saveRDS(list(pred = y_conf_list, true = y_list, x = x_list), path_pred_list)
416
    } else {
417
      saveRDS(list(pred = y_conf_list, true = y_list), path_pred_list)
418
    }
419
  } 
420
  
421
  eval_list <- list()
422
  for (i in 1:num_out_layers) {
423
    
424
    if (activations[i] == "softmax") {
425
      eval_list[[i]] <- evaluate_softmax(y = y_list[[i]], y_conf = y_conf_list[[i]],
426
                                         auc = auc, auprc = auprc,
427
                                         label_names = vocabulary_label[[i]])
428
    }
429
    
430
    if (activations[i] == "sigmoid") {
431
      eval_list[[i]] <- evaluate_sigmoid(y = y_list[[i]], y_conf = y_conf_list[[i]],
432
                                         auc = auc, auprc = auprc,
433
                                         label_names = vocabulary_label[[i]])
434
    }
435
    
436
    if (activations[i] == "linear") {
437
      eval_list[[i]] <- evaluate_linear(y_true = y_list[[i]], y_pred = y_conf_list[[i]], label_names = vocabulary_label[[i]])
438
    }
439
    
440
  }
441
  
442
  return(eval_list)
443
}
444
445
446
reshape_y_list <- function(y, num_out_layers, tf_format = TRUE) {
447
  
448
  if (num_out_layers > 1) {
449
    y <- do.call(c, y)
450
  }
451
  
452
  reshaped_list <- vector("list", num_out_layers)
453
  
454
  for (i in 1:num_out_layers) {
455
    index <- seq(i, length(y), by = num_out_layers)
456
    if (tf_format) {
457
      reshaped_list[[i]] <- y[index] %>%
458
        tensorflow::tf$concat(axis = 0L) %>%
459
        keras::k_eval()
460
    } else {
461
      reshaped_list[[i]] <- do.call(rbind, y[index])
462
    }
463
  }
464
  return(reshaped_list)
465
}
466
467
#' Evaluate matrices of true targets and predictions from layer with softmax activation. 
468
#' 
469
#' Compute confusion matrix, accuracy, categorical crossentropy and (optionally) AUC or AUPRC, given predictions and
470
#' true targets. AUC and AUPRC only possible for 2 targets. 
471
#' 
472
#' @param y Matrix of true target.
473
#' @param y_conf Matrix of predictions.
474
#' @param auc Whether to include AUC metric. Only possible for 2 targets. 
475
#' @param auprc Whether to include AUPRC metric. Only possible for 2 targets. 
476
#' @param label_names Names of corresponding labels. Length must be equal to number of columns of \code{y}.
477
#' @examplesIf reticulate::py_module_available("tensorflow")
478
#' y <- matrix(c(1, 0, 0, 0, 1, 1), ncol = 2)
479
#' y_conf <- matrix(c(0.3, 0.5, 0.1, 0.7, 0.5, 0.9), ncol = 2)
480
#' evaluate_softmax(y, y_conf, auc = TRUE, auprc = TRUE, label_names = c("A", "B")) 
481
#' 
482
#' @returns A list of evaluation results. 
483
#' @export    
484
evaluate_softmax <- function(y, y_conf, auc = FALSE, auprc = FALSE, label_names = NULL) {
485
  
486
  if (ncol(y) != 2 & (auc | auprc)) {
487
    message("Can only compute AUC or AUPRC if output layer with softmax acticvation has two neurons.")
488
    auc <- FALSE
489
    auprc <- FALSE
490
  }
491
  
492
  y_pred <- apply(y_conf, 1, which.max)
493
  y_true <- apply(y, 1, FUN = which.max) - 1
494
  
495
  df_true_pred <- data.frame(
496
    true = factor(y_true + 1, levels = 1:(length(label_names)), labels = label_names),
497
    pred = factor(y_pred, levels = 1:(length(label_names)), labels = label_names)
498
  )
499
  
500
  loss_per_class <- list()
501
  for (i in 1:ncol(y)) {
502
    index <- (y_true + 1) == i
503
    if (any(index)) {
504
      cce_loss_class <- tensorflow::tf$keras$losses$categorical_crossentropy(y[index, ], y_conf[index, ])
505
      loss_per_class[[i]] <- cce_loss_class$numpy()
506
    } else {
507
      loss_per_class[[i]] <- NA
508
    }
509
  }
510
  
511
  cm <- yardstick::conf_mat(df_true_pred, true, pred)
512
  confMat <- cm[[1]]
513
  
514
  acc <- sum(diag(confMat))/sum(confMat)
515
  loss <- mean(unlist(loss_per_class))
516
  
517
  for (i in 1:length(loss_per_class)) {
518
    loss_per_class[[i]] <- mean(unlist(loss_per_class[[i]]), na.rm = TRUE)
519
  }
520
  
521
  loss_per_class <- unlist(loss_per_class)
522
  m <- as.matrix(confMat)
523
  class_acc <- vector("numeric")
524
  for (i in 1:ncol(m)) {
525
    if (sum(m[ , i]) == 0) {
526
      class_acc[i] <- NA
527
    } else {
528
      class_acc[i] <- m[i, i]/sum(m[ , i])
529
    }
530
  }
531
  names(class_acc) <- label_names
532
  names(loss_per_class) <- label_names
533
  balanced_acc <- mean(class_acc)
534
  
535
  if (auc) {
536
    auc_list <- PRROC::roc.curve(
537
      scores.class0 = y_conf[ , 2],
538
      weights.class0 = y_true)
539
  } else {
540
    auc_list <- NULL
541
  }
542
  
543
  if (auprc) {
544
    auprc_list <- PRROC::pr.curve(
545
      scores.class0 = y_conf[ , 2],
546
      weights.class0 = y_true)
547
  } else {
548
    auprc_list <- NULL
549
  }
550
  
551
  return(list(confusion_matrix = confMat,
552
              accuracy = acc,
553
              categorical_crossentropy_loss = loss,
554
              #balanced_accuracy = balanced_acc,
555
              #loss_per_class = loss_per_class,
556
              #accuracy_per_class = class_acc,
557
              AUC = auc_list$auc,
558
              AUPRC = auprc_list$auc.integral))
559
}
560
561
#' Evaluate matrices of true targets and predictions from layer with sigmoid activation. 
562
#' 
563
#' Compute accuracy, binary crossentropy and (optionally) AUC or AUPRC, given predictions and
564
#' true targets. Outputs columnwise average.  
565
#' 
566
#' @inheritParams evaluate_model
567
#' @inheritParams evaluate_softmax
568
#' @param auc Whether to include AUC metric.
569
#' @param auprc Whether to include AUPRC metric. 
570
#' @examplesIf reticulate::py_module_available("tensorflow")
571
#' y <- matrix(sample(c(0, 1), 30, replace = TRUE), ncol = 3)
572
#' y_conf <- matrix(runif(n = 30), ncol = 3)
573
#' evaluate_sigmoid(y, y_conf, auc = TRUE, auprc = TRUE)
574
#' 
575
#' @returns A list of evaluation results. 
576
#' @export    
577
evaluate_sigmoid <- function(y, y_conf, auc = FALSE, auprc = FALSE, label_names = NULL) {
578
  
579
  y_pred <- ifelse(y_conf > 0.5, 1, 0)
580
  
581
  loss_per_class <- list()
582
  for (i in 1:ncol(y)) {
583
    bce_loss_class <- tensorflow::tf$keras$losses$binary_crossentropy(y[ , i], y_conf[ , i])
584
    loss_per_class[[i]] <- bce_loss_class$numpy()
585
  }
586
  
587
  loss_per_class <- unlist(loss_per_class)
588
  names(loss_per_class) <- label_names
589
  loss <- mean(unlist(loss_per_class))
590
  
591
  class_acc <- vector("numeric", ncol(y))
592
  for (i in 1:ncol(y)) {
593
    num_true_pred <-  sum(y[ , i] == y_pred[ , i])
594
    class_acc[i] <- num_true_pred /nrow(y)
595
  }
596
  names(class_acc) <- label_names
597
  acc <- mean(class_acc)
598
  
599
  if (auc) {
600
    auc_list <- purrr::map(1:ncol(y_conf), ~PRROC::roc.curve(
601
      scores.class0 = y_conf[ , .x],
602
      weights.class0 = y[ , .x]))
603
    auc_vector <- vector("numeric", ncol(y))
604
    for (i in 1:length(auc_vector)) {
605
      auc_vector[i] <- auc_list[[i]]$auc
606
    }
607
    
608
    na_count <- sum(is.na(auc_vector))
609
    if (na_count > 0) {
610
      message(paste(sum(na_count), ifelse(na_count > 1, "columns", "column"),
611
                    "removed from AUC evaluation since they contain only one label"))
612
    }
613
    AUC <- mean(auc_vector, na.rm = TRUE)
614
  } else {
615
    AUC <- NULL
616
  }
617
  
618
  if (auprc) {
619
    auprc_list <- purrr::map(1:ncol(y_conf), ~PRROC::pr.curve(
620
      scores.class0 = y_conf[ , .x],
621
      weights.class0 = y[ , .x]))
622
    auprc_vector <- vector("numeric", ncol(y))
623
    for (i in 1:length(auprc_vector)) {
624
      auprc_vector[i] <- auprc_list[[i]]$auc.integral
625
    }
626
    AUPRC <- mean(auprc_vector, na.rm = TRUE) 
627
  } else {
628
    AUPRC <- NULL
629
  }
630
  
631
  return(list(accuracy = acc,
632
              binary_crossentropy_loss = loss,
633
              #loss_per_class = loss_per_class,
634
              #accuracy_per_class = class_acc,
635
              AUC = AUC,
636
              AUPRC = AUPRC))
637
  
638
}
639
640
#' Evaluate matrices of true targets and predictions from layer with linear activation. 
641
#' 
642
#' Compute MAE and MSE, given predictions and
643
#' true targets. Outputs columnwise average.  
644
#' 
645
#' @inheritParams evaluate_model
646
#' @inheritParams evaluate_softmax
647
#' @param y_true Matrix of true labels.
648
#' @param y_pred Matrix of predictions.
649
#' @examplesIf reticulate::py_module_available("tensorflow")
650
#' y_true <- matrix(rnorm(n = 12), ncol = 3)
651
#' y_pred <- matrix(rnorm(n = 12), ncol = 3)
652
#' evaluate_linear(y_true, y_pred)
653
#' 
654
#' @returns A list of evaluation results. 
655
#' @export    
656
evaluate_linear <- function(y_true, y_pred, label_names = NULL) {
657
  
658
  loss_per_class_mse <- list()
659
  loss_per_class_mae <- list()
660
  for (i in 1:ncol(y_true)) {
661
    mse_loss_class <- tensorflow::tf$keras$losses$mean_squared_error(y_true[ ,i], y_pred[ , i])
662
    mae_loss_class <- tensorflow::tf$keras$losses$mean_absolute_error(y_true[ ,i], y_pred[ , i])
663
    loss_per_class_mse[[i]] <- mse_loss_class$numpy()
664
    loss_per_class_mae[[i]] <- mae_loss_class$numpy()
665
  }
666
  
667
  return(list(mse = mean(unlist(loss_per_class_mse)),
668
              mae = mean(unlist(loss_per_class_mae))))
669
  
670
}
671
672
673
#' Plot ROC
674
#' 
675
#' Compute ROC and AUC from target and prediction matrix and plot ROC. Target/prediction matrix should 
676
#' have one column if output of layer with sigmoid activation and two columns for softmax activation. 
677
#' 
678
#' @inheritParams evaluate_softmax
679
#' @inheritParams evaluate_linear
680
#' @param path_roc_plot Where to store ROC plot.
681
#' @param return_plot Whether to return plot.
682
#' @examples
683
#' y_true <- matrix(c(1, 0, 0, 0, 1, 1), ncol = 1)
684
#' y_conf <- matrix(runif(n = nrow(y_true)), ncol = 1)
685
#' p <- plot_roc(y_true, y_conf, return_plot = TRUE)
686
#' p
687
#' 
688
#' @returns A ggplot of ROC curve.
689
#' @export    
690
plot_roc <- function(y_true, y_conf, path_roc_plot = NULL,
691
                     return_plot = TRUE) {
692
  
693
  if (!all(y_true == 0 | y_true == 1)) {
694
    stop("y_true should only contain 0 and 1 entries")
695
  }
696
  
697
  if (is.matrix(y_true) && ncol(y_true) > 2) {
698
    stop("y_true can contain 1 or 2 columns")
699
  }
700
  
701
  if (is.matrix(y_true) && ncol(y_true) == 2) {
702
    y_true <- y_true[ , 1] 
703
    y_conf <- y_conf[ , 2]
704
  }
705
  
706
  if (stats::var(y_true) == 0) {
707
    stop("y_true contains just one label")
708
  }
709
  
710
  y_true <- as.vector(y_true)
711
  y_conf <- as.vector(y_conf)
712
  
713
  rocobj <-  pROC::roc(y_true, y_conf, quiet = TRUE)
714
  auc <- round(pROC::auc(y_true, y_conf, quiet = TRUE), 4)
715
  p <- pROC::ggroc(rocobj,  size = 1, color = "black")
716
  p <- p + ggplot2::theme_classic() + ggplot2::theme(aspect.ratio = 1) 
717
  p <- p + ggplot2::ggtitle(paste0('ROC Curve ', '(AUC = ', auc, ')'))
718
  p <- p + ggplot2::geom_abline(intercept = 1, linetype = 2, color = "grey50")
719
  p <- p + ggplot2::geom_vline(xintercept = 1, linetype = 2, color = "grey50")
720
  p <- p + ggplot2::geom_hline(yintercept = 1,  linetype = 2, color = "grey50")
721
  
722
  if (!is.null(path_roc_plot)) {
723
    ggplot2::ggsave(path_roc_plot, p)
724
  }
725
  
726
  if (return_plot) {
727
    return(p)
728
  } else {
729
    return(NULL)
730
  }
731
  
732
}
733
734
# plot_roc_auprc <- function(y_true, y_conf, path_roc_plot = NULL, path_auprc_plot = NULL,
735
#                            return_plot = TRUE, layer_activation = "softmax") {
736
#   
737
#   if (layer_activation == "softmax") {
738
#     
739
#     if (!all(y_true == 0 | y_true == 1)) {
740
#       stop("y_true should only contain 0 and 1 entries")
741
#     }
742
#     
743
#     if (ncol(y_true) != 2 & (auc | auprc)) {
744
#       message("Can only compute AUC or AUPRC if output layer with softmax acticvation has two neurons.")
745
#     }
746
#     
747
#     auc_list <- PRROC::roc.curve(
748
#       scores.class0 = y_conf[ , 2],
749
#       weights.class0 = y_true[ , 2], curve = TRUE)
750
#     
751
#     
752
#     auprc_list <- PRROC::pr.curve(
753
#       scores.class0 = y_conf[ , 2],
754
#       weights.class0 = y_true[ , 2], curve = TRUE)
755
#     
756
#     #auc_plot <- NULL
757
#     #auprc_plot <- NULL  
758
#     
759
#   }
760
#   
761
#   if (layer_activation == "sigmoid") {
762
#     
763
#     auc_list <- purrr::map(1:ncol(y_conf), ~PRROC::roc.curve(
764
#       scores.class0 = y_conf[ , .x],
765
#       weights.class0 = y[ , .x], curve = TRUE))
766
#     auc_vector <- vector("numeric", ncol(y))
767
#     
768
#     
769
#     auprc_list <- purrr::map(1:ncol(y_conf), ~PRROC::pr.curve(
770
#       scores.class0 = y_conf[ , .x],
771
#       weights.class0 = y[ , .x], curve = TRUE))
772
#     auprc_vector <- vector("numeric", ncol(y))
773
#     
774
#   }
775
#   
776
#   if (!is.null(path_roc_plot)) {
777
#     
778
#   }
779
#   
780
#   if (!is.null(path_auprc_plot)) {
781
#     
782
#   }
783
#   
784
# }