|
a |
|
b/tests/conftest.py |
|
|
1 |
import logging |
|
|
2 |
import os |
|
|
3 |
import time |
|
|
4 |
from datetime import datetime |
|
|
5 |
|
|
|
6 |
import pandas as pd |
|
|
7 |
import pytest |
|
|
8 |
import spacy |
|
|
9 |
from helpers import make_nlp |
|
|
10 |
from pytest import fixture |
|
|
11 |
|
|
|
12 |
import edsnlp |
|
|
13 |
|
|
|
14 |
os.environ["EDSNLP_MAX_CPU_WORKERS"] = "2" |
|
|
15 |
os.environ["TZ"] = "Europe/Paris" |
|
|
16 |
|
|
|
17 |
try: |
|
|
18 |
time.tzset() |
|
|
19 |
except AttributeError: |
|
|
20 |
pass |
|
|
21 |
logging.basicConfig(level=logging.INFO) |
|
|
22 |
try: |
|
|
23 |
import torch.nn |
|
|
24 |
except ImportError: |
|
|
25 |
torch = None |
|
|
26 |
|
|
|
27 |
pytest.importorskip("rich") |
|
|
28 |
|
|
|
29 |
|
|
|
30 |
def pytest_collection_modifyitems(items): |
|
|
31 |
"""Run test_docs* at the end""" |
|
|
32 |
first_tests = [] |
|
|
33 |
last_tests = [] |
|
|
34 |
for item in items: |
|
|
35 |
if item.name.startswith("test_code_blocks"): |
|
|
36 |
last_tests.append(item) |
|
|
37 |
else: |
|
|
38 |
first_tests.append(item) |
|
|
39 |
items[:] = first_tests + last_tests |
|
|
40 |
|
|
|
41 |
|
|
|
42 |
@fixture(scope="session", params=["eds", "fr"]) |
|
|
43 |
def lang(request): |
|
|
44 |
return request.param |
|
|
45 |
|
|
|
46 |
|
|
|
47 |
@fixture(scope="session") |
|
|
48 |
def nlp(lang): |
|
|
49 |
return make_nlp(lang) |
|
|
50 |
|
|
|
51 |
|
|
|
52 |
@fixture(scope="session") |
|
|
53 |
def nlp_eds(): |
|
|
54 |
return make_nlp("eds") |
|
|
55 |
|
|
|
56 |
|
|
|
57 |
@fixture |
|
|
58 |
def blank_nlp(lang): |
|
|
59 |
if lang == "eds": |
|
|
60 |
model = spacy.blank("eds") |
|
|
61 |
else: |
|
|
62 |
model = edsnlp.blank("fr") |
|
|
63 |
model.add_pipe("eds.sentences") |
|
|
64 |
return model |
|
|
65 |
|
|
|
66 |
|
|
|
67 |
def make_ml_pipeline(): |
|
|
68 |
nlp = edsnlp.blank("eds") |
|
|
69 |
nlp.add_pipe("eds.sentences", name="sentences") |
|
|
70 |
nlp.add_pipe( |
|
|
71 |
"eds.transformer", |
|
|
72 |
name="transformer", |
|
|
73 |
config=dict( |
|
|
74 |
model="prajjwal1/bert-tiny", |
|
|
75 |
window=128, |
|
|
76 |
stride=96, |
|
|
77 |
), |
|
|
78 |
) |
|
|
79 |
nlp.add_pipe( |
|
|
80 |
"eds.ner_crf", |
|
|
81 |
name="ner", |
|
|
82 |
config=dict( |
|
|
83 |
embedding=nlp.get_pipe("transformer"), |
|
|
84 |
mode="independent", |
|
|
85 |
target_span_getter=["ents", "ner-preds"], |
|
|
86 |
span_setter="ents", |
|
|
87 |
), |
|
|
88 |
) |
|
|
89 |
ner = nlp.get_pipe("ner") |
|
|
90 |
ner.update_labels(["PERSON", "GIFT"]) |
|
|
91 |
return nlp |
|
|
92 |
|
|
|
93 |
|
|
|
94 |
@fixture() |
|
|
95 |
def ml_nlp(): |
|
|
96 |
if torch is None: |
|
|
97 |
pytest.skip("torch not installed", allow_module_level=False) |
|
|
98 |
return make_ml_pipeline() |
|
|
99 |
|
|
|
100 |
|
|
|
101 |
@fixture(scope="session") |
|
|
102 |
def frozen_ml_nlp(): |
|
|
103 |
if torch is None: |
|
|
104 |
pytest.skip("torch not installed", allow_module_level=False) |
|
|
105 |
return make_ml_pipeline() |
|
|
106 |
|
|
|
107 |
|
|
|
108 |
@fixture() |
|
|
109 |
def text(): |
|
|
110 |
return ( |
|
|
111 |
"Le patient est admis pour des douleurs dans le bras droit, " |
|
|
112 |
"mais n'a pas de problème de locomotion. " |
|
|
113 |
"Historique d'AVC dans la famille. pourrait être un cas de rhume.\n" |
|
|
114 |
"NBNbWbWbNbWbNBNbNbWbWbNBNbWbNbNbWbNBNbWbNbNBWbWbNbNbNBWbNbWbNbWBNb" |
|
|
115 |
"NbWbNbNBNbWbWbNbWBNbNbWbNBNbWbWbNb\n" |
|
|
116 |
"Pourrait être un cas de rhume.\n" |
|
|
117 |
"Motif :\n" |
|
|
118 |
"Douleurs dans le bras droit.\n" |
|
|
119 |
"ANTÉCÉDENTS\n" |
|
|
120 |
"Le patient est déjà venu pendant les vacances\n" |
|
|
121 |
"d'été.\n" |
|
|
122 |
"Pas d'anomalie détectée." |
|
|
123 |
) |
|
|
124 |
|
|
|
125 |
|
|
|
126 |
@fixture |
|
|
127 |
def doc(nlp, text): |
|
|
128 |
return nlp(text) |
|
|
129 |
|
|
|
130 |
|
|
|
131 |
@fixture |
|
|
132 |
def blank_doc(blank_nlp, text): |
|
|
133 |
return blank_nlp(text) |
|
|
134 |
|
|
|
135 |
|
|
|
136 |
@fixture |
|
|
137 |
def df_notes(): |
|
|
138 |
N_LINES = 100 |
|
|
139 |
notes = pd.DataFrame( |
|
|
140 |
data={ |
|
|
141 |
"note_id": list(range(N_LINES)), |
|
|
142 |
"note_text": N_LINES * [text], |
|
|
143 |
"note_datetime": N_LINES * [datetime.today()], |
|
|
144 |
} |
|
|
145 |
) |
|
|
146 |
|
|
|
147 |
return notes |
|
|
148 |
|
|
|
149 |
|
|
|
150 |
def make_df_note(text, module): |
|
|
151 |
from pyspark.sql import types as T |
|
|
152 |
from pyspark.sql.session import SparkSession |
|
|
153 |
|
|
|
154 |
data = [(i, i // 5, text, datetime(2021, 1, 1)) for i in range(20)] |
|
|
155 |
|
|
|
156 |
if module == "pandas": |
|
|
157 |
return pd.DataFrame( |
|
|
158 |
data=data, columns=["note_id", "person_id", "note_text", "note_datetime"] |
|
|
159 |
) |
|
|
160 |
|
|
|
161 |
note_schema = T.StructType( |
|
|
162 |
[ |
|
|
163 |
T.StructField("note_id", T.IntegerType()), |
|
|
164 |
T.StructField("person_id", T.IntegerType()), |
|
|
165 |
T.StructField("note_text", T.StringType()), |
|
|
166 |
T.StructField( |
|
|
167 |
"note_datetime", |
|
|
168 |
T.TimestampType(), |
|
|
169 |
), |
|
|
170 |
] |
|
|
171 |
) |
|
|
172 |
|
|
|
173 |
spark = SparkSession.builder.getOrCreate() |
|
|
174 |
notes = spark.createDataFrame(data=data, schema=note_schema) |
|
|
175 |
if module == "pyspark": |
|
|
176 |
return notes |
|
|
177 |
|
|
|
178 |
if module == "koalas": |
|
|
179 |
try: |
|
|
180 |
import databricks.koalas # noqa F401 |
|
|
181 |
except ImportError: |
|
|
182 |
pytest.skip("Koalas not installed") |
|
|
183 |
return notes.to_koalas() |
|
|
184 |
|
|
|
185 |
|
|
|
186 |
@fixture |
|
|
187 |
def df_notes_pandas(text): |
|
|
188 |
return make_df_note(text, "pandas") |
|
|
189 |
|
|
|
190 |
|
|
|
191 |
@fixture |
|
|
192 |
def df_notes_pyspark(text): |
|
|
193 |
return make_df_note(text, "pyspark") |
|
|
194 |
|
|
|
195 |
|
|
|
196 |
@fixture |
|
|
197 |
def df_notes_koalas(text): |
|
|
198 |
return make_df_note(text, "koalas") |
|
|
199 |
|
|
|
200 |
|
|
|
201 |
@fixture |
|
|
202 |
def run_in_test_dir(request, monkeypatch): |
|
|
203 |
monkeypatch.chdir(request.fspath.dirname) |
|
|
204 |
|
|
|
205 |
|
|
|
206 |
@pytest.fixture(autouse=True) |
|
|
207 |
def stop_spark(): |
|
|
208 |
yield |
|
|
209 |
try: |
|
|
210 |
from pyspark.sql import SparkSession |
|
|
211 |
except ImportError: |
|
|
212 |
return |
|
|
213 |
try: |
|
|
214 |
getActiveSession = SparkSession.getActiveSession |
|
|
215 |
except AttributeError: |
|
|
216 |
|
|
|
217 |
def getActiveSession(): # pragma: no cover |
|
|
218 |
from pyspark import SparkContext |
|
|
219 |
|
|
|
220 |
sc = SparkContext._active_spark_context |
|
|
221 |
if sc is None: |
|
|
222 |
return None |
|
|
223 |
else: |
|
|
224 |
assert sc._jvm is not None |
|
|
225 |
if sc._jvm.SparkSession.getActiveSession().isDefined(): |
|
|
226 |
SparkSession(sc, sc._jvm.SparkSession.getActiveSession().get()) |
|
|
227 |
try: |
|
|
228 |
return SparkSession._activeSession |
|
|
229 |
except AttributeError: |
|
|
230 |
try: |
|
|
231 |
return SparkSession._instantiatedSession |
|
|
232 |
except AttributeError: |
|
|
233 |
return None |
|
|
234 |
else: |
|
|
235 |
return None |
|
|
236 |
|
|
|
237 |
session = getActiveSession() |
|
|
238 |
if session is not None: |
|
|
239 |
session.stop() |