Switch to unified view

a b/src/features/pos_feature.py
1
# Base Dependencies
2
# ----------------
3
from typing import Optional
4
5
# Local Dependencies
6
# ------------------
7
from models import RelationCollection
8
9
# 3rd-Party Dependencies
10
# ----------------------
11
from sklearn.base import BaseEstimator
12
13
# Constants
14
# ---------
15
from constants import U_POS_TAGS
16
17
18
class POSFeature(BaseEstimator):
19
    """
20
    PoS Tagging
21
22
    Obtains the universal POS tag of each token in the relation's sentence.
23
    """
24
25
    def __init__(self, padding_idx: Optional[int] = None):
26
        self.padding_idx = padding_idx
27
28
    def get_feature_names(self, input_features=None):
29
        return ["POS"]
30
31
    def create_pos_feature(self, collection: RelationCollection) -> list:
32
        all_pos = []
33
34
        for doc in collection.tokens:
35
            r_pos = []
36
            for t in doc:
37
                r_pos.append(self.pos_index(t.pos_))
38
39
            all_pos.append(r_pos)
40
41
        return all_pos
42
43
    def pos_index(self, pos_tag: str):
44
        """
45
        Computes the index of the corresponding POS tag
46
        """
47
        idx = U_POS_TAGS.index(pos_tag)
48
49
        if self.padding_idx is not None and idx >= self.padding_idx:
50
            idx += 1
51
        return idx
52
53
    def fit(self, x: RelationCollection, y=None):
54
        return self
55
56
    def transform(self, x: RelationCollection) -> list:
57
        return self.create_pos_feature(x)
58
59
    def fit_transform(self, x: RelationCollection, y=None) -> list:
60
        return self.create_pos_feature(x)