Switch to unified view

a b/utils/loggers/comet/comet_utils.py
1
import logging
2
import os
3
from urllib.parse import urlparse
4
5
try:
6
    import comet_ml
7
except (ModuleNotFoundError, ImportError):
8
    comet_ml = None
9
10
import yaml
11
12
logger = logging.getLogger(__name__)
13
14
COMET_PREFIX = 'comet://'
15
COMET_MODEL_NAME = os.getenv('COMET_MODEL_NAME', 'yolov5')
16
COMET_DEFAULT_CHECKPOINT_FILENAME = os.getenv('COMET_DEFAULT_CHECKPOINT_FILENAME', 'last.pt')
17
18
19
def download_model_checkpoint(opt, experiment):
20
    model_dir = f'{opt.project}/{experiment.name}'
21
    os.makedirs(model_dir, exist_ok=True)
22
23
    model_name = COMET_MODEL_NAME
24
    model_asset_list = experiment.get_model_asset_list(model_name)
25
26
    if len(model_asset_list) == 0:
27
        logger.error(f'COMET ERROR: No checkpoints found for model name : {model_name}')
28
        return
29
30
    model_asset_list = sorted(
31
        model_asset_list,
32
        key=lambda x: x['step'],
33
        reverse=True,
34
    )
35
    logged_checkpoint_map = {asset['fileName']: asset['assetId'] for asset in model_asset_list}
36
37
    resource_url = urlparse(opt.weights)
38
    checkpoint_filename = resource_url.query
39
40
    if checkpoint_filename:
41
        asset_id = logged_checkpoint_map.get(checkpoint_filename)
42
    else:
43
        asset_id = logged_checkpoint_map.get(COMET_DEFAULT_CHECKPOINT_FILENAME)
44
        checkpoint_filename = COMET_DEFAULT_CHECKPOINT_FILENAME
45
46
    if asset_id is None:
47
        logger.error(f'COMET ERROR: Checkpoint {checkpoint_filename} not found in the given Experiment')
48
        return
49
50
    try:
51
        logger.info(f'COMET INFO: Downloading checkpoint {checkpoint_filename}')
52
        asset_filename = checkpoint_filename
53
54
        model_binary = experiment.get_asset(asset_id, return_type='binary', stream=False)
55
        model_download_path = f'{model_dir}/{asset_filename}'
56
        with open(model_download_path, 'wb') as f:
57
            f.write(model_binary)
58
59
        opt.weights = model_download_path
60
61
    except Exception as e:
62
        logger.warning('COMET WARNING: Unable to download checkpoint from Comet')
63
        logger.exception(e)
64
65
66
def set_opt_parameters(opt, experiment):
67
    """Update the opts Namespace with parameters
68
    from Comet's ExistingExperiment when resuming a run
69
70
    Args:
71
        opt (argparse.Namespace): Namespace of command line options
72
        experiment (comet_ml.APIExperiment): Comet API Experiment object
73
    """
74
    asset_list = experiment.get_asset_list()
75
    resume_string = opt.resume
76
77
    for asset in asset_list:
78
        if asset['fileName'] == 'opt.yaml':
79
            asset_id = asset['assetId']
80
            asset_binary = experiment.get_asset(asset_id, return_type='binary', stream=False)
81
            opt_dict = yaml.safe_load(asset_binary)
82
            for key, value in opt_dict.items():
83
                setattr(opt, key, value)
84
            opt.resume = resume_string
85
86
    # Save hyperparameters to YAML file
87
    # Necessary to pass checks in training script
88
    save_dir = f'{opt.project}/{experiment.name}'
89
    os.makedirs(save_dir, exist_ok=True)
90
91
    hyp_yaml_path = f'{save_dir}/hyp.yaml'
92
    with open(hyp_yaml_path, 'w') as f:
93
        yaml.dump(opt.hyp, f)
94
    opt.hyp = hyp_yaml_path
95
96
97
def check_comet_weights(opt):
98
    """Downloads model weights from Comet and updates the
99
    weights path to point to saved weights location
100
101
    Args:
102
        opt (argparse.Namespace): Command Line arguments passed
103
            to YOLOv5 training script
104
105
    Returns:
106
        None/bool: Return True if weights are successfully downloaded
107
            else return None
108
    """
109
    if comet_ml is None:
110
        return
111
112
    if isinstance(opt.weights, str):
113
        if opt.weights.startswith(COMET_PREFIX):
114
            api = comet_ml.API()
115
            resource = urlparse(opt.weights)
116
            experiment_path = f'{resource.netloc}{resource.path}'
117
            experiment = api.get(experiment_path)
118
            download_model_checkpoint(opt, experiment)
119
            return True
120
121
    return None
122
123
124
def check_comet_resume(opt):
125
    """Restores run parameters to its original state based on the model checkpoint
126
    and logged Experiment parameters.
127
128
    Args:
129
        opt (argparse.Namespace): Command Line arguments passed
130
            to YOLOv5 training script
131
132
    Returns:
133
        None/bool: Return True if the run is restored successfully
134
            else return None
135
    """
136
    if comet_ml is None:
137
        return
138
139
    if isinstance(opt.resume, str):
140
        if opt.resume.startswith(COMET_PREFIX):
141
            api = comet_ml.API()
142
            resource = urlparse(opt.resume)
143
            experiment_path = f'{resource.netloc}{resource.path}'
144
            experiment = api.get(experiment_path)
145
            set_opt_parameters(opt, experiment)
146
            download_model_checkpoint(opt, experiment)
147
148
            return True
149
150
    return None