--- a +++ b/tests.py @@ -0,0 +1,56 @@ +import os +import unittest + +from models.unet import get_unet_model +from train import load_data, load_dataset, shuffling + + +class TestClassifier(unittest.TestCase): + def test_model(self): + """ + Check the input layer and the output layer of the model. + Args: + None + Returns: + None + """ + model = get_unet_model((512, 512, 3)) + # Check the image height + self.assertEqual( + model.get_layer("input_1").input_shape[0][0], + tuple(model.get_layer("conv2d_18").output.shape)[0], + ) + # Check the image width + self.assertEqual( + model.get_layer("input_1").input_shape[0][1], + tuple(model.get_layer("conv2d_18").output.shape)[1], + ) + + def test_dataloader(self): + """ + Check the number of images and masks of the training and validation pipeline. + Args: + None + Returns: + None + """ + dataset_path = os.path.join("new_data") + train_path = os.path.join(dataset_path, "train") + valid_path = os.path.join(dataset_path, "valid") + + train_x, train_y = load_data(train_path) + train_x, train_y = shuffling(train_x, train_y) + valid_x, valid_y = load_data(valid_path) + + train_dataset = load_dataset(train_x, train_y, batch_size=16) + valid_dataset = load_dataset(valid_x, valid_y, batch_size=16) + + train_image, train_mask = next(iter(train_dataset)) + valid_image, valid_mask = next(iter(valid_dataset)) + + self.assertEqual(len(train_image), len(train_mask)) + self.assertEqual(len(valid_image), len(valid_mask)) + + +if __name__ == "__main__": + unittest.main()