a b/utils/loggers/wandb/wandb_utils.py
1
# YOLOv5 🚀 by Ultralytics, AGPL-3.0 license
2
3
# WARNING ⚠️ wandb is deprecated and will be removed in future release.
4
# See supported integrations at https://github.com/ultralytics/yolov5#integrations
5
6
import logging
7
import os
8
import sys
9
from contextlib import contextmanager
10
from pathlib import Path
11
12
from utils.general import LOGGER, colorstr
13
14
FILE = Path(__file__).resolve()
15
ROOT = FILE.parents[3]  # YOLOv5 root directory
16
if str(ROOT) not in sys.path:
17
    sys.path.append(str(ROOT))  # add ROOT to PATH
18
RANK = int(os.getenv('RANK', -1))
19
DEPRECATION_WARNING = f"{colorstr('wandb')}: WARNING ⚠️ wandb is deprecated and will be removed in a future release. " \
20
                      f'See supported integrations at https://github.com/ultralytics/yolov5#integrations.'
21
22
try:
23
    import wandb
24
25
    assert hasattr(wandb, '__version__')  # verify package import not local dir
26
    LOGGER.warning(DEPRECATION_WARNING)
27
except (ImportError, AssertionError):
28
    wandb = None
29
30
31
class WandbLogger():
32
    """Log training runs, datasets, models, and predictions to Weights & Biases.
33
34
    This logger sends information to W&B at wandb.ai. By default, this information
35
    includes hyperparameters, system configuration and metrics, model metrics,
36
    and basic data metrics and analyses.
37
38
    By providing additional command line arguments to train.py, datasets,
39
    models and predictions can also be logged.
40
41
    For more on how this logger is used, see the Weights & Biases documentation:
42
    https://docs.wandb.com/guides/integrations/yolov5
43
    """
44
45
    def __init__(self, opt, run_id=None, job_type='Training'):
46
        """
47
        - Initialize WandbLogger instance
48
        - Upload dataset if opt.upload_dataset is True
49
        - Setup training processes if job_type is 'Training'
50
51
        arguments:
52
        opt (namespace) -- Commandline arguments for this run
53
        run_id (str) -- Run ID of W&B run to be resumed
54
        job_type (str) -- To set the job_type for this run
55
56
       """
57
        # Pre-training routine --
58
        self.job_type = job_type
59
        self.wandb, self.wandb_run = wandb, wandb.run if wandb else None
60
        self.val_artifact, self.train_artifact = None, None
61
        self.train_artifact_path, self.val_artifact_path = None, None
62
        self.result_artifact = None
63
        self.val_table, self.result_table = None, None
64
        self.max_imgs_to_log = 16
65
        self.data_dict = None
66
        if self.wandb:
67
            self.wandb_run = wandb.init(config=opt,
68
                                        resume='allow',
69
                                        project='YOLOv5' if opt.project == 'runs/train' else Path(opt.project).stem,
70
                                        entity=opt.entity,
71
                                        name=opt.name if opt.name != 'exp' else None,
72
                                        job_type=job_type,
73
                                        id=run_id,
74
                                        allow_val_change=True) if not wandb.run else wandb.run
75
76
        if self.wandb_run:
77
            if self.job_type == 'Training':
78
                if isinstance(opt.data, dict):
79
                    # This means another dataset manager has already processed the dataset info (e.g. ClearML)
80
                    # and they will have stored the already processed dict in opt.data
81
                    self.data_dict = opt.data
82
                self.setup_training(opt)
83
84
    def setup_training(self, opt):
85
        """
86
        Setup the necessary processes for training YOLO models:
87
          - Attempt to download model checkpoint and dataset artifacts if opt.resume stats with WANDB_ARTIFACT_PREFIX
88
          - Update data_dict, to contain info of previous run if resumed and the paths of dataset artifact if downloaded
89
          - Setup log_dict, initialize bbox_interval
90
91
        arguments:
92
        opt (namespace) -- commandline arguments for this run
93
94
        """
95
        self.log_dict, self.current_epoch = {}, 0
96
        self.bbox_interval = opt.bbox_interval
97
        if isinstance(opt.resume, str):
98
            model_dir, _ = self.download_model_artifact(opt)
99
            if model_dir:
100
                self.weights = Path(model_dir) / 'last.pt'
101
                config = self.wandb_run.config
102
                opt.weights, opt.save_period, opt.batch_size, opt.bbox_interval, opt.epochs, opt.hyp, opt.imgsz = str(
103
                    self.weights), config.save_period, config.batch_size, config.bbox_interval, config.epochs, \
104
                    config.hyp, config.imgsz
105
106
        if opt.bbox_interval == -1:
107
            self.bbox_interval = opt.bbox_interval = (opt.epochs // 10) if opt.epochs > 10 else 1
108
            if opt.evolve or opt.noplots:
109
                self.bbox_interval = opt.bbox_interval = opt.epochs + 1  # disable bbox_interval
110
111
    def log_model(self, path, opt, epoch, fitness_score, best_model=False):
112
        """
113
        Log the model checkpoint as W&B artifact
114
115
        arguments:
116
        path (Path)   -- Path of directory containing the checkpoints
117
        opt (namespace) -- Command line arguments for this run
118
        epoch (int)  -- Current epoch number
119
        fitness_score (float) -- fitness score for current epoch
120
        best_model (boolean) -- Boolean representing if the current checkpoint is the best yet.
121
        """
122
        model_artifact = wandb.Artifact('run_' + wandb.run.id + '_model',
123
                                        type='model',
124
                                        metadata={
125
                                            'original_url': str(path),
126
                                            'epochs_trained': epoch + 1,
127
                                            'save period': opt.save_period,
128
                                            'project': opt.project,
129
                                            'total_epochs': opt.epochs,
130
                                            'fitness_score': fitness_score})
131
        model_artifact.add_file(str(path / 'last.pt'), name='last.pt')
132
        wandb.log_artifact(model_artifact,
133
                           aliases=['latest', 'last', 'epoch ' + str(self.current_epoch), 'best' if best_model else ''])
134
        LOGGER.info(f'Saving model artifact on epoch {epoch + 1}')
135
136
    def val_one_image(self, pred, predn, path, names, im):
137
        pass
138
139
    def log(self, log_dict):
140
        """
141
        save the metrics to the logging dictionary
142
143
        arguments:
144
        log_dict (Dict) -- metrics/media to be logged in current step
145
        """
146
        if self.wandb_run:
147
            for key, value in log_dict.items():
148
                self.log_dict[key] = value
149
150
    def end_epoch(self):
151
        """
152
        commit the log_dict, model artifacts and Tables to W&B and flush the log_dict.
153
154
        arguments:
155
        best_result (boolean): Boolean representing if the result of this evaluation is best or not
156
        """
157
        if self.wandb_run:
158
            with all_logging_disabled():
159
                try:
160
                    wandb.log(self.log_dict)
161
                except BaseException as e:
162
                    LOGGER.info(
163
                        f'An error occurred in wandb logger. The training will proceed without interruption. More info\n{e}'
164
                    )
165
                    self.wandb_run.finish()
166
                    self.wandb_run = None
167
                self.log_dict = {}
168
169
    def finish_run(self):
170
        """
171
        Log metrics if any and finish the current W&B run
172
        """
173
        if self.wandb_run:
174
            if self.log_dict:
175
                with all_logging_disabled():
176
                    wandb.log(self.log_dict)
177
            wandb.run.finish()
178
            LOGGER.warning(DEPRECATION_WARNING)
179
180
181
@contextmanager
182
def all_logging_disabled(highest_level=logging.CRITICAL):
183
    """ source - https://gist.github.com/simon-weber/7853144
184
    A context manager that will prevent any logging messages triggered during the body from being processed.
185
    :param highest_level: the maximum logging level in use.
186
      This would only need to be changed if a custom level greater than CRITICAL is defined.
187
    """
188
    previous_level = logging.root.manager.disable
189
    logging.disable(highest_level)
190
    try:
191
        yield
192
    finally:
193
        logging.disable(previous_level)