|
a |
|
b/partyMod/R/Classes.R |
|
|
1 |
|
|
|
2 |
# $Id$ |
|
|
3 |
|
|
|
4 |
### Linear statistic with expectation and covariance |
|
|
5 |
setClass(Class = "LinStatExpectCovar", |
|
|
6 |
representation = representation( |
|
|
7 |
linearstatistic = "numeric", |
|
|
8 |
expcovinf = "ExpectCovarInfluence" |
|
|
9 |
), |
|
|
10 |
contains = "ExpectCovar" |
|
|
11 |
) |
|
|
12 |
|
|
|
13 |
### Memory for C_svd |
|
|
14 |
setClass(Class = "svd_mem", |
|
|
15 |
representation = representation( |
|
|
16 |
method = "character", |
|
|
17 |
jobu = "character", |
|
|
18 |
jobv = "character", |
|
|
19 |
u = "matrix", |
|
|
20 |
v = "matrix", |
|
|
21 |
s = "numeric", |
|
|
22 |
p = "integer" |
|
|
23 |
) |
|
|
24 |
) |
|
|
25 |
|
|
|
26 |
### with Moore-Penrose inverse of the covariance matrix |
|
|
27 |
setClass(Class = "LinStatExpectCovarMPinv", |
|
|
28 |
representation = representation( |
|
|
29 |
MPinv = "matrix", |
|
|
30 |
rank = "numeric", |
|
|
31 |
svdmem = "svd_mem" |
|
|
32 |
), |
|
|
33 |
contains = "LinStatExpectCovar" |
|
|
34 |
) |
|
|
35 |
|
|
|
36 |
################ Memory Classes ##################### |
|
|
37 |
|
|
|
38 |
setClass(Class = "TreeFitMemory", |
|
|
39 |
representation = representation( |
|
|
40 |
expcovinf = "ExpectCovarInfluence", |
|
|
41 |
expcovinfss = "ExpectCovarInfluence", |
|
|
42 |
linexpcov2sample = "LinStatExpectCovar", |
|
|
43 |
weights = "numeric", |
|
|
44 |
varmemory = "list", |
|
|
45 |
dontuse = "logical", |
|
|
46 |
dontusetmp = "logical", |
|
|
47 |
splitstatistics = "numeric" |
|
|
48 |
), |
|
|
49 |
validity = function(object) { |
|
|
50 |
ni <- length(dontuse) |
|
|
51 |
length(varmemory) == ni && length(dontusetmp) == ni |
|
|
52 |
} |
|
|
53 |
) |
|
|
54 |
|
|
|
55 |
|
|
|
56 |
############## Tree Classes ###################### |
|
|
57 |
|
|
|
58 |
setClassUnion("df_OR_list", c("data.frame", "list")) |
|
|
59 |
|
|
|
60 |
setClass(Class = "VariableControl", |
|
|
61 |
representation = representation( |
|
|
62 |
teststat = "factor", |
|
|
63 |
pvalue = "logical", |
|
|
64 |
tol = "numeric", |
|
|
65 |
maxpts = "integer", |
|
|
66 |
abseps = "numeric", |
|
|
67 |
releps = "numeric" |
|
|
68 |
), |
|
|
69 |
prototype = list( |
|
|
70 |
teststat = factor("max", levels = c("max", "quad")), |
|
|
71 |
pvalue = as.logical(TRUE), |
|
|
72 |
tol = as.double(1e-10), |
|
|
73 |
maxpts = as.integer(25000), |
|
|
74 |
abseps = as.double(1e-4), |
|
|
75 |
releps = as.double(0.0) |
|
|
76 |
) |
|
|
77 |
) |
|
|
78 |
|
|
|
79 |
setClass(Class = "SplitControl", |
|
|
80 |
representation = representation( |
|
|
81 |
minprob = "numeric", |
|
|
82 |
minsplit = "numeric", |
|
|
83 |
minbucket = "numeric", |
|
|
84 |
tol = "numeric", |
|
|
85 |
maxsurrogate = "integer" |
|
|
86 |
), |
|
|
87 |
prototype = list(minprob = as.double(0.01), |
|
|
88 |
minsplit = as.double(20), |
|
|
89 |
minbucket = as.double(7), |
|
|
90 |
tol = as.double(1e-10), |
|
|
91 |
maxsurrogate = as.integer(0) |
|
|
92 |
), |
|
|
93 |
validity = function(object) { |
|
|
94 |
if (any(c(object@minsplit, object@minbucket, |
|
|
95 |
object@tol, object@maxsurrogate) < 0)) { |
|
|
96 |
warning("no negative values allowed in objects of class ", |
|
|
97 |
sQuote("SplitControl")) |
|
|
98 |
return(FALSE) |
|
|
99 |
} |
|
|
100 |
if (object@minprob < 0.01 || object@minprob > 0.99) { |
|
|
101 |
warning(sQuote("minprob"), " must be in (0.01, 0.99)") |
|
|
102 |
return(FALSE) |
|
|
103 |
} |
|
|
104 |
return(TRUE) |
|
|
105 |
} |
|
|
106 |
) |
|
|
107 |
|
|
|
108 |
setClass(Class = "GlobalTestControl", |
|
|
109 |
representation = representation( |
|
|
110 |
testtype = "factor", |
|
|
111 |
nresample = "integer", |
|
|
112 |
randomsplits = "logical", |
|
|
113 |
mtry = "integer", |
|
|
114 |
mincriterion = "numeric" |
|
|
115 |
), |
|
|
116 |
prototype = list( |
|
|
117 |
testtype = factor("Bonferroni", |
|
|
118 |
levels = c("Bonferroni", "MonteCarlo", "Aggregated", |
|
|
119 |
"Univariate", "Teststatistic")), |
|
|
120 |
nresample = as.integer(9999), |
|
|
121 |
randomsplits = as.logical(FALSE), |
|
|
122 |
mtry = as.integer(0), |
|
|
123 |
mincriterion = as.double(0.95) |
|
|
124 |
), |
|
|
125 |
validity = function(object) { |
|
|
126 |
if (object@mincriterion < 0) { |
|
|
127 |
warning(sQuote("mincriterion"), " must not be negative") |
|
|
128 |
return(FALSE) |
|
|
129 |
} |
|
|
130 |
if (any(object@mtry < 0)) { |
|
|
131 |
warning(sQuote("mtry"), " must be positive") |
|
|
132 |
return(FALSE) |
|
|
133 |
} |
|
|
134 |
if (object@nresample < 100) { |
|
|
135 |
warning(sQuote("nresample"), " must be larger than 100") |
|
|
136 |
return(FALSE) |
|
|
137 |
} |
|
|
138 |
return(TRUE) |
|
|
139 |
}, |
|
|
140 |
) |
|
|
141 |
|
|
|
142 |
setClass(Class = "TreeGrowControl", |
|
|
143 |
representation = representation( |
|
|
144 |
stump = "logical", |
|
|
145 |
varOnce = "logical", |
|
|
146 |
maxdepth = "integer", |
|
|
147 |
savesplitstats = "logical" |
|
|
148 |
), |
|
|
149 |
prototype = list(stump = as.logical(FALSE), |
|
|
150 |
varOnce = as.logical(FALSE), |
|
|
151 |
maxdepth = as.integer(0), |
|
|
152 |
savesplitstats = as.logical(TRUE)), |
|
|
153 |
validity = function(object) { |
|
|
154 |
if (object@maxdepth < 0) { |
|
|
155 |
warning(sQuote("maxdepth"), " must be positive") |
|
|
156 |
return(FALSE) |
|
|
157 |
} |
|
|
158 |
return(TRUE) |
|
|
159 |
} |
|
|
160 |
) |
|
|
161 |
|
|
|
162 |
setClass(Class = "TreeControl", |
|
|
163 |
representation = representation( |
|
|
164 |
varctrl = "VariableControl", |
|
|
165 |
splitctrl = "SplitControl", |
|
|
166 |
gtctrl = "GlobalTestControl", |
|
|
167 |
tgctrl = "TreeGrowControl" |
|
|
168 |
), |
|
|
169 |
prototype = list(varctrl = new("VariableControl"), |
|
|
170 |
splitctrl = new("SplitControl"), |
|
|
171 |
gtctrl = new("GlobalTestControl"), |
|
|
172 |
tgctrl = new("TreeGrowControl") |
|
|
173 |
), |
|
|
174 |
validity = function(object) { |
|
|
175 |
(validObject(object@varctrl) && |
|
|
176 |
validObject(object@splitctrl)) && |
|
|
177 |
(validObject(object@gtctrl) && |
|
|
178 |
validObject(object@tgctrl)) |
|
|
179 |
} |
|
|
180 |
) |
|
|
181 |
|
|
|
182 |
setClass(Class = "ForestControl", |
|
|
183 |
representation = representation( |
|
|
184 |
ntree = "integer", |
|
|
185 |
replace = "logical", |
|
|
186 |
fraction = "numeric", |
|
|
187 |
trace = "logical", |
|
|
188 |
dropcriterion = "logical", |
|
|
189 |
compress = "function", |
|
|
190 |
expand = "function"), |
|
|
191 |
contains = "TreeControl", |
|
|
192 |
validity = function(object) { |
|
|
193 |
if (object@ntree < 1) { |
|
|
194 |
warning(sQuote("ntree"), " must be equal or greater 1") |
|
|
195 |
return(FALSE) |
|
|
196 |
} |
|
|
197 |
if (object@fraction < 0.01 || object@fraction > 0.99) { |
|
|
198 |
warning(sQuote("fraction"), " must be in (0.01, 0.99)") |
|
|
199 |
return(FALSE) |
|
|
200 |
} |
|
|
201 |
return(TRUE) |
|
|
202 |
} |
|
|
203 |
) |
|
|
204 |
|
|
|
205 |
setClass(Class = "VariableFrame", |
|
|
206 |
representation = representation( |
|
|
207 |
variables = "df_OR_list", |
|
|
208 |
transformations = "list", |
|
|
209 |
is_nominal = "logical", |
|
|
210 |
is_ordinal = "logical", |
|
|
211 |
is_censored = "logical", |
|
|
212 |
ordering = "list", |
|
|
213 |
levels = "list", |
|
|
214 |
scores = "list", |
|
|
215 |
has_missings = "logical", |
|
|
216 |
whichNA = "list", |
|
|
217 |
nobs = "integer", |
|
|
218 |
ninputs = "integer") |
|
|
219 |
) |
|
|
220 |
|
|
|
221 |
setClass(Class = "ResponseFrame", |
|
|
222 |
representation = representation( |
|
|
223 |
test_trafo = "matrix", |
|
|
224 |
predict_trafo = "matrix" |
|
|
225 |
), contains = "VariableFrame" |
|
|
226 |
) |
|
|
227 |
|
|
|
228 |
setClass(Class = "LearningSample", |
|
|
229 |
representation = representation( |
|
|
230 |
responses = "ResponseFrame", |
|
|
231 |
inputs = "VariableFrame", |
|
|
232 |
weights = "numeric", |
|
|
233 |
nobs = "integer", |
|
|
234 |
ninputs = "integer" |
|
|
235 |
) |
|
|
236 |
) |
|
|
237 |
|
|
|
238 |
setClass(Class = "LearningSampleFormula", |
|
|
239 |
representation = representation( |
|
|
240 |
menv = "ModelEnv" |
|
|
241 |
), contains = "LearningSample" |
|
|
242 |
) |
|
|
243 |
|
|
|
244 |
### the tree structure itself is a list, |
|
|
245 |
### and we need to make sure that the tree slot excepts |
|
|
246 |
### the S3 classes. |
|
|
247 |
setClass(Class = "SplittingNode", contains = "list") |
|
|
248 |
setClass(Class = "TerminalNode", contains = "list") |
|
|
249 |
setClass(Class = "TerminalModelNode", contains = "list") |
|
|
250 |
setClass(Class = "orderedSplit", contains = "list") |
|
|
251 |
setClass(Class = "nominalSplit", contains = "list") |
|
|
252 |
|
|
|
253 |
### and we don't want to see warnings that class `Surv' |
|
|
254 |
### (S3 method in `survival') is unknown |
|
|
255 |
setClass(Class = "Surv", contains = "list") |
|
|
256 |
|
|
|
257 |
|
|
|
258 |
### A class for partitions induced by recursive binary splits |
|
|
259 |
setClass(Class = "BinaryTreePartition", |
|
|
260 |
representation = representation( |
|
|
261 |
tree = "list", # the basic tree structure as (named or |
|
|
262 |
# unnamed) list |
|
|
263 |
where = "integer", # the nodeID of the observations in the |
|
|
264 |
# learning sample |
|
|
265 |
weights = "numeric" # the weights in the root node |
|
|
266 |
), |
|
|
267 |
) |
|
|
268 |
|
|
|
269 |
### A class for binary trees |
|
|
270 |
setClass(Class = "BinaryTree", |
|
|
271 |
representation = representation( |
|
|
272 |
data = "ModelEnv", |
|
|
273 |
responses = "VariableFrame", # a list of response `variables' |
|
|
274 |
# for computing predictions |
|
|
275 |
cond_distr_response = "function", # predict distribtion |
|
|
276 |
predict_response = "function", # predict responses |
|
|
277 |
prediction_weights = "function", # prediction weights |
|
|
278 |
get_where = "function", # node numbers |
|
|
279 |
update = "function" # update weights |
|
|
280 |
), |
|
|
281 |
contains = "BinaryTreePartition" |
|
|
282 |
) |
|
|
283 |
|
|
|
284 |
### A class for random forest |
|
|
285 |
setClass(Class = "RandomForest", |
|
|
286 |
representation = representation( |
|
|
287 |
ensemble = "list", |
|
|
288 |
where = "list", |
|
|
289 |
weights = "list", |
|
|
290 |
initweights = "numeric", |
|
|
291 |
data = "ModelEnv", |
|
|
292 |
responses = "VariableFrame", # a list of response `variables' |
|
|
293 |
# for computing predictions |
|
|
294 |
cond_distr_response = "function", # predict distribtion |
|
|
295 |
predict_response = "function", # predict responses |
|
|
296 |
prediction_weights = "function", # prediction weights |
|
|
297 |
get_where = "function", # node numbers |
|
|
298 |
update = "function", # update weights |
|
|
299 |
expand = "function" # function to invert compress operation |
|
|
300 |
) |
|
|
301 |
) |
|
|
302 |
|