diff --git a/test/test_dtype_alu.py b/test/test_dtype_alu.py index 102ae908b9..8de58d346a 100644 --- a/test/test_dtype_alu.py +++ b/test/test_dtype_alu.py @@ -1,7 +1,7 @@ import unittest, operator, math from tinygrad import Tensor, dtypes, Device from tinygrad.dtype import DType, truncate -from tinygrad.helpers import CI, getenv, CPU_LLVM +from tinygrad.helpers import CI, getenv from tinygrad.tensor import _to_np_dtype from tinygrad.device import is_dtype_supported from tinygrad.runtime.ops_python import from_storage_scalar @@ -48,7 +48,7 @@ class ht: int32 = strat.integers(-2147483648, 2147483647) int64 = strat.integers(-9223372036854775808, 9223372036854775807) bool = strat.booleans() -ht.bfloat16 = ht.uint16 +ht.bfloat16 = ht.uint16.filter(lambda x: ((x >> 7) & 0xFF) != 0) # filter subnormal bfloat16 ht.fp8e4m3 = ht.uint8 ht.fp8e5m2 = ht.uint8 @@ -138,7 +138,6 @@ class TestDTypeALU(unittest.TestCase): def test_float16_unary(self, a, op): universal_test_unary(a, dtypes.float16, op) @unittest.skipUnless(is_dtype_supported(dtypes.bfloat16), f"no bfloat16 on {Device.DEFAULT}") - @unittest.skipIf(CPU_LLVM, "bfloat16 precision issues with CPU_LLVM") @given(ht.bfloat16, strat.sampled_from(unary_operations)) def test_bfloat16_unary(self, a, op): universal_test_unary(from_storage_scalar(a, dtypes.bfloat16), dtypes.bfloat16, op)