[7e250a]: / src / hint / trial / layers.py

Download this file

13 lines (8 with data), 451 Bytes

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
import torch.nn as nn
def FeedForward(input_dim: int, hidden_dim: int, output_dim: int, num_layers: int, dropout: float = 0.0):
layers = [nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Dropout(dropout)]
for _ in range(num_layers):
layers += [nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Dropout(dropout)]
layers.append(nn.Linear(hidden_dim, output_dim))
model = nn.Sequential(*layers)
return model