Diff of /sybil/datasets/nlst.py [000000] .. [d9566e]

Switch to unified view

a b/sybil/datasets/nlst.py
1
import os
2
from posixpath import split
3
import traceback, warnings
4
import pickle, json
5
import numpy as np
6
import pydicom
7
import torchio as tio
8
from tqdm import tqdm
9
from collections import Counter
10
import torch
11
import torch.nn.functional as F
12
from torch.utils import data
13
from sybil.serie import Serie
14
from sybil.utils.loading import get_sample_loader
15
from sybil.datasets.utils import (
16
    METAFILE_NOTFOUND_ERR,
17
    LOAD_FAIL_MSG,
18
    VOXEL_SPACING,
19
)
20
import copy
21
from sybil.datasets.nlst_risk_factors import NLSTRiskFactorVectorizer
22
23
METADATA_FILENAME = {"google_test": "NLST/full_nlst_google.json"}
24
25
GOOGLE_SPLITS_FILENAME = (
26
    "/Mounts/rbg-storage1/datasets/NLST/Shetty_et_al(Google)/data_splits.p"
27
)
28
29
CORRUPTED_PATHS = "/Mounts/rbg-storage1/datasets/NLST/corrupted_img_paths.pkl"
30
31
CT_ITEM_KEYS = [
32
    "pid",
33
    "exam",
34
    "series",
35
    "y_seq",
36
    "y_mask",
37
    "time_at_event",
38
    "cancer_laterality",
39
    "has_annotation",
40
    "origin_dataset",
41
]
42
43
RACE_ID_KEYS = {
44
    1: "white",
45
    2: "black",
46
    3: "asian",
47
    4: "american_indian_alaskan",
48
    5: "native_hawaiian_pacific",
49
    6: "hispanic",
50
}
51
ETHNICITY_KEYS = {1: "Hispanic or Latino", 2: "Neither Hispanic nor Latino"}
52
GENDER_KEYS = {1: "Male", 2: "Female"}
53
EDUCAT_LEVEL = {
54
    1: 1,  # 8th grade = less than HS
55
    2: 1,  # 9-11th = less than HS
56
    3: 2,  # HS Grade
57
    4: 3,  # Post-HS
58
    5: 4,  # Some College
59
    6: 5,  # Bachelors = College Grad
60
    7: 6,  # Graduate School = Postrad/Prof
61
}
62
63
64
class NLST_Survival_Dataset(data.Dataset):
65
    def __init__(self, args, split_group):
66
        """
67
        NLST Dataset
68
        params: args - config.
69
        params: transformer - A transformer object, takes in a PIL image, performs some transforms and returns a Tensor
70
        params: split_group - ['train'|'dev'|'test'].
71
72
        constructs: standard pytorch Dataset obj, which can be fed in a DataLoader for batching
73
        """
74
        super(NLST_Survival_Dataset, self).__init__()
75
76
        self.split_group = split_group
77
        self.args = args
78
        self._num_images = args.num_images  # number of slices in each volume
79
        self._max_followup = args.max_followup
80
81
        try:
82
            self.metadata_json = json.load(open(args.dataset_file_path, "r"))
83
        except Exception as e:
84
            raise Exception(METAFILE_NOTFOUND_ERR.format(args.dataset_file_path, e))
85
86
        self.input_loader = get_sample_loader(split_group, args)
87
        self.always_resample_pixel_spacing = split_group in ["dev", "test"]
88
89
        self.resample_transform = tio.transforms.Resample(target=VOXEL_SPACING)
90
        self.padding_transform = tio.transforms.CropOrPad(
91
            target_shape=tuple(args.img_size + [args.num_images]), padding_mode=0
92
        )
93
94
        if args.use_annotations:
95
            assert (
96
                self.args.region_annotations_filepath
97
            ), "ANNOTATIONS METADATA FILE NOT SPECIFIED"
98
            self.annotations_metadata = json.load(
99
                open(self.args.region_annotations_filepath, "r")
100
            )
101
102
        self.dataset = self.create_dataset(split_group)
103
        if len(self.dataset) == 0:
104
            return
105
106
        print(self.get_summary_statement(self.dataset, split_group))
107
108
        dist_key = "y"
109
        label_dist = [d[dist_key] for d in self.dataset]
110
        label_counts = Counter(label_dist)
111
        weight_per_label = 1.0 / len(label_counts)
112
        label_weights = {
113
            label: weight_per_label / count for label, count in label_counts.items()
114
        }
115
116
        print("Class counts are: {}".format(label_counts))
117
        print("Label weights are {}".format(label_weights))
118
        self.weights = [label_weights[d[dist_key]] for d in self.dataset]
119
120
    def create_dataset(self, split_group):
121
        """
122
        Gets the dataset from the paths and labels in the json.
123
        Arguments:
124
            split_group(str): One of ['train'|'dev'|'test'].
125
        Returns:
126
            The dataset as a dictionary with img paths, label,
127
            and additional information regarding exam or participant
128
        """
129
        self.corrupted_paths = self.CORRUPTED_PATHS["paths"]
130
        self.corrupted_series = self.CORRUPTED_PATHS["series"]
131
        # self.risk_factor_vectorizer = NLSTRiskFactorVectorizer(self.args)
132
133
        if self.args.assign_splits:
134
            np.random.seed(self.args.cross_val_seed)
135
            self.assign_splits(self.metadata_json)
136
137
        dataset = []
138
139
        for mrn_row in tqdm(self.metadata_json, position=0):
140
            pid, split, exams, pt_metadata = (
141
                mrn_row["pid"],
142
                mrn_row["split"],
143
                mrn_row["accessions"],
144
                mrn_row["pt_metadata"],
145
            )
146
147
            if not split == split_group:
148
                continue
149
150
            for exam_dict in exams:
151
152
                if self.args.use_only_thin_cuts_for_ct and split_group in [
153
                    "train",
154
                    "dev",
155
                ]:
156
                    thinnest_series_id = self.get_thinnest_cut(exam_dict)
157
158
                elif split == "test" and self.args.assign_splits:
159
                    thinnest_series_id = self.get_thinnest_cut(exam_dict)
160
161
                elif split == "test":
162
                    google_series = list(self.GOOGLE_SPLITS[pid]["exams"])
163
                    nlst_series = list(exam_dict["image_series"].keys())
164
                    thinnest_series_id = [s for s in nlst_series if s in google_series]
165
                    assert len(thinnest_series_id) < 2
166
                    if len(thinnest_series_id) > 0:
167
                        thinnest_series_id = thinnest_series_id[0]
168
                    elif len(thinnest_series_id) == 0:
169
                        if self.args.assign_splits:
170
                            thinnest_series_id = self.get_thinnest_cut(exam_dict)
171
                        else:
172
                            continue
173
174
                for series_id, series_dict in exam_dict["image_series"].items():
175
                    if self.skip_sample(series_dict, pt_metadata):
176
                        continue
177
178
                    if self.args.use_only_thin_cuts_for_ct and (
179
                        not series_id == thinnest_series_id
180
                    ):
181
                        continue
182
183
                    sample = self.get_volume_dict(
184
                        series_id, series_dict, exam_dict, pt_metadata, pid, split
185
                    )
186
                    if len(sample) == 0:
187
                        continue
188
189
                    dataset.append(sample)
190
191
        return dataset
192
193
    def get_thinnest_cut(self, exam_dict):
194
        # volume that is not thin cut might be the one annotated; or there are multiple volumes with same num slices, so:
195
        # use annotated if available, otherwise use thinnest cut
196
        possibly_annotated_series = [
197
            s in self.annotations_metadata
198
            for s in list(exam_dict["image_series"].keys())
199
        ]
200
        series_lengths = [
201
            len(exam_dict["image_series"][series_id]["paths"])
202
            for series_id in exam_dict["image_series"].keys()
203
        ]
204
        thinnest_series_len = max(series_lengths)
205
        thinnest_series_id = [
206
            k
207
            for k, v in exam_dict["image_series"].items()
208
            if len(v["paths"]) == thinnest_series_len
209
        ]
210
        if any(possibly_annotated_series):
211
            thinnest_series_id = list(exam_dict["image_series"].keys())[
212
                possibly_annotated_series.index(1)
213
            ]
214
        else:
215
            thinnest_series_id = thinnest_series_id[0]
216
        return thinnest_series_id
217
218
    def skip_sample(self, series_dict, pt_metadata):
219
        series_data = series_dict["series_data"]
220
        # check if screen is localizer screen or not enough images
221
        is_localizer = self.is_localizer(series_data)
222
223
        # check if restricting to specific slice thicknesses
224
        slice_thickness = series_data["reconthickness"][0]
225
        wrong_thickness = (self.args.slice_thickness_filter is not None) and (
226
            slice_thickness not in self.args.slice_thickness_filter
227
        )
228
229
        # check if valid label (info is not missing)
230
        screen_timepoint = series_data["study_yr"][0]
231
        bad_label = not self.check_label(pt_metadata, screen_timepoint)
232
233
        # invalid label
234
        if not bad_label:
235
            y, _, _, time_at_event = self.get_label(pt_metadata, screen_timepoint)
236
            invalid_label = (y == -1) or (time_at_event < 0)
237
        else:
238
            invalid_label = False
239
240
        insufficient_slices = len(series_dict["paths"]) < self.args.min_num_images
241
242
        if (
243
            is_localizer
244
            or wrong_thickness
245
            or bad_label
246
            or invalid_label
247
            or insufficient_slices
248
        ):
249
            return True
250
        else:
251
            return False
252
253
    def get_volume_dict(
254
        self, series_id, series_dict, exam_dict, pt_metadata, pid, split
255
    ):
256
        img_paths = series_dict["paths"]
257
        slice_locations = series_dict["img_position"]
258
        series_data = series_dict["series_data"]
259
        device = series_data["manufacturer"][0]
260
        screen_timepoint = series_data["study_yr"][0]
261
        assert screen_timepoint == exam_dict["screen_timepoint"]
262
263
        if series_id in self.corrupted_series:
264
            if any([path in self.corrupted_paths for path in img_paths]):
265
                uncorrupted_imgs = np.where(
266
                    [path not in self.corrupted_paths for path in img_paths]
267
                )[0]
268
                img_paths = np.array(img_paths)[uncorrupted_imgs].tolist()
269
                slice_locations = np.array(slice_locations)[uncorrupted_imgs].tolist()
270
271
        sorted_img_paths, sorted_slice_locs = self.order_slices(
272
            img_paths, slice_locations
273
        )
274
275
        y, y_seq, y_mask, time_at_event = self.get_label(pt_metadata, screen_timepoint)
276
277
        exam_int = int(
278
            "{}{}{}".format(
279
                int(pid), int(screen_timepoint), int(series_id.split(".")[-1][-3:])
280
            )
281
        )
282
        sample = {
283
            "paths": sorted_img_paths,
284
            "slice_locations": sorted_slice_locs,
285
            "y": int(y),
286
            "time_at_event": time_at_event,
287
            "y_seq": y_seq,
288
            "y_mask": y_mask,
289
            "exam_str": "{}_{}".format(exam_dict["exam"], series_id),
290
            "exam": exam_int,
291
            "accession": exam_dict["accession_number"],
292
            "series": series_id,
293
            "study": series_data["studyuid"][0],
294
            "screen_timepoint": screen_timepoint,
295
            "pid": pid,
296
            "device": device,
297
            "institution": pt_metadata["cen"][0],
298
            "cancer_laterality": self.get_cancer_side(pt_metadata),
299
            "num_original_slices": len(series_dict["paths"]),
300
            "pixel_spacing": series_dict["pixel_spacing"]
301
            + [series_dict["slice_thickness"]],
302
            "slice_thickness": self.get_slice_thickness_class(
303
                series_dict["slice_thickness"]
304
            ),
305
        }
306
307
        if self.args.use_risk_factors:
308
            sample["risk_factors"] = self.get_risk_factors(
309
                pt_metadata, screen_timepoint, return_dict=False
310
            )
311
312
        return sample
313
314
    def check_label(self, pt_metadata, screen_timepoint):
315
        valid_days_since_rand = (
316
            pt_metadata["scr_days{}".format(screen_timepoint)][0] > -1
317
        )
318
        valid_days_to_cancer = pt_metadata["candx_days"][0] > -1
319
        valid_followup = pt_metadata["fup_days"][0] > -1
320
        return (valid_days_since_rand) and (valid_days_to_cancer or valid_followup)
321
322
    def get_label(self, pt_metadata, screen_timepoint):
323
        days_since_rand = pt_metadata["scr_days{}".format(screen_timepoint)][0]
324
        days_to_cancer_since_rand = pt_metadata["candx_days"][0]
325
        days_to_cancer = days_to_cancer_since_rand - days_since_rand
326
        years_to_cancer = (
327
            int(days_to_cancer // 365) if days_to_cancer_since_rand > -1 else 100
328
        )
329
        days_to_last_followup = int(pt_metadata["fup_days"][0] - days_since_rand)
330
        years_to_last_followup = days_to_last_followup // 365
331
        y = years_to_cancer < self.args.max_followup
332
        y_seq = np.zeros(self.args.max_followup)
333
        cancer_timepoint = pt_metadata["cancyr"][0]
334
        if y:
335
            if years_to_cancer > -1:
336
                assert screen_timepoint <= cancer_timepoint
337
            time_at_event = years_to_cancer
338
            y_seq[years_to_cancer:] = 1
339
        else:
340
            time_at_event = min(years_to_last_followup, self.args.max_followup - 1)
341
        y_mask = np.array(
342
            [1] * (time_at_event + 1)
343
            + [0] * (self.args.max_followup - (time_at_event + 1))
344
        )
345
        assert len(y_mask) == self.args.max_followup
346
        return y, y_seq.astype("float64"), y_mask.astype("float64"), time_at_event
347
348
    def is_localizer(self, series_dict):
349
        is_localizer = (
350
            (series_dict["imageclass"][0] == 0)
351
            or ("LOCALIZER" in series_dict["imagetype"][0])
352
            or ("TOP" in series_dict["imagetype"][0])
353
        )
354
        return is_localizer
355
356
    def get_cancer_side(self, pt_metadata):
357
        """
358
        Return if cancer in left or right
359
360
        right: (rhil, right hilum), (rlow, right lower lobe), (rmid, right middle lobe), (rmsb, right main stem), (rup, right upper lobe),
361
        left: (lhil, left hilum),  (llow, left lower lobe), (lmsb, left main stem), (lup, left upper lobe), (lin, lingula)
362
        else: (med, mediastinum), (oth, other), (unk, unknown), (car, carina)
363
        """
364
        right_keys = ["locrhil", "locrlow", "locrmid", "locrmsb", "locrup"]
365
        left_keys = ["loclup", "loclmsb", "locllow", "loclhil", "loclin"]
366
        other_keys = ["loccar", "locmed", "locoth", "locunk"]
367
368
        right = any([pt_metadata[key][0] > 0 for key in right_keys])
369
        left = any([pt_metadata[key][0] > 0 for key in left_keys])
370
        other = any([pt_metadata[key][0] > 0 for key in other_keys])
371
372
        return np.array([int(right), int(left), int(other)])
373
374
    def order_slices(self, img_paths, slice_locations):
375
        sorted_ids = np.argsort(slice_locations)
376
        sorted_img_paths = np.array(img_paths)[sorted_ids].tolist()
377
        sorted_slice_locs = np.sort(slice_locations).tolist()
378
379
        if not sorted_img_paths[0].startswith(self.args.img_dir):
380
            sorted_img_paths = [
381
                self.args.img_dir
382
                + path[path.find("nlst-ct-png") + len("nlst-ct-png") :]
383
                for path in sorted_img_paths
384
            ]
385
        if (
386
            self.args.img_file_type == "dicom"
387
        ):  # ! NOTE: removing file extension affects get_ct_annotations mapping path to annotation
388
            sorted_img_paths = [
389
                path.replace("nlst-ct-png", "nlst-ct").replace(".png", "")
390
                for path in sorted_img_paths
391
            ]
392
393
        return sorted_img_paths, sorted_slice_locs
394
395
    def get_risk_factors(self, pt_metadata, screen_timepoint, return_dict=False):
396
        age_at_randomization = pt_metadata["age"][0]
397
        days_since_randomization = pt_metadata["scr_days{}".format(screen_timepoint)][0]
398
        current_age = age_at_randomization + days_since_randomization // 365
399
400
        age_start_smoking = pt_metadata["smokeage"][0]
401
        age_quit_smoking = pt_metadata["age_quit"][0]
402
        years_smoking = pt_metadata["smokeyr"][0]
403
        is_smoker = pt_metadata["cigsmok"][0]
404
405
        years_since_quit_smoking = 0 if is_smoker else current_age - age_quit_smoking
406
407
        education = (
408
            pt_metadata["educat"][0]
409
            if pt_metadata["educat"][0] != -1
410
            else pt_metadata["educat"][0]
411
        )
412
413
        race = pt_metadata["race"][0] if pt_metadata["race"][0] != -1 else 0
414
        race = 6 if pt_metadata["ethnic"][0] == 1 else race
415
        ethnicity = pt_metadata["ethnic"][0]
416
417
        weight = pt_metadata["weight"][0] if pt_metadata["weight"][0] != -1 else 0
418
        height = pt_metadata["height"][0] if pt_metadata["height"][0] != -1 else 0
419
        bmi = weight / (height**2) * 703 if height > 0 else 0  # inches, lbs
420
421
        prior_cancer_keys = [
422
            "cancblad",
423
            "cancbrea",
424
            "canccerv",
425
            "canccolo",
426
            "cancesop",
427
            "canckidn",
428
            "canclary",
429
            "canclung",
430
            "cancoral",
431
            "cancnasa",
432
            "cancpanc",
433
            "cancphar",
434
            "cancstom",
435
            "cancthyr",
436
            "canctran",
437
        ]
438
        cancer_hx = any([pt_metadata[key][0] == 1 for key in prior_cancer_keys])
439
        family_hx = any(
440
            [pt_metadata[key][0] == 1 for key in pt_metadata if key.startswith("fam")]
441
        )
442
443
        risk_factors = {
444
            "age": current_age,
445
            "race": race,
446
            "race_name": RACE_ID_KEYS.get(pt_metadata["race"][0], "UNK"),
447
            "ethnicity": ethnicity,
448
            "ethnicity_name": ETHNICITY_KEYS.get(ethnicity, "UNK"),
449
            "education": education,
450
            "bmi": bmi,
451
            "cancer_hx": cancer_hx,
452
            "family_lc_hx": family_hx,
453
            "copd": pt_metadata["diagcopd"][0],
454
            "is_smoker": is_smoker,
455
            "smoking_intensity": pt_metadata["smokeday"][0],
456
            "smoking_duration": pt_metadata["smokeyr"][0],
457
            "years_since_quit_smoking": years_since_quit_smoking,
458
            "weight": weight,
459
            "height": height,
460
            "gender": GENDER_KEYS.get(pt_metadata["gender"][0], "UNK"),
461
        }
462
463
        if return_dict:
464
            return risk_factors
465
        else:
466
            return np.array(
467
                [v for v in risk_factors.values() if not isinstance(v, str)]
468
            )
469
470
    def assign_splits(self, meta):
471
        if self.args.split_type == "institution_split":
472
            self.assign_institutions_splits(meta)
473
        elif self.args.split_type == "random":
474
            for idx in range(len(meta)):
475
                meta[idx]["split"] = np.random.choice(
476
                    ["train", "dev", "test"], p=self.args.split_probs
477
                )
478
479
    def assign_institutions_splits(self, meta):
480
        institutions = set([m["pt_metadata"]["cen"][0] for m in meta])
481
        institutions = sorted(institutions)
482
        institute_to_split = {
483
            cen: np.random.choice(["train", "dev", "test"], p=self.args.split_probs)
484
            for cen in institutions
485
        }
486
        for idx in range(len(meta)):
487
            meta[idx]["split"] = institute_to_split[meta[idx]["pt_metadata"]["cen"][0]]
488
489
    @property
490
    def METADATA_FILENAME(self):
491
        return METADATA_FILENAME["google_test"]
492
493
    @property
494
    def CORRUPTED_PATHS(self):
495
        return pickle.load(open(CORRUPTED_PATHS, "rb"))
496
497
    def get_summary_statement(self, dataset, split_group):
498
        summary = "Contructed NLST CT Cancer Risk {} dataset with {} records, {} exams, {} patients, and the following class balance \n {}"
499
        class_balance = Counter([d["y"] for d in dataset])
500
        exams = set([d["exam"] for d in dataset])
501
        patients = set([d["pid"] for d in dataset])
502
        statement = summary.format(
503
            split_group, len(dataset), len(exams), len(patients), class_balance
504
        )
505
        statement += "\n" + "Censor Times: {}".format(
506
            Counter([d["time_at_event"] for d in dataset])
507
        )
508
        statement
509
        return statement
510
511
    @property
512
    def GOOGLE_SPLITS(self):
513
        return pickle.load(open(GOOGLE_SPLITS_FILENAME, "rb"))
514
515
    def get_ct_annotations(self, sample):
516
        # correct empty lists of annotations
517
        if sample["series"] in self.annotations_metadata:
518
            self.annotations_metadata[sample["series"]] = {
519
                k: v
520
                for k, v in self.annotations_metadata[sample["series"]].items()
521
                if len(v) > 0
522
            }
523
524
        if sample["series"] in self.annotations_metadata:
525
            # store annotation(s) data (x,y,width,height) for each slice
526
            if (
527
                self.args.img_file_type == "dicom"
528
            ):  # no file extension, so os.path.splitext breaks behavior
529
                sample["annotations"] = [
530
                    {
531
                        "image_annotations": self.annotations_metadata[
532
                            sample["series"]
533
                        ].get(os.path.basename(path), None)
534
                    }
535
                    for path in sample["paths"]
536
                ]
537
            else:  # expects file extension to exist, so use os.path.splitext
538
                sample["annotations"] = [
539
                    {
540
                        "image_annotations": self.annotations_metadata[
541
                            sample["series"]
542
                        ].get(os.path.splitext(os.path.basename(path))[0], None)
543
                    }
544
                    for path in sample["paths"]
545
                ]
546
        else:
547
            sample["annotations"] = [
548
                {"image_annotations": None} for path in sample["paths"]
549
            ]
550
        return sample
551
552
    def __len__(self):
553
        return len(self.dataset)
554
555
    def __getitem__(self, index):
556
        sample = self.dataset[index]
557
        if self.args.use_annotations:
558
            sample = self.get_ct_annotations(sample)
559
        try:
560
            item = {}
561
            input_dict = self.get_images(sample["paths"], sample)
562
563
            x = input_dict["input"]
564
565
            if self.args.use_annotations:
566
                mask = torch.abs(input_dict["mask"])
567
                mask_area = mask.sum(dim=(-1, -2))
568
                item["volume_annotations"] = mask_area[0] / max(1, mask_area.sum())
569
                item["annotation_areas"] = mask_area[0] / (
570
                    mask.shape[-2] * mask.shape[-1]
571
                )
572
                mask_area = mask_area.unsqueeze(-1).unsqueeze(-1)
573
                mask_area[mask_area == 0] = 1
574
                item["image_annotations"] = mask / mask_area
575
                item["has_annotation"] = item["volume_annotations"].sum() > 0
576
577
            if self.args.use_risk_factors:
578
                item["risk_factors"] = sample["risk_factors"]
579
580
            item["x"] = x
581
            item["y"] = sample["y"]
582
            for key in CT_ITEM_KEYS:
583
                if key in sample:
584
                    item[key] = sample[key]
585
586
            return item
587
        except Exception:
588
            warnings.warn(LOAD_FAIL_MSG.format(sample["exam"], traceback.print_exc()))
589
590
    def get_images(self, paths, sample):
591
        """
592
        Returns a stack of transformed images by their absolute paths.
593
        If cache is used - transformed images will be loaded if available,
594
        and saved to cache if not.
595
        """
596
        out_dict = {}
597
        if self.args.fix_seed_for_multi_image_augmentations:
598
            sample["seed"] = np.random.randint(0, 2**32 - 1)
599
600
        # get images for multi image input
601
        s = copy.deepcopy(sample)
602
        input_dicts = []
603
        for e, path in enumerate(paths):
604
            if self.args.use_annotations:
605
                s["annotations"] = sample["annotations"][e]
606
            input_dicts.append(self.input_loader.get_image(path, s))
607
608
        images = [i["input"] for i in input_dicts]
609
        input_arr = self.reshape_images(images)
610
        if self.args.use_annotations:
611
            masks = [i["mask"] for i in input_dicts]
612
            mask_arr = self.reshape_images(masks) if self.args.use_annotations else None
613
614
        # resample pixel spacing
615
        resample_now = self.args.resample_pixel_spacing_prob > np.random.uniform()
616
        if self.always_resample_pixel_spacing or resample_now:
617
            spacing = torch.tensor(sample["pixel_spacing"] + [1])
618
            input_arr = tio.ScalarImage(
619
                affine=torch.diag(spacing),
620
                tensor=input_arr.permute(0, 2, 3, 1),
621
            )
622
            input_arr = self.resample_transform(input_arr)
623
            input_arr = self.padding_transform(input_arr.data)
624
625
            if self.args.use_annotations:
626
                mask_arr = tio.ScalarImage(
627
                    affine=torch.diag(spacing),
628
                    tensor=mask_arr.permute(0, 2, 3, 1),
629
                )
630
                mask_arr = self.resample_transform(mask_arr)
631
                mask_arr = self.padding_transform(mask_arr.data)
632
633
        out_dict["input"] = input_arr.data.permute(0, 3, 1, 2)
634
        if self.args.use_annotations:
635
            out_dict["mask"] = mask_arr.data.permute(0, 3, 1, 2)
636
637
        return out_dict
638
639
    def reshape_images(self, images):
640
        images = [im.unsqueeze(0) for im in images]
641
        images = torch.cat(images, dim=0)
642
        # Convert from (T, C, H, W) to (C, T, H, W)
643
        images = images.permute(1, 0, 2, 3)
644
        return images
645
646
    def get_slice_thickness_class(self, thickness):
647
        BINS = [1, 1.5, 2, 2.5]
648
        for i, tau in enumerate(BINS):
649
            if thickness <= tau:
650
                return i
651
        if self.args.slice_thickness_filter is not None:
652
            raise ValueError("THICKNESS > 2.5")
653
        return 4
654
655
656
class NLST_for_PLCO(NLST_Survival_Dataset):
657
    """
658
    Dataset for risk factor-based risk model
659
    """
660
661
    def get_volume_dict(
662
        self, series_id, series_dict, exam_dict, pt_metadata, pid, split
663
    ):
664
        series_data = series_dict["series_data"]
665
        screen_timepoint = series_data["study_yr"][0]
666
        assert screen_timepoint == exam_dict["screen_timepoint"]
667
668
        y, y_seq, y_mask, time_at_event = self.get_label(pt_metadata, screen_timepoint)
669
670
        exam_int = int(
671
            "{}{}{}".format(
672
                int(pid), int(screen_timepoint), int(series_id.split(".")[-1][-3:])
673
            )
674
        )
675
676
        riskfactors = self.get_risk_factors(
677
            pt_metadata, screen_timepoint, return_dict=True
678
        )
679
680
        riskfactors["education"] = EDUCAT_LEVEL.get(riskfactors["education"], -1)
681
        riskfactors["race"] = RACE_ID_KEYS.get(pt_metadata["race"][0], -1)
682
683
        sample = {
684
            "y": int(y),
685
            "time_at_event": time_at_event,
686
            "y_seq": y_seq,
687
            "y_mask": y_mask,
688
            "exam_str": "{}_{}".format(exam_dict["exam"], series_id),
689
            "exam": exam_int,
690
            "accession": exam_dict["accession_number"],
691
            "series": series_id,
692
            "study": series_data["studyuid"][0],
693
            "screen_timepoint": screen_timepoint,
694
            "pid": pid,
695
        }
696
        sample.update(riskfactors)
697
698
        if (
699
            riskfactors["education"] == -1
700
            or riskfactors["race"] == -1
701
            or pt_metadata["weight"][0] == -1
702
            or pt_metadata["height"][0] == -1
703
        ):
704
            return {}
705
706
        return sample
707
708
709
class NLST_for_PLCO_Screening(NLST_for_PLCO):
710
    def create_dataset(self, split_group):
711
        generated_lung_rads = pickle.load(
712
            open("/data/rsg/mammogram/NLST/nlst_acc2lungrads.p", "rb")
713
        )
714
        dataset = super().create_dataset(split_group)
715
        # get lung rads for each year
716
        pid2lungrads = {}
717
        for d in dataset:
718
            lungrads = generated_lung_rads[d["exam"]]
719
            if d["pid"] in pid2lungrads:
720
                pid2lungrads[d["pid"]][d["screen_timepoint"]] = lungrads
721
            else:
722
                pid2lungrads[d["pid"]] = {d["screen_timepoint"]: lungrads}
723
        plco_results_dataset = []
724
        for d in dataset:
725
            if len(pid2lungrads[d["pid"]]) < 3:
726
                continue
727
            is_third_screen = d["screen_timepoint"] == 2
728
            is_1yr_ca_free = (d["y"] and d["time_at_event"] > 0) or (not d["y"])
729
            if is_third_screen and is_1yr_ca_free:
730
                d["scr_group_coef"] = self.get_screening_group(pid2lungrads[d["pid"]])
731
                for k in ["age", "years_since_quit_smoking", "smoking_duration"]:
732
                    d[k] = d[k] + 1
733
                plco_results_dataset.append(d)
734
            else:
735
                continue
736
        return plco_results_dataset
737
738
    def get_screening_group(self, lung_rads_dict):
739
        """doi:10.1001/jamanetworkopen.2019.0204 Table 1"""
740
        scr1, scr2, scr3 = lung_rads_dict[0], lung_rads_dict[1], lung_rads_dict[2]
741
742
        if all([not scr1, not scr2, not scr3]):
743
            return 0
744
        elif (not scr3) and ((not scr1) or (not scr2)):
745
            return 0.6554117
746
        elif ((not scr3) and all([scr1, scr2])) or (
747
            all([not scr1, not scr2]) and (scr3)
748
        ):
749
            return 0.9798233
750
        elif (
751
            (all([scr1, scr3]) and not scr2)
752
            or (not scr1 and all([scr2, scr3]))
753
            or (all([scr1, scr2, scr3]))
754
        ):
755
            return 2.1940610
756
        raise ValueError(
757
            "Screen {} has not equivalent PLCO group".format(lung_rads_dict)
758
        )
759
760
761
class NLST_Risk_Factor_Task(NLST_Survival_Dataset):
762
    """
763
    Dataset for risk factor-based risk model
764
    """
765
766
    def get_risk_factors(self, pt_metadata, screen_timepoint, return_dict=False):
767
        return self.risk_factor_vectorizer.get_risk_factors_for_sample(
768
            pt_metadata, screen_timepoint
769
        )