Make torch backend more usable, fix bfloat support in the llvm backend (#2765)

* Uncripple dtype tests, TestBFloat16DType never actually runs.

* Fix conversion from/to bfloat16.

Call cast() recursively, so that it works for any type combo.

* Run this test on torch backend as well.

* Add torch.bfloat16.

* Add support for ushort and uint.

* Convert np.uint32 to np.int32 when loading.

* Fix warning.
This commit is contained in:
Maksym Sobolyev
2023-12-17 11:04:26 -08:00
committed by GitHub
parent 9c32474a1f
commit 887f3d9933
3 changed files with 16 additions and 8 deletions

View File

@@ -103,14 +103,14 @@ def _test_ops(a_dtype:DType, b_dtype:DType, target_dtype=None):
class TestBFloat16DType(unittest.TestCase):
def setUp(self):
if not is_dtype_supported(dtypes.bfloat16): raise unittest.SkipTest("bfloat16 not supported")
if Device.DEFAULT not in ["LLVM", "TORCH"]: raise unittest.SkipTest("bfloat16 not supported")
def test_bf16_to_float(self):
with self.assertRaises(AssertionError):
_test_cast(Tensor([100000], dtype=dtypes.bfloat16), dtypes.float32, [100000])
_test_cast(Tensor([100000], dtype=dtypes.bfloat16), dtypes.float32)
def test_float_to_bf16(self):
with self.assertRaises(AssertionError):
_test_cast(Tensor([100000], dtype=dtypes.float32), dtypes.bfloat16, [100000])
_test_cast(Tensor([100000], dtype=dtypes.float32), dtypes.bfloat16)
# torch.tensor([10000, -1, -1000, -10000, 20]).type(torch.bfloat16)
@@ -125,7 +125,7 @@ class TestBFloat16DType(unittest.TestCase):
t.to(f"disk:{temp('f32')}").realize()
# hack to "cast" f32 -> bf16
dat = open(temp('f32'), "rb").read()
with open(temp('f32'), "rb") as f: dat = f.read()
adat = b''.join([dat[i+2:i+4] for i in range(0, len(dat), 4)])
with open(temp('bf16'), "wb") as f: f.write(adat)