mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-07 21:26:21 -05:00
Bitcast support / fast bf16 load (#2011)
* bitcast renderers * fast llama load * make it one kernel * regression testing p1: re-enable test_dtype for all backends fix GPU * regression testing p2: fuzz all possible cases against numpy remove hancoded tests since the fuzzer covers them * define ushort * fix indent, probably need flake8 back for CI to catch --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
@@ -42,7 +42,7 @@ def _assert_eq(tensor:Tensor, target_dtype:DType, target):
|
||||
|
||||
def _test_op(fxn, target_dtype:DType, target): _assert_eq(fxn(), target_dtype, target)
|
||||
def _test_cast(a:Tensor, target_dtype:DType): _test_op(lambda: a.cast(target_dtype), target_dtype, a.numpy().astype(target_dtype.np).tolist())
|
||||
def _test_bitcast(a:Tensor, target_dtype:DType, target): _test_op(lambda: a.bitcast(target_dtype), target_dtype, target)
|
||||
def _test_bitcast(a:Tensor, target_dtype:DType, target=None): _test_op(lambda: a.bitcast(target_dtype), target_dtype, target or a.numpy().view(target_dtype.np).tolist())
|
||||
|
||||
class TestDType(unittest.TestCase):
|
||||
DTYPE: Any = None
|
||||
@@ -82,6 +82,12 @@ class TestDType(unittest.TestCase):
|
||||
lambda dtype: _test_ops(a_dtype=dtype, b_dtype=self.DTYPE) if dtype.itemsize < self.DTYPE.itemsize else None,
|
||||
get_available_cast_dtypes(self.DTYPE)
|
||||
))
|
||||
def test_bitcast(self):
|
||||
if self.DTYPE == dtypes.bool: raise unittest.SkipTest("no bools in bitcast")
|
||||
list(map(
|
||||
lambda dtype: _test_bitcast(Tensor(self.DATA, dtype=self.DTYPE), dtype) if dtype.itemsize == self.DTYPE.itemsize and dtype != dtypes.bool else None,
|
||||
get_available_cast_dtypes(self.DTYPE)
|
||||
))
|
||||
|
||||
def _test_ops(a_dtype:DType, b_dtype:DType, target_dtype=None):
|
||||
if not is_dtype_supported(a_dtype) or not is_dtype_supported(b_dtype): return
|
||||
@@ -140,21 +146,7 @@ class TestUint8Dtype(TestDType):
|
||||
@unittest.skipIf(getenv("CUDA",0)==1 or getenv("PTX", 0)==1, "cuda saturation works differently")
|
||||
def test_uint8_to_int8_overflow(self): _test_op(lambda: Tensor([255, 254, 253, 252], dtype=dtypes.uint8).cast(dtypes.int8), dtypes.int8, [-1, -2, -3, -4])
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT not in {"CPU", "TORCH"}, "only bitcast in CPU and TORCH")
|
||||
class TestBitCast(unittest.TestCase):
|
||||
def test_float32_bitcast_to_int32(self): _test_bitcast(Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.int32, [1065353216, 1073741824, 1077936128, 1082130432])
|
||||
@unittest.skipIf(Device.DEFAULT == "TORCH", "no uint32 in torch")
|
||||
def test_float32_bitcast_to_uint32(self): _test_bitcast(Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.uint32, [1065353216, 1073741824, 1077936128, 1082130432])
|
||||
def test_int32_bitcast_to_float32(self): _test_bitcast(Tensor([1065353216, 1073741824, 1077936128, 1082130432], dtype=dtypes.int32), dtypes.float32, [1.0, 2.0, 3.0, 4.0])
|
||||
|
||||
# NOTE: these are the same as normal casts
|
||||
def test_int8_bitcast_to_uint8(self): _test_bitcast(Tensor([-1, -2, -3, -4], dtype=dtypes.int8), dtypes.uint8, [255, 254, 253, 252])
|
||||
def test_uint8_bitcast_to_int8(self): _test_bitcast(Tensor([255, 254, 253, 252], dtype=dtypes.uint8), dtypes.int8, [-1, -2, -3, -4])
|
||||
@unittest.skipIf(Device.DEFAULT == "TORCH", "no uint64 in torch")
|
||||
def test_int64_bitcast_to_uint64(self): _test_bitcast(Tensor([-1, -2, -3, -4], dtype=dtypes.int64), dtypes.uint64, [18446744073709551615, 18446744073709551614, 18446744073709551613, 18446744073709551612])
|
||||
@unittest.skipIf(Device.DEFAULT == "TORCH", "no uint64 in torch")
|
||||
def test_uint64_bitcast_to_int64(self): _test_bitcast(Tensor([18446744073709551615, 18446744073709551614, 18446744073709551613, 18446744073709551612], dtype=dtypes.uint64), dtypes.int64, [-1, -2, -3, -4])
|
||||
|
||||
def test_shape_change_bitcast(self):
|
||||
with self.assertRaises(AssertionError):
|
||||
_test_bitcast(Tensor([100000], dtype=dtypes.float32), dtypes.uint8, [100000])
|
||||
|
||||
Reference in New Issue
Block a user