--- a +++ b/src/optimizer.py @@ -0,0 +1,68 @@ +import torch.optim as optim + +def get_params_to_update(model, print_params=True): + # Gather the parameters to be optimized/updated in this run. If we are + # finetuning we will be updating all parameters. However, if we are + # doing feature extract method, we will only update the parameters + # that we have just initialized, i.e. the parameters with requires_grad + # is True. + params_to_update = [] + if print_params: print("Params to learn:") + for name, param in model.named_parameters(): + if param.requires_grad == True: + params_to_update.append(param) + if print_params: print(name) + return params_to_update + +def get_optimizer(conf): + model = conf['model'] + print_params=conf['optimizer']['print_params'] + params_to_update = get_params_to_update(model=model, + print_params=print_params) + + optimizer = conf['optimizer']['name'] + # SGD + if optimizer == 'SGD': + lr = conf['optimizer']['lr'] + momentum = conf['optimizer'].get('momentum', 0) + dampening = conf['optimizer'].get('dampening', 0) + weight_decay = conf['optimizer'].get('weight_decay', 0) + nesterov = conf['optimizer'].get('nesterov', False) + optimizer = optim.SGD(params=params_to_update, + lr=lr, + dampening=dampening, + momentum=momentum, + weight_decay=weight_decay, + nesterov=nesterov) + # Adam + elif optimizer == 'Adam': + lr = conf['optimizer'].get('lr', 0.001) + beta0 = conf['optimizer']['beta'].get('beta0', 0.9) + beta1 = conf['optimizer']['beta'].get('beta1', 0.999) + eps=conf['optimizer'].get('eps', 1e-8) + weight_decay = conf['optimizer'].get('weight_decay', 0.0) + amsgrad = conf['optimizer'].get('amsgrad', False) + optimizer = optim.Adam(params=params_to_update, + lr=lr, + betas=(beta0, beta1), + eps=eps, + weight_decay=weight_decay, + amsgrad=amsgrad) + # AdamW + elif optimizer == 'AdamW': + lr = conf['optimizer'].get('lr', 0.001) + beta0 = conf['optimizer']['beta'].get('beta0', 0.9) + beta1 = conf['optimizer']['beta'].get('beta1', 0.999) + eps=conf['optimizer'].get('eps', 1e-8) + weight_decay = conf['optimizer'].get('weight_decay', 0.01) + amsgrad = conf['optimizer'].get('amsgrad', False) + optimizer = optim.AdamW(params=params_to_update, + lr=lr, + betas=(beta0, beta1), + eps=eps, + weight_decay=weight_decay, + amsgrad=amsgrad) + else: + print('Optimizer {optimizer} not supported.') + exit() + return optimizer \ No newline at end of file