mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
@@ -414,11 +414,11 @@ class TestDtypeUsage(unittest.TestCase):
|
||||
t = Tensor([[1, 2], [3, 4]], dtype=d)
|
||||
(t*t).max().item()
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.bfloat16), f"no bfloat16 on {Device.DEFAULT}")
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.bfloat16) or Device.DEFAULT == "PYTHON", f"no bfloat16 on {Device.DEFAULT}")
|
||||
class TestOpsBFloat16(unittest.TestCase):
|
||||
def test_cast(self):
|
||||
# TODO: helper_test_op breaks in unrelated part
|
||||
# TODO: wrong output with GPU=1 / PYTHON=1 on mac
|
||||
# TODO: wrong output with GPU=1 on mac
|
||||
data = [60000.0, 70000.0, 80000.0]
|
||||
np.testing.assert_allclose(Tensor(data).cast("bfloat16").numpy(), torch.tensor(data).type(torch.bfloat16).float().numpy())
|
||||
|
||||
|
||||
Reference in New Issue
Block a user