a b/tests/test_dataset.py
1
from __future__ import division, print_function
2
3
import unittest
4
5
import numpy as np
6
from rvseg import dataset
7
8
class TestDataset(unittest.TestCase):
9
    def test_generator(self):
10
        self._test_generator(mask='inner')
11
        self._test_generator(mask='outer')
12
        self._test_generator(mask='both')
13
14
    def test_no_validation(self):
15
        self._test_no_validation(mask='inner')
16
        self._test_no_validation(mask='outer')
17
        self._test_no_validation(mask='both')
18
19
    def _test_generator(self, mask):
20
        data_dir = "../test-assets/"
21
        batch_size = 2
22
        validation_split = 0.5
23
        # With a total of 3 training images, this split will create 1
24
        # training image and 2 validation images
25
26
        (train_generator, train_steps_per_epoch,
27
         val_generator, val_steps_per_epoch) = dataset.create_generators(
28
             data_dir, batch_size,
29
             validation_split=validation_split,
30
             mask=mask)
31
32
        self.assertEqual(train_steps_per_epoch, 1)
33
        self.assertEqual(val_steps_per_epoch, 1)
34
35
        classes = 3 if mask == 'both' else 2
36
37
        images, masks = next(train_generator)
38
        self.assertEqual(images.shape, (1, 216, 256, 1))
39
        self.assertEqual(masks.shape, (1, 216, 256, classes))
40
41
        images, masks = next(val_generator)
42
        self.assertEqual(images.shape, (2, 216, 256, 1))
43
        self.assertEqual(masks.shape, (2, 216, 256, classes))
44
45
    def _test_no_validation(self, mask):
46
        data_dir = "../test-assets/"
47
        batch_size = 2
48
        validation_split = 0.0
49
50
        (train_generator, train_steps_per_epoch,
51
         val_generator, val_steps_per_epoch) = dataset.create_generators(
52
             data_dir, batch_size,
53
             validation_split=validation_split,
54
             mask=mask)
55
56
        self.assertEqual(train_steps_per_epoch, 2)
57
        self.assertEqual(val_steps_per_epoch, 0)
58
59
        classes = 3 if mask == 'both' else 2
60
61
        # first 2 train images
62
        images, masks = next(train_generator)
63
        self.assertEqual(images.shape, (2, 216, 256, 1))
64
        self.assertEqual(masks.shape, (2, 216, 256, classes))
65
66
        # last train image (for total of 3)
67
        images, masks = next(train_generator)
68
        self.assertEqual(images.shape, (1, 216, 256, 1))
69
        self.assertEqual(masks.shape, (1, 216, 256, classes))
70
71
        # first 2 train images again
72
        images, masks = next(train_generator)
73
        self.assertEqual(images.shape, (2, 216, 256, 1))
74
        self.assertEqual(masks.shape, (2, 216, 256, classes))
75
76
        # validation generator should be nothing
77
        self.assertEqual(val_generator, None)
78
79
80
    def test_shuffle_train_val(self):
81
        # test shuffling of entire dataset prior to train-val split
82
        # (does not test shuffling within each epoch)
83
        data_dir = "../test-assets/"
84
        batch_size = 2
85
        validation_split = 0.5
86
        mask = "inner"
87
        classes = 2
88
        seed = 5               # random number seed
89
90
        # there should be 2 images in the validation set, and we'll check if
91
        # they always appear in the same order with a fixed seed
92
        image_list = []
93
        mask_list = []
94
        for i in range(10):
95
            _, _, val_generator, _ = dataset.create_generators(
96
                data_dir, batch_size, validation_split=validation_split,
97
                mask=mask, shuffle_train_val=True, shuffle=False, seed=seed,
98
                normalize_images=True)
99
100
            images, masks = next(val_generator)
101
            self.assertEqual(images.shape, (2, 216, 256, 1))
102
            self.assertEqual(masks.shape, (2, 216, 256, classes))
103
104
            # also check image normalization
105
            for image in images:
106
                self.assertAlmostEqual(np.mean(image), 0)
107
                self.assertAlmostEqual(np.std(image), 1, places=5)
108
109
            image_list.append(images[0])
110
            mask_list.append(masks[0])
111
112
        # first image/mask in each case should be the same
113
        image0 = image_list[0]
114
        for image in image_list[1:]:
115
            np.testing.assert_array_equal(image0, image)
116
        mask0 = mask_list[0]
117
        for mask in mask_list[1:]:
118
            np.testing.assert_array_equal(mask0, mask)
119
120
        # now test that things get shuffled if we don't specify a seed
121
        mask = "both"
122
        _, _, val_generator, _ = dataset.create_generators(
123
            data_dir, batch_size, validation_split=validation_split,
124
            mask=mask, shuffle_train_val=True, shuffle=False, seed=None,
125
            normalize_images=True)
126
127
        images, masks = next(val_generator)
128
        image0 = images[0]
129
        while 1:
130
            _, _, val_generator, _ = dataset.create_generators(
131
                data_dir, batch_size, validation_split=validation_split,
132
                mask=mask, shuffle_train_val=True, shuffle=True, seed=None,
133
                normalize_images=True)
134
            images, masks = next(val_generator)            
135
            try:
136
                np.testing.assert_array_equal(image0, images[0])
137
            except AssertionError:
138
                break           # break if arrays are differet (= success!)