mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
start function and add walk rewrite (#14992)
* start function and add walk rewrite * work * add function on feed_forward * llm progress * stuff * none of that
This commit is contained in:
@@ -349,5 +349,184 @@ class TestStopEarly(unittest.TestCase):
|
||||
ret = (c+d).substitute({c:cn}, extra_pm=pm_cvisit)
|
||||
assert ret == cn+d
|
||||
|
||||
class TestWalkRewrite(unittest.TestCase):
|
||||
"""Tests for graph_rewrite with walk=True (MLIR Walk Pattern Rewrite Driver semantics).
|
||||
walk=True gives a single-pass traversal that does NOT revisit or re-traverse into rewritten subtrees.
|
||||
Supports both top-down (default) and bottom-up (bottom_up=True) modes."""
|
||||
|
||||
# *** top-down walk (default): process children first, then try pm on rebuilt node ***
|
||||
|
||||
def test_walk_topdown_simple_substitute(self):
|
||||
a = UOp.variable('a', 0, 10)
|
||||
b = UOp.variable('b', 0, 10)
|
||||
ret = graph_rewrite(a + 4, _substitute, {a:b}, walk=True)
|
||||
self.assertIs(ret, b+4)
|
||||
|
||||
def test_walk_topdown_does_not_traverse_into_replacement(self):
|
||||
"""Top-down walk: replacement subtrees are NOT re-entered."""
|
||||
a = UOp.variable('a', 0, 10)
|
||||
b = UOp.variable('b', 0, 10)
|
||||
c = UOp.variable('c', 0, 10)
|
||||
d = UOp.variable('d', 0, 10)
|
||||
# a is replaced by b+c, but b inside the replacement is NOT further substituted to d
|
||||
ret_walk = graph_rewrite(a + 4, _substitute, {a:b+c, b:d}, walk=True)
|
||||
self.assertIs(ret_walk, (b+c)+4)
|
||||
# contrast: greedy bottom_up WOULD replace b inside the replacement
|
||||
ret_greedy = graph_rewrite(a + 4, _substitute, {a:b+c, b:d}, bottom_up=True)
|
||||
self.assertIs(ret_greedy, (d+c)+4)
|
||||
|
||||
def test_walk_topdown_no_fixed_point(self):
|
||||
"""A bouncing pattern applies once and stops instead of looping."""
|
||||
a = UOp.const(dtypes.int, 3)
|
||||
pm = PatternMatcher([
|
||||
(UPat(Ops.CONST, arg=3, name="x"), lambda x: x.replace(arg=4)),
|
||||
(UPat(Ops.CONST, arg=4, name="x"), lambda x: x.replace(arg=3)),
|
||||
])
|
||||
with self.assertRaises(RuntimeError):
|
||||
graph_rewrite(a, pm, bottom_up=True)
|
||||
ret = graph_rewrite(a, pm, walk=True)
|
||||
self.assertIs(ret, UOp.const(dtypes.int, 4))
|
||||
|
||||
def test_walk_topdown_rewrites_children(self):
|
||||
a = UOp.variable('a', 0, 10)
|
||||
b = UOp.variable('b', 0, 10)
|
||||
c = UOp.variable('c', 0, 10)
|
||||
ret = graph_rewrite((a + 4) + (b + 5), _substitute, {a:c, b:c}, walk=True)
|
||||
self.assertIs(ret, (c + 4) + (c + 5))
|
||||
|
||||
def test_walk_topdown_diamond(self):
|
||||
a = UOp.variable('a', 0, 10)
|
||||
b = UOp.variable('b', 0, 10)
|
||||
ret = graph_rewrite((a + 4) + (a + 5), _substitute, {a:b}, walk=True)
|
||||
self.assertIs(ret, (b + 4) + (b + 5))
|
||||
|
||||
def test_walk_topdown_children_rewritten_before_parent(self):
|
||||
"""Top-down walk processes children first: child substitution changes the rebuilt parent."""
|
||||
a = UOp.variable('a', 0, 10, dtype=dtypes.float)
|
||||
n1 = a.sin() # sin(a)
|
||||
ret = n1.sin() # sin(sin(a))
|
||||
# sin(a)->sqrt(a) fires first (child), parent rebuilds to sin(sqrt(a)), which doesn't match sin(sin(a)) in dvars
|
||||
ret_walk = graph_rewrite(ret, _substitute, {a.sin():a.sqrt(), n1.sin():n1.sqrt()}, walk=True)
|
||||
self.assertIs(ret_walk, a.sqrt().sin())
|
||||
|
||||
def test_walk_topdown_self_referential_replacement(self):
|
||||
"""Replacement containing the replaced node works without infinite recursion."""
|
||||
a = UOp.variable('a', 0, 10, dtype=dtypes.float)
|
||||
ret = graph_rewrite(a.sin() + 4, _substitute, {a.sin(): a.sin().sqrt()}, walk=True)
|
||||
self.assertIs(ret, a.sin().sqrt() + 4)
|
||||
|
||||
def test_walk_topdown_visit_order(self):
|
||||
"""Top-down walk fires pm after children are processed (post-order)."""
|
||||
visited = []
|
||||
def track_visit(ctx, x):
|
||||
ctx.append(x.arg if x.op is Ops.CONST else x.op)
|
||||
return None
|
||||
pm = PatternMatcher([(UPat(GroupOp.All, name="x"), track_visit)])
|
||||
a = UOp.const(dtypes.int, 1)
|
||||
b = UOp.const(dtypes.int, 2)
|
||||
graph_rewrite(a + b, pm, ctx=visited, walk=True)
|
||||
self.assertEqual(visited, [1, 2, Ops.ADD])
|
||||
|
||||
# *** bottom-up walk: try bpm on node first, skip children if it matches ***
|
||||
|
||||
def test_walk_bottomup_simple_substitute(self):
|
||||
a = UOp.variable('a', 0, 10)
|
||||
b = UOp.variable('b', 0, 10)
|
||||
ret = graph_rewrite(a + 4, _substitute, {a:b}, bottom_up=True, walk=True)
|
||||
self.assertIs(ret, b+4)
|
||||
|
||||
def test_walk_bottomup_does_not_traverse_into_replacement(self):
|
||||
"""Bottom-up walk: replacement subtrees are NOT entered."""
|
||||
a = UOp.variable('a', 0, 10)
|
||||
b = UOp.variable('b', 0, 10)
|
||||
c = UOp.variable('c', 0, 10)
|
||||
d = UOp.variable('d', 0, 10)
|
||||
ret = graph_rewrite(a + 4, _substitute, {a:b+c, b:d}, bottom_up=True, walk=True)
|
||||
self.assertIs(ret, (b+c)+4)
|
||||
|
||||
def test_walk_bottomup_parent_match_skips_children(self):
|
||||
"""Bottom-up walk matches parent first: if it matches, children are never visited."""
|
||||
a = UOp.variable('a', 0, 10, dtype=dtypes.float)
|
||||
n1 = a.sin()
|
||||
ret = n1.sin() # sin(sin(a))
|
||||
# sin(sin(a)) matches n1.sin()->n1.sqrt() immediately, children never visited, sin(a) inside replacement untouched
|
||||
ret_walk = graph_rewrite(ret, _substitute, {a.sin():a.sqrt(), n1.sin():n1.sqrt()}, bottom_up=True, walk=True)
|
||||
self.assertIs(ret_walk, a.sin().sqrt())
|
||||
|
||||
def test_walk_bottomup_no_fixed_point(self):
|
||||
"""Bottom-up walk also applies once per node, no fixed-point iteration."""
|
||||
a = UOp.const(dtypes.int, 3)
|
||||
pm = PatternMatcher([
|
||||
(UPat(Ops.CONST, arg=3, name="x"), lambda x: x.replace(arg=4)),
|
||||
(UPat(Ops.CONST, arg=4, name="x"), lambda x: x.replace(arg=3)),
|
||||
])
|
||||
ret = graph_rewrite(a, pm, bottom_up=True, walk=True)
|
||||
self.assertIs(ret, UOp.const(dtypes.int, 4))
|
||||
|
||||
def test_walk_bottomup_visit_order(self):
|
||||
"""Bottom-up walk fires bpm before descending (pre-order)."""
|
||||
visited = []
|
||||
def track_visit(ctx, x):
|
||||
ctx.append(x.arg if x.op is Ops.CONST else x.op)
|
||||
return None
|
||||
pm = PatternMatcher([(UPat(GroupOp.All, name="x"), track_visit)])
|
||||
a = UOp.const(dtypes.int, 1)
|
||||
b = UOp.const(dtypes.int, 2)
|
||||
graph_rewrite(a + b, pm, ctx=visited, bottom_up=True, walk=True)
|
||||
# bpm fires on each node before children: +, 1, 2
|
||||
self.assertEqual(visited, [Ops.ADD, 1, 2])
|
||||
|
||||
def test_walk_bottomup_unmatched_falls_through_to_children(self):
|
||||
"""Bottom-up walk: if bpm doesn't match a node, its children are still processed."""
|
||||
a = UOp.variable('a', 0, 10)
|
||||
b = UOp.variable('b', 0, 10)
|
||||
c = UOp.variable('c', 0, 10)
|
||||
# only a is in dvars, not a+4. bpm won't match a+4, so it descends and finds a.
|
||||
ret = graph_rewrite((a + 4) + (b + 5), _substitute, {a:c, b:c}, bottom_up=True, walk=True)
|
||||
self.assertIs(ret, (c + 4) + (c + 5))
|
||||
|
||||
# *** bidirectional walk: bpm fires before children, pm fires after rebuild ***
|
||||
|
||||
def test_walk_bidirectional_visit_order(self):
|
||||
"""Bidirectional walk: bpm fires pre-order, pm fires post-order."""
|
||||
visited = []
|
||||
def bpm_visit(ctx, x):
|
||||
ctx.append((x.arg if x.op is Ops.CONST else x.op, "bpm"))
|
||||
return None
|
||||
def pm_visit(ctx, x):
|
||||
ctx.append((x.arg if x.op is Ops.CONST else x.op, "pm"))
|
||||
return None
|
||||
bpm = PatternMatcher([(UPat(GroupOp.All, name="x"), bpm_visit)])
|
||||
pm = PatternMatcher([(UPat(GroupOp.All, name="x"), pm_visit)])
|
||||
a = UOp.const(dtypes.int, 1)
|
||||
b = UOp.const(dtypes.int, 2)
|
||||
graph_rewrite(a + b, pm, ctx=visited, bpm=bpm, walk=True)
|
||||
# bpm fires pre-order, pm fires post-order
|
||||
self.assertEqual(visited, [
|
||||
(Ops.ADD, "bpm"), (1, "bpm"), (1, "pm"), (2, "bpm"), (2, "pm"), (Ops.ADD, "pm"),
|
||||
])
|
||||
|
||||
def test_walk_bidirectional_bpm_short_circuits(self):
|
||||
"""If bpm matches, children are skipped and pm never fires on that node."""
|
||||
visited = []
|
||||
def bpm_match(ctx, x):
|
||||
ctx.append((x.arg if x.op is Ops.CONST else x.op, "bpm"))
|
||||
# rewrite const(1) -> const(10), short-circuiting its subtree
|
||||
if x.op is Ops.CONST and x.arg == 1: return x.replace(arg=10)
|
||||
return None
|
||||
def pm_match(ctx, x):
|
||||
ctx.append((x.arg if x.op is Ops.CONST else x.op, "pm"))
|
||||
return None
|
||||
bpm = PatternMatcher([(UPat(GroupOp.All, name="x"), bpm_match)])
|
||||
pm = PatternMatcher([(UPat(GroupOp.All, name="x"), pm_match)])
|
||||
a = UOp.const(dtypes.int, 1)
|
||||
b = UOp.const(dtypes.int, 2)
|
||||
ret = graph_rewrite(a + b, pm, ctx=visited, bpm=bpm, walk=True)
|
||||
# bpm matches const(1) and short-circuits it, so pm never fires on const(1)
|
||||
self.assertNotIn((1, "pm"), visited)
|
||||
# but pm still fires on const(2) and the rebuilt ADD
|
||||
self.assertIn((2, "pm"), visited)
|
||||
self.assertIs(ret, UOp.const(dtypes.int, 10) + b)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
42
test/unit/test_function.py
Normal file
42
test/unit/test_function.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import unittest
|
||||
from tinygrad.function import function
|
||||
from tinygrad import Tensor
|
||||
|
||||
class TestFunction(unittest.TestCase):
|
||||
def test_simple(self):
|
||||
@function
|
||||
def f(a:Tensor, b:Tensor) -> Tensor: return a+b
|
||||
|
||||
a = Tensor([1,2,3])
|
||||
b = Tensor([4,5,6])
|
||||
c = f(a,b)
|
||||
c.realize()
|
||||
|
||||
def test_implicit(self):
|
||||
inp = Tensor([7,8,9])
|
||||
@function
|
||||
def f(a:Tensor, b:Tensor) -> Tensor: return a+b+inp
|
||||
|
||||
a = Tensor([1,2,3])
|
||||
b = Tensor([4,5,6])
|
||||
c = f(a,b)
|
||||
c.realize()
|
||||
|
||||
def test_implicit_2(self):
|
||||
inp = Tensor([7,8,9])
|
||||
@function
|
||||
def f(a:Tensor, b:Tensor) -> Tensor:
|
||||
return a+b+inp
|
||||
inp2 = Tensor([7,8,10])
|
||||
@function
|
||||
def g(a:Tensor, b:Tensor) -> Tensor:
|
||||
return a+b+inp2
|
||||
|
||||
a = Tensor([1,2,3])
|
||||
b = Tensor([4,5,6])
|
||||
c = f(a,b)
|
||||
d = g(a,b)
|
||||
c.realize(d)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -4,6 +4,7 @@ if int(os.getenv("TYPED", "0")):
|
||||
install_import_hook(__name__)
|
||||
from tinygrad.tensor import Tensor # noqa: F401
|
||||
from tinygrad.engine.jit import TinyJit # noqa: F401
|
||||
from tinygrad.function import function # noqa: F401
|
||||
from tinygrad.uop.ops import UOp
|
||||
Variable = UOp.variable
|
||||
from tinygrad.dtype import dtypes # noqa: F401
|
||||
|
||||
@@ -105,7 +105,7 @@ def lower_schedule_to_linear(big_sink:UOp) -> UOp|None:
|
||||
if not SCACHE or (sc_ret:=schedule_cache.get(function.key, None)) is None:
|
||||
if SPEC: type_verify(big_sink, tensor_spec)
|
||||
# support recursive CALLs
|
||||
function = graph_rewrite(function, pm_schedule, name="schedule to linear")
|
||||
function = graph_rewrite(function, pm_schedule, name="inner schedule to linear")
|
||||
linear = create_schedule(get_kernel_graph(function))
|
||||
if SCACHE: schedule_cache[function.key] = linear
|
||||
else:
|
||||
@@ -121,6 +121,7 @@ def lower_schedule_to_linear(big_sink:UOp) -> UOp|None:
|
||||
print(f"scheduled {len(linear.src):5d} kernels in {(time.perf_counter()-st)*1000:8.2f} ms"+\
|
||||
f" | {' cache hit' if SCACHE and sc_ret is not None else 'CACHE MISS'} {function.key.hex()[:8]}"+\
|
||||
f" | {len(UOpMetaClass.ucache):7d} uops in cache"+("" if frm is None else f" | {frm.filename}:{frm.lineno}"))
|
||||
# TODO: use walk and avoid the remove tags
|
||||
linear = graph_rewrite(linear, pm_post_sched_cache, ctx=({}, big_sink.src[1:]), name="params to buffers")
|
||||
return graph_rewrite(linear, _remove_all_tags, name="remove tags")
|
||||
|
||||
|
||||
48
tinygrad/function.py
Normal file
48
tinygrad/function.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import functools
|
||||
from typing import Generic, TypeVar, Callable, cast
|
||||
from dataclasses import dataclass, field
|
||||
from tinygrad.helpers import Context
|
||||
from tinygrad.uop.ops import UOp, Ops, UPat, PatternMatcher, graph_rewrite
|
||||
from tinygrad.tensor import Tensor
|
||||
|
||||
@dataclass
|
||||
class _ImplicitBufCtx:
|
||||
offset: int
|
||||
bufs: list[UOp] = field(default_factory=list)
|
||||
|
||||
def _replace_implicit_buffer(ctx:_ImplicitBufCtx, b:UOp):
|
||||
if b not in ctx.bufs: ctx.bufs.append(b)
|
||||
return UOp.param(ctx.offset + ctx.bufs.index(b), b.dtype, b.shape, b._device)
|
||||
|
||||
pm_implicit = PatternMatcher([(UPat(Ops.BUFFER, src=(UPat(Ops.UNIQUE), UPat(Ops.DEVICE)), name="b"), _replace_implicit_buffer)])
|
||||
|
||||
ReturnType = TypeVar('ReturnType')
|
||||
class function(Generic[ReturnType]):
|
||||
def __init__(self, fxn:Callable[..., ReturnType]):
|
||||
self.fxn = fxn
|
||||
|
||||
def __get__(self, obj, objtype=None): return functools.partial(self.__call__, obj) if obj is not None else self
|
||||
|
||||
def __call__(self, *args, **kwargs) -> ReturnType:
|
||||
input_uops: list[UOp] = [(t.uop if isinstance(t, Tensor) else t).multibase
|
||||
for name,t in list(enumerate(args))+sorted(kwargs.items()) if isinstance(t, (Tensor, UOp))]
|
||||
|
||||
# disable realize/schedule while this is running
|
||||
# run it and do surgery later
|
||||
with Context(ALLOW_DEVICE_USAGE=0):
|
||||
ret = self.fxn(*args, **kwargs)
|
||||
assert isinstance(ret, Tensor), "only supports one tensor return for now"
|
||||
|
||||
# replace the known inputs with params
|
||||
subs = {}
|
||||
for i,x in enumerate(input_uops):
|
||||
# TODO: this can be better
|
||||
if x.op is Ops.BIND: subs[x] = UOp.param(i, x.dtype, x._shape, x._device, x._min_max)
|
||||
else: subs[x] = UOp.param(i, x.dtype, x._shape, x._device)
|
||||
uret = ret.uop.substitute(subs)
|
||||
|
||||
# replace the implicit BUFFER inputs with params using graph_rewrite
|
||||
ctx = _ImplicitBufCtx(offset=len(input_uops))
|
||||
uret = graph_rewrite(uret, pm_implicit, ctx=ctx)
|
||||
|
||||
return cast(ReturnType, Tensor(uret.call(*input_uops, *ctx.bufs, name=self.fxn.__name__), device=ret.device))
|
||||
@@ -163,7 +163,8 @@ def assign_multi(dest:UOp, src:UOp):
|
||||
def passthrough_multi(root:UOp, multi:UOp):
|
||||
return UOp(root.op, root.dtype, (multi.src[0],)+tuple(x.src[0] if x.op is Ops.MULTI else x for x in root.src[1:]), root.arg).multi(multi.axis)
|
||||
|
||||
def rewrite_into_call(call:UOp): return call.replace(src=(graph_rewrite(call.src[0], multi_pm),)+call.src[1:]) if should_resolve_call(call) else None
|
||||
def rewrite_into_call(call:UOp):
|
||||
return call.replace(src=(graph_rewrite(call.src[0], multi_pm, name="subcall"),)+call.src[1:]) if should_resolve_call(call) else None
|
||||
|
||||
# NOTE: this is the same pattern as Ops.UNROLL
|
||||
multi_pm = PatternMatcher([
|
||||
|
||||
@@ -79,7 +79,7 @@ pm_gather_params = PatternMatcher([ (UPat(Ops.PARAM, name="p"), lambda ctx, p: c
|
||||
def resolve_call(c:UOp, allow_param_mismatch=False) -> UOp|None:
|
||||
if not should_resolve_call(c): return None
|
||||
params: list[UOp] = []
|
||||
graph_rewrite(c.src[0], pm_gather_params, bottom_up=True, ctx=params)
|
||||
graph_rewrite(c.src[0], pm_gather_params, bottom_up=True, ctx=params, name="gather params")
|
||||
params = sorted(params, key=lambda x: x.arg)
|
||||
args = c.src[1:]
|
||||
# TODO: this check belongs in spec, not here
|
||||
@@ -89,7 +89,7 @@ def resolve_call(c:UOp, allow_param_mismatch=False) -> UOp|None:
|
||||
for i, (p, a) in enumerate(zip(params, args)):
|
||||
if p.shape != a.shape: raise TypeError(f"arg {i} shape mismatch: expected {p.shape}, got {a.shape}")
|
||||
if p.dtype != a.dtype: raise TypeError(f"arg {i} dtype mismatch: expected {p.dtype}, got {a.dtype}")
|
||||
return c.src[0].substitute(dict(zip(params, args)))
|
||||
return c.src[0].substitute(dict(zip(params, args)), walk=True)
|
||||
|
||||
earliest_rewrites = mop_cleanup+PatternMatcher([
|
||||
# resolve calls
|
||||
|
||||
@@ -373,11 +373,12 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
def __bool__(self): return self._eval((dtypes.bool,), bool)
|
||||
def __int__(self): return self._eval(dtypes.ints, int)
|
||||
def __float__(self): return float(self._eval(dtypes.floats, float))
|
||||
def substitute(self, dvars:dict[UOp, UOp], name:str|None=None, extra_pm:PatternMatcher|None=None):
|
||||
def substitute(self, dvars:dict[UOp, UOp], name:str|None=None, extra_pm:PatternMatcher|None=None, walk:bool=False):
|
||||
dvars = {k:v for k,v in dvars.items() if k is not v}
|
||||
if len(dvars) == 0: return self
|
||||
with Context(TRACK_MATCH_STATS=(0 if name is None else TRACK_MATCH_STATS.value)):
|
||||
return graph_rewrite(self, (extra_pm+_substitute) if extra_pm is not None else _substitute, dvars, bottom_up=True, name=name)
|
||||
return graph_rewrite(self, (extra_pm+_substitute) if extra_pm is not None else _substitute, dvars,
|
||||
bottom_up=True, walk=walk, name=name)
|
||||
# NOTE: this is not called by Tensor slice (Tensor handles UOps directly), but satisfies SupportsIndex for type checking
|
||||
def __index__(self): return self.__int__()
|
||||
|
||||
@@ -864,10 +865,10 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
if name is not None: src += (UOp(Ops.NOOP, arg=name),)
|
||||
return UOp(Ops.PARAM, dtype, src, arg=slot)
|
||||
|
||||
def call(self, *srcs:UOp, grad_fxn:Callable|None=None, metadata:tuple[Metadata, ...]=()) -> UOp:
|
||||
def call(self, *srcs:UOp, grad_fxn:Callable|None=None, metadata:tuple[Metadata, ...]=(), name:str|None=None) -> UOp:
|
||||
# TODO: reenable this after ENCDEC is fixed
|
||||
#assert len(self.ranges) == 0, f"ranges {self.ranges} are leaking out of the call in {self.pyrender()}"
|
||||
return UOp(Ops.CALL, self.dtype, (self,)+srcs, CallInfo(grad_fxn, metadata))
|
||||
return UOp(Ops.CALL, self.dtype, (self,)+srcs, CallInfo(grad_fxn, metadata, name))
|
||||
def custom_kernel(*srcs:UOp, fxn:Callable, grad_fxn:Callable|None=None) -> list[UOp]:
|
||||
contig_srcs = tuple(x.contiguous() if x.op is not Ops.AFTER else x for x in srcs)
|
||||
placeholders = [UOp.placeholder_like(s, slot=i) for i,s in enumerate(contig_srcs)]
|
||||
@@ -889,9 +890,10 @@ class KernelInfo:
|
||||
class CallInfo:
|
||||
grad_fxn: Callable|None = None
|
||||
metadata: tuple[Metadata, ...] = ()
|
||||
name: str|None = None
|
||||
# grad_fxn can't be pickled, but metadata can
|
||||
def __reduce__(self): return (CallInfo, (None, self.metadata))
|
||||
def __repr__(self): return f"CallInfo({id(self.grad_fxn) if self.grad_fxn else None}, {self.metadata})"
|
||||
def __reduce__(self): return (CallInfo, (None, self.metadata, self.name))
|
||||
def __repr__(self): return f"CallInfo({id(self.grad_fxn) if self.grad_fxn else None}, {self.metadata}, {repr(self.name)})"
|
||||
|
||||
def should_resolve_call(c:UOp) -> bool:
|
||||
# don't resolve real kernel calls, sink or program
|
||||
@@ -1259,6 +1261,30 @@ class RewriteContext:
|
||||
ret = self.bpm_cache[x] = unwrap(self.bpm).rewrite(x, self.ctx)
|
||||
return ret
|
||||
|
||||
def walk_rewrite(self, root:UOp) -> UOp:
|
||||
"""MLIR-style Walk Pattern Rewrite Driver: single-pass, no re-traversal into rewritten subtrees."""
|
||||
stack: list[tuple[UOp, bool]] = [(root, False)]
|
||||
while stack:
|
||||
n, processed = stack.pop()
|
||||
if n in self.replace: continue
|
||||
if not processed:
|
||||
# bottom-up: try bpm on original node first, if it rewrites, use result as-is (no traversal into replacement)
|
||||
if self.bpm is not None and (rewritten:=self.cached_bpm_rewrite(n)) is not None:
|
||||
self.replace[n] = rewritten
|
||||
continue
|
||||
# no rewrite, process children then come back to rebuild
|
||||
stack.append((n, True))
|
||||
for x in reversed(n.src):
|
||||
if x not in self.replace: stack.append((x, False))
|
||||
else:
|
||||
# rebuild node with rewritten srcs
|
||||
new_src = tuple(self.replace.get(x, x) for x in n.src)
|
||||
new_n = UOp(n.op, n.dtype, new_src, n.arg, n.tag) if new_src != n.src else n
|
||||
# top-down: try pm on rebuilt node, use result as-is (no re-traversal)
|
||||
if self.pm is not None and (rewritten:=self.pm_rewrite(new_n)) is not None: new_n = rewritten
|
||||
self.replace[n] = new_n
|
||||
return self.replace.get(root, root)
|
||||
|
||||
def unified_rewrite(self, root:UOp) -> UOp:
|
||||
stack: collections.deque[tuple[UOp, int, UOp]] = collections.deque([(root, 0, root)])
|
||||
on_stack = {root} # all UOps either on the stack or in self.replace, i.e. dont have to be placed again
|
||||
@@ -1326,9 +1352,9 @@ class RewriteContext:
|
||||
return self.replace[root]
|
||||
|
||||
@profile_matches
|
||||
def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, name=None, bpm=None) -> UOp:
|
||||
def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, name=None, bpm=None, walk=False) -> UOp:
|
||||
rewrite_ctx = RewriteContext(pm if not bottom_up else None, pm if bottom_up else bpm, ctx)
|
||||
return rewrite_ctx.unified_rewrite(sink)
|
||||
return rewrite_ctx.walk_rewrite(sink) if walk else rewrite_ctx.unified_rewrite(sink)
|
||||
|
||||
def sint_to_uop(x:sint, dtype=dtypes.index) -> UOp: return UOp.const(dtype, x) if isinstance(x, int) else x.cast(dtype)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user