a b/fetal_net/normalize.py
1
import os
2
3
import numpy as np
4
from nilearn.image import new_img_like
5
6
from fetal_net.utils.utils import resize, read_image_files
7
from .utils import crop_img, crop_img_to, read_image
8
9
10
def find_downsized_info(training_data_files, input_shape):
11
    foreground = get_complete_foreground(training_data_files)
12
    crop_slices = crop_img(foreground, return_slices=True, copy=True)
13
    cropped = crop_img_to(foreground, crop_slices, copy=True)
14
    final_image = resize(cropped, new_shape=input_shape, interpolation="nearest")
15
    return crop_slices, final_image.affine, final_image.header
16
17
18
def get_cropping_parameters(in_files):
19
    if len(in_files) > 1:
20
        foreground = get_complete_foreground(in_files)
21
    else:
22
        foreground = get_foreground_from_set_of_files(in_files[0], return_image=True)
23
    return crop_img(foreground, return_slices=True, copy=True)
24
25
26
def reslice_image_set(in_files, image_shape, out_files=None, label_indices=None, crop=False):
27
    if crop:
28
        crop_slices = get_cropping_parameters([in_files])
29
    else:
30
        crop_slices = None
31
    images = read_image_files(in_files, image_shape=image_shape, crop=crop_slices, label_indices=label_indices)
32
    if out_files:
33
        for image, out_file in zip(images, out_files):
34
            image.to_filename(out_file)
35
        return [os.path.abspath(out_file) for out_file in out_files]
36
    else:
37
        return images
38
39
40
def get_complete_foreground(training_data_files):
41
    for i, set_of_files in enumerate(training_data_files):
42
        subject_foreground = get_foreground_from_set_of_files(set_of_files)
43
        if i == 0:
44
            foreground = subject_foreground
45
        else:
46
            foreground[subject_foreground > 0] = 1
47
48
    return new_img_like(read_image(training_data_files[0][-1]), foreground)
49
50
51
def get_foreground_from_set_of_files(set_of_files, background_value=0, tolerance=0.00001, return_image=False):
52
    for i, image_file in enumerate(set_of_files):
53
        image = read_image(image_file)
54
        is_foreground = np.logical_or(image.get_data() < (background_value - tolerance),
55
                                      image.get_data() > (background_value + tolerance))
56
        if i == 0:
57
            foreground = np.zeros(is_foreground.shape, dtype=np.uint8)
58
59
        foreground[is_foreground] = 1
60
    if return_image:
61
        return new_img_like(image, foreground)
62
    else:
63
        return foreground
64
65
66
def normalize_data(data, mean, std):
67
    data -= mean
68
    data /= std
69
    return data
70
71
72
def normalize_data_storage(data_storage):
73
    means = list()
74
    stds = list()
75
    for index in range(data_storage.shape[0]):
76
        data = data_storage[index]
77
        means.append(data.mean(axis=(-1, -2, -3)))
78
        stds.append(data.std(axis=(-1, -2, -3)))
79
    mean = np.asarray(means).mean(axis=0)
80
    std = np.asarray(stds).mean(axis=0)
81
    for index in range(data_storage.shape[0]):
82
        data_storage[index] = normalize_data(data_storage[index], mean, std)
83
    return data_storage, mean, std
84
85
86
def normalize_data_storage_each(data_storage):
87
    for index in range(data_storage.shape[0]):
88
        data = data_storage[index]
89
        mean = data.mean(axis=(-1, -2, -3))
90
        std = data.std(axis=(-1, -2, -3))
91
        data_storage[index] = normalize_data(data, mean, std)
92
    return data_storage, None, None