Diff of /src/scheduler.py [000000] .. [f45789]

Switch to side-by-side view

--- a
+++ b/src/scheduler.py
@@ -0,0 +1,16 @@
+import torch.optim as optim
+
+def get_scheduler(conf):
+    '''Setup the learning rate scheduler'''
+    scheduler = conf['scheduler']['name']
+    if scheduler == 'StepLR':
+        optimizer = conf['optimizer']
+        step_size = conf['scheduler']['step_size']
+        gamma = conf['scheduler']['gamma']
+        scheduler = optim.lr_scheduler.StepLR(optimizer=optimizer,
+                                              step_size=step_size,
+                                              gamma=gamma)
+    else:
+        print('Scheduler {scheduler} not suported.')
+        exit()
+    return scheduler
\ No newline at end of file