--- a +++ b/dosma/scan_sequences/mri/cones.py @@ -0,0 +1,228 @@ +"""Ultra-short Echo Time Cones (UTE-Cones).""" +import logging +import os +from typing import Sequence + +import numpy as np + +from dosma import file_constants as fc +from dosma.core import quant_vals as qv +from dosma.core.fitting import MonoExponentialFit +from dosma.core.io import format_io_utils as fio_utils +from dosma.core.io.nifti_io import NiftiReader +from dosma.core.med_volume import MedicalVolume +from dosma.core.registration import apply_warp, register +from dosma.scan_sequences.scans import NonTargetSequence +from dosma.tissues.tissue import Tissue +from dosma.utils.cmd_line_utils import ActionWrapper + +__all__ = ["Cones"] + +__EXPECTED_NUM_ECHO_TIMES__ = 4 + +__INITIAL_T2_STAR_VAL__ = 30.0 + +__T2_STAR_LOWER_BOUND__ = 0 +__T2_STAR_UPPER_BOUND__ = np.inf +__T2_STAR_DECIMAL_PRECISION__ = 3 + +_logger = logging.getLogger(__name__) + + +class Cones(NonTargetSequence): + """UTE-Cones MRI sequence. + + Ultra-short echo time cones (UTE-Cones) is a :math:`T_2^*`-weighted sequence. + In practice, many of these scans are low resolution and are ofter interregistered + with higher-resolution scans. This can be done with :meth:`Cones.interregister`. + + References: + Qian Y, Williams AA, Chu CR, Boada FE. Multicomponent T2* mapping of + knee cartilage: technical feasibility ex vivo. + Magnetic resonance in medicine 2010;64(5):1426-1431." + """ + + NAME = "cones" + + def __init__(self, volumes, echo_times: Sequence[float] = None): + super().__init__(volumes) + + if echo_times is None: + try: + if all(x.headers() is not None for x in self.volumes): + echo_times = [x.get_metadata("EchoTime", float) for x in self.volumes] + except (KeyError, AttributeError, RuntimeError) as e: + raise ValueError( + f"Could not extract echo times from header. " + f"Please specify `echo_times` argument - {e}" + ) + + self.echo_times = echo_times + + def interregister(self, target_path: str, target_mask_path: str = None): + volumes = self.volumes + echo_times = self.echo_times + idxs = np.argsort(echo_times) + + echo_times = [echo_times[i] for i in idxs] + volumes = [volumes[i] for i in idxs] + nr = NiftiReader() + out_path = os.path.join(self.temp_path, "interregistered") + os.makedirs(out_path, exist_ok=True) + + # TODO: Make these into parameters + num_threads = 2 + num_workers = 0 + verbose = True + + if verbose: # pragma: no cover + _logger.info("") + _logger.info("==" * 40) + _logger.info("Interregistering...") + _logger.info("Target: {}".format(target_path)) + if target_mask_path is not None: + _logger.info("Mask: {}".format(target_mask_path)) + _logger.info("==" * 40) + + # Target mask path has to be dilated. + if target_mask_path: + target_mask_path = self.__dilate_mask__(target_mask_path, out_path) + parameter_files = [ + fc.ELASTIX_RIGID_INTERREGISTER_PARAMS_FILE, + fc.ELASTIX_AFFINE_INTERREGISTER_PARAMS_FILE, + ] + use_mask = [False, True] + else: + parameter_files = [fc.ELASTIX_RIGID_PARAMS_FILE, fc.ELASTIX_AFFINE_PARAMS_FILE] + use_mask = None + + # Last echo should be the base. + base, moving = volumes[-1], volumes[:-1] + + out_reg, _ = register( + target_path, + base, + parameters=parameter_files, + output_path=out_path, + sequential=True, + collate=True, + num_workers=num_workers, + num_threads=num_threads, + return_volumes=False, + target_mask=target_mask_path, + use_mask=use_mask, + rtype=tuple, + show_pbar=verbose, + ) + out_reg = out_reg[0] + + reg_vols = [] + for mvg in moving: + reg_vols.append(apply_warp(mvg, out_reg.transform)) + reg_vols.append(nr.load(out_reg.warped_file)) # base volume is last + + # Undo sorting by echo time. + reverse_idxs = {v: i for i, v in enumerate(idxs)} + reg_vols = [reg_vols[reverse_idxs[k]] for k in sorted(reverse_idxs.keys())] + + self.volumes = reg_vols + + def generate_t2_star_map(self, tissue: Tissue, mask_path: str = None, num_workers: int = 0): + """ + Generate 3D :math:`T_2^* map and r-squared fit map using mono-exponential fit + across subvolumes acquired at different echo times. + + :math:`T_2^* map is also added to the tissue. + + Args: + tissue (Tissue): Tissue to generate quantitative value for. + mask_path (:obj:`str`, optional): File path to mask of ROI to analyze. + If specified, only voxels specified by mask will be fit. + This can considerably speed up computation. + num_workers (int, optional): Number of subprocesses to use for fitting. + If `0`, will execute on the main thread. + + Returns: + qv.T2Star: :math:`T_2^* fit for tissue. + + Raises: + ValueError: If ``mask_path`` corresponds to non-binary volume. + """ + # only calculate for focused region if a mask is available, this speeds up computation + mask = tissue.get_mask() + if mask_path is not None: + mask = ( + fio_utils.generic_load(mask_path, expected_num_volumes=1) + if isinstance(mask_path, (str, os.PathLike)) + else mask_path + ) + + spin_lock_times = self.echo_times + subvolumes_list = self.volumes + + mef = MonoExponentialFit( + bounds=(__T2_STAR_LOWER_BOUND__, __T2_STAR_UPPER_BOUND__), + tc0="polyfit", + decimal_precision=__T2_STAR_DECIMAL_PRECISION__, + num_workers=num_workers, + verbose=True, + ) + + t2star_map, r2 = mef.fit(spin_lock_times, subvolumes_list, mask=mask) + + quant_val_map = qv.T2Star(t2star_map) + quant_val_map.add_additional_volume("r2", r2) + + tissue.add_quantitative_value(quant_val_map) + + return quant_val_map + + def _save(self, metadata, save_dir, fname_fmt=None, **kwargs): + default_fmt = {MedicalVolume: "echo-{}"} + default_fmt.update(fname_fmt if fname_fmt else {}) + return super()._save(metadata, save_dir, fname_fmt=default_fmt, **kwargs) + + @classmethod + def from_dict(cls, data, force: bool = False) -> "Cones": + interregistered_dirpath = None + if "subvolumes" in data: + interregistered_dirpath = os.path.dirname(data.pop("subvolumes")[0]) + scan: Cones = super().from_dict(data, force=force) + if interregistered_dirpath is not None: + subvolumes = scan.__load_interregistered_files__(interregistered_dirpath) + cls.volumes = [subvolumes[k] for k in sorted(subvolumes.keys())] + + return scan + + @classmethod + def cmd_line_actions(cls): + """ + Provide command line information (such as name, help strings, etc) + as list of dictionary. + """ + + interregister_action = ActionWrapper( + name=cls.interregister.__name__, + help="register to another scan", + param_help={ + "target_path": "path to target image in nifti format (.nii.gz)", + "target_mask_path": "path to target mask in nifti format (.nii.gz)", + }, + alternative_param_names={ + "target_path": ["tp", "target"], + "target_mask_path": ["tm", "target_mask"], + }, + ) + generate_t2star_map_action = ActionWrapper( + name=cls.generate_t2_star_map.__name__, + help="generate T2-star map", + param_help={ + "mask_path": "Mask used for fitting select voxels - " "in nifti format (.nii.gz)" + }, + aliases=["t2_star"], + ) + + return [ + (cls.interregister, interregister_action), + (cls.generate_t2_star_map, generate_t2star_map_action), + ]