Diff of /skull-stripping/data.py [000000] .. [c9a8c4]

Switch to unified view

a b/skull-stripping/data.py
1
from __future__ import print_function
2
3
import numpy as np
4
import os
5
6
from skimage.io import imread
7
from skimage.transform import rescale
8
from skimage.transform import rotate
9
10
image_rows = 256
11
image_cols = 256
12
13
channels = 3    # refers to neighboring slices; if set to 3, takes previous and next slice as additional channels
14
modalities = 1  # refers to pre, flair and post modalities; if set to 3, uses all and if set to 1, only flair
15
16
17
def load_data(path):
18
    """
19
    Assumes filenames in given path to be in the following format as defined in `preprocessing3D.m`:
20
    for images: <case_id>_<slice_number>.tif
21
    for masks: <case_id>_<slice_number>_mask.tif
22
23
        Args:
24
            path: string to the folder with images
25
26
        Returns:
27
            np.ndarray: array of images
28
            np.ndarray: array of masks
29
            np.chararray: array of corresponding images' filenames without extensions
30
    """
31
    images_list = os.listdir(path)
32
    total_count = len(images_list) / 2
33
    images = np.ndarray((total_count, image_rows, image_cols,
34
                         channels * modalities), dtype=np.uint8)
35
    masks = np.ndarray((total_count, image_rows, image_cols), dtype=np.uint8)
36
    names = np.chararray(total_count, itemsize=64)
37
38
    i = 0
39
    for image_name in images_list:
40
        if 'mask' in image_name:
41
            continue
42
43
        names[i] = image_name.split('.')[0]
44
        slice_number = int(names[i].split('_')[-1])
45
        patient_id = '_'.join(names[i].split('_')[:-1])
46
47
        image_mask_name = image_name.split('.')[0] + '_mask.tif'
48
        img = imread(os.path.join(path, image_name), as_grey=(modalities == 1))
49
        img_mask = imread(os.path.join(path, image_mask_name), as_grey=True)
50
51
        if channels > 1:
52
            img_prev = read_slice(path, patient_id, slice_number - 1)
53
            img_next = read_slice(path, patient_id, slice_number + 1)
54
55
            img = np.dstack((img_prev, img[..., np.newaxis], img_next))
56
57
        elif modalities == 1:
58
            img = np.array([img])
59
60
        img_mask = np.array([img_mask])
61
62
        images[i] = img
63
        masks[i] = img_mask
64
65
        i += 1
66
67
    images = images.astype('float32')
68
    masks = masks[..., np.newaxis]
69
    masks = masks.astype('float32')
70
    masks /= 255.
71
72
    return images, masks, names
73
74
75
def read_slice(path, patient_id, slice):
76
    img = np.zeros((image_rows, image_cols))
77
    img_name = patient_id + '_' + str(slice) + '.tif'
78
    img_path = os.path.join(path, img_name)
79
80
    try:
81
        img = imread(img_path, as_grey=(modalities == 1))
82
    except Exception:
83
        pass
84
85
    return img[..., np.newaxis]