[a4ee51]: / R / credintLearner.R

Download this file

89 lines (72 with data), 3.2 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
# This is a basic utility function to plot the credible intervals based on BART posterior samples.
# Depends on the library mcmcplots.
credint.learner <- function(fit,
test = FALSE,
ylab = NULL,
xlab = "Observations",
cex.main = 1,
font.main = 1,
cex.lab = 1,
cex.axis = 1,
style = "plain",
legend = c("Y=0", "Y=1"),...){
################################################
# Extract required elements from the IL object #
################################################
if(fit$base_learner =="SL.BART" & fit$meta_learner=="SL.nnls.auc"){
weights <- fit$weights
if(test==TRUE){
if(fit$test==FALSE){stop("No test set information available as part of the fit object")}
dataX <- fit$X_test_layers
dataY <- fit$Y_test
}else{
dataX <- fit$X_train_layers
dataY <- fit$Y_train
}
#############################
# Extract posterior samples #
#############################
post.samples <- vector("list", length(weights))
names(post.samples) <- names(dataX)
for(i in seq_along(post.samples)){
post.samples[[i]] <- bart_machine_get_posterior(fit$model_fits$model_layers[[i]],dataX[[i]])$y_hat_posterior_samples
}
##################################
# Get weighted posterior samples #
##################################
weighted.post.samples <-Reduce('+', Map('*', post.samples, weights))
rownames(weighted.post.samples) <- rownames(dataX[[1]])
names(dataY) <- rownames(dataX[[1]])
######################################
# Credible interval plot (caterplot) #
######################################
pdf(file = NULL)
temp <- caterplot(t(weighted.post.samples),add = FALSE)
dev.off()
######################################
# Save the plot as ggplot and return #
######################################
if(fit$family=="gaussian"){
caterplot(t(weighted.post.samples),
horizontal = FALSE,labels.loc="fhfh",style=style,...)
points(dataY[temp])
title(main ="", xlab = xlab, ylab = ylab,
line = NA, outer = FALSE,cex.main=cex.main,font.main=font.main,cex.lab=cex.lab,cex.axis=cex.axis)
p <- recordPlot()
}else if(fit$family=="binomial"){
caterplot(t(weighted.post.samples),
pch = ifelse(dataY[temp]==0,4,20),
horizontal = FALSE,labels.loc="fhfh",style=style,
labels = rep("",nrow(weighted.post.samples)), ...)
title(main ="", xlab = xlab, ylab = ylab,
line = NA, outer = FALSE,cex.main=cex.main,font.main=font.main,cex.lab=cex.lab,cex.axis=cex.axis)
legend("bottomleft", legend=legend,
pch=c(4, 20), cex=0.8)
p <- recordPlot()
}
}else{
stop("Credible Interval feature is currently only available for
BART as base learner and NNLS/AUC as the meta learner")
}
return(p)
}