[735bb5]: / src / extensions / transformers.py

Download this file

41 lines (30 with data), 1.3 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
# Base Dependencies
# -----------------
from typing import Union, Tuple
# 3-rd Party Dependencies
# -----------------------
import torch
from torch import nn
from transformers import Trainer
class WeightedLossTrainer(Trainer):
"""Custom Transformers' Trainer which uses a weighted cross entropy loss"""
class_weights: torch.Tensor = None
def compute_loss(
self, model: nn.Module, inputs: dict, return_outputs: bool = False
) -> Union[Tuple[float, dict], float]:
"""Computes the weighted cross entropy loss of a model on the batch of inputs.
Args:
model (nn.Module): Transformers model
inputs (dict): batch of inputs
return_outputs(bool):
Returns:
Union[Tuple[float, dict], float]: loss and outputs of the model, or loss
"""
labels = inputs.get("labels")
# forward pass
outputs = model(**inputs)
logits = outputs.get("logits")
# compute custom loss (suppose one has 3 labels with different weights)
loss_fct = nn.CrossEntropyLoss(weight=self.class_weights)
loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
return (loss, outputs) if return_outputs else loss