--- a +++ b/demo/scripts/demo_hypersearch_nn.py @@ -0,0 +1,45 @@ +import pickle +from pathlib import Path +from argparse import ArgumentParser + +from survival4D.config import HypersearchConfig, NNExperimentConfig, ModelConfig +from survival4D.nn import hypersearch_nn + +DEFAULT_CONF_PATH = Path(__file__).parent.joinpath("default_nn.conf") + + +def parse_args(): + parser = ArgumentParser() + parser.add_argument( + "-c", "--conf-path", dest="conf_path", type=str, default=None, help="Conf path." + ) + return parser.parse_args() + + +def main(): + args = parse_args() + if args.conf_path is None: + conf_path = DEFAULT_CONF_PATH + else: + conf_path = Path(args.conf_path) + exp_config = NNExperimentConfig.from_conf(conf_path) + model_config = ModelConfig.from_conf(conf_path) + hypersearch_config = HypersearchConfig.from_conf(conf_path) + with open(str(exp_config.data_path), 'rb') as f: + c3 = pickle.load(f) + x_full = c3[0] + y_full = c3[1] + del c3 + + opars, clog = hypersearch_nn( + x_data=x_full, y_data=y_full, + method='particle swarm', nfolds=exp_config.n_folds, nevals=exp_config.n_evals, + batch_size=exp_config.batch_size, num_epochs=exp_config.n_epochs, + backend=exp_config.backend, + model_kwargs=model_config.to_dict(), + **hypersearch_config.to_dict() + ) + + +if __name__ == '__main__': + main()