a b/dosma/tissues/femoral_cartilage.py
1
import os
2
import warnings
3
4
import numpy as np
5
import pandas as pd
6
import scipy.ndimage as sni
7
8
from dosma.core.device import get_array_module
9
from dosma.core.io.format_io import ImageDataFormat
10
from dosma.core.med_volume import MedicalVolume
11
from dosma.core.quant_vals import QuantitativeValueType
12
from dosma.defaults import preferences
13
from dosma.tissues.tissue import Tissue, largest_cc
14
from dosma.utils import img_utils, io_utils
15
from dosma.utils.geometry_utils import cart2pol, circle_fit
16
17
import matplotlib.pyplot as plt
18
19
# milliseconds
20
BOUNDS = {
21
    QuantitativeValueType.T2: 80.0,
22
    QuantitativeValueType.T1_RHO: 100.0,
23
    QuantitativeValueType.T2_STAR: 80.0,
24
}
25
26
__all__ = ["FemoralCartilage"]
27
28
29
class FemoralCartilage(Tissue):
30
    """Handles analysis and visualization for femoral cartilage.
31
32
    This class extends functionality from `Tissue`.
33
34
    For visualization, the femoral cartilage is unrolled onto a 2D plane using angular binning [1].
35
36
    References:
37
        [1] Monu UD, Jordan CD, Samuelson BL, Hargreaves BA, Gold GE, McWalter EJ.
38
        Cluster analysis of quantitative MRI T2 and :math:`T1\\rho` relaxation times of
39
        cartilage identifies differences between healthy and ACL-injured individuals at 3T."
40
        Osteoarthritis and cartilage 2017;25(4):513-520.
41
    """
42
43
    ID = 1
44
    STR_ID = "fc"
45
    FULL_NAME = "femoral cartilage"
46
47
    # Expected quantitative values
48
    T1_EXPECTED = 1200  # milliseconds
49
50
    # Keys correspond to integer representing bit location for each region
51
    # bit string: 'T D S M L A C P' (stored as integer)
52
    # Coronal Keys
53
    _POSTERIOR_KEY = 2 ** 0
54
    _CENTRAL_KEY = 2 ** 1
55
    _ANTERIOR_KEY = 2 ** 2
56
    _CORONAL_KEYS = [_POSTERIOR_KEY, _CENTRAL_KEY, _ANTERIOR_KEY]
57
58
    # Sagittal Keys
59
    _MEDIAL_KEY = 2 ** 3
60
    _LATERAL_KEY = 2 ** 4
61
    _SAGITTAL_KEYS = [_MEDIAL_KEY, _LATERAL_KEY]
62
63
    # Axial Keys
64
    _DEEP_KEY = 2 ** 5
65
    _SUPERFICIAL_KEY = 2 ** 6
66
    _TOTAL_AXIAL_KEY = 2 ** 7
67
    _AXIAL_KEYS = [_DEEP_KEY, _SUPERFICIAL_KEY, _TOTAL_AXIAL_KEY]
68
69
    # Do not change order of below.
70
    # Order reflects order of _CORONAL_KEYS, _SAGITTAL_KEYS, _AXIAL_KEYS
71
    _AXIAL_NAMES = ["deep", "superficial", "total"]
72
    _SAGITTAL_NAMES = ["medial", "lateral"]
73
    _CORONAL_NAMES = ["posterior", "central", "anterior"]
74
75
    ML_BOUNDARY = None
76
    ACP_BOUNDARY = None
77
78
    def __init__(self, weights_dir=None, medial_to_lateral=None):
79
        super().__init__(weights_dir=weights_dir)
80
81
        self.regions_mask = None
82
        self.theta_bins = None
83
84
        self.medial_to_lateral = medial_to_lateral
85
86
    def split_regions(
87
        self, base_map: np.ndarray, thickness_divisor=0.5, num_bins=72, theta=(-270, 90)
88
    ):
89
        """Split volume into anatomical regions.
90
91
        Pixels corresponding to femoral cartilage are divided across 3 planes:
92
            - Coronal: Posterior, Central, or Anterior
93
            - Sagittal: Medial, Lateral
94
            - Axial: Deep, Superficial
95
96
        For example, a pixel could correspond to the Posterior Lateral Deep region of
97
            femoral cartilage.
98
99
        Args:
100
            base_map (np.ndarray): 3D numpy array typically corresponding to volume to split.
101
102
        Returns:
103
            np.ndarray: 4D numpy array (region, height, width, depth).
104
                Saved in variable ``self.regions``.
105
        """
106
        dtheta = 360 / num_bins
107
        theta_min, theta_max = tuple(theta)
108
109
        mask = self.__mask__.volume
110
111
        mask = mask * np.nan_to_num(base_map)
112
113
        height, width, num_slices = mask.shape
114
115
        # STEP 1: PROJECTING AND CYLINDRICAL FIT
116
        segmented_t2maps_projected = np.max(mask, 2)  # Project segmented T2maps on sagittal axis
117
        non_zero_element = np.nonzero(segmented_t2maps_projected)
118
119
        xc_fit, yc_fit, R_fit = circle_fit(
120
            non_zero_element[1], non_zero_element[0]
121
        )  # fit a circle to projected cartilage tissue
122
123
        # STEP 2: SLICE BY SLICE BINNING
124
        yv, xv = np.meshgrid(range(height), range(width), indexing="ij")
125
126
        rho, theta = cart2pol(xv - xc_fit, yc_fit - yv)
127
        theta = (theta >= 90) * (theta - 360) + (theta < 90) * theta  # range: [-270, 90)
128
129
        assert (np.min(theta) >= theta_min) and (
130
            np.max(theta) < theta_max
131
        ), "Expected Theta range is [{:d}, {:d}) degrees. Received min: {:d} max: {:d})".format(
132
            theta_min, theta_max, np.min(theta), np.max(theta)
133
        )
134
135
        theta_bins = np.floor((theta - theta_min) / dtheta)
136
137
        # STEP 3: COMPUTE THRESHOLD RADII
138
        # TODO: This step takes a long time
139
        rhos_threshold_volume = np.zeros(mask.shape)
140
        for curr_slice in range(num_slices):
141
            mask_slice = mask[..., curr_slice]
142
143
            for curr_bin in range(num_bins):
144
                rhos_valid = rho[np.logical_and(mask_slice > 0, theta_bins == curr_bin)]
145
                if len(rhos_valid) == 0:
146
                    continue
147
148
                rho_min = np.min(rhos_valid)
149
                rho_max = np.max(rhos_valid)
150
151
                rho_threshold = thickness_divisor * (rho_max - rho_min) + rho_min
152
                rhos_threshold_volume[theta_bins == curr_bin, curr_slice] = rho_threshold
153
154
        regions_volume = np.asarray(np.zeros(mask.shape), dtype=np.uint16)
155
156
        # anterior/central/posterior division
157
        # Central region occupies middle 30 degrees, anterior on left, posterior on right
158
        anterior_region = self._ANTERIOR_KEY * (theta < -105)
159
        central_region = self._CENTRAL_KEY * np.logical_and((theta >= -105), (theta < -75))
160
        posterior_region = self._POSTERIOR_KEY * (theta >= -75)
161
        acp_map = anterior_region + central_region + posterior_region
162
        acp_volume = np.asarray(np.stack([acp_map] * num_slices, axis=-1), dtype=np.uint16)
163
        regions_volume += acp_volume
164
165
        # medial/lateral division
166
        # take into account scanning direction
167
        center_of_mass = sni.measurements.center_of_mass(mask)
168
        com_slicewise = center_of_mass[-1]
169
        ml_volume = np.asarray(np.zeros(mask.shape), dtype=np.uint16)
170
171
        if self.medial_to_lateral:
172
            ml_volume[..., : int(np.ceil(com_slicewise))] = self._MEDIAL_KEY
173
            ml_volume[..., int(np.ceil(com_slicewise)) :] = self._LATERAL_KEY
174
        else:
175
            ml_volume[..., : int(np.ceil(com_slicewise))] = self._LATERAL_KEY
176
            ml_volume[..., int(np.ceil(com_slicewise)) :] = self._MEDIAL_KEY
177
        regions_volume += ml_volume
178
179
        # deep/superficial division
180
        rho_volume = np.stack([rho] * num_slices, axis=-1)
181
        deep_volume = (rho_volume <= rhos_threshold_volume) * self._DEEP_KEY
182
        superficial_volume = (rho_volume >= rhos_threshold_volume) * self._SUPERFICIAL_KEY
183
        ds_volume = np.asarray(
184
            deep_volume + superficial_volume + self._TOTAL_AXIAL_KEY, dtype=np.uint16
185
        )
186
187
        regions_volume += ds_volume
188
        ml_boundary = int(np.ceil(com_slicewise))
189
        acp_boundary = [
190
            int(np.floor((-105 - theta_min) / dtheta)),
191
            int(np.floor((-75 - theta_min) / dtheta)),
192
        ]
193
194
        return regions_volume, theta_bins, ml_boundary, acp_boundary
195
196
    def unroll(self, qv_map: np.ndarray, regions_mask: np.ndarray, theta_bins):
197
        """Unroll femoral cartilage 3D quantitative value (qv) maps to 2D for visualization.
198
199
        The function multiplies a 3D segmentation mask to a 3D qv map to produce a 3D femoral
200
        cartilage qv (fc_qv) map. It then fits a circle to the collapsed sagittal projection
201
        of the fc_qv map. Each slice is binned into bins of 5 degree sizes
202
203
        The unrolled map is then divided into deep and superficial cartilage.
204
205
        Args:
206
            qv_map (np.ndarray): 3D array (slices last) of sagittal knee describing
207
                quantitative parameter values regions_mask (np.ndarray): regions_mask
208
        Returns:
209
            tuple: (row, column) format
210
                1. 2D Total unrolled cartilage (slices, degrees) - average of superficial
211
                    and deep layers
212
                2. Superficial unrolled cartilage (slices, degrees) - superficial layer
213
                3. Deep unrolled cartilage (slices, degrees) - deep layer
214
        """
215
        num_bins = len(np.unique(theta_bins))
216
217
        mask = self.__mask__.volume
218
219
        if qv_map.shape != mask.shape:
220
            raise ValueError("t2_map and mask must have same shape")
221
222
        if len(qv_map.shape) != 3:
223
            raise ValueError("t2_map and mask must be 3D")
224
225
        # assert self.regions_mask is not None, (
226
        #     "region_mask not initialized. Should be initialized when mask is set"
227
        # )
228
229
        num_slices = qv_map.shape[-1]
230
231
        qv_map = np.nan_to_num(qv_map)
232
        qv_map = np.multiply(mask, qv_map)  # apply binary mask
233
        qv_map[
234
            qv_map <= 0
235
        ] = np.nan  # wherever qv_map is 0, either no cartilage or qv=0 ms, which is impractical
236
237
        # theta_bins = self.theta_bins  # binning with theta
238
239
        # regions_mask = self.regions_mask
240
241
        Unrolled_Cartilage = np.zeros([num_bins, num_slices])
242
        Sup_layer = np.zeros([num_bins, num_slices])
243
        Deep_layer = np.zeros([num_bins, num_slices])
244
245
        for slice_ind in range(num_slices):
246
            qv_slice = qv_map[..., slice_ind]
247
            curr_slice = regions_mask[..., slice_ind]
248
249
            # if slice is all NaNs, then don't analyze
250
            if np.sum(np.isnan(qv_slice)) == qv_slice.shape[0] * qv_slice.shape[1]:
251
                continue
252
253
            for curr_bin in range(num_bins):
254
                qv_bin = qv_slice[theta_bins == curr_bin]
255
                if np.sum(np.isnan(qv_bin)) == len(qv_bin):
256
                    continue
257
258
                Unrolled_Cartilage[curr_bin, slice_ind] = np.nanmean(qv_bin)
259
260
                qv_superficial = qv_slice[
261
                    np.logical_and(
262
                        theta_bins == curr_bin,
263
                        self.__binarize_region_mask__(curr_slice, self._SUPERFICIAL_KEY),
264
                    )
265
                ]
266
                qv_deep = qv_slice[
267
                    np.logical_and(
268
                        theta_bins == curr_bin,
269
                        self.__binarize_region_mask__(curr_slice, self._DEEP_KEY),
270
                    )
271
                ]
272
273
                qv_superficial = np.nan_to_num(qv_superficial)
274
                qv_deep = np.nan_to_num(qv_deep)
275
276
                qv_sup_mean = np.mean(qv_superficial[qv_superficial > 0])
277
                qv_deep_mean = np.mean(qv_deep[qv_deep > 0])
278
                Sup_layer[curr_bin, slice_ind] = qv_sup_mean
279
                Deep_layer[curr_bin, slice_ind] = qv_deep_mean
280
281
        Unrolled_Cartilage[Unrolled_Cartilage == 0] = np.nan
282
        Sup_layer[Sup_layer == 0] = np.nan
283
        Deep_layer[Deep_layer == 0] = np.nan
284
285
        return Unrolled_Cartilage, Sup_layer, Deep_layer
286
287
    def __calc_quant_vals__(self, quant_map: MedicalVolume, map_type):
288
        """Calculate quantitative values per region and 2D visualizations
289
290
        1. Save 2D figure (deep, superficial, total) information to use with matplotlib
291
            (title, data, xlabel, ylabel, filename)
292
293
        2. Save 2D dataframes in format
294
                [['DMA', 'DMC', 'DMP'], ['DLA', 'DLC', 'DLP'],
295
                 ['SMA', 'SMC', 'SMP'], ['SLA', 'SLC', 'SLP'],
296
                 ['TMA', 'TMC', 'TMP'], ['TLA', 'TLC', 'TLP']]
297
298
                 D=deep, S=superficial, T=total,
299
                 M=medial, L=lateral,
300
                 A=anterior, C=central, P=posterior
301
302
        Args:
303
            quant_map (MedicalVolume): 3D volumes of quantitative values.
304
                Volume should have ``np.nan`` values for all pixels unable to be calculated.
305
            map_type (QuantitativeValueType): Type of quantitative value to analyze.
306
        """
307
308
        super().__calc_quant_vals__(quant_map, map_type)
309
310
        # assert self.regions_mask is not None, (
311
        #     "region_mask not initialized. Should be initialized when mask is set"
312
        # )
313
314
        # We have to call this every time we load a new quantitative map
315
        # mask = segmentation_mask * clipped_quant_map
316
        regions_mask, theta_bins, ml_boundary, acp_boundary = self.split_regions(quant_map.volume)
317
        if self.ML_BOUNDARY is None:
318
            self.ML_BOUNDARY = ml_boundary
319
        if self.ACP_BOUNDARY is None:
320
            self.ACP_BOUNDARY = acp_boundary
321
322
        total, superficial, deep = self.unroll(quant_map.volume, regions_mask, theta_bins)
323
324
        assert total.shape == deep.shape
325
        assert deep.shape == superficial.shape
326
327
        # regions_mask = self.regions_mask
328
        mask = self.__mask__.volume
329
330
        subject_pid = self.pid
331
        pd_header = ["Subject", "Location", "Side", "Region", "Mean", "Std", "Median", "# Voxels"]
332
        pd_list = []
333
334
        # Replace strings with values - eg. DMA = 'deep, medial, anterior'
335
        # tissue_values = [['DMA', 'DMC', 'DMP'], ['DLA', 'DLC', 'DLP'],
336
        #                  ['SMA', 'SMC', 'SMP'], ['SLA', 'SLC', 'SLP'],
337
        #                  ['TMA', 'TMC', 'TMP'], ['TLA', 'TLC', 'TLP']]
338
        # tissue_values = []
339
340
        for axial_ind in range(len(self._AXIAL_KEYS)):
341
            axial = self._AXIAL_KEYS[axial_ind]
342
343
            for sagittal_ind in range(len(self._SAGITTAL_KEYS)):
344
                sagittal = self._SAGITTAL_KEYS[sagittal_ind]
345
                for coronal_ind in range(len(self._CORONAL_KEYS)):
346
                    coronal = self._CORONAL_KEYS[coronal_ind]
347
348
                    curr_region_mask = self.__binarize_region_mask__(
349
                        regions_mask, (axial | coronal | sagittal)
350
                    )
351
                    curr_region_mask = curr_region_mask * mask * quant_map.volume
352
353
                    # discard all values that are <= 0
354
                    qv_region_vals = curr_region_mask[curr_region_mask > 0]
355
356
                    num_voxels = len(qv_region_vals)
357
358
                    c_mean = np.nanmean(qv_region_vals)
359
                    c_std = np.nanstd(qv_region_vals)
360
                    c_median = np.nanmedian(qv_region_vals)
361
362
                    row_info = [
363
                        subject_pid,
364
                        self._AXIAL_NAMES[axial_ind],
365
                        self._SAGITTAL_NAMES[sagittal_ind],
366
                        self._CORONAL_NAMES[coronal_ind],
367
                        c_mean,
368
                        c_std,
369
                        c_median,
370
                        num_voxels,
371
                    ]
372
373
                    pd_list.append(row_info)
374
375
        df = pd.DataFrame(pd_list, columns=pd_header)
376
        qv_name = map_type.name
377
        maps = [
378
            {
379
                "title": "{} deep".format(qv_name),
380
                "data": deep,
381
                "xlabel": "Slice",
382
                "ylabel": "Angle (binned)",
383
                "filename": "{}_deep".format(qv_name),
384
                "raw_data_filename": "{}_deep.data".format(qv_name),
385
            },
386
            {
387
                "title": "{} superficial".format(qv_name),
388
                "data": superficial,
389
                "xlabel": "Slice",
390
                "ylabel": "Angle (binned)",
391
                "filename": "{}_superficial".format(qv_name),
392
                "raw_data_filename": "{}_superficial.data".format(qv_name),
393
            },
394
            {
395
                "title": "{} total".format(qv_name),
396
                "data": total,
397
                "xlabel": "Slice",
398
                "ylabel": "Angle (binned)",
399
                "filename": "{}_total".format(qv_name),
400
                "raw_data_filename": "{}_total.data".format(qv_name),
401
            },
402
        ]
403
404
        self.__store_quant_vals__(maps, df, map_type)
405
406
    def set_mask(
407
        self, mask: MedicalVolume, use_largest_cc: bool = True, split_regions: bool = True
408
    ):
409
        """Set mask for tissue.
410
411
        Mask is cleaned by selecting the largest connected component from the mask.
412
            Femoral cartilage is expected to be single connected tissue.
413
414
        Args:
415
            mask (MedicalVolume): Binary mask of segmented tissue.
416
        """
417
        xp = get_array_module(mask.A)
418
        if use_largest_cc:
419
            msk = xp.asarray(largest_cc(mask.A), dtype=xp.uint8)
420
        else:
421
            msk = xp.asarray(mask.A, dtype=xp.uint8)
422
        mask_copy = mask._partial_clone(volume=msk)
423
424
        super().set_mask(mask_copy)
425
426
        if split_regions:
427
            (
428
                self.regions_mask,
429
                self.theta_bins,
430
                self.ML_BOUNDARY,
431
                self.ACP_BOUNDARY,
432
            ) = self.split_regions(  # noqa: E501
433
                self.__mask__.volume
434
            )
435
436
    def __save_quant_data__(self, dirpath: str):
437
        """Save quantitative data and 2D visualizations of femoral cartilage.
438
439
        Check which quantitative values (T2, T1rho, etc) are defined for femoral cartilage
440
        and analyze these:
441
442
            1. Save 2D total, superficial, and deep visualization maps.
443
            2. Save {'medial', 'lateral'}, {'anterior', 'central', 'posterior'},
444
            q{'deep', 'superficial'} data to excel file
445
446
        Args:
447
            dirpath (str): Directory path to tissue data.
448
        """
449
        q_names = []
450
        dfs = []
451
452
        for quant_val in QuantitativeValueType:
453
            if quant_val.name not in self.quant_vals.keys():
454
                continue
455
456
            q_names.append(quant_val.name)
457
            q_val = self.quant_vals[quant_val.name]
458
            dfs.append(q_val[1])
459
460
            q_name_dirpath = io_utils.mkdirs(os.path.join(dirpath, quant_val.name.lower()))
461
            for q_map_data in q_val[0]:
462
                filepath = os.path.join(q_name_dirpath, q_map_data["filename"])
463
                xlabel = "Slice"
464
                ylabel = "Angle (binned)"
465
                title = q_map_data["title"]
466
                data_map = q_map_data["data"]
467
468
                plt.clf()
469
470
                upper_bound = BOUNDS[quant_val]
471
472
                if preferences.visualization_use_vmax:
473
                    # Hard bounds - clipping
474
                    plt.imshow(data_map, cmap="jet", vmin=0.0, vmax=BOUNDS[quant_val])
475
                else:
476
                    # Try to use a soft bounds
477
                    if np.sum(data_map <= upper_bound) == 0:
478
                        plt.imshow(data_map, cmap="jet", vmin=0.0, vmax=BOUNDS[quant_val])
479
                    else:
480
                        warnings.warn(
481
                            "%s: Pixel value exceeded upper bound (%0.1f). Using normalized scale."
482
                            % (quant_val.name, upper_bound)
483
                        )
484
                        plt.imshow(data_map, cmap="jet")
485
486
                plt.xlabel(xlabel)
487
                plt.ylabel(ylabel)
488
                plt.title(title)
489
                clb = plt.colorbar()
490
                clb.ax.set_title("(ms)")
491
492
                plt.savefig(filepath)
493
494
                # Save data
495
                raw_data_filepath = os.path.join(
496
                    q_name_dirpath, "raw_data", q_map_data["raw_data_filename"]
497
                )
498
                io_utils.save_pik(raw_data_filepath, data_map)
499
500
        if len(dfs) > 0:
501
            io_utils.save_tables(os.path.join(dirpath, "data.xlsx"), dfs, q_names)
502
503
    def save_data(self, save_dirpath, data_format: ImageDataFormat = preferences.image_data_format):
504
        super().save_data(save_dirpath, data_format=data_format)
505
506
        save_dirpath = self.__save_dirpath__(save_dirpath)
507
508
        if self.regions_mask is None:
509
            return
510
511
        sagital_region_mask, coronal_region_mask = self.__split_mask__()
512
513
        # Save region map - add by 1 because no key can be 0
514
        coronal_region_mask = (coronal_region_mask + 1) * 10
515
        sagital_region_mask = sagital_region_mask + 1
516
        joined_mask = coronal_region_mask + sagital_region_mask
517
        labels = [
518
            "medial posterior",
519
            "medial central",
520
            "medial anterior",
521
            "lateral posterior",
522
            "lateral central",
523
            "lateral anterior",
524
        ]
525
        plt_dict = {
526
            "labels": labels,
527
            "xlabel": "Slice",
528
            "ylabel": "Angle (binned)",
529
            "title": "Unrolled Regions",
530
        }
531
        img_utils.write_regions(
532
            os.path.join(save_dirpath, "region_map"), joined_mask, plt_dict=plt_dict
533
        )
534
535
    def __binarize_region_mask__(self, region_mask, roi):
536
        return np.asarray(np.bitwise_and(region_mask, roi) == roi, dtype=np.bool)
537
538
    def __split_mask__(self):
539
        assert (
540
            self.ML_BOUNDARY is not None and self.ACP_BOUNDARY is not None
541
        ), "medial/lateral and anterior/central/posterior boundaries should be specified"
542
543
        # split into regions
544
        unrolled_total, _, _ = self.unroll(
545
            np.asarray(self.__mask__.volume, dtype=np.float32), self.regions_mask, self.theta_bins
546
        )
547
548
        acp_division_unrolled = np.zeros(unrolled_total.shape)
549
550
        ac_threshold = self.ACP_BOUNDARY[0]
551
        cp_threshold = self.ACP_BOUNDARY[1]
552
        acp_division_unrolled[:ac_threshold, :] = self._ANTERIOR_KEY
553
        acp_division_unrolled[ac_threshold:cp_threshold, :] = self._CENTRAL_KEY
554
        acp_division_unrolled[cp_threshold:, :] = self._POSTERIOR_KEY
555
556
        ml_division_unrolled = np.zeros(unrolled_total.shape)
557
        if self.medial_to_lateral:
558
            ml_division_unrolled[..., : self.ML_BOUNDARY] = self._MEDIAL_KEY
559
            ml_division_unrolled[..., self.ML_BOUNDARY :] = self._LATERAL_KEY
560
        else:
561
            ml_division_unrolled[..., : self.ML_BOUNDARY] = self._LATERAL_KEY
562
            ml_division_unrolled[..., self.ML_BOUNDARY :] = self._MEDIAL_KEY
563
564
        acp_division_unrolled[np.isnan(unrolled_total)] = np.nan
565
        ml_division_unrolled[np.isnan(unrolled_total)] = np.nan
566
567
        return acp_division_unrolled, ml_division_unrolled