Diff of /test/features/conftest.py [000000] .. [735bb5]

Switch to unified view

a b/test/features/conftest.py
1
# Base Dependencies 
2
# -----------------
3
import pytest
4
from typing import List
5
from pathlib import Path
6
7
# Local Dependencies
8
# ------------------
9
from models.relation import RelationN2C2
10
from models import Relation, RelationCollection, Entity
11
from nlp_pipeline import get_pipeline 
12
13
# 3rd-Party Dependencies 
14
# ----------------------
15
from gensim.models import KeyedVectors
16
17
18
# Fixtures 
19
# ---------
20
@pytest.fixture(scope="session")
21
def n2c2_small_collection() -> RelationCollection:
22
    small_collection = RelationCollection.from_datading(
23
        "n2c2", Path("data/n2c2/small/relations.msgpack")
24
    )
25
    return small_collection
26
27
@pytest.fixture(scope="session")
28
def wv_model():
29
    print("Loading bioword2vec model ...")
30
    model = KeyedVectors.load_word2vec_format("data/bioword2vec/bio_embedding_extrinsic.txt", binary=False)
31
    print("Bioword2vec loaded!")
32
    return model 
33
34
35
@pytest.fixture(scope="session")
36
def NLP():
37
    return get_pipeline()
38
39
# Fixtures
40
# ---------
41
@pytest.fixture(scope="session")
42
def entity1() -> Entity:
43
44
    id = "T11"
45
    text = "Ibuprofen"
46
    type = "Drug"
47
    doc_id = "doc1202"
48
    start = 11
49
    end = start + len(text)
50
    return Entity(id=id, text=text, type=type, doc_id=doc_id, start=start, end=end)
51
52
53
@pytest.fixture(scope="session")
54
def entity2() -> Entity:
55
56
    id = "T13"
57
    text = "Paracetamol"
58
    type = "Drug"
59
    doc_id = "doc1202"
60
    start = 24
61
    end = start + len(text)
62
    return Entity(id=id, text=text, type=type, doc_id=doc_id, start=start, end=end)
63
64
@pytest.fixture(scope="session")
65
def entity3() -> Entity:
66
67
    id = "T11"
68
    text = "500mg"
69
    type = "Dosage"
70
    doc_id = "doc1202"
71
    start = 35
72
    end = start + len(text)
73
    return Entity(id=id, text=text, type=type, doc_id=doc_id, start=start, end=end)
74
75
76
77
@pytest.fixture(scope="function")
78
def relation_attributes(entity1, entity2, entity3) -> dict: 
79
    attrs = {
80
        "doc_id": "doc1202",
81
        "type": "Dosage-Drug",
82
        "entity1": entity1,
83
        "entity2": entity3,
84
        "label": 1,
85
        "left_context": "He was administered ",
86
        "middle_context": " and Paracetamol ",
87
        "right_context": " for three days",
88
        "middle_entities": [entity2],
89
    }
90
    return attrs
91
92
93
@pytest.fixture(scope="function")
94
def relation(relation_attributes) -> RelationN2C2:
95
    return RelationN2C2(**relation_attributes)
96
97
98
99
# @pytest.fixture(scope="session")
100
# def n2c2_train_collection() -> RelationCollection:
101
102
#     collections = RelationCollection.load_collections("n2c2", splits=["train"])
103
#     collection = collections["train"]
104
    
105
#     assert len(collection) > 0
106
    
107
#     return collection
108
109
110
# @pytest.fixture(scope="session")
111
# def ddi_train_collection() -> RelationCollection:
112
113
#     collections = RelationCollection.load_collections("ddi", splits=["train"])
114
#     collection = collections["train"].type_subcollection("Strength-Drug")
115
116
#     return collection