Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/examples/tutorials/tuning-a-process.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ These conditional search space functions are supported by Ray Tune and can be de
2. Write a custom function to define the search space, where each tunable parameter has a name of the form `"{component_name.field_or_arg_name}"`; then
3. Supply your custom function to the `OptunaSpec` algorithm configuration.

If your search space function accepts an argument named `spec`, Plugboard will pass the [`ProcessSpec`][plugboard.schemas.ProcessSpec] to it. This allows you to inspect the process definition when creating the parameter space, which is useful for example to use the process or component parameters to influence the search space.

For example, the following search space makes the velocity depend on the angle:
```python
--8<-- "examples/tutorials/006_optimisation/hello_tuner.py:custom_search_space"
Expand Down
2 changes: 2 additions & 0 deletions plugboard-schemas/plugboard_schemas/tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ class OptunaSpec(PlugboardBaseModel):
type: The algorithm type to load.
space: Optional; A function defining the search space. Use this to define more complex
search spaces that cannot be represented using the built-in parameter types.
The function must accept a `trial` argument. If it also accepts a `spec` argument,
the `ProcessSpec` will be passed to it.
study_name: Optional; The name of the study.
storage: Optional; The storage URI to save the optimisation results to.
"""
Expand Down
45 changes: 30 additions & 15 deletions plugboard/tune/tune.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Provides `Tuner` class for optimising Plugboard processes."""

from inspect import isfunction
from functools import partial
from inspect import isfunction, signature
import math
from pydoc import locate
import typing as _t
Expand Down Expand Up @@ -61,14 +62,14 @@ def __init__(
self._objective, self._mode, self._metric = self._normalize_objective_and_mode(
objective, mode
)
self._custom_space = bool(algorithm and algorithm.space)
self._algorithm = algorithm
self._custom_space = bool(self._algorithm and self._algorithm.space)
self._max_concurrent = max_concurrent
self._num_samples = num_samples

# Prepare parameters and search algorithm
# Prepare parameters
self._parameters_dict, self._parameters = self._prepare_parameters(parameters)
searcher = self._init_search_algorithm(algorithm, max_concurrent)

# Configure Ray Tune
self._config = ray.tune.TuneConfig(num_samples=num_samples, search_alg=searcher)
self._result_grid: _t.Optional[ray.tune.ResultGrid] = None
self._logger.info("Tuner created")

Expand Down Expand Up @@ -96,13 +97,15 @@ def _check_objective(
raise ValueError("If using a single objective, `mode` must not be a list.")

def _build_algorithm(
self, algorithm: _t.Optional[OptunaSpec] = None
self,
process_spec: ProcessSpec,
algorithm: _t.Optional[OptunaSpec] = None,
) -> ray.tune.search.Searcher:
if algorithm is None:
self._logger.info("Using default Optuna search algorithm")
return self._default_searcher()

algo_kwargs = self._build_algo_kwargs(algorithm)
algo_kwargs = self._build_algo_kwargs(algorithm, process_spec)
algo_cls = self._get_algo_class(algorithm.type)
self._logger.info(
"Using custom search algorithm",
Expand All @@ -114,7 +117,9 @@ def _build_algorithm(
def _default_searcher(self) -> "ray.tune.search.Searcher":
return ray.tune.search.optuna.OptunaSearch(metric=self._metric, mode=self._mode)

def _build_algo_kwargs(self, algorithm: OptunaSpec) -> dict[str, _t.Any]:
def _build_algo_kwargs(
self, algorithm: OptunaSpec, process_spec: ProcessSpec
) -> dict[str, _t.Any]:
"""Prepare keyword args for the searcher, normalising storage/space."""
kwargs = algorithm.model_dump(exclude={"type"})
kwargs["mode"] = self._mode
Expand All @@ -130,14 +135,18 @@ def _build_algo_kwargs(self, algorithm: OptunaSpec) -> dict[str, _t.Any]:

space = kwargs.get("space")
if space is not None:
kwargs["space"] = self._resolve_space_fn(space)
kwargs["space"] = self._resolve_space_fn(space, process_spec)

return kwargs

def _resolve_space_fn(self, space: str) -> _t.Callable:
def _resolve_space_fn(self, space: str, process_spec: ProcessSpec) -> _t.Callable:
space_fn = locate(space)
if not space_fn or not isfunction(space_fn): # pragma: no cover
if not space_fn or not callable(space_fn): # pragma: no cover
raise ValueError(f"Could not locate search space function {space}")
sig = signature(space_fn)
if "spec" in sig.parameters:
self._logger.info("Search space function accepts `spec` argument, passing ProcessSpec")
return partial(space_fn, spec=process_spec)
return space_fn

def _get_algo_class(self, type_path: str) -> _t.Type[ray.tune.search.searcher.Searcher]:
Expand Down Expand Up @@ -232,8 +241,11 @@ def run(self, spec: ProcessSpec) -> ray.tune.Result | list[ray.tune.Result]:
ray.tune.PlacementGroupFactory(placement_bundles),
)

searcher = self._init_search_algorithm(self._algorithm, self._max_concurrent, spec)
config = ray.tune.TuneConfig(num_samples=self._num_samples, search_alg=searcher)

tuner_kwargs: dict[str, _t.Any] = {
"tune_config": self._config,
"tune_config": config,
}
if not self._custom_space:
self._logger.info("Setting Tuner with parameters", params=list(self._parameters.keys()))
Expand Down Expand Up @@ -315,10 +327,13 @@ def _prepare_parameters(
return params_dict, params_space

def _init_search_algorithm(
self, algorithm: _t.Optional[OptunaSpec], max_concurrent: _t.Optional[int]
self,
algorithm: _t.Optional[OptunaSpec],
max_concurrent: _t.Optional[int],
process_spec: ProcessSpec,
) -> "ray.tune.search.Searcher":
"""Create the search algorithm and apply concurrency limits if requested."""
algo = self._build_algorithm(algorithm)
algo = self._build_algorithm(process_spec, algorithm)
if max_concurrent is not None:
algo = ray.tune.search.ConcurrencyLimiter(algo, max_concurrent)
return algo
Expand Down
4 changes: 3 additions & 1 deletion tests/data/dynamic-param-process.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,6 @@ plugboard:
- source: "a.out_1"
target: "d.in_1"
- source: "d.out_1"
target: "c.in_1"
target: "c.in_1"
parameters:
max_iters: 6
29 changes: 25 additions & 4 deletions tests/integration/test_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
IntParameterSpec,
ObjectiveSpec,
OptunaSpec,
ProcessSpec,
)
from plugboard.tune import Tuner
from tests.conftest import ComponentTestHelper
Expand Down Expand Up @@ -68,7 +69,21 @@ def custom_space(trial: Trial) -> dict[str, _t.Any] | None:
trial.suggest_float(f"list_param_{i}", -5.0, -5.0 + float(i)) for i in range(n_list)
]
# Set existing parameter
trial.suggest_int("a.iters", 1, 10)
trial.suggest_int("component.a.arg.iters", 1, 10)
# Use the return value to set the list parameter
return {"component.d.arg.list_param": list_param}


def custom_space_with_process_spec(trial: Trial, spec: ProcessSpec) -> dict[str, _t.Any] | None:
"""Custom space function that also takes the process spec."""
# Process spec can be used to obtain parameters or other information to define the search space
iters = spec.args.parameters["max_iters"]
n_list = trial.suggest_int("n_list", 1, 10)
list_param = [
trial.suggest_float(f"list_param_{i}", -5.0, -5.0 + float(i)) for i in range(n_list)
]
# Set existing parameter
trial.suggest_int("component.a.arg.iters", 1, iters)
# Use the return value to set the list parameter
return {"component.d.arg.list_param": list_param}

Expand Down Expand Up @@ -174,7 +189,7 @@ async def test_multi_objective_tune(config: dict, ray_ctx: None) -> None:
categories=[1, -1],
),
],
num_samples=10,
num_samples=20,
mode=["max", "min"],
max_concurrent=2,
)
Expand Down Expand Up @@ -281,7 +296,10 @@ async def test_tune_with_constraint(config: dict, ray_ctx: None) -> None:

@pytest.mark.tuner
@pytest.mark.asyncio
async def test_custom_space_tune(dynamic_param_config: dict, ray_ctx: None) -> None:
@pytest.mark.parametrize("space_func", [custom_space, custom_space_with_process_spec])
async def test_custom_space_tune(
dynamic_param_config: dict, ray_ctx: None, space_func: _t.Callable
) -> None:
"""Tests tuning with a custom search space."""
spec = ConfigSpec.model_validate(dynamic_param_config)
process_spec = spec.plugboard.process
Expand Down Expand Up @@ -312,7 +330,7 @@ async def test_custom_space_tune(dynamic_param_config: dict, ray_ctx: None) -> N
num_samples=10,
mode="max",
max_concurrent=2,
algorithm=OptunaSpec(space="tests.integration.test_tuner.custom_space"),
algorithm=OptunaSpec(space=f"tests.integration.test_tuner.{space_func.__name__}"),
)
tuner.run(
spec=process_spec,
Expand All @@ -327,3 +345,6 @@ async def test_custom_space_tune(dynamic_param_config: dict, ray_ctx: None) -> N
if r.config["n_list"] < 5:
# When n_list < 5, all list_param values are negative
assert all(v < 0.0 for v in r.config["component.d.arg.list_param"])
if space_func.__name__ == "custom_space_with_process_spec":
# The iters parameter must be set based on the process params
assert r.config["component.a.arg.iters"] <= process_spec.args.parameters["max_iters"]
8 changes: 5 additions & 3 deletions tests/unit/test_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import msgspec
import pytest
import ray.tune
from ray.tune.search.optuna import OptunaSearch

from plugboard.schemas import (
CategoricalParameterSpec,
Expand Down Expand Up @@ -127,5 +127,7 @@ def test_optuna_storage_uri_conversion(temp_dir: str) -> None:
storage=f"sqlite:///{temp_dir}/test_conversion.db",
),
)
algo = tuner._config.search_alg
assert isinstance(algo, ray.tune.search.optuna.OptunaSearch)
with patch("ray.tune.Tuner") as mock_tuner_cls:
tuner.run(spec=MagicMock())
passed_alg = mock_tuner_cls.call_args.kwargs["tune_config"].search_alg
assert isinstance(passed_alg, OptunaSearch)
Loading