Diff of /R/predictLearner.R [000000] .. [a4ee51]

Switch to unified view

a b/R/predictLearner.R
1
#' Make predictions using a trained 'IntegratedLearner' model
2
#'
3
#'@description This function makes predictions using a trained 'IntegratedLearner' model for new samples for which predictions are to be made
4
#'
5
#' @param fit fitted "IntegratedLearner" object 
6
#' @param feature_table_valid Feature table from validation set. Must have the exact same structure as feature_table.
7
#' @param sample_metadata_valid OPTIONAL (can provide feature_table_valid and not this):  Sample-specific metadata table from independent validation set. If provided, it must have the exact same structure as sample_metadata.
8
#' @param feature_metadata Matrix containing feature names and their corresponding layers. Must be same as that provided in IntegratedLearner object. 
9
#'
10
#' @return Predicted values
11
#' @export
12
predict.learner <- function(fit,
13
                            feature_table_valid = NULL, # Feature table from validation set. Must have the exact same structure as feature_table. If missing, uses feature_table for feature_table_valid.
14
                            sample_metadata_valid = NULL, # Optional: Sample-specific metadata table from independent validation set. Must have the exact same structure as sample_metadata.
15
                            feature_metadata=NULL){
16
  
17
  if(all(fit$feature.names==rownames(feature_metadata))==FALSE){
18
    stop("Both training feature_table and feature_metadata should have the same rownames.")
19
  }
20
  
21
  
22
  if(is.null(feature_table_valid)){
23
    stop("Feature table for validation set cannot be empty")
24
  } 
25
  # if(is.null(sample_metadata_valid)){
26
  #   stop("Sample metadata for validation set cannot be empty")
27
  # }
28
  
29
  if (!is.null(feature_table_valid)){
30
    if(all(fit$feature.names==rownames(feature_table_valid))==FALSE)
31
      stop("Both feature_table and feature_table_valid should have the same rownames.")
32
  }
33
  
34
  if (!is.null(sample_metadata_valid)){
35
    if(all(colnames(feature_table_valid)==rownames(sample_metadata_valid))==FALSE)
36
      stop("Row names of sample_metadata_valid must match the column names of feature_table_valid")
37
  }
38
  
39
  
40
  
41
  if (!'featureID' %in% colnames(feature_metadata)){
42
    stop("feature_metadata must have a column named 'featureID' describing per-feature unique identifiers.")
43
  }
44
  
45
  if (!'featureType' %in% colnames(feature_metadata)){
46
    stop("feature_metadata must have a column named 'featureType' describing the corresponding source layers.")
47
  }
48
  
49
  if (!is.null(sample_metadata_valid)){
50
    if (!'subjectID' %in% colnames(sample_metadata_valid)){
51
      stop("sample_metadata_valid must have a column named 'subjectID' describing per-subject unique identifiers.")
52
    }
53
    
54
    if (!'Y' %in% colnames(sample_metadata_valid)){
55
      stop("sample_metadata_valid must have a column named 'Y' describing the outcome of interest.")
56
    }
57
  }
58
  
59
  #############################################################################################
60
  # Extract validation Y right away (will not be used anywhere during the validation process) #
61
  #############################################################################################
62
  
63
  if (!is.null(sample_metadata_valid)){validY<-sample_metadata_valid['Y']}
64
  
65
  #####################################################################
66
  # Stacked generalization input data preparation for validation data #
67
  #####################################################################
68
  feature_metadata$featureType<-as.factor(feature_metadata$featureType)
69
  name_layers<-with(droplevels(feature_metadata), list(levels = levels(featureType)), 
70
                    nlevels = nlevels(featureType))$levels
71
  
72
  X_test_layers <- vector("list", length(name_layers)) 
73
  names(X_test_layers) <- name_layers
74
  
75
  layer_wise_prediction_valid<-vector("list", length(name_layers))
76
  names(layer_wise_prediction_valid)<-name_layers
77
  
78
  for(i in seq_along(name_layers)){
79
    
80
    ############################################################
81
    # Prepare single-omic validation data and save predictions #
82
    ############################################################
83
    include_list<-feature_metadata %>% filter(featureType == name_layers[i]) 
84
    t_dat_slice_valid<-feature_table_valid[rownames(feature_table_valid) %in% include_list$featureID, ]
85
    dat_slice_valid<-as.data.frame(t(t_dat_slice_valid))
86
    X_test_layers[[i]] <- dat_slice_valid
87
    layer_wise_prediction_valid[[i]]<-predict.SuperLearner(fit$SL_fits$SL_fit_layers[[i]], newdata = dat_slice_valid)$pred
88
    rownames(layer_wise_prediction_valid[[i]])<-rownames(dat_slice_valid)
89
    rm(dat_slice_valid); rm(include_list)
90
  }
91
  
92
  combo_valid <- as.data.frame(do.call(cbind, layer_wise_prediction_valid))
93
  names(combo_valid)<-name_layers
94
  
95
  if(fit$run_stacked==TRUE){
96
    stacked_prediction_valid<-predict.SuperLearner(fit$SL_fits$SL_fit_stacked, newdata = combo_valid)$pred
97
    rownames(stacked_prediction_valid)<-rownames(combo_valid)  
98
  }
99
  if(fit$run_concat==TRUE){
100
    fulldat_valid<-as.data.frame(t(feature_table_valid))
101
    concat_prediction_valid<-predict.SuperLearner(fit$SL_fits$SL_fit_concat, 
102
                                                  newdata = fulldat_valid)$pred
103
    rownames(concat_prediction_valid)<-rownames(fulldat_valid)
104
  }
105
  
106
  res=list()
107
  
108
  if (!is.null(sample_metadata_valid)){
109
    Y_test=validY$Y
110
    res$Y_test =Y_test
111
  }
112
  
113
  if(fit$run_concat & fit$run_stacked){
114
    yhat.test <- cbind(combo_valid, stacked_prediction_valid , concat_prediction_valid)
115
    colnames(yhat.test) <- c(colnames(combo_valid),"stacked","concatenated")  
116
  }else if(fit$run_concat & !fit$run_stacked){
117
    yhat.test <- cbind(combo_valid,  concat_prediction_valid)
118
    colnames(yhat.test) <- c(colnames(combo_valid),"concatenated")  
119
  }else if(!fit$run_concat & fit$run_stacked){
120
    yhat.test <- cbind(combo_valid, stacked_prediction_valid )
121
    colnames(yhat.test) <- c(colnames(combo_valid),"stacked")  
122
  }else{
123
    yhat.test <- combo_valid   
124
  }
125
  
126
  res$yhat.test <- yhat.test
127
  if (!is.null(sample_metadata_valid)){
128
    if(fit$family=='binomial'){
129
      # Calculate AUC for each layer, stacked and concatenated 
130
      pred=apply(res$yhat.test, 2, ROCR::prediction, labels=res$Y_test)
131
      AUC=vector(length = length(pred))
132
      names(AUC)=names(pred)
133
      for(i in seq_along(pred)){
134
        AUC[i] = round(ROCR::performance(pred[[i]], "auc")@y.values[[1]], 3)
135
      }
136
      res$AUC.test <- AUC  
137
      
138
    }
139
    
140
    if(fit$family=='gaussian'){
141
      # Calculate R^2 for each layer, stacked and concatenated 
142
      R2=vector(length = ncol(res$yhat.test))
143
      names(R2)=names(res$yhat.test)
144
      for(i in seq_along(R2)){
145
        R2[i] = as.vector(cor(res$yhat.test[ ,i], res$Y_test)^2)
146
      }
147
      res$R2.test <- R2
148
    }
149
  }
150
  
151
  return(res)
152
  
153
}
154