|
a |
|
b/sybil/datasets/nlst.py |
|
|
1 |
import os |
|
|
2 |
from posixpath import split |
|
|
3 |
import traceback, warnings |
|
|
4 |
import pickle, json |
|
|
5 |
import numpy as np |
|
|
6 |
import pydicom |
|
|
7 |
import torchio as tio |
|
|
8 |
from tqdm import tqdm |
|
|
9 |
from collections import Counter |
|
|
10 |
import torch |
|
|
11 |
import torch.nn.functional as F |
|
|
12 |
from torch.utils import data |
|
|
13 |
from sybil.serie import Serie |
|
|
14 |
from sybil.utils.loading import get_sample_loader |
|
|
15 |
from sybil.datasets.utils import ( |
|
|
16 |
METAFILE_NOTFOUND_ERR, |
|
|
17 |
LOAD_FAIL_MSG, |
|
|
18 |
VOXEL_SPACING, |
|
|
19 |
) |
|
|
20 |
import copy |
|
|
21 |
from sybil.datasets.nlst_risk_factors import NLSTRiskFactorVectorizer |
|
|
22 |
|
|
|
23 |
METADATA_FILENAME = {"google_test": "NLST/full_nlst_google.json"} |
|
|
24 |
|
|
|
25 |
GOOGLE_SPLITS_FILENAME = ( |
|
|
26 |
"/Mounts/rbg-storage1/datasets/NLST/Shetty_et_al(Google)/data_splits.p" |
|
|
27 |
) |
|
|
28 |
|
|
|
29 |
CORRUPTED_PATHS = "/Mounts/rbg-storage1/datasets/NLST/corrupted_img_paths.pkl" |
|
|
30 |
|
|
|
31 |
CT_ITEM_KEYS = [ |
|
|
32 |
"pid", |
|
|
33 |
"exam", |
|
|
34 |
"series", |
|
|
35 |
"y_seq", |
|
|
36 |
"y_mask", |
|
|
37 |
"time_at_event", |
|
|
38 |
"cancer_laterality", |
|
|
39 |
"has_annotation", |
|
|
40 |
"origin_dataset", |
|
|
41 |
] |
|
|
42 |
|
|
|
43 |
RACE_ID_KEYS = { |
|
|
44 |
1: "white", |
|
|
45 |
2: "black", |
|
|
46 |
3: "asian", |
|
|
47 |
4: "american_indian_alaskan", |
|
|
48 |
5: "native_hawaiian_pacific", |
|
|
49 |
6: "hispanic", |
|
|
50 |
} |
|
|
51 |
ETHNICITY_KEYS = {1: "Hispanic or Latino", 2: "Neither Hispanic nor Latino"} |
|
|
52 |
GENDER_KEYS = {1: "Male", 2: "Female"} |
|
|
53 |
EDUCAT_LEVEL = { |
|
|
54 |
1: 1, # 8th grade = less than HS |
|
|
55 |
2: 1, # 9-11th = less than HS |
|
|
56 |
3: 2, # HS Grade |
|
|
57 |
4: 3, # Post-HS |
|
|
58 |
5: 4, # Some College |
|
|
59 |
6: 5, # Bachelors = College Grad |
|
|
60 |
7: 6, # Graduate School = Postrad/Prof |
|
|
61 |
} |
|
|
62 |
|
|
|
63 |
|
|
|
64 |
class NLST_Survival_Dataset(data.Dataset): |
|
|
65 |
def __init__(self, args, split_group): |
|
|
66 |
""" |
|
|
67 |
NLST Dataset |
|
|
68 |
params: args - config. |
|
|
69 |
params: transformer - A transformer object, takes in a PIL image, performs some transforms and returns a Tensor |
|
|
70 |
params: split_group - ['train'|'dev'|'test']. |
|
|
71 |
|
|
|
72 |
constructs: standard pytorch Dataset obj, which can be fed in a DataLoader for batching |
|
|
73 |
""" |
|
|
74 |
super(NLST_Survival_Dataset, self).__init__() |
|
|
75 |
|
|
|
76 |
self.split_group = split_group |
|
|
77 |
self.args = args |
|
|
78 |
self._num_images = args.num_images # number of slices in each volume |
|
|
79 |
self._max_followup = args.max_followup |
|
|
80 |
|
|
|
81 |
try: |
|
|
82 |
self.metadata_json = json.load(open(args.dataset_file_path, "r")) |
|
|
83 |
except Exception as e: |
|
|
84 |
raise Exception(METAFILE_NOTFOUND_ERR.format(args.dataset_file_path, e)) |
|
|
85 |
|
|
|
86 |
self.input_loader = get_sample_loader(split_group, args) |
|
|
87 |
self.always_resample_pixel_spacing = split_group in ["dev", "test"] |
|
|
88 |
|
|
|
89 |
self.resample_transform = tio.transforms.Resample(target=VOXEL_SPACING) |
|
|
90 |
self.padding_transform = tio.transforms.CropOrPad( |
|
|
91 |
target_shape=tuple(args.img_size + [args.num_images]), padding_mode=0 |
|
|
92 |
) |
|
|
93 |
|
|
|
94 |
if args.use_annotations: |
|
|
95 |
assert ( |
|
|
96 |
self.args.region_annotations_filepath |
|
|
97 |
), "ANNOTATIONS METADATA FILE NOT SPECIFIED" |
|
|
98 |
self.annotations_metadata = json.load( |
|
|
99 |
open(self.args.region_annotations_filepath, "r") |
|
|
100 |
) |
|
|
101 |
|
|
|
102 |
self.dataset = self.create_dataset(split_group) |
|
|
103 |
if len(self.dataset) == 0: |
|
|
104 |
return |
|
|
105 |
|
|
|
106 |
print(self.get_summary_statement(self.dataset, split_group)) |
|
|
107 |
|
|
|
108 |
dist_key = "y" |
|
|
109 |
label_dist = [d[dist_key] for d in self.dataset] |
|
|
110 |
label_counts = Counter(label_dist) |
|
|
111 |
weight_per_label = 1.0 / len(label_counts) |
|
|
112 |
label_weights = { |
|
|
113 |
label: weight_per_label / count for label, count in label_counts.items() |
|
|
114 |
} |
|
|
115 |
|
|
|
116 |
print("Class counts are: {}".format(label_counts)) |
|
|
117 |
print("Label weights are {}".format(label_weights)) |
|
|
118 |
self.weights = [label_weights[d[dist_key]] for d in self.dataset] |
|
|
119 |
|
|
|
120 |
def create_dataset(self, split_group): |
|
|
121 |
""" |
|
|
122 |
Gets the dataset from the paths and labels in the json. |
|
|
123 |
Arguments: |
|
|
124 |
split_group(str): One of ['train'|'dev'|'test']. |
|
|
125 |
Returns: |
|
|
126 |
The dataset as a dictionary with img paths, label, |
|
|
127 |
and additional information regarding exam or participant |
|
|
128 |
""" |
|
|
129 |
self.corrupted_paths = self.CORRUPTED_PATHS["paths"] |
|
|
130 |
self.corrupted_series = self.CORRUPTED_PATHS["series"] |
|
|
131 |
# self.risk_factor_vectorizer = NLSTRiskFactorVectorizer(self.args) |
|
|
132 |
|
|
|
133 |
if self.args.assign_splits: |
|
|
134 |
np.random.seed(self.args.cross_val_seed) |
|
|
135 |
self.assign_splits(self.metadata_json) |
|
|
136 |
|
|
|
137 |
dataset = [] |
|
|
138 |
|
|
|
139 |
for mrn_row in tqdm(self.metadata_json, position=0): |
|
|
140 |
pid, split, exams, pt_metadata = ( |
|
|
141 |
mrn_row["pid"], |
|
|
142 |
mrn_row["split"], |
|
|
143 |
mrn_row["accessions"], |
|
|
144 |
mrn_row["pt_metadata"], |
|
|
145 |
) |
|
|
146 |
|
|
|
147 |
if not split == split_group: |
|
|
148 |
continue |
|
|
149 |
|
|
|
150 |
for exam_dict in exams: |
|
|
151 |
|
|
|
152 |
if self.args.use_only_thin_cuts_for_ct and split_group in [ |
|
|
153 |
"train", |
|
|
154 |
"dev", |
|
|
155 |
]: |
|
|
156 |
thinnest_series_id = self.get_thinnest_cut(exam_dict) |
|
|
157 |
|
|
|
158 |
elif split == "test" and self.args.assign_splits: |
|
|
159 |
thinnest_series_id = self.get_thinnest_cut(exam_dict) |
|
|
160 |
|
|
|
161 |
elif split == "test": |
|
|
162 |
google_series = list(self.GOOGLE_SPLITS[pid]["exams"]) |
|
|
163 |
nlst_series = list(exam_dict["image_series"].keys()) |
|
|
164 |
thinnest_series_id = [s for s in nlst_series if s in google_series] |
|
|
165 |
assert len(thinnest_series_id) < 2 |
|
|
166 |
if len(thinnest_series_id) > 0: |
|
|
167 |
thinnest_series_id = thinnest_series_id[0] |
|
|
168 |
elif len(thinnest_series_id) == 0: |
|
|
169 |
if self.args.assign_splits: |
|
|
170 |
thinnest_series_id = self.get_thinnest_cut(exam_dict) |
|
|
171 |
else: |
|
|
172 |
continue |
|
|
173 |
|
|
|
174 |
for series_id, series_dict in exam_dict["image_series"].items(): |
|
|
175 |
if self.skip_sample(series_dict, pt_metadata): |
|
|
176 |
continue |
|
|
177 |
|
|
|
178 |
if self.args.use_only_thin_cuts_for_ct and ( |
|
|
179 |
not series_id == thinnest_series_id |
|
|
180 |
): |
|
|
181 |
continue |
|
|
182 |
|
|
|
183 |
sample = self.get_volume_dict( |
|
|
184 |
series_id, series_dict, exam_dict, pt_metadata, pid, split |
|
|
185 |
) |
|
|
186 |
if len(sample) == 0: |
|
|
187 |
continue |
|
|
188 |
|
|
|
189 |
dataset.append(sample) |
|
|
190 |
|
|
|
191 |
return dataset |
|
|
192 |
|
|
|
193 |
def get_thinnest_cut(self, exam_dict): |
|
|
194 |
# volume that is not thin cut might be the one annotated; or there are multiple volumes with same num slices, so: |
|
|
195 |
# use annotated if available, otherwise use thinnest cut |
|
|
196 |
possibly_annotated_series = [ |
|
|
197 |
s in self.annotations_metadata |
|
|
198 |
for s in list(exam_dict["image_series"].keys()) |
|
|
199 |
] |
|
|
200 |
series_lengths = [ |
|
|
201 |
len(exam_dict["image_series"][series_id]["paths"]) |
|
|
202 |
for series_id in exam_dict["image_series"].keys() |
|
|
203 |
] |
|
|
204 |
thinnest_series_len = max(series_lengths) |
|
|
205 |
thinnest_series_id = [ |
|
|
206 |
k |
|
|
207 |
for k, v in exam_dict["image_series"].items() |
|
|
208 |
if len(v["paths"]) == thinnest_series_len |
|
|
209 |
] |
|
|
210 |
if any(possibly_annotated_series): |
|
|
211 |
thinnest_series_id = list(exam_dict["image_series"].keys())[ |
|
|
212 |
possibly_annotated_series.index(1) |
|
|
213 |
] |
|
|
214 |
else: |
|
|
215 |
thinnest_series_id = thinnest_series_id[0] |
|
|
216 |
return thinnest_series_id |
|
|
217 |
|
|
|
218 |
def skip_sample(self, series_dict, pt_metadata): |
|
|
219 |
series_data = series_dict["series_data"] |
|
|
220 |
# check if screen is localizer screen or not enough images |
|
|
221 |
is_localizer = self.is_localizer(series_data) |
|
|
222 |
|
|
|
223 |
# check if restricting to specific slice thicknesses |
|
|
224 |
slice_thickness = series_data["reconthickness"][0] |
|
|
225 |
wrong_thickness = (self.args.slice_thickness_filter is not None) and ( |
|
|
226 |
slice_thickness not in self.args.slice_thickness_filter |
|
|
227 |
) |
|
|
228 |
|
|
|
229 |
# check if valid label (info is not missing) |
|
|
230 |
screen_timepoint = series_data["study_yr"][0] |
|
|
231 |
bad_label = not self.check_label(pt_metadata, screen_timepoint) |
|
|
232 |
|
|
|
233 |
# invalid label |
|
|
234 |
if not bad_label: |
|
|
235 |
y, _, _, time_at_event = self.get_label(pt_metadata, screen_timepoint) |
|
|
236 |
invalid_label = (y == -1) or (time_at_event < 0) |
|
|
237 |
else: |
|
|
238 |
invalid_label = False |
|
|
239 |
|
|
|
240 |
insufficient_slices = len(series_dict["paths"]) < self.args.min_num_images |
|
|
241 |
|
|
|
242 |
if ( |
|
|
243 |
is_localizer |
|
|
244 |
or wrong_thickness |
|
|
245 |
or bad_label |
|
|
246 |
or invalid_label |
|
|
247 |
or insufficient_slices |
|
|
248 |
): |
|
|
249 |
return True |
|
|
250 |
else: |
|
|
251 |
return False |
|
|
252 |
|
|
|
253 |
def get_volume_dict( |
|
|
254 |
self, series_id, series_dict, exam_dict, pt_metadata, pid, split |
|
|
255 |
): |
|
|
256 |
img_paths = series_dict["paths"] |
|
|
257 |
slice_locations = series_dict["img_position"] |
|
|
258 |
series_data = series_dict["series_data"] |
|
|
259 |
device = series_data["manufacturer"][0] |
|
|
260 |
screen_timepoint = series_data["study_yr"][0] |
|
|
261 |
assert screen_timepoint == exam_dict["screen_timepoint"] |
|
|
262 |
|
|
|
263 |
if series_id in self.corrupted_series: |
|
|
264 |
if any([path in self.corrupted_paths for path in img_paths]): |
|
|
265 |
uncorrupted_imgs = np.where( |
|
|
266 |
[path not in self.corrupted_paths for path in img_paths] |
|
|
267 |
)[0] |
|
|
268 |
img_paths = np.array(img_paths)[uncorrupted_imgs].tolist() |
|
|
269 |
slice_locations = np.array(slice_locations)[uncorrupted_imgs].tolist() |
|
|
270 |
|
|
|
271 |
sorted_img_paths, sorted_slice_locs = self.order_slices( |
|
|
272 |
img_paths, slice_locations |
|
|
273 |
) |
|
|
274 |
|
|
|
275 |
y, y_seq, y_mask, time_at_event = self.get_label(pt_metadata, screen_timepoint) |
|
|
276 |
|
|
|
277 |
exam_int = int( |
|
|
278 |
"{}{}{}".format( |
|
|
279 |
int(pid), int(screen_timepoint), int(series_id.split(".")[-1][-3:]) |
|
|
280 |
) |
|
|
281 |
) |
|
|
282 |
sample = { |
|
|
283 |
"paths": sorted_img_paths, |
|
|
284 |
"slice_locations": sorted_slice_locs, |
|
|
285 |
"y": int(y), |
|
|
286 |
"time_at_event": time_at_event, |
|
|
287 |
"y_seq": y_seq, |
|
|
288 |
"y_mask": y_mask, |
|
|
289 |
"exam_str": "{}_{}".format(exam_dict["exam"], series_id), |
|
|
290 |
"exam": exam_int, |
|
|
291 |
"accession": exam_dict["accession_number"], |
|
|
292 |
"series": series_id, |
|
|
293 |
"study": series_data["studyuid"][0], |
|
|
294 |
"screen_timepoint": screen_timepoint, |
|
|
295 |
"pid": pid, |
|
|
296 |
"device": device, |
|
|
297 |
"institution": pt_metadata["cen"][0], |
|
|
298 |
"cancer_laterality": self.get_cancer_side(pt_metadata), |
|
|
299 |
"num_original_slices": len(series_dict["paths"]), |
|
|
300 |
"pixel_spacing": series_dict["pixel_spacing"] |
|
|
301 |
+ [series_dict["slice_thickness"]], |
|
|
302 |
"slice_thickness": self.get_slice_thickness_class( |
|
|
303 |
series_dict["slice_thickness"] |
|
|
304 |
), |
|
|
305 |
} |
|
|
306 |
|
|
|
307 |
if self.args.use_risk_factors: |
|
|
308 |
sample["risk_factors"] = self.get_risk_factors( |
|
|
309 |
pt_metadata, screen_timepoint, return_dict=False |
|
|
310 |
) |
|
|
311 |
|
|
|
312 |
return sample |
|
|
313 |
|
|
|
314 |
def check_label(self, pt_metadata, screen_timepoint): |
|
|
315 |
valid_days_since_rand = ( |
|
|
316 |
pt_metadata["scr_days{}".format(screen_timepoint)][0] > -1 |
|
|
317 |
) |
|
|
318 |
valid_days_to_cancer = pt_metadata["candx_days"][0] > -1 |
|
|
319 |
valid_followup = pt_metadata["fup_days"][0] > -1 |
|
|
320 |
return (valid_days_since_rand) and (valid_days_to_cancer or valid_followup) |
|
|
321 |
|
|
|
322 |
def get_label(self, pt_metadata, screen_timepoint): |
|
|
323 |
days_since_rand = pt_metadata["scr_days{}".format(screen_timepoint)][0] |
|
|
324 |
days_to_cancer_since_rand = pt_metadata["candx_days"][0] |
|
|
325 |
days_to_cancer = days_to_cancer_since_rand - days_since_rand |
|
|
326 |
years_to_cancer = ( |
|
|
327 |
int(days_to_cancer // 365) if days_to_cancer_since_rand > -1 else 100 |
|
|
328 |
) |
|
|
329 |
days_to_last_followup = int(pt_metadata["fup_days"][0] - days_since_rand) |
|
|
330 |
years_to_last_followup = days_to_last_followup // 365 |
|
|
331 |
y = years_to_cancer < self.args.max_followup |
|
|
332 |
y_seq = np.zeros(self.args.max_followup) |
|
|
333 |
cancer_timepoint = pt_metadata["cancyr"][0] |
|
|
334 |
if y: |
|
|
335 |
if years_to_cancer > -1: |
|
|
336 |
assert screen_timepoint <= cancer_timepoint |
|
|
337 |
time_at_event = years_to_cancer |
|
|
338 |
y_seq[years_to_cancer:] = 1 |
|
|
339 |
else: |
|
|
340 |
time_at_event = min(years_to_last_followup, self.args.max_followup - 1) |
|
|
341 |
y_mask = np.array( |
|
|
342 |
[1] * (time_at_event + 1) |
|
|
343 |
+ [0] * (self.args.max_followup - (time_at_event + 1)) |
|
|
344 |
) |
|
|
345 |
assert len(y_mask) == self.args.max_followup |
|
|
346 |
return y, y_seq.astype("float64"), y_mask.astype("float64"), time_at_event |
|
|
347 |
|
|
|
348 |
def is_localizer(self, series_dict): |
|
|
349 |
is_localizer = ( |
|
|
350 |
(series_dict["imageclass"][0] == 0) |
|
|
351 |
or ("LOCALIZER" in series_dict["imagetype"][0]) |
|
|
352 |
or ("TOP" in series_dict["imagetype"][0]) |
|
|
353 |
) |
|
|
354 |
return is_localizer |
|
|
355 |
|
|
|
356 |
def get_cancer_side(self, pt_metadata): |
|
|
357 |
""" |
|
|
358 |
Return if cancer in left or right |
|
|
359 |
|
|
|
360 |
right: (rhil, right hilum), (rlow, right lower lobe), (rmid, right middle lobe), (rmsb, right main stem), (rup, right upper lobe), |
|
|
361 |
left: (lhil, left hilum), (llow, left lower lobe), (lmsb, left main stem), (lup, left upper lobe), (lin, lingula) |
|
|
362 |
else: (med, mediastinum), (oth, other), (unk, unknown), (car, carina) |
|
|
363 |
""" |
|
|
364 |
right_keys = ["locrhil", "locrlow", "locrmid", "locrmsb", "locrup"] |
|
|
365 |
left_keys = ["loclup", "loclmsb", "locllow", "loclhil", "loclin"] |
|
|
366 |
other_keys = ["loccar", "locmed", "locoth", "locunk"] |
|
|
367 |
|
|
|
368 |
right = any([pt_metadata[key][0] > 0 for key in right_keys]) |
|
|
369 |
left = any([pt_metadata[key][0] > 0 for key in left_keys]) |
|
|
370 |
other = any([pt_metadata[key][0] > 0 for key in other_keys]) |
|
|
371 |
|
|
|
372 |
return np.array([int(right), int(left), int(other)]) |
|
|
373 |
|
|
|
374 |
def order_slices(self, img_paths, slice_locations): |
|
|
375 |
sorted_ids = np.argsort(slice_locations) |
|
|
376 |
sorted_img_paths = np.array(img_paths)[sorted_ids].tolist() |
|
|
377 |
sorted_slice_locs = np.sort(slice_locations).tolist() |
|
|
378 |
|
|
|
379 |
if not sorted_img_paths[0].startswith(self.args.img_dir): |
|
|
380 |
sorted_img_paths = [ |
|
|
381 |
self.args.img_dir |
|
|
382 |
+ path[path.find("nlst-ct-png") + len("nlst-ct-png") :] |
|
|
383 |
for path in sorted_img_paths |
|
|
384 |
] |
|
|
385 |
if ( |
|
|
386 |
self.args.img_file_type == "dicom" |
|
|
387 |
): # ! NOTE: removing file extension affects get_ct_annotations mapping path to annotation |
|
|
388 |
sorted_img_paths = [ |
|
|
389 |
path.replace("nlst-ct-png", "nlst-ct").replace(".png", "") |
|
|
390 |
for path in sorted_img_paths |
|
|
391 |
] |
|
|
392 |
|
|
|
393 |
return sorted_img_paths, sorted_slice_locs |
|
|
394 |
|
|
|
395 |
def get_risk_factors(self, pt_metadata, screen_timepoint, return_dict=False): |
|
|
396 |
age_at_randomization = pt_metadata["age"][0] |
|
|
397 |
days_since_randomization = pt_metadata["scr_days{}".format(screen_timepoint)][0] |
|
|
398 |
current_age = age_at_randomization + days_since_randomization // 365 |
|
|
399 |
|
|
|
400 |
age_start_smoking = pt_metadata["smokeage"][0] |
|
|
401 |
age_quit_smoking = pt_metadata["age_quit"][0] |
|
|
402 |
years_smoking = pt_metadata["smokeyr"][0] |
|
|
403 |
is_smoker = pt_metadata["cigsmok"][0] |
|
|
404 |
|
|
|
405 |
years_since_quit_smoking = 0 if is_smoker else current_age - age_quit_smoking |
|
|
406 |
|
|
|
407 |
education = ( |
|
|
408 |
pt_metadata["educat"][0] |
|
|
409 |
if pt_metadata["educat"][0] != -1 |
|
|
410 |
else pt_metadata["educat"][0] |
|
|
411 |
) |
|
|
412 |
|
|
|
413 |
race = pt_metadata["race"][0] if pt_metadata["race"][0] != -1 else 0 |
|
|
414 |
race = 6 if pt_metadata["ethnic"][0] == 1 else race |
|
|
415 |
ethnicity = pt_metadata["ethnic"][0] |
|
|
416 |
|
|
|
417 |
weight = pt_metadata["weight"][0] if pt_metadata["weight"][0] != -1 else 0 |
|
|
418 |
height = pt_metadata["height"][0] if pt_metadata["height"][0] != -1 else 0 |
|
|
419 |
bmi = weight / (height**2) * 703 if height > 0 else 0 # inches, lbs |
|
|
420 |
|
|
|
421 |
prior_cancer_keys = [ |
|
|
422 |
"cancblad", |
|
|
423 |
"cancbrea", |
|
|
424 |
"canccerv", |
|
|
425 |
"canccolo", |
|
|
426 |
"cancesop", |
|
|
427 |
"canckidn", |
|
|
428 |
"canclary", |
|
|
429 |
"canclung", |
|
|
430 |
"cancoral", |
|
|
431 |
"cancnasa", |
|
|
432 |
"cancpanc", |
|
|
433 |
"cancphar", |
|
|
434 |
"cancstom", |
|
|
435 |
"cancthyr", |
|
|
436 |
"canctran", |
|
|
437 |
] |
|
|
438 |
cancer_hx = any([pt_metadata[key][0] == 1 for key in prior_cancer_keys]) |
|
|
439 |
family_hx = any( |
|
|
440 |
[pt_metadata[key][0] == 1 for key in pt_metadata if key.startswith("fam")] |
|
|
441 |
) |
|
|
442 |
|
|
|
443 |
risk_factors = { |
|
|
444 |
"age": current_age, |
|
|
445 |
"race": race, |
|
|
446 |
"race_name": RACE_ID_KEYS.get(pt_metadata["race"][0], "UNK"), |
|
|
447 |
"ethnicity": ethnicity, |
|
|
448 |
"ethnicity_name": ETHNICITY_KEYS.get(ethnicity, "UNK"), |
|
|
449 |
"education": education, |
|
|
450 |
"bmi": bmi, |
|
|
451 |
"cancer_hx": cancer_hx, |
|
|
452 |
"family_lc_hx": family_hx, |
|
|
453 |
"copd": pt_metadata["diagcopd"][0], |
|
|
454 |
"is_smoker": is_smoker, |
|
|
455 |
"smoking_intensity": pt_metadata["smokeday"][0], |
|
|
456 |
"smoking_duration": pt_metadata["smokeyr"][0], |
|
|
457 |
"years_since_quit_smoking": years_since_quit_smoking, |
|
|
458 |
"weight": weight, |
|
|
459 |
"height": height, |
|
|
460 |
"gender": GENDER_KEYS.get(pt_metadata["gender"][0], "UNK"), |
|
|
461 |
} |
|
|
462 |
|
|
|
463 |
if return_dict: |
|
|
464 |
return risk_factors |
|
|
465 |
else: |
|
|
466 |
return np.array( |
|
|
467 |
[v for v in risk_factors.values() if not isinstance(v, str)] |
|
|
468 |
) |
|
|
469 |
|
|
|
470 |
def assign_splits(self, meta): |
|
|
471 |
if self.args.split_type == "institution_split": |
|
|
472 |
self.assign_institutions_splits(meta) |
|
|
473 |
elif self.args.split_type == "random": |
|
|
474 |
for idx in range(len(meta)): |
|
|
475 |
meta[idx]["split"] = np.random.choice( |
|
|
476 |
["train", "dev", "test"], p=self.args.split_probs |
|
|
477 |
) |
|
|
478 |
|
|
|
479 |
def assign_institutions_splits(self, meta): |
|
|
480 |
institutions = set([m["pt_metadata"]["cen"][0] for m in meta]) |
|
|
481 |
institutions = sorted(institutions) |
|
|
482 |
institute_to_split = { |
|
|
483 |
cen: np.random.choice(["train", "dev", "test"], p=self.args.split_probs) |
|
|
484 |
for cen in institutions |
|
|
485 |
} |
|
|
486 |
for idx in range(len(meta)): |
|
|
487 |
meta[idx]["split"] = institute_to_split[meta[idx]["pt_metadata"]["cen"][0]] |
|
|
488 |
|
|
|
489 |
@property |
|
|
490 |
def METADATA_FILENAME(self): |
|
|
491 |
return METADATA_FILENAME["google_test"] |
|
|
492 |
|
|
|
493 |
@property |
|
|
494 |
def CORRUPTED_PATHS(self): |
|
|
495 |
return pickle.load(open(CORRUPTED_PATHS, "rb")) |
|
|
496 |
|
|
|
497 |
def get_summary_statement(self, dataset, split_group): |
|
|
498 |
summary = "Contructed NLST CT Cancer Risk {} dataset with {} records, {} exams, {} patients, and the following class balance \n {}" |
|
|
499 |
class_balance = Counter([d["y"] for d in dataset]) |
|
|
500 |
exams = set([d["exam"] for d in dataset]) |
|
|
501 |
patients = set([d["pid"] for d in dataset]) |
|
|
502 |
statement = summary.format( |
|
|
503 |
split_group, len(dataset), len(exams), len(patients), class_balance |
|
|
504 |
) |
|
|
505 |
statement += "\n" + "Censor Times: {}".format( |
|
|
506 |
Counter([d["time_at_event"] for d in dataset]) |
|
|
507 |
) |
|
|
508 |
statement |
|
|
509 |
return statement |
|
|
510 |
|
|
|
511 |
@property |
|
|
512 |
def GOOGLE_SPLITS(self): |
|
|
513 |
return pickle.load(open(GOOGLE_SPLITS_FILENAME, "rb")) |
|
|
514 |
|
|
|
515 |
def get_ct_annotations(self, sample): |
|
|
516 |
# correct empty lists of annotations |
|
|
517 |
if sample["series"] in self.annotations_metadata: |
|
|
518 |
self.annotations_metadata[sample["series"]] = { |
|
|
519 |
k: v |
|
|
520 |
for k, v in self.annotations_metadata[sample["series"]].items() |
|
|
521 |
if len(v) > 0 |
|
|
522 |
} |
|
|
523 |
|
|
|
524 |
if sample["series"] in self.annotations_metadata: |
|
|
525 |
# store annotation(s) data (x,y,width,height) for each slice |
|
|
526 |
if ( |
|
|
527 |
self.args.img_file_type == "dicom" |
|
|
528 |
): # no file extension, so os.path.splitext breaks behavior |
|
|
529 |
sample["annotations"] = [ |
|
|
530 |
{ |
|
|
531 |
"image_annotations": self.annotations_metadata[ |
|
|
532 |
sample["series"] |
|
|
533 |
].get(os.path.basename(path), None) |
|
|
534 |
} |
|
|
535 |
for path in sample["paths"] |
|
|
536 |
] |
|
|
537 |
else: # expects file extension to exist, so use os.path.splitext |
|
|
538 |
sample["annotations"] = [ |
|
|
539 |
{ |
|
|
540 |
"image_annotations": self.annotations_metadata[ |
|
|
541 |
sample["series"] |
|
|
542 |
].get(os.path.splitext(os.path.basename(path))[0], None) |
|
|
543 |
} |
|
|
544 |
for path in sample["paths"] |
|
|
545 |
] |
|
|
546 |
else: |
|
|
547 |
sample["annotations"] = [ |
|
|
548 |
{"image_annotations": None} for path in sample["paths"] |
|
|
549 |
] |
|
|
550 |
return sample |
|
|
551 |
|
|
|
552 |
def __len__(self): |
|
|
553 |
return len(self.dataset) |
|
|
554 |
|
|
|
555 |
def __getitem__(self, index): |
|
|
556 |
sample = self.dataset[index] |
|
|
557 |
if self.args.use_annotations: |
|
|
558 |
sample = self.get_ct_annotations(sample) |
|
|
559 |
try: |
|
|
560 |
item = {} |
|
|
561 |
input_dict = self.get_images(sample["paths"], sample) |
|
|
562 |
|
|
|
563 |
x = input_dict["input"] |
|
|
564 |
|
|
|
565 |
if self.args.use_annotations: |
|
|
566 |
mask = torch.abs(input_dict["mask"]) |
|
|
567 |
mask_area = mask.sum(dim=(-1, -2)) |
|
|
568 |
item["volume_annotations"] = mask_area[0] / max(1, mask_area.sum()) |
|
|
569 |
item["annotation_areas"] = mask_area[0] / ( |
|
|
570 |
mask.shape[-2] * mask.shape[-1] |
|
|
571 |
) |
|
|
572 |
mask_area = mask_area.unsqueeze(-1).unsqueeze(-1) |
|
|
573 |
mask_area[mask_area == 0] = 1 |
|
|
574 |
item["image_annotations"] = mask / mask_area |
|
|
575 |
item["has_annotation"] = item["volume_annotations"].sum() > 0 |
|
|
576 |
|
|
|
577 |
if self.args.use_risk_factors: |
|
|
578 |
item["risk_factors"] = sample["risk_factors"] |
|
|
579 |
|
|
|
580 |
item["x"] = x |
|
|
581 |
item["y"] = sample["y"] |
|
|
582 |
for key in CT_ITEM_KEYS: |
|
|
583 |
if key in sample: |
|
|
584 |
item[key] = sample[key] |
|
|
585 |
|
|
|
586 |
return item |
|
|
587 |
except Exception: |
|
|
588 |
warnings.warn(LOAD_FAIL_MSG.format(sample["exam"], traceback.print_exc())) |
|
|
589 |
|
|
|
590 |
def get_images(self, paths, sample): |
|
|
591 |
""" |
|
|
592 |
Returns a stack of transformed images by their absolute paths. |
|
|
593 |
If cache is used - transformed images will be loaded if available, |
|
|
594 |
and saved to cache if not. |
|
|
595 |
""" |
|
|
596 |
out_dict = {} |
|
|
597 |
if self.args.fix_seed_for_multi_image_augmentations: |
|
|
598 |
sample["seed"] = np.random.randint(0, 2**32 - 1) |
|
|
599 |
|
|
|
600 |
# get images for multi image input |
|
|
601 |
s = copy.deepcopy(sample) |
|
|
602 |
input_dicts = [] |
|
|
603 |
for e, path in enumerate(paths): |
|
|
604 |
if self.args.use_annotations: |
|
|
605 |
s["annotations"] = sample["annotations"][e] |
|
|
606 |
input_dicts.append(self.input_loader.get_image(path, s)) |
|
|
607 |
|
|
|
608 |
images = [i["input"] for i in input_dicts] |
|
|
609 |
input_arr = self.reshape_images(images) |
|
|
610 |
if self.args.use_annotations: |
|
|
611 |
masks = [i["mask"] for i in input_dicts] |
|
|
612 |
mask_arr = self.reshape_images(masks) if self.args.use_annotations else None |
|
|
613 |
|
|
|
614 |
# resample pixel spacing |
|
|
615 |
resample_now = self.args.resample_pixel_spacing_prob > np.random.uniform() |
|
|
616 |
if self.always_resample_pixel_spacing or resample_now: |
|
|
617 |
spacing = torch.tensor(sample["pixel_spacing"] + [1]) |
|
|
618 |
input_arr = tio.ScalarImage( |
|
|
619 |
affine=torch.diag(spacing), |
|
|
620 |
tensor=input_arr.permute(0, 2, 3, 1), |
|
|
621 |
) |
|
|
622 |
input_arr = self.resample_transform(input_arr) |
|
|
623 |
input_arr = self.padding_transform(input_arr.data) |
|
|
624 |
|
|
|
625 |
if self.args.use_annotations: |
|
|
626 |
mask_arr = tio.ScalarImage( |
|
|
627 |
affine=torch.diag(spacing), |
|
|
628 |
tensor=mask_arr.permute(0, 2, 3, 1), |
|
|
629 |
) |
|
|
630 |
mask_arr = self.resample_transform(mask_arr) |
|
|
631 |
mask_arr = self.padding_transform(mask_arr.data) |
|
|
632 |
|
|
|
633 |
out_dict["input"] = input_arr.data.permute(0, 3, 1, 2) |
|
|
634 |
if self.args.use_annotations: |
|
|
635 |
out_dict["mask"] = mask_arr.data.permute(0, 3, 1, 2) |
|
|
636 |
|
|
|
637 |
return out_dict |
|
|
638 |
|
|
|
639 |
def reshape_images(self, images): |
|
|
640 |
images = [im.unsqueeze(0) for im in images] |
|
|
641 |
images = torch.cat(images, dim=0) |
|
|
642 |
# Convert from (T, C, H, W) to (C, T, H, W) |
|
|
643 |
images = images.permute(1, 0, 2, 3) |
|
|
644 |
return images |
|
|
645 |
|
|
|
646 |
def get_slice_thickness_class(self, thickness): |
|
|
647 |
BINS = [1, 1.5, 2, 2.5] |
|
|
648 |
for i, tau in enumerate(BINS): |
|
|
649 |
if thickness <= tau: |
|
|
650 |
return i |
|
|
651 |
if self.args.slice_thickness_filter is not None: |
|
|
652 |
raise ValueError("THICKNESS > 2.5") |
|
|
653 |
return 4 |
|
|
654 |
|
|
|
655 |
|
|
|
656 |
class NLST_for_PLCO(NLST_Survival_Dataset): |
|
|
657 |
""" |
|
|
658 |
Dataset for risk factor-based risk model |
|
|
659 |
""" |
|
|
660 |
|
|
|
661 |
def get_volume_dict( |
|
|
662 |
self, series_id, series_dict, exam_dict, pt_metadata, pid, split |
|
|
663 |
): |
|
|
664 |
series_data = series_dict["series_data"] |
|
|
665 |
screen_timepoint = series_data["study_yr"][0] |
|
|
666 |
assert screen_timepoint == exam_dict["screen_timepoint"] |
|
|
667 |
|
|
|
668 |
y, y_seq, y_mask, time_at_event = self.get_label(pt_metadata, screen_timepoint) |
|
|
669 |
|
|
|
670 |
exam_int = int( |
|
|
671 |
"{}{}{}".format( |
|
|
672 |
int(pid), int(screen_timepoint), int(series_id.split(".")[-1][-3:]) |
|
|
673 |
) |
|
|
674 |
) |
|
|
675 |
|
|
|
676 |
riskfactors = self.get_risk_factors( |
|
|
677 |
pt_metadata, screen_timepoint, return_dict=True |
|
|
678 |
) |
|
|
679 |
|
|
|
680 |
riskfactors["education"] = EDUCAT_LEVEL.get(riskfactors["education"], -1) |
|
|
681 |
riskfactors["race"] = RACE_ID_KEYS.get(pt_metadata["race"][0], -1) |
|
|
682 |
|
|
|
683 |
sample = { |
|
|
684 |
"y": int(y), |
|
|
685 |
"time_at_event": time_at_event, |
|
|
686 |
"y_seq": y_seq, |
|
|
687 |
"y_mask": y_mask, |
|
|
688 |
"exam_str": "{}_{}".format(exam_dict["exam"], series_id), |
|
|
689 |
"exam": exam_int, |
|
|
690 |
"accession": exam_dict["accession_number"], |
|
|
691 |
"series": series_id, |
|
|
692 |
"study": series_data["studyuid"][0], |
|
|
693 |
"screen_timepoint": screen_timepoint, |
|
|
694 |
"pid": pid, |
|
|
695 |
} |
|
|
696 |
sample.update(riskfactors) |
|
|
697 |
|
|
|
698 |
if ( |
|
|
699 |
riskfactors["education"] == -1 |
|
|
700 |
or riskfactors["race"] == -1 |
|
|
701 |
or pt_metadata["weight"][0] == -1 |
|
|
702 |
or pt_metadata["height"][0] == -1 |
|
|
703 |
): |
|
|
704 |
return {} |
|
|
705 |
|
|
|
706 |
return sample |
|
|
707 |
|
|
|
708 |
|
|
|
709 |
class NLST_for_PLCO_Screening(NLST_for_PLCO): |
|
|
710 |
def create_dataset(self, split_group): |
|
|
711 |
generated_lung_rads = pickle.load( |
|
|
712 |
open("/data/rsg/mammogram/NLST/nlst_acc2lungrads.p", "rb") |
|
|
713 |
) |
|
|
714 |
dataset = super().create_dataset(split_group) |
|
|
715 |
# get lung rads for each year |
|
|
716 |
pid2lungrads = {} |
|
|
717 |
for d in dataset: |
|
|
718 |
lungrads = generated_lung_rads[d["exam"]] |
|
|
719 |
if d["pid"] in pid2lungrads: |
|
|
720 |
pid2lungrads[d["pid"]][d["screen_timepoint"]] = lungrads |
|
|
721 |
else: |
|
|
722 |
pid2lungrads[d["pid"]] = {d["screen_timepoint"]: lungrads} |
|
|
723 |
plco_results_dataset = [] |
|
|
724 |
for d in dataset: |
|
|
725 |
if len(pid2lungrads[d["pid"]]) < 3: |
|
|
726 |
continue |
|
|
727 |
is_third_screen = d["screen_timepoint"] == 2 |
|
|
728 |
is_1yr_ca_free = (d["y"] and d["time_at_event"] > 0) or (not d["y"]) |
|
|
729 |
if is_third_screen and is_1yr_ca_free: |
|
|
730 |
d["scr_group_coef"] = self.get_screening_group(pid2lungrads[d["pid"]]) |
|
|
731 |
for k in ["age", "years_since_quit_smoking", "smoking_duration"]: |
|
|
732 |
d[k] = d[k] + 1 |
|
|
733 |
plco_results_dataset.append(d) |
|
|
734 |
else: |
|
|
735 |
continue |
|
|
736 |
return plco_results_dataset |
|
|
737 |
|
|
|
738 |
def get_screening_group(self, lung_rads_dict): |
|
|
739 |
"""doi:10.1001/jamanetworkopen.2019.0204 Table 1""" |
|
|
740 |
scr1, scr2, scr3 = lung_rads_dict[0], lung_rads_dict[1], lung_rads_dict[2] |
|
|
741 |
|
|
|
742 |
if all([not scr1, not scr2, not scr3]): |
|
|
743 |
return 0 |
|
|
744 |
elif (not scr3) and ((not scr1) or (not scr2)): |
|
|
745 |
return 0.6554117 |
|
|
746 |
elif ((not scr3) and all([scr1, scr2])) or ( |
|
|
747 |
all([not scr1, not scr2]) and (scr3) |
|
|
748 |
): |
|
|
749 |
return 0.9798233 |
|
|
750 |
elif ( |
|
|
751 |
(all([scr1, scr3]) and not scr2) |
|
|
752 |
or (not scr1 and all([scr2, scr3])) |
|
|
753 |
or (all([scr1, scr2, scr3])) |
|
|
754 |
): |
|
|
755 |
return 2.1940610 |
|
|
756 |
raise ValueError( |
|
|
757 |
"Screen {} has not equivalent PLCO group".format(lung_rads_dict) |
|
|
758 |
) |
|
|
759 |
|
|
|
760 |
|
|
|
761 |
class NLST_Risk_Factor_Task(NLST_Survival_Dataset): |
|
|
762 |
""" |
|
|
763 |
Dataset for risk factor-based risk model |
|
|
764 |
""" |
|
|
765 |
|
|
|
766 |
def get_risk_factors(self, pt_metadata, screen_timepoint, return_dict=False): |
|
|
767 |
return self.risk_factor_vectorizer.get_risk_factors_for_sample( |
|
|
768 |
pt_metadata, screen_timepoint |
|
|
769 |
) |