|
a |
|
b/src/features/position_feature.py |
|
|
1 |
# Base Dependencies |
|
|
2 |
# ---------------- |
|
|
3 |
import numpy |
|
|
4 |
|
|
|
5 |
# Local Dependencies |
|
|
6 |
# ------------------ |
|
|
7 |
from models import RelationCollection |
|
|
8 |
|
|
|
9 |
# 3rd-Party Dependencies |
|
|
10 |
# ---------------------- |
|
|
11 |
from sklearn.base import BaseEstimator |
|
|
12 |
|
|
|
13 |
# Constants |
|
|
14 |
# --------- |
|
|
15 |
from constants import ( |
|
|
16 |
N2C2_ATTR_ENTITY_CANDIDATES, |
|
|
17 |
DDI_ATTR_ENTITY_CANDIDATES, |
|
|
18 |
) |
|
|
19 |
|
|
|
20 |
|
|
|
21 |
class PositionFeature(BaseEstimator): |
|
|
22 |
""" |
|
|
23 |
Position Distance |
|
|
24 |
|
|
|
25 |
Computes the position of the entity candidate (drug) with respect to |
|
|
26 |
the attribute among the entire entity candidates of the attribute, where |
|
|
27 |
the position of medical attribute is set to 0. |
|
|
28 |
|
|
|
29 |
Source: |
|
|
30 |
Alimova and Tutubalina (2020) - Multiple features for clinical relation extraction: A machine learning approach |
|
|
31 |
""" |
|
|
32 |
|
|
|
33 |
def __init__(self, dataset: str): |
|
|
34 |
if dataset == "n2c2": |
|
|
35 |
self.attr_entity_candidates = N2C2_ATTR_ENTITY_CANDIDATES |
|
|
36 |
|
|
|
37 |
elif dataset == "ddi": |
|
|
38 |
self.attr_entity_candidates = DDI_ATTR_ENTITY_CANDIDATES |
|
|
39 |
else: |
|
|
40 |
raise ValueError( |
|
|
41 |
"only datasets 'n2c2' and 'ddi' are supported, but no '{}'".format( |
|
|
42 |
dataset |
|
|
43 |
) |
|
|
44 |
) |
|
|
45 |
self.dataset = dataset |
|
|
46 |
|
|
|
47 |
def get_feature_names(self, input_features=None): |
|
|
48 |
return ["position_1", "position_2"] |
|
|
49 |
|
|
|
50 |
def create_position_feature(self, collection: RelationCollection) -> numpy.array: |
|
|
51 |
features = [] |
|
|
52 |
for r in collection.relations: |
|
|
53 |
feature = [0] * 2 |
|
|
54 |
|
|
|
55 |
attr, drug = r._ordered_entities |
|
|
56 |
candidates = self.attr_entity_candidates[attr.type] |
|
|
57 |
|
|
|
58 |
# count middle entities which could form the same type of relation |
|
|
59 |
# i.e., count number of middle entities that are drugs for n2c2 and DDI |
|
|
60 |
position = 0 |
|
|
61 |
for ent in r.middle_entities: |
|
|
62 |
if ent.type in candidates: |
|
|
63 |
position += 1 |
|
|
64 |
|
|
|
65 |
ent1 = r.entity1 |
|
|
66 |
ent2 = r.entity2 |
|
|
67 |
# if the attribute is the first entity, the position is positive |
|
|
68 |
if ent1.type == attr.type: |
|
|
69 |
feature[0] = 0 |
|
|
70 |
feature[1] = position |
|
|
71 |
# if the attribute is the second entity, the position is negative |
|
|
72 |
elif ent2.type == attr.type: |
|
|
73 |
feature[0] = -position |
|
|
74 |
feature[1] = 0 |
|
|
75 |
else: |
|
|
76 |
raise ValueError("none of the entities correspond with the attribute") |
|
|
77 |
|
|
|
78 |
features.append(feature) |
|
|
79 |
|
|
|
80 |
return numpy.array(features) |
|
|
81 |
|
|
|
82 |
def fit(self, x: RelationCollection, y=None): |
|
|
83 |
return self |
|
|
84 |
|
|
|
85 |
def transform(self, x: RelationCollection) -> numpy.array: |
|
|
86 |
return self.create_position_feature(x) |
|
|
87 |
|
|
|
88 |
def fit_transform(self, x: RelationCollection, y=None) -> numpy.array: |
|
|
89 |
return self.create_position_feature(x) |