|
a |
|
b/src/rnn/lstm.py |
|
|
1 |
import torch |
|
|
2 |
import torch.nn as nn |
|
|
3 |
from src.rnn.rnn_utils import create_emb_layer |
|
|
4 |
|
|
|
5 |
class LSTMw2vmodel(nn.Module) : |
|
|
6 |
def __init__(self, weights_matrix, hidden_size, num_layers, device, num_classes = 10) : |
|
|
7 |
|
|
|
8 |
super().__init__() |
|
|
9 |
self.num_layers = num_layers |
|
|
10 |
self.hidden_size = hidden_size |
|
|
11 |
self.device = device |
|
|
12 |
self.embeddings, num_embeddings, embedding_size = create_emb_layer(weights_matrix, True) |
|
|
13 |
self.lstm = nn.LSTM(embedding_size, hidden_size, num_layers, batch_first=True, bidirectional = True) |
|
|
14 |
self.fc1 = nn.Sequential( |
|
|
15 |
nn.Linear(2*hidden_size, 128), |
|
|
16 |
nn.ReLU(), |
|
|
17 |
) |
|
|
18 |
self.fc2 = nn.Linear(128, num_classes) |
|
|
19 |
self.act = nn.Sigmoid() |
|
|
20 |
|
|
|
21 |
|
|
|
22 |
def forward(self, x): |
|
|
23 |
x = self.embeddings(x) |
|
|
24 |
h0 = torch.zeros(2*self.num_layers, x.size(0), self.hidden_size).to(self.device) |
|
|
25 |
c0 = torch.zeros(2*self.num_layers, x.size(0), self.hidden_size).to(self.device) |
|
|
26 |
lstm_out, (ht, ct) = self.lstm(x, (h0,c0)) |
|
|
27 |
|
|
|
28 |
out = self.fc1(lstm_out[:,-1,:]) |
|
|
29 |
out = self.fc2(out) |
|
|
30 |
return self.act(out) |