Switch to unified view

a b/dosma/scan_sequences/scan_io.py
1
import inspect
2
import os
3
import warnings
4
from abc import ABC
5
from pathlib import Path
6
from typing import Any, Dict, Optional, Sequence, Set, Union
7
8
import pydicom
9
10
from dosma.core.io import format_io_utils as fio_utils
11
from dosma.core.io.dicom_io import DicomReader
12
from dosma.core.io.format_io import ImageDataFormat
13
from dosma.core.med_volume import MedicalVolume
14
from dosma.defaults import preferences
15
from dosma.tissues.tissue import Tissue
16
from dosma.utils import io_utils
17
18
19
def _contains_type(value, types):
20
    """Returns ``True`` if any value is an instance of ``types``."""
21
    if isinstance(value, types):
22
        return True
23
    if not isinstance(value, str) and isinstance(value, (Sequence, Set)) and len(value) > 0:
24
        return any(_contains_type(x, types) for x in value)
25
    elif isinstance(value, Dict):
26
        return _contains_type(value.keys(), types) or _contains_type(value.values(), types)
27
    return isinstance(value, types)
28
29
30
class ScanIOMixin(ABC):
31
    # This is just a summary on variables used in this abstract class,
32
    # the proper values/initialization should be done in child class.
33
    NAME: str
34
    __DEFAULT_SPLIT_BY__: Optional[str]
35
    _from_file_args: Dict[str, Any]
36
37
    @classmethod
38
    def from_dicom(
39
        cls,
40
        dir_or_files,
41
        group_by=None,
42
        ignore_ext: bool = False,
43
        num_workers: int = 0,
44
        verbose: bool = False,
45
        **kwargs,
46
    ):
47
        """Load scan from dicom files.
48
49
        Args:
50
            dir_or_files (str): The path to dicom directory or files.
51
            group_by: DICOM field tag name or tag number used to group dicoms. Defaults
52
                to scan's ``__DEFAULT_SPLIT_BY__``.
53
            ignore_ext (bool, optional): If `True`, ignore extension (`.dcm`)
54
                when loading dicoms from directory.
55
            num_workers (int, optional): Number of workers to use for loading.
56
            verbose (bool, optional): If ``True``, enable verbose logging for dicom loading.
57
            kwargs: Other keywords required to construct scan.
58
59
        Returns:
60
            The scan.
61
        """
62
        dr = DicomReader(num_workers, verbose)
63
        if group_by is None:
64
            group_by = cls.__DEFAULT_SPLIT_BY__
65
        volumes = dr.load(dir_or_files, group_by, ignore_ext)
66
67
        if isinstance(dir_or_files, (str, Path, os.PathLike)):
68
            dir_or_files = os.path.abspath(dir_or_files)
69
        else:
70
            dir_or_files = type(dir_or_files)([os.path.abspath(x) for x in dir_or_files])
71
72
        scan = cls(volumes, **kwargs)
73
        scan._from_file_args = {
74
            "dir_or_files": dir_or_files,
75
            "ignore_ext": ignore_ext,
76
            "group_by": group_by,
77
            "_type": "dicom",
78
        }
79
80
        return scan
81
82
    @classmethod
83
    def from_dict(cls, data: Dict[str, Any], force: bool = False):
84
        """Loads class from data dictionary.
85
86
        Args:
87
            data (Dict): The data.
88
            force (bool, optional): If ``True``, writes attributes even if they do not exist.
89
                Use with caution.
90
91
        Returns:
92
            The scan
93
94
        Examples:
95
            >>> scan = ... # some scan
96
            >>> filepath = scan.save("/path/to/base/directory")
97
            >>> scan_from_saved = type(scan).from_dict(io_utils.load_pik(filepath))
98
            >>> scan_from_dict = type(scan).from_dict(scan.__dict__)
99
        """
100
        # TODO: Add check for deprecated and converted attribute names.
101
        data = cls._convert_attr_name(data)
102
103
        # TODO: Convert metadata to appropriate type.
104
        # Converting metadata type is important when loading MedicalVolume data (for example).
105
        # The data is stored as a path, but should be loaded as a MedicalVolume.
106
        data = cls.load_custom_data(data)
107
108
        signature = inspect.signature(cls)
109
        init_metadata = {k: v for k, v in data.items() if k in signature.parameters}
110
        scan = cls(**init_metadata)
111
        for k in init_metadata.keys():
112
            data.pop(k)
113
114
        for k, v in data.items():
115
            if not hasattr(scan, k) and not force:
116
                warnings.warn(f"{cls.__name__} does not have attribute {k}. Skipping...")
117
                continue
118
            scan.__setattr__(k, v)
119
120
        return scan
121
122
    def save(
123
        self,
124
        path: str,
125
        save_custom: bool = False,
126
        image_data_format: ImageDataFormat = None,
127
        num_workers: int = 0,
128
    ):
129
        """Saves scan data to disk with option for custom saving.
130
131
        Custom saving may be useful to reduce redundant saving and/or save data in standard
132
        compatible formats (e.g. medical images - nifti/dicom), which are not feasible with
133
        python serialization libraries, like pickle.
134
135
        When ``save_custom=True``, this method overloads standard pickling with customizable
136
        saving by first saving data in customizable way (e.g. MedicalVolume -> Nifti file),
137
        and then pickling the reference to the saved object (e.g. Nifti filepath).
138
139
        Currently certain custom saving of objects such as ``pydicom.FileDataset`` and
140
        :cls:`Tissue` objects are not supported.
141
142
        To load the data, do the following:
143
144
        >>> filepath = scan.save("/path/to/directory", save_custom=True)
145
        >>> scan_loaded = type(scan).load(io_utils.load_pik(filepath))
146
147
        Args:
148
            path (str): Directory where data is stored.
149
            data_format (ImageDataFormat, optional): Format to save data.
150
                Defaults to ``preferences.image_data_format``.
151
            save_custom (bool, optional): If ``True``, saves data in custom way specified
152
                by :meth:`save_custom_data` in format specified by ``data_format``. For
153
                example, for default classes this will save :cls:`MedicalVolume` data
154
                to nifti/dicom files as specified by ``image_data_format``.
155
            image_data_format (ImageDataFormat, optional): The data format to save
156
                :cls:`MedicalVolume` data. Only used if save_custom is ``True``.
157
            num_workers (int, bool): Number of workers for saving custom data.
158
                Only used if save_custom is ``True``.
159
160
        Returns:
161
            str: The path to the pickled file.
162
        """
163
        if image_data_format is None:
164
            image_data_format = preferences.image_data_format
165
166
        save_dirpath = path  # self._save_dir(path)
167
        os.makedirs(save_dirpath, exist_ok=True)
168
        filepath = os.path.join(save_dirpath, "%s.data" % self.NAME)
169
170
        metadata: Dict = {}
171
        for attr in self.__serializable_variables__():
172
            metadata[attr] = self.__getattribute__(attr)
173
174
        if save_custom:
175
            metadata = self._save(
176
                metadata, save_dirpath, image_data_format=image_data_format, num_workers=num_workers
177
            )
178
179
        io_utils.save_pik(filepath, metadata)
180
        return filepath
181
182
    @classmethod
183
    def load(cls, path_or_data: Union[str, Dict], num_workers: int = 0):
184
        """Load scan.
185
186
        This method overloads the :func:`from_dict` method by supporting loading from a file
187
        in addition to the data dictionary. If loading and constructing a scan using
188
        :func:`from_dict` fails, defaults to loading data from original dicoms
189
        (if ``self._from_file_args`` is initialized).
190
191
        Args:
192
            path_or_data (Union[str, Dict]): Pickle file to load or data dictionary.
193
            num_workers (int, optional): Number of workers to use for loading.
194
195
        Returns:
196
            ScanSequence: Of type ``cls``.
197
198
        Raises:
199
            ValueError: If ``scan`` cannot be constructed.
200
        """
201
        if isinstance(path_or_data, (str, Path, os.PathLike)):
202
            if os.path.isdir(path_or_data):
203
                path_or_data = os.path.join(path_or_data, f"{cls.NAME}.data")
204
205
            if not os.path.isfile(path_or_data):
206
                raise FileNotFoundError(f"File {path_or_data} does not exist")
207
            data = io_utils.load_pik(path_or_data)
208
        else:
209
            data = path_or_data
210
211
        try:
212
            scan = cls.from_dict(data)
213
            return scan
214
        except Exception:
215
            warnings.warn(
216
                f"Failed to load {cls.__name__} from data. Trying to load from dicom file."
217
            )
218
219
        data = cls._convert_attr_name(data)
220
        data = cls.load_custom_data(data, num_workers=num_workers)
221
222
        scan = None
223
        if "_from_file_args" in data:
224
            dicom_args = data.pop("_from_file_args")
225
            assert dicom_args.pop("_type") == "dicom"
226
            scan = cls.from_dicom(**dicom_args, num_workers=num_workers)
227
        elif "dicom_path" in data:
228
            # Backwards compatibility
229
            dicom_path = data.pop("dicom_path")
230
            ignore_ext = data.pop("ignore_ext", False)
231
            group_by = data.pop("split_by", cls.__DEFAULT_SPLIT_BY__)
232
            scan = cls.from_dicom(
233
                dicom_path, ignore_ext=ignore_ext, group_by=group_by, num_workers=num_workers
234
            )
235
236
        if scan is None:
237
            raise ValueError(f"Data is insufficient to construct {cls.__name__}")
238
239
        for k, v in data.items():
240
            if not hasattr(scan, k):
241
                warnings.warn(f"{cls.__name__} does not have attribute {k}. Skipping...")
242
                continue
243
            scan.__setattr__(k, v)
244
245
        return scan
246
247
    def save_data(
248
        self, base_save_dirpath: str, data_format: ImageDataFormat = preferences.image_data_format
249
    ):
250
        """Deprecated: Alias for :func:`self.save`."""
251
        warnings.warn(
252
            "save_data is deprecated since v0.0.13 and will no longer be "
253
            "available in v0.1. Use `save` instead.",
254
            DeprecationWarning,
255
        )
256
        return self.save(base_save_dirpath, data_format)
257
258
    def _save(
259
        self,
260
        metadata: Dict[str, Any],
261
        save_dir: str,
262
        fname_fmt: Dict[Union[str, type], str] = None,
263
        **kwargs,
264
    ):
265
        if fname_fmt is None:
266
            fname_fmt = {}
267
268
        default_fname_fmt = {MedicalVolume: "image-{}"}
269
        for k, v in default_fname_fmt.items():
270
            if k not in fname_fmt:
271
                fname_fmt[k] = v
272
273
        for attr in metadata.keys():
274
            val = metadata[attr]
275
            path = fname_fmt.get(attr, None)
276
277
            if path is None:
278
                path = os.path.abspath(os.path.join(save_dir, attr))
279
            if not os.path.isabs(path):
280
                path = os.path.join(save_dir, attr, path)
281
            try:
282
                metadata[attr] = self.save_custom_data(val, path, fname_fmt, **kwargs)
283
            except Exception as e:
284
                raise RuntimeError(f"Failed to save metadata {attr} - {e}")
285
286
        return metadata
287
288
    def save_custom_data(
289
        self, metadata, paths, fname_fmt: Dict[Union[str, type], str] = None, **kwargs
290
    ):
291
        """
292
        Finds all attributes of type MedicalVolume or Sequence/Mapping to MedicalVolume
293
        and saves them.
294
        """
295
        if isinstance(metadata, (Dict, Sequence, Set)):
296
            if isinstance(paths, str):
297
                paths = [paths] * len(metadata)
298
            else:
299
                assert len(paths) == len(metadata)
300
301
        if isinstance(metadata, Dict):
302
            keys = metadata.keys()
303
            if isinstance(paths, Dict):
304
                paths = [paths[k] for k in keys]
305
            paths = [os.path.join(_path, f"{k}") for k, _path in zip(keys, paths)]
306
            values = self.save_custom_data(metadata.values(), paths, fname_fmt, **kwargs)
307
            metadata = {k: v for k, v in zip(keys, values)}
308
        elif not isinstance(metadata, str) and isinstance(metadata, (Sequence, Set)):
309
            values = list(metadata)
310
            paths = [os.path.join(_path, "{:03d}".format(i)) for i, _path in enumerate(paths)]
311
            values = [
312
                self.save_custom_data(_x, _path, fname_fmt, **kwargs)
313
                for _x, _path in zip(values, paths)
314
            ]
315
            if not isinstance(values, type(metadata)):
316
                metadata = type(metadata)(values)
317
            else:
318
                metadata = values
319
        else:
320
            formatter = [fname_fmt.get(x) for x in type(metadata).__mro__]
321
            formatter = [x for x in formatter if x is not None]
322
            if len(formatter) == 0:
323
                formatter = None
324
            else:
325
                formatter = formatter[0]
326
            metadata = self._save_custom_data_base(metadata, paths, formatter, **kwargs)
327
328
        return metadata
329
330
    def _save_custom_data_base(self, metadata, path, formatter: str = None, **kwargs):
331
        """The base condition for :meth:`save_custom_data`.
332
333
        Args:
334
            metadata (Any): The data to save.
335
            path (str): The path to save the data.
336
            formatter (str, optional): If provided, this formatted string
337
                will be used to format ``path``.
338
        """
339
        out = {"__dtype__": type(metadata)}
340
        # TODO: Add support for num workers.
341
        # num_workers = kwargs.pop("num_workers", 0)
342
343
        if formatter:
344
            path = os.path.join(os.path.dirname(path), formatter.format(os.path.basename(path)))
345
346
        if isinstance(metadata, MedicalVolume):
347
            image_data_format = kwargs.get("image_data_format", preferences.image_data_format)
348
            # TODO: Once, `files` property added to MedicalVolume, check if property is
349
            # set before doing saving.
350
            path = fio_utils.convert_image_data_format(path, image_data_format)
351
            metadata.save_volume(path, data_format=image_data_format)
352
            out["__value__"] = path
353
        else:
354
            out = metadata
355
356
        return metadata
357
358
    @classmethod
359
    def _convert_attr_name(cls, data: Dict[str, Any]):
360
        return data
361
362
    @classmethod
363
    def load_custom_data(cls, data: Any, **kwargs):
364
        """Recursively converts data to appropriate types.
365
366
        By default, this loads all :class:`MedicalVolume` objects from their corresponding paths.
367
368
        Args:
369
            data (Any): The data. Can either be the dictionary or the metadata value.
370
                If the data corresponds to a custom type, it should have the following
371
                schema:
372
373
                {
374
                    '__dtype__': The type of the data.
375
                    '__value__': The value from which object of type __dtype__ can be constructed.
376
                }
377
378
            **kwargs: Keyword Arguments to pass to :meth:`_load_custom_data_base`.
379
380
        Returns:
381
            Any: The loaded metadata.
382
        """
383
        dtype = type(data)
384
        if isinstance(data, Dict) and "__value__" in data:
385
            dtype = data["__dtype__"]
386
            data = data["__value__"]
387
388
        if issubclass(dtype, Dict):
389
            keys = cls.load_custom_data(data.keys(), **kwargs)
390
            values = cls.load_custom_data(data.values(), **kwargs)
391
            data = {k: v for k, v in zip(keys, values)}
392
        elif not issubclass(dtype, str) and issubclass(dtype, (list, tuple, set)):
393
            data = dtype([cls.load_custom_data(x, **kwargs) for x in data])
394
        else:
395
            data = cls._load_custom_data_base(data, dtype, **kwargs)
396
397
        return data
398
399
    @classmethod
400
    def _load_custom_data_base(cls, data, dtype=None, **kwargs):
401
        """The base condition for :meth:`load_custom_data`.
402
403
        Args:
404
            data:
405
            dtype (type): The data type.
406
407
        Return:
408
            The loaded data.
409
        """
410
        if dtype is None:
411
            dtype = type(data)
412
413
        # TODO: Add support for loading with num_workers
414
        num_workers = kwargs.get("num_workers", 0)
415
        if isinstance(data, str) and issubclass(dtype, MedicalVolume):
416
            data = fio_utils.generic_load(data, num_workers=num_workers)
417
418
        return data
419
420
    def __serializable_variables__(
421
        self, ignore_types=(pydicom.FileDataset, pydicom.Dataset, Tissue), ignore_attrs=()
422
    ) -> Set:
423
        """
424
        By default, all instance attributes are serialized except those
425
        corresponding to headers, :class:`MedicalVolume`(s), or :class:`Tissues`.
426
        Properties and class attributes are also not stored. Class attributes are
427
        indentified using the PEP8 nomenclature of all caps variables.
428
429
        Note:
430
            This method has not been profiled, but times may be large if
431
            the instance contains many variables. Currently this is not
432
            cached as attributes values can change and, as a result, must
433
            be inspected.
434
        """
435
        serializable = []
436
        for attr, value in self.__dict__.items():
437
            if attr in ignore_attrs or _contains_type(value, ignore_types):
438
                continue
439
            if attr.startswith("temp") or attr.startswith("_temp"):
440
                continue
441
            if attr.upper() == attr or (attr.startswith("__") and attr.endswith("__")):
442
                continue
443
            if callable(value) or isinstance(value, property):
444
                continue
445
            serializable.append(attr)
446
447
        return set(serializable)