Switch to unified view

a b/Ensemble Learning/Boosting/gradient_boosting_predict.m
1
% Practicum, Task #3, 'Compositions of algorithms'.
2
%
3
% FUNCTION:
4
% [prediction, err] = gradient_boosting_predict (model, X, y)
5
%
6
% DESCRIPTION:
7
% This function use the composition of algorithms, trained with gradient 
8
% boosting method, for prediction.
9
%
10
% INPUT:
11
% X --- matrix of objects, N x K double matrix, N --- number of objects, 
12
%       K --- number of features.
13
% y --- vector of answers, N x 1 double vector, N --- number of objects. y
14
%       can have only two values --- +1 and -1.
15
% model --- trained composition.
16
%
17
% OUTPUT:
18
% prediction --- vector of predicted answers, N x 1 double vector.
19
% error --- the ratio of number of correct answers to number of objects on
20
%           each iteration, num_iterations x 1 vector
21
%
22
% AUTHOR: 
23
% Murat Apishev (great-mel@yandex.ru)
24
%
25
26
function [prediction, err] = gradient_boosting_predict (model, X, y)
27
28
    num_iterations = length(model.weights);
29
    no_objects = length(y);
30
    pred_prediction = zeros([no_objects num_iterations]);
31
32
    for alg = 1 : num_iterations
33
        value = zeros([no_objects 1]) + model.b_0;
34
        for i = 1 : alg
35
            if strcmp(model.algorithm, 'epsilon_svr')
36
                value = value + svmpredict(y, X, model.models{i}) * model.weights(i);
37
            elseif strcmp(model.algorithm, 'regression_tree')
38
                value = value + predict(model.models{i}, X) * model.weights(i);
39
            end
40
        end
41
        pred_prediction(:,alg) = value;
42
    end
43
    prediction = pred_prediction(:,end);
44
    err = zeros([num_iterations 1]);
45
    if strcmp(model.loss, 'absolute')
46
        temp = (bsxfun(@minus, pred_prediction, y));
47
        err = abs(sum(temp)) / no_objects;
48
    elseif strcmp(model.loss, 'logistic')
49
        prediction = sign(prediction);
50
        temp = (bsxfun(@eq, sign(pred_prediction), y));
51
        err = sum(temp == 0) / no_objects;
52
    end
53
    if size(err, 1) == 1
54
        err = err';
55
    end
56
end