a b/sybil/serie.py
1
import functools
2
from typing import List, Optional, NamedTuple, Literal
3
from argparse import Namespace
4
5
import torch
6
import numpy as np
7
import pydicom
8
import torchio as tio
9
10
from sybil.datasets.utils import order_slices, VOXEL_SPACING
11
from sybil.utils.loading import get_sample_loader
12
13
14
class Meta(NamedTuple):
15
    paths: list
16
    thickness: float
17
    pixel_spacing: list
18
    manufacturer: str
19
    slice_positions: list
20
    voxel_spacing: torch.Tensor
21
22
23
class Label(NamedTuple):
24
    y: int
25
    y_seq: np.ndarray
26
    y_mask: np.ndarray
27
    censor_time: int
28
29
30
class Serie:
31
    def __init__(
32
        self,
33
        dicoms: List[str],
34
        voxel_spacing: Optional[List[float]] = None,
35
        label: Optional[int] = None,
36
        censor_time: Optional[int] = None,
37
        file_type: Literal["png", "dicom"] = "dicom",
38
        split: Literal["train", "dev", "test"] = "test",
39
    ):
40
        """Initialize a Serie.
41
42
        Parameters
43
        ----------
44
        `dicoms` : List[str]
45
            [description]
46
        `voxel_spacing`: Optional[List[float]], optional
47
            The voxel spacing associated with input CT
48
            as (row spacing, col spacing, slice thickness)
49
        `label` : Optional[int], optional
50
            Whether the patient associated with this serie
51
            has or ever developped cancer.
52
        `censor_time` : Optional[int]
53
            Number of years until cancer diagnostic.
54
            If less than 1 year, should be 0.
55
        `file_type`: Literal['png', 'dicom']
56
            File type of CT slices
57
        `split`: Literal['train', 'dev', 'test']
58
            Dataset split into which the serie falls into.
59
            Assumed to be test by default
60
        """
61
        if label is not None and censor_time is None:
62
            raise ValueError("censor_time should also provided with label.")
63
        if file_type == "png" and voxel_spacing is None:
64
            raise ValueError("voxel_spacing should be provided for PNG files.")
65
66
        self._censor_time = censor_time
67
        self._label = label
68
        args = self._load_args(file_type)
69
        self._args = args
70
        self._loader = get_sample_loader(split, args)
71
        self._meta = self._load_metadata(dicoms, voxel_spacing, file_type)
72
        self._check_valid(args)
73
        self.resample_transform = tio.transforms.Resample(target=VOXEL_SPACING)
74
        self.padding_transform = tio.transforms.CropOrPad(
75
            target_shape=tuple(args.img_size + [args.num_images]), padding_mode=0
76
        )
77
78
    def has_label(self) -> bool:
79
        """Check if there is a label associated with this serie.
80
81
        Returns
82
        -------
83
        bool
84
            [description]
85
        """
86
        return self._label is not None
87
88
    def get_label(self, max_followup: int = 6) -> Label:
89
        """Get the label for this Serie.
90
91
        Parameters
92
        ----------
93
        max_followup : int, optional
94
            [description], by default 6
95
96
        Returns
97
        -------
98
        Tuple[bool, np.array, np.array, int]
99
            [description]
100
101
        Raises
102
        ------
103
        ValueError
104
            [description]
105
106
        """
107
        if not self.has_label():
108
            raise ValueError("No label in this serie.")
109
110
        # First convert months to years
111
        year_to_cancer = self._censor_time  # type: ignore
112
113
        y_seq = np.zeros(max_followup, dtype=np.float64)
114
        y = int((year_to_cancer < max_followup) and self._label)  # type: ignore
115
        if y:
116
            y_seq[year_to_cancer:] = 1
117
        else:
118
            year_to_cancer = min(year_to_cancer, max_followup - 1)
119
120
        y_mask = np.array(
121
            [1] * (year_to_cancer + 1) + [0] * (max_followup - (year_to_cancer + 1)),
122
            dtype=np.float64,
123
        )
124
        return Label(y=y, y_seq=y_seq, y_mask=y_mask, censor_time=year_to_cancer)
125
126
    def get_raw_images(self) -> List[np.ndarray]:
127
        """
128
        Load raw images from serie
129
130
        Returns
131
        -------
132
        List[np.ndarray]
133
            List of CT slices of shape (1, C, H, W)
134
        """
135
136
        loader = get_sample_loader("test", self._args, apply_augmentations=False)
137
        input_dicts = [loader.get_image(path) for path in self._meta.paths]
138
        images = [i["input"] for i in input_dicts]
139
        return images
140
141
    @functools.lru_cache
142
    def get_volume(self) -> torch.Tensor:
143
        """
144
        Load loaded 3D CT volume
145
146
        Returns
147
        -------
148
        torch.Tensor
149
            CT volume of shape (1, C, N, H, W)
150
        """
151
152
        input_dicts = [
153
            self._loader.get_image(path) for path in self._meta.paths
154
        ]
155
156
        x = torch.cat([i["input"].unsqueeze(0) for i in input_dicts], dim=0)
157
158
        # Convert from (T, C, H, W) to (C, T, H, W)
159
        x = x.permute(1, 0, 2, 3)
160
161
        x = tio.ScalarImage(
162
            affine=torch.diag(self._meta.voxel_spacing),
163
            tensor=x.permute(0, 2, 3, 1),
164
        )
165
        x = self.resample_transform(x)
166
        x = self.padding_transform(x)
167
        x = x.data.permute(0, 3, 1, 2)
168
        x.unsqueeze_(0)
169
        return x
170
171
    def _load_metadata(self, paths, voxel_spacing, file_type):
172
        """Extract metadata from dicom files efficiently
173
174
        Parameters
175
        ----------
176
        `paths` : List[str]
177
            List of paths to dicom files
178
        `voxel_spacing`: Optional[List[float]], optional
179
            The voxel spacing associated with input CT
180
            as (row spacing, col spacing, slice thickness)
181
        `file_type` : Literal['png', 'dicom']
182
            File type of CT slices
183
184
        Returns
185
        -------
186
        Tuple[list]
187
            slice_positions: list of indices for dicoms along z-axis
188
        """
189
        if file_type == "dicom":
190
            slice_positions = []
191
            processed_paths = []
192
            for path in paths:
193
                dcm = pydicom.dcmread(path, stop_before_pixels=True)
194
                processed_paths.append(path)
195
                slice_positions.append(float(dcm.ImagePositionPatient[-1]))
196
197
            processed_paths, slice_positions = order_slices(
198
                processed_paths, slice_positions
199
            )
200
201
            thickness = float(dcm.SliceThickness)
202
            pixel_spacing = list(map(float, dcm.PixelSpacing))
203
            manufacturer = dcm.Manufacturer
204
            voxel_spacing = torch.tensor(pixel_spacing + [thickness, 1])
205
        elif file_type == "png":
206
            processed_paths = paths
207
            slice_positions = list(range(len(paths)))
208
            thickness = voxel_spacing[-1] if voxel_spacing is not None else None
209
            pixel_spacing = []
210
            manufacturer = ""
211
            voxel_spacing = (
212
                torch.tensor(voxel_spacing + [1]) if voxel_spacing is not None else None
213
            )
214
215
        meta = Meta(
216
            paths=processed_paths,
217
            thickness=thickness,
218
            pixel_spacing=pixel_spacing,
219
            manufacturer=manufacturer,
220
            slice_positions=slice_positions,
221
            voxel_spacing=voxel_spacing,
222
        )
223
        return meta
224
225
    def _load_args(self, file_type):
226
        """
227
        Load default args required for a single Serie volume
228
229
        Parameters
230
        ----------
231
        file_type : Literal['png', 'dicom']
232
            File type of CT slices
233
234
        Returns
235
        -------
236
        Namespace
237
            args with preset values
238
        """
239
        args = Namespace(
240
            **{
241
                "img_size": [256, 256],
242
                "img_mean": [128.1722],
243
                "img_std": [87.1849],
244
                "num_images": 200,
245
                "img_file_type": file_type,
246
                "num_chan": 3,
247
                "cache_path": None,
248
                "use_annotations": False,
249
                "fix_seed_for_multi_image_augmentations": True,
250
                "slice_thickness_filter": 5,
251
            }
252
        )
253
        return args
254
255
    def _check_valid(self, args):
256
        """
257
        Check if serie is acceptable:
258
259
        Parameters
260
        ----------
261
        `args` : Namespace
262
            manually set args used to develop model
263
264
        Raises
265
        ------
266
        ValueError if:
267
            - serie doesn't have a label, OR
268
            - slice thickness is too big
269
        """
270
        if self._meta.thickness is None:
271
            raise ValueError("slice thickness not found")
272
        if self._meta.thickness > args.slice_thickness_filter:
273
            raise ValueError(
274
                f"slice thickness {self._meta.thickness} is greater than {args.slice_thickness_filter}."
275
            )
276
        if self._meta.voxel_spacing is None:
277
            raise ValueError("voxel spacing either not set or not found in DICOM")