|
a |
|
b/sybil/datasets/validation.py |
|
|
1 |
import numpy as np |
|
|
2 |
import torch |
|
|
3 |
from torch.utils import data |
|
|
4 |
import warnings |
|
|
5 |
import json, csv |
|
|
6 |
import traceback |
|
|
7 |
from collections import Counter |
|
|
8 |
from sybil.augmentations import get_augmentations |
|
|
9 |
from tqdm import tqdm |
|
|
10 |
from sybil.serie import Serie |
|
|
11 |
from sybil.datasets.utils import order_slices, METAFILE_NOTFOUND_ERR, LOAD_FAIL_MSG |
|
|
12 |
from sybil.loaders.image_loaders import OpenCVLoader, DicomLoader |
|
|
13 |
|
|
|
14 |
|
|
|
15 |
|
|
|
16 |
class CSVDataset(data.Dataset): |
|
|
17 |
""" |
|
|
18 |
Dataset used for large validations |
|
|
19 |
""" |
|
|
20 |
def __init__(self, args, split_group): |
|
|
21 |
''' |
|
|
22 |
params: args - config. |
|
|
23 |
params: split_group - ['train'|'dev'|'test']. |
|
|
24 |
|
|
|
25 |
constructs: standard pytorch Dataset obj, which can be fed in a DataLoader for batching |
|
|
26 |
''' |
|
|
27 |
super(CSVDataset, self).__init__() |
|
|
28 |
|
|
|
29 |
self.split_group = split_group |
|
|
30 |
self.args = args |
|
|
31 |
self._num_images = args.num_images # number of slices in each volume |
|
|
32 |
self._max_followup = args.max_followup |
|
|
33 |
|
|
|
34 |
try: |
|
|
35 |
self.dataset_dicts = self.parse_csv_dataset(args.dataset_file_path) |
|
|
36 |
except Exception as e: |
|
|
37 |
raise Exception(METAFILE_NOTFOUND_ERR.format(args.dataset_file_path, e)) |
|
|
38 |
|
|
|
39 |
augmentations = get_augmentations(split_group, args) |
|
|
40 |
if args.img_file_type == 'dicom': |
|
|
41 |
self.input_loader = DicomLoader(args.cache_path, augmentations, args) |
|
|
42 |
else: |
|
|
43 |
self.input_loader = OpenCVLoader(args.cache_path, augmentations, args) |
|
|
44 |
|
|
|
45 |
self.dataset = self.create_dataset(split_group) |
|
|
46 |
if len(self.dataset) == 0: |
|
|
47 |
return |
|
|
48 |
|
|
|
49 |
print(self.get_summary_statement(self.dataset, split_group)) |
|
|
50 |
|
|
|
51 |
dist_key = 'y' |
|
|
52 |
label_dist = [d[dist_key] for d in self.dataset] |
|
|
53 |
label_counts = Counter(label_dist) |
|
|
54 |
weight_per_label = 1./ len(label_counts) |
|
|
55 |
label_weights = { |
|
|
56 |
label: weight_per_label/count for label, count in label_counts.items() |
|
|
57 |
} |
|
|
58 |
|
|
|
59 |
print("Class counts are: {}".format(label_counts)) |
|
|
60 |
print("Label weights are {}".format(label_weights)) |
|
|
61 |
self.weights = [ label_weights[d[dist_key]] for d in self.dataset] |
|
|
62 |
|
|
|
63 |
|
|
|
64 |
def parse_csv_dataset(self, file_path): |
|
|
65 |
""" |
|
|
66 |
Convert a CSV file into a list of dictionaries for each patient like: |
|
|
67 |
[ |
|
|
68 |
{ |
|
|
69 |
'patient_id': str, |
|
|
70 |
'split': str, |
|
|
71 |
'exam_id': str, |
|
|
72 |
'series_id': str, |
|
|
73 |
'ever_has_future_cancer': bool |
|
|
74 |
'years_to_cancer': int, |
|
|
75 |
'years_to_last_negative_followup': int, |
|
|
76 |
'paths': [str], |
|
|
77 |
'slice_locations': [str] |
|
|
78 |
} |
|
|
79 |
] |
|
|
80 |
|
|
|
81 |
Parameters |
|
|
82 |
---------- |
|
|
83 |
file_path : str |
|
|
84 |
path to csv file |
|
|
85 |
|
|
|
86 |
Returns |
|
|
87 |
------- |
|
|
88 |
list |
|
|
89 |
list patient cases in the above structure |
|
|
90 |
""" |
|
|
91 |
dataset_dicts = {} |
|
|
92 |
_reader = csv.DictReader(open(file_path,'r')) |
|
|
93 |
for _row in _reader: |
|
|
94 |
row = {k.encode('ascii', 'ignore').decode(): v.encode('ascii', 'ignore').decode() for k,v in _row.items()} |
|
|
95 |
patient_id, exam_id, series_id = row['patient_id'], row['exam_id'], row['series_id'] |
|
|
96 |
unique_id = '{}_{}_{}'.format(patient_id, exam_id, series_id) |
|
|
97 |
if unique_id in dataset_dicts: |
|
|
98 |
dataset_dicts[unique_id]['paths'].append(row['file_path']) |
|
|
99 |
dataset_dicts[unique_id]['slice_locations'].append(row['fileslice_position_path']) |
|
|
100 |
else: |
|
|
101 |
dataset_dicts[unique_id] = { |
|
|
102 |
'unique_id': unique_id, |
|
|
103 |
'patient_id': patient_id, |
|
|
104 |
'exam_id': exam_id, |
|
|
105 |
'series_id': series_id, |
|
|
106 |
'split': row['split'], |
|
|
107 |
'ever_has_future_cancer': row['ever_has_future_cancer'], |
|
|
108 |
'years_to_cancer': row['years_to_cancer'], |
|
|
109 |
'years_to_last_negative_followup': row['years_to_last_negative_followup'], |
|
|
110 |
'paths': [ row['file_path'] ], |
|
|
111 |
'slice_locations': [ row['slice_position'] ] |
|
|
112 |
} |
|
|
113 |
|
|
|
114 |
dataset_dicts = list(dataset_dicts.values()) |
|
|
115 |
return dataset_dicts |
|
|
116 |
|
|
|
117 |
def create_dataset(self, split_group): |
|
|
118 |
""" |
|
|
119 |
Gets the dataset from the paths and labels in the json. |
|
|
120 |
Arguments: |
|
|
121 |
split_group(str): One of ['train'|'dev'|'test']. |
|
|
122 |
Returns: |
|
|
123 |
The dataset as a dictionary with img paths, label, |
|
|
124 |
and additional information regarding exam or participant |
|
|
125 |
""" |
|
|
126 |
|
|
|
127 |
if self.args.assign_splits: |
|
|
128 |
np.random.seed(self.args.cross_val_seed) |
|
|
129 |
self.assign_splits(self.metadata_json) |
|
|
130 |
|
|
|
131 |
dataset = [] |
|
|
132 |
|
|
|
133 |
for mrn_row in tqdm(self.dataset_dicts, position = 0): |
|
|
134 |
|
|
|
135 |
label = mrn_row['ever_has_future_cancer'] |
|
|
136 |
censor_time = mrn_row['years_to_cancer'] if label else mrn_row['years_to_last_negative_followup'] |
|
|
137 |
paths = order_slices(mrn_row['paths'], mrn_row['slice_locations']) |
|
|
138 |
try: |
|
|
139 |
series_object = Serie( |
|
|
140 |
paths, |
|
|
141 |
label, |
|
|
142 |
censor_time, |
|
|
143 |
self.args.img_file_type, |
|
|
144 |
mrn_row['split'] |
|
|
145 |
) |
|
|
146 |
except Exception: |
|
|
147 |
continue |
|
|
148 |
|
|
|
149 |
if self.skip_sample(series_object, mrn_row, split_group): |
|
|
150 |
continue |
|
|
151 |
|
|
|
152 |
dataset.append({ |
|
|
153 |
'serie': series_object, |
|
|
154 |
'exam': mrn_row['unique_id'] |
|
|
155 |
}) |
|
|
156 |
|
|
|
157 |
return dataset |
|
|
158 |
|
|
|
159 |
def skip_sample(self, series_object, row, split_group ): |
|
|
160 |
if row['split'] != split_group: |
|
|
161 |
return True |
|
|
162 |
|
|
|
163 |
if not series_object.has_label(): |
|
|
164 |
return True |
|
|
165 |
|
|
|
166 |
return False |
|
|
167 |
|
|
|
168 |
def get_summary_statement(self, dataset, split_group): |
|
|
169 |
summary = "Contructed Sybil Cancer Risk {} dataset with {} records, {} exams, {} patients, and the following class balance \n {}" |
|
|
170 |
class_balance = Counter([d['y'] for d in dataset]) |
|
|
171 |
exams = set([d['exam'] for d in dataset]) |
|
|
172 |
patients = set([d['patient_id'] for d in dataset]) |
|
|
173 |
statement = summary.format(split_group, len(dataset), len(exams), len(patients), class_balance) |
|
|
174 |
statement += "\n" + "Censor Times: {}".format( Counter([d['time_at_event'] for d in dataset])) |
|
|
175 |
return statement |
|
|
176 |
|
|
|
177 |
def __len__(self): |
|
|
178 |
return len(self.dataset) |
|
|
179 |
|
|
|
180 |
def __getitem__(self, index): |
|
|
181 |
sample = self.dataset[index] |
|
|
182 |
serie = sample['serie'] |
|
|
183 |
try: |
|
|
184 |
labels = serie.get_label() |
|
|
185 |
item = { |
|
|
186 |
'x': serie.get_volume(), |
|
|
187 |
'y': labels.y, |
|
|
188 |
'y_seq': labels.y_seq, |
|
|
189 |
'y_mask': labels.y_mask, |
|
|
190 |
'time_at_event': labels.censor_time, |
|
|
191 |
'exam': sample['exam'] |
|
|
192 |
} |
|
|
193 |
return item |
|
|
194 |
except Exception: |
|
|
195 |
warnings.warn(LOAD_FAIL_MSG.format(sample['paths'], traceback.print_exc())) |
|
|
196 |
|