Diff of /ClassificationCNN.m [000000] .. [e6b7a4]

Switch to unified view

a b/ClassificationCNN.m
1
clear;clc;
2
%% 载入数据;
3
fprintf('Loading data...\n');
4
tic;
5
load('N_dat.mat');
6
load('L_dat.mat');
7
load('R_dat.mat');
8
load('V_dat.mat');
9
fprintf('Finished!\n');
10
toc;
11
fprintf('=============================================================\n');
12
%% 控制使用数据量,每一类5000,并生成标签,one-hot编码;
13
fprintf('Data preprocessing...\n');
14
tic;
15
Nb=Nb(1:5000,:);Label1=repmat([1;0;0;0],1,5000);
16
Vb=Vb(1:5000,:);Label2=repmat([0;1;0;0],1,5000);
17
Rb=Rb(1:5000,:);Label3=repmat([0;0;1;0],1,5000);
18
Lb=Lb(1:5000,:);Label4=repmat([0;0;0;1],1,5000);
19
20
Data=[Nb;Vb;Rb;Lb];
21
Label=[Label1,Label2,Label3,Label4];
22
23
clear Nb;clear Label1;
24
clear Rb;clear Label2;
25
clear Lb;clear Label3;
26
clear Vb;clear Label4;
27
Data=Data-repmat(mean(Data,2),1,250); %使信号的均值为0,去掉基线的影响;
28
fprintf('Finished!\n');
29
toc;
30
fprintf('=============================================================\n');
31
32
%% 数据划分与模型训练测试;
33
fprintf('Model training and testing...\n');
34
Nums=randperm(20000);      %随机打乱样本顺序,达到随机选择训练测试样本的目的;
35
train_x=Data(Nums(1:10000),:);
36
test_x=Data(Nums(10001:end),:);
37
train_y=Label(:,Nums(1:10000));
38
test_y=Label(:,Nums(10001:end));
39
train_x=train_x';
40
test_x=test_x';
41
42
cnn.layers = {
43
    struct('type', 'i') %input layer
44
    struct('type', 'c', 'outputmaps', 4, 'kernelsize', 31,'actv','relu') %convolution layer
45
    struct('type', 's', 'scale', 5,'pool','mean') %sub sampling layer
46
    struct('type', 'c', 'outputmaps', 8, 'kernelsize', 6,'actv','relu') %convolution layer
47
    struct('type', 's', 'scale', 3,'pool','mean') %subsampling layer
48
};
49
cnn.output = 'softmax';  %确定cnn结构;
50
                         %确定超参数;
51
opts.alpha = 0.01;       %学习率;
52
opts.batchsize = 16;     %batch块大小;
53
opts.numepochs = 30;     %迭代epoch;
54
55
cnn = cnnsetup1d(cnn, train_x, train_y);      %建立1D CNN;
56
cnn = cnntrain1d(cnn, train_x, train_y,opts); %训练1D CNN;
57
[er,bad,out] = cnntest1d(cnn, test_x, test_y);%测试1D CNN;
58
59
[~,ptest]=max(out,[],1);
60
[~,test_yt]=max(test_y,[],1);
61
62
Correct_Predict=zeros(1,4);                     %统计各类准确率;
63
Class_Num=zeros(1,4);                           %并得到混淆矩阵;
64
Conf_Mat=zeros(4);
65
for i=1:10000
66
    Class_Num(test_yt(i))=Class_Num(test_yt(i))+1;
67
    Conf_Mat(test_yt(i),ptest(i))=Conf_Mat(test_yt(i),ptest(i))+1;
68
    if ptest(i)==test_yt(i)
69
        Correct_Predict(test_yt(i))= Correct_Predict(test_yt(i))+1;
70
    end
71
end
72
73
ACCs=Correct_Predict./Class_Num;
74
fprintf('Accuracy = %.2f%%\n',(1-er)*100);
75
fprintf('Accuracy_N = %.2f%%\n',ACCs(1)*100);
76
fprintf('Accuracy_V = %.2f%%\n',ACCs(2)*100);
77
fprintf('Accuracy_R = %.2f%%\n',ACCs(3)*100);
78
fprintf('Accuracy_L = %.2f%%\n',ACCs(4)*100);