diff --git a/test/test_tiny.py b/test/test_tiny.py index 31bb84f595..5f9ea1e629 100644 --- a/test/test_tiny.py +++ b/test/test_tiny.py @@ -15,6 +15,10 @@ class TestTiny(unittest.TestCase): out = Tensor([1.,2,3]) self.assertListEqual(out.tolist(), [1.0, 2.0, 3.0]) + def test_elu(self): + out = Tensor([1.,2,3]).sum().elu() + self.assertListEqual(out.tolist(), [1.0, 2.0, 3.0]) + def test_plus(self): out = Tensor([1.,2,3]) + Tensor([4.,5,6]) self.assertListEqual(out.tolist(), [5.0, 7.0, 9.0]) diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index a4b3f28f1b..4c2f4155ed 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -2,7 +2,7 @@ from typing import Any, cast, Iterator import functools, operator, itertools from dataclasses import dataclass, field from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace -from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, RewriteNotReady, _substitute, ssimplify, KernelInfo +from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, ReprocessNode, _substitute, ssimplify, KernelInfo, BottomUpGate from tinygrad.uop.symbolic import sym, symbolic_simple from tinygrad.helpers import argsort, prod, all_same, pluralize, getenv, RANGEIFY, Context, flatten, dedup, unwrap, all_int, DEBUG, SPLIT_REDUCEOP from tinygrad.schedule.kernelize import Kernel @@ -151,6 +151,7 @@ class RangeifyContext: # block on parent until all children have been seen seen_children: dict[UOp, dict[int, UOp]] = field(default_factory=dict) seen_child: dict[UOp, Any] = field(default_factory=dict) + pending_children: dict[UOp, list[UOp]] = field(default_factory=dict) progress: int = 0 # create ranges @@ -271,13 +272,18 @@ def map_reduce(ctx:RangeifyContext, idx:UOp, red:UOp): def index_child(ctx:RangeifyContext, c:UOp, x:UOp, idx:UOp): if c not in ctx.seen_children: ctx.seen_children[c] = {} # wait here until we have seen all the children + ctx.seen_children[c][x.arg[0]] = idx + print("see child", x.arg) if len(ctx.seen_children[c]) != x.arg[1]: ctx.progress += 1 if ctx.progress > 10000: raise RuntimeError("children not making progress") # NOTE: we mark this here - ctx.seen_children[c][x.arg[0]] = idx - raise RewriteNotReady + print("BU GATE") + ctx.pending_children.setdefault(c, []).append(idx) + raise BottomUpGate + #raise RewriteNotReady ctx.progress = 0 + print("CHILDREN", id(c)) if c not in ctx.seen_child: all_rngs = list(zip(*[ch.src[1:] for ch in ctx.seen_children[c].values()])) @@ -321,6 +327,11 @@ def index_child(ctx:RangeifyContext, c:UOp, x:UOp, idx:UOp): def children_gate(ctx:RangeifyContext, idx:UOp, c:UOp): if len(ctx.seen_children[c]) != c.arg: raise RuntimeError("all children should have been seen by now") + if len(pc:=ctx.pending_children[c]): + pcn = pc.pop() + print("reprocess", pcn.src[0].arg) + raise ReprocessNode(pcn) + print("COMPLETE", id(c)) return idx.replace(src=(idx.src[0].src[0],)+idx.src[1:]) def might_end_axis(idx:UOp): @@ -345,7 +356,7 @@ pm_rangeify = pm_mops+PatternMatcher([ (UPat(Ops.INDEX, src=(UPat(Ops.REALIZE, src=(UPat(),), name="x"),), allow_any_len=True, name="idx"), map_partial_realize), # if there are new ended children, tag the SINK - (UPat(Ops.INDEX, src=(UPat(Ops.CHILD, src=(UPat(name="c"), ), name="x"),), allow_any_len=True, name="idx"), index_child), + (UPat(Ops.INDEX, src=(UPat(Ops.CHILD, src=(UPat(Ops.CHILDREN, name="c"), ), name="x"),), allow_any_len=True, name="idx"), index_child), (UPat(Ops.INDEX, src=(UPat(Ops.CHILDREN, name="c"),), allow_any_len=True, name="idx"), children_gate), # if we come across this, remove it. it was a CHILD unused in an INDEX diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 518cc95824..d514cb027e 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -1032,6 +1032,11 @@ if TRACK_MATCH_STATS or PROFILE: class RewriteNotReady(Exception): pass class BottomUpGate(Exception): pass +class ReprocessNode(Exception): + def __init__(self, node): + self.node = node + super().__init__(self, "reprocess node") + class RewriteContext: def __init__(self, pm, bpm, ctx=None): self.pm: PatternMatcher|None = pm @@ -1053,10 +1058,10 @@ class RewriteContext: def unified_rewrite(self, root:UOp) -> UOp: stack: collections.deque[tuple[UOp, int, UOp]] = collections.deque([(root, 0, root)]) - on_stack = {root} # all UOps either on the stack or in self.replace, i.e. dont have to be placed again while stack: if len(stack) > getenv("REWRITE_STACK_LIMIT", 250000): raise RuntimeError("infinite loop in graph_rewrite (stack too big)") n, stage, new_n = stack.pop() + #print(len(stack), stage) if n in self.replace: continue # skip any nodes we have seen try: if stage == 0: @@ -1071,15 +1076,14 @@ class RewriteContext: seen.add(test_n) new_n, test_n = test_n, self.cached_bpm_rewrite(test_n) stack.append((n, 1, new_n)) - for x in reversed(new_n.src): - if x in on_stack: continue - stack.append((x, 0, x)) - on_stack.add(x) + for x in reversed(new_n.src): stack.append((x, 0, x)) # if the bpm matching raised a gate, we are done with this node and dont continue down the srcs except BottomUpGate: self.replace[n] = new_n elif stage == 1: try: new_src = tuple([self.replace[x] for x in new_n.src]) - except KeyError: raise RewriteNotReady + except KeyError: + stack.append((new_n, 0, new_n)) + continue if new_src == new_n.src: # if top down, do the rewrite. if no rewrite or bottom up, we are done rewriting this node so we add it to the dict if self.pm is None or (new_src_n:=self.cached_pm_rewrite(new_n)) is None: @@ -1093,11 +1097,11 @@ class RewriteContext: stack.append((new_src_n, 0, new_src_n)) else: # in stage 2, we link the result of new_n to the result of n - try: self.replace[n] = self.replace[new_n] - except KeyError: raise RewriteNotReady - except RewriteNotReady: - # retry this later - stack.appendleft((n, stage, new_n)) + self.replace[n] = self.replace[new_n] + except ReprocessNode as e: + assert e.node is self.replace[e.node] + del self.replace[e.node] + stack.append((e.node, 0, e.node)) return self.replace[root] @track_matches