Diff of /src/optimizer.py [000000] .. [f45789]

Switch to unified view

a b/src/optimizer.py
1
import torch.optim as optim
2
3
def get_params_to_update(model, print_params=True):
4
    # Gather the parameters to be optimized/updated in this run. If we are
5
    #  finetuning we will be updating all parameters. However, if we are
6
    #  doing feature extract method, we will only update the parameters
7
    #  that we have just initialized, i.e. the parameters with requires_grad
8
    #  is True.
9
    params_to_update = []
10
    if print_params: print("Params to learn:")
11
    for name, param in model.named_parameters():
12
        if param.requires_grad == True:
13
            params_to_update.append(param)
14
            if print_params: print(name)
15
    return params_to_update
16
17
def get_optimizer(conf):
18
    model = conf['model']
19
    print_params=conf['optimizer']['print_params']
20
    params_to_update = get_params_to_update(model=model,
21
                                            print_params=print_params)
22
23
    optimizer = conf['optimizer']['name']
24
    # SGD
25
    if optimizer == 'SGD':
26
        lr = conf['optimizer']['lr']
27
        momentum = conf['optimizer'].get('momentum', 0)
28
        dampening = conf['optimizer'].get('dampening', 0)
29
        weight_decay = conf['optimizer'].get('weight_decay', 0)
30
        nesterov = conf['optimizer'].get('nesterov', False)
31
        optimizer = optim.SGD(params=params_to_update,
32
                              lr=lr,
33
                              dampening=dampening,
34
                              momentum=momentum,
35
                              weight_decay=weight_decay,
36
                              nesterov=nesterov)
37
    # Adam
38
    elif optimizer == 'Adam':
39
        lr = conf['optimizer'].get('lr', 0.001)
40
        beta0 = conf['optimizer']['beta'].get('beta0', 0.9)
41
        beta1 = conf['optimizer']['beta'].get('beta1', 0.999)
42
        eps=conf['optimizer'].get('eps', 1e-8)
43
        weight_decay = conf['optimizer'].get('weight_decay', 0.0)
44
        amsgrad = conf['optimizer'].get('amsgrad', False)
45
        optimizer = optim.Adam(params=params_to_update,
46
                                lr=lr,
47
                                betas=(beta0, beta1),
48
                                eps=eps,
49
                                weight_decay=weight_decay,
50
                                amsgrad=amsgrad)
51
    # AdamW
52
    elif optimizer == 'AdamW':
53
        lr = conf['optimizer'].get('lr', 0.001)
54
        beta0 = conf['optimizer']['beta'].get('beta0', 0.9)
55
        beta1 = conf['optimizer']['beta'].get('beta1', 0.999)
56
        eps=conf['optimizer'].get('eps', 1e-8)
57
        weight_decay = conf['optimizer'].get('weight_decay', 0.01)
58
        amsgrad = conf['optimizer'].get('amsgrad', False)
59
        optimizer = optim.AdamW(params=params_to_update,
60
                                lr=lr,
61
                                betas=(beta0, beta1),
62
                                eps=eps,
63
                                weight_decay=weight_decay,
64
                                amsgrad=amsgrad)
65
    else:
66
        print('Optimizer {optimizer} not supported.')
67
        exit()
68
    return optimizer