Diff of /dosma/tissues/meniscus.py [000000] .. [030aeb]

Switch to unified view

a b/dosma/tissues/meniscus.py
1
"""Analysis for meniscus.
2
3
Attributes:
4
    BOUNDS (dict): Upper bounds for quantitative values.
5
"""
6
7
import itertools
8
import os
9
import warnings
10
11
import numpy as np
12
import pandas as pd
13
import scipy.ndimage as sni
14
15
from dosma.core.device import get_array_module
16
from dosma.core.med_volume import MedicalVolume
17
from dosma.core.quant_vals import T2, QuantitativeValueType
18
from dosma.defaults import preferences
19
from dosma.tissues.tissue import Tissue, largest_cc
20
from dosma.utils import io_utils
21
22
import matplotlib.pyplot as plt
23
24
# milliseconds
25
BOUNDS = {
26
    QuantitativeValueType.T2: 60.0,
27
    QuantitativeValueType.T1_RHO: 100.0,
28
    QuantitativeValueType.T2_STAR: 50.0,
29
}
30
31
__all__ = ["Meniscus"]
32
33
34
class Meniscus(Tissue):
35
    """Handles analysis and visualization for meniscus.
36
37
    This class extends functionality from `Tissue`.
38
39
    For visualization, the meniscus is unrolled across the axial plane.
40
    """
41
42
    ID = 2
43
    STR_ID = "men"
44
    FULL_NAME = "meniscus"
45
46
    # Expected quantitative values
47
    T1_EXPECTED = 1000  # milliseconds
48
49
    # Coronal Keys
50
    _ANTERIOR_KEY = 0
51
    _POSTERIOR_KEY = 1
52
    _CORONAL_KEYS = [_ANTERIOR_KEY, _POSTERIOR_KEY]
53
54
    # Saggital Keys
55
    _MEDIAL_KEY = 0
56
    _LATERAL_KEY = 1
57
    _SAGGITAL_KEYS = [_MEDIAL_KEY, _LATERAL_KEY]
58
59
    # Axial Keys
60
    _SUPERIOR_KEY = 0
61
    _INFERIOR_KEY = 1
62
    _TOTAL_AXIAL_KEY = -1
63
64
    def __init__(
65
        self, weights_dir: str = None, medial_to_lateral: bool = None, split_ml_only: bool = False
66
    ):
67
        super().__init__(weights_dir=weights_dir, medial_to_lateral=medial_to_lateral)
68
69
        self.split_ml_only = split_ml_only
70
        self.regions_mask = None
71
72
    def unroll_axial(self, quant_map: np.ndarray):
73
        """Unroll meniscus in axial direction.
74
75
        Args:
76
            quant_map (np.ndarray): Map to roll out.
77
78
        """
79
        mask = self.__mask__.volume
80
81
        assert (
82
            self.regions_mask is not None
83
        ), "region_mask not initialized. Should be initialized when mask is set"
84
        region_mask_sup_inf = self.regions_mask[..., 0]
85
86
        superior = (region_mask_sup_inf == self._SUPERIOR_KEY) * mask * quant_map
87
        superior[superior == 0] = np.nan
88
        superior = np.nanmean(superior, axis=0)
89
90
        inferior = (region_mask_sup_inf == self._INFERIOR_KEY) * mask * quant_map
91
        inferior[inferior == 0] = np.nan
92
        inferior = np.nanmean(inferior, axis=0)
93
94
        total = mask * quant_map
95
        total[total == 0] = np.nan
96
        total = np.nanmean(total, axis=0)
97
98
        return total, superior, inferior
99
100
    def split_regions(self, base_map):
101
        """Split meniscus into subregions.
102
103
        Center-of-mass (COM) is used to subdivide into
104
        anterior/posterior, superior/inferior, and medial/lateral regions.
105
106
        Note:
107
            The anterior/posterior and superior/inferior subdivision may causes issues
108
            with tilted mensici. This will be addressed in a later release. To avoid
109
            computing metrics on these regions, set ``self.split_ml_only=True``.
110
        """
111
        center_of_mass = sni.measurements.center_of_mass(base_map)  # zero indexed
112
113
        com_sup_inf = int(np.ceil(center_of_mass[0]))
114
        com_ant_post = int(np.ceil(center_of_mass[1]))
115
        com_med_lat = int(np.ceil(center_of_mass[2]))
116
117
        region_mask_sup_inf = np.zeros(base_map.shape)
118
        region_mask_sup_inf[:com_sup_inf, :, :] = self._SUPERIOR_KEY
119
        region_mask_sup_inf[com_sup_inf:, :, :] = self._INFERIOR_KEY
120
121
        region_mask_ant_post = np.zeros(base_map.shape)
122
        region_mask_ant_post[:, :com_ant_post, :] = self._ANTERIOR_KEY
123
        region_mask_ant_post[:, com_ant_post:, :] = self._POSTERIOR_KEY
124
125
        region_mask_med_lat = np.zeros(base_map.shape)
126
        region_mask_med_lat[:, :, :com_med_lat] = (
127
            self._MEDIAL_KEY if self.medial_to_lateral else self._LATERAL_KEY
128
        )
129
        region_mask_med_lat[:, :, com_med_lat:] = (
130
            self._LATERAL_KEY if self.medial_to_lateral else self._MEDIAL_KEY
131
        )
132
133
        self.regions_mask = np.stack(
134
            [region_mask_sup_inf, region_mask_ant_post, region_mask_med_lat], axis=-1
135
        )
136
137
    def __calc_quant_vals__(self, quant_map: MedicalVolume, map_type: QuantitativeValueType):
138
        subject_pid = self.pid
139
140
        # Reformats the quantitative map to the appropriate orientation.
141
        super().__calc_quant_vals__(quant_map, map_type)
142
143
        assert (
144
            self.regions_mask is not None
145
        ), "region_mask not initialized. Should be initialized when mask is set"
146
147
        region_mask = self.regions_mask
148
        axial_region_mask = self.regions_mask[..., 0]
149
        coronal_region_mask = self.regions_mask[..., 1]
150
        sagittal_region_mask = self.regions_mask[..., 2]
151
152
        # Combine region mask into categorical mask.
153
        axial_categories = [
154
            (self._SUPERIOR_KEY, "superior"),
155
            (self._INFERIOR_KEY, "inferior"),
156
            (-1, "total"),
157
        ]
158
        coronal_categories = [
159
            (self._ANTERIOR_KEY, "anterior"),
160
            (self._POSTERIOR_KEY, "posterior"),
161
            (-1, "total"),
162
        ]
163
        sagittal_categories = [(self._MEDIAL_KEY, "medial"), (self._LATERAL_KEY, "lateral")]
164
        if self.split_ml_only:
165
            axial_categories = [x for x in axial_categories if x[0] == -1]
166
            coronal_categories = [x for x in coronal_categories if x[0] == -1]
167
168
        categorical_mask = np.zeros(region_mask.shape[:-1])
169
        base_mask = self.__mask__.A.astype(np.bool)
170
        labels = {}
171
        for idx, (
172
            (axial, axial_name),
173
            (coronal, coronal_name),
174
            (sagittal, sagittal_name),
175
        ) in enumerate(
176
            itertools.product(axial_categories, coronal_categories, sagittal_categories)
177
        ):
178
            label = idx + 1
179
            axial_map = np.asarray([True]) if axial == -1 else axial_region_mask == axial
180
            coronal_map = np.asarray([True]) if coronal == -1 else coronal_region_mask == coronal
181
            sagittal_map = sagittal_region_mask == sagittal
182
            categorical_mask[base_mask & axial_map & coronal_map & sagittal_map] = label
183
            labels[label] = f"{axial_name}-{coronal_name}-{sagittal_name}"
184
185
        # TODO: Change this to be any arbitrary quantitative value type.
186
        # Note, it does not matter what we wrap it in because the underlying operations
187
        # are not specific to the value type.
188
        t2 = T2(quant_map)
189
        categorical_mask = MedicalVolume(categorical_mask, affine=quant_map.affine)
190
        df = t2.to_metrics(categorical_mask, labels=labels, bounds=(0, np.inf), closed="neither")
191
        df.insert(0, "Subject", subject_pid)
192
193
        total, superior, inferior = self.unroll_axial(quant_map.volume)
194
        qv_name = map_type.name
195
        maps = [
196
            {
197
                "title": "%s superior" % qv_name,
198
                "data": superior,
199
                "xlabel": "Slice",
200
                "ylabel": "Angle (binned)",
201
                "filename": "%s_superior" % qv_name,
202
                "raw_data_filename": "%s_superior.data" % qv_name,
203
            },
204
            {
205
                "title": "%s inferior" % qv_name,
206
                "data": inferior,
207
                "xlabel": "Slice",
208
                "ylabel": "Angle (binned)",
209
                "filename": "%s_inferior" % qv_name,
210
                "raw_data_filename": "%s_inferior.data" % qv_name,
211
            },
212
            {
213
                "title": "%s total" % qv_name,
214
                "data": total,
215
                "xlabel": "Slice",
216
                "ylabel": "Angle (binned)",
217
                "filename": "%s_total" % qv_name,
218
                "raw_data_filename": "%s_total.data" % qv_name,
219
            },
220
        ]
221
222
        self.__store_quant_vals__(maps, df, map_type)
223
224
    def __calc_quant_vals_old__(self, quant_map, map_type):
225
        subject_pid = self.pid
226
227
        super().__calc_quant_vals__(quant_map, map_type)
228
229
        assert (
230
            self.regions_mask is not None
231
        ), "region_mask not initialized. Should be initialized when mask is set"
232
233
        quant_map_volume = quant_map.volume
234
        mask = self.__mask__.volume
235
236
        quant_map_volume = mask * quant_map_volume
237
238
        axial_region_mask = self.regions_mask[..., 0]
239
        sagittal_region_mask = self.regions_mask[..., 1]
240
        coronal_region_mask = self.regions_mask[..., 2]
241
242
        axial_names = ["superior", "inferior", "total"]
243
        coronal_names = ["medial", "lateral"]
244
        sagittal_names = ["anterior", "posterior"]
245
246
        pd_header = ["Subject", "Location", "Side", "Region", "Mean", "Std", "Median"]
247
        pd_list = []
248
249
        for axial in [self._SUPERIOR_KEY, self._INFERIOR_KEY, self._TOTAL_AXIAL_KEY]:
250
            if axial == self._TOTAL_AXIAL_KEY:
251
                axial_map = np.asarray(
252
                    axial_region_mask == self._SUPERIOR_KEY, dtype=np.float32
253
                ) + np.asarray(axial_region_mask == self._INFERIOR_KEY, dtype=np.float32)
254
                axial_map = np.asarray(axial_map, dtype=np.bool)
255
            else:
256
                axial_map = axial_region_mask == axial
257
258
            for coronal in [self._MEDIAL_KEY, self._LATERAL_KEY]:
259
                for sagittal in [self._ANTERIOR_KEY, self._POSTERIOR_KEY]:
260
                    curr_region_mask = (
261
                        quant_map_volume
262
                        * (coronal_region_mask == coronal)
263
                        * (sagittal_region_mask == sagittal)
264
                        * axial_map
265
                    )
266
                    curr_region_mask[curr_region_mask == 0] = np.nan
267
                    # discard all values that are 0
268
                    c_mean = np.nanmean(curr_region_mask)
269
                    c_std = np.nanstd(curr_region_mask)
270
                    c_median = np.nanmedian(curr_region_mask)
271
272
                    row_info = [
273
                        subject_pid,
274
                        axial_names[axial],
275
                        coronal_names[coronal],
276
                        sagittal_names[sagittal],
277
                        c_mean,
278
                        c_std,
279
                        c_median,
280
                    ]
281
282
                    pd_list.append(row_info)
283
284
        # Generate 2D unrolled matrix
285
        total, superior, inferior = self.unroll_axial(quant_map.volume)
286
287
        df = pd.DataFrame(pd_list, columns=pd_header)
288
        qv_name = map_type.name
289
        maps = [
290
            {
291
                "title": "%s superior" % qv_name,
292
                "data": superior,
293
                "xlabel": "Slice",
294
                "ylabel": "Angle (binned)",
295
                "filename": "%s_superior" % qv_name,
296
                "raw_data_filename": "%s_superior.data" % qv_name,
297
            },
298
            {
299
                "title": "%s inferior" % qv_name,
300
                "data": inferior,
301
                "xlabel": "Slice",
302
                "ylabel": "Angle (binned)",
303
                "filename": "%s_inferior" % qv_name,
304
                "raw_data_filename": "%s_inferior.data" % qv_name,
305
            },
306
            {
307
                "title": "%s total" % qv_name,
308
                "data": total,
309
                "xlabel": "Slice",
310
                "ylabel": "Angle (binned)",
311
                "filename": "%s_total" % qv_name,
312
                "raw_data_filename": "%s_total.data" % qv_name,
313
            },
314
        ]
315
316
        self.__store_quant_vals__(maps, df, map_type)
317
318
    def set_mask(self, mask: MedicalVolume, use_largest_ccs: bool = False, ml_only: bool = False):
319
        xp = get_array_module(mask.A)
320
        if use_largest_ccs:
321
            msk = xp.asarray(largest_cc(mask.A, num=2), dtype=xp.uint8)
322
        else:
323
            msk = xp.asarray(mask.A, dtype=xp.uint8)
324
        mask_copy = mask._partial_clone(volume=msk)
325
        super().set_mask(mask_copy)
326
327
        self.split_regions(self.__mask__.volume)
328
329
    def __save_quant_data__(self, dirpath):
330
        """Save quantitative data and 2D visualizations of meniscus
331
332
        Check which quantitative values (T2, T1rho, etc) are defined for meniscus and analyze these
333
            1. Save 2D total, superficial, and deep visualization maps
334
            2. Save {'medial', 'lateral'}, {'anterior', 'posterior'},
335
                {'superior', 'inferior', 'total'} data to excel file.
336
337
        Args:
338
            dirpath (str): Directory path to tissue data.
339
        """
340
        q_names = []
341
        dfs = []
342
343
        for quant_val in QuantitativeValueType:
344
            if quant_val.name not in self.quant_vals.keys():
345
                continue
346
347
            q_names.append(quant_val.name)
348
            q_val = self.quant_vals[quant_val.name]
349
            dfs.append(q_val[1])
350
351
            q_name_dirpath = io_utils.mkdirs(os.path.join(dirpath, quant_val.name.lower()))
352
            for q_map_data in q_val[0]:
353
                filepath = os.path.join(q_name_dirpath, q_map_data["filename"])
354
                xlabel = "Slice"
355
                ylabel = ""
356
                title = q_map_data["title"]
357
                data_map = q_map_data["data"]
358
359
                plt.clf()
360
361
                upper_bound = BOUNDS[quant_val]
362
363
                if preferences.visualization_use_vmax:
364
                    # Hard bounds - clipping
365
                    plt.imshow(data_map, cmap="jet", vmin=0.0, vmax=BOUNDS[quant_val])
366
                else:
367
                    # Try to use a soft bounds
368
                    if np.sum(data_map <= upper_bound) == 0:
369
                        plt.imshow(data_map, cmap="jet", vmin=0.0, vmax=BOUNDS[quant_val])
370
                    else:
371
                        warnings.warn(
372
                            "%s: Pixel value exceeded upper bound (%0.1f). Using normalized scale."
373
                            % (quant_val.name, upper_bound)
374
                        )
375
                        plt.imshow(data_map, cmap="jet")
376
377
                plt.xlabel(xlabel)
378
                plt.ylabel(ylabel)
379
                plt.title(title)
380
                clb = plt.colorbar()
381
                clb.ax.set_title("(ms)")
382
                plt.axis("tight")
383
384
                plt.savefig(filepath)
385
386
                # Save data
387
                raw_data_filepath = os.path.join(
388
                    q_name_dirpath, "raw_data", q_map_data["raw_data_filename"]
389
                )
390
                io_utils.save_pik(raw_data_filepath, data_map)
391
392
        if len(dfs) > 0:
393
            io_utils.save_tables(os.path.join(dirpath, "data.xlsx"), dfs, q_names)