|
a |
|
b/utils/callbacks.py |
|
|
1 |
# YOLOv5 🚀 by Ultralytics, AGPL-3.0 license |
|
|
2 |
""" |
|
|
3 |
Callback utils |
|
|
4 |
""" |
|
|
5 |
|
|
|
6 |
import threading |
|
|
7 |
|
|
|
8 |
|
|
|
9 |
class Callbacks: |
|
|
10 |
"""" |
|
|
11 |
Handles all registered callbacks for YOLOv5 Hooks |
|
|
12 |
""" |
|
|
13 |
|
|
|
14 |
def __init__(self): |
|
|
15 |
# Define the available callbacks |
|
|
16 |
self._callbacks = { |
|
|
17 |
'on_pretrain_routine_start': [], |
|
|
18 |
'on_pretrain_routine_end': [], |
|
|
19 |
'on_train_start': [], |
|
|
20 |
'on_train_epoch_start': [], |
|
|
21 |
'on_train_batch_start': [], |
|
|
22 |
'optimizer_step': [], |
|
|
23 |
'on_before_zero_grad': [], |
|
|
24 |
'on_train_batch_end': [], |
|
|
25 |
'on_train_epoch_end': [], |
|
|
26 |
'on_val_start': [], |
|
|
27 |
'on_val_batch_start': [], |
|
|
28 |
'on_val_image_end': [], |
|
|
29 |
'on_val_batch_end': [], |
|
|
30 |
'on_val_end': [], |
|
|
31 |
'on_fit_epoch_end': [], # fit = train + val |
|
|
32 |
'on_model_save': [], |
|
|
33 |
'on_train_end': [], |
|
|
34 |
'on_params_update': [], |
|
|
35 |
'teardown': [], } |
|
|
36 |
self.stop_training = False # set True to interrupt training |
|
|
37 |
|
|
|
38 |
def register_action(self, hook, name='', callback=None): |
|
|
39 |
""" |
|
|
40 |
Register a new action to a callback hook |
|
|
41 |
|
|
|
42 |
Args: |
|
|
43 |
hook: The callback hook name to register the action to |
|
|
44 |
name: The name of the action for later reference |
|
|
45 |
callback: The callback to fire |
|
|
46 |
""" |
|
|
47 |
assert hook in self._callbacks, f"hook '{hook}' not found in callbacks {self._callbacks}" |
|
|
48 |
assert callable(callback), f"callback '{callback}' is not callable" |
|
|
49 |
self._callbacks[hook].append({'name': name, 'callback': callback}) |
|
|
50 |
|
|
|
51 |
def get_registered_actions(self, hook=None): |
|
|
52 |
"""" |
|
|
53 |
Returns all the registered actions by callback hook |
|
|
54 |
|
|
|
55 |
Args: |
|
|
56 |
hook: The name of the hook to check, defaults to all |
|
|
57 |
""" |
|
|
58 |
return self._callbacks[hook] if hook else self._callbacks |
|
|
59 |
|
|
|
60 |
def run(self, hook, *args, thread=False, **kwargs): |
|
|
61 |
""" |
|
|
62 |
Loop through the registered actions and fire all callbacks on main thread |
|
|
63 |
|
|
|
64 |
Args: |
|
|
65 |
hook: The name of the hook to check, defaults to all |
|
|
66 |
args: Arguments to receive from YOLOv5 |
|
|
67 |
thread: (boolean) Run callbacks in daemon thread |
|
|
68 |
kwargs: Keyword Arguments to receive from YOLOv5 |
|
|
69 |
""" |
|
|
70 |
|
|
|
71 |
assert hook in self._callbacks, f"hook '{hook}' not found in callbacks {self._callbacks}" |
|
|
72 |
for logger in self._callbacks[hook]: |
|
|
73 |
if thread: |
|
|
74 |
threading.Thread(target=logger['callback'], args=args, kwargs=kwargs, daemon=True).start() |
|
|
75 |
else: |
|
|
76 |
logger['callback'](*args, **kwargs) |