Revert "split grouper into insert and finalize stages [pr] (#10222)" (#10224)

This reverts commit 2594e4db15.
This commit is contained in:
qazal
2025-05-09 03:02:38 +03:00
committed by GitHub
parent 2594e4db15
commit b6904bbf83
2 changed files with 22 additions and 18 deletions

View File

@@ -15,7 +15,7 @@ from tinygrad.ops import PatternMatcher, UOp, Ops, GroupOp, UPat, graph_rewrite,
from tinygrad.codegen.symbolic import symbolic_simple
from tinygrad.spec import type_verify, shape_spec
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, all_same, temp
from tinygrad.engine.grouper import view_left, view_right, sym, get_becomes_map, Kernel, finalize_kernels, merge_views
from tinygrad.engine.grouper import view_left, view_right, sym, get_becomes_map, Kernel, create_ast, merge_views
from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars
from tinygrad.engine.realize import CompiledRunner, run_schedule, lower_schedule
from extra.models.llama import precompute_freqs_cis
@@ -605,7 +605,7 @@ class TestSchedule(unittest.TestCase):
b = Tensor.zeros(1, dtype=dtypes.int).contiguous().realize().lazydata
c = Tensor.arange(4).realize().lazydata
kernel = UOp(Ops.KERNEL, src=(a, b, c), arg=Kernel(UOp.sink(c.r(Ops.ADD, (0,))+1, c.r(Ops.ADD, (0,))*2)))
kernel = graph_rewrite(kernel, finalize_kernels)
kernel = graph_rewrite(kernel, create_ast)
run_schedule(check_schedule(UOp.sink(a.assign(kernel), b.assign(kernel)), 1))
self.assertEqual(a.buffer.numpy(), [7])
self.assertEqual(b.buffer.numpy(), [12])

View File

@@ -262,6 +262,19 @@ insert_kernels = merge_views+PatternMatcher([
(UPat(GroupOp.All-DONT_PLACE_IN_KERNEL, name="x"), lambda ctx,x: create_kernel(ctx, x) if x in ctx.realizes else None),
])
def append_to_kernel(ctx:KernelContext, x:UOp):
new_srcs: list[UOp] = []
metadata = dict.fromkeys(x.arg.metadata)
for s in x.src:
if s.op in DONT_PLACE_IN_KERNEL or s in ctx.realizes: new_srcs.append(s)
else:
new_srcs.extend(s.src)
if s.base.op not in {Ops.CONST, Ops.DEVICE} and (m:=ctx.metadata.get(s)): metadata[m] = None
if (new_src:=tuple(dedup(new_srcs))) != x.src: return x.replace(src=new_src, arg=Kernel(x.arg.ast, tuple(metadata)))
# walk back the local graph until we reach a realized parent
create_kernels = insert_kernels+PatternMatcher([(UPat(Ops.KERNEL, name="x"), append_to_kernel),])
# **** swizzler
def reduce_push_add_ones(src:UOp, r:UOp, view:UOp):
@@ -376,17 +389,7 @@ fix_kernel_ops = PatternMatcher([
(UPat(Ops.LOAD, src=(UPat.var("glbl"), UPat.var("view"))), check_load_st),
])
def finalize(ctx:dict[UOp, Metadata|None], k:UOp) -> UOp|None:
# walk back the kernel sources until we reach a realized parent
new_srcs: list[UOp] = []
metadata = dict.fromkeys(k.arg.metadata)
for s in k.src:
if s.op in DONT_PLACE_IN_KERNEL: new_srcs.append(s)
else:
new_srcs.extend(s.src)
if s.base.op not in {Ops.CONST, Ops.DEVICE} and (m:=ctx.get(s)): metadata[m] = None
if (new_src:=tuple(dedup(new_srcs))) != k.src: return k.replace(src=new_src, arg=Kernel(k.arg.ast, tuple(metadata)))
# once we're done fusing, create the local ast with load/stores
def fix_kernel_ast(k:UOp) -> UOp|None:
if k.arg.ast.op in GroupOp.Meta or all(s.op is Ops.STORE for s in k.arg.ast.src): return None
# replace assign sources with a view of the target buffer
parents_rep: dict[UOp, UOp] = {}
@@ -402,7 +405,7 @@ def finalize(ctx:dict[UOp, Metadata|None], k:UOp) -> UOp|None:
if ast.op is Ops.SINK and not all_same(dev:=[x.device for x in bufs]): raise RuntimeError(f"all buffers must be on the same device: {dev}")
return k.replace(arg=Kernel(ast, k.arg.metadata))
finalize_kernels = PatternMatcher([(UPat(Ops.KERNEL, name="k"), finalize),])
create_ast = PatternMatcher([(UPat(Ops.KERNEL, name="k"), fix_kernel_ast),])
pm_fuse = PatternMatcher([
# FUSE on CONTIGUOUS removes FUSE
@@ -488,10 +491,8 @@ def get_becomes_map(big_sink:UOp) -> dict[UOp, UOp]:
# group into kernels
realize_map = group_realizes(tensor_map[big_sink])
tensor_map = graph_rewrite_map(tensor_map[big_sink], insert_kernels, ctx=KernelContext(realize_map, {v:k.metadata for k,v in tensor_map.items()}),
bottom_up=True, input_map=tensor_map, name="insert_kernels")
tensor_map = graph_rewrite_map(tensor_map[big_sink], finalize_kernels, ctx={v:k.metadata for k,v in tensor_map.items()}, bottom_up=True,
input_map=tensor_map, name="finalize_kernels")
tensor_map = graph_rewrite_map(tensor_map[big_sink], create_kernels, ctx=KernelContext(realize_map, {v:k.metadata for k,v in tensor_map.items()}),
bottom_up=True, input_map=tensor_map, name="create_kernels")
# if a kernel depends on a buffer, and that buffer is later assigned to, make the assign depend on the kernel's assign
kernel_assign: dict[UOp, UOp] = {}
@@ -507,6 +508,9 @@ def get_becomes_map(big_sink:UOp) -> dict[UOp, UOp]:
if assign_rep:
tensor_map = graph_rewrite_map(tensor_map[big_sink], _substitute, ctx=assign_rep, bottom_up=True, input_map=tensor_map, name="fix_assign")
# finally, create the AST for kernels
tensor_map = graph_rewrite_map(tensor_map[big_sink], create_ast, bottom_up=True, input_map=tensor_map, name="create_ast")
# display the final graph
sched_sink = tensor_map[big_sink]
if getenv("VIZ"): graph_rewrite(sched_sink, PatternMatcher([]), name="View Kernel Graph")