a | b/create_rnn_model.py | ||
---|---|---|---|
1 | from tensorflow.keras.models import Sequential |
||
2 | from tensorflow.keras.layers import Embedding, LSTM, Dense |
||
3 | |||
4 | def create_rnn_model(input_shape, num_classes): |
||
5 | model = Sequential() |
||
6 | model.add(Embedding(input_dim=10000, output_dim=128, input_length=input_shape[1])) |
||
7 | model.add(LSTM(64)) |
||
8 | model.add(Dense(128, activation='relu')) |
||
9 | model.add(Dense(num_classes, activation='softmax')) |
||
10 | |||
11 | return model |