[ccb1dd]: / test / test_training.py

Download this file

17 lines (10 with data), 567 Bytes

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
from unittest import TestCase
from keras.callbacks import ReduceLROnPlateau, EarlyStopping
from fetal_net.training import get_callbacks
class TestCallbakcs(TestCase):
def test_reduce_on_plateau(self):
_, _, scheduler = get_callbacks(model_file='model.h5', learning_rate_patience=50, learning_rate_drop=0.5)
self.assertIsInstance(scheduler, ReduceLROnPlateau)
def test_early_stopping(self):
_, _, _, stopper = get_callbacks(model_file='model.h5', early_stopping_patience=100)
self.assertIsInstance(stopper, EarlyStopping)