Diff of /bin/tuning.py [000000] .. [13a70a]

Switch to unified view

a b/bin/tuning.py
1
#!/usr/bin/env python3
2
"""CLI module for running raytune tuning experiment."""
3
4
import argparse
5
import logging
6
import shutil
7
from pathlib import Path
8
from typing import Any
9
10
import ray
11
import yaml
12
13
from stimulus.data import loaders
14
from stimulus.learner import raytune_learner, raytune_parser
15
from stimulus.utils import launch_utils, yaml_data, yaml_model_schema
16
17
logger = logging.getLogger(__name__)
18
19
20
def _raise_empty_grid() -> None:
21
    """Raise an error when grid results are empty."""
22
    raise RuntimeError("Ray Tune returned empty results grid")
23
24
25
def get_args() -> argparse.Namespace:
26
    """Get the arguments when using from the commandline.
27
28
    Returns:
29
        Parsed command line arguments.
30
    """
31
    parser = argparse.ArgumentParser(description="Launch check_model.")
32
    parser.add_argument("-d", "--data", type=str, required=True, metavar="FILE", help="Path to input csv file.")
33
    parser.add_argument("-m", "--model", type=str, required=True, metavar="FILE", help="Path to model file.")
34
    parser.add_argument(
35
        "-e",
36
        "--data_config",
37
        type=str,
38
        required=True,
39
        metavar="FILE",
40
        help="Path to data config file.",
41
    )
42
    parser.add_argument(
43
        "-c",
44
        "--model_config",
45
        type=str,
46
        required=True,
47
        metavar="FILE",
48
        help="Path to yaml config training file.",
49
    )
50
    parser.add_argument(
51
        "-w",
52
        "--initial_weights",
53
        type=str,
54
        required=False,
55
        nargs="?",
56
        const=None,
57
        default=None,
58
        metavar="FILE",
59
        help="The path to the initial weights (optional).",
60
    )
61
    parser.add_argument(
62
        "--ray_results_dirpath",
63
        type=str,
64
        required=False,
65
        nargs="?",
66
        const=None,
67
        default=None,
68
        metavar="DIR_PATH",
69
        help="Location where ray_results output dir should be written. If None, uses ~/ray_results.",
70
    )
71
    parser.add_argument(
72
        "-o",
73
        "--output",
74
        type=str,
75
        required=False,
76
        nargs="?",
77
        const="best_model.pt",
78
        default="best_model.pt",
79
        metavar="FILE",
80
        help="The output file path to write the trained model to",
81
    )
82
    parser.add_argument(
83
        "-bm",
84
        "--best_metrics",
85
        type=str,
86
        required=False,
87
        nargs="?",
88
        const="best_metrics.csv",
89
        default="best_metrics.csv",
90
        metavar="FILE",
91
        help="The path to write the best metrics to",
92
    )
93
    parser.add_argument(
94
        "-bc",
95
        "--best_config",
96
        type=str,
97
        required=False,
98
        nargs="?",
99
        const="best_config.yaml",
100
        default="best_config.yaml",
101
        metavar="FILE",
102
        help="The path to write the best config to",
103
    )
104
    parser.add_argument(
105
        "-bo",
106
        "--best_optimizer",
107
        type=str,
108
        required=False,
109
        nargs="?",
110
        const="best_optimizer.pt",
111
        default="best_optimizer.pt",
112
        metavar="FILE",
113
        help="The path to write the best optimizer to",
114
    )
115
    parser.add_argument(
116
        "--tune_run_name",
117
        type=str,
118
        required=False,
119
        nargs="?",
120
        const=None,
121
        default=None,
122
        metavar="CUSTOM_RUN_NAME",
123
        help=(
124
            "Tells ray tune what the 'experiment_name' (i.e. the given tune_run name) should be. "
125
            "If set, the subdirectory of ray_results is named with this value and its train dir is prefixed accordingly. "
126
            "Default None means that ray will generate such a name on its own."
127
        ),
128
    )
129
    parser.add_argument(
130
        "--debug_mode",
131
        action="store_true",
132
        help="Activate debug mode for tuning. Default false, no debug.",
133
    )
134
    return parser.parse_args()
135
136
137
def main(
138
    model_path: str,
139
    data_path: str,
140
    data_config_path: str,
141
    model_config_path: str,
142
    initial_weights: str | None = None,  # noqa: ARG001
143
    ray_results_dirpath: str | None = None,
144
    output_path: str | None = None,
145
    best_optimizer_path: str | None = None,
146
    best_metrics_path: str | None = None,
147
    best_config_path: str | None = None,
148
    *,
149
    debug_mode: bool = False,
150
) -> None:
151
    """Run the main model checking pipeline.
152
153
    Args:
154
        data_path: Path to input data file.
155
        model_path: Path to model file.
156
        data_config_path: Path to data config file.
157
        model_config_path: Path to model config file.
158
        initial_weights: Optional path to initial weights.
159
        ray_results_dirpath: Directory for ray results.
160
        debug_mode: Whether to run in debug mode.
161
        output_path: Path to write the best model to.
162
        best_optimizer_path: Path to write the best optimizer to.
163
        best_metrics_path: Path to write the best metrics to.
164
        best_config_path: Path to write the best config to.
165
    """
166
    # Convert data config to proper type
167
    with open(data_config_path) as file:
168
        data_config_dict: dict[str, Any] = yaml.safe_load(file)
169
    data_config: yaml_data.YamlSubConfigDict = yaml_data.YamlSubConfigDict(**data_config_dict)
170
171
    with open(model_config_path) as file:
172
        model_config_dict: dict[str, Any] = yaml.safe_load(file)
173
    model_config: yaml_model_schema.Model = yaml_model_schema.Model(**model_config_dict)
174
175
    encoder_loader = loaders.EncoderLoader()
176
    encoder_loader.initialize_column_encoders_from_config(column_config=data_config.columns)
177
178
    model_class = launch_utils.import_class_from_file(model_path)
179
180
    ray_config_loader = yaml_model_schema.YamlRayConfigLoader(model=model_config)
181
    ray_config_model = ray_config_loader.get_config()
182
183
    tuner = raytune_learner.TuneWrapper(
184
        model_config=ray_config_model,
185
        data_config_path=data_config_path,
186
        model_class=model_class,
187
        data_path=data_path,
188
        encoder_loader=encoder_loader,
189
        seed=42,
190
        ray_results_dir=ray_results_dirpath,
191
        debug=debug_mode,
192
    )
193
194
    # Ensure output_path is provided
195
    if output_path is None:
196
        raise ValueError("output_path must not be None")
197
    try:
198
        grid_results = tuner.tune()
199
        if not grid_results:
200
            _raise_empty_grid()
201
202
        # Initialize parser with results
203
        parser = raytune_parser.TuneParser(result=grid_results)
204
205
        # Ensure output directory exists
206
        Path(output_path).parent.mkdir(parents=True, exist_ok=True)
207
208
        # Save outputs using proper Result object API
209
        parser.save_best_model(output=output_path)
210
        parser.save_best_optimizer(output=best_optimizer_path)
211
        parser.save_best_metrics_dataframe(output=best_metrics_path)
212
        parser.save_best_config(output=best_config_path)
213
214
    except RuntimeError:
215
        logger.exception("Tuning failed")
216
        raise
217
    except KeyError:
218
        logger.exception("Missing expected result key")
219
        raise
220
    finally:
221
        if debug_mode:
222
            logger.info("Debug mode - preserving Ray results directory")
223
        #elif ray_results_dirpath:
224
        #    shutil.rmtree(ray_results_dirpath, ignore_errors=True)
225
226
227
def run() -> None:
228
    """Run the model checking script."""
229
    args = get_args()
230
    main(
231
        data_path=args.data,
232
        model_path=args.model,
233
        data_config_path=args.data_config,
234
        model_config_path=args.model_config,
235
        initial_weights=args.initial_weights,
236
        ray_results_dirpath=args.ray_results_dirpath,
237
        output_path=args.output,
238
        best_optimizer_path=args.best_optimizer,
239
        best_metrics_path=args.best_metrics,
240
        best_config_path=args.best_config,
241
        debug_mode=args.debug_mode,
242
    )
243
244
245
if __name__ == "__main__":
246
    ray.init(address="auto", ignore_reinit_error=True)
247
    run()