|
a |
|
b/medicalbert/classifiers/standard/bert_head.py |
|
|
1 |
from torch import nn |
|
|
2 |
|
|
|
3 |
|
|
|
4 |
class BertMeanPooling(nn.Module): |
|
|
5 |
def __init__(self, config): |
|
|
6 |
super(BertMeanPooling, self).__init__() |
|
|
7 |
self.dense = nn.Linear(config.hidden_size, config.hidden_size) |
|
|
8 |
self.activation = nn.Tanh() |
|
|
9 |
|
|
|
10 |
def forward(self, hidden_states): |
|
|
11 |
# We "pool" the model by simply taking all the hidden states and averaging them. |
|
|
12 |
pooled_output = self.dense(hidden_states.mean(1)) |
|
|
13 |
pooled_output = self.activation(pooled_output) |
|
|
14 |
return pooled_output |
|
|
15 |
|
|
|
16 |
|
|
|
17 |
class BERTFCHead(nn.Module): |
|
|
18 |
def __init__(self, config): |
|
|
19 |
super(BERTFCHead, self).__init__() |
|
|
20 |
self.lstm = nn.LSTM(768, 768, 2, batch_first = True) |
|
|
21 |
|
|
|
22 |
def forward(self, hidden_states): |
|
|
23 |
print(hidden_states.shape) |
|
|
24 |
# We "pool" the model by simply taking all the hidden states and averaging them. |
|
|
25 |
pooled_output = self.lstm(hidden_states) |
|
|
26 |
return pooled_output[0] |