--- a
+++ b/R/plotLearner.R
@@ -0,0 +1,230 @@
+#' Plot the summary curves produced by IntegratedLearner object
+#'
+#'@description Plots the R^2/AUC curves for the training (and test, if provided) set produced by IntegratedLearner object 
+#'
+#' @param fit fitted "IntegratedLearner" object 
+#' @param label_size (optional) Numerical value indicating the label size. Default is 8.
+#' @param label_x (optional) Single value or vector of x positions for plot labels, relative to each subplot. Defaults to 0.3 for all labels. (Each label is placed all the way to the left of each plot.)
+#' @param vjust Adjusts the vertical position of each label. More positive values move the label further down on the plot canvas. Can be a single value (applied to all labels) or a vector of values (one for each label). Default is 0.1.
+#' @param rowwise_plot If both train and test data is available, should the train and test plots be rowwise_plot. Default is TRUE. If FALSE, plots are aligned column-wise.
+#'
+#' @return ggplot2 object
+#' @export
+plot.learner <- function(fit,label_size=8, label_x=0.3,vjust=0.1, rowwise_plot=TRUE){
+  
+  clean_base_learner <- str_remove_all(fit$base_learner, 'SL.')
+  clean_meta_learner <- str_remove_all(fit$meta_learner, 'SL.')  
+  method <- paste(clean_base_learner,clean_meta_learner,sep=' + ')
+  if(rowwise_plot) {
+    nrow = 2
+    ncol = 1
+  } else{
+    nrow = 1
+    ncol = 2
+    }
+  
+  if(fit$family=='binomial'){
+    
+    # Extract ROC plot data 
+    list.ROC<-vector("list", length = ncol(fit$yhat.train))
+    names(list.ROC)<-colnames(fit$yhat.train)
+    
+    y <- fit$Y_train
+    # Loop over layers 
+    for(k in 1:length(list.ROC)){
+      preds<-fit$yhat.train[ ,k]
+      pred = ROCR::prediction(preds, y)
+      AUC = round(ROCR::performance(pred, "auc")@y.values[[1]], 2)
+      perf = ROCR::performance(pred, "sens", "spec") 
+      list.ROC[[k]] <- data.frame(sensitivity = methods::slot(perf, "y.values")[[1]],
+                                  specificity = 1 - methods::slot(perf, "x.values")[[1]],
+                                  AUC = AUC,
+                                  layer = names(list.ROC)[k],
+                                  method = method)
+    }
+    
+    # Combine
+    ROC_table<-do.call('rbind', list.ROC)
+    
+    # Prepare data for plotting
+    plot_data<-ROC_table
+    plot_data$displayItem<-paste(plot_data$layer, " AUC = ", plot_data$AUC, sep="")
+    plot_data$displayItem<-factor(plot_data$displayItem,
+                                  levels = unique(plot_data$displayItem))
+    
+    # ROC curves
+    p1<-ggplot(plot_data,
+               aes(x=specificity,
+                   y=sensitivity,
+                   group=displayItem)) + 
+      geom_line(aes(x=specificity,y=sensitivity,color=displayItem)) +
+      #ggtitle(paste('Training data: ', method, sep=''))+
+      theme(legend.position="bottom", 
+            legend.background=element_blank(),
+            legend.box.background=element_rect(colour="black")) + 
+      theme_bw() +
+      xlab("False Positive Rate") +
+      ylab("True Positive Rate") +
+      theme(legend.position = "right", legend.direction = "vertical") +
+      labs(color='') 
+    
+    if(fit$test==TRUE){
+      
+      # Extract ROC plot data 
+      list.ROC.valid<-vector("list", length = ncol(fit$yhat.test))
+      names(list.ROC.valid)<-colnames(fit$yhat.test)
+      
+      y <- fit$Y_test
+      # Loop over layers 
+      for(k in 1:length(list.ROC.valid)){
+        preds<-fit$yhat.test[ ,k]
+        pred = ROCR::prediction(preds, y)
+        AUC = round(ROCR::performance(pred, "auc")@y.values[[1]], 2)
+        perf = ROCR::performance(pred, "sens", "spec") 
+        list.ROC.valid[[k]] <- data.frame(sensitivity = methods::slot(perf, "y.values")[[1]],
+                                          specificity = 1 - methods::slot(perf, "x.values")[[1]],
+                                          AUC = AUC,
+                                          layer = names(list.ROC.valid)[k],
+                                          method = method)
+      }
+      
+      # Combine
+      ROC_table_valid<-do.call('rbind', list.ROC.valid)
+      
+      # Prepare data for plotting
+      plot_data<-ROC_table_valid
+      plot_data$displayItem<-paste(plot_data$layer, " AUC = ", plot_data$AUC, sep="")
+      plot_data$displayItem<-factor(plot_data$displayItem,
+                                    levels = unique(plot_data$displayItem))
+      
+      # ROC curves
+      p2<-ggplot(plot_data,
+                 aes(x=specificity,
+                     y=sensitivity,
+                     group=displayItem)) + 
+        geom_line(aes(x=specificity,y=sensitivity,color=displayItem)) +
+        #ggtitle(paste('Test data: ', method, sep=''))+
+        theme(legend.position="bottom", 
+              legend.background=element_blank(),
+              legend.box.background=element_rect(colour="black")) + 
+        theme_bw() +
+        xlab("False Positive Rate") +
+        ylab("True Positive Rate") +
+        theme(legend.position = "right", legend.direction = "vertical") +
+        labs(color='') 
+      
+      p<-plot_grid(p1, 
+                   p2, 
+                   nrow = 2, 
+                   labels = c(paste('A. ', fit$folds,'-fold CV',sep = ''), 
+                              'B. Independent Validation'),
+                   label_size = label_size, label_x = label_x,vjust = vjust)+
+        theme(plot.margin = unit(c(1,1,1,1), "cm"))  
+      print(p)
+      return(list('plot'=p,'ROC_table'=ROC_table,'ROC_table_valid'=ROC_table_valid))
+    }
+    p <- plot_grid(p1, 
+                   nrow = nrow,
+                   ncol = ncol,
+                   labels = c(paste('A. ', fit$folds,'-fold CV',sep = '')), 
+                   label_size = label_size, label_x = label_x,vjust = vjust)+
+      theme(plot.margin = unit(c(1,1,1,1), "cm"))
+    print(p)
+    return(list('plot'=p,'ROC_table'=ROC_table)) 
+  }
+  else if(fit$family=='gaussian'){
+    
+    
+    # Extract R2 plot data 
+    list.R2<-vector("list", length = ncol(fit$yhat.train))
+    names(list.R2)<-colnames(fit$yhat.train)
+    
+    y <- fit$Y_train
+    # Loop over layers 
+    for(k in 1:length(list.R2)){
+      preds<-fit$yhat.train[ ,k]
+      R2<- as.vector(cor(preds, y)^2)
+      list.R2[[k]] <- data.frame(R2 = R2,
+                                 layer = names(list.R2)[k],
+                                 method = method)
+    }
+    
+    # Combine 
+    R2_table<-do.call('rbind', list.R2)
+    
+    # Plot
+    p1<-ggplot(R2_table, aes(x = method, y = R2)) +
+      geom_bar(position="dodge", stat="identity", aes(fill=layer)) +
+      xlab("") + 
+      ylab(expression(paste("Prediction accuracy (", R^2, ")"))) +
+      scale_fill_discrete(name="") + 
+      theme(legend.position="bottom", 
+            legend.background=element_blank(),
+            legend.box.background=element_rect(colour="black")) + 
+      theme_bw() +
+      guides(fill=guide_legend(title="")) +
+      theme(legend.position = "right", legend.direction = "vertical",
+            strip.background = element_blank()) +
+      labs(fill='') 
+    
+    
+    
+    if(fit$test==TRUE){
+      
+      
+      # Extract R2 plot data 
+      list.R2.valid<-vector("list", length = ncol(fit$yhat.test))
+      names(list.R2.valid)<-colnames(fit$yhat.test)
+      
+      y <- fit$Y_test
+      # Loop over layers 
+      for(k in 1:length(list.R2.valid)){
+        preds<-fit$yhat.test[ ,k]
+        R2<- as.vector(cor(preds, y)^2)
+        list.R2.valid[[k]] <- data.frame(R2 = R2,
+                                         layer = names(list.R2.valid)[k],
+                                         method = method)
+      }
+      
+      # Combine 
+      R2_table_valid<-do.call('rbind', list.R2.valid)
+      
+      # Plot
+      p2<-ggplot(R2_table_valid, aes(x = method, y = R2)) +
+        geom_bar(position="dodge", stat="identity", aes(fill=layer)) +
+        xlab("") + 
+        ylab(expression(paste("Prediction accuracy (", R^2, ")"))) +
+        scale_fill_discrete(name="") + 
+        theme(legend.position="bottom", 
+              legend.background=element_blank(),
+              legend.box.background=element_rect(colour="black")) + 
+        theme_bw() +
+        guides(fill=guide_legend(title="")) +
+        theme(legend.position = "right", legend.direction = "vertical",
+              strip.background = element_blank()) +
+        labs(fill='') 
+      
+      nrow = NULL
+      ncol = NULL
+      p<-plot_grid(p1, 
+                   p2, 
+                   nrow = nrow,
+                   ncol = ncol,
+                   labels = c(paste('A. ', fit$folds,'-fold CV',sep = ''), 
+                              'B. Independent Validation'),
+                   label_size = label_size, label_x = label_x,vjust = vjust)+
+        theme(plot.margin = unit(c(1,1,1,1), "cm"))  
+      print(p)
+      return(list('plot'=p,'R2_table'=R2_table,'R2_table_valid'=R2_table_valid))
+      
+    }
+    p <- plot_grid(p1, 
+                   ncol = 1, 
+                   labels = c(paste('A. ', fit$folds,'-fold CV',sep = '')), 
+                   label_size = label_size, label_x = label_x,vjust = vjust)+
+      theme(plot.margin = unit(c(1,1,1,1), "cm"))
+    print(p)
+    return(list('plot'=p,'R2_table'=R2_table)) 
+    
+  }
+}