--- a +++ b/tests/processing/test_backends.py @@ -0,0 +1,577 @@ +import random +import time +from itertools import chain +from pathlib import Path +from typing import Any, Dict, List, Sequence + +import pandas as pd +import pytest +from confit import validate_arguments +from spacy.tokens import Doc + +import edsnlp.data +import edsnlp.processing +from edsnlp.data.converters import get_current_tokenizer +from edsnlp.processing.multiprocessing import get_dispatch_schedule + +try: + import torch.nn +except ImportError: + torch = None + + +docs = [ + { + "note_id": 1234, + "note_text": "This is a test.", + "entities": [ + { + "note_nlp_id": 0, + "start_char": 0, + "end_char": 4, + "lexical_variant": "This", + "note_nlp_source_value": "test", + "negation": True, + }, + { + "note_nlp_id": 1, + "start_char": 5, + "end_char": 7, + "lexical_variant": "is", + "note_nlp_source_value": "test", + }, + ], + }, + { + "note_id": 0, + "note_text": "This is an empty document.", + "entities": None, + }, +] + + +@pytest.mark.parametrize( + "reader_format,reader_converter,backend,writer_format,writer_converter,worker_io", + [ + ("pandas", "omop", "simple", "pandas", "omop", False), + ("pandas", "omop", "multiprocessing", "pandas", "omop", False), + ("pandas", "omop", "spark", "pandas", "omop", False), + ("parquet", "omop", "simple", "parquet", "omop", False), + ("parquet", "omop", "multiprocessing", "parquet", "omop", False), + ("parquet", "omop", "spark", "parquet", "omop", False), + ("parquet", "omop", "multiprocessing", "parquet", "omop", True), + ("parquet", "omop", "spark", "parquet", "omop", True), + ("parquet", "omop", "multiprocessing", "iterable", None, False), + ], +) +def test_end_to_end( + reader_format, + reader_converter, + backend, + writer_format, + writer_converter, + worker_io, + nlp_eds, + tmp_path, +): + nlp = nlp_eds + rsrc = Path(__file__).parent.parent.resolve() / "resources" + if reader_format == "pandas": + pandas_dataframe = pd.DataFrame(docs) + data = edsnlp.data.from_pandas( + pandas_dataframe, + converter=reader_converter, + ) + elif reader_format == "parquet": + data = edsnlp.data.read_parquet( + rsrc / "docs.parquet", + converter=reader_converter, + read_in_worker=worker_io, + ) + else: + raise ValueError(reader_format) + + data = data.map_batches(lambda x: sorted(x, key=len), batch_size=2) + data = data.map_pipeline(nlp) + + data = data.set_processing( + backend=backend, + show_progress=True, + batch_by="words", + batch_size=2, + ) + + if writer_format == "pandas": + data.to_pandas(converter=writer_converter) + elif writer_format == "parquet": + data.write_parquet( + tmp_path, + converter=writer_converter, + write_in_worker=worker_io, + ) + elif writer_format == "iterable": + list(data) + else: + raise ValueError(writer_format) + + +def test_multiprocessing_backend(frozen_ml_nlp): + texts = ["Ceci est un exemple", "Ceci est un autre exemple"] + docs = list( + frozen_ml_nlp.pipe( + texts * 20, + batch_size=2, + ).set_processing( + backend="multiprocessing", + num_cpu_workers=-1, + sort_chunks=True, + chunk_size=2, + batch_by="words", + show_progress=True, + ) + ) + assert len(docs) == 40 + + +def error_pipe(doc: Doc): + if doc._.note_id == "text-3": + raise ValueError("error") + return doc + + +@pytest.mark.parametrize( + "backend,deterministic", + [ + ("simple", True), + ("multiprocessing", True), + ("multiprocessing", False), + ("spark", True), + ], +) +def test_multiprocessing_gpu_stub_backend(frozen_ml_nlp, backend, deterministic): + text1 = "Ceci est un exemple" + text2 = "Ceci est un autre exemple" + stream = frozen_ml_nlp.pipe( + chain.from_iterable( + [ + text1, + text2, + ] + for i in range(2) + ), + ) + if backend == "simple": + pass + elif backend == "multiprocessing": + stream = stream.set_processing( + batch_size=2, + num_gpu_workers=1, + num_cpu_workers=1, + gpu_worker_devices=["cpu"], + deterministic=deterministic, + ) + elif backend == "spark": + stream = stream.set_processing(backend="spark") + list(stream) + + +def test_multiprocessing_gpu_stub_multi_cpu_deterministic_backend(frozen_ml_nlp): + text1 = "Exemple" + text2 = "Ceci est un autre exemple" + text3 = "Ceci est un très long exemple ! Regardez tous ces mots !" + texts = [text1, text2, text3] * 100 + random.Random(42).shuffle(texts) + stream = frozen_ml_nlp.pipe(iter(texts)) + stream = stream.set_processing( + batch_size="15 words", + num_gpu_workers=1, + num_cpu_workers=2, + deterministic=True, + # show_progress=True, + # just to test in gpu-less environments + gpu_worker_devices=["cpu"], + ) + list(stream) + + +@pytest.mark.parametrize("wait", [True, False]) +def test_multiprocessing_gpu_stub_wait(frozen_ml_nlp, wait): + text1 = "Ceci est un exemple" + text2 = "Ceci est un autre exemple" + it = iter( + frozen_ml_nlp.pipe( + chain.from_iterable( + [ + text1, + text2, + ] + for i in range(2) + ), + ).set_processing( + batch_size=2, + num_gpu_workers=1, + num_cpu_workers=1, + gpu_worker_devices=["cpu"], + ) + ) + if wait: + time.sleep(5) + list(it) + + +def simple_converter(obj): + tok = get_current_tokenizer() + doc = tok(obj["content"]) + doc._.note_id = obj["id"] + return doc + + +def test_iterable_error(): + class Gen: + def __iter__(self): + for i in range(5): + if i == 3: + raise ValueError("error") + yield {"content": f"text-{i}", "id": f"text-{i}"} + + with pytest.raises(ValueError): + list( + edsnlp.data.from_iterable(Gen(), converter=simple_converter).set_processing( + num_cpu_workers=2 + ) + ) + + +def test_multiprocessing_rb_error(ml_nlp): + text1 = "Ceci est un exemple" + text2 = "Ceci est un autre exemple" + ml_nlp.add_pipe(error_pipe, name="error", after="sentences") + with pytest.raises(ValueError): + docs = edsnlp.data.from_iterable( + chain.from_iterable( + [ + {"content": text1, "id": f"text-{i}"}, + {"content": text2, "id": f"other-text-{i}"}, + ] + for i in range(5) + ), + converter=simple_converter, + ).map(lambda x: time.sleep(0.2) or x) + docs = ml_nlp.pipe( + docs, + n_process=2, + batch_size=2, + ) + list(docs) + + +if torch is not None: + from edsnlp.core.torch_component import TorchComponent + + class DeepLearningError(TorchComponent): + def __init__(self, *args, **kwargs): + super().__init__() + + def preprocess(self, doc): + return {"num_words": len(doc), "doc_id": doc._.note_id} + + def collate(self, batch): + return { + "num_words": torch.tensor(batch["num_words"]), + "doc_id": batch["doc_id"], + } + + def forward(self, batch): + if "text-1" in batch["doc_id"]: + raise RuntimeError("Deep learning error") + return {} + + +@pytest.mark.skipif(torch is None, reason="torch not installed") +def test_multiprocessing_ml_error(ml_nlp): + text1 = "Ceci est un exemple" + text2 = "Ceci est un autre exemple" + ml_nlp.add_pipe( + DeepLearningError(pipeline=ml_nlp), + name="error", + after="sentences", + ) + + with pytest.raises(RuntimeError) as e: + docs = edsnlp.data.from_iterable( + chain.from_iterable( + [ + {"content": text1, "id": f"text-{i}"}, + {"content": text2, "id": f"other-text-{i}"}, + ] + for i in range(5) + ), + converter=simple_converter, + ) + docs = ml_nlp.pipe(docs) + docs = docs.set_processing( + batch_size=2, + num_gpu_workers=1, + num_cpu_workers=1, + gpu_worker_devices=["cpu"], + ) + list(docs) + assert "Deep learning error" in str(e.value) + + +@pytest.mark.parametrize( + "backend", + ["simple", "multiprocessing", "spark"], +) +def test_generator(backend): + items = ["abc", "def", "ghij"] + items = edsnlp.data.from_iterable(items) + + def gen(x): + for char in x: + yield char + + items = items.map(gen).set_processing(backend=backend, num_cpu_workers=2) + # output from workers will be read in a round-robin fashion + # ie zip( + # ("a", "b", "c", "g", "h", "i", "j") # worker 1 + # ("d", "e", "f") # worker 2 + # ) + assert set(items) == {"a", "d", "b", "e", "c", "f", "g", "h", "i", "j"} + + +@pytest.mark.parametrize("deterministic", [True, False]) +def test_multiprocessing_sleep(deterministic): + def process(x): + if x % 2 == 0: + time.sleep(0.1) + return x + + items = list(range(100)) + items = edsnlp.data.from_iterable(items) + items = items.map(process) + items = items.set_processing( + backend="multiprocessing", + deterministic=deterministic, + num_cpu_workers=2, + ) + items = list(items) + if deterministic: + assert items == list(range(100)) + else: + assert items != list(range(100)) + + +@pytest.mark.parametrize("num_cpu_workers", [0, 1, 2]) +def test_deterministic_skip(num_cpu_workers): + def process_batch(x): + return [i for i in x if i < 10 or i % 2 == 0] + + items = list(range(100)) + items = edsnlp.data.from_iterable(items) + items = items.map_batches(process_batch) + items = items.set_processing( + deterministic=True, + num_cpu_workers=num_cpu_workers, + ) + items = list(items) + assert items == [*range(0, 10), *range(10, 100, 2)] + + +@pytest.mark.parametrize( + "backend", + ["simple", "multiprocesing"], +) +@pytest.mark.skipif(torch is None, reason="torch not installed") +def test_backend_cache(backend): + import torch + + from edsnlp.core.torch_component import ( + BatchInput, + BatchOutput, + TorchComponent, + _caches, + ) + + @validate_arguments + class InnerComponent(TorchComponent): + def __init__(self, nlp=None, *args, **kwargs): + super().__init__() + self.called_forward = False + + def preprocess(self, doc): + return {"text": doc.text} + + def collate(self, batch: Dict[str, Any]) -> BatchInput: + return {"sizes": torch.as_tensor([len(x) for x in batch["text"]])} + + def forward(self, batch): + assert not self.called_forward + self.called_forward = True + return {"sizes": batch["sizes"] * 2} + + @validate_arguments + class OuterComponent(TorchComponent): + def __init__(self, inner): + super().__init__() + self.inner = inner + + def preprocess(self, doc): + return {"inner": self.inner.preprocess(doc)} + + def collate(self, batch: Dict[str, Any]) -> BatchInput: + return {"inner": self.inner.collate(batch["inner"])} + + def forward(self, batch: BatchInput) -> BatchOutput: + return {"inner": self.inner(batch["inner"])["sizes"].clone()} + + def postprocess( + self, + docs: Sequence[Doc], + results: BatchOutput, + inputs: List[Dict[str, Any]], + ) -> Sequence[Doc]: + return docs + + nlp = edsnlp.blank("eds") + nlp.add_pipe(InnerComponent(), name="inner") + nlp.add_pipe(OuterComponent(nlp.pipes.inner), name="outer") + text1 = "Word" + text2 = "A phrase" + text3 = "This is a sentence" + text4 = "This is a longer document with many words." + stream = edsnlp.data.from_iterable([text1, text2, text3, text4]) + stream = stream.map_pipeline(nlp) + if backend == "simple": + assert list(_caches) == [] + list(stream.set_processing(batch_size=4)) + assert list(_caches) == [] + elif backend == "multiprocessing": + list( + stream.set_processing( + batch_size=2, + num_gpu_workers=2, + num_cpu_workers=1, + gpu_worker_devices=["cpu", "cpu"], + ) + ) + elif backend == "spark": + list(stream.set_processing(backend="spark", batch_size=4)) + + +def test_task_dispatch_schedule(): + fn = get_dispatch_schedule + + assert fn(0, range(4), range(2)) == [0, 0] + assert fn(1, range(4), range(2)) == [1, 1] + assert fn(2, range(4), range(2)) == [0, 0] + assert fn(3, range(4), range(2)) == [1, 1] + + assert fn(0, range(3), range(2)) == [0, 0] + assert fn(1, range(3), range(2)) == [1, 1] + assert fn(2, range(3), range(2)) == [0, 1] + + assert fn(0, range(2), range(3)) == [0, 0, 2] + assert fn(1, range(2), range(3)) == [1, 1, 2] + assert fn(0, range(2), range(3)) == [0, 0, 2] + assert fn(1, range(2), range(3)) == [1, 1, 2] + + assert fn(0, range(16), range(10)) == [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + assert fn(1, range(16), range(10)) == [1, 1, 1, 1, 1, 1, 1, 1, 1, 1] + assert fn(2, range(16), range(10)) == [2, 2, 2, 2, 2, 2, 2, 2, 2, 2] + assert fn(3, range(16), range(10)) == [3, 3, 3, 3, 3, 3, 3, 3, 3, 3] + assert fn(4, range(16), range(10)) == [4, 4, 4, 4, 4, 4, 4, 4, 4, 4] + assert fn(5, range(16), range(10)) == [5, 5, 5, 5, 5, 5, 5, 5, 5, 5] + assert fn(6, range(16), range(10)) == [6, 6, 6, 6, 6, 6, 6, 6, 6, 6] + assert fn(7, range(16), range(10)) == [7, 7, 7, 7, 7, 7, 7, 7, 7, 7] + assert fn(8, range(16), range(10)) == [8, 8, 8, 8, 8, 8, 8, 8, 8, 8] + assert fn(9, range(16), range(10)) == [9, 9, 9, 9, 9, 9, 9, 9, 9, 9] + assert fn(10, range(16), range(10)) == [0, 0, 0, 0, 0, 0, 6, 6, 6, 6] + assert fn(11, range(16), range(10)) == [1, 1, 1, 1, 1, 1, 7, 7, 7, 7] + assert fn(12, range(16), range(10)) == [2, 2, 2, 2, 2, 2, 8, 8, 8, 8] + assert fn(13, range(16), range(10)) == [3, 3, 3, 3, 3, 3, 9, 9, 9, 9] + assert fn(14, range(16), range(10)) == [4, 4, 4, 4, 4, 4, 6, 7, 6, 7] + assert fn(15, range(16), range(10)) == [5, 5, 5, 5, 5, 5, 8, 9, 8, 9] + + +def test_multiprocessing_on_simple_iterable_in_main(): + exec( + """ +import edsnlp.data + +counter = 0 + +def complex_func(n): + global counter + counter += 1 + return n * n + +stream = edsnlp.data.from_iterable(range(20)) +stream = stream.map(complex_func) +stream = stream.set_processing(num_cpu_workers=2) +res = list(stream) +""", + {"__MODULE__": "__main__"}, + ) + + +def test_multiprocessing_on_full_example_in_main(): + exec( + """ +from spacy.tokens import Doc + +import edsnlp +import edsnlp.pipes as eds +from edsnlp.data.converters import get_current_tokenizer + +if not Doc.has_extension("note_text"): + Doc.set_extension("note_text", default=None) +if not Doc.has_extension("date"): + Doc.set_extension("date", default=None) +if not Doc.has_extension("person_id"): + Doc.set_extension("person_id", default=None) + + +def convert_row_to_doc(row): + if row["note_text"] is None: + row["note_text"] = "" + text = row["note_text"] + doc = get_current_tokenizer()(text) + doc._.note_id = row["note_id"] + return doc + + +def convert_doc_to_row(doc_): + note_id = doc_._.note_id + person_id = doc_._.person_id + note_text = doc_.text + result = [] + for date in doc_.spans["dates"]: + result.append( + { + "note_id": note_id, + "person_id": person_id, + # "note_text" : note_text, + # "note_doc" : doc_, + "date": date._.date.datetime, + } + ) + return result + + +nlp = edsnlp.blank("eds") +# nlp = eds_biomedic_aphp.load() +# nlp.add_pipe(eds.sections()) +nlp.add_pipe(eds.dates()) +nlp.add_pipe(eds.sentences()) +data = edsnlp.data.from_iterable( + [{"note_text": "Test", "note_id": "test"}], + converter=convert_row_to_doc, +) +data = data.map_pipeline(nlp) +data_pd = data.set_processing(show_progress=True, num_cpu_workers=5).to_pandas( + converter=convert_doc_to_row +) +""", + {"__MODULE__": "__main__"}, + )