Diff of /test/test_generator.py [000000] .. [ccb1dd]

Switch to unified view

a b/test/test_generator.py
1
import os
2
from unittest import TestCase
3
4
import numpy as np
5
6
from fetal_net.data import add_data_to_storage, create_data_file
7
from fetal_net.generator import get_multi_class_labels, get_training_and_validation_generators
8
from fetal_net.augment import generate_permutation_keys, permute_data, reverse_permute_data
9
10
11
class TestDataGenerator(TestCase):
12
    def setUp(self):
13
        self.tmp_files = list()
14
        self.data_file = None
15
16
    def tearDown(self):
17
        if self.data_file:
18
            self.data_file.close()
19
        self.rm_tmp_files()
20
21
    def create_data_file(self, n_samples=20, len_x=5, len_y=5, len_z=10, n_channels=1):
22
        self.data_file_path = "./temporary_data_test_file.h5"
23
        self.training_keys_file = "./temporary_training_keys_file.pkl"
24
        self.validation_keys_file = "./temporary_validation_keys_file.pkl"
25
        self.tmp_files = [self.data_file_path, self.training_keys_file, self.validation_keys_file]
26
27
        self.rm_tmp_files()
28
29
        self.n_samples = n_samples
30
        self.n_channels = n_channels
31
        self.n_labels = 1
32
33
        image_shape = (len_x, len_y, len_z)
34
        data_size = self.n_samples * self.n_channels * len_x * len_y * len_z
35
        data = np.asarray(np.arange(data_size).reshape((self.n_samples, self.n_channels, len_x, len_y, len_z)),
36
                          dtype=np.int16)
37
        self.assertEqual(data.shape[-3:], image_shape)
38
        truth = (data[:, 0] == 3).astype(np.int8).reshape(data.shape[0], 1, data.shape[2], data.shape[3], data.shape[4])
39
        affine = np.diag(np.ones(4))
40
        affine[:, -1] = 1
41
        self.data_file, data_storage, truth_storage, affine_storage = create_data_file(self.data_file_path,
42
                                                                                       self.n_channels, self.n_samples,
43
                                                                                       image_shape)
44
45
        for index in range(self.n_samples):
46
            add_data_to_storage(data_storage, truth_storage, affine_storage,
47
                                np.concatenate([data[index], truth[index]], axis=0), affine=affine,
48
                                n_channels=self.n_channels,
49
                                truth_dtype=np.int16)
50
            self.assertTrue(np.all(data_storage[index] == data[index]))
51
            self.assertTrue(np.all(truth_storage[index] == truth[index]))
52
53
    def rm_tmp_files(self):
54
        for tmp_file in self.tmp_files:
55
            if os.path.exists(tmp_file):
56
                os.remove(tmp_file)
57
58
    def test_multi_class_labels(self):
59
        n_labels = 5
60
        labels = np.arange(1, n_labels+1)
61
        x_dim = 3
62
        label_map = np.asarray([[[np.arange(n_labels+1)] * x_dim]])
63
        binary_labels = get_multi_class_labels(label_map, n_labels, labels)
64
65
        for label in labels:
66
            self.assertTrue(np.all(binary_labels[:, label - 1][label_map[:, 0] == label] == 1))
67
68
    def test_get_training_and_validation_generators(self):
69
        self.create_data_file()
70
71
        validation_split = 0.8
72
        batch_size = 3
73
        validation_batch_size = 3
74
75
        generators = get_training_and_validation_generators(data_file=self.data_file, batch_size=batch_size,
76
                                                            n_labels=self.n_labels,
77
                                                            training_keys_file=self.training_keys_file,
78
                                                            validation_keys_file=self.validation_keys_file,
79
                                                            data_split=validation_split,
80
                                                            validation_batch_size=validation_batch_size,
81
                                                            skip_blank=False)
82
        training_generator, validation_generator, n_training_steps, n_validation_steps = generators
83
84
        self.verify_generator(training_generator, n_training_steps, batch_size,
85
                              np.round(validation_split * self.n_samples))
86
87
        self.verify_generator(validation_generator, n_validation_steps, validation_batch_size,
88
                              np.round((1 - validation_split) * self.n_samples))
89
90
        self.data_file.close()
91
        self.rm_tmp_files()
92
93
    def verify_generator(self, generator, steps, batch_size, expected_samples):
94
        # check that the generator covers all the samples
95
        n_validation_samples = 0
96
        validation_samples = list()
97
        for i in range(steps):
98
            x, y = next(generator)
99
            hash_x = hash(str(x))
100
            self.assertNotIn(hash_x, validation_samples)
101
            validation_samples.append(hash_x)
102
            n_validation_samples += x.shape[0]
103
            if i + 1 != steps:
104
                self.assertEqual(x.shape[0], batch_size)
105
        self.assertEqual(n_validation_samples, expected_samples)
106
107
    def test_patch_generators(self):
108
        self.create_data_file(len_x=4, len_y=4, len_z=4)
109
110
        validation_split = 0.8
111
        batch_size = 10
112
        validation_batch_size = 3
113
        patch_shape = (2, 2, 2)
114
115
        generators = get_training_and_validation_generators(self.data_file, batch_size, self.n_labels,
116
                                                            self.training_keys_file, self.validation_keys_file,
117
                                                            patch_shape=patch_shape, data_split=validation_split,
118
                                                            validation_batch_size=validation_batch_size,
119
                                                            skip_blank=False)
120
        training_generator, validation_generator, n_training_steps, n_validation_steps = generators
121
122
        expected_training_samples = int(np.round(self.n_samples * validation_split)) * 2**3
123
124
        self.verify_generator(training_generator, n_training_steps, batch_size, expected_training_samples)
125
126
        expected_validation_samples = int(np.round(self.n_samples * (1 - validation_split))) * 2**3
127
128
        self.verify_generator(validation_generator, n_validation_steps, validation_batch_size,
129
                              expected_validation_samples)
130
131
        self.data_file.close()
132
        self.rm_tmp_files()
133
134
    def test_random_patch_start(self):
135
        self.create_data_file(len_x=10, len_y=10, len_z=10)
136
137
        validation_split = 0.8
138
        batch_size = 10
139
        validation_batch_size = 3
140
        patch_shape = (5, 5, 5)
141
        random_start = (3, 3, 3)
142
        overlap = 2
143
144
        generators = get_training_and_validation_generators(self.data_file, batch_size, self.n_labels,
145
                                                            self.training_keys_file, self.validation_keys_file,
146
                                                            patch_shape=patch_shape, data_split=validation_split,
147
                                                            validation_batch_size=validation_batch_size,
148
                                                            skip_blank=False)
149
150
        training_generator, validation_generator, n_training_steps, n_validation_steps = generators
151
152
        expected_training_samples = int(np.round(self.n_samples * validation_split)) * 2**3
153
154
        self.verify_generator(training_generator, n_training_steps, batch_size, expected_training_samples)
155
156
        expected_validation_samples = int(np.round(self.n_samples * (1 - validation_split))) * 4**3
157
158
        self.verify_generator(validation_generator, n_validation_steps, validation_batch_size,
159
                              expected_validation_samples)
160
161
        self.data_file.close()
162
        self.rm_tmp_files()
163
164
    def test_unique_permutations(self):
165
        permutations = list()
166
        shape = (2, 3, 3, 3)
167
        data = np.arange(54).reshape(shape)
168
        for key in generate_permutation_keys():
169
            permutations.append(permute_data(data, key))
170
            for array in permutations[:-1]:
171
                self.assertTrue(permutations[-1].shape == shape)
172
                self.assertFalse(np.all(array == permutations[-1]))
173
                self.assertEqual(np.sum(data), np.sum(permutations[-1]))
174
175
    def test_n_permutations(self):
176
        self.assertEqual(len(generate_permutation_keys()), 48)
177
178
    def test_generator_with_permutations(self):
179
        self.create_data_file(len_x=5, len_y=5, len_z=5, n_channels=5)
180
        batch_size = 2
181
        generators = get_training_and_validation_generators(self.data_file, batch_size, self.n_labels,
182
                                                            self.training_keys_file, self.validation_keys_file)
183
        training_generator, validation_generator, n_training_steps, n_validation_steps = generators
184
185
        for x in training_generator:
186
            break
187
188
        self.rm_tmp_files()
189
190
    def test_reverse_permutation(self):
191
        data_shape = (4, 32, 32, 32)
192
        data = np.arange(np.prod(data_shape)).reshape(data_shape)
193
        for permutation_key in generate_permutation_keys():
194
            permuted_data = permute_data(data, permutation_key)
195
            reversed_permutation = reverse_permute_data(permuted_data, permutation_key)
196
            self.assertTrue(np.all(data == reversed_permutation))