mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
Make torch backend more usable, fix bfloat support in the llvm backend (#2765)
* Uncripple dtype tests, TestBFloat16DType never actually runs. * Fix conversion from/to bfloat16. Call cast() recursively, so that it works for any type combo. * Run this test on torch backend as well. * Add torch.bfloat16. * Add support for ushort and uint. * Convert np.uint32 to np.int32 when loading. * Fix warning.
This commit is contained in:
@@ -103,14 +103,14 @@ def _test_ops(a_dtype:DType, b_dtype:DType, target_dtype=None):
|
||||
|
||||
class TestBFloat16DType(unittest.TestCase):
|
||||
def setUp(self):
|
||||
if not is_dtype_supported(dtypes.bfloat16): raise unittest.SkipTest("bfloat16 not supported")
|
||||
if Device.DEFAULT not in ["LLVM", "TORCH"]: raise unittest.SkipTest("bfloat16 not supported")
|
||||
def test_bf16_to_float(self):
|
||||
with self.assertRaises(AssertionError):
|
||||
_test_cast(Tensor([100000], dtype=dtypes.bfloat16), dtypes.float32, [100000])
|
||||
_test_cast(Tensor([100000], dtype=dtypes.bfloat16), dtypes.float32)
|
||||
|
||||
def test_float_to_bf16(self):
|
||||
with self.assertRaises(AssertionError):
|
||||
_test_cast(Tensor([100000], dtype=dtypes.float32), dtypes.bfloat16, [100000])
|
||||
_test_cast(Tensor([100000], dtype=dtypes.float32), dtypes.bfloat16)
|
||||
|
||||
# torch.tensor([10000, -1, -1000, -10000, 20]).type(torch.bfloat16)
|
||||
|
||||
@@ -125,7 +125,7 @@ class TestBFloat16DType(unittest.TestCase):
|
||||
t.to(f"disk:{temp('f32')}").realize()
|
||||
|
||||
# hack to "cast" f32 -> bf16
|
||||
dat = open(temp('f32'), "rb").read()
|
||||
with open(temp('f32'), "rb") as f: dat = f.read()
|
||||
adat = b''.join([dat[i+2:i+4] for i in range(0, len(dat), 4)])
|
||||
with open(temp('bf16'), "wb") as f: f.write(adat)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user