From 52d8a92eb9a0af436bee77ea46a41e3640cc2189 Mon Sep 17 00:00:00 2001 From: Jonathan Dekermanjian Date: Fri, 7 Nov 2025 07:20:44 -0700 Subject: [PATCH 1/5] updated dim_shape assignment logic in fit_laplace to handle absent dims on data containers and deterministics --- .../inference/laplace_approx/laplace.py | 4 +- .../inference/laplace_approx/test_laplace.py | 38 +++++++++++++++++++ 2 files changed, 40 insertions(+), 2 deletions(-) diff --git a/pymc_extras/inference/laplace_approx/laplace.py b/pymc_extras/inference/laplace_approx/laplace.py index cd248dcf..72247555 100644 --- a/pymc_extras/inference/laplace_approx/laplace.py +++ b/pymc_extras/inference/laplace_approx/laplace.py @@ -226,9 +226,9 @@ def model_to_laplace_approx( else: dims = (*batch_dims, *[f"{name}_dim_{i}" for i in range(batched_rv.ndim - 2)]) initval = initial_point.get(name, None) - dim_shapes = initval.shape if initval is not None else batched_rv.type.shape[2:] + dim_shapes = initval.shape if initval is not None else batched_rv.shape.eval()[2:] laplace_model.add_coords( - {name: np.arange(shape) for name, shape in zip(dims[2:], dim_shapes)} + {name: pt.arange(shape) for name, shape in zip(dims[2:], dim_shapes)} ) pm.Deterministic(name, batched_rv, dims=dims) diff --git a/tests/inference/laplace_approx/test_laplace.py b/tests/inference/laplace_approx/test_laplace.py index f02f296a..d9edb958 100644 --- a/tests/inference/laplace_approx/test_laplace.py +++ b/tests/inference/laplace_approx/test_laplace.py @@ -193,6 +193,44 @@ def test_fit_laplace_ragged_coords(rng): assert (idata["posterior"].beta.sel(feature=1).to_numpy() > 0).all() +def test_fit_laplace_no_data_or_deterministic_dims(rng): + coords = {"city": ["A", "B", "C"], "feature": [0, 1], "obs_idx": np.arange(100)} + with pm.Model(coords=coords) as ragged_dim_model: + X = pm.Data("X", np.ones((100, 2))) + beta = pm.Normal( + "beta", mu=[[-100.0, 100.0], [-100.0, 100.0], [-100.0, 100.0]], dims=["city", "feature"] + ) + mu = pm.Deterministic("mu", (X[:, None, :] * beta[None]).sum(axis=-1)) + sigma = pm.Normal("sigma", mu=1.5, sigma=0.5, dims=["city"]) + + obs = pm.Normal( + "obs", + mu=mu, + sigma=sigma, + observed=rng.normal(loc=3, scale=1.5, size=(100, 3)), + dims=["obs_idx", "city"], + ) + + idata = fit_laplace( + optimize_method="Newton-CG", + progressbar=False, + use_grad=True, + use_hessp=True, + ) + + # These should have been dropped when the laplace idata was created + assert "laplace_approximation" not in list(idata.posterior.data_vars.keys()) + assert "unpacked_var_names" not in list(idata.posterior.coords.keys()) + + assert idata["posterior"].beta.shape[-2:] == (3, 2) + assert idata["posterior"].sigma.shape[-1:] == (3,) + + # Check that everything got unraveled correctly -- feature 0 should be strictly negative, feature 1 + # strictly positive + assert (idata["posterior"].beta.sel(feature=0).to_numpy() < 0).all() + assert (idata["posterior"].beta.sel(feature=1).to_numpy() > 0).all() + + def test_model_with_nonstandard_dimensionality(rng): y_obs = np.concatenate( [rng.normal(-1, 2, size=150), rng.normal(3, 1, size=350), rng.normal(5, 4, size=50)] From 1e417582deb7325cb21c8537ef9208ad60a88722 Mon Sep 17 00:00:00 2001 From: Jonathan Dekermanjian Date: Sat, 8 Nov 2025 05:36:13 -0700 Subject: [PATCH 2/5] reverted to numpy arange --- pymc_extras/inference/laplace_approx/laplace.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc_extras/inference/laplace_approx/laplace.py b/pymc_extras/inference/laplace_approx/laplace.py index 72247555..7ec838bc 100644 --- a/pymc_extras/inference/laplace_approx/laplace.py +++ b/pymc_extras/inference/laplace_approx/laplace.py @@ -228,7 +228,7 @@ def model_to_laplace_approx( initval = initial_point.get(name, None) dim_shapes = initval.shape if initval is not None else batched_rv.shape.eval()[2:] laplace_model.add_coords( - {name: pt.arange(shape) for name, shape in zip(dims[2:], dim_shapes)} + {name: np.arange(shape) for name, shape in zip(dims[2:], dim_shapes)} ) pm.Deterministic(name, batched_rv, dims=dims) From 5abe4ff11221531d82e7f9db879bb18acf9e3249 Mon Sep 17 00:00:00 2001 From: Jonathan Dekermanjian Date: Sat, 8 Nov 2025 08:40:32 -0700 Subject: [PATCH 3/5] updated dim handling in model_to_laplace_approx to not force dims on variables that did not have them originally --- pymc_extras/inference/laplace_approx/laplace.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/pymc_extras/inference/laplace_approx/laplace.py b/pymc_extras/inference/laplace_approx/laplace.py index 7ec838bc..d78fc3df 100644 --- a/pymc_extras/inference/laplace_approx/laplace.py +++ b/pymc_extras/inference/laplace_approx/laplace.py @@ -224,12 +224,15 @@ def model_to_laplace_approx( elif name in model.named_vars_to_dims: dims = (*batch_dims, *model.named_vars_to_dims[name]) else: - dims = (*batch_dims, *[f"{name}_dim_{i}" for i in range(batched_rv.ndim - 2)]) initval = initial_point.get(name, None) - dim_shapes = initval.shape if initval is not None else batched_rv.shape.eval()[2:] - laplace_model.add_coords( - {name: np.arange(shape) for name, shape in zip(dims[2:], dim_shapes)} - ) + dim_shapes = initval.shape if initval is not None else batched_rv.type.shape[2:] + if dim_shapes[0] is not None: + dims = (*batch_dims, *[f"{name}_dim_{i}" for i in range(batched_rv.ndim - 2)]) + laplace_model.add_coords( + {name: np.arange(shape) for name, shape in zip(dims[2:], dim_shapes)} + ) + else: + dims = None pm.Deterministic(name, batched_rv, dims=dims) From 81086e7c41695ce8b81a25a973870a1052745101 Mon Sep 17 00:00:00 2001 From: Jonathan Dekermanjian Date: Sat, 13 Dec 2025 09:40:04 -0700 Subject: [PATCH 4/5] allow inference data object to automatically set dims and post process the dimension names after --- .../inference/laplace_approx/laplace.py | 66 ++++++++++++++++--- 1 file changed, 57 insertions(+), 9 deletions(-) diff --git a/pymc_extras/inference/laplace_approx/laplace.py b/pymc_extras/inference/laplace_approx/laplace.py index d78fc3df..7ad71083 100644 --- a/pymc_extras/inference/laplace_approx/laplace.py +++ b/pymc_extras/inference/laplace_approx/laplace.py @@ -14,6 +14,7 @@ import logging +import re from collections.abc import Callable from functools import partial @@ -51,6 +52,58 @@ _log = logging.getLogger(__name__) +def _reset_laplace_dim_idx(idata: az.InferenceData) -> az.InferenceData: + """ + Because `fit_laplace` adds the (temp_chain, temp_draw) dimensions, + any variables without explicitly assigned dimensions receive + automatically generated indices that are shifted by two during + InferenceData creation. + + This helper function corrects that shift by subtracting 2 from the + automatically detected dimension indices of the form + `_dim_`, restoring them to the indices they would have + had if the (temp_chain, temp_draw) dimensions were not added. + + Only affects auto-assigned dimensions in `idata.posterior`. + """ + + pattern = re.compile(r"^(?P.+)_dim_(?P\d+)$") + + dim_renames = {} + var_renames = {} + + for dim in idata.posterior.dims: + match = pattern.match(dim) + if match is None: + continue + + base = match.group("base") + idx = int(match.group("idx")) + + # Guard against invalid or unintended renames + if idx < 2: + raise ValueError( + f"Cannot reset Laplace dimension index for '{dim}': " + f"index {idx} would become negative." + ) + + new_dim = f"{base}_dim_{idx - 2}" + + dim_renames[dim] = new_dim + + # Only rename variables if they actually exist + if dim in idata.posterior.variables: + var_renames[dim] = new_dim + + if dim_renames: + idata.posterior = idata.posterior.rename_dims(dim_renames) + + if var_renames: + idata.posterior = idata.posterior.rename_vars(var_renames) + + return idata + + def get_conditional_gaussian_approximation( x: TensorVariable, Q: TensorVariable | ArrayLike, @@ -224,15 +277,8 @@ def model_to_laplace_approx( elif name in model.named_vars_to_dims: dims = (*batch_dims, *model.named_vars_to_dims[name]) else: - initval = initial_point.get(name, None) - dim_shapes = initval.shape if initval is not None else batched_rv.type.shape[2:] - if dim_shapes[0] is not None: - dims = (*batch_dims, *[f"{name}_dim_{i}" for i in range(batched_rv.ndim - 2)]) - laplace_model.add_coords( - {name: np.arange(shape) for name, shape in zip(dims[2:], dim_shapes)} - ) - else: - dims = None + n_dim = batched_rv.ndim - 2 # (temp_chain, temp_draw) are always first 2 dims + dims = (*batch_dims,) + (None,) * n_dim pm.Deterministic(name, batched_rv, dims=dims) @@ -471,4 +517,6 @@ def fit_laplace( ["laplace_approximation", "unpacked_variable_names"] ) + idata = _reset_laplace_dim_idx(idata) + return idata From 1450629440ca2e33d32bec12056fbc4ac4ad2d88 Mon Sep 17 00:00:00 2001 From: Jonathan Dekermanjian Date: Sun, 14 Dec 2025 07:56:48 -0700 Subject: [PATCH 5/5] added test for sampling dims and consolidated two tests into one --- .../inference/laplace_approx/test_laplace.py | 54 +++++-------------- 1 file changed, 13 insertions(+), 41 deletions(-) diff --git a/tests/inference/laplace_approx/test_laplace.py b/tests/inference/laplace_approx/test_laplace.py index d9edb958..a9134036 100644 --- a/tests/inference/laplace_approx/test_laplace.py +++ b/tests/inference/laplace_approx/test_laplace.py @@ -153,54 +153,22 @@ def test_fit_laplace_coords(include_transformed, rng): assert "city" in idata.unconstrained_posterior.coords -def test_fit_laplace_ragged_coords(rng): +@pytest.mark.parametrize( + "chains, draws, use_dims", + [(1, 500, False), (1, 500, True), (2, 1000, False), (2, 1000, True)], +) +def test_fit_laplace_ragged_coords(chains, draws, use_dims, rng): coords = {"city": ["A", "B", "C"], "feature": [0, 1], "obs_idx": np.arange(100)} with pm.Model(coords=coords) as ragged_dim_model: - X = pm.Data("X", np.ones((100, 2)), dims=["obs_idx", "feature"]) + X = pm.Data("X", np.ones((100, 2)), dims=["obs_idx", "feature"] if use_dims else None) beta = pm.Normal( "beta", mu=[[-100.0, 100.0], [-100.0, 100.0], [-100.0, 100.0]], dims=["city", "feature"] ) mu = pm.Deterministic( - "mu", (X[:, None, :] * beta[None]).sum(axis=-1), dims=["obs_idx", "city"] - ) - sigma = pm.Normal("sigma", mu=1.5, sigma=0.5, dims=["city"]) - - obs = pm.Normal( - "obs", - mu=mu, - sigma=sigma, - observed=rng.normal(loc=3, scale=1.5, size=(100, 3)), - dims=["obs_idx", "city"], - ) - - idata = fit_laplace( - optimize_method="Newton-CG", - progressbar=False, - use_grad=True, - use_hessp=True, - ) - - # These should have been dropped when the laplace idata was created - assert "laplace_approximation" not in list(idata.posterior.data_vars.keys()) - assert "unpacked_var_names" not in list(idata.posterior.coords.keys()) - - assert idata["posterior"].beta.shape[-2:] == (3, 2) - assert idata["posterior"].sigma.shape[-1:] == (3,) - - # Check that everything got unraveled correctly -- feature 0 should be strictly negative, feature 1 - # strictly positive - assert (idata["posterior"].beta.sel(feature=0).to_numpy() < 0).all() - assert (idata["posterior"].beta.sel(feature=1).to_numpy() > 0).all() - - -def test_fit_laplace_no_data_or_deterministic_dims(rng): - coords = {"city": ["A", "B", "C"], "feature": [0, 1], "obs_idx": np.arange(100)} - with pm.Model(coords=coords) as ragged_dim_model: - X = pm.Data("X", np.ones((100, 2))) - beta = pm.Normal( - "beta", mu=[[-100.0, 100.0], [-100.0, 100.0], [-100.0, 100.0]], dims=["city", "feature"] + "mu", + (X[:, None, :] * beta[None]).sum(axis=-1), + dims=["obs_idx", "city"] if use_dims else None, ) - mu = pm.Deterministic("mu", (X[:, None, :] * beta[None]).sum(axis=-1)) sigma = pm.Normal("sigma", mu=1.5, sigma=0.5, dims=["city"]) obs = pm.Normal( @@ -216,6 +184,8 @@ def test_fit_laplace_no_data_or_deterministic_dims(rng): progressbar=False, use_grad=True, use_hessp=True, + chains=chains, + draws=draws, ) # These should have been dropped when the laplace idata was created @@ -224,6 +194,8 @@ def test_fit_laplace_no_data_or_deterministic_dims(rng): assert idata["posterior"].beta.shape[-2:] == (3, 2) assert idata["posterior"].sigma.shape[-1:] == (3,) + assert idata["posterior"].chain.shape[0] == chains + assert idata["posterior"].draw.shape[0] == draws # Check that everything got unraveled correctly -- feature 0 should be strictly negative, feature 1 # strictly positive