Skip to content

Commit dae7ff3

Browse files
committed
Add support for TypedList in numba backend
1 parent c62d733 commit dae7ff3

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

pytensor/link/numba/dispatch/basic.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from pytensor.tensor.slinalg import Solve
4141
from pytensor.tensor.type import TensorType
4242
from pytensor.tensor.type_other import MakeSlice, NoneConst
43+
from pytensor.typed_list import TypedListType
4344

4445

4546
def global_numba_func(func):
@@ -135,6 +136,8 @@ def get_numba_type(
135136
return CSCMatrixType(numba_dtype)
136137

137138
raise NotImplementedError()
139+
elif isinstance(pytensor_type, TypedListType):
140+
return numba.types.List(get_numba_type(pytensor_type.ttype))
138141
else:
139142
raise NotImplementedError(f"Numba type not implemented for {pytensor_type}")
140143

0 commit comments

Comments
 (0)