[6c353a]: / medacy / pipeline_components / learners / bilstm_crf_learner.py

Download this file

163 lines (125 with data), 5.4 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
"""
BiLSTM+CRF PyTorch network and model.
"""
import logging
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from medacy.pipeline_components.learners.nn.bilstm_crf import BiLstmCrf
from medacy.pipeline_components.learners.nn.vectorizer import Vectorizer
class BiLstmCrfLearner:
"""
BiLSTM-CRF model class for using the network. Currently handles all vectorization as well.
:ivar device: PyTorch device to use.
:ivar model: Instance of BiLstmCrfNetwork to use.
:ivar word_embeddings_file: File to load word embeddings from.
:ivar word_vectors: Gensim word vectors object for use in configuring word embeddings.
"""
def __init__(self, word_embeddings, cuda_device):
"""Init BiLstmCrfLearner object.
:param word_embeddings: Path to word embeddings file to use.
:param cuda_device: Index of cuda device to use. Use -1 to use CPU.
"""
torch.manual_seed(1)
device_string = 'cuda:%d' % cuda_device if cuda_device >= 0 else 'cpu'
self.device = torch.device(device_string)
self.vectorizer = Vectorizer(self.device)
if word_embeddings is None:
raise ValueError('BiLSTM-CRF requires word embeddings.')
else:
self.word_embeddings_file = word_embeddings
self.word_vectors = None
# Other instance attributes
self.model = None
self.learning_rate = 0.01
# TODO: Implement cleaner way to handle this
if word_embeddings.endswith('test_word_embeddings.txt'):
self.epochs = 2
self.crf_delay = 1
else:
self.epochs = 40
self.crf_delay = 20
def fit(self, x_data, y_data):
"""Fully train model based on x and y data. self.model is set to trained model.
:param x_data: List of list of tokens.
:param y_data: List of list of correct labels for the tokens.
"""
if self.vectorizer.word_vectors is None:
self.vectorizer.load_word_embeddings(self.word_embeddings_file)
data = self.vectorizer.vectorize_dataset(x_data, y_data)
# Create network
model = BiLstmCrf(
self.vectorizer.word_vectors,
len(data[0][0][0][2:]),
len(self.vectorizer.tag_to_index),
self.device
)
# Move to GPU if possible
if self.device.type != 'cpu':
logging.info('CUDA available. Moving model to GPU.')
model = model.to(self.device)
# Setup optimizer and loss function
optimizer = optim.SGD(model.parameters(), lr=self.learning_rate)
loss_function = nn.NLLLoss()
logging.info('Training BiLSTM-CRF...')
# Training loop
for i in range(1, self.epochs + 1):
random.shuffle(data)
epoch_losses = []
for sentence, sentence_tags in data:
if i <= self.crf_delay:
emissions = model(sentence)
predictions = F.log_softmax(emissions, dim=1)
loss = loss_function(predictions, sentence_tags)
else:
emissions = model(sentence).unsqueeze(1)
sentence_tags = sentence_tags.unsqueeze(1)
loss = -model.crf(emissions, sentence_tags)
optimizer.zero_grad()
loss.backward()
optimizer.step()
epoch_losses.append(loss)
average_loss = sum(epoch_losses) / len(epoch_losses)
logging.info('Epoch %d average loss: %f', i, average_loss)
self.model = model
def predict(self, sequences):
"""Use model to make predictions over a given dataset.
:param sequences: Sequences to predict labels for.
:return: List of list of predicted labels.
"""
if not self.vectorizer.word_vectors:
raise RuntimeError('Loading word embeddings is required.')
with torch.no_grad():
predictions = []
for sequence in sequences:
vectorized_tokens = self.vectorizer.vectorize_tokens(sequence)
emissions = self.model(vectorized_tokens).unsqueeze(1)
tag_indices = self.model.crf.decode(emissions)
predictions.append(self.vectorizer.devectorize_tag(tag_indices[0]))
return predictions
def save(self, path):
"""Save model and other required values.
:param path: Path to save model to.
"""
vectorizer_values = self.vectorizer.get_values()
properties = {
'model': self.model,
'vectorizer_values': vectorizer_values,
}
if not path.endswith('.pt'):
path += '.pt'
torch.save(properties, path)
def load(self, path):
"""Load model and other required values from given path.
:param path: Path of saved model.
"""
saved_data = torch.load(path, map_location=self.device)
self.vectorizer.load_values(saved_data['vectorizer_values'])
model = saved_data['model']
model.device = self.device
model.eval()
self.model = model.to(self.device)
self.model.device = self.device
self.vectorizer.load_word_embeddings(self.word_embeddings_file)