a b/pretrain/extract_bert.py
1
import torch
2
import sys
3
import os
4
5
6
model = torch.load(sys.argv[1], map_location=torch.device('cpu'))
7
bert_model = model.bert
8
torch.save(bert_model, sys.argv[2])