Skip to content

Commit 040f77a

Browse files
committed
Try to run full test suite in Numba backend
1 parent dae7ff3 commit 040f77a

File tree

3 files changed

+25
-16
lines changed

3 files changed

+25
-16
lines changed

.github/workflows/test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ jobs:
193193
else
194194
micromamba install --yes -q "python~=${PYTHON_VERSION}" mkl "numpy${NUMPY_VERSION}" scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock;
195195
fi
196-
if [[ $INSTALL_NUMBA == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" "numba>=0.57"; fi
196+
micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" "numba>=0.57"; fi
197197
if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" jax jaxlib numpyro && pip install tensorflow-probability; fi
198198
if [[ $INSTALL_TORCH == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" pytorch pytorch-cuda=12.1 "mkl<=2024.0" -c pytorch -c nvidia; fi
199199
pip install pytest-sphinx

pytensor/compile/mode.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,8 @@ def register_linker(name, linker):
6363
# If a string is passed as the optimizer argument in the constructor
6464
# for Mode, it will be used as the key to retrieve the real optimizer
6565
# in this dictionary
66-
exclude = []
67-
if not config.cxx:
68-
exclude = ["cxx_only"]
66+
67+
exclude = ["cxx_only", "BlasOpt"]
6968
OPT_NONE = RewriteDatabaseQuery(include=[], exclude=exclude)
7069
# Even if multiple merge optimizer call will be there, this shouldn't
7170
# impact performance.
@@ -346,6 +345,11 @@ def __setstate__(self, state):
346345
optimizer = predefined_optimizers[optimizer]
347346
if isinstance(optimizer, RewriteDatabaseQuery):
348347
self.provided_optimizer = optimizer
348+
349+
# Force numba-required rewrites if using NumbaLinker
350+
if isinstance(linker, NumbaLinker):
351+
optimizer = optimizer.including("numba")
352+
349353
self._optimizer = optimizer
350354
self.call_time = 0
351355
self.fn_time = 0
@@ -443,16 +447,20 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
443447
# string as the key
444448
# Use VM_linker to allow lazy evaluation by default.
445449
FAST_COMPILE = Mode(
446-
VMLinker(use_cloop=False, c_thunks=False),
447-
RewriteDatabaseQuery(include=["fast_compile", "py_only"]),
450+
NumbaLinker(),
451+
# TODO: Fast_compile should just use python code, CHANGE ME!
452+
RewriteDatabaseQuery(
453+
include=["fast_compile", "numba"],
454+
exclude=["cxx_only", "BlasOpt", "local_careduce_fusion"],
455+
),
456+
)
457+
FAST_RUN = Mode(
458+
NumbaLinker(),
459+
RewriteDatabaseQuery(
460+
include=["fast_run", "numba"],
461+
exclude=["cxx_only", "BlasOpt", "local_careduce_fusion"],
462+
),
448463
)
449-
if config.cxx:
450-
FAST_RUN = Mode("cvm", "fast_run")
451-
else:
452-
FAST_RUN = Mode(
453-
"vm",
454-
RewriteDatabaseQuery(include=["fast_run", "py_only"]),
455-
)
456464

457465
NUMBA = Mode(
458466
NumbaLinker(),
@@ -565,6 +573,7 @@ def register_mode(name, mode):
565573
Add a `Mode` which can be referred to by `name` in `function`.
566574
567575
"""
576+
# TODO: Remove me
568577
if name in predefined_modes:
569578
raise ValueError(f"Mode name already taken: {name}")
570579
predefined_modes[name] = mode

pytensor/configdefaults.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -370,11 +370,11 @@ def add_compile_configvars():
370370

371371
if rc == 0 and config.cxx != "":
372372
# Keep the default linker the same as the one for the mode FAST_RUN
373-
linker_options = ["c|py", "py", "c", "c|py_nogc", "vm", "vm_nogc", "cvm_nogc"]
373+
linker_options = ["cvm", "c|py", "py", "c", "c|py_nogc", "vm", "vm_nogc", "cvm_nogc", "jax"]
374374
else:
375375
# g++ is not present or the user disabled it,
376376
# linker should default to python only.
377-
linker_options = ["py", "vm_nogc"]
377+
linker_options = ["py", "vm", "vm_nogc", "jax"]
378378
if type(config).cxx.is_default:
379379
# If the user provided an empty value for cxx, do not warn.
380380
_logger.warning(
@@ -388,7 +388,7 @@ def add_compile_configvars():
388388
"linker",
389389
"Default linker used if the pytensor flags mode is Mode",
390390
# Not mutable because the default mode is cached after the first use.
391-
EnumStr("cvm", linker_options, mutable=False),
391+
EnumStr("numba", linker_options, mutable=False),
392392
in_c_key=False,
393393
)
394394

0 commit comments

Comments
 (0)