--- a +++ b/src/cnn/cnn.py @@ -0,0 +1,72 @@ +import torch +import torch.nn as nn + +class character_cnn(nn.Module): + def __init__(self, vocabulary, sequence_length, number_classes = 10): + super().__init__() + + self.conv1 = nn.Sequential(nn.Conv1d(len(vocabulary)+1, 256, kernel_size = 7, padding = 0), + nn.ReLU(), + nn.MaxPool1d(3) + ) + + self.conv2 = nn.Sequential(nn.Conv1d(256, 256, kernel_size=7, padding=0), + nn.ReLU(), + nn.MaxPool1d(3) + ) + + self.conv3 = nn.Sequential(nn.Conv1d(256, 256, kernel_size=3, padding=0), + nn.ReLU() + ) + + self.conv4 = nn.Sequential(nn.Conv1d(256, 256, kernel_size=3, padding=0), + nn.ReLU() + ) + + input_shape = (1, len(vocabulary)+1, sequence_length) + self.output_dimension = self._get_conv_output(input_shape) + + self.fc1 = nn.Sequential( + nn.Linear(self.output_dimension, 1024), + nn.ReLU(), + nn.Dropout(0.5) + ) + + self.fc2 = nn.Sequential( + nn.Linear(1024, 1024), + nn.ReLU(), + nn.Dropout(0.5) + ) + + self.fc3 = nn.Linear(1024, number_classes) + + + self.act = nn.Sigmoid() + + def _get_conv_output(self, shape): + x = torch.rand(shape) + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) + x = self.conv4(x) + x = x.view(x.size(0), -1) + output_dimension = x.size(1) + return output_dimension + + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) + x = self.conv4(x) + x = x.view(x.size(0), -1) + x = self.fc1(x) + x = self.fc2(x) + x = self.fc3(x) + x = self.act(x) + return x + + + + +