Diff of /tests/conftest.py [000000] .. [cad161]

Switch to unified view

a b/tests/conftest.py
1
import logging
2
import os
3
import time
4
from datetime import datetime
5
6
import pandas as pd
7
import pytest
8
import spacy
9
from helpers import make_nlp
10
from pytest import fixture
11
12
import edsnlp
13
14
os.environ["EDSNLP_MAX_CPU_WORKERS"] = "2"
15
os.environ["TZ"] = "Europe/Paris"
16
17
try:
18
    time.tzset()
19
except AttributeError:
20
    pass
21
logging.basicConfig(level=logging.INFO)
22
try:
23
    import torch.nn
24
except ImportError:
25
    torch = None
26
27
pytest.importorskip("rich")
28
29
30
def pytest_collection_modifyitems(items):
31
    """Run test_docs* at the end"""
32
    first_tests = []
33
    last_tests = []
34
    for item in items:
35
        if item.name.startswith("test_code_blocks"):
36
            last_tests.append(item)
37
        else:
38
            first_tests.append(item)
39
    items[:] = first_tests + last_tests
40
41
42
@fixture(scope="session", params=["eds", "fr"])
43
def lang(request):
44
    return request.param
45
46
47
@fixture(scope="session")
48
def nlp(lang):
49
    return make_nlp(lang)
50
51
52
@fixture(scope="session")
53
def nlp_eds():
54
    return make_nlp("eds")
55
56
57
@fixture
58
def blank_nlp(lang):
59
    if lang == "eds":
60
        model = spacy.blank("eds")
61
    else:
62
        model = edsnlp.blank("fr")
63
    model.add_pipe("eds.sentences")
64
    return model
65
66
67
def make_ml_pipeline():
68
    nlp = edsnlp.blank("eds")
69
    nlp.add_pipe("eds.sentences", name="sentences")
70
    nlp.add_pipe(
71
        "eds.transformer",
72
        name="transformer",
73
        config=dict(
74
            model="prajjwal1/bert-tiny",
75
            window=128,
76
            stride=96,
77
        ),
78
    )
79
    nlp.add_pipe(
80
        "eds.ner_crf",
81
        name="ner",
82
        config=dict(
83
            embedding=nlp.get_pipe("transformer"),
84
            mode="independent",
85
            target_span_getter=["ents", "ner-preds"],
86
            span_setter="ents",
87
        ),
88
    )
89
    ner = nlp.get_pipe("ner")
90
    ner.update_labels(["PERSON", "GIFT"])
91
    return nlp
92
93
94
@fixture()
95
def ml_nlp():
96
    if torch is None:
97
        pytest.skip("torch not installed", allow_module_level=False)
98
    return make_ml_pipeline()
99
100
101
@fixture(scope="session")
102
def frozen_ml_nlp():
103
    if torch is None:
104
        pytest.skip("torch not installed", allow_module_level=False)
105
    return make_ml_pipeline()
106
107
108
@fixture()
109
def text():
110
    return (
111
        "Le patient est admis pour des douleurs dans le bras droit, "
112
        "mais n'a pas de problème de locomotion. "
113
        "Historique d'AVC dans la famille. pourrait être un cas de rhume.\n"
114
        "NBNbWbWbNbWbNBNbNbWbWbNBNbWbNbNbWbNBNbWbNbNBWbWbNbNbNBWbNbWbNbWBNb"
115
        "NbWbNbNBNbWbWbNbWBNbNbWbNBNbWbWbNb\n"
116
        "Pourrait être un cas de rhume.\n"
117
        "Motif :\n"
118
        "Douleurs dans le bras droit.\n"
119
        "ANTÉCÉDENTS\n"
120
        "Le patient est déjà venu pendant les vacances\n"
121
        "d'été.\n"
122
        "Pas d'anomalie détectée."
123
    )
124
125
126
@fixture
127
def doc(nlp, text):
128
    return nlp(text)
129
130
131
@fixture
132
def blank_doc(blank_nlp, text):
133
    return blank_nlp(text)
134
135
136
@fixture
137
def df_notes():
138
    N_LINES = 100
139
    notes = pd.DataFrame(
140
        data={
141
            "note_id": list(range(N_LINES)),
142
            "note_text": N_LINES * [text],
143
            "note_datetime": N_LINES * [datetime.today()],
144
        }
145
    )
146
147
    return notes
148
149
150
def make_df_note(text, module):
151
    from pyspark.sql import types as T
152
    from pyspark.sql.session import SparkSession
153
154
    data = [(i, i // 5, text, datetime(2021, 1, 1)) for i in range(20)]
155
156
    if module == "pandas":
157
        return pd.DataFrame(
158
            data=data, columns=["note_id", "person_id", "note_text", "note_datetime"]
159
        )
160
161
    note_schema = T.StructType(
162
        [
163
            T.StructField("note_id", T.IntegerType()),
164
            T.StructField("person_id", T.IntegerType()),
165
            T.StructField("note_text", T.StringType()),
166
            T.StructField(
167
                "note_datetime",
168
                T.TimestampType(),
169
            ),
170
        ]
171
    )
172
173
    spark = SparkSession.builder.getOrCreate()
174
    notes = spark.createDataFrame(data=data, schema=note_schema)
175
    if module == "pyspark":
176
        return notes
177
178
    if module == "koalas":
179
        try:
180
            import databricks.koalas  # noqa F401
181
        except ImportError:
182
            pytest.skip("Koalas not installed")
183
        return notes.to_koalas()
184
185
186
@fixture
187
def df_notes_pandas(text):
188
    return make_df_note(text, "pandas")
189
190
191
@fixture
192
def df_notes_pyspark(text):
193
    return make_df_note(text, "pyspark")
194
195
196
@fixture
197
def df_notes_koalas(text):
198
    return make_df_note(text, "koalas")
199
200
201
@fixture
202
def run_in_test_dir(request, monkeypatch):
203
    monkeypatch.chdir(request.fspath.dirname)
204
205
206
@pytest.fixture(autouse=True)
207
def stop_spark():
208
    yield
209
    try:
210
        from pyspark.sql import SparkSession
211
    except ImportError:
212
        return
213
    try:
214
        getActiveSession = SparkSession.getActiveSession
215
    except AttributeError:
216
217
        def getActiveSession():  # pragma: no cover
218
            from pyspark import SparkContext
219
220
            sc = SparkContext._active_spark_context
221
            if sc is None:
222
                return None
223
            else:
224
                assert sc._jvm is not None
225
                if sc._jvm.SparkSession.getActiveSession().isDefined():
226
                    SparkSession(sc, sc._jvm.SparkSession.getActiveSession().get())
227
                    try:
228
                        return SparkSession._activeSession
229
                    except AttributeError:
230
                        try:
231
                            return SparkSession._instantiatedSession
232
                        except AttributeError:
233
                            return None
234
                else:
235
                    return None
236
237
    session = getActiveSession()
238
    if session is not None:
239
        session.stop()