Switch to unified view

a b/src/re_datasets/bilstm_utils.py
1
# Base Dependencies
2
# -----------------
3
import numpy as np
4
from typing import Dict, List
5
6
# PyTorch Dependencies
7
# ---------------------
8
import torch
9
from torch import Tensor
10
11
12
# Auxiliar Functions
13
# -------------------
14
def sort_batch(
15
    batch: Dict[str, List[List[float]]], lengths: List[List[float]]
16
) -> Dict[str, List[List[float]]]:
17
    """
18
    Sort a minibatch by the length of the sequences with the longest sequences first
19
    return the sorted batch targes and sequence lengths. This way the output can be used by pack_padded_sequences(...)
20
21
    Args:
22
        batch (Dict[str,  List[List[float]]]): batch of data
23
24
    Return:
25
        Dict[str,  List[List[float]]]: batch of data ordered in descending order of sequence length.
26
27
    """
28
    perm_idx = np.argsort(-lengths)
29
30
    for key in batch.keys():
31
        batch[key] = batch[key][perm_idx]
32
33
    return batch
34
35
36
def pad_seqs(
37
    seqs: List[List[float]], lengths: List[int], padding_idx: int
38
) -> List[List[float]]:
39
    """Pads sequences
40
41
    Args:
42
        seqs (List[List[float]]): sequences of different lengths
43
        lengths (List[int]): length of each sequence
44
        padding_idx (int): value used for padding
45
46
    Returns:
47
        List[List[float]]: padded sequences
48
    """
49
    batch_size = len(lengths)
50
    max_length = max(lengths)
51
52
    padded_seqs = np.full(
53
        shape=(batch_size, max_length), fill_value=padding_idx, dtype=np.int32
54
    )
55
56
    for i, l in enumerate(lengths):
57
        padded_seqs[i, 0:l] = seqs[i]
58
59
    return padded_seqs
60
61
62
def pad_and_sort_batch(batch: Dict, padding_idx: int, rd_max: int) -> Dict[str, Tensor]:
63
    """
64
    DataLoaderBatch should be a list of (sequence, target, length) tuples...
65
    Returns a padded tensor of sequences sorted from longest to shortest,
66
    """
67
68
    for key in ["char_length", "seq_length", "label"]:
69
        batch[key] = np.array(batch[key])
70
71
    for key in ["e1", "e2"]:
72
        seqs = batch[key]
73
        # pad entities apart to avoid unnecessary padding
74
        lengths = list(map(lambda x: len(x), seqs))
75
        batch[key] = pad_seqs(seqs, lengths, padding_idx)
76
77
    for key in ["rd1", "rd2"]:
78
        seqs = batch[key]
79
        # pad relative distance with maximum value
80
        batch[key] = pad_seqs(seqs, batch["seq_length"], rd_max)
81
82
    for key in ["sent", "iob", "pos", "dep"]:
83
        seqs = batch[key]
84
        # pad other features with the common padding index
85
        batch[key] = pad_seqs(seqs, batch["seq_length"], padding_idx)
86
87
    return sort_batch(batch, batch["seq_length"])
88
89
90
def custom_collate(data: Dict[str, List[List[float]]]):
91
    """Separates the inputs and the targets
92
93
    Args:
94
        data (Dict[str, List[List[float]]]): batch of data
95
96
    Returns:
97
        Tuple[Dict[str, Tensor], Tensor]: inputs and targets.
98
99
    """
100
    inputs = {}
101
    targets = torch.from_numpy(data[0]["label"]).long()
102
    for key, value in data[0].items():
103
        if key != "label":
104
            inputs[key] = torch.from_numpy(value).long()
105
    return inputs, targets