--- a
+++ b/utils_lung.py
@@ -0,0 +1,382 @@
+import dicom
+import SimpleITK as sitk
+import numpy as np
+import csv
+import os
+from collections import defaultdict
+import cPickle as pickle
+import glob
+import utils
+
+
+def read_pkl(path):
+    d = pickle.load(open(path, "rb"))
+    return d['pixel_data'], d['origin'], d['spacing']
+
+def read_mhd(path):
+    itk_data = sitk.ReadImage(path.encode('utf-8'))
+    pixel_data = sitk.GetArrayFromImage(itk_data)
+    origin = np.array(list(reversed(itk_data.GetOrigin())))
+    spacing = np.array(list(reversed(itk_data.GetSpacing())))
+    return pixel_data, origin, spacing
+
+
+def world2voxel(world_coord, origin, spacing):
+    stretched_voxel_coord = np.absolute(world_coord - origin)
+    voxel_coord = stretched_voxel_coord / spacing
+    return voxel_coord
+
+
+def read_dicom(path):
+    d = dicom.read_file(path)
+    metadata = {}
+    for attr in dir(d):
+        if attr[0].isupper() and attr != 'PixelData':
+            try:
+                metadata[attr] = getattr(d, attr)
+            except AttributeError:
+                pass
+
+    metadata['InstanceNumber'] = int(metadata['InstanceNumber'])
+    metadata['PixelSpacing'] = np.float32(metadata['PixelSpacing'])
+    metadata['ImageOrientationPatient'] = np.float32(metadata['ImageOrientationPatient'])
+    try:
+        metadata['SliceLocation'] = np.float32(metadata['SliceLocation'])
+    except:
+        metadata['SliceLocation'] = None
+    metadata['ImagePositionPatient'] = np.float32(metadata['ImagePositionPatient'])
+    metadata['Rows'] = int(metadata['Rows'])
+    metadata['Columns'] = int(metadata['Columns'])
+    metadata['RescaleSlope'] = float(metadata['RescaleSlope'])
+    metadata['RescaleIntercept'] = float(metadata['RescaleIntercept'])
+    return np.array(d.pixel_array), metadata
+
+
+def extract_pid_dir(patient_data_path):
+    return patient_data_path.split('/')[-1]
+
+
+def extract_pid_filename(file_path, replace_str='.mhd'):
+    return os.path.basename(file_path).replace(replace_str, '').replace('.pkl', '')
+
+
+def get_candidates_paths(path):
+    id2candidates_path = {}
+    file_paths = sorted(glob.glob(path + '/*.pkl'))
+    for p in file_paths:
+        pid = extract_pid_filename(p, '.pkl')
+        id2candidates_path[pid] = p
+    return id2candidates_path
+
+
+def get_patient_data(patient_data_path):
+    slice_paths = os.listdir(patient_data_path)
+    sid2data = {}
+    sid2metadata = {}
+    for s in slice_paths:
+        slice_id = s.split('.')[0]
+        data, metadata = read_dicom(patient_data_path + '/' + s)
+        sid2data[slice_id] = data
+        sid2metadata[slice_id] = metadata
+    return sid2data, sid2metadata
+
+
+def ct2HU(x, metadata):
+    x = metadata['RescaleSlope'] * x + metadata['RescaleIntercept']
+    x[x < -1000] = -1000
+    return x
+
+
+def read_dicom_scan(patient_data_path):
+    sid2data, sid2metadata = get_patient_data(patient_data_path)
+    sid2position = {}
+    for sid in sid2data.keys():
+        sid2position[sid] = get_slice_position(sid2metadata[sid])
+    sids_sorted = sorted(sid2position.items(), key=lambda x: x[1])
+    sids_sorted = [s[0] for s in sids_sorted]
+    z_pixel_spacing = []
+    for s1, s2 in zip(sids_sorted[1:], sids_sorted[:-1]):
+        z_pixel_spacing.append(sid2position[s1] - sid2position[s2])
+    z_pixel_spacing = np.array(z_pixel_spacing)
+    try:
+        assert np.all((z_pixel_spacing - z_pixel_spacing[0]) < 0.01)
+    except:
+        print 'This patient has multiple series, we will remove one'
+        sids_sorted_2 = []
+        for s1, s2 in zip(sids_sorted[::2], sids_sorted[1::2]):
+            if sid2metadata[s1]["InstanceNumber"] > sid2metadata[s2]["InstanceNumber"]:
+                sids_sorted_2.append(s1)
+            else:
+                sids_sorted_2.append(s2)
+        sids_sorted = sids_sorted_2
+        z_pixel_spacing = []
+        for s1, s2 in zip(sids_sorted[1:], sids_sorted[:-1]):
+            z_pixel_spacing.append(sid2position[s1] - sid2position[s2])
+        z_pixel_spacing = np.array(z_pixel_spacing)
+        assert np.all((z_pixel_spacing - z_pixel_spacing[0]) < 0.01)
+
+    pixel_spacing = np.array((z_pixel_spacing[0],
+                              sid2metadata[sids_sorted[0]]['PixelSpacing'][0],
+                              sid2metadata[sids_sorted[0]]['PixelSpacing'][1]))
+
+    img = np.stack([ct2HU(sid2data[sid], sid2metadata[sid]) for sid in sids_sorted])
+
+    return img, pixel_spacing
+
+
+def sort_slices_position(patient_data):
+    return sorted(patient_data, key=lambda x: get_slice_position(x['metadata']))
+
+
+def sort_sids_by_position(sid2metadata):
+    return sorted(sid2metadata.keys(), key=lambda x: get_slice_position(sid2metadata[x]))
+
+
+def sort_slices_jonas(sid2metadata):
+    sid2position = slice_location_finder(sid2metadata)
+    return sorted(sid2metadata.keys(), key=lambda x: sid2position[x])
+
+
+def get_slice_position(slice_metadata):
+    """
+    https://www.kaggle.com/rmchamberlain/data-science-bowl-2017/dicom-to-3d-numpy-arrays
+    """
+    orientation = tuple((o for o in slice_metadata['ImageOrientationPatient']))
+    position = tuple((p for p in slice_metadata['ImagePositionPatient']))
+    rowvec, colvec = orientation[:3], orientation[3:]
+    normal_vector = np.cross(rowvec, colvec)
+    slice_pos = np.dot(position, normal_vector)
+    return slice_pos
+
+
+def slice_location_finder(sid2metadata):
+    """
+    :param slicepath2metadata: dict with arbitrary keys, and metadata values
+    :return:
+    """
+
+    sid2midpix = {}
+    sid2position = {}
+
+    for sid in sid2metadata:
+        metadata = sid2metadata[sid]
+        image_orientation = metadata["ImageOrientationPatient"]
+        image_position = metadata["ImagePositionPatient"]
+        pixel_spacing = metadata["PixelSpacing"]
+        rows = metadata['Rows']
+        columns = metadata['Columns']
+
+        # calculate value of middle pixel
+        F = np.array(image_orientation).reshape((2, 3))
+        # reversed order, as per http://nipy.org/nibabel/dicom/dicom_orientation.html
+        i, j = columns / 2.0, rows / 2.0
+        im_pos = np.array([[i * pixel_spacing[0], j * pixel_spacing[1]]], dtype='float32')
+        pos = np.array(image_position).reshape((1, 3))
+        position = np.dot(im_pos, F) + pos
+        sid2midpix[sid] = position[0, :]
+
+    if len(sid2midpix) <= 1:
+        for sp, midpix in sid2midpix.iteritems():
+            sid2position[sp] = 0.
+    else:
+        # find the keys of the 2 points furthest away from each other
+        max_dist = -1.0
+        max_dist_keys = []
+        for sp1, midpix1 in sid2midpix.iteritems():
+            for sp2, midpix2 in sid2midpix.iteritems():
+                if sp1 == sp2:
+                    continue
+                distance = np.sqrt(np.sum((midpix1 - midpix2) ** 2))
+                if distance > max_dist:
+                    max_dist_keys = [sp1, sp2]
+                    max_dist = distance
+        # project the others on the line between these 2 points
+        # sort the keys, so the order is more or less the same as they were
+        # max_dist_keys.sort(key=lambda x: int(re.search(r'/sax_(\d+)\.pkl$', x).group(1)))
+        p_ref1 = sid2midpix[max_dist_keys[0]]
+        p_ref2 = sid2midpix[max_dist_keys[1]]
+        v1 = p_ref2 - p_ref1
+        v1 /= np.linalg.norm(v1)
+
+        for sp, midpix in sid2midpix.iteritems():
+            v2 = midpix - p_ref1
+            sid2position[sp] = np.inner(v1, v2)
+
+    return sid2position
+
+
+def get_patient_data_paths(data_dir):
+    pids = sorted(os.listdir(data_dir))
+    return [data_dir + '/' + p for p in pids]
+
+def read_patient_annotations_luna(pid, directory):
+    return pickle.load(open(os.path.join(directory,pid+'.pkl'),"rb"))
+
+
+def read_labels(file_path):
+    id2labels = {}
+    train_csv = open(file_path)
+    lines = train_csv.readlines()
+    i = 0
+    for item in lines:
+        if i == 0:
+            i = 1
+            continue
+        id, label = item.replace('\n', '').split(',')
+        id2labels[id] = int(label)
+    return id2labels
+
+
+def read_test_labels(file_path):
+    id2labels = {}
+    train_csv = open(file_path)
+    lines = train_csv.readlines()
+    i = 0
+    for item in lines:
+        if i == 0:
+            i = 1
+            continue
+        id, label = item.replace('\n', '').split(';')
+        id2labels[id] = int(label)
+    return id2labels
+
+
+def read_luna_annotations(file_path):
+    id2xyzd = defaultdict(list)
+    train_csv = open(file_path)
+    lines = train_csv.readlines()
+    i = 0
+    for item in lines:
+        if i == 0:
+            i = 1
+            continue
+        id, x, y, z, d = item.replace('\n', '').split(',')
+        id2xyzd[id].append([float(z), float(y), float(x), float(d)])
+    return id2xyzd
+
+
+def read_luna_negative_candidates(file_path):
+    id2xyzd = defaultdict(list)
+    train_csv = open(file_path)
+    lines = train_csv.readlines()
+    i = 0
+    for item in lines:
+        if i == 0:
+            i = 1
+            continue
+        id, x, y, z, d = item.replace('\n', '').split(',')
+        if float(d) == 0:
+            id2xyzd[id].append([float(z), float(y), float(x), float(d)])
+    return id2xyzd
+
+
+def write_submission(pid2prediction, submission_path):
+    """
+    :param pid2prediction: dict of {patient_id: label}
+    :param submission_path:
+    """
+    f = open(submission_path, 'w+')
+    fo = csv.writer(f, lineterminator='\n')
+    fo.writerow(['id', 'cancer'])
+    for pid in pid2prediction.keys():
+        fo.writerow([pid, pid2prediction[pid]])
+    f.close()
+
+
+def filter_close_neighbors(candidates, min_dist=16):
+    #TODO pixelspacing should be added , it is now hardcoded 
+    candidates_wo_dupes = set()
+    no_pairs = 0
+    for can1 in candidates:
+        found_close_candidate = False
+        swap_candidate = None
+        for can2 in candidates_wo_dupes:
+            if (can1 == can2).all():
+                raise "Candidate should not be in the target array yet"
+            else:
+                delta = can1[:3] - can2[:3]
+                delta[0] = 2.5*delta[0] #zyx coos
+                dist = np.sum(delta**2)**(1./2)
+                if dist<min_dist:
+                    no_pairs += 1
+                    print 'Warning: there is a pair nodules close together',  can1[:3], can2[:3]
+                    found_close_candidate = True
+                    if can1[4]>can2[4]:
+                        swap_candidate = can2
+                    break
+        if not found_close_candidate:
+            candidates_wo_dupes.add(tuple(can1))
+        elif swap_candidate:
+            candidates_wo_dupes.remove(swap_candidate)
+            candidates_wo_dupes.add(tuple(can1))
+    print 'n candidates filtered out', no_pairs
+    return candidates_wo_dupes
+
+def dice_index(predictions, targets, epsilon=1e-12):
+    predictions = np.asarray(predictions).flatten()
+    targets = np.asarray(targets).flatten()
+    dice = (2. * np.sum(targets * predictions) + epsilon) / (np.sum(predictions) + np.sum(targets) + epsilon)
+    return dice
+
+
+def cross_entropy(predictions, targets, epsilon=1e-12):
+    predictions = np.asarray(predictions).flatten()
+    predictions = np.clip(predictions, epsilon, 1. - epsilon)
+    targets = np.asarray(targets).flatten()
+    ce = np.mean(np.log(predictions) * targets + np.log(1 - predictions) * (1. - targets))
+    return ce
+
+
+def get_generated_pids(predictions_dir):
+    pids = []
+    if os.path.isdir(predictions_dir):
+        pids = os.listdir(predictions_dir)
+        pids = [extract_pid_filename(p) for p in pids]
+    return pids
+
+def evaluate_log_loss(pid2prediction, pid2label):
+    predictions, labels = [], []
+    assert set(pid2prediction.keys()) == set(pid2label.keys())
+    for k, v in pid2prediction.iteritems():
+        predictions.append(v)
+        labels.append(pid2label[k])
+    return log_loss(labels, predictions)
+
+
+def log_loss(y_real, y_pred, eps=1e-15):
+    y_pred = np.clip(y_pred, eps, 1 - eps)
+    y_real = np.array(y_real)
+    losses = y_real * np.log(y_pred) + (1 - y_real) * np.log(1 - y_pred)
+    return - np.average(losses)
+
+
+def read_luna_properties(file_path):
+    id2xyzp = defaultdict(list)
+    train_csv = open(file_path)
+    lines = train_csv.readlines()
+    i = 0
+    for item in lines:
+        if i == 0:
+            i = 1
+            continue
+        annotation = item.replace('\n', '').split(',')
+        id = annotation[0]
+        x = float(annotation[1])
+        y = float(annotation[2])
+        z = float(annotation[3])
+        d = float(annotation[4])
+        properties_dict = {
+            'diameter': d,
+            'calcification': float(annotation[5]),
+            'internalStructure': float(annotation[6]),
+            'lobulation': float(annotation[7]),
+            'malignancy': float(annotation[8]),
+            'margin': float(annotation[9]),
+            'sphericity': float(annotation[10]),
+            'spiculation': float(annotation[11]),
+            'subtlety': float(annotation[12]),
+            'texture': float(annotation[13]),
+        }
+
+        id2xyzp[id].append([z, y, x, d, properties_dict])
+    return id2xyzp