Diff of /tests/util.py [000000] .. [030aeb]

Switch to unified view

a b/tests/util.py
1
"""Utilities for unit tests."""
2
3
import datetime
4
import os
5
import re
6
import shutil
7
import subprocess
8
import tempfile
9
import unittest
10
import uuid
11
from pathlib import Path
12
from typing import Callable
13
14
import natsort
15
import numpy as np
16
from pydicom.dataset import FileDataset, FileMetaDataset
17
18
from dosma.cli import SUPPORTED_SCAN_TYPES, parse_args
19
from dosma.core.fitting import monoexponential
20
from dosma.core.io.format_io import ImageDataFormat
21
from dosma.core.med_volume import MedicalVolume
22
from dosma.utils import env
23
from dosma.utils.cmd_line_utils import ActionWrapper
24
25
UNITTEST_DATA_PATH = os.environ.get(
26
    "DOSMA_UNITTEST_DATA_PATH", os.path.join(os.path.dirname(__file__), "../unittest-data/")
27
)
28
UNITTEST_SCANDATA_PATH = os.path.join(UNITTEST_DATA_PATH, "scans")
29
TEMP_PATH = os.path.join(
30
    UNITTEST_SCANDATA_PATH, f"temp-{str(uuid.uuid1())}-{str(uuid.uuid4())}"
31
)  # should be used when for writing with assert_raises clauses
32
33
SCANS = ["qdess", "mapss", "cubequant", "cones"]
34
SCANS_INFO = {
35
    "mapss": {"expected_num_echos": 7},
36
    "qdess": {"expected_num_echos": 2},
37
    "cubequant": {"expected_num_echos": 4},
38
    "cones": {"expected_num_echos": 4},
39
}
40
41
SCAN_DIRPATHS = [os.path.join(UNITTEST_SCANDATA_PATH, x) for x in SCANS]
42
43
# Decimal precision for analysis (quantitative values, etc)
44
DECIMAL_PRECISION = 1  # (+/- 0.1ms)
45
46
# If elastix is available
47
_IS_ELASTIX_AVAILABLE = None
48
49
50
def is_data_available():
51
    disable_data = os.environ.get("DOSMA_UNITTEST_DISABLE_DATA", "").lower() == "true"
52
    return not disable_data and os.path.isdir(UNITTEST_DATA_PATH)
53
54
55
def get_scan_dirpath(scan: str):
56
    for ind, x in enumerate(SCANS):
57
        if scan == x:
58
            return SCAN_DIRPATHS[ind]
59
60
61
def get_dicoms_path(fp):
62
    return os.path.join(fp, "dicoms")
63
64
65
def get_write_path(fp, data_format: ImageDataFormat):
66
    return os.path.join(fp, "multi-echo-write-%s" % data_format.name)
67
68
69
def get_read_paths(fp, data_format: ImageDataFormat):
70
    """Get ground truth data (produced by imageviewer like itksnap, horos, etc)"""
71
    base_name = os.path.join(fp, "multi-echo-gt-%s" % data_format.name)
72
    files_or_dirs = os.listdir(base_name)
73
    fd = [x for x in files_or_dirs if re.match("e[0-9]+", x)]
74
    files_or_dirs = natsort.natsorted(fd)
75
76
    return [os.path.join(base_name, x) for x in files_or_dirs]
77
78
79
def get_data_path(fp):
80
    return os.path.join(fp, f"data-{str(uuid.uuid1())}")
81
82
83
def get_expected_data_path(fp):
84
    return os.path.join(fp, "expected")
85
86
87
def is_elastix_available():
88
    global _IS_ELASTIX_AVAILABLE
89
90
    if _IS_ELASTIX_AVAILABLE is None:
91
        disable_elastix = os.environ.get("DOSMA_UNITTEST_DISABLE_ELASTIX", None)
92
        if disable_elastix is None:
93
            _IS_ELASTIX_AVAILABLE = subprocess.run(["elastix", "--help"]).returncode == 0
94
        else:
95
            _IS_ELASTIX_AVAILABLE = disable_elastix.lower() != "true"
96
97
    return _IS_ELASTIX_AVAILABLE
98
99
100
def num_workers() -> int:
101
    return int(os.environ.get("DOSMA_NUM_WORKERS", min(8, os.cpu_count())))
102
103
104
def requires_packages(*packages):
105
    """
106
    Decorator for functions that should only execute when
107
    all packages defined by *args are supported.
108
    """
109
110
    def _decorator(func):
111
        def _wrapper(*args, **kwargs):
112
            if all([env.package_available(x) for x in packages]):
113
                func(*args, **kwargs)
114
115
        return _wrapper
116
117
    return _decorator
118
119
120
def generate_monoexp_data(shape=None, x=None, a=1.0, b=None):
121
    """Generate sample monoexponetial data.
122
    ``a=1.0``, ``b`` is randomly generated in interval [0.1, 1.1).
123
124
    The equation is :math:`y =  a * \\exp (b*x)`.
125
    """
126
    if b is None:
127
        b = np.random.rand(*shape) + 0.1
128
    else:
129
        shape = b.shape
130
    if x is None:
131
        x = np.asarray([0.5, 1.0, 2.0, 4.0])
132
    y = [MedicalVolume(monoexponential(t, a, b), affine=np.eye(4)) for t in x]
133
    return x, y, a, b
134
135
136
def build_dummy_headers(shape, fields=None):
137
    """Build dummy ``pydicom.FileDataset`` headers.
138
139
    Note these headers are not dicom compliant and should not be used to write out DICOM
140
    files.
141
142
    Args:
143
        shape (int or tuple[int]): Shape of headers array.
144
        fields (Dict): Fields and corresponding values to use to populate the header.
145
146
    Returns:
147
        ndarray: Headers
148
    """
149
    if isinstance(shape, int):
150
        shape = (shape,)
151
    num_headers = np.prod(shape)
152
    headers = np.asarray([_build_dummy_pydicom_header(fields) for _ in range(num_headers)])
153
    return headers.reshape(shape)
154
155
156
def _build_dummy_pydicom_header(fields=None):
157
    """Builds dummy pydicom-based header.
158
159
    Note these headers are not dicom compliant and should not be used to write out DICOM
160
    files.
161
162
    Adapted from
163
    https://pydicom.github.io/pydicom/dev/auto_examples/input_output/plot_write_dicom.html
164
    """
165
    suffix = ".dcm"
166
    filename_little_endian = tempfile.NamedTemporaryFile(suffix=suffix).name
167
168
    file_meta = FileMetaDataset()
169
    file_meta.MediaStorageSOPClassUID = "1.2.840.10008.5.1.4.1.1.2"
170
    file_meta.MediaStorageSOPInstanceUID = "1.2.3"
171
    file_meta.ImplementationClassUID = "1.2.3.4"
172
173
    # Create the FileDataset instance (initially no data elements, but file_meta supplied).
174
    ds = FileDataset(filename_little_endian, {}, file_meta=file_meta, preamble=b"\0" * 128)
175
176
    # Add the data elements -- not trying to set all required here. Check DICOM standard.
177
    ds.PatientName = "Test^Firstname"
178
    ds.PatientID = "123456"
179
180
    if fields is not None:
181
        for k, v in fields.items():
182
            setattr(ds, k, v)
183
184
    # Set the transfer syntax
185
    ds.is_little_endian = True
186
    ds.is_implicit_VR = True
187
188
    # Set creation date/time
189
    dt = datetime.datetime.now()
190
    ds.ContentDate = dt.strftime("%Y%m%d")
191
    timeStr = dt.strftime("%H%M%S.%f")  # long format with micro seconds
192
    ds.ContentTime = timeStr
193
194
    return ds
195
196
197
class TempPathMixin(unittest.TestCase):
198
    """Testing helper that creates temporary path for the class."""
199
200
    data_dirpath = None
201
202
    @classmethod
203
    def setUpClass(cls):
204
        cls.data_dirpath = Path(
205
            os.path.join(
206
                get_data_path(os.path.join(UNITTEST_SCANDATA_PATH, "temp")), f"{cls.__name__}"
207
            )
208
        )
209
        os.makedirs(cls.data_dirpath, exist_ok=True)
210
211
    @classmethod
212
    def tearDownClass(cls):
213
        shutil.rmtree(cls.data_dirpath)
214
215
216
class ScanTest(TempPathMixin):
217
    from dosma.scan_sequences.scans import ScanSequence
218
219
    SCAN_TYPE = ScanSequence  # override in subclasses
220
221
    dicom_dirpath = None
222
223
    def setUp(self):
224
        print("Testing: ", self._testMethodName)
225
226
    @classmethod
227
    def setUpClass(cls):
228
        super().setUpClass()
229
        if is_data_available():
230
            cls.dicom_dirpath = Path(
231
                get_dicoms_path(os.path.join(UNITTEST_SCANDATA_PATH, cls.SCAN_TYPE.NAME))
232
            )
233
234
    def test_has_cmd_line_actions_attr(self):
235
        """
236
        If scan can be accessed via the command line,
237
        verify that the scan has a ``cmd_line_actions`` method.
238
        """
239
        # if the scan is not supported via the command line, then ignore this test
240
        if self.SCAN_TYPE not in SUPPORTED_SCAN_TYPES:
241
            return
242
243
        assert hasattr(
244
            self.SCAN_TYPE, "cmd_line_actions"
245
        ), "All scans supported by command line must have `cmd_line_actions` method"
246
247
        cmd_line_actions = self.SCAN_TYPE.cmd_line_actions()
248
        for func, action in cmd_line_actions:
249
            assert isinstance(func, Callable)
250
            assert isinstance(action, ActionWrapper)
251
252
            func_name = func.__name__
253
            cls_name = self.SCAN_TYPE.__name__
254
            assert action.name, f"Action for `{cls_name}.{func_name}()` must have a name"
255
            assert action.help, f"Action for `{cls_name}.{func_name}()` must have help menu"
256
257
        assert hasattr(
258
            type(self), "test_cmd_line"
259
        ), "All scan supported in command line must have test methods `test_cmd_line`"
260
261
    def __cmd_line_helper__(self, cmdline_str: str):
262
        env_args = {"--num-workers": num_workers()}
263
        for arg, value in env_args.items():
264
            if arg in cmdline_str:
265
                continue
266
            cmdline_str = f"{arg} {value} {cmdline_str}"
267
268
        cmdline_input = cmdline_str.strip().split()
269
        parse_args(cmdline_input)