--- a +++ b/tests/processing/test_processing.py @@ -0,0 +1,231 @@ +from datetime import datetime +from typing import Any, Dict, List + +import pandas as pd +import pytest +from spacy.tokens import Doc + +text = """ +Motif : +Le patient est admis le 29 août 2020 pour des difficultés respiratoires. + +Antécédents familiaux : +Le père est asthmatique, sans traitement particulier. + +HISTOIRE DE LA MALADIE +Le patient dit avoir de la toux. \ +Elle a empiré jusqu'à nécessiter un passage aux urgences. +La patiente avait un SOFA à l'admission de 8. + +Conclusion +Possible infection au coronavirus +""" + + +def note(module): + from pyspark.sql import types as T + from pyspark.sql.session import SparkSession + + data = [(i, i // 5, text, datetime(2021, 1, 1)) for i in range(20)] + + if module == "pandas": + return pd.DataFrame( + data=data, columns=["note_id", "person_id", "note_text", "note_datetime"] + ) + + note_schema = T.StructType( + [ + T.StructField("note_id", T.IntegerType()), + T.StructField("person_id", T.IntegerType()), + T.StructField("note_text", T.StringType()), + T.StructField( + "note_datetime", + T.TimestampType(), + ), + ] + ) + + spark = SparkSession.builder.getOrCreate() + notes = spark.createDataFrame(data=data, schema=note_schema) + if module == "pyspark": + return notes + + if module == "koalas": + return notes.to_koalas() + + +@pytest.fixture +def model(blank_nlp): + # Creates the spaCy instance + nlp = blank_nlp + + # Normalisation of accents, case and other special characters + nlp.add_pipe("eds.normalizer") + + # Extraction of named entities + nlp.add_pipe( + "eds.matcher", + config=dict( + terms=dict( + respiratoire=[ + "difficultes respiratoires", + "asthmatique", + "toux", + ] + ), + regex=dict( + covid=r"(?i)(?:infection\sau\s)?(covid[\s\-]?19|corona[\s\-]?virus)", + traitement=r"(?i)traitements?|medicaments?", + respiratoire="respiratoires", + ), + attr="NORM", + ), + ) + + # Qualification of matched entities + nlp.add_pipe("eds.negation") + nlp.add_pipe("eds.hypothesis") + nlp.add_pipe("eds.family") + nlp.add_pipe("eds.reported_speech") + nlp.add_pipe("eds.sofa") + nlp.add_pipe("eds.dates") + + return nlp + + +params = [ + dict(module="pandas", n_jobs=1), + dict(module="pandas", n_jobs=-2), + dict(module="pyspark", n_jobs=None), +] + +try: + import databricks.koalas # noqa F401 + + params.append(dict(module="koalas", n_jobs=None)) +except ImportError: + pass + + +@pytest.mark.parametrize("param", params) +def test_pipelines(param, model): + from pyspark.sql import types as T + + from edsnlp.processing import pipe + + module = param["module"] + + note_nlp = pipe( + note(module=module), + nlp=model, + n_jobs=param["n_jobs"], + context=["note_datetime"], + extensions={ + "score_method": T.StringType(), + "negation": T.BooleanType(), + "hypothesis": T.BooleanType(), + "family": T.BooleanType(), + "reported_speech": T.BooleanType(), + "date.year": T.IntegerType(), + "date.month": T.IntegerType(), + } + if module in ("pyspark", "koalas") + else [ + "score_method", + "negation", + "hypothesis", + "family", + "reported_speech", + "date_year", + "date_month", + ], + additional_spans=["dates"], + ) + + if module == "pyspark": + note_nlp = note_nlp.toPandas() + elif module == "koalas": + note_nlp = note_nlp.to_pandas() + + assert len(note_nlp) == 140 + assert set(note_nlp.columns) == set( + ( + "note_id", + "lexical_variant", + "label", + "span_type", + "start", + "end", + "negation", + "hypothesis", + "reported_speech", + "family", + "score_method", + "date_year", + "date_month", + ) + ) + + +def test_spark_missing_types(model): + from edsnlp.processing import pipe + + with pytest.warns(Warning) as warned: + pipe( + note(module="pyspark"), + nlp=model, + extensions={"negation", "hypothesis", "family"}, + ) + assert any( + "The following schema was inferred" in str(warning.message) + for warning in warned + ) + + +@pytest.mark.parametrize("param", params) +def test_arbitrary_callback(param, model): + from pyspark.sql import types as T + + from edsnlp.processing import pipe + + # We need to test PySpark with an installed function + def dummy_extractor(doc: Doc) -> List[Dict[str, Any]]: + return [ + dict( + snippet=ent.text, + length=len(ent.text), + note_datetime=doc._.note_datetime, + ) + for ent in doc.ents + ] + + module = param["module"] + + note_nlp = pipe( + note(module=module), + nlp=model, + n_jobs=param["n_jobs"], + context=["note_datetime"], + results_extractor=dummy_extractor, + dtypes={ + "snippet": T.StringType(), + "length": T.IntegerType(), + }, + ) + + if module == "pandas": + assert set(note_nlp.columns) == {"snippet", "length", "note_datetime"} + assert (note_nlp.snippet.str.len() == note_nlp.length).all() + + else: + if module == "pyspark": + note_nlp = note_nlp.toPandas() + elif module == "koalas": + note_nlp = note_nlp.to_pandas() + + assert set(note_nlp.columns) == { + "note_id", + "snippet", + "length", + } + assert (note_nlp.snippet.str.len() == note_nlp.length).all()