mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-15 01:48:23 -05:00
experiments with reprocessing node
This commit is contained in:
@@ -15,6 +15,10 @@ class TestTiny(unittest.TestCase):
|
||||
out = Tensor([1.,2,3])
|
||||
self.assertListEqual(out.tolist(), [1.0, 2.0, 3.0])
|
||||
|
||||
def test_elu(self):
|
||||
out = Tensor([1.,2,3]).sum().elu()
|
||||
self.assertListEqual(out.tolist(), [1.0, 2.0, 3.0])
|
||||
|
||||
def test_plus(self):
|
||||
out = Tensor([1.,2,3]) + Tensor([4.,5,6])
|
||||
self.assertListEqual(out.tolist(), [5.0, 7.0, 9.0])
|
||||
|
||||
@@ -2,7 +2,7 @@ from typing import Any, cast, Iterator
|
||||
import functools, operator, itertools
|
||||
from dataclasses import dataclass, field
|
||||
from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, RewriteNotReady, _substitute, ssimplify, KernelInfo
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, ReprocessNode, _substitute, ssimplify, KernelInfo, BottomUpGate
|
||||
from tinygrad.uop.symbolic import sym, symbolic_simple
|
||||
from tinygrad.helpers import argsort, prod, all_same, pluralize, getenv, RANGEIFY, Context, flatten, dedup, unwrap, all_int, DEBUG, SPLIT_REDUCEOP
|
||||
from tinygrad.schedule.kernelize import Kernel
|
||||
@@ -151,6 +151,7 @@ class RangeifyContext:
|
||||
# block on parent until all children have been seen
|
||||
seen_children: dict[UOp, dict[int, UOp]] = field(default_factory=dict)
|
||||
seen_child: dict[UOp, Any] = field(default_factory=dict)
|
||||
pending_children: dict[UOp, list[UOp]] = field(default_factory=dict)
|
||||
progress: int = 0
|
||||
|
||||
# create ranges
|
||||
@@ -271,13 +272,18 @@ def map_reduce(ctx:RangeifyContext, idx:UOp, red:UOp):
|
||||
def index_child(ctx:RangeifyContext, c:UOp, x:UOp, idx:UOp):
|
||||
if c not in ctx.seen_children: ctx.seen_children[c] = {}
|
||||
# wait here until we have seen all the children
|
||||
ctx.seen_children[c][x.arg[0]] = idx
|
||||
print("see child", x.arg)
|
||||
if len(ctx.seen_children[c]) != x.arg[1]:
|
||||
ctx.progress += 1
|
||||
if ctx.progress > 10000: raise RuntimeError("children not making progress")
|
||||
# NOTE: we mark this here
|
||||
ctx.seen_children[c][x.arg[0]] = idx
|
||||
raise RewriteNotReady
|
||||
print("BU GATE")
|
||||
ctx.pending_children.setdefault(c, []).append(idx)
|
||||
raise BottomUpGate
|
||||
#raise RewriteNotReady
|
||||
ctx.progress = 0
|
||||
print("CHILDREN", id(c))
|
||||
|
||||
if c not in ctx.seen_child:
|
||||
all_rngs = list(zip(*[ch.src[1:] for ch in ctx.seen_children[c].values()]))
|
||||
@@ -321,6 +327,11 @@ def index_child(ctx:RangeifyContext, c:UOp, x:UOp, idx:UOp):
|
||||
|
||||
def children_gate(ctx:RangeifyContext, idx:UOp, c:UOp):
|
||||
if len(ctx.seen_children[c]) != c.arg: raise RuntimeError("all children should have been seen by now")
|
||||
if len(pc:=ctx.pending_children[c]):
|
||||
pcn = pc.pop()
|
||||
print("reprocess", pcn.src[0].arg)
|
||||
raise ReprocessNode(pcn)
|
||||
print("COMPLETE", id(c))
|
||||
return idx.replace(src=(idx.src[0].src[0],)+idx.src[1:])
|
||||
|
||||
def might_end_axis(idx:UOp):
|
||||
@@ -345,7 +356,7 @@ pm_rangeify = pm_mops+PatternMatcher([
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.REALIZE, src=(UPat(),), name="x"),), allow_any_len=True, name="idx"), map_partial_realize),
|
||||
|
||||
# if there are new ended children, tag the SINK
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.CHILD, src=(UPat(name="c"), ), name="x"),), allow_any_len=True, name="idx"), index_child),
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.CHILD, src=(UPat(Ops.CHILDREN, name="c"), ), name="x"),), allow_any_len=True, name="idx"), index_child),
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.CHILDREN, name="c"),), allow_any_len=True, name="idx"), children_gate),
|
||||
|
||||
# if we come across this, remove it. it was a CHILD unused in an INDEX
|
||||
|
||||
@@ -1032,6 +1032,11 @@ if TRACK_MATCH_STATS or PROFILE:
|
||||
|
||||
class RewriteNotReady(Exception): pass
|
||||
class BottomUpGate(Exception): pass
|
||||
class ReprocessNode(Exception):
|
||||
def __init__(self, node):
|
||||
self.node = node
|
||||
super().__init__(self, "reprocess node")
|
||||
|
||||
class RewriteContext:
|
||||
def __init__(self, pm, bpm, ctx=None):
|
||||
self.pm: PatternMatcher|None = pm
|
||||
@@ -1053,10 +1058,10 @@ class RewriteContext:
|
||||
|
||||
def unified_rewrite(self, root:UOp) -> UOp:
|
||||
stack: collections.deque[tuple[UOp, int, UOp]] = collections.deque([(root, 0, root)])
|
||||
on_stack = {root} # all UOps either on the stack or in self.replace, i.e. dont have to be placed again
|
||||
while stack:
|
||||
if len(stack) > getenv("REWRITE_STACK_LIMIT", 250000): raise RuntimeError("infinite loop in graph_rewrite (stack too big)")
|
||||
n, stage, new_n = stack.pop()
|
||||
#print(len(stack), stage)
|
||||
if n in self.replace: continue # skip any nodes we have seen
|
||||
try:
|
||||
if stage == 0:
|
||||
@@ -1071,15 +1076,14 @@ class RewriteContext:
|
||||
seen.add(test_n)
|
||||
new_n, test_n = test_n, self.cached_bpm_rewrite(test_n)
|
||||
stack.append((n, 1, new_n))
|
||||
for x in reversed(new_n.src):
|
||||
if x in on_stack: continue
|
||||
stack.append((x, 0, x))
|
||||
on_stack.add(x)
|
||||
for x in reversed(new_n.src): stack.append((x, 0, x))
|
||||
# if the bpm matching raised a gate, we are done with this node and dont continue down the srcs
|
||||
except BottomUpGate: self.replace[n] = new_n
|
||||
elif stage == 1:
|
||||
try: new_src = tuple([self.replace[x] for x in new_n.src])
|
||||
except KeyError: raise RewriteNotReady
|
||||
except KeyError:
|
||||
stack.append((new_n, 0, new_n))
|
||||
continue
|
||||
if new_src == new_n.src:
|
||||
# if top down, do the rewrite. if no rewrite or bottom up, we are done rewriting this node so we add it to the dict
|
||||
if self.pm is None or (new_src_n:=self.cached_pm_rewrite(new_n)) is None:
|
||||
@@ -1093,11 +1097,11 @@ class RewriteContext:
|
||||
stack.append((new_src_n, 0, new_src_n))
|
||||
else:
|
||||
# in stage 2, we link the result of new_n to the result of n
|
||||
try: self.replace[n] = self.replace[new_n]
|
||||
except KeyError: raise RewriteNotReady
|
||||
except RewriteNotReady:
|
||||
# retry this later
|
||||
stack.appendleft((n, stage, new_n))
|
||||
self.replace[n] = self.replace[new_n]
|
||||
except ReprocessNode as e:
|
||||
assert e.node is self.replace[e.node]
|
||||
del self.replace[e.node]
|
||||
stack.append((e.node, 0, e.node))
|
||||
return self.replace[root]
|
||||
|
||||
@track_matches
|
||||
|
||||
Reference in New Issue
Block a user