diff --git a/test/test_dtype.py b/test/test_dtype.py index 760079d352..9440dab49b 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -137,6 +137,20 @@ class TestBFloat16(unittest.TestCase): assert tnp.dtype == np.float32 np.testing.assert_allclose(tnp, np.array(data)) + @unittest.expectedFailure + def test_bf16_ones(self): + # TODO: fix this with correct bfloat16 cast + t = Tensor.ones(3, 5, dtype=dtypes.bfloat16) + assert t.dtype == dtypes.bfloat16 + np.testing.assert_allclose(t.numpy(), np.ones((3, 5))) + + @unittest.expectedFailure + def test_bf16_eye(self): + # TODO: fix this with correct bfloat16 cast + t = Tensor.eye(3, dtype=dtypes.bfloat16) + assert t.dtype == dtypes.bfloat16 + np.testing.assert_allclose(t.numpy(), np.eye(3)) + @unittest.skipUnless(Device.DEFAULT in ["LLVM", "HIP"], "bfloat16 not supported") class TestBFloat16DType(unittest.TestCase): def test_bf16_to_float(self):