|
a |
|
b/bin/tuning.py |
|
|
1 |
#!/usr/bin/env python3 |
|
|
2 |
"""CLI module for running raytune tuning experiment.""" |
|
|
3 |
|
|
|
4 |
import argparse |
|
|
5 |
import logging |
|
|
6 |
import shutil |
|
|
7 |
from pathlib import Path |
|
|
8 |
from typing import Any |
|
|
9 |
|
|
|
10 |
import ray |
|
|
11 |
import yaml |
|
|
12 |
|
|
|
13 |
from stimulus.data import loaders |
|
|
14 |
from stimulus.learner import raytune_learner, raytune_parser |
|
|
15 |
from stimulus.utils import launch_utils, yaml_data, yaml_model_schema |
|
|
16 |
|
|
|
17 |
logger = logging.getLogger(__name__) |
|
|
18 |
|
|
|
19 |
|
|
|
20 |
def _raise_empty_grid() -> None: |
|
|
21 |
"""Raise an error when grid results are empty.""" |
|
|
22 |
raise RuntimeError("Ray Tune returned empty results grid") |
|
|
23 |
|
|
|
24 |
|
|
|
25 |
def get_args() -> argparse.Namespace: |
|
|
26 |
"""Get the arguments when using from the commandline. |
|
|
27 |
|
|
|
28 |
Returns: |
|
|
29 |
Parsed command line arguments. |
|
|
30 |
""" |
|
|
31 |
parser = argparse.ArgumentParser(description="Launch check_model.") |
|
|
32 |
parser.add_argument("-d", "--data", type=str, required=True, metavar="FILE", help="Path to input csv file.") |
|
|
33 |
parser.add_argument("-m", "--model", type=str, required=True, metavar="FILE", help="Path to model file.") |
|
|
34 |
parser.add_argument( |
|
|
35 |
"-e", |
|
|
36 |
"--data_config", |
|
|
37 |
type=str, |
|
|
38 |
required=True, |
|
|
39 |
metavar="FILE", |
|
|
40 |
help="Path to data config file.", |
|
|
41 |
) |
|
|
42 |
parser.add_argument( |
|
|
43 |
"-c", |
|
|
44 |
"--model_config", |
|
|
45 |
type=str, |
|
|
46 |
required=True, |
|
|
47 |
metavar="FILE", |
|
|
48 |
help="Path to yaml config training file.", |
|
|
49 |
) |
|
|
50 |
parser.add_argument( |
|
|
51 |
"-w", |
|
|
52 |
"--initial_weights", |
|
|
53 |
type=str, |
|
|
54 |
required=False, |
|
|
55 |
nargs="?", |
|
|
56 |
const=None, |
|
|
57 |
default=None, |
|
|
58 |
metavar="FILE", |
|
|
59 |
help="The path to the initial weights (optional).", |
|
|
60 |
) |
|
|
61 |
parser.add_argument( |
|
|
62 |
"--ray_results_dirpath", |
|
|
63 |
type=str, |
|
|
64 |
required=False, |
|
|
65 |
nargs="?", |
|
|
66 |
const=None, |
|
|
67 |
default=None, |
|
|
68 |
metavar="DIR_PATH", |
|
|
69 |
help="Location where ray_results output dir should be written. If None, uses ~/ray_results.", |
|
|
70 |
) |
|
|
71 |
parser.add_argument( |
|
|
72 |
"-o", |
|
|
73 |
"--output", |
|
|
74 |
type=str, |
|
|
75 |
required=False, |
|
|
76 |
nargs="?", |
|
|
77 |
const="best_model.pt", |
|
|
78 |
default="best_model.pt", |
|
|
79 |
metavar="FILE", |
|
|
80 |
help="The output file path to write the trained model to", |
|
|
81 |
) |
|
|
82 |
parser.add_argument( |
|
|
83 |
"-bm", |
|
|
84 |
"--best_metrics", |
|
|
85 |
type=str, |
|
|
86 |
required=False, |
|
|
87 |
nargs="?", |
|
|
88 |
const="best_metrics.csv", |
|
|
89 |
default="best_metrics.csv", |
|
|
90 |
metavar="FILE", |
|
|
91 |
help="The path to write the best metrics to", |
|
|
92 |
) |
|
|
93 |
parser.add_argument( |
|
|
94 |
"-bc", |
|
|
95 |
"--best_config", |
|
|
96 |
type=str, |
|
|
97 |
required=False, |
|
|
98 |
nargs="?", |
|
|
99 |
const="best_config.yaml", |
|
|
100 |
default="best_config.yaml", |
|
|
101 |
metavar="FILE", |
|
|
102 |
help="The path to write the best config to", |
|
|
103 |
) |
|
|
104 |
parser.add_argument( |
|
|
105 |
"-bo", |
|
|
106 |
"--best_optimizer", |
|
|
107 |
type=str, |
|
|
108 |
required=False, |
|
|
109 |
nargs="?", |
|
|
110 |
const="best_optimizer.pt", |
|
|
111 |
default="best_optimizer.pt", |
|
|
112 |
metavar="FILE", |
|
|
113 |
help="The path to write the best optimizer to", |
|
|
114 |
) |
|
|
115 |
parser.add_argument( |
|
|
116 |
"--tune_run_name", |
|
|
117 |
type=str, |
|
|
118 |
required=False, |
|
|
119 |
nargs="?", |
|
|
120 |
const=None, |
|
|
121 |
default=None, |
|
|
122 |
metavar="CUSTOM_RUN_NAME", |
|
|
123 |
help=( |
|
|
124 |
"Tells ray tune what the 'experiment_name' (i.e. the given tune_run name) should be. " |
|
|
125 |
"If set, the subdirectory of ray_results is named with this value and its train dir is prefixed accordingly. " |
|
|
126 |
"Default None means that ray will generate such a name on its own." |
|
|
127 |
), |
|
|
128 |
) |
|
|
129 |
parser.add_argument( |
|
|
130 |
"--debug_mode", |
|
|
131 |
action="store_true", |
|
|
132 |
help="Activate debug mode for tuning. Default false, no debug.", |
|
|
133 |
) |
|
|
134 |
return parser.parse_args() |
|
|
135 |
|
|
|
136 |
|
|
|
137 |
def main( |
|
|
138 |
model_path: str, |
|
|
139 |
data_path: str, |
|
|
140 |
data_config_path: str, |
|
|
141 |
model_config_path: str, |
|
|
142 |
initial_weights: str | None = None, # noqa: ARG001 |
|
|
143 |
ray_results_dirpath: str | None = None, |
|
|
144 |
output_path: str | None = None, |
|
|
145 |
best_optimizer_path: str | None = None, |
|
|
146 |
best_metrics_path: str | None = None, |
|
|
147 |
best_config_path: str | None = None, |
|
|
148 |
*, |
|
|
149 |
debug_mode: bool = False, |
|
|
150 |
) -> None: |
|
|
151 |
"""Run the main model checking pipeline. |
|
|
152 |
|
|
|
153 |
Args: |
|
|
154 |
data_path: Path to input data file. |
|
|
155 |
model_path: Path to model file. |
|
|
156 |
data_config_path: Path to data config file. |
|
|
157 |
model_config_path: Path to model config file. |
|
|
158 |
initial_weights: Optional path to initial weights. |
|
|
159 |
ray_results_dirpath: Directory for ray results. |
|
|
160 |
debug_mode: Whether to run in debug mode. |
|
|
161 |
output_path: Path to write the best model to. |
|
|
162 |
best_optimizer_path: Path to write the best optimizer to. |
|
|
163 |
best_metrics_path: Path to write the best metrics to. |
|
|
164 |
best_config_path: Path to write the best config to. |
|
|
165 |
""" |
|
|
166 |
# Convert data config to proper type |
|
|
167 |
with open(data_config_path) as file: |
|
|
168 |
data_config_dict: dict[str, Any] = yaml.safe_load(file) |
|
|
169 |
data_config: yaml_data.YamlSubConfigDict = yaml_data.YamlSubConfigDict(**data_config_dict) |
|
|
170 |
|
|
|
171 |
with open(model_config_path) as file: |
|
|
172 |
model_config_dict: dict[str, Any] = yaml.safe_load(file) |
|
|
173 |
model_config: yaml_model_schema.Model = yaml_model_schema.Model(**model_config_dict) |
|
|
174 |
|
|
|
175 |
encoder_loader = loaders.EncoderLoader() |
|
|
176 |
encoder_loader.initialize_column_encoders_from_config(column_config=data_config.columns) |
|
|
177 |
|
|
|
178 |
model_class = launch_utils.import_class_from_file(model_path) |
|
|
179 |
|
|
|
180 |
ray_config_loader = yaml_model_schema.YamlRayConfigLoader(model=model_config) |
|
|
181 |
ray_config_model = ray_config_loader.get_config() |
|
|
182 |
|
|
|
183 |
tuner = raytune_learner.TuneWrapper( |
|
|
184 |
model_config=ray_config_model, |
|
|
185 |
data_config_path=data_config_path, |
|
|
186 |
model_class=model_class, |
|
|
187 |
data_path=data_path, |
|
|
188 |
encoder_loader=encoder_loader, |
|
|
189 |
seed=42, |
|
|
190 |
ray_results_dir=ray_results_dirpath, |
|
|
191 |
debug=debug_mode, |
|
|
192 |
) |
|
|
193 |
|
|
|
194 |
# Ensure output_path is provided |
|
|
195 |
if output_path is None: |
|
|
196 |
raise ValueError("output_path must not be None") |
|
|
197 |
try: |
|
|
198 |
grid_results = tuner.tune() |
|
|
199 |
if not grid_results: |
|
|
200 |
_raise_empty_grid() |
|
|
201 |
|
|
|
202 |
# Initialize parser with results |
|
|
203 |
parser = raytune_parser.TuneParser(result=grid_results) |
|
|
204 |
|
|
|
205 |
# Ensure output directory exists |
|
|
206 |
Path(output_path).parent.mkdir(parents=True, exist_ok=True) |
|
|
207 |
|
|
|
208 |
# Save outputs using proper Result object API |
|
|
209 |
parser.save_best_model(output=output_path) |
|
|
210 |
parser.save_best_optimizer(output=best_optimizer_path) |
|
|
211 |
parser.save_best_metrics_dataframe(output=best_metrics_path) |
|
|
212 |
parser.save_best_config(output=best_config_path) |
|
|
213 |
|
|
|
214 |
except RuntimeError: |
|
|
215 |
logger.exception("Tuning failed") |
|
|
216 |
raise |
|
|
217 |
except KeyError: |
|
|
218 |
logger.exception("Missing expected result key") |
|
|
219 |
raise |
|
|
220 |
finally: |
|
|
221 |
if debug_mode: |
|
|
222 |
logger.info("Debug mode - preserving Ray results directory") |
|
|
223 |
#elif ray_results_dirpath: |
|
|
224 |
# shutil.rmtree(ray_results_dirpath, ignore_errors=True) |
|
|
225 |
|
|
|
226 |
|
|
|
227 |
def run() -> None: |
|
|
228 |
"""Run the model checking script.""" |
|
|
229 |
args = get_args() |
|
|
230 |
main( |
|
|
231 |
data_path=args.data, |
|
|
232 |
model_path=args.model, |
|
|
233 |
data_config_path=args.data_config, |
|
|
234 |
model_config_path=args.model_config, |
|
|
235 |
initial_weights=args.initial_weights, |
|
|
236 |
ray_results_dirpath=args.ray_results_dirpath, |
|
|
237 |
output_path=args.output, |
|
|
238 |
best_optimizer_path=args.best_optimizer, |
|
|
239 |
best_metrics_path=args.best_metrics, |
|
|
240 |
best_config_path=args.best_config, |
|
|
241 |
debug_mode=args.debug_mode, |
|
|
242 |
) |
|
|
243 |
|
|
|
244 |
|
|
|
245 |
if __name__ == "__main__": |
|
|
246 |
ray.init(address="auto", ignore_reinit_error=True) |
|
|
247 |
run() |