a b/medicalbert/classifiers/standard/bert_head.py
1
from torch import nn
2
3
4
class BertMeanPooling(nn.Module):
5
    def __init__(self, config):
6
        super(BertMeanPooling, self).__init__()
7
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
8
        self.activation = nn.Tanh()
9
10
    def forward(self, hidden_states):
11
        # We "pool" the model by simply taking all the hidden states and averaging them.
12
        pooled_output = self.dense(hidden_states.mean(1))
13
        pooled_output = self.activation(pooled_output)
14
        return pooled_output
15
16
17
class BERTFCHead(nn.Module):
18
    def __init__(self, config):
19
        super(BERTFCHead, self).__init__()
20
        self.lstm = nn.LSTM(768, 768, 2, batch_first = True)
21
22
    def forward(self, hidden_states):
23
        print(hidden_states.shape)
24
        # We "pool" the model by simply taking all the hidden states and averaging them.
25
        pooled_output = self.lstm(hidden_states)
26
        return pooled_output[0]