|
a |
|
b/R/evaluation.R |
|
|
1 |
#' Evaluates a trained model on fasta, fastq or rds files |
|
|
2 |
#' |
|
|
3 |
#' Returns evaluation metric like confusion matrix, loss, AUC, AUPRC, MAE, MSE (depending on output layer). |
|
|
4 |
#' |
|
|
5 |
#' @inheritParams generator_fasta_lm |
|
|
6 |
#' @inheritParams generator_fasta_label_folder |
|
|
7 |
#' @inheritParams generator_fasta_label_header_csv |
|
|
8 |
#' @param path_input Input directory where fasta, fastq or rds files are located. |
|
|
9 |
#' @param model A keras model. |
|
|
10 |
#' @param batch_size Number of samples per batch. |
|
|
11 |
#' @param step How often to take a sample. |
|
|
12 |
#' @param vocabulary Vector of allowed characters. Character outside vocabulary get encoded as specified in ambiguous_nuc. |
|
|
13 |
#' @param vocabulary_label List of labels for targets of each output layer. |
|
|
14 |
#' @param number_batches How many batches to evaluate. |
|
|
15 |
#' @param format File format, `"fasta"`, `"fastq"` or `"rds"`. |
|
|
16 |
#' @param mode Either `"lm"` for language model or `"label_header"`, `"label_csv"` or `"label_folder"` for label classification. |
|
|
17 |
#' @param verbose Boolean. |
|
|
18 |
#' @param target_middle Whether model is language model with separate input layers. |
|
|
19 |
#' @param evaluate_all_files Boolean, if `TRUE` will iterate over all files in \code{path_input} once. \code{number_batches} will be overwritten. |
|
|
20 |
#' @param auc Whether to include AUC metric. If output layer activation is `"softmax"`, only possible for 2 targets. Computes the average if output layer has sigmoid |
|
|
21 |
#' activation and multiple targets. |
|
|
22 |
#' @param auprc Whether to include AUPRC metric. If output layer activation is `"softmax"`, only possible for 2 targets. Computes the average if output layer has sigmoid |
|
|
23 |
#' activation and multiple targets. |
|
|
24 |
#' @param path_pred_list Path to store list of predictions (output of output layers) and corresponding true labels as rds file. |
|
|
25 |
#' @param exact_num_samples Exact number of samples to evaluate. If you want to evaluate a number of samples not divisible by batch_size. Useful if you want |
|
|
26 |
#' to evaluate a data set exactly ones and know the number of samples already. Should be a vector if `mode = "label_folder"` (with same length as `vocabulary_label`) |
|
|
27 |
#' and else an integer. |
|
|
28 |
#' @param activations List containing output formats for output layers (`softmax, sigmoid` or `linear`). If `NULL`, will be estimated from model. |
|
|
29 |
#' @param include_seq Whether to store input. Only applies if `path_pred_list` is not `NULL`. |
|
|
30 |
#' @param ... Further generator options. See \code{\link{get_generator}}. |
|
|
31 |
#' @examplesIf reticulate::py_module_available("tensorflow") |
|
|
32 |
#' # create dummy data |
|
|
33 |
#' path_input <- tempfile() |
|
|
34 |
#' dir.create(path_input) |
|
|
35 |
#' create_dummy_data(file_path = path_input, |
|
|
36 |
#' num_files = 3, |
|
|
37 |
#' seq_length = 11, |
|
|
38 |
#' num_seq = 5, |
|
|
39 |
#' vocabulary = c("a", "c", "g", "t")) |
|
|
40 |
#' # create model |
|
|
41 |
#' model <- create_model_lstm_cnn(layer_lstm = 8, layer_dense = 4, maxlen = 10, verbose = FALSE) |
|
|
42 |
#' # evaluate |
|
|
43 |
#' evaluate_model(path_input = path_input, |
|
|
44 |
#' model = model, |
|
|
45 |
#' step = 11, |
|
|
46 |
#' vocabulary = c("a", "c", "g", "t"), |
|
|
47 |
#' vocabulary_label = list(c("a", "c", "g", "t")), |
|
|
48 |
#' mode = "lm", |
|
|
49 |
#' output_format = "target_right", |
|
|
50 |
#' evaluate_all_files = TRUE, |
|
|
51 |
#' verbose = FALSE) |
|
|
52 |
#' |
|
|
53 |
#' @returns A list of evaluation results. Each list element corresponds to an output layer of the model. |
|
|
54 |
#' @export |
|
|
55 |
evaluate_model <- function(path_input, |
|
|
56 |
model = NULL, |
|
|
57 |
batch_size = 100, |
|
|
58 |
step = 1, |
|
|
59 |
padding = FALSE, |
|
|
60 |
vocabulary = c("a", "c", "g", "t"), |
|
|
61 |
vocabulary_label = list(c("a", "c", "g", "t")), |
|
|
62 |
number_batches = 10, |
|
|
63 |
format = "fasta", |
|
|
64 |
target_middle = FALSE, |
|
|
65 |
mode = "lm", |
|
|
66 |
output_format = "target_right", |
|
|
67 |
ambiguous_nuc = "zero", |
|
|
68 |
evaluate_all_files = FALSE, |
|
|
69 |
verbose = TRUE, |
|
|
70 |
max_iter = 20000, |
|
|
71 |
target_from_csv = NULL, |
|
|
72 |
max_samples = NULL, |
|
|
73 |
proportion_per_seq = NULL, |
|
|
74 |
concat_seq = NULL, |
|
|
75 |
seed = 1234, |
|
|
76 |
auc = FALSE, |
|
|
77 |
auprc = FALSE, |
|
|
78 |
path_pred_list = NULL, |
|
|
79 |
exact_num_samples = NULL, |
|
|
80 |
activations = NULL, |
|
|
81 |
shuffle_file_order = FALSE, |
|
|
82 |
include_seq = FALSE, |
|
|
83 |
...) { |
|
|
84 |
|
|
|
85 |
set.seed(seed) |
|
|
86 |
path_model <- NULL |
|
|
87 |
stopifnot(mode %in% c("lm", "label_header", "label_folder", "label_csv", "lm_rds", "label_rds")) |
|
|
88 |
stopifnot(format %in% c("fasta", "fastq", "rds")) |
|
|
89 |
stopifnot(is.null(proportion_per_seq) || proportion_per_seq <= 1) |
|
|
90 |
if (!is.null(exact_num_samples) & evaluate_all_files) { |
|
|
91 |
warning(paste("Will evaluate number of samples as specified in exact_num_samples argument. Setting evaluate_all_files to FALSE.")) |
|
|
92 |
evaluate_all_files <- FALSE |
|
|
93 |
} |
|
|
94 |
eval_exact_num_samples <- !is.null(exact_num_samples) | evaluate_all_files |
|
|
95 |
if (is.null(activations)) activations <- get_output_activations(model) |
|
|
96 |
if (is.null(path_pred_list) & include_seq) { |
|
|
97 |
stop("Can only store input, if path_pred_list is specified.") |
|
|
98 |
} |
|
|
99 |
if (is.null(vocabulary_label)) vocabulary_label <- list(vocabulary) |
|
|
100 |
if (!is.list(vocabulary_label)) vocabulary_label <- list(vocabulary_label) |
|
|
101 |
if (mode == "label_folder") { |
|
|
102 |
number_batches <- rep(ceiling(number_batches/length(path_input)), length(path_input)) |
|
|
103 |
} |
|
|
104 |
num_classes <- ifelse(mode == "label_folder", length(path_input), 1) |
|
|
105 |
num_out_layers <- length(activations) |
|
|
106 |
|
|
|
107 |
# extract maxlen from model |
|
|
108 |
num_in_layers <- length(model$inputs) |
|
|
109 |
if (num_in_layers == 1) { |
|
|
110 |
maxlen <- model$input$shape[[2]] |
|
|
111 |
} else { |
|
|
112 |
if (!target_middle) { |
|
|
113 |
maxlen <- model$input[[num_in_layers]]$shape[[2]] |
|
|
114 |
} else { |
|
|
115 |
maxlen <- model$input[[num_in_layers - 1]]$shape[[2]] + model$input[[num_in_layers]]$shape[[2]] |
|
|
116 |
} |
|
|
117 |
} |
|
|
118 |
|
|
|
119 |
if (evaluate_all_files & (format %in% c("fasta", "fastq"))) { |
|
|
120 |
|
|
|
121 |
number_batches <- NULL |
|
|
122 |
num_samples <- rep(0, length(path_input)) |
|
|
123 |
|
|
|
124 |
for (i in 1:num_classes) { |
|
|
125 |
if (mode == "label_folder") { |
|
|
126 |
files <- list_fasta_files(path_input[[i]], format = format, file_filter = NULL) |
|
|
127 |
} else { |
|
|
128 |
files <- list_fasta_files(path_input, format = format, file_filter = NULL) |
|
|
129 |
} |
|
|
130 |
|
|
|
131 |
# remove files not in csv table |
|
|
132 |
if (mode == "label_csv") { |
|
|
133 |
csv_file <- utils::read.csv2(target_from_csv, header = TRUE, stringsAsFactors = FALSE) |
|
|
134 |
if (dim(csv_file)[2] == 1) { |
|
|
135 |
csv_file <- utils::read.csv(target_from_csv, header = TRUE, stringsAsFactors = FALSE) |
|
|
136 |
} |
|
|
137 |
index <- basename(files) %in% csv_file$file |
|
|
138 |
files <- files[index] |
|
|
139 |
if (length(files) == 0) { |
|
|
140 |
stop("No files from path_input have label in target_from_csv file.") |
|
|
141 |
} |
|
|
142 |
} |
|
|
143 |
|
|
|
144 |
for (file in files) { |
|
|
145 |
if (format == "fasta") { |
|
|
146 |
fasta_file <- microseq::readFasta(file) |
|
|
147 |
} else { |
|
|
148 |
fasta_file <- microseq::readFastq(file) |
|
|
149 |
} |
|
|
150 |
|
|
|
151 |
# remove entries with wrong header |
|
|
152 |
if (mode == "label_header") { |
|
|
153 |
index <- fasta_file$Header %in% vocabulary_label |
|
|
154 |
fasta_file <- fasta_file[index, ] |
|
|
155 |
} |
|
|
156 |
|
|
|
157 |
seq_vector <- fasta_file$Sequence |
|
|
158 |
|
|
|
159 |
if (!is.null(concat_seq)) { |
|
|
160 |
seq_vector <- paste(seq_vector, collapse = concat_seq) |
|
|
161 |
} |
|
|
162 |
|
|
|
163 |
if (!is.null(proportion_per_seq)) { |
|
|
164 |
fasta_width <- nchar(seq_vector) |
|
|
165 |
sample_range <- floor(fasta_width - (proportion_per_seq * fasta_width)) |
|
|
166 |
start <- mapply(sample_range, FUN = sample, size = 1) |
|
|
167 |
perc_length <- floor(fasta_width * proportion_per_seq) |
|
|
168 |
stop <- start + perc_length |
|
|
169 |
seq_vector <- mapply(seq_vector, FUN = substr, start = start, stop = stop) |
|
|
170 |
} |
|
|
171 |
|
|
|
172 |
if (mode == "lm") { |
|
|
173 |
if (!padding) { |
|
|
174 |
seq_vector <- seq_vector[nchar(seq_vector) >= (maxlen + 1)] |
|
|
175 |
} else { |
|
|
176 |
length_vector <- nchar(seq_vector) |
|
|
177 |
short_seq_index <- which(length_vector < (maxlen + 1)) |
|
|
178 |
for (ssi in short_seq_index) { |
|
|
179 |
seq_vector[ssi] <- paste0(paste(rep("0", (maxlen + 1) - length_vector[ssi]), collapse = ""), seq_vector[ssi]) |
|
|
180 |
} |
|
|
181 |
} |
|
|
182 |
} else { |
|
|
183 |
if (!padding) { |
|
|
184 |
seq_vector <- seq_vector[nchar(seq_vector) >= (maxlen)] |
|
|
185 |
} else { |
|
|
186 |
length_vector <- nchar(seq_vector) |
|
|
187 |
short_seq_index <- which(length_vector < (maxlen)) |
|
|
188 |
for (ssi in short_seq_index) { |
|
|
189 |
seq_vector[ssi] <- paste0(paste(rep("0", (maxlen) - length_vector[ssi]), collapse = ""), seq_vector[ssi]) |
|
|
190 |
} |
|
|
191 |
} |
|
|
192 |
} |
|
|
193 |
|
|
|
194 |
if (length(seq_vector) == 0) next |
|
|
195 |
new_samples <- get_start_ind(seq_vector = seq_vector, |
|
|
196 |
length_vector = nchar(seq_vector), |
|
|
197 |
maxlen = maxlen, |
|
|
198 |
step = step, |
|
|
199 |
train_mode = ifelse(mode == "lm", "lm", "label"), |
|
|
200 |
discard_amb_nuc = ifelse(ambiguous_nuc == "discard", TRUE, FALSE), |
|
|
201 |
vocabulary = vocabulary |
|
|
202 |
) %>% length() |
|
|
203 |
|
|
|
204 |
if (is.null(max_samples)) { |
|
|
205 |
num_samples[i] <- num_samples[i] + new_samples |
|
|
206 |
} else { |
|
|
207 |
num_samples[i] <- num_samples[i] + min(new_samples, max_samples) |
|
|
208 |
} |
|
|
209 |
} |
|
|
210 |
number_batches[i] <- ceiling(num_samples[i]/batch_size) |
|
|
211 |
|
|
|
212 |
} |
|
|
213 |
if (mode == "label_folder") { |
|
|
214 |
message_string <- paste0("Evaluate ", num_samples, " samples for class ", vocabulary_label[[1]], ".\n") |
|
|
215 |
} else { |
|
|
216 |
message_string <- paste0("Evaluate ", sum(num_samples), " samples.") |
|
|
217 |
} |
|
|
218 |
message(message_string) |
|
|
219 |
} |
|
|
220 |
|
|
|
221 |
if (evaluate_all_files & format == "rds") { |
|
|
222 |
rds_files <- list_fasta_files(path_corpus = path_input, |
|
|
223 |
format = "rds", |
|
|
224 |
file_filter = NULL) |
|
|
225 |
num_samples <- 0 |
|
|
226 |
for (file in rds_files) { |
|
|
227 |
rds_file <- readRDS(file) |
|
|
228 |
x <- rds_file[[1]] |
|
|
229 |
while (is.list(x)) { |
|
|
230 |
x <- x[[1]] |
|
|
231 |
} |
|
|
232 |
num_samples <- dim(x)[1] + num_samples |
|
|
233 |
} |
|
|
234 |
number_batches <- ceiling(num_samples/batch_size) |
|
|
235 |
message_string <- paste0("Evaluate ", num_samples, " samples.") |
|
|
236 |
message(message_string) |
|
|
237 |
} |
|
|
238 |
|
|
|
239 |
if (!is.null(exact_num_samples)) { |
|
|
240 |
num_samples <- exact_num_samples |
|
|
241 |
number_batches <- ceiling(num_samples/batch_size) |
|
|
242 |
} |
|
|
243 |
|
|
|
244 |
overall_num_batches <- sum(number_batches) |
|
|
245 |
|
|
|
246 |
if (mode == "lm") { |
|
|
247 |
gen <- generator_fasta_lm(path_corpus = path_input, |
|
|
248 |
format = format, |
|
|
249 |
batch_size = batch_size, |
|
|
250 |
maxlen = maxlen, |
|
|
251 |
max_iter = max_iter, |
|
|
252 |
vocabulary = vocabulary, |
|
|
253 |
verbose = FALSE, |
|
|
254 |
shuffle_file_order = shuffle_file_order, |
|
|
255 |
step = step, |
|
|
256 |
concat_seq = concat_seq, |
|
|
257 |
padding = padding, |
|
|
258 |
shuffle_input = FALSE, |
|
|
259 |
reverse_complement = FALSE, |
|
|
260 |
output_format = output_format, |
|
|
261 |
ambiguous_nuc = ambiguous_nuc, |
|
|
262 |
proportion_per_seq = proportion_per_seq, |
|
|
263 |
max_samples = max_samples, |
|
|
264 |
seed = seed, |
|
|
265 |
...) |
|
|
266 |
} |
|
|
267 |
|
|
|
268 |
if (mode == "label_header" | mode == "label_csv") { |
|
|
269 |
gen <- generator_fasta_label_header_csv(path_corpus = path_input, |
|
|
270 |
format = format, |
|
|
271 |
batch_size = batch_size, |
|
|
272 |
maxlen = maxlen, |
|
|
273 |
max_iter = max_iter, |
|
|
274 |
vocabulary = vocabulary, |
|
|
275 |
verbose = FALSE, |
|
|
276 |
shuffle_file_order = shuffle_file_order, |
|
|
277 |
step = step, |
|
|
278 |
padding = padding, |
|
|
279 |
shuffle_input = FALSE, |
|
|
280 |
concat_seq = concat_seq, |
|
|
281 |
vocabulary_label = vocabulary_label[[1]], |
|
|
282 |
reverse_complement = FALSE, |
|
|
283 |
ambiguous_nuc = ambiguous_nuc, |
|
|
284 |
target_from_csv = target_from_csv, |
|
|
285 |
proportion_per_seq = proportion_per_seq, |
|
|
286 |
max_samples = max_samples, |
|
|
287 |
seed = seed, ...) |
|
|
288 |
} |
|
|
289 |
|
|
|
290 |
if (mode == "label_rds" | mode == "lm_rds") { |
|
|
291 |
gen <- generator_rds(rds_folder = path_input, batch_size = batch_size, path_file_log = NULL, ...) |
|
|
292 |
} |
|
|
293 |
|
|
|
294 |
batch_index <- 1 |
|
|
295 |
start_time <- Sys.time() |
|
|
296 |
ten_percent_steps <- seq(overall_num_batches/10, overall_num_batches, length.out = 10) |
|
|
297 |
percentage_index <- 1 |
|
|
298 |
count <- 1 |
|
|
299 |
y_conf_list <- vector("list", overall_num_batches) |
|
|
300 |
y_list <- vector("list", overall_num_batches) |
|
|
301 |
if (include_seq) { |
|
|
302 |
x_list <- vector("list", overall_num_batches) |
|
|
303 |
} |
|
|
304 |
|
|
|
305 |
for (k in 1:num_classes) { |
|
|
306 |
|
|
|
307 |
index <- NULL |
|
|
308 |
if (mode == "label_folder") { |
|
|
309 |
gen <- generator_fasta_label_folder(path_corpus = path_input[[k]], |
|
|
310 |
format = format, |
|
|
311 |
batch_size = batch_size, |
|
|
312 |
maxlen = maxlen, |
|
|
313 |
max_iter = max_iter, |
|
|
314 |
vocabulary = vocabulary, |
|
|
315 |
step = step, |
|
|
316 |
padding = padding, |
|
|
317 |
concat_seq = concat_seq, |
|
|
318 |
reverse_complement = FALSE, |
|
|
319 |
num_targets = length(path_input), |
|
|
320 |
ones_column = k, |
|
|
321 |
ambiguous_nuc = ambiguous_nuc, |
|
|
322 |
proportion_per_seq = proportion_per_seq, |
|
|
323 |
max_samples = max_samples, |
|
|
324 |
seed = seed, ...) |
|
|
325 |
} |
|
|
326 |
|
|
|
327 |
for (i in 1:number_batches[k]) { |
|
|
328 |
z <- gen() |
|
|
329 |
x <- z[[1]] |
|
|
330 |
y <- z[[2]] |
|
|
331 |
|
|
|
332 |
y_conf <- model(x) |
|
|
333 |
batch_index <- batch_index + 1 |
|
|
334 |
|
|
|
335 |
# remove double predictions |
|
|
336 |
if (eval_exact_num_samples & (i == number_batches[k])) { |
|
|
337 |
double_index <- (i * batch_size) - num_samples[k] |
|
|
338 |
|
|
|
339 |
if (double_index > 0) { |
|
|
340 |
index <- 1:(nrow(y_conf) - double_index) |
|
|
341 |
|
|
|
342 |
if (is.list(y_conf)) { |
|
|
343 |
for (m in 1:length(y_conf)) { |
|
|
344 |
y_conf[[m]] <- y_conf[[m]][index, ] |
|
|
345 |
y[[m]] <- y[[m]][index, ] |
|
|
346 |
} |
|
|
347 |
} else { |
|
|
348 |
y_conf <- y_conf[index, ] |
|
|
349 |
y <- y[index, ] |
|
|
350 |
} |
|
|
351 |
|
|
|
352 |
# vector to matrix |
|
|
353 |
if (length(index) == 1) { |
|
|
354 |
if (is.list(y_conf)) { |
|
|
355 |
for (m in 1:length(y_conf)) { |
|
|
356 |
y_conf[[m]] <- array(as.array(y_conf[[m]]), dim = c(1, length(y_conf[[m]]))) |
|
|
357 |
y[[m]] <- matrix(y[[m]], ncol = length(y[[m]])) |
|
|
358 |
} |
|
|
359 |
} else { |
|
|
360 |
y_conf <- array(as.array(y_conf), dim = c(1, length(y_conf))) |
|
|
361 |
y <- matrix(y, ncol = length(y)) |
|
|
362 |
} |
|
|
363 |
} |
|
|
364 |
|
|
|
365 |
} |
|
|
366 |
} |
|
|
367 |
|
|
|
368 |
if (include_seq) { |
|
|
369 |
x_list[[count]] <- x |
|
|
370 |
} |
|
|
371 |
y_conf_list[[count]] <- y_conf |
|
|
372 |
if (batch_size == 1 | (!is.null(index) && length(index == 1))) { |
|
|
373 |
col_num <- ncol(y_conf) |
|
|
374 |
if (is.na(col_num)) col_num <- length(y_conf) |
|
|
375 |
y_list[[count]] <- matrix(y, ncol = col_num) |
|
|
376 |
} else { |
|
|
377 |
y_list[[count]] <- y |
|
|
378 |
} |
|
|
379 |
count <- count + 1 |
|
|
380 |
|
|
|
381 |
if (verbose & (batch_index == 10)) { |
|
|
382 |
time_passed <- as.double(difftime(Sys.time(), start_time, units = "hours")) |
|
|
383 |
time_estimation <- (overall_num_batches/10) * time_passed |
|
|
384 |
cat("Evaluation will take approximately", round(time_estimation, 3), "hours. Starting time:", format(Sys.time(), "%F %R."), " \n") |
|
|
385 |
|
|
|
386 |
} |
|
|
387 |
|
|
|
388 |
if (verbose & (batch_index > ten_percent_steps[percentage_index]) & percentage_index < 10) { |
|
|
389 |
cat("Progress: ", percentage_index * 10 ,"% \n") |
|
|
390 |
time_passed <- as.double(difftime(Sys.time(), start_time, units = "hours")) |
|
|
391 |
cat("Time passed: ", round(time_passed, 3), "hours \n") |
|
|
392 |
percentage_index <- percentage_index + 1 |
|
|
393 |
} |
|
|
394 |
|
|
|
395 |
} |
|
|
396 |
} |
|
|
397 |
|
|
|
398 |
if (verbose) { |
|
|
399 |
cat("Progress: 100 % \n") |
|
|
400 |
time_passed <- as.double(difftime(Sys.time(), start_time, units = "hours")) |
|
|
401 |
cat("Time passed: ", round(time_passed, 3), "hours \n") |
|
|
402 |
} |
|
|
403 |
|
|
|
404 |
y_conf_list <- reshape_y_list(y_conf_list, num_out_layers = num_out_layers, tf_format = TRUE) |
|
|
405 |
y_list <- reshape_y_list(y_list, num_out_layers = num_out_layers, tf_format = FALSE) |
|
|
406 |
|
|
|
407 |
if (!is.null(path_pred_list)) { |
|
|
408 |
if (include_seq) { |
|
|
409 |
if (is.list(x_list[[1]])) { |
|
|
410 |
num_layers <- length(x_list[[1]]) |
|
|
411 |
} else { |
|
|
412 |
num_layers <- 1 |
|
|
413 |
} |
|
|
414 |
x_list <- reshape_y_list(x_list, num_out_layers = num_layers, tf_format = FALSE) |
|
|
415 |
saveRDS(list(pred = y_conf_list, true = y_list, x = x_list), path_pred_list) |
|
|
416 |
} else { |
|
|
417 |
saveRDS(list(pred = y_conf_list, true = y_list), path_pred_list) |
|
|
418 |
} |
|
|
419 |
} |
|
|
420 |
|
|
|
421 |
eval_list <- list() |
|
|
422 |
for (i in 1:num_out_layers) { |
|
|
423 |
|
|
|
424 |
if (activations[i] == "softmax") { |
|
|
425 |
eval_list[[i]] <- evaluate_softmax(y = y_list[[i]], y_conf = y_conf_list[[i]], |
|
|
426 |
auc = auc, auprc = auprc, |
|
|
427 |
label_names = vocabulary_label[[i]]) |
|
|
428 |
} |
|
|
429 |
|
|
|
430 |
if (activations[i] == "sigmoid") { |
|
|
431 |
eval_list[[i]] <- evaluate_sigmoid(y = y_list[[i]], y_conf = y_conf_list[[i]], |
|
|
432 |
auc = auc, auprc = auprc, |
|
|
433 |
label_names = vocabulary_label[[i]]) |
|
|
434 |
} |
|
|
435 |
|
|
|
436 |
if (activations[i] == "linear") { |
|
|
437 |
eval_list[[i]] <- evaluate_linear(y_true = y_list[[i]], y_pred = y_conf_list[[i]], label_names = vocabulary_label[[i]]) |
|
|
438 |
} |
|
|
439 |
|
|
|
440 |
} |
|
|
441 |
|
|
|
442 |
return(eval_list) |
|
|
443 |
} |
|
|
444 |
|
|
|
445 |
|
|
|
446 |
reshape_y_list <- function(y, num_out_layers, tf_format = TRUE) { |
|
|
447 |
|
|
|
448 |
if (num_out_layers > 1) { |
|
|
449 |
y <- do.call(c, y) |
|
|
450 |
} |
|
|
451 |
|
|
|
452 |
reshaped_list <- vector("list", num_out_layers) |
|
|
453 |
|
|
|
454 |
for (i in 1:num_out_layers) { |
|
|
455 |
index <- seq(i, length(y), by = num_out_layers) |
|
|
456 |
if (tf_format) { |
|
|
457 |
reshaped_list[[i]] <- y[index] %>% |
|
|
458 |
tensorflow::tf$concat(axis = 0L) %>% |
|
|
459 |
keras::k_eval() |
|
|
460 |
} else { |
|
|
461 |
reshaped_list[[i]] <- do.call(rbind, y[index]) |
|
|
462 |
} |
|
|
463 |
} |
|
|
464 |
return(reshaped_list) |
|
|
465 |
} |
|
|
466 |
|
|
|
467 |
#' Evaluate matrices of true targets and predictions from layer with softmax activation. |
|
|
468 |
#' |
|
|
469 |
#' Compute confusion matrix, accuracy, categorical crossentropy and (optionally) AUC or AUPRC, given predictions and |
|
|
470 |
#' true targets. AUC and AUPRC only possible for 2 targets. |
|
|
471 |
#' |
|
|
472 |
#' @param y Matrix of true target. |
|
|
473 |
#' @param y_conf Matrix of predictions. |
|
|
474 |
#' @param auc Whether to include AUC metric. Only possible for 2 targets. |
|
|
475 |
#' @param auprc Whether to include AUPRC metric. Only possible for 2 targets. |
|
|
476 |
#' @param label_names Names of corresponding labels. Length must be equal to number of columns of \code{y}. |
|
|
477 |
#' @examplesIf reticulate::py_module_available("tensorflow") |
|
|
478 |
#' y <- matrix(c(1, 0, 0, 0, 1, 1), ncol = 2) |
|
|
479 |
#' y_conf <- matrix(c(0.3, 0.5, 0.1, 0.7, 0.5, 0.9), ncol = 2) |
|
|
480 |
#' evaluate_softmax(y, y_conf, auc = TRUE, auprc = TRUE, label_names = c("A", "B")) |
|
|
481 |
#' |
|
|
482 |
#' @returns A list of evaluation results. |
|
|
483 |
#' @export |
|
|
484 |
evaluate_softmax <- function(y, y_conf, auc = FALSE, auprc = FALSE, label_names = NULL) { |
|
|
485 |
|
|
|
486 |
if (ncol(y) != 2 & (auc | auprc)) { |
|
|
487 |
message("Can only compute AUC or AUPRC if output layer with softmax acticvation has two neurons.") |
|
|
488 |
auc <- FALSE |
|
|
489 |
auprc <- FALSE |
|
|
490 |
} |
|
|
491 |
|
|
|
492 |
y_pred <- apply(y_conf, 1, which.max) |
|
|
493 |
y_true <- apply(y, 1, FUN = which.max) - 1 |
|
|
494 |
|
|
|
495 |
df_true_pred <- data.frame( |
|
|
496 |
true = factor(y_true + 1, levels = 1:(length(label_names)), labels = label_names), |
|
|
497 |
pred = factor(y_pred, levels = 1:(length(label_names)), labels = label_names) |
|
|
498 |
) |
|
|
499 |
|
|
|
500 |
loss_per_class <- list() |
|
|
501 |
for (i in 1:ncol(y)) { |
|
|
502 |
index <- (y_true + 1) == i |
|
|
503 |
if (any(index)) { |
|
|
504 |
cce_loss_class <- tensorflow::tf$keras$losses$categorical_crossentropy(y[index, ], y_conf[index, ]) |
|
|
505 |
loss_per_class[[i]] <- cce_loss_class$numpy() |
|
|
506 |
} else { |
|
|
507 |
loss_per_class[[i]] <- NA |
|
|
508 |
} |
|
|
509 |
} |
|
|
510 |
|
|
|
511 |
cm <- yardstick::conf_mat(df_true_pred, true, pred) |
|
|
512 |
confMat <- cm[[1]] |
|
|
513 |
|
|
|
514 |
acc <- sum(diag(confMat))/sum(confMat) |
|
|
515 |
loss <- mean(unlist(loss_per_class)) |
|
|
516 |
|
|
|
517 |
for (i in 1:length(loss_per_class)) { |
|
|
518 |
loss_per_class[[i]] <- mean(unlist(loss_per_class[[i]]), na.rm = TRUE) |
|
|
519 |
} |
|
|
520 |
|
|
|
521 |
loss_per_class <- unlist(loss_per_class) |
|
|
522 |
m <- as.matrix(confMat) |
|
|
523 |
class_acc <- vector("numeric") |
|
|
524 |
for (i in 1:ncol(m)) { |
|
|
525 |
if (sum(m[ , i]) == 0) { |
|
|
526 |
class_acc[i] <- NA |
|
|
527 |
} else { |
|
|
528 |
class_acc[i] <- m[i, i]/sum(m[ , i]) |
|
|
529 |
} |
|
|
530 |
} |
|
|
531 |
names(class_acc) <- label_names |
|
|
532 |
names(loss_per_class) <- label_names |
|
|
533 |
balanced_acc <- mean(class_acc) |
|
|
534 |
|
|
|
535 |
if (auc) { |
|
|
536 |
auc_list <- PRROC::roc.curve( |
|
|
537 |
scores.class0 = y_conf[ , 2], |
|
|
538 |
weights.class0 = y_true) |
|
|
539 |
} else { |
|
|
540 |
auc_list <- NULL |
|
|
541 |
} |
|
|
542 |
|
|
|
543 |
if (auprc) { |
|
|
544 |
auprc_list <- PRROC::pr.curve( |
|
|
545 |
scores.class0 = y_conf[ , 2], |
|
|
546 |
weights.class0 = y_true) |
|
|
547 |
} else { |
|
|
548 |
auprc_list <- NULL |
|
|
549 |
} |
|
|
550 |
|
|
|
551 |
return(list(confusion_matrix = confMat, |
|
|
552 |
accuracy = acc, |
|
|
553 |
categorical_crossentropy_loss = loss, |
|
|
554 |
#balanced_accuracy = balanced_acc, |
|
|
555 |
#loss_per_class = loss_per_class, |
|
|
556 |
#accuracy_per_class = class_acc, |
|
|
557 |
AUC = auc_list$auc, |
|
|
558 |
AUPRC = auprc_list$auc.integral)) |
|
|
559 |
} |
|
|
560 |
|
|
|
561 |
#' Evaluate matrices of true targets and predictions from layer with sigmoid activation. |
|
|
562 |
#' |
|
|
563 |
#' Compute accuracy, binary crossentropy and (optionally) AUC or AUPRC, given predictions and |
|
|
564 |
#' true targets. Outputs columnwise average. |
|
|
565 |
#' |
|
|
566 |
#' @inheritParams evaluate_model |
|
|
567 |
#' @inheritParams evaluate_softmax |
|
|
568 |
#' @param auc Whether to include AUC metric. |
|
|
569 |
#' @param auprc Whether to include AUPRC metric. |
|
|
570 |
#' @examplesIf reticulate::py_module_available("tensorflow") |
|
|
571 |
#' y <- matrix(sample(c(0, 1), 30, replace = TRUE), ncol = 3) |
|
|
572 |
#' y_conf <- matrix(runif(n = 30), ncol = 3) |
|
|
573 |
#' evaluate_sigmoid(y, y_conf, auc = TRUE, auprc = TRUE) |
|
|
574 |
#' |
|
|
575 |
#' @returns A list of evaluation results. |
|
|
576 |
#' @export |
|
|
577 |
evaluate_sigmoid <- function(y, y_conf, auc = FALSE, auprc = FALSE, label_names = NULL) { |
|
|
578 |
|
|
|
579 |
y_pred <- ifelse(y_conf > 0.5, 1, 0) |
|
|
580 |
|
|
|
581 |
loss_per_class <- list() |
|
|
582 |
for (i in 1:ncol(y)) { |
|
|
583 |
bce_loss_class <- tensorflow::tf$keras$losses$binary_crossentropy(y[ , i], y_conf[ , i]) |
|
|
584 |
loss_per_class[[i]] <- bce_loss_class$numpy() |
|
|
585 |
} |
|
|
586 |
|
|
|
587 |
loss_per_class <- unlist(loss_per_class) |
|
|
588 |
names(loss_per_class) <- label_names |
|
|
589 |
loss <- mean(unlist(loss_per_class)) |
|
|
590 |
|
|
|
591 |
class_acc <- vector("numeric", ncol(y)) |
|
|
592 |
for (i in 1:ncol(y)) { |
|
|
593 |
num_true_pred <- sum(y[ , i] == y_pred[ , i]) |
|
|
594 |
class_acc[i] <- num_true_pred /nrow(y) |
|
|
595 |
} |
|
|
596 |
names(class_acc) <- label_names |
|
|
597 |
acc <- mean(class_acc) |
|
|
598 |
|
|
|
599 |
if (auc) { |
|
|
600 |
auc_list <- purrr::map(1:ncol(y_conf), ~PRROC::roc.curve( |
|
|
601 |
scores.class0 = y_conf[ , .x], |
|
|
602 |
weights.class0 = y[ , .x])) |
|
|
603 |
auc_vector <- vector("numeric", ncol(y)) |
|
|
604 |
for (i in 1:length(auc_vector)) { |
|
|
605 |
auc_vector[i] <- auc_list[[i]]$auc |
|
|
606 |
} |
|
|
607 |
|
|
|
608 |
na_count <- sum(is.na(auc_vector)) |
|
|
609 |
if (na_count > 0) { |
|
|
610 |
message(paste(sum(na_count), ifelse(na_count > 1, "columns", "column"), |
|
|
611 |
"removed from AUC evaluation since they contain only one label")) |
|
|
612 |
} |
|
|
613 |
AUC <- mean(auc_vector, na.rm = TRUE) |
|
|
614 |
} else { |
|
|
615 |
AUC <- NULL |
|
|
616 |
} |
|
|
617 |
|
|
|
618 |
if (auprc) { |
|
|
619 |
auprc_list <- purrr::map(1:ncol(y_conf), ~PRROC::pr.curve( |
|
|
620 |
scores.class0 = y_conf[ , .x], |
|
|
621 |
weights.class0 = y[ , .x])) |
|
|
622 |
auprc_vector <- vector("numeric", ncol(y)) |
|
|
623 |
for (i in 1:length(auprc_vector)) { |
|
|
624 |
auprc_vector[i] <- auprc_list[[i]]$auc.integral |
|
|
625 |
} |
|
|
626 |
AUPRC <- mean(auprc_vector, na.rm = TRUE) |
|
|
627 |
} else { |
|
|
628 |
AUPRC <- NULL |
|
|
629 |
} |
|
|
630 |
|
|
|
631 |
return(list(accuracy = acc, |
|
|
632 |
binary_crossentropy_loss = loss, |
|
|
633 |
#loss_per_class = loss_per_class, |
|
|
634 |
#accuracy_per_class = class_acc, |
|
|
635 |
AUC = AUC, |
|
|
636 |
AUPRC = AUPRC)) |
|
|
637 |
|
|
|
638 |
} |
|
|
639 |
|
|
|
640 |
#' Evaluate matrices of true targets and predictions from layer with linear activation. |
|
|
641 |
#' |
|
|
642 |
#' Compute MAE and MSE, given predictions and |
|
|
643 |
#' true targets. Outputs columnwise average. |
|
|
644 |
#' |
|
|
645 |
#' @inheritParams evaluate_model |
|
|
646 |
#' @inheritParams evaluate_softmax |
|
|
647 |
#' @param y_true Matrix of true labels. |
|
|
648 |
#' @param y_pred Matrix of predictions. |
|
|
649 |
#' @examplesIf reticulate::py_module_available("tensorflow") |
|
|
650 |
#' y_true <- matrix(rnorm(n = 12), ncol = 3) |
|
|
651 |
#' y_pred <- matrix(rnorm(n = 12), ncol = 3) |
|
|
652 |
#' evaluate_linear(y_true, y_pred) |
|
|
653 |
#' |
|
|
654 |
#' @returns A list of evaluation results. |
|
|
655 |
#' @export |
|
|
656 |
evaluate_linear <- function(y_true, y_pred, label_names = NULL) { |
|
|
657 |
|
|
|
658 |
loss_per_class_mse <- list() |
|
|
659 |
loss_per_class_mae <- list() |
|
|
660 |
for (i in 1:ncol(y_true)) { |
|
|
661 |
mse_loss_class <- tensorflow::tf$keras$losses$mean_squared_error(y_true[ ,i], y_pred[ , i]) |
|
|
662 |
mae_loss_class <- tensorflow::tf$keras$losses$mean_absolute_error(y_true[ ,i], y_pred[ , i]) |
|
|
663 |
loss_per_class_mse[[i]] <- mse_loss_class$numpy() |
|
|
664 |
loss_per_class_mae[[i]] <- mae_loss_class$numpy() |
|
|
665 |
} |
|
|
666 |
|
|
|
667 |
return(list(mse = mean(unlist(loss_per_class_mse)), |
|
|
668 |
mae = mean(unlist(loss_per_class_mae)))) |
|
|
669 |
|
|
|
670 |
} |
|
|
671 |
|
|
|
672 |
|
|
|
673 |
#' Plot ROC |
|
|
674 |
#' |
|
|
675 |
#' Compute ROC and AUC from target and prediction matrix and plot ROC. Target/prediction matrix should |
|
|
676 |
#' have one column if output of layer with sigmoid activation and two columns for softmax activation. |
|
|
677 |
#' |
|
|
678 |
#' @inheritParams evaluate_softmax |
|
|
679 |
#' @inheritParams evaluate_linear |
|
|
680 |
#' @param path_roc_plot Where to store ROC plot. |
|
|
681 |
#' @param return_plot Whether to return plot. |
|
|
682 |
#' @examples |
|
|
683 |
#' y_true <- matrix(c(1, 0, 0, 0, 1, 1), ncol = 1) |
|
|
684 |
#' y_conf <- matrix(runif(n = nrow(y_true)), ncol = 1) |
|
|
685 |
#' p <- plot_roc(y_true, y_conf, return_plot = TRUE) |
|
|
686 |
#' p |
|
|
687 |
#' |
|
|
688 |
#' @returns A ggplot of ROC curve. |
|
|
689 |
#' @export |
|
|
690 |
plot_roc <- function(y_true, y_conf, path_roc_plot = NULL, |
|
|
691 |
return_plot = TRUE) { |
|
|
692 |
|
|
|
693 |
if (!all(y_true == 0 | y_true == 1)) { |
|
|
694 |
stop("y_true should only contain 0 and 1 entries") |
|
|
695 |
} |
|
|
696 |
|
|
|
697 |
if (is.matrix(y_true) && ncol(y_true) > 2) { |
|
|
698 |
stop("y_true can contain 1 or 2 columns") |
|
|
699 |
} |
|
|
700 |
|
|
|
701 |
if (is.matrix(y_true) && ncol(y_true) == 2) { |
|
|
702 |
y_true <- y_true[ , 1] |
|
|
703 |
y_conf <- y_conf[ , 2] |
|
|
704 |
} |
|
|
705 |
|
|
|
706 |
if (stats::var(y_true) == 0) { |
|
|
707 |
stop("y_true contains just one label") |
|
|
708 |
} |
|
|
709 |
|
|
|
710 |
y_true <- as.vector(y_true) |
|
|
711 |
y_conf <- as.vector(y_conf) |
|
|
712 |
|
|
|
713 |
rocobj <- pROC::roc(y_true, y_conf, quiet = TRUE) |
|
|
714 |
auc <- round(pROC::auc(y_true, y_conf, quiet = TRUE), 4) |
|
|
715 |
p <- pROC::ggroc(rocobj, size = 1, color = "black") |
|
|
716 |
p <- p + ggplot2::theme_classic() + ggplot2::theme(aspect.ratio = 1) |
|
|
717 |
p <- p + ggplot2::ggtitle(paste0('ROC Curve ', '(AUC = ', auc, ')')) |
|
|
718 |
p <- p + ggplot2::geom_abline(intercept = 1, linetype = 2, color = "grey50") |
|
|
719 |
p <- p + ggplot2::geom_vline(xintercept = 1, linetype = 2, color = "grey50") |
|
|
720 |
p <- p + ggplot2::geom_hline(yintercept = 1, linetype = 2, color = "grey50") |
|
|
721 |
|
|
|
722 |
if (!is.null(path_roc_plot)) { |
|
|
723 |
ggplot2::ggsave(path_roc_plot, p) |
|
|
724 |
} |
|
|
725 |
|
|
|
726 |
if (return_plot) { |
|
|
727 |
return(p) |
|
|
728 |
} else { |
|
|
729 |
return(NULL) |
|
|
730 |
} |
|
|
731 |
|
|
|
732 |
} |
|
|
733 |
|
|
|
734 |
# plot_roc_auprc <- function(y_true, y_conf, path_roc_plot = NULL, path_auprc_plot = NULL, |
|
|
735 |
# return_plot = TRUE, layer_activation = "softmax") { |
|
|
736 |
# |
|
|
737 |
# if (layer_activation == "softmax") { |
|
|
738 |
# |
|
|
739 |
# if (!all(y_true == 0 | y_true == 1)) { |
|
|
740 |
# stop("y_true should only contain 0 and 1 entries") |
|
|
741 |
# } |
|
|
742 |
# |
|
|
743 |
# if (ncol(y_true) != 2 & (auc | auprc)) { |
|
|
744 |
# message("Can only compute AUC or AUPRC if output layer with softmax acticvation has two neurons.") |
|
|
745 |
# } |
|
|
746 |
# |
|
|
747 |
# auc_list <- PRROC::roc.curve( |
|
|
748 |
# scores.class0 = y_conf[ , 2], |
|
|
749 |
# weights.class0 = y_true[ , 2], curve = TRUE) |
|
|
750 |
# |
|
|
751 |
# |
|
|
752 |
# auprc_list <- PRROC::pr.curve( |
|
|
753 |
# scores.class0 = y_conf[ , 2], |
|
|
754 |
# weights.class0 = y_true[ , 2], curve = TRUE) |
|
|
755 |
# |
|
|
756 |
# #auc_plot <- NULL |
|
|
757 |
# #auprc_plot <- NULL |
|
|
758 |
# |
|
|
759 |
# } |
|
|
760 |
# |
|
|
761 |
# if (layer_activation == "sigmoid") { |
|
|
762 |
# |
|
|
763 |
# auc_list <- purrr::map(1:ncol(y_conf), ~PRROC::roc.curve( |
|
|
764 |
# scores.class0 = y_conf[ , .x], |
|
|
765 |
# weights.class0 = y[ , .x], curve = TRUE)) |
|
|
766 |
# auc_vector <- vector("numeric", ncol(y)) |
|
|
767 |
# |
|
|
768 |
# |
|
|
769 |
# auprc_list <- purrr::map(1:ncol(y_conf), ~PRROC::pr.curve( |
|
|
770 |
# scores.class0 = y_conf[ , .x], |
|
|
771 |
# weights.class0 = y[ , .x], curve = TRUE)) |
|
|
772 |
# auprc_vector <- vector("numeric", ncol(y)) |
|
|
773 |
# |
|
|
774 |
# } |
|
|
775 |
# |
|
|
776 |
# if (!is.null(path_roc_plot)) { |
|
|
777 |
# |
|
|
778 |
# } |
|
|
779 |
# |
|
|
780 |
# if (!is.null(path_auprc_plot)) { |
|
|
781 |
# |
|
|
782 |
# } |
|
|
783 |
# |
|
|
784 |
# } |