From bf099520a4f98ad549bd68aeaf45472c6b9c92cd Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Mon, 14 Apr 2025 19:47:36 +0800 Subject: [PATCH] add names to grouper rewrites + cleanups [pr] (#9881) * add names to grouper rewrites + cleanups [pr] * assign_targets --- tinygrad/engine/grouper.py | 48 ++++++++++++++++---------------------- 1 file changed, 20 insertions(+), 28 deletions(-) diff --git a/tinygrad/engine/grouper.py b/tinygrad/engine/grouper.py index 78f82835a7..442696f116 100644 --- a/tinygrad/engine/grouper.py +++ b/tinygrad/engine/grouper.py @@ -78,6 +78,16 @@ sym = symbolic_simple+PatternMatcher([ # substitute BITCAST/CONTIGUOUS with BUFFER_VIEW on DISK (UPat((Ops.BITCAST, Ops.CONTIGUOUS), src=(UPat.var("x"),), name="t"), lambda x,t: UOp(Ops.BUFFER_VIEW, t.dtype, (x.base,), (t.size, x.st.views[0].offset)).reshape(t.shape) if x.device.startswith("DISK") else None), + # put CAST to smaller dtype before EXPAND + (UPat(Ops.CAST, name="cast", src=(UPat(Ops.VIEW, name="vm"),)), lambda cast,vm: vm.base.cast(cast.dtype).view(vm.st) + if (not getenv("CAST_AFTER_EXPAND") or vm.base.op is not Ops.BUFFER) and cast.dtype.itemsize <= vm.dtype.itemsize + and resolve(prod(vm.shape) > vm.st.real_size()) else None), + # store a shrink before COPY, otherwise view after the COPY + (UPat(Ops.COPY, src=(UPat(), UPat(Ops.VIEW, name="v")), name="copy"), lambda copy,v: v.contiguous().copy_to_device(copy.device) \ + if prod(v.shape) < prod(v.base.shape) else v.base.copy_to_device(copy.device, clone=copy.arg).view(v.st)), + # put UnaryOps before EXPANDs + (UPat(GroupOp.Unary, src=(UPat(Ops.VIEW, src=(UPat.var("inp"),), name="v"),), name="alu"), + lambda inp,v,alu: inp.alu(alu.op).view(v.st) if resolve(prod(alu.shape) > v.st.real_size()) else None), ]) # support for using a contiguous permuted view instead of the parent view if one exists @@ -90,32 +100,14 @@ replace_contiguous = PatternMatcher([ (UPat(GroupOp.ALU, name="alu"), lambda ctx,alu: alu.replace(src=new_src) if (new_src:=tuple(ctx.get(s, s) for s in alu.src)) != alu.src else None), ]) -# reorder view +# **** Grouper decides which of the UOps realize DONT_PUSH_VIEWS = {Ops.BUFFER, Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.ASSIGN, Ops.SINK, Ops.CONTIGUOUS, Ops.COPY} -reorder_view = PatternMatcher([ - # put CAST to smaller dtype before EXPAND - (UPat(Ops.CAST, name="cast", src=(UPat(Ops.VIEW, name="vm"),)), lambda cast,vm: vm.base.cast(cast.dtype).view(vm.st) - if (not getenv("CAST_AFTER_EXPAND") or vm.base.op is not Ops.BUFFER) and cast.dtype.itemsize <= vm.dtype.itemsize - and resolve(prod(vm.shape) > vm.st.real_size()) else None), - # store a shrink before COPY, otherwise view after the COPY - (UPat(Ops.COPY, src=(UPat(), UPat(Ops.VIEW, name="v")), name="copy"), lambda copy,v: v.contiguous().copy_to_device(copy.device) \ - if prod(v.shape) < prod(v.base.shape) else v.base.copy_to_device(copy.device, clone=copy.arg).view(v.st)), - # put UnaryOps before EXPANDs - (UPat(GroupOp.Unary, src=(UPat(Ops.VIEW, src=(UPat.var("inp"),), name="v"),), name="alu"), - lambda inp,v,alu: inp.alu(alu.op).view(v.st) if resolve(prod(alu.shape) > v.st.real_size()) else None), - # put CAST after expanding BUFFER - (UPat(Ops.VIEW, src=(UPat(Ops.CAST, src=(UPat.var("x"),)),), name="v"), lambda x,v: x.view(x.st+v.st).cast(v.dtype) if getenv("CAST_AFTER_EXPAND") - and x.base.op is Ops.BUFFER and resolve(prod(v.shape) > prod(x.shape)) else None), -]) - -# **** UOp realization - @dataclass(frozen=True) class GrouperContext: - assigns: dict[UOp, UOp] # maps realized buffers to assigns - realizes: dict[UOp, None] # all the simplified tensor uops we realize + assigns: dict[UOp, None] # all the buffers that are assigned to + realizes: dict[UOp, None] # all the tensor uops we realize children: defaultdict[UOp, dict[UOp, None]] # children graph of tensor uops def realize(ctx:GrouperContext, tr:UOp) -> None: ctx.realizes[tr] = None @@ -141,7 +133,7 @@ do_realize = PatternMatcher([ ]) def append_uop(ctx:GrouperContext, u:UOp) -> None: - if u.op is Ops.ASSIGN: ctx.assigns[u.buf_uop] = u + if u.op is Ops.ASSIGN: ctx.assigns[u.buf_uop] = None for s in u.src: ctx.children[s.base][u] = None create_ctx = PatternMatcher([(UPat(GroupOp.All-{Ops.SINK, Ops.VIEW}, name="u"), append_uop)]) @@ -171,7 +163,7 @@ def group_realizes(sink:UOp) -> dict[UOp, None]: double_reduces: list[UOp] = [] for r in sink.toposort: if r.op is not Ops.REDUCE_AXIS: continue - if r.op is Ops.REDUCE_AXIS and len(r.arg) == 3 and r.arg[2] is True: continue + if len(r.arg) == 3 and r.arg[2] is True: continue if FUSE_CONV_BW and r.src[0].base.op is Ops.REDUCE_AXIS and r.src[0] is not r.src[0].base: double_reduces.append(r) if r in ctx.realizes: continue group: dict[UOp, None] = {} @@ -183,11 +175,11 @@ def group_realizes(sink:UOp) -> dict[UOp, None]: # can only have one output if not forced_realize and len(group) > 1: forced_realize = True # can only fuse assign if no other assign_target is used in the kernel - if not forced_realize and any(x.op is Ops.ASSIGN for x in group): + if not forced_realize and (assign_targets:={x.buf_uop for x in group if x.op is Ops.ASSIGN}): parents = deque((r, *group)) while parents and not forced_realize: p = parents.pop().base - if (assign:=ctx.assigns.get(p)) is not None and assign not in group: forced_realize, can_chase = True, False + if p.op is Ops.BUFFER and p in ctx.assigns and p not in assign_targets: forced_realize, can_chase = True, False if p in ctx.realizes: continue parents.extend(p.src) if forced_realize or not group: @@ -432,7 +424,7 @@ if CAPTURE_PROCESS_REPLAY: @track_rewrites(name_fxn=lambda ret: f"Schedule {pluralize('Kernel', len({u.base.src[1] for u in ret.values() if u.base.op is Ops.ASSIGN}))}") def get_becomes_map(big_sink:UOp) -> dict[UOp, UOp]: # merge_views + simplify - tensor_map = graph_rewrite_map(big_sink, merge_views+sym+reorder_view+replace_contiguous+pm_fuse, ctx={}) + tensor_map = graph_rewrite_map(big_sink, merge_views+sym+replace_contiguous+pm_fuse, ctx={}) # display the cleaned up tensor graph if getenv("VIZ"): graph_rewrite(tensor_map[big_sink], PatternMatcher([]), name="View Tensor Graph") @@ -441,8 +433,8 @@ def get_becomes_map(big_sink:UOp) -> dict[UOp, UOp]: sink = tensor_map[big_sink] realize_map = group_realizes(sink) tensor_map = graph_rewrite_map(sink, create_kernels, KernelContext(realize_map, {v:k.metadata for k,v in tensor_map.items()}), - bottom_up=True, input_map=tensor_map) - tensor_map = graph_rewrite_map(tensor_map[big_sink], create_ast, bottom_up=True, input_map=tensor_map) + bottom_up=True, input_map=tensor_map, name="create_kernels") + tensor_map = graph_rewrite_map(tensor_map[big_sink], create_ast, bottom_up=True, input_map=tensor_map, name="create_ast") # verify Kernels match the spec sched_sink = tensor_map[big_sink]