Diff of /tests.py [000000] .. [dce3d9]

Switch to unified view

a b/tests.py
1
import os
2
import unittest
3
4
from models.unet import get_unet_model
5
from train import load_data, load_dataset, shuffling
6
7
8
class TestClassifier(unittest.TestCase):
9
    def test_model(self):
10
        """
11
        Check the input layer and the output layer of the model.
12
        Args:
13
            None
14
        Returns:
15
            None
16
        """
17
        model = get_unet_model((512, 512, 3))
18
        # Check the image height
19
        self.assertEqual(
20
            model.get_layer("input_1").input_shape[0][0],
21
            tuple(model.get_layer("conv2d_18").output.shape)[0],
22
        )
23
        # Check the image width
24
        self.assertEqual(
25
            model.get_layer("input_1").input_shape[0][1],
26
            tuple(model.get_layer("conv2d_18").output.shape)[1],
27
        )
28
29
    def test_dataloader(self):
30
        """
31
        Check the number of images and masks of the training and validation pipeline.
32
        Args:
33
            None
34
        Returns:
35
            None
36
        """
37
        dataset_path = os.path.join("new_data")
38
        train_path = os.path.join(dataset_path, "train")
39
        valid_path = os.path.join(dataset_path, "valid")
40
41
        train_x, train_y = load_data(train_path)
42
        train_x, train_y = shuffling(train_x, train_y)
43
        valid_x, valid_y = load_data(valid_path)
44
45
        train_dataset = load_dataset(train_x, train_y, batch_size=16)
46
        valid_dataset = load_dataset(valid_x, valid_y, batch_size=16)
47
48
        train_image, train_mask = next(iter(train_dataset))
49
        valid_image, valid_mask = next(iter(valid_dataset))
50
51
        self.assertEqual(len(train_image), len(train_mask))
52
        self.assertEqual(len(valid_image), len(valid_mask))
53
54
55
if __name__ == "__main__":
56
    unittest.main()