mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
multi stage graph_rewrite_map (#9803)
* multistage graph_rewrite_map * s/merge_map/input_map * build up kernel_map from the tensor_map
This commit is contained in:
@@ -28,6 +28,22 @@ class TestRewriteMap(unittest.TestCase):
|
||||
self.assertIs(sub_map[a+b], e)
|
||||
self.assertIs(sub_map[(a+b)*c], f)
|
||||
|
||||
def test_multistage_substitute(self):
|
||||
a = UOp.variable('a', 0, 10)
|
||||
b = UOp.variable('b', 0, 10)
|
||||
c = UOp.variable('c', 0, 10)
|
||||
d = UOp.variable('d', 0, 10)
|
||||
sub1 = {a+b:c}
|
||||
start = (a+b)*c
|
||||
# stage 1: (a+b)*c -> c*c
|
||||
sub_map1 = graph_rewrite_map(start, _substitute, sub1, bottom_up=True)
|
||||
self.assertIs(sub_map1[(a+b)*c], c*c)
|
||||
# stage 2: c*c -> d
|
||||
sub2 = {c*c:d}
|
||||
sub_map2 = graph_rewrite_map(sub_map1[start], _substitute, sub2, input_map=sub_map1, bottom_up=True)
|
||||
# (a+b)*c -> c*c -> d
|
||||
self.assertIs(sub_map2[(a+b)*c], d)
|
||||
|
||||
def test_add_zero(self):
|
||||
# Build a small graph: add(0, add(const=0, const=5))
|
||||
zero_node = UOp.const(dtypes.int, 0)
|
||||
|
||||
@@ -401,14 +401,15 @@ def get_becomes_map(big_sink:UOp) -> tuple[dict[UOp, UOp], dict[Variable, int]]:
|
||||
# group into kernels
|
||||
sink = tensor_map[big_sink]
|
||||
realize_map = group_realizes(sink)
|
||||
kernel_map = graph_rewrite_map(sink, create_kernels, KernelContext(realize_map, {v:k.metadata for k,v in tensor_map.items()}), bottom_up=True)
|
||||
sched_sink = kernel_map[sink]
|
||||
tensor_map = graph_rewrite_map(sink, create_kernels, KernelContext(realize_map, {v:k.metadata for k,v in tensor_map.items()}), bottom_up=True,
|
||||
input_map=tensor_map)
|
||||
sched_sink = tensor_map[sink]
|
||||
type_verify(list(sched_sink.toposort), kernel_spec)
|
||||
|
||||
# map tensors to buffer/const, optionally apply a VIEW on top
|
||||
becomes_map: dict[UOp, UOp] = {}
|
||||
for k,v in tensor_map.items():
|
||||
if (kernel:=kernel_map.get(v.base)) is not None and kernel.base.op is Ops.ASSIGN: v = kernel.view(unwrap(v.st))
|
||||
if (kernel:=tensor_map.get(v.base)) is not None and kernel.base.op is Ops.ASSIGN: v = kernel.view(unwrap(v.st))
|
||||
if k is v: continue
|
||||
if k.op is Ops.ASSIGN:
|
||||
becomes_map[k] = k.src[0]
|
||||
|
||||
@@ -958,9 +958,12 @@ def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, name=N
|
||||
return rewrite_ctx.bottom_up_rewrite(sink) if bottom_up else rewrite_ctx.top_down_rewrite(sink)
|
||||
|
||||
@track_matches
|
||||
def graph_rewrite_map(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, name=None, track_children=False) -> dict[UOp, UOp]:
|
||||
def graph_rewrite_map(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, name=None, track_children=False, input_map=None) -> dict[UOp, UOp]:
|
||||
rewrite_ctx = RewriteContext(pm, ctx, children=sink.get_children_map() if track_children else None)
|
||||
return {k:(rewrite_ctx.bottom_up_rewrite(k) if bottom_up else rewrite_ctx.top_down_rewrite(k)) for k in list(sink.toposort)[::-1]}
|
||||
new_map = {k:(rewrite_ctx.bottom_up_rewrite(k) if bottom_up else rewrite_ctx.top_down_rewrite(k)) for k in list(sink.toposort)[::-1]}
|
||||
if input_map is not None:
|
||||
for k,v in input_map.items(): new_map[k] = new_map.get(v,v)
|
||||
return new_map
|
||||
|
||||
def sint_to_uop(x:sint, dtype:DType=dtypes.int) -> UOp: return UOp.const(dtype, x) if isinstance(x, int) else x
|
||||
|
||||
|
||||
Reference in New Issue
Block a user