mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-06 21:53:53 -05:00
no subnormal bf16 (#13905)
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import unittest, operator, math
|
||||
from tinygrad import Tensor, dtypes, Device
|
||||
from tinygrad.dtype import DType, truncate
|
||||
from tinygrad.helpers import CI, getenv, CPU_LLVM
|
||||
from tinygrad.helpers import CI, getenv
|
||||
from tinygrad.tensor import _to_np_dtype
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.runtime.ops_python import from_storage_scalar
|
||||
@@ -48,7 +48,7 @@ class ht:
|
||||
int32 = strat.integers(-2147483648, 2147483647)
|
||||
int64 = strat.integers(-9223372036854775808, 9223372036854775807)
|
||||
bool = strat.booleans()
|
||||
ht.bfloat16 = ht.uint16
|
||||
ht.bfloat16 = ht.uint16.filter(lambda x: ((x >> 7) & 0xFF) != 0) # filter subnormal bfloat16
|
||||
ht.fp8e4m3 = ht.uint8
|
||||
ht.fp8e5m2 = ht.uint8
|
||||
|
||||
@@ -138,7 +138,6 @@ class TestDTypeALU(unittest.TestCase):
|
||||
def test_float16_unary(self, a, op): universal_test_unary(a, dtypes.float16, op)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.bfloat16), f"no bfloat16 on {Device.DEFAULT}")
|
||||
@unittest.skipIf(CPU_LLVM, "bfloat16 precision issues with CPU_LLVM")
|
||||
@given(ht.bfloat16, strat.sampled_from(unary_operations))
|
||||
def test_bfloat16_unary(self, a, op): universal_test_unary(from_storage_scalar(a, dtypes.bfloat16), dtypes.bfloat16, op)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user