--- a +++ b/src/extensions/transformers.py @@ -0,0 +1,40 @@ +# Base Dependencies +# ----------------- +from typing import Union, Tuple + +# 3-rd Party Dependencies +# ----------------------- +import torch +from torch import nn +from transformers import Trainer + + +class WeightedLossTrainer(Trainer): + """Custom Transformers' Trainer which uses a weighted cross entropy loss""" + + class_weights: torch.Tensor = None + + def compute_loss( + self, model: nn.Module, inputs: dict, return_outputs: bool = False + ) -> Union[Tuple[float, dict], float]: + """Computes the weighted cross entropy loss of a model on the batch of inputs. + + Args: + model (nn.Module): Transformers model + inputs (dict): batch of inputs + return_outputs(bool): + + Returns: + Union[Tuple[float, dict], float]: loss and outputs of the model, or loss + """ + labels = inputs.get("labels") + + # forward pass + outputs = model(**inputs) + logits = outputs.get("logits") + + # compute custom loss (suppose one has 3 labels with different weights) + loss_fct = nn.CrossEntropyLoss(weight=self.class_weights) + loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1)) + + return (loss, outputs) if return_outputs else loss