experiments with reprocessing node

This commit is contained in:
George Hotz
2025-10-05 14:12:40 +08:00
parent 6538935441
commit c600446299
3 changed files with 34 additions and 15 deletions

View File

@@ -15,6 +15,10 @@ class TestTiny(unittest.TestCase):
out = Tensor([1.,2,3])
self.assertListEqual(out.tolist(), [1.0, 2.0, 3.0])
def test_elu(self):
out = Tensor([1.,2,3]).sum().elu()
self.assertListEqual(out.tolist(), [1.0, 2.0, 3.0])
def test_plus(self):
out = Tensor([1.,2,3]) + Tensor([4.,5,6])
self.assertListEqual(out.tolist(), [5.0, 7.0, 9.0])

View File

@@ -2,7 +2,7 @@ from typing import Any, cast, Iterator
import functools, operator, itertools
from dataclasses import dataclass, field
from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, RewriteNotReady, _substitute, ssimplify, KernelInfo
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, ReprocessNode, _substitute, ssimplify, KernelInfo, BottomUpGate
from tinygrad.uop.symbolic import sym, symbolic_simple
from tinygrad.helpers import argsort, prod, all_same, pluralize, getenv, RANGEIFY, Context, flatten, dedup, unwrap, all_int, DEBUG, SPLIT_REDUCEOP
from tinygrad.schedule.kernelize import Kernel
@@ -151,6 +151,7 @@ class RangeifyContext:
# block on parent until all children have been seen
seen_children: dict[UOp, dict[int, UOp]] = field(default_factory=dict)
seen_child: dict[UOp, Any] = field(default_factory=dict)
pending_children: dict[UOp, list[UOp]] = field(default_factory=dict)
progress: int = 0
# create ranges
@@ -271,13 +272,18 @@ def map_reduce(ctx:RangeifyContext, idx:UOp, red:UOp):
def index_child(ctx:RangeifyContext, c:UOp, x:UOp, idx:UOp):
if c not in ctx.seen_children: ctx.seen_children[c] = {}
# wait here until we have seen all the children
ctx.seen_children[c][x.arg[0]] = idx
print("see child", x.arg)
if len(ctx.seen_children[c]) != x.arg[1]:
ctx.progress += 1
if ctx.progress > 10000: raise RuntimeError("children not making progress")
# NOTE: we mark this here
ctx.seen_children[c][x.arg[0]] = idx
raise RewriteNotReady
print("BU GATE")
ctx.pending_children.setdefault(c, []).append(idx)
raise BottomUpGate
#raise RewriteNotReady
ctx.progress = 0
print("CHILDREN", id(c))
if c not in ctx.seen_child:
all_rngs = list(zip(*[ch.src[1:] for ch in ctx.seen_children[c].values()]))
@@ -321,6 +327,11 @@ def index_child(ctx:RangeifyContext, c:UOp, x:UOp, idx:UOp):
def children_gate(ctx:RangeifyContext, idx:UOp, c:UOp):
if len(ctx.seen_children[c]) != c.arg: raise RuntimeError("all children should have been seen by now")
if len(pc:=ctx.pending_children[c]):
pcn = pc.pop()
print("reprocess", pcn.src[0].arg)
raise ReprocessNode(pcn)
print("COMPLETE", id(c))
return idx.replace(src=(idx.src[0].src[0],)+idx.src[1:])
def might_end_axis(idx:UOp):
@@ -345,7 +356,7 @@ pm_rangeify = pm_mops+PatternMatcher([
(UPat(Ops.INDEX, src=(UPat(Ops.REALIZE, src=(UPat(),), name="x"),), allow_any_len=True, name="idx"), map_partial_realize),
# if there are new ended children, tag the SINK
(UPat(Ops.INDEX, src=(UPat(Ops.CHILD, src=(UPat(name="c"), ), name="x"),), allow_any_len=True, name="idx"), index_child),
(UPat(Ops.INDEX, src=(UPat(Ops.CHILD, src=(UPat(Ops.CHILDREN, name="c"), ), name="x"),), allow_any_len=True, name="idx"), index_child),
(UPat(Ops.INDEX, src=(UPat(Ops.CHILDREN, name="c"),), allow_any_len=True, name="idx"), children_gate),
# if we come across this, remove it. it was a CHILD unused in an INDEX

View File

@@ -1032,6 +1032,11 @@ if TRACK_MATCH_STATS or PROFILE:
class RewriteNotReady(Exception): pass
class BottomUpGate(Exception): pass
class ReprocessNode(Exception):
def __init__(self, node):
self.node = node
super().__init__(self, "reprocess node")
class RewriteContext:
def __init__(self, pm, bpm, ctx=None):
self.pm: PatternMatcher|None = pm
@@ -1053,10 +1058,10 @@ class RewriteContext:
def unified_rewrite(self, root:UOp) -> UOp:
stack: collections.deque[tuple[UOp, int, UOp]] = collections.deque([(root, 0, root)])
on_stack = {root} # all UOps either on the stack or in self.replace, i.e. dont have to be placed again
while stack:
if len(stack) > getenv("REWRITE_STACK_LIMIT", 250000): raise RuntimeError("infinite loop in graph_rewrite (stack too big)")
n, stage, new_n = stack.pop()
#print(len(stack), stage)
if n in self.replace: continue # skip any nodes we have seen
try:
if stage == 0:
@@ -1071,15 +1076,14 @@ class RewriteContext:
seen.add(test_n)
new_n, test_n = test_n, self.cached_bpm_rewrite(test_n)
stack.append((n, 1, new_n))
for x in reversed(new_n.src):
if x in on_stack: continue
stack.append((x, 0, x))
on_stack.add(x)
for x in reversed(new_n.src): stack.append((x, 0, x))
# if the bpm matching raised a gate, we are done with this node and dont continue down the srcs
except BottomUpGate: self.replace[n] = new_n
elif stage == 1:
try: new_src = tuple([self.replace[x] for x in new_n.src])
except KeyError: raise RewriteNotReady
except KeyError:
stack.append((new_n, 0, new_n))
continue
if new_src == new_n.src:
# if top down, do the rewrite. if no rewrite or bottom up, we are done rewriting this node so we add it to the dict
if self.pm is None or (new_src_n:=self.cached_pm_rewrite(new_n)) is None:
@@ -1093,11 +1097,11 @@ class RewriteContext:
stack.append((new_src_n, 0, new_src_n))
else:
# in stage 2, we link the result of new_n to the result of n
try: self.replace[n] = self.replace[new_n]
except KeyError: raise RewriteNotReady
except RewriteNotReady:
# retry this later
stack.appendleft((n, stage, new_n))
self.replace[n] = self.replace[new_n]
except ReprocessNode as e:
assert e.node is self.replace[e.node]
del self.replace[e.node]
stack.append((e.node, 0, e.node))
return self.replace[root]
@track_matches