mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
@@ -48,11 +48,11 @@ def _test_cast(a:Tensor, target_dtype:DType):
|
||||
a = (a > 65504).where(65504, a)
|
||||
|
||||
expected = list(a.numpy().astype(_to_np_dtype(target_dtype)))
|
||||
if target_dtype in dtypes.fp8s: expected = list(map(lambda x: truncate[target_dtype](x), expected))
|
||||
if target_dtype in dtypes.fp8s: expected = [truncate[target_dtype](x) for x in expected]
|
||||
_test_op(lambda: a.cast(target_dtype), target_dtype, expected)
|
||||
def _test_bitcast(a:Tensor, target_dtype:DType, target=None):
|
||||
expected = torch.tensor(a.tolist(), dtype=_to_torch_storage_type(a.dtype)).view(_to_torch_dtype(target_dtype)).tolist()
|
||||
if target_dtype in dtypes.fp8s: expected = list(map(lambda x: fp8_to_float(x, target_dtype), expected))
|
||||
if target_dtype in dtypes.fp8s: expected = [fp8_to_float(x, target_dtype) for x in expected]
|
||||
_test_op(lambda: a.bitcast(target_dtype), target_dtype, target or expected)
|
||||
|
||||
class TestDType(unittest.TestCase):
|
||||
@@ -68,37 +68,34 @@ class TestDType(unittest.TestCase):
|
||||
def test_to_np(self):
|
||||
_test_to_np(Tensor(self.DATA, dtype=self.DTYPE), _to_np_dtype(self.DTYPE), np.array(self.DATA, dtype=_to_np_dtype(self.DTYPE)))
|
||||
|
||||
def test_casts_to(self): list(map(
|
||||
lambda dtype: _test_cast(Tensor(self.DATA, dtype=dtype), self.DTYPE),
|
||||
get_available_cast_dtypes(self.DTYPE)
|
||||
))
|
||||
def test_casts_from(self): list(map(
|
||||
lambda dtype: _test_cast(Tensor(self.DATA, dtype=self.DTYPE), dtype),
|
||||
get_available_cast_dtypes(self.DTYPE)
|
||||
))
|
||||
def test_casts_to(self):
|
||||
for dtype in get_available_cast_dtypes(self.DTYPE):
|
||||
_test_cast(Tensor(self.DATA, dtype=dtype), self.DTYPE)
|
||||
|
||||
def test_casts_from(self):
|
||||
for dtype in get_available_cast_dtypes(self.DTYPE):
|
||||
_test_cast(Tensor(self.DATA, dtype=self.DTYPE), dtype)
|
||||
|
||||
def test_same_size_ops(self):
|
||||
list(map(
|
||||
lambda dtype: _test_ops(a_dtype=self.DTYPE, b_dtype=dtype) if dtype.itemsize == self.DTYPE.itemsize else None,
|
||||
get_available_cast_dtypes(self.DTYPE)
|
||||
))
|
||||
for dtype in get_available_cast_dtypes(self.DTYPE):
|
||||
if dtype.itemsize == self.DTYPE.itemsize:
|
||||
_test_ops(a_dtype=self.DTYPE, b_dtype=dtype)
|
||||
|
||||
def test_upcast_ops(self):
|
||||
list(map(
|
||||
lambda dtype: _test_ops(a_dtype=self.DTYPE, b_dtype=dtype) if dtype.itemsize > self.DTYPE.itemsize else None,
|
||||
get_available_cast_dtypes(self.DTYPE)
|
||||
))
|
||||
for dtype in get_available_cast_dtypes(self.DTYPE):
|
||||
if dtype.itemsize > self.DTYPE.itemsize:
|
||||
_test_ops(a_dtype=self.DTYPE, b_dtype=dtype)
|
||||
|
||||
def test_upcast_to_ops(self):
|
||||
list(map(
|
||||
lambda dtype: _test_ops(a_dtype=dtype, b_dtype=self.DTYPE) if dtype.itemsize < self.DTYPE.itemsize else None,
|
||||
get_available_cast_dtypes(self.DTYPE)
|
||||
))
|
||||
for dtype in get_available_cast_dtypes(self.DTYPE):
|
||||
if dtype.itemsize < self.DTYPE.itemsize:
|
||||
_test_ops(a_dtype=dtype, b_dtype=self.DTYPE)
|
||||
|
||||
def test_bitcast(self):
|
||||
if self.DTYPE == dtypes.bool: raise unittest.SkipTest("no bools in bitcast")
|
||||
list(map(
|
||||
lambda dtype:
|
||||
_test_bitcast(Tensor(self.DATA[:8], dtype=self.DTYPE), dtype) if dtype != dtypes.bool else None,
|
||||
get_available_cast_dtypes(self.DTYPE)
|
||||
))
|
||||
for dtype in get_available_cast_dtypes(self.DTYPE):
|
||||
if dtype != dtypes.bool:
|
||||
_test_bitcast(Tensor(self.DATA[:8], dtype=self.DTYPE), dtype)
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT == "PYTHON", "skip for now")
|
||||
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, (PTXRenderer, NIRRenderer)), "skip for now")
|
||||
@@ -307,7 +304,7 @@ class TestBitCast(unittest.TestCase):
|
||||
data = rand_for_dtype(dt1, 32).reshape(2, 2, 8)
|
||||
expected = torch.tensor(data.tolist(), dtype=_to_torch_storage_type(dt1)).view(_to_torch_dtype(dt2))
|
||||
if dt2 in dtypes.fp8s:
|
||||
expected = torch.tensor(list(map(lambda x: fp8_to_float(x, dt2), expected.view(-1).tolist()))).view_as(expected)
|
||||
expected = torch.tensor([fp8_to_float(x, dt2) for x in expected.view(-1).tolist()]).view_as(expected)
|
||||
_test_op(lambda: Tensor(data, dtype=dt1).bitcast(dt2), dt2, expected.tolist())
|
||||
|
||||
def test_shape_change_bitcast_exceptions(self):
|
||||
|
||||
Reference in New Issue
Block a user