mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
add names to grouper rewrites + cleanups [pr] (#9881)
* add names to grouper rewrites + cleanups [pr] * assign_targets
This commit is contained in:
@@ -78,6 +78,16 @@ sym = symbolic_simple+PatternMatcher([
|
||||
# substitute BITCAST/CONTIGUOUS with BUFFER_VIEW on DISK
|
||||
(UPat((Ops.BITCAST, Ops.CONTIGUOUS), src=(UPat.var("x"),), name="t"),
|
||||
lambda x,t: UOp(Ops.BUFFER_VIEW, t.dtype, (x.base,), (t.size, x.st.views[0].offset)).reshape(t.shape) if x.device.startswith("DISK") else None),
|
||||
# put CAST to smaller dtype before EXPAND
|
||||
(UPat(Ops.CAST, name="cast", src=(UPat(Ops.VIEW, name="vm"),)), lambda cast,vm: vm.base.cast(cast.dtype).view(vm.st)
|
||||
if (not getenv("CAST_AFTER_EXPAND") or vm.base.op is not Ops.BUFFER) and cast.dtype.itemsize <= vm.dtype.itemsize
|
||||
and resolve(prod(vm.shape) > vm.st.real_size()) else None),
|
||||
# store a shrink before COPY, otherwise view after the COPY
|
||||
(UPat(Ops.COPY, src=(UPat(), UPat(Ops.VIEW, name="v")), name="copy"), lambda copy,v: v.contiguous().copy_to_device(copy.device) \
|
||||
if prod(v.shape) < prod(v.base.shape) else v.base.copy_to_device(copy.device, clone=copy.arg).view(v.st)),
|
||||
# put UnaryOps before EXPANDs
|
||||
(UPat(GroupOp.Unary, src=(UPat(Ops.VIEW, src=(UPat.var("inp"),), name="v"),), name="alu"),
|
||||
lambda inp,v,alu: inp.alu(alu.op).view(v.st) if resolve(prod(alu.shape) > v.st.real_size()) else None),
|
||||
])
|
||||
|
||||
# support for using a contiguous permuted view instead of the parent view if one exists
|
||||
@@ -90,32 +100,14 @@ replace_contiguous = PatternMatcher([
|
||||
(UPat(GroupOp.ALU, name="alu"), lambda ctx,alu: alu.replace(src=new_src) if (new_src:=tuple(ctx.get(s, s) for s in alu.src)) != alu.src else None),
|
||||
])
|
||||
|
||||
# reorder view
|
||||
# **** Grouper decides which of the UOps realize
|
||||
|
||||
DONT_PUSH_VIEWS = {Ops.BUFFER, Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.ASSIGN, Ops.SINK, Ops.CONTIGUOUS, Ops.COPY}
|
||||
|
||||
reorder_view = PatternMatcher([
|
||||
# put CAST to smaller dtype before EXPAND
|
||||
(UPat(Ops.CAST, name="cast", src=(UPat(Ops.VIEW, name="vm"),)), lambda cast,vm: vm.base.cast(cast.dtype).view(vm.st)
|
||||
if (not getenv("CAST_AFTER_EXPAND") or vm.base.op is not Ops.BUFFER) and cast.dtype.itemsize <= vm.dtype.itemsize
|
||||
and resolve(prod(vm.shape) > vm.st.real_size()) else None),
|
||||
# store a shrink before COPY, otherwise view after the COPY
|
||||
(UPat(Ops.COPY, src=(UPat(), UPat(Ops.VIEW, name="v")), name="copy"), lambda copy,v: v.contiguous().copy_to_device(copy.device) \
|
||||
if prod(v.shape) < prod(v.base.shape) else v.base.copy_to_device(copy.device, clone=copy.arg).view(v.st)),
|
||||
# put UnaryOps before EXPANDs
|
||||
(UPat(GroupOp.Unary, src=(UPat(Ops.VIEW, src=(UPat.var("inp"),), name="v"),), name="alu"),
|
||||
lambda inp,v,alu: inp.alu(alu.op).view(v.st) if resolve(prod(alu.shape) > v.st.real_size()) else None),
|
||||
# put CAST after expanding BUFFER
|
||||
(UPat(Ops.VIEW, src=(UPat(Ops.CAST, src=(UPat.var("x"),)),), name="v"), lambda x,v: x.view(x.st+v.st).cast(v.dtype) if getenv("CAST_AFTER_EXPAND")
|
||||
and x.base.op is Ops.BUFFER and resolve(prod(v.shape) > prod(x.shape)) else None),
|
||||
])
|
||||
|
||||
# **** UOp realization
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GrouperContext:
|
||||
assigns: dict[UOp, UOp] # maps realized buffers to assigns
|
||||
realizes: dict[UOp, None] # all the simplified tensor uops we realize
|
||||
assigns: dict[UOp, None] # all the buffers that are assigned to
|
||||
realizes: dict[UOp, None] # all the tensor uops we realize
|
||||
children: defaultdict[UOp, dict[UOp, None]] # children graph of tensor uops
|
||||
|
||||
def realize(ctx:GrouperContext, tr:UOp) -> None: ctx.realizes[tr] = None
|
||||
@@ -141,7 +133,7 @@ do_realize = PatternMatcher([
|
||||
])
|
||||
|
||||
def append_uop(ctx:GrouperContext, u:UOp) -> None:
|
||||
if u.op is Ops.ASSIGN: ctx.assigns[u.buf_uop] = u
|
||||
if u.op is Ops.ASSIGN: ctx.assigns[u.buf_uop] = None
|
||||
for s in u.src: ctx.children[s.base][u] = None
|
||||
create_ctx = PatternMatcher([(UPat(GroupOp.All-{Ops.SINK, Ops.VIEW}, name="u"), append_uop)])
|
||||
|
||||
@@ -171,7 +163,7 @@ def group_realizes(sink:UOp) -> dict[UOp, None]:
|
||||
double_reduces: list[UOp] = []
|
||||
for r in sink.toposort:
|
||||
if r.op is not Ops.REDUCE_AXIS: continue
|
||||
if r.op is Ops.REDUCE_AXIS and len(r.arg) == 3 and r.arg[2] is True: continue
|
||||
if len(r.arg) == 3 and r.arg[2] is True: continue
|
||||
if FUSE_CONV_BW and r.src[0].base.op is Ops.REDUCE_AXIS and r.src[0] is not r.src[0].base: double_reduces.append(r)
|
||||
if r in ctx.realizes: continue
|
||||
group: dict[UOp, None] = {}
|
||||
@@ -183,11 +175,11 @@ def group_realizes(sink:UOp) -> dict[UOp, None]:
|
||||
# can only have one output
|
||||
if not forced_realize and len(group) > 1: forced_realize = True
|
||||
# can only fuse assign if no other assign_target is used in the kernel
|
||||
if not forced_realize and any(x.op is Ops.ASSIGN for x in group):
|
||||
if not forced_realize and (assign_targets:={x.buf_uop for x in group if x.op is Ops.ASSIGN}):
|
||||
parents = deque((r, *group))
|
||||
while parents and not forced_realize:
|
||||
p = parents.pop().base
|
||||
if (assign:=ctx.assigns.get(p)) is not None and assign not in group: forced_realize, can_chase = True, False
|
||||
if p.op is Ops.BUFFER and p in ctx.assigns and p not in assign_targets: forced_realize, can_chase = True, False
|
||||
if p in ctx.realizes: continue
|
||||
parents.extend(p.src)
|
||||
if forced_realize or not group:
|
||||
@@ -432,7 +424,7 @@ if CAPTURE_PROCESS_REPLAY:
|
||||
@track_rewrites(name_fxn=lambda ret: f"Schedule {pluralize('Kernel', len({u.base.src[1] for u in ret.values() if u.base.op is Ops.ASSIGN}))}")
|
||||
def get_becomes_map(big_sink:UOp) -> dict[UOp, UOp]:
|
||||
# merge_views + simplify
|
||||
tensor_map = graph_rewrite_map(big_sink, merge_views+sym+reorder_view+replace_contiguous+pm_fuse, ctx={})
|
||||
tensor_map = graph_rewrite_map(big_sink, merge_views+sym+replace_contiguous+pm_fuse, ctx={})
|
||||
|
||||
# display the cleaned up tensor graph
|
||||
if getenv("VIZ"): graph_rewrite(tensor_map[big_sink], PatternMatcher([]), name="View Tensor Graph")
|
||||
@@ -441,8 +433,8 @@ def get_becomes_map(big_sink:UOp) -> dict[UOp, UOp]:
|
||||
sink = tensor_map[big_sink]
|
||||
realize_map = group_realizes(sink)
|
||||
tensor_map = graph_rewrite_map(sink, create_kernels, KernelContext(realize_map, {v:k.metadata for k,v in tensor_map.items()}),
|
||||
bottom_up=True, input_map=tensor_map)
|
||||
tensor_map = graph_rewrite_map(tensor_map[big_sink], create_ast, bottom_up=True, input_map=tensor_map)
|
||||
bottom_up=True, input_map=tensor_map, name="create_kernels")
|
||||
tensor_map = graph_rewrite_map(tensor_map[big_sink], create_ast, bottom_up=True, input_map=tensor_map, name="create_ast")
|
||||
|
||||
# verify Kernels match the spec
|
||||
sched_sink = tensor_map[big_sink]
|
||||
|
||||
Reference in New Issue
Block a user