Diff of /fetal_net/data.py [000000] .. [ccb1dd]

Switch to unified view

a b/fetal_net/data.py
1
import os
2
3
import numpy as np
4
import tables
5
from scipy.ndimage import zoom
6
7
from fetal_net.utils.utils import read_img, resize
8
from .normalize import normalize_data_storage, normalize_data_storage_each
9
10
11
def create_data_file(out_file, n_samples):
12
    hdf5_file = tables.open_file(out_file, mode='w')
13
    filters = tables.Filters(complevel=5, complib='blosc')
14
    data_storage = hdf5_file.create_vlarray(hdf5_file.root, 'data', tables.ObjectAtom(), filters=filters, expectedrows=n_samples)
15
    truth_storage = hdf5_file.create_vlarray(hdf5_file.root, 'truth', tables.ObjectAtom(), filters=filters, expectedrows=n_samples)
16
    mask_storage = hdf5_file.create_vlarray(hdf5_file.root, 'mask', tables.ObjectAtom(), filters=filters, expectedrows=n_samples)
17
    return hdf5_file, data_storage, truth_storage, mask_storage
18
19
20
def write_image_data_to_file(image_files, data_storage, truth_storage, mask_storage, truth_dtype=np.uint8, scale=None,
21
                             preproc=None):
22
    for set_of_files in image_files:
23
        images = [read_img(_) for _ in set_of_files]
24
        subject_data = [image.get_data() for image in images]
25
        if scale is not None:
26
            subject_data[0] = zoom(subject_data[0], scale) # for sub_data in subject_data]
27
            subject_data[1] = zoom(subject_data[1], scale, order=0) # for sub_data in subject_data]
28
        if preproc is not None:
29
            subject_data[0] = preproc(subject_data[0])
30
        print(subject_data[0].shape)
31
        add_data_to_storage(data_storage, truth_storage, mask_storage, subject_data, truth_dtype)
32
    return data_storage, truth_storage, mask_storage
33
34
35
def add_data_to_storage(data_storage, truth_storage, mask_storage, subject_data, truth_dtype):
36
    data_storage.append(np.asarray(subject_data[0]).astype(np.float))
37
    truth_storage.append(np.asarray(subject_data[1], dtype=truth_dtype))
38
    if len(subject_data) > 2:
39
        mask_storage.append(np.asarray(subject_data[2]).astype(np.float))
40
41
42
def write_data_to_file(training_data_files, out_file, truth_dtype=np.uint8,
43
                       subject_ids=None, normalize='all', scale=None, preproc=None):
44
    """
45
    Takes in a set of training images and writes those images to an hdf5 file.
46
    :param training_data_files: List of tuples containing the training data files. The modalities should be listed in
47
    the same order in each tuple. The last item in each tuple must be the labeled image. 
48
    Example: [('sub1-T1.nii.gz', 'sub1-T2.nii.gz', 'sub1-truth.nii.gz'), 
49
              ('sub2-T1.nii.gz', 'sub2-T2.nii.gz', 'sub2-truth.nii.gz')]
50
    :param out_file: Where the hdf5 file will be written to.
51
    :param truth_dtype: Default is 8-bit unsigned integer.
52
    :return: Location of the hdf5 file with the image data written to it. 
53
    """
54
    n_samples = len(training_data_files)
55
    try:
56
        hdf5_file, data_storage, truth_storage, mask_storage = create_data_file(out_file, n_samples=n_samples)
57
    except Exception as e:
58
        # If something goes wrong, delete the incomplete data file
59
        os.remove(out_file)
60
        raise e
61
62
    write_image_data_to_file(training_data_files, data_storage, truth_storage, mask_storage,
63
                             truth_dtype=truth_dtype, scale=scale, preproc=preproc)
64
    if subject_ids:
65
        hdf5_file.create_array(hdf5_file.root, 'subject_ids', obj=subject_ids)
66
    if isinstance(normalize, str):
67
        _, mean, std = {
68
            'all': normalize_data_storage,
69
            'each': normalize_data_storage_each
70
        }[normalize](data_storage)
71
    else:
72
        mean, std = None, None
73
    hdf5_file.close()
74
    return out_file, (mean, std)
75
76
77
def open_data_file(filename, readwrite="r"):
78
    return tables.open_file(filename, readwrite)