a b/src/rnn/lstm.py
1
import torch 
2
import torch.nn as nn
3
from src.rnn.rnn_utils import create_emb_layer
4
5
class LSTMw2vmodel(nn.Module) :
6
  def __init__(self, weights_matrix, hidden_size, num_layers, device, num_classes = 10) :
7
8
    super().__init__()
9
    self.num_layers = num_layers
10
    self.hidden_size = hidden_size
11
    self.device = device
12
    self.embeddings, num_embeddings, embedding_size = create_emb_layer(weights_matrix, True)
13
    self.lstm = nn.LSTM(embedding_size, hidden_size, num_layers, batch_first=True, bidirectional = True)
14
    self.fc1 = nn.Sequential(
15
        nn.Linear(2*hidden_size, 128),
16
        nn.ReLU(),
17
    )
18
    self.fc2 = nn.Linear(128, num_classes)
19
    self.act = nn.Sigmoid()
20
      
21
      
22
  def forward(self, x):     
23
    x = self.embeddings(x)
24
    h0 = torch.zeros(2*self.num_layers, x.size(0), self.hidden_size).to(self.device)
25
    c0 = torch.zeros(2*self.num_layers, x.size(0), self.hidden_size).to(self.device)
26
    lstm_out, (ht, ct) = self.lstm(x, (h0,c0))
27
28
    out = self.fc1(lstm_out[:,-1,:])
29
    out = self.fc2(out)
30
    return self.act(out)