Diff of /R/cross.val.R [000000] .. [28e211]

Switch to unified view

a b/R/cross.val.R
1
cross.val <- function(
2
  exp.df, class.vec, segments, performance, class.algo, quiet = TRUE
3
) {
4
  # Validation
5
  if (!(class.algo %in% c("J48", "rpart"))) {
6
    stop("Unknown classification algorithm")
7
  }
8
  # Start cross validation loop
9
  class1 <- levels(class.vec)[1]
10
  for (fold in seq_len(length(segments))) {
11
    if (!quiet) message("Fold ", fold, " of ", length(segments))
12
    # Define training and test set
13
    test.ind <- segments[[fold]]
14
    training.set <- exp.df[-test.ind, ]
15
    test.set <- exp.df[test.ind, , drop = FALSE]
16
    test.set$training.class <- class.vec[-test.ind]
17
    test.class <- class.vec[test.ind]
18
    # Train J48 on training set
19
    if (class.algo == "J48") {
20
      cv.model <- J48(training.class ~ ., training.set)
21
      pred.class <- predict(cv.model, test.set)
22
    } else {
23
      cv.model <- rpart(training.class ~ ., training.set, method = "class")
24
      pred.class <- predict(cv.model, test.set, type = "class")
25
    }
26
    # Evaluate model on test set
27
    performance <- eval.pred(
28
      pred.class, test.class, class1, performance
29
    )
30
  }
31
  return(performance)
32
}