Diff of /Models/Network/BiGRU.py [000000] .. [259458]

Switch to unified view

a b/Models/Network/BiGRU.py
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
# Import useful packages
5
import tensorflow as tf
6
7
8
def BiGRU(Input, max_time, n_input, gru_size, keep_prob, weights_1, biases_1, weights_2, biases_2):
9
    '''
10
11
    Args:
12
        Input: The reshaped input EEG signals
13
        max_time: The unfolded time slice of BiGRU Model
14
        n_input: The input signal size at one time
15
        gru_size: The number of GRU units inside the BiGRU Model
16
        keep_prob: The Keep probability of Dropout
17
        weights_1: The Weights of first fully-connected layer
18
        biases_1: The biases of first fully-connected layer
19
        weights_2: The Weights of second fully-connected layer
20
        biases_2: The biases of second fully-connected layer
21
22
    Returns:
23
        FC_2: Final prediction of BiGRU Model
24
        FC_1: Extracted features from the first fully connected layer
25
26
    '''
27
28
    # reshaped Input EEG signals
29
    Input = tf.reshape(Input, [-1, max_time, n_input])
30
31
    # Forward and Backward GRU Model (BiGRU Model)
32
    gru_fw_cell = tf.contrib.rnn.GRUCell(num_units=gru_size)
33
    gru_bw_cell = tf.contrib.rnn.GRUCell(num_units=gru_size)
34
35
    # Dropout for forward and backward GRU Model
36
    gru_fw_drop = tf.contrib.rnn.DropoutWrapper(cell=gru_fw_cell, input_keep_prob=keep_prob)
37
    gru_bw_drop = tf.contrib.rnn.DropoutWrapper(cell=gru_bw_cell, input_keep_prob=keep_prob)
38
39
    # One layer BiGRU Model
40
    outputs, _ = tf.compat.v1.nn.bidirectional_dynamic_rnn(gru_fw_drop, gru_bw_drop, Input, dtype=tf.float32)
41
    outputs = tf.concat(outputs, 2)
42
    outputs = outputs[:, max_time - 1, :]
43
44
    # First fully-connected layer
45
    FC_1 = tf.matmul(outputs, weights_1) + biases_1
46
    FC_1 = tf.layers.batch_normalization(FC_1, training=True)
47
    FC_1 = tf.nn.softplus(FC_1)
48
    FC_1 = tf.nn.dropout(FC_1, keep_prob)
49
50
    # Second fully-connected layer
51
    FC_2 = tf.matmul(FC_1, weights_2) + biases_2
52
    FC_2 = tf.nn.softmax(FC_2)
53
54
    return FC_2, FC_1