move bf16 cast hack to Tensor.llvm_bf16_cast (#3788)

This commit is contained in:
chenyu
2024-03-17 18:51:22 -04:00
committed by GitHub
parent 311cf2b7d3
commit 639bd5dbfc
5 changed files with 13 additions and 16 deletions

View File

@@ -216,7 +216,7 @@ class TestDiskTensor(unittest.TestCase):
@unittest.skipIf(Device.DEFAULT == "RHIP", "no real HIP device exists in CI")
def test_bf16_disk_write_read(self):
t = Tensor([10000, -1, -1000, -10000, 20]).cast(dtypes.float32)
t = Tensor([10000, -1, -1000, -10000, 20], dtype=dtypes.float32)
t.to(f"disk:{temp('f32')}").realize()
# hack to "cast" f32 -> bf16
@@ -224,9 +224,8 @@ class TestDiskTensor(unittest.TestCase):
adat = b''.join([dat[i+2:i+4] for i in range(0, len(dat), 4)])
with open(temp('bf16'), "wb") as f: f.write(adat)
t = Tensor.empty(5, dtype=dtypes.bfloat16, device=f"disk:{temp('bf16')}").llvm().realize()
back = t.cast(dtypes.float32)
assert tuple(back.numpy().tolist()) == (9984., -1, -1000, -9984, 20)
t = Tensor.empty(5, dtype=dtypes.bfloat16, device=f"disk:{temp('bf16')}").llvm_bf16_cast(dtypes.float)
assert t.numpy().tolist() == [9984., -1, -1000, -9984, 20]
if __name__ == "__main__":
unittest.main()