Diff of /sybil/serie.py [000000] .. [d9566e]

Switch to side-by-side view

--- a
+++ b/sybil/serie.py
@@ -0,0 +1,277 @@
+import functools
+from typing import List, Optional, NamedTuple, Literal
+from argparse import Namespace
+
+import torch
+import numpy as np
+import pydicom
+import torchio as tio
+
+from sybil.datasets.utils import order_slices, VOXEL_SPACING
+from sybil.utils.loading import get_sample_loader
+
+
+class Meta(NamedTuple):
+    paths: list
+    thickness: float
+    pixel_spacing: list
+    manufacturer: str
+    slice_positions: list
+    voxel_spacing: torch.Tensor
+
+
+class Label(NamedTuple):
+    y: int
+    y_seq: np.ndarray
+    y_mask: np.ndarray
+    censor_time: int
+
+
+class Serie:
+    def __init__(
+        self,
+        dicoms: List[str],
+        voxel_spacing: Optional[List[float]] = None,
+        label: Optional[int] = None,
+        censor_time: Optional[int] = None,
+        file_type: Literal["png", "dicom"] = "dicom",
+        split: Literal["train", "dev", "test"] = "test",
+    ):
+        """Initialize a Serie.
+
+        Parameters
+        ----------
+        `dicoms` : List[str]
+            [description]
+        `voxel_spacing`: Optional[List[float]], optional
+            The voxel spacing associated with input CT
+            as (row spacing, col spacing, slice thickness)
+        `label` : Optional[int], optional
+            Whether the patient associated with this serie
+            has or ever developped cancer.
+        `censor_time` : Optional[int]
+            Number of years until cancer diagnostic.
+            If less than 1 year, should be 0.
+        `file_type`: Literal['png', 'dicom']
+            File type of CT slices
+        `split`: Literal['train', 'dev', 'test']
+            Dataset split into which the serie falls into.
+            Assumed to be test by default
+        """
+        if label is not None and censor_time is None:
+            raise ValueError("censor_time should also provided with label.")
+        if file_type == "png" and voxel_spacing is None:
+            raise ValueError("voxel_spacing should be provided for PNG files.")
+
+        self._censor_time = censor_time
+        self._label = label
+        args = self._load_args(file_type)
+        self._args = args
+        self._loader = get_sample_loader(split, args)
+        self._meta = self._load_metadata(dicoms, voxel_spacing, file_type)
+        self._check_valid(args)
+        self.resample_transform = tio.transforms.Resample(target=VOXEL_SPACING)
+        self.padding_transform = tio.transforms.CropOrPad(
+            target_shape=tuple(args.img_size + [args.num_images]), padding_mode=0
+        )
+
+    def has_label(self) -> bool:
+        """Check if there is a label associated with this serie.
+
+        Returns
+        -------
+        bool
+            [description]
+        """
+        return self._label is not None
+
+    def get_label(self, max_followup: int = 6) -> Label:
+        """Get the label for this Serie.
+
+        Parameters
+        ----------
+        max_followup : int, optional
+            [description], by default 6
+
+        Returns
+        -------
+        Tuple[bool, np.array, np.array, int]
+            [description]
+
+        Raises
+        ------
+        ValueError
+            [description]
+
+        """
+        if not self.has_label():
+            raise ValueError("No label in this serie.")
+
+        # First convert months to years
+        year_to_cancer = self._censor_time  # type: ignore
+
+        y_seq = np.zeros(max_followup, dtype=np.float64)
+        y = int((year_to_cancer < max_followup) and self._label)  # type: ignore
+        if y:
+            y_seq[year_to_cancer:] = 1
+        else:
+            year_to_cancer = min(year_to_cancer, max_followup - 1)
+
+        y_mask = np.array(
+            [1] * (year_to_cancer + 1) + [0] * (max_followup - (year_to_cancer + 1)),
+            dtype=np.float64,
+        )
+        return Label(y=y, y_seq=y_seq, y_mask=y_mask, censor_time=year_to_cancer)
+
+    def get_raw_images(self) -> List[np.ndarray]:
+        """
+        Load raw images from serie
+
+        Returns
+        -------
+        List[np.ndarray]
+            List of CT slices of shape (1, C, H, W)
+        """
+
+        loader = get_sample_loader("test", self._args, apply_augmentations=False)
+        input_dicts = [loader.get_image(path) for path in self._meta.paths]
+        images = [i["input"] for i in input_dicts]
+        return images
+
+    @functools.lru_cache
+    def get_volume(self) -> torch.Tensor:
+        """
+        Load loaded 3D CT volume
+
+        Returns
+        -------
+        torch.Tensor
+            CT volume of shape (1, C, N, H, W)
+        """
+
+        input_dicts = [
+            self._loader.get_image(path) for path in self._meta.paths
+        ]
+
+        x = torch.cat([i["input"].unsqueeze(0) for i in input_dicts], dim=0)
+
+        # Convert from (T, C, H, W) to (C, T, H, W)
+        x = x.permute(1, 0, 2, 3)
+
+        x = tio.ScalarImage(
+            affine=torch.diag(self._meta.voxel_spacing),
+            tensor=x.permute(0, 2, 3, 1),
+        )
+        x = self.resample_transform(x)
+        x = self.padding_transform(x)
+        x = x.data.permute(0, 3, 1, 2)
+        x.unsqueeze_(0)
+        return x
+
+    def _load_metadata(self, paths, voxel_spacing, file_type):
+        """Extract metadata from dicom files efficiently
+
+        Parameters
+        ----------
+        `paths` : List[str]
+            List of paths to dicom files
+        `voxel_spacing`: Optional[List[float]], optional
+            The voxel spacing associated with input CT
+            as (row spacing, col spacing, slice thickness)
+        `file_type` : Literal['png', 'dicom']
+            File type of CT slices
+
+        Returns
+        -------
+        Tuple[list]
+            slice_positions: list of indices for dicoms along z-axis
+        """
+        if file_type == "dicom":
+            slice_positions = []
+            processed_paths = []
+            for path in paths:
+                dcm = pydicom.dcmread(path, stop_before_pixels=True)
+                processed_paths.append(path)
+                slice_positions.append(float(dcm.ImagePositionPatient[-1]))
+
+            processed_paths, slice_positions = order_slices(
+                processed_paths, slice_positions
+            )
+
+            thickness = float(dcm.SliceThickness)
+            pixel_spacing = list(map(float, dcm.PixelSpacing))
+            manufacturer = dcm.Manufacturer
+            voxel_spacing = torch.tensor(pixel_spacing + [thickness, 1])
+        elif file_type == "png":
+            processed_paths = paths
+            slice_positions = list(range(len(paths)))
+            thickness = voxel_spacing[-1] if voxel_spacing is not None else None
+            pixel_spacing = []
+            manufacturer = ""
+            voxel_spacing = (
+                torch.tensor(voxel_spacing + [1]) if voxel_spacing is not None else None
+            )
+
+        meta = Meta(
+            paths=processed_paths,
+            thickness=thickness,
+            pixel_spacing=pixel_spacing,
+            manufacturer=manufacturer,
+            slice_positions=slice_positions,
+            voxel_spacing=voxel_spacing,
+        )
+        return meta
+
+    def _load_args(self, file_type):
+        """
+        Load default args required for a single Serie volume
+
+        Parameters
+        ----------
+        file_type : Literal['png', 'dicom']
+            File type of CT slices
+
+        Returns
+        -------
+        Namespace
+            args with preset values
+        """
+        args = Namespace(
+            **{
+                "img_size": [256, 256],
+                "img_mean": [128.1722],
+                "img_std": [87.1849],
+                "num_images": 200,
+                "img_file_type": file_type,
+                "num_chan": 3,
+                "cache_path": None,
+                "use_annotations": False,
+                "fix_seed_for_multi_image_augmentations": True,
+                "slice_thickness_filter": 5,
+            }
+        )
+        return args
+
+    def _check_valid(self, args):
+        """
+        Check if serie is acceptable:
+
+        Parameters
+        ----------
+        `args` : Namespace
+            manually set args used to develop model
+
+        Raises
+        ------
+        ValueError if:
+            - serie doesn't have a label, OR
+            - slice thickness is too big
+        """
+        if self._meta.thickness is None:
+            raise ValueError("slice thickness not found")
+        if self._meta.thickness > args.slice_thickness_filter:
+            raise ValueError(
+                f"slice thickness {self._meta.thickness} is greater than {args.slice_thickness_filter}."
+            )
+        if self._meta.voxel_spacing is None:
+            raise ValueError("voxel spacing either not set or not found in DICOM")