Switch to unified view

a b/Ensemble Learning/Bagging/bagging_train.m
1
% Practicum, Task #3, 'Compositions of algorithms'.
2
%
3
% FUNCTION:
4
% [model] = bagging_train (X, y, num_iterations, base_algorithm, ...
5
%                   param_name1, param_value1, param_name2, param_value2)
6
%
7
% DESCRIPTION:
8
% This function train the composition of algorithms using bagging method.
9
%
10
% INPUT:
11
% X --- matrix of objects, N x K double matrix, N --- number of objects, 
12
%       K --- number of features.
13
% y --- vector of answers, N x 1 double vector, N --- number of objects. y
14
%       can have only two values --- +1 and -1.
15
% num_iterations --- the number ob algorithms in composition, scalar.
16
% base_algorithm --- the base algorithm, string. Can have one of two
17
%                    values: 'classification_tree' or 'svm'.
18
% param_name1 --- parameter of base_algorithm. For 'classification_tree' it 
19
%            is a 'min_parent' --- min number of objects in the leaf of 
20
%            classification tree. For 'svm' it is 'gamma' parameter.
21
% param_name2 --- parameter, that exists only for 'svm', it is a 'C' 
22
%                 parameter.
23
% param_value1, param_value2 --- values of corresponding parametres,
24
%                                scalar.
25
% OUTPUT:
26
% model --- trained composition, structure with two fields
27
%       - models --- cell array with trained models
28
%       - algorithm --- string, 'svm' or 'classification_tree'
29
%
30
% AUTHOR: 
31
% Murat Apishev (great-mel@yandex.ru)
32
%
33
34
function [model] = bagging_train (X, y, num_iterations, base_algorithm, ...
35
                        param_name1, param_value1, param_name2, param_value2)
36
        
37
    no_objects = size(X, 1);
38
    models = cell([1 num_iterations]);
39
                    
40
    if strcmp(base_algorithm, 'svm')
41
        if ~strcmp(param_name1, 'gamma')
42
            temp = param_value1;
43
            param_value1 = param_value2;
44
            param_value2 = temp;
45
        end
46
47
        for iter = 1 : num_iterations
48
            indices = randi(no_objects, 1, no_objects);
49
            indices = unique(indices);
50
            models{iter} = svmtrain(y(indices), X(indices,:), ...
51
                [' -g ', num2str(param_value1), ' -c ', num2str(param_value2)]);
52
        end
53
    elseif strcmp(base_algorithm, 'classification_tree')
54
        for iter = 1 : num_iterations
55
            indices = randi(no_objects, 1, no_objects);
56
            indices = unique(indices);
57
            if (param_value1 > length(indices))
58
                value = length(indices);
59
            else
60
                value = param_value1;
61
            end
62
            models{iter} = ClassificationTree.fit(X(indices,:), y(indices), 'MinParent', value);
63
        end        
64
    else
65
        error('Incorrect type of algorithm!');
66
    end
67
    
68
    model.models = models;
69
    model.algorithm = base_algorithm;
70
end