Switch to unified view

a b/deep_learning_multiscale_lstm.m
1
% Multiscale LSTM combining fine-grained short-term dynamics (each 6 h time
2
% block) and coarst long-term dynsmics (all past recordings up to now).
3
% Wei-Long Zheng, MGH
4
% Email: weilonglive@gmail.com
5
6
% short term lstms
7
clear
8
close all
9
load('feature_sequences\all_features_sequences.mat')
10
save_path = ('multi_scale\');
11
for num_neuron = 30
12
    preds_all = {};
13
    labels_all = {};
14
    pred_probability_all = {};
15
    net_model_all = {};
16
    pts_id_all = {};
17
    
18
    for itrial = 1:10
19
time_step = 6;
20
time_range = 12:time_step:96;
21
num_fold = 5;
22
23
preds = cell(length(time_range),num_fold);
24
labels = cell(length(time_range),num_fold);
25
pred_probability = cell(length(time_range),num_fold);
26
net_model = cell(length(time_range),num_fold);
27
for i = length(time_range)-5%1:length(time_range)
28
29
    % short sequneces
30
    bs_x = bs(:,time_range(i)-time_step+1:time_range(i),:);
31
    bs_x = permute(bs_x,[1,3,2]);
32
    bs_x = reshape(bs_x,[size(bs_x,1),size(bs_x,2)*size(bs_x,3)]);
33
    spike_x = spike(:,time_range(i)-time_step+1:time_range(i),:);
34
    spike_x = permute(spike_x,[1,3,2]);
35
    spike_x = reshape(spike_x,[size(spike_x,1),size(spike_x,2)*size(spike_x,3)]);
36
    x = cat(3, bs_x, spike_x);
37
38
    for ifea = 1:7
39
        feature_tmp = features{ifea}(:,time_range(i)-time_step+1:time_range(i),:);
40
        feature_tmp = permute(feature_tmp,[1,3,2]);
41
        feature_tmp = reshape(feature_tmp,[size(feature_tmp,1),size(feature_tmp,2)*size(feature_tmp,3)]);
42
        if ifea==4||ifea==5||ifea==6||ifea==7
43
            feature_tmp = 10*log10(feature_tmp);
44
        end
45
        x = cat(3, x, feature_tmp);
46
    end
47
    
48
    % long sequences
49
    bs_x = bs(:,time_range(1)-time_step+1:time_range(i),:);
50
    bs_x = permute(bs_x,[1,3,2]);
51
    bs_x = reshape(bs_x,[size(bs_x,1),size(bs_x,2)*size(bs_x,3)]);
52
    spike_x = spike(:,time_range(1)-time_step+1:time_range(i),:);
53
    spike_x = permute(spike_x,[1,3,2]);
54
    spike_x = reshape(spike_x,[size(spike_x,1),size(spike_x,2)*size(spike_x,3)]);
55
    x_long = cat(3, bs_x, spike_x);
56
57
    for ifea = 1:7
58
        feature_tmp = features{ifea}(:,time_range(1)-time_step+1:time_range(i),:);
59
        feature_tmp = permute(feature_tmp,[1,3,2]);
60
        feature_tmp = reshape(feature_tmp,[size(feature_tmp,1),size(feature_tmp,2)*size(feature_tmp,3)]);
61
        if ifea==4||ifea==5||ifea==6||ifea==7
62
            feature_tmp = 10*log10(feature_tmp);
63
        end
64
        x_long = cat(3, x_long, feature_tmp);
65
    end
66
    
67
    cpc_scores_binary = cpc_scores;
68
    pos_index = find(cpc_scores<3);
69
    neg_index = find(cpc_scores>=3);
70
    cpc_scores_binary(pos_index) = 1;
71
    cpc_scores_binary(neg_index) = 0;
72
    
73
    X = {};
74
    X_long = {};
75
    Y = {};
76
    pts_id = {};
77
    for ipts = 1:size(x,1)
78
        x_tmp = squeeze(x(ipts,:,:));
79
        x_tmp = x_tmp';
80
        x_long_tmp = squeeze(x_long(ipts,:,:));
81
        x_long_tmp = x_long_tmp';
82
        
83
        for itmp = 1:size(x_tmp,1)
84
            x_tmp(itmp,isinf(x_tmp(itmp,:))) = nan;
85
            x_long_tmp(itmp,isinf(x_long_tmp(itmp,:))) = nan;
86
            x_tmp(itmp,isnan(x_tmp(itmp,:))) = nanmean(x_tmp(itmp,:));
87
            temp = x_long_tmp(itmp,end-time_step*12+1:end);
88
            temp(1,isnan(temp(1,:))) = nanmean(temp(1,:));
89
            x_long_tmp(itmp,end-time_step*12+1:end) = temp;
90
            nan_index = find(isnan(x_long_tmp(itmp,:)));
91
            for jnan = length(nan_index):-1:1
92
                if nan_index(jnan)+time_step*12<=size(x_long_tmp,2)
93
                    x_long_tmp(itmp,nan_index(jnan)) = nanmean(x_long_tmp(itmp,nan_index(jnan)+1:nan_index(jnan)+time_step*12));
94
                end
95
            end
96
%             index = find(isinf(x_long_tmp(itmp,:)));
97
%             for idex = 1:length(index)
98
%                 if ~isinf(x_long_tmp(itmp,index(idex)-1))
99
%                     x_long_tmp(itmp,index(idex)) = x_long_tmp(itmp,index(idex)-1);
100
%                 else
101
%                     x_long_tmp(itmp,index(idex)) = x_long_tmp(itmp,index(idex)+1);
102
%                 end
103
%             end
104
        end
105
        
106
        % reshape inputs of long term lstm with the same dimensions
107
        reduced_dim = 72*2;
108
        x_long_tmp_short = zeros(9,reduced_dim);
109
        for itmp = 1:size(x_long_tmp,1)
110
            if size(x_long_tmp,2)>reduced_dim
111
                n = fix(size(x_long_tmp,2)/72/2); % for every n points, generate 1 points
112
                b = arrayfun(@(i) mean(x_long_tmp(itmp,i:i+n-1),2),sort([1:size(x_long_tmp,2)/72:size(x_long_tmp,2),n:size(x_long_tmp,2)/72:size(x_long_tmp,2)])); % the averaged vector
113
                x_long_tmp_short(itmp,:) = b;
114
            end
115
        end
116
        
117
        if ~isnan(sum(sum(x_tmp)))&&~isinf(sum(sum(x_tmp)))
118
            X = [X; x_tmp];
119
            X_long = [X_long; x_long_tmp_short];
120
            Y = [Y; num2str(cpc_scores_binary(ipts))];
121
            pts_id = [pts_id; unique_names{ipts}];
122
        end
123
    end
124
    
125
    
126
    
127
    
128
    XV = [X{:}];
129
    mu = mean(XV,2);
130
    sg = std(XV,[],2);
131
    X = cellfun(@(x)(x-mu)./sg,X,'UniformOutput',false);
132
    XV = [X_long{:}];
133
    mu = mean(XV,2);
134
    sg = std(XV,[],2);
135
    X_long = cellfun(@(x)(x-mu)./sg,X_long,'UniformOutput',false);
136
137
    num_pts = length(Y);
138
    num_test = round(num_pts/num_fold);
139
    idx = randperm(num_pts);
140
    X = X(idx);
141
    X_long = X_long(idx);
142
    Y = Y(idx);
143
    pts_id = pts_id(idx);
144
%     Y = categorical(Y);
145
    for ifold = 1:num_fold
146
        if ifold~=num_fold
147
            start_index = (ifold-1)*num_test+1;
148
            end_index = ifold*num_test;
149
        else
150
            start_index = (ifold-1)*num_test+1;
151
            end_index = num_pts;
152
        end
153
        train_index = setdiff(1:num_pts,start_index:end_index);
154
        
155
        test_data_short = X(start_index:end_index);
156
        test_label = Y(start_index:end_index);
157
        train_data_short = X(train_index);
158
        train_label = Y(train_index);
159
        
160
        test_data_long = X_long(start_index:end_index);
161
        train_data_long = X_long(train_index);
162
        train_set = [train_data_short,train_data_long,train_label];
163
        test_set = [test_data_short,test_data_long,test_label];
164
        
165
        pts_id_test = pts_id(start_index:end_index);
166
        
167
        ipath_short = [save_path,num2str(i),'h','_train_short_',num2str(round(rand(1)*10e6)),'\'];
168
        mkdir(ipath_short);
169
        for itrain = 1:length(train_data_short)
170
            train_samples = train_data_short{itrain};
171
            save([ipath_short,num2str(itrain)],'train_samples');
172
        end
173
        ipath_long = [save_path,num2str(i),'h','_train_long_',num2str(round(rand(1)*10e6)),'\'];
174
        mkdir(ipath_long);
175
        for itrain = 1:length(train_data_long)
176
            train_samples = train_data_long{itrain};
177
            save([ipath_long,num2str(itrain)],'train_samples');
178
        end
179
        ipath_label = [save_path,num2str(i),'h','_train_label_',num2str(round(rand(1)*10e6)),'\'];
180
        mkdir(ipath_label);
181
        for itrain = 1:length(train_label)
182
            train_labels = train_label{itrain};
183
            save([ipath_label,num2str(itrain)],'train_labels');
184
        end
185
        
186
        fds_short = fileDatastore(ipath_short,'ReadFcn',@load_variable,'FileExtensions','.mat');
187
        fds_long = fileDatastore(ipath_long,'ReadFcn',@load_variable,'FileExtensions','.mat');
188
        fds_label = fileDatastore(ipath_label,'ReadFcn',@load_variable,'FileExtensions','.mat');
189
        train_datastore = combine(fds_short,fds_long);
190
        
191
        %% LSTM
192
        miniBatchSize = 150;
193
        maxEpochs = 100;
194
        layers_short = [ ...
195
            sequenceInputLayer(9,'Name','InputLayer')
196
            sequenceFoldingLayer('Name','fold')
197
            splittingLayer('Splitting-1st','1st')
198
            bilstmLayer(num_neuron,'OutputMode','sequence','Name','lstm1_short')
199
            dropoutLayer(0.1,'Name','dropout1_short')
200
            bilstmLayer(num_neuron,'OutputMode','sequence','Name','lstm2_short')
201
%             dropoutLayer(0.1)
202
            bilstmLayer(num_neuron,'OutputMode','sequence','Name','lstm3_short')
203
%             dropoutLayer(0.1)
204
            bilstmLayer(num_neuron,'OutputMode','last','Name','lstm4_short')
205
%             dropoutLayer(0.1)
206
            fullyConnectedLayer(num_neuron,'Name','fc_short')
207
            concatenationLayer(1,2,'Name','cat')
208
%             additionLayer(2,'Name','add')
209
            fullyConnectedLayer(2,'Name','fc')
210
            softmaxLayer('Name','softmax_short')
211
            classificationLayer('Name','classOutput')
212
            ];
213
        layers_long = [ ...
214
%             sequenceInputLayer(9,'Name','input_long')
215
            splittingLayer('Splitting-2nd','2nd')
216
            bilstmLayer(num_neuron,'OutputMode','sequence','Name','lstm1_long')
217
            dropoutLayer(0.1,'Name','dropout1_long')
218
            bilstmLayer(num_neuron,'OutputMode','sequence','Name','lstm2_long')
219
%             dropoutLayer(0.1)
220
            bilstmLayer(num_neuron,'OutputMode','sequence','Name','lstm3_long')
221
%             dropoutLayer(0.1)
222
            bilstmLayer(num_neuron,'OutputMode','last','Name','lstm4_long')
223
%             dropoutLayer(0.1)
224
            fullyConnectedLayer(num_neuron,'Name','fc_long')
225
%             softmaxLayer('Name','softmax_long')
226
            ];
227
        lgraph = layerGraph(layers_short);
228
        lgraph = addLayers(lgraph,layers_long);
229
        lgraph = connectLayers(lgraph,'fc_long','cat/in2');
230
        layers = connectLayers(lgraph,'InputLayer','Splitting-2nd');
231
%         lgraph = addLayers(lgraph,sequenceInputLayer(9,'Name','input'));
232
%         lgraph = connectLayers(lgraph,'input','lstm1_short');
233
%         lgraph = connectLayers(lgraph,'input','lstm1_long');
234
        figure,plot(lgraph)
235
236
        options = trainingOptions('sgdm', ...%adam sgdm
237
            'MaxEpochs',maxEpochs, ...
238
            'MiniBatchSize', miniBatchSize, ...
239
            'InitialLearnRate', 0.1, ... %0.8
240
            'ExecutionEnvironment',"cpu",...%'GradientThreshold', 1, ... 
241
            'Shuffle','never', ... %every-epoch
242
            'plots','training-progress', ...%training-progress none
243
            'ValidationData',{test_set(:,1:2),categorical(test_label)},...
244
            'Verbose',false);%'OutputFcn', @(info)savetrainingplot(info)
245
        
246
        % concatenation or addition
247
        net = trainNetwork(train_datastore,train_label,lgraph,options);
248
        [pred,probabilities] = classify(net,test_data);
249
        
250
        preds{i,ifold} = pred;
251
        labels{i,ifold} = test_label;
252
        pred_probability{i,ifold} = probabilities;
253
        net_model{i,ifold} = net;
254
        pts_id_fold{i,ifold} = pts_id_test;
255
    end
256
end
257
        preds_all{itrial} = preds;
258
        labels_all{itrial} = labels;
259
        pred_probability_all{itrial} = pred_probability;
260
        net_model_all{itrial} = net_model;
261
        pts_id_all{itrial} = pts_id_fold;
262
    end
263
save(['D:\Research\Cardiac_arrest_EEG\Codes\ComaPrognosticanUsingEEG-master\deep_learning_results\four_layers\','bilstm_four_neurons_',num2str(num_neuron),'_epoch_',num2str(maxEpochs)],'preds','labels','layers','options','pred_probability','net_model');
264
end
265
266
%[updatedNet,YPred] = predictAndUpdateState(recNet,sequences)
267
268
% function stop=savetrainingplot(info)
269
% stop=false;  %prevents this function from ending trainNetwork prematurely
270
% if info.State=='done'   %check if all iterations have completed
271
% % if true
272
%       saveas(gca,'training_process.png')  % save figure as .png, you can change this
273
% 
274
% end
275
% end
276
277
% options = trainingOptions('sgdm',...
278
%     'InitialLearnRate',0.003,...
279
%     'Plots','training-progress', ...
280
%     'ValidationData',garVal,...
281
%     'ValidationFrequency',40,...
282
%     'MaxEpochs',1,...
283
%     'LearnRateSchedule', 'piecewise',...
284
%     'LearnRateDropPeriod',3,...
285
%     'Shuffle','every-epoch',...
286
%     'ValidationPatience',5,...
287
%     'OutputFcn',@(info)SaveTrainingPlot(info),...
288
c%     'Verbose',true);
289
% % ... Training code ...
290
% % At the end of the script:
291
% function stop = SaveTrainingPlot(info)
292
% stop = false;
293
% if info.State == "done"
294
%     currentfig = findall(groot,'Type','Figure');
295
%     savefig(currentfig,'prova.png')
296
% end
297
% end