UOps.BITCAST (#3747)

* UOps.BITCAST

implicitly fixed no const folding for bitcast

* python backend

* ptx

* consistent llvm
This commit is contained in:
chenyu
2024-03-14 21:00:35 -04:00
committed by GitHub
parent 9a00a453c7
commit 75d4344cda
8 changed files with 13 additions and 15 deletions

View File

@@ -137,14 +137,14 @@ class TestBFloat16(unittest.TestCase):
assert tnp.dtype == np.float32
np.testing.assert_allclose(tnp, np.array(data))
@unittest.expectedFailure
@unittest.skipIf(Device.DEFAULT=="LLVM", "no LLVM bf16 buffer")
def test_bf16_ones(self):
# TODO: fix this with correct bfloat16 cast
t = Tensor.ones(3, 5, dtype=dtypes.bfloat16)
assert t.dtype == dtypes.bfloat16
np.testing.assert_allclose(t.numpy(), np.ones((3, 5)))
@unittest.expectedFailure
@unittest.skipIf(Device.DEFAULT=="LLVM", "no LLVM bf16 buffer")
def test_bf16_eye(self):
# TODO: fix this with correct bfloat16 cast
t = Tensor.eye(3, dtype=dtypes.bfloat16)

View File

@@ -186,15 +186,13 @@ class TestConstantFolding(unittest.TestCase):
lin = Device[Device.DEFAULT].get_linearizer(si.ast[0]).linearize()
assert all(uop.uop is not UOps.CAST for uop in lin.uops.uops), f"{[uop.uop for uop in lin.uops.uops]} contains non-folded constant cast"
@unittest.expectedFailure
def test_bitcast_const(self):
# TODO: fix bitcast const should not fold
t = Tensor(1, dtype=dtypes.float).bitcast(dtypes.int)
si = create_schedule([t.lazydata])
assert len(si) == 1
si = si[0]
lin = Device[Device.DEFAULT].get_linearizer(si.ast[0]).linearize()
assert any(uop.uop is UOps.CAST for uop in lin.uops.uops), f"{[uop.uop for uop in lin.uops.uops]} does not contain bitcast"
assert any(uop.uop is UOps.BITCAST for uop in lin.uops.uops), f"{[uop.uop for uop in lin.uops.uops]} does not contain bitcast"
if __name__ == '__main__':
unittest.main(verbosity=2)