Diff of /fetal_net/data.py [000000] .. [ccb1dd]

Switch to side-by-side view

--- a
+++ b/fetal_net/data.py
@@ -0,0 +1,78 @@
+import os
+
+import numpy as np
+import tables
+from scipy.ndimage import zoom
+
+from fetal_net.utils.utils import read_img, resize
+from .normalize import normalize_data_storage, normalize_data_storage_each
+
+
+def create_data_file(out_file, n_samples):
+    hdf5_file = tables.open_file(out_file, mode='w')
+    filters = tables.Filters(complevel=5, complib='blosc')
+    data_storage = hdf5_file.create_vlarray(hdf5_file.root, 'data', tables.ObjectAtom(), filters=filters, expectedrows=n_samples)
+    truth_storage = hdf5_file.create_vlarray(hdf5_file.root, 'truth', tables.ObjectAtom(), filters=filters, expectedrows=n_samples)
+    mask_storage = hdf5_file.create_vlarray(hdf5_file.root, 'mask', tables.ObjectAtom(), filters=filters, expectedrows=n_samples)
+    return hdf5_file, data_storage, truth_storage, mask_storage
+
+
+def write_image_data_to_file(image_files, data_storage, truth_storage, mask_storage, truth_dtype=np.uint8, scale=None,
+                             preproc=None):
+    for set_of_files in image_files:
+        images = [read_img(_) for _ in set_of_files]
+        subject_data = [image.get_data() for image in images]
+        if scale is not None:
+            subject_data[0] = zoom(subject_data[0], scale) # for sub_data in subject_data]
+            subject_data[1] = zoom(subject_data[1], scale, order=0) # for sub_data in subject_data]
+        if preproc is not None:
+            subject_data[0] = preproc(subject_data[0])
+        print(subject_data[0].shape)
+        add_data_to_storage(data_storage, truth_storage, mask_storage, subject_data, truth_dtype)
+    return data_storage, truth_storage, mask_storage
+
+
+def add_data_to_storage(data_storage, truth_storage, mask_storage, subject_data, truth_dtype):
+    data_storage.append(np.asarray(subject_data[0]).astype(np.float))
+    truth_storage.append(np.asarray(subject_data[1], dtype=truth_dtype))
+    if len(subject_data) > 2:
+        mask_storage.append(np.asarray(subject_data[2]).astype(np.float))
+
+
+def write_data_to_file(training_data_files, out_file, truth_dtype=np.uint8,
+                       subject_ids=None, normalize='all', scale=None, preproc=None):
+    """
+    Takes in a set of training images and writes those images to an hdf5 file.
+    :param training_data_files: List of tuples containing the training data files. The modalities should be listed in
+    the same order in each tuple. The last item in each tuple must be the labeled image. 
+    Example: [('sub1-T1.nii.gz', 'sub1-T2.nii.gz', 'sub1-truth.nii.gz'), 
+              ('sub2-T1.nii.gz', 'sub2-T2.nii.gz', 'sub2-truth.nii.gz')]
+    :param out_file: Where the hdf5 file will be written to.
+    :param truth_dtype: Default is 8-bit unsigned integer.
+    :return: Location of the hdf5 file with the image data written to it. 
+    """
+    n_samples = len(training_data_files)
+    try:
+        hdf5_file, data_storage, truth_storage, mask_storage = create_data_file(out_file, n_samples=n_samples)
+    except Exception as e:
+        # If something goes wrong, delete the incomplete data file
+        os.remove(out_file)
+        raise e
+
+    write_image_data_to_file(training_data_files, data_storage, truth_storage, mask_storage,
+                             truth_dtype=truth_dtype, scale=scale, preproc=preproc)
+    if subject_ids:
+        hdf5_file.create_array(hdf5_file.root, 'subject_ids', obj=subject_ids)
+    if isinstance(normalize, str):
+        _, mean, std = {
+            'all': normalize_data_storage,
+            'each': normalize_data_storage_each
+        }[normalize](data_storage)
+    else:
+        mean, std = None, None
+    hdf5_file.close()
+    return out_file, (mean, std)
+
+
+def open_data_file(filename, readwrite="r"):
+    return tables.open_file(filename, readwrite)