[c3444c]: / pretrain / extract_bert.py

Download this file

9 lines (6 with data), 170 Bytes

1
2
3
4
5
6
7
8
import torch
import sys
import os
model = torch.load(sys.argv[1], map_location=torch.device('cpu'))
bert_model = model.bert
torch.save(bert_model, sys.argv[2])