Diff of /R/cross.val.R [000000] .. [28e211]

Switch to side-by-side view

--- a
+++ b/R/cross.val.R
@@ -0,0 +1,32 @@
+cross.val <- function(
+  exp.df, class.vec, segments, performance, class.algo, quiet = TRUE
+) {
+  # Validation
+  if (!(class.algo %in% c("J48", "rpart"))) {
+    stop("Unknown classification algorithm")
+  }
+  # Start cross validation loop
+  class1 <- levels(class.vec)[1]
+  for (fold in seq_len(length(segments))) {
+    if (!quiet) message("Fold ", fold, " of ", length(segments))
+    # Define training and test set
+    test.ind <- segments[[fold]]
+    training.set <- exp.df[-test.ind, ]
+    test.set <- exp.df[test.ind, , drop = FALSE]
+    test.set$training.class <- class.vec[-test.ind]
+    test.class <- class.vec[test.ind]
+    # Train J48 on training set
+    if (class.algo == "J48") {
+      cv.model <- J48(training.class ~ ., training.set)
+      pred.class <- predict(cv.model, test.set)
+    } else {
+      cv.model <- rpart(training.class ~ ., training.set, method = "class")
+      pred.class <- predict(cv.model, test.set, type = "class")
+    }
+    # Evaluate model on test set
+    performance <- eval.pred(
+      pred.class, test.class, class1, performance
+    )
+  }
+  return(performance)
+}