diff --git a/docs/examples/tutorials/tuning-a-process.md b/docs/examples/tutorials/tuning-a-process.md index a947be90..7578a9b3 100644 --- a/docs/examples/tutorials/tuning-a-process.md +++ b/docs/examples/tutorials/tuning-a-process.md @@ -90,6 +90,8 @@ These conditional search space functions are supported by Ray Tune and can be de 2. Write a custom function to define the search space, where each tunable parameter has a name of the form `"{component_name.field_or_arg_name}"`; then 3. Supply your custom function to the `OptunaSpec` algorithm configuration. +If your search space function accepts an argument named `spec`, Plugboard will pass the [`ProcessSpec`][plugboard.schemas.ProcessSpec] to it. This allows you to inspect the process definition when creating the parameter space, which is useful for example to use the process or component parameters to influence the search space. + For example, the following search space makes the velocity depend on the angle: ```python --8<-- "examples/tutorials/006_optimisation/hello_tuner.py:custom_search_space" diff --git a/plugboard-schemas/plugboard_schemas/tune.py b/plugboard-schemas/plugboard_schemas/tune.py index f7134bfa..7018b6f9 100644 --- a/plugboard-schemas/plugboard_schemas/tune.py +++ b/plugboard-schemas/plugboard_schemas/tune.py @@ -19,6 +19,8 @@ class OptunaSpec(PlugboardBaseModel): type: The algorithm type to load. space: Optional; A function defining the search space. Use this to define more complex search spaces that cannot be represented using the built-in parameter types. + The function must accept a `trial` argument. If it also accepts a `spec` argument, + the `ProcessSpec` will be passed to it. study_name: Optional; The name of the study. storage: Optional; The storage URI to save the optimisation results to. """ diff --git a/plugboard/tune/tune.py b/plugboard/tune/tune.py index 7fce1552..827643f5 100644 --- a/plugboard/tune/tune.py +++ b/plugboard/tune/tune.py @@ -1,6 +1,7 @@ """Provides `Tuner` class for optimising Plugboard processes.""" -from inspect import isfunction +from functools import partial +from inspect import isfunction, signature import math from pydoc import locate import typing as _t @@ -61,14 +62,14 @@ def __init__( self._objective, self._mode, self._metric = self._normalize_objective_and_mode( objective, mode ) - self._custom_space = bool(algorithm and algorithm.space) + self._algorithm = algorithm + self._custom_space = bool(self._algorithm and self._algorithm.space) + self._max_concurrent = max_concurrent + self._num_samples = num_samples - # Prepare parameters and search algorithm + # Prepare parameters self._parameters_dict, self._parameters = self._prepare_parameters(parameters) - searcher = self._init_search_algorithm(algorithm, max_concurrent) - # Configure Ray Tune - self._config = ray.tune.TuneConfig(num_samples=num_samples, search_alg=searcher) self._result_grid: _t.Optional[ray.tune.ResultGrid] = None self._logger.info("Tuner created") @@ -96,13 +97,15 @@ def _check_objective( raise ValueError("If using a single objective, `mode` must not be a list.") def _build_algorithm( - self, algorithm: _t.Optional[OptunaSpec] = None + self, + process_spec: ProcessSpec, + algorithm: _t.Optional[OptunaSpec] = None, ) -> ray.tune.search.Searcher: if algorithm is None: self._logger.info("Using default Optuna search algorithm") return self._default_searcher() - algo_kwargs = self._build_algo_kwargs(algorithm) + algo_kwargs = self._build_algo_kwargs(algorithm, process_spec) algo_cls = self._get_algo_class(algorithm.type) self._logger.info( "Using custom search algorithm", @@ -114,7 +117,9 @@ def _build_algorithm( def _default_searcher(self) -> "ray.tune.search.Searcher": return ray.tune.search.optuna.OptunaSearch(metric=self._metric, mode=self._mode) - def _build_algo_kwargs(self, algorithm: OptunaSpec) -> dict[str, _t.Any]: + def _build_algo_kwargs( + self, algorithm: OptunaSpec, process_spec: ProcessSpec + ) -> dict[str, _t.Any]: """Prepare keyword args for the searcher, normalising storage/space.""" kwargs = algorithm.model_dump(exclude={"type"}) kwargs["mode"] = self._mode @@ -130,14 +135,18 @@ def _build_algo_kwargs(self, algorithm: OptunaSpec) -> dict[str, _t.Any]: space = kwargs.get("space") if space is not None: - kwargs["space"] = self._resolve_space_fn(space) + kwargs["space"] = self._resolve_space_fn(space, process_spec) return kwargs - def _resolve_space_fn(self, space: str) -> _t.Callable: + def _resolve_space_fn(self, space: str, process_spec: ProcessSpec) -> _t.Callable: space_fn = locate(space) - if not space_fn or not isfunction(space_fn): # pragma: no cover + if not space_fn or not callable(space_fn): # pragma: no cover raise ValueError(f"Could not locate search space function {space}") + sig = signature(space_fn) + if "spec" in sig.parameters: + self._logger.info("Search space function accepts `spec` argument, passing ProcessSpec") + return partial(space_fn, spec=process_spec) return space_fn def _get_algo_class(self, type_path: str) -> _t.Type[ray.tune.search.searcher.Searcher]: @@ -232,8 +241,11 @@ def run(self, spec: ProcessSpec) -> ray.tune.Result | list[ray.tune.Result]: ray.tune.PlacementGroupFactory(placement_bundles), ) + searcher = self._init_search_algorithm(self._algorithm, self._max_concurrent, spec) + config = ray.tune.TuneConfig(num_samples=self._num_samples, search_alg=searcher) + tuner_kwargs: dict[str, _t.Any] = { - "tune_config": self._config, + "tune_config": config, } if not self._custom_space: self._logger.info("Setting Tuner with parameters", params=list(self._parameters.keys())) @@ -315,10 +327,13 @@ def _prepare_parameters( return params_dict, params_space def _init_search_algorithm( - self, algorithm: _t.Optional[OptunaSpec], max_concurrent: _t.Optional[int] + self, + algorithm: _t.Optional[OptunaSpec], + max_concurrent: _t.Optional[int], + process_spec: ProcessSpec, ) -> "ray.tune.search.Searcher": """Create the search algorithm and apply concurrency limits if requested.""" - algo = self._build_algorithm(algorithm) + algo = self._build_algorithm(process_spec, algorithm) if max_concurrent is not None: algo = ray.tune.search.ConcurrencyLimiter(algo, max_concurrent) return algo diff --git a/tests/data/dynamic-param-process.yaml b/tests/data/dynamic-param-process.yaml index 76854633..da390c65 100644 --- a/tests/data/dynamic-param-process.yaml +++ b/tests/data/dynamic-param-process.yaml @@ -18,4 +18,6 @@ plugboard: - source: "a.out_1" target: "d.in_1" - source: "d.out_1" - target: "c.in_1" \ No newline at end of file + target: "c.in_1" + parameters: + max_iters: 6 \ No newline at end of file diff --git a/tests/integration/test_tuner.py b/tests/integration/test_tuner.py index 39c43324..1ac2199d 100644 --- a/tests/integration/test_tuner.py +++ b/tests/integration/test_tuner.py @@ -17,6 +17,7 @@ IntParameterSpec, ObjectiveSpec, OptunaSpec, + ProcessSpec, ) from plugboard.tune import Tuner from tests.conftest import ComponentTestHelper @@ -68,7 +69,21 @@ def custom_space(trial: Trial) -> dict[str, _t.Any] | None: trial.suggest_float(f"list_param_{i}", -5.0, -5.0 + float(i)) for i in range(n_list) ] # Set existing parameter - trial.suggest_int("a.iters", 1, 10) + trial.suggest_int("component.a.arg.iters", 1, 10) + # Use the return value to set the list parameter + return {"component.d.arg.list_param": list_param} + + +def custom_space_with_process_spec(trial: Trial, spec: ProcessSpec) -> dict[str, _t.Any] | None: + """Custom space function that also takes the process spec.""" + # Process spec can be used to obtain parameters or other information to define the search space + iters = spec.args.parameters["max_iters"] + n_list = trial.suggest_int("n_list", 1, 10) + list_param = [ + trial.suggest_float(f"list_param_{i}", -5.0, -5.0 + float(i)) for i in range(n_list) + ] + # Set existing parameter + trial.suggest_int("component.a.arg.iters", 1, iters) # Use the return value to set the list parameter return {"component.d.arg.list_param": list_param} @@ -174,7 +189,7 @@ async def test_multi_objective_tune(config: dict, ray_ctx: None) -> None: categories=[1, -1], ), ], - num_samples=10, + num_samples=20, mode=["max", "min"], max_concurrent=2, ) @@ -281,7 +296,10 @@ async def test_tune_with_constraint(config: dict, ray_ctx: None) -> None: @pytest.mark.tuner @pytest.mark.asyncio -async def test_custom_space_tune(dynamic_param_config: dict, ray_ctx: None) -> None: +@pytest.mark.parametrize("space_func", [custom_space, custom_space_with_process_spec]) +async def test_custom_space_tune( + dynamic_param_config: dict, ray_ctx: None, space_func: _t.Callable +) -> None: """Tests tuning with a custom search space.""" spec = ConfigSpec.model_validate(dynamic_param_config) process_spec = spec.plugboard.process @@ -312,7 +330,7 @@ async def test_custom_space_tune(dynamic_param_config: dict, ray_ctx: None) -> N num_samples=10, mode="max", max_concurrent=2, - algorithm=OptunaSpec(space="tests.integration.test_tuner.custom_space"), + algorithm=OptunaSpec(space=f"tests.integration.test_tuner.{space_func.__name__}"), ) tuner.run( spec=process_spec, @@ -327,3 +345,6 @@ async def test_custom_space_tune(dynamic_param_config: dict, ray_ctx: None) -> N if r.config["n_list"] < 5: # When n_list < 5, all list_param values are negative assert all(v < 0.0 for v in r.config["component.d.arg.list_param"]) + if space_func.__name__ == "custom_space_with_process_spec": + # The iters parameter must be set based on the process params + assert r.config["component.a.arg.iters"] <= process_spec.args.parameters["max_iters"] diff --git a/tests/unit/test_tuner.py b/tests/unit/test_tuner.py index 582472a6..4cb0e1a0 100644 --- a/tests/unit/test_tuner.py +++ b/tests/unit/test_tuner.py @@ -6,7 +6,7 @@ import msgspec import pytest -import ray.tune +from ray.tune.search.optuna import OptunaSearch from plugboard.schemas import ( CategoricalParameterSpec, @@ -127,5 +127,7 @@ def test_optuna_storage_uri_conversion(temp_dir: str) -> None: storage=f"sqlite:///{temp_dir}/test_conversion.db", ), ) - algo = tuner._config.search_alg - assert isinstance(algo, ray.tune.search.optuna.OptunaSearch) + with patch("ray.tune.Tuner") as mock_tuner_cls: + tuner.run(spec=MagicMock()) + passed_alg = mock_tuner_cls.call_args.kwargs["tune_config"].search_alg + assert isinstance(passed_alg, OptunaSearch)