Diff of /skull-stripping/data.py [000000] .. [c9a8c4]

Switch to side-by-side view

--- a
+++ b/skull-stripping/data.py
@@ -0,0 +1,85 @@
+from __future__ import print_function
+
+import numpy as np
+import os
+
+from skimage.io import imread
+from skimage.transform import rescale
+from skimage.transform import rotate
+
+image_rows = 256
+image_cols = 256
+
+channels = 3    # refers to neighboring slices; if set to 3, takes previous and next slice as additional channels
+modalities = 1  # refers to pre, flair and post modalities; if set to 3, uses all and if set to 1, only flair
+
+
+def load_data(path):
+    """
+    Assumes filenames in given path to be in the following format as defined in `preprocessing3D.m`:
+    for images: <case_id>_<slice_number>.tif
+    for masks: <case_id>_<slice_number>_mask.tif
+
+        Args:
+            path: string to the folder with images
+
+        Returns:
+            np.ndarray: array of images
+            np.ndarray: array of masks
+            np.chararray: array of corresponding images' filenames without extensions
+    """
+    images_list = os.listdir(path)
+    total_count = len(images_list) / 2
+    images = np.ndarray((total_count, image_rows, image_cols,
+                         channels * modalities), dtype=np.uint8)
+    masks = np.ndarray((total_count, image_rows, image_cols), dtype=np.uint8)
+    names = np.chararray(total_count, itemsize=64)
+
+    i = 0
+    for image_name in images_list:
+        if 'mask' in image_name:
+            continue
+
+        names[i] = image_name.split('.')[0]
+        slice_number = int(names[i].split('_')[-1])
+        patient_id = '_'.join(names[i].split('_')[:-1])
+
+        image_mask_name = image_name.split('.')[0] + '_mask.tif'
+        img = imread(os.path.join(path, image_name), as_grey=(modalities == 1))
+        img_mask = imread(os.path.join(path, image_mask_name), as_grey=True)
+
+        if channels > 1:
+            img_prev = read_slice(path, patient_id, slice_number - 1)
+            img_next = read_slice(path, patient_id, slice_number + 1)
+
+            img = np.dstack((img_prev, img[..., np.newaxis], img_next))
+
+        elif modalities == 1:
+            img = np.array([img])
+
+        img_mask = np.array([img_mask])
+
+        images[i] = img
+        masks[i] = img_mask
+
+        i += 1
+
+    images = images.astype('float32')
+    masks = masks[..., np.newaxis]
+    masks = masks.astype('float32')
+    masks /= 255.
+
+    return images, masks, names
+
+
+def read_slice(path, patient_id, slice):
+    img = np.zeros((image_rows, image_cols))
+    img_name = patient_id + '_' + str(slice) + '.tif'
+    img_path = os.path.join(path, img_name)
+
+    try:
+        img = imread(img_path, as_grey=(modalities == 1))
+    except Exception:
+        pass
+
+    return img[..., np.newaxis]