Diff of /R/updateLearner.R [000000] .. [a4ee51]

Switch to unified view

a b/R/updateLearner.R
1
#' Update IntegratedLearner fit object based on layers available in the test set
2
#'
3
#' @description Allow update of IntegratedLearner if only a subset of omics layers are available in test set. If all layers and features match, it calls predict.learner() 
4
#'
5
#' @param fit fitted "IntegratedLearner" object 
6
#' 
7
#' 
8
#'
9
#' @param feature_table_valid Feature table from validation set. It should be a data frame with features in rows and samples in columns. Feature names should be a subset of training data feature names.
10
#' @param sample_metadata_valid OPTIONAL (can provide feature_table_valid and not this):  Sample-specific metadata table from independent validation set. If provided, it must have the exact same structure as sample_metadata. Default is NULL. 
11
#' @param feature_metadata_valid Matrix containing feature names and their corresponding layers. Must be subset of feature_metadata provided in IntegratedLearner object. 
12
#' @param seed Seed for reproducibility. Default is 1234.
13
#' @param verbose Should a summary of fits/ results be printed. Default is FALSE
14
#'
15
#' @return SL object
16
#' @export
17
update.learner <- function(fit,                                     
18
                           feature_table_valid, # Feature table from validation set. Must have the exact same structure as feature_table. If missing, uses feature_table for feature_table_valid.
19
                           sample_metadata_valid=NULL, # OPTIONAL (can provide feature_table_valid and not this):  Sample-specific metadata table from independent validation set. Must have the exact same structure as sample_metadata.
20
                           feature_metadata_valid,
21
                           seed = 1234, # Specify the arbitrary seed value for reproducibility. Default is 1234.
22
                           verbose=FALSE
23
){
24
  # Check that feature table and feature meta data valid is not empty here
25
  if(is.null(feature_table_valid | is.null(feature_metadata_valid))){
26
    stop("feature table/ feature metadata cannot be NULL for validation set in update learner")
27
  }
28
  
29
  if(fit$family=="gaussian"){
30
    family=gaussian()
31
  }else if(fit$family=="binomial"){
32
    family=binomial()
33
  }
34
  
35
  if (!is.null(sample_metadata_valid)){
36
    validY<-sample_metadata_valid['Y']
37
  }
38
  
39
  
40
  feature_metadata_valid$featureType<-as.factor(feature_metadata_valid$featureType)
41
  name_layers_valid<-with(droplevels(feature_metadata_valid), list(levels = levels(featureType)), nlevels = nlevels(featureType))$levels
42
  
43
  
44
  name_layers <- names(fit$model_fits$model_layers)
45
  
46
  # If layers in validation match layers in train
47
  # Just run predict function and return its object
48
  if(length(intersect(name_layers_valid,name_layers))==length(name_layers)){
49
    
50
    # Check if feature names are same for the train and test 
51
    
52
    return(predict.learner(fit, 
53
                           feature_table_valid = feature_table_valid,
54
                           sample_metadata_valid = sample_metadata_valid,
55
                           feature_metadata = feature_metadata_valid))
56
  }else if(length(intersect(name_layers_valid,name_layers))==0){
57
    
58
    stop("Validation set has no layers in common with model fit")
59
    
60
  }else{
61
    
62
    name_layers_common <- intersect(name_layers_valid,name_layers)
63
    
64
    
65
    
66
    # Extract only common name layers part of the fit object 
67
    fit$model_fits$model_layers <- fit$model_fits$model_layers[name_layers_common]
68
    fit$SL_fits$SL_fit_layers <-  fit$SL_fits$SL_fit_layers[name_layers_common]
69
    fit$X_train_layers <- fit$X_train_layers[name_layers_common]
70
    
71
    # Use common layers to get layer wise predictions for validation set
72
    X_test_layers <- vector("list", length(name_layers_common)) 
73
    names(X_test_layers) <- name_layers_common
74
    
75
    if (!is.null(feature_table_valid)){
76
      layer_wise_prediction_valid<-vector("list", length(name_layers_common))
77
      names(layer_wise_prediction_valid)<-name_layers_common
78
    } 
79
    
80
    
81
    for(i in seq_along(name_layers_common)){
82
      include_list<-feature_metadata_valid %>% filter(featureType == name_layers_common[i]) 
83
      
84
      # check if feature names in common layers match for train and test set 
85
      if(!all(include_list$featureID==colnames(fit$X_train_layers[name_layers_common[i]]))){
86
        stop(paste0("Validation set feature names for layer ", name_layers_common[i]," do not match with training data" ))
87
      }
88
      
89
      
90
      if (!is.null(feature_table_valid)){
91
        t_dat_slice_valid<-feature_table_valid[rownames(feature_table_valid) %in% include_list$featureID, ]
92
        dat_slice_valid<-as.data.frame(t(t_dat_slice_valid))
93
        X_test_layers[[i]] <- dat_slice_valid
94
        layer_wise_prediction_valid[[i]]<-predict.SuperLearner(fit$SL_fits$SL_fit_layers[[i]], newdata = dat_slice_valid)$pred
95
        rownames(layer_wise_prediction_valid[[i]])<-rownames(dat_slice_valid)
96
        fit$SL_fits$SL_fit_layers[[i]]$validX<-dat_slice_valid
97
        fit$SL_fits$SL_fit_layers[[i]]$validPrediction<-layer_wise_prediction_valid[[i]]
98
        colnames(fit$SL_fits$SL_fit_layers[[i]]$validPrediction)<-'validPrediction'
99
        rm(dat_slice_valid); rm(include_list)
100
      }
101
    }
102
    
103
    combo <- fit$yhat.train[ ,name_layers_common]
104
    
105
    if (!is.null(feature_table_valid)){
106
      combo_valid <- as.data.frame(do.call(cbind, layer_wise_prediction_valid))
107
      names(combo_valid)<-name_layers_valid
108
    }
109
    
110
    
111
    if(fit$run_stacked){
112
      
113
      cat('Running new stacked model...\n')
114
      #}
115
      
116
      ###################################
117
      # Run user-specified meta learner #
118
      ###################################
119
      
120
      SL_fit_stacked<-SuperLearner::SuperLearner(Y = fit$Y_train, 
121
                                                 X = combo, 
122
                                                 cvControl = fit$cvControl,    
123
                                                 verbose = verbose, 
124
                                                 SL.library = fit$meta_learner,
125
                                                 family=family)
126
      
127
      # Extract the fit object from superlearner
128
      model_stacked <- SL_fit_stacked$fitLibrary[[1]]$object
129
      
130
      ###################################################
131
      # Append the corresponding y and X to the results #
132
      ###################################################
133
      
134
      SL_fit_stacked$Y<-fit$Y_train
135
      SL_fit_stacked$X<-combo
136
      if (!is.null(sample_metadata_valid)) SL_fit_stacked$validY<-validY
137
      
138
      #################################################################
139
      # Prepate stacked input data for validation and save prediction #
140
      #################################################################
141
      
142
      if (!is.null(feature_table_valid)){
143
        stacked_prediction_valid<-predict.SuperLearner(SL_fit_stacked, newdata = combo_valid)$pred
144
        rownames(stacked_prediction_valid)<-rownames(combo_valid)
145
        SL_fit_stacked$validX<-combo_valid
146
        SL_fit_stacked$validPrediction<-stacked_prediction_valid
147
        colnames(SL_fit_stacked$validPrediction)<-'validPrediction'
148
      }
149
      
150
      fit$model_fits$model_stacked <- model_stacked
151
      fit$SL_fits$SL_fit_stacked <- SL_fit_stacked
152
      fit$yhat.train$stacked <- SL_fit_stacked$Z
153
      
154
      
155
    }
156
    
157
    
158
    if(fit$run_concat){
159
      #if (verbose) {
160
      cat('Running new concatenated model...\n')
161
      #}
162
      ###################################
163
      # Prepate concatenated input data #
164
      ###################################
165
      feature_table <-  Reduce(cbind.data.frame,fit$X_train_layers)
166
      feature_table <- feature_table[ ,feature_metadata_valid$featureID]
167
      fulldat<-as.data.frame(feature_table)
168
      
169
      ###################################
170
      # Run user-specified base learner #
171
      ###################################
172
      
173
      SL_fit_concat<-SuperLearner::SuperLearner(Y = fit$Y_train, 
174
                                                X = fulldat, 
175
                                                cvControl = fit$cvControl,    
176
                                                verbose = verbose, 
177
                                                SL.library = list(c(fit$base_learner,fit$base_screener)),
178
                                                family=family)
179
      
180
      # Extract the fit object from superlearner
181
      model_concat <- SL_fit_concat$fitLibrary[[1]]$object
182
      
183
      ###################################################
184
      # Append the corresponding y and X to the results #
185
      ###################################################
186
      
187
      SL_fit_concat$Y<-fit$Y_train
188
      SL_fit_concat$X<-fulldat
189
      if (!is.null(sample_metadata_valid)) SL_fit_concat$validY<-validY
190
      
191
      #########################################################################
192
      # Prepate concatenated input data for validaton set and save prediction #
193
      #########################################################################
194
      
195
      if (!is.null(feature_table_valid)){
196
        fulldat_valid<-as.data.frame(t(feature_table_valid))
197
        concat_prediction_valid<-predict.SuperLearner(SL_fit_concat, newdata = fulldat_valid)$pred
198
        SL_fit_concat$validX<-fulldat_valid
199
        rownames(concat_prediction_valid)<-rownames(fulldat_valid)
200
        SL_fit_concat$validPrediction<-concat_prediction_valid
201
        colnames(SL_fit_concat$validPrediction)<-'validPrediction'
202
      }
203
      
204
      fit$model_fits$model_concat <- model_concat
205
      fit$SL_fits$SL_fit_concat <- SL_fit_concat
206
      fit$yhat.train$concatenated <- SL_fit_concat$Z
207
    }
208
    
209
    
210
    if(fit$run_concat & fit$run_stacked){
211
      fit$yhat.train <- fit$yhat.train[ ,c(name_layers_common,"stacked","concatenated")]
212
      
213
    }else if(fit$run_concat & !fit$run_stacked){
214
      fit$yhat.train <- fit$yhat.train[ ,c(name_layers_common,"concatenated")]
215
      
216
    }else if(!fit$run_concat & fit$run_stacked){
217
      fit$yhat.train <- fit$yhat.train[ ,c(name_layers_common,"stacked")]
218
      
219
    }else if(!fit$run_concat & !fit$run_stacked){
220
      fit$yhat.train <- fit$yhat.train[ ,name_layers_common]
221
      
222
    }
223
    
224
    
225
    if(!is.null(feature_table_valid)){
226
      
227
      if(fit$run_concat & fit$run_stacked){
228
        yhat.test <- cbind(combo_valid, SL_fit_stacked$validPrediction,SL_fit_concat$validPrediction)
229
        colnames(yhat.test) <- c(colnames(combo_valid),"stacked","concatenated")
230
        
231
      }else if(fit$run_concat & !fit$run_stacked){
232
        yhat.test <- cbind(combo_valid, SL_fit_concat$validPrediction)
233
        colnames(yhat.test) <- c(colnames(combo_valid),"concatenated")
234
        
235
      }else if(!fit$run_concat & fit$run_stacked){
236
        yhat.test <- cbind(combo_valid, SL_fit_stacked$validPrediction)
237
        colnames(yhat.test) <- c(colnames(combo_valid),"stacked")
238
        
239
      }else if(!fit$run_concat & !fit$run_stacked){
240
        yhat.test <- cbind(combo_valid)
241
        colnames(yhat.test) <- c(colnames(combo_valid))
242
        
243
      }
244
      fit$yhat.test <- yhat.test
245
      fit$X_test_layers <- X_test_layers
246
    }
247
    if(is.null(sample_metadata_valid)){
248
      fit$test=FALSE
249
    }else{
250
      fit$test=TRUE
251
    }
252
    if(fit$meta_learner=="SL.nnls.auc" & fit$run_stacked){
253
      fit$weights <- fit$model_fits$model_stacked$solution
254
      names(fit$weights) <- colnames(combo)
255
    }
256
    
257
    if(!is.null(sample_metadata_valid)){fit$Y_test=validY$Y}
258
    
259
    if(fit$family=="binomial"){
260
      # Calculate AUC for each layer, stacked and concatenated 
261
      pred=apply(fit$yhat.train, 2, ROCR::prediction, labels=fit$Y_train)
262
      AUC=vector(length = length(pred))
263
      names(AUC)=names(pred)
264
      for(i in seq_along(pred)){
265
        AUC[i] = round(ROCR::performance(pred[[i]], "auc")@y.values[[1]], 3)
266
      }
267
      fit$AUC.train <- AUC
268
      
269
      if(fit$test==TRUE){
270
        # Calculate AUC for each layer, stacked and concatenated 
271
        pred=apply(fit$yhat.test, 2, ROCR::prediction, labels=fit$Y_test)
272
        AUC=vector(length = length(pred))
273
        names(AUC)=names(pred)
274
        for(i in seq_along(pred)){
275
          AUC[i] = round(ROCR::performance(pred[[i]], "auc")@y.values[[1]], 3)
276
        }
277
        fit$AUC.test <- AUC  
278
      }
279
    }
280
    if(fit$family=="gaussian"){
281
      
282
      # Calculate R^2 for each layer, stacked and concatenated 
283
      R2=vector(length = ncol(fit$yhat.train))
284
      names(R2)=names(fit$yhat.train)
285
      for(i in seq_along(R2)){
286
        R2[i] = as.vector(cor(fit$yhat.train[ ,i], fit$Y_train)^2)
287
      }
288
      fit$R2.train <- R2
289
      if(fit$test==TRUE){
290
        # Calculate R^2 for each layer, stacked and concatenated 
291
        R2=vector(length = ncol(fit$yhat.test))
292
        names(R2)=names(fit$yhat.test)
293
        for(i in seq_along(R2)){
294
          R2[i] = as.vector(cor(fit$yhat.test[ ,i], fit$Y_test)^2)
295
        }
296
        fit$R2.test <- R2
297
      }
298
      
299
    }  
300
    fit$feature.names <- rownames(feature_table_valid)
301
    print.learner(fit)
302
    return(fit)
303
  }
304
}