[89883a]: / tests.py

Download this file

57 lines (46 with data), 1.7 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import os
import unittest
from models.unet import get_unet_model
from train import load_data, load_dataset, shuffling
class TestClassifier(unittest.TestCase):
def test_model(self):
"""
Check the input layer and the output layer of the model.
Args:
None
Returns:
None
"""
model = get_unet_model((512, 512, 3))
# Check the image height
self.assertEqual(
model.get_layer("input_1").input_shape[0][0],
tuple(model.get_layer("conv2d_18").output.shape)[0],
)
# Check the image width
self.assertEqual(
model.get_layer("input_1").input_shape[0][1],
tuple(model.get_layer("conv2d_18").output.shape)[1],
)
def test_dataloader(self):
"""
Check the number of images and masks of the training and validation pipeline.
Args:
None
Returns:
None
"""
dataset_path = os.path.join("new_data")
train_path = os.path.join(dataset_path, "train")
valid_path = os.path.join(dataset_path, "valid")
train_x, train_y = load_data(train_path)
train_x, train_y = shuffling(train_x, train_y)
valid_x, valid_y = load_data(valid_path)
train_dataset = load_dataset(train_x, train_y, batch_size=16)
valid_dataset = load_dataset(valid_x, valid_y, batch_size=16)
train_image, train_mask = next(iter(train_dataset))
valid_image, valid_mask = next(iter(valid_dataset))
self.assertEqual(len(train_image), len(train_mask))
self.assertEqual(len(valid_image), len(valid_mask))
if __name__ == "__main__":
unittest.main()