[735bb5]: / test / features / test_entity_embedding_feature.py

Download this file

50 lines (32 with data), 1.4 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
# Base Dependencies
# ----------------
import pytest
import logging
# Local Dependencies
# ------------------
from features import EntityEmbedding
from models import RelationCollection
# Tests
# ------
def test_entity_embedding_init(wv_model):
ent_emb = EntityEmbedding("n2c2", wv_model)
assert isinstance(ent_emb, EntityEmbedding)
assert ent_emb.dataset == "n2c2"
assert ent_emb.model == wv_model
def test_entity_embedding_init_unsupported_dataset_raises(wv_model):
with pytest.raises(ValueError):
ent_emb = EntityEmbedding("i2b2", wv_model)
def test_entity_embedding_get_feature_names(wv_model):
ent_emb = EntityEmbedding("n2c2", wv_model)
assert ent_emb.get_feature_names() == ["entity_embedding"]
def test_entity_embedding_fit(n2c2_small_collection, wv_model):
ent_emb = EntityEmbedding("n2c2", wv_model)
ent_emb = ent_emb.fit(n2c2_small_collection)
assert isinstance(ent_emb, EntityEmbedding)
def test_entity_embedding_create_entity_embedding(n2c2_small_collection, wv_model, NLP):
RelationCollection.set_nlp(NLP)
ent_emb = EntityEmbedding("n2c2", wv_model)
e1_emb, e2_emb = ent_emb.create_entity_embedding(n2c2_small_collection)
assert e1_emb.shape == e2_emb.shape
assert e1_emb.shape[0] == len(n2c2_small_collection)
assert e1_emb.shape[1] == wv_model.vector_size