Skip to content

Commit 17e46d2

Browse files
committed
Fix passing M=None to function in Eye test
1 parent 24591f6 commit 17e46d2

File tree

2 files changed

+9
-14
lines changed

2 files changed

+9
-14
lines changed

pytensor/tensor/basic.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1453,8 +1453,7 @@ def eye(n, m=None, k=0, dtype=None):
14531453
dtype = config.floatX
14541454
if m is None:
14551455
m = n
1456-
localop = Eye(dtype)
1457-
return localop(n, m, k)
1456+
return Eye(dtype)(n, m, k)
14581457

14591458

14601459
def identity_like(x, dtype: str | np.generic | np.dtype | None = None):

tests/tensor/test_basic.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -934,22 +934,18 @@ def test_infer_static_shape():
934934
class TestEye:
935935
# This is slow for the ('int8', 3) version.
936936
def test_basic(self):
937-
def check(dtype, N, M_=None, k=0):
938-
# PyTensor does not accept None as a tensor.
939-
# So we must use a real value.
940-
M = M_
941-
# Currently DebugMode does not support None as inputs even if this is
942-
# allowed.
943-
if M is None and config.mode in ["DebugMode", "DEBUG_MODE"]:
944-
M = N
937+
def check(dtype, N, M=None, k=0):
945938
N_symb = iscalar()
946939
M_symb = iscalar()
947940
k_symb = iscalar()
941+
test_inputs = [N, k] if M is None else [N, M, k]
942+
inputs = [N_symb, k_symb] if M is None else [N_symb, M_symb, k_symb]
948943
f = function(
949-
[N_symb, M_symb, k_symb], eye(N_symb, M_symb, k_symb, dtype=dtype)
944+
inputs,
945+
eye(N_symb, None if (M is None) else M_symb, k_symb, dtype=dtype),
950946
)
951-
result = f(N, M, k)
952-
assert np.allclose(result, np.eye(N, M_, k, dtype=dtype))
947+
result = f(*test_inputs)
948+
assert np.allclose(result, np.eye(N, M, k, dtype=dtype))
953949
assert result.dtype == np.dtype(dtype)
954950

955951
for dtype in ALL_DTYPES:
@@ -1755,7 +1751,7 @@ def test_join_matrixV_negative_axis(self):
17551751
got = f(-2)
17561752
assert np.allclose(got, want)
17571753

1758-
with pytest.raises(ValueError):
1754+
with pytest.raises((ValueError, IndexError)):
17591755
f(-3)
17601756

17611757
@pytest.mark.parametrize("py_impl", (False, True))

0 commit comments

Comments
 (0)