Diff of /src/rnn/gru.py [000000] .. [71ad2f]

Switch to unified view

a b/src/rnn/gru.py
1
import torch 
2
import torch.nn as nn
3
from src.rnn.rnn_utils import create_emb_layer
4
5
class GRUw2vmodel(nn.Module) :
6
  def __init__(self, weights_matrix, hidden_size, num_layers, device, num_classes = 10) :
7
8
    super().__init__()
9
    self.num_layers = num_layers
10
    self.hidden_size = hidden_size
11
    self.device = device
12
    self.embeddings, num_embeddings, embedding_size = create_emb_layer(weights_matrix, True)
13
    self.gru1 = nn.GRU(embedding_size, hidden_size, num_layers, batch_first=True)
14
        
15
    self.fc1 = nn.Linear(hidden_size, 10)
16
        
17
    self.act = nn.Sigmoid()
18
      
19
      
20
  def forward(self, x):     
21
    x = self.embeddings(x)
22
    h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(self.device)
23
    gru_out, _ = self.gru1(x, h0)
24
    out = self.fc1(gru_out[:,-1,:])
25
    return self.act(out)