--- a +++ b/foresight/trainer.py @@ -0,0 +1,61 @@ +from transformers import Trainer +from transformers.trainer import * + +class SuperTrainer(Trainer): + def create_optimizer(self): + """ + Setup the optimizer. + + We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the + Trainer's init through :obj:`optimizers`, or subclass and override this method in a subclass. + """ + if self.optimizer is None: + decay_parameters = get_parameter_names(self.model, [nn.LayerNorm]) + decay_parameters = [name for name in decay_parameters if "bias" not in name] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in self.model.named_parameters() if n in decay_parameters], + "weight_decay": self.args.weight_decay, + }, + { + "params": [p for n, p in self.model.named_parameters() if n not in decay_parameters], + "weight_decay": 0.0, + }, + ] + + # Add the extra_params + if hasattr(self, 'extra_params'): + optimizer_grouped_parameters[0]['params'].extend([p for d, p in self.extra_params if d]) + optimizer_grouped_parameters[1]['params'].extend([p for d, p in self.extra_params if not d]) + + optimizer_cls = Adafactor if self.args.adafactor else AdamW + if self.args.adafactor: + optimizer_cls = Adafactor + optimizer_kwargs = {"scale_parameter": False, "relative_step": False} + else: + optimizer_cls = AdamW + optimizer_kwargs = { + "betas": (self.args.adam_beta1, self.args.adam_beta2), + "eps": self.args.adam_epsilon, + } + optimizer_kwargs["lr"] = self.args.learning_rate + if self.sharded_ddp == ShardedDDPOption.SIMPLE: + self.optimizer = OSS( + params=optimizer_grouped_parameters, + optim=optimizer_cls, + **optimizer_kwargs, + ) + else: + self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) + + if is_sagemaker_mp_enabled(): + self.optimizer = smp.DistributedOptimizer(self.optimizer) + + return self.optimizer + + + def add_params_to_be_tracked(self, params, decay=True): + if hasattr(self, 'extra_params'): + self.extra_params.append((decay, params)) + else: + self.extra_params = [(decay, params)]