--- a +++ b/bilstm_crf_ner/model/crf.py @@ -0,0 +1,301 @@ +from typing import List, Optional, Union + +from torch.autograd import Variable +import torch +import torch.nn as nn + + +class CRF(nn.Module): + """Conditional random field. + This module implements a conditional random field [LMP]. The forward computation + of this class computes the log likelihood of the given sequence of tags and + emission score tensor. This class also has ``decode`` method which finds the + best tag sequence given an emission score tensor using `Viterbi algorithm`_. + Arguments + --------- + num_tags : int + Number of tags. + Attributes + ---------- + num_tags : int + Number of tags passed to ``__init__``. + start_transitions : :class:`~torch.nn.Parameter` + Start transition score tensor of size ``(num_tags,)``. + end_transitions : :class:`~torch.nn.Parameter` + End transition score tensor of size ``(num_tags,)``. + transitions : :class:`~torch.nn.Parameter` + Transition score tensor of size ``(num_tags, num_tags)``. + References + ---------- + .. [LMP] Lafferty, J., McCallum, A., Pereira, F. (2001). + "Conditional random fields: Probabilistic models for segmenting and + labeling sequence data". *Proc. 18th International Conf. on Machine + Learning*. Morgan Kaufmann. pp. 282–289. + .. _Viterbi algorithm: https://en.wikipedia.org/wiki/Viterbi_algorithm + """ + def __init__(self, num_tags: int) -> None: + if num_tags <= 0: + raise ValueError(f'invalid number of tags: {num_tags}') + super().__init__() + self.num_tags = num_tags + self.start_transitions = nn.Parameter(torch.Tensor(num_tags)) + self.end_transitions = nn.Parameter(torch.Tensor(num_tags)) + self.transitions = nn.Parameter(torch.Tensor(num_tags, num_tags)) + + self.reset_parameters() + + def reset_parameters(self) -> None: + """Initialize the transition parameters. + The parameters will be initialized randomly from a uniform distribution + between -0.1 and 0.1. + """ + nn.init.uniform(self.start_transitions, -0.1, 0.1) + nn.init.uniform(self.end_transitions, -0.1, 0.1) + nn.init.uniform(self.transitions, -0.1, 0.1) + + def __repr__(self) -> str: + return f'{self.__class__.__name__}(num_tags={self.num_tags})' + + def forward(self, + emissions: Variable, + tags: Variable, + mask: Optional[Variable] = None, + reduce: bool = True, + ) -> Variable: + """Compute the log likelihood of the given sequence of tags and emission score. + Arguments + --------- + emissions : :class:`~torch.autograd.Variable` + Emission score tensor of size ``(seq_length, batch_size, num_tags)``. + tags : :class:`~torch.autograd.Variable` + Sequence of tags as ``LongTensor`` of size ``(seq_length, batch_size)``. + mask : :class:`~torch.autograd.Variable`, optional + Mask tensor as ``ByteTensor`` of size ``(seq_length, batch_size)``. + reduce : bool + Whether to sum the log likelihood over the batch. + Returns + ------- + :class:`~torch.autograd.Variable` + The log likelihood. This will have size (1,) if ``reduce=True``, ``(batch_size,)`` + otherwise. + """ + if emissions.dim() != 3: + raise ValueError(f'emissions must have dimension of 3, got {emissions.dim()}') + if tags.dim() != 2: + raise ValueError(f'tags must have dimension of 2, got {tags.dim()}') + if emissions.size()[:2] != tags.size(): + raise ValueError( + 'the first two dimensions of emissions and tags must match, ' + f'got {tuple(emissions.size()[:2])} and {tuple(tags.size())}' + ) + if emissions.size(2) != self.num_tags: + raise ValueError( + f'expected last dimension of emissions is {self.num_tags}, ' + f'got {emissions.size(2)}' + ) + if mask is not None: + if tags.size() != mask.size(): + raise ValueError( + f'size of tags and mask must match, got {tuple(tags.size())} ' + f'and {tuple(mask.size())}' + ) + if not all(mask[0].data): + raise ValueError('mask of the first timestep must all be on') + + if mask is None: + mask = Variable(self._new(tags.size()).fill_(1)).byte() + + numerator = self._compute_joint_llh(emissions, tags, mask) + denominator = self._compute_log_partition_function(emissions, mask) + llh = numerator - denominator + return llh if not reduce else torch.sum(llh) + + def decode(self, + emissions: Union[Variable, torch.FloatTensor], + mask: Optional[Union[Variable, torch.ByteTensor]] = None) -> List[List[int]]: + """Find the most likely tag sequence using Viterbi algorithm. + Arguments + --------- + emissions : :class:`~torch.autograd.Variable` or :class:`~torch.FloatTensor` + Emission score tensor of size ``(seq_length, batch_size, num_tags)``. + mask : :class:`~torch.autograd.Variable` or :class:`torch.ByteTensor` + Mask tensor of size ``(seq_length, batch_size)``. + Returns + ------- + list + List of list containing the best tag sequence for each batch. + """ + if emissions.dim() != 3: + raise ValueError(f'emissions must have dimension of 3, got {emissions.dim()}') + if emissions.size(2) != self.num_tags: + raise ValueError( + f'expected last dimension of emissions is {self.num_tags}, ' + f'got {emissions.size(2)}' + ) + if mask is not None and emissions.size()[:2] != mask.size(): + raise ValueError( + 'the first two dimensions of emissions and mask must match, ' + f'got {tuple(emissions.size()[:2])} and {tuple(mask.size())}' + ) + + if isinstance(emissions, Variable): + emissions = emissions.data + if mask is None: + mask = self._new(emissions.size()[:2]).fill_(1).byte() + elif isinstance(mask, Variable): + mask = mask.data + + return self._viterbi_decode(emissions, mask) + + def _compute_joint_llh(self, + emissions: Variable, + tags: Variable, + mask: Variable) -> Variable: + # emissions: (seq_length, batch_size, num_tags) + # tags: (seq_length, batch_size) + # mask: (seq_length, batch_size) + assert emissions.dim() == 3 and tags.dim() == 2 + assert emissions.size()[:2] == tags.size() + assert emissions.size(2) == self.num_tags + assert mask.size() == tags.size() + assert all(mask[0].data) + + seq_length = emissions.size(0) + mask = mask.float() + + # Start transition score + llh = self.start_transitions[tags[0]] # (batch_size,) + + for i in range(seq_length - 1): + cur_tag, next_tag = tags[i], tags[i+1] + # Emission score for current tag + llh += emissions[i].gather(1, cur_tag.view(-1, 1)).squeeze(1) * mask[i] + # Transition score to next tag + transition_score = self.transitions[cur_tag, next_tag] + # Only add transition score if the next tag is not masked (mask == 1) + llh += transition_score * mask[i+1] + + # Find last tag index + last_tag_indices = mask.long().sum(0) - 1 # (batch_size,) + last_tags = tags.gather(0, last_tag_indices.view(1, -1)).squeeze(0) + + # End transition score + llh += self.end_transitions[last_tags] + # Emission score for the last tag, if mask is valid (mask == 1) + llh += emissions[-1].gather(1, last_tags.view(-1, 1)).squeeze(1) * mask[-1] + + return llh + + def _compute_log_partition_function(self, + emissions: Variable, + mask: Variable) -> Variable: + # emissions: (seq_length, batch_size, num_tags) + # mask: (seq_length, batch_size) + assert emissions.dim() == 3 and mask.dim() == 2 + assert emissions.size()[:2] == mask.size() + assert emissions.size(2) == self.num_tags + assert all(mask[0].data) + + seq_length = emissions.size(0) + mask = mask.float() + + # Start transition score and first emission + log_prob = self.start_transitions.view(1, -1) + emissions[0] + # Here, log_prob has size (batch_size, num_tags) where for each batch, + # the j-th column stores the log probability that the current timestep has tag j + + for i in range(1, seq_length): + # Broadcast log_prob over all possible next tags + broadcast_log_prob = log_prob.unsqueeze(2) # (batch_size, num_tags, 1) + # Broadcast transition score over all instances in the batch + broadcast_transitions = self.transitions.unsqueeze(0) # (1, num_tags, num_tags) + # Broadcast emission score over all possible current tags + broadcast_emissions = emissions[i].unsqueeze(1) # (batch_size, 1, num_tags) + # Sum current log probability, transition, and emission scores + score = broadcast_log_prob + broadcast_transitions \ + + broadcast_emissions # (batch_size, num_tags, num_tags) + # Sum over all possible current tags, but we're in log prob space, so a sum + # becomes a log-sum-exp + score = self._log_sum_exp(score, 1) # (batch_size, num_tags) + # Set log_prob to the score if this timestep is valid (mask == 1), otherwise + # leave it alone + log_prob = score * mask[i].unsqueeze(1) + log_prob * (1.-mask[i]).unsqueeze(1) + + # End transition score + log_prob += self.end_transitions.view(1, -1) + # Sum (log-sum-exp) over all possible tags + return self._log_sum_exp(log_prob, 1) # (batch_size,) + + def _viterbi_decode(self, emissions: torch.FloatTensor, mask: torch.ByteTensor) \ + -> List[List[int]]: + # Get input sizes + seq_length = emissions.size(0) + batch_size = emissions.size(1) + sequence_lengths = mask.long().sum(dim=0) + + # emissions: (seq_length, batch_size, num_tags) + assert emissions.size(2) == self.num_tags + + # list to store the decoded paths + best_tags_list = [] + + # Start transition + viterbi_score = [] + viterbi_score.append(self.start_transitions.data + emissions[0]) + viterbi_path = [] + + # Here, viterbi_score is a list of tensors of shapes of (num_tags,) where value at + # index i stores the score of the best tag sequence so far that ends with tag i + # viterbi_path saves where the best tags candidate transitioned from; this is used + # when we trace back the best tag sequence + + # Viterbi algorithm recursive case: we compute the score of the best tag sequence + # for every possible next tag + for i in range(1, seq_length): + # Broadcast viterbi score for every possible next tag + broadcast_score = viterbi_score[i - 1].view(batch_size, -1, 1) + # Broadcast emission score for every possible current tag + broadcast_emission = emissions[i].view(batch_size, 1, -1) + # Compute the score matrix of shape (batch_size, num_tags, num_tags) where + # for each sample, each entry at row i and column j stores the score of + # transitioning from tag i to tag j and emitting + score = broadcast_score + self.transitions.data + broadcast_emission + # Find the maximum score over all possible current tag + best_score, best_path = score.max(1) # (batch_size,num_tags,) + # Save the score and the path + viterbi_score.append(best_score) + viterbi_path.append(best_path) + + # Now, compute the best path for each sample + for idx in range(batch_size): + # Find the tag which maximizes the score at the last timestep; this is our best tag + # for the last timestep + seq_end = sequence_lengths[idx]-1 + _, best_last_tag = (viterbi_score[seq_end][idx] + self.end_transitions.data).max(0) + best_tags = [best_last_tag.item()] #[best_last_tag[0]] #[best_last_tag.item()] + + # We trace back where the best last tag comes from, append that to our best tag + # sequence, and trace it back again, and so on + for path in reversed(viterbi_path[:sequence_lengths[idx] - 1]): + best_last_tag = path[idx][best_tags[-1]] + best_tags.append(best_last_tag) + + # Reverse the order because we start from the last timestep + best_tags.reverse() + best_tags_list.append(best_tags) + return best_tags_list + + @staticmethod + def _log_sum_exp(tensor: Variable, dim: int) -> Variable: + # Find the max value along `dim` + offset, _ = tensor.max(dim) + # Make offset broadcastable + broadcast_offset = offset.unsqueeze(dim) + # Perform log-sum-exp safely + safe_log_sum_exp = torch.log(torch.sum(torch.exp(tensor - broadcast_offset), dim)) + # Add offset back + return offset + safe_log_sum_exp + + def _new(self, *args, **kwargs) -> torch.FloatTensor: + param = next(self.parameters()) + return param.data.new(*args, **kwargs)