Skip to content

Commit 44df98a

Browse files
committed
XFAIL/SKIP float16 tests
1 parent f1b6ff1 commit 44df98a

File tree

3 files changed

+58
-3
lines changed

3 files changed

+58
-3
lines changed

tests/tensor/rewriting/test_basic.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from pytensor.graph.rewriting.basic import check_stack_trace, out2in
1919
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
2020
from pytensor.graph.rewriting.utils import rewrite_graph
21+
from pytensor.link.numba import NumbaLinker
2122
from pytensor.printing import debugprint, pprint
2223
from pytensor.raise_op import Assert, CheckAndRaise
2324
from pytensor.scalar import Composite, float64
@@ -1206,6 +1207,10 @@ def test_sum_bool_upcast(self):
12061207
f(5)
12071208

12081209

1210+
@pytest.mark.xfail(
1211+
condition=isinstance(get_default_mode().linker, NumbaLinker),
1212+
reason="Numba does not support float16",
1213+
)
12091214
class TestLocalOptAllocF16(TestLocalOptAlloc):
12101215
dtype = "float16"
12111216

tests/tensor/test_math.py

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from pytensor.graph.replace import vectorize_node
2525
from pytensor.graph.traversal import ancestors, applys_between
2626
from pytensor.link.c.basic import DualLinker
27+
from pytensor.link.numba import NumbaLinker
2728
from pytensor.printing import pprint
2829
from pytensor.raise_op import Assert
2930
from pytensor.tensor import blas, blas_c
@@ -858,6 +859,10 @@ def test_basic_2(self, axis, np_axis):
858859
([1, 0], None),
859860
],
860861
)
862+
@pytest.mark.xfail(
863+
condition=isinstance(get_default_mode().linker, NumbaLinker),
864+
reason="Numba does not support float16",
865+
)
861866
def test_basic_2_float16(self, axis, np_axis):
862867
# Test negative values and bigger range to make sure numpy don't do the argmax as on uint16
863868
data = (random(20, 30).astype("float16") - 0.5) * 20
@@ -1114,6 +1119,10 @@ def test2(self):
11141119
v_shape = eval_outputs(fct(n, axis).shape)
11151120
assert tuple(v_shape) == nfct(data, np_axis).shape
11161121

1122+
@pytest.mark.xfail(
1123+
condition=isinstance(get_default_mode().linker, NumbaLinker),
1124+
reason="Numba does not support float16",
1125+
)
11171126
def test2_float16(self):
11181127
# Test negative values and bigger range to make sure numpy don't do the argmax as on uint16
11191128
data = (random(20, 30).astype("float16") - 0.5) * 20
@@ -1981,6 +1990,10 @@ def test_mean_single_element(self):
19811990
res = mean(np.zeros(1))
19821991
assert res.eval() == 0.0
19831992

1993+
@pytest.mark.xfail(
1994+
condition=isinstance(get_default_mode().linker, NumbaLinker),
1995+
reason="Numba does not support float16",
1996+
)
19841997
def test_mean_f16(self):
19851998
x = vector(dtype="float16")
19861999
y = x.mean()
@@ -3153,7 +3166,9 @@ class TestSumProdReduceDtype:
31533166
op = CAReduce
31543167
axes = [None, 0, 1, [], [0], [1], [0, 1]]
31553168
methods = ["sum", "prod"]
3156-
dtypes = list(map(str, ps.all_types))
3169+
dtypes = tuple(map(str, ps.all_types))
3170+
if isinstance(mode.linker, NumbaLinker):
3171+
dtypes = tuple(d for d in dtypes if d != "float16")
31573172

31583173
# Test the default dtype of a method().
31593174
def test_reduce_default_dtype(self):
@@ -3313,10 +3328,13 @@ def test_reduce_precision(self):
33133328
class TestMeanDtype:
33143329
def test_mean_default_dtype(self):
33153330
# Test the default dtype of a mean().
3331+
is_numba = isinstance(get_default_mode().linker, NumbaLinker)
33163332

33173333
# We try multiple axis combinations even though axis should not matter.
33183334
axes = [None, 0, 1, [], [0], [1], [0, 1]]
33193335
for idx, dtype in enumerate(map(str, ps.all_types)):
3336+
if is_numba and dtype == "float16":
3337+
continue
33203338
axis = axes[idx % len(axes)]
33213339
x = matrix(dtype=dtype)
33223340
m = x.mean(axis=axis)
@@ -3411,10 +3429,13 @@ def test_prod_without_zeros_default_dtype(self):
34113429

34123430
def test_prod_without_zeros_default_acc_dtype(self):
34133431
# Test the default dtype of a ProdWithoutZeros().
3432+
is_numba = isinstance(get_default_mode().linker, NumbaLinker)
34143433

34153434
# We try multiple axis combinations even though axis should not matter.
34163435
axes = [None, 0, 1, [], [0], [1], [0, 1]]
34173436
for idx, dtype in enumerate(map(str, ps.all_types)):
3437+
if is_numba and dtype == "float16":
3438+
continue
34183439
axis = axes[idx % len(axes)]
34193440
x = matrix(dtype=dtype)
34203441
p = ProdWithoutZeros(axis=axis)(x)
@@ -3442,13 +3463,17 @@ def test_prod_without_zeros_default_acc_dtype(self):
34423463
@pytest.mark.slow
34433464
def test_prod_without_zeros_custom_dtype(self):
34443465
# Test ability to provide your own output dtype for a ProdWithoutZeros().
3445-
3466+
is_numba = isinstance(get_default_mode().linker, NumbaLinker)
34463467
# We try multiple axis combinations even though axis should not matter.
34473468
axes = [None, 0, 1, [], [0], [1], [0, 1]]
34483469
idx = 0
34493470
for input_dtype in map(str, ps.all_types):
3471+
if is_numba and input_dtype == "float16":
3472+
continue
34503473
x = matrix(dtype=input_dtype)
34513474
for output_dtype in map(str, ps.all_types):
3475+
if is_numba and output_dtype == "float16":
3476+
continue
34523477
axis = axes[idx % len(axes)]
34533478
prod_woz_var = ProdWithoutZeros(axis=axis, dtype=output_dtype)(x)
34543479
assert prod_woz_var.dtype == output_dtype
@@ -3464,13 +3489,18 @@ def test_prod_without_zeros_custom_dtype(self):
34643489
@pytest.mark.slow
34653490
def test_prod_without_zeros_custom_acc_dtype(self):
34663491
# Test ability to provide your own acc_dtype for a ProdWithoutZeros().
3492+
is_numba = isinstance(get_default_mode().linker, NumbaLinker)
34673493

34683494
# We try multiple axis combinations even though axis should not matter.
34693495
axes = [None, 0, 1, [], [0], [1], [0, 1]]
34703496
idx = 0
34713497
for input_dtype in map(str, ps.all_types):
3498+
if is_numba and input_dtype == "float16":
3499+
continue
34723500
x = matrix(dtype=input_dtype)
34733501
for acc_dtype in map(str, ps.all_types):
3502+
if is_numba and acc_dtype == "float16":
3503+
continue
34743504
axis = axes[idx % len(axes)]
34753505
# If acc_dtype would force a downcast, we expect a TypeError
34763506
# We always allow int/uint inputs with float/complex outputs.
@@ -3746,7 +3776,20 @@ def test_scalar_error(self):
37463776
with pytest.raises(ValueError, match="cannot be scalar"):
37473777
self.op(4, [4, 1])
37483778

3749-
@pytest.mark.parametrize("dtype", (np.float16, np.float32, np.float64))
3779+
@pytest.mark.parametrize(
3780+
"dtype",
3781+
(
3782+
pytest.param(
3783+
np.float16,
3784+
marks=pytest.mark.xfail(
3785+
condition=isinstance(get_default_mode().linker, NumbaLinker),
3786+
reason="Numba does not support float16",
3787+
),
3788+
),
3789+
np.float32,
3790+
np.float64,
3791+
),
3792+
)
37503793
def test_dtype_param(self, dtype):
37513794
sol = self.op([1, 2, 3], [3, 2, 1], dtype=dtype)
37523795
assert sol.eval().dtype == dtype

tests/tensor/test_slinalg.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,10 @@
1010

1111
from pytensor import function, grad
1212
from pytensor import tensor as pt
13+
from pytensor.compile import get_default_mode
1314
from pytensor.configdefaults import config
1415
from pytensor.graph.basic import equal_computations
16+
from pytensor.link.numba import NumbaLinker
1517
from pytensor.tensor import TensorVariable
1618
from pytensor.tensor.slinalg import (
1719
Cholesky,
@@ -606,6 +608,8 @@ def test_solve_correctness(self):
606608
)
607609

608610
def test_solve_dtype(self):
611+
is_numba = isinstance(get_default_mode().linker, NumbaLinker)
612+
609613
dtypes = [
610614
"uint8",
611615
"uint16",
@@ -626,6 +630,9 @@ def test_solve_dtype(self):
626630

627631
# try all dtype combinations
628632
for A_dtype, b_dtype in itertools.product(dtypes, dtypes):
633+
if is_numba and (A_dtype == "float16" or b_dtype == "float16"):
634+
# Numba does not support float16
635+
continue
629636
A = matrix(dtype=A_dtype)
630637
b = matrix(dtype=b_dtype)
631638
x = op(A, b)

0 commit comments

Comments
 (0)