[735bb5]: / src / re_datasets / bilstm_utils.py

Download this file

106 lines (78 with data), 3.0 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
# Base Dependencies
# -----------------
import numpy as np
from typing import Dict, List
# PyTorch Dependencies
# ---------------------
import torch
from torch import Tensor
# Auxiliar Functions
# -------------------
def sort_batch(
batch: Dict[str, List[List[float]]], lengths: List[List[float]]
) -> Dict[str, List[List[float]]]:
"""
Sort a minibatch by the length of the sequences with the longest sequences first
return the sorted batch targes and sequence lengths. This way the output can be used by pack_padded_sequences(...)
Args:
batch (Dict[str, List[List[float]]]): batch of data
Return:
Dict[str, List[List[float]]]: batch of data ordered in descending order of sequence length.
"""
perm_idx = np.argsort(-lengths)
for key in batch.keys():
batch[key] = batch[key][perm_idx]
return batch
def pad_seqs(
seqs: List[List[float]], lengths: List[int], padding_idx: int
) -> List[List[float]]:
"""Pads sequences
Args:
seqs (List[List[float]]): sequences of different lengths
lengths (List[int]): length of each sequence
padding_idx (int): value used for padding
Returns:
List[List[float]]: padded sequences
"""
batch_size = len(lengths)
max_length = max(lengths)
padded_seqs = np.full(
shape=(batch_size, max_length), fill_value=padding_idx, dtype=np.int32
)
for i, l in enumerate(lengths):
padded_seqs[i, 0:l] = seqs[i]
return padded_seqs
def pad_and_sort_batch(batch: Dict, padding_idx: int, rd_max: int) -> Dict[str, Tensor]:
"""
DataLoaderBatch should be a list of (sequence, target, length) tuples...
Returns a padded tensor of sequences sorted from longest to shortest,
"""
for key in ["char_length", "seq_length", "label"]:
batch[key] = np.array(batch[key])
for key in ["e1", "e2"]:
seqs = batch[key]
# pad entities apart to avoid unnecessary padding
lengths = list(map(lambda x: len(x), seqs))
batch[key] = pad_seqs(seqs, lengths, padding_idx)
for key in ["rd1", "rd2"]:
seqs = batch[key]
# pad relative distance with maximum value
batch[key] = pad_seqs(seqs, batch["seq_length"], rd_max)
for key in ["sent", "iob", "pos", "dep"]:
seqs = batch[key]
# pad other features with the common padding index
batch[key] = pad_seqs(seqs, batch["seq_length"], padding_idx)
return sort_batch(batch, batch["seq_length"])
def custom_collate(data: Dict[str, List[List[float]]]):
"""Separates the inputs and the targets
Args:
data (Dict[str, List[List[float]]]): batch of data
Returns:
Tuple[Dict[str, Tensor], Tensor]: inputs and targets.
"""
inputs = {}
targets = torch.from_numpy(data[0]["label"]).long()
for key, value in data[0].items():
if key != "label":
inputs[key] = torch.from_numpy(value).long()
return inputs, targets