|
a |
|
b/deep_learning_multiscale_lstm.m |
|
|
1 |
% Multiscale LSTM combining fine-grained short-term dynamics (each 6 h time |
|
|
2 |
% block) and coarst long-term dynsmics (all past recordings up to now). |
|
|
3 |
% Wei-Long Zheng, MGH |
|
|
4 |
% Email: weilonglive@gmail.com |
|
|
5 |
|
|
|
6 |
% short term lstms |
|
|
7 |
clear |
|
|
8 |
close all |
|
|
9 |
load('feature_sequences\all_features_sequences.mat') |
|
|
10 |
save_path = ('multi_scale\'); |
|
|
11 |
for num_neuron = 30 |
|
|
12 |
preds_all = {}; |
|
|
13 |
labels_all = {}; |
|
|
14 |
pred_probability_all = {}; |
|
|
15 |
net_model_all = {}; |
|
|
16 |
pts_id_all = {}; |
|
|
17 |
|
|
|
18 |
for itrial = 1:10 |
|
|
19 |
time_step = 6; |
|
|
20 |
time_range = 12:time_step:96; |
|
|
21 |
num_fold = 5; |
|
|
22 |
|
|
|
23 |
preds = cell(length(time_range),num_fold); |
|
|
24 |
labels = cell(length(time_range),num_fold); |
|
|
25 |
pred_probability = cell(length(time_range),num_fold); |
|
|
26 |
net_model = cell(length(time_range),num_fold); |
|
|
27 |
for i = length(time_range)-5%1:length(time_range) |
|
|
28 |
|
|
|
29 |
% short sequneces |
|
|
30 |
bs_x = bs(:,time_range(i)-time_step+1:time_range(i),:); |
|
|
31 |
bs_x = permute(bs_x,[1,3,2]); |
|
|
32 |
bs_x = reshape(bs_x,[size(bs_x,1),size(bs_x,2)*size(bs_x,3)]); |
|
|
33 |
spike_x = spike(:,time_range(i)-time_step+1:time_range(i),:); |
|
|
34 |
spike_x = permute(spike_x,[1,3,2]); |
|
|
35 |
spike_x = reshape(spike_x,[size(spike_x,1),size(spike_x,2)*size(spike_x,3)]); |
|
|
36 |
x = cat(3, bs_x, spike_x); |
|
|
37 |
|
|
|
38 |
for ifea = 1:7 |
|
|
39 |
feature_tmp = features{ifea}(:,time_range(i)-time_step+1:time_range(i),:); |
|
|
40 |
feature_tmp = permute(feature_tmp,[1,3,2]); |
|
|
41 |
feature_tmp = reshape(feature_tmp,[size(feature_tmp,1),size(feature_tmp,2)*size(feature_tmp,3)]); |
|
|
42 |
if ifea==4||ifea==5||ifea==6||ifea==7 |
|
|
43 |
feature_tmp = 10*log10(feature_tmp); |
|
|
44 |
end |
|
|
45 |
x = cat(3, x, feature_tmp); |
|
|
46 |
end |
|
|
47 |
|
|
|
48 |
% long sequences |
|
|
49 |
bs_x = bs(:,time_range(1)-time_step+1:time_range(i),:); |
|
|
50 |
bs_x = permute(bs_x,[1,3,2]); |
|
|
51 |
bs_x = reshape(bs_x,[size(bs_x,1),size(bs_x,2)*size(bs_x,3)]); |
|
|
52 |
spike_x = spike(:,time_range(1)-time_step+1:time_range(i),:); |
|
|
53 |
spike_x = permute(spike_x,[1,3,2]); |
|
|
54 |
spike_x = reshape(spike_x,[size(spike_x,1),size(spike_x,2)*size(spike_x,3)]); |
|
|
55 |
x_long = cat(3, bs_x, spike_x); |
|
|
56 |
|
|
|
57 |
for ifea = 1:7 |
|
|
58 |
feature_tmp = features{ifea}(:,time_range(1)-time_step+1:time_range(i),:); |
|
|
59 |
feature_tmp = permute(feature_tmp,[1,3,2]); |
|
|
60 |
feature_tmp = reshape(feature_tmp,[size(feature_tmp,1),size(feature_tmp,2)*size(feature_tmp,3)]); |
|
|
61 |
if ifea==4||ifea==5||ifea==6||ifea==7 |
|
|
62 |
feature_tmp = 10*log10(feature_tmp); |
|
|
63 |
end |
|
|
64 |
x_long = cat(3, x_long, feature_tmp); |
|
|
65 |
end |
|
|
66 |
|
|
|
67 |
cpc_scores_binary = cpc_scores; |
|
|
68 |
pos_index = find(cpc_scores<3); |
|
|
69 |
neg_index = find(cpc_scores>=3); |
|
|
70 |
cpc_scores_binary(pos_index) = 1; |
|
|
71 |
cpc_scores_binary(neg_index) = 0; |
|
|
72 |
|
|
|
73 |
X = {}; |
|
|
74 |
X_long = {}; |
|
|
75 |
Y = {}; |
|
|
76 |
pts_id = {}; |
|
|
77 |
for ipts = 1:size(x,1) |
|
|
78 |
x_tmp = squeeze(x(ipts,:,:)); |
|
|
79 |
x_tmp = x_tmp'; |
|
|
80 |
x_long_tmp = squeeze(x_long(ipts,:,:)); |
|
|
81 |
x_long_tmp = x_long_tmp'; |
|
|
82 |
|
|
|
83 |
for itmp = 1:size(x_tmp,1) |
|
|
84 |
x_tmp(itmp,isinf(x_tmp(itmp,:))) = nan; |
|
|
85 |
x_long_tmp(itmp,isinf(x_long_tmp(itmp,:))) = nan; |
|
|
86 |
x_tmp(itmp,isnan(x_tmp(itmp,:))) = nanmean(x_tmp(itmp,:)); |
|
|
87 |
temp = x_long_tmp(itmp,end-time_step*12+1:end); |
|
|
88 |
temp(1,isnan(temp(1,:))) = nanmean(temp(1,:)); |
|
|
89 |
x_long_tmp(itmp,end-time_step*12+1:end) = temp; |
|
|
90 |
nan_index = find(isnan(x_long_tmp(itmp,:))); |
|
|
91 |
for jnan = length(nan_index):-1:1 |
|
|
92 |
if nan_index(jnan)+time_step*12<=size(x_long_tmp,2) |
|
|
93 |
x_long_tmp(itmp,nan_index(jnan)) = nanmean(x_long_tmp(itmp,nan_index(jnan)+1:nan_index(jnan)+time_step*12)); |
|
|
94 |
end |
|
|
95 |
end |
|
|
96 |
% index = find(isinf(x_long_tmp(itmp,:))); |
|
|
97 |
% for idex = 1:length(index) |
|
|
98 |
% if ~isinf(x_long_tmp(itmp,index(idex)-1)) |
|
|
99 |
% x_long_tmp(itmp,index(idex)) = x_long_tmp(itmp,index(idex)-1); |
|
|
100 |
% else |
|
|
101 |
% x_long_tmp(itmp,index(idex)) = x_long_tmp(itmp,index(idex)+1); |
|
|
102 |
% end |
|
|
103 |
% end |
|
|
104 |
end |
|
|
105 |
|
|
|
106 |
% reshape inputs of long term lstm with the same dimensions |
|
|
107 |
reduced_dim = 72*2; |
|
|
108 |
x_long_tmp_short = zeros(9,reduced_dim); |
|
|
109 |
for itmp = 1:size(x_long_tmp,1) |
|
|
110 |
if size(x_long_tmp,2)>reduced_dim |
|
|
111 |
n = fix(size(x_long_tmp,2)/72/2); % for every n points, generate 1 points |
|
|
112 |
b = arrayfun(@(i) mean(x_long_tmp(itmp,i:i+n-1),2),sort([1:size(x_long_tmp,2)/72:size(x_long_tmp,2),n:size(x_long_tmp,2)/72:size(x_long_tmp,2)])); % the averaged vector |
|
|
113 |
x_long_tmp_short(itmp,:) = b; |
|
|
114 |
end |
|
|
115 |
end |
|
|
116 |
|
|
|
117 |
if ~isnan(sum(sum(x_tmp)))&&~isinf(sum(sum(x_tmp))) |
|
|
118 |
X = [X; x_tmp]; |
|
|
119 |
X_long = [X_long; x_long_tmp_short]; |
|
|
120 |
Y = [Y; num2str(cpc_scores_binary(ipts))]; |
|
|
121 |
pts_id = [pts_id; unique_names{ipts}]; |
|
|
122 |
end |
|
|
123 |
end |
|
|
124 |
|
|
|
125 |
|
|
|
126 |
|
|
|
127 |
|
|
|
128 |
XV = [X{:}]; |
|
|
129 |
mu = mean(XV,2); |
|
|
130 |
sg = std(XV,[],2); |
|
|
131 |
X = cellfun(@(x)(x-mu)./sg,X,'UniformOutput',false); |
|
|
132 |
XV = [X_long{:}]; |
|
|
133 |
mu = mean(XV,2); |
|
|
134 |
sg = std(XV,[],2); |
|
|
135 |
X_long = cellfun(@(x)(x-mu)./sg,X_long,'UniformOutput',false); |
|
|
136 |
|
|
|
137 |
num_pts = length(Y); |
|
|
138 |
num_test = round(num_pts/num_fold); |
|
|
139 |
idx = randperm(num_pts); |
|
|
140 |
X = X(idx); |
|
|
141 |
X_long = X_long(idx); |
|
|
142 |
Y = Y(idx); |
|
|
143 |
pts_id = pts_id(idx); |
|
|
144 |
% Y = categorical(Y); |
|
|
145 |
for ifold = 1:num_fold |
|
|
146 |
if ifold~=num_fold |
|
|
147 |
start_index = (ifold-1)*num_test+1; |
|
|
148 |
end_index = ifold*num_test; |
|
|
149 |
else |
|
|
150 |
start_index = (ifold-1)*num_test+1; |
|
|
151 |
end_index = num_pts; |
|
|
152 |
end |
|
|
153 |
train_index = setdiff(1:num_pts,start_index:end_index); |
|
|
154 |
|
|
|
155 |
test_data_short = X(start_index:end_index); |
|
|
156 |
test_label = Y(start_index:end_index); |
|
|
157 |
train_data_short = X(train_index); |
|
|
158 |
train_label = Y(train_index); |
|
|
159 |
|
|
|
160 |
test_data_long = X_long(start_index:end_index); |
|
|
161 |
train_data_long = X_long(train_index); |
|
|
162 |
train_set = [train_data_short,train_data_long,train_label]; |
|
|
163 |
test_set = [test_data_short,test_data_long,test_label]; |
|
|
164 |
|
|
|
165 |
pts_id_test = pts_id(start_index:end_index); |
|
|
166 |
|
|
|
167 |
ipath_short = [save_path,num2str(i),'h','_train_short_',num2str(round(rand(1)*10e6)),'\']; |
|
|
168 |
mkdir(ipath_short); |
|
|
169 |
for itrain = 1:length(train_data_short) |
|
|
170 |
train_samples = train_data_short{itrain}; |
|
|
171 |
save([ipath_short,num2str(itrain)],'train_samples'); |
|
|
172 |
end |
|
|
173 |
ipath_long = [save_path,num2str(i),'h','_train_long_',num2str(round(rand(1)*10e6)),'\']; |
|
|
174 |
mkdir(ipath_long); |
|
|
175 |
for itrain = 1:length(train_data_long) |
|
|
176 |
train_samples = train_data_long{itrain}; |
|
|
177 |
save([ipath_long,num2str(itrain)],'train_samples'); |
|
|
178 |
end |
|
|
179 |
ipath_label = [save_path,num2str(i),'h','_train_label_',num2str(round(rand(1)*10e6)),'\']; |
|
|
180 |
mkdir(ipath_label); |
|
|
181 |
for itrain = 1:length(train_label) |
|
|
182 |
train_labels = train_label{itrain}; |
|
|
183 |
save([ipath_label,num2str(itrain)],'train_labels'); |
|
|
184 |
end |
|
|
185 |
|
|
|
186 |
fds_short = fileDatastore(ipath_short,'ReadFcn',@load_variable,'FileExtensions','.mat'); |
|
|
187 |
fds_long = fileDatastore(ipath_long,'ReadFcn',@load_variable,'FileExtensions','.mat'); |
|
|
188 |
fds_label = fileDatastore(ipath_label,'ReadFcn',@load_variable,'FileExtensions','.mat'); |
|
|
189 |
train_datastore = combine(fds_short,fds_long); |
|
|
190 |
|
|
|
191 |
%% LSTM |
|
|
192 |
miniBatchSize = 150; |
|
|
193 |
maxEpochs = 100; |
|
|
194 |
layers_short = [ ... |
|
|
195 |
sequenceInputLayer(9,'Name','InputLayer') |
|
|
196 |
sequenceFoldingLayer('Name','fold') |
|
|
197 |
splittingLayer('Splitting-1st','1st') |
|
|
198 |
bilstmLayer(num_neuron,'OutputMode','sequence','Name','lstm1_short') |
|
|
199 |
dropoutLayer(0.1,'Name','dropout1_short') |
|
|
200 |
bilstmLayer(num_neuron,'OutputMode','sequence','Name','lstm2_short') |
|
|
201 |
% dropoutLayer(0.1) |
|
|
202 |
bilstmLayer(num_neuron,'OutputMode','sequence','Name','lstm3_short') |
|
|
203 |
% dropoutLayer(0.1) |
|
|
204 |
bilstmLayer(num_neuron,'OutputMode','last','Name','lstm4_short') |
|
|
205 |
% dropoutLayer(0.1) |
|
|
206 |
fullyConnectedLayer(num_neuron,'Name','fc_short') |
|
|
207 |
concatenationLayer(1,2,'Name','cat') |
|
|
208 |
% additionLayer(2,'Name','add') |
|
|
209 |
fullyConnectedLayer(2,'Name','fc') |
|
|
210 |
softmaxLayer('Name','softmax_short') |
|
|
211 |
classificationLayer('Name','classOutput') |
|
|
212 |
]; |
|
|
213 |
layers_long = [ ... |
|
|
214 |
% sequenceInputLayer(9,'Name','input_long') |
|
|
215 |
splittingLayer('Splitting-2nd','2nd') |
|
|
216 |
bilstmLayer(num_neuron,'OutputMode','sequence','Name','lstm1_long') |
|
|
217 |
dropoutLayer(0.1,'Name','dropout1_long') |
|
|
218 |
bilstmLayer(num_neuron,'OutputMode','sequence','Name','lstm2_long') |
|
|
219 |
% dropoutLayer(0.1) |
|
|
220 |
bilstmLayer(num_neuron,'OutputMode','sequence','Name','lstm3_long') |
|
|
221 |
% dropoutLayer(0.1) |
|
|
222 |
bilstmLayer(num_neuron,'OutputMode','last','Name','lstm4_long') |
|
|
223 |
% dropoutLayer(0.1) |
|
|
224 |
fullyConnectedLayer(num_neuron,'Name','fc_long') |
|
|
225 |
% softmaxLayer('Name','softmax_long') |
|
|
226 |
]; |
|
|
227 |
lgraph = layerGraph(layers_short); |
|
|
228 |
lgraph = addLayers(lgraph,layers_long); |
|
|
229 |
lgraph = connectLayers(lgraph,'fc_long','cat/in2'); |
|
|
230 |
layers = connectLayers(lgraph,'InputLayer','Splitting-2nd'); |
|
|
231 |
% lgraph = addLayers(lgraph,sequenceInputLayer(9,'Name','input')); |
|
|
232 |
% lgraph = connectLayers(lgraph,'input','lstm1_short'); |
|
|
233 |
% lgraph = connectLayers(lgraph,'input','lstm1_long'); |
|
|
234 |
figure,plot(lgraph) |
|
|
235 |
|
|
|
236 |
options = trainingOptions('sgdm', ...%adam sgdm |
|
|
237 |
'MaxEpochs',maxEpochs, ... |
|
|
238 |
'MiniBatchSize', miniBatchSize, ... |
|
|
239 |
'InitialLearnRate', 0.1, ... %0.8 |
|
|
240 |
'ExecutionEnvironment',"cpu",...%'GradientThreshold', 1, ... |
|
|
241 |
'Shuffle','never', ... %every-epoch |
|
|
242 |
'plots','training-progress', ...%training-progress none |
|
|
243 |
'ValidationData',{test_set(:,1:2),categorical(test_label)},... |
|
|
244 |
'Verbose',false);%'OutputFcn', @(info)savetrainingplot(info) |
|
|
245 |
|
|
|
246 |
% concatenation or addition |
|
|
247 |
net = trainNetwork(train_datastore,train_label,lgraph,options); |
|
|
248 |
[pred,probabilities] = classify(net,test_data); |
|
|
249 |
|
|
|
250 |
preds{i,ifold} = pred; |
|
|
251 |
labels{i,ifold} = test_label; |
|
|
252 |
pred_probability{i,ifold} = probabilities; |
|
|
253 |
net_model{i,ifold} = net; |
|
|
254 |
pts_id_fold{i,ifold} = pts_id_test; |
|
|
255 |
end |
|
|
256 |
end |
|
|
257 |
preds_all{itrial} = preds; |
|
|
258 |
labels_all{itrial} = labels; |
|
|
259 |
pred_probability_all{itrial} = pred_probability; |
|
|
260 |
net_model_all{itrial} = net_model; |
|
|
261 |
pts_id_all{itrial} = pts_id_fold; |
|
|
262 |
end |
|
|
263 |
save(['D:\Research\Cardiac_arrest_EEG\Codes\ComaPrognosticanUsingEEG-master\deep_learning_results\four_layers\','bilstm_four_neurons_',num2str(num_neuron),'_epoch_',num2str(maxEpochs)],'preds','labels','layers','options','pred_probability','net_model'); |
|
|
264 |
end |
|
|
265 |
|
|
|
266 |
%[updatedNet,YPred] = predictAndUpdateState(recNet,sequences) |
|
|
267 |
|
|
|
268 |
% function stop=savetrainingplot(info) |
|
|
269 |
% stop=false; %prevents this function from ending trainNetwork prematurely |
|
|
270 |
% if info.State=='done' %check if all iterations have completed |
|
|
271 |
% % if true |
|
|
272 |
% saveas(gca,'training_process.png') % save figure as .png, you can change this |
|
|
273 |
% |
|
|
274 |
% end |
|
|
275 |
% end |
|
|
276 |
|
|
|
277 |
% options = trainingOptions('sgdm',... |
|
|
278 |
% 'InitialLearnRate',0.003,... |
|
|
279 |
% 'Plots','training-progress', ... |
|
|
280 |
% 'ValidationData',garVal,... |
|
|
281 |
% 'ValidationFrequency',40,... |
|
|
282 |
% 'MaxEpochs',1,... |
|
|
283 |
% 'LearnRateSchedule', 'piecewise',... |
|
|
284 |
% 'LearnRateDropPeriod',3,... |
|
|
285 |
% 'Shuffle','every-epoch',... |
|
|
286 |
% 'ValidationPatience',5,... |
|
|
287 |
% 'OutputFcn',@(info)SaveTrainingPlot(info),... |
|
|
288 |
c% 'Verbose',true); |
|
|
289 |
% % ... Training code ... |
|
|
290 |
% % At the end of the script: |
|
|
291 |
% function stop = SaveTrainingPlot(info) |
|
|
292 |
% stop = false; |
|
|
293 |
% if info.State == "done" |
|
|
294 |
% currentfig = findall(groot,'Type','Figure'); |
|
|
295 |
% savefig(currentfig,'prova.png') |
|
|
296 |
% end |
|
|
297 |
% end |