[20be63]: / foresight / trainer.py

Download this file

62 lines (53 with data), 2.6 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from transformers import Trainer
from transformers.trainer import *
class SuperTrainer(Trainer):
def create_optimizer(self):
"""
Setup the optimizer.
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
Trainer's init through :obj:`optimizers`, or subclass and override this method in a subclass.
"""
if self.optimizer is None:
decay_parameters = get_parameter_names(self.model, [nn.LayerNorm])
decay_parameters = [name for name in decay_parameters if "bias" not in name]
optimizer_grouped_parameters = [
{
"params": [p for n, p in self.model.named_parameters() if n in decay_parameters],
"weight_decay": self.args.weight_decay,
},
{
"params": [p for n, p in self.model.named_parameters() if n not in decay_parameters],
"weight_decay": 0.0,
},
]
# Add the extra_params
if hasattr(self, 'extra_params'):
optimizer_grouped_parameters[0]['params'].extend([p for d, p in self.extra_params if d])
optimizer_grouped_parameters[1]['params'].extend([p for d, p in self.extra_params if not d])
optimizer_cls = Adafactor if self.args.adafactor else AdamW
if self.args.adafactor:
optimizer_cls = Adafactor
optimizer_kwargs = {"scale_parameter": False, "relative_step": False}
else:
optimizer_cls = AdamW
optimizer_kwargs = {
"betas": (self.args.adam_beta1, self.args.adam_beta2),
"eps": self.args.adam_epsilon,
}
optimizer_kwargs["lr"] = self.args.learning_rate
if self.sharded_ddp == ShardedDDPOption.SIMPLE:
self.optimizer = OSS(
params=optimizer_grouped_parameters,
optim=optimizer_cls,
**optimizer_kwargs,
)
else:
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer(self.optimizer)
return self.optimizer
def add_params_to_be_tracked(self, params, decay=True):
if hasattr(self, 'extra_params'):
self.extra_params.append((decay, params))
else:
self.extra_params = [(decay, params)]