|
a |
|
b/fetal_net/data.py |
|
|
1 |
import os |
|
|
2 |
|
|
|
3 |
import numpy as np |
|
|
4 |
import tables |
|
|
5 |
from scipy.ndimage import zoom |
|
|
6 |
|
|
|
7 |
from fetal_net.utils.utils import read_img, resize |
|
|
8 |
from .normalize import normalize_data_storage, normalize_data_storage_each |
|
|
9 |
|
|
|
10 |
|
|
|
11 |
def create_data_file(out_file, n_samples): |
|
|
12 |
hdf5_file = tables.open_file(out_file, mode='w') |
|
|
13 |
filters = tables.Filters(complevel=5, complib='blosc') |
|
|
14 |
data_storage = hdf5_file.create_vlarray(hdf5_file.root, 'data', tables.ObjectAtom(), filters=filters, expectedrows=n_samples) |
|
|
15 |
truth_storage = hdf5_file.create_vlarray(hdf5_file.root, 'truth', tables.ObjectAtom(), filters=filters, expectedrows=n_samples) |
|
|
16 |
mask_storage = hdf5_file.create_vlarray(hdf5_file.root, 'mask', tables.ObjectAtom(), filters=filters, expectedrows=n_samples) |
|
|
17 |
return hdf5_file, data_storage, truth_storage, mask_storage |
|
|
18 |
|
|
|
19 |
|
|
|
20 |
def write_image_data_to_file(image_files, data_storage, truth_storage, mask_storage, truth_dtype=np.uint8, scale=None, |
|
|
21 |
preproc=None): |
|
|
22 |
for set_of_files in image_files: |
|
|
23 |
images = [read_img(_) for _ in set_of_files] |
|
|
24 |
subject_data = [image.get_data() for image in images] |
|
|
25 |
if scale is not None: |
|
|
26 |
subject_data[0] = zoom(subject_data[0], scale) # for sub_data in subject_data] |
|
|
27 |
subject_data[1] = zoom(subject_data[1], scale, order=0) # for sub_data in subject_data] |
|
|
28 |
if preproc is not None: |
|
|
29 |
subject_data[0] = preproc(subject_data[0]) |
|
|
30 |
print(subject_data[0].shape) |
|
|
31 |
add_data_to_storage(data_storage, truth_storage, mask_storage, subject_data, truth_dtype) |
|
|
32 |
return data_storage, truth_storage, mask_storage |
|
|
33 |
|
|
|
34 |
|
|
|
35 |
def add_data_to_storage(data_storage, truth_storage, mask_storage, subject_data, truth_dtype): |
|
|
36 |
data_storage.append(np.asarray(subject_data[0]).astype(np.float)) |
|
|
37 |
truth_storage.append(np.asarray(subject_data[1], dtype=truth_dtype)) |
|
|
38 |
if len(subject_data) > 2: |
|
|
39 |
mask_storage.append(np.asarray(subject_data[2]).astype(np.float)) |
|
|
40 |
|
|
|
41 |
|
|
|
42 |
def write_data_to_file(training_data_files, out_file, truth_dtype=np.uint8, |
|
|
43 |
subject_ids=None, normalize='all', scale=None, preproc=None): |
|
|
44 |
""" |
|
|
45 |
Takes in a set of training images and writes those images to an hdf5 file. |
|
|
46 |
:param training_data_files: List of tuples containing the training data files. The modalities should be listed in |
|
|
47 |
the same order in each tuple. The last item in each tuple must be the labeled image. |
|
|
48 |
Example: [('sub1-T1.nii.gz', 'sub1-T2.nii.gz', 'sub1-truth.nii.gz'), |
|
|
49 |
('sub2-T1.nii.gz', 'sub2-T2.nii.gz', 'sub2-truth.nii.gz')] |
|
|
50 |
:param out_file: Where the hdf5 file will be written to. |
|
|
51 |
:param truth_dtype: Default is 8-bit unsigned integer. |
|
|
52 |
:return: Location of the hdf5 file with the image data written to it. |
|
|
53 |
""" |
|
|
54 |
n_samples = len(training_data_files) |
|
|
55 |
try: |
|
|
56 |
hdf5_file, data_storage, truth_storage, mask_storage = create_data_file(out_file, n_samples=n_samples) |
|
|
57 |
except Exception as e: |
|
|
58 |
# If something goes wrong, delete the incomplete data file |
|
|
59 |
os.remove(out_file) |
|
|
60 |
raise e |
|
|
61 |
|
|
|
62 |
write_image_data_to_file(training_data_files, data_storage, truth_storage, mask_storage, |
|
|
63 |
truth_dtype=truth_dtype, scale=scale, preproc=preproc) |
|
|
64 |
if subject_ids: |
|
|
65 |
hdf5_file.create_array(hdf5_file.root, 'subject_ids', obj=subject_ids) |
|
|
66 |
if isinstance(normalize, str): |
|
|
67 |
_, mean, std = { |
|
|
68 |
'all': normalize_data_storage, |
|
|
69 |
'each': normalize_data_storage_each |
|
|
70 |
}[normalize](data_storage) |
|
|
71 |
else: |
|
|
72 |
mean, std = None, None |
|
|
73 |
hdf5_file.close() |
|
|
74 |
return out_file, (mean, std) |
|
|
75 |
|
|
|
76 |
|
|
|
77 |
def open_data_file(filename, readwrite="r"): |
|
|
78 |
return tables.open_file(filename, readwrite) |