Switch to unified view

a b/src/features/negated_entities_feature.py
1
# Base Dependencies
2
# ----------------
3
import numpy as np
4
from typing import Optional
5
6
# Local Dependencies
7
# ------------------
8
from models import RelationCollection
9
from nlp_pipeline import get_pipeline, set_spacy_entities
10
11
# 3rd-Party Dependencies
12
# ----------------------
13
from sklearn.base import BaseEstimator
14
15
16
17
class NegatedEntitiesFeature(BaseEstimator):
18
    """
19
    Negated Entities Feature 
20
    
21
    Determines if each of the target entities of a relation is negated or not.
22
    """
23
24
    def __init__(self, padding_idx: Optional[int] = None):
25
        self.padding_idx = padding_idx
26
27
    def get_feature_names(self, input_features=None):
28
        return ["e1_negated", "e2_negated"]
29
30
    def create_negated_entities_feature(self, collection: RelationCollection) -> list:
31
        features = []
32
33
        NLP = get_pipeline()
34
        parser = NLP.get_pipe("parser")
35
        negex = NLP.get_pipe("negex")
36
        docs = collection.tokens 
37
        
38
        for i, doc in enumerate(parser.pipe(docs)):
39
            set_spacy_entities(
40
                doc,
41
                collection.left_tokens[i],
42
                collection.entities1_tokens[i],
43
                collection.relations[i].entity1.type,
44
                collection.middle_tokens[i],
45
                collection.entities2_tokens[i],
46
                collection.relations[i].entity2.type,
47
                collection.right_tokens[i],
48
            )
49
            assert len(doc.ents) == 2
50
            doc = negex(doc)
51
            e1_negated = int(doc.ents[0]._.negex)
52
            e2_negated = int(doc.ents[1]._.negex)
53
            
54
            features.append([e1_negated, e2_negated])
55
56
        return np.array(features)
57
58
    def fit(self, x: RelationCollection, y=None):
59
        return self
60
61
    def transform(self, x: RelationCollection) -> list:
62
        return self.create_negated_entities_feature(x)
63
64
    def fit_transform(self, x: RelationCollection, y=None) -> list:
65
        return self.create_negated_entities_feature(x)