a b/Models/Network/LSTM.py
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
# Import useful packages
5
import tensorflow as tf
6
7
8
def LSTM(Input, max_time, n_input, lstm_size, keep_prob, weights_1, biases_1, weights_2, biases_2):
9
    '''
10
11
    Args:
12
        Input: The reshaped input EEG signals
13
        max_time: The unfolded time slice of LSTM Model
14
        n_input: The input signal size at one time
15
        rnn_size: The number of LSTM units inside the LSTM Model
16
        keep_prob: The Keep probability of Dropout
17
        weights_1: The Weights of first fully-connected layer
18
        biases_1: The biases of first fully-connected layer
19
        weights_2: The Weights of second fully-connected layer
20
        biases_2: The biases of second fully-connected layer
21
22
    Returns:
23
        FC_2: Final prediction of LSTM Model
24
        FC_1: Extracted features from the first fully connected layer
25
26
    '''
27
28
    # One layer RNN Model
29
    Input = tf.reshape(Input, [-1, max_time, n_input])
30
    cell_encoder = tf.contrib.rnn.BasicLSTMCell(num_units=lstm_size)
31
    encoder_drop = tf.contrib.rnn.DropoutWrapper(cell=cell_encoder, input_keep_prob=keep_prob)
32
    outputs_encoder, final_state_encoder = tf.nn.dynamic_rnn(cell=encoder_drop, inputs=Input, dtype=tf.float32)
33
34
    # First fully-connected layer
35
    # final_state_encoder[0] is the long-term memory
36
    FC_1 = tf.matmul(final_state_encoder[0], weights_1) + biases_1
37
    FC_1 = tf.layers.batch_normalization(FC_1, training=True)
38
    FC_1 = tf.nn.softplus(FC_1)
39
    FC_1 = tf.nn.dropout(FC_1, keep_prob)
40
41
    # Second fully-connected layer
42
    FC_2 = tf.matmul(FC_1, weights_2) + biases_2
43
    FC_2 = tf.nn.softmax(FC_2)
44
45
    return FC_2, FC_1
46