Diff of /partyMod/R/Classes.R [000000] .. [fbf06f]

Switch to unified view

a b/partyMod/R/Classes.R
1
2
# $Id$
3
4
### Linear statistic with expectation and covariance
5
setClass(Class = "LinStatExpectCovar",
6
    representation = representation(
7
        linearstatistic = "numeric",
8
        expcovinf = "ExpectCovarInfluence"
9
    ),
10
    contains = "ExpectCovar"
11
)
12
13
### Memory for C_svd
14
setClass(Class = "svd_mem",
15
    representation = representation(
16
        method = "character",
17
        jobu   = "character",
18
        jobv   = "character",
19
        u      = "matrix",
20
        v      = "matrix",
21
        s      = "numeric",
22
        p      = "integer"
23
    )
24
)
25
26
### with Moore-Penrose inverse of the covariance matrix
27
setClass(Class = "LinStatExpectCovarMPinv",
28
    representation = representation(
29
        MPinv  = "matrix",   
30
        rank   = "numeric",
31
        svdmem = "svd_mem"
32
    ), 
33
    contains = "LinStatExpectCovar"
34
)
35
36
################ Memory Classes #####################
37
38
setClass(Class = "TreeFitMemory",
39
    representation = representation(
40
        expcovinf         = "ExpectCovarInfluence",
41
        expcovinfss       = "ExpectCovarInfluence",
42
        linexpcov2sample  = "LinStatExpectCovar",
43
        weights           = "numeric",
44
        varmemory         = "list",
45
        dontuse           = "logical",
46
        dontusetmp        = "logical",
47
        splitstatistics   = "numeric"
48
    ), 
49
    validity = function(object) {
50
        ni <- length(dontuse)
51
        length(varmemory) == ni && length(dontusetmp) == ni
52
    }
53
)
54
55
56
##############  Tree Classes  ######################
57
58
setClassUnion("df_OR_list", c("data.frame", "list"))
59
60
setClass(Class = "VariableControl",
61
    representation = representation(
62
        teststat = "factor",
63
        pvalue   = "logical",
64
        tol      = "numeric",
65
        maxpts   = "integer",
66
        abseps   = "numeric",
67
        releps   = "numeric"
68
    ),
69
    prototype = list(
70
        teststat = factor("max", levels = c("max", "quad")),
71
        pvalue   = as.logical(TRUE),
72
        tol      = as.double(1e-10),
73
        maxpts   = as.integer(25000),
74
        abseps   = as.double(1e-4),
75
        releps   = as.double(0.0)
76
    )
77
)
78
79
setClass(Class = "SplitControl",
80
    representation = representation(
81
        minprob      = "numeric",
82
        minsplit     = "numeric",
83
        minbucket    = "numeric",
84
        tol          = "numeric",
85
        maxsurrogate = "integer"
86
    ),
87
    prototype = list(minprob = as.double(0.01), 
88
                     minsplit = as.double(20), 
89
                     minbucket = as.double(7), 
90
                     tol = as.double(1e-10), 
91
                     maxsurrogate = as.integer(0)
92
    ),
93
    validity = function(object) {
94
        if (any(c(object@minsplit, object@minbucket, 
95
                  object@tol, object@maxsurrogate) < 0)) {
96
            warning("no negative values allowed in objects of class ", 
97
                    sQuote("SplitControl"))
98
            return(FALSE)
99
        }
100
        if (object@minprob < 0.01 || object@minprob > 0.99) {
101
            warning(sQuote("minprob"), " must be in (0.01, 0.99)")
102
            return(FALSE)
103
        }
104
        return(TRUE)
105
    }
106
)
107
108
setClass(Class = "GlobalTestControl",
109
    representation = representation(
110
        testtype     = "factor",
111
        nresample    = "integer",
112
        randomsplits = "logical",
113
        mtry         = "integer",
114
        mincriterion = "numeric"
115
    ),
116
    prototype = list(
117
        testtype = factor("Bonferroni", 
118
            levels = c("Bonferroni", "MonteCarlo", "Aggregated", 
119
                       "Univariate", "Teststatistic")),
120
        nresample = as.integer(9999),
121
        randomsplits = as.logical(FALSE),
122
        mtry = as.integer(0),
123
        mincriterion = as.double(0.95)
124
    ),
125
    validity = function(object) {
126
        if (object@mincriterion < 0) {
127
            warning(sQuote("mincriterion"), " must not be negative")
128
            return(FALSE)
129
        }
130
        if (any(object@mtry < 0)) {
131
            warning(sQuote("mtry"), " must be positive")
132
            return(FALSE)
133
        }
134
        if (object@nresample < 100) {
135
            warning(sQuote("nresample"), " must be larger than 100")
136
            return(FALSE)
137
        }
138
        return(TRUE)
139
    },
140
)
141
142
setClass(Class = "TreeGrowControl",
143
    representation = representation(
144
        stump           = "logical",
145
        varOnce         = "logical",
146
        maxdepth        = "integer",
147
        savesplitstats  = "logical"
148
    ),
149
    prototype = list(stump = as.logical(FALSE), 
150
                     varOnce = as.logical(FALSE),
151
                     maxdepth = as.integer(0), 
152
                     savesplitstats = as.logical(TRUE)),
153
    validity = function(object) {
154
        if (object@maxdepth < 0) {
155
            warning(sQuote("maxdepth"), " must be positive")
156
            return(FALSE)
157
        }
158
        return(TRUE)
159
    }
160
)
161
162
setClass(Class = "TreeControl",
163
    representation = representation(
164
        varctrl   = "VariableControl",
165
        splitctrl = "SplitControl",
166
        gtctrl    = "GlobalTestControl",
167
        tgctrl    = "TreeGrowControl"
168
    ),
169
    prototype = list(varctrl = new("VariableControl"),
170
                     splitctrl = new("SplitControl"),
171
                     gtctrl = new("GlobalTestControl"),
172
                     tgctrl = new("TreeGrowControl")
173
    ),
174
    validity = function(object) {
175
        (validObject(object@varctrl) && 
176
        validObject(object@splitctrl)) &&
177
        (validObject(object@gtctrl) &&
178
        validObject(object@tgctrl))
179
    }
180
)
181
182
setClass(Class = "ForestControl",
183
    representation = representation(
184
        ntree    = "integer",
185
        replace  = "logical",
186
        fraction = "numeric",
187
        trace    = "logical",
188
        dropcriterion = "logical",
189
    compress = "function",
190
    expand = "function"),
191
    contains = "TreeControl",
192
    validity = function(object) {
193
        if (object@ntree < 1) {
194
            warning(sQuote("ntree"), " must be equal or greater 1")
195
            return(FALSE)
196
        }
197
        if (object@fraction < 0.01 || object@fraction > 0.99) {
198
            warning(sQuote("fraction"), " must be in (0.01, 0.99)")
199
            return(FALSE)
200
        }
201
        return(TRUE)
202
    }
203
)
204
205
setClass(Class = "VariableFrame",
206
    representation = representation(
207
        variables       = "df_OR_list", 
208
        transformations = "list", 
209
        is_nominal      = "logical", 
210
        is_ordinal      = "logical",
211
        is_censored     = "logical",
212
        ordering        = "list", 
213
        levels          = "list", 
214
        scores          = "list",
215
        has_missings    = "logical", 
216
        whichNA         = "list",
217
        nobs            = "integer",
218
        ninputs         = "integer")
219
)
220
221
setClass(Class = "ResponseFrame",
222
    representation = representation(
223
        test_trafo = "matrix",
224
        predict_trafo = "matrix"
225
    ), contains = "VariableFrame"
226
)   
227
228
setClass(Class = "LearningSample",
229
    representation = representation(
230
        responses = "ResponseFrame",
231
        inputs    = "VariableFrame",
232
        weights   = "numeric",
233
        nobs      = "integer",
234
        ninputs   = "integer"
235
    )
236
)
237
238
setClass(Class = "LearningSampleFormula",
239
    representation = representation(
240
        menv      = "ModelEnv"
241
    ), contains = "LearningSample"
242
)
243
244
### the tree structure itself is a list, 
245
### and we need to make sure that the tree slot excepts
246
### the S3 classes. 
247
setClass(Class = "SplittingNode", contains = "list")
248
setClass(Class = "TerminalNode", contains = "list")
249
setClass(Class = "TerminalModelNode", contains = "list")
250
setClass(Class = "orderedSplit", contains = "list")
251
setClass(Class = "nominalSplit", contains = "list")
252
253
### and we don't want to see warnings that class `Surv'
254
### (S3 method in `survival') is unknown
255
setClass(Class = "Surv", contains = "list")
256
257
258
### A class for partitions induced by recursive binary splits
259
setClass(Class = "BinaryTreePartition",
260
    representation = representation(
261
        tree     = "list",          # the basic tree structure as (named or
262
                                    # unnamed) list
263
        where    = "integer",       # the nodeID of the observations in the
264
                                    # learning sample
265
        weights  = "numeric"         # the weights in the root node
266
    ),
267
)
268
269
### A class for binary trees   
270
setClass(Class = "BinaryTree", 
271
    representation = representation(
272
        data                = "ModelEnv",
273
        responses           = "VariableFrame", # a list of response `variables'
274
                                               # for computing predictions
275
        cond_distr_response = "function",      # predict distribtion
276
        predict_response    = "function",      # predict responses
277
        prediction_weights  = "function",      # prediction weights
278
        get_where           = "function",      # node numbers
279
        update              = "function"       # update weights
280
    ),
281
    contains = "BinaryTreePartition"
282
)
283
284
### A class for random forest  
285
setClass(Class = "RandomForest", 
286
    representation = representation(
287
        ensemble            = "list",
288
        where               = "list",
289
        weights             = "list",
290
        initweights         = "numeric",
291
        data                = "ModelEnv",
292
        responses           = "VariableFrame", # a list of response `variables'
293
                                               # for computing predictions
294
        cond_distr_response = "function",      # predict distribtion
295
        predict_response    = "function",      # predict responses
296
        prediction_weights  = "function",      # prediction weights
297
        get_where           = "function",      # node numbers
298
    update              = "function",      # update weights
299
    expand              = "function"       # function to invert compress operation
300
    )
301
)
302