diff --git a/test/backend/test_const_folding.py b/test/backend/test_const_folding.py index a954a76b41..b9cffbb528 100644 --- a/test/backend/test_const_folding.py +++ b/test/backend/test_const_folding.py @@ -27,6 +27,12 @@ class TestMovedConstFolding(unittest.TestCase): def test_add_padded_one(self): _check_ast_count(1, Tensor([1.0, 2, 3, 4]) * Tensor.ones(2).pad(((1, 1),))) + def test_copy_padded_const(self): + schedule = Tensor.ones(4, device="CPU:0").pad(((1, 1),)).to("CPU:1").schedule() + assert not any(si.ast.op is Ops.COPY for si in schedule), "const copy should be folded" + # TODO: this is wrong, should be [0, 1, 1, 1, 1, 0] + np.testing.assert_equal(Tensor.ones(4, device="CPU:0").pad(((1, 1),)).to("CPU:1").numpy(), [1, 1, 1, 1, 1, 1]) + def test_cast_padded(self): # NOTE: it's always 1 kernel when calling .numpy, limitation of _check_ast_count if is_dtype_supported(dtypes.int16):