Bitcast support / fast bf16 load (#2011)

* bitcast renderers

* fast llama load

* make it one kernel

* regression testing p1: re-enable test_dtype for all backends

fix GPU

* regression testing p2: fuzz all possible cases against numpy

remove hancoded tests since the fuzzer covers them

* define ushort

* fix indent, probably need flake8 back for CI to catch

---------

Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
qazal
2023-12-05 19:19:28 -05:00
committed by GitHub
parent 232ed2af3f
commit be09cc87c1
5 changed files with 26 additions and 28 deletions

View File

@@ -42,7 +42,7 @@ def _assert_eq(tensor:Tensor, target_dtype:DType, target):
def _test_op(fxn, target_dtype:DType, target): _assert_eq(fxn(), target_dtype, target)
def _test_cast(a:Tensor, target_dtype:DType): _test_op(lambda: a.cast(target_dtype), target_dtype, a.numpy().astype(target_dtype.np).tolist())
def _test_bitcast(a:Tensor, target_dtype:DType, target): _test_op(lambda: a.bitcast(target_dtype), target_dtype, target)
def _test_bitcast(a:Tensor, target_dtype:DType, target=None): _test_op(lambda: a.bitcast(target_dtype), target_dtype, target or a.numpy().view(target_dtype.np).tolist())
class TestDType(unittest.TestCase):
DTYPE: Any = None
@@ -82,6 +82,12 @@ class TestDType(unittest.TestCase):
lambda dtype: _test_ops(a_dtype=dtype, b_dtype=self.DTYPE) if dtype.itemsize < self.DTYPE.itemsize else None,
get_available_cast_dtypes(self.DTYPE)
))
def test_bitcast(self):
if self.DTYPE == dtypes.bool: raise unittest.SkipTest("no bools in bitcast")
list(map(
lambda dtype: _test_bitcast(Tensor(self.DATA, dtype=self.DTYPE), dtype) if dtype.itemsize == self.DTYPE.itemsize and dtype != dtypes.bool else None,
get_available_cast_dtypes(self.DTYPE)
))
def _test_ops(a_dtype:DType, b_dtype:DType, target_dtype=None):
if not is_dtype_supported(a_dtype) or not is_dtype_supported(b_dtype): return
@@ -140,21 +146,7 @@ class TestUint8Dtype(TestDType):
@unittest.skipIf(getenv("CUDA",0)==1 or getenv("PTX", 0)==1, "cuda saturation works differently")
def test_uint8_to_int8_overflow(self): _test_op(lambda: Tensor([255, 254, 253, 252], dtype=dtypes.uint8).cast(dtypes.int8), dtypes.int8, [-1, -2, -3, -4])
@unittest.skipIf(Device.DEFAULT not in {"CPU", "TORCH"}, "only bitcast in CPU and TORCH")
class TestBitCast(unittest.TestCase):
def test_float32_bitcast_to_int32(self): _test_bitcast(Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.int32, [1065353216, 1073741824, 1077936128, 1082130432])
@unittest.skipIf(Device.DEFAULT == "TORCH", "no uint32 in torch")
def test_float32_bitcast_to_uint32(self): _test_bitcast(Tensor([1,2,3,4], dtype=dtypes.float32), dtypes.uint32, [1065353216, 1073741824, 1077936128, 1082130432])
def test_int32_bitcast_to_float32(self): _test_bitcast(Tensor([1065353216, 1073741824, 1077936128, 1082130432], dtype=dtypes.int32), dtypes.float32, [1.0, 2.0, 3.0, 4.0])
# NOTE: these are the same as normal casts
def test_int8_bitcast_to_uint8(self): _test_bitcast(Tensor([-1, -2, -3, -4], dtype=dtypes.int8), dtypes.uint8, [255, 254, 253, 252])
def test_uint8_bitcast_to_int8(self): _test_bitcast(Tensor([255, 254, 253, 252], dtype=dtypes.uint8), dtypes.int8, [-1, -2, -3, -4])
@unittest.skipIf(Device.DEFAULT == "TORCH", "no uint64 in torch")
def test_int64_bitcast_to_uint64(self): _test_bitcast(Tensor([-1, -2, -3, -4], dtype=dtypes.int64), dtypes.uint64, [18446744073709551615, 18446744073709551614, 18446744073709551613, 18446744073709551612])
@unittest.skipIf(Device.DEFAULT == "TORCH", "no uint64 in torch")
def test_uint64_bitcast_to_int64(self): _test_bitcast(Tensor([18446744073709551615, 18446744073709551614, 18446744073709551613, 18446744073709551612], dtype=dtypes.uint64), dtypes.int64, [-1, -2, -3, -4])
def test_shape_change_bitcast(self):
with self.assertRaises(AssertionError):
_test_bitcast(Tensor([100000], dtype=dtypes.float32), dtypes.uint8, [100000])