a b/test/test_training.py
1
from unittest import TestCase
2
3
from keras.callbacks import ReduceLROnPlateau, EarlyStopping
4
5
from fetal_net.training import get_callbacks
6
7
8
class TestCallbakcs(TestCase):
9
    def test_reduce_on_plateau(self):
10
        _, _, scheduler = get_callbacks(model_file='model.h5', learning_rate_patience=50, learning_rate_drop=0.5)
11
        self.assertIsInstance(scheduler, ReduceLROnPlateau)
12
13
    def test_early_stopping(self):
14
        _, _, _, stopper = get_callbacks(model_file='model.h5', early_stopping_patience=100)
15
        self.assertIsInstance(stopper, EarlyStopping)
16