--- a +++ b/sybil/serie.py @@ -0,0 +1,277 @@ +import functools +from typing import List, Optional, NamedTuple, Literal +from argparse import Namespace + +import torch +import numpy as np +import pydicom +import torchio as tio + +from sybil.datasets.utils import order_slices, VOXEL_SPACING +from sybil.utils.loading import get_sample_loader + + +class Meta(NamedTuple): + paths: list + thickness: float + pixel_spacing: list + manufacturer: str + slice_positions: list + voxel_spacing: torch.Tensor + + +class Label(NamedTuple): + y: int + y_seq: np.ndarray + y_mask: np.ndarray + censor_time: int + + +class Serie: + def __init__( + self, + dicoms: List[str], + voxel_spacing: Optional[List[float]] = None, + label: Optional[int] = None, + censor_time: Optional[int] = None, + file_type: Literal["png", "dicom"] = "dicom", + split: Literal["train", "dev", "test"] = "test", + ): + """Initialize a Serie. + + Parameters + ---------- + `dicoms` : List[str] + [description] + `voxel_spacing`: Optional[List[float]], optional + The voxel spacing associated with input CT + as (row spacing, col spacing, slice thickness) + `label` : Optional[int], optional + Whether the patient associated with this serie + has or ever developped cancer. + `censor_time` : Optional[int] + Number of years until cancer diagnostic. + If less than 1 year, should be 0. + `file_type`: Literal['png', 'dicom'] + File type of CT slices + `split`: Literal['train', 'dev', 'test'] + Dataset split into which the serie falls into. + Assumed to be test by default + """ + if label is not None and censor_time is None: + raise ValueError("censor_time should also provided with label.") + if file_type == "png" and voxel_spacing is None: + raise ValueError("voxel_spacing should be provided for PNG files.") + + self._censor_time = censor_time + self._label = label + args = self._load_args(file_type) + self._args = args + self._loader = get_sample_loader(split, args) + self._meta = self._load_metadata(dicoms, voxel_spacing, file_type) + self._check_valid(args) + self.resample_transform = tio.transforms.Resample(target=VOXEL_SPACING) + self.padding_transform = tio.transforms.CropOrPad( + target_shape=tuple(args.img_size + [args.num_images]), padding_mode=0 + ) + + def has_label(self) -> bool: + """Check if there is a label associated with this serie. + + Returns + ------- + bool + [description] + """ + return self._label is not None + + def get_label(self, max_followup: int = 6) -> Label: + """Get the label for this Serie. + + Parameters + ---------- + max_followup : int, optional + [description], by default 6 + + Returns + ------- + Tuple[bool, np.array, np.array, int] + [description] + + Raises + ------ + ValueError + [description] + + """ + if not self.has_label(): + raise ValueError("No label in this serie.") + + # First convert months to years + year_to_cancer = self._censor_time # type: ignore + + y_seq = np.zeros(max_followup, dtype=np.float64) + y = int((year_to_cancer < max_followup) and self._label) # type: ignore + if y: + y_seq[year_to_cancer:] = 1 + else: + year_to_cancer = min(year_to_cancer, max_followup - 1) + + y_mask = np.array( + [1] * (year_to_cancer + 1) + [0] * (max_followup - (year_to_cancer + 1)), + dtype=np.float64, + ) + return Label(y=y, y_seq=y_seq, y_mask=y_mask, censor_time=year_to_cancer) + + def get_raw_images(self) -> List[np.ndarray]: + """ + Load raw images from serie + + Returns + ------- + List[np.ndarray] + List of CT slices of shape (1, C, H, W) + """ + + loader = get_sample_loader("test", self._args, apply_augmentations=False) + input_dicts = [loader.get_image(path) for path in self._meta.paths] + images = [i["input"] for i in input_dicts] + return images + + @functools.lru_cache + def get_volume(self) -> torch.Tensor: + """ + Load loaded 3D CT volume + + Returns + ------- + torch.Tensor + CT volume of shape (1, C, N, H, W) + """ + + input_dicts = [ + self._loader.get_image(path) for path in self._meta.paths + ] + + x = torch.cat([i["input"].unsqueeze(0) for i in input_dicts], dim=0) + + # Convert from (T, C, H, W) to (C, T, H, W) + x = x.permute(1, 0, 2, 3) + + x = tio.ScalarImage( + affine=torch.diag(self._meta.voxel_spacing), + tensor=x.permute(0, 2, 3, 1), + ) + x = self.resample_transform(x) + x = self.padding_transform(x) + x = x.data.permute(0, 3, 1, 2) + x.unsqueeze_(0) + return x + + def _load_metadata(self, paths, voxel_spacing, file_type): + """Extract metadata from dicom files efficiently + + Parameters + ---------- + `paths` : List[str] + List of paths to dicom files + `voxel_spacing`: Optional[List[float]], optional + The voxel spacing associated with input CT + as (row spacing, col spacing, slice thickness) + `file_type` : Literal['png', 'dicom'] + File type of CT slices + + Returns + ------- + Tuple[list] + slice_positions: list of indices for dicoms along z-axis + """ + if file_type == "dicom": + slice_positions = [] + processed_paths = [] + for path in paths: + dcm = pydicom.dcmread(path, stop_before_pixels=True) + processed_paths.append(path) + slice_positions.append(float(dcm.ImagePositionPatient[-1])) + + processed_paths, slice_positions = order_slices( + processed_paths, slice_positions + ) + + thickness = float(dcm.SliceThickness) + pixel_spacing = list(map(float, dcm.PixelSpacing)) + manufacturer = dcm.Manufacturer + voxel_spacing = torch.tensor(pixel_spacing + [thickness, 1]) + elif file_type == "png": + processed_paths = paths + slice_positions = list(range(len(paths))) + thickness = voxel_spacing[-1] if voxel_spacing is not None else None + pixel_spacing = [] + manufacturer = "" + voxel_spacing = ( + torch.tensor(voxel_spacing + [1]) if voxel_spacing is not None else None + ) + + meta = Meta( + paths=processed_paths, + thickness=thickness, + pixel_spacing=pixel_spacing, + manufacturer=manufacturer, + slice_positions=slice_positions, + voxel_spacing=voxel_spacing, + ) + return meta + + def _load_args(self, file_type): + """ + Load default args required for a single Serie volume + + Parameters + ---------- + file_type : Literal['png', 'dicom'] + File type of CT slices + + Returns + ------- + Namespace + args with preset values + """ + args = Namespace( + **{ + "img_size": [256, 256], + "img_mean": [128.1722], + "img_std": [87.1849], + "num_images": 200, + "img_file_type": file_type, + "num_chan": 3, + "cache_path": None, + "use_annotations": False, + "fix_seed_for_multi_image_augmentations": True, + "slice_thickness_filter": 5, + } + ) + return args + + def _check_valid(self, args): + """ + Check if serie is acceptable: + + Parameters + ---------- + `args` : Namespace + manually set args used to develop model + + Raises + ------ + ValueError if: + - serie doesn't have a label, OR + - slice thickness is too big + """ + if self._meta.thickness is None: + raise ValueError("slice thickness not found") + if self._meta.thickness > args.slice_thickness_filter: + raise ValueError( + f"slice thickness {self._meta.thickness} is greater than {args.slice_thickness_filter}." + ) + if self._meta.voxel_spacing is None: + raise ValueError("voxel spacing either not set or not found in DICOM")