[e6b7a4]: / ClassificationCNN.m

Download this file

79 lines (70 with data), 2.8 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
clear;clc;
%% 载入数据
fprintf('Loading data...\n');
tic;
load('N_dat.mat');
load('L_dat.mat');
load('R_dat.mat');
load('V_dat.mat');
fprintf('Finished!\n');
toc;
fprintf('=============================================================\n');
%% 控制使用数据量每一类5000并生成标签,one-hot编码
fprintf('Data preprocessing...\n');
tic;
Nb=Nb(1:5000,:);Label1=repmat([1;0;0;0],1,5000);
Vb=Vb(1:5000,:);Label2=repmat([0;1;0;0],1,5000);
Rb=Rb(1:5000,:);Label3=repmat([0;0;1;0],1,5000);
Lb=Lb(1:5000,:);Label4=repmat([0;0;0;1],1,5000);
Data=[Nb;Vb;Rb;Lb];
Label=[Label1,Label2,Label3,Label4];
clear Nb;clear Label1;
clear Rb;clear Label2;
clear Lb;clear Label3;
clear Vb;clear Label4;
Data=Data-repmat(mean(Data,2),1,250); %使信号的均值为0去掉基线的影响
fprintf('Finished!\n');
toc;
fprintf('=============================================================\n');
%% 数据划分与模型训练测试
fprintf('Model training and testing...\n');
Nums=randperm(20000); %随机打乱样本顺序达到随机选择训练测试样本的目的
train_x=Data(Nums(1:10000),:);
test_x=Data(Nums(10001:end),:);
train_y=Label(:,Nums(1:10000));
test_y=Label(:,Nums(10001:end));
train_x=train_x';
test_x=test_x';
cnn.layers = {
struct('type', 'i') %input layer
struct('type', 'c', 'outputmaps', 4, 'kernelsize', 31,'actv','relu') %convolution layer
struct('type', 's', 'scale', 5,'pool','mean') %sub sampling layer
struct('type', 'c', 'outputmaps', 8, 'kernelsize', 6,'actv','relu') %convolution layer
struct('type', 's', 'scale', 3,'pool','mean') %subsampling layer
};
cnn.output = 'softmax'; %确定cnn结构
%确定超参数
opts.alpha = 0.01; %学习率
opts.batchsize = 16; %batch块大小
opts.numepochs = 30; %迭代epoch
cnn = cnnsetup1d(cnn, train_x, train_y); %建立1D CNN;
cnn = cnntrain1d(cnn, train_x, train_y,opts); %训练1D CNN;
[er,bad,out] = cnntest1d(cnn, test_x, test_y);%测试1D CNN;
[~,ptest]=max(out,[],1);
[~,test_yt]=max(test_y,[],1);
Correct_Predict=zeros(1,4); %统计各类准确率
Class_Num=zeros(1,4); %并得到混淆矩阵
Conf_Mat=zeros(4);
for i=1:10000
Class_Num(test_yt(i))=Class_Num(test_yt(i))+1;
Conf_Mat(test_yt(i),ptest(i))=Conf_Mat(test_yt(i),ptest(i))+1;
if ptest(i)==test_yt(i)
Correct_Predict(test_yt(i))= Correct_Predict(test_yt(i))+1;
end
end
ACCs=Correct_Predict./Class_Num;
fprintf('Accuracy = %.2f%%\n',(1-er)*100);
fprintf('Accuracy_N = %.2f%%\n',ACCs(1)*100);
fprintf('Accuracy_V = %.2f%%\n',ACCs(2)*100);
fprintf('Accuracy_R = %.2f%%\n',ACCs(3)*100);
fprintf('Accuracy_L = %.2f%%\n',ACCs(4)*100);