Switch to unified view

a b/src/features/iob_feature.py
1
# Base Dependencies
2
# ----------------
3
import numpy as np
4
from typing import List, Any, Optional
5
6
# Local Dependencies
7
# ------------------
8
from models import RelationCollection
9
10
# 3rd-Party Dependencies
11
# ----------------------
12
from sklearn.base import BaseEstimator
13
14
# Constants
15
# ---------
16
from constants import DATASETS, DDI_IOB_TAGS, N2C2_IOB_TAGS
17
18
19
class IOBFeature(BaseEstimator):
20
    """
21
    IOB encoding
22
23
    Obtains the IOB tag of each token in the relation's sentence.
24
    """
25
26
    def __init__(self, dataset: str, padding_idx: Optional[int] = None):
27
        """
28
        Args:
29
            dataset (str): dataset name
30
            padding_idx (int, default = 0): index that will be used for padding
31
        """
32
        if dataset not in DATASETS:
33
            raise ValueError("unsupported dataset '{}'".format(dataset))
34
35
        self.dataset = dataset
36
        self.iob_tags = N2C2_IOB_TAGS if dataset == "n2c2" else DDI_IOB_TAGS
37
        self.padding_idx = padding_idx
38
39
    def get_feature_names(self, input_features=None):
40
        """
41
        Gets the name of the feature
42
        """
43
        return ["IOB"]
44
45
    def create_iob_feature(self, collection: RelationCollection) -> List[List[int]]:
46
        """
47
        Computes the IOB encoding for a list of relations.
48
49
        Args:
50
            relations (List[Relation]): list of relations
51
52
        Returns:
53
            IOB encoding of the relations' sentence
54
        """
55
        iob_all = []
56
        o_index = self.iob_index("O")
57
58
        for i in range(len(collection)):
59
60
            # IOB of entity1
61
            B_e1 = self.iob_index("B-" + collection.relations[i].entity1.type)
62
            I_e1 = self.iob_index("I-" + collection.relations[i].entity1.type)
63
            iob_e1 = [B_e1] + ([I_e1] * (len(collection.entities1_tokens[i]) - 1))
64
65
            # IOB of entity2
66
            B_e2 = self.iob_index("B-" + collection.relations[i].entity2.type)
67
            I_e2 = self.iob_index("I-" + collection.relations[i].entity2.type)
68
            iob_e2 = [B_e2] + ([I_e2] * (len(collection.entities2_tokens[i]) - 1))
69
70
            iob_sent = (
71
                ([o_index] * len(collection.left_tokens[i]))
72
                + iob_e1
73
                + ([o_index] * len(collection.middle_tokens[i]))
74
                + iob_e2
75
                + ([o_index] * len(collection.right_tokens[i]))
76
            )
77
78
            iob_all.append(np.array(iob_sent))
79
80
        return iob_all
81
82
    def iob_index(self, iob_tag: str):
83
        """
84
        Computes the index of the corresponding IOB tag
85
        """
86
        idx = self.iob_tags.index(iob_tag)
87
88
        if self.padding_idx is not None and idx >= self.padding_idx:
89
            idx += 1
90
        return idx
91
92
    def fit(self, x: RelationCollection, y: Any = None):
93
        return self
94
95
    def transform(self, x: RelationCollection) -> list:
96
        return self.create_iob_feature(x)
97
98
    def fit_transform(self, x: RelationCollection, y: Any = None) -> list:
99
        return self.create_iob_feature(x)