|
a |
|
b/R/predictLearner.R |
|
|
1 |
#' Make predictions using a trained 'IntegratedLearner' model |
|
|
2 |
#' |
|
|
3 |
#'@description This function makes predictions using a trained 'IntegratedLearner' model for new samples for which predictions are to be made |
|
|
4 |
#' |
|
|
5 |
#' @param fit fitted "IntegratedLearner" object |
|
|
6 |
#' @param feature_table_valid Feature table from validation set. Must have the exact same structure as feature_table. |
|
|
7 |
#' @param sample_metadata_valid OPTIONAL (can provide feature_table_valid and not this): Sample-specific metadata table from independent validation set. If provided, it must have the exact same structure as sample_metadata. |
|
|
8 |
#' @param feature_metadata Matrix containing feature names and their corresponding layers. Must be same as that provided in IntegratedLearner object. |
|
|
9 |
#' |
|
|
10 |
#' @return Predicted values |
|
|
11 |
#' @export |
|
|
12 |
predict.learner <- function(fit, |
|
|
13 |
feature_table_valid = NULL, # Feature table from validation set. Must have the exact same structure as feature_table. If missing, uses feature_table for feature_table_valid. |
|
|
14 |
sample_metadata_valid = NULL, # Optional: Sample-specific metadata table from independent validation set. Must have the exact same structure as sample_metadata. |
|
|
15 |
feature_metadata=NULL){ |
|
|
16 |
|
|
|
17 |
if(all(fit$feature.names==rownames(feature_metadata))==FALSE){ |
|
|
18 |
stop("Both training feature_table and feature_metadata should have the same rownames.") |
|
|
19 |
} |
|
|
20 |
|
|
|
21 |
|
|
|
22 |
if(is.null(feature_table_valid)){ |
|
|
23 |
stop("Feature table for validation set cannot be empty") |
|
|
24 |
} |
|
|
25 |
# if(is.null(sample_metadata_valid)){ |
|
|
26 |
# stop("Sample metadata for validation set cannot be empty") |
|
|
27 |
# } |
|
|
28 |
|
|
|
29 |
if (!is.null(feature_table_valid)){ |
|
|
30 |
if(all(fit$feature.names==rownames(feature_table_valid))==FALSE) |
|
|
31 |
stop("Both feature_table and feature_table_valid should have the same rownames.") |
|
|
32 |
} |
|
|
33 |
|
|
|
34 |
if (!is.null(sample_metadata_valid)){ |
|
|
35 |
if(all(colnames(feature_table_valid)==rownames(sample_metadata_valid))==FALSE) |
|
|
36 |
stop("Row names of sample_metadata_valid must match the column names of feature_table_valid") |
|
|
37 |
} |
|
|
38 |
|
|
|
39 |
|
|
|
40 |
|
|
|
41 |
if (!'featureID' %in% colnames(feature_metadata)){ |
|
|
42 |
stop("feature_metadata must have a column named 'featureID' describing per-feature unique identifiers.") |
|
|
43 |
} |
|
|
44 |
|
|
|
45 |
if (!'featureType' %in% colnames(feature_metadata)){ |
|
|
46 |
stop("feature_metadata must have a column named 'featureType' describing the corresponding source layers.") |
|
|
47 |
} |
|
|
48 |
|
|
|
49 |
if (!is.null(sample_metadata_valid)){ |
|
|
50 |
if (!'subjectID' %in% colnames(sample_metadata_valid)){ |
|
|
51 |
stop("sample_metadata_valid must have a column named 'subjectID' describing per-subject unique identifiers.") |
|
|
52 |
} |
|
|
53 |
|
|
|
54 |
if (!'Y' %in% colnames(sample_metadata_valid)){ |
|
|
55 |
stop("sample_metadata_valid must have a column named 'Y' describing the outcome of interest.") |
|
|
56 |
} |
|
|
57 |
} |
|
|
58 |
|
|
|
59 |
############################################################################################# |
|
|
60 |
# Extract validation Y right away (will not be used anywhere during the validation process) # |
|
|
61 |
############################################################################################# |
|
|
62 |
|
|
|
63 |
if (!is.null(sample_metadata_valid)){validY<-sample_metadata_valid['Y']} |
|
|
64 |
|
|
|
65 |
##################################################################### |
|
|
66 |
# Stacked generalization input data preparation for validation data # |
|
|
67 |
##################################################################### |
|
|
68 |
feature_metadata$featureType<-as.factor(feature_metadata$featureType) |
|
|
69 |
name_layers<-with(droplevels(feature_metadata), list(levels = levels(featureType)), |
|
|
70 |
nlevels = nlevels(featureType))$levels |
|
|
71 |
|
|
|
72 |
X_test_layers <- vector("list", length(name_layers)) |
|
|
73 |
names(X_test_layers) <- name_layers |
|
|
74 |
|
|
|
75 |
layer_wise_prediction_valid<-vector("list", length(name_layers)) |
|
|
76 |
names(layer_wise_prediction_valid)<-name_layers |
|
|
77 |
|
|
|
78 |
for(i in seq_along(name_layers)){ |
|
|
79 |
|
|
|
80 |
############################################################ |
|
|
81 |
# Prepare single-omic validation data and save predictions # |
|
|
82 |
############################################################ |
|
|
83 |
include_list<-feature_metadata %>% filter(featureType == name_layers[i]) |
|
|
84 |
t_dat_slice_valid<-feature_table_valid[rownames(feature_table_valid) %in% include_list$featureID, ] |
|
|
85 |
dat_slice_valid<-as.data.frame(t(t_dat_slice_valid)) |
|
|
86 |
X_test_layers[[i]] <- dat_slice_valid |
|
|
87 |
layer_wise_prediction_valid[[i]]<-predict.SuperLearner(fit$SL_fits$SL_fit_layers[[i]], newdata = dat_slice_valid)$pred |
|
|
88 |
rownames(layer_wise_prediction_valid[[i]])<-rownames(dat_slice_valid) |
|
|
89 |
rm(dat_slice_valid); rm(include_list) |
|
|
90 |
} |
|
|
91 |
|
|
|
92 |
combo_valid <- as.data.frame(do.call(cbind, layer_wise_prediction_valid)) |
|
|
93 |
names(combo_valid)<-name_layers |
|
|
94 |
|
|
|
95 |
if(fit$run_stacked==TRUE){ |
|
|
96 |
stacked_prediction_valid<-predict.SuperLearner(fit$SL_fits$SL_fit_stacked, newdata = combo_valid)$pred |
|
|
97 |
rownames(stacked_prediction_valid)<-rownames(combo_valid) |
|
|
98 |
} |
|
|
99 |
if(fit$run_concat==TRUE){ |
|
|
100 |
fulldat_valid<-as.data.frame(t(feature_table_valid)) |
|
|
101 |
concat_prediction_valid<-predict.SuperLearner(fit$SL_fits$SL_fit_concat, |
|
|
102 |
newdata = fulldat_valid)$pred |
|
|
103 |
rownames(concat_prediction_valid)<-rownames(fulldat_valid) |
|
|
104 |
} |
|
|
105 |
|
|
|
106 |
res=list() |
|
|
107 |
|
|
|
108 |
if (!is.null(sample_metadata_valid)){ |
|
|
109 |
Y_test=validY$Y |
|
|
110 |
res$Y_test =Y_test |
|
|
111 |
} |
|
|
112 |
|
|
|
113 |
if(fit$run_concat & fit$run_stacked){ |
|
|
114 |
yhat.test <- cbind(combo_valid, stacked_prediction_valid , concat_prediction_valid) |
|
|
115 |
colnames(yhat.test) <- c(colnames(combo_valid),"stacked","concatenated") |
|
|
116 |
}else if(fit$run_concat & !fit$run_stacked){ |
|
|
117 |
yhat.test <- cbind(combo_valid, concat_prediction_valid) |
|
|
118 |
colnames(yhat.test) <- c(colnames(combo_valid),"concatenated") |
|
|
119 |
}else if(!fit$run_concat & fit$run_stacked){ |
|
|
120 |
yhat.test <- cbind(combo_valid, stacked_prediction_valid ) |
|
|
121 |
colnames(yhat.test) <- c(colnames(combo_valid),"stacked") |
|
|
122 |
}else{ |
|
|
123 |
yhat.test <- combo_valid |
|
|
124 |
} |
|
|
125 |
|
|
|
126 |
res$yhat.test <- yhat.test |
|
|
127 |
if (!is.null(sample_metadata_valid)){ |
|
|
128 |
if(fit$family=='binomial'){ |
|
|
129 |
# Calculate AUC for each layer, stacked and concatenated |
|
|
130 |
pred=apply(res$yhat.test, 2, ROCR::prediction, labels=res$Y_test) |
|
|
131 |
AUC=vector(length = length(pred)) |
|
|
132 |
names(AUC)=names(pred) |
|
|
133 |
for(i in seq_along(pred)){ |
|
|
134 |
AUC[i] = round(ROCR::performance(pred[[i]], "auc")@y.values[[1]], 3) |
|
|
135 |
} |
|
|
136 |
res$AUC.test <- AUC |
|
|
137 |
|
|
|
138 |
} |
|
|
139 |
|
|
|
140 |
if(fit$family=='gaussian'){ |
|
|
141 |
# Calculate R^2 for each layer, stacked and concatenated |
|
|
142 |
R2=vector(length = ncol(res$yhat.test)) |
|
|
143 |
names(R2)=names(res$yhat.test) |
|
|
144 |
for(i in seq_along(R2)){ |
|
|
145 |
R2[i] = as.vector(cor(res$yhat.test[ ,i], res$Y_test)^2) |
|
|
146 |
} |
|
|
147 |
res$R2.test <- R2 |
|
|
148 |
} |
|
|
149 |
} |
|
|
150 |
|
|
|
151 |
return(res) |
|
|
152 |
|
|
|
153 |
} |
|
|
154 |
|