@@ -934,22 +934,18 @@ def test_infer_static_shape():
934934class TestEye :
935935 # This is slow for the ('int8', 3) version.
936936 def test_basic (self ):
937- def check (dtype , N , M_ = None , k = 0 ):
938- # PyTensor does not accept None as a tensor.
939- # So we must use a real value.
940- M = M_
941- # Currently DebugMode does not support None as inputs even if this is
942- # allowed.
943- if M is None and config .mode in ["DebugMode" , "DEBUG_MODE" ]:
944- M = N
937+ def check (dtype , N , M = None , k = 0 ):
945938 N_symb = iscalar ()
946939 M_symb = iscalar ()
947940 k_symb = iscalar ()
941+ test_inputs = [N , k ] if M is None else [N , M , k ]
942+ inputs = [N_symb , k_symb ] if M is None else [N_symb , M_symb , k_symb ]
948943 f = function (
949- [N_symb , M_symb , k_symb ], eye (N_symb , M_symb , k_symb , dtype = dtype )
944+ inputs ,
945+ eye (N_symb , None if (M is None ) else M_symb , k_symb , dtype = dtype ),
950946 )
951- result = f (N , M , k )
952- assert np .allclose (result , np .eye (N , M_ , k , dtype = dtype ))
947+ result = f (* test_inputs )
948+ assert np .allclose (result , np .eye (N , M , k , dtype = dtype ))
953949 assert result .dtype == np .dtype (dtype )
954950
955951 for dtype in ALL_DTYPES :
@@ -1755,7 +1751,7 @@ def test_join_matrixV_negative_axis(self):
17551751 got = f (- 2 )
17561752 assert np .allclose (got , want )
17571753
1758- with pytest .raises (ValueError ):
1754+ with pytest .raises (( ValueError , IndexError ) ):
17591755 f (- 3 )
17601756
17611757 @pytest .mark .parametrize ("py_impl" , (False , True ))
0 commit comments