Diff of /allennlp/loader.py [000000] .. [2d4573]

Switch to unified view

a b/allennlp/loader.py
1
from allennlp.predictors.predictor import Predictor
2
from allennlp.predictors.text_classifier import TextClassifierPredictor
3
from allennlp.predictors.sentence_tagger import SentenceTaggerPredictor
4
5
from allennlp.models.archival import *
6
from allennlp.data import DatasetReader
7
from allennlp.common.params import Params
8
from allennlp.models import Model
9
import allennlp_models.classification
10
import allennlp_models.tagging
11
12
import torch
13
import os
14
import shutil
15
16
17
def rm_tmp(tmp_start):
18
    # Remove new directories in tmp
19
    tmp_now = os.listdir('/tmp')
20
    for i in tmp_now:
21
        if i not in tmp_start:
22
            print('removing directory /tmp/' + i)
23
            shutil.rmtree('/tmp/' + i)
24
25
26
def get_config(archive_path):
27
    archive = load_archive(archive_path)
28
    config = archive.config.duplicate()
29
    return config
30
31
32
def download_glove():
33
    # Saves TextClassifierPredictor object
34
    tmp_start = os.listdir('/tmp')
35
    glove_predictor = Predictor.from_path("https://storage.googleapis.com/allennlp-public-models/basic_stanford_sentiment_treebank-2020.06.09.tar.gz")
36
    glove_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'glove_sentiment_predictor.txt')
37
    torch.save(glove_predictor, glove_path)
38
    try:
39
        rm_tmp(tmp_start)
40
    except:
41
        pass
42
    return glove_predictor
43
44
45
def download_roberta():
46
    tmp_start = os.listdir('/tmp')
47
    archive_path = "https://storage.googleapis.com/allennlp-public-models/sst-roberta-large-2020.06.08.tar.gz"
48
    config = get_config(archive_path)
49
    roberta_predictor = Predictor.from_path(archive_path)
50
51
    serialization_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'roberta', '')
52
    if os.path.exists(serialization_dir):
53
        shutil.rmtree(serialization_dir)
54
    os.makedirs(serialization_dir)
55
56
    # Create config and model files
57
    config.to_file(os.path.join(serialization_dir, 'config.json'))
58
    with open(os.path.join(serialization_dir, 'whole_model.pt'), 'wb') as file:
59
        torch.save(roberta_predictor._model, file)
60
61
    try:
62
        rm_tmp(tmp_start)
63
    except:
64
        pass
65
    return roberta_predictor
66
67
68
def download_ner():
69
    tmp_start = os.listdir('/tmp')
70
    archive_path = "https://storage.googleapis.com/allennlp-public-models/ner-model-2020.02.10.tar.gz"
71
    config = get_config(archive_path)
72
    ner_predictor = Predictor.from_path(archive_path)
73
74
    serialization_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'elmo-ner', '')
75
    if os.path.exists(serialization_dir):
76
        shutil.rmtree(serialization_dir)
77
    os.makedirs(serialization_dir)
78
79
    config.to_file(os.path.join(serialization_dir, 'config.json'))
80
    vocab = ner_predictor._model.vocab
81
    vocab.save_to_files(os.path.join(serialization_dir, 'vocabulary'))
82
    with open(os.path.join(serialization_dir, 'whole_model.pt'), 'wb') as file:
83
        torch.save(ner_predictor._model.state_dict(), file)
84
85
    try:
86
        rm_tmp(tmp_start)
87
    except:
88
        pass
89
    return ner_predictor
90
91
92
def load_glove():
93
    # Loads GLOVE model
94
    glove_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "glove_sentiment_predictor.txt")
95
    if os.path.exists(glove_path):  # same dir for github
96
        print('Loading Glove Sentiment Analysis Model')
97
        predictor = torch.load(glove_path)
98
    else:
99
        print('Downloading Glove Sentiment Analysis Model')
100
        predictor = download_glove()
101
    return predictor
102
103
104
def load_roberta():
105
    # Loads Roberta model
106
    serialization_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'roberta', '')
107
    config_file = os.path.join(serialization_dir, 'config.json')
108
    if os.path.exists(config_file):
109
        print('Loading Roberta Sentiment Analysis Model')
110
        model_file = os.path.join(serialization_dir, 'whole_model.pt')
111
        model = torch.load(model_file)
112
        loaded_params = Params.from_file(config_file)
113
        dataset_reader = DatasetReader.from_params(loaded_params.get('dataset_reader'))
114
115
        # Gets predictor from model and dataset reader
116
        predictor = TextClassifierPredictor(model, dataset_reader)
117
    else:
118
        print('Downloading Roberta Sentiment Analysis Model')
119
        predictor = download_roberta()
120
    return predictor
121
122
123
def load_ner():
124
    serialization_dir = "../allennlp/elmo-ner"
125
126
    config_file = os.path.join(serialization_dir, 'config.json')
127
    weights_file = os.path.join(serialization_dir, 'whole_model.pt')
128
    loaded_params = Params.from_file(config_file)
129
    loaded_model = Model.load(loaded_params, serialization_dir, weights_file)
130
    dataset_reader = DatasetReader.from_params(loaded_params.get('dataset_reader'))
131
132
    predictor = SentenceTaggerPredictor(loaded_model, dataset_reader)
133
    return predictor
134
135
if __name__ == "__main__":
136
    download_glove()
137
    download_roberta()
138
    download_ner()
139