mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
move bf16 cast hack to Tensor.llvm_bf16_cast (#3788)
This commit is contained in:
@@ -117,7 +117,7 @@ def _test_ops(a_dtype:DType, b_dtype:DType, target_dtype=None):
|
||||
_assert_eq(Tensor([[1,2],[3,4]], dtype=a_dtype)@Tensor.eye(2, dtype=b_dtype), target_dtype, [[1,2],[3,4]])
|
||||
_assert_eq(Tensor([1,1,1,1], dtype=a_dtype)+Tensor.ones((4,4), dtype=b_dtype), target_dtype, 2*Tensor.ones(4,4).numpy())
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT in ["LLVM", "HIP", "METAL"], "bfloat16 not supported")
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.bfloat16), "bfloat16 not supported")
|
||||
class TestBFloat16(unittest.TestCase):
|
||||
def test_bf16_creation_numpy(self):
|
||||
data = [-1, 1, 2]
|
||||
@@ -127,21 +127,19 @@ class TestBFloat16(unittest.TestCase):
|
||||
assert tnp.dtype == np.float32
|
||||
np.testing.assert_allclose(tnp, np.array(data))
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT=="LLVM", "no LLVM bf16 buffer")
|
||||
def test_bf16_ones(self):
|
||||
# TODO: fix this with correct bfloat16 cast
|
||||
t = Tensor.ones(3, 5, dtype=dtypes.bfloat16)
|
||||
assert t.dtype == dtypes.bfloat16
|
||||
np.testing.assert_allclose(t.numpy(), np.ones((3, 5)))
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT=="LLVM", "no LLVM bf16 buffer")
|
||||
def test_bf16_eye(self):
|
||||
# TODO: fix this with correct bfloat16 cast
|
||||
t = Tensor.eye(3, dtype=dtypes.bfloat16)
|
||||
assert t.dtype == dtypes.bfloat16
|
||||
np.testing.assert_allclose(t.numpy(), np.eye(3))
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT in ["LLVM", "HIP"], "bfloat16 not supported")
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.bfloat16), "bfloat16 not supported")
|
||||
class TestBFloat16DType(unittest.TestCase):
|
||||
def test_bf16_to_float(self):
|
||||
_test_cast(Tensor([100000], dtype=dtypes.bfloat16), dtypes.float32)
|
||||
@@ -157,7 +155,7 @@ class TestBFloat16DType(unittest.TestCase):
|
||||
back = t.cast(dtypes.float32)
|
||||
assert tuple(back.numpy().tolist()) == (9984., -1, -1000, -9984, 20)
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT in ["HIP"], "bfloat16 not supported")
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.bfloat16), "bfloat16 not supported")
|
||||
class TestBFloat16DTypeCast(unittest.TestCase):
|
||||
def test_f16_to_bf16_conversion(self):
|
||||
original_tensor = Tensor([1.0, 2.0, 3.0], dtype=dtypes.float16)
|
||||
|
||||
Reference in New Issue
Block a user