Switch to unified view

a b/sybil/datasets/validation.py
1
import numpy as np
2
import torch
3
from torch.utils import data
4
import warnings
5
import json, csv
6
import traceback
7
from collections import Counter
8
from sybil.augmentations import get_augmentations
9
from tqdm import tqdm 
10
from sybil.serie import Serie
11
from sybil.datasets.utils import order_slices, METAFILE_NOTFOUND_ERR, LOAD_FAIL_MSG
12
from sybil.loaders.image_loaders import OpenCVLoader, DicomLoader 
13
14
15
16
class CSVDataset(data.Dataset):
17
    """
18
    Dataset used for large validations
19
    """
20
    def __init__(self, args, split_group):
21
        '''
22
        params: args - config.
23
        params: split_group - ['train'|'dev'|'test'].
24
25
        constructs: standard pytorch Dataset obj, which can be fed in a DataLoader for batching
26
        '''
27
        super(CSVDataset, self).__init__()
28
        
29
        self.split_group = split_group
30
        self.args = args
31
        self._num_images = args.num_images # number of slices in each volume
32
        self._max_followup = args.max_followup
33
34
        try:
35
            self.dataset_dicts = self.parse_csv_dataset(args.dataset_file_path)
36
        except Exception as e:
37
            raise Exception(METAFILE_NOTFOUND_ERR.format(args.dataset_file_path, e))
38
39
        augmentations = get_augmentations(split_group, args)
40
        if args.img_file_type == 'dicom':
41
            self.input_loader = DicomLoader(args.cache_path, augmentations, args)  
42
        else:
43
            self.input_loader = OpenCVLoader(args.cache_path, augmentations, args)
44
            
45
        self.dataset = self.create_dataset(split_group)
46
        if len(self.dataset) == 0:
47
            return
48
        
49
        print(self.get_summary_statement(self.dataset, split_group))
50
        
51
        dist_key = 'y'
52
        label_dist = [d[dist_key] for d in self.dataset]
53
        label_counts = Counter(label_dist)
54
        weight_per_label = 1./ len(label_counts)
55
        label_weights = {
56
            label: weight_per_label/count for label, count in label_counts.items()
57
            }
58
        
59
        print("Class counts are: {}".format(label_counts))
60
        print("Label weights are {}".format(label_weights))
61
        self.weights = [ label_weights[d[dist_key]] for d in self.dataset]
62
63
64
    def parse_csv_dataset(self, file_path):
65
        """
66
        Convert a CSV file into a list of dictionaries for each patient like:
67
        [
68
            {
69
                'patient_id': str, 
70
                'split': str, 
71
                'exam_id': str,
72
                'series_id': str,
73
                'ever_has_future_cancer': bool
74
                'years_to_cancer': int,
75
                'years_to_last_negative_followup': int,
76
                'paths': [str],
77
                'slice_locations': [str]
78
            }
79
        ]
80
81
        Parameters
82
        ----------
83
        file_path : str
84
            path to csv file
85
86
        Returns
87
        -------
88
        list
89
            list patient cases in the above structure
90
        """
91
        dataset_dicts = {}
92
        _reader = csv.DictReader(open(file_path,'r'))
93
        for _row in _reader:
94
            row = {k.encode('ascii', 'ignore').decode(): v.encode('ascii', 'ignore').decode() for k,v in _row.items()}
95
            patient_id, exam_id, series_id = row['patient_id'], row['exam_id'], row['series_id']
96
            unique_id = '{}_{}_{}'.format(patient_id, exam_id, series_id)
97
            if unique_id in dataset_dicts:
98
                dataset_dicts[unique_id]['paths'].append(row['file_path'])
99
                dataset_dicts[unique_id]['slice_locations'].append(row['fileslice_position_path'])
100
            else:
101
                dataset_dicts[unique_id] = {
102
                    'unique_id': unique_id,
103
                    'patient_id': patient_id, 
104
                    'exam_id': exam_id,
105
                    'series_id': series_id,
106
                    'split': row['split'],
107
                    'ever_has_future_cancer': row['ever_has_future_cancer'],
108
                    'years_to_cancer': row['years_to_cancer'],
109
                    'years_to_last_negative_followup': row['years_to_last_negative_followup'],
110
                    'paths': [ row['file_path'] ],
111
                    'slice_locations': [ row['slice_position'] ]
112
                }
113
        
114
        dataset_dicts = list(dataset_dicts.values())
115
        return dataset_dicts
116
117
    def create_dataset(self, split_group):
118
        """
119
        Gets the dataset from the paths and labels in the json.
120
        Arguments:
121
            split_group(str): One of ['train'|'dev'|'test'].
122
        Returns:
123
            The dataset as a dictionary with img paths, label, 
124
            and additional information regarding exam or participant
125
        """
126
        
127
        if self.args.assign_splits:
128
            np.random.seed(self.args.cross_val_seed)
129
            self.assign_splits(self.metadata_json)
130
131
        dataset = []
132
        
133
        for mrn_row in tqdm(self.dataset_dicts, position = 0):
134
            
135
            label = mrn_row['ever_has_future_cancer']
136
            censor_time = mrn_row['years_to_cancer'] if label else mrn_row['years_to_last_negative_followup']    
137
            paths = order_slices(mrn_row['paths'], mrn_row['slice_locations'])
138
            try:
139
                series_object = Serie(
140
                    paths, 
141
                    label, 
142
                    censor_time,
143
                    self.args.img_file_type,
144
                    mrn_row['split']
145
                    )
146
            except Exception:
147
                continue 
148
149
            if self.skip_sample(series_object, mrn_row, split_group):
150
                continue
151
            
152
            dataset.append({
153
                'serie': series_object,
154
                'exam': mrn_row['unique_id']
155
                })
156
157
        return dataset
158
159
    def skip_sample(self, series_object, row, split_group ):
160
        if row['split'] != split_group:
161
            return True
162
        
163
        if not series_object.has_label():
164
            return True
165
        
166
        return False
167
168
    def get_summary_statement(self, dataset, split_group):
169
        summary = "Contructed Sybil Cancer Risk {} dataset with {} records, {} exams, {} patients, and the following class balance \n {}"
170
        class_balance = Counter([d['y'] for d in dataset])
171
        exams = set([d['exam'] for d in dataset])
172
        patients = set([d['patient_id'] for d in dataset])
173
        statement = summary.format(split_group, len(dataset), len(exams), len(patients), class_balance)
174
        statement += "\n" + "Censor Times: {}".format( Counter([d['time_at_event'] for d in dataset]))
175
        return statement
176
177
    def __len__(self):
178
        return len(self.dataset)
179
180
    def __getitem__(self, index):
181
        sample = self.dataset[index]
182
        serie = sample['serie']
183
        try:
184
            labels = serie.get_label()
185
            item = {
186
                'x': serie.get_volume(),
187
                'y': labels.y,
188
                'y_seq': labels.y_seq,
189
                'y_mask': labels.y_mask,
190
                'time_at_event': labels.censor_time,
191
                'exam': sample['exam']
192
            }
193
            return item
194
        except Exception:
195
            warnings.warn(LOAD_FAIL_MSG.format(sample['paths'], traceback.print_exc()))  
196