|
a |
|
b/src/features/pos_feature.py |
|
|
1 |
# Base Dependencies |
|
|
2 |
# ---------------- |
|
|
3 |
from typing import Optional |
|
|
4 |
|
|
|
5 |
# Local Dependencies |
|
|
6 |
# ------------------ |
|
|
7 |
from models import RelationCollection |
|
|
8 |
|
|
|
9 |
# 3rd-Party Dependencies |
|
|
10 |
# ---------------------- |
|
|
11 |
from sklearn.base import BaseEstimator |
|
|
12 |
|
|
|
13 |
# Constants |
|
|
14 |
# --------- |
|
|
15 |
from constants import U_POS_TAGS |
|
|
16 |
|
|
|
17 |
|
|
|
18 |
class POSFeature(BaseEstimator): |
|
|
19 |
""" |
|
|
20 |
PoS Tagging |
|
|
21 |
|
|
|
22 |
Obtains the universal POS tag of each token in the relation's sentence. |
|
|
23 |
""" |
|
|
24 |
|
|
|
25 |
def __init__(self, padding_idx: Optional[int] = None): |
|
|
26 |
self.padding_idx = padding_idx |
|
|
27 |
|
|
|
28 |
def get_feature_names(self, input_features=None): |
|
|
29 |
return ["POS"] |
|
|
30 |
|
|
|
31 |
def create_pos_feature(self, collection: RelationCollection) -> list: |
|
|
32 |
all_pos = [] |
|
|
33 |
|
|
|
34 |
for doc in collection.tokens: |
|
|
35 |
r_pos = [] |
|
|
36 |
for t in doc: |
|
|
37 |
r_pos.append(self.pos_index(t.pos_)) |
|
|
38 |
|
|
|
39 |
all_pos.append(r_pos) |
|
|
40 |
|
|
|
41 |
return all_pos |
|
|
42 |
|
|
|
43 |
def pos_index(self, pos_tag: str): |
|
|
44 |
""" |
|
|
45 |
Computes the index of the corresponding POS tag |
|
|
46 |
""" |
|
|
47 |
idx = U_POS_TAGS.index(pos_tag) |
|
|
48 |
|
|
|
49 |
if self.padding_idx is not None and idx >= self.padding_idx: |
|
|
50 |
idx += 1 |
|
|
51 |
return idx |
|
|
52 |
|
|
|
53 |
def fit(self, x: RelationCollection, y=None): |
|
|
54 |
return self |
|
|
55 |
|
|
|
56 |
def transform(self, x: RelationCollection) -> list: |
|
|
57 |
return self.create_pos_feature(x) |
|
|
58 |
|
|
|
59 |
def fit_transform(self, x: RelationCollection, y=None) -> list: |
|
|
60 |
return self.create_pos_feature(x) |