|
a |
|
b/partyMod/R/Predict.R |
|
|
1 |
|
|
|
2 |
# $Id$ |
|
|
3 |
|
|
|
4 |
predict.BinaryTree <- function(object, ...) { |
|
|
5 |
conditionalTree@predict(object, ...) |
|
|
6 |
} |
|
|
7 |
|
|
|
8 |
predict.RandomForest <- function(object, OOB = FALSE, ...) { |
|
|
9 |
RandomForest@predict(object, OOB = OOB, ...) |
|
|
10 |
} |
|
|
11 |
|
|
|
12 |
|
|
|
13 |
setGeneric("treeresponse", function(object, ...) |
|
|
14 |
standardGeneric("treeresponse")) |
|
|
15 |
|
|
|
16 |
setMethod("treeresponse", signature = signature(object = "BinaryTree"), |
|
|
17 |
definition = function(object, newdata = NULL, ...) |
|
|
18 |
object@cond_distr_response(newdata = newdata, ...) |
|
|
19 |
) |
|
|
20 |
|
|
|
21 |
setMethod("treeresponse", signature = signature(object = "RandomForest"), |
|
|
22 |
definition = function(object, newdata = NULL, ...) |
|
|
23 |
object@cond_distr_response(newdata = newdata, ...) |
|
|
24 |
) |
|
|
25 |
|
|
|
26 |
|
|
|
27 |
### weights is an S3 generic |
|
|
28 |
###setGeneric("weights", function(object, ...) standardGeneric("weights")) |
|
|
29 |
### |
|
|
30 |
###setMethod("weights", signature = signature(object = "BinaryTree"), |
|
|
31 |
### definition = function(object, newdata = NULL, ...) |
|
|
32 |
### object@prediction_weights(newdata = newdata, ...) |
|
|
33 |
###) |
|
|
34 |
### |
|
|
35 |
###setMethod("weights", signature = signature(object = "RandomForest"), |
|
|
36 |
### definition = function(object, newdata = NULL, OOB = FALSE, ...) |
|
|
37 |
### object@prediction_weights(newdata = newdata, OOB = OOB, ...) |
|
|
38 |
###) |
|
|
39 |
|
|
|
40 |
weights.BinaryTree <- function(object, newdata = NULL, ...) |
|
|
41 |
object@prediction_weights(newdata = newdata, ...) |
|
|
42 |
|
|
|
43 |
weights.RandomForest <- function(object, newdata = NULL, OOB = FALSE, ...) |
|
|
44 |
object@prediction_weights(newdata = newdata, OOB = OOB, ...) |
|
|
45 |
|
|
|
46 |
|
|
|
47 |
setGeneric("where", function(object, ...) standardGeneric("where")) |
|
|
48 |
|
|
|
49 |
setMethod("where", signature = signature(object = "BinaryTree"), |
|
|
50 |
definition = function(object, newdata = NULL, ...) { |
|
|
51 |
if(is.null(newdata)) object@where |
|
|
52 |
else object@get_where(newdata = newdata, ...) |
|
|
53 |
} |
|
|
54 |
) |
|
|
55 |
|
|
|
56 |
setMethod("where", signature = signature(object = "RandomForest"), |
|
|
57 |
definition = function(object, newdata = NULL, ...) { |
|
|
58 |
if(is.null(newdata)) object@where |
|
|
59 |
else object@get_where(newdata = newdata, ...) |
|
|
60 |
} |
|
|
61 |
) |
|
|
62 |
|
|
|
63 |
|
|
|
64 |
setGeneric("nodes", function(object, where, ...) standardGeneric("nodes")) |
|
|
65 |
|
|
|
66 |
setMethod("nodes", signature = signature(object = "BinaryTree", |
|
|
67 |
where = "integer"), |
|
|
68 |
definition = function(object, where, ...) |
|
|
69 |
lapply(where, function(i) .Call("R_get_nodebynum", object@tree, i)) |
|
|
70 |
) |
|
|
71 |
|
|
|
72 |
setMethod("nodes", signature = signature(object = "BinaryTree", |
|
|
73 |
where = "numeric"), |
|
|
74 |
definition = function(object, where, ...) |
|
|
75 |
nodes(object, as.integer(where)) |
|
|
76 |
) |