Switch to unified view

a b/src/features/sentence_embedding.py
1
# Base Dependencies
2
# ----------------
3
import numpy as np
4
from typing import List
5
6
# Local Dependencies
7
# ------------------
8
from models import RelationCollection
9
10
# 3rd-Party Dependencies
11
# ----------------------
12
from gensim.models import KeyedVectors
13
from sklearn.base import BaseEstimator
14
15
16
class SentenceEmbedding(BaseEstimator):
17
    """
18
    Sentence Embedding
19
20
    Obtains the word embedding indexes of the sentence.
21
22
    Source: 
23
        Alimova and Tutubalina (2020) - Multiple features for clinical relation extraction: A machine learning approachFF
24
    """
25
26
    def __init__(self, model: KeyedVectors):
27
        self.model = model
28
29
    def get_feature_names(self, input_features=None):
30
        return ["sentence_embedding"]
31
32
    def create_sentence_embedding(self, collection: RelationCollection) -> np.array:
33
        sent_embs = []
34
        for doc in collection.tokens:
35
            sent_tokens: List[str] = list(map(lambda t: t.text.lower(), doc))
36
            sent_embs.append(self.model.get_mean_vector(sent_tokens))
37
38
        return np.array(sent_embs)
39
40
    def fit(self, x: RelationCollection, y=None):
41
        return self
42
43
    def transform(self, x: RelationCollection) -> np.array:
44
        return self.create_sentence_embedding(x)
45
46
    def fit_transform(self, x: RelationCollection, y=None) -> np.array:
47
        return self.create_sentence_embedding(x)