Diff of /mlp/mlp_model.py [000000] .. [3f1788]

Switch to unified view

a b/mlp/mlp_model.py
1
#!/usr/bin/env python
2
3
import torch
4
from torch import nn
5
6
class MLPModel(nn.Module):
7
  def __init__(self, vocab_sz, hidden_dim, dropout_p):
8
    super(MLPModel, self).__init__()
9
    
10
    self.fc1 = nn.Linear(in_features=vocab_sz, out_features=hidden_dim)
11
    self.relu = nn.ReLU()
12
    self.dropout = nn.Dropout(dropout_p)
13
    self.fc2 = nn.Linear(in_features=hidden_dim, out_features=1)
14
15
  def forward(self, x_in):
16
    x_out = self.fc1(x_in)
17
    x_out = self.dropout(self.relu(x_out))
18
    x_out = self.fc2(x_out)
19
    
20
    return x_out.squeeze(1)