Switch to unified view

a b/src/features/entity_embedding.py
1
# Base Dependencies
2
# ----------------
3
import numpy as np
4
from typing import List, Tuple
5
6
# Local Dependencies
7
# ------------------
8
from models import RelationCollection
9
10
# 3rd-Party Dependencies
11
# ----------------------
12
from gensim.models import KeyedVectors
13
from spacy.tokens import Doc
14
from sklearn.base import BaseEstimator
15
16
# Constants
17
# ---------
18
from constants import DATASETS
19
20
21
class EntityEmbedding(BaseEstimator):
22
    """
23
    Entity Embedding
24
25
    Obtains the vectors indexes of the two entities in the relation.
26
    
27
    Source: 
28
        Alimova and Tutubalina (2020) - Multiple features for clinical relation extraction: A machine learning approach
29
    """
30
31
    def __init__(self, dataset: str, model: KeyedVectors):
32
        if dataset not in DATASETS:
33
            raise ValueError("unsupported dataset '{}'".format(dataset))
34
        self.dataset = dataset
35
        self.model = model
36
37
    def get_feature_names(self, input_features=None):
38
        return ["ent_emb"]
39
40
    def create_entity_embedding(
41
        self, collection: RelationCollection
42
    ) -> Tuple[np.array, np.array]:
43
        e1_embs = []
44
        e2_embs = []
45
        entities1: List[Doc] = collection.entities1_tokens
46
        entities2: List[Doc] = collection.entities2_tokens
47
48
        assert len(entities1) == len(entities2)
49
50
        for e1, e2 in zip(entities1, entities2):
51
            e1_tokens: List[str] = list(map(lambda t: t.text.lower(), e1))
52
            e2_tokens: List[str] = list(map(lambda t: t.text.lower(), e2))
53
54
            e1_embs.append(self.model.get_mean_vector(e1_tokens))
55
            e2_embs.append(self.model.get_mean_vector(e2_tokens))
56
57
        return np.array(e1_embs), np.array(e2_embs)
58
59
    def fit(self, x: RelationCollection, y=None):
60
        return self
61
62
    def transform(self, x: RelationCollection) -> Tuple[np.array, np.array]:
63
        return self.create_entity_embedding(x)
64
65
    def fit_transform(self, x: RelationCollection, y=None) -> Tuple[np.array, np.array]:
66
        return self.create_entity_embedding(x)