a b/src/hybrid/hybrid.py
1
import torch 
2
import torch.nn as nn
3
from src.rnn.rnn_utils import create_emb_layer
4
5
6
class hybrid(nn.Module):
7
  def __init__(self, vocabulary, sequence_length, weights_matrix, hidden_size, num_layers=2, num_classes=10):
8
    super().__init__()
9
10
    self.num_layers = num_layers
11
    self.hidden_size = hidden_size
12
13
    self.conv1 = nn.Sequential(nn.Conv1d(len(vocabulary)+1,
14
                                            128,
15
                                            kernel_size=7,
16
                                            padding=0),
17
                                  nn.ReLU(),
18
                                  nn.MaxPool1d(3)
19
                                  )
20
21
    self.conv2 = nn.Sequential(nn.Conv1d(128, 128, kernel_size=7, padding=0),
22
                                nn.ReLU(),
23
                                nn.MaxPool1d(3)
24
                                )
25
26
    self.conv3 = nn.Sequential(nn.Conv1d(128, 128, kernel_size=3, padding=0),
27
                                nn.ReLU()
28
                                )
29
30
    self.conv4 = nn.Sequential(nn.Conv1d(128, 128, kernel_size=3, padding=0),
31
                                nn.ReLU()
32
                                )
33
    input_shape = (1, len(vocabulary)+1, sequence_length)
34
    self.output_dimension = self._get_conv_output(input_shape)
35
36
    # define linear layers
37
38
    self.fc1 = nn.Sequential(
39
        nn.Linear(self.output_dimension, 256),
40
        nn.ReLU(),
41
    )
42
43
      
44
    self.embeddings, num_embeddings, embedding_size = create_emb_layer(weights_matrix, True)
45
    self.gru1 = nn.GRU(embedding_size, hidden_size, num_layers, bidirectional = True, batch_first=True)
46
        
47
    self.fc2 = nn.Sequential(
48
        nn.Linear(2*hidden_size, 256),
49
        nn.ReLU(),
50
    ) 
51
52
53
    self.fc3 = nn.Linear(512,10)   
54
    self.act = nn.Sigmoid()
55
56
57
  def _get_conv_output(self, shape):
58
        x = torch.rand(shape)
59
        x = self.conv1(x)
60
        x = self.conv2(x)
61
        x = self.conv3(x)
62
        x = self.conv4(x)
63
        x = x.view(x.size(0), -1)
64
        output_dimension = x.size(1)
65
        return output_dimension
66
  
67
  def forward(self,rnninput, cnninput):
68
    cnn_out = self.conv1(cnninput)
69
    cnn_out = self.conv2(cnn_out)
70
    cnn_out = self.conv3(cnn_out)
71
    cnn_out = self.conv4(cnn_out)
72
    cnn_out = cnn_out.view(cnn_out.size(0),-1)
73
    cnn_out = self.fc1(cnn_out)
74
75
    rnn_out = self.embeddings(rnninput)
76
    rnn_out,_ = self.gru1(rnn_out)
77
    rnn_out = self.fc2(rnn_out[:,-1,:])
78
79
80
    x = torch.cat((cnn_out,rnn_out),dim=1)
81
    out = self.fc3(x)
82
    out = self.act(out)
83
    return out