multi stage graph_rewrite_map (#9803)

* multistage graph_rewrite_map

* s/merge_map/input_map

* build up kernel_map from the tensor_map
This commit is contained in:
qazal
2025-04-09 15:59:45 +08:00
committed by GitHub
parent 57f4bc3fbb
commit 3bd992dc95
3 changed files with 25 additions and 5 deletions

View File

@@ -28,6 +28,22 @@ class TestRewriteMap(unittest.TestCase):
self.assertIs(sub_map[a+b], e)
self.assertIs(sub_map[(a+b)*c], f)
def test_multistage_substitute(self):
a = UOp.variable('a', 0, 10)
b = UOp.variable('b', 0, 10)
c = UOp.variable('c', 0, 10)
d = UOp.variable('d', 0, 10)
sub1 = {a+b:c}
start = (a+b)*c
# stage 1: (a+b)*c -> c*c
sub_map1 = graph_rewrite_map(start, _substitute, sub1, bottom_up=True)
self.assertIs(sub_map1[(a+b)*c], c*c)
# stage 2: c*c -> d
sub2 = {c*c:d}
sub_map2 = graph_rewrite_map(sub_map1[start], _substitute, sub2, input_map=sub_map1, bottom_up=True)
# (a+b)*c -> c*c -> d
self.assertIs(sub_map2[(a+b)*c], d)
def test_add_zero(self):
# Build a small graph: add(0, add(const=0, const=5))
zero_node = UOp.const(dtypes.int, 0)

View File

@@ -401,14 +401,15 @@ def get_becomes_map(big_sink:UOp) -> tuple[dict[UOp, UOp], dict[Variable, int]]:
# group into kernels
sink = tensor_map[big_sink]
realize_map = group_realizes(sink)
kernel_map = graph_rewrite_map(sink, create_kernels, KernelContext(realize_map, {v:k.metadata for k,v in tensor_map.items()}), bottom_up=True)
sched_sink = kernel_map[sink]
tensor_map = graph_rewrite_map(sink, create_kernels, KernelContext(realize_map, {v:k.metadata for k,v in tensor_map.items()}), bottom_up=True,
input_map=tensor_map)
sched_sink = tensor_map[sink]
type_verify(list(sched_sink.toposort), kernel_spec)
# map tensors to buffer/const, optionally apply a VIEW on top
becomes_map: dict[UOp, UOp] = {}
for k,v in tensor_map.items():
if (kernel:=kernel_map.get(v.base)) is not None and kernel.base.op is Ops.ASSIGN: v = kernel.view(unwrap(v.st))
if (kernel:=tensor_map.get(v.base)) is not None and kernel.base.op is Ops.ASSIGN: v = kernel.view(unwrap(v.st))
if k is v: continue
if k.op is Ops.ASSIGN:
becomes_map[k] = k.src[0]

View File

@@ -958,9 +958,12 @@ def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, name=N
return rewrite_ctx.bottom_up_rewrite(sink) if bottom_up else rewrite_ctx.top_down_rewrite(sink)
@track_matches
def graph_rewrite_map(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, name=None, track_children=False) -> dict[UOp, UOp]:
def graph_rewrite_map(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, name=None, track_children=False, input_map=None) -> dict[UOp, UOp]:
rewrite_ctx = RewriteContext(pm, ctx, children=sink.get_children_map() if track_children else None)
return {k:(rewrite_ctx.bottom_up_rewrite(k) if bottom_up else rewrite_ctx.top_down_rewrite(k)) for k in list(sink.toposort)[::-1]}
new_map = {k:(rewrite_ctx.bottom_up_rewrite(k) if bottom_up else rewrite_ctx.top_down_rewrite(k)) for k in list(sink.toposort)[::-1]}
if input_map is not None:
for k,v in input_map.items(): new_map[k] = new_map.get(v,v)
return new_map
def sint_to_uop(x:sint, dtype:DType=dtypes.int) -> UOp: return UOp.const(dtype, x) if isinstance(x, int) else x