clean up test_dtype (#14055)

use less lambda
This commit is contained in:
chenyu
2026-01-07 15:45:42 -05:00
committed by GitHub
parent 5bd4593eda
commit 5f1ede7f7e

View File

@@ -48,11 +48,11 @@ def _test_cast(a:Tensor, target_dtype:DType):
a = (a > 65504).where(65504, a)
expected = list(a.numpy().astype(_to_np_dtype(target_dtype)))
if target_dtype in dtypes.fp8s: expected = list(map(lambda x: truncate[target_dtype](x), expected))
if target_dtype in dtypes.fp8s: expected = [truncate[target_dtype](x) for x in expected]
_test_op(lambda: a.cast(target_dtype), target_dtype, expected)
def _test_bitcast(a:Tensor, target_dtype:DType, target=None):
expected = torch.tensor(a.tolist(), dtype=_to_torch_storage_type(a.dtype)).view(_to_torch_dtype(target_dtype)).tolist()
if target_dtype in dtypes.fp8s: expected = list(map(lambda x: fp8_to_float(x, target_dtype), expected))
if target_dtype in dtypes.fp8s: expected = [fp8_to_float(x, target_dtype) for x in expected]
_test_op(lambda: a.bitcast(target_dtype), target_dtype, target or expected)
class TestDType(unittest.TestCase):
@@ -68,37 +68,34 @@ class TestDType(unittest.TestCase):
def test_to_np(self):
_test_to_np(Tensor(self.DATA, dtype=self.DTYPE), _to_np_dtype(self.DTYPE), np.array(self.DATA, dtype=_to_np_dtype(self.DTYPE)))
def test_casts_to(self): list(map(
lambda dtype: _test_cast(Tensor(self.DATA, dtype=dtype), self.DTYPE),
get_available_cast_dtypes(self.DTYPE)
))
def test_casts_from(self): list(map(
lambda dtype: _test_cast(Tensor(self.DATA, dtype=self.DTYPE), dtype),
get_available_cast_dtypes(self.DTYPE)
))
def test_casts_to(self):
for dtype in get_available_cast_dtypes(self.DTYPE):
_test_cast(Tensor(self.DATA, dtype=dtype), self.DTYPE)
def test_casts_from(self):
for dtype in get_available_cast_dtypes(self.DTYPE):
_test_cast(Tensor(self.DATA, dtype=self.DTYPE), dtype)
def test_same_size_ops(self):
list(map(
lambda dtype: _test_ops(a_dtype=self.DTYPE, b_dtype=dtype) if dtype.itemsize == self.DTYPE.itemsize else None,
get_available_cast_dtypes(self.DTYPE)
))
for dtype in get_available_cast_dtypes(self.DTYPE):
if dtype.itemsize == self.DTYPE.itemsize:
_test_ops(a_dtype=self.DTYPE, b_dtype=dtype)
def test_upcast_ops(self):
list(map(
lambda dtype: _test_ops(a_dtype=self.DTYPE, b_dtype=dtype) if dtype.itemsize > self.DTYPE.itemsize else None,
get_available_cast_dtypes(self.DTYPE)
))
for dtype in get_available_cast_dtypes(self.DTYPE):
if dtype.itemsize > self.DTYPE.itemsize:
_test_ops(a_dtype=self.DTYPE, b_dtype=dtype)
def test_upcast_to_ops(self):
list(map(
lambda dtype: _test_ops(a_dtype=dtype, b_dtype=self.DTYPE) if dtype.itemsize < self.DTYPE.itemsize else None,
get_available_cast_dtypes(self.DTYPE)
))
for dtype in get_available_cast_dtypes(self.DTYPE):
if dtype.itemsize < self.DTYPE.itemsize:
_test_ops(a_dtype=dtype, b_dtype=self.DTYPE)
def test_bitcast(self):
if self.DTYPE == dtypes.bool: raise unittest.SkipTest("no bools in bitcast")
list(map(
lambda dtype:
_test_bitcast(Tensor(self.DATA[:8], dtype=self.DTYPE), dtype) if dtype != dtypes.bool else None,
get_available_cast_dtypes(self.DTYPE)
))
for dtype in get_available_cast_dtypes(self.DTYPE):
if dtype != dtypes.bool:
_test_bitcast(Tensor(self.DATA[:8], dtype=self.DTYPE), dtype)
@unittest.skipIf(Device.DEFAULT == "PYTHON", "skip for now")
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, (PTXRenderer, NIRRenderer)), "skip for now")
@@ -307,7 +304,7 @@ class TestBitCast(unittest.TestCase):
data = rand_for_dtype(dt1, 32).reshape(2, 2, 8)
expected = torch.tensor(data.tolist(), dtype=_to_torch_storage_type(dt1)).view(_to_torch_dtype(dt2))
if dt2 in dtypes.fp8s:
expected = torch.tensor(list(map(lambda x: fp8_to_float(x, dt2), expected.view(-1).tolist()))).view_as(expected)
expected = torch.tensor([fp8_to_float(x, dt2) for x in expected.view(-1).tolist()]).view_as(expected)
_test_op(lambda: Tensor(data, dtype=dt1).bitcast(dt2), dt2, expected.tolist())
def test_shape_change_bitcast_exceptions(self):