Switch to unified view

a b/src/features/punct_distance_feature.py
1
# Base Dependencies
2
# ----------------
3
import numpy
4
5
# Local Dependencies
6
# ------------------
7
from models import RelationCollection
8
9
# 3rd-Party Dependencies
10
# ----------------------
11
from sklearn.base import BaseEstimator
12
13
14
class PunctuationFeature(BaseEstimator):
15
    """
16
    PunctuationFeature
17
18
    Computes the number of punctuation characters between the two entities of a relation.
19
    
20
    Source: 
21
        Alimova and Tutubalina (2020) - Multiple features for clinical relation extraction: A machine learning approach
22
    """
23
24
    def __init__(self):
25
        pass
26
27
    def get_feature_names(self, input_features=None):
28
        return ["punct_dist"]
29
30
    def create_punctuation_distance_feature(
31
        self, collection: RelationCollection
32
    ) -> numpy.array:
33
        features = []
34
        for doc in collection.middle_tokens:
35
            features.append([len(list(filter(lambda t: t.is_punct, doc)))])
36
37
        return numpy.array(features)
38
39
    def fit(self, x: RelationCollection, y=None):
40
        return self
41
42
    def transform(self, x: RelationCollection) -> numpy.array:
43
        return self.create_punctuation_distance_feature(x)
44
45
    def fit_transform(self, x: RelationCollection, y=None) -> numpy.array:
46
        return self.create_punctuation_distance_feature(x)