a b/tests/surrogates/test_rewrite_dataset.py
1
import argparse
2
import filecmp
3
import glob
4
from os.path import basename, dirname, join
5
6
import pytest
7
8
from deidentify.base import Annotation
9
from deidentify.surrogates import rewrite_dataset
10
11
12
def test_apply_surrogates():
13
    text = 'ccc cc ccc c c ccc cccccc cccc'
14
    annotations = [
15
        Annotation('ccc', start=0, end=3, tag='A'),
16
        Annotation('cc', start=4, end=6, tag='A'),
17
        Annotation('ccc', start=15, end=18, tag='B')
18
    ]
19
    surrogates = ['a', 'dd', 'bbbbb']
20
21
    surrogate_doc = rewrite_dataset.apply_surrogates(text, annotations, surrogates)
22
    assert surrogate_doc.text == 'a dd ccc c c bbbbb cccccc cccc'
23
    assert surrogate_doc.annotations == [
24
        Annotation('a', start=0, end=1, tag='A'),
25
        Annotation('dd', start=2, end=4, tag='A'),
26
        Annotation('bbbbb', start=13, end=18, tag='B')
27
    ]
28
    assert surrogate_doc.annotations_without_surrogates == []
29
30
31
def test_apply_surrogates_no_annotations():
32
    surrogate_doc = rewrite_dataset.apply_surrogates('ccc cc ccc', annotations=[], surrogates=[])
33
    assert surrogate_doc.text == 'ccc cc ccc'
34
    assert surrogate_doc.annotations == []
35
    assert surrogate_doc.annotations_without_surrogates == []
36
37
38
def test_apply_surrogates_errors_raise():
39
    text = 'ccc cc ccc'
40
    annotations = [
41
        Annotation('ccc', start=0, end=3, tag='A'),
42
        Annotation('cc', start=4, end=6, tag='A'),
43
        Annotation('ccc', start=7, end=10, tag='B')
44
    ]
45
    surrogates = ['a', None, 'b']
46
47
    with pytest.raises(ValueError):
48
        rewrite_dataset.apply_surrogates(text, annotations, surrogates)
49
50
    with pytest.raises(ValueError):
51
        rewrite_dataset.apply_surrogates(text, annotations, surrogates, errors='raise')
52
53
54
def test_apply_surrogates_errors_ignore():
55
    text = 'ccc cc ccc'
56
    annotations = [
57
        Annotation('ccc', start=0, end=3, tag='A'),
58
        Annotation('cc', start=4, end=6, tag='A'),
59
        Annotation('ccc', start=7, end=10, tag='B')
60
    ]
61
    surrogates = ['a', None, 'b']
62
63
    surrogate_doc = rewrite_dataset.apply_surrogates(text, annotations, surrogates, errors='ignore')
64
    assert surrogate_doc.text == 'a cc b'
65
    assert surrogate_doc.annotations == [
66
        Annotation('a', start=0, end=1, tag='A'),
67
        Annotation('cc', start=2, end=4, tag='A'),
68
        Annotation('b', start=5, end=6, tag='B')
69
    ]
70
    assert surrogate_doc.annotations_without_surrogates == [
71
        Annotation('cc', start=4, end=6, tag='A'),
72
    ]
73
74
75
def test_apply_surrogates_errors_coerce():
76
    text = 'ccc cc ccc'
77
    annotations = [
78
        Annotation('ccc', start=0, end=3, tag='A'),
79
        Annotation('cc', start=4, end=6, tag='A'),
80
        Annotation('ccc', start=7, end=10, tag='B')
81
    ]
82
    surrogates = ['a', None, 'b']
83
84
    surrogate_doc = rewrite_dataset.apply_surrogates(text, annotations, surrogates, errors='coerce')
85
    assert surrogate_doc.text == 'a [A] b'
86
    assert surrogate_doc.annotations == [
87
        Annotation('a', start=0, end=1, tag='A'),
88
        Annotation('[A]', start=2, end=5, tag='A'),
89
        Annotation('b', start=6, end=7, tag='B')
90
    ]
91
    assert surrogate_doc.annotations_without_surrogates == [
92
        Annotation('cc', start=4, end=6, tag='A'),
93
    ]
94
95
96
def test_main(tmpdir):
97
    args = argparse.Namespace(
98
        surrogate_table=join(dirname(__file__), 'data/annotations-rewrite-table.csv'),
99
        data_path=join(dirname(__file__), 'data/original'),
100
        output_path=tmpdir
101
    )
102
103
    ann_files = glob.glob(join(dirname(__file__), 'data/rewritten/*.ann'))
104
    txt_files = glob.glob(join(dirname(__file__), 'data/rewritten/*.txt'))
105
    to_compare = ann_files + txt_files
106
    to_compare = [basename(f) for f in to_compare]
107
108
    rewrite_dataset.main(args)
109
110
    for file in to_compare:
111
        expected = join(dirname(__file__), 'data/rewritten/', file)
112
        actual = join(tmpdir, file)
113
        assert filecmp.cmp(expected, actual)