Diff of /utils/callbacks.py [000000] .. [190ca4]

Switch to unified view

a b/utils/callbacks.py
1
# YOLOv5 🚀 by Ultralytics, AGPL-3.0 license
2
"""
3
Callback utils
4
"""
5
6
import threading
7
8
9
class Callbacks:
10
    """"
11
    Handles all registered callbacks for YOLOv5 Hooks
12
    """
13
14
    def __init__(self):
15
        # Define the available callbacks
16
        self._callbacks = {
17
            'on_pretrain_routine_start': [],
18
            'on_pretrain_routine_end': [],
19
            'on_train_start': [],
20
            'on_train_epoch_start': [],
21
            'on_train_batch_start': [],
22
            'optimizer_step': [],
23
            'on_before_zero_grad': [],
24
            'on_train_batch_end': [],
25
            'on_train_epoch_end': [],
26
            'on_val_start': [],
27
            'on_val_batch_start': [],
28
            'on_val_image_end': [],
29
            'on_val_batch_end': [],
30
            'on_val_end': [],
31
            'on_fit_epoch_end': [],  # fit = train + val
32
            'on_model_save': [],
33
            'on_train_end': [],
34
            'on_params_update': [],
35
            'teardown': [], }
36
        self.stop_training = False  # set True to interrupt training
37
38
    def register_action(self, hook, name='', callback=None):
39
        """
40
        Register a new action to a callback hook
41
42
        Args:
43
            hook: The callback hook name to register the action to
44
            name: The name of the action for later reference
45
            callback: The callback to fire
46
        """
47
        assert hook in self._callbacks, f"hook '{hook}' not found in callbacks {self._callbacks}"
48
        assert callable(callback), f"callback '{callback}' is not callable"
49
        self._callbacks[hook].append({'name': name, 'callback': callback})
50
51
    def get_registered_actions(self, hook=None):
52
        """"
53
        Returns all the registered actions by callback hook
54
55
        Args:
56
            hook: The name of the hook to check, defaults to all
57
        """
58
        return self._callbacks[hook] if hook else self._callbacks
59
60
    def run(self, hook, *args, thread=False, **kwargs):
61
        """
62
        Loop through the registered actions and fire all callbacks on main thread
63
64
        Args:
65
            hook: The name of the hook to check, defaults to all
66
            args: Arguments to receive from YOLOv5
67
            thread: (boolean) Run callbacks in daemon thread
68
            kwargs: Keyword Arguments to receive from YOLOv5
69
        """
70
71
        assert hook in self._callbacks, f"hook '{hook}' not found in callbacks {self._callbacks}"
72
        for logger in self._callbacks[hook]:
73
            if thread:
74
                threading.Thread(target=logger['callback'], args=args, kwargs=kwargs, daemon=True).start()
75
            else:
76
                logger['callback'](*args, **kwargs)