|
a |
|
b/R/updateLearner.R |
|
|
1 |
#' Update IntegratedLearner fit object based on layers available in the test set |
|
|
2 |
#' |
|
|
3 |
#' @description Allow update of IntegratedLearner if only a subset of omics layers are available in test set. If all layers and features match, it calls predict.learner() |
|
|
4 |
#' |
|
|
5 |
#' @param fit fitted "IntegratedLearner" object |
|
|
6 |
#' |
|
|
7 |
#' |
|
|
8 |
#' |
|
|
9 |
#' @param feature_table_valid Feature table from validation set. It should be a data frame with features in rows and samples in columns. Feature names should be a subset of training data feature names. |
|
|
10 |
#' @param sample_metadata_valid OPTIONAL (can provide feature_table_valid and not this): Sample-specific metadata table from independent validation set. If provided, it must have the exact same structure as sample_metadata. Default is NULL. |
|
|
11 |
#' @param feature_metadata_valid Matrix containing feature names and their corresponding layers. Must be subset of feature_metadata provided in IntegratedLearner object. |
|
|
12 |
#' @param seed Seed for reproducibility. Default is 1234. |
|
|
13 |
#' @param verbose Should a summary of fits/ results be printed. Default is FALSE |
|
|
14 |
#' |
|
|
15 |
#' @return SL object |
|
|
16 |
#' @export |
|
|
17 |
update.learner <- function(fit, |
|
|
18 |
feature_table_valid, # Feature table from validation set. Must have the exact same structure as feature_table. If missing, uses feature_table for feature_table_valid. |
|
|
19 |
sample_metadata_valid=NULL, # OPTIONAL (can provide feature_table_valid and not this): Sample-specific metadata table from independent validation set. Must have the exact same structure as sample_metadata. |
|
|
20 |
feature_metadata_valid, |
|
|
21 |
seed = 1234, # Specify the arbitrary seed value for reproducibility. Default is 1234. |
|
|
22 |
verbose=FALSE |
|
|
23 |
){ |
|
|
24 |
# Check that feature table and feature meta data valid is not empty here |
|
|
25 |
if(is.null(feature_table_valid | is.null(feature_metadata_valid))){ |
|
|
26 |
stop("feature table/ feature metadata cannot be NULL for validation set in update learner") |
|
|
27 |
} |
|
|
28 |
|
|
|
29 |
if(fit$family=="gaussian"){ |
|
|
30 |
family=gaussian() |
|
|
31 |
}else if(fit$family=="binomial"){ |
|
|
32 |
family=binomial() |
|
|
33 |
} |
|
|
34 |
|
|
|
35 |
if (!is.null(sample_metadata_valid)){ |
|
|
36 |
validY<-sample_metadata_valid['Y'] |
|
|
37 |
} |
|
|
38 |
|
|
|
39 |
|
|
|
40 |
feature_metadata_valid$featureType<-as.factor(feature_metadata_valid$featureType) |
|
|
41 |
name_layers_valid<-with(droplevels(feature_metadata_valid), list(levels = levels(featureType)), nlevels = nlevels(featureType))$levels |
|
|
42 |
|
|
|
43 |
|
|
|
44 |
name_layers <- names(fit$model_fits$model_layers) |
|
|
45 |
|
|
|
46 |
# If layers in validation match layers in train |
|
|
47 |
# Just run predict function and return its object |
|
|
48 |
if(length(intersect(name_layers_valid,name_layers))==length(name_layers)){ |
|
|
49 |
|
|
|
50 |
# Check if feature names are same for the train and test |
|
|
51 |
|
|
|
52 |
return(predict.learner(fit, |
|
|
53 |
feature_table_valid = feature_table_valid, |
|
|
54 |
sample_metadata_valid = sample_metadata_valid, |
|
|
55 |
feature_metadata = feature_metadata_valid)) |
|
|
56 |
}else if(length(intersect(name_layers_valid,name_layers))==0){ |
|
|
57 |
|
|
|
58 |
stop("Validation set has no layers in common with model fit") |
|
|
59 |
|
|
|
60 |
}else{ |
|
|
61 |
|
|
|
62 |
name_layers_common <- intersect(name_layers_valid,name_layers) |
|
|
63 |
|
|
|
64 |
|
|
|
65 |
|
|
|
66 |
# Extract only common name layers part of the fit object |
|
|
67 |
fit$model_fits$model_layers <- fit$model_fits$model_layers[name_layers_common] |
|
|
68 |
fit$SL_fits$SL_fit_layers <- fit$SL_fits$SL_fit_layers[name_layers_common] |
|
|
69 |
fit$X_train_layers <- fit$X_train_layers[name_layers_common] |
|
|
70 |
|
|
|
71 |
# Use common layers to get layer wise predictions for validation set |
|
|
72 |
X_test_layers <- vector("list", length(name_layers_common)) |
|
|
73 |
names(X_test_layers) <- name_layers_common |
|
|
74 |
|
|
|
75 |
if (!is.null(feature_table_valid)){ |
|
|
76 |
layer_wise_prediction_valid<-vector("list", length(name_layers_common)) |
|
|
77 |
names(layer_wise_prediction_valid)<-name_layers_common |
|
|
78 |
} |
|
|
79 |
|
|
|
80 |
|
|
|
81 |
for(i in seq_along(name_layers_common)){ |
|
|
82 |
include_list<-feature_metadata_valid %>% filter(featureType == name_layers_common[i]) |
|
|
83 |
|
|
|
84 |
# check if feature names in common layers match for train and test set |
|
|
85 |
if(!all(include_list$featureID==colnames(fit$X_train_layers[name_layers_common[i]]))){ |
|
|
86 |
stop(paste0("Validation set feature names for layer ", name_layers_common[i]," do not match with training data" )) |
|
|
87 |
} |
|
|
88 |
|
|
|
89 |
|
|
|
90 |
if (!is.null(feature_table_valid)){ |
|
|
91 |
t_dat_slice_valid<-feature_table_valid[rownames(feature_table_valid) %in% include_list$featureID, ] |
|
|
92 |
dat_slice_valid<-as.data.frame(t(t_dat_slice_valid)) |
|
|
93 |
X_test_layers[[i]] <- dat_slice_valid |
|
|
94 |
layer_wise_prediction_valid[[i]]<-predict.SuperLearner(fit$SL_fits$SL_fit_layers[[i]], newdata = dat_slice_valid)$pred |
|
|
95 |
rownames(layer_wise_prediction_valid[[i]])<-rownames(dat_slice_valid) |
|
|
96 |
fit$SL_fits$SL_fit_layers[[i]]$validX<-dat_slice_valid |
|
|
97 |
fit$SL_fits$SL_fit_layers[[i]]$validPrediction<-layer_wise_prediction_valid[[i]] |
|
|
98 |
colnames(fit$SL_fits$SL_fit_layers[[i]]$validPrediction)<-'validPrediction' |
|
|
99 |
rm(dat_slice_valid); rm(include_list) |
|
|
100 |
} |
|
|
101 |
} |
|
|
102 |
|
|
|
103 |
combo <- fit$yhat.train[ ,name_layers_common] |
|
|
104 |
|
|
|
105 |
if (!is.null(feature_table_valid)){ |
|
|
106 |
combo_valid <- as.data.frame(do.call(cbind, layer_wise_prediction_valid)) |
|
|
107 |
names(combo_valid)<-name_layers_valid |
|
|
108 |
} |
|
|
109 |
|
|
|
110 |
|
|
|
111 |
if(fit$run_stacked){ |
|
|
112 |
|
|
|
113 |
cat('Running new stacked model...\n') |
|
|
114 |
#} |
|
|
115 |
|
|
|
116 |
################################### |
|
|
117 |
# Run user-specified meta learner # |
|
|
118 |
################################### |
|
|
119 |
|
|
|
120 |
SL_fit_stacked<-SuperLearner::SuperLearner(Y = fit$Y_train, |
|
|
121 |
X = combo, |
|
|
122 |
cvControl = fit$cvControl, |
|
|
123 |
verbose = verbose, |
|
|
124 |
SL.library = fit$meta_learner, |
|
|
125 |
family=family) |
|
|
126 |
|
|
|
127 |
# Extract the fit object from superlearner |
|
|
128 |
model_stacked <- SL_fit_stacked$fitLibrary[[1]]$object |
|
|
129 |
|
|
|
130 |
################################################### |
|
|
131 |
# Append the corresponding y and X to the results # |
|
|
132 |
################################################### |
|
|
133 |
|
|
|
134 |
SL_fit_stacked$Y<-fit$Y_train |
|
|
135 |
SL_fit_stacked$X<-combo |
|
|
136 |
if (!is.null(sample_metadata_valid)) SL_fit_stacked$validY<-validY |
|
|
137 |
|
|
|
138 |
################################################################# |
|
|
139 |
# Prepate stacked input data for validation and save prediction # |
|
|
140 |
################################################################# |
|
|
141 |
|
|
|
142 |
if (!is.null(feature_table_valid)){ |
|
|
143 |
stacked_prediction_valid<-predict.SuperLearner(SL_fit_stacked, newdata = combo_valid)$pred |
|
|
144 |
rownames(stacked_prediction_valid)<-rownames(combo_valid) |
|
|
145 |
SL_fit_stacked$validX<-combo_valid |
|
|
146 |
SL_fit_stacked$validPrediction<-stacked_prediction_valid |
|
|
147 |
colnames(SL_fit_stacked$validPrediction)<-'validPrediction' |
|
|
148 |
} |
|
|
149 |
|
|
|
150 |
fit$model_fits$model_stacked <- model_stacked |
|
|
151 |
fit$SL_fits$SL_fit_stacked <- SL_fit_stacked |
|
|
152 |
fit$yhat.train$stacked <- SL_fit_stacked$Z |
|
|
153 |
|
|
|
154 |
|
|
|
155 |
} |
|
|
156 |
|
|
|
157 |
|
|
|
158 |
if(fit$run_concat){ |
|
|
159 |
#if (verbose) { |
|
|
160 |
cat('Running new concatenated model...\n') |
|
|
161 |
#} |
|
|
162 |
################################### |
|
|
163 |
# Prepate concatenated input data # |
|
|
164 |
################################### |
|
|
165 |
feature_table <- Reduce(cbind.data.frame,fit$X_train_layers) |
|
|
166 |
feature_table <- feature_table[ ,feature_metadata_valid$featureID] |
|
|
167 |
fulldat<-as.data.frame(feature_table) |
|
|
168 |
|
|
|
169 |
################################### |
|
|
170 |
# Run user-specified base learner # |
|
|
171 |
################################### |
|
|
172 |
|
|
|
173 |
SL_fit_concat<-SuperLearner::SuperLearner(Y = fit$Y_train, |
|
|
174 |
X = fulldat, |
|
|
175 |
cvControl = fit$cvControl, |
|
|
176 |
verbose = verbose, |
|
|
177 |
SL.library = list(c(fit$base_learner,fit$base_screener)), |
|
|
178 |
family=family) |
|
|
179 |
|
|
|
180 |
# Extract the fit object from superlearner |
|
|
181 |
model_concat <- SL_fit_concat$fitLibrary[[1]]$object |
|
|
182 |
|
|
|
183 |
################################################### |
|
|
184 |
# Append the corresponding y and X to the results # |
|
|
185 |
################################################### |
|
|
186 |
|
|
|
187 |
SL_fit_concat$Y<-fit$Y_train |
|
|
188 |
SL_fit_concat$X<-fulldat |
|
|
189 |
if (!is.null(sample_metadata_valid)) SL_fit_concat$validY<-validY |
|
|
190 |
|
|
|
191 |
######################################################################### |
|
|
192 |
# Prepate concatenated input data for validaton set and save prediction # |
|
|
193 |
######################################################################### |
|
|
194 |
|
|
|
195 |
if (!is.null(feature_table_valid)){ |
|
|
196 |
fulldat_valid<-as.data.frame(t(feature_table_valid)) |
|
|
197 |
concat_prediction_valid<-predict.SuperLearner(SL_fit_concat, newdata = fulldat_valid)$pred |
|
|
198 |
SL_fit_concat$validX<-fulldat_valid |
|
|
199 |
rownames(concat_prediction_valid)<-rownames(fulldat_valid) |
|
|
200 |
SL_fit_concat$validPrediction<-concat_prediction_valid |
|
|
201 |
colnames(SL_fit_concat$validPrediction)<-'validPrediction' |
|
|
202 |
} |
|
|
203 |
|
|
|
204 |
fit$model_fits$model_concat <- model_concat |
|
|
205 |
fit$SL_fits$SL_fit_concat <- SL_fit_concat |
|
|
206 |
fit$yhat.train$concatenated <- SL_fit_concat$Z |
|
|
207 |
} |
|
|
208 |
|
|
|
209 |
|
|
|
210 |
if(fit$run_concat & fit$run_stacked){ |
|
|
211 |
fit$yhat.train <- fit$yhat.train[ ,c(name_layers_common,"stacked","concatenated")] |
|
|
212 |
|
|
|
213 |
}else if(fit$run_concat & !fit$run_stacked){ |
|
|
214 |
fit$yhat.train <- fit$yhat.train[ ,c(name_layers_common,"concatenated")] |
|
|
215 |
|
|
|
216 |
}else if(!fit$run_concat & fit$run_stacked){ |
|
|
217 |
fit$yhat.train <- fit$yhat.train[ ,c(name_layers_common,"stacked")] |
|
|
218 |
|
|
|
219 |
}else if(!fit$run_concat & !fit$run_stacked){ |
|
|
220 |
fit$yhat.train <- fit$yhat.train[ ,name_layers_common] |
|
|
221 |
|
|
|
222 |
} |
|
|
223 |
|
|
|
224 |
|
|
|
225 |
if(!is.null(feature_table_valid)){ |
|
|
226 |
|
|
|
227 |
if(fit$run_concat & fit$run_stacked){ |
|
|
228 |
yhat.test <- cbind(combo_valid, SL_fit_stacked$validPrediction,SL_fit_concat$validPrediction) |
|
|
229 |
colnames(yhat.test) <- c(colnames(combo_valid),"stacked","concatenated") |
|
|
230 |
|
|
|
231 |
}else if(fit$run_concat & !fit$run_stacked){ |
|
|
232 |
yhat.test <- cbind(combo_valid, SL_fit_concat$validPrediction) |
|
|
233 |
colnames(yhat.test) <- c(colnames(combo_valid),"concatenated") |
|
|
234 |
|
|
|
235 |
}else if(!fit$run_concat & fit$run_stacked){ |
|
|
236 |
yhat.test <- cbind(combo_valid, SL_fit_stacked$validPrediction) |
|
|
237 |
colnames(yhat.test) <- c(colnames(combo_valid),"stacked") |
|
|
238 |
|
|
|
239 |
}else if(!fit$run_concat & !fit$run_stacked){ |
|
|
240 |
yhat.test <- cbind(combo_valid) |
|
|
241 |
colnames(yhat.test) <- c(colnames(combo_valid)) |
|
|
242 |
|
|
|
243 |
} |
|
|
244 |
fit$yhat.test <- yhat.test |
|
|
245 |
fit$X_test_layers <- X_test_layers |
|
|
246 |
} |
|
|
247 |
if(is.null(sample_metadata_valid)){ |
|
|
248 |
fit$test=FALSE |
|
|
249 |
}else{ |
|
|
250 |
fit$test=TRUE |
|
|
251 |
} |
|
|
252 |
if(fit$meta_learner=="SL.nnls.auc" & fit$run_stacked){ |
|
|
253 |
fit$weights <- fit$model_fits$model_stacked$solution |
|
|
254 |
names(fit$weights) <- colnames(combo) |
|
|
255 |
} |
|
|
256 |
|
|
|
257 |
if(!is.null(sample_metadata_valid)){fit$Y_test=validY$Y} |
|
|
258 |
|
|
|
259 |
if(fit$family=="binomial"){ |
|
|
260 |
# Calculate AUC for each layer, stacked and concatenated |
|
|
261 |
pred=apply(fit$yhat.train, 2, ROCR::prediction, labels=fit$Y_train) |
|
|
262 |
AUC=vector(length = length(pred)) |
|
|
263 |
names(AUC)=names(pred) |
|
|
264 |
for(i in seq_along(pred)){ |
|
|
265 |
AUC[i] = round(ROCR::performance(pred[[i]], "auc")@y.values[[1]], 3) |
|
|
266 |
} |
|
|
267 |
fit$AUC.train <- AUC |
|
|
268 |
|
|
|
269 |
if(fit$test==TRUE){ |
|
|
270 |
# Calculate AUC for each layer, stacked and concatenated |
|
|
271 |
pred=apply(fit$yhat.test, 2, ROCR::prediction, labels=fit$Y_test) |
|
|
272 |
AUC=vector(length = length(pred)) |
|
|
273 |
names(AUC)=names(pred) |
|
|
274 |
for(i in seq_along(pred)){ |
|
|
275 |
AUC[i] = round(ROCR::performance(pred[[i]], "auc")@y.values[[1]], 3) |
|
|
276 |
} |
|
|
277 |
fit$AUC.test <- AUC |
|
|
278 |
} |
|
|
279 |
} |
|
|
280 |
if(fit$family=="gaussian"){ |
|
|
281 |
|
|
|
282 |
# Calculate R^2 for each layer, stacked and concatenated |
|
|
283 |
R2=vector(length = ncol(fit$yhat.train)) |
|
|
284 |
names(R2)=names(fit$yhat.train) |
|
|
285 |
for(i in seq_along(R2)){ |
|
|
286 |
R2[i] = as.vector(cor(fit$yhat.train[ ,i], fit$Y_train)^2) |
|
|
287 |
} |
|
|
288 |
fit$R2.train <- R2 |
|
|
289 |
if(fit$test==TRUE){ |
|
|
290 |
# Calculate R^2 for each layer, stacked and concatenated |
|
|
291 |
R2=vector(length = ncol(fit$yhat.test)) |
|
|
292 |
names(R2)=names(fit$yhat.test) |
|
|
293 |
for(i in seq_along(R2)){ |
|
|
294 |
R2[i] = as.vector(cor(fit$yhat.test[ ,i], fit$Y_test)^2) |
|
|
295 |
} |
|
|
296 |
fit$R2.test <- R2 |
|
|
297 |
} |
|
|
298 |
|
|
|
299 |
} |
|
|
300 |
fit$feature.names <- rownames(feature_table_valid) |
|
|
301 |
print.learner(fit) |
|
|
302 |
return(fit) |
|
|
303 |
} |
|
|
304 |
} |