--- a +++ b/tests/conftest.py @@ -0,0 +1,239 @@ +import logging +import os +import time +from datetime import datetime + +import pandas as pd +import pytest +import spacy +from helpers import make_nlp +from pytest import fixture + +import edsnlp + +os.environ["EDSNLP_MAX_CPU_WORKERS"] = "2" +os.environ["TZ"] = "Europe/Paris" + +try: + time.tzset() +except AttributeError: + pass +logging.basicConfig(level=logging.INFO) +try: + import torch.nn +except ImportError: + torch = None + +pytest.importorskip("rich") + + +def pytest_collection_modifyitems(items): + """Run test_docs* at the end""" + first_tests = [] + last_tests = [] + for item in items: + if item.name.startswith("test_code_blocks"): + last_tests.append(item) + else: + first_tests.append(item) + items[:] = first_tests + last_tests + + +@fixture(scope="session", params=["eds", "fr"]) +def lang(request): + return request.param + + +@fixture(scope="session") +def nlp(lang): + return make_nlp(lang) + + +@fixture(scope="session") +def nlp_eds(): + return make_nlp("eds") + + +@fixture +def blank_nlp(lang): + if lang == "eds": + model = spacy.blank("eds") + else: + model = edsnlp.blank("fr") + model.add_pipe("eds.sentences") + return model + + +def make_ml_pipeline(): + nlp = edsnlp.blank("eds") + nlp.add_pipe("eds.sentences", name="sentences") + nlp.add_pipe( + "eds.transformer", + name="transformer", + config=dict( + model="prajjwal1/bert-tiny", + window=128, + stride=96, + ), + ) + nlp.add_pipe( + "eds.ner_crf", + name="ner", + config=dict( + embedding=nlp.get_pipe("transformer"), + mode="independent", + target_span_getter=["ents", "ner-preds"], + span_setter="ents", + ), + ) + ner = nlp.get_pipe("ner") + ner.update_labels(["PERSON", "GIFT"]) + return nlp + + +@fixture() +def ml_nlp(): + if torch is None: + pytest.skip("torch not installed", allow_module_level=False) + return make_ml_pipeline() + + +@fixture(scope="session") +def frozen_ml_nlp(): + if torch is None: + pytest.skip("torch not installed", allow_module_level=False) + return make_ml_pipeline() + + +@fixture() +def text(): + return ( + "Le patient est admis pour des douleurs dans le bras droit, " + "mais n'a pas de problème de locomotion. " + "Historique d'AVC dans la famille. pourrait être un cas de rhume.\n" + "NBNbWbWbNbWbNBNbNbWbWbNBNbWbNbNbWbNBNbWbNbNBWbWbNbNbNBWbNbWbNbWBNb" + "NbWbNbNBNbWbWbNbWBNbNbWbNBNbWbWbNb\n" + "Pourrait être un cas de rhume.\n" + "Motif :\n" + "Douleurs dans le bras droit.\n" + "ANTÉCÉDENTS\n" + "Le patient est déjà venu pendant les vacances\n" + "d'été.\n" + "Pas d'anomalie détectée." + ) + + +@fixture +def doc(nlp, text): + return nlp(text) + + +@fixture +def blank_doc(blank_nlp, text): + return blank_nlp(text) + + +@fixture +def df_notes(): + N_LINES = 100 + notes = pd.DataFrame( + data={ + "note_id": list(range(N_LINES)), + "note_text": N_LINES * [text], + "note_datetime": N_LINES * [datetime.today()], + } + ) + + return notes + + +def make_df_note(text, module): + from pyspark.sql import types as T + from pyspark.sql.session import SparkSession + + data = [(i, i // 5, text, datetime(2021, 1, 1)) for i in range(20)] + + if module == "pandas": + return pd.DataFrame( + data=data, columns=["note_id", "person_id", "note_text", "note_datetime"] + ) + + note_schema = T.StructType( + [ + T.StructField("note_id", T.IntegerType()), + T.StructField("person_id", T.IntegerType()), + T.StructField("note_text", T.StringType()), + T.StructField( + "note_datetime", + T.TimestampType(), + ), + ] + ) + + spark = SparkSession.builder.getOrCreate() + notes = spark.createDataFrame(data=data, schema=note_schema) + if module == "pyspark": + return notes + + if module == "koalas": + try: + import databricks.koalas # noqa F401 + except ImportError: + pytest.skip("Koalas not installed") + return notes.to_koalas() + + +@fixture +def df_notes_pandas(text): + return make_df_note(text, "pandas") + + +@fixture +def df_notes_pyspark(text): + return make_df_note(text, "pyspark") + + +@fixture +def df_notes_koalas(text): + return make_df_note(text, "koalas") + + +@fixture +def run_in_test_dir(request, monkeypatch): + monkeypatch.chdir(request.fspath.dirname) + + +@pytest.fixture(autouse=True) +def stop_spark(): + yield + try: + from pyspark.sql import SparkSession + except ImportError: + return + try: + getActiveSession = SparkSession.getActiveSession + except AttributeError: + + def getActiveSession(): # pragma: no cover + from pyspark import SparkContext + + sc = SparkContext._active_spark_context + if sc is None: + return None + else: + assert sc._jvm is not None + if sc._jvm.SparkSession.getActiveSession().isDefined(): + SparkSession(sc, sc._jvm.SparkSession.getActiveSession().get()) + try: + return SparkSession._activeSession + except AttributeError: + try: + return SparkSession._instantiatedSession + except AttributeError: + return None + else: + return None + + session = getActiveSession() + if session is not None: + session.stop()