--- a +++ b/R/plotLearner.R @@ -0,0 +1,230 @@ +#' Plot the summary curves produced by IntegratedLearner object +#' +#'@description Plots the R^2/AUC curves for the training (and test, if provided) set produced by IntegratedLearner object +#' +#' @param fit fitted "IntegratedLearner" object +#' @param label_size (optional) Numerical value indicating the label size. Default is 8. +#' @param label_x (optional) Single value or vector of x positions for plot labels, relative to each subplot. Defaults to 0.3 for all labels. (Each label is placed all the way to the left of each plot.) +#' @param vjust Adjusts the vertical position of each label. More positive values move the label further down on the plot canvas. Can be a single value (applied to all labels) or a vector of values (one for each label). Default is 0.1. +#' @param rowwise_plot If both train and test data is available, should the train and test plots be rowwise_plot. Default is TRUE. If FALSE, plots are aligned column-wise. +#' +#' @return ggplot2 object +#' @export +plot.learner <- function(fit,label_size=8, label_x=0.3,vjust=0.1, rowwise_plot=TRUE){ + + clean_base_learner <- str_remove_all(fit$base_learner, 'SL.') + clean_meta_learner <- str_remove_all(fit$meta_learner, 'SL.') + method <- paste(clean_base_learner,clean_meta_learner,sep=' + ') + if(rowwise_plot) { + nrow = 2 + ncol = 1 + } else{ + nrow = 1 + ncol = 2 + } + + if(fit$family=='binomial'){ + + # Extract ROC plot data + list.ROC<-vector("list", length = ncol(fit$yhat.train)) + names(list.ROC)<-colnames(fit$yhat.train) + + y <- fit$Y_train + # Loop over layers + for(k in 1:length(list.ROC)){ + preds<-fit$yhat.train[ ,k] + pred = ROCR::prediction(preds, y) + AUC = round(ROCR::performance(pred, "auc")@y.values[[1]], 2) + perf = ROCR::performance(pred, "sens", "spec") + list.ROC[[k]] <- data.frame(sensitivity = methods::slot(perf, "y.values")[[1]], + specificity = 1 - methods::slot(perf, "x.values")[[1]], + AUC = AUC, + layer = names(list.ROC)[k], + method = method) + } + + # Combine + ROC_table<-do.call('rbind', list.ROC) + + # Prepare data for plotting + plot_data<-ROC_table + plot_data$displayItem<-paste(plot_data$layer, " AUC = ", plot_data$AUC, sep="") + plot_data$displayItem<-factor(plot_data$displayItem, + levels = unique(plot_data$displayItem)) + + # ROC curves + p1<-ggplot(plot_data, + aes(x=specificity, + y=sensitivity, + group=displayItem)) + + geom_line(aes(x=specificity,y=sensitivity,color=displayItem)) + + #ggtitle(paste('Training data: ', method, sep=''))+ + theme(legend.position="bottom", + legend.background=element_blank(), + legend.box.background=element_rect(colour="black")) + + theme_bw() + + xlab("False Positive Rate") + + ylab("True Positive Rate") + + theme(legend.position = "right", legend.direction = "vertical") + + labs(color='') + + if(fit$test==TRUE){ + + # Extract ROC plot data + list.ROC.valid<-vector("list", length = ncol(fit$yhat.test)) + names(list.ROC.valid)<-colnames(fit$yhat.test) + + y <- fit$Y_test + # Loop over layers + for(k in 1:length(list.ROC.valid)){ + preds<-fit$yhat.test[ ,k] + pred = ROCR::prediction(preds, y) + AUC = round(ROCR::performance(pred, "auc")@y.values[[1]], 2) + perf = ROCR::performance(pred, "sens", "spec") + list.ROC.valid[[k]] <- data.frame(sensitivity = methods::slot(perf, "y.values")[[1]], + specificity = 1 - methods::slot(perf, "x.values")[[1]], + AUC = AUC, + layer = names(list.ROC.valid)[k], + method = method) + } + + # Combine + ROC_table_valid<-do.call('rbind', list.ROC.valid) + + # Prepare data for plotting + plot_data<-ROC_table_valid + plot_data$displayItem<-paste(plot_data$layer, " AUC = ", plot_data$AUC, sep="") + plot_data$displayItem<-factor(plot_data$displayItem, + levels = unique(plot_data$displayItem)) + + # ROC curves + p2<-ggplot(plot_data, + aes(x=specificity, + y=sensitivity, + group=displayItem)) + + geom_line(aes(x=specificity,y=sensitivity,color=displayItem)) + + #ggtitle(paste('Test data: ', method, sep=''))+ + theme(legend.position="bottom", + legend.background=element_blank(), + legend.box.background=element_rect(colour="black")) + + theme_bw() + + xlab("False Positive Rate") + + ylab("True Positive Rate") + + theme(legend.position = "right", legend.direction = "vertical") + + labs(color='') + + p<-plot_grid(p1, + p2, + nrow = 2, + labels = c(paste('A. ', fit$folds,'-fold CV',sep = ''), + 'B. Independent Validation'), + label_size = label_size, label_x = label_x,vjust = vjust)+ + theme(plot.margin = unit(c(1,1,1,1), "cm")) + print(p) + return(list('plot'=p,'ROC_table'=ROC_table,'ROC_table_valid'=ROC_table_valid)) + } + p <- plot_grid(p1, + nrow = nrow, + ncol = ncol, + labels = c(paste('A. ', fit$folds,'-fold CV',sep = '')), + label_size = label_size, label_x = label_x,vjust = vjust)+ + theme(plot.margin = unit(c(1,1,1,1), "cm")) + print(p) + return(list('plot'=p,'ROC_table'=ROC_table)) + } + else if(fit$family=='gaussian'){ + + + # Extract R2 plot data + list.R2<-vector("list", length = ncol(fit$yhat.train)) + names(list.R2)<-colnames(fit$yhat.train) + + y <- fit$Y_train + # Loop over layers + for(k in 1:length(list.R2)){ + preds<-fit$yhat.train[ ,k] + R2<- as.vector(cor(preds, y)^2) + list.R2[[k]] <- data.frame(R2 = R2, + layer = names(list.R2)[k], + method = method) + } + + # Combine + R2_table<-do.call('rbind', list.R2) + + # Plot + p1<-ggplot(R2_table, aes(x = method, y = R2)) + + geom_bar(position="dodge", stat="identity", aes(fill=layer)) + + xlab("") + + ylab(expression(paste("Prediction accuracy (", R^2, ")"))) + + scale_fill_discrete(name="") + + theme(legend.position="bottom", + legend.background=element_blank(), + legend.box.background=element_rect(colour="black")) + + theme_bw() + + guides(fill=guide_legend(title="")) + + theme(legend.position = "right", legend.direction = "vertical", + strip.background = element_blank()) + + labs(fill='') + + + + if(fit$test==TRUE){ + + + # Extract R2 plot data + list.R2.valid<-vector("list", length = ncol(fit$yhat.test)) + names(list.R2.valid)<-colnames(fit$yhat.test) + + y <- fit$Y_test + # Loop over layers + for(k in 1:length(list.R2.valid)){ + preds<-fit$yhat.test[ ,k] + R2<- as.vector(cor(preds, y)^2) + list.R2.valid[[k]] <- data.frame(R2 = R2, + layer = names(list.R2.valid)[k], + method = method) + } + + # Combine + R2_table_valid<-do.call('rbind', list.R2.valid) + + # Plot + p2<-ggplot(R2_table_valid, aes(x = method, y = R2)) + + geom_bar(position="dodge", stat="identity", aes(fill=layer)) + + xlab("") + + ylab(expression(paste("Prediction accuracy (", R^2, ")"))) + + scale_fill_discrete(name="") + + theme(legend.position="bottom", + legend.background=element_blank(), + legend.box.background=element_rect(colour="black")) + + theme_bw() + + guides(fill=guide_legend(title="")) + + theme(legend.position = "right", legend.direction = "vertical", + strip.background = element_blank()) + + labs(fill='') + + nrow = NULL + ncol = NULL + p<-plot_grid(p1, + p2, + nrow = nrow, + ncol = ncol, + labels = c(paste('A. ', fit$folds,'-fold CV',sep = ''), + 'B. Independent Validation'), + label_size = label_size, label_x = label_x,vjust = vjust)+ + theme(plot.margin = unit(c(1,1,1,1), "cm")) + print(p) + return(list('plot'=p,'R2_table'=R2_table,'R2_table_valid'=R2_table_valid)) + + } + p <- plot_grid(p1, + ncol = 1, + labels = c(paste('A. ', fit$folds,'-fold CV',sep = '')), + label_size = label_size, label_x = label_x,vjust = vjust)+ + theme(plot.margin = unit(c(1,1,1,1), "cm")) + print(p) + return(list('plot'=p,'R2_table'=R2_table)) + + } +}