fix symbolic usage. use shrink, not reshape (#11762)

* fix test_var

* revert those things

* fix the ones in test tiny

* use better syntax

* it's the same, but that's clearer

* fix pad
This commit is contained in:
George Hotz
2025-08-20 18:35:42 -07:00
committed by GitHub
parent 5276fbc9c5
commit 9f94c25a25
4 changed files with 23 additions and 23 deletions

View File

@@ -229,12 +229,12 @@ class TestSymbolicOps(unittest.TestCase):
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
def test_var(self):
a = Tensor.rand(10, 3)
for i in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
for axis in [None, 0, 1]:
a = Tensor.rand(i, 3)
expected = a.var(axis).numpy()
symbolic = a.reshape(vi, 3).var(axis).reshape(expected.shape).numpy()
expected = a[:i, :].var(axis).numpy()
symbolic = a[:vi, :].var(axis).reshape(expected.shape).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
def test_var_2d(self):

View File

@@ -73,17 +73,17 @@ class TestTiny(unittest.TestCase):
def test_symbolic(self):
i = Variable('i', 1, 10)
with Context(IGNORE_OOB=1):
for s in [2,5]:
ret = Tensor.ones(s).contiguous().reshape(i.bind(s)) + 1
self.assertListEqual(ret.reshape(s).tolist(), [2.0]*s)
ones = Tensor.ones(10).contiguous()
for s in [2,5]:
ret = ones[:i.bind(s)] + 1
self.assertListEqual(ret.contiguous().reshape(s).tolist(), [2.0]*s)
def test_symbolic_reduce(self):
i = Variable('i', 1, 10)
with Context(IGNORE_OOB=1):
for s in [2,5]:
ret = Tensor.ones(s).contiguous().reshape(i.bind(s)).sum()
self.assertEqual(ret.item(), s)
ones = Tensor.ones(10).contiguous()
for s in [2,5]:
ret = ones[:i.bind(s)].sum()
self.assertEqual(ret.item(), s)
# *** a model ***

View File

@@ -108,7 +108,7 @@ def collapse_to_1(shp:tuple[sint, ...], idxs:tuple[UOp, ...]) -> UOp:
for s,src in list(zip(shp, idxs))[::-1]:
to_sum.append(acc*src)
acc *= s
return sum(to_sum)
return sum(to_sum, start=UOp.const(dtypes.int, 0))
def map_reshape(idx:UOp, r:UOp):
mish = collapse_to_1(idx.shape, idx.src[1:])
@@ -120,7 +120,7 @@ def map_reshape(idx:UOp, r:UOp):
mish //= s
else:
ret.append(UOp.const(dtypes.int, 0))
tret = ret[0].sink(*ret[1:]).simplify().src[::-1] if len(ret) else ()
tret = ret[0].sink(*ret[1:]).simplify(tracked=True).src[::-1] if len(ret) else ()
return r.src[0].index(*tret, dtype=idx.dtype, arg=idx.arg)
def map_pad(idx:UOp, r:UOp):
@@ -129,8 +129,8 @@ def map_pad(idx:UOp, r:UOp):
for i,(sh,(s,e)) in enumerate(zip(r.shape, r.arg)):
if s == 0 and e == 0: continue
where = UOp.const(dtypes.bool, True)
if e > 0: where = where & (ret[i] < (sh-e))
if s > 0: where = where & (ret[i] >= s)
if resolve(e > 0): where = where & (ret[i] < (sh-e))
if resolve(s > 0): where = where & (ret[i] >= s)
bigwhere = bigwhere & where
# this is safe but dumb
ret[i] = (ret[i] - s).maximum(0).minimum(r.src[0].shape[i]-1)
@@ -262,14 +262,14 @@ pm_rangeify = pm_mops+PatternMatcher([
# if we come across this, remove it. it was a CHILD unused in an INDEX
(UPat(Ops.CHILD, src=(UPat(Ops.CHILDREN, src=(UPat.var("x"),)),)), lambda x: x),
# CONST can't have axes. remove srcs when we idx
(UPat(Ops.INDEX, src=(UPat(Ops.CONST, name="c"),)), lambda c: c.replace(src=())),
# CONST (or DEFINE_VAR) can't have axes. remove srcs when we idx
(UPat(Ops.INDEX, src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="c"),)), lambda c: c.replace(src=())),
# handle arg on any op with weight. old endrange stuff
(UPat(Ops.INDEX, src=(UPat(GroupOp.Elementwise.union({Ops.REDUCE_AXIS})),), allow_any_len=True, name="idx"), might_end_axis),
# move MAP through elementwise ALU / reduce. these are the items with cost
(UPat(Ops.INDEX, src=(UPat(GroupOp.Elementwise.union({Ops.STORE, Ops.ASSIGN, Ops.COPY, Ops.DEVICE})),), allow_any_len=True, name="x"),
(UPat(Ops.INDEX, src=(UPat(GroupOp.Elementwise.union({Ops.STORE, Ops.ASSIGN, Ops.COPY, Ops.DEVICE, Ops.BIND})),), allow_any_len=True, name="x"),
lambda x: x.src[0].replace(src=tuple([s.index(*x.src[1:]) for s in x.src[0].src]))),
(UPat(Ops.INDEX, src=(UPat(Ops.REDUCE_AXIS, name="red"),), allow_any_len=True, name="idx"), map_reduce),
])
@@ -393,9 +393,9 @@ def split_store(x:UOp):
ctx = LocalAddBufferContext()
ret = graph_rewrite(x, to_define_global, ctx=ctx, name="kernel split", bottom_up=True)
store_rngs = x.src[2:]
store_rngs = ret.src[2:]
rng = sorted([u for u in ret.toposort() if u.op is Ops.RANGE], key=lambda x: x.arg)
name = "k"+colored('_', 'BLACK').join(['']+[colored(str(s.vmax+1), "WHITE") if s in store_rngs else colored(str(s.vmax+1), "red") for s in rng])
name = "k"+colored('_', 'BLACK').join(['']+[colored(s.src[0].render(), "WHITE" if s in store_rngs else "red") for s in rng])
# NOTE: the hack for COPY is here
ret = ret.sink(arg=KernelInfo(name=name)) if ret.src[1].op is not Ops.COPY else ret.src[1]

View File

@@ -210,11 +210,11 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
# *** uop evaluation ***
def simplify(self):
def simplify(self, tracked=False):
# late import!
from tinygrad.uop.symbolic import symbolic
with Context(TRACK_MATCH_STATS=0):
return graph_rewrite(self, symbolic)
with Context(TRACK_MATCH_STATS=0 if not tracked else TRACK_MATCH_STATS.value):
return graph_rewrite(self, symbolic, name="simplify")
def ssimplify(self) -> UOp|ConstType: return ret.arg if (ret:=self.simplify()).op is Ops.CONST else ret
def _eval(self, dtype, expected_type:Type[T]) -> T:
assert self.dtype in dtype, f"eval with wrong dtype {self}"