Diff of /dosma/tissues/tissue.py [000000] .. [030aeb]

Switch to unified view

a b/dosma/tissues/tissue.py
1
import os
2
from abc import ABC, abstractmethod
3
from typing import Union
4
5
import numpy as np
6
import pandas as pd
7
import scipy.ndimage as sni
8
9
from dosma.core.io import format_io_utils as fio_utils
10
from dosma.core.io.format_io import ImageDataFormat
11
from dosma.core.med_volume import MedicalVolume
12
from dosma.core.orientation import SAGITTAL
13
from dosma.core.quant_vals import QuantitativeValue, QuantitativeValueType
14
from dosma.defaults import preferences
15
from dosma.utils import io_utils
16
17
WEIGHTS_FILE_EXT = "h5"
18
19
__all__ = ["Tissue"]
20
21
22
class Tissue(ABC):
23
    """Abstract class for tissues.
24
25
    Tissues are defined loosely as any tissue structures (bones, soft tissue, etc.).
26
27
    Args:
28
        weights_dir (str): Directory to all segmentation weights.
29
        medial_to_lateral (`bool`, optional): If `True`, anatomy is from medial_to_lateral.
30
31
    Attributes:
32
        FULL_NAME (str): Full name of tissue 'femoral cartilage' for femoral cartilage.
33
        ID (int): Unique integer ID for tissue. Should be unique to all tissues,
34
            and should not change.
35
        STR_ID (str): Short hand string id such as 'fc' for femoral cartilage.
36
        T1_EXPECTED (float): Expected T1 value (in milliseconds).
37
        medial_to_lateral (bool): If ``True``, mask is in medial to lateral direction.
38
        pid (str): Patient/subject ID. Should be anonymized.
39
        quant_vals (dict[str, tuple[np.ndarray, pd.DataFrame]]): Mapping from quantitative value
40
            name (t2, t1-rho, etc.) to tuple of unrolled map and DataFrame containing
41
            measurement values.
42
        weights_filepath (str): File path to weights directory for neural network segmentation.
43
    """
44
45
    ID = -1
46
    STR_ID = ""
47
    FULL_NAME = ""
48
49
    # Expected quantitative param values.
50
    T1_EXPECTED = None
51
52
    def __init__(self, weights_dir: str = None, medial_to_lateral: bool = None):
53
        self.pid = None
54
        self.__mask__ = None
55
        self.quant_vals = {}
56
        self.weights_file_path = None
57
58
        if weights_dir is not None:
59
            self.weights_file_path = self.find_weights(weights_dir)
60
61
        self.medial_to_lateral = medial_to_lateral
62
63
        # quantitative value list
64
        self.quantitative_values = []
65
66
    @abstractmethod
67
    def split_regions(self, base_map: Union[np.ndarray, MedicalVolume]):
68
        """Split mask into anatomical regions.
69
70
        Args:
71
            base_map (np.ndarray): 3D numpy array typically corresponding to volume to split.
72
73
        Returns:
74
            np.ndarray: 4D numpy array (region, height, width, depth).
75
                        Saved in variable `self.regions`.
76
        """
77
        pass
78
79
    def calc_quant_vals(self):
80
        """Calculate quantitative values for pixels corresponding to the tissue.
81
82
        Requires mask to be set for this tissue.
83
        """
84
        for qv in self.quantitative_values:
85
            self.__calc_quant_vals__(qv.volumetric_map, qv.qv_type)
86
87
    @abstractmethod
88
    def __calc_quant_vals__(self, quant_map: MedicalVolume, map_type: QuantitativeValueType):
89
        """Helper method to get quantitative values for tissue - implemented per tissue.
90
91
        Different tissues should override this as they see fit.
92
93
        Args:
94
            quant_map (MedicalVolume): 3D map of pixel-wise quantitative measures
95
                (T2, T2*, T1-rho, etc.). Volume should have ``np.nan`` values for
96
                all pixels unable to be calculated.
97
            map_type (QuantitativeValueType): Type of quantitative value to analyze.
98
99
        Raises:
100
            TypeError: If `quant_map` is not of type `MedicalVolume` or `map_type` is not of type
101
                `QuantitativeValueType`.
102
            ValueError: If no mask is found for tissue.
103
        """
104
        if not isinstance(quant_map, MedicalVolume):
105
            raise TypeError("`Expected type 'MedicalVolume' for `quant_map`")
106
        if not isinstance(map_type, QuantitativeValueType):
107
            raise TypeError("`Expected type 'QuantitativeValueType' for `map_type`")
108
109
        if self.__mask__ is None:
110
            raise ValueError("Please initialize mask for {}".format(self.FULL_NAME))
111
112
        quant_map.reformat(self.__mask__.orientation, inplace=True)
113
        pass
114
115
    def __store_quant_vals__(
116
        self, quant_map: MedicalVolume, quant_df: pd.DataFrame, map_type: QuantitativeValueType
117
    ):
118
        """Adds quantitative value in `self.quant_vals`.
119
120
        Args:
121
            quant_map (list[dict]): Dictionaries of different unrolled maps and
122
                corresponding plotting data (title, xlabel, etc.).
123
            quant_df (pd.DataFrame): Computed data for this quantitative value.
124
            map_type (QuantitativeValueType): Type of quantitative value to analyze.
125
        """
126
        self.quant_vals[map_type.name] = (quant_map, quant_df)
127
128
    def find_weights(self, weights_dir: str):
129
        """Search for weights file in weights directory.
130
131
        Args:
132
            weights_dir (str): Directory where weights are stored.
133
134
        Returns:
135
            str: File path to weights corresponding to tissue.
136
137
        Raises:
138
            ValueError: If multiple weights files exists for the tissue
139
                or no valid weights file found.
140
        """
141
142
        # Find weights file with NAME in the filename, like 'fc_weights.h5'
143
        files = os.listdir(weights_dir)
144
        weights_file = None
145
        for f in files:
146
            file = os.path.join(weights_dir, f)
147
            if os.path.isfile(file) and f.endswith(WEIGHTS_FILE_EXT) and self.STR_ID in f:
148
                if weights_file is not None:
149
                    raise ValueError("There are multiple weights files, please remove duplicates")
150
                weights_file = file
151
152
        if weights_file is None:
153
            raise ValueError(
154
                "No file found that contains '{}' and ends in '{}'".format(
155
                    self.STR_ID, WEIGHTS_FILE_EXT
156
                )
157
            )
158
159
        self.weights_file_path = weights_file
160
161
        return weights_file
162
163
    def save_data(
164
        self, save_dirpath: str, data_format: ImageDataFormat = preferences.image_data_format
165
    ):
166
        """Save data for tissue.
167
168
        Saves mask and quantitative values associated with this tissue.
169
170
        Override in subclasses to save additional data. When overriding in subclasses, call
171
        ``super().save_data(save_dirpath)`` first to save mask and quantitative values by default.
172
        See :mod:`dosma.tissues.femoral_cartilage` for details.
173
174
        .. literalinclude:: femoral_cartilage.py
175
176
        Args:
177
            save_dirpath (str): Directory path where all data is stored.
178
            data_format (`ImageDataFormat`, optional): Format to save data.
179
        """
180
        save_dirpath = self.__save_dirpath__(save_dirpath)
181
182
        if self.__mask__ is not None:
183
            mask_file_path = os.path.join(save_dirpath, "{}.nii.gz".format(self.STR_ID))
184
            mask_file_path = fio_utils.convert_image_data_format(mask_file_path, data_format)
185
            self.__mask__.save_volume(mask_file_path, data_format=data_format)
186
187
        for qv in self.quantitative_values:
188
            qv.save_data(save_dirpath, data_format)
189
190
        self.__save_quant_data__(save_dirpath)
191
192
    @abstractmethod
193
    def __save_quant_data__(self, dirpath: str):
194
        """Save quantitative data generated for this tissue.
195
196
        Called by `save_data`.
197
198
        Args:
199
            dirpath (str): Directory path to tissue data.
200
        """
201
        pass
202
203
    def save_quant_data(self, dirpath: str):
204
        """Save quantitative data generated for this tissue.
205
206
        Does not save mask or quantitative parameter map.
207
208
        Args:
209
            dirpath (str): Directory path to tissue data.
210
        """
211
        return self.__save_quant_data__(dirpath)
212
213
    def load_data(self, load_dir_path: str):
214
        """Load data for tissue.
215
216
        All tissue information is based on the mask. If mask for tissue doesn't exist,
217
        there is no information to load.
218
219
        Args:
220
            load_dir_path (str): Directory path where all data is stored.
221
        """
222
        load_dir_path = self.__save_dirpath__(load_dir_path)
223
        mask_file_path = os.path.join(load_dir_path, "{}.nii.gz".format(self.STR_ID))
224
225
        # Try to load mask, if file exists.
226
        try:
227
            msk = fio_utils.generic_load(mask_file_path, expected_num_volumes=1)
228
            self.set_mask(msk)
229
        except FileNotFoundError:
230
            # do nothing
231
            pass
232
233
        self.quantitative_values = QuantitativeValue.load_qvs(load_dir_path)
234
235
    def __save_dirpath__(self, dirpath):
236
        """Tissue-specific subdirectory to store data.
237
238
        Subdirectory will have path '`dirpath`/`self.STR_ID`/'.
239
240
        If directory does not exist, it will be created.
241
242
        Args:
243
            dirpath (str): Directory path where all data is stored.
244
245
        Returns:
246
            str: Tissue-specific data directory.
247
        """
248
        return io_utils.mkdirs(os.path.join(dirpath, self.STR_ID))
249
250
    # TODO (arjundd): Refactor get/set methods of mask to property.
251
    def set_mask(self, mask: MedicalVolume):
252
        """Set mask for tissue.
253
254
        Args:
255
            mask (MedicalVolume): Binary mask of segmented tissue.
256
        """
257
        assert type(mask) is MedicalVolume, "mask for tissue must be of type MedicalVolume"
258
        mask = mask.reformat(SAGITTAL)
259
        self.__mask__ = mask
260
261
    def get_mask(self):
262
        """
263
        Returns:
264
            MedicalVolume: Binary mask of segmented tissue.
265
        """
266
        return self.__mask__
267
268
    def add_quantitative_value(self, qv_new: QuantitativeValue):
269
        """Add quantitative value to the tissue.
270
271
        Args:
272
            qv_new (QuantitativeValue): Quantitative value to add to tissue.
273
        """
274
        # for qv in self.quantitative_values:
275
        #     if qv_new.NAME == qv.NAME:
276
        #         raise ValueError('This quantitative value already exists. '
277
        #                          'Only one type of quantitative value can be added per tissue.\n'
278
        #                          'Manually delete %s folder' % qv_new.NAME)
279
280
        self.quantitative_values.append(qv_new)
281
282
    def __get_axis_bounds__(
283
        self, im: np.ndarray, ignore_nan: bool = True, leave_buffer: bool = False
284
    ):
285
        """Get tightest bounds for data in the array.
286
287
        When plotting data, we would like to avoid making our dynamic range too large such
288
        that we cannot detect color changes in differences that matter.
289
        To avoid this, we make our bounds as tight as possible.
290
291
        Bounds are calculated with respect to non-zero elements. If unique values are [0, 8, 9],
292
            the dyanmic range will be [8, 9].
293
294
        Args:
295
            im (np.ndarray): Array containing information for which bounds have to be computed.
296
            ignore_nan (obj:`bool`, optional): Ignore `nan` values when computing the bounds.
297
            leave_buffer (obj:`bool`, optional): Add buffer of +/-5 to dynamic range.
298
        """
299
        im_temp = im
300
        axs = []
301
        if ignore_nan:
302
            im_temp = np.nan_to_num(im)
303
304
        non_zero_elems = np.nonzero(im_temp)
305
306
        for i in range(len(non_zero_elems)):
307
            v_min = np.min(non_zero_elems[i])
308
            v_max = np.max(non_zero_elems[i])
309
            if leave_buffer:
310
                v_min -= 5
311
                v_max += 5
312
313
            axs.append((v_min, v_max))
314
315
        return axs
316
317
318
def largest_cc(mask, num=1):
319
    """Return the largest `num` connected component(s) of a 3D mask array.
320
321
    Args:
322
        mask (np.ndarray): 3D mask array (`np.bool` or `np.[u]int`).
323
        num (int, optional): Maximum number of connected components to keep.
324
325
    Returns:
326
        mask (np.ndarray): 3D mask array with `num` connected components.
327
328
329
    Note:
330
        Adapted from nipy (https://github.com/nipy/nipy/blob/master/nipy/labs/mask.py)
331
        due to dependency issues.
332
    """
333
    # We use asarray to be able to work with masked arrays.
334
    mask = np.asarray(mask)
335
    labels, label_nb = sni.label(mask)
336
    if not label_nb:
337
        raise ValueError("No non-zero values: no connected components")
338
    if label_nb == 1:
339
        return mask.astype(np.bool)
340
    label_count = np.bincount(labels.ravel().astype(np.int))
341
    # discard 0 the 0 label
342
    label_count[0] = 0
343
344
    # Split num=1 case for speed.
345
    if num == 1:
346
        return labels == label_count.argmax()
347
    else:
348
        # 1) discard 0 the 0 label and 2) descending order
349
        order = np.argsort(label_count)[1:][::-1]
350
        return np.isin(labels, order[:num])