From 5f1ede7f7e1e5d31e14c99d3fdfd5f2a94c2af16 Mon Sep 17 00:00:00 2001 From: chenyu Date: Wed, 7 Jan 2026 15:45:42 -0500 Subject: [PATCH] clean up test_dtype (#14055) use less lambda --- test/test_dtype.py | 53 ++++++++++++++++++++++------------------------ 1 file changed, 25 insertions(+), 28 deletions(-) diff --git a/test/test_dtype.py b/test/test_dtype.py index 2104b67112..5b34d11989 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -48,11 +48,11 @@ def _test_cast(a:Tensor, target_dtype:DType): a = (a > 65504).where(65504, a) expected = list(a.numpy().astype(_to_np_dtype(target_dtype))) - if target_dtype in dtypes.fp8s: expected = list(map(lambda x: truncate[target_dtype](x), expected)) + if target_dtype in dtypes.fp8s: expected = [truncate[target_dtype](x) for x in expected] _test_op(lambda: a.cast(target_dtype), target_dtype, expected) def _test_bitcast(a:Tensor, target_dtype:DType, target=None): expected = torch.tensor(a.tolist(), dtype=_to_torch_storage_type(a.dtype)).view(_to_torch_dtype(target_dtype)).tolist() - if target_dtype in dtypes.fp8s: expected = list(map(lambda x: fp8_to_float(x, target_dtype), expected)) + if target_dtype in dtypes.fp8s: expected = [fp8_to_float(x, target_dtype) for x in expected] _test_op(lambda: a.bitcast(target_dtype), target_dtype, target or expected) class TestDType(unittest.TestCase): @@ -68,37 +68,34 @@ class TestDType(unittest.TestCase): def test_to_np(self): _test_to_np(Tensor(self.DATA, dtype=self.DTYPE), _to_np_dtype(self.DTYPE), np.array(self.DATA, dtype=_to_np_dtype(self.DTYPE))) - def test_casts_to(self): list(map( - lambda dtype: _test_cast(Tensor(self.DATA, dtype=dtype), self.DTYPE), - get_available_cast_dtypes(self.DTYPE) - )) - def test_casts_from(self): list(map( - lambda dtype: _test_cast(Tensor(self.DATA, dtype=self.DTYPE), dtype), - get_available_cast_dtypes(self.DTYPE) - )) + def test_casts_to(self): + for dtype in get_available_cast_dtypes(self.DTYPE): + _test_cast(Tensor(self.DATA, dtype=dtype), self.DTYPE) + + def test_casts_from(self): + for dtype in get_available_cast_dtypes(self.DTYPE): + _test_cast(Tensor(self.DATA, dtype=self.DTYPE), dtype) def test_same_size_ops(self): - list(map( - lambda dtype: _test_ops(a_dtype=self.DTYPE, b_dtype=dtype) if dtype.itemsize == self.DTYPE.itemsize else None, - get_available_cast_dtypes(self.DTYPE) - )) + for dtype in get_available_cast_dtypes(self.DTYPE): + if dtype.itemsize == self.DTYPE.itemsize: + _test_ops(a_dtype=self.DTYPE, b_dtype=dtype) + def test_upcast_ops(self): - list(map( - lambda dtype: _test_ops(a_dtype=self.DTYPE, b_dtype=dtype) if dtype.itemsize > self.DTYPE.itemsize else None, - get_available_cast_dtypes(self.DTYPE) - )) + for dtype in get_available_cast_dtypes(self.DTYPE): + if dtype.itemsize > self.DTYPE.itemsize: + _test_ops(a_dtype=self.DTYPE, b_dtype=dtype) + def test_upcast_to_ops(self): - list(map( - lambda dtype: _test_ops(a_dtype=dtype, b_dtype=self.DTYPE) if dtype.itemsize < self.DTYPE.itemsize else None, - get_available_cast_dtypes(self.DTYPE) - )) + for dtype in get_available_cast_dtypes(self.DTYPE): + if dtype.itemsize < self.DTYPE.itemsize: + _test_ops(a_dtype=dtype, b_dtype=self.DTYPE) + def test_bitcast(self): if self.DTYPE == dtypes.bool: raise unittest.SkipTest("no bools in bitcast") - list(map( - lambda dtype: - _test_bitcast(Tensor(self.DATA[:8], dtype=self.DTYPE), dtype) if dtype != dtypes.bool else None, - get_available_cast_dtypes(self.DTYPE) - )) + for dtype in get_available_cast_dtypes(self.DTYPE): + if dtype != dtypes.bool: + _test_bitcast(Tensor(self.DATA[:8], dtype=self.DTYPE), dtype) @unittest.skipIf(Device.DEFAULT == "PYTHON", "skip for now") @unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, (PTXRenderer, NIRRenderer)), "skip for now") @@ -307,7 +304,7 @@ class TestBitCast(unittest.TestCase): data = rand_for_dtype(dt1, 32).reshape(2, 2, 8) expected = torch.tensor(data.tolist(), dtype=_to_torch_storage_type(dt1)).view(_to_torch_dtype(dt2)) if dt2 in dtypes.fp8s: - expected = torch.tensor(list(map(lambda x: fp8_to_float(x, dt2), expected.view(-1).tolist()))).view_as(expected) + expected = torch.tensor([fp8_to_float(x, dt2) for x in expected.view(-1).tolist()]).view_as(expected) _test_op(lambda: Tensor(data, dtype=dt1).bitcast(dt2), dt2, expected.tolist()) def test_shape_change_bitcast_exceptions(self):