Diff of /earlystoping.py [000000] .. [2d53aa]

Switch to unified view

a b/earlystoping.py
1
import torch
2
3
class Earlystopping:
4
    def __init__(self, number=50, path='../ssd/checkpoint.pt'):
5
        self.number = number
6
        self.path = path
7
        self.counter = 0
8
        self.epoch_count = 0
9
        self.best_epoch_num = 1
10
        self.max_acc = None
11
        self.stop_now = False
12
13
    def __call__(self, model, acc):
14
        if self.max_acc is None:
15
            self.epoch_count += 1
16
            self.max_acc = acc
17
            self.save_model(model)
18
        elif acc > self.max_acc:
19
            self.epoch_count += 1
20
            self.best_epoch_num = self.epoch_count
21
            print('New maximum accuracy: {:.4f}% -> {:.4f}%\n'.format(self.max_acc, acc))
22
            self.max_acc = acc
23
            self.save_model(model)
24
            self.counter = 0
25
        else:
26
            self.epoch_count += 1
27
            print('Current maximum accuracy: {:.4f}%\n'
28
                  'Early stopping counter: {:d}/{:d}\n'.format(self.max_acc, self.counter, self.number))
29
            if self.counter >= self.number:
30
                self.stop_now = True
31
            self.counter += 1
32
33
    def save_model(self, model):
34
        torch.save(model.state_dict(), self.path)