Switch to unified view

a b/classification/SMOTEBoost/SMOTEBoost.m
1
function prediction = SMOTEBoost (TRAIN,TEST,WeakLearn,ClassDist)
2
% This function implements the SMOTEBoost Algorithm. For more details on the 
3
% theoretical description of the algorithm please refer to the following 
4
% paper:
5
% N.V. Chawla, A.Lazarevic, L.O. Hall, K. Bowyer, "SMOTEBoost: Improving 
6
% Prediction of Minority Class in Boosting, Journal of Knowledge Discovery
7
% in Databases: PKDD, 2003.
8
% Input: TRAIN = Training data as matrix
9
%        TEST = Test data as matrix
10
%        WeakLearn = String to choose algortihm. Choices are
11
%                    'svm','tree','knn' and 'logistic'.
12
%        ClassDist = true or false. true indicates that the class
13
%                    distribution is maintained while doing weighted 
14
%                    resampling and before SMOTE is called at each 
15
%                    iteration. false indicates that the class distribution
16
%                    is not maintained while resampling.
17
% Output: prediction = size(TEST,1)x 2 matrix. Col 1 is class labels for 
18
%                      all instances. Col 2 is probability of the instances 
19
%                      being classified as positive class.
20
21
javaaddpath('weka.jar');
22
23
%% Training SMOTEBoost
24
% Total number of instances in the training set
25
m = size(TRAIN,1);
26
POS_DATA = TRAIN(TRAIN(:,end)==1,:);
27
NEG_DATA = TRAIN(TRAIN(:,end)==0,:);
28
pos_size = size(POS_DATA,1);
29
neg_size = size(NEG_DATA,1);
30
31
% Reorganize TRAIN by putting all the positive and negative exampels
32
% together, respectively.
33
TRAIN = [POS_DATA;NEG_DATA];
34
35
% Converting training set into Weka compatible format
36
CSVtoARFF (TRAIN, 'train', 'train');
37
train_reader = javaObject('java.io.FileReader', 'train.arff');
38
train = javaObject('weka.core.Instances', train_reader);
39
train.setClassIndex(train.numAttributes() - 1);
40
    
41
% Total number of iterations of the boosting method
42
T = 10;
43
44
% W stores the weights of the instances in each row for every iteration of
45
% boosting. Weights for all the instances are initialized by 1/m for the
46
% first iteration.
47
W = zeros(1,m);
48
for i = 1:m
49
    W(1,i) = 1/m;
50
end
51
52
% L stores pseudo loss values, H stores hypothesis, B stores (1/beta) 
53
% values that is used as the weight of the % hypothesis while forming the 
54
% final hypothesis. % All of the following are of length <=T and stores 
55
% values for every iteration of the boosting process.
56
L = [];
57
H = {};
58
B = [];
59
60
% Loop counter
61
t = 1;
62
63
% Keeps counts of the number of times the same boosting iteration have been
64
% repeated
65
count = 0;
66
67
% Boosting T iterations
68
while t <= T
69
    
70
    % LOG MESSAGE
71
    disp (['Boosting iteration #' int2str(t)]);
72
    
73
    if ClassDist == true
74
        % Resampling POS_DATA with weights of positive example
75
        POS_WT = zeros(1,pos_size);
76
        sum_POS_WT = sum(W(t,1:pos_size));
77
        for i = 1:pos_size
78
           POS_WT(i) = W(t,i)/sum_POS_WT ;
79
        end
80
        RESAM_POS = POS_DATA(randsample(1:pos_size,pos_size,true,POS_WT),:);
81
82
        % Resampling NEG_DATA with weights of positive example
83
        NEG_WT = zeros(1,neg_size);
84
        sum_NEG_WT = sum(W(t,pos_size+1:m));
85
        for i = 1:neg_size
86
           NEG_WT(i) = W(t,pos_size+i)/sum_NEG_WT ;
87
        end
88
        RESAM_NEG = NEG_DATA(randsample(1:neg_size,neg_size,true,NEG_WT),:);
89
    
90
        % Resampled TRAIN is stored in RESAMPLED
91
        RESAMPLED = [RESAM_POS;RESAM_NEG];
92
        
93
        % Calulating the percentage of boosting the positive class. 'pert'
94
        % is used as a parameter of SMOTE
95
        pert = ((neg_size-pos_size)/pos_size)*100;
96
    else 
97
        % Indices of resampled train
98
        RND_IDX = randsample(1:m,m,true,W(t,:));
99
        
100
        % Resampled TRAIN is stored in RESAMPLED
101
        RESAMPLED = TRAIN(RND_IDX,:);
102
        
103
        % Calulating the percentage of boosting the positive class. 'pert'
104
        % is used as a parameter of SMOTE
105
        pos_size = sum(RESAMPLED(:,end)==1);
106
        neg_size = sum(RESAMPLED(:,end)==0);
107
        pert = ((neg_size-pos_size)/pos_size)*100;
108
    end
109
    
110
    % Converting resample training set into Weka compatible format
111
    CSVtoARFF (RESAMPLED,'resampled','resampled');
112
    reader = javaObject('java.io.FileReader','resampled.arff');
113
    resampled = javaObject('weka.core.Instances',reader);
114
    resampled.setClassIndex(resampled.numAttributes()-1);
115
    
116
    % New SMOTE boosted data gets stored in S
117
    smote = javaObject('weka.filters.supervised.instance.SMOTE');
118
    pert = ((neg_size-pos_size)/pos_size)*100;
119
    smote.setPercentage(pert);
120
    smote.setInputFormat(resampled);
121
    
122
    S = weka.filters.Filter.useFilter(resampled, smote);
123
    
124
    % Training a weak learner. 'pred' is the weak hypothesis. However, the 
125
    % hypothesis function is encoded in 'model'.
126
    switch WeakLearn
127
        case 'svm'
128
            model = javaObject('weka.classifiers.functions.SMO');
129
        case 'tree'
130
            model = javaObject('weka.classifiers.trees.J48');
131
        case 'knn'
132
            model = javaObject('weka.classifiers.lazy.IBk');
133
            model.setKNN(5);
134
        case 'logistic'
135
            model = javaObject('weka.classifiers.functions.Logistic');
136
    end
137
    model.buildClassifier(S);
138
    
139
    pred = zeros(m,1);
140
    for i = 0 : m - 1
141
        pred(i+1) = model.classifyInstance(train.instance(i));
142
    end
143
144
    % Computing the pseudo loss of hypothesis 'model'
145
    loss = 0;
146
    for i = 1:m
147
        if TRAIN(i,end)==pred(i)
148
            continue;
149
        else
150
            loss = loss + W(t,i);
151
        end
152
    end
153
    
154
    % If count exceeds a pre-defined threshold (5 in the current
155
    % implementation), the loop is broken and rolled back to the state
156
    % where loss > 0.5 was not encountered.
157
    if count > 5
158
       L = L(1:t-1);
159
       H = H(1:t-1);
160
       B = B(1:t-1);
161
       disp ('          Too many iterations have loss > 0.5');
162
       disp ('          Aborting boosting...');
163
       break;
164
    end
165
    
166
    % If the loss is greater than 1/2, it means that an inverted
167
    % hypothesis would perform better. In such cases, do not take that
168
    % hypothesis into consideration and repeat the same iteration. 'count'
169
    % keeps counts of the number of times the same boosting iteration have
170
    % been repeated
171
    if loss > 0.5
172
        count = count + 1;
173
        continue;
174
    else
175
        count = 1;
176
    end        
177
    
178
    L(t) = loss; % Pseudo-loss at each iteration
179
    H{t} = model; % Hypothesis function   
180
    beta = loss/(1-loss); % Setting weight update parameter 'beta'.
181
    B(t) = log(1/beta); % Weight of the hypothesis
182
    
183
    % At the final iteration there is no need to update the weights any
184
    % further
185
    if t==T
186
        break;
187
    end
188
    
189
    % Updating weight    
190
    for i = 1:m
191
        if TRAIN(i,end)==pred(i)
192
            W(t+1,i) = W(t,i)*beta;
193
        else
194
            W(t+1,i) = W(t,i);
195
        end
196
    end
197
    
198
    % Normalizing the weight for the next iteration
199
    sum_W = sum(W(t+1,:));
200
    for i = 1:m
201
        W(t+1,i) = W(t+1,i)/sum_W;
202
    end
203
    
204
    % Incrementing loop counter
205
    t = t + 1;
206
end
207
208
% The final hypothesis is calculated and tested on the test set
209
% simulteneously.
210
211
%% Testing SMOTEBoost
212
n = size(TEST,1); % Total number of instances in the test set
213
214
CSVtoARFF(TEST,'test','test');
215
test = 'test.arff';
216
test_reader = javaObject('java.io.FileReader', test);
217
test = javaObject('weka.core.Instances', test_reader);
218
test.setClassIndex(test.numAttributes() - 1);
219
220
% Normalizing B
221
sum_B = sum(B);
222
for i = 1:size(B,2)
223
   B(i) = B(i)/sum_B;
224
end
225
226
prediction = zeros(n,2);
227
228
for i = 1:n
229
    % Calculating the total weight of the class labels from all the models
230
    % produced during boosting
231
    wt_zero = 0;
232
    wt_one = 0;
233
    for j = 1:size(H,2)
234
       p = H{j}.classifyInstance(test.instance(i-1));      
235
       if p==1
236
           wt_one = wt_one + B(j);
237
       else 
238
           wt_zero = wt_zero + B(j);           
239
       end
240
    end
241
    
242
    if (wt_one > wt_zero)
243
        prediction(i,:) = [1 wt_one];
244
    else
245
        prediction(i,:) = [0 wt_one];
246
    end
247
end