diff --git a/pymc_extras/inference/laplace_approx/laplace.py b/pymc_extras/inference/laplace_approx/laplace.py index cd248dcf..7ad71083 100644 --- a/pymc_extras/inference/laplace_approx/laplace.py +++ b/pymc_extras/inference/laplace_approx/laplace.py @@ -14,6 +14,7 @@ import logging +import re from collections.abc import Callable from functools import partial @@ -51,6 +52,58 @@ _log = logging.getLogger(__name__) +def _reset_laplace_dim_idx(idata: az.InferenceData) -> az.InferenceData: + """ + Because `fit_laplace` adds the (temp_chain, temp_draw) dimensions, + any variables without explicitly assigned dimensions receive + automatically generated indices that are shifted by two during + InferenceData creation. + + This helper function corrects that shift by subtracting 2 from the + automatically detected dimension indices of the form + `_dim_`, restoring them to the indices they would have + had if the (temp_chain, temp_draw) dimensions were not added. + + Only affects auto-assigned dimensions in `idata.posterior`. + """ + + pattern = re.compile(r"^(?P.+)_dim_(?P\d+)$") + + dim_renames = {} + var_renames = {} + + for dim in idata.posterior.dims: + match = pattern.match(dim) + if match is None: + continue + + base = match.group("base") + idx = int(match.group("idx")) + + # Guard against invalid or unintended renames + if idx < 2: + raise ValueError( + f"Cannot reset Laplace dimension index for '{dim}': " + f"index {idx} would become negative." + ) + + new_dim = f"{base}_dim_{idx - 2}" + + dim_renames[dim] = new_dim + + # Only rename variables if they actually exist + if dim in idata.posterior.variables: + var_renames[dim] = new_dim + + if dim_renames: + idata.posterior = idata.posterior.rename_dims(dim_renames) + + if var_renames: + idata.posterior = idata.posterior.rename_vars(var_renames) + + return idata + + def get_conditional_gaussian_approximation( x: TensorVariable, Q: TensorVariable | ArrayLike, @@ -224,12 +277,8 @@ def model_to_laplace_approx( elif name in model.named_vars_to_dims: dims = (*batch_dims, *model.named_vars_to_dims[name]) else: - dims = (*batch_dims, *[f"{name}_dim_{i}" for i in range(batched_rv.ndim - 2)]) - initval = initial_point.get(name, None) - dim_shapes = initval.shape if initval is not None else batched_rv.type.shape[2:] - laplace_model.add_coords( - {name: np.arange(shape) for name, shape in zip(dims[2:], dim_shapes)} - ) + n_dim = batched_rv.ndim - 2 # (temp_chain, temp_draw) are always first 2 dims + dims = (*batch_dims,) + (None,) * n_dim pm.Deterministic(name, batched_rv, dims=dims) @@ -468,4 +517,6 @@ def fit_laplace( ["laplace_approximation", "unpacked_variable_names"] ) + idata = _reset_laplace_dim_idx(idata) + return idata diff --git a/tests/inference/laplace_approx/test_laplace.py b/tests/inference/laplace_approx/test_laplace.py index f02f296a..a9134036 100644 --- a/tests/inference/laplace_approx/test_laplace.py +++ b/tests/inference/laplace_approx/test_laplace.py @@ -153,15 +153,21 @@ def test_fit_laplace_coords(include_transformed, rng): assert "city" in idata.unconstrained_posterior.coords -def test_fit_laplace_ragged_coords(rng): +@pytest.mark.parametrize( + "chains, draws, use_dims", + [(1, 500, False), (1, 500, True), (2, 1000, False), (2, 1000, True)], +) +def test_fit_laplace_ragged_coords(chains, draws, use_dims, rng): coords = {"city": ["A", "B", "C"], "feature": [0, 1], "obs_idx": np.arange(100)} with pm.Model(coords=coords) as ragged_dim_model: - X = pm.Data("X", np.ones((100, 2)), dims=["obs_idx", "feature"]) + X = pm.Data("X", np.ones((100, 2)), dims=["obs_idx", "feature"] if use_dims else None) beta = pm.Normal( "beta", mu=[[-100.0, 100.0], [-100.0, 100.0], [-100.0, 100.0]], dims=["city", "feature"] ) mu = pm.Deterministic( - "mu", (X[:, None, :] * beta[None]).sum(axis=-1), dims=["obs_idx", "city"] + "mu", + (X[:, None, :] * beta[None]).sum(axis=-1), + dims=["obs_idx", "city"] if use_dims else None, ) sigma = pm.Normal("sigma", mu=1.5, sigma=0.5, dims=["city"]) @@ -178,6 +184,8 @@ def test_fit_laplace_ragged_coords(rng): progressbar=False, use_grad=True, use_hessp=True, + chains=chains, + draws=draws, ) # These should have been dropped when the laplace idata was created @@ -186,6 +194,8 @@ def test_fit_laplace_ragged_coords(rng): assert idata["posterior"].beta.shape[-2:] == (3, 2) assert idata["posterior"].sigma.shape[-1:] == (3,) + assert idata["posterior"].chain.shape[0] == chains + assert idata["posterior"].draw.shape[0] == draws # Check that everything got unraveled correctly -- feature 0 should be strictly negative, feature 1 # strictly positive