mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
fix symbolic usage. use shrink, not reshape (#11762)
* fix test_var * revert those things * fix the ones in test tiny * use better syntax * it's the same, but that's clearer * fix pad
This commit is contained in:
@@ -229,12 +229,12 @@ class TestSymbolicOps(unittest.TestCase):
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
|
||||
def test_var(self):
|
||||
a = Tensor.rand(10, 3)
|
||||
for i in range(1, 5):
|
||||
vi = Variable("i", 1, 10).bind(i)
|
||||
for axis in [None, 0, 1]:
|
||||
a = Tensor.rand(i, 3)
|
||||
expected = a.var(axis).numpy()
|
||||
symbolic = a.reshape(vi, 3).var(axis).reshape(expected.shape).numpy()
|
||||
expected = a[:i, :].var(axis).numpy()
|
||||
symbolic = a[:vi, :].var(axis).reshape(expected.shape).numpy()
|
||||
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
|
||||
|
||||
def test_var_2d(self):
|
||||
|
||||
@@ -73,17 +73,17 @@ class TestTiny(unittest.TestCase):
|
||||
|
||||
def test_symbolic(self):
|
||||
i = Variable('i', 1, 10)
|
||||
with Context(IGNORE_OOB=1):
|
||||
for s in [2,5]:
|
||||
ret = Tensor.ones(s).contiguous().reshape(i.bind(s)) + 1
|
||||
self.assertListEqual(ret.reshape(s).tolist(), [2.0]*s)
|
||||
ones = Tensor.ones(10).contiguous()
|
||||
for s in [2,5]:
|
||||
ret = ones[:i.bind(s)] + 1
|
||||
self.assertListEqual(ret.contiguous().reshape(s).tolist(), [2.0]*s)
|
||||
|
||||
def test_symbolic_reduce(self):
|
||||
i = Variable('i', 1, 10)
|
||||
with Context(IGNORE_OOB=1):
|
||||
for s in [2,5]:
|
||||
ret = Tensor.ones(s).contiguous().reshape(i.bind(s)).sum()
|
||||
self.assertEqual(ret.item(), s)
|
||||
ones = Tensor.ones(10).contiguous()
|
||||
for s in [2,5]:
|
||||
ret = ones[:i.bind(s)].sum()
|
||||
self.assertEqual(ret.item(), s)
|
||||
|
||||
# *** a model ***
|
||||
|
||||
|
||||
@@ -108,7 +108,7 @@ def collapse_to_1(shp:tuple[sint, ...], idxs:tuple[UOp, ...]) -> UOp:
|
||||
for s,src in list(zip(shp, idxs))[::-1]:
|
||||
to_sum.append(acc*src)
|
||||
acc *= s
|
||||
return sum(to_sum)
|
||||
return sum(to_sum, start=UOp.const(dtypes.int, 0))
|
||||
|
||||
def map_reshape(idx:UOp, r:UOp):
|
||||
mish = collapse_to_1(idx.shape, idx.src[1:])
|
||||
@@ -120,7 +120,7 @@ def map_reshape(idx:UOp, r:UOp):
|
||||
mish //= s
|
||||
else:
|
||||
ret.append(UOp.const(dtypes.int, 0))
|
||||
tret = ret[0].sink(*ret[1:]).simplify().src[::-1] if len(ret) else ()
|
||||
tret = ret[0].sink(*ret[1:]).simplify(tracked=True).src[::-1] if len(ret) else ()
|
||||
return r.src[0].index(*tret, dtype=idx.dtype, arg=idx.arg)
|
||||
|
||||
def map_pad(idx:UOp, r:UOp):
|
||||
@@ -129,8 +129,8 @@ def map_pad(idx:UOp, r:UOp):
|
||||
for i,(sh,(s,e)) in enumerate(zip(r.shape, r.arg)):
|
||||
if s == 0 and e == 0: continue
|
||||
where = UOp.const(dtypes.bool, True)
|
||||
if e > 0: where = where & (ret[i] < (sh-e))
|
||||
if s > 0: where = where & (ret[i] >= s)
|
||||
if resolve(e > 0): where = where & (ret[i] < (sh-e))
|
||||
if resolve(s > 0): where = where & (ret[i] >= s)
|
||||
bigwhere = bigwhere & where
|
||||
# this is safe but dumb
|
||||
ret[i] = (ret[i] - s).maximum(0).minimum(r.src[0].shape[i]-1)
|
||||
@@ -262,14 +262,14 @@ pm_rangeify = pm_mops+PatternMatcher([
|
||||
# if we come across this, remove it. it was a CHILD unused in an INDEX
|
||||
(UPat(Ops.CHILD, src=(UPat(Ops.CHILDREN, src=(UPat.var("x"),)),)), lambda x: x),
|
||||
|
||||
# CONST can't have axes. remove srcs when we idx
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.CONST, name="c"),)), lambda c: c.replace(src=())),
|
||||
# CONST (or DEFINE_VAR) can't have axes. remove srcs when we idx
|
||||
(UPat(Ops.INDEX, src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="c"),)), lambda c: c.replace(src=())),
|
||||
|
||||
# handle arg on any op with weight. old endrange stuff
|
||||
(UPat(Ops.INDEX, src=(UPat(GroupOp.Elementwise.union({Ops.REDUCE_AXIS})),), allow_any_len=True, name="idx"), might_end_axis),
|
||||
|
||||
# move MAP through elementwise ALU / reduce. these are the items with cost
|
||||
(UPat(Ops.INDEX, src=(UPat(GroupOp.Elementwise.union({Ops.STORE, Ops.ASSIGN, Ops.COPY, Ops.DEVICE})),), allow_any_len=True, name="x"),
|
||||
(UPat(Ops.INDEX, src=(UPat(GroupOp.Elementwise.union({Ops.STORE, Ops.ASSIGN, Ops.COPY, Ops.DEVICE, Ops.BIND})),), allow_any_len=True, name="x"),
|
||||
lambda x: x.src[0].replace(src=tuple([s.index(*x.src[1:]) for s in x.src[0].src]))),
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.REDUCE_AXIS, name="red"),), allow_any_len=True, name="idx"), map_reduce),
|
||||
])
|
||||
@@ -393,9 +393,9 @@ def split_store(x:UOp):
|
||||
ctx = LocalAddBufferContext()
|
||||
ret = graph_rewrite(x, to_define_global, ctx=ctx, name="kernel split", bottom_up=True)
|
||||
|
||||
store_rngs = x.src[2:]
|
||||
store_rngs = ret.src[2:]
|
||||
rng = sorted([u for u in ret.toposort() if u.op is Ops.RANGE], key=lambda x: x.arg)
|
||||
name = "k"+colored('_', 'BLACK').join(['']+[colored(str(s.vmax+1), "WHITE") if s in store_rngs else colored(str(s.vmax+1), "red") for s in rng])
|
||||
name = "k"+colored('_', 'BLACK').join(['']+[colored(s.src[0].render(), "WHITE" if s in store_rngs else "red") for s in rng])
|
||||
|
||||
# NOTE: the hack for COPY is here
|
||||
ret = ret.sink(arg=KernelInfo(name=name)) if ret.src[1].op is not Ops.COPY else ret.src[1]
|
||||
|
||||
@@ -210,11 +210,11 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
|
||||
# *** uop evaluation ***
|
||||
|
||||
def simplify(self):
|
||||
def simplify(self, tracked=False):
|
||||
# late import!
|
||||
from tinygrad.uop.symbolic import symbolic
|
||||
with Context(TRACK_MATCH_STATS=0):
|
||||
return graph_rewrite(self, symbolic)
|
||||
with Context(TRACK_MATCH_STATS=0 if not tracked else TRACK_MATCH_STATS.value):
|
||||
return graph_rewrite(self, symbolic, name="simplify")
|
||||
def ssimplify(self) -> UOp|ConstType: return ret.arg if (ret:=self.simplify()).op is Ops.CONST else ret
|
||||
def _eval(self, dtype, expected_type:Type[T]) -> T:
|
||||
assert self.dtype in dtype, f"eval with wrong dtype {self}"
|
||||
|
||||
Reference in New Issue
Block a user