add names to grouper rewrites + cleanups [pr] (#9881)

* add names to grouper rewrites + cleanups [pr]

* assign_targets
This commit is contained in:
qazal
2025-04-14 19:47:36 +08:00
committed by GitHub
parent ca8aaadd00
commit bf099520a4

View File

@@ -78,6 +78,16 @@ sym = symbolic_simple+PatternMatcher([
# substitute BITCAST/CONTIGUOUS with BUFFER_VIEW on DISK
(UPat((Ops.BITCAST, Ops.CONTIGUOUS), src=(UPat.var("x"),), name="t"),
lambda x,t: UOp(Ops.BUFFER_VIEW, t.dtype, (x.base,), (t.size, x.st.views[0].offset)).reshape(t.shape) if x.device.startswith("DISK") else None),
# put CAST to smaller dtype before EXPAND
(UPat(Ops.CAST, name="cast", src=(UPat(Ops.VIEW, name="vm"),)), lambda cast,vm: vm.base.cast(cast.dtype).view(vm.st)
if (not getenv("CAST_AFTER_EXPAND") or vm.base.op is not Ops.BUFFER) and cast.dtype.itemsize <= vm.dtype.itemsize
and resolve(prod(vm.shape) > vm.st.real_size()) else None),
# store a shrink before COPY, otherwise view after the COPY
(UPat(Ops.COPY, src=(UPat(), UPat(Ops.VIEW, name="v")), name="copy"), lambda copy,v: v.contiguous().copy_to_device(copy.device) \
if prod(v.shape) < prod(v.base.shape) else v.base.copy_to_device(copy.device, clone=copy.arg).view(v.st)),
# put UnaryOps before EXPANDs
(UPat(GroupOp.Unary, src=(UPat(Ops.VIEW, src=(UPat.var("inp"),), name="v"),), name="alu"),
lambda inp,v,alu: inp.alu(alu.op).view(v.st) if resolve(prod(alu.shape) > v.st.real_size()) else None),
])
# support for using a contiguous permuted view instead of the parent view if one exists
@@ -90,32 +100,14 @@ replace_contiguous = PatternMatcher([
(UPat(GroupOp.ALU, name="alu"), lambda ctx,alu: alu.replace(src=new_src) if (new_src:=tuple(ctx.get(s, s) for s in alu.src)) != alu.src else None),
])
# reorder view
# **** Grouper decides which of the UOps realize
DONT_PUSH_VIEWS = {Ops.BUFFER, Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.ASSIGN, Ops.SINK, Ops.CONTIGUOUS, Ops.COPY}
reorder_view = PatternMatcher([
# put CAST to smaller dtype before EXPAND
(UPat(Ops.CAST, name="cast", src=(UPat(Ops.VIEW, name="vm"),)), lambda cast,vm: vm.base.cast(cast.dtype).view(vm.st)
if (not getenv("CAST_AFTER_EXPAND") or vm.base.op is not Ops.BUFFER) and cast.dtype.itemsize <= vm.dtype.itemsize
and resolve(prod(vm.shape) > vm.st.real_size()) else None),
# store a shrink before COPY, otherwise view after the COPY
(UPat(Ops.COPY, src=(UPat(), UPat(Ops.VIEW, name="v")), name="copy"), lambda copy,v: v.contiguous().copy_to_device(copy.device) \
if prod(v.shape) < prod(v.base.shape) else v.base.copy_to_device(copy.device, clone=copy.arg).view(v.st)),
# put UnaryOps before EXPANDs
(UPat(GroupOp.Unary, src=(UPat(Ops.VIEW, src=(UPat.var("inp"),), name="v"),), name="alu"),
lambda inp,v,alu: inp.alu(alu.op).view(v.st) if resolve(prod(alu.shape) > v.st.real_size()) else None),
# put CAST after expanding BUFFER
(UPat(Ops.VIEW, src=(UPat(Ops.CAST, src=(UPat.var("x"),)),), name="v"), lambda x,v: x.view(x.st+v.st).cast(v.dtype) if getenv("CAST_AFTER_EXPAND")
and x.base.op is Ops.BUFFER and resolve(prod(v.shape) > prod(x.shape)) else None),
])
# **** UOp realization
@dataclass(frozen=True)
class GrouperContext:
assigns: dict[UOp, UOp] # maps realized buffers to assigns
realizes: dict[UOp, None] # all the simplified tensor uops we realize
assigns: dict[UOp, None] # all the buffers that are assigned to
realizes: dict[UOp, None] # all the tensor uops we realize
children: defaultdict[UOp, dict[UOp, None]] # children graph of tensor uops
def realize(ctx:GrouperContext, tr:UOp) -> None: ctx.realizes[tr] = None
@@ -141,7 +133,7 @@ do_realize = PatternMatcher([
])
def append_uop(ctx:GrouperContext, u:UOp) -> None:
if u.op is Ops.ASSIGN: ctx.assigns[u.buf_uop] = u
if u.op is Ops.ASSIGN: ctx.assigns[u.buf_uop] = None
for s in u.src: ctx.children[s.base][u] = None
create_ctx = PatternMatcher([(UPat(GroupOp.All-{Ops.SINK, Ops.VIEW}, name="u"), append_uop)])
@@ -171,7 +163,7 @@ def group_realizes(sink:UOp) -> dict[UOp, None]:
double_reduces: list[UOp] = []
for r in sink.toposort:
if r.op is not Ops.REDUCE_AXIS: continue
if r.op is Ops.REDUCE_AXIS and len(r.arg) == 3 and r.arg[2] is True: continue
if len(r.arg) == 3 and r.arg[2] is True: continue
if FUSE_CONV_BW and r.src[0].base.op is Ops.REDUCE_AXIS and r.src[0] is not r.src[0].base: double_reduces.append(r)
if r in ctx.realizes: continue
group: dict[UOp, None] = {}
@@ -183,11 +175,11 @@ def group_realizes(sink:UOp) -> dict[UOp, None]:
# can only have one output
if not forced_realize and len(group) > 1: forced_realize = True
# can only fuse assign if no other assign_target is used in the kernel
if not forced_realize and any(x.op is Ops.ASSIGN for x in group):
if not forced_realize and (assign_targets:={x.buf_uop for x in group if x.op is Ops.ASSIGN}):
parents = deque((r, *group))
while parents and not forced_realize:
p = parents.pop().base
if (assign:=ctx.assigns.get(p)) is not None and assign not in group: forced_realize, can_chase = True, False
if p.op is Ops.BUFFER and p in ctx.assigns and p not in assign_targets: forced_realize, can_chase = True, False
if p in ctx.realizes: continue
parents.extend(p.src)
if forced_realize or not group:
@@ -432,7 +424,7 @@ if CAPTURE_PROCESS_REPLAY:
@track_rewrites(name_fxn=lambda ret: f"Schedule {pluralize('Kernel', len({u.base.src[1] for u in ret.values() if u.base.op is Ops.ASSIGN}))}")
def get_becomes_map(big_sink:UOp) -> dict[UOp, UOp]:
# merge_views + simplify
tensor_map = graph_rewrite_map(big_sink, merge_views+sym+reorder_view+replace_contiguous+pm_fuse, ctx={})
tensor_map = graph_rewrite_map(big_sink, merge_views+sym+replace_contiguous+pm_fuse, ctx={})
# display the cleaned up tensor graph
if getenv("VIZ"): graph_rewrite(tensor_map[big_sink], PatternMatcher([]), name="View Tensor Graph")
@@ -441,8 +433,8 @@ def get_becomes_map(big_sink:UOp) -> dict[UOp, UOp]:
sink = tensor_map[big_sink]
realize_map = group_realizes(sink)
tensor_map = graph_rewrite_map(sink, create_kernels, KernelContext(realize_map, {v:k.metadata for k,v in tensor_map.items()}),
bottom_up=True, input_map=tensor_map)
tensor_map = graph_rewrite_map(tensor_map[big_sink], create_ast, bottom_up=True, input_map=tensor_map)
bottom_up=True, input_map=tensor_map, name="create_kernels")
tensor_map = graph_rewrite_map(tensor_map[big_sink], create_ast, bottom_up=True, input_map=tensor_map, name="create_ast")
# verify Kernels match the spec
sched_sink = tensor_map[big_sink]