Diff of /R/VarImpLearner.R [000000] .. [a4ee51]

Switch to unified view

a b/R/VarImpLearner.R
1
# This is a basic utility function to plot the top 20 feasture importance scores based on BART posterior samples.
2
# Depends on the library bartMachine. 
3
4
VarImp.learner <- function(fit,
5
                           num.var = 20,
6
                           layer.names = NULL){
7
  
8
  ################################################
9
  # Extract required elements from the IL object #
10
  ################################################
11
12
  if(fit$meta_learner=="SL.nnls.auc"){
13
    VIMP_stack<- cbind.data.frame(fit$weights)
14
    colnames(VIMP_stack)<-c('mean')
15
    VIMP_stack$sd <- NA
16
    VIMP_stack$type<-'stack'
17
  }else{
18
    VIMP_stack <- NULL
19
  }
20
  
21
  if(fit$base_learner=="SL.BART"){
22
    if(is.null(layer.names)){
23
      layer.names <- names(fit$model_fits$model_layers)  
24
    }
25
    
26
    if(all(layer.names %in% names(fit$model_fits$model_layers))==FALSE){
27
      stop(paste(layer.names[!(layer.names %in% names(fit$model_fits$model_layers))],
28
                             "is not a valid layer in the fit object."))
29
    }
30
    
31
    
32
    #######################################################
33
    # Extract per-layer feature importance scores (VIMPs) #
34
    #######################################################
35
    
36
    VIMP_list <- list()
37
    for( i in 1:length(layer.names)){
38
      qq<-bartMachine::investigate_var_importance(
39
        fit$model_fits$model_layers[[layer.names[i]]],plot = FALSE)
40
      
41
      VIMP_layer<-cbind.data.frame(qq$avg_var_props, qq$sd_var_props)
42
      colnames(VIMP_layer)<-c('mean', 'sd')
43
      VIMP_layer$type<-layer.names[i]
44
      VIMP_list[[i]] <- VIMP_layer[1:num.var, ]
45
    }
46
    VIMP <- do.call(rbind,VIMP_list)
47
  }else{
48
    stop("This functionality is currently available only for BART base learner")
49
  }
50
  
51
  ###########################
52
  # Feature importance plot #
53
  ###########################
54
  
55
  if(!is.null(VIMP_stack)){ VIMP <- as.data.frame(rbind.data.frame(VIMP_stack,VIMP))}
56
    
57
  VIMP<-rownames_to_column(VIMP, 'ID')
58
  p<-VIMP %>% 
59
    filter(type %in% layer.names) %>% 
60
    arrange(mean) %>% 
61
    mutate(ID = str_replace_all(ID, fixed("_"), " ")) %>% 
62
    mutate(type = factor(type, 
63
                         levels = layer.names,
64
                         labels = layer.names)) %>% 
65
    ggplot(aes(reorder(ID, -mean), mean, fill = type)) +
66
    facet_wrap(.~ type, scale = 'free') + 
67
    geom_bar(stat = "identity", fill = "lightsalmon") + 
68
    geom_errorbar(aes(ymin=ifelse(mean-sd>0,mean-sd,0), ymax=mean+sd), width=.2, position=position_dodge(.9)) +
69
    theme_bw() + 
70
    coord_flip() + 
71
    omicsEye_theme() + 
72
    theme (strip.background = element_blank()) + 
73
    ylab('Inclusion proportion') + 
74
    xlab('') 
75
76
  return(p)  
77
}
78
79
# This is a ggplot theme from the Rahnavard lab at GWU.
80
omicsEye_theme <- function() {
81
  # set default text format based on categorical and length
82
  angle = 45
83
  hjust = 1
84
  size = 6
85
  return (ggplot2::theme_bw() + ggplot2::theme(
86
    axis.text.x = ggplot2::element_text(size = 8, vjust = 1, hjust = hjust, angle = angle),
87
    axis.text.y = ggplot2::element_text(size = 8, hjust = 1),
88
    axis.title = ggplot2::element_text(size = 10),
89
    plot.title = ggplot2::element_text(size = 10),
90
    plot.subtitle = ggplot2::element_text(size = 8),
91
    legend.title = ggplot2::element_text(size = 6, face = 'bold'),
92
    legend.text = ggplot2::element_text(size = 7),
93
    axis.line = ggplot2::element_line(colour = 'black', size = .25),
94
    ggplot2::element_line(colour = 'black', size = .25),
95
    axis.line.x = ggplot2::element_line(colour = 'black', size = .25),
96
    axis.line.y = ggplot2::element_line(colour = 'black', size = .25),
97
    panel.border = ggplot2::element_blank(),
98
    panel.grid.major = ggplot2::element_blank(),
99
    panel.grid.minor = ggplot2::element_blank())
100
  )
101
}