|
a |
|
b/tests.py |
|
|
1 |
import os |
|
|
2 |
import unittest |
|
|
3 |
|
|
|
4 |
from models.unet import get_unet_model |
|
|
5 |
from train import load_data, load_dataset, shuffling |
|
|
6 |
|
|
|
7 |
|
|
|
8 |
class TestClassifier(unittest.TestCase): |
|
|
9 |
def test_model(self): |
|
|
10 |
""" |
|
|
11 |
Check the input layer and the output layer of the model. |
|
|
12 |
Args: |
|
|
13 |
None |
|
|
14 |
Returns: |
|
|
15 |
None |
|
|
16 |
""" |
|
|
17 |
model = get_unet_model((512, 512, 3)) |
|
|
18 |
# Check the image height |
|
|
19 |
self.assertEqual( |
|
|
20 |
model.get_layer("input_1").input_shape[0][0], |
|
|
21 |
tuple(model.get_layer("conv2d_18").output.shape)[0], |
|
|
22 |
) |
|
|
23 |
# Check the image width |
|
|
24 |
self.assertEqual( |
|
|
25 |
model.get_layer("input_1").input_shape[0][1], |
|
|
26 |
tuple(model.get_layer("conv2d_18").output.shape)[1], |
|
|
27 |
) |
|
|
28 |
|
|
|
29 |
def test_dataloader(self): |
|
|
30 |
""" |
|
|
31 |
Check the number of images and masks of the training and validation pipeline. |
|
|
32 |
Args: |
|
|
33 |
None |
|
|
34 |
Returns: |
|
|
35 |
None |
|
|
36 |
""" |
|
|
37 |
dataset_path = os.path.join("new_data") |
|
|
38 |
train_path = os.path.join(dataset_path, "train") |
|
|
39 |
valid_path = os.path.join(dataset_path, "valid") |
|
|
40 |
|
|
|
41 |
train_x, train_y = load_data(train_path) |
|
|
42 |
train_x, train_y = shuffling(train_x, train_y) |
|
|
43 |
valid_x, valid_y = load_data(valid_path) |
|
|
44 |
|
|
|
45 |
train_dataset = load_dataset(train_x, train_y, batch_size=16) |
|
|
46 |
valid_dataset = load_dataset(valid_x, valid_y, batch_size=16) |
|
|
47 |
|
|
|
48 |
train_image, train_mask = next(iter(train_dataset)) |
|
|
49 |
valid_image, valid_mask = next(iter(valid_dataset)) |
|
|
50 |
|
|
|
51 |
self.assertEqual(len(train_image), len(train_mask)) |
|
|
52 |
self.assertEqual(len(valid_image), len(valid_mask)) |
|
|
53 |
|
|
|
54 |
|
|
|
55 |
if __name__ == "__main__": |
|
|
56 |
unittest.main() |