Switch to unified view

a b/src/features/position_feature.py
1
# Base Dependencies
2
# ----------------
3
import numpy
4
5
# Local Dependencies
6
# ------------------
7
from models import RelationCollection
8
9
# 3rd-Party Dependencies
10
# ----------------------
11
from sklearn.base import BaseEstimator
12
13
# Constants
14
# ---------
15
from constants import (
16
    N2C2_ATTR_ENTITY_CANDIDATES,
17
    DDI_ATTR_ENTITY_CANDIDATES,
18
)
19
20
21
class PositionFeature(BaseEstimator):
22
    """
23
    Position Distance
24
25
    Computes the position of the entity candidate (drug) with respect to
26
    the attribute among the entire entity candidates of the attribute, where
27
    the position of medical attribute is set to 0.
28
    
29
    Source: 
30
        Alimova and Tutubalina (2020) - Multiple features for clinical relation extraction: A machine learning approach
31
    """
32
33
    def __init__(self, dataset: str):
34
        if dataset == "n2c2":
35
            self.attr_entity_candidates = N2C2_ATTR_ENTITY_CANDIDATES
36
37
        elif dataset == "ddi":
38
            self.attr_entity_candidates = DDI_ATTR_ENTITY_CANDIDATES
39
        else:
40
            raise ValueError(
41
                "only datasets 'n2c2' and 'ddi' are supported, but no '{}'".format(
42
                    dataset
43
                )
44
            )
45
        self.dataset = dataset
46
47
    def get_feature_names(self, input_features=None):
48
        return  ["position_1", "position_2"]
49
50
    def create_position_feature(self, collection: RelationCollection) -> numpy.array:
51
        features = []
52
        for r in collection.relations:
53
            feature = [0] * 2
54
55
            attr, drug = r._ordered_entities
56
            candidates = self.attr_entity_candidates[attr.type]
57
58
            # count middle entities which could form the same type of relation
59
            # i.e., count number of middle entities that are drugs for n2c2 and DDI
60
            position = 0
61
            for ent in r.middle_entities:
62
                if ent.type in candidates:
63
                    position += 1
64
65
            ent1 = r.entity1
66
            ent2 = r.entity2
67
            # if the attribute is the first entity, the position is positive
68
            if ent1.type == attr.type:
69
                feature[0] = 0
70
                feature[1] = position
71
            # if the attribute is the second entity, the position is negative
72
            elif ent2.type == attr.type:
73
                feature[0] = -position
74
                feature[1] = 0
75
            else:
76
                raise ValueError("none of the entities correspond with the attribute")
77
78
            features.append(feature)
79
80
        return numpy.array(features)
81
82
    def fit(self, x: RelationCollection, y=None):
83
        return self
84
85
    def transform(self, x: RelationCollection) -> numpy.array:
86
        return self.create_position_feature(x)
87
88
    def fit_transform(self, x: RelationCollection, y=None) -> numpy.array:
89
        return self.create_position_feature(x)