start function and add walk rewrite (#14992)

* start function and add walk rewrite

* work

* add function on feed_forward

* llm progress

* stuff

* none of that
This commit is contained in:
George Hotz
2026-02-25 13:56:27 +08:00
committed by GitHub
parent fde7a40bb0
commit e3fa9896b7
8 changed files with 310 additions and 12 deletions

View File

@@ -349,5 +349,184 @@ class TestStopEarly(unittest.TestCase):
ret = (c+d).substitute({c:cn}, extra_pm=pm_cvisit)
assert ret == cn+d
class TestWalkRewrite(unittest.TestCase):
"""Tests for graph_rewrite with walk=True (MLIR Walk Pattern Rewrite Driver semantics).
walk=True gives a single-pass traversal that does NOT revisit or re-traverse into rewritten subtrees.
Supports both top-down (default) and bottom-up (bottom_up=True) modes."""
# *** top-down walk (default): process children first, then try pm on rebuilt node ***
def test_walk_topdown_simple_substitute(self):
a = UOp.variable('a', 0, 10)
b = UOp.variable('b', 0, 10)
ret = graph_rewrite(a + 4, _substitute, {a:b}, walk=True)
self.assertIs(ret, b+4)
def test_walk_topdown_does_not_traverse_into_replacement(self):
"""Top-down walk: replacement subtrees are NOT re-entered."""
a = UOp.variable('a', 0, 10)
b = UOp.variable('b', 0, 10)
c = UOp.variable('c', 0, 10)
d = UOp.variable('d', 0, 10)
# a is replaced by b+c, but b inside the replacement is NOT further substituted to d
ret_walk = graph_rewrite(a + 4, _substitute, {a:b+c, b:d}, walk=True)
self.assertIs(ret_walk, (b+c)+4)
# contrast: greedy bottom_up WOULD replace b inside the replacement
ret_greedy = graph_rewrite(a + 4, _substitute, {a:b+c, b:d}, bottom_up=True)
self.assertIs(ret_greedy, (d+c)+4)
def test_walk_topdown_no_fixed_point(self):
"""A bouncing pattern applies once and stops instead of looping."""
a = UOp.const(dtypes.int, 3)
pm = PatternMatcher([
(UPat(Ops.CONST, arg=3, name="x"), lambda x: x.replace(arg=4)),
(UPat(Ops.CONST, arg=4, name="x"), lambda x: x.replace(arg=3)),
])
with self.assertRaises(RuntimeError):
graph_rewrite(a, pm, bottom_up=True)
ret = graph_rewrite(a, pm, walk=True)
self.assertIs(ret, UOp.const(dtypes.int, 4))
def test_walk_topdown_rewrites_children(self):
a = UOp.variable('a', 0, 10)
b = UOp.variable('b', 0, 10)
c = UOp.variable('c', 0, 10)
ret = graph_rewrite((a + 4) + (b + 5), _substitute, {a:c, b:c}, walk=True)
self.assertIs(ret, (c + 4) + (c + 5))
def test_walk_topdown_diamond(self):
a = UOp.variable('a', 0, 10)
b = UOp.variable('b', 0, 10)
ret = graph_rewrite((a + 4) + (a + 5), _substitute, {a:b}, walk=True)
self.assertIs(ret, (b + 4) + (b + 5))
def test_walk_topdown_children_rewritten_before_parent(self):
"""Top-down walk processes children first: child substitution changes the rebuilt parent."""
a = UOp.variable('a', 0, 10, dtype=dtypes.float)
n1 = a.sin() # sin(a)
ret = n1.sin() # sin(sin(a))
# sin(a)->sqrt(a) fires first (child), parent rebuilds to sin(sqrt(a)), which doesn't match sin(sin(a)) in dvars
ret_walk = graph_rewrite(ret, _substitute, {a.sin():a.sqrt(), n1.sin():n1.sqrt()}, walk=True)
self.assertIs(ret_walk, a.sqrt().sin())
def test_walk_topdown_self_referential_replacement(self):
"""Replacement containing the replaced node works without infinite recursion."""
a = UOp.variable('a', 0, 10, dtype=dtypes.float)
ret = graph_rewrite(a.sin() + 4, _substitute, {a.sin(): a.sin().sqrt()}, walk=True)
self.assertIs(ret, a.sin().sqrt() + 4)
def test_walk_topdown_visit_order(self):
"""Top-down walk fires pm after children are processed (post-order)."""
visited = []
def track_visit(ctx, x):
ctx.append(x.arg if x.op is Ops.CONST else x.op)
return None
pm = PatternMatcher([(UPat(GroupOp.All, name="x"), track_visit)])
a = UOp.const(dtypes.int, 1)
b = UOp.const(dtypes.int, 2)
graph_rewrite(a + b, pm, ctx=visited, walk=True)
self.assertEqual(visited, [1, 2, Ops.ADD])
# *** bottom-up walk: try bpm on node first, skip children if it matches ***
def test_walk_bottomup_simple_substitute(self):
a = UOp.variable('a', 0, 10)
b = UOp.variable('b', 0, 10)
ret = graph_rewrite(a + 4, _substitute, {a:b}, bottom_up=True, walk=True)
self.assertIs(ret, b+4)
def test_walk_bottomup_does_not_traverse_into_replacement(self):
"""Bottom-up walk: replacement subtrees are NOT entered."""
a = UOp.variable('a', 0, 10)
b = UOp.variable('b', 0, 10)
c = UOp.variable('c', 0, 10)
d = UOp.variable('d', 0, 10)
ret = graph_rewrite(a + 4, _substitute, {a:b+c, b:d}, bottom_up=True, walk=True)
self.assertIs(ret, (b+c)+4)
def test_walk_bottomup_parent_match_skips_children(self):
"""Bottom-up walk matches parent first: if it matches, children are never visited."""
a = UOp.variable('a', 0, 10, dtype=dtypes.float)
n1 = a.sin()
ret = n1.sin() # sin(sin(a))
# sin(sin(a)) matches n1.sin()->n1.sqrt() immediately, children never visited, sin(a) inside replacement untouched
ret_walk = graph_rewrite(ret, _substitute, {a.sin():a.sqrt(), n1.sin():n1.sqrt()}, bottom_up=True, walk=True)
self.assertIs(ret_walk, a.sin().sqrt())
def test_walk_bottomup_no_fixed_point(self):
"""Bottom-up walk also applies once per node, no fixed-point iteration."""
a = UOp.const(dtypes.int, 3)
pm = PatternMatcher([
(UPat(Ops.CONST, arg=3, name="x"), lambda x: x.replace(arg=4)),
(UPat(Ops.CONST, arg=4, name="x"), lambda x: x.replace(arg=3)),
])
ret = graph_rewrite(a, pm, bottom_up=True, walk=True)
self.assertIs(ret, UOp.const(dtypes.int, 4))
def test_walk_bottomup_visit_order(self):
"""Bottom-up walk fires bpm before descending (pre-order)."""
visited = []
def track_visit(ctx, x):
ctx.append(x.arg if x.op is Ops.CONST else x.op)
return None
pm = PatternMatcher([(UPat(GroupOp.All, name="x"), track_visit)])
a = UOp.const(dtypes.int, 1)
b = UOp.const(dtypes.int, 2)
graph_rewrite(a + b, pm, ctx=visited, bottom_up=True, walk=True)
# bpm fires on each node before children: +, 1, 2
self.assertEqual(visited, [Ops.ADD, 1, 2])
def test_walk_bottomup_unmatched_falls_through_to_children(self):
"""Bottom-up walk: if bpm doesn't match a node, its children are still processed."""
a = UOp.variable('a', 0, 10)
b = UOp.variable('b', 0, 10)
c = UOp.variable('c', 0, 10)
# only a is in dvars, not a+4. bpm won't match a+4, so it descends and finds a.
ret = graph_rewrite((a + 4) + (b + 5), _substitute, {a:c, b:c}, bottom_up=True, walk=True)
self.assertIs(ret, (c + 4) + (c + 5))
# *** bidirectional walk: bpm fires before children, pm fires after rebuild ***
def test_walk_bidirectional_visit_order(self):
"""Bidirectional walk: bpm fires pre-order, pm fires post-order."""
visited = []
def bpm_visit(ctx, x):
ctx.append((x.arg if x.op is Ops.CONST else x.op, "bpm"))
return None
def pm_visit(ctx, x):
ctx.append((x.arg if x.op is Ops.CONST else x.op, "pm"))
return None
bpm = PatternMatcher([(UPat(GroupOp.All, name="x"), bpm_visit)])
pm = PatternMatcher([(UPat(GroupOp.All, name="x"), pm_visit)])
a = UOp.const(dtypes.int, 1)
b = UOp.const(dtypes.int, 2)
graph_rewrite(a + b, pm, ctx=visited, bpm=bpm, walk=True)
# bpm fires pre-order, pm fires post-order
self.assertEqual(visited, [
(Ops.ADD, "bpm"), (1, "bpm"), (1, "pm"), (2, "bpm"), (2, "pm"), (Ops.ADD, "pm"),
])
def test_walk_bidirectional_bpm_short_circuits(self):
"""If bpm matches, children are skipped and pm never fires on that node."""
visited = []
def bpm_match(ctx, x):
ctx.append((x.arg if x.op is Ops.CONST else x.op, "bpm"))
# rewrite const(1) -> const(10), short-circuiting its subtree
if x.op is Ops.CONST and x.arg == 1: return x.replace(arg=10)
return None
def pm_match(ctx, x):
ctx.append((x.arg if x.op is Ops.CONST else x.op, "pm"))
return None
bpm = PatternMatcher([(UPat(GroupOp.All, name="x"), bpm_match)])
pm = PatternMatcher([(UPat(GroupOp.All, name="x"), pm_match)])
a = UOp.const(dtypes.int, 1)
b = UOp.const(dtypes.int, 2)
ret = graph_rewrite(a + b, pm, ctx=visited, bpm=bpm, walk=True)
# bpm matches const(1) and short-circuits it, so pm never fires on const(1)
self.assertNotIn((1, "pm"), visited)
# but pm still fires on const(2) and the rebuilt ADD
self.assertIn((2, "pm"), visited)
self.assertIs(ret, UOp.const(dtypes.int, 10) + b)
if __name__ == '__main__':
unittest.main()

View File

@@ -0,0 +1,42 @@
import unittest
from tinygrad.function import function
from tinygrad import Tensor
class TestFunction(unittest.TestCase):
def test_simple(self):
@function
def f(a:Tensor, b:Tensor) -> Tensor: return a+b
a = Tensor([1,2,3])
b = Tensor([4,5,6])
c = f(a,b)
c.realize()
def test_implicit(self):
inp = Tensor([7,8,9])
@function
def f(a:Tensor, b:Tensor) -> Tensor: return a+b+inp
a = Tensor([1,2,3])
b = Tensor([4,5,6])
c = f(a,b)
c.realize()
def test_implicit_2(self):
inp = Tensor([7,8,9])
@function
def f(a:Tensor, b:Tensor) -> Tensor:
return a+b+inp
inp2 = Tensor([7,8,10])
@function
def g(a:Tensor, b:Tensor) -> Tensor:
return a+b+inp2
a = Tensor([1,2,3])
b = Tensor([4,5,6])
c = f(a,b)
d = g(a,b)
c.realize(d)
if __name__ == '__main__':
unittest.main()

View File

@@ -4,6 +4,7 @@ if int(os.getenv("TYPED", "0")):
install_import_hook(__name__)
from tinygrad.tensor import Tensor # noqa: F401
from tinygrad.engine.jit import TinyJit # noqa: F401
from tinygrad.function import function # noqa: F401
from tinygrad.uop.ops import UOp
Variable = UOp.variable
from tinygrad.dtype import dtypes # noqa: F401

View File

@@ -105,7 +105,7 @@ def lower_schedule_to_linear(big_sink:UOp) -> UOp|None:
if not SCACHE or (sc_ret:=schedule_cache.get(function.key, None)) is None:
if SPEC: type_verify(big_sink, tensor_spec)
# support recursive CALLs
function = graph_rewrite(function, pm_schedule, name="schedule to linear")
function = graph_rewrite(function, pm_schedule, name="inner schedule to linear")
linear = create_schedule(get_kernel_graph(function))
if SCACHE: schedule_cache[function.key] = linear
else:
@@ -121,6 +121,7 @@ def lower_schedule_to_linear(big_sink:UOp) -> UOp|None:
print(f"scheduled {len(linear.src):5d} kernels in {(time.perf_counter()-st)*1000:8.2f} ms"+\
f" | {' cache hit' if SCACHE and sc_ret is not None else 'CACHE MISS'} {function.key.hex()[:8]}"+\
f" | {len(UOpMetaClass.ucache):7d} uops in cache"+("" if frm is None else f" | {frm.filename}:{frm.lineno}"))
# TODO: use walk and avoid the remove tags
linear = graph_rewrite(linear, pm_post_sched_cache, ctx=({}, big_sink.src[1:]), name="params to buffers")
return graph_rewrite(linear, _remove_all_tags, name="remove tags")

48
tinygrad/function.py Normal file
View File

@@ -0,0 +1,48 @@
import functools
from typing import Generic, TypeVar, Callable, cast
from dataclasses import dataclass, field
from tinygrad.helpers import Context
from tinygrad.uop.ops import UOp, Ops, UPat, PatternMatcher, graph_rewrite
from tinygrad.tensor import Tensor
@dataclass
class _ImplicitBufCtx:
offset: int
bufs: list[UOp] = field(default_factory=list)
def _replace_implicit_buffer(ctx:_ImplicitBufCtx, b:UOp):
if b not in ctx.bufs: ctx.bufs.append(b)
return UOp.param(ctx.offset + ctx.bufs.index(b), b.dtype, b.shape, b._device)
pm_implicit = PatternMatcher([(UPat(Ops.BUFFER, src=(UPat(Ops.UNIQUE), UPat(Ops.DEVICE)), name="b"), _replace_implicit_buffer)])
ReturnType = TypeVar('ReturnType')
class function(Generic[ReturnType]):
def __init__(self, fxn:Callable[..., ReturnType]):
self.fxn = fxn
def __get__(self, obj, objtype=None): return functools.partial(self.__call__, obj) if obj is not None else self
def __call__(self, *args, **kwargs) -> ReturnType:
input_uops: list[UOp] = [(t.uop if isinstance(t, Tensor) else t).multibase
for name,t in list(enumerate(args))+sorted(kwargs.items()) if isinstance(t, (Tensor, UOp))]
# disable realize/schedule while this is running
# run it and do surgery later
with Context(ALLOW_DEVICE_USAGE=0):
ret = self.fxn(*args, **kwargs)
assert isinstance(ret, Tensor), "only supports one tensor return for now"
# replace the known inputs with params
subs = {}
for i,x in enumerate(input_uops):
# TODO: this can be better
if x.op is Ops.BIND: subs[x] = UOp.param(i, x.dtype, x._shape, x._device, x._min_max)
else: subs[x] = UOp.param(i, x.dtype, x._shape, x._device)
uret = ret.uop.substitute(subs)
# replace the implicit BUFFER inputs with params using graph_rewrite
ctx = _ImplicitBufCtx(offset=len(input_uops))
uret = graph_rewrite(uret, pm_implicit, ctx=ctx)
return cast(ReturnType, Tensor(uret.call(*input_uops, *ctx.bufs, name=self.fxn.__name__), device=ret.device))

View File

@@ -163,7 +163,8 @@ def assign_multi(dest:UOp, src:UOp):
def passthrough_multi(root:UOp, multi:UOp):
return UOp(root.op, root.dtype, (multi.src[0],)+tuple(x.src[0] if x.op is Ops.MULTI else x for x in root.src[1:]), root.arg).multi(multi.axis)
def rewrite_into_call(call:UOp): return call.replace(src=(graph_rewrite(call.src[0], multi_pm),)+call.src[1:]) if should_resolve_call(call) else None
def rewrite_into_call(call:UOp):
return call.replace(src=(graph_rewrite(call.src[0], multi_pm, name="subcall"),)+call.src[1:]) if should_resolve_call(call) else None
# NOTE: this is the same pattern as Ops.UNROLL
multi_pm = PatternMatcher([

View File

@@ -79,7 +79,7 @@ pm_gather_params = PatternMatcher([ (UPat(Ops.PARAM, name="p"), lambda ctx, p: c
def resolve_call(c:UOp, allow_param_mismatch=False) -> UOp|None:
if not should_resolve_call(c): return None
params: list[UOp] = []
graph_rewrite(c.src[0], pm_gather_params, bottom_up=True, ctx=params)
graph_rewrite(c.src[0], pm_gather_params, bottom_up=True, ctx=params, name="gather params")
params = sorted(params, key=lambda x: x.arg)
args = c.src[1:]
# TODO: this check belongs in spec, not here
@@ -89,7 +89,7 @@ def resolve_call(c:UOp, allow_param_mismatch=False) -> UOp|None:
for i, (p, a) in enumerate(zip(params, args)):
if p.shape != a.shape: raise TypeError(f"arg {i} shape mismatch: expected {p.shape}, got {a.shape}")
if p.dtype != a.dtype: raise TypeError(f"arg {i} dtype mismatch: expected {p.dtype}, got {a.dtype}")
return c.src[0].substitute(dict(zip(params, args)))
return c.src[0].substitute(dict(zip(params, args)), walk=True)
earliest_rewrites = mop_cleanup+PatternMatcher([
# resolve calls

View File

@@ -373,11 +373,12 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
def __bool__(self): return self._eval((dtypes.bool,), bool)
def __int__(self): return self._eval(dtypes.ints, int)
def __float__(self): return float(self._eval(dtypes.floats, float))
def substitute(self, dvars:dict[UOp, UOp], name:str|None=None, extra_pm:PatternMatcher|None=None):
def substitute(self, dvars:dict[UOp, UOp], name:str|None=None, extra_pm:PatternMatcher|None=None, walk:bool=False):
dvars = {k:v for k,v in dvars.items() if k is not v}
if len(dvars) == 0: return self
with Context(TRACK_MATCH_STATS=(0 if name is None else TRACK_MATCH_STATS.value)):
return graph_rewrite(self, (extra_pm+_substitute) if extra_pm is not None else _substitute, dvars, bottom_up=True, name=name)
return graph_rewrite(self, (extra_pm+_substitute) if extra_pm is not None else _substitute, dvars,
bottom_up=True, walk=walk, name=name)
# NOTE: this is not called by Tensor slice (Tensor handles UOps directly), but satisfies SupportsIndex for type checking
def __index__(self): return self.__int__()
@@ -864,10 +865,10 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
if name is not None: src += (UOp(Ops.NOOP, arg=name),)
return UOp(Ops.PARAM, dtype, src, arg=slot)
def call(self, *srcs:UOp, grad_fxn:Callable|None=None, metadata:tuple[Metadata, ...]=()) -> UOp:
def call(self, *srcs:UOp, grad_fxn:Callable|None=None, metadata:tuple[Metadata, ...]=(), name:str|None=None) -> UOp:
# TODO: reenable this after ENCDEC is fixed
#assert len(self.ranges) == 0, f"ranges {self.ranges} are leaking out of the call in {self.pyrender()}"
return UOp(Ops.CALL, self.dtype, (self,)+srcs, CallInfo(grad_fxn, metadata))
return UOp(Ops.CALL, self.dtype, (self,)+srcs, CallInfo(grad_fxn, metadata, name))
def custom_kernel(*srcs:UOp, fxn:Callable, grad_fxn:Callable|None=None) -> list[UOp]:
contig_srcs = tuple(x.contiguous() if x.op is not Ops.AFTER else x for x in srcs)
placeholders = [UOp.placeholder_like(s, slot=i) for i,s in enumerate(contig_srcs)]
@@ -889,9 +890,10 @@ class KernelInfo:
class CallInfo:
grad_fxn: Callable|None = None
metadata: tuple[Metadata, ...] = ()
name: str|None = None
# grad_fxn can't be pickled, but metadata can
def __reduce__(self): return (CallInfo, (None, self.metadata))
def __repr__(self): return f"CallInfo({id(self.grad_fxn) if self.grad_fxn else None}, {self.metadata})"
def __reduce__(self): return (CallInfo, (None, self.metadata, self.name))
def __repr__(self): return f"CallInfo({id(self.grad_fxn) if self.grad_fxn else None}, {self.metadata}, {repr(self.name)})"
def should_resolve_call(c:UOp) -> bool:
# don't resolve real kernel calls, sink or program
@@ -1259,6 +1261,30 @@ class RewriteContext:
ret = self.bpm_cache[x] = unwrap(self.bpm).rewrite(x, self.ctx)
return ret
def walk_rewrite(self, root:UOp) -> UOp:
"""MLIR-style Walk Pattern Rewrite Driver: single-pass, no re-traversal into rewritten subtrees."""
stack: list[tuple[UOp, bool]] = [(root, False)]
while stack:
n, processed = stack.pop()
if n in self.replace: continue
if not processed:
# bottom-up: try bpm on original node first, if it rewrites, use result as-is (no traversal into replacement)
if self.bpm is not None and (rewritten:=self.cached_bpm_rewrite(n)) is not None:
self.replace[n] = rewritten
continue
# no rewrite, process children then come back to rebuild
stack.append((n, True))
for x in reversed(n.src):
if x not in self.replace: stack.append((x, False))
else:
# rebuild node with rewritten srcs
new_src = tuple(self.replace.get(x, x) for x in n.src)
new_n = UOp(n.op, n.dtype, new_src, n.arg, n.tag) if new_src != n.src else n
# top-down: try pm on rebuilt node, use result as-is (no re-traversal)
if self.pm is not None and (rewritten:=self.pm_rewrite(new_n)) is not None: new_n = rewritten
self.replace[n] = new_n
return self.replace.get(root, root)
def unified_rewrite(self, root:UOp) -> UOp:
stack: collections.deque[tuple[UOp, int, UOp]] = collections.deque([(root, 0, root)])
on_stack = {root} # all UOps either on the stack or in self.replace, i.e. dont have to be placed again
@@ -1326,9 +1352,9 @@ class RewriteContext:
return self.replace[root]
@profile_matches
def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, name=None, bpm=None) -> UOp:
def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, name=None, bpm=None, walk=False) -> UOp:
rewrite_ctx = RewriteContext(pm if not bottom_up else None, pm if bottom_up else bpm, ctx)
return rewrite_ctx.unified_rewrite(sink)
return rewrite_ctx.walk_rewrite(sink) if walk else rewrite_ctx.unified_rewrite(sink)
def sint_to_uop(x:sint, dtype=dtypes.index) -> UOp: return UOp.const(dtype, x) if isinstance(x, int) else x.cast(dtype)