|
a |
|
b/test/test_training.py |
|
|
1 |
from unittest import TestCase |
|
|
2 |
|
|
|
3 |
from keras.callbacks import ReduceLROnPlateau, EarlyStopping |
|
|
4 |
|
|
|
5 |
from fetal_net.training import get_callbacks |
|
|
6 |
|
|
|
7 |
|
|
|
8 |
class TestCallbakcs(TestCase): |
|
|
9 |
def test_reduce_on_plateau(self): |
|
|
10 |
_, _, scheduler = get_callbacks(model_file='model.h5', learning_rate_patience=50, learning_rate_drop=0.5) |
|
|
11 |
self.assertIsInstance(scheduler, ReduceLROnPlateau) |
|
|
12 |
|
|
|
13 |
def test_early_stopping(self): |
|
|
14 |
_, _, _, stopper = get_callbacks(model_file='model.h5', early_stopping_patience=100) |
|
|
15 |
self.assertIsInstance(stopper, EarlyStopping) |
|
|
16 |
|