Diff of /R/plotLearner.R [000000] .. [a4ee51]

Switch to unified view

a b/R/plotLearner.R
1
#' Plot the summary curves produced by IntegratedLearner object
2
#'
3
#'@description Plots the R^2/AUC curves for the training (and test, if provided) set produced by IntegratedLearner object 
4
#'
5
#' @param fit fitted "IntegratedLearner" object 
6
#' @param label_size (optional) Numerical value indicating the label size. Default is 8.
7
#' @param label_x (optional) Single value or vector of x positions for plot labels, relative to each subplot. Defaults to 0.3 for all labels. (Each label is placed all the way to the left of each plot.)
8
#' @param vjust Adjusts the vertical position of each label. More positive values move the label further down on the plot canvas. Can be a single value (applied to all labels) or a vector of values (one for each label). Default is 0.1.
9
#' @param rowwise_plot If both train and test data is available, should the train and test plots be rowwise_plot. Default is TRUE. If FALSE, plots are aligned column-wise.
10
#'
11
#' @return ggplot2 object
12
#' @export
13
plot.learner <- function(fit,label_size=8, label_x=0.3,vjust=0.1, rowwise_plot=TRUE){
14
  
15
  clean_base_learner <- str_remove_all(fit$base_learner, 'SL.')
16
  clean_meta_learner <- str_remove_all(fit$meta_learner, 'SL.')  
17
  method <- paste(clean_base_learner,clean_meta_learner,sep=' + ')
18
  if(rowwise_plot) {
19
    nrow = 2
20
    ncol = 1
21
  } else{
22
    nrow = 1
23
    ncol = 2
24
    }
25
  
26
  if(fit$family=='binomial'){
27
    
28
    # Extract ROC plot data 
29
    list.ROC<-vector("list", length = ncol(fit$yhat.train))
30
    names(list.ROC)<-colnames(fit$yhat.train)
31
    
32
    y <- fit$Y_train
33
    # Loop over layers 
34
    for(k in 1:length(list.ROC)){
35
      preds<-fit$yhat.train[ ,k]
36
      pred = ROCR::prediction(preds, y)
37
      AUC = round(ROCR::performance(pred, "auc")@y.values[[1]], 2)
38
      perf = ROCR::performance(pred, "sens", "spec") 
39
      list.ROC[[k]] <- data.frame(sensitivity = methods::slot(perf, "y.values")[[1]],
40
                                  specificity = 1 - methods::slot(perf, "x.values")[[1]],
41
                                  AUC = AUC,
42
                                  layer = names(list.ROC)[k],
43
                                  method = method)
44
    }
45
    
46
    # Combine
47
    ROC_table<-do.call('rbind', list.ROC)
48
    
49
    # Prepare data for plotting
50
    plot_data<-ROC_table
51
    plot_data$displayItem<-paste(plot_data$layer, " AUC = ", plot_data$AUC, sep="")
52
    plot_data$displayItem<-factor(plot_data$displayItem,
53
                                  levels = unique(plot_data$displayItem))
54
    
55
    # ROC curves
56
    p1<-ggplot(plot_data,
57
               aes(x=specificity,
58
                   y=sensitivity,
59
                   group=displayItem)) + 
60
      geom_line(aes(x=specificity,y=sensitivity,color=displayItem)) +
61
      #ggtitle(paste('Training data: ', method, sep=''))+
62
      theme(legend.position="bottom", 
63
            legend.background=element_blank(),
64
            legend.box.background=element_rect(colour="black")) + 
65
      theme_bw() +
66
      xlab("False Positive Rate") +
67
      ylab("True Positive Rate") +
68
      theme(legend.position = "right", legend.direction = "vertical") +
69
      labs(color='') 
70
    
71
    if(fit$test==TRUE){
72
      
73
      # Extract ROC plot data 
74
      list.ROC.valid<-vector("list", length = ncol(fit$yhat.test))
75
      names(list.ROC.valid)<-colnames(fit$yhat.test)
76
      
77
      y <- fit$Y_test
78
      # Loop over layers 
79
      for(k in 1:length(list.ROC.valid)){
80
        preds<-fit$yhat.test[ ,k]
81
        pred = ROCR::prediction(preds, y)
82
        AUC = round(ROCR::performance(pred, "auc")@y.values[[1]], 2)
83
        perf = ROCR::performance(pred, "sens", "spec") 
84
        list.ROC.valid[[k]] <- data.frame(sensitivity = methods::slot(perf, "y.values")[[1]],
85
                                          specificity = 1 - methods::slot(perf, "x.values")[[1]],
86
                                          AUC = AUC,
87
                                          layer = names(list.ROC.valid)[k],
88
                                          method = method)
89
      }
90
      
91
      # Combine
92
      ROC_table_valid<-do.call('rbind', list.ROC.valid)
93
      
94
      # Prepare data for plotting
95
      plot_data<-ROC_table_valid
96
      plot_data$displayItem<-paste(plot_data$layer, " AUC = ", plot_data$AUC, sep="")
97
      plot_data$displayItem<-factor(plot_data$displayItem,
98
                                    levels = unique(plot_data$displayItem))
99
      
100
      # ROC curves
101
      p2<-ggplot(plot_data,
102
                 aes(x=specificity,
103
                     y=sensitivity,
104
                     group=displayItem)) + 
105
        geom_line(aes(x=specificity,y=sensitivity,color=displayItem)) +
106
        #ggtitle(paste('Test data: ', method, sep=''))+
107
        theme(legend.position="bottom", 
108
              legend.background=element_blank(),
109
              legend.box.background=element_rect(colour="black")) + 
110
        theme_bw() +
111
        xlab("False Positive Rate") +
112
        ylab("True Positive Rate") +
113
        theme(legend.position = "right", legend.direction = "vertical") +
114
        labs(color='') 
115
      
116
      p<-plot_grid(p1, 
117
                   p2, 
118
                   nrow = 2, 
119
                   labels = c(paste('A. ', fit$folds,'-fold CV',sep = ''), 
120
                              'B. Independent Validation'),
121
                   label_size = label_size, label_x = label_x,vjust = vjust)+
122
        theme(plot.margin = unit(c(1,1,1,1), "cm"))  
123
      print(p)
124
      return(list('plot'=p,'ROC_table'=ROC_table,'ROC_table_valid'=ROC_table_valid))
125
    }
126
    p <- plot_grid(p1, 
127
                   nrow = nrow,
128
                   ncol = ncol,
129
                   labels = c(paste('A. ', fit$folds,'-fold CV',sep = '')), 
130
                   label_size = label_size, label_x = label_x,vjust = vjust)+
131
      theme(plot.margin = unit(c(1,1,1,1), "cm"))
132
    print(p)
133
    return(list('plot'=p,'ROC_table'=ROC_table)) 
134
  }
135
  else if(fit$family=='gaussian'){
136
    
137
    
138
    # Extract R2 plot data 
139
    list.R2<-vector("list", length = ncol(fit$yhat.train))
140
    names(list.R2)<-colnames(fit$yhat.train)
141
    
142
    y <- fit$Y_train
143
    # Loop over layers 
144
    for(k in 1:length(list.R2)){
145
      preds<-fit$yhat.train[ ,k]
146
      R2<- as.vector(cor(preds, y)^2)
147
      list.R2[[k]] <- data.frame(R2 = R2,
148
                                 layer = names(list.R2)[k],
149
                                 method = method)
150
    }
151
    
152
    # Combine 
153
    R2_table<-do.call('rbind', list.R2)
154
    
155
    # Plot
156
    p1<-ggplot(R2_table, aes(x = method, y = R2)) +
157
      geom_bar(position="dodge", stat="identity", aes(fill=layer)) +
158
      xlab("") + 
159
      ylab(expression(paste("Prediction accuracy (", R^2, ")"))) +
160
      scale_fill_discrete(name="") + 
161
      theme(legend.position="bottom", 
162
            legend.background=element_blank(),
163
            legend.box.background=element_rect(colour="black")) + 
164
      theme_bw() +
165
      guides(fill=guide_legend(title="")) +
166
      theme(legend.position = "right", legend.direction = "vertical",
167
            strip.background = element_blank()) +
168
      labs(fill='') 
169
    
170
    
171
    
172
    if(fit$test==TRUE){
173
      
174
      
175
      # Extract R2 plot data 
176
      list.R2.valid<-vector("list", length = ncol(fit$yhat.test))
177
      names(list.R2.valid)<-colnames(fit$yhat.test)
178
      
179
      y <- fit$Y_test
180
      # Loop over layers 
181
      for(k in 1:length(list.R2.valid)){
182
        preds<-fit$yhat.test[ ,k]
183
        R2<- as.vector(cor(preds, y)^2)
184
        list.R2.valid[[k]] <- data.frame(R2 = R2,
185
                                         layer = names(list.R2.valid)[k],
186
                                         method = method)
187
      }
188
      
189
      # Combine 
190
      R2_table_valid<-do.call('rbind', list.R2.valid)
191
      
192
      # Plot
193
      p2<-ggplot(R2_table_valid, aes(x = method, y = R2)) +
194
        geom_bar(position="dodge", stat="identity", aes(fill=layer)) +
195
        xlab("") + 
196
        ylab(expression(paste("Prediction accuracy (", R^2, ")"))) +
197
        scale_fill_discrete(name="") + 
198
        theme(legend.position="bottom", 
199
              legend.background=element_blank(),
200
              legend.box.background=element_rect(colour="black")) + 
201
        theme_bw() +
202
        guides(fill=guide_legend(title="")) +
203
        theme(legend.position = "right", legend.direction = "vertical",
204
              strip.background = element_blank()) +
205
        labs(fill='') 
206
      
207
      nrow = NULL
208
      ncol = NULL
209
      p<-plot_grid(p1, 
210
                   p2, 
211
                   nrow = nrow,
212
                   ncol = ncol,
213
                   labels = c(paste('A. ', fit$folds,'-fold CV',sep = ''), 
214
                              'B. Independent Validation'),
215
                   label_size = label_size, label_x = label_x,vjust = vjust)+
216
        theme(plot.margin = unit(c(1,1,1,1), "cm"))  
217
      print(p)
218
      return(list('plot'=p,'R2_table'=R2_table,'R2_table_valid'=R2_table_valid))
219
      
220
    }
221
    p <- plot_grid(p1, 
222
                   ncol = 1, 
223
                   labels = c(paste('A. ', fit$folds,'-fold CV',sep = '')), 
224
                   label_size = label_size, label_x = label_x,vjust = vjust)+
225
      theme(plot.margin = unit(c(1,1,1,1), "cm"))
226
    print(p)
227
    return(list('plot'=p,'R2_table'=R2_table)) 
228
    
229
  }
230
}