--- a
+++ b/R/updateLearner.R
@@ -0,0 +1,304 @@
+#' Update IntegratedLearner fit object based on layers available in the test set
+#'
+#' @description Allow update of IntegratedLearner if only a subset of omics layers are available in test set. If all layers and features match, it calls predict.learner() 
+#'
+#' @param fit fitted "IntegratedLearner" object 
+#' 
+#' 
+#'
+#' @param feature_table_valid Feature table from validation set. It should be a data frame with features in rows and samples in columns. Feature names should be a subset of training data feature names.
+#' @param sample_metadata_valid OPTIONAL (can provide feature_table_valid and not this):  Sample-specific metadata table from independent validation set. If provided, it must have the exact same structure as sample_metadata. Default is NULL. 
+#' @param feature_metadata_valid Matrix containing feature names and their corresponding layers. Must be subset of feature_metadata provided in IntegratedLearner object. 
+#' @param seed Seed for reproducibility. Default is 1234.
+#' @param verbose Should a summary of fits/ results be printed. Default is FALSE
+#'
+#' @return SL object
+#' @export
+update.learner <- function(fit,                                     
+                           feature_table_valid, # Feature table from validation set. Must have the exact same structure as feature_table. If missing, uses feature_table for feature_table_valid.
+                           sample_metadata_valid=NULL, # OPTIONAL (can provide feature_table_valid and not this):  Sample-specific metadata table from independent validation set. Must have the exact same structure as sample_metadata.
+                           feature_metadata_valid,
+                           seed = 1234, # Specify the arbitrary seed value for reproducibility. Default is 1234.
+                           verbose=FALSE
+){
+  # Check that feature table and feature meta data valid is not empty here
+  if(is.null(feature_table_valid | is.null(feature_metadata_valid))){
+    stop("feature table/ feature metadata cannot be NULL for validation set in update learner")
+  }
+  
+  if(fit$family=="gaussian"){
+    family=gaussian()
+  }else if(fit$family=="binomial"){
+    family=binomial()
+  }
+  
+  if (!is.null(sample_metadata_valid)){
+    validY<-sample_metadata_valid['Y']
+  }
+  
+  
+  feature_metadata_valid$featureType<-as.factor(feature_metadata_valid$featureType)
+  name_layers_valid<-with(droplevels(feature_metadata_valid), list(levels = levels(featureType)), nlevels = nlevels(featureType))$levels
+  
+  
+  name_layers <- names(fit$model_fits$model_layers)
+  
+  # If layers in validation match layers in train
+  # Just run predict function and return its object
+  if(length(intersect(name_layers_valid,name_layers))==length(name_layers)){
+    
+    # Check if feature names are same for the train and test 
+    
+    return(predict.learner(fit, 
+                           feature_table_valid = feature_table_valid,
+                           sample_metadata_valid = sample_metadata_valid,
+                           feature_metadata = feature_metadata_valid))
+  }else if(length(intersect(name_layers_valid,name_layers))==0){
+    
+    stop("Validation set has no layers in common with model fit")
+    
+  }else{
+    
+    name_layers_common <- intersect(name_layers_valid,name_layers)
+    
+    
+    
+    # Extract only common name layers part of the fit object 
+    fit$model_fits$model_layers <- fit$model_fits$model_layers[name_layers_common]
+    fit$SL_fits$SL_fit_layers <-  fit$SL_fits$SL_fit_layers[name_layers_common]
+    fit$X_train_layers <- fit$X_train_layers[name_layers_common]
+    
+    # Use common layers to get layer wise predictions for validation set
+    X_test_layers <- vector("list", length(name_layers_common)) 
+    names(X_test_layers) <- name_layers_common
+    
+    if (!is.null(feature_table_valid)){
+      layer_wise_prediction_valid<-vector("list", length(name_layers_common))
+      names(layer_wise_prediction_valid)<-name_layers_common
+    } 
+    
+    
+    for(i in seq_along(name_layers_common)){
+      include_list<-feature_metadata_valid %>% filter(featureType == name_layers_common[i]) 
+      
+      # check if feature names in common layers match for train and test set 
+      if(!all(include_list$featureID==colnames(fit$X_train_layers[name_layers_common[i]]))){
+        stop(paste0("Validation set feature names for layer ", name_layers_common[i]," do not match with training data" ))
+      }
+      
+      
+      if (!is.null(feature_table_valid)){
+        t_dat_slice_valid<-feature_table_valid[rownames(feature_table_valid) %in% include_list$featureID, ]
+        dat_slice_valid<-as.data.frame(t(t_dat_slice_valid))
+        X_test_layers[[i]] <- dat_slice_valid
+        layer_wise_prediction_valid[[i]]<-predict.SuperLearner(fit$SL_fits$SL_fit_layers[[i]], newdata = dat_slice_valid)$pred
+        rownames(layer_wise_prediction_valid[[i]])<-rownames(dat_slice_valid)
+        fit$SL_fits$SL_fit_layers[[i]]$validX<-dat_slice_valid
+        fit$SL_fits$SL_fit_layers[[i]]$validPrediction<-layer_wise_prediction_valid[[i]]
+        colnames(fit$SL_fits$SL_fit_layers[[i]]$validPrediction)<-'validPrediction'
+        rm(dat_slice_valid); rm(include_list)
+      }
+    }
+    
+    combo <- fit$yhat.train[ ,name_layers_common]
+    
+    if (!is.null(feature_table_valid)){
+      combo_valid <- as.data.frame(do.call(cbind, layer_wise_prediction_valid))
+      names(combo_valid)<-name_layers_valid
+    }
+    
+    
+    if(fit$run_stacked){
+      
+      cat('Running new stacked model...\n')
+      #}
+      
+      ###################################
+      # Run user-specified meta learner #
+      ###################################
+      
+      SL_fit_stacked<-SuperLearner::SuperLearner(Y = fit$Y_train, 
+                                                 X = combo, 
+                                                 cvControl = fit$cvControl,    
+                                                 verbose = verbose, 
+                                                 SL.library = fit$meta_learner,
+                                                 family=family)
+      
+      # Extract the fit object from superlearner
+      model_stacked <- SL_fit_stacked$fitLibrary[[1]]$object
+      
+      ###################################################
+      # Append the corresponding y and X to the results #
+      ###################################################
+      
+      SL_fit_stacked$Y<-fit$Y_train
+      SL_fit_stacked$X<-combo
+      if (!is.null(sample_metadata_valid)) SL_fit_stacked$validY<-validY
+      
+      #################################################################
+      # Prepate stacked input data for validation and save prediction #
+      #################################################################
+      
+      if (!is.null(feature_table_valid)){
+        stacked_prediction_valid<-predict.SuperLearner(SL_fit_stacked, newdata = combo_valid)$pred
+        rownames(stacked_prediction_valid)<-rownames(combo_valid)
+        SL_fit_stacked$validX<-combo_valid
+        SL_fit_stacked$validPrediction<-stacked_prediction_valid
+        colnames(SL_fit_stacked$validPrediction)<-'validPrediction'
+      }
+      
+      fit$model_fits$model_stacked <- model_stacked
+      fit$SL_fits$SL_fit_stacked <- SL_fit_stacked
+      fit$yhat.train$stacked <- SL_fit_stacked$Z
+      
+      
+    }
+    
+    
+    if(fit$run_concat){
+      #if (verbose) {
+      cat('Running new concatenated model...\n')
+      #}
+      ###################################
+      # Prepate concatenated input data #
+      ###################################
+      feature_table <-  Reduce(cbind.data.frame,fit$X_train_layers)
+      feature_table <- feature_table[ ,feature_metadata_valid$featureID]
+      fulldat<-as.data.frame(feature_table)
+      
+      ###################################
+      # Run user-specified base learner #
+      ###################################
+      
+      SL_fit_concat<-SuperLearner::SuperLearner(Y = fit$Y_train, 
+                                                X = fulldat, 
+                                                cvControl = fit$cvControl,    
+                                                verbose = verbose, 
+                                                SL.library = list(c(fit$base_learner,fit$base_screener)),
+                                                family=family)
+      
+      # Extract the fit object from superlearner
+      model_concat <- SL_fit_concat$fitLibrary[[1]]$object
+      
+      ###################################################
+      # Append the corresponding y and X to the results #
+      ###################################################
+      
+      SL_fit_concat$Y<-fit$Y_train
+      SL_fit_concat$X<-fulldat
+      if (!is.null(sample_metadata_valid)) SL_fit_concat$validY<-validY
+      
+      #########################################################################
+      # Prepate concatenated input data for validaton set and save prediction #
+      #########################################################################
+      
+      if (!is.null(feature_table_valid)){
+        fulldat_valid<-as.data.frame(t(feature_table_valid))
+        concat_prediction_valid<-predict.SuperLearner(SL_fit_concat, newdata = fulldat_valid)$pred
+        SL_fit_concat$validX<-fulldat_valid
+        rownames(concat_prediction_valid)<-rownames(fulldat_valid)
+        SL_fit_concat$validPrediction<-concat_prediction_valid
+        colnames(SL_fit_concat$validPrediction)<-'validPrediction'
+      }
+      
+      fit$model_fits$model_concat <- model_concat
+      fit$SL_fits$SL_fit_concat <- SL_fit_concat
+      fit$yhat.train$concatenated <- SL_fit_concat$Z
+    }
+    
+    
+    if(fit$run_concat & fit$run_stacked){
+      fit$yhat.train <- fit$yhat.train[ ,c(name_layers_common,"stacked","concatenated")]
+      
+    }else if(fit$run_concat & !fit$run_stacked){
+      fit$yhat.train <- fit$yhat.train[ ,c(name_layers_common,"concatenated")]
+      
+    }else if(!fit$run_concat & fit$run_stacked){
+      fit$yhat.train <- fit$yhat.train[ ,c(name_layers_common,"stacked")]
+      
+    }else if(!fit$run_concat & !fit$run_stacked){
+      fit$yhat.train <- fit$yhat.train[ ,name_layers_common]
+      
+    }
+    
+    
+    if(!is.null(feature_table_valid)){
+      
+      if(fit$run_concat & fit$run_stacked){
+        yhat.test <- cbind(combo_valid, SL_fit_stacked$validPrediction,SL_fit_concat$validPrediction)
+        colnames(yhat.test) <- c(colnames(combo_valid),"stacked","concatenated")
+        
+      }else if(fit$run_concat & !fit$run_stacked){
+        yhat.test <- cbind(combo_valid, SL_fit_concat$validPrediction)
+        colnames(yhat.test) <- c(colnames(combo_valid),"concatenated")
+        
+      }else if(!fit$run_concat & fit$run_stacked){
+        yhat.test <- cbind(combo_valid, SL_fit_stacked$validPrediction)
+        colnames(yhat.test) <- c(colnames(combo_valid),"stacked")
+        
+      }else if(!fit$run_concat & !fit$run_stacked){
+        yhat.test <- cbind(combo_valid)
+        colnames(yhat.test) <- c(colnames(combo_valid))
+        
+      }
+      fit$yhat.test <- yhat.test
+      fit$X_test_layers <- X_test_layers
+    }
+    if(is.null(sample_metadata_valid)){
+      fit$test=FALSE
+    }else{
+      fit$test=TRUE
+    }
+    if(fit$meta_learner=="SL.nnls.auc" & fit$run_stacked){
+      fit$weights <- fit$model_fits$model_stacked$solution
+      names(fit$weights) <- colnames(combo)
+    }
+    
+    if(!is.null(sample_metadata_valid)){fit$Y_test=validY$Y}
+    
+    if(fit$family=="binomial"){
+      # Calculate AUC for each layer, stacked and concatenated 
+      pred=apply(fit$yhat.train, 2, ROCR::prediction, labels=fit$Y_train)
+      AUC=vector(length = length(pred))
+      names(AUC)=names(pred)
+      for(i in seq_along(pred)){
+        AUC[i] = round(ROCR::performance(pred[[i]], "auc")@y.values[[1]], 3)
+      }
+      fit$AUC.train <- AUC
+      
+      if(fit$test==TRUE){
+        # Calculate AUC for each layer, stacked and concatenated 
+        pred=apply(fit$yhat.test, 2, ROCR::prediction, labels=fit$Y_test)
+        AUC=vector(length = length(pred))
+        names(AUC)=names(pred)
+        for(i in seq_along(pred)){
+          AUC[i] = round(ROCR::performance(pred[[i]], "auc")@y.values[[1]], 3)
+        }
+        fit$AUC.test <- AUC  
+      }
+    }
+    if(fit$family=="gaussian"){
+      
+      # Calculate R^2 for each layer, stacked and concatenated 
+      R2=vector(length = ncol(fit$yhat.train))
+      names(R2)=names(fit$yhat.train)
+      for(i in seq_along(R2)){
+        R2[i] = as.vector(cor(fit$yhat.train[ ,i], fit$Y_train)^2)
+      }
+      fit$R2.train <- R2
+      if(fit$test==TRUE){
+        # Calculate R^2 for each layer, stacked and concatenated 
+        R2=vector(length = ncol(fit$yhat.test))
+        names(R2)=names(fit$yhat.test)
+        for(i in seq_along(R2)){
+          R2[i] = as.vector(cor(fit$yhat.test[ ,i], fit$Y_test)^2)
+        }
+        fit$R2.test <- R2
+      }
+      
+    }  
+    fit$feature.names <- rownames(feature_table_valid)
+    print.learner(fit)
+    return(fit)
+  }
+}