|
a |
|
b/allennlp/loader.py |
|
|
1 |
from allennlp.predictors.predictor import Predictor |
|
|
2 |
from allennlp.predictors.text_classifier import TextClassifierPredictor |
|
|
3 |
from allennlp.predictors.sentence_tagger import SentenceTaggerPredictor |
|
|
4 |
|
|
|
5 |
from allennlp.models.archival import * |
|
|
6 |
from allennlp.data import DatasetReader |
|
|
7 |
from allennlp.common.params import Params |
|
|
8 |
from allennlp.models import Model |
|
|
9 |
import allennlp_models.classification |
|
|
10 |
import allennlp_models.tagging |
|
|
11 |
|
|
|
12 |
import torch |
|
|
13 |
import os |
|
|
14 |
import shutil |
|
|
15 |
|
|
|
16 |
|
|
|
17 |
def rm_tmp(tmp_start): |
|
|
18 |
# Remove new directories in tmp |
|
|
19 |
tmp_now = os.listdir('/tmp') |
|
|
20 |
for i in tmp_now: |
|
|
21 |
if i not in tmp_start: |
|
|
22 |
print('removing directory /tmp/' + i) |
|
|
23 |
shutil.rmtree('/tmp/' + i) |
|
|
24 |
|
|
|
25 |
|
|
|
26 |
def get_config(archive_path): |
|
|
27 |
archive = load_archive(archive_path) |
|
|
28 |
config = archive.config.duplicate() |
|
|
29 |
return config |
|
|
30 |
|
|
|
31 |
|
|
|
32 |
def download_glove(): |
|
|
33 |
# Saves TextClassifierPredictor object |
|
|
34 |
tmp_start = os.listdir('/tmp') |
|
|
35 |
glove_predictor = Predictor.from_path("https://storage.googleapis.com/allennlp-public-models/basic_stanford_sentiment_treebank-2020.06.09.tar.gz") |
|
|
36 |
glove_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'glove_sentiment_predictor.txt') |
|
|
37 |
torch.save(glove_predictor, glove_path) |
|
|
38 |
try: |
|
|
39 |
rm_tmp(tmp_start) |
|
|
40 |
except: |
|
|
41 |
pass |
|
|
42 |
return glove_predictor |
|
|
43 |
|
|
|
44 |
|
|
|
45 |
def download_roberta(): |
|
|
46 |
tmp_start = os.listdir('/tmp') |
|
|
47 |
archive_path = "https://storage.googleapis.com/allennlp-public-models/sst-roberta-large-2020.06.08.tar.gz" |
|
|
48 |
config = get_config(archive_path) |
|
|
49 |
roberta_predictor = Predictor.from_path(archive_path) |
|
|
50 |
|
|
|
51 |
serialization_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'roberta', '') |
|
|
52 |
if os.path.exists(serialization_dir): |
|
|
53 |
shutil.rmtree(serialization_dir) |
|
|
54 |
os.makedirs(serialization_dir) |
|
|
55 |
|
|
|
56 |
# Create config and model files |
|
|
57 |
config.to_file(os.path.join(serialization_dir, 'config.json')) |
|
|
58 |
with open(os.path.join(serialization_dir, 'whole_model.pt'), 'wb') as file: |
|
|
59 |
torch.save(roberta_predictor._model, file) |
|
|
60 |
|
|
|
61 |
try: |
|
|
62 |
rm_tmp(tmp_start) |
|
|
63 |
except: |
|
|
64 |
pass |
|
|
65 |
return roberta_predictor |
|
|
66 |
|
|
|
67 |
|
|
|
68 |
def download_ner(): |
|
|
69 |
tmp_start = os.listdir('/tmp') |
|
|
70 |
archive_path = "https://storage.googleapis.com/allennlp-public-models/ner-model-2020.02.10.tar.gz" |
|
|
71 |
config = get_config(archive_path) |
|
|
72 |
ner_predictor = Predictor.from_path(archive_path) |
|
|
73 |
|
|
|
74 |
serialization_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'elmo-ner', '') |
|
|
75 |
if os.path.exists(serialization_dir): |
|
|
76 |
shutil.rmtree(serialization_dir) |
|
|
77 |
os.makedirs(serialization_dir) |
|
|
78 |
|
|
|
79 |
config.to_file(os.path.join(serialization_dir, 'config.json')) |
|
|
80 |
vocab = ner_predictor._model.vocab |
|
|
81 |
vocab.save_to_files(os.path.join(serialization_dir, 'vocabulary')) |
|
|
82 |
with open(os.path.join(serialization_dir, 'whole_model.pt'), 'wb') as file: |
|
|
83 |
torch.save(ner_predictor._model.state_dict(), file) |
|
|
84 |
|
|
|
85 |
try: |
|
|
86 |
rm_tmp(tmp_start) |
|
|
87 |
except: |
|
|
88 |
pass |
|
|
89 |
return ner_predictor |
|
|
90 |
|
|
|
91 |
|
|
|
92 |
def load_glove(): |
|
|
93 |
# Loads GLOVE model |
|
|
94 |
glove_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "glove_sentiment_predictor.txt") |
|
|
95 |
if os.path.exists(glove_path): # same dir for github |
|
|
96 |
print('Loading Glove Sentiment Analysis Model') |
|
|
97 |
predictor = torch.load(glove_path) |
|
|
98 |
else: |
|
|
99 |
print('Downloading Glove Sentiment Analysis Model') |
|
|
100 |
predictor = download_glove() |
|
|
101 |
return predictor |
|
|
102 |
|
|
|
103 |
|
|
|
104 |
def load_roberta(): |
|
|
105 |
# Loads Roberta model |
|
|
106 |
serialization_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'roberta', '') |
|
|
107 |
config_file = os.path.join(serialization_dir, 'config.json') |
|
|
108 |
if os.path.exists(config_file): |
|
|
109 |
print('Loading Roberta Sentiment Analysis Model') |
|
|
110 |
model_file = os.path.join(serialization_dir, 'whole_model.pt') |
|
|
111 |
model = torch.load(model_file) |
|
|
112 |
loaded_params = Params.from_file(config_file) |
|
|
113 |
dataset_reader = DatasetReader.from_params(loaded_params.get('dataset_reader')) |
|
|
114 |
|
|
|
115 |
# Gets predictor from model and dataset reader |
|
|
116 |
predictor = TextClassifierPredictor(model, dataset_reader) |
|
|
117 |
else: |
|
|
118 |
print('Downloading Roberta Sentiment Analysis Model') |
|
|
119 |
predictor = download_roberta() |
|
|
120 |
return predictor |
|
|
121 |
|
|
|
122 |
|
|
|
123 |
def load_ner(): |
|
|
124 |
serialization_dir = "../allennlp/elmo-ner" |
|
|
125 |
|
|
|
126 |
config_file = os.path.join(serialization_dir, 'config.json') |
|
|
127 |
weights_file = os.path.join(serialization_dir, 'whole_model.pt') |
|
|
128 |
loaded_params = Params.from_file(config_file) |
|
|
129 |
loaded_model = Model.load(loaded_params, serialization_dir, weights_file) |
|
|
130 |
dataset_reader = DatasetReader.from_params(loaded_params.get('dataset_reader')) |
|
|
131 |
|
|
|
132 |
predictor = SentenceTaggerPredictor(loaded_model, dataset_reader) |
|
|
133 |
return predictor |
|
|
134 |
|
|
|
135 |
if __name__ == "__main__": |
|
|
136 |
download_glove() |
|
|
137 |
download_roberta() |
|
|
138 |
download_ner() |
|
|
139 |
|