a b/medacy/tests/data/test_dataset.py
1
import os
2
import shutil
3
import tempfile
4
import unittest
5
from collections import Counter
6
from pathlib import Path
7
8
import pkg_resources
9
10
from medacy.data.dataset import Dataset
11
from medacy.data.annotations import Annotations
12
from medacy.data.data_file import DataFile
13
from medacy.tests.sample_data import test_dir
14
15
16
class TestDataset(unittest.TestCase):
17
    """Unit tests for Dataset"""
18
19
    @classmethod
20
    def setUpClass(cls):
21
        cls.dataset = Dataset(os.path.join(test_dir, 'sample_dataset_1'))
22
        cls.prediction_directory = tempfile.mkdtemp()  # Set up predict directory
23
        cls.entities = cls.dataset.get_labels(as_list=True)
24
        cls.ann_files = []
25
26
        # Fill directory of prediction files (only the text files)
27
        for data_file in cls.dataset:
28
            new_file_path = os.path.join(cls.prediction_directory, data_file.file_name + '.txt')
29
            shutil.copyfile(data_file.txt_path, new_file_path)
30
31
        # Fill a directory with just ann files
32
        cls.ann_dir = tempfile.mkdtemp()
33
        for data_file in cls.dataset:
34
            new_ann_path = os.path.join(cls.ann_dir, data_file.file_name + '.ann')
35
            shutil.copyfile(data_file.ann_path, new_ann_path)
36
37
    @classmethod
38
    def tearDownClass(cls):
39
        pkg_resources.cleanup_resources()
40
        for directory in [cls.prediction_directory, cls.ann_dir]:
41
            shutil.rmtree(directory)
42
43
    def test_init(self):
44
        """Tests initializing Datasets from different directories to see that they create accurate DataFiles"""
45
46
        # Test both txt, ann, and metamapped
47
        test_dir_path = Path(self.dataset.data_directory)
48
        expected = [
49
            DataFile(
50
                file_name="PMC1257590",
51
                txt_path=test_dir_path / "PMC1257590.txt",
52
                ann_path=test_dir_path / "PMC1257590.ann",
53
                metamapped_path=test_dir_path / "metamapped" / "PMC1257590.metamapped"
54
            ),
55
            DataFile(
56
                file_name="PMC1314908",
57
                txt_path=test_dir_path / "PMC1314908.txt",
58
                ann_path=test_dir_path / "PMC1314908.ann",
59
                metamapped_path=test_dir_path / "metamapped" / "PMC1314908.metamapped"
60
            ),
61
            DataFile(
62
                file_name="PMC1392236",
63
                txt_path=test_dir_path / "PMC1392236.txt",
64
                ann_path=test_dir_path / "PMC1392236.ann",
65
                metamapped_path=test_dir_path / "metamapped" / "PMC1392236.metamapped"
66
            )
67
        ]
68
        expected.sort(key=lambda x: x.file_name)
69
        actual = list(self.dataset)
70
        self.assertListEqual(actual, expected)
71
72
        # Test txt only
73
        test_dir_path = Path(self.prediction_directory)
74
        expected = [
75
            DataFile(
76
                file_name="PMC1257590",
77
                txt_path=test_dir_path / "PMC1257590.txt",
78
                ann_path=None,
79
                metamapped_path=None
80
            ),
81
            DataFile(
82
                file_name="PMC1314908",
83
                txt_path=test_dir_path / "PMC1314908.txt",
84
                ann_path=None,
85
                metamapped_path=None
86
            ),
87
            DataFile(
88
                file_name="PMC1392236",
89
                txt_path=test_dir_path / "PMC1392236.txt",
90
                ann_path=None,
91
                metamapped_path=None
92
            )
93
        ]
94
        expected.sort(key=lambda x: x.file_name)
95
        actual = list(Dataset(self.prediction_directory))
96
        self.assertListEqual(actual, expected)
97
98
        # Test ann only
99
        test_dir_path = Path(self.ann_dir)
100
        expected = [
101
            DataFile(
102
                file_name="PMC1257590",
103
                txt_path=None,
104
                ann_path=test_dir_path / "PMC1257590.ann",
105
                metamapped_path=None
106
            ),
107
            DataFile(
108
                file_name="PMC1314908",
109
                txt_path=None,
110
                ann_path=test_dir_path / "PMC1314908.ann",
111
                metamapped_path=None,
112
            ),
113
            DataFile(
114
                file_name="PMC1392236",
115
                txt_path=None,
116
                ann_path=test_dir_path / "PMC1392236.ann",
117
                metamapped_path=None
118
            )
119
        ]
120
        expected.sort(key=lambda x: x.file_name)
121
        actual = list(Dataset(self.ann_dir))
122
        self.assertListEqual(actual, expected)
123
124
    def test_init_with_data_limit(self):
125
        """Tests that initializing with a data limit works"""
126
        dataset = Dataset(self.dataset.data_directory, data_limit=1)
127
        self.assertEqual(len(list(dataset)), 1)
128
129
    def test_generate_annotations(self):
130
        """Tests that generate_annotations() creates Annotations objects"""
131
        for ann in self.dataset.generate_annotations():
132
            self.assertIsInstance(ann, Annotations)
133
134
    def test_get_labels(self):
135
        """Tests that get_labels returns a set of the correct labels"""
136
        expected = {
137
            'DoseFrequency', 'SampleSize', 'TimeUnits', 'Vehicle', 'TestArticlePurity', 'Endpoint', 'TestArticle',
138
            'GroupName', 'DoseDurationUnits', 'GroupSize', 'TimeAtFirstDose', 'Dose', 'DoseDuration', 'Species',
139
            'DoseUnits', 'Sex', 'EndpointUnitOfMeasure', 'TimeEndpointAssessed', 'DoseRoute', 'CellLine', 'Strain'
140
        }
141
        actual = self.dataset.get_labels()
142
        self.assertSetEqual(actual, expected)
143
144
    def test_compute_counts(self):
145
        """Tests that compute_counts() returns a Counter containing counts for all labels"""
146
        counts = self.dataset.compute_counts()
147
        self.assertIsInstance(counts, Counter)
148
        for label in self.dataset.get_labels():
149
            self.assertIn(label, counts.keys())
150
151
    def test_getitem(self):
152
        """Tests that some_dataset['filename'] returns an Annotations for 'filename.ann', or raises FileNotFoundError"""
153
        some_file_name = self.dataset.data_files[0].file_name
154
        result = self.dataset[some_file_name]
155
        self.assertIsInstance(result, Annotations)
156
157
        with self.assertRaises(FileNotFoundError):
158
            ann = self.dataset['notafilepath']
159
160
    def test_valid_datafiles(self):
161
        """Tests that each DataFile in the Dataset is an existing file"""
162
        for d in self.dataset:
163
            self.assertTrue(os.path.isfile(d.txt_path))
164
            self.assertTrue(os.path.isfile(d.ann_path))
165
            self.assertTrue(os.path.isfile(d.metamapped_path))
166
167
168
if __name__ == '__main__':
169
    unittest.main()