--- a +++ b/dosma/cli.py @@ -0,0 +1,813 @@ +"""Initialize and parse command line arguments for DOSMA. + +This module is the entry point for executing DOSMA from the command line. +The DOSMA library is critical for processing. + +To use this file, it must be run as a module from the parent directory:: + + $ python -m dosma/cli ... + +Examples: + Run T2 fitting on Subject 01, Series 007, a quantitative DESS (qDESS) scan, + for femoral cartilage:: + + python -m dosma/cli --dicom subject01/dicoms/007/ --save subject01/data/ \ + qdess --fc generate_t2_map + +Hint: + Run ``python -m dosma/cli --help`` for a detailed description of different + command line arguments. +""" + +import argparse +import ast +import functools +import inspect +import logging +import os +import time +import warnings +from collections import defaultdict +from typing import Sequence + +from dosma.core.io.format_io import ImageDataFormat +from dosma.core.quant_vals import QuantitativeValueType as QV +from dosma.defaults import preferences +from dosma.models.seg_model import SegModel +from dosma.models.util import SUPPORTED_MODELS, get_model, model_from_config +from dosma.msk import knee +from dosma.scan_sequences.mri.cones import Cones +from dosma.scan_sequences.mri.cube_quant import CubeQuant +from dosma.scan_sequences.mri.mapss import Mapss +from dosma.scan_sequences.mri.qdess import QDess +from dosma.scan_sequences.scans import ScanSequence +from dosma.tissues.tissue import Tissue +from dosma.utils import env +from dosma.utils.logger import setup_logger + +SUPPORTED_QUANTITATIVE_VALUES = [QV.T2, QV.T1_RHO, QV.T2_STAR] + +DEBUG_KEY = "debug" + +DICOM_KEY = "dicom" +SAVE_KEY = "save" +LOAD_KEY = "load" +IGNORE_EXT_KEY = "ignore_ext" +SPLIT_BY_KEY = "split_by" + +GPU_KEY = "gpu" +NUM_WORKERS_KEY = "num-workers" + +SCAN_KEY = "scan" +SCAN_ACTION_KEY = "scan_action" + +SEGMENTATION_MODEL_KEY = "model" +SEGMENTATION_CONFIG_KEY = "config" +SEGMENTATION_WEIGHTS_DIR_KEY = "weights_dir" +SEGMENTATION_BATCH_SIZE_KEY = "batch_size" + +TISSUES_KEY = "tissues" + +SUPPORTED_SCAN_TYPES = [Cones, CubeQuant, Mapss, QDess] +BASIC_TYPES = [bool, str, float, int, list, tuple] + +_logger = logging.getLogger(__name__) + + +class CommandLineScanContainer: + def __init__( + self, + scan_type: type, + dicom_path, + load_path, + ignore_ext: bool = False, + group_by=None, + num_workers=0, + **kwargs, + ): + """The class for command-line handling around :class:`ScanSequence`. + + The command line interface for :class:`ScanSequence` data is highly structured, + particularly in data saving and loading to support easy and accurate command-line + compatibility. + + This class overloads some standard functionality in :class:`ScanSequence` + (:func:`save`, :func:`load`). When methods are not implemented, this class provides + access directly to attributes/methods of the underlying scan instantiation. + For example, if ``scan_type=QDess``, the following will call + :func:`QDess.generate_t2_map`: + + >>> cli_scan = CommandLineScanContainer(QDess, dicom_path="/path/to/qdess/scan") + >>> cli_scan.generate_t2_map(...) # this calls cli_scan.scan.generate_t2_map + + Data is loaded either from the ``dicom_path`` or the ``load_path``. If both are specified, + the data is loaded from the ``dicom_path``. + + Args: + scan_type (type): A scan type. Should be subclass of `ScanSequence`. + dicom_path (str): The dicom path. This value can be ``None``, but must + be explicitly set. + load_path (str): The load path. This value can be ``None``, but must be + explicitly set. + ignore_ext (bool, optional): If ``True``, ignore extensions when loading + dicom data. See :func:`DicomReader.load` for details. + group_by (optional): The value(s) to group dicoms by. See :func:`DicomReader.load` + for details. + num_workers (int, optional): Number of works for loading scan. + + Attributes: + scan_type (type): The scan type to instantiate. + scan (ScanSequence): The instantiated scan. + generic_args (Dict[str, Any]): Generic duck typed parameter names and values. + If parameters with this name are part of the method signature, they will + automatically be set to the values in this dictionary. Keys include: + * "num_workers": Number of cpu workers to use. + * "max_workers": Alias for "num_workers" in some functions + * "verbose": Verbosity + * "show_pbar": Show progress bar. + + Raises: + NotADirectoryError: If ``dicom_path`` is not a path to a directory. + """ + self.scan_type = scan_type + + if (dicom_path is not None) and (not os.path.isdir(dicom_path)): + if load_path is not None: + warnings.warn( + "Dicom_path {} not found. Will load data from {}".format(dicom_path, load_path) + ) + else: + raise NotADirectoryError("{} is not a directory".format(dicom_path)) + + # Only use dicoms if the path exists and path contains files. + is_dicom_available = (dicom_path is not None) and (os.path.isdir(dicom_path)) + + # If dicom_path is specified and exists, assume user wants to start from scratch with the + # dicoms. load_path is ignored. + group_by = group_by if group_by is not None else scan_type.__DEFAULT_SPLIT_BY__ + if is_dicom_available: + scan = scan_type.from_dicom( + dicom_path, group_by=group_by, ignore_ext=ignore_ext, num_workers=num_workers + ) + else: + scan = self.load(load_path, num_workers=num_workers) + + self.scan = scan + self.generic_args = { + "num_workers": num_workers, + "max_workers": num_workers, + "verbose": True, + "show_pbar": True, + } + + def __getattr__(self, name): + attr = getattr(self.scan, name) + if callable(attr): + params = inspect.signature(attr).parameters + params = params.keys() & self.generic_args.keys() + kwargs = {k: self.generic_args[k] for k in params} + if len(kwargs): + attr = functools.partial(attr, **kwargs) + return attr + + def load(self, path: str, num_workers: int = 0): + """Command line interface loading scan data. + + ``self.scan_type`` must be set before calling this function. + + Args + path (str): Path to pickle file or directory where data is stored. + num_workers (int, optional): Number of workers to use to load data. + + Returns: + ScanSequence: Scan of type ``self.scan_type``. + + Raises: + ValueError: If path to load data from cannot be determined. + + Examples: + >>> cli_scan.load("/path/to/pickle/file") # load data from pickle file + >>> cli_scan.load("/path/to/directory") # load data from directory + """ + scan_type = self.scan_type + + file_path = None + if os.path.isfile(path): + file_path = path + elif os.path.isdir(path) and scan_type.NAME: + fname = f"{scan_type.NAME}.data" + _paths = ( + os.path.join(path, fname), + os.path.join(self._save_dir(path, create_dir=False), fname), + ) + for _path in _paths: + if os.path.isfile(_path): + file_path = _path + break + if file_path is None: + raise ValueError(f"Cannot load {scan_type.__name__} data from path '{path}'") + + return scan_type.load(file_path, num_workers) + + def _save_dir(self, dir_path: str, create_dir: bool = True): + """Returns directory path specific to this scan. + + Formatted as '`base_load_dirpath`/`scan.NAME`'. + + Args: + dir_path (str): Directory path where all data is stored. + create_dir (`bool`, optional): If `True`, creates directory if it doesn't exist. + + Returns: + str: Data directory path for this scan. + """ + scan_type = self.scan_type + folder_id = scan_type.NAME + + name_len = len(folder_id) + 2 # buffer + if scan_type.NAME not in dir_path[-name_len:]: + scan_dirpath = os.path.join(dir_path, folder_id) + else: + scan_dirpath = dir_path + + # scan_dirpath = os.path.join(scan_dirpath, folder_id) + + if create_dir: + os.makedirs(scan_dirpath, exist_ok=True) + + return scan_dirpath + + def save( + self, + path: str, + save_custom: bool = True, + image_data_format: ImageDataFormat = None, + num_workers: int = 0, + ): + path = self._save_dir(path, create_dir=True) + return self.scan.save(path, save_custom, image_data_format, num_workers) + + +def get_nargs_for_basic_type(base_type: type): + if base_type in [str, float, int]: + return 1 + elif base_type in [list, tuple]: + return "+" + + +def add_tissues(parser: argparse.ArgumentParser): + for tissue in knee.SUPPORTED_TISSUES: + parser.add_argument( + "--%s" % tissue.STR_ID, + action="store_const", + default=False, + const=True, + help="analyze %s" % tissue.FULL_NAME, + ) + + +def parse_tissues(vargin: dict): + tissues = [] + for tissue in knee.SUPPORTED_TISSUES: + t = tissue() + if ( + t.STR_ID in vargin.keys() + and vargin[t.STR_ID] + and t.STR_ID not in [x.STR_ID for x in tissues] + ): + load_path = vargin[LOAD_KEY] + if load_path: + t.load_data(load_path) + + tissues.append(t) + + # if no tissues are specified, do computation for all supported tissues + if len(tissues) == 0: + _logger.info("No tissues specified, computing for all supported tissues...") + tissues = [] + for tissue in knee.SUPPORTED_TISSUES: + t = tissue() + if t.STR_ID not in [x.STR_ID for x in tissues]: + load_path = vargin[LOAD_KEY] + if load_path: + t.load_data(load_path) + + tissues.append(t) + + analysis_str = "Tissue(s): " + for tissue in tissues: + analysis_str += "%s, " % tissue.FULL_NAME + + _logger.info(analysis_str) + + return tissues + + +def add_segmentation_subparser(parser): + parser.add_argument( + "--%s" % SEGMENTATION_WEIGHTS_DIR_KEY, + type=str, + nargs=1, + required=True, + help="path to directory with weights", + ) + parser.add_argument( + "--%s" % SEGMENTATION_MODEL_KEY, + choices=SUPPORTED_MODELS, + nargs="?", + default=None, + help="built-in model to use for segmentation. Choices: %s" % SUPPORTED_MODELS, + ) + parser.add_argument( + "--%s" % SEGMENTATION_CONFIG_KEY, + type=str, + default=None, + help="config file for non-built-in model", + ) + parser.add_argument( + "--%s" % SEGMENTATION_BATCH_SIZE_KEY, + metavar="B", + type=int, + default=preferences.segmentation_batch_size, + nargs="?", + help="batch size for inference. Default: %d" % preferences.segmentation_batch_size, + ) + + return parser + + +def handle_segmentation(vargin, scan: ScanSequence, tissue: Tissue): + if not vargin[SEGMENTATION_MODEL_KEY] and not vargin[SEGMENTATION_CONFIG_KEY]: + raise ValueError( + "Either `--{}` or `--{}` must be specified".format( + SEGMENTATION_MODEL_KEY, SEGMENTATION_CONFIG_KEY + ) + ) + + segment_weights_path = vargin[SEGMENTATION_WEIGHTS_DIR_KEY][0] + if isinstance(tissue, Sequence): + weights = [t.find_weights(segment_weights_path) for t in tissue] + assert all(weights_file == weights[0] for weights_file in weights) + weights_path = weights[0] + else: + weights_path = tissue.find_weights(segment_weights_path) + + # Load model + dims = scan.get_dimensions() + # TODO: Input shape should be determined by combination of model + scan. + # Currently fixed in 2D plane + input_shape = (dims[0], dims[1], 1) + if vargin[SEGMENTATION_MODEL_KEY]: + # Use built-in model + model = get_model( + vargin[SEGMENTATION_MODEL_KEY], input_shape=input_shape, weights_path=weights_path + ) + else: + # Use config + model = model_from_config( + vargin[SEGMENTATION_CONFIG_KEY], + weights_dir=segment_weights_path, + input_shape=input_shape, + ) + model.batch_size = vargin[SEGMENTATION_BATCH_SIZE_KEY] + + return model + + +CUSTOM_TYPE_TO_HANDLE_DICT = {SegModel: handle_segmentation} + + +def add_custom_argument(parser, param_type): + # handle all custom arguments except tissues + has_custom_argument = False + if param_type is SegModel: + add_segmentation_subparser(parser) + has_custom_argument = True + + return has_custom_argument + + +def add_base_argument( + parser: argparse.ArgumentParser, + param_name, + param_type, + param_default, + param_help, + additional_param_names: list = None, +): + if additional_param_names is None: + additional_param_names = [] + + # TODO: Clean up this code block to properly do syntax parsing. + try: + if param_type not in BASIC_TYPES: + param_type = extract_basic_type(param_type) + except (AttributeError, TypeError): + raise TypeError( + "Parameter '{}' - type '{}' not in BASIC_TYPES".format(param_name, param_type) + ) + + # add default value to param help + has_default = param_default is not inspect._empty + if has_default: + param_help = "%s. Default: %s" % (param_help, param_default) + + if additional_param_names: + param_names = ["--%s" % n for n in additional_param_names] + else: + param_names = [] + + param_names.append("--%s" % param_name) + + if param_type is bool: + if not has_default: + raise ValueError("All boolean parameters must have a default value.") + + parser.add_argument( + *param_names, + action="store_%s" % (str(not param_default).lower()), + dest=param_name, + help=param_help, + ) + return + + # all other values with default have this parameter + nargs_no_default = get_nargs_for_basic_type(param_type) + nargs = "?" if has_default else nargs_no_default + + parser.add_argument( + *param_names, + nargs=nargs, + default=param_default if has_default else None, + dest=param_name, + help=param_help, + required=not has_default, + ) + + +def parse_basic_type(val, param_type): + if param_type not in BASIC_TYPES: + param_type = extract_basic_type(param_type) + + if type(val) is param_type: + return val + + if param_type in [list, tuple]: + return param_type(val) + + nargs = get_nargs_for_basic_type(param_type) + if type(val) is list and nargs == 1: + return val[0] + return param_type(val) if val else val + + +def extract_basic_type(param_type): + """Extracts basic types from ``typing`` aliases. + + Args: + param_type (typing._GenericAlias): A generic alias + (e.g. ``typing.Tuple``, ``typing.List``). + + Returns: + type: The basic type. + """ + try: + # Python 3.5 / 3.6 + return param_type.__extra__ + except AttributeError: + # Python 3.7/3.8/3.9 + return param_type.__origin__ + + +def add_scans(dosma_subparser): + for scan in SUPPORTED_SCAN_TYPES: + supported_actions = scan.cmd_line_actions() + + # skip scans that are not supported on the command line + if len(supported_actions) == 0: + continue + scan_name = scan.NAME + scan_parser = dosma_subparser.add_parser(scan.NAME, help="analyze %s sequence" % scan_name) + add_tissues(scan_parser) + + if not len(supported_actions): + raise ValueError("No supported actions for scan %s" % scan_name) + + scan_subparser = scan_parser.add_subparsers( + description="%s subcommands" % scan.NAME, dest=SCAN_ACTION_KEY + ) + + for action, action_wrapper in supported_actions: + func_signature = inspect.signature(action) + func_name = action_wrapper.name + aliases = action_wrapper.aliases + action_parser = scan_subparser.add_parser( + func_name, aliases=aliases, help=action_wrapper.help + ) + + parameters = func_signature.parameters + for param_name in parameters.keys(): + param = parameters[param_name] + param_type = param.annotation + param_default = param.default + + if param_name == "self" or param_type is Tissue: + continue + + param_help = action_wrapper.get_param_help(param_name) + alternative_param_names = action_wrapper.get_alternative_param_names(param_name) + + if param_type is inspect._empty: + raise ValueError( + "scan %s, action %s, param %s does not have an annotation. " + "Use pytying in the method declaration" % (scan.NAME, func_name, param_name) + ) + + # see if the type is a custom type, if not handle it as a basic type + is_custom_arg = add_custom_argument(action_parser, param_type) + if is_custom_arg: + continue + + add_base_argument( + action_parser, + param_name, + param_type, + param_default, + param_help=param_help, + additional_param_names=alternative_param_names, + ) + + scan_parser.set_defaults(func=handle_scan) + + +def _find_tissue_groups(vargin, tissues: Sequence[Tissue]): + """Finds groups of tissues that can be segmented together. + + Some segmentation models have multiple weight files for different tissues. + Many models have one weight file for multiple tissues (i.e. multi-class segmentation). + + This function matches tissues with their corresponding weight file. + If multiple tissues share a single weight file, they will map to the same weight file, + allowing multiple tissues to be segmented simultaneously. + + This is a temporary fix for segmenting multiple classes. + It should not be extended or used as precedence for future development. + """ + if not isinstance(tissues, Sequence): + assert isinstance(tissues, Tissue) + tissues = [tissues] + + weights_dir = vargin[SEGMENTATION_WEIGHTS_DIR_KEY][0] + weights_to_tissues = defaultdict(list) + for tissue in tissues: + weights_to_tissues[tissue.find_weights(weights_dir)].append(tissue) + + return weights_to_tissues + + +def _build_params(vargin, scan, parameters, tissue=None): + param_dict = {} + for param_name in parameters.keys(): + param = parameters[param_name] + param_type = param.annotation + + if param_name == "self": + continue + + if param_type is Tissue: + assert tissue is not None + param_dict["tissue"] = tissue + continue + + if param_type in CUSTOM_TYPE_TO_HANDLE_DICT: + param_dict[param_name] = CUSTOM_TYPE_TO_HANDLE_DICT[param_type](vargin, scan, tissue) + else: + param_dict[param_name] = parse_basic_type(vargin[param_name], param_type) + return param_dict + + +def handle_scan(vargin): + + scan_name = vargin[SCAN_KEY] + _logger.info("Analyzing {}...".format(scan_name)) + scan = None + + for p_scan in SUPPORTED_SCAN_TYPES: + if p_scan.NAME == scan_name: + scan = p_scan + break + + scan = CommandLineScanContainer( + scan, + dicom_path=vargin[DICOM_KEY], + load_path=vargin[LOAD_KEY], + ignore_ext=vargin[IGNORE_EXT_KEY], + split_by=vargin[SPLIT_BY_KEY] if vargin[SPLIT_BY_KEY] else scan.__DEFAULT_SPLIT_BY__, + num_workers=vargin[NUM_WORKERS_KEY], + ) + + tissues = vargin["tissues"] + scan_action = scan_action_str = vargin[SCAN_ACTION_KEY] + + p_action = None + for action, action_wrapper in scan.cmd_line_actions(): + if scan_action == action_wrapper.name or scan_action in action_wrapper.aliases: + p_action = action + break + + # search for name in the cmd_line actions + action = p_action + + if action is None: + scan.save(vargin[SAVE_KEY], image_data_format=preferences.image_data_format) + return + + func_signature = inspect.signature(action) + parameters = func_signature.parameters + if scan_action_str == "segment": + weights_to_tissues = _find_tissue_groups(vargin, tissues) + for _weights_file, seg_tissues in weights_to_tissues.items(): + if len(seg_tissues) == 1: + seg_tissues = seg_tissues[0] + param_dict = _build_params(vargin, scan, parameters, seg_tissues) + getattr(scan, action.__name__)(**param_dict) + else: + if "tissue" in func_signature.parameters.keys(): + for tissue in tissues: + param_dict = _build_params(vargin, scan, parameters, tissue) + getattr(scan, action.__name__)(**param_dict) + else: + param_dict = _build_params(vargin, scan, parameters) + getattr(scan, action.__name__)(**param_dict) + + scan.save(vargin[SAVE_KEY], image_data_format=preferences.image_data_format) + for tissue in tissues: + tissue.save_data(vargin[SAVE_KEY], data_format=preferences.image_data_format) + + return scan + + +def parse_dicom_tag_splitby(vargin_str): + if not vargin_str: + return vargin_str + + try: + parsed_str = ast.literal_eval(vargin_str) + return parsed_str + except Exception: + return vargin_str + + +def parse_args(f_input=None): + """Parse arguments given through command line (argv) + + :raises ValueError if dicom path is not provided + :raises NotADirectoryError if dicom path does not exist or is not a directory + """ + parser = argparse.ArgumentParser( + prog="DOSMA", + description="A deep-learning powered open source MRI analysis pipeline", + epilog="Either `--dicom` or `--load` must be specified. " + "If both are given, `--dicom` will be used", + ) + parser.add_argument("--%s" % DEBUG_KEY, action="store_true", help="use debug mode") + + # Dicom and results paths + parser.add_argument( + "--d", + "--%s" % DICOM_KEY, + metavar="D", + type=str, + default=None, + nargs="?", + dest=DICOM_KEY, + help="path to directory storing dicom files", + ) + parser.add_argument( + "--l", + "--%s" % LOAD_KEY, + metavar="L", + type=str, + default=None, + nargs="?", + dest=LOAD_KEY, + help="path to data directory to load from", + ) + parser.add_argument( + "--s", + "--%s" % SAVE_KEY, + metavar="S", + type=str, + default=None, + nargs="?", + dest=SAVE_KEY, + help="path to data directory to save to. Default: L/D", + ) + parser.add_argument( + "--%s" % IGNORE_EXT_KEY, + action="store_true", + default=False, + dest=IGNORE_EXT_KEY, + help="ignore .dcm extension when loading dicoms. Default: False", + ) + parser.add_argument( + "--%s" % SPLIT_BY_KEY, + metavar="G", + type=str, + default=None, + nargs="?", + dest=SPLIT_BY_KEY, + help="override dicom tag to split volumes by (eg. `EchoNumbers`)", + ) + + parser.add_argument( + "--%s" % GPU_KEY, + metavar="G", + type=str, + default=None, + nargs="?", + dest=GPU_KEY, + help="gpu id. Default: None", + ) + + parser.add_argument( + "--%s" % NUM_WORKERS_KEY, + metavar="G", + type=int, + default=0, + dest=NUM_WORKERS_KEY, + help="num cpu workers. Default: 0", + ) + + # Add preferences + preferences_flags = preferences.cmd_line_flags() + for flag in preferences_flags.keys(): + argparse_kwargs = preferences_flags[flag] + argparse_kwargs["dest"] = flag + aliases = argparse_kwargs.pop("aliases", None) + name = argparse_kwargs.pop("name", None) # noqa: F841 + parser.add_argument(*aliases, **argparse_kwargs) + + subparsers = parser.add_subparsers(help="sub-command help", dest=SCAN_KEY) + add_scans(subparsers) + + # MSK knee parser + knee.knee_parser(subparsers) + + start_time = time.time() + if f_input: + args = parser.parse_args(f_input) + else: + args = parser.parse_args() + + # Only initialize logger if called from command line. + # If UI is using it, the logger should be initialized by the UI. + setup_logger(env.log_file_path()) + + vargin = vars(args) + + if vargin[DEBUG_KEY]: + env.debug(True) + + gpu = vargin[GPU_KEY] + + _logger.debug(vargin) + + if gpu is not None: + os.environ["CUDA_VISIBLE_DEVICES"] = gpu + elif "CUDA_VISIBLE_DEVICES" not in os.environ: + os.environ["CUDA_VISIBLE_DEVICES"] = "-1" + + # parse and update preferences + for flag in preferences_flags.keys(): + preferences.set(flag, vargin[flag]) + + dicom_path = vargin[DICOM_KEY] + load_path = vargin[LOAD_KEY] + + if not dicom_path and not load_path: + raise ValueError("Must provide path to dicoms or path to load data from") + + save_path = vargin[SAVE_KEY] + if not save_path: + save_path = load_path if load_path else "%s/data" % dicom_path + vargin[SAVE_KEY] = save_path + + if not os.path.isdir(save_path): + os.makedirs(save_path) + + tissues = parse_tissues(vargin) + vargin["tissues"] = tissues + + vargin[SPLIT_BY_KEY] = parse_dicom_tag_splitby(vargin[SPLIT_BY_KEY]) + + args.func(vargin) + + time_elapsed = time.time() - start_time + _logger.info("Time Elapsed: {:.2f} seconds".format(time.time() - start_time)) + + return time_elapsed + + +if __name__ == "__main__": + parse_args()