revert the removal of CAST_BEFORE_VIEW (#4471)

this brings most of the memory gain for resnet back.
This commit is contained in:
chenyu
2024-05-08 00:14:29 -04:00
committed by GitHub
parent 5dbab7fae6
commit c508eb7425
3 changed files with 10 additions and 5 deletions

View File

@@ -106,12 +106,14 @@ class TestMovedConstFolding(unittest.TestCase):
_check_ast_count(1, Tensor([1.0, 2, 3, 4]) * Tensor.ones(2).pad(((1, 1),)))
def test_cast_padded(self):
_check_ast_count(1, Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int16))
# NOTE: this is folded due to CAST_BEFORE_VIEW
_check_ast_count(0, Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int16))
np.testing.assert_equal(Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int16).numpy(), [0, 1, 1, 1, 1, 0])
_check_ast_count(0, Tensor.full(4, fill_value=-1).pad(((1, 1),)).cast(dtypes.uint16))
np.testing.assert_equal(Tensor.full(4, fill_value=-1).pad(((1, 1),)).cast(dtypes.uint16).numpy(), [0, 65535, 65535, 65535, 65535, 0])
# not folded
_check_ast_count(1, Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int64))
np.testing.assert_equal(Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int64).numpy(), [0, 1, 1, 1, 1, 0])
_check_ast_count(1, Tensor.full(4, fill_value=-1).pad(((1, 1),)).cast(dtypes.uint16))
np.testing.assert_equal(Tensor.full(4, fill_value=-1).pad(((1, 1),)).cast(dtypes.uint16).numpy(), [0, 65535, 65535, 65535, 65535, 0])
class TestReduceOpsConstFolding(unittest.TestCase):
def test_const_sum(self):