--- a +++ b/allennlp/loader.py @@ -0,0 +1,139 @@ +from allennlp.predictors.predictor import Predictor +from allennlp.predictors.text_classifier import TextClassifierPredictor +from allennlp.predictors.sentence_tagger import SentenceTaggerPredictor + +from allennlp.models.archival import * +from allennlp.data import DatasetReader +from allennlp.common.params import Params +from allennlp.models import Model +import allennlp_models.classification +import allennlp_models.tagging + +import torch +import os +import shutil + + +def rm_tmp(tmp_start): + # Remove new directories in tmp + tmp_now = os.listdir('/tmp') + for i in tmp_now: + if i not in tmp_start: + print('removing directory /tmp/' + i) + shutil.rmtree('/tmp/' + i) + + +def get_config(archive_path): + archive = load_archive(archive_path) + config = archive.config.duplicate() + return config + + +def download_glove(): + # Saves TextClassifierPredictor object + tmp_start = os.listdir('/tmp') + glove_predictor = Predictor.from_path("https://storage.googleapis.com/allennlp-public-models/basic_stanford_sentiment_treebank-2020.06.09.tar.gz") + glove_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'glove_sentiment_predictor.txt') + torch.save(glove_predictor, glove_path) + try: + rm_tmp(tmp_start) + except: + pass + return glove_predictor + + +def download_roberta(): + tmp_start = os.listdir('/tmp') + archive_path = "https://storage.googleapis.com/allennlp-public-models/sst-roberta-large-2020.06.08.tar.gz" + config = get_config(archive_path) + roberta_predictor = Predictor.from_path(archive_path) + + serialization_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'roberta', '') + if os.path.exists(serialization_dir): + shutil.rmtree(serialization_dir) + os.makedirs(serialization_dir) + + # Create config and model files + config.to_file(os.path.join(serialization_dir, 'config.json')) + with open(os.path.join(serialization_dir, 'whole_model.pt'), 'wb') as file: + torch.save(roberta_predictor._model, file) + + try: + rm_tmp(tmp_start) + except: + pass + return roberta_predictor + + +def download_ner(): + tmp_start = os.listdir('/tmp') + archive_path = "https://storage.googleapis.com/allennlp-public-models/ner-model-2020.02.10.tar.gz" + config = get_config(archive_path) + ner_predictor = Predictor.from_path(archive_path) + + serialization_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'elmo-ner', '') + if os.path.exists(serialization_dir): + shutil.rmtree(serialization_dir) + os.makedirs(serialization_dir) + + config.to_file(os.path.join(serialization_dir, 'config.json')) + vocab = ner_predictor._model.vocab + vocab.save_to_files(os.path.join(serialization_dir, 'vocabulary')) + with open(os.path.join(serialization_dir, 'whole_model.pt'), 'wb') as file: + torch.save(ner_predictor._model.state_dict(), file) + + try: + rm_tmp(tmp_start) + except: + pass + return ner_predictor + + +def load_glove(): + # Loads GLOVE model + glove_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "glove_sentiment_predictor.txt") + if os.path.exists(glove_path): # same dir for github + print('Loading Glove Sentiment Analysis Model') + predictor = torch.load(glove_path) + else: + print('Downloading Glove Sentiment Analysis Model') + predictor = download_glove() + return predictor + + +def load_roberta(): + # Loads Roberta model + serialization_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'roberta', '') + config_file = os.path.join(serialization_dir, 'config.json') + if os.path.exists(config_file): + print('Loading Roberta Sentiment Analysis Model') + model_file = os.path.join(serialization_dir, 'whole_model.pt') + model = torch.load(model_file) + loaded_params = Params.from_file(config_file) + dataset_reader = DatasetReader.from_params(loaded_params.get('dataset_reader')) + + # Gets predictor from model and dataset reader + predictor = TextClassifierPredictor(model, dataset_reader) + else: + print('Downloading Roberta Sentiment Analysis Model') + predictor = download_roberta() + return predictor + + +def load_ner(): + serialization_dir = "../allennlp/elmo-ner" + + config_file = os.path.join(serialization_dir, 'config.json') + weights_file = os.path.join(serialization_dir, 'whole_model.pt') + loaded_params = Params.from_file(config_file) + loaded_model = Model.load(loaded_params, serialization_dir, weights_file) + dataset_reader = DatasetReader.from_params(loaded_params.get('dataset_reader')) + + predictor = SentenceTaggerPredictor(loaded_model, dataset_reader) + return predictor + +if __name__ == "__main__": + download_glove() + download_roberta() + download_ner() +