Diff of /model/optimiser.py [000000] .. [bad60c]

Switch to unified view

a b/model/optimiser.py
1
import pytorch_pretrained_bert as Bert
2
3
def adam(params, config=None):
4
    if config is None:
5
        config = {
6
            'lr': 3e-5,
7
            'warmup_proportion': 0.1,
8
            'weight_decay': 0.01
9
        }
10
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
11
12
    optimizer_grouped_parameters = [
13
        {'params': [p for n, p in params if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
14
        {'params': [p for n, p in params if any(nd in n for nd in no_decay)], 'weight_decay': 0}
15
    ]
16
17
    optim = Bert.optimization.BertAdam(optimizer_grouped_parameters,
18
                                       lr=config['lr'],
19
                                       warmup=config['warmup_proportion'])
20
    return optim