diff --git a/test/null/test_graph_rewrite.py b/test/null/test_graph_rewrite.py index b6f9749c1c..c25dfc342f 100644 --- a/test/null/test_graph_rewrite.py +++ b/test/null/test_graph_rewrite.py @@ -349,5 +349,184 @@ class TestStopEarly(unittest.TestCase): ret = (c+d).substitute({c:cn}, extra_pm=pm_cvisit) assert ret == cn+d +class TestWalkRewrite(unittest.TestCase): + """Tests for graph_rewrite with walk=True (MLIR Walk Pattern Rewrite Driver semantics). + walk=True gives a single-pass traversal that does NOT revisit or re-traverse into rewritten subtrees. + Supports both top-down (default) and bottom-up (bottom_up=True) modes.""" + + # *** top-down walk (default): process children first, then try pm on rebuilt node *** + + def test_walk_topdown_simple_substitute(self): + a = UOp.variable('a', 0, 10) + b = UOp.variable('b', 0, 10) + ret = graph_rewrite(a + 4, _substitute, {a:b}, walk=True) + self.assertIs(ret, b+4) + + def test_walk_topdown_does_not_traverse_into_replacement(self): + """Top-down walk: replacement subtrees are NOT re-entered.""" + a = UOp.variable('a', 0, 10) + b = UOp.variable('b', 0, 10) + c = UOp.variable('c', 0, 10) + d = UOp.variable('d', 0, 10) + # a is replaced by b+c, but b inside the replacement is NOT further substituted to d + ret_walk = graph_rewrite(a + 4, _substitute, {a:b+c, b:d}, walk=True) + self.assertIs(ret_walk, (b+c)+4) + # contrast: greedy bottom_up WOULD replace b inside the replacement + ret_greedy = graph_rewrite(a + 4, _substitute, {a:b+c, b:d}, bottom_up=True) + self.assertIs(ret_greedy, (d+c)+4) + + def test_walk_topdown_no_fixed_point(self): + """A bouncing pattern applies once and stops instead of looping.""" + a = UOp.const(dtypes.int, 3) + pm = PatternMatcher([ + (UPat(Ops.CONST, arg=3, name="x"), lambda x: x.replace(arg=4)), + (UPat(Ops.CONST, arg=4, name="x"), lambda x: x.replace(arg=3)), + ]) + with self.assertRaises(RuntimeError): + graph_rewrite(a, pm, bottom_up=True) + ret = graph_rewrite(a, pm, walk=True) + self.assertIs(ret, UOp.const(dtypes.int, 4)) + + def test_walk_topdown_rewrites_children(self): + a = UOp.variable('a', 0, 10) + b = UOp.variable('b', 0, 10) + c = UOp.variable('c', 0, 10) + ret = graph_rewrite((a + 4) + (b + 5), _substitute, {a:c, b:c}, walk=True) + self.assertIs(ret, (c + 4) + (c + 5)) + + def test_walk_topdown_diamond(self): + a = UOp.variable('a', 0, 10) + b = UOp.variable('b', 0, 10) + ret = graph_rewrite((a + 4) + (a + 5), _substitute, {a:b}, walk=True) + self.assertIs(ret, (b + 4) + (b + 5)) + + def test_walk_topdown_children_rewritten_before_parent(self): + """Top-down walk processes children first: child substitution changes the rebuilt parent.""" + a = UOp.variable('a', 0, 10, dtype=dtypes.float) + n1 = a.sin() # sin(a) + ret = n1.sin() # sin(sin(a)) + # sin(a)->sqrt(a) fires first (child), parent rebuilds to sin(sqrt(a)), which doesn't match sin(sin(a)) in dvars + ret_walk = graph_rewrite(ret, _substitute, {a.sin():a.sqrt(), n1.sin():n1.sqrt()}, walk=True) + self.assertIs(ret_walk, a.sqrt().sin()) + + def test_walk_topdown_self_referential_replacement(self): + """Replacement containing the replaced node works without infinite recursion.""" + a = UOp.variable('a', 0, 10, dtype=dtypes.float) + ret = graph_rewrite(a.sin() + 4, _substitute, {a.sin(): a.sin().sqrt()}, walk=True) + self.assertIs(ret, a.sin().sqrt() + 4) + + def test_walk_topdown_visit_order(self): + """Top-down walk fires pm after children are processed (post-order).""" + visited = [] + def track_visit(ctx, x): + ctx.append(x.arg if x.op is Ops.CONST else x.op) + return None + pm = PatternMatcher([(UPat(GroupOp.All, name="x"), track_visit)]) + a = UOp.const(dtypes.int, 1) + b = UOp.const(dtypes.int, 2) + graph_rewrite(a + b, pm, ctx=visited, walk=True) + self.assertEqual(visited, [1, 2, Ops.ADD]) + + # *** bottom-up walk: try bpm on node first, skip children if it matches *** + + def test_walk_bottomup_simple_substitute(self): + a = UOp.variable('a', 0, 10) + b = UOp.variable('b', 0, 10) + ret = graph_rewrite(a + 4, _substitute, {a:b}, bottom_up=True, walk=True) + self.assertIs(ret, b+4) + + def test_walk_bottomup_does_not_traverse_into_replacement(self): + """Bottom-up walk: replacement subtrees are NOT entered.""" + a = UOp.variable('a', 0, 10) + b = UOp.variable('b', 0, 10) + c = UOp.variable('c', 0, 10) + d = UOp.variable('d', 0, 10) + ret = graph_rewrite(a + 4, _substitute, {a:b+c, b:d}, bottom_up=True, walk=True) + self.assertIs(ret, (b+c)+4) + + def test_walk_bottomup_parent_match_skips_children(self): + """Bottom-up walk matches parent first: if it matches, children are never visited.""" + a = UOp.variable('a', 0, 10, dtype=dtypes.float) + n1 = a.sin() + ret = n1.sin() # sin(sin(a)) + # sin(sin(a)) matches n1.sin()->n1.sqrt() immediately, children never visited, sin(a) inside replacement untouched + ret_walk = graph_rewrite(ret, _substitute, {a.sin():a.sqrt(), n1.sin():n1.sqrt()}, bottom_up=True, walk=True) + self.assertIs(ret_walk, a.sin().sqrt()) + + def test_walk_bottomup_no_fixed_point(self): + """Bottom-up walk also applies once per node, no fixed-point iteration.""" + a = UOp.const(dtypes.int, 3) + pm = PatternMatcher([ + (UPat(Ops.CONST, arg=3, name="x"), lambda x: x.replace(arg=4)), + (UPat(Ops.CONST, arg=4, name="x"), lambda x: x.replace(arg=3)), + ]) + ret = graph_rewrite(a, pm, bottom_up=True, walk=True) + self.assertIs(ret, UOp.const(dtypes.int, 4)) + + def test_walk_bottomup_visit_order(self): + """Bottom-up walk fires bpm before descending (pre-order).""" + visited = [] + def track_visit(ctx, x): + ctx.append(x.arg if x.op is Ops.CONST else x.op) + return None + pm = PatternMatcher([(UPat(GroupOp.All, name="x"), track_visit)]) + a = UOp.const(dtypes.int, 1) + b = UOp.const(dtypes.int, 2) + graph_rewrite(a + b, pm, ctx=visited, bottom_up=True, walk=True) + # bpm fires on each node before children: +, 1, 2 + self.assertEqual(visited, [Ops.ADD, 1, 2]) + + def test_walk_bottomup_unmatched_falls_through_to_children(self): + """Bottom-up walk: if bpm doesn't match a node, its children are still processed.""" + a = UOp.variable('a', 0, 10) + b = UOp.variable('b', 0, 10) + c = UOp.variable('c', 0, 10) + # only a is in dvars, not a+4. bpm won't match a+4, so it descends and finds a. + ret = graph_rewrite((a + 4) + (b + 5), _substitute, {a:c, b:c}, bottom_up=True, walk=True) + self.assertIs(ret, (c + 4) + (c + 5)) + + # *** bidirectional walk: bpm fires before children, pm fires after rebuild *** + + def test_walk_bidirectional_visit_order(self): + """Bidirectional walk: bpm fires pre-order, pm fires post-order.""" + visited = [] + def bpm_visit(ctx, x): + ctx.append((x.arg if x.op is Ops.CONST else x.op, "bpm")) + return None + def pm_visit(ctx, x): + ctx.append((x.arg if x.op is Ops.CONST else x.op, "pm")) + return None + bpm = PatternMatcher([(UPat(GroupOp.All, name="x"), bpm_visit)]) + pm = PatternMatcher([(UPat(GroupOp.All, name="x"), pm_visit)]) + a = UOp.const(dtypes.int, 1) + b = UOp.const(dtypes.int, 2) + graph_rewrite(a + b, pm, ctx=visited, bpm=bpm, walk=True) + # bpm fires pre-order, pm fires post-order + self.assertEqual(visited, [ + (Ops.ADD, "bpm"), (1, "bpm"), (1, "pm"), (2, "bpm"), (2, "pm"), (Ops.ADD, "pm"), + ]) + + def test_walk_bidirectional_bpm_short_circuits(self): + """If bpm matches, children are skipped and pm never fires on that node.""" + visited = [] + def bpm_match(ctx, x): + ctx.append((x.arg if x.op is Ops.CONST else x.op, "bpm")) + # rewrite const(1) -> const(10), short-circuiting its subtree + if x.op is Ops.CONST and x.arg == 1: return x.replace(arg=10) + return None + def pm_match(ctx, x): + ctx.append((x.arg if x.op is Ops.CONST else x.op, "pm")) + return None + bpm = PatternMatcher([(UPat(GroupOp.All, name="x"), bpm_match)]) + pm = PatternMatcher([(UPat(GroupOp.All, name="x"), pm_match)]) + a = UOp.const(dtypes.int, 1) + b = UOp.const(dtypes.int, 2) + ret = graph_rewrite(a + b, pm, ctx=visited, bpm=bpm, walk=True) + # bpm matches const(1) and short-circuits it, so pm never fires on const(1) + self.assertNotIn((1, "pm"), visited) + # but pm still fires on const(2) and the rebuilt ADD + self.assertIn((2, "pm"), visited) + self.assertIs(ret, UOp.const(dtypes.int, 10) + b) + if __name__ == '__main__': unittest.main() diff --git a/test/unit/test_function.py b/test/unit/test_function.py new file mode 100644 index 0000000000..735c505bec --- /dev/null +++ b/test/unit/test_function.py @@ -0,0 +1,42 @@ +import unittest +from tinygrad.function import function +from tinygrad import Tensor + +class TestFunction(unittest.TestCase): + def test_simple(self): + @function + def f(a:Tensor, b:Tensor) -> Tensor: return a+b + + a = Tensor([1,2,3]) + b = Tensor([4,5,6]) + c = f(a,b) + c.realize() + + def test_implicit(self): + inp = Tensor([7,8,9]) + @function + def f(a:Tensor, b:Tensor) -> Tensor: return a+b+inp + + a = Tensor([1,2,3]) + b = Tensor([4,5,6]) + c = f(a,b) + c.realize() + + def test_implicit_2(self): + inp = Tensor([7,8,9]) + @function + def f(a:Tensor, b:Tensor) -> Tensor: + return a+b+inp + inp2 = Tensor([7,8,10]) + @function + def g(a:Tensor, b:Tensor) -> Tensor: + return a+b+inp2 + + a = Tensor([1,2,3]) + b = Tensor([4,5,6]) + c = f(a,b) + d = g(a,b) + c.realize(d) + +if __name__ == '__main__': + unittest.main() diff --git a/tinygrad/__init__.py b/tinygrad/__init__.py index 5bc12f9116..9af09ab30c 100644 --- a/tinygrad/__init__.py +++ b/tinygrad/__init__.py @@ -4,6 +4,7 @@ if int(os.getenv("TYPED", "0")): install_import_hook(__name__) from tinygrad.tensor import Tensor # noqa: F401 from tinygrad.engine.jit import TinyJit # noqa: F401 +from tinygrad.function import function # noqa: F401 from tinygrad.uop.ops import UOp Variable = UOp.variable from tinygrad.dtype import dtypes # noqa: F401 diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index f6fcc29a76..987183640c 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -105,7 +105,7 @@ def lower_schedule_to_linear(big_sink:UOp) -> UOp|None: if not SCACHE or (sc_ret:=schedule_cache.get(function.key, None)) is None: if SPEC: type_verify(big_sink, tensor_spec) # support recursive CALLs - function = graph_rewrite(function, pm_schedule, name="schedule to linear") + function = graph_rewrite(function, pm_schedule, name="inner schedule to linear") linear = create_schedule(get_kernel_graph(function)) if SCACHE: schedule_cache[function.key] = linear else: @@ -121,6 +121,7 @@ def lower_schedule_to_linear(big_sink:UOp) -> UOp|None: print(f"scheduled {len(linear.src):5d} kernels in {(time.perf_counter()-st)*1000:8.2f} ms"+\ f" | {' cache hit' if SCACHE and sc_ret is not None else 'CACHE MISS'} {function.key.hex()[:8]}"+\ f" | {len(UOpMetaClass.ucache):7d} uops in cache"+("" if frm is None else f" | {frm.filename}:{frm.lineno}")) + # TODO: use walk and avoid the remove tags linear = graph_rewrite(linear, pm_post_sched_cache, ctx=({}, big_sink.src[1:]), name="params to buffers") return graph_rewrite(linear, _remove_all_tags, name="remove tags") diff --git a/tinygrad/function.py b/tinygrad/function.py new file mode 100644 index 0000000000..11398e8f8e --- /dev/null +++ b/tinygrad/function.py @@ -0,0 +1,48 @@ +import functools +from typing import Generic, TypeVar, Callable, cast +from dataclasses import dataclass, field +from tinygrad.helpers import Context +from tinygrad.uop.ops import UOp, Ops, UPat, PatternMatcher, graph_rewrite +from tinygrad.tensor import Tensor + +@dataclass +class _ImplicitBufCtx: + offset: int + bufs: list[UOp] = field(default_factory=list) + +def _replace_implicit_buffer(ctx:_ImplicitBufCtx, b:UOp): + if b not in ctx.bufs: ctx.bufs.append(b) + return UOp.param(ctx.offset + ctx.bufs.index(b), b.dtype, b.shape, b._device) + +pm_implicit = PatternMatcher([(UPat(Ops.BUFFER, src=(UPat(Ops.UNIQUE), UPat(Ops.DEVICE)), name="b"), _replace_implicit_buffer)]) + +ReturnType = TypeVar('ReturnType') +class function(Generic[ReturnType]): + def __init__(self, fxn:Callable[..., ReturnType]): + self.fxn = fxn + + def __get__(self, obj, objtype=None): return functools.partial(self.__call__, obj) if obj is not None else self + + def __call__(self, *args, **kwargs) -> ReturnType: + input_uops: list[UOp] = [(t.uop if isinstance(t, Tensor) else t).multibase + for name,t in list(enumerate(args))+sorted(kwargs.items()) if isinstance(t, (Tensor, UOp))] + + # disable realize/schedule while this is running + # run it and do surgery later + with Context(ALLOW_DEVICE_USAGE=0): + ret = self.fxn(*args, **kwargs) + assert isinstance(ret, Tensor), "only supports one tensor return for now" + + # replace the known inputs with params + subs = {} + for i,x in enumerate(input_uops): + # TODO: this can be better + if x.op is Ops.BIND: subs[x] = UOp.param(i, x.dtype, x._shape, x._device, x._min_max) + else: subs[x] = UOp.param(i, x.dtype, x._shape, x._device) + uret = ret.uop.substitute(subs) + + # replace the implicit BUFFER inputs with params using graph_rewrite + ctx = _ImplicitBufCtx(offset=len(input_uops)) + uret = graph_rewrite(uret, pm_implicit, ctx=ctx) + + return cast(ReturnType, Tensor(uret.call(*input_uops, *ctx.bufs, name=self.fxn.__name__), device=ret.device)) diff --git a/tinygrad/schedule/multi.py b/tinygrad/schedule/multi.py index 487e562139..b92bc38dc8 100644 --- a/tinygrad/schedule/multi.py +++ b/tinygrad/schedule/multi.py @@ -163,7 +163,8 @@ def assign_multi(dest:UOp, src:UOp): def passthrough_multi(root:UOp, multi:UOp): return UOp(root.op, root.dtype, (multi.src[0],)+tuple(x.src[0] if x.op is Ops.MULTI else x for x in root.src[1:]), root.arg).multi(multi.axis) -def rewrite_into_call(call:UOp): return call.replace(src=(graph_rewrite(call.src[0], multi_pm),)+call.src[1:]) if should_resolve_call(call) else None +def rewrite_into_call(call:UOp): + return call.replace(src=(graph_rewrite(call.src[0], multi_pm, name="subcall"),)+call.src[1:]) if should_resolve_call(call) else None # NOTE: this is the same pattern as Ops.UNROLL multi_pm = PatternMatcher([ diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index 7c07e41e84..f274bae424 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -79,7 +79,7 @@ pm_gather_params = PatternMatcher([ (UPat(Ops.PARAM, name="p"), lambda ctx, p: c def resolve_call(c:UOp, allow_param_mismatch=False) -> UOp|None: if not should_resolve_call(c): return None params: list[UOp] = [] - graph_rewrite(c.src[0], pm_gather_params, bottom_up=True, ctx=params) + graph_rewrite(c.src[0], pm_gather_params, bottom_up=True, ctx=params, name="gather params") params = sorted(params, key=lambda x: x.arg) args = c.src[1:] # TODO: this check belongs in spec, not here @@ -89,7 +89,7 @@ def resolve_call(c:UOp, allow_param_mismatch=False) -> UOp|None: for i, (p, a) in enumerate(zip(params, args)): if p.shape != a.shape: raise TypeError(f"arg {i} shape mismatch: expected {p.shape}, got {a.shape}") if p.dtype != a.dtype: raise TypeError(f"arg {i} dtype mismatch: expected {p.dtype}, got {a.dtype}") - return c.src[0].substitute(dict(zip(params, args))) + return c.src[0].substitute(dict(zip(params, args)), walk=True) earliest_rewrites = mop_cleanup+PatternMatcher([ # resolve calls diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index d7603b5e5d..eecff85ceb 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -373,11 +373,12 @@ class UOp(OpMixin, metaclass=UOpMetaClass): def __bool__(self): return self._eval((dtypes.bool,), bool) def __int__(self): return self._eval(dtypes.ints, int) def __float__(self): return float(self._eval(dtypes.floats, float)) - def substitute(self, dvars:dict[UOp, UOp], name:str|None=None, extra_pm:PatternMatcher|None=None): + def substitute(self, dvars:dict[UOp, UOp], name:str|None=None, extra_pm:PatternMatcher|None=None, walk:bool=False): dvars = {k:v for k,v in dvars.items() if k is not v} if len(dvars) == 0: return self with Context(TRACK_MATCH_STATS=(0 if name is None else TRACK_MATCH_STATS.value)): - return graph_rewrite(self, (extra_pm+_substitute) if extra_pm is not None else _substitute, dvars, bottom_up=True, name=name) + return graph_rewrite(self, (extra_pm+_substitute) if extra_pm is not None else _substitute, dvars, + bottom_up=True, walk=walk, name=name) # NOTE: this is not called by Tensor slice (Tensor handles UOps directly), but satisfies SupportsIndex for type checking def __index__(self): return self.__int__() @@ -864,10 +865,10 @@ class UOp(OpMixin, metaclass=UOpMetaClass): if name is not None: src += (UOp(Ops.NOOP, arg=name),) return UOp(Ops.PARAM, dtype, src, arg=slot) - def call(self, *srcs:UOp, grad_fxn:Callable|None=None, metadata:tuple[Metadata, ...]=()) -> UOp: + def call(self, *srcs:UOp, grad_fxn:Callable|None=None, metadata:tuple[Metadata, ...]=(), name:str|None=None) -> UOp: # TODO: reenable this after ENCDEC is fixed #assert len(self.ranges) == 0, f"ranges {self.ranges} are leaking out of the call in {self.pyrender()}" - return UOp(Ops.CALL, self.dtype, (self,)+srcs, CallInfo(grad_fxn, metadata)) + return UOp(Ops.CALL, self.dtype, (self,)+srcs, CallInfo(grad_fxn, metadata, name)) def custom_kernel(*srcs:UOp, fxn:Callable, grad_fxn:Callable|None=None) -> list[UOp]: contig_srcs = tuple(x.contiguous() if x.op is not Ops.AFTER else x for x in srcs) placeholders = [UOp.placeholder_like(s, slot=i) for i,s in enumerate(contig_srcs)] @@ -889,9 +890,10 @@ class KernelInfo: class CallInfo: grad_fxn: Callable|None = None metadata: tuple[Metadata, ...] = () + name: str|None = None # grad_fxn can't be pickled, but metadata can - def __reduce__(self): return (CallInfo, (None, self.metadata)) - def __repr__(self): return f"CallInfo({id(self.grad_fxn) if self.grad_fxn else None}, {self.metadata})" + def __reduce__(self): return (CallInfo, (None, self.metadata, self.name)) + def __repr__(self): return f"CallInfo({id(self.grad_fxn) if self.grad_fxn else None}, {self.metadata}, {repr(self.name)})" def should_resolve_call(c:UOp) -> bool: # don't resolve real kernel calls, sink or program @@ -1259,6 +1261,30 @@ class RewriteContext: ret = self.bpm_cache[x] = unwrap(self.bpm).rewrite(x, self.ctx) return ret + def walk_rewrite(self, root:UOp) -> UOp: + """MLIR-style Walk Pattern Rewrite Driver: single-pass, no re-traversal into rewritten subtrees.""" + stack: list[tuple[UOp, bool]] = [(root, False)] + while stack: + n, processed = stack.pop() + if n in self.replace: continue + if not processed: + # bottom-up: try bpm on original node first, if it rewrites, use result as-is (no traversal into replacement) + if self.bpm is not None and (rewritten:=self.cached_bpm_rewrite(n)) is not None: + self.replace[n] = rewritten + continue + # no rewrite, process children then come back to rebuild + stack.append((n, True)) + for x in reversed(n.src): + if x not in self.replace: stack.append((x, False)) + else: + # rebuild node with rewritten srcs + new_src = tuple(self.replace.get(x, x) for x in n.src) + new_n = UOp(n.op, n.dtype, new_src, n.arg, n.tag) if new_src != n.src else n + # top-down: try pm on rebuilt node, use result as-is (no re-traversal) + if self.pm is not None and (rewritten:=self.pm_rewrite(new_n)) is not None: new_n = rewritten + self.replace[n] = new_n + return self.replace.get(root, root) + def unified_rewrite(self, root:UOp) -> UOp: stack: collections.deque[tuple[UOp, int, UOp]] = collections.deque([(root, 0, root)]) on_stack = {root} # all UOps either on the stack or in self.replace, i.e. dont have to be placed again @@ -1326,9 +1352,9 @@ class RewriteContext: return self.replace[root] @profile_matches -def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, name=None, bpm=None) -> UOp: +def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, name=None, bpm=None, walk=False) -> UOp: rewrite_ctx = RewriteContext(pm if not bottom_up else None, pm if bottom_up else bpm, ctx) - return rewrite_ctx.unified_rewrite(sink) + return rewrite_ctx.walk_rewrite(sink) if walk else rewrite_ctx.unified_rewrite(sink) def sint_to_uop(x:sint, dtype=dtypes.index) -> UOp: return UOp.const(dtype, x) if isinstance(x, int) else x.cast(dtype)