|
a |
|
b/R/VarImpLearner.R |
|
|
1 |
# This is a basic utility function to plot the top 20 feasture importance scores based on BART posterior samples. |
|
|
2 |
# Depends on the library bartMachine. |
|
|
3 |
|
|
|
4 |
VarImp.learner <- function(fit, |
|
|
5 |
num.var = 20, |
|
|
6 |
layer.names = NULL){ |
|
|
7 |
|
|
|
8 |
################################################ |
|
|
9 |
# Extract required elements from the IL object # |
|
|
10 |
################################################ |
|
|
11 |
|
|
|
12 |
if(fit$meta_learner=="SL.nnls.auc"){ |
|
|
13 |
VIMP_stack<- cbind.data.frame(fit$weights) |
|
|
14 |
colnames(VIMP_stack)<-c('mean') |
|
|
15 |
VIMP_stack$sd <- NA |
|
|
16 |
VIMP_stack$type<-'stack' |
|
|
17 |
}else{ |
|
|
18 |
VIMP_stack <- NULL |
|
|
19 |
} |
|
|
20 |
|
|
|
21 |
if(fit$base_learner=="SL.BART"){ |
|
|
22 |
if(is.null(layer.names)){ |
|
|
23 |
layer.names <- names(fit$model_fits$model_layers) |
|
|
24 |
} |
|
|
25 |
|
|
|
26 |
if(all(layer.names %in% names(fit$model_fits$model_layers))==FALSE){ |
|
|
27 |
stop(paste(layer.names[!(layer.names %in% names(fit$model_fits$model_layers))], |
|
|
28 |
"is not a valid layer in the fit object.")) |
|
|
29 |
} |
|
|
30 |
|
|
|
31 |
|
|
|
32 |
####################################################### |
|
|
33 |
# Extract per-layer feature importance scores (VIMPs) # |
|
|
34 |
####################################################### |
|
|
35 |
|
|
|
36 |
VIMP_list <- list() |
|
|
37 |
for( i in 1:length(layer.names)){ |
|
|
38 |
qq<-bartMachine::investigate_var_importance( |
|
|
39 |
fit$model_fits$model_layers[[layer.names[i]]],plot = FALSE) |
|
|
40 |
|
|
|
41 |
VIMP_layer<-cbind.data.frame(qq$avg_var_props, qq$sd_var_props) |
|
|
42 |
colnames(VIMP_layer)<-c('mean', 'sd') |
|
|
43 |
VIMP_layer$type<-layer.names[i] |
|
|
44 |
VIMP_list[[i]] <- VIMP_layer[1:num.var, ] |
|
|
45 |
} |
|
|
46 |
VIMP <- do.call(rbind,VIMP_list) |
|
|
47 |
}else{ |
|
|
48 |
stop("This functionality is currently available only for BART base learner") |
|
|
49 |
} |
|
|
50 |
|
|
|
51 |
########################### |
|
|
52 |
# Feature importance plot # |
|
|
53 |
########################### |
|
|
54 |
|
|
|
55 |
if(!is.null(VIMP_stack)){ VIMP <- as.data.frame(rbind.data.frame(VIMP_stack,VIMP))} |
|
|
56 |
|
|
|
57 |
VIMP<-rownames_to_column(VIMP, 'ID') |
|
|
58 |
p<-VIMP %>% |
|
|
59 |
filter(type %in% layer.names) %>% |
|
|
60 |
arrange(mean) %>% |
|
|
61 |
mutate(ID = str_replace_all(ID, fixed("_"), " ")) %>% |
|
|
62 |
mutate(type = factor(type, |
|
|
63 |
levels = layer.names, |
|
|
64 |
labels = layer.names)) %>% |
|
|
65 |
ggplot(aes(reorder(ID, -mean), mean, fill = type)) + |
|
|
66 |
facet_wrap(.~ type, scale = 'free') + |
|
|
67 |
geom_bar(stat = "identity", fill = "lightsalmon") + |
|
|
68 |
geom_errorbar(aes(ymin=ifelse(mean-sd>0,mean-sd,0), ymax=mean+sd), width=.2, position=position_dodge(.9)) + |
|
|
69 |
theme_bw() + |
|
|
70 |
coord_flip() + |
|
|
71 |
omicsEye_theme() + |
|
|
72 |
theme (strip.background = element_blank()) + |
|
|
73 |
ylab('Inclusion proportion') + |
|
|
74 |
xlab('') |
|
|
75 |
|
|
|
76 |
return(p) |
|
|
77 |
} |
|
|
78 |
|
|
|
79 |
# This is a ggplot theme from the Rahnavard lab at GWU. |
|
|
80 |
omicsEye_theme <- function() { |
|
|
81 |
# set default text format based on categorical and length |
|
|
82 |
angle = 45 |
|
|
83 |
hjust = 1 |
|
|
84 |
size = 6 |
|
|
85 |
return (ggplot2::theme_bw() + ggplot2::theme( |
|
|
86 |
axis.text.x = ggplot2::element_text(size = 8, vjust = 1, hjust = hjust, angle = angle), |
|
|
87 |
axis.text.y = ggplot2::element_text(size = 8, hjust = 1), |
|
|
88 |
axis.title = ggplot2::element_text(size = 10), |
|
|
89 |
plot.title = ggplot2::element_text(size = 10), |
|
|
90 |
plot.subtitle = ggplot2::element_text(size = 8), |
|
|
91 |
legend.title = ggplot2::element_text(size = 6, face = 'bold'), |
|
|
92 |
legend.text = ggplot2::element_text(size = 7), |
|
|
93 |
axis.line = ggplot2::element_line(colour = 'black', size = .25), |
|
|
94 |
ggplot2::element_line(colour = 'black', size = .25), |
|
|
95 |
axis.line.x = ggplot2::element_line(colour = 'black', size = .25), |
|
|
96 |
axis.line.y = ggplot2::element_line(colour = 'black', size = .25), |
|
|
97 |
panel.border = ggplot2::element_blank(), |
|
|
98 |
panel.grid.major = ggplot2::element_blank(), |
|
|
99 |
panel.grid.minor = ggplot2::element_blank()) |
|
|
100 |
) |
|
|
101 |
} |