Diff of /R/IntegratedLearner.R [000000] .. [a4ee51]

Switch to unified view

a b/R/IntegratedLearner.R
1
#' Integrated machine learning for multi-omics prediction and classification
2
#'
3
#' Performs integrated machine learning to predict a binary or continuous outcome based on two or more omics layers (views). 
4
#' The \code{IntegratedLearner} function takes a training set (Y, X1, X2,...,Xn) and returns the predicted values based on a validation set.
5
#' It also performs V-fold nested cross-validation to estimate the prediction accuracy of various fusion algorithms. 
6
#' Three types of integration paradigms are supported: early, late, and intermediate. 
7
#' The software includes multiple ML models based on the \code{\link[SuperLearner]{SuperLearner}} R package as well as several data exploration capabilities and visualization modules in a unified estimation framework.
8
#' @param feature_table An R data frame containing multiview features (in rows) and samples (in columns). 
9
#' Column names of \code{feature_metadata} must match the row names of \code{sample_metadata}.
10
#' @param sample_metadata An R data frame of metadata variables (in columns). 
11
#' Must have a column named \code{subjectID} describing per-subject unique identifiers. 
12
#' For longitudinal designs, this variable is expected to have non-unique values. 
13
#' Additionally, a column named \code{Y} must be present which is the outcome of interest (can be binary or continuous). 
14
#' Row names of \code{sample_metadata} must match the column names of \code{feature_table}.
15
#' @param feature_metadata An R data frame of feature-specific metadata across views (in columns) and features (in rows).
16
#' Must have a column named \code{featureID} describing per-feature unique identifiers. 
17
#' Additionally, a column named \code{featureType} should describe the corresponding source layers.
18
#' Row names of \code{feature_metadata} must match the row names of \code{feature_table}.
19
#' @param feature_table_valid Feature table from validation set for which prediction is desired. 
20
#' Must have the exact same structure as \code{feature_table}. If missing, uses \code{feature_table} for \code{feature_table_valid}.
21
#' @param sample_metadata_valid Sample-specific metadata table from independent validation set when available. 
22
#' Must have the exact same structure as \code{sample_metadata}. 
23
#' @param folds How many folds in the V-fold nested cross-validation? Default is 10.
24
#' @param seed Specify the arbitrary seed value for reproducibility. Default is 1234.
25
#' @param base_learner Base learner for late fusion and early fusion. 
26
#' Check out the \href{https://cran.r-project.org/web/packages/SuperLearner/vignettes/Guide-to-SuperLearner.html}{SuperLearner user manual} for all available options. Default is \code{`SL.BART`}.
27
#' @param base_screener Whether to screen variables before fitting base models? \code{All} means no screening which is the default.
28
#' Check out the \href{https://cran.r-project.org/web/packages/SuperLearner/vignettes/Guide-to-SuperLearner.html}{SuperLearner user manual} for all available options. 
29
#' @param meta_learner Meta-learner for late fusion (stacked generalization). Defaults to \code{`SL.nnls.auc`}.
30
#' Check out the \href{https://cran.r-project.org/web/packages/SuperLearner/vignettes/Guide-to-SuperLearner.html}{SuperLearner user manual} for all available options. 
31
#' @param run_concat Should early fusion be run? Default is TRUE. Uses the specified \code{base_learner} as the learning algorithm.
32
#' @param run_stacked Should stacked model (late fusion) be run? Default is TRUE.
33
#' @param verbose logical; TRUE for \code{SuperLearner} printing progress (helpful for debugging). Default is FALSE.
34
#' @param print_learner logical; Should a detailed summary be printed? Default is TRUE.
35
#' @param refit.stack logical; For late fusion, post-refit predictions on the entire data is returned if specified. Default is FALSE.
36
#' @param family Currently allows \code{`gaussian()`} for continuous or \code{`binomial()`} for binary outcomes.
37
#' @param ... Additional arguments. Not used currently.
38
#' 
39
#' @return A \code{SuperLearner} object containing the trained model fits.
40
#'
41
#' @author Himel Mallick, \email{him4004@@med.cornell.edu}
42
#' 
43
#' @keywords microbiome, metagenomics, multiomics, scRNASeq, tweedie, singlecell
44
#' @export
45
IntegratedLearner<-function(feature_table,
46
                            sample_metadata, 
47
                            feature_metadata,
48
                            feature_table_valid = NULL, 
49
                            sample_metadata_valid = NULL, 
50
                            folds = 5, 
51
                            seed = 1234, 
52
                            base_learner = 'SL.BART',
53
                            base_screener = 'All', 
54
                            meta_learner = 'SL.nnls.auc',
55
                            run_concat = TRUE, 
56
                            run_stacked = TRUE, 
57
                            verbose = FALSE, 
58
                            print_learner = TRUE, 
59
                            refit.stack = FALSE, 
60
                            family=gaussian(), ...)
61
{ 
62
  
63
  ##############
64
  # Track time #
65
  ##############
66
  
67
  start.time<-Sys.time()
68
  
69
  #######################
70
  # Basic sanity checks #
71
  #######################
72
  
73
  ######################################
74
  # Check Y is appropriate with family #
75
  ######################################
76
  
77
  if (family$family=='gaussian' && length(unique(sample_metadata$Y)) <= 5) {
78
    warning("The response has five or fewer unique values.  Are you sure you want the family to be gaussian?")
79
  }
80
  if (family$family=='binomial' && (length(unique(sample_metadata$Y))< 2))
81
    stop("Need at least two classes to do classification.")
82
  
83
  if (family$family=='binomial' && (length(unique(sample_metadata$Y))> 2))
84
    stop("Classification with more than two classes currently not supported")
85
  
86
  ############################
87
  # Check dimension mismatch #
88
  ############################
89
  
90
  if(all(rownames(feature_table)==rownames(feature_metadata))==FALSE)
91
    stop("Both feature_table and feature_metadata should have the same rownames.")
92
  
93
  if(all(colnames(feature_table)==rownames(sample_metadata))==FALSE)
94
    stop("Row names of sample_metadata must match the column names of feature_table.")
95
  
96
  if (!is.null(feature_table_valid)){
97
    if(all(rownames(feature_table)==rownames(feature_table_valid))==FALSE)
98
      stop("Both feature_table and feature_table_valid should have the same rownames.")
99
  }
100
  
101
  if (!is.null(sample_metadata_valid)){
102
    if(all(colnames(feature_table_valid)==rownames(sample_metadata_valid))==FALSE)
103
      stop("Row names of sample_metadata_valid must match the column names of feature_table_valid")
104
  }
105
  
106
  #########################
107
  # Check missing columns #
108
  #########################
109
  
110
  if (!'subjectID' %in% colnames(sample_metadata)){
111
    stop("sample_metadata must have a column named 'subjectID' describing per-subject unique identifiers.")
112
  }
113
  
114
  if (!'Y' %in% colnames(sample_metadata)){
115
    stop("sample_metadata must have a column named 'Y' describing the outcome of interest.")
116
  }
117
  
118
  if (!'featureID' %in% colnames(feature_metadata)){
119
    stop("feature_metadata must have a column named 'featureID' describing per-feature unique identifiers.")
120
  }
121
  
122
  if (!'featureType' %in% colnames(feature_metadata)){
123
    stop("feature_metadata must have a column named 'featureType' describing the corresponding source layers.")
124
  }
125
  
126
  if (!is.null(sample_metadata_valid)){
127
    if (!'subjectID' %in% colnames(sample_metadata_valid)){
128
      stop("sample_metadata_valid must have a column named 'subjectID' describing per-subject unique identifiers.")
129
    }
130
    
131
    if (!'Y' %in% colnames(sample_metadata_valid)){
132
      stop("sample_metadata_valid must have a column named 'Y' describing the outcome of interest.")
133
    }
134
  }
135
  
136
  #############################################################################################
137
  # Extract validation Y right away (will not be used anywhere during the validation process) #
138
  #############################################################################################
139
  
140
  if (!is.null(sample_metadata_valid)){
141
    validY<-sample_metadata_valid['Y']
142
  }
143
  
144
  ###############################################################
145
  # Set parameters and extract subject IDs for sample splitting #
146
  ###############################################################
147
  
148
  set.seed(seed)
149
  subjectID <- unique(sample_metadata$subjectID)
150
  
151
  ##################################
152
  # Trigger V-fold CV (Outer Loop) #
153
  ##################################
154
  
155
  subjectCvFoldsIN <- caret::createFolds(1:length(subjectID), k = folds, returnTrain=TRUE)
156
  
157
  ########################################
158
  # Curate subject-level samples per fold #
159
  ########################################
160
  
161
  obsIndexIn <- vector("list", folds) 
162
  for(k in 1:length(obsIndexIn)){
163
    x <- which(!sample_metadata$subjectID %in%  subjectID[subjectCvFoldsIN[[k]]])
164
    obsIndexIn[[k]] <- x
165
  }
166
  names(obsIndexIn) <- sapply(1:folds, function(x) paste(c("fold", x), collapse=''))
167
  
168
  ###############################
169
  # Set up data for SL training #
170
  ###############################
171
  
172
  cvControl = list(V = folds, shuffle = FALSE, validRows = obsIndexIn)
173
  
174
  #################################################
175
  # Stacked generalization input data preparation #
176
  #################################################
177
  
178
  feature_metadata$featureType<-as.factor(feature_metadata$featureType)
179
  name_layers<-with(droplevels(feature_metadata), list(levels = levels(featureType)), nlevels = nlevels(featureType))$levels
180
  SL_fit_predictions<-vector("list", length(name_layers))
181
  SL_fit_layers<-vector("list", length(name_layers)) 
182
  names(SL_fit_layers)<-name_layers
183
  names(SL_fit_predictions)<-name_layers
184
  X_train_layers <- vector("list", length(name_layers)) 
185
  names(X_train_layers) <- name_layers
186
  X_test_layers <- vector("list", length(name_layers)) 
187
  names(X_test_layers) <- name_layers
188
  layer_wise_predictions_train<-vector("list", length(name_layers))
189
  names(layer_wise_predictions_train)<-name_layers
190
  
191
  #####################################################################
192
  # Stacked generalization input data preparation for validation data #
193
  #####################################################################
194
  
195
  if (!is.null(feature_table_valid)){
196
    layer_wise_prediction_valid<-vector("list", length(name_layers))
197
    names(layer_wise_prediction_valid)<-name_layers
198
  } 
199
  
200
  ##################################################################
201
  # Carefully subset data per omics and run each individual layers #
202
  ##################################################################
203
  
204
  for (i in seq_along(name_layers)){
205
    #if (verbose){ 
206
      cat('Running base model for layer ', i, "...", "\n")
207
    #}
208
    
209
    ##################################
210
    # Prepate single-omic input data #
211
    ##################################
212
    
213
    include_list<-feature_metadata %>% dplyr::filter(featureType == name_layers[i]) 
214
    t_dat_slice<-feature_table[rownames(feature_table) %in% include_list$featureID, ]
215
    dat_slice<-as.data.frame(t(t_dat_slice))
216
    Y = sample_metadata$Y
217
    X = dat_slice
218
    X_train_layers[[i]] <- X
219
220
    ###################################
221
    # Run user-specified base learner #
222
    ###################################
223
    
224
    SL_fit_layers[[i]] <- SuperLearner::SuperLearner(Y = Y, 
225
                                                     X = X,
226
                                                     cvControl = cvControl,    
227
                                                     verbose = verbose, 
228
                                                     SL.library = list(c(base_learner,base_screener)),
229
                                                     family = family)
230
    
231
    ###################################################
232
    # Append the corresponding y and X to the results #
233
    ###################################################
234
    
235
    SL_fit_layers[[i]]$Y<-sample_metadata['Y']
236
    SL_fit_layers[[i]]$X<-X
237
    if (!is.null(sample_metadata_valid)) SL_fit_layers[[i]]$validY<-validY
238
    
239
    ##################################################################
240
    # Remove redundant data frames and collect pre-stack predictions #
241
    ##################################################################
242
    
243
    rm(t_dat_slice); rm(dat_slice); rm(X)
244
    SL_fit_predictions[[i]]<-SL_fit_layers[[i]]$Z
245
    
246
    ##################################################
247
    # Re-fit to entire dataset for final predictions #
248
    ##################################################
249
    
250
    layer_wise_predictions_train[[i]]<-SL_fit_layers[[i]]$SL.predict
251
    
252
    ############################################################
253
    # Prepate single-omic validation data and save predictions #
254
    ############################################################
255
    
256
    if (!is.null(feature_table_valid)){
257
      t_dat_slice_valid<-feature_table_valid[rownames(feature_table_valid) %in% include_list$featureID, ]
258
      dat_slice_valid<-as.data.frame(t(t_dat_slice_valid))
259
      X_test_layers[[i]] <- dat_slice_valid
260
      layer_wise_prediction_valid[[i]]<-predict.SuperLearner(SL_fit_layers[[i]], newdata = dat_slice_valid)$pred
261
      layer_wise_prediction_valid[[i]] <- matrix(layer_wise_prediction_valid[[i]], ncol = 1) # <- Change here
262
      rownames(layer_wise_prediction_valid[[i]])<-rownames(dat_slice_valid)
263
      SL_fit_layers[[i]]$validX<-dat_slice_valid
264
      SL_fit_layers[[i]]$validPrediction<-layer_wise_prediction_valid[[i]]
265
      SL_fit_layers[[i]]$validPrediction <- matrix(SL_fit_layers[[i]]$validPrediction, ncol = 1) # <- Change here
266
      colnames(SL_fit_layers[[i]]$validPrediction)<-'validPrediction'
267
      rm(dat_slice_valid); rm(include_list)
268
    }
269
  }
270
  
271
  ##############################
272
  # Prepate stacked input data #
273
  ##############################
274
  
275
  combo <- as.data.frame(do.call(cbind, SL_fit_predictions))
276
  names(combo)<-name_layers
277
                              
278
  ###############################
279
  # Set aside final predictions #
280
  ###############################
281
  
282
  combo_final <- as.data.frame(do.call(cbind, layer_wise_predictions_train))
283
  names(combo_final)<-name_layers
284
  
285
  if (!is.null(feature_table_valid)){
286
    combo_valid <- as.data.frame(do.call(cbind, layer_wise_prediction_valid))
287
    names(combo_valid)<-name_layers
288
  }
289
  
290
  ####################
291
  # Stack all models #
292
  ####################
293
  
294
  if (run_stacked){
295
    
296
    #if (verbose) {
297
      cat('Running stacked model...\n')
298
    #}
299
    
300
    ###################################
301
    # Run user-specified meta learner #
302
    ###################################
303
    
304
    SL_fit_stacked<-SuperLearner::SuperLearner(Y = Y, 
305
                                               X = combo, 
306
                                               cvControl = cvControl,    
307
                                               verbose = verbose, 
308
                                               SL.library = meta_learner,
309
                                               family=family)
310
                                                
311
    
312
    # Extract the fit object from SuperLearner
313
    model_stacked <- SL_fit_stacked$fitLibrary[[1]]$object
314
    stacked_prediction_train<-predict.SuperLearner(SL_fit_stacked, newdata = combo_final)$pred
315
    
316
    ###################################################
317
    # Append the corresponding y and X to the results #
318
    ###################################################
319
    
320
    SL_fit_stacked$Y<-sample_metadata['Y']
321
    SL_fit_stacked$X<-combo
322
    if (!is.null(sample_metadata_valid)) SL_fit_stacked$validY<-validY
323
    
324
    #################################################################
325
    # Prepate stacked input data for validation and save prediction #
326
    #################################################################
327
    
328
    if (!is.null(feature_table_valid)){
329
      stacked_prediction_valid<-predict.SuperLearner(SL_fit_stacked, newdata = combo_valid)$pred
330
      rownames(stacked_prediction_valid)<-rownames(combo_valid)
331
      SL_fit_stacked$validX<-combo_valid
332
      SL_fit_stacked$validPrediction<-stacked_prediction_valid
333
      colnames(SL_fit_stacked$validPrediction)<-'validPrediction'
334
    }
335
  }
336
  
337
  #######################################
338
  # Run concatenated model if specified #
339
  #######################################
340
  
341
  if(run_concat){
342
    #if (verbose) {
343
      cat('Running concatenated model...\n')
344
    #}
345
    ###################################
346
    # Prepate concatenated input data #
347
    ###################################
348
    
349
    fulldat<-as.data.frame(t(feature_table))
350
    
351
    ###################################
352
    # Run user-specified base learner #
353
    ###################################
354
    
355
    SL_fit_concat<-SuperLearner::SuperLearner(Y = Y, 
356
                                              X = fulldat, 
357
                                              cvControl = cvControl,    
358
                                              verbose = verbose, 
359
                                              SL.library = list(c(base_learner,base_screener)),
360
                                              family=family)
361
    
362
    # Extract the fit object from superlearner
363
    model_concat <- SL_fit_concat$fitLibrary[[1]]$object
364
    
365
    ###################################################
366
    # Append the corresponding y and X to the results #
367
    ###################################################
368
    
369
    SL_fit_concat$Y<-sample_metadata['Y']
370
    SL_fit_concat$X<-fulldat
371
    if (!is.null(sample_metadata_valid)) SL_fit_concat$validY<-validY
372
    
373
    #########################################################################
374
    # Prepate concatenated input data for validaton set and save prediction #
375
    #########################################################################
376
    
377
    if (!is.null(feature_table_valid)){
378
      fulldat_valid<-as.data.frame(t(feature_table_valid))
379
      concat_prediction_valid<-predict.SuperLearner(SL_fit_concat, newdata = fulldat_valid)$pred
380
      SL_fit_concat$validX<-fulldat_valid
381
      rownames(concat_prediction_valid)<-rownames(fulldat_valid)
382
      SL_fit_concat$validPrediction<-concat_prediction_valid
383
      colnames(SL_fit_concat$validPrediction)<-'validPrediction'
384
    }
385
  }
386
  
387
  ######################
388
  # Save model results #
389
  ######################
390
  
391
  # Extract the fit object from superlearner
392
  model_layers <- vector("list", length(name_layers))
393
  names(model_layers) <- name_layers
394
  for (i in seq_along(name_layers)) {
395
    model_layers[[i]] <- SL_fit_layers[[i]]$fitLibrary[[1]]$object
396
  }
397
  
398
  ##################
399
  # CONCAT + STACK #
400
  ##################
401
  
402
  if(run_concat & run_stacked){
403
    
404
    model_fits <- list(model_layers=model_layers,
405
                       model_stacked=model_stacked,
406
                       model_concat=model_concat)
407
    
408
    SL_fits<-list(SL_fit_layers = SL_fit_layers, 
409
                  SL_fit_stacked = SL_fit_stacked,
410
                  SL_fit_concat = SL_fit_concat)
411
    
412
    ###############################
413
    # Prediction (Stack + Concat) #
414
    ###############################
415
    
416
    if(refit.stack){
417
      yhat.train <- cbind(combo, stacked_prediction_train, SL_fit_concat$Z)
418
    } else{
419
      yhat.train <- cbind(combo, SL_fit_stacked$Z, SL_fit_concat$Z)
420
    }
421
    colnames(yhat.train) <- c(colnames(combo), "stacked", "concatenated")
422
    
423
    ###############################
424
    # Validation (Stack + Concat) #
425
    ###############################
426
    
427
    if(!is.null(feature_table_valid)){
428
      yhat.test <- cbind(combo_valid, SL_fit_stacked$validPrediction,SL_fit_concat$validPrediction)
429
      colnames(yhat.test) <- c(colnames(combo_valid),"stacked","concatenated")
430
      
431
    ########
432
    # Save #
433
    ########
434
      
435
      res <- list(model_fits=model_fits, 
436
                  SL_fits=SL_fits,
437
                  X_train_layers=X_train_layers,
438
                  Y_train=Y,
439
                  yhat.train=yhat.train,
440
                  X_test_layers=X_test_layers,
441
                  yhat.test=yhat.test
442
      )
443
    }else{
444
      res <- list(model_fits=model_fits, 
445
                  SL_fits=SL_fits,
446
                  X_train_layers=X_train_layers,
447
                  Y_train=Y,
448
                  yhat.train=yhat.train
449
      )
450
      
451
    }
452
    
453
    ###############
454
    # CONCAT ONLY #
455
    ###############
456
    
457
  } else if (run_concat & !run_stacked){
458
    
459
    model_fits <- list(model_layers=model_layers,
460
                       model_concat=model_concat)
461
    
462
    SL_fits<-list(SL_fit_layers = SL_fit_layers, 
463
                  SL_fit_concat = SL_fit_concat)
464
    
465
    
466
    ############################
467
    # Prediction (Concat Only) #
468
    ############################
469
    
470
    yhat.train <- cbind(combo, SL_fit_concat$Z)
471
    colnames(yhat.train) <- c(colnames(combo), "concatenated")
472
  
473
    ############################
474
    # Validation (Concat Only) #
475
    ############################
476
    
477
    if(!is.null(feature_table_valid)){
478
      yhat.test <- cbind(combo_valid,SL_fit_concat$validPrediction)
479
      colnames(yhat.test) <- c(colnames(combo_valid),"concatenated")
480
      
481
      res <- list(model_fits=model_fits, 
482
                  SL_fits=SL_fits,
483
                  X_train_layers=X_train_layers,
484
                  Y_train=Y,
485
                  yhat.train=yhat.train,
486
                  X_test_layers=X_test_layers,
487
                  yhat.test=yhat.test
488
      )
489
    }else{
490
      res <- list(model_fits=model_fits, 
491
                  SL_fits=SL_fits,
492
                  X_train_layers=X_train_layers,
493
                  Y_train=Y,
494
                  yhat.train=yhat.train
495
      )
496
      
497
    }
498
    
499
    
500
    ##############
501
    # STACK ONLY #
502
    ##############
503
    
504
  } else if (!run_concat & run_stacked){
505
    
506
    model_fits <- list(model_layers = model_layers,
507
                       model_stacked = model_stacked)
508
    
509
    SL_fits<-list(SL_fit_layers = SL_fit_layers, 
510
                  SL_fit_stacked = SL_fit_stacked)
511
    
512
    ###########################
513
    # Prediction (Stack Only) #
514
    ###########################
515
    
516
    if(refit.stack){
517
      yhat.train <- cbind(combo, stacked_prediction_train)
518
    } else{
519
      yhat.train <- cbind(combo, SL_fit_stacked$Z)
520
    }
521
    colnames(yhat.train) <- c(colnames(combo), "stacked")
522
    
523
    ###########################
524
    # Validation (Stack Only) #
525
    ###########################
526
    
527
    if(!is.null(feature_table_valid)){
528
      yhat.test <- cbind(combo_valid, SL_fit_stacked$validPrediction)
529
      colnames(yhat.test) <- c(colnames(combo_valid),"stacked")
530
      
531
    ########
532
    # Save #
533
    ########
534
      
535
      res <- list(model_fits=model_fits, 
536
                  SL_fits=SL_fits,
537
                  X_train_layers=X_train_layers,
538
                  Y_train=Y,
539
                  yhat.train=yhat.train,
540
                  X_test_layers=X_test_layers,
541
                  yhat.test=yhat.test
542
      )
543
    }else{
544
      res <- list(model_fits=model_fits, 
545
                  SL_fits=SL_fits,
546
                  X_train_layers=X_train_layers,
547
                  Y_train=Y,
548
                  yhat.train=yhat.train
549
      )
550
      
551
    }
552
    
553
    
554
    ############################
555
    # NEITHER CONCAT NOR STACK #
556
    ############################
557
    
558
  } else{ 
559
    
560
    model_fits <- list(model_layers=model_layers)
561
    SL_fits<-list(SL_fit_layers = SL_fit_layers)
562
    
563
    #########################################
564
    # Prediction (Neither Stack nor Concat) #
565
    #########################################
566
    
567
    yhat.train <- combo
568
    colnames(yhat.train) <- colnames(combo)
569
    
570
    #########################################
571
    # Validation (Neither Stack nor Concat) #
572
    #########################################
573
    
574
    if(!is.null(feature_table_valid)){
575
      yhat.test <- combo_valid
576
      colnames(yhat.test) <- colnames(combo_valid)
577
      
578
      #########
579
      # Save #
580
      ########
581
      
582
      res <- list(model_fits=model_fits, 
583
                  SL_fits=SL_fits,
584
                  X_train_layers=X_train_layers,
585
                  Y_train=Y,
586
                  yhat.train=yhat.train,
587
                  X_test_layers=X_test_layers,
588
                  yhat.test=yhat.test
589
      )
590
    }else{
591
      res <- list(model_fits=model_fits, 
592
                  SL_fits=SL_fits,
593
                  X_train_layers=X_train_layers,
594
                  Y_train=Y,
595
                  yhat.train=yhat.train
596
      )
597
      
598
    }
599
    
600
    
601
  }
602
  if(!is.null(sample_metadata_valid)){res$Y_test=validY$Y}
603
  res$base_learner <- base_learner
604
  res$meta_learner <- meta_learner
605
  res$base_screener <- base_screener
606
  res$run_concat <- run_concat
607
  res$run_stacked <- run_stacked
608
  res$family <- family$family
609
  res$feature.names <- rownames(feature_table)
610
  if(is.null(sample_metadata_valid)){
611
    res$test=FALSE
612
  }else{
613
    res$test=TRUE
614
  }
615
  if(meta_learner=="SL.nnls.auc" & run_stacked){
616
    res$weights <- res$model_fits$model_stacked$solution
617
    names(res$weights) <- colnames(combo)
618
  }
619
  
620
  if(res$family=="binomial"){
621
    # Calculate AUC for each layer, stacked and concatenated 
622
    pred=apply(res$yhat.train, 2, ROCR::prediction, labels=res$Y_train)
623
    AUC=vector(length = length(pred))
624
    names(AUC)=names(pred)
625
    for(i in seq_along(pred)){
626
      AUC[i] = round(ROCR::performance(pred[[i]], "auc")@y.values[[1]], 3)
627
    }
628
    res$AUC.train <- AUC
629
    
630
    if(res$test==TRUE){
631
      
632
      # Calculate AUC for each layer, stacked and concatenated 
633
      pred=apply(res$yhat.test, 2, ROCR::prediction, labels=res$Y_test)
634
      AUC=vector(length = length(pred))
635
      names(AUC)=names(pred)
636
      for(i in seq_along(pred)){
637
        AUC[i] = round(ROCR::performance(pred[[i]], "auc")@y.values[[1]], 3)
638
      }
639
    res$AUC.test <- AUC  
640
    }
641
  }
642
  if(res$family=="gaussian"){
643
      
644
      # Calculate R^2 for each layer, stacked and concatenated 
645
      R2=vector(length = ncol(res$yhat.train))
646
      names(R2)=names(res$yhat.train)
647
      for(i in seq_along(R2)){
648
        R2[i] = as.vector(cor(res$yhat.train[ ,i], res$Y_train)^2)
649
      }
650
      res$R2.train <- R2
651
      if(res$test==TRUE){
652
        # Calculate R^2 for each layer, stacked and concatenated 
653
        R2=vector(length = ncol(res$yhat.test))
654
        names(R2)=names(res$yhat.test)
655
        for(i in seq_along(R2)){
656
          R2[i] = as.vector(cor(res$yhat.test[ ,i], res$Y_test)^2)
657
        }
658
        res$R2.test <- R2
659
      }
660
      
661
  }    
662
  res$folds <- folds
663
  res$cvControl <- cvControl
664
  res$id <- id
665
  stop.time<-Sys.time()
666
  time <- as.numeric(round(difftime(stop.time, start.time, units="min"), 3), units = "mins")
667
  res$time <- time
668
  ##########
669
  # Return #
670
  ##########
671
672
  if(print_learner==TRUE){print.learner(res)}
673
  return(res)
674
}  
675