Switch to unified view

a b/Semantic Features/CrossValLearn.m
1
function [ classValue, regError, successRate, trainedStruct ] = CrossValLearn2(X, Y, trainFunc, evalFunc)
2
%CrossValLearn does cross validation learning 
3
%  You need to supply it the function used to run the training and make the
4
%  prediction. Hardcoded to do 10 folds
5
6
%Perform random sampling by just jumbling up the data then slicing the new
7
%set into 4ths or nths.
8
divisions = 10;
9
numSamples = size(X,1);
10
testSize = round(numSamples/divisions);
11
12
%get a random order of our rows
13
randomRows = randsample(numSamples, numSamples);
14
15
%get vector of row order to undo the scrambling of the rows
16
for i = 1:numSamples
17
    restoreRows(i) = find(i == randomRows);
18
end
19
20
Xmixed = X(randomRows,:);
21
Ymixed = Y(randomRows,:);
22
23
%perform process repeatedly with the test set different each time untill
24
%all are covered.
25
classValue = 0;
26
testrows = cell(divisions,1);
27
trainedStruct = cell(divisions,1);
28
for i = 1:(divisions - 1) %perform all iterations guaranteeed to have a full share
29
    %start with testing at the beginning rows, then cycle down
30
    testrows{i} = [(i-1)*testSize + 1:i*testSize];
31
    
32
    Xtest = Xmixed(testrows{i}, :);
33
    Ytest = Ymixed(testrows{i}, :);
34
    
35
    Xtrain = Xmixed;
36
    Xtrain(testrows{i},:) = [];
37
    Ytrain = Ymixed;
38
    Ytrain(testrows{i},:) = [];
39
    
40
    trainedStruct{i} = trainFunc(Xtrain, Ytrain);
41
    classValue = vertcat(classValue, evalFunc(Xtest, trainedStruct{i}));
42
end
43
%collect all the remaining rows. Could be undersized, but eliminates
44
%problems of some rows getting lost
45
testrows{divisions} = [(divisions-1)*testSize + 1:numSamples];
46
    
47
Xtest = Xmixed(testrows{divisions}, :);
48
Ytest = Ymixed(testrows{divisions}, :);
49
50
Xtrain = Xmixed;
51
Xtrain(testrows{divisions},:) = [];
52
Ytrain = Ymixed;
53
Ytrain(testrows{divisions},:) = [];
54
    
55
trainedStruct{divisions} = trainFunc(Xtrain, Ytrain);
56
classValue = vertcat(classValue(2:end,:), evalFunc(Xtest, trainedStruct{divisions})); %Chop off the zero we put at the beginning
57
58
%Resort everything to the original order so we can compare against other
59
%algorithms
60
classValue = classValue(restoreRows,:);
61
62
%perform RMSE on allll the samples
63
regError = RMSE(classValue, Y); %RMSE error. Maybe better as an array so we can combine in the future
64
65
successRate = sum(round(classValue) == round(Y)) / size(Y,1);
66
67
68
end
69
70