Diff of /utils_lung.py [000000] .. [70b6b3]

Switch to unified view

a b/utils_lung.py
1
import dicom
2
import SimpleITK as sitk
3
import numpy as np
4
import csv
5
import os
6
from collections import defaultdict
7
import cPickle as pickle
8
import glob
9
import utils
10
11
12
def read_pkl(path):
13
    d = pickle.load(open(path, "rb"))
14
    return d['pixel_data'], d['origin'], d['spacing']
15
16
def read_mhd(path):
17
    itk_data = sitk.ReadImage(path.encode('utf-8'))
18
    pixel_data = sitk.GetArrayFromImage(itk_data)
19
    origin = np.array(list(reversed(itk_data.GetOrigin())))
20
    spacing = np.array(list(reversed(itk_data.GetSpacing())))
21
    return pixel_data, origin, spacing
22
23
24
def world2voxel(world_coord, origin, spacing):
25
    stretched_voxel_coord = np.absolute(world_coord - origin)
26
    voxel_coord = stretched_voxel_coord / spacing
27
    return voxel_coord
28
29
30
def read_dicom(path):
31
    d = dicom.read_file(path)
32
    metadata = {}
33
    for attr in dir(d):
34
        if attr[0].isupper() and attr != 'PixelData':
35
            try:
36
                metadata[attr] = getattr(d, attr)
37
            except AttributeError:
38
                pass
39
40
    metadata['InstanceNumber'] = int(metadata['InstanceNumber'])
41
    metadata['PixelSpacing'] = np.float32(metadata['PixelSpacing'])
42
    metadata['ImageOrientationPatient'] = np.float32(metadata['ImageOrientationPatient'])
43
    try:
44
        metadata['SliceLocation'] = np.float32(metadata['SliceLocation'])
45
    except:
46
        metadata['SliceLocation'] = None
47
    metadata['ImagePositionPatient'] = np.float32(metadata['ImagePositionPatient'])
48
    metadata['Rows'] = int(metadata['Rows'])
49
    metadata['Columns'] = int(metadata['Columns'])
50
    metadata['RescaleSlope'] = float(metadata['RescaleSlope'])
51
    metadata['RescaleIntercept'] = float(metadata['RescaleIntercept'])
52
    return np.array(d.pixel_array), metadata
53
54
55
def extract_pid_dir(patient_data_path):
56
    return patient_data_path.split('/')[-1]
57
58
59
def extract_pid_filename(file_path, replace_str='.mhd'):
60
    return os.path.basename(file_path).replace(replace_str, '').replace('.pkl', '')
61
62
63
def get_candidates_paths(path):
64
    id2candidates_path = {}
65
    file_paths = sorted(glob.glob(path + '/*.pkl'))
66
    for p in file_paths:
67
        pid = extract_pid_filename(p, '.pkl')
68
        id2candidates_path[pid] = p
69
    return id2candidates_path
70
71
72
def get_patient_data(patient_data_path):
73
    slice_paths = os.listdir(patient_data_path)
74
    sid2data = {}
75
    sid2metadata = {}
76
    for s in slice_paths:
77
        slice_id = s.split('.')[0]
78
        data, metadata = read_dicom(patient_data_path + '/' + s)
79
        sid2data[slice_id] = data
80
        sid2metadata[slice_id] = metadata
81
    return sid2data, sid2metadata
82
83
84
def ct2HU(x, metadata):
85
    x = metadata['RescaleSlope'] * x + metadata['RescaleIntercept']
86
    x[x < -1000] = -1000
87
    return x
88
89
90
def read_dicom_scan(patient_data_path):
91
    sid2data, sid2metadata = get_patient_data(patient_data_path)
92
    sid2position = {}
93
    for sid in sid2data.keys():
94
        sid2position[sid] = get_slice_position(sid2metadata[sid])
95
    sids_sorted = sorted(sid2position.items(), key=lambda x: x[1])
96
    sids_sorted = [s[0] for s in sids_sorted]
97
    z_pixel_spacing = []
98
    for s1, s2 in zip(sids_sorted[1:], sids_sorted[:-1]):
99
        z_pixel_spacing.append(sid2position[s1] - sid2position[s2])
100
    z_pixel_spacing = np.array(z_pixel_spacing)
101
    try:
102
        assert np.all((z_pixel_spacing - z_pixel_spacing[0]) < 0.01)
103
    except:
104
        print 'This patient has multiple series, we will remove one'
105
        sids_sorted_2 = []
106
        for s1, s2 in zip(sids_sorted[::2], sids_sorted[1::2]):
107
            if sid2metadata[s1]["InstanceNumber"] > sid2metadata[s2]["InstanceNumber"]:
108
                sids_sorted_2.append(s1)
109
            else:
110
                sids_sorted_2.append(s2)
111
        sids_sorted = sids_sorted_2
112
        z_pixel_spacing = []
113
        for s1, s2 in zip(sids_sorted[1:], sids_sorted[:-1]):
114
            z_pixel_spacing.append(sid2position[s1] - sid2position[s2])
115
        z_pixel_spacing = np.array(z_pixel_spacing)
116
        assert np.all((z_pixel_spacing - z_pixel_spacing[0]) < 0.01)
117
118
    pixel_spacing = np.array((z_pixel_spacing[0],
119
                              sid2metadata[sids_sorted[0]]['PixelSpacing'][0],
120
                              sid2metadata[sids_sorted[0]]['PixelSpacing'][1]))
121
122
    img = np.stack([ct2HU(sid2data[sid], sid2metadata[sid]) for sid in sids_sorted])
123
124
    return img, pixel_spacing
125
126
127
def sort_slices_position(patient_data):
128
    return sorted(patient_data, key=lambda x: get_slice_position(x['metadata']))
129
130
131
def sort_sids_by_position(sid2metadata):
132
    return sorted(sid2metadata.keys(), key=lambda x: get_slice_position(sid2metadata[x]))
133
134
135
def sort_slices_jonas(sid2metadata):
136
    sid2position = slice_location_finder(sid2metadata)
137
    return sorted(sid2metadata.keys(), key=lambda x: sid2position[x])
138
139
140
def get_slice_position(slice_metadata):
141
    """
142
    https://www.kaggle.com/rmchamberlain/data-science-bowl-2017/dicom-to-3d-numpy-arrays
143
    """
144
    orientation = tuple((o for o in slice_metadata['ImageOrientationPatient']))
145
    position = tuple((p for p in slice_metadata['ImagePositionPatient']))
146
    rowvec, colvec = orientation[:3], orientation[3:]
147
    normal_vector = np.cross(rowvec, colvec)
148
    slice_pos = np.dot(position, normal_vector)
149
    return slice_pos
150
151
152
def slice_location_finder(sid2metadata):
153
    """
154
    :param slicepath2metadata: dict with arbitrary keys, and metadata values
155
    :return:
156
    """
157
158
    sid2midpix = {}
159
    sid2position = {}
160
161
    for sid in sid2metadata:
162
        metadata = sid2metadata[sid]
163
        image_orientation = metadata["ImageOrientationPatient"]
164
        image_position = metadata["ImagePositionPatient"]
165
        pixel_spacing = metadata["PixelSpacing"]
166
        rows = metadata['Rows']
167
        columns = metadata['Columns']
168
169
        # calculate value of middle pixel
170
        F = np.array(image_orientation).reshape((2, 3))
171
        # reversed order, as per http://nipy.org/nibabel/dicom/dicom_orientation.html
172
        i, j = columns / 2.0, rows / 2.0
173
        im_pos = np.array([[i * pixel_spacing[0], j * pixel_spacing[1]]], dtype='float32')
174
        pos = np.array(image_position).reshape((1, 3))
175
        position = np.dot(im_pos, F) + pos
176
        sid2midpix[sid] = position[0, :]
177
178
    if len(sid2midpix) <= 1:
179
        for sp, midpix in sid2midpix.iteritems():
180
            sid2position[sp] = 0.
181
    else:
182
        # find the keys of the 2 points furthest away from each other
183
        max_dist = -1.0
184
        max_dist_keys = []
185
        for sp1, midpix1 in sid2midpix.iteritems():
186
            for sp2, midpix2 in sid2midpix.iteritems():
187
                if sp1 == sp2:
188
                    continue
189
                distance = np.sqrt(np.sum((midpix1 - midpix2) ** 2))
190
                if distance > max_dist:
191
                    max_dist_keys = [sp1, sp2]
192
                    max_dist = distance
193
        # project the others on the line between these 2 points
194
        # sort the keys, so the order is more or less the same as they were
195
        # max_dist_keys.sort(key=lambda x: int(re.search(r'/sax_(\d+)\.pkl$', x).group(1)))
196
        p_ref1 = sid2midpix[max_dist_keys[0]]
197
        p_ref2 = sid2midpix[max_dist_keys[1]]
198
        v1 = p_ref2 - p_ref1
199
        v1 /= np.linalg.norm(v1)
200
201
        for sp, midpix in sid2midpix.iteritems():
202
            v2 = midpix - p_ref1
203
            sid2position[sp] = np.inner(v1, v2)
204
205
    return sid2position
206
207
208
def get_patient_data_paths(data_dir):
209
    pids = sorted(os.listdir(data_dir))
210
    return [data_dir + '/' + p for p in pids]
211
212
def read_patient_annotations_luna(pid, directory):
213
    return pickle.load(open(os.path.join(directory,pid+'.pkl'),"rb"))
214
215
216
def read_labels(file_path):
217
    id2labels = {}
218
    train_csv = open(file_path)
219
    lines = train_csv.readlines()
220
    i = 0
221
    for item in lines:
222
        if i == 0:
223
            i = 1
224
            continue
225
        id, label = item.replace('\n', '').split(',')
226
        id2labels[id] = int(label)
227
    return id2labels
228
229
230
def read_test_labels(file_path):
231
    id2labels = {}
232
    train_csv = open(file_path)
233
    lines = train_csv.readlines()
234
    i = 0
235
    for item in lines:
236
        if i == 0:
237
            i = 1
238
            continue
239
        id, label = item.replace('\n', '').split(';')
240
        id2labels[id] = int(label)
241
    return id2labels
242
243
244
def read_luna_annotations(file_path):
245
    id2xyzd = defaultdict(list)
246
    train_csv = open(file_path)
247
    lines = train_csv.readlines()
248
    i = 0
249
    for item in lines:
250
        if i == 0:
251
            i = 1
252
            continue
253
        id, x, y, z, d = item.replace('\n', '').split(',')
254
        id2xyzd[id].append([float(z), float(y), float(x), float(d)])
255
    return id2xyzd
256
257
258
def read_luna_negative_candidates(file_path):
259
    id2xyzd = defaultdict(list)
260
    train_csv = open(file_path)
261
    lines = train_csv.readlines()
262
    i = 0
263
    for item in lines:
264
        if i == 0:
265
            i = 1
266
            continue
267
        id, x, y, z, d = item.replace('\n', '').split(',')
268
        if float(d) == 0:
269
            id2xyzd[id].append([float(z), float(y), float(x), float(d)])
270
    return id2xyzd
271
272
273
def write_submission(pid2prediction, submission_path):
274
    """
275
    :param pid2prediction: dict of {patient_id: label}
276
    :param submission_path:
277
    """
278
    f = open(submission_path, 'w+')
279
    fo = csv.writer(f, lineterminator='\n')
280
    fo.writerow(['id', 'cancer'])
281
    for pid in pid2prediction.keys():
282
        fo.writerow([pid, pid2prediction[pid]])
283
    f.close()
284
285
286
def filter_close_neighbors(candidates, min_dist=16):
287
    #TODO pixelspacing should be added , it is now hardcoded 
288
    candidates_wo_dupes = set()
289
    no_pairs = 0
290
    for can1 in candidates:
291
        found_close_candidate = False
292
        swap_candidate = None
293
        for can2 in candidates_wo_dupes:
294
            if (can1 == can2).all():
295
                raise "Candidate should not be in the target array yet"
296
            else:
297
                delta = can1[:3] - can2[:3]
298
                delta[0] = 2.5*delta[0] #zyx coos
299
                dist = np.sum(delta**2)**(1./2)
300
                if dist<min_dist:
301
                    no_pairs += 1
302
                    print 'Warning: there is a pair nodules close together',  can1[:3], can2[:3]
303
                    found_close_candidate = True
304
                    if can1[4]>can2[4]:
305
                        swap_candidate = can2
306
                    break
307
        if not found_close_candidate:
308
            candidates_wo_dupes.add(tuple(can1))
309
        elif swap_candidate:
310
            candidates_wo_dupes.remove(swap_candidate)
311
            candidates_wo_dupes.add(tuple(can1))
312
    print 'n candidates filtered out', no_pairs
313
    return candidates_wo_dupes
314
315
def dice_index(predictions, targets, epsilon=1e-12):
316
    predictions = np.asarray(predictions).flatten()
317
    targets = np.asarray(targets).flatten()
318
    dice = (2. * np.sum(targets * predictions) + epsilon) / (np.sum(predictions) + np.sum(targets) + epsilon)
319
    return dice
320
321
322
def cross_entropy(predictions, targets, epsilon=1e-12):
323
    predictions = np.asarray(predictions).flatten()
324
    predictions = np.clip(predictions, epsilon, 1. - epsilon)
325
    targets = np.asarray(targets).flatten()
326
    ce = np.mean(np.log(predictions) * targets + np.log(1 - predictions) * (1. - targets))
327
    return ce
328
329
330
def get_generated_pids(predictions_dir):
331
    pids = []
332
    if os.path.isdir(predictions_dir):
333
        pids = os.listdir(predictions_dir)
334
        pids = [extract_pid_filename(p) for p in pids]
335
    return pids
336
337
def evaluate_log_loss(pid2prediction, pid2label):
338
    predictions, labels = [], []
339
    assert set(pid2prediction.keys()) == set(pid2label.keys())
340
    for k, v in pid2prediction.iteritems():
341
        predictions.append(v)
342
        labels.append(pid2label[k])
343
    return log_loss(labels, predictions)
344
345
346
def log_loss(y_real, y_pred, eps=1e-15):
347
    y_pred = np.clip(y_pred, eps, 1 - eps)
348
    y_real = np.array(y_real)
349
    losses = y_real * np.log(y_pred) + (1 - y_real) * np.log(1 - y_pred)
350
    return - np.average(losses)
351
352
353
def read_luna_properties(file_path):
354
    id2xyzp = defaultdict(list)
355
    train_csv = open(file_path)
356
    lines = train_csv.readlines()
357
    i = 0
358
    for item in lines:
359
        if i == 0:
360
            i = 1
361
            continue
362
        annotation = item.replace('\n', '').split(',')
363
        id = annotation[0]
364
        x = float(annotation[1])
365
        y = float(annotation[2])
366
        z = float(annotation[3])
367
        d = float(annotation[4])
368
        properties_dict = {
369
            'diameter': d,
370
            'calcification': float(annotation[5]),
371
            'internalStructure': float(annotation[6]),
372
            'lobulation': float(annotation[7]),
373
            'malignancy': float(annotation[8]),
374
            'margin': float(annotation[9]),
375
            'sphericity': float(annotation[10]),
376
            'spiculation': float(annotation[11]),
377
            'subtlety': float(annotation[12]),
378
            'texture': float(annotation[13]),
379
        }
380
381
        id2xyzp[id].append([z, y, x, d, properties_dict])
382
    return id2xyzp