|
a |
|
b/R/IntegratedLearner.R |
|
|
1 |
#' Integrated machine learning for multi-omics prediction and classification |
|
|
2 |
#' |
|
|
3 |
#' Performs integrated machine learning to predict a binary or continuous outcome based on two or more omics layers (views). |
|
|
4 |
#' The \code{IntegratedLearner} function takes a training set (Y, X1, X2,...,Xn) and returns the predicted values based on a validation set. |
|
|
5 |
#' It also performs V-fold nested cross-validation to estimate the prediction accuracy of various fusion algorithms. |
|
|
6 |
#' Three types of integration paradigms are supported: early, late, and intermediate. |
|
|
7 |
#' The software includes multiple ML models based on the \code{\link[SuperLearner]{SuperLearner}} R package as well as several data exploration capabilities and visualization modules in a unified estimation framework. |
|
|
8 |
#' @param feature_table An R data frame containing multiview features (in rows) and samples (in columns). |
|
|
9 |
#' Column names of \code{feature_metadata} must match the row names of \code{sample_metadata}. |
|
|
10 |
#' @param sample_metadata An R data frame of metadata variables (in columns). |
|
|
11 |
#' Must have a column named \code{subjectID} describing per-subject unique identifiers. |
|
|
12 |
#' For longitudinal designs, this variable is expected to have non-unique values. |
|
|
13 |
#' Additionally, a column named \code{Y} must be present which is the outcome of interest (can be binary or continuous). |
|
|
14 |
#' Row names of \code{sample_metadata} must match the column names of \code{feature_table}. |
|
|
15 |
#' @param feature_metadata An R data frame of feature-specific metadata across views (in columns) and features (in rows). |
|
|
16 |
#' Must have a column named \code{featureID} describing per-feature unique identifiers. |
|
|
17 |
#' Additionally, a column named \code{featureType} should describe the corresponding source layers. |
|
|
18 |
#' Row names of \code{feature_metadata} must match the row names of \code{feature_table}. |
|
|
19 |
#' @param feature_table_valid Feature table from validation set for which prediction is desired. |
|
|
20 |
#' Must have the exact same structure as \code{feature_table}. If missing, uses \code{feature_table} for \code{feature_table_valid}. |
|
|
21 |
#' @param sample_metadata_valid Sample-specific metadata table from independent validation set when available. |
|
|
22 |
#' Must have the exact same structure as \code{sample_metadata}. |
|
|
23 |
#' @param folds How many folds in the V-fold nested cross-validation? Default is 10. |
|
|
24 |
#' @param seed Specify the arbitrary seed value for reproducibility. Default is 1234. |
|
|
25 |
#' @param base_learner Base learner for late fusion and early fusion. |
|
|
26 |
#' Check out the \href{https://cran.r-project.org/web/packages/SuperLearner/vignettes/Guide-to-SuperLearner.html}{SuperLearner user manual} for all available options. Default is \code{`SL.BART`}. |
|
|
27 |
#' @param base_screener Whether to screen variables before fitting base models? \code{All} means no screening which is the default. |
|
|
28 |
#' Check out the \href{https://cran.r-project.org/web/packages/SuperLearner/vignettes/Guide-to-SuperLearner.html}{SuperLearner user manual} for all available options. |
|
|
29 |
#' @param meta_learner Meta-learner for late fusion (stacked generalization). Defaults to \code{`SL.nnls.auc`}. |
|
|
30 |
#' Check out the \href{https://cran.r-project.org/web/packages/SuperLearner/vignettes/Guide-to-SuperLearner.html}{SuperLearner user manual} for all available options. |
|
|
31 |
#' @param run_concat Should early fusion be run? Default is TRUE. Uses the specified \code{base_learner} as the learning algorithm. |
|
|
32 |
#' @param run_stacked Should stacked model (late fusion) be run? Default is TRUE. |
|
|
33 |
#' @param verbose logical; TRUE for \code{SuperLearner} printing progress (helpful for debugging). Default is FALSE. |
|
|
34 |
#' @param print_learner logical; Should a detailed summary be printed? Default is TRUE. |
|
|
35 |
#' @param refit.stack logical; For late fusion, post-refit predictions on the entire data is returned if specified. Default is FALSE. |
|
|
36 |
#' @param family Currently allows \code{`gaussian()`} for continuous or \code{`binomial()`} for binary outcomes. |
|
|
37 |
#' @param ... Additional arguments. Not used currently. |
|
|
38 |
#' |
|
|
39 |
#' @return A \code{SuperLearner} object containing the trained model fits. |
|
|
40 |
#' |
|
|
41 |
#' @author Himel Mallick, \email{him4004@@med.cornell.edu} |
|
|
42 |
#' |
|
|
43 |
#' @keywords microbiome, metagenomics, multiomics, scRNASeq, tweedie, singlecell |
|
|
44 |
#' @export |
|
|
45 |
IntegratedLearner<-function(feature_table, |
|
|
46 |
sample_metadata, |
|
|
47 |
feature_metadata, |
|
|
48 |
feature_table_valid = NULL, |
|
|
49 |
sample_metadata_valid = NULL, |
|
|
50 |
folds = 5, |
|
|
51 |
seed = 1234, |
|
|
52 |
base_learner = 'SL.BART', |
|
|
53 |
base_screener = 'All', |
|
|
54 |
meta_learner = 'SL.nnls.auc', |
|
|
55 |
run_concat = TRUE, |
|
|
56 |
run_stacked = TRUE, |
|
|
57 |
verbose = FALSE, |
|
|
58 |
print_learner = TRUE, |
|
|
59 |
refit.stack = FALSE, |
|
|
60 |
family=gaussian(), ...) |
|
|
61 |
{ |
|
|
62 |
|
|
|
63 |
############## |
|
|
64 |
# Track time # |
|
|
65 |
############## |
|
|
66 |
|
|
|
67 |
start.time<-Sys.time() |
|
|
68 |
|
|
|
69 |
####################### |
|
|
70 |
# Basic sanity checks # |
|
|
71 |
####################### |
|
|
72 |
|
|
|
73 |
###################################### |
|
|
74 |
# Check Y is appropriate with family # |
|
|
75 |
###################################### |
|
|
76 |
|
|
|
77 |
if (family$family=='gaussian' && length(unique(sample_metadata$Y)) <= 5) { |
|
|
78 |
warning("The response has five or fewer unique values. Are you sure you want the family to be gaussian?") |
|
|
79 |
} |
|
|
80 |
if (family$family=='binomial' && (length(unique(sample_metadata$Y))< 2)) |
|
|
81 |
stop("Need at least two classes to do classification.") |
|
|
82 |
|
|
|
83 |
if (family$family=='binomial' && (length(unique(sample_metadata$Y))> 2)) |
|
|
84 |
stop("Classification with more than two classes currently not supported") |
|
|
85 |
|
|
|
86 |
############################ |
|
|
87 |
# Check dimension mismatch # |
|
|
88 |
############################ |
|
|
89 |
|
|
|
90 |
if(all(rownames(feature_table)==rownames(feature_metadata))==FALSE) |
|
|
91 |
stop("Both feature_table and feature_metadata should have the same rownames.") |
|
|
92 |
|
|
|
93 |
if(all(colnames(feature_table)==rownames(sample_metadata))==FALSE) |
|
|
94 |
stop("Row names of sample_metadata must match the column names of feature_table.") |
|
|
95 |
|
|
|
96 |
if (!is.null(feature_table_valid)){ |
|
|
97 |
if(all(rownames(feature_table)==rownames(feature_table_valid))==FALSE) |
|
|
98 |
stop("Both feature_table and feature_table_valid should have the same rownames.") |
|
|
99 |
} |
|
|
100 |
|
|
|
101 |
if (!is.null(sample_metadata_valid)){ |
|
|
102 |
if(all(colnames(feature_table_valid)==rownames(sample_metadata_valid))==FALSE) |
|
|
103 |
stop("Row names of sample_metadata_valid must match the column names of feature_table_valid") |
|
|
104 |
} |
|
|
105 |
|
|
|
106 |
######################### |
|
|
107 |
# Check missing columns # |
|
|
108 |
######################### |
|
|
109 |
|
|
|
110 |
if (!'subjectID' %in% colnames(sample_metadata)){ |
|
|
111 |
stop("sample_metadata must have a column named 'subjectID' describing per-subject unique identifiers.") |
|
|
112 |
} |
|
|
113 |
|
|
|
114 |
if (!'Y' %in% colnames(sample_metadata)){ |
|
|
115 |
stop("sample_metadata must have a column named 'Y' describing the outcome of interest.") |
|
|
116 |
} |
|
|
117 |
|
|
|
118 |
if (!'featureID' %in% colnames(feature_metadata)){ |
|
|
119 |
stop("feature_metadata must have a column named 'featureID' describing per-feature unique identifiers.") |
|
|
120 |
} |
|
|
121 |
|
|
|
122 |
if (!'featureType' %in% colnames(feature_metadata)){ |
|
|
123 |
stop("feature_metadata must have a column named 'featureType' describing the corresponding source layers.") |
|
|
124 |
} |
|
|
125 |
|
|
|
126 |
if (!is.null(sample_metadata_valid)){ |
|
|
127 |
if (!'subjectID' %in% colnames(sample_metadata_valid)){ |
|
|
128 |
stop("sample_metadata_valid must have a column named 'subjectID' describing per-subject unique identifiers.") |
|
|
129 |
} |
|
|
130 |
|
|
|
131 |
if (!'Y' %in% colnames(sample_metadata_valid)){ |
|
|
132 |
stop("sample_metadata_valid must have a column named 'Y' describing the outcome of interest.") |
|
|
133 |
} |
|
|
134 |
} |
|
|
135 |
|
|
|
136 |
############################################################################################# |
|
|
137 |
# Extract validation Y right away (will not be used anywhere during the validation process) # |
|
|
138 |
############################################################################################# |
|
|
139 |
|
|
|
140 |
if (!is.null(sample_metadata_valid)){ |
|
|
141 |
validY<-sample_metadata_valid['Y'] |
|
|
142 |
} |
|
|
143 |
|
|
|
144 |
############################################################### |
|
|
145 |
# Set parameters and extract subject IDs for sample splitting # |
|
|
146 |
############################################################### |
|
|
147 |
|
|
|
148 |
set.seed(seed) |
|
|
149 |
subjectID <- unique(sample_metadata$subjectID) |
|
|
150 |
|
|
|
151 |
################################## |
|
|
152 |
# Trigger V-fold CV (Outer Loop) # |
|
|
153 |
################################## |
|
|
154 |
|
|
|
155 |
subjectCvFoldsIN <- caret::createFolds(1:length(subjectID), k = folds, returnTrain=TRUE) |
|
|
156 |
|
|
|
157 |
######################################## |
|
|
158 |
# Curate subject-level samples per fold # |
|
|
159 |
######################################## |
|
|
160 |
|
|
|
161 |
obsIndexIn <- vector("list", folds) |
|
|
162 |
for(k in 1:length(obsIndexIn)){ |
|
|
163 |
x <- which(!sample_metadata$subjectID %in% subjectID[subjectCvFoldsIN[[k]]]) |
|
|
164 |
obsIndexIn[[k]] <- x |
|
|
165 |
} |
|
|
166 |
names(obsIndexIn) <- sapply(1:folds, function(x) paste(c("fold", x), collapse='')) |
|
|
167 |
|
|
|
168 |
############################### |
|
|
169 |
# Set up data for SL training # |
|
|
170 |
############################### |
|
|
171 |
|
|
|
172 |
cvControl = list(V = folds, shuffle = FALSE, validRows = obsIndexIn) |
|
|
173 |
|
|
|
174 |
################################################# |
|
|
175 |
# Stacked generalization input data preparation # |
|
|
176 |
################################################# |
|
|
177 |
|
|
|
178 |
feature_metadata$featureType<-as.factor(feature_metadata$featureType) |
|
|
179 |
name_layers<-with(droplevels(feature_metadata), list(levels = levels(featureType)), nlevels = nlevels(featureType))$levels |
|
|
180 |
SL_fit_predictions<-vector("list", length(name_layers)) |
|
|
181 |
SL_fit_layers<-vector("list", length(name_layers)) |
|
|
182 |
names(SL_fit_layers)<-name_layers |
|
|
183 |
names(SL_fit_predictions)<-name_layers |
|
|
184 |
X_train_layers <- vector("list", length(name_layers)) |
|
|
185 |
names(X_train_layers) <- name_layers |
|
|
186 |
X_test_layers <- vector("list", length(name_layers)) |
|
|
187 |
names(X_test_layers) <- name_layers |
|
|
188 |
layer_wise_predictions_train<-vector("list", length(name_layers)) |
|
|
189 |
names(layer_wise_predictions_train)<-name_layers |
|
|
190 |
|
|
|
191 |
##################################################################### |
|
|
192 |
# Stacked generalization input data preparation for validation data # |
|
|
193 |
##################################################################### |
|
|
194 |
|
|
|
195 |
if (!is.null(feature_table_valid)){ |
|
|
196 |
layer_wise_prediction_valid<-vector("list", length(name_layers)) |
|
|
197 |
names(layer_wise_prediction_valid)<-name_layers |
|
|
198 |
} |
|
|
199 |
|
|
|
200 |
################################################################## |
|
|
201 |
# Carefully subset data per omics and run each individual layers # |
|
|
202 |
################################################################## |
|
|
203 |
|
|
|
204 |
for (i in seq_along(name_layers)){ |
|
|
205 |
#if (verbose){ |
|
|
206 |
cat('Running base model for layer ', i, "...", "\n") |
|
|
207 |
#} |
|
|
208 |
|
|
|
209 |
################################## |
|
|
210 |
# Prepate single-omic input data # |
|
|
211 |
################################## |
|
|
212 |
|
|
|
213 |
include_list<-feature_metadata %>% dplyr::filter(featureType == name_layers[i]) |
|
|
214 |
t_dat_slice<-feature_table[rownames(feature_table) %in% include_list$featureID, ] |
|
|
215 |
dat_slice<-as.data.frame(t(t_dat_slice)) |
|
|
216 |
Y = sample_metadata$Y |
|
|
217 |
X = dat_slice |
|
|
218 |
X_train_layers[[i]] <- X |
|
|
219 |
|
|
|
220 |
################################### |
|
|
221 |
# Run user-specified base learner # |
|
|
222 |
################################### |
|
|
223 |
|
|
|
224 |
SL_fit_layers[[i]] <- SuperLearner::SuperLearner(Y = Y, |
|
|
225 |
X = X, |
|
|
226 |
cvControl = cvControl, |
|
|
227 |
verbose = verbose, |
|
|
228 |
SL.library = list(c(base_learner,base_screener)), |
|
|
229 |
family = family) |
|
|
230 |
|
|
|
231 |
################################################### |
|
|
232 |
# Append the corresponding y and X to the results # |
|
|
233 |
################################################### |
|
|
234 |
|
|
|
235 |
SL_fit_layers[[i]]$Y<-sample_metadata['Y'] |
|
|
236 |
SL_fit_layers[[i]]$X<-X |
|
|
237 |
if (!is.null(sample_metadata_valid)) SL_fit_layers[[i]]$validY<-validY |
|
|
238 |
|
|
|
239 |
################################################################## |
|
|
240 |
# Remove redundant data frames and collect pre-stack predictions # |
|
|
241 |
################################################################## |
|
|
242 |
|
|
|
243 |
rm(t_dat_slice); rm(dat_slice); rm(X) |
|
|
244 |
SL_fit_predictions[[i]]<-SL_fit_layers[[i]]$Z |
|
|
245 |
|
|
|
246 |
################################################## |
|
|
247 |
# Re-fit to entire dataset for final predictions # |
|
|
248 |
################################################## |
|
|
249 |
|
|
|
250 |
layer_wise_predictions_train[[i]]<-SL_fit_layers[[i]]$SL.predict |
|
|
251 |
|
|
|
252 |
############################################################ |
|
|
253 |
# Prepate single-omic validation data and save predictions # |
|
|
254 |
############################################################ |
|
|
255 |
|
|
|
256 |
if (!is.null(feature_table_valid)){ |
|
|
257 |
t_dat_slice_valid<-feature_table_valid[rownames(feature_table_valid) %in% include_list$featureID, ] |
|
|
258 |
dat_slice_valid<-as.data.frame(t(t_dat_slice_valid)) |
|
|
259 |
X_test_layers[[i]] <- dat_slice_valid |
|
|
260 |
layer_wise_prediction_valid[[i]]<-predict.SuperLearner(SL_fit_layers[[i]], newdata = dat_slice_valid)$pred |
|
|
261 |
layer_wise_prediction_valid[[i]] <- matrix(layer_wise_prediction_valid[[i]], ncol = 1) # <- Change here |
|
|
262 |
rownames(layer_wise_prediction_valid[[i]])<-rownames(dat_slice_valid) |
|
|
263 |
SL_fit_layers[[i]]$validX<-dat_slice_valid |
|
|
264 |
SL_fit_layers[[i]]$validPrediction<-layer_wise_prediction_valid[[i]] |
|
|
265 |
SL_fit_layers[[i]]$validPrediction <- matrix(SL_fit_layers[[i]]$validPrediction, ncol = 1) # <- Change here |
|
|
266 |
colnames(SL_fit_layers[[i]]$validPrediction)<-'validPrediction' |
|
|
267 |
rm(dat_slice_valid); rm(include_list) |
|
|
268 |
} |
|
|
269 |
} |
|
|
270 |
|
|
|
271 |
############################## |
|
|
272 |
# Prepate stacked input data # |
|
|
273 |
############################## |
|
|
274 |
|
|
|
275 |
combo <- as.data.frame(do.call(cbind, SL_fit_predictions)) |
|
|
276 |
names(combo)<-name_layers |
|
|
277 |
|
|
|
278 |
############################### |
|
|
279 |
# Set aside final predictions # |
|
|
280 |
############################### |
|
|
281 |
|
|
|
282 |
combo_final <- as.data.frame(do.call(cbind, layer_wise_predictions_train)) |
|
|
283 |
names(combo_final)<-name_layers |
|
|
284 |
|
|
|
285 |
if (!is.null(feature_table_valid)){ |
|
|
286 |
combo_valid <- as.data.frame(do.call(cbind, layer_wise_prediction_valid)) |
|
|
287 |
names(combo_valid)<-name_layers |
|
|
288 |
} |
|
|
289 |
|
|
|
290 |
#################### |
|
|
291 |
# Stack all models # |
|
|
292 |
#################### |
|
|
293 |
|
|
|
294 |
if (run_stacked){ |
|
|
295 |
|
|
|
296 |
#if (verbose) { |
|
|
297 |
cat('Running stacked model...\n') |
|
|
298 |
#} |
|
|
299 |
|
|
|
300 |
################################### |
|
|
301 |
# Run user-specified meta learner # |
|
|
302 |
################################### |
|
|
303 |
|
|
|
304 |
SL_fit_stacked<-SuperLearner::SuperLearner(Y = Y, |
|
|
305 |
X = combo, |
|
|
306 |
cvControl = cvControl, |
|
|
307 |
verbose = verbose, |
|
|
308 |
SL.library = meta_learner, |
|
|
309 |
family=family) |
|
|
310 |
|
|
|
311 |
|
|
|
312 |
# Extract the fit object from SuperLearner |
|
|
313 |
model_stacked <- SL_fit_stacked$fitLibrary[[1]]$object |
|
|
314 |
stacked_prediction_train<-predict.SuperLearner(SL_fit_stacked, newdata = combo_final)$pred |
|
|
315 |
|
|
|
316 |
################################################### |
|
|
317 |
# Append the corresponding y and X to the results # |
|
|
318 |
################################################### |
|
|
319 |
|
|
|
320 |
SL_fit_stacked$Y<-sample_metadata['Y'] |
|
|
321 |
SL_fit_stacked$X<-combo |
|
|
322 |
if (!is.null(sample_metadata_valid)) SL_fit_stacked$validY<-validY |
|
|
323 |
|
|
|
324 |
################################################################# |
|
|
325 |
# Prepate stacked input data for validation and save prediction # |
|
|
326 |
################################################################# |
|
|
327 |
|
|
|
328 |
if (!is.null(feature_table_valid)){ |
|
|
329 |
stacked_prediction_valid<-predict.SuperLearner(SL_fit_stacked, newdata = combo_valid)$pred |
|
|
330 |
rownames(stacked_prediction_valid)<-rownames(combo_valid) |
|
|
331 |
SL_fit_stacked$validX<-combo_valid |
|
|
332 |
SL_fit_stacked$validPrediction<-stacked_prediction_valid |
|
|
333 |
colnames(SL_fit_stacked$validPrediction)<-'validPrediction' |
|
|
334 |
} |
|
|
335 |
} |
|
|
336 |
|
|
|
337 |
####################################### |
|
|
338 |
# Run concatenated model if specified # |
|
|
339 |
####################################### |
|
|
340 |
|
|
|
341 |
if(run_concat){ |
|
|
342 |
#if (verbose) { |
|
|
343 |
cat('Running concatenated model...\n') |
|
|
344 |
#} |
|
|
345 |
################################### |
|
|
346 |
# Prepate concatenated input data # |
|
|
347 |
################################### |
|
|
348 |
|
|
|
349 |
fulldat<-as.data.frame(t(feature_table)) |
|
|
350 |
|
|
|
351 |
################################### |
|
|
352 |
# Run user-specified base learner # |
|
|
353 |
################################### |
|
|
354 |
|
|
|
355 |
SL_fit_concat<-SuperLearner::SuperLearner(Y = Y, |
|
|
356 |
X = fulldat, |
|
|
357 |
cvControl = cvControl, |
|
|
358 |
verbose = verbose, |
|
|
359 |
SL.library = list(c(base_learner,base_screener)), |
|
|
360 |
family=family) |
|
|
361 |
|
|
|
362 |
# Extract the fit object from superlearner |
|
|
363 |
model_concat <- SL_fit_concat$fitLibrary[[1]]$object |
|
|
364 |
|
|
|
365 |
################################################### |
|
|
366 |
# Append the corresponding y and X to the results # |
|
|
367 |
################################################### |
|
|
368 |
|
|
|
369 |
SL_fit_concat$Y<-sample_metadata['Y'] |
|
|
370 |
SL_fit_concat$X<-fulldat |
|
|
371 |
if (!is.null(sample_metadata_valid)) SL_fit_concat$validY<-validY |
|
|
372 |
|
|
|
373 |
######################################################################### |
|
|
374 |
# Prepate concatenated input data for validaton set and save prediction # |
|
|
375 |
######################################################################### |
|
|
376 |
|
|
|
377 |
if (!is.null(feature_table_valid)){ |
|
|
378 |
fulldat_valid<-as.data.frame(t(feature_table_valid)) |
|
|
379 |
concat_prediction_valid<-predict.SuperLearner(SL_fit_concat, newdata = fulldat_valid)$pred |
|
|
380 |
SL_fit_concat$validX<-fulldat_valid |
|
|
381 |
rownames(concat_prediction_valid)<-rownames(fulldat_valid) |
|
|
382 |
SL_fit_concat$validPrediction<-concat_prediction_valid |
|
|
383 |
colnames(SL_fit_concat$validPrediction)<-'validPrediction' |
|
|
384 |
} |
|
|
385 |
} |
|
|
386 |
|
|
|
387 |
###################### |
|
|
388 |
# Save model results # |
|
|
389 |
###################### |
|
|
390 |
|
|
|
391 |
# Extract the fit object from superlearner |
|
|
392 |
model_layers <- vector("list", length(name_layers)) |
|
|
393 |
names(model_layers) <- name_layers |
|
|
394 |
for (i in seq_along(name_layers)) { |
|
|
395 |
model_layers[[i]] <- SL_fit_layers[[i]]$fitLibrary[[1]]$object |
|
|
396 |
} |
|
|
397 |
|
|
|
398 |
################## |
|
|
399 |
# CONCAT + STACK # |
|
|
400 |
################## |
|
|
401 |
|
|
|
402 |
if(run_concat & run_stacked){ |
|
|
403 |
|
|
|
404 |
model_fits <- list(model_layers=model_layers, |
|
|
405 |
model_stacked=model_stacked, |
|
|
406 |
model_concat=model_concat) |
|
|
407 |
|
|
|
408 |
SL_fits<-list(SL_fit_layers = SL_fit_layers, |
|
|
409 |
SL_fit_stacked = SL_fit_stacked, |
|
|
410 |
SL_fit_concat = SL_fit_concat) |
|
|
411 |
|
|
|
412 |
############################### |
|
|
413 |
# Prediction (Stack + Concat) # |
|
|
414 |
############################### |
|
|
415 |
|
|
|
416 |
if(refit.stack){ |
|
|
417 |
yhat.train <- cbind(combo, stacked_prediction_train, SL_fit_concat$Z) |
|
|
418 |
} else{ |
|
|
419 |
yhat.train <- cbind(combo, SL_fit_stacked$Z, SL_fit_concat$Z) |
|
|
420 |
} |
|
|
421 |
colnames(yhat.train) <- c(colnames(combo), "stacked", "concatenated") |
|
|
422 |
|
|
|
423 |
############################### |
|
|
424 |
# Validation (Stack + Concat) # |
|
|
425 |
############################### |
|
|
426 |
|
|
|
427 |
if(!is.null(feature_table_valid)){ |
|
|
428 |
yhat.test <- cbind(combo_valid, SL_fit_stacked$validPrediction,SL_fit_concat$validPrediction) |
|
|
429 |
colnames(yhat.test) <- c(colnames(combo_valid),"stacked","concatenated") |
|
|
430 |
|
|
|
431 |
######## |
|
|
432 |
# Save # |
|
|
433 |
######## |
|
|
434 |
|
|
|
435 |
res <- list(model_fits=model_fits, |
|
|
436 |
SL_fits=SL_fits, |
|
|
437 |
X_train_layers=X_train_layers, |
|
|
438 |
Y_train=Y, |
|
|
439 |
yhat.train=yhat.train, |
|
|
440 |
X_test_layers=X_test_layers, |
|
|
441 |
yhat.test=yhat.test |
|
|
442 |
) |
|
|
443 |
}else{ |
|
|
444 |
res <- list(model_fits=model_fits, |
|
|
445 |
SL_fits=SL_fits, |
|
|
446 |
X_train_layers=X_train_layers, |
|
|
447 |
Y_train=Y, |
|
|
448 |
yhat.train=yhat.train |
|
|
449 |
) |
|
|
450 |
|
|
|
451 |
} |
|
|
452 |
|
|
|
453 |
############### |
|
|
454 |
# CONCAT ONLY # |
|
|
455 |
############### |
|
|
456 |
|
|
|
457 |
} else if (run_concat & !run_stacked){ |
|
|
458 |
|
|
|
459 |
model_fits <- list(model_layers=model_layers, |
|
|
460 |
model_concat=model_concat) |
|
|
461 |
|
|
|
462 |
SL_fits<-list(SL_fit_layers = SL_fit_layers, |
|
|
463 |
SL_fit_concat = SL_fit_concat) |
|
|
464 |
|
|
|
465 |
|
|
|
466 |
############################ |
|
|
467 |
# Prediction (Concat Only) # |
|
|
468 |
############################ |
|
|
469 |
|
|
|
470 |
yhat.train <- cbind(combo, SL_fit_concat$Z) |
|
|
471 |
colnames(yhat.train) <- c(colnames(combo), "concatenated") |
|
|
472 |
|
|
|
473 |
############################ |
|
|
474 |
# Validation (Concat Only) # |
|
|
475 |
############################ |
|
|
476 |
|
|
|
477 |
if(!is.null(feature_table_valid)){ |
|
|
478 |
yhat.test <- cbind(combo_valid,SL_fit_concat$validPrediction) |
|
|
479 |
colnames(yhat.test) <- c(colnames(combo_valid),"concatenated") |
|
|
480 |
|
|
|
481 |
res <- list(model_fits=model_fits, |
|
|
482 |
SL_fits=SL_fits, |
|
|
483 |
X_train_layers=X_train_layers, |
|
|
484 |
Y_train=Y, |
|
|
485 |
yhat.train=yhat.train, |
|
|
486 |
X_test_layers=X_test_layers, |
|
|
487 |
yhat.test=yhat.test |
|
|
488 |
) |
|
|
489 |
}else{ |
|
|
490 |
res <- list(model_fits=model_fits, |
|
|
491 |
SL_fits=SL_fits, |
|
|
492 |
X_train_layers=X_train_layers, |
|
|
493 |
Y_train=Y, |
|
|
494 |
yhat.train=yhat.train |
|
|
495 |
) |
|
|
496 |
|
|
|
497 |
} |
|
|
498 |
|
|
|
499 |
|
|
|
500 |
############## |
|
|
501 |
# STACK ONLY # |
|
|
502 |
############## |
|
|
503 |
|
|
|
504 |
} else if (!run_concat & run_stacked){ |
|
|
505 |
|
|
|
506 |
model_fits <- list(model_layers = model_layers, |
|
|
507 |
model_stacked = model_stacked) |
|
|
508 |
|
|
|
509 |
SL_fits<-list(SL_fit_layers = SL_fit_layers, |
|
|
510 |
SL_fit_stacked = SL_fit_stacked) |
|
|
511 |
|
|
|
512 |
########################### |
|
|
513 |
# Prediction (Stack Only) # |
|
|
514 |
########################### |
|
|
515 |
|
|
|
516 |
if(refit.stack){ |
|
|
517 |
yhat.train <- cbind(combo, stacked_prediction_train) |
|
|
518 |
} else{ |
|
|
519 |
yhat.train <- cbind(combo, SL_fit_stacked$Z) |
|
|
520 |
} |
|
|
521 |
colnames(yhat.train) <- c(colnames(combo), "stacked") |
|
|
522 |
|
|
|
523 |
########################### |
|
|
524 |
# Validation (Stack Only) # |
|
|
525 |
########################### |
|
|
526 |
|
|
|
527 |
if(!is.null(feature_table_valid)){ |
|
|
528 |
yhat.test <- cbind(combo_valid, SL_fit_stacked$validPrediction) |
|
|
529 |
colnames(yhat.test) <- c(colnames(combo_valid),"stacked") |
|
|
530 |
|
|
|
531 |
######## |
|
|
532 |
# Save # |
|
|
533 |
######## |
|
|
534 |
|
|
|
535 |
res <- list(model_fits=model_fits, |
|
|
536 |
SL_fits=SL_fits, |
|
|
537 |
X_train_layers=X_train_layers, |
|
|
538 |
Y_train=Y, |
|
|
539 |
yhat.train=yhat.train, |
|
|
540 |
X_test_layers=X_test_layers, |
|
|
541 |
yhat.test=yhat.test |
|
|
542 |
) |
|
|
543 |
}else{ |
|
|
544 |
res <- list(model_fits=model_fits, |
|
|
545 |
SL_fits=SL_fits, |
|
|
546 |
X_train_layers=X_train_layers, |
|
|
547 |
Y_train=Y, |
|
|
548 |
yhat.train=yhat.train |
|
|
549 |
) |
|
|
550 |
|
|
|
551 |
} |
|
|
552 |
|
|
|
553 |
|
|
|
554 |
############################ |
|
|
555 |
# NEITHER CONCAT NOR STACK # |
|
|
556 |
############################ |
|
|
557 |
|
|
|
558 |
} else{ |
|
|
559 |
|
|
|
560 |
model_fits <- list(model_layers=model_layers) |
|
|
561 |
SL_fits<-list(SL_fit_layers = SL_fit_layers) |
|
|
562 |
|
|
|
563 |
######################################### |
|
|
564 |
# Prediction (Neither Stack nor Concat) # |
|
|
565 |
######################################### |
|
|
566 |
|
|
|
567 |
yhat.train <- combo |
|
|
568 |
colnames(yhat.train) <- colnames(combo) |
|
|
569 |
|
|
|
570 |
######################################### |
|
|
571 |
# Validation (Neither Stack nor Concat) # |
|
|
572 |
######################################### |
|
|
573 |
|
|
|
574 |
if(!is.null(feature_table_valid)){ |
|
|
575 |
yhat.test <- combo_valid |
|
|
576 |
colnames(yhat.test) <- colnames(combo_valid) |
|
|
577 |
|
|
|
578 |
######### |
|
|
579 |
# Save # |
|
|
580 |
######## |
|
|
581 |
|
|
|
582 |
res <- list(model_fits=model_fits, |
|
|
583 |
SL_fits=SL_fits, |
|
|
584 |
X_train_layers=X_train_layers, |
|
|
585 |
Y_train=Y, |
|
|
586 |
yhat.train=yhat.train, |
|
|
587 |
X_test_layers=X_test_layers, |
|
|
588 |
yhat.test=yhat.test |
|
|
589 |
) |
|
|
590 |
}else{ |
|
|
591 |
res <- list(model_fits=model_fits, |
|
|
592 |
SL_fits=SL_fits, |
|
|
593 |
X_train_layers=X_train_layers, |
|
|
594 |
Y_train=Y, |
|
|
595 |
yhat.train=yhat.train |
|
|
596 |
) |
|
|
597 |
|
|
|
598 |
} |
|
|
599 |
|
|
|
600 |
|
|
|
601 |
} |
|
|
602 |
if(!is.null(sample_metadata_valid)){res$Y_test=validY$Y} |
|
|
603 |
res$base_learner <- base_learner |
|
|
604 |
res$meta_learner <- meta_learner |
|
|
605 |
res$base_screener <- base_screener |
|
|
606 |
res$run_concat <- run_concat |
|
|
607 |
res$run_stacked <- run_stacked |
|
|
608 |
res$family <- family$family |
|
|
609 |
res$feature.names <- rownames(feature_table) |
|
|
610 |
if(is.null(sample_metadata_valid)){ |
|
|
611 |
res$test=FALSE |
|
|
612 |
}else{ |
|
|
613 |
res$test=TRUE |
|
|
614 |
} |
|
|
615 |
if(meta_learner=="SL.nnls.auc" & run_stacked){ |
|
|
616 |
res$weights <- res$model_fits$model_stacked$solution |
|
|
617 |
names(res$weights) <- colnames(combo) |
|
|
618 |
} |
|
|
619 |
|
|
|
620 |
if(res$family=="binomial"){ |
|
|
621 |
# Calculate AUC for each layer, stacked and concatenated |
|
|
622 |
pred=apply(res$yhat.train, 2, ROCR::prediction, labels=res$Y_train) |
|
|
623 |
AUC=vector(length = length(pred)) |
|
|
624 |
names(AUC)=names(pred) |
|
|
625 |
for(i in seq_along(pred)){ |
|
|
626 |
AUC[i] = round(ROCR::performance(pred[[i]], "auc")@y.values[[1]], 3) |
|
|
627 |
} |
|
|
628 |
res$AUC.train <- AUC |
|
|
629 |
|
|
|
630 |
if(res$test==TRUE){ |
|
|
631 |
|
|
|
632 |
# Calculate AUC for each layer, stacked and concatenated |
|
|
633 |
pred=apply(res$yhat.test, 2, ROCR::prediction, labels=res$Y_test) |
|
|
634 |
AUC=vector(length = length(pred)) |
|
|
635 |
names(AUC)=names(pred) |
|
|
636 |
for(i in seq_along(pred)){ |
|
|
637 |
AUC[i] = round(ROCR::performance(pred[[i]], "auc")@y.values[[1]], 3) |
|
|
638 |
} |
|
|
639 |
res$AUC.test <- AUC |
|
|
640 |
} |
|
|
641 |
} |
|
|
642 |
if(res$family=="gaussian"){ |
|
|
643 |
|
|
|
644 |
# Calculate R^2 for each layer, stacked and concatenated |
|
|
645 |
R2=vector(length = ncol(res$yhat.train)) |
|
|
646 |
names(R2)=names(res$yhat.train) |
|
|
647 |
for(i in seq_along(R2)){ |
|
|
648 |
R2[i] = as.vector(cor(res$yhat.train[ ,i], res$Y_train)^2) |
|
|
649 |
} |
|
|
650 |
res$R2.train <- R2 |
|
|
651 |
if(res$test==TRUE){ |
|
|
652 |
# Calculate R^2 for each layer, stacked and concatenated |
|
|
653 |
R2=vector(length = ncol(res$yhat.test)) |
|
|
654 |
names(R2)=names(res$yhat.test) |
|
|
655 |
for(i in seq_along(R2)){ |
|
|
656 |
R2[i] = as.vector(cor(res$yhat.test[ ,i], res$Y_test)^2) |
|
|
657 |
} |
|
|
658 |
res$R2.test <- R2 |
|
|
659 |
} |
|
|
660 |
|
|
|
661 |
} |
|
|
662 |
res$folds <- folds |
|
|
663 |
res$cvControl <- cvControl |
|
|
664 |
res$id <- id |
|
|
665 |
stop.time<-Sys.time() |
|
|
666 |
time <- as.numeric(round(difftime(stop.time, start.time, units="min"), 3), units = "mins") |
|
|
667 |
res$time <- time |
|
|
668 |
########## |
|
|
669 |
# Return # |
|
|
670 |
########## |
|
|
671 |
|
|
|
672 |
if(print_learner==TRUE){print.learner(res)} |
|
|
673 |
return(res) |
|
|
674 |
} |
|
|
675 |
|