|
a |
|
b/earlystoping.py |
|
|
1 |
import torch |
|
|
2 |
|
|
|
3 |
class Earlystopping: |
|
|
4 |
def __init__(self, number=50, path='../ssd/checkpoint.pt'): |
|
|
5 |
self.number = number |
|
|
6 |
self.path = path |
|
|
7 |
self.counter = 0 |
|
|
8 |
self.epoch_count = 0 |
|
|
9 |
self.best_epoch_num = 1 |
|
|
10 |
self.max_acc = None |
|
|
11 |
self.stop_now = False |
|
|
12 |
|
|
|
13 |
def __call__(self, model, acc): |
|
|
14 |
if self.max_acc is None: |
|
|
15 |
self.epoch_count += 1 |
|
|
16 |
self.max_acc = acc |
|
|
17 |
self.save_model(model) |
|
|
18 |
elif acc > self.max_acc: |
|
|
19 |
self.epoch_count += 1 |
|
|
20 |
self.best_epoch_num = self.epoch_count |
|
|
21 |
print('New maximum accuracy: {:.4f}% -> {:.4f}%\n'.format(self.max_acc, acc)) |
|
|
22 |
self.max_acc = acc |
|
|
23 |
self.save_model(model) |
|
|
24 |
self.counter = 0 |
|
|
25 |
else: |
|
|
26 |
self.epoch_count += 1 |
|
|
27 |
print('Current maximum accuracy: {:.4f}%\n' |
|
|
28 |
'Early stopping counter: {:d}/{:d}\n'.format(self.max_acc, self.counter, self.number)) |
|
|
29 |
if self.counter >= self.number: |
|
|
30 |
self.stop_now = True |
|
|
31 |
self.counter += 1 |
|
|
32 |
|
|
|
33 |
def save_model(self, model): |
|
|
34 |
torch.save(model.state_dict(), self.path) |