fix _to_const_val and const folding around it (#4017)

* fix _to_const_val and const folding around it

is_unrealized_contiguous_const is too strict and almost never hit if const is expanded.
suffice to check if there's no pad

* that test is folded

* test_const_folding
This commit is contained in:
chenyu
2024-03-31 13:09:23 -04:00
committed by GitHub
parent 2abb474d43
commit 7f859593b8
5 changed files with 67 additions and 8 deletions

View File

@@ -0,0 +1,64 @@
import unittest
from tinygrad import Tensor
from tinygrad.ops import BufferOps
from tinygrad.engine.schedule import create_schedule
def _check_ast_count(desired_count:int, t:Tensor):
# NOTE: this has side effect because everything can be scheduled only once
asts = [s for s in create_schedule([t.lazydata]) if s.ast[0].op is BufferOps.STORE]
assert len(asts) == desired_count
class TestSimpleConstFolding(unittest.TestCase):
def test_add_literal_zero(self):
_check_ast_count(0, Tensor([1, 2, 3, 4]) + 0)
def test_add_tensor_zero(self):
_check_ast_count(0, Tensor([1, 2, 3, 4]) + Tensor.zeros(4))
def test_sub_literal_zero(self):
_check_ast_count(0, Tensor([1, 2, 3, 4]) - 0)
def test_sub_tensor_zero(self):
_check_ast_count(0, Tensor([1, 2, 3, 4]) - Tensor.zeros(4))
def test_mul_literal_zero(self):
_check_ast_count(0, Tensor([1, 2, 3, 4]) * 0)
def test_mul_tensor_zero(self):
_check_ast_count(0, Tensor([1, 2, 3, 4]) * Tensor.zeros(4))
def test_mul_literal_one(self):
_check_ast_count(0, Tensor([1, 2, 3, 4]) * 1)
def test_mul_tensor_one(self):
_check_ast_count(0, Tensor([1, 2, 3, 4]) * Tensor.ones(4))
def test_div_literal_one(self):
_check_ast_count(0, Tensor([1, 2, 3, 4]) / 1)
def test_div_tensor_one(self):
_check_ast_count(0, Tensor([1, 2, 3, 4]) / Tensor.ones(4))
# TODO: fix pow const folding
@unittest.expectedFailure
def test_pow_literal_zero(self):
_check_ast_count(0, Tensor([1, 2, 3, 4]) ** 0)
@unittest.expectedFailure
def test_pow_tensor_zero(self):
_check_ast_count(0, Tensor([1, 2, 3, 4]) ** Tensor.zeros(4))
@unittest.expectedFailure
def test_pow_literal_one(self):
_check_ast_count(0, Tensor([1, 2, 3, 4]) ** 1)
@unittest.expectedFailure
def test_pow_tensor_one(self):
_check_ast_count(0, Tensor([1, 2, 3, 4]) ** Tensor.ones(4))
class TestMovedConstFolding(unittest.TestCase):
def test_add_shrunk_zero(self):
_check_ast_count(0, Tensor([1, 2, 3, 4]) + Tensor.zeros(6).shrink(((1, 5),)))
def test_add_padded_zero(self):
# TODO: it's 1 now, this might be possible to fold
_check_ast_count(1, Tensor([1, 2, 3, 4]) + Tensor.zeros(2).pad(((1, 1),)))
def test_mul_shrunk_one(self):
_check_ast_count(0, Tensor([1, 2, 3, 4]) * Tensor.ones(6).shrink(((1, 5),)))
def test_add_padded_one(self):
_check_ast_count(1, Tensor([1, 2, 3, 4]) * Tensor.ones(2).pad(((1, 1),)))

View File

@@ -338,11 +338,6 @@ class TestJit(unittest.TestCase):
assert isinstance(jf.jit_cache[1].prg, graph_t)
def test_jit_const_inputs(self):
@TinyJit
def f(x,y): return (x+y).realize()
for _ in range(5):
np.testing.assert_equal(f(Tensor.ones(3), Tensor.zeros(3)).numpy(), np.ones(3))
@TinyJit
def g(x,y,z): return (x+y+z).realize()
for i in range(5):