mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
fix lazy r const folding with variable shape (#4783)
currently not supporting const fold symbolic shape. I think it's possible with a refactor to Tensor.from_node. also added some failed required tests for symbolic arange.
This commit is contained in:
@@ -177,6 +177,17 @@ class TestSymbolicJit(unittest.TestCase):
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
assert_jit_cache_len(jf, 1)
|
||||
|
||||
def test_ones_sum(self):
|
||||
def f(a): return a.sum().realize()
|
||||
jf = TinyJit(f)
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
# TODO: without contiguous, the CONST shape are different in jit
|
||||
t = Tensor.ones(i).contiguous()
|
||||
symbolic = jf(t.reshape(vi)).item()
|
||||
expected = f(t).item()
|
||||
np.testing.assert_equal(symbolic, expected)
|
||||
|
||||
def test_mean(self):
|
||||
def f(a): return a.mean().realize()
|
||||
def f0(a): return a.mean(0).realize()
|
||||
|
||||
@@ -141,6 +141,14 @@ class TestSymbolicOps(unittest.TestCase):
|
||||
expected = a.shrink(((3,5),(i,i+2))).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
|
||||
def test_ones_sum(self):
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
t = Tensor.ones(i)
|
||||
symbolic = t.reshape(vi).sum().item()
|
||||
expected = t.sum().item()
|
||||
np.testing.assert_equal(symbolic, expected)
|
||||
|
||||
def test_mean(self):
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
|
||||
@@ -173,6 +173,17 @@ class TestSymbolicShrink(unittest.TestCase):
|
||||
t = Tensor.rand(3, 5).shrink(((0, 2), (vi, vi+1)))
|
||||
assert t.shape == (2, 1)
|
||||
|
||||
class TestSymbolicPad(unittest.TestCase):
|
||||
def test_pad(self):
|
||||
v = Variable("v", 1, 100).bind(5)
|
||||
t = Tensor.ones(5).reshape(v).pad(((4, 0),)).reshape(9)
|
||||
assert t.shape == (9,)
|
||||
st = t.lazydata.st
|
||||
print(st)
|
||||
# TODO: fix this, required for symbolic arange
|
||||
with self.assertRaises(RuntimeError):
|
||||
st.expr_idxs()
|
||||
|
||||
class TestSymbolicShapeExpr(unittest.TestCase):
|
||||
def test_symbolic_expr_idxs(self):
|
||||
# taken from symbolic shape llama
|
||||
|
||||
@@ -74,6 +74,7 @@ class LazyBuffer:
|
||||
return create_lazybuffer(device, ShapeTracker.from_shape(shape), dtype, op, arg, src, enable_cache=enable_cache)
|
||||
|
||||
def const(self, val:ConstType, shape:Optional[Tuple[sint,...]]=None) -> LazyBuffer:
|
||||
assert isinstance(val, (int,float,bool)), f"{val=} has {type(val)=}, not a ConstType"
|
||||
shape = self.shape if shape is None else shape
|
||||
return LazyBuffer.loadop(LoadOps.CONST, tuple(), self.dtype, self.device, arg=val).reshape((1,)*len(shape)).expand(shape)
|
||||
|
||||
@@ -181,7 +182,8 @@ class LazyBuffer:
|
||||
if self.size == 0 and 0 not in new_shape: return self.const({ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf}[op], new_shape)
|
||||
|
||||
# const folding
|
||||
if self.is_unrealized_unmasked_const():
|
||||
# TODO: fold this for symbolic?
|
||||
if self.is_unrealized_unmasked_const() and all_int(self.shape):
|
||||
return self.const(self.base.arg * {ReduceOps.SUM: prod(self.shape[i] for i in axis), ReduceOps.MAX: 1}[op], new_shape)
|
||||
|
||||
# TODO: can we split symbolic shape if the reduce axis is not symbolic?
|
||||
|
||||
Reference in New Issue
Block a user