--- a +++ b/src/bert/bert_model.py @@ -0,0 +1,22 @@ +import torch.nn as nn +import transformers + + +class BERTclassifier(nn.Module): + def __init__(self, bert_freeze=False): + super().__init__() + self.bert = transformers.BertModel.from_pretrained("bert-base-uncased") + self.drop = nn.Dropout(0.3) + self.out = nn.Linear(768, 10) + self.act = nn.Sigmoid() + + if bert_freeze: + for param in self.bert.parameters(): + param.requires_grad = False + + def forward(self, ids, mask, token_type_ids): + outputs = self.bert(ids, attention_mask=mask, token_type_ids=token_type_ids) + x = self.drop(outputs[0][:, 0, :]) + x = self.out(x) + x = self.act(x) + return x