a b/src/training/early_stopping.py
1
"""Early Stopping
2
Source: https://github.com/Bjarten/early-stopping-pytorch/blob/master/pytorchtools.py
3
"""
4
5
# Base Dependencies
6
# -----------------
7
import numpy as np
8
import torch
9
10
11
class EarlyStopping:
12
    """Early stops the training if validation loss doesn't improve after a given patience."""
13
14
    def __init__(
15
        self, patience=7, verbose=False, delta=0, path="checkpoint.pt", trace_func=print
16
    ):
17
        """
18
        Args:
19
            patience (int): How long to wait after last time validation loss improved.
20
                            Default: 7
21
            verbose (bool): If True, prints a message for each validation loss improvement.
22
                            Default: False
23
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
24
                            Default: 0
25
            path (str): Path for the checkpoint to be saved to.
26
                            Default: 'checkpoint.pt'
27
            trace_func (function): trace print function.
28
                            Default: print
29
        """
30
        self.patience = patience
31
        self.verbose = verbose
32
        self.counter = 0
33
        self.best_score = None
34
        self.early_stop = False
35
        self.val_loss_min = np.Inf
36
        self.delta = delta
37
        self.path = path
38
        self.trace_func = trace_func
39
40
    def __call__(self, val_loss, model):
41
42
        score = -val_loss
43
44
        if self.best_score is None:
45
            self.best_score = score
46
            self.save_checkpoint(val_loss, model)
47
        elif score < self.best_score + self.delta:
48
            self.counter += 1
49
            self.trace_func(
50
                f"EarlyStopping counter: {self.counter} out of {self.patience}"
51
            )
52
            if self.counter >= self.patience:
53
                self.early_stop = True
54
        else:
55
            self.best_score = score
56
            self.save_checkpoint(val_loss, model)
57
            self.counter = 0
58
59
    def save_checkpoint(self, val_loss, model):
60
        """Saves model when validation loss decrease."""
61
        if self.verbose:
62
            self.trace_func(
63
                f"Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ..."
64
            )
65
        torch.save(model.state_dict(), self.path)
66
        self.val_loss_min = val_loss