Diff of /R/VarImpLearner.R [000000] .. [a4ee51]

Switch to side-by-side view

--- a
+++ b/R/VarImpLearner.R
@@ -0,0 +1,101 @@
+# This is a basic utility function to plot the top 20 feasture importance scores based on BART posterior samples.
+# Depends on the library bartMachine. 
+
+VarImp.learner <- function(fit,
+                           num.var = 20,
+                           layer.names = NULL){
+  
+  ################################################
+  # Extract required elements from the IL object #
+  ################################################
+
+  if(fit$meta_learner=="SL.nnls.auc"){
+    VIMP_stack<- cbind.data.frame(fit$weights)
+    colnames(VIMP_stack)<-c('mean')
+    VIMP_stack$sd <- NA
+    VIMP_stack$type<-'stack'
+  }else{
+    VIMP_stack <- NULL
+  }
+  
+  if(fit$base_learner=="SL.BART"){
+    if(is.null(layer.names)){
+      layer.names <- names(fit$model_fits$model_layers)  
+    }
+    
+    if(all(layer.names %in% names(fit$model_fits$model_layers))==FALSE){
+      stop(paste(layer.names[!(layer.names %in% names(fit$model_fits$model_layers))],
+                             "is not a valid layer in the fit object."))
+    }
+    
+    
+    #######################################################
+    # Extract per-layer feature importance scores (VIMPs) #
+    #######################################################
+    
+    VIMP_list <- list()
+    for( i in 1:length(layer.names)){
+      qq<-bartMachine::investigate_var_importance(
+        fit$model_fits$model_layers[[layer.names[i]]],plot = FALSE)
+      
+      VIMP_layer<-cbind.data.frame(qq$avg_var_props, qq$sd_var_props)
+      colnames(VIMP_layer)<-c('mean', 'sd')
+      VIMP_layer$type<-layer.names[i]
+      VIMP_list[[i]] <- VIMP_layer[1:num.var, ]
+    }
+    VIMP <- do.call(rbind,VIMP_list)
+  }else{
+    stop("This functionality is currently available only for BART base learner")
+  }
+  
+  ###########################
+  # Feature importance plot #
+  ###########################
+  
+  if(!is.null(VIMP_stack)){ VIMP <- as.data.frame(rbind.data.frame(VIMP_stack,VIMP))}
+    
+  VIMP<-rownames_to_column(VIMP, 'ID')
+  p<-VIMP %>% 
+    filter(type %in% layer.names) %>% 
+    arrange(mean) %>% 
+    mutate(ID = str_replace_all(ID, fixed("_"), " ")) %>% 
+    mutate(type = factor(type, 
+                         levels = layer.names,
+                         labels = layer.names)) %>% 
+    ggplot(aes(reorder(ID, -mean), mean, fill = type)) +
+    facet_wrap(.~ type, scale = 'free') + 
+    geom_bar(stat = "identity", fill = "lightsalmon") + 
+    geom_errorbar(aes(ymin=ifelse(mean-sd>0,mean-sd,0), ymax=mean+sd), width=.2, position=position_dodge(.9)) +
+    theme_bw() + 
+    coord_flip() + 
+    omicsEye_theme() + 
+    theme (strip.background = element_blank()) + 
+    ylab('Inclusion proportion') + 
+    xlab('') 
+
+  return(p)  
+}
+
+# This is a ggplot theme from the Rahnavard lab at GWU.
+omicsEye_theme <- function() {
+  # set default text format based on categorical and length
+  angle = 45
+  hjust = 1
+  size = 6
+  return (ggplot2::theme_bw() + ggplot2::theme(
+    axis.text.x = ggplot2::element_text(size = 8, vjust = 1, hjust = hjust, angle = angle),
+    axis.text.y = ggplot2::element_text(size = 8, hjust = 1),
+    axis.title = ggplot2::element_text(size = 10),
+    plot.title = ggplot2::element_text(size = 10),
+    plot.subtitle = ggplot2::element_text(size = 8),
+    legend.title = ggplot2::element_text(size = 6, face = 'bold'),
+    legend.text = ggplot2::element_text(size = 7),
+    axis.line = ggplot2::element_line(colour = 'black', size = .25),
+    ggplot2::element_line(colour = 'black', size = .25),
+    axis.line.x = ggplot2::element_line(colour = 'black', size = .25),
+    axis.line.y = ggplot2::element_line(colour = 'black', size = .25),
+    panel.border = ggplot2::element_blank(),
+    panel.grid.major = ggplot2::element_blank(),
+    panel.grid.minor = ggplot2::element_blank())
+  )
+}