Diff of /R/credintLearner.R [000000] .. [a4ee51]

Switch to unified view

a b/R/credintLearner.R
1
# This is a basic utility function to plot the credible intervals based on BART posterior samples.
2
# Depends on the library mcmcplots. 
3
4
credint.learner <- function(fit, 
5
                            test = FALSE,
6
                            ylab = NULL,
7
                            xlab = "Observations",
8
                            cex.main = 1,
9
                            font.main = 1,
10
                            cex.lab = 1,
11
                            cex.axis = 1,
12
                            style = "plain",
13
                            legend = c("Y=0", "Y=1"),...){
14
  
15
  ################################################
16
  # Extract required elements from the IL object #
17
  ################################################
18
  
19
  if(fit$base_learner =="SL.BART" & fit$meta_learner=="SL.nnls.auc"){
20
    weights <- fit$weights
21
    
22
    if(test==TRUE){
23
      if(fit$test==FALSE){stop("No test set information available as part of the fit object")}
24
      dataX <- fit$X_test_layers
25
      dataY <- fit$Y_test
26
      
27
    }else{
28
      dataX <- fit$X_train_layers
29
      dataY <- fit$Y_train  
30
    }
31
    
32
    #############################
33
    # Extract posterior samples #
34
    #############################
35
    
36
    post.samples <- vector("list", length(weights))
37
    names(post.samples) <- names(dataX)
38
    
39
    for(i in seq_along(post.samples)){
40
      post.samples[[i]] <- bart_machine_get_posterior(fit$model_fits$model_layers[[i]],dataX[[i]])$y_hat_posterior_samples
41
    }
42
    
43
    ##################################
44
    # Get weighted posterior samples #
45
    ##################################
46
    
47
    weighted.post.samples <-Reduce('+', Map('*', post.samples, weights))
48
    rownames(weighted.post.samples) <- rownames(dataX[[1]])
49
    names(dataY) <- rownames(dataX[[1]])
50
    
51
    ######################################
52
    # Credible interval plot (caterplot) #
53
    ######################################
54
    
55
    pdf(file = NULL)
56
    temp <- caterplot(t(weighted.post.samples),add = FALSE)
57
    dev.off()
58
    
59
    ######################################
60
    # Save the plot as ggplot and return #
61
    ######################################
62
    
63
    if(fit$family=="gaussian"){
64
      caterplot(t(weighted.post.samples),
65
                horizontal = FALSE,labels.loc="fhfh",style=style,...)
66
      points(dataY[temp])
67
      title(main ="", xlab = xlab, ylab = ylab,
68
            line = NA, outer = FALSE,cex.main=cex.main,font.main=font.main,cex.lab=cex.lab,cex.axis=cex.axis)
69
      
70
      p <- recordPlot()   
71
    }else if(fit$family=="binomial"){
72
      caterplot(t(weighted.post.samples),
73
                pch = ifelse(dataY[temp]==0,4,20),
74
                horizontal = FALSE,labels.loc="fhfh",style=style,
75
                labels = rep("",nrow(weighted.post.samples)), ...)
76
      title(main ="", xlab = xlab, ylab = ylab,
77
            line = NA, outer = FALSE,cex.main=cex.main,font.main=font.main,cex.lab=cex.lab,cex.axis=cex.axis)
78
      legend("bottomleft", legend=legend,
79
             pch=c(4, 20), cex=0.8)
80
      p <- recordPlot() 
81
    }
82
  }else{
83
    stop("Credible Interval feature is currently only available for 
84
         BART as base learner and NNLS/AUC as the meta learner")
85
  }
86
  
87
  return(p)
88
  
89
}