2424from pytensor .graph .replace import vectorize_node
2525from pytensor .graph .traversal import ancestors , applys_between
2626from pytensor .link .c .basic import DualLinker
27+ from pytensor .link .numba import NumbaLinker
2728from pytensor .printing import pprint
2829from pytensor .raise_op import Assert
2930from pytensor .tensor import blas , blas_c
@@ -858,6 +859,10 @@ def test_basic_2(self, axis, np_axis):
858859 ([1 , 0 ], None ),
859860 ],
860861 )
862+ @pytest .mark .xfail (
863+ condition = isinstance (get_default_mode ().linker , NumbaLinker ),
864+ reason = "Numba does not support float16" ,
865+ )
861866 def test_basic_2_float16 (self , axis , np_axis ):
862867 # Test negative values and bigger range to make sure numpy don't do the argmax as on uint16
863868 data = (random (20 , 30 ).astype ("float16" ) - 0.5 ) * 20
@@ -1114,6 +1119,10 @@ def test2(self):
11141119 v_shape = eval_outputs (fct (n , axis ).shape )
11151120 assert tuple (v_shape ) == nfct (data , np_axis ).shape
11161121
1122+ @pytest .mark .xfail (
1123+ condition = isinstance (get_default_mode ().linker , NumbaLinker ),
1124+ reason = "Numba does not support float16" ,
1125+ )
11171126 def test2_float16 (self ):
11181127 # Test negative values and bigger range to make sure numpy don't do the argmax as on uint16
11191128 data = (random (20 , 30 ).astype ("float16" ) - 0.5 ) * 20
@@ -1981,6 +1990,10 @@ def test_mean_single_element(self):
19811990 res = mean (np .zeros (1 ))
19821991 assert res .eval () == 0.0
19831992
1993+ @pytest .mark .xfail (
1994+ condition = isinstance (get_default_mode ().linker , NumbaLinker ),
1995+ reason = "Numba does not support float16" ,
1996+ )
19841997 def test_mean_f16 (self ):
19851998 x = vector (dtype = "float16" )
19861999 y = x .mean ()
@@ -3153,7 +3166,9 @@ class TestSumProdReduceDtype:
31533166 op = CAReduce
31543167 axes = [None , 0 , 1 , [], [0 ], [1 ], [0 , 1 ]]
31553168 methods = ["sum" , "prod" ]
3156- dtypes = list (map (str , ps .all_types ))
3169+ dtypes = tuple (map (str , ps .all_types ))
3170+ if isinstance (mode .linker , NumbaLinker ):
3171+ dtypes = tuple (d for d in dtypes if d != "float16" )
31573172
31583173 # Test the default dtype of a method().
31593174 def test_reduce_default_dtype (self ):
@@ -3313,10 +3328,13 @@ def test_reduce_precision(self):
33133328class TestMeanDtype :
33143329 def test_mean_default_dtype (self ):
33153330 # Test the default dtype of a mean().
3331+ is_numba = isinstance (get_default_mode ().linker , NumbaLinker )
33163332
33173333 # We try multiple axis combinations even though axis should not matter.
33183334 axes = [None , 0 , 1 , [], [0 ], [1 ], [0 , 1 ]]
33193335 for idx , dtype in enumerate (map (str , ps .all_types )):
3336+ if is_numba and dtype == "float16" :
3337+ continue
33203338 axis = axes [idx % len (axes )]
33213339 x = matrix (dtype = dtype )
33223340 m = x .mean (axis = axis )
@@ -3411,10 +3429,13 @@ def test_prod_without_zeros_default_dtype(self):
34113429
34123430 def test_prod_without_zeros_default_acc_dtype (self ):
34133431 # Test the default dtype of a ProdWithoutZeros().
3432+ is_numba = isinstance (get_default_mode ().linker , NumbaLinker )
34143433
34153434 # We try multiple axis combinations even though axis should not matter.
34163435 axes = [None , 0 , 1 , [], [0 ], [1 ], [0 , 1 ]]
34173436 for idx , dtype in enumerate (map (str , ps .all_types )):
3437+ if is_numba and dtype == "float16" :
3438+ continue
34183439 axis = axes [idx % len (axes )]
34193440 x = matrix (dtype = dtype )
34203441 p = ProdWithoutZeros (axis = axis )(x )
@@ -3442,13 +3463,17 @@ def test_prod_without_zeros_default_acc_dtype(self):
34423463 @pytest .mark .slow
34433464 def test_prod_without_zeros_custom_dtype (self ):
34443465 # Test ability to provide your own output dtype for a ProdWithoutZeros().
3445-
3466+ is_numba = isinstance ( get_default_mode (). linker , NumbaLinker )
34463467 # We try multiple axis combinations even though axis should not matter.
34473468 axes = [None , 0 , 1 , [], [0 ], [1 ], [0 , 1 ]]
34483469 idx = 0
34493470 for input_dtype in map (str , ps .all_types ):
3471+ if is_numba and input_dtype == "float16" :
3472+ continue
34503473 x = matrix (dtype = input_dtype )
34513474 for output_dtype in map (str , ps .all_types ):
3475+ if is_numba and output_dtype == "float16" :
3476+ continue
34523477 axis = axes [idx % len (axes )]
34533478 prod_woz_var = ProdWithoutZeros (axis = axis , dtype = output_dtype )(x )
34543479 assert prod_woz_var .dtype == output_dtype
@@ -3464,13 +3489,18 @@ def test_prod_without_zeros_custom_dtype(self):
34643489 @pytest .mark .slow
34653490 def test_prod_without_zeros_custom_acc_dtype (self ):
34663491 # Test ability to provide your own acc_dtype for a ProdWithoutZeros().
3492+ is_numba = isinstance (get_default_mode ().linker , NumbaLinker )
34673493
34683494 # We try multiple axis combinations even though axis should not matter.
34693495 axes = [None , 0 , 1 , [], [0 ], [1 ], [0 , 1 ]]
34703496 idx = 0
34713497 for input_dtype in map (str , ps .all_types ):
3498+ if is_numba and input_dtype == "float16" :
3499+ continue
34723500 x = matrix (dtype = input_dtype )
34733501 for acc_dtype in map (str , ps .all_types ):
3502+ if is_numba and acc_dtype == "float16" :
3503+ continue
34743504 axis = axes [idx % len (axes )]
34753505 # If acc_dtype would force a downcast, we expect a TypeError
34763506 # We always allow int/uint inputs with float/complex outputs.
@@ -3746,7 +3776,20 @@ def test_scalar_error(self):
37463776 with pytest .raises (ValueError , match = "cannot be scalar" ):
37473777 self .op (4 , [4 , 1 ])
37483778
3749- @pytest .mark .parametrize ("dtype" , (np .float16 , np .float32 , np .float64 ))
3779+ @pytest .mark .parametrize (
3780+ "dtype" ,
3781+ (
3782+ pytest .param (
3783+ np .float16 ,
3784+ marks = pytest .mark .xfail (
3785+ condition = isinstance (get_default_mode ().linker , NumbaLinker ),
3786+ reason = "Numba does not support float16" ,
3787+ ),
3788+ ),
3789+ np .float32 ,
3790+ np .float64 ,
3791+ ),
3792+ )
37503793 def test_dtype_param (self , dtype ):
37513794 sol = self .op ([1 , 2 , 3 ], [3 , 2 , 1 ], dtype = dtype )
37523795 assert sol .eval ().dtype == dtype
0 commit comments