--- a
+++ b/foresight/trainer.py
@@ -0,0 +1,61 @@
+from transformers import Trainer
+from transformers.trainer import *
+
+class SuperTrainer(Trainer):
+    def create_optimizer(self):
+        """
+        Setup the optimizer.
+
+        We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
+        Trainer's init through :obj:`optimizers`, or subclass and override this method in a subclass.
+        """
+        if self.optimizer is None:
+            decay_parameters = get_parameter_names(self.model, [nn.LayerNorm])
+            decay_parameters = [name for name in decay_parameters if "bias" not in name]
+            optimizer_grouped_parameters = [
+                {
+                    "params": [p for n, p in self.model.named_parameters() if n in decay_parameters],
+                    "weight_decay": self.args.weight_decay,
+                },
+                {
+                    "params": [p for n, p in self.model.named_parameters() if n not in decay_parameters],
+                    "weight_decay": 0.0,
+                },
+            ]
+
+            # Add the extra_params
+            if hasattr(self, 'extra_params'):
+                optimizer_grouped_parameters[0]['params'].extend([p for d, p in self.extra_params if d])
+                optimizer_grouped_parameters[1]['params'].extend([p for d, p in self.extra_params if not d])
+
+            optimizer_cls = Adafactor if self.args.adafactor else AdamW
+            if self.args.adafactor:
+                optimizer_cls = Adafactor
+                optimizer_kwargs = {"scale_parameter": False, "relative_step": False}
+            else:
+                optimizer_cls = AdamW
+                optimizer_kwargs = {
+                    "betas": (self.args.adam_beta1, self.args.adam_beta2),
+                    "eps": self.args.adam_epsilon,
+                }
+            optimizer_kwargs["lr"] = self.args.learning_rate
+            if self.sharded_ddp == ShardedDDPOption.SIMPLE:
+                self.optimizer = OSS(
+                    params=optimizer_grouped_parameters,
+                    optim=optimizer_cls,
+                    **optimizer_kwargs,
+                )
+            else:
+                self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
+
+        if is_sagemaker_mp_enabled():
+            self.optimizer = smp.DistributedOptimizer(self.optimizer)
+
+        return self.optimizer
+
+
+    def add_params_to_be_tracked(self, params, decay=True):
+        if hasattr(self, 'extra_params'):
+            self.extra_params.append((decay, params))
+        else:
+            self.extra_params = [(decay, params)]