mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
fix _to_const_val and const folding around it (#4017)
* fix _to_const_val and const folding around it is_unrealized_contiguous_const is too strict and almost never hit if const is expanded. suffice to check if there's no pad * that test is folded * test_const_folding
This commit is contained in:
64
test/test_const_folding.py
Normal file
64
test/test_const_folding.py
Normal file
@@ -0,0 +1,64 @@
|
||||
import unittest
|
||||
from tinygrad import Tensor
|
||||
from tinygrad.ops import BufferOps
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
|
||||
def _check_ast_count(desired_count:int, t:Tensor):
|
||||
# NOTE: this has side effect because everything can be scheduled only once
|
||||
asts = [s for s in create_schedule([t.lazydata]) if s.ast[0].op is BufferOps.STORE]
|
||||
assert len(asts) == desired_count
|
||||
|
||||
class TestSimpleConstFolding(unittest.TestCase):
|
||||
def test_add_literal_zero(self):
|
||||
_check_ast_count(0, Tensor([1, 2, 3, 4]) + 0)
|
||||
def test_add_tensor_zero(self):
|
||||
_check_ast_count(0, Tensor([1, 2, 3, 4]) + Tensor.zeros(4))
|
||||
|
||||
def test_sub_literal_zero(self):
|
||||
_check_ast_count(0, Tensor([1, 2, 3, 4]) - 0)
|
||||
def test_sub_tensor_zero(self):
|
||||
_check_ast_count(0, Tensor([1, 2, 3, 4]) - Tensor.zeros(4))
|
||||
|
||||
def test_mul_literal_zero(self):
|
||||
_check_ast_count(0, Tensor([1, 2, 3, 4]) * 0)
|
||||
def test_mul_tensor_zero(self):
|
||||
_check_ast_count(0, Tensor([1, 2, 3, 4]) * Tensor.zeros(4))
|
||||
|
||||
def test_mul_literal_one(self):
|
||||
_check_ast_count(0, Tensor([1, 2, 3, 4]) * 1)
|
||||
def test_mul_tensor_one(self):
|
||||
_check_ast_count(0, Tensor([1, 2, 3, 4]) * Tensor.ones(4))
|
||||
|
||||
def test_div_literal_one(self):
|
||||
_check_ast_count(0, Tensor([1, 2, 3, 4]) / 1)
|
||||
def test_div_tensor_one(self):
|
||||
_check_ast_count(0, Tensor([1, 2, 3, 4]) / Tensor.ones(4))
|
||||
|
||||
# TODO: fix pow const folding
|
||||
@unittest.expectedFailure
|
||||
def test_pow_literal_zero(self):
|
||||
_check_ast_count(0, Tensor([1, 2, 3, 4]) ** 0)
|
||||
@unittest.expectedFailure
|
||||
def test_pow_tensor_zero(self):
|
||||
_check_ast_count(0, Tensor([1, 2, 3, 4]) ** Tensor.zeros(4))
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_pow_literal_one(self):
|
||||
_check_ast_count(0, Tensor([1, 2, 3, 4]) ** 1)
|
||||
@unittest.expectedFailure
|
||||
def test_pow_tensor_one(self):
|
||||
_check_ast_count(0, Tensor([1, 2, 3, 4]) ** Tensor.ones(4))
|
||||
|
||||
class TestMovedConstFolding(unittest.TestCase):
|
||||
def test_add_shrunk_zero(self):
|
||||
_check_ast_count(0, Tensor([1, 2, 3, 4]) + Tensor.zeros(6).shrink(((1, 5),)))
|
||||
|
||||
def test_add_padded_zero(self):
|
||||
# TODO: it's 1 now, this might be possible to fold
|
||||
_check_ast_count(1, Tensor([1, 2, 3, 4]) + Tensor.zeros(2).pad(((1, 1),)))
|
||||
|
||||
def test_mul_shrunk_one(self):
|
||||
_check_ast_count(0, Tensor([1, 2, 3, 4]) * Tensor.ones(6).shrink(((1, 5),)))
|
||||
|
||||
def test_add_padded_one(self):
|
||||
_check_ast_count(1, Tensor([1, 2, 3, 4]) * Tensor.ones(2).pad(((1, 1),)))
|
||||
@@ -338,11 +338,6 @@ class TestJit(unittest.TestCase):
|
||||
assert isinstance(jf.jit_cache[1].prg, graph_t)
|
||||
|
||||
def test_jit_const_inputs(self):
|
||||
@TinyJit
|
||||
def f(x,y): return (x+y).realize()
|
||||
for _ in range(5):
|
||||
np.testing.assert_equal(f(Tensor.ones(3), Tensor.zeros(3)).numpy(), np.ones(3))
|
||||
|
||||
@TinyJit
|
||||
def g(x,y,z): return (x+y+z).realize()
|
||||
for i in range(5):
|
||||
|
||||
Reference in New Issue
Block a user