update RANGEIFY test_cast_padded (#12421)

* update RANGEIFY test_cast_padded

* update test
This commit is contained in:
chenyu
2025-10-02 16:37:35 +08:00
committed by GitHub
parent 37beef6de3
commit 98163832e4
2 changed files with 5 additions and 6 deletions

View File

@@ -165,15 +165,16 @@ class TestMovedConstFolding(unittest.TestCase):
_check_ast_count(1, Tensor([1.0, 2, 3, 4]) * Tensor.ones(2).pad(((1, 1),)))
def test_cast_padded(self):
# NOTE: RANGEIFY or not, it's always 1 kernel when calling .numpy, limitation of _check_ast_count
if is_dtype_supported(dtypes.int16):
_check_ast_count(0, Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int16))
_check_ast_count(1 if RANGEIFY else 0, Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int16))
np.testing.assert_equal(Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int16).numpy(), [0, 1, 1, 1, 1, 0])
if is_dtype_supported(dtypes.uint16):
_check_ast_count(0, Tensor.full(4, fill_value=-1).pad(((1, 1),)).cast(dtypes.uint16))
_check_ast_count(1 if RANGEIFY else 0, Tensor.full(4, fill_value=-1).pad(((1, 1),)).cast(dtypes.uint16))
np.testing.assert_equal(Tensor.full(4, fill_value=-1).pad(((1, 1),)).cast(dtypes.uint16).numpy(), [0, 65535, 65535, 65535, 65535, 0])
# folded
if is_dtype_supported(dtypes.int64):
_check_ast_count(0, Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int64))
_check_ast_count(1 if RANGEIFY else 0, Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int64))
np.testing.assert_equal(Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int64).numpy(), [0, 1, 1, 1, 1, 0])
class TestReduceOpsConstFolding(unittest.TestCase):