diff --git a/src/tinygp/solvers/quasisep/core.py b/src/tinygp/solvers/quasisep/core.py index 18766014..814cb216 100644 --- a/src/tinygp/solvers/quasisep/core.py +++ b/src/tinygp/solvers/quasisep/core.py @@ -68,6 +68,10 @@ def matmul(self, x: JAXArray) -> JAXArray: """ raise NotImplementedError + @abstractmethod + def parallel_matmul(self, x: JAXArray) -> JAXArray: + raise NotImplementedError + @abstractmethod def scale(self, other: JAXArray) -> QSM: """The multiplication of this matrix times a scalar, as a QSM""" @@ -150,6 +154,10 @@ def transpose(self) -> DiagQSM: def matmul(self, x: JAXArray) -> JAXArray: return self.d[:, None] * x + @handle_matvec_shapes + def parallel_matmul(self, x: JAXArray) -> JAXArray: + return self.matmul(x) + def scale(self, other: JAXArray) -> DiagQSM: return DiagQSM(d=self.d * other) @@ -197,6 +205,17 @@ def impl(f, data): # type: ignore _, f = jax.lax.scan(impl, init, (self.q, self.a, x)) return jax.vmap(jnp.dot)(self.p, f) + @jax.jit + @handle_matvec_shapes + def parallel_matmul(self, x: JAXArray) -> JAXArray: + def impl(sm, sn): + return (sn[0] @ sm[0], sn[0] @ sm[1] + sn[1]) + + states = jax.vmap(lambda l, x: (l.a, jnp.outer(l.q, x)))(self, x) + f = jax.lax.associative_scan(impl, states)[1] + f = jnp.concatenate((jnp.zeros_like(f[:1]), f[:-1]), axis=0) + return jax.vmap(jnp.dot)(self.p, f) + def scale(self, other: JAXArray) -> StrictLowerTriQSM: return StrictLowerTriQSM(p=self.p * other, q=self.q, a=self.a) @@ -270,6 +289,17 @@ def impl(f, data): # type: ignore _, f = jax.lax.scan(impl, init, (self.p, self.a, x), reverse=True) return jax.vmap(jnp.dot)(self.q, f) + @jax.jit + @handle_matvec_shapes + def parallel_matmul(self, x: JAXArray) -> JAXArray: + def impl(sm, sn): + return (sn[0] @ sm[0], sn[0] @ sm[1] + sn[1]) + + states = jax.vmap(lambda u, x: (u.a.T, jnp.outer(u.p, x)))(self, x) + f = jax.lax.associative_scan(impl, states, reverse=True)[1] + f = jnp.concatenate((f[1:], jnp.zeros_like(f[:1])), axis=0) + return jax.vmap(jnp.dot)(self.q, f) + def scale(self, other: JAXArray) -> StrictUpperTriQSM: return StrictUpperTriQSM(p=self.p, q=self.q * other, a=self.a) @@ -303,6 +333,10 @@ def transpose(self) -> UpperTriQSM: def matmul(self, x: JAXArray) -> JAXArray: return self.diag.matmul(x) + self.lower.matmul(x) + @handle_matvec_shapes + def parallel_matmul(self, x: JAXArray) -> JAXArray: + return self.diag.parallel_matmul(x) + self.lower.parallel_matmul(x) + def scale(self, other: JAXArray) -> LowerTriQSM: return LowerTriQSM(diag=self.diag.scale(other), lower=self.lower.scale(other)) @@ -359,6 +393,10 @@ def transpose(self) -> LowerTriQSM: def matmul(self, x: JAXArray) -> JAXArray: return self.diag.matmul(x) + self.upper.matmul(x) + @handle_matvec_shapes + def parallel_matmul(self, x: JAXArray) -> JAXArray: + return self.diag.parallel_matmul(x) + self.upper.parallel_matmul(x) + def scale(self, other: JAXArray) -> UpperTriQSM: return UpperTriQSM(diag=self.diag.scale(other), upper=self.upper.scale(other)) @@ -415,6 +453,14 @@ def transpose(self) -> SquareQSM: def matmul(self, x: JAXArray) -> JAXArray: return self.diag.matmul(x) + self.lower.matmul(x) + self.upper.matmul(x) + @handle_matvec_shapes + def parallel_matmul(self, x: JAXArray) -> JAXArray: + return ( + self.diag.parallel_matmul(x) + + self.lower.parallel_matmul(x) + + self.upper.parallel_matmul(x) + ) + def scale(self, other: JAXArray) -> SquareQSM: return SquareQSM( diag=self.diag.scale(other), @@ -505,6 +551,14 @@ def matmul(self, x: JAXArray) -> JAXArray: + self.lower.transpose().matmul(x) ) + @handle_matvec_shapes + def parallel_matmul(self, x: JAXArray) -> JAXArray: + return ( + self.diag.parallel_matmul(x) + + self.lower.parallel_matmul(x) + + self.lower.transpose().parallel_matmul(x) + ) + def scale(self, other: JAXArray) -> SymmQSM: return SymmQSM(diag=self.diag.scale(other), lower=self.lower.scale(other)) diff --git a/tests/test_solvers/test_quasisep/test_core.py b/tests/test_solvers/test_quasisep/test_core.py index ae1de4c4..80c6a861 100644 --- a/tests/test_solvers/test_quasisep/test_core.py +++ b/tests/test_solvers/test_quasisep/test_core.py @@ -6,6 +6,7 @@ import pytest from numpy import random as np_random +from tinygp.kernels.quasisep import Matern52 from tinygp.solvers.quasisep.core import ( DiagQSM, LowerTriQSM, @@ -17,7 +18,7 @@ from tinygp.test_utils import assert_allclose -@pytest.fixture(params=["random", "celerite"]) +@pytest.fixture(params=["random", "celerite", "matern"]) def name(request): return request.param @@ -104,6 +105,17 @@ def get_matrices(name): a = jnp.stack([jnp.diag(v) for v in jnp.exp(-c[None] * dt[:, None])], axis=0) p = jnp.einsum("ni,nij->nj", p, a) + elif name == "matern": + t = jnp.sort(random.uniform(0, 10, N)) + kernel = Matern52(1.5, 1.0) + matrix = kernel.to_symm_qsm(t) + diag += matrix.diag.d + p = matrix.lower.p + q = matrix.lower.q + a = matrix.lower.a + l = matrix.lower.to_dense() + u = l.T + else: raise AssertionError() @@ -166,6 +178,20 @@ def test_strict_tri_matmul(matrices): assert_allclose(mat.T @ m, u @ m) +def test_strict_lower_tri_parallel_matmul(matrices): + _, p, q, a, v, m, _, _ = matrices + mat = StrictLowerTriQSM(p=p, q=q, a=a) + assert_allclose(mat.parallel_matmul(v), mat.matmul(v)) + assert_allclose(mat.parallel_matmul(m), mat.matmul(m)) + + +def test_strict_upper_tri_parallel_matmul(matrices): + _, p, q, a, v, m, _, _ = matrices + mat = StrictLowerTriQSM(p=p, q=q, a=a).T + assert_allclose(mat.parallel_matmul(v), mat.matmul(v)) + assert_allclose(mat.parallel_matmul(m), mat.matmul(m)) + + def test_tri_matmul(matrices): diag, p, q, a, v, m, l, _ = matrices mat = LowerTriQSM(diag=DiagQSM(diag), lower=StrictLowerTriQSM(p=p, q=q, a=a))