|
a |
|
b/combinedDeepLearningActiveContour/functions/mrTrain.m |
|
|
1 |
function [softmaxModel] = mrTrain(inputSize, outSize, lambda, inputData, labels, options) |
|
|
2 |
%softmaxTrain Train a softmax model with the given parameters on the given |
|
|
3 |
% data. Returns softmaxOptTheta, a vector containing the trained parameters |
|
|
4 |
% for the model. |
|
|
5 |
% |
|
|
6 |
% inputSize: the size of an input vector x^(i) |
|
|
7 |
% numClasses: the number of classes |
|
|
8 |
% lambda: weight decay parameter |
|
|
9 |
% inputData: an N by M matrix containing the input data, such that |
|
|
10 |
% inputData(:, c) is the cth input |
|
|
11 |
% labels: M by 1 matrix containing the class labels for the |
|
|
12 |
% corresponding inputs. labels(c) is the class label for |
|
|
13 |
% the cth input |
|
|
14 |
% options (optional): options |
|
|
15 |
% options.maxIter: number of iterations to train for |
|
|
16 |
|
|
|
17 |
if ~exist('options', 'var') |
|
|
18 |
options = struct; |
|
|
19 |
end |
|
|
20 |
|
|
|
21 |
if ~isfield(options, 'maxIter') |
|
|
22 |
options.maxIter = 400; |
|
|
23 |
end |
|
|
24 |
|
|
|
25 |
% initialize parameters |
|
|
26 |
theta = 0.005 * randn(outSize * inputSize, 1); |
|
|
27 |
|
|
|
28 |
% Use minFunc to minimize the function |
|
|
29 |
addpath minFunc/ |
|
|
30 |
options.Method = 'lbfgs'; % Here, we use L-BFGS to optimize our cost |
|
|
31 |
% function. Generally, for minFunc to work, you |
|
|
32 |
% need a function pointer with two outputs: the |
|
|
33 |
% function value and the gradient. In our problem, |
|
|
34 |
% softmaxCost.m satisfies this. |
|
|
35 |
options.display = 'on'; |
|
|
36 |
|
|
|
37 |
[softmaxOptTheta, cost] = minFunc( @(p) mrCost(p, ... |
|
|
38 |
outSize, inputSize, lambda, ... |
|
|
39 |
inputData, labels), ... |
|
|
40 |
theta, options); |
|
|
41 |
|
|
|
42 |
% Fold softmaxOptTheta into a nicer format |
|
|
43 |
softmaxModel.optTheta = reshape(softmaxOptTheta, outSize, inputSize); |
|
|
44 |
softmaxModel.inputSize = inputSize; |
|
|
45 |
softmaxModel.numClasses = outSize; |
|
|
46 |
|
|
|
47 |
end |