fix const bitcast should not be constant folded (#3743)

* fix const bitcast should not be constant folded

* fixed const bf16 creation

* LLVM still broken
This commit is contained in:
chenyu
2024-03-14 19:13:52 -04:00
committed by GitHub
parent 557c7a5c54
commit 38ba277ac8
2 changed files with 7 additions and 5 deletions

View File

@@ -137,16 +137,14 @@ class TestBFloat16(unittest.TestCase):
assert tnp.dtype == np.float32
np.testing.assert_allclose(tnp, np.array(data))
@unittest.expectedFailure
@unittest.skipIf(Device.DEFAULT=="LLVM", "broken for LLVM")
def test_bf16_ones(self):
# TODO: fix this with correct bfloat16 cast
t = Tensor.ones(3, 5, dtype=dtypes.bfloat16)
assert t.dtype == dtypes.bfloat16
np.testing.assert_allclose(t.numpy(), np.ones((3, 5)))
@unittest.expectedFailure
@unittest.skipIf(Device.DEFAULT=="LLVM", "broken for LLVM")
def test_bf16_eye(self):
# TODO: fix this with correct bfloat16 cast
t = Tensor.eye(3, dtype=dtypes.bfloat16)
assert t.dtype == dtypes.bfloat16
np.testing.assert_allclose(t.numpy(), np.eye(3))
@@ -256,6 +254,10 @@ class TestBitCast(unittest.TestCase):
b = a.bitcast(dtypes.int32)
assert b.numpy()[0] == 0x3f800000
def test_bitcast_const(self):
a = Tensor(1, dtype=dtypes.float32).bitcast(dtypes.uint32)
np.testing.assert_equal(a.numpy(), np.array(1, dtype=np.float32).view(np.uint32))
def test_bitcast_upcasted(self):
a = Tensor.zeros(100, 4, dtype=dtypes.int32).contiguous() + 0x3f800000
b = a.bitcast(dtypes.float32)