[1e1d11]: / configs / hparams.py

Download this file

17 lines (15 with data), 557 Bytes

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
def get_hparams_class(dataset_name):
"""Return the algorithm class with the given name."""
if dataset_name not in globals():
raise NotImplementedError("Algorithm not found: {}".format(dataset_name))
return globals()[dataset_name]
class supervised():
def __init__(self):
super(supervised, self).__init__()
self.train_params = {
'num_epochs': 60,
'batch_size': 128,
'weight_decay': 1e-4,
'learning_rate': 1e-3,
'feature_dim': 1*128
}